feat(agent): loop detection (#152)

* feat(loop-detection): add configurable text and tool loop guards

* style(web): remove duplicate separator in bot settings
This commit is contained in:
Ringo.Typowriter
2026-03-02 15:00:09 +08:00
committed by GitHub
parent 04bce702b7
commit d3edd17d90
13 changed files with 1381 additions and 53 deletions
+174 -42
View File
@@ -33,9 +33,20 @@ import { getMCPTools } from './tools/mcp'
import { getTools } from './tools'
import { buildIdentityHeaders } from './utils/headers'
import { createFS } from './utils'
import { createTextLoopGuard, createTextLoopProbeBuffer } from './sential'
import { createToolLoopGuardedTools } from './tool-loop'
const ANTHROPIC_BUDGET: Record<string, number> = { low: 5000, medium: 16000, high: 50000 }
const GOOGLE_BUDGET: Record<string, number> = { low: 5000, medium: 16000, high: 50000 }
const LOOP_DETECTED_ABORT_MESSAGE = 'loop detected, stream aborted'
const LOOP_DETECTED_STREAK_THRESHOLD = 3
const LOOP_DETECTED_MIN_NEW_GRAMS_PER_CHUNK = 8
const LOOP_DETECTED_PROBE_CHARS = 256
const TOOL_LOOP_DETECTED_ABORT_MESSAGE = 'tool loop detected, stream aborted'
const TOOL_LOOP_REPEAT_THRESHOLD = 5
const TOOL_LOOP_WARNINGS_BEFORE_ABORT = 1
const TOOL_LOOP_WARNING_KEY = '__memoh_tool_loop_warning'
const TOOL_LOOP_WARNING_TEXT = '[MEMOH_TOOL_LOOP_WARNING] Repeated identical tool invocation (same tool + arguments) was detected more than 5 times. Stop looping this tool and either summarize current results or change strategy.'
const buildProviderOptions = (config: ModelConfig): Record<string, Record<string, unknown>> | undefined => {
if (!config.reasoning?.enabled) return undefined
@@ -93,12 +104,14 @@ export const createAgent = (
},
auth,
inbox = [],
loopDetection = { enabled: false },
}: AgentParams,
fetch: AuthFetcher,
) => {
const model = createModel(modelConfig)
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const providerOptions = buildProviderOptions(modelConfig) as any
const loopDetectionEnabled = loopDetection?.enabled === true
const enabledSkills: AgentSkill[] = []
const fs = createFS({ fetch, botId: identity.botId })
@@ -194,27 +207,102 @@ export const createAgent = (
return userMessage
}
const createNonStreamTextLoopInspector = () => {
if (!loopDetectionEnabled) {
return null
}
const textLoopGuard = createTextLoopGuard({
consecutiveHitsToAbort: LOOP_DETECTED_STREAK_THRESHOLD,
minNewGramsPerChunk: LOOP_DETECTED_MIN_NEW_GRAMS_PER_CHUNK,
})
return (text: string) => {
const result = textLoopGuard.inspect(text)
if (result.abort) {
throw new Error(LOOP_DETECTED_ABORT_MESSAGE)
}
}
}
const buildGuardedTools = (
tools: ToolSet,
onAbortToolCall: (toolCallId: string) => void = () => {},
): ToolSet => {
if (!loopDetectionEnabled) {
return tools
}
return createToolLoopGuardedTools(tools, {
repeatThreshold: TOOL_LOOP_REPEAT_THRESHOLD,
warningsBeforeAbort: TOOL_LOOP_WARNINGS_BEFORE_ABORT,
onAbortToolCall,
warningKey: TOOL_LOOP_WARNING_KEY,
warningText: TOOL_LOOP_WARNING_TEXT,
})
}
const runTextGeneration = async ({
messages,
systemPrompt,
prepareStep,
}: {
messages: ModelMessage[]
systemPrompt: string
prepareStep?: () => { system: string }
}) => {
const { tools, close } = await getAgentTools()
let shouldAbortForToolLoop = false
const guardedTools = buildGuardedTools(tools, () => {
shouldAbortForToolLoop = true
})
const inspectTextLoop = createNonStreamTextLoopInspector()
let runError: unknown = null
try {
return await generateText({
model,
messages,
system: systemPrompt,
...(providerOptions && { providerOptions }),
stopWhen: stepCountIs(Infinity),
...(prepareStep && { prepareStep }),
...(loopDetectionEnabled && {
onStepFinish: ({ text }: { text: string }) => {
if (shouldAbortForToolLoop) {
throw new Error(TOOL_LOOP_DETECTED_ABORT_MESSAGE)
}
if (inspectTextLoop) {
inspectTextLoop(text)
}
},
}),
tools: guardedTools,
})
} catch (error) {
runError = error
throw error
} finally {
try {
await close()
} catch (closeError) {
if (runError == null) {
throw closeError
}
console.error(closeError)
}
}
}
const ask = async (input: AgentInput) => {
const userPrompt = generateUserPrompt(input)
const messages = [...input.messages, userPrompt]
input.skills.forEach((skill) => enableSkill(skill))
const systemPrompt = await generateSystemPrompt()
const { tools, close } = await getAgentTools()
const { response, reasoning, text, usage, steps } = await generateText({
model,
const { response, reasoning, text, usage, steps } = await runTextGeneration({
messages,
system: systemPrompt,
...(providerOptions && { providerOptions }),
stopWhen: stepCountIs(Infinity),
systemPrompt,
prepareStep: () => {
return {
system: systemPrompt,
}
},
onFinish: async () => {
await close()
},
tools,
})
const stepUsages = buildStepUsages(steps)
const { cleanedText, attachments: textAttachments } =
@@ -257,22 +345,14 @@ export const createAgent = (
})
}
const messages = [...params.messages, userPrompt]
const { tools, close } = await getAgentTools()
const { response, reasoning, text, usage, steps } = await generateText({
model,
const { response, reasoning, text, usage, steps } = await runTextGeneration({
messages,
system: generateSubagentSystemPrompt(),
...(providerOptions && { providerOptions }),
stopWhen: stepCountIs(Infinity),
systemPrompt: generateSubagentSystemPrompt(),
prepareStep: () => {
return {
system: generateSubagentSystemPrompt(),
}
},
onFinish: async () => {
await close()
},
tools,
})
const stepUsages = buildStepUsages(steps)
return {
@@ -301,17 +381,9 @@ export const createAgent = (
}
const messages = [...params.messages, scheduleMessage]
params.skills.forEach((skill) => enableSkill(skill))
const { tools, close } = await getAgentTools()
const { response, reasoning, text, usage, steps } = await generateText({
model,
const { response, reasoning, text, usage, steps } = await runTextGeneration({
messages,
system: await generateSystemPrompt(),
...(providerOptions && { providerOptions }),
stopWhen: stepCountIs(Infinity),
onFinish: async () => {
await close()
},
tools,
systemPrompt: await generateSystemPrompt(),
})
const stepUsages = buildStepUsages(steps)
return {
@@ -341,17 +413,9 @@ export const createAgent = (
}
const messages = [...params.messages, heartbeatMessage]
params.skills.forEach((skill) => enableSkill(skill))
const { tools, close } = await getAgentTools()
const { response, reasoning, text, usage, steps } = await generateText({
model,
const { response, reasoning, text, usage, steps } = await runTextGeneration({
messages,
system: await generateSystemPrompt(),
...(providerOptions && { providerOptions }),
stopWhen: stepCountIs(Infinity),
onFinish: async () => {
await close()
},
tools,
systemPrompt: await generateSystemPrompt(),
})
const stepUsages = buildStepUsages(steps)
return {
@@ -392,6 +456,27 @@ export const createAgent = (
input.skills.forEach((skill) => enableSkill(skill))
const systemPrompt = await generateSystemPrompt()
const attachmentsExtractor = new AttachmentsStreamExtractor()
const textLoopGuard = loopDetectionEnabled
? createTextLoopGuard({
consecutiveHitsToAbort: LOOP_DETECTED_STREAK_THRESHOLD,
minNewGramsPerChunk: LOOP_DETECTED_MIN_NEW_GRAMS_PER_CHUNK,
})
: null
const guardLoopOutput = (text: string) => {
if (!textLoopGuard) {
return
}
const result = textLoopGuard.inspect(text)
if (result.abort) {
throw new Error(LOOP_DETECTED_ABORT_MESSAGE)
}
}
const textLoopProbeBuffer = textLoopGuard
? createTextLoopProbeBuffer(
LOOP_DETECTED_PROBE_CHARS,
guardLoopOutput,
)
: null
const result: {
messages: ModelMessage[];
reasoning: string[];
@@ -403,7 +488,20 @@ export const createAgent = (
usage: null,
usages: [],
}
const toolLoopAbortCallIds = new Set<string>()
const { tools, close } = await getAgentTools()
// Stream path needs deferred abort to keep tool_call_start/tool_call_end event pairing.
const guardedTools = buildGuardedTools(tools, (toolCallId) => {
toolLoopAbortCallIds.add(toolCallId)
})
let closePromise: Promise<void> | null = null
const closeTools = async () => {
if (!closePromise) {
closePromise = Promise.resolve().then(() => close())
}
await closePromise
}
let streamError: unknown = null
try {
const { fullStream } = streamText({
model,
@@ -416,9 +514,9 @@ export const createAgent = (
system: systemPrompt,
}
},
tools,
tools: guardedTools,
onFinish: async ({ usage, reasoning, response, steps }) => {
await close()
await closeTools()
result.usage = usage as never
result.reasoning = reasoning.map((part) => part.text)
result.messages = response.messages
@@ -464,6 +562,9 @@ export const createAgent = (
chunk.text,
)
if (visibleText) {
if (textLoopProbeBuffer) {
textLoopProbeBuffer.push(visibleText)
}
yield {
type: 'text_delta',
delta: visibleText,
@@ -481,11 +582,17 @@ export const createAgent = (
// Flush any remaining buffered content before ending the text stream.
const remainder = attachmentsExtractor.flushRemainder()
if (remainder.visibleText) {
if (textLoopProbeBuffer) {
textLoopProbeBuffer.push(remainder.visibleText)
}
yield {
type: 'text_delta',
delta: remainder.visibleText,
}
}
if (textLoopProbeBuffer) {
textLoopProbeBuffer.flush()
}
if (remainder.attachments.length) {
yield {
type: 'attachment_delta',
@@ -502,11 +609,17 @@ export const createAgent = (
// Flush any remaining buffered content before ending the text stream.
const remainder = attachmentsExtractor.flushRemainder()
if (remainder.visibleText) {
if (textLoopProbeBuffer) {
textLoopProbeBuffer.push(remainder.visibleText)
}
yield {
type: 'text_delta',
delta: remainder.visibleText,
}
}
if (textLoopProbeBuffer) {
textLoopProbeBuffer.flush()
}
if (remainder.attachments.length) {
yield {
type: 'attachment_delta',
@@ -522,6 +635,9 @@ export const createAgent = (
}
break
case 'tool-result':
// Always emit the terminal tool event first so downstream reducers
// can close the in-flight tool block before the stream aborts.
const shouldAbortForToolLoop = toolLoopAbortCallIds.delete(chunk.toolCallId)
yield {
type: 'tool_call_end',
toolName: chunk.toolName,
@@ -530,6 +646,9 @@ export const createAgent = (
result: chunk.output,
metadata: chunk,
}
if (shouldAbortForToolLoop) {
throw new Error(TOOL_LOOP_DETECTED_ABORT_MESSAGE)
}
break
case 'file':
yield {
@@ -544,6 +663,9 @@ export const createAgent = (
}
}
}
if (textLoopProbeBuffer) {
textLoopProbeBuffer.flush()
}
const { messages: strippedMessages } = stripAttachmentsFromMessages(
result.messages,
@@ -560,8 +682,18 @@ export const createAgent = (
skills: getEnabledSkills(),
}
} catch (error) {
streamError = error
console.error(error)
throw error
} finally {
try {
await closeTools()
} catch (closeError) {
if (streamError == null) {
throw closeError
}
console.error(closeError)
}
}
}
+265
View File
@@ -0,0 +1,265 @@
import { describe, expect, it } from 'vitest'
import {
createSential,
createTextLoopGuard,
createToolLoopGuard,
} from './sential'
describe('sential', () => {
it('does not hit when overlap stays low', () => {
const sential = createSential()
sential.inspect('ABCDEFGHIJKLMNO')
const result = sential.inspect('qrstuvwxyz12345')
expect(result.hit).toBe(false)
expect(result.overlap).toBe(0)
})
it('hits when overlap is above threshold', () => {
const sential = createSential()
sential.inspect('0123456789abcdefghij0123456789abcdefghij')
const result = sential.inspect('0123456789abcdefghij')
expect(result.hit).toBe(true)
expect(result.overlap).toBeGreaterThan(0.75)
})
it('does not hit when overlap is exactly threshold', () => {
const sential = createSential({
ngramSize: 1,
overlapThreshold: 0.75,
})
sential.inspect('aaaaaaaaaa')
const result = sential.inspect(`${'a'.repeat(15)}bcdef`)
expect(result.newGrams).toBe(20)
expect(result.matchedGrams).toBe(15)
expect(result.overlap).toBeCloseTo(0.75, 10)
expect(result.hit).toBe(false)
})
it('evicts old grams with sliding window', () => {
const sential = createSential({
windowSize: 20,
})
sential.inspect('abcdefghijabcdefghij')
sential.inspect('KLMNOPQRST')
sential.inspect('UVWXYZ1234')
const result = sential.inspect('abcdefghij')
expect(result.hit).toBe(false)
expect(result.matchedGrams).toBe(0)
})
it('aborts only after 10 consecutive hits', () => {
const guard = createTextLoopGuard({
ngramSize: 1,
overlapThreshold: 0.5,
consecutiveHitsToAbort: 10,
})
const seeded = guard.inspect('aaaaaaaaaa')
expect(seeded.hit).toBe(false)
expect(seeded.streak).toBe(0)
expect(seeded.abort).toBe(false)
for (let i = 1; i <= 9; i += 1) {
const result = guard.inspect('aaaaaaaaaa')
expect(result.hit).toBe(true)
expect(result.streak).toBe(i)
expect(result.abort).toBe(false)
}
const tenth = guard.inspect('aaaaaaaaaa')
expect(tenth.hit).toBe(true)
expect(tenth.streak).toBe(10)
expect(tenth.abort).toBe(true)
})
it('resets streak when a non-hit chunk appears', () => {
const guard = createTextLoopGuard({
ngramSize: 1,
overlapThreshold: 0.5,
consecutiveHitsToAbort: 10,
})
guard.inspect('aaaaaaaaaa')
for (let i = 0; i < 5; i += 1) {
guard.inspect('aaaaaaaaaa')
}
const miss = guard.inspect('bcdefghijk')
expect(miss.hit).toBe(false)
expect(miss.streak).toBe(0)
expect(miss.abort).toBe(false)
const hitAgain = guard.inspect('aaaaaaaaaa')
expect(hitAgain.hit).toBe(true)
expect(hitAgain.streak).toBe(1)
expect(hitAgain.abort).toBe(false)
})
it('only updates streak when chunk has enough new grams', () => {
const guard = createTextLoopGuard({
ngramSize: 1,
overlapThreshold: 0.5,
minNewGramsPerChunk: 5,
})
guard.inspect('aaaaaaaaaa')
const smallHit = guard.inspect('aaaa')
expect(smallHit.hit).toBe(true)
expect(smallHit.newGrams).toBe(4)
expect(smallHit.streak).toBe(0)
expect(smallHit.abort).toBe(false)
const countedHit = guard.inspect('aaaaa')
expect(countedHit.hit).toBe(true)
expect(countedHit.newGrams).toBe(5)
expect(countedHit.streak).toBe(1)
expect(countedHit.abort).toBe(false)
const smallMiss = guard.inspect('cccc')
expect(smallMiss.hit).toBe(false)
expect(smallMiss.newGrams).toBe(4)
expect(smallMiss.streak).toBe(1)
const countedMiss = guard.inspect('ddddd')
expect(countedMiss.hit).toBe(false)
expect(countedMiss.newGrams).toBe(5)
expect(countedMiss.streak).toBe(0)
})
it('warns on first tool-loop breach and aborts on second breach', () => {
const guard = createToolLoopGuard({
repeatThreshold: 5,
warningsBeforeAbort: 1,
})
const payload = {
toolName: 'web_fetch',
input: { url: 'https://example.com', requestId: 'r-1' },
}
for (let i = 1; i <= 5; i += 1) {
const result = guard.inspect(payload)
expect(result.warn).toBe(false)
expect(result.abort).toBe(false)
expect(result.repeatCount).toBe(i)
}
const firstBreach = guard.inspect(payload)
expect(firstBreach.warn).toBe(true)
expect(firstBreach.abort).toBe(false)
expect(firstBreach.breachCount).toBe(1)
expect(firstBreach.repeatCount).toBe(0)
for (let i = 1; i <= 5; i += 1) {
const result = guard.inspect(payload)
expect(result.warn).toBe(false)
expect(result.abort).toBe(false)
expect(result.repeatCount).toBe(i)
}
const secondBreach = guard.inspect(payload)
expect(secondBreach.warn).toBe(false)
expect(secondBreach.abort).toBe(true)
expect(secondBreach.breachCount).toBe(2)
})
it('resets tool-loop repeat count when hash changes', () => {
const guard = createToolLoopGuard({
repeatThreshold: 5,
warningsBeforeAbort: 1,
})
const first = guard.inspect({
toolName: 'web_fetch',
input: { url: 'https://a.example.com' },
})
expect(first.repeatCount).toBe(1)
const second = guard.inspect({
toolName: 'web_fetch',
input: { url: 'https://a.example.com' },
})
expect(second.repeatCount).toBe(2)
const changed = guard.inspect({
toolName: 'web_fetch',
input: { url: 'https://b.example.com' },
})
expect(changed.repeatCount).toBe(1)
expect(changed.warn).toBe(false)
expect(changed.abort).toBe(false)
})
it('resets tool-loop breach count when hash changes', () => {
const guard = createToolLoopGuard({
repeatThreshold: 1,
warningsBeforeAbort: 1,
})
// First fingerprint reaches warning phase.
guard.inspect({
toolName: 'web_fetch',
input: { url: 'https://a.example.com' },
})
const warned = guard.inspect({
toolName: 'web_fetch',
input: { url: 'https://a.example.com' },
})
expect(warned.warn).toBe(true)
expect(warned.breachCount).toBe(1)
// Switching fingerprint should restart warning/abort phase.
const switched = guard.inspect({
toolName: 'web_fetch',
input: { url: 'https://b.example.com' },
})
expect(switched.warn).toBe(false)
expect(switched.abort).toBe(false)
expect(switched.breachCount).toBe(0)
const warnedAgain = guard.inspect({
toolName: 'web_fetch',
input: { url: 'https://b.example.com' },
})
expect(warnedAgain.warn).toBe(true)
expect(warnedAgain.abort).toBe(false)
expect(warnedAgain.breachCount).toBe(1)
})
it('ignores volatile keys when computing tool-loop hash', () => {
const guard = createToolLoopGuard({
repeatThreshold: 1,
warningsBeforeAbort: 1,
})
const first = guard.inspect({
toolName: 'web_fetch',
input: {
url: 'https://example.com',
request_id: 'req-1',
updatedAt: '2026-02-28T00:00:00.000Z',
},
})
expect(first.warn).toBe(false)
expect(first.abort).toBe(false)
const second = guard.inspect({
toolName: 'web_fetch',
input: {
url: 'https://example.com',
request_id: 'req-2',
updatedAt: '2026-02-28T00:01:00.000Z',
},
})
expect(second.warn).toBe(true)
expect(second.abort).toBe(false)
})
})
+506
View File
@@ -0,0 +1,506 @@
import { createHash } from 'node:crypto'
export interface SentialOptions {
ngramSize?: number;
windowSize?: number;
overlapThreshold?: number;
}
export interface TextLoopGuardOptions extends SentialOptions {
consecutiveHitsToAbort?: number;
minNewGramsPerChunk?: number;
}
export interface SentialInspectResult {
hit: boolean;
overlap: number;
matchedGrams: number;
newGrams: number;
}
export interface TextLoopGuardInspectResult extends SentialInspectResult {
streak: number;
abort: boolean;
}
export interface ToolLoopInspectInput {
toolName: string;
input: unknown;
}
export interface ToolLoopInspectResult {
hash: string;
repeatCount: number;
breachCount: number;
warn: boolean;
abort: boolean;
}
export interface ToolLoopGuardOptions {
repeatThreshold?: number;
warningsBeforeAbort?: number;
volatileKeys?: string[];
}
export interface Sential {
inspect(text: string): SentialInspectResult;
reset(): void;
}
export interface TextLoopGuard {
inspect(text: string): TextLoopGuardInspectResult;
reset(): void;
}
export interface TextLoopProbeBuffer {
push(text: string): void;
flush(): void;
}
export interface ToolLoopGuard {
inspect(input: ToolLoopInspectInput): ToolLoopInspectResult;
reset(): void;
}
const DEFAULT_NGRAM_SIZE = 10
const DEFAULT_WINDOW_SIZE = 1000
const DEFAULT_OVERLAP_THRESHOLD = 0.75
const DEFAULT_CONSECUTIVE_HITS_TO_ABORT = 10
const DEFAULT_MIN_NEW_GRAMS_PER_CHUNK = 1
const DEFAULT_TOOL_LOOP_REPEAT_THRESHOLD = 5
const DEFAULT_TOOL_LOOP_WARNINGS_BEFORE_ABORT = 1
const DEFAULT_VOLATILE_KEYS = [
'toolCallId',
'tool_call_id',
'requestId',
'request_id',
'traceId',
'trace_id',
'spanId',
'span_id',
'sessionId',
'session_id',
'timestamp',
'createdAt',
'created_at',
'updatedAt',
'updated_at',
'expiresAt',
'expires_at',
'nonce',
]
const VOLATILE_KEY_SUFFIXES = [
'requestid',
'traceid',
'sessionid',
'toolcallid',
'timestamp',
'createdat',
'updatedat',
'expiresat',
]
type NormalizedValue =
| null
| string
| number
| boolean
| NormalizedValue[]
| { [key: string]: NormalizedValue };
function validatePositiveInt(name: string, value: number): number {
if (!Number.isFinite(value) || value <= 0 || !Number.isInteger(value)) {
throw new Error(`${name} must be a positive integer`)
}
return value
}
function validateThreshold(value: number): number {
if (!Number.isFinite(value) || value < 0 || value > 1) {
throw new Error('overlapThreshold must be between 0 and 1')
}
return value
}
function normalizeChars(text: string): string[] {
if (!text) return []
return Array.from(text.normalize('NFC'))
}
function buildNgram(chars: string[], start: number, size: number): string {
return chars.slice(start, start + size).join('')
}
function normalizeKeyName(key: string): string {
return key
.trim()
.toLowerCase()
.replace(/[^a-z0-9]/g, '')
}
function isVolatileKey(key: string, volatileKeySet: Set<string>): boolean {
const normalized = normalizeKeyName(key)
if (!normalized) return false
if (volatileKeySet.has(normalized)) return true
return VOLATILE_KEY_SUFFIXES.some((suffix) => normalized.endsWith(suffix))
}
function isPlainObject(value: unknown): value is Record<string, unknown> {
if (value === null || typeof value !== 'object') return false
const prototype = Object.getPrototypeOf(value as object)
return prototype === Object.prototype || prototype === null
}
function normalizeToolLoopValue(
value: unknown,
volatileKeySet: Set<string>,
seen: WeakSet<object>,
): NormalizedValue | undefined {
if (value === null) return null
if (typeof value === 'string') return value.normalize('NFC')
if (typeof value === 'boolean') return value
if (typeof value === 'number') {
return Number.isFinite(value) ? value : (String(value) as NormalizedValue)
}
if (typeof value === 'bigint') return value.toString()
if (
typeof value === 'undefined' ||
typeof value === 'function' ||
typeof value === 'symbol'
) {
return undefined
}
if (value instanceof Date) {
return value.toISOString()
}
if (Array.isArray(value)) {
return value.map(
(item) => normalizeToolLoopValue(item, volatileKeySet, seen) ?? null,
)
}
if (!isPlainObject(value)) {
const maybeRecord = value as { toJSON?: () => unknown }
if (typeof maybeRecord.toJSON === 'function') {
return normalizeToolLoopValue(maybeRecord.toJSON(), volatileKeySet, seen)
}
return String(value)
}
if (seen.has(value)) {
return '[Circular]'
}
seen.add(value)
const normalizedObject: { [key: string]: NormalizedValue } = {}
const keys = Object.keys(value).sort()
for (const key of keys) {
if (isVolatileKey(key, volatileKeySet)) {
continue
}
const normalized = normalizeToolLoopValue(value[key], volatileKeySet, seen)
if (normalized !== undefined) {
normalizedObject[key] = normalized
}
}
seen.delete(value)
return normalizedObject
}
function computeToolLoopHash(
input: ToolLoopInspectInput,
volatileKeySet: Set<string>,
): string {
const payload = {
toolName: input.toolName.trim(),
input:
normalizeToolLoopValue(input.input, volatileKeySet, new WeakSet()) ??
null,
}
const serialized = JSON.stringify(payload)
return createHash('sha256').update(serialized).digest('hex')
}
export function createSential(options: SentialOptions = {}): Sential {
const ngramSize = validatePositiveInt(
'ngramSize',
options.ngramSize ?? DEFAULT_NGRAM_SIZE,
)
const windowSize = validatePositiveInt(
'windowSize',
options.windowSize ?? DEFAULT_WINDOW_SIZE,
)
const overlapThreshold = validateThreshold(
options.overlapThreshold ?? DEFAULT_OVERLAP_THRESHOLD,
)
if (windowSize < ngramSize) {
throw new Error('windowSize must be greater than or equal to ngramSize')
}
const windowChars: string[] = []
const windowNgramQueue: string[] = []
const historySet = new Set<string>()
const historyCounts = new Map<string, number>()
const addHistoryGram = (gram: string) => {
const nextCount = (historyCounts.get(gram) ?? 0) + 1
historyCounts.set(gram, nextCount)
if (nextCount === 1) {
historySet.add(gram)
}
}
const removeHistoryGram = (gram: string) => {
const prevCount = historyCounts.get(gram)
if (!prevCount) return
if (prevCount <= 1) {
historyCounts.delete(gram)
historySet.delete(gram)
return
}
historyCounts.set(gram, prevCount - 1)
}
const pushWindowChar = (char: string) => {
windowChars.push(char)
if (windowChars.length >= ngramSize) {
const gram = buildNgram(
windowChars,
windowChars.length - ngramSize,
ngramSize,
)
windowNgramQueue.push(gram)
addHistoryGram(gram)
}
if (windowChars.length <= windowSize) {
return
}
windowChars.shift()
const removedGram = windowNgramQueue.shift()
if (removedGram) {
removeHistoryGram(removedGram)
}
}
return {
inspect(text: string): SentialInspectResult {
const incomingChars = normalizeChars(text)
if (incomingChars.length === 0) {
return {
hit: false,
overlap: 0,
matchedGrams: 0,
newGrams: 0,
}
}
const contextSize = Math.max(ngramSize - 1, 0)
const contextChars =
contextSize > 0 ? windowChars.slice(-contextSize) : []
const candidateChars = [...contextChars, ...incomingChars]
let matchedGrams = 0
let newGrams = 0
const contextLength = contextChars.length
if (candidateChars.length >= ngramSize) {
for (let i = 0; i <= candidateChars.length - ngramSize; i += 1) {
const gramEndIndex = i + ngramSize - 1
if (gramEndIndex < contextLength) {
continue
}
const gram = buildNgram(candidateChars, i, ngramSize)
newGrams += 1
if (historySet.has(gram)) {
matchedGrams += 1
}
}
}
const overlap = newGrams === 0 ? 0 : matchedGrams / newGrams
const hit = overlap > overlapThreshold
for (const char of incomingChars) {
pushWindowChar(char)
}
return {
hit,
overlap,
matchedGrams,
newGrams,
}
},
reset(): void {
windowChars.length = 0
windowNgramQueue.length = 0
historySet.clear()
historyCounts.clear()
},
}
}
export function createTextLoopGuard(
options: TextLoopGuardOptions = {},
): TextLoopGuard {
const consecutiveHitsToAbort = validatePositiveInt(
'consecutiveHitsToAbort',
options.consecutiveHitsToAbort ?? DEFAULT_CONSECUTIVE_HITS_TO_ABORT,
)
const minNewGramsPerChunk = validatePositiveInt(
'minNewGramsPerChunk',
options.minNewGramsPerChunk ?? DEFAULT_MIN_NEW_GRAMS_PER_CHUNK,
)
const sential = createSential(options)
let streak = 0
return {
inspect(text: string): TextLoopGuardInspectResult {
const result = sential.inspect(text)
if (result.newGrams >= minNewGramsPerChunk) {
if (result.hit) {
streak += 1
} else {
streak = 0
}
}
return {
...result,
streak,
abort: streak >= consecutiveHitsToAbort,
}
},
reset(): void {
sential.reset()
streak = 0
},
}
}
export function createTextLoopProbeBuffer(
chunkSize: number,
inspect: (text: string) => void,
): TextLoopProbeBuffer {
validatePositiveInt('chunkSize', chunkSize)
let chars: string[] = []
let offset = 0
const compact = () => {
if (offset > 0) {
chars = chars.slice(offset)
offset = 0
}
}
const inspectChunk = (text: string) => {
if (text.length > 0) {
inspect(text)
}
}
return {
push(text: string): void {
if (!text) return
chars.push(...normalizeChars(text))
while (chars.length - offset >= chunkSize) {
const chunk = chars.slice(offset, offset + chunkSize).join('')
offset += chunkSize
inspectChunk(chunk)
}
// Prevent unbounded front-gaps after many chunks.
if (offset >= chunkSize) {
compact()
}
},
flush(): void {
if (chars.length - offset > 0) {
const remainder = chars.slice(offset).join('')
inspectChunk(remainder)
}
chars = []
offset = 0
},
}
}
export function createToolLoopGuard(
options: ToolLoopGuardOptions = {},
): ToolLoopGuard {
const repeatThreshold = validatePositiveInt(
'repeatThreshold',
options.repeatThreshold ?? DEFAULT_TOOL_LOOP_REPEAT_THRESHOLD,
)
const warningsBeforeAbort = validatePositiveInt(
'warningsBeforeAbort',
options.warningsBeforeAbort ?? DEFAULT_TOOL_LOOP_WARNINGS_BEFORE_ABORT,
)
const volatileKeySet = new Set(DEFAULT_VOLATILE_KEYS.map(normalizeKeyName))
for (const key of options.volatileKeys ?? []) {
const normalizedKey = normalizeKeyName(key)
if (normalizedKey) {
volatileKeySet.add(normalizedKey)
}
}
let lastHash = ''
let repeatCount = 0
let breachCount = 0
let breachHash = ''
return {
inspect(input: ToolLoopInspectInput): ToolLoopInspectResult {
const hash = computeToolLoopHash(input, volatileKeySet)
if (hash === lastHash) {
repeatCount += 1
} else {
lastHash = hash
repeatCount = 1
}
// Breach phase is fingerprint-specific: switching tool signature restarts it.
if (breachHash !== hash) {
breachHash = hash
breachCount = 0
}
let warn = false
let abort = false
if (repeatCount > repeatThreshold) {
if (breachCount < warningsBeforeAbort) {
breachCount += 1
warn = true
// Reset consecutive accumulation after first warning.
lastHash = ''
repeatCount = 0
} else {
breachCount += 1
abort = true
}
}
return {
hash,
repeatCount,
breachCount,
warn,
abort,
}
},
reset(): void {
lastHash = ''
repeatCount = 0
breachCount = 0
breachHash = ''
},
}
}
+93
View File
@@ -0,0 +1,93 @@
import { describe, expect, it, vi } from 'vitest'
import type { ToolSet } from 'ai'
import { createToolLoopGuardedTools } from './tool-loop'
describe('tool loop guarded tools', () => {
it('preserves promised async-iterable tool outputs', async () => {
const onAbortToolCall = vi.fn()
const streamedChunks = ['chunk-1', 'chunk-2']
const stream = {
async *[Symbol.asyncIterator]() {
for (const chunk of streamedChunks) {
yield chunk
}
},
}
const baseTools = {
streamy: {
execute: async () => stream,
},
} as unknown as ToolSet
const tools = createToolLoopGuardedTools(baseTools, {
repeatThreshold: 1,
warningsBeforeAbort: 1,
onAbortToolCall,
warningKey: '__warn',
warningText: 'loop warning',
})
const output = await tools.streamy.execute?.({ value: 'same' } as never, { toolCallId: 't-stream' } as never)
expect(output).toBe(stream)
const received: string[] = []
for await (const chunk of output as AsyncIterable<string>) {
received.push(chunk)
}
expect(received).toEqual(streamedChunks)
expect(onAbortToolCall).not.toHaveBeenCalled()
})
it('defers abort to stream layer when onAbortToolCall is provided', async () => {
const onAbortToolCall = vi.fn()
const baseTools = {
echo: {
execute: async (input: unknown) => ({ result: input }),
},
} as unknown as ToolSet
const tools = createToolLoopGuardedTools(baseTools, {
repeatThreshold: 1,
warningsBeforeAbort: 1,
onAbortToolCall,
warningKey: '__warn',
warningText: 'loop warning',
})
await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-1' } as never)
const warned = await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-1' } as never)
expect(warned).toMatchObject({
__warn: {
marker: 'MEMOH_TOOL_LOOP_WARNING',
},
})
await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-1' } as never)
const abortedOutput = await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-1' } as never)
expect(onAbortToolCall).toHaveBeenCalledWith('t-1')
expect(abortedOutput).toEqual({ result: { value: 'same' } })
})
it('reports abort via callback without throwing inside tool execution', async () => {
const onAbortToolCall = vi.fn()
const baseTools = {
echo: {
execute: async (input: unknown) => ({ result: input }),
},
} as unknown as ToolSet
const tools = createToolLoopGuardedTools(baseTools, {
repeatThreshold: 1,
warningsBeforeAbort: 1,
onAbortToolCall,
warningKey: '__warn',
warningText: 'loop warning',
})
await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-2' } as never)
await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-2' } as never)
await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-2' } as never)
const abortedOutput = await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-2' } as never)
expect(onAbortToolCall).toHaveBeenCalledWith('t-2')
expect(abortedOutput).toEqual({ result: { value: 'same' } })
})
})
+122
View File
@@ -0,0 +1,122 @@
import { ToolExecutionOptions, ToolSet } from 'ai'
import { createToolLoopGuard, type ToolLoopInspectResult } from './sential'
export interface CreateToolLoopGuardedToolsOptions {
repeatThreshold: number
warningsBeforeAbort: number
onAbortToolCall: (toolCallId: string) => void
warningKey: string
warningText: string
}
const isRecord = (value: unknown): value is Record<string, unknown> => {
return value !== null && typeof value === 'object' && !Array.isArray(value)
}
const isAsyncIterable = (value: unknown): value is AsyncIterable<unknown> => {
return (
value !== null &&
typeof value === 'object' &&
Symbol.asyncIterator in value
)
}
const injectToolLoopWarning = (
output: unknown,
inspectResult: ToolLoopInspectResult,
warningKey: string,
warningText: string,
): unknown => {
// Keep warning payload structured so UI/consumers can detect and render it.
const warningPayload = {
marker: 'MEMOH_TOOL_LOOP_WARNING',
message: warningText,
fingerprint: inspectResult.hash,
breachCount: inspectResult.breachCount,
}
if (isRecord(output)) {
return {
...output,
[warningKey]: warningPayload,
}
}
return {
[warningKey]: warningPayload,
result: output,
}
}
export function createToolLoopGuardedTools(
tools: ToolSet,
{
repeatThreshold,
warningsBeforeAbort,
onAbortToolCall,
warningKey,
warningText,
}: CreateToolLoopGuardedToolsOptions,
): ToolSet {
const guard = createToolLoopGuard({
repeatThreshold,
warningsBeforeAbort,
})
// Wrap each executable tool to inspect (toolName + input) after execution.
// First breach injects a warning into this tool result; second breach signals abort.
return Object.fromEntries(
Object.entries(tools).map(([toolName, toolDefinition]) => {
const execute = toolDefinition.execute
if (typeof execute !== 'function') {
return [toolName, toolDefinition]
}
const wrappedTool = {
...toolDefinition,
execute: (
toolInput: unknown,
options: ToolExecutionOptions,
) => {
const directOutput = execute(
toolInput as never,
options as never,
) as unknown
// Streamed tool outputs are passed through unchanged to preserve streaming semantics.
if (isAsyncIterable(directOutput)) {
return directOutput as never
}
return (async () => {
const resolvedOutput = await directOutput
// Tools may return Promise<AsyncIterable>; keep that stream untouched too.
if (isAsyncIterable(resolvedOutput)) {
return resolvedOutput as never
}
const inspectResult = guard.inspect({
toolName,
input: toolInput,
})
if (inspectResult.abort) {
// Report loop abort to generation layer; it decides when/how to stop.
onAbortToolCall(options.toolCallId)
return resolvedOutput as never
}
if (inspectResult.warn) {
return injectToolLoopWarning(
resolvedOutput,
inspectResult,
warningKey,
warningText,
) as never
}
return resolvedOutput as never
})()
},
}
return [toolName, wrappedTool]
}),
) as ToolSet
}
+5
View File
@@ -38,6 +38,10 @@ export interface InboxItem {
createdAt: string
}
export interface LoopDetectionConfig {
enabled?: boolean
}
export interface AgentParams {
model: ModelConfig
language?: string
@@ -50,6 +54,7 @@ export interface AgentParams {
auth: AgentAuthContext
skills?: AgentSkill[]
inbox?: InboxItem[]
loopDetection?: LoopDetectionConfig
}
export interface AgentInput {
+1
View File
@@ -487,6 +487,7 @@
"heartbeatModelPlaceholder": "Use chat model (default)",
"allowGuest": "Allow Guest Access",
"allowGuestPersonalHint": "Personal bots do not support guest access. Use a public bot instead.",
"loopDetectionTitle": "Detect and auto-block output loops",
"searchModel": "Search models…",
"noModel": "No models available",
"saveSuccess": "Settings saved",
+1
View File
@@ -483,6 +483,7 @@
"heartbeatModelPlaceholder": "使用聊天模型(默认)",
"allowGuest": "允许游客访问",
"allowGuestPersonalHint": "个人 Bot 不支持游客访问,请使用公开 Bot。",
"loopDetectionTitle": "自动检测并阻止模型循环输出",
"searchModel": "搜索模型…",
"noModel": "暂无可选模型",
"saveSuccess": "设置已保存",
@@ -134,6 +134,17 @@
<Separator />
</template>
<!-- Loop Detection -->
<div class="flex items-center justify-between">
<Label>{{ $t('bots.settings.loopDetectionTitle') }}</Label>
<Switch
:model-value="loopDetectionEnabled"
:disabled="isLoopDetectionToggleDisabled"
@update:model-value="handleLoopDetectionToggle"
/>
</div>
<!-- Save -->
<div class="flex justify-end">
<Button
@@ -195,7 +206,7 @@ import {
SelectTrigger,
SelectValue,
} from '@memoh/ui'
import { reactive, computed, watch } from 'vue'
import { reactive, computed, watch, ref } from 'vue'
import { useRouter } from 'vue-router'
import { toast } from 'vue-sonner'
import { useI18n } from 'vue-i18n'
@@ -203,7 +214,7 @@ import ConfirmPopover from '@/components/confirm-popover/index.vue'
import ModelSelect from './model-select.vue'
import SearchProviderSelect from './search-provider-select.vue'
import { useQuery, useMutation, useQueryCache } from '@pinia/colada'
import { getBotsByBotIdSettings, putBotsByBotIdSettings, deleteBotsById, getModels, getProviders, getSearchProviders } from '@memoh/sdk'
import { getBotsByBotIdSettings, putBotsByBotIdSettings, deleteBotsById, getModels, getProviders, getSearchProviders, getBotsById, putBotsById } from '@memoh/sdk'
import type { SettingsSettings } from '@memoh/sdk'
import type { Ref } from 'vue'
import { resolveApiErrorMessage } from '@/utils/api-error'
@@ -232,6 +243,15 @@ const { data: settings } = useQuery({
enabled: () => !!botIdRef.value,
})
const { data: botProfile, isLoading: isBotProfileLoading } = useQuery({
key: () => ['bot', botIdRef.value],
query: async () => {
const { data } = await getBotsById({ path: { id: botIdRef.value }, throwOnError: true })
return data
},
enabled: () => !!botIdRef.value,
})
const { data: modelData } = useQuery({
key: ['all-models'],
query: async () => {
@@ -278,6 +298,21 @@ const { mutateAsync: deleteBot, isLoading: deleteLoading } = useMutation({
},
})
const { mutateAsync: updateBotProfile, isLoading: updateBotProfileLoading } = useMutation({
mutation: async (metadata: Record<string, unknown>) => {
const { data } = await putBotsById({
path: { id: botIdRef.value },
body: { metadata },
throwOnError: true,
})
return data
},
onSettled: () => {
queryCache.invalidateQueries({ key: ['bots'] })
queryCache.invalidateQueries({ key: ['bot', botIdRef.value] })
},
})
const models = computed(() => modelData.value ?? [])
const providers = computed(() => providerData.value ?? [])
const searchProviders = computed(() => searchProviderData.value ?? [])
@@ -302,6 +337,23 @@ const form = reactive({
reasoning_effort: 'medium',
})
const loopDetectionEnabled = ref(false)
const loopDetectionMetadata = ref<Record<string, unknown> | null>(null)
const asRecord = (value: unknown): Record<string, unknown> => {
if (value && typeof value === 'object' && !Array.isArray(value)) {
return value as Record<string, unknown>
}
return {}
}
const readLoopDetectionEnabled = (metadata: unknown): boolean => {
const metadataRecord = asRecord(metadata)
const featuresRecord = asRecord(metadataRecord.features)
const loopDetectionRecord = asRecord(featuresRecord.loop_detection)
return loopDetectionRecord.enabled === true
}
watch(settings, (val) => {
if (val) {
form.chat_model_id = val.chat_model_id ?? ''
@@ -317,6 +369,13 @@ watch(settings, (val) => {
}
}, { immediate: true })
watch(botProfile, (val) => {
if (val === undefined) return
const metadata = asRecord(val?.metadata)
loopDetectionMetadata.value = metadata
loopDetectionEnabled.value = readLoopDetectionEnabled(metadata)
}, { immediate: true })
const hasChanges = computed(() => {
if (!settings.value) return true
const s = settings.value
@@ -336,6 +395,40 @@ const hasChanges = computed(() => {
return changed
})
const isLoopDetectionToggleDisabled = computed(() =>
updateBotProfileLoading.value || isBotProfileLoading.value || loopDetectionMetadata.value === null,
)
async function handleLoopDetectionToggle(value: boolean) {
if (isLoopDetectionToggleDisabled.value || loopDetectionMetadata.value === null) return
const nextEnabled = value === true
const prevEnabled = loopDetectionEnabled.value
if (nextEnabled === prevEnabled) return
loopDetectionEnabled.value = nextEnabled
const currentMetadata = loopDetectionMetadata.value
const currentFeatures = asRecord(currentMetadata.features)
const currentLoopDetection = asRecord(currentFeatures.loop_detection)
const nextMetadata = {
...currentMetadata,
features: {
...currentFeatures,
loop_detection: {
...currentLoopDetection,
enabled: nextEnabled,
},
},
}
try {
await updateBotProfile(nextMetadata)
loopDetectionMetadata.value = nextMetadata
} catch (error) {
loopDetectionEnabled.value = prevEnabled
toast.error(resolveApiErrorMessage(error, t('common.saveFailed')))
}
}
async function handleSave() {
try {
await updateSettings({ ...form })