diff --git a/agent/src/models.ts b/agent/src/models.ts index 13134026..186ac47b 100644 --- a/agent/src/models.ts +++ b/agent/src/models.ts @@ -51,6 +51,10 @@ export const HeartbeatModel = z.object({ interval: z.number().int().positive().default(30), }) +export const LoopDetectionModel = z.object({ + enabled: z.boolean().default(false), +}).optional() + export const AttachmentModel = z.object({ contentHash: z.string().optional(), type: z.string().min(1, 'Attachment type is required'), @@ -93,4 +97,4 @@ export const InboxItemModel = z.object({ header: z.record(z.string(), z.unknown()).default({}), content: z.string().default(''), createdAt: z.string(), -}) \ No newline at end of file +}) diff --git a/agent/src/modules/chat.ts b/agent/src/modules/chat.ts index b714d0ae..636b444c 100644 --- a/agent/src/modules/chat.ts +++ b/agent/src/modules/chat.ts @@ -3,7 +3,7 @@ import z from 'zod' import { createAgent, ModelConfig, allActions } from '@memoh/agent' import { createAuthFetcher, getBaseUrl } from '../index' import { bearerMiddleware } from '../middlewares/bearer' -import { AgentSkillModel, AllowedActionModel, AttachmentModel, HeartbeatModel, IdentityContextModel, InboxItemModel, MCPConnectionModel, ModelConfigModel, ScheduleModel } from '../models' +import { AgentSkillModel, AllowedActionModel, AttachmentModel, HeartbeatModel, IdentityContextModel, InboxItemModel, LoopDetectionModel, MCPConnectionModel, ModelConfigModel, ScheduleModel } from '../models' import { sseChunked } from '../utils/sse' const AgentModel = z.object({ @@ -19,6 +19,7 @@ const AgentModel = z.object({ attachments: z.array(AttachmentModel).optional().default([]), mcpConnections: z.array(MCPConnectionModel).optional().default([]), inbox: z.array(InboxItemModel).optional().default([]), + loopDetection: LoopDetectionModel, }) export const chatModule = new Elysia({ prefix: '/chat' }) @@ -41,6 +42,7 @@ export const chatModule = new Elysia({ prefix: '/chat' }) skills: body.usableSkills, mcpConnections: body.mcpConnections, inbox: body.inbox, + loopDetection: body.loopDetection, }, authFetcher) return ask({ query: body.query, @@ -72,6 +74,7 @@ export const chatModule = new Elysia({ prefix: '/chat' }) skills: body.usableSkills, mcpConnections: body.mcpConnections, inbox: body.inbox, + loopDetection: body.loopDetection, }, authFetcher) for await (const action of stream({ query: body.query, @@ -113,6 +116,7 @@ export const chatModule = new Elysia({ prefix: '/chat' }) skills: body.usableSkills, mcpConnections: body.mcpConnections, inbox: body.inbox, + loopDetection: body.loopDetection, }, authFetcher) return triggerSchedule({ schedule: body.schedule, @@ -126,20 +130,22 @@ export const chatModule = new Elysia({ prefix: '/chat' }) }) .post('/trigger-heartbeat', async ({ body, bearer }) => { console.log('trigger-heartbeat', body) - const authFetcher = createAuthFetcher(bearer) + const auth = { + bearer: bearer!, + baseUrl: getBaseUrl(), + } + const authFetcher = createAuthFetcher(auth) const { triggerHeartbeat } = createAgent({ model: body.model as ModelConfig, activeContextTime: body.activeContextTime, channels: body.channels, currentChannel: body.currentChannel, identity: body.identity, - auth: { - bearer: bearer!, - baseUrl: getBaseUrl(), - }, + auth, skills: body.usableSkills, mcpConnections: body.mcpConnections, inbox: body.inbox, + loopDetection: body.loopDetection, }, authFetcher) return triggerHeartbeat({ heartbeat: body.heartbeat, diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go index cd0fd367..92b66264 100644 --- a/internal/conversation/flow/resolver.go +++ b/internal/conversation/flow/resolver.go @@ -180,6 +180,10 @@ type gatewayInboxItem struct { CreatedAt string `json:"createdAt"` } +type gatewayLoopDetectionConfig struct { + Enabled bool `json:"enabled"` +} + type gatewayRequest struct { Model gatewayModelConfig `json:"model"` ActiveContextTime int `json:"activeContextTime"` @@ -193,6 +197,7 @@ type gatewayRequest struct { Identity gatewayIdentity `json:"identity"` Attachments []any `json:"attachments"` Inbox []gatewayInboxItem `json:"inbox,omitempty"` + LoopDetection *gatewayLoopDetectionConfig `json:"loopDetection,omitempty"` } type gatewayResponse struct { @@ -298,6 +303,7 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r if err != nil { return resolvedContext{}, err } + loopDetectionEnabled := r.loadBotLoopDetectionEnabled(ctx, req.BotID) // Check chat-level model override. var chatSettings conversation.Settings @@ -466,8 +472,9 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r ConversationType: strings.TrimSpace(req.ConversationType), SessionToken: req.ChatToken, }, - Attachments: attachments, - Inbox: inboxGatewayItems, + Attachments: attachments, + Inbox: inboxGatewayItems, + LoopDetection: &gatewayLoopDetectionConfig{Enabled: loopDetectionEnabled}, } return resolvedContext{payload: payload, model: chatModel, provider: provider, inboxItemIDs: inboxItemIDs}, nil @@ -1835,6 +1842,48 @@ func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (settings. return r.settingsService.GetBot(ctx, botID) } +func (r *Resolver) loadBotLoopDetectionEnabled(ctx context.Context, botID string) bool { + if r.queries == nil { + return false + } + botUUID, err := db.ParseUUID(botID) + if err != nil { + return false + } + row, err := r.queries.GetBotByID(ctx, botUUID) + if err != nil { + r.logger.Debug("failed to load bot metadata for loop detection", + slog.String("bot_id", botID), + slog.Any("error", err), + ) + return false + } + return parseLoopDetectionEnabledFromMetadata(row.Metadata) +} + +func parseLoopDetectionEnabledFromMetadata(payload []byte) bool { + if len(payload) == 0 { + return false + } + var metadata map[string]any + if err := json.Unmarshal(payload, &metadata); err != nil || metadata == nil { + return false + } + features, ok := metadata["features"].(map[string]any) + if !ok { + return false + } + loopDetection, ok := features["loop_detection"].(map[string]any) + if !ok { + return false + } + enabled, ok := loopDetection["enabled"].(bool) + if !ok { + return false + } + return enabled +} + // --- utility --- func normalizeClientType(clientType string) (string, error) { diff --git a/internal/conversation/flow/resolver_loop_detection_test.go b/internal/conversation/flow/resolver_loop_detection_test.go new file mode 100644 index 00000000..73fc0eb1 --- /dev/null +++ b/internal/conversation/flow/resolver_loop_detection_test.go @@ -0,0 +1,51 @@ +package flow + +import "testing" + +func TestParseLoopDetectionEnabledFromMetadata(t *testing.T) { + tests := []struct { + name string + payload []byte + expected bool + }{ + { + name: "empty payload defaults to false", + payload: nil, + expected: false, + }, + { + name: "invalid json defaults to false", + payload: []byte("{"), + expected: false, + }, + { + name: "missing nested path defaults to false", + payload: []byte(`{"features":{}}`), + expected: false, + }, + { + name: "explicit false", + payload: []byte(`{"features":{"loop_detection":{"enabled":false}}}`), + expected: false, + }, + { + name: "explicit true", + payload: []byte(`{"features":{"loop_detection":{"enabled":true}}}`), + expected: true, + }, + { + name: "non-boolean value defaults to false", + payload: []byte(`{"features":{"loop_detection":{"enabled":"true"}}}`), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseLoopDetectionEnabledFromMetadata(tt.payload) + if got != tt.expected { + t.Fatalf("expected %v, got %v", tt.expected, got) + } + }) + } +} diff --git a/packages/agent/src/agent.ts b/packages/agent/src/agent.ts index be8fa4c5..db8223da 100644 --- a/packages/agent/src/agent.ts +++ b/packages/agent/src/agent.ts @@ -33,9 +33,20 @@ import { getMCPTools } from './tools/mcp' import { getTools } from './tools' import { buildIdentityHeaders } from './utils/headers' import { createFS } from './utils' +import { createTextLoopGuard, createTextLoopProbeBuffer } from './sential' +import { createToolLoopGuardedTools } from './tool-loop' const ANTHROPIC_BUDGET: Record = { low: 5000, medium: 16000, high: 50000 } const GOOGLE_BUDGET: Record = { low: 5000, medium: 16000, high: 50000 } +const LOOP_DETECTED_ABORT_MESSAGE = 'loop detected, stream aborted' +const LOOP_DETECTED_STREAK_THRESHOLD = 3 +const LOOP_DETECTED_MIN_NEW_GRAMS_PER_CHUNK = 8 +const LOOP_DETECTED_PROBE_CHARS = 256 +const TOOL_LOOP_DETECTED_ABORT_MESSAGE = 'tool loop detected, stream aborted' +const TOOL_LOOP_REPEAT_THRESHOLD = 5 +const TOOL_LOOP_WARNINGS_BEFORE_ABORT = 1 +const TOOL_LOOP_WARNING_KEY = '__memoh_tool_loop_warning' +const TOOL_LOOP_WARNING_TEXT = '[MEMOH_TOOL_LOOP_WARNING] Repeated identical tool invocation (same tool + arguments) was detected more than 5 times. Stop looping this tool and either summarize current results or change strategy.' const buildProviderOptions = (config: ModelConfig): Record> | undefined => { if (!config.reasoning?.enabled) return undefined @@ -93,12 +104,14 @@ export const createAgent = ( }, auth, inbox = [], + loopDetection = { enabled: false }, }: AgentParams, fetch: AuthFetcher, ) => { const model = createModel(modelConfig) // eslint-disable-next-line @typescript-eslint/no-explicit-any const providerOptions = buildProviderOptions(modelConfig) as any + const loopDetectionEnabled = loopDetection?.enabled === true const enabledSkills: AgentSkill[] = [] const fs = createFS({ fetch, botId: identity.botId }) @@ -194,27 +207,102 @@ export const createAgent = ( return userMessage } + const createNonStreamTextLoopInspector = () => { + if (!loopDetectionEnabled) { + return null + } + const textLoopGuard = createTextLoopGuard({ + consecutiveHitsToAbort: LOOP_DETECTED_STREAK_THRESHOLD, + minNewGramsPerChunk: LOOP_DETECTED_MIN_NEW_GRAMS_PER_CHUNK, + }) + return (text: string) => { + const result = textLoopGuard.inspect(text) + if (result.abort) { + throw new Error(LOOP_DETECTED_ABORT_MESSAGE) + } + } + } + + const buildGuardedTools = ( + tools: ToolSet, + onAbortToolCall: (toolCallId: string) => void = () => {}, + ): ToolSet => { + if (!loopDetectionEnabled) { + return tools + } + return createToolLoopGuardedTools(tools, { + repeatThreshold: TOOL_LOOP_REPEAT_THRESHOLD, + warningsBeforeAbort: TOOL_LOOP_WARNINGS_BEFORE_ABORT, + onAbortToolCall, + warningKey: TOOL_LOOP_WARNING_KEY, + warningText: TOOL_LOOP_WARNING_TEXT, + }) + } + + const runTextGeneration = async ({ + messages, + systemPrompt, + prepareStep, + }: { + messages: ModelMessage[] + systemPrompt: string + prepareStep?: () => { system: string } + }) => { + const { tools, close } = await getAgentTools() + let shouldAbortForToolLoop = false + const guardedTools = buildGuardedTools(tools, () => { + shouldAbortForToolLoop = true + }) + const inspectTextLoop = createNonStreamTextLoopInspector() + let runError: unknown = null + try { + return await generateText({ + model, + messages, + system: systemPrompt, + ...(providerOptions && { providerOptions }), + stopWhen: stepCountIs(Infinity), + ...(prepareStep && { prepareStep }), + ...(loopDetectionEnabled && { + onStepFinish: ({ text }: { text: string }) => { + if (shouldAbortForToolLoop) { + throw new Error(TOOL_LOOP_DETECTED_ABORT_MESSAGE) + } + if (inspectTextLoop) { + inspectTextLoop(text) + } + }, + }), + tools: guardedTools, + }) + } catch (error) { + runError = error + throw error + } finally { + try { + await close() + } catch (closeError) { + if (runError == null) { + throw closeError + } + console.error(closeError) + } + } + } + const ask = async (input: AgentInput) => { const userPrompt = generateUserPrompt(input) const messages = [...input.messages, userPrompt] input.skills.forEach((skill) => enableSkill(skill)) const systemPrompt = await generateSystemPrompt() - const { tools, close } = await getAgentTools() - const { response, reasoning, text, usage, steps } = await generateText({ - model, + const { response, reasoning, text, usage, steps } = await runTextGeneration({ messages, - system: systemPrompt, - ...(providerOptions && { providerOptions }), - stopWhen: stepCountIs(Infinity), + systemPrompt, prepareStep: () => { return { system: systemPrompt, } }, - onFinish: async () => { - await close() - }, - tools, }) const stepUsages = buildStepUsages(steps) const { cleanedText, attachments: textAttachments } = @@ -257,22 +345,14 @@ export const createAgent = ( }) } const messages = [...params.messages, userPrompt] - const { tools, close } = await getAgentTools() - const { response, reasoning, text, usage, steps } = await generateText({ - model, + const { response, reasoning, text, usage, steps } = await runTextGeneration({ messages, - system: generateSubagentSystemPrompt(), - ...(providerOptions && { providerOptions }), - stopWhen: stepCountIs(Infinity), + systemPrompt: generateSubagentSystemPrompt(), prepareStep: () => { return { system: generateSubagentSystemPrompt(), } }, - onFinish: async () => { - await close() - }, - tools, }) const stepUsages = buildStepUsages(steps) return { @@ -301,17 +381,9 @@ export const createAgent = ( } const messages = [...params.messages, scheduleMessage] params.skills.forEach((skill) => enableSkill(skill)) - const { tools, close } = await getAgentTools() - const { response, reasoning, text, usage, steps } = await generateText({ - model, + const { response, reasoning, text, usage, steps } = await runTextGeneration({ messages, - system: await generateSystemPrompt(), - ...(providerOptions && { providerOptions }), - stopWhen: stepCountIs(Infinity), - onFinish: async () => { - await close() - }, - tools, + systemPrompt: await generateSystemPrompt(), }) const stepUsages = buildStepUsages(steps) return { @@ -341,17 +413,9 @@ export const createAgent = ( } const messages = [...params.messages, heartbeatMessage] params.skills.forEach((skill) => enableSkill(skill)) - const { tools, close } = await getAgentTools() - const { response, reasoning, text, usage, steps } = await generateText({ - model, + const { response, reasoning, text, usage, steps } = await runTextGeneration({ messages, - system: await generateSystemPrompt(), - ...(providerOptions && { providerOptions }), - stopWhen: stepCountIs(Infinity), - onFinish: async () => { - await close() - }, - tools, + systemPrompt: await generateSystemPrompt(), }) const stepUsages = buildStepUsages(steps) return { @@ -392,6 +456,27 @@ export const createAgent = ( input.skills.forEach((skill) => enableSkill(skill)) const systemPrompt = await generateSystemPrompt() const attachmentsExtractor = new AttachmentsStreamExtractor() + const textLoopGuard = loopDetectionEnabled + ? createTextLoopGuard({ + consecutiveHitsToAbort: LOOP_DETECTED_STREAK_THRESHOLD, + minNewGramsPerChunk: LOOP_DETECTED_MIN_NEW_GRAMS_PER_CHUNK, + }) + : null + const guardLoopOutput = (text: string) => { + if (!textLoopGuard) { + return + } + const result = textLoopGuard.inspect(text) + if (result.abort) { + throw new Error(LOOP_DETECTED_ABORT_MESSAGE) + } + } + const textLoopProbeBuffer = textLoopGuard + ? createTextLoopProbeBuffer( + LOOP_DETECTED_PROBE_CHARS, + guardLoopOutput, + ) + : null const result: { messages: ModelMessage[]; reasoning: string[]; @@ -403,7 +488,20 @@ export const createAgent = ( usage: null, usages: [], } + const toolLoopAbortCallIds = new Set() const { tools, close } = await getAgentTools() + // Stream path needs deferred abort to keep tool_call_start/tool_call_end event pairing. + const guardedTools = buildGuardedTools(tools, (toolCallId) => { + toolLoopAbortCallIds.add(toolCallId) + }) + let closePromise: Promise | null = null + const closeTools = async () => { + if (!closePromise) { + closePromise = Promise.resolve().then(() => close()) + } + await closePromise + } + let streamError: unknown = null try { const { fullStream } = streamText({ model, @@ -416,9 +514,9 @@ export const createAgent = ( system: systemPrompt, } }, - tools, + tools: guardedTools, onFinish: async ({ usage, reasoning, response, steps }) => { - await close() + await closeTools() result.usage = usage as never result.reasoning = reasoning.map((part) => part.text) result.messages = response.messages @@ -464,6 +562,9 @@ export const createAgent = ( chunk.text, ) if (visibleText) { + if (textLoopProbeBuffer) { + textLoopProbeBuffer.push(visibleText) + } yield { type: 'text_delta', delta: visibleText, @@ -481,11 +582,17 @@ export const createAgent = ( // Flush any remaining buffered content before ending the text stream. const remainder = attachmentsExtractor.flushRemainder() if (remainder.visibleText) { + if (textLoopProbeBuffer) { + textLoopProbeBuffer.push(remainder.visibleText) + } yield { type: 'text_delta', delta: remainder.visibleText, } } + if (textLoopProbeBuffer) { + textLoopProbeBuffer.flush() + } if (remainder.attachments.length) { yield { type: 'attachment_delta', @@ -502,11 +609,17 @@ export const createAgent = ( // Flush any remaining buffered content before ending the text stream. const remainder = attachmentsExtractor.flushRemainder() if (remainder.visibleText) { + if (textLoopProbeBuffer) { + textLoopProbeBuffer.push(remainder.visibleText) + } yield { type: 'text_delta', delta: remainder.visibleText, } } + if (textLoopProbeBuffer) { + textLoopProbeBuffer.flush() + } if (remainder.attachments.length) { yield { type: 'attachment_delta', @@ -522,6 +635,9 @@ export const createAgent = ( } break case 'tool-result': + // Always emit the terminal tool event first so downstream reducers + // can close the in-flight tool block before the stream aborts. + const shouldAbortForToolLoop = toolLoopAbortCallIds.delete(chunk.toolCallId) yield { type: 'tool_call_end', toolName: chunk.toolName, @@ -530,6 +646,9 @@ export const createAgent = ( result: chunk.output, metadata: chunk, } + if (shouldAbortForToolLoop) { + throw new Error(TOOL_LOOP_DETECTED_ABORT_MESSAGE) + } break case 'file': yield { @@ -544,6 +663,9 @@ export const createAgent = ( } } } + if (textLoopProbeBuffer) { + textLoopProbeBuffer.flush() + } const { messages: strippedMessages } = stripAttachmentsFromMessages( result.messages, @@ -560,8 +682,18 @@ export const createAgent = ( skills: getEnabledSkills(), } } catch (error) { + streamError = error console.error(error) throw error + } finally { + try { + await closeTools() + } catch (closeError) { + if (streamError == null) { + throw closeError + } + console.error(closeError) + } } } diff --git a/packages/agent/src/sential.test.ts b/packages/agent/src/sential.test.ts new file mode 100644 index 00000000..603a183e --- /dev/null +++ b/packages/agent/src/sential.test.ts @@ -0,0 +1,265 @@ +import { describe, expect, it } from 'vitest' +import { + createSential, + createTextLoopGuard, + createToolLoopGuard, +} from './sential' + +describe('sential', () => { + it('does not hit when overlap stays low', () => { + const sential = createSential() + sential.inspect('ABCDEFGHIJKLMNO') + + const result = sential.inspect('qrstuvwxyz12345') + + expect(result.hit).toBe(false) + expect(result.overlap).toBe(0) + }) + + it('hits when overlap is above threshold', () => { + const sential = createSential() + sential.inspect('0123456789abcdefghij0123456789abcdefghij') + + const result = sential.inspect('0123456789abcdefghij') + + expect(result.hit).toBe(true) + expect(result.overlap).toBeGreaterThan(0.75) + }) + + it('does not hit when overlap is exactly threshold', () => { + const sential = createSential({ + ngramSize: 1, + overlapThreshold: 0.75, + }) + sential.inspect('aaaaaaaaaa') + + const result = sential.inspect(`${'a'.repeat(15)}bcdef`) + + expect(result.newGrams).toBe(20) + expect(result.matchedGrams).toBe(15) + expect(result.overlap).toBeCloseTo(0.75, 10) + expect(result.hit).toBe(false) + }) + + it('evicts old grams with sliding window', () => { + const sential = createSential({ + windowSize: 20, + }) + sential.inspect('abcdefghijabcdefghij') + sential.inspect('KLMNOPQRST') + sential.inspect('UVWXYZ1234') + + const result = sential.inspect('abcdefghij') + + expect(result.hit).toBe(false) + expect(result.matchedGrams).toBe(0) + }) + + it('aborts only after 10 consecutive hits', () => { + const guard = createTextLoopGuard({ + ngramSize: 1, + overlapThreshold: 0.5, + consecutiveHitsToAbort: 10, + }) + + const seeded = guard.inspect('aaaaaaaaaa') + expect(seeded.hit).toBe(false) + expect(seeded.streak).toBe(0) + expect(seeded.abort).toBe(false) + + for (let i = 1; i <= 9; i += 1) { + const result = guard.inspect('aaaaaaaaaa') + expect(result.hit).toBe(true) + expect(result.streak).toBe(i) + expect(result.abort).toBe(false) + } + + const tenth = guard.inspect('aaaaaaaaaa') + expect(tenth.hit).toBe(true) + expect(tenth.streak).toBe(10) + expect(tenth.abort).toBe(true) + }) + + it('resets streak when a non-hit chunk appears', () => { + const guard = createTextLoopGuard({ + ngramSize: 1, + overlapThreshold: 0.5, + consecutiveHitsToAbort: 10, + }) + + guard.inspect('aaaaaaaaaa') + for (let i = 0; i < 5; i += 1) { + guard.inspect('aaaaaaaaaa') + } + + const miss = guard.inspect('bcdefghijk') + expect(miss.hit).toBe(false) + expect(miss.streak).toBe(0) + expect(miss.abort).toBe(false) + + const hitAgain = guard.inspect('aaaaaaaaaa') + expect(hitAgain.hit).toBe(true) + expect(hitAgain.streak).toBe(1) + expect(hitAgain.abort).toBe(false) + }) + + it('only updates streak when chunk has enough new grams', () => { + const guard = createTextLoopGuard({ + ngramSize: 1, + overlapThreshold: 0.5, + minNewGramsPerChunk: 5, + }) + + guard.inspect('aaaaaaaaaa') + + const smallHit = guard.inspect('aaaa') + expect(smallHit.hit).toBe(true) + expect(smallHit.newGrams).toBe(4) + expect(smallHit.streak).toBe(0) + expect(smallHit.abort).toBe(false) + + const countedHit = guard.inspect('aaaaa') + expect(countedHit.hit).toBe(true) + expect(countedHit.newGrams).toBe(5) + expect(countedHit.streak).toBe(1) + expect(countedHit.abort).toBe(false) + + const smallMiss = guard.inspect('cccc') + expect(smallMiss.hit).toBe(false) + expect(smallMiss.newGrams).toBe(4) + expect(smallMiss.streak).toBe(1) + + const countedMiss = guard.inspect('ddddd') + expect(countedMiss.hit).toBe(false) + expect(countedMiss.newGrams).toBe(5) + expect(countedMiss.streak).toBe(0) + }) + + it('warns on first tool-loop breach and aborts on second breach', () => { + const guard = createToolLoopGuard({ + repeatThreshold: 5, + warningsBeforeAbort: 1, + }) + const payload = { + toolName: 'web_fetch', + input: { url: 'https://example.com', requestId: 'r-1' }, + } + + for (let i = 1; i <= 5; i += 1) { + const result = guard.inspect(payload) + expect(result.warn).toBe(false) + expect(result.abort).toBe(false) + expect(result.repeatCount).toBe(i) + } + + const firstBreach = guard.inspect(payload) + expect(firstBreach.warn).toBe(true) + expect(firstBreach.abort).toBe(false) + expect(firstBreach.breachCount).toBe(1) + expect(firstBreach.repeatCount).toBe(0) + + for (let i = 1; i <= 5; i += 1) { + const result = guard.inspect(payload) + expect(result.warn).toBe(false) + expect(result.abort).toBe(false) + expect(result.repeatCount).toBe(i) + } + + const secondBreach = guard.inspect(payload) + expect(secondBreach.warn).toBe(false) + expect(secondBreach.abort).toBe(true) + expect(secondBreach.breachCount).toBe(2) + }) + + it('resets tool-loop repeat count when hash changes', () => { + const guard = createToolLoopGuard({ + repeatThreshold: 5, + warningsBeforeAbort: 1, + }) + + const first = guard.inspect({ + toolName: 'web_fetch', + input: { url: 'https://a.example.com' }, + }) + expect(first.repeatCount).toBe(1) + + const second = guard.inspect({ + toolName: 'web_fetch', + input: { url: 'https://a.example.com' }, + }) + expect(second.repeatCount).toBe(2) + + const changed = guard.inspect({ + toolName: 'web_fetch', + input: { url: 'https://b.example.com' }, + }) + expect(changed.repeatCount).toBe(1) + expect(changed.warn).toBe(false) + expect(changed.abort).toBe(false) + }) + + it('resets tool-loop breach count when hash changes', () => { + const guard = createToolLoopGuard({ + repeatThreshold: 1, + warningsBeforeAbort: 1, + }) + + // First fingerprint reaches warning phase. + guard.inspect({ + toolName: 'web_fetch', + input: { url: 'https://a.example.com' }, + }) + const warned = guard.inspect({ + toolName: 'web_fetch', + input: { url: 'https://a.example.com' }, + }) + expect(warned.warn).toBe(true) + expect(warned.breachCount).toBe(1) + + // Switching fingerprint should restart warning/abort phase. + const switched = guard.inspect({ + toolName: 'web_fetch', + input: { url: 'https://b.example.com' }, + }) + expect(switched.warn).toBe(false) + expect(switched.abort).toBe(false) + expect(switched.breachCount).toBe(0) + + const warnedAgain = guard.inspect({ + toolName: 'web_fetch', + input: { url: 'https://b.example.com' }, + }) + expect(warnedAgain.warn).toBe(true) + expect(warnedAgain.abort).toBe(false) + expect(warnedAgain.breachCount).toBe(1) + }) + + it('ignores volatile keys when computing tool-loop hash', () => { + const guard = createToolLoopGuard({ + repeatThreshold: 1, + warningsBeforeAbort: 1, + }) + + const first = guard.inspect({ + toolName: 'web_fetch', + input: { + url: 'https://example.com', + request_id: 'req-1', + updatedAt: '2026-02-28T00:00:00.000Z', + }, + }) + expect(first.warn).toBe(false) + expect(first.abort).toBe(false) + + const second = guard.inspect({ + toolName: 'web_fetch', + input: { + url: 'https://example.com', + request_id: 'req-2', + updatedAt: '2026-02-28T00:01:00.000Z', + }, + }) + expect(second.warn).toBe(true) + expect(second.abort).toBe(false) + }) +}) diff --git a/packages/agent/src/sential.ts b/packages/agent/src/sential.ts new file mode 100644 index 00000000..1a5e5738 --- /dev/null +++ b/packages/agent/src/sential.ts @@ -0,0 +1,506 @@ +import { createHash } from 'node:crypto' + +export interface SentialOptions { + ngramSize?: number; + windowSize?: number; + overlapThreshold?: number; +} + +export interface TextLoopGuardOptions extends SentialOptions { + consecutiveHitsToAbort?: number; + minNewGramsPerChunk?: number; +} + +export interface SentialInspectResult { + hit: boolean; + overlap: number; + matchedGrams: number; + newGrams: number; +} + +export interface TextLoopGuardInspectResult extends SentialInspectResult { + streak: number; + abort: boolean; +} + +export interface ToolLoopInspectInput { + toolName: string; + input: unknown; +} + +export interface ToolLoopInspectResult { + hash: string; + repeatCount: number; + breachCount: number; + warn: boolean; + abort: boolean; +} + +export interface ToolLoopGuardOptions { + repeatThreshold?: number; + warningsBeforeAbort?: number; + volatileKeys?: string[]; +} + +export interface Sential { + inspect(text: string): SentialInspectResult; + reset(): void; +} + +export interface TextLoopGuard { + inspect(text: string): TextLoopGuardInspectResult; + reset(): void; +} + +export interface TextLoopProbeBuffer { + push(text: string): void; + flush(): void; +} + +export interface ToolLoopGuard { + inspect(input: ToolLoopInspectInput): ToolLoopInspectResult; + reset(): void; +} + +const DEFAULT_NGRAM_SIZE = 10 +const DEFAULT_WINDOW_SIZE = 1000 +const DEFAULT_OVERLAP_THRESHOLD = 0.75 +const DEFAULT_CONSECUTIVE_HITS_TO_ABORT = 10 +const DEFAULT_MIN_NEW_GRAMS_PER_CHUNK = 1 +const DEFAULT_TOOL_LOOP_REPEAT_THRESHOLD = 5 +const DEFAULT_TOOL_LOOP_WARNINGS_BEFORE_ABORT = 1 +const DEFAULT_VOLATILE_KEYS = [ + 'toolCallId', + 'tool_call_id', + 'requestId', + 'request_id', + 'traceId', + 'trace_id', + 'spanId', + 'span_id', + 'sessionId', + 'session_id', + 'timestamp', + 'createdAt', + 'created_at', + 'updatedAt', + 'updated_at', + 'expiresAt', + 'expires_at', + 'nonce', +] +const VOLATILE_KEY_SUFFIXES = [ + 'requestid', + 'traceid', + 'sessionid', + 'toolcallid', + 'timestamp', + 'createdat', + 'updatedat', + 'expiresat', +] + +type NormalizedValue = + | null + | string + | number + | boolean + | NormalizedValue[] + | { [key: string]: NormalizedValue }; + +function validatePositiveInt(name: string, value: number): number { + if (!Number.isFinite(value) || value <= 0 || !Number.isInteger(value)) { + throw new Error(`${name} must be a positive integer`) + } + return value +} + +function validateThreshold(value: number): number { + if (!Number.isFinite(value) || value < 0 || value > 1) { + throw new Error('overlapThreshold must be between 0 and 1') + } + return value +} + +function normalizeChars(text: string): string[] { + if (!text) return [] + return Array.from(text.normalize('NFC')) +} + +function buildNgram(chars: string[], start: number, size: number): string { + return chars.slice(start, start + size).join('') +} + +function normalizeKeyName(key: string): string { + return key + .trim() + .toLowerCase() + .replace(/[^a-z0-9]/g, '') +} + +function isVolatileKey(key: string, volatileKeySet: Set): boolean { + const normalized = normalizeKeyName(key) + if (!normalized) return false + if (volatileKeySet.has(normalized)) return true + return VOLATILE_KEY_SUFFIXES.some((suffix) => normalized.endsWith(suffix)) +} + +function isPlainObject(value: unknown): value is Record { + if (value === null || typeof value !== 'object') return false + const prototype = Object.getPrototypeOf(value as object) + return prototype === Object.prototype || prototype === null +} + +function normalizeToolLoopValue( + value: unknown, + volatileKeySet: Set, + seen: WeakSet, +): NormalizedValue | undefined { + if (value === null) return null + + if (typeof value === 'string') return value.normalize('NFC') + if (typeof value === 'boolean') return value + if (typeof value === 'number') { + return Number.isFinite(value) ? value : (String(value) as NormalizedValue) + } + if (typeof value === 'bigint') return value.toString() + if ( + typeof value === 'undefined' || + typeof value === 'function' || + typeof value === 'symbol' + ) { + return undefined + } + + if (value instanceof Date) { + return value.toISOString() + } + + if (Array.isArray(value)) { + return value.map( + (item) => normalizeToolLoopValue(item, volatileKeySet, seen) ?? null, + ) + } + + if (!isPlainObject(value)) { + const maybeRecord = value as { toJSON?: () => unknown } + if (typeof maybeRecord.toJSON === 'function') { + return normalizeToolLoopValue(maybeRecord.toJSON(), volatileKeySet, seen) + } + return String(value) + } + + if (seen.has(value)) { + return '[Circular]' + } + seen.add(value) + + const normalizedObject: { [key: string]: NormalizedValue } = {} + const keys = Object.keys(value).sort() + for (const key of keys) { + if (isVolatileKey(key, volatileKeySet)) { + continue + } + const normalized = normalizeToolLoopValue(value[key], volatileKeySet, seen) + if (normalized !== undefined) { + normalizedObject[key] = normalized + } + } + + seen.delete(value) + return normalizedObject +} + +function computeToolLoopHash( + input: ToolLoopInspectInput, + volatileKeySet: Set, +): string { + const payload = { + toolName: input.toolName.trim(), + input: + normalizeToolLoopValue(input.input, volatileKeySet, new WeakSet()) ?? + null, + } + const serialized = JSON.stringify(payload) + return createHash('sha256').update(serialized).digest('hex') +} + +export function createSential(options: SentialOptions = {}): Sential { + const ngramSize = validatePositiveInt( + 'ngramSize', + options.ngramSize ?? DEFAULT_NGRAM_SIZE, + ) + const windowSize = validatePositiveInt( + 'windowSize', + options.windowSize ?? DEFAULT_WINDOW_SIZE, + ) + const overlapThreshold = validateThreshold( + options.overlapThreshold ?? DEFAULT_OVERLAP_THRESHOLD, + ) + if (windowSize < ngramSize) { + throw new Error('windowSize must be greater than or equal to ngramSize') + } + + const windowChars: string[] = [] + const windowNgramQueue: string[] = [] + const historySet = new Set() + const historyCounts = new Map() + + const addHistoryGram = (gram: string) => { + const nextCount = (historyCounts.get(gram) ?? 0) + 1 + historyCounts.set(gram, nextCount) + if (nextCount === 1) { + historySet.add(gram) + } + } + + const removeHistoryGram = (gram: string) => { + const prevCount = historyCounts.get(gram) + if (!prevCount) return + if (prevCount <= 1) { + historyCounts.delete(gram) + historySet.delete(gram) + return + } + historyCounts.set(gram, prevCount - 1) + } + + const pushWindowChar = (char: string) => { + windowChars.push(char) + + if (windowChars.length >= ngramSize) { + const gram = buildNgram( + windowChars, + windowChars.length - ngramSize, + ngramSize, + ) + windowNgramQueue.push(gram) + addHistoryGram(gram) + } + + if (windowChars.length <= windowSize) { + return + } + + windowChars.shift() + const removedGram = windowNgramQueue.shift() + if (removedGram) { + removeHistoryGram(removedGram) + } + } + + return { + inspect(text: string): SentialInspectResult { + const incomingChars = normalizeChars(text) + if (incomingChars.length === 0) { + return { + hit: false, + overlap: 0, + matchedGrams: 0, + newGrams: 0, + } + } + + const contextSize = Math.max(ngramSize - 1, 0) + const contextChars = + contextSize > 0 ? windowChars.slice(-contextSize) : [] + const candidateChars = [...contextChars, ...incomingChars] + + let matchedGrams = 0 + let newGrams = 0 + const contextLength = contextChars.length + + if (candidateChars.length >= ngramSize) { + for (let i = 0; i <= candidateChars.length - ngramSize; i += 1) { + const gramEndIndex = i + ngramSize - 1 + if (gramEndIndex < contextLength) { + continue + } + const gram = buildNgram(candidateChars, i, ngramSize) + newGrams += 1 + if (historySet.has(gram)) { + matchedGrams += 1 + } + } + } + + const overlap = newGrams === 0 ? 0 : matchedGrams / newGrams + const hit = overlap > overlapThreshold + + for (const char of incomingChars) { + pushWindowChar(char) + } + + return { + hit, + overlap, + matchedGrams, + newGrams, + } + }, + reset(): void { + windowChars.length = 0 + windowNgramQueue.length = 0 + historySet.clear() + historyCounts.clear() + }, + } +} + +export function createTextLoopGuard( + options: TextLoopGuardOptions = {}, +): TextLoopGuard { + const consecutiveHitsToAbort = validatePositiveInt( + 'consecutiveHitsToAbort', + options.consecutiveHitsToAbort ?? DEFAULT_CONSECUTIVE_HITS_TO_ABORT, + ) + const minNewGramsPerChunk = validatePositiveInt( + 'minNewGramsPerChunk', + options.minNewGramsPerChunk ?? DEFAULT_MIN_NEW_GRAMS_PER_CHUNK, + ) + const sential = createSential(options) + let streak = 0 + + return { + inspect(text: string): TextLoopGuardInspectResult { + const result = sential.inspect(text) + if (result.newGrams >= minNewGramsPerChunk) { + if (result.hit) { + streak += 1 + } else { + streak = 0 + } + } + return { + ...result, + streak, + abort: streak >= consecutiveHitsToAbort, + } + }, + reset(): void { + sential.reset() + streak = 0 + }, + } +} + +export function createTextLoopProbeBuffer( + chunkSize: number, + inspect: (text: string) => void, +): TextLoopProbeBuffer { + validatePositiveInt('chunkSize', chunkSize) + let chars: string[] = [] + let offset = 0 + + const compact = () => { + if (offset > 0) { + chars = chars.slice(offset) + offset = 0 + } + } + + const inspectChunk = (text: string) => { + if (text.length > 0) { + inspect(text) + } + } + + return { + push(text: string): void { + if (!text) return + chars.push(...normalizeChars(text)) + + while (chars.length - offset >= chunkSize) { + const chunk = chars.slice(offset, offset + chunkSize).join('') + offset += chunkSize + inspectChunk(chunk) + } + + // Prevent unbounded front-gaps after many chunks. + if (offset >= chunkSize) { + compact() + } + }, + flush(): void { + if (chars.length - offset > 0) { + const remainder = chars.slice(offset).join('') + inspectChunk(remainder) + } + chars = [] + offset = 0 + }, + } +} + +export function createToolLoopGuard( + options: ToolLoopGuardOptions = {}, +): ToolLoopGuard { + const repeatThreshold = validatePositiveInt( + 'repeatThreshold', + options.repeatThreshold ?? DEFAULT_TOOL_LOOP_REPEAT_THRESHOLD, + ) + const warningsBeforeAbort = validatePositiveInt( + 'warningsBeforeAbort', + options.warningsBeforeAbort ?? DEFAULT_TOOL_LOOP_WARNINGS_BEFORE_ABORT, + ) + const volatileKeySet = new Set(DEFAULT_VOLATILE_KEYS.map(normalizeKeyName)) + for (const key of options.volatileKeys ?? []) { + const normalizedKey = normalizeKeyName(key) + if (normalizedKey) { + volatileKeySet.add(normalizedKey) + } + } + + let lastHash = '' + let repeatCount = 0 + let breachCount = 0 + let breachHash = '' + + return { + inspect(input: ToolLoopInspectInput): ToolLoopInspectResult { + const hash = computeToolLoopHash(input, volatileKeySet) + + if (hash === lastHash) { + repeatCount += 1 + } else { + lastHash = hash + repeatCount = 1 + } + + // Breach phase is fingerprint-specific: switching tool signature restarts it. + if (breachHash !== hash) { + breachHash = hash + breachCount = 0 + } + + let warn = false + let abort = false + if (repeatCount > repeatThreshold) { + if (breachCount < warningsBeforeAbort) { + breachCount += 1 + warn = true + // Reset consecutive accumulation after first warning. + lastHash = '' + repeatCount = 0 + } else { + breachCount += 1 + abort = true + } + } + + return { + hash, + repeatCount, + breachCount, + warn, + abort, + } + }, + reset(): void { + lastHash = '' + repeatCount = 0 + breachCount = 0 + breachHash = '' + }, + } +} diff --git a/packages/agent/src/tool-loop.test.ts b/packages/agent/src/tool-loop.test.ts new file mode 100644 index 00000000..8d481db6 --- /dev/null +++ b/packages/agent/src/tool-loop.test.ts @@ -0,0 +1,93 @@ +import { describe, expect, it, vi } from 'vitest' +import type { ToolSet } from 'ai' +import { createToolLoopGuardedTools } from './tool-loop' + +describe('tool loop guarded tools', () => { + it('preserves promised async-iterable tool outputs', async () => { + const onAbortToolCall = vi.fn() + const streamedChunks = ['chunk-1', 'chunk-2'] + const stream = { + async *[Symbol.asyncIterator]() { + for (const chunk of streamedChunks) { + yield chunk + } + }, + } + + const baseTools = { + streamy: { + execute: async () => stream, + }, + } as unknown as ToolSet + + const tools = createToolLoopGuardedTools(baseTools, { + repeatThreshold: 1, + warningsBeforeAbort: 1, + onAbortToolCall, + warningKey: '__warn', + warningText: 'loop warning', + }) + + const output = await tools.streamy.execute?.({ value: 'same' } as never, { toolCallId: 't-stream' } as never) + + expect(output).toBe(stream) + const received: string[] = [] + for await (const chunk of output as AsyncIterable) { + received.push(chunk) + } + expect(received).toEqual(streamedChunks) + expect(onAbortToolCall).not.toHaveBeenCalled() + }) + + it('defers abort to stream layer when onAbortToolCall is provided', async () => { + const onAbortToolCall = vi.fn() + const baseTools = { + echo: { + execute: async (input: unknown) => ({ result: input }), + }, + } as unknown as ToolSet + const tools = createToolLoopGuardedTools(baseTools, { + repeatThreshold: 1, + warningsBeforeAbort: 1, + onAbortToolCall, + warningKey: '__warn', + warningText: 'loop warning', + }) + + await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-1' } as never) + const warned = await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-1' } as never) + expect(warned).toMatchObject({ + __warn: { + marker: 'MEMOH_TOOL_LOOP_WARNING', + }, + }) + await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-1' } as never) + const abortedOutput = await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-1' } as never) + + expect(onAbortToolCall).toHaveBeenCalledWith('t-1') + expect(abortedOutput).toEqual({ result: { value: 'same' } }) + }) + + it('reports abort via callback without throwing inside tool execution', async () => { + const onAbortToolCall = vi.fn() + const baseTools = { + echo: { + execute: async (input: unknown) => ({ result: input }), + }, + } as unknown as ToolSet + const tools = createToolLoopGuardedTools(baseTools, { + repeatThreshold: 1, + warningsBeforeAbort: 1, + onAbortToolCall, + warningKey: '__warn', + warningText: 'loop warning', + }) + + await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-2' } as never) + await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-2' } as never) + await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-2' } as never) + const abortedOutput = await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-2' } as never) + expect(onAbortToolCall).toHaveBeenCalledWith('t-2') + expect(abortedOutput).toEqual({ result: { value: 'same' } }) + }) +}) diff --git a/packages/agent/src/tool-loop.ts b/packages/agent/src/tool-loop.ts new file mode 100644 index 00000000..ab8dfc08 --- /dev/null +++ b/packages/agent/src/tool-loop.ts @@ -0,0 +1,122 @@ +import { ToolExecutionOptions, ToolSet } from 'ai' +import { createToolLoopGuard, type ToolLoopInspectResult } from './sential' + +export interface CreateToolLoopGuardedToolsOptions { + repeatThreshold: number + warningsBeforeAbort: number + onAbortToolCall: (toolCallId: string) => void + warningKey: string + warningText: string +} + +const isRecord = (value: unknown): value is Record => { + return value !== null && typeof value === 'object' && !Array.isArray(value) +} + +const isAsyncIterable = (value: unknown): value is AsyncIterable => { + return ( + value !== null && + typeof value === 'object' && + Symbol.asyncIterator in value + ) +} + +const injectToolLoopWarning = ( + output: unknown, + inspectResult: ToolLoopInspectResult, + warningKey: string, + warningText: string, +): unknown => { + // Keep warning payload structured so UI/consumers can detect and render it. + const warningPayload = { + marker: 'MEMOH_TOOL_LOOP_WARNING', + message: warningText, + fingerprint: inspectResult.hash, + breachCount: inspectResult.breachCount, + } + if (isRecord(output)) { + return { + ...output, + [warningKey]: warningPayload, + } + } + return { + [warningKey]: warningPayload, + result: output, + } +} + +export function createToolLoopGuardedTools( + tools: ToolSet, + { + repeatThreshold, + warningsBeforeAbort, + onAbortToolCall, + warningKey, + warningText, + }: CreateToolLoopGuardedToolsOptions, +): ToolSet { + const guard = createToolLoopGuard({ + repeatThreshold, + warningsBeforeAbort, + }) + + // Wrap each executable tool to inspect (toolName + input) after execution. + // First breach injects a warning into this tool result; second breach signals abort. + return Object.fromEntries( + Object.entries(tools).map(([toolName, toolDefinition]) => { + const execute = toolDefinition.execute + if (typeof execute !== 'function') { + return [toolName, toolDefinition] + } + + const wrappedTool = { + ...toolDefinition, + execute: ( + toolInput: unknown, + options: ToolExecutionOptions, + ) => { + const directOutput = execute( + toolInput as never, + options as never, + ) as unknown + + // Streamed tool outputs are passed through unchanged to preserve streaming semantics. + if (isAsyncIterable(directOutput)) { + return directOutput as never + } + + return (async () => { + const resolvedOutput = await directOutput + + // Tools may return Promise; keep that stream untouched too. + if (isAsyncIterable(resolvedOutput)) { + return resolvedOutput as never + } + + const inspectResult = guard.inspect({ + toolName, + input: toolInput, + }) + if (inspectResult.abort) { + // Report loop abort to generation layer; it decides when/how to stop. + onAbortToolCall(options.toolCallId) + return resolvedOutput as never + } + if (inspectResult.warn) { + return injectToolLoopWarning( + resolvedOutput, + inspectResult, + warningKey, + warningText, + ) as never + } + return resolvedOutput as never + })() + }, + } + + return [toolName, wrappedTool] + }), + ) as ToolSet +} diff --git a/packages/agent/src/types/agent.ts b/packages/agent/src/types/agent.ts index 1c84f17f..56159d64 100644 --- a/packages/agent/src/types/agent.ts +++ b/packages/agent/src/types/agent.ts @@ -38,6 +38,10 @@ export interface InboxItem { createdAt: string } +export interface LoopDetectionConfig { + enabled?: boolean +} + export interface AgentParams { model: ModelConfig language?: string @@ -50,6 +54,7 @@ export interface AgentParams { auth: AgentAuthContext skills?: AgentSkill[] inbox?: InboxItem[] + loopDetection?: LoopDetectionConfig } export interface AgentInput { diff --git a/packages/web/src/i18n/locales/en.json b/packages/web/src/i18n/locales/en.json index 8d9b7d3c..2c300b79 100644 --- a/packages/web/src/i18n/locales/en.json +++ b/packages/web/src/i18n/locales/en.json @@ -487,6 +487,7 @@ "heartbeatModelPlaceholder": "Use chat model (default)", "allowGuest": "Allow Guest Access", "allowGuestPersonalHint": "Personal bots do not support guest access. Use a public bot instead.", + "loopDetectionTitle": "Detect and auto-block output loops", "searchModel": "Search models…", "noModel": "No models available", "saveSuccess": "Settings saved", diff --git a/packages/web/src/i18n/locales/zh.json b/packages/web/src/i18n/locales/zh.json index cb54ebc3..7392e400 100644 --- a/packages/web/src/i18n/locales/zh.json +++ b/packages/web/src/i18n/locales/zh.json @@ -483,6 +483,7 @@ "heartbeatModelPlaceholder": "使用聊天模型(默认)", "allowGuest": "允许游客访问", "allowGuestPersonalHint": "个人 Bot 不支持游客访问,请使用公开 Bot。", + "loopDetectionTitle": "自动检测并阻止模型循环输出", "searchModel": "搜索模型…", "noModel": "暂无可选模型", "saveSuccess": "设置已保存", diff --git a/packages/web/src/pages/bots/components/bot-settings.vue b/packages/web/src/pages/bots/components/bot-settings.vue index b146ff8d..3036053c 100644 --- a/packages/web/src/pages/bots/components/bot-settings.vue +++ b/packages/web/src/pages/bots/components/bot-settings.vue @@ -134,6 +134,17 @@ + + +
+ + +
+