feat(agent): loop detection (#152)

* feat(loop-detection): add configurable text and tool loop guards

* style(web): remove duplicate separator in bot settings
This commit is contained in:
Ringo.Typowriter
2026-03-02 15:00:09 +08:00
committed by GitHub
parent 04bce702b7
commit d3edd17d90
13 changed files with 1381 additions and 53 deletions
+174 -42
View File
@@ -33,9 +33,20 @@ import { getMCPTools } from './tools/mcp'
import { getTools } from './tools'
import { buildIdentityHeaders } from './utils/headers'
import { createFS } from './utils'
import { createTextLoopGuard, createTextLoopProbeBuffer } from './sential'
import { createToolLoopGuardedTools } from './tool-loop'
const ANTHROPIC_BUDGET: Record<string, number> = { low: 5000, medium: 16000, high: 50000 }
const GOOGLE_BUDGET: Record<string, number> = { low: 5000, medium: 16000, high: 50000 }
const LOOP_DETECTED_ABORT_MESSAGE = 'loop detected, stream aborted'
const LOOP_DETECTED_STREAK_THRESHOLD = 3
const LOOP_DETECTED_MIN_NEW_GRAMS_PER_CHUNK = 8
const LOOP_DETECTED_PROBE_CHARS = 256
const TOOL_LOOP_DETECTED_ABORT_MESSAGE = 'tool loop detected, stream aborted'
const TOOL_LOOP_REPEAT_THRESHOLD = 5
const TOOL_LOOP_WARNINGS_BEFORE_ABORT = 1
const TOOL_LOOP_WARNING_KEY = '__memoh_tool_loop_warning'
const TOOL_LOOP_WARNING_TEXT = '[MEMOH_TOOL_LOOP_WARNING] Repeated identical tool invocation (same tool + arguments) was detected more than 5 times. Stop looping this tool and either summarize current results or change strategy.'
const buildProviderOptions = (config: ModelConfig): Record<string, Record<string, unknown>> | undefined => {
if (!config.reasoning?.enabled) return undefined
@@ -93,12 +104,14 @@ export const createAgent = (
},
auth,
inbox = [],
loopDetection = { enabled: false },
}: AgentParams,
fetch: AuthFetcher,
) => {
const model = createModel(modelConfig)
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const providerOptions = buildProviderOptions(modelConfig) as any
const loopDetectionEnabled = loopDetection?.enabled === true
const enabledSkills: AgentSkill[] = []
const fs = createFS({ fetch, botId: identity.botId })
@@ -194,27 +207,102 @@ export const createAgent = (
return userMessage
}
const createNonStreamTextLoopInspector = () => {
if (!loopDetectionEnabled) {
return null
}
const textLoopGuard = createTextLoopGuard({
consecutiveHitsToAbort: LOOP_DETECTED_STREAK_THRESHOLD,
minNewGramsPerChunk: LOOP_DETECTED_MIN_NEW_GRAMS_PER_CHUNK,
})
return (text: string) => {
const result = textLoopGuard.inspect(text)
if (result.abort) {
throw new Error(LOOP_DETECTED_ABORT_MESSAGE)
}
}
}
const buildGuardedTools = (
tools: ToolSet,
onAbortToolCall: (toolCallId: string) => void = () => {},
): ToolSet => {
if (!loopDetectionEnabled) {
return tools
}
return createToolLoopGuardedTools(tools, {
repeatThreshold: TOOL_LOOP_REPEAT_THRESHOLD,
warningsBeforeAbort: TOOL_LOOP_WARNINGS_BEFORE_ABORT,
onAbortToolCall,
warningKey: TOOL_LOOP_WARNING_KEY,
warningText: TOOL_LOOP_WARNING_TEXT,
})
}
const runTextGeneration = async ({
messages,
systemPrompt,
prepareStep,
}: {
messages: ModelMessage[]
systemPrompt: string
prepareStep?: () => { system: string }
}) => {
const { tools, close } = await getAgentTools()
let shouldAbortForToolLoop = false
const guardedTools = buildGuardedTools(tools, () => {
shouldAbortForToolLoop = true
})
const inspectTextLoop = createNonStreamTextLoopInspector()
let runError: unknown = null
try {
return await generateText({
model,
messages,
system: systemPrompt,
...(providerOptions && { providerOptions }),
stopWhen: stepCountIs(Infinity),
...(prepareStep && { prepareStep }),
...(loopDetectionEnabled && {
onStepFinish: ({ text }: { text: string }) => {
if (shouldAbortForToolLoop) {
throw new Error(TOOL_LOOP_DETECTED_ABORT_MESSAGE)
}
if (inspectTextLoop) {
inspectTextLoop(text)
}
},
}),
tools: guardedTools,
})
} catch (error) {
runError = error
throw error
} finally {
try {
await close()
} catch (closeError) {
if (runError == null) {
throw closeError
}
console.error(closeError)
}
}
}
const ask = async (input: AgentInput) => {
const userPrompt = generateUserPrompt(input)
const messages = [...input.messages, userPrompt]
input.skills.forEach((skill) => enableSkill(skill))
const systemPrompt = await generateSystemPrompt()
const { tools, close } = await getAgentTools()
const { response, reasoning, text, usage, steps } = await generateText({
model,
const { response, reasoning, text, usage, steps } = await runTextGeneration({
messages,
system: systemPrompt,
...(providerOptions && { providerOptions }),
stopWhen: stepCountIs(Infinity),
systemPrompt,
prepareStep: () => {
return {
system: systemPrompt,
}
},
onFinish: async () => {
await close()
},
tools,
})
const stepUsages = buildStepUsages(steps)
const { cleanedText, attachments: textAttachments } =
@@ -257,22 +345,14 @@ export const createAgent = (
})
}
const messages = [...params.messages, userPrompt]
const { tools, close } = await getAgentTools()
const { response, reasoning, text, usage, steps } = await generateText({
model,
const { response, reasoning, text, usage, steps } = await runTextGeneration({
messages,
system: generateSubagentSystemPrompt(),
...(providerOptions && { providerOptions }),
stopWhen: stepCountIs(Infinity),
systemPrompt: generateSubagentSystemPrompt(),
prepareStep: () => {
return {
system: generateSubagentSystemPrompt(),
}
},
onFinish: async () => {
await close()
},
tools,
})
const stepUsages = buildStepUsages(steps)
return {
@@ -301,17 +381,9 @@ export const createAgent = (
}
const messages = [...params.messages, scheduleMessage]
params.skills.forEach((skill) => enableSkill(skill))
const { tools, close } = await getAgentTools()
const { response, reasoning, text, usage, steps } = await generateText({
model,
const { response, reasoning, text, usage, steps } = await runTextGeneration({
messages,
system: await generateSystemPrompt(),
...(providerOptions && { providerOptions }),
stopWhen: stepCountIs(Infinity),
onFinish: async () => {
await close()
},
tools,
systemPrompt: await generateSystemPrompt(),
})
const stepUsages = buildStepUsages(steps)
return {
@@ -341,17 +413,9 @@ export const createAgent = (
}
const messages = [...params.messages, heartbeatMessage]
params.skills.forEach((skill) => enableSkill(skill))
const { tools, close } = await getAgentTools()
const { response, reasoning, text, usage, steps } = await generateText({
model,
const { response, reasoning, text, usage, steps } = await runTextGeneration({
messages,
system: await generateSystemPrompt(),
...(providerOptions && { providerOptions }),
stopWhen: stepCountIs(Infinity),
onFinish: async () => {
await close()
},
tools,
systemPrompt: await generateSystemPrompt(),
})
const stepUsages = buildStepUsages(steps)
return {
@@ -392,6 +456,27 @@ export const createAgent = (
input.skills.forEach((skill) => enableSkill(skill))
const systemPrompt = await generateSystemPrompt()
const attachmentsExtractor = new AttachmentsStreamExtractor()
const textLoopGuard = loopDetectionEnabled
? createTextLoopGuard({
consecutiveHitsToAbort: LOOP_DETECTED_STREAK_THRESHOLD,
minNewGramsPerChunk: LOOP_DETECTED_MIN_NEW_GRAMS_PER_CHUNK,
})
: null
const guardLoopOutput = (text: string) => {
if (!textLoopGuard) {
return
}
const result = textLoopGuard.inspect(text)
if (result.abort) {
throw new Error(LOOP_DETECTED_ABORT_MESSAGE)
}
}
const textLoopProbeBuffer = textLoopGuard
? createTextLoopProbeBuffer(
LOOP_DETECTED_PROBE_CHARS,
guardLoopOutput,
)
: null
const result: {
messages: ModelMessage[];
reasoning: string[];
@@ -403,7 +488,20 @@ export const createAgent = (
usage: null,
usages: [],
}
const toolLoopAbortCallIds = new Set<string>()
const { tools, close } = await getAgentTools()
// Stream path needs deferred abort to keep tool_call_start/tool_call_end event pairing.
const guardedTools = buildGuardedTools(tools, (toolCallId) => {
toolLoopAbortCallIds.add(toolCallId)
})
let closePromise: Promise<void> | null = null
const closeTools = async () => {
if (!closePromise) {
closePromise = Promise.resolve().then(() => close())
}
await closePromise
}
let streamError: unknown = null
try {
const { fullStream } = streamText({
model,
@@ -416,9 +514,9 @@ export const createAgent = (
system: systemPrompt,
}
},
tools,
tools: guardedTools,
onFinish: async ({ usage, reasoning, response, steps }) => {
await close()
await closeTools()
result.usage = usage as never
result.reasoning = reasoning.map((part) => part.text)
result.messages = response.messages
@@ -464,6 +562,9 @@ export const createAgent = (
chunk.text,
)
if (visibleText) {
if (textLoopProbeBuffer) {
textLoopProbeBuffer.push(visibleText)
}
yield {
type: 'text_delta',
delta: visibleText,
@@ -481,11 +582,17 @@ export const createAgent = (
// Flush any remaining buffered content before ending the text stream.
const remainder = attachmentsExtractor.flushRemainder()
if (remainder.visibleText) {
if (textLoopProbeBuffer) {
textLoopProbeBuffer.push(remainder.visibleText)
}
yield {
type: 'text_delta',
delta: remainder.visibleText,
}
}
if (textLoopProbeBuffer) {
textLoopProbeBuffer.flush()
}
if (remainder.attachments.length) {
yield {
type: 'attachment_delta',
@@ -502,11 +609,17 @@ export const createAgent = (
// Flush any remaining buffered content before ending the text stream.
const remainder = attachmentsExtractor.flushRemainder()
if (remainder.visibleText) {
if (textLoopProbeBuffer) {
textLoopProbeBuffer.push(remainder.visibleText)
}
yield {
type: 'text_delta',
delta: remainder.visibleText,
}
}
if (textLoopProbeBuffer) {
textLoopProbeBuffer.flush()
}
if (remainder.attachments.length) {
yield {
type: 'attachment_delta',
@@ -522,6 +635,9 @@ export const createAgent = (
}
break
case 'tool-result':
// Always emit the terminal tool event first so downstream reducers
// can close the in-flight tool block before the stream aborts.
const shouldAbortForToolLoop = toolLoopAbortCallIds.delete(chunk.toolCallId)
yield {
type: 'tool_call_end',
toolName: chunk.toolName,
@@ -530,6 +646,9 @@ export const createAgent = (
result: chunk.output,
metadata: chunk,
}
if (shouldAbortForToolLoop) {
throw new Error(TOOL_LOOP_DETECTED_ABORT_MESSAGE)
}
break
case 'file':
yield {
@@ -544,6 +663,9 @@ export const createAgent = (
}
}
}
if (textLoopProbeBuffer) {
textLoopProbeBuffer.flush()
}
const { messages: strippedMessages } = stripAttachmentsFromMessages(
result.messages,
@@ -560,8 +682,18 @@ export const createAgent = (
skills: getEnabledSkills(),
}
} catch (error) {
streamError = error
console.error(error)
throw error
} finally {
try {
await closeTools()
} catch (closeError) {
if (streamError == null) {
throw closeError
}
console.error(closeError)
}
}
}
+265
View File
@@ -0,0 +1,265 @@
import { describe, expect, it } from 'vitest'
import {
createSential,
createTextLoopGuard,
createToolLoopGuard,
} from './sential'
describe('sential', () => {
it('does not hit when overlap stays low', () => {
const sential = createSential()
sential.inspect('ABCDEFGHIJKLMNO')
const result = sential.inspect('qrstuvwxyz12345')
expect(result.hit).toBe(false)
expect(result.overlap).toBe(0)
})
it('hits when overlap is above threshold', () => {
const sential = createSential()
sential.inspect('0123456789abcdefghij0123456789abcdefghij')
const result = sential.inspect('0123456789abcdefghij')
expect(result.hit).toBe(true)
expect(result.overlap).toBeGreaterThan(0.75)
})
it('does not hit when overlap is exactly threshold', () => {
const sential = createSential({
ngramSize: 1,
overlapThreshold: 0.75,
})
sential.inspect('aaaaaaaaaa')
const result = sential.inspect(`${'a'.repeat(15)}bcdef`)
expect(result.newGrams).toBe(20)
expect(result.matchedGrams).toBe(15)
expect(result.overlap).toBeCloseTo(0.75, 10)
expect(result.hit).toBe(false)
})
it('evicts old grams with sliding window', () => {
const sential = createSential({
windowSize: 20,
})
sential.inspect('abcdefghijabcdefghij')
sential.inspect('KLMNOPQRST')
sential.inspect('UVWXYZ1234')
const result = sential.inspect('abcdefghij')
expect(result.hit).toBe(false)
expect(result.matchedGrams).toBe(0)
})
it('aborts only after 10 consecutive hits', () => {
const guard = createTextLoopGuard({
ngramSize: 1,
overlapThreshold: 0.5,
consecutiveHitsToAbort: 10,
})
const seeded = guard.inspect('aaaaaaaaaa')
expect(seeded.hit).toBe(false)
expect(seeded.streak).toBe(0)
expect(seeded.abort).toBe(false)
for (let i = 1; i <= 9; i += 1) {
const result = guard.inspect('aaaaaaaaaa')
expect(result.hit).toBe(true)
expect(result.streak).toBe(i)
expect(result.abort).toBe(false)
}
const tenth = guard.inspect('aaaaaaaaaa')
expect(tenth.hit).toBe(true)
expect(tenth.streak).toBe(10)
expect(tenth.abort).toBe(true)
})
it('resets streak when a non-hit chunk appears', () => {
const guard = createTextLoopGuard({
ngramSize: 1,
overlapThreshold: 0.5,
consecutiveHitsToAbort: 10,
})
guard.inspect('aaaaaaaaaa')
for (let i = 0; i < 5; i += 1) {
guard.inspect('aaaaaaaaaa')
}
const miss = guard.inspect('bcdefghijk')
expect(miss.hit).toBe(false)
expect(miss.streak).toBe(0)
expect(miss.abort).toBe(false)
const hitAgain = guard.inspect('aaaaaaaaaa')
expect(hitAgain.hit).toBe(true)
expect(hitAgain.streak).toBe(1)
expect(hitAgain.abort).toBe(false)
})
it('only updates streak when chunk has enough new grams', () => {
const guard = createTextLoopGuard({
ngramSize: 1,
overlapThreshold: 0.5,
minNewGramsPerChunk: 5,
})
guard.inspect('aaaaaaaaaa')
const smallHit = guard.inspect('aaaa')
expect(smallHit.hit).toBe(true)
expect(smallHit.newGrams).toBe(4)
expect(smallHit.streak).toBe(0)
expect(smallHit.abort).toBe(false)
const countedHit = guard.inspect('aaaaa')
expect(countedHit.hit).toBe(true)
expect(countedHit.newGrams).toBe(5)
expect(countedHit.streak).toBe(1)
expect(countedHit.abort).toBe(false)
const smallMiss = guard.inspect('cccc')
expect(smallMiss.hit).toBe(false)
expect(smallMiss.newGrams).toBe(4)
expect(smallMiss.streak).toBe(1)
const countedMiss = guard.inspect('ddddd')
expect(countedMiss.hit).toBe(false)
expect(countedMiss.newGrams).toBe(5)
expect(countedMiss.streak).toBe(0)
})
it('warns on first tool-loop breach and aborts on second breach', () => {
const guard = createToolLoopGuard({
repeatThreshold: 5,
warningsBeforeAbort: 1,
})
const payload = {
toolName: 'web_fetch',
input: { url: 'https://example.com', requestId: 'r-1' },
}
for (let i = 1; i <= 5; i += 1) {
const result = guard.inspect(payload)
expect(result.warn).toBe(false)
expect(result.abort).toBe(false)
expect(result.repeatCount).toBe(i)
}
const firstBreach = guard.inspect(payload)
expect(firstBreach.warn).toBe(true)
expect(firstBreach.abort).toBe(false)
expect(firstBreach.breachCount).toBe(1)
expect(firstBreach.repeatCount).toBe(0)
for (let i = 1; i <= 5; i += 1) {
const result = guard.inspect(payload)
expect(result.warn).toBe(false)
expect(result.abort).toBe(false)
expect(result.repeatCount).toBe(i)
}
const secondBreach = guard.inspect(payload)
expect(secondBreach.warn).toBe(false)
expect(secondBreach.abort).toBe(true)
expect(secondBreach.breachCount).toBe(2)
})
it('resets tool-loop repeat count when hash changes', () => {
const guard = createToolLoopGuard({
repeatThreshold: 5,
warningsBeforeAbort: 1,
})
const first = guard.inspect({
toolName: 'web_fetch',
input: { url: 'https://a.example.com' },
})
expect(first.repeatCount).toBe(1)
const second = guard.inspect({
toolName: 'web_fetch',
input: { url: 'https://a.example.com' },
})
expect(second.repeatCount).toBe(2)
const changed = guard.inspect({
toolName: 'web_fetch',
input: { url: 'https://b.example.com' },
})
expect(changed.repeatCount).toBe(1)
expect(changed.warn).toBe(false)
expect(changed.abort).toBe(false)
})
it('resets tool-loop breach count when hash changes', () => {
const guard = createToolLoopGuard({
repeatThreshold: 1,
warningsBeforeAbort: 1,
})
// First fingerprint reaches warning phase.
guard.inspect({
toolName: 'web_fetch',
input: { url: 'https://a.example.com' },
})
const warned = guard.inspect({
toolName: 'web_fetch',
input: { url: 'https://a.example.com' },
})
expect(warned.warn).toBe(true)
expect(warned.breachCount).toBe(1)
// Switching fingerprint should restart warning/abort phase.
const switched = guard.inspect({
toolName: 'web_fetch',
input: { url: 'https://b.example.com' },
})
expect(switched.warn).toBe(false)
expect(switched.abort).toBe(false)
expect(switched.breachCount).toBe(0)
const warnedAgain = guard.inspect({
toolName: 'web_fetch',
input: { url: 'https://b.example.com' },
})
expect(warnedAgain.warn).toBe(true)
expect(warnedAgain.abort).toBe(false)
expect(warnedAgain.breachCount).toBe(1)
})
it('ignores volatile keys when computing tool-loop hash', () => {
const guard = createToolLoopGuard({
repeatThreshold: 1,
warningsBeforeAbort: 1,
})
const first = guard.inspect({
toolName: 'web_fetch',
input: {
url: 'https://example.com',
request_id: 'req-1',
updatedAt: '2026-02-28T00:00:00.000Z',
},
})
expect(first.warn).toBe(false)
expect(first.abort).toBe(false)
const second = guard.inspect({
toolName: 'web_fetch',
input: {
url: 'https://example.com',
request_id: 'req-2',
updatedAt: '2026-02-28T00:01:00.000Z',
},
})
expect(second.warn).toBe(true)
expect(second.abort).toBe(false)
})
})
+506
View File
@@ -0,0 +1,506 @@
import { createHash } from 'node:crypto'
export interface SentialOptions {
ngramSize?: number;
windowSize?: number;
overlapThreshold?: number;
}
export interface TextLoopGuardOptions extends SentialOptions {
consecutiveHitsToAbort?: number;
minNewGramsPerChunk?: number;
}
export interface SentialInspectResult {
hit: boolean;
overlap: number;
matchedGrams: number;
newGrams: number;
}
export interface TextLoopGuardInspectResult extends SentialInspectResult {
streak: number;
abort: boolean;
}
export interface ToolLoopInspectInput {
toolName: string;
input: unknown;
}
export interface ToolLoopInspectResult {
hash: string;
repeatCount: number;
breachCount: number;
warn: boolean;
abort: boolean;
}
export interface ToolLoopGuardOptions {
repeatThreshold?: number;
warningsBeforeAbort?: number;
volatileKeys?: string[];
}
export interface Sential {
inspect(text: string): SentialInspectResult;
reset(): void;
}
export interface TextLoopGuard {
inspect(text: string): TextLoopGuardInspectResult;
reset(): void;
}
export interface TextLoopProbeBuffer {
push(text: string): void;
flush(): void;
}
export interface ToolLoopGuard {
inspect(input: ToolLoopInspectInput): ToolLoopInspectResult;
reset(): void;
}
const DEFAULT_NGRAM_SIZE = 10
const DEFAULT_WINDOW_SIZE = 1000
const DEFAULT_OVERLAP_THRESHOLD = 0.75
const DEFAULT_CONSECUTIVE_HITS_TO_ABORT = 10
const DEFAULT_MIN_NEW_GRAMS_PER_CHUNK = 1
const DEFAULT_TOOL_LOOP_REPEAT_THRESHOLD = 5
const DEFAULT_TOOL_LOOP_WARNINGS_BEFORE_ABORT = 1
const DEFAULT_VOLATILE_KEYS = [
'toolCallId',
'tool_call_id',
'requestId',
'request_id',
'traceId',
'trace_id',
'spanId',
'span_id',
'sessionId',
'session_id',
'timestamp',
'createdAt',
'created_at',
'updatedAt',
'updated_at',
'expiresAt',
'expires_at',
'nonce',
]
const VOLATILE_KEY_SUFFIXES = [
'requestid',
'traceid',
'sessionid',
'toolcallid',
'timestamp',
'createdat',
'updatedat',
'expiresat',
]
type NormalizedValue =
| null
| string
| number
| boolean
| NormalizedValue[]
| { [key: string]: NormalizedValue };
function validatePositiveInt(name: string, value: number): number {
if (!Number.isFinite(value) || value <= 0 || !Number.isInteger(value)) {
throw new Error(`${name} must be a positive integer`)
}
return value
}
function validateThreshold(value: number): number {
if (!Number.isFinite(value) || value < 0 || value > 1) {
throw new Error('overlapThreshold must be between 0 and 1')
}
return value
}
function normalizeChars(text: string): string[] {
if (!text) return []
return Array.from(text.normalize('NFC'))
}
function buildNgram(chars: string[], start: number, size: number): string {
return chars.slice(start, start + size).join('')
}
function normalizeKeyName(key: string): string {
return key
.trim()
.toLowerCase()
.replace(/[^a-z0-9]/g, '')
}
function isVolatileKey(key: string, volatileKeySet: Set<string>): boolean {
const normalized = normalizeKeyName(key)
if (!normalized) return false
if (volatileKeySet.has(normalized)) return true
return VOLATILE_KEY_SUFFIXES.some((suffix) => normalized.endsWith(suffix))
}
function isPlainObject(value: unknown): value is Record<string, unknown> {
if (value === null || typeof value !== 'object') return false
const prototype = Object.getPrototypeOf(value as object)
return prototype === Object.prototype || prototype === null
}
function normalizeToolLoopValue(
value: unknown,
volatileKeySet: Set<string>,
seen: WeakSet<object>,
): NormalizedValue | undefined {
if (value === null) return null
if (typeof value === 'string') return value.normalize('NFC')
if (typeof value === 'boolean') return value
if (typeof value === 'number') {
return Number.isFinite(value) ? value : (String(value) as NormalizedValue)
}
if (typeof value === 'bigint') return value.toString()
if (
typeof value === 'undefined' ||
typeof value === 'function' ||
typeof value === 'symbol'
) {
return undefined
}
if (value instanceof Date) {
return value.toISOString()
}
if (Array.isArray(value)) {
return value.map(
(item) => normalizeToolLoopValue(item, volatileKeySet, seen) ?? null,
)
}
if (!isPlainObject(value)) {
const maybeRecord = value as { toJSON?: () => unknown }
if (typeof maybeRecord.toJSON === 'function') {
return normalizeToolLoopValue(maybeRecord.toJSON(), volatileKeySet, seen)
}
return String(value)
}
if (seen.has(value)) {
return '[Circular]'
}
seen.add(value)
const normalizedObject: { [key: string]: NormalizedValue } = {}
const keys = Object.keys(value).sort()
for (const key of keys) {
if (isVolatileKey(key, volatileKeySet)) {
continue
}
const normalized = normalizeToolLoopValue(value[key], volatileKeySet, seen)
if (normalized !== undefined) {
normalizedObject[key] = normalized
}
}
seen.delete(value)
return normalizedObject
}
function computeToolLoopHash(
input: ToolLoopInspectInput,
volatileKeySet: Set<string>,
): string {
const payload = {
toolName: input.toolName.trim(),
input:
normalizeToolLoopValue(input.input, volatileKeySet, new WeakSet()) ??
null,
}
const serialized = JSON.stringify(payload)
return createHash('sha256').update(serialized).digest('hex')
}
export function createSential(options: SentialOptions = {}): Sential {
const ngramSize = validatePositiveInt(
'ngramSize',
options.ngramSize ?? DEFAULT_NGRAM_SIZE,
)
const windowSize = validatePositiveInt(
'windowSize',
options.windowSize ?? DEFAULT_WINDOW_SIZE,
)
const overlapThreshold = validateThreshold(
options.overlapThreshold ?? DEFAULT_OVERLAP_THRESHOLD,
)
if (windowSize < ngramSize) {
throw new Error('windowSize must be greater than or equal to ngramSize')
}
const windowChars: string[] = []
const windowNgramQueue: string[] = []
const historySet = new Set<string>()
const historyCounts = new Map<string, number>()
const addHistoryGram = (gram: string) => {
const nextCount = (historyCounts.get(gram) ?? 0) + 1
historyCounts.set(gram, nextCount)
if (nextCount === 1) {
historySet.add(gram)
}
}
const removeHistoryGram = (gram: string) => {
const prevCount = historyCounts.get(gram)
if (!prevCount) return
if (prevCount <= 1) {
historyCounts.delete(gram)
historySet.delete(gram)
return
}
historyCounts.set(gram, prevCount - 1)
}
const pushWindowChar = (char: string) => {
windowChars.push(char)
if (windowChars.length >= ngramSize) {
const gram = buildNgram(
windowChars,
windowChars.length - ngramSize,
ngramSize,
)
windowNgramQueue.push(gram)
addHistoryGram(gram)
}
if (windowChars.length <= windowSize) {
return
}
windowChars.shift()
const removedGram = windowNgramQueue.shift()
if (removedGram) {
removeHistoryGram(removedGram)
}
}
return {
inspect(text: string): SentialInspectResult {
const incomingChars = normalizeChars(text)
if (incomingChars.length === 0) {
return {
hit: false,
overlap: 0,
matchedGrams: 0,
newGrams: 0,
}
}
const contextSize = Math.max(ngramSize - 1, 0)
const contextChars =
contextSize > 0 ? windowChars.slice(-contextSize) : []
const candidateChars = [...contextChars, ...incomingChars]
let matchedGrams = 0
let newGrams = 0
const contextLength = contextChars.length
if (candidateChars.length >= ngramSize) {
for (let i = 0; i <= candidateChars.length - ngramSize; i += 1) {
const gramEndIndex = i + ngramSize - 1
if (gramEndIndex < contextLength) {
continue
}
const gram = buildNgram(candidateChars, i, ngramSize)
newGrams += 1
if (historySet.has(gram)) {
matchedGrams += 1
}
}
}
const overlap = newGrams === 0 ? 0 : matchedGrams / newGrams
const hit = overlap > overlapThreshold
for (const char of incomingChars) {
pushWindowChar(char)
}
return {
hit,
overlap,
matchedGrams,
newGrams,
}
},
reset(): void {
windowChars.length = 0
windowNgramQueue.length = 0
historySet.clear()
historyCounts.clear()
},
}
}
export function createTextLoopGuard(
options: TextLoopGuardOptions = {},
): TextLoopGuard {
const consecutiveHitsToAbort = validatePositiveInt(
'consecutiveHitsToAbort',
options.consecutiveHitsToAbort ?? DEFAULT_CONSECUTIVE_HITS_TO_ABORT,
)
const minNewGramsPerChunk = validatePositiveInt(
'minNewGramsPerChunk',
options.minNewGramsPerChunk ?? DEFAULT_MIN_NEW_GRAMS_PER_CHUNK,
)
const sential = createSential(options)
let streak = 0
return {
inspect(text: string): TextLoopGuardInspectResult {
const result = sential.inspect(text)
if (result.newGrams >= minNewGramsPerChunk) {
if (result.hit) {
streak += 1
} else {
streak = 0
}
}
return {
...result,
streak,
abort: streak >= consecutiveHitsToAbort,
}
},
reset(): void {
sential.reset()
streak = 0
},
}
}
export function createTextLoopProbeBuffer(
chunkSize: number,
inspect: (text: string) => void,
): TextLoopProbeBuffer {
validatePositiveInt('chunkSize', chunkSize)
let chars: string[] = []
let offset = 0
const compact = () => {
if (offset > 0) {
chars = chars.slice(offset)
offset = 0
}
}
const inspectChunk = (text: string) => {
if (text.length > 0) {
inspect(text)
}
}
return {
push(text: string): void {
if (!text) return
chars.push(...normalizeChars(text))
while (chars.length - offset >= chunkSize) {
const chunk = chars.slice(offset, offset + chunkSize).join('')
offset += chunkSize
inspectChunk(chunk)
}
// Prevent unbounded front-gaps after many chunks.
if (offset >= chunkSize) {
compact()
}
},
flush(): void {
if (chars.length - offset > 0) {
const remainder = chars.slice(offset).join('')
inspectChunk(remainder)
}
chars = []
offset = 0
},
}
}
export function createToolLoopGuard(
options: ToolLoopGuardOptions = {},
): ToolLoopGuard {
const repeatThreshold = validatePositiveInt(
'repeatThreshold',
options.repeatThreshold ?? DEFAULT_TOOL_LOOP_REPEAT_THRESHOLD,
)
const warningsBeforeAbort = validatePositiveInt(
'warningsBeforeAbort',
options.warningsBeforeAbort ?? DEFAULT_TOOL_LOOP_WARNINGS_BEFORE_ABORT,
)
const volatileKeySet = new Set(DEFAULT_VOLATILE_KEYS.map(normalizeKeyName))
for (const key of options.volatileKeys ?? []) {
const normalizedKey = normalizeKeyName(key)
if (normalizedKey) {
volatileKeySet.add(normalizedKey)
}
}
let lastHash = ''
let repeatCount = 0
let breachCount = 0
let breachHash = ''
return {
inspect(input: ToolLoopInspectInput): ToolLoopInspectResult {
const hash = computeToolLoopHash(input, volatileKeySet)
if (hash === lastHash) {
repeatCount += 1
} else {
lastHash = hash
repeatCount = 1
}
// Breach phase is fingerprint-specific: switching tool signature restarts it.
if (breachHash !== hash) {
breachHash = hash
breachCount = 0
}
let warn = false
let abort = false
if (repeatCount > repeatThreshold) {
if (breachCount < warningsBeforeAbort) {
breachCount += 1
warn = true
// Reset consecutive accumulation after first warning.
lastHash = ''
repeatCount = 0
} else {
breachCount += 1
abort = true
}
}
return {
hash,
repeatCount,
breachCount,
warn,
abort,
}
},
reset(): void {
lastHash = ''
repeatCount = 0
breachCount = 0
breachHash = ''
},
}
}
+93
View File
@@ -0,0 +1,93 @@
import { describe, expect, it, vi } from 'vitest'
import type { ToolSet } from 'ai'
import { createToolLoopGuardedTools } from './tool-loop'
describe('tool loop guarded tools', () => {
it('preserves promised async-iterable tool outputs', async () => {
const onAbortToolCall = vi.fn()
const streamedChunks = ['chunk-1', 'chunk-2']
const stream = {
async *[Symbol.asyncIterator]() {
for (const chunk of streamedChunks) {
yield chunk
}
},
}
const baseTools = {
streamy: {
execute: async () => stream,
},
} as unknown as ToolSet
const tools = createToolLoopGuardedTools(baseTools, {
repeatThreshold: 1,
warningsBeforeAbort: 1,
onAbortToolCall,
warningKey: '__warn',
warningText: 'loop warning',
})
const output = await tools.streamy.execute?.({ value: 'same' } as never, { toolCallId: 't-stream' } as never)
expect(output).toBe(stream)
const received: string[] = []
for await (const chunk of output as AsyncIterable<string>) {
received.push(chunk)
}
expect(received).toEqual(streamedChunks)
expect(onAbortToolCall).not.toHaveBeenCalled()
})
it('defers abort to stream layer when onAbortToolCall is provided', async () => {
const onAbortToolCall = vi.fn()
const baseTools = {
echo: {
execute: async (input: unknown) => ({ result: input }),
},
} as unknown as ToolSet
const tools = createToolLoopGuardedTools(baseTools, {
repeatThreshold: 1,
warningsBeforeAbort: 1,
onAbortToolCall,
warningKey: '__warn',
warningText: 'loop warning',
})
await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-1' } as never)
const warned = await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-1' } as never)
expect(warned).toMatchObject({
__warn: {
marker: 'MEMOH_TOOL_LOOP_WARNING',
},
})
await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-1' } as never)
const abortedOutput = await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-1' } as never)
expect(onAbortToolCall).toHaveBeenCalledWith('t-1')
expect(abortedOutput).toEqual({ result: { value: 'same' } })
})
it('reports abort via callback without throwing inside tool execution', async () => {
const onAbortToolCall = vi.fn()
const baseTools = {
echo: {
execute: async (input: unknown) => ({ result: input }),
},
} as unknown as ToolSet
const tools = createToolLoopGuardedTools(baseTools, {
repeatThreshold: 1,
warningsBeforeAbort: 1,
onAbortToolCall,
warningKey: '__warn',
warningText: 'loop warning',
})
await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-2' } as never)
await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-2' } as never)
await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-2' } as never)
const abortedOutput = await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-2' } as never)
expect(onAbortToolCall).toHaveBeenCalledWith('t-2')
expect(abortedOutput).toEqual({ result: { value: 'same' } })
})
})
+122
View File
@@ -0,0 +1,122 @@
import { ToolExecutionOptions, ToolSet } from 'ai'
import { createToolLoopGuard, type ToolLoopInspectResult } from './sential'
export interface CreateToolLoopGuardedToolsOptions {
repeatThreshold: number
warningsBeforeAbort: number
onAbortToolCall: (toolCallId: string) => void
warningKey: string
warningText: string
}
const isRecord = (value: unknown): value is Record<string, unknown> => {
return value !== null && typeof value === 'object' && !Array.isArray(value)
}
const isAsyncIterable = (value: unknown): value is AsyncIterable<unknown> => {
return (
value !== null &&
typeof value === 'object' &&
Symbol.asyncIterator in value
)
}
const injectToolLoopWarning = (
output: unknown,
inspectResult: ToolLoopInspectResult,
warningKey: string,
warningText: string,
): unknown => {
// Keep warning payload structured so UI/consumers can detect and render it.
const warningPayload = {
marker: 'MEMOH_TOOL_LOOP_WARNING',
message: warningText,
fingerprint: inspectResult.hash,
breachCount: inspectResult.breachCount,
}
if (isRecord(output)) {
return {
...output,
[warningKey]: warningPayload,
}
}
return {
[warningKey]: warningPayload,
result: output,
}
}
export function createToolLoopGuardedTools(
tools: ToolSet,
{
repeatThreshold,
warningsBeforeAbort,
onAbortToolCall,
warningKey,
warningText,
}: CreateToolLoopGuardedToolsOptions,
): ToolSet {
const guard = createToolLoopGuard({
repeatThreshold,
warningsBeforeAbort,
})
// Wrap each executable tool to inspect (toolName + input) after execution.
// First breach injects a warning into this tool result; second breach signals abort.
return Object.fromEntries(
Object.entries(tools).map(([toolName, toolDefinition]) => {
const execute = toolDefinition.execute
if (typeof execute !== 'function') {
return [toolName, toolDefinition]
}
const wrappedTool = {
...toolDefinition,
execute: (
toolInput: unknown,
options: ToolExecutionOptions,
) => {
const directOutput = execute(
toolInput as never,
options as never,
) as unknown
// Streamed tool outputs are passed through unchanged to preserve streaming semantics.
if (isAsyncIterable(directOutput)) {
return directOutput as never
}
return (async () => {
const resolvedOutput = await directOutput
// Tools may return Promise<AsyncIterable>; keep that stream untouched too.
if (isAsyncIterable(resolvedOutput)) {
return resolvedOutput as never
}
const inspectResult = guard.inspect({
toolName,
input: toolInput,
})
if (inspectResult.abort) {
// Report loop abort to generation layer; it decides when/how to stop.
onAbortToolCall(options.toolCallId)
return resolvedOutput as never
}
if (inspectResult.warn) {
return injectToolLoopWarning(
resolvedOutput,
inspectResult,
warningKey,
warningText,
) as never
}
return resolvedOutput as never
})()
},
}
return [toolName, wrappedTool]
}),
) as ToolSet
}
+5
View File
@@ -38,6 +38,10 @@ export interface InboxItem {
createdAt: string
}
export interface LoopDetectionConfig {
enabled?: boolean
}
export interface AgentParams {
model: ModelConfig
language?: string
@@ -50,6 +54,7 @@ export interface AgentParams {
auth: AgentAuthContext
skills?: AgentSkill[]
inbox?: InboxItem[]
loopDetection?: LoopDetectionConfig
}
export interface AgentInput {