mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat(agent): loop detection (#152)
* feat(loop-detection): add configurable text and tool loop guards * style(web): remove duplicate separator in bot settings
This commit is contained in:
+5
-1
@@ -51,6 +51,10 @@ export const HeartbeatModel = z.object({
|
||||
interval: z.number().int().positive().default(30),
|
||||
})
|
||||
|
||||
export const LoopDetectionModel = z.object({
|
||||
enabled: z.boolean().default(false),
|
||||
}).optional()
|
||||
|
||||
export const AttachmentModel = z.object({
|
||||
contentHash: z.string().optional(),
|
||||
type: z.string().min(1, 'Attachment type is required'),
|
||||
@@ -93,4 +97,4 @@ export const InboxItemModel = z.object({
|
||||
header: z.record(z.string(), z.unknown()).default({}),
|
||||
content: z.string().default(''),
|
||||
createdAt: z.string(),
|
||||
})
|
||||
})
|
||||
|
||||
@@ -3,7 +3,7 @@ import z from 'zod'
|
||||
import { createAgent, ModelConfig, allActions } from '@memoh/agent'
|
||||
import { createAuthFetcher, getBaseUrl } from '../index'
|
||||
import { bearerMiddleware } from '../middlewares/bearer'
|
||||
import { AgentSkillModel, AllowedActionModel, AttachmentModel, HeartbeatModel, IdentityContextModel, InboxItemModel, MCPConnectionModel, ModelConfigModel, ScheduleModel } from '../models'
|
||||
import { AgentSkillModel, AllowedActionModel, AttachmentModel, HeartbeatModel, IdentityContextModel, InboxItemModel, LoopDetectionModel, MCPConnectionModel, ModelConfigModel, ScheduleModel } from '../models'
|
||||
import { sseChunked } from '../utils/sse'
|
||||
|
||||
const AgentModel = z.object({
|
||||
@@ -19,6 +19,7 @@ const AgentModel = z.object({
|
||||
attachments: z.array(AttachmentModel).optional().default([]),
|
||||
mcpConnections: z.array(MCPConnectionModel).optional().default([]),
|
||||
inbox: z.array(InboxItemModel).optional().default([]),
|
||||
loopDetection: LoopDetectionModel,
|
||||
})
|
||||
|
||||
export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
@@ -41,6 +42,7 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
skills: body.usableSkills,
|
||||
mcpConnections: body.mcpConnections,
|
||||
inbox: body.inbox,
|
||||
loopDetection: body.loopDetection,
|
||||
}, authFetcher)
|
||||
return ask({
|
||||
query: body.query,
|
||||
@@ -72,6 +74,7 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
skills: body.usableSkills,
|
||||
mcpConnections: body.mcpConnections,
|
||||
inbox: body.inbox,
|
||||
loopDetection: body.loopDetection,
|
||||
}, authFetcher)
|
||||
for await (const action of stream({
|
||||
query: body.query,
|
||||
@@ -113,6 +116,7 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
skills: body.usableSkills,
|
||||
mcpConnections: body.mcpConnections,
|
||||
inbox: body.inbox,
|
||||
loopDetection: body.loopDetection,
|
||||
}, authFetcher)
|
||||
return triggerSchedule({
|
||||
schedule: body.schedule,
|
||||
@@ -126,20 +130,22 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
})
|
||||
.post('/trigger-heartbeat', async ({ body, bearer }) => {
|
||||
console.log('trigger-heartbeat', body)
|
||||
const authFetcher = createAuthFetcher(bearer)
|
||||
const auth = {
|
||||
bearer: bearer!,
|
||||
baseUrl: getBaseUrl(),
|
||||
}
|
||||
const authFetcher = createAuthFetcher(auth)
|
||||
const { triggerHeartbeat } = createAgent({
|
||||
model: body.model as ModelConfig,
|
||||
activeContextTime: body.activeContextTime,
|
||||
channels: body.channels,
|
||||
currentChannel: body.currentChannel,
|
||||
identity: body.identity,
|
||||
auth: {
|
||||
bearer: bearer!,
|
||||
baseUrl: getBaseUrl(),
|
||||
},
|
||||
auth,
|
||||
skills: body.usableSkills,
|
||||
mcpConnections: body.mcpConnections,
|
||||
inbox: body.inbox,
|
||||
loopDetection: body.loopDetection,
|
||||
}, authFetcher)
|
||||
return triggerHeartbeat({
|
||||
heartbeat: body.heartbeat,
|
||||
|
||||
@@ -180,6 +180,10 @@ type gatewayInboxItem struct {
|
||||
CreatedAt string `json:"createdAt"`
|
||||
}
|
||||
|
||||
type gatewayLoopDetectionConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type gatewayRequest struct {
|
||||
Model gatewayModelConfig `json:"model"`
|
||||
ActiveContextTime int `json:"activeContextTime"`
|
||||
@@ -193,6 +197,7 @@ type gatewayRequest struct {
|
||||
Identity gatewayIdentity `json:"identity"`
|
||||
Attachments []any `json:"attachments"`
|
||||
Inbox []gatewayInboxItem `json:"inbox,omitempty"`
|
||||
LoopDetection *gatewayLoopDetectionConfig `json:"loopDetection,omitempty"`
|
||||
}
|
||||
|
||||
type gatewayResponse struct {
|
||||
@@ -298,6 +303,7 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
|
||||
if err != nil {
|
||||
return resolvedContext{}, err
|
||||
}
|
||||
loopDetectionEnabled := r.loadBotLoopDetectionEnabled(ctx, req.BotID)
|
||||
|
||||
// Check chat-level model override.
|
||||
var chatSettings conversation.Settings
|
||||
@@ -466,8 +472,9 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
|
||||
ConversationType: strings.TrimSpace(req.ConversationType),
|
||||
SessionToken: req.ChatToken,
|
||||
},
|
||||
Attachments: attachments,
|
||||
Inbox: inboxGatewayItems,
|
||||
Attachments: attachments,
|
||||
Inbox: inboxGatewayItems,
|
||||
LoopDetection: &gatewayLoopDetectionConfig{Enabled: loopDetectionEnabled},
|
||||
}
|
||||
|
||||
return resolvedContext{payload: payload, model: chatModel, provider: provider, inboxItemIDs: inboxItemIDs}, nil
|
||||
@@ -1835,6 +1842,48 @@ func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (settings.
|
||||
return r.settingsService.GetBot(ctx, botID)
|
||||
}
|
||||
|
||||
func (r *Resolver) loadBotLoopDetectionEnabled(ctx context.Context, botID string) bool {
|
||||
if r.queries == nil {
|
||||
return false
|
||||
}
|
||||
botUUID, err := db.ParseUUID(botID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
row, err := r.queries.GetBotByID(ctx, botUUID)
|
||||
if err != nil {
|
||||
r.logger.Debug("failed to load bot metadata for loop detection",
|
||||
slog.String("bot_id", botID),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
return false
|
||||
}
|
||||
return parseLoopDetectionEnabledFromMetadata(row.Metadata)
|
||||
}
|
||||
|
||||
func parseLoopDetectionEnabledFromMetadata(payload []byte) bool {
|
||||
if len(payload) == 0 {
|
||||
return false
|
||||
}
|
||||
var metadata map[string]any
|
||||
if err := json.Unmarshal(payload, &metadata); err != nil || metadata == nil {
|
||||
return false
|
||||
}
|
||||
features, ok := metadata["features"].(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
loopDetection, ok := features["loop_detection"].(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
enabled, ok := loopDetection["enabled"].(bool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
// --- utility ---
|
||||
|
||||
func normalizeClientType(clientType string) (string, error) {
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
package flow
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseLoopDetectionEnabledFromMetadata(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload []byte
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "empty payload defaults to false",
|
||||
payload: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid json defaults to false",
|
||||
payload: []byte("{"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "missing nested path defaults to false",
|
||||
payload: []byte(`{"features":{}}`),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "explicit false",
|
||||
payload: []byte(`{"features":{"loop_detection":{"enabled":false}}}`),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "explicit true",
|
||||
payload: []byte(`{"features":{"loop_detection":{"enabled":true}}}`),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "non-boolean value defaults to false",
|
||||
payload: []byte(`{"features":{"loop_detection":{"enabled":"true"}}}`),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := parseLoopDetectionEnabledFromMetadata(tt.payload)
|
||||
if got != tt.expected {
|
||||
t.Fatalf("expected %v, got %v", tt.expected, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+174
-42
@@ -33,9 +33,20 @@ import { getMCPTools } from './tools/mcp'
|
||||
import { getTools } from './tools'
|
||||
import { buildIdentityHeaders } from './utils/headers'
|
||||
import { createFS } from './utils'
|
||||
import { createTextLoopGuard, createTextLoopProbeBuffer } from './sential'
|
||||
import { createToolLoopGuardedTools } from './tool-loop'
|
||||
|
||||
const ANTHROPIC_BUDGET: Record<string, number> = { low: 5000, medium: 16000, high: 50000 }
|
||||
const GOOGLE_BUDGET: Record<string, number> = { low: 5000, medium: 16000, high: 50000 }
|
||||
const LOOP_DETECTED_ABORT_MESSAGE = 'loop detected, stream aborted'
|
||||
const LOOP_DETECTED_STREAK_THRESHOLD = 3
|
||||
const LOOP_DETECTED_MIN_NEW_GRAMS_PER_CHUNK = 8
|
||||
const LOOP_DETECTED_PROBE_CHARS = 256
|
||||
const TOOL_LOOP_DETECTED_ABORT_MESSAGE = 'tool loop detected, stream aborted'
|
||||
const TOOL_LOOP_REPEAT_THRESHOLD = 5
|
||||
const TOOL_LOOP_WARNINGS_BEFORE_ABORT = 1
|
||||
const TOOL_LOOP_WARNING_KEY = '__memoh_tool_loop_warning'
|
||||
const TOOL_LOOP_WARNING_TEXT = '[MEMOH_TOOL_LOOP_WARNING] Repeated identical tool invocation (same tool + arguments) was detected more than 5 times. Stop looping this tool and either summarize current results or change strategy.'
|
||||
|
||||
const buildProviderOptions = (config: ModelConfig): Record<string, Record<string, unknown>> | undefined => {
|
||||
if (!config.reasoning?.enabled) return undefined
|
||||
@@ -93,12 +104,14 @@ export const createAgent = (
|
||||
},
|
||||
auth,
|
||||
inbox = [],
|
||||
loopDetection = { enabled: false },
|
||||
}: AgentParams,
|
||||
fetch: AuthFetcher,
|
||||
) => {
|
||||
const model = createModel(modelConfig)
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const providerOptions = buildProviderOptions(modelConfig) as any
|
||||
const loopDetectionEnabled = loopDetection?.enabled === true
|
||||
const enabledSkills: AgentSkill[] = []
|
||||
const fs = createFS({ fetch, botId: identity.botId })
|
||||
|
||||
@@ -194,27 +207,102 @@ export const createAgent = (
|
||||
return userMessage
|
||||
}
|
||||
|
||||
const createNonStreamTextLoopInspector = () => {
|
||||
if (!loopDetectionEnabled) {
|
||||
return null
|
||||
}
|
||||
const textLoopGuard = createTextLoopGuard({
|
||||
consecutiveHitsToAbort: LOOP_DETECTED_STREAK_THRESHOLD,
|
||||
minNewGramsPerChunk: LOOP_DETECTED_MIN_NEW_GRAMS_PER_CHUNK,
|
||||
})
|
||||
return (text: string) => {
|
||||
const result = textLoopGuard.inspect(text)
|
||||
if (result.abort) {
|
||||
throw new Error(LOOP_DETECTED_ABORT_MESSAGE)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const buildGuardedTools = (
|
||||
tools: ToolSet,
|
||||
onAbortToolCall: (toolCallId: string) => void = () => {},
|
||||
): ToolSet => {
|
||||
if (!loopDetectionEnabled) {
|
||||
return tools
|
||||
}
|
||||
return createToolLoopGuardedTools(tools, {
|
||||
repeatThreshold: TOOL_LOOP_REPEAT_THRESHOLD,
|
||||
warningsBeforeAbort: TOOL_LOOP_WARNINGS_BEFORE_ABORT,
|
||||
onAbortToolCall,
|
||||
warningKey: TOOL_LOOP_WARNING_KEY,
|
||||
warningText: TOOL_LOOP_WARNING_TEXT,
|
||||
})
|
||||
}
|
||||
|
||||
const runTextGeneration = async ({
|
||||
messages,
|
||||
systemPrompt,
|
||||
prepareStep,
|
||||
}: {
|
||||
messages: ModelMessage[]
|
||||
systemPrompt: string
|
||||
prepareStep?: () => { system: string }
|
||||
}) => {
|
||||
const { tools, close } = await getAgentTools()
|
||||
let shouldAbortForToolLoop = false
|
||||
const guardedTools = buildGuardedTools(tools, () => {
|
||||
shouldAbortForToolLoop = true
|
||||
})
|
||||
const inspectTextLoop = createNonStreamTextLoopInspector()
|
||||
let runError: unknown = null
|
||||
try {
|
||||
return await generateText({
|
||||
model,
|
||||
messages,
|
||||
system: systemPrompt,
|
||||
...(providerOptions && { providerOptions }),
|
||||
stopWhen: stepCountIs(Infinity),
|
||||
...(prepareStep && { prepareStep }),
|
||||
...(loopDetectionEnabled && {
|
||||
onStepFinish: ({ text }: { text: string }) => {
|
||||
if (shouldAbortForToolLoop) {
|
||||
throw new Error(TOOL_LOOP_DETECTED_ABORT_MESSAGE)
|
||||
}
|
||||
if (inspectTextLoop) {
|
||||
inspectTextLoop(text)
|
||||
}
|
||||
},
|
||||
}),
|
||||
tools: guardedTools,
|
||||
})
|
||||
} catch (error) {
|
||||
runError = error
|
||||
throw error
|
||||
} finally {
|
||||
try {
|
||||
await close()
|
||||
} catch (closeError) {
|
||||
if (runError == null) {
|
||||
throw closeError
|
||||
}
|
||||
console.error(closeError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const ask = async (input: AgentInput) => {
|
||||
const userPrompt = generateUserPrompt(input)
|
||||
const messages = [...input.messages, userPrompt]
|
||||
input.skills.forEach((skill) => enableSkill(skill))
|
||||
const systemPrompt = await generateSystemPrompt()
|
||||
const { tools, close } = await getAgentTools()
|
||||
const { response, reasoning, text, usage, steps } = await generateText({
|
||||
model,
|
||||
const { response, reasoning, text, usage, steps } = await runTextGeneration({
|
||||
messages,
|
||||
system: systemPrompt,
|
||||
...(providerOptions && { providerOptions }),
|
||||
stopWhen: stepCountIs(Infinity),
|
||||
systemPrompt,
|
||||
prepareStep: () => {
|
||||
return {
|
||||
system: systemPrompt,
|
||||
}
|
||||
},
|
||||
onFinish: async () => {
|
||||
await close()
|
||||
},
|
||||
tools,
|
||||
})
|
||||
const stepUsages = buildStepUsages(steps)
|
||||
const { cleanedText, attachments: textAttachments } =
|
||||
@@ -257,22 +345,14 @@ export const createAgent = (
|
||||
})
|
||||
}
|
||||
const messages = [...params.messages, userPrompt]
|
||||
const { tools, close } = await getAgentTools()
|
||||
const { response, reasoning, text, usage, steps } = await generateText({
|
||||
model,
|
||||
const { response, reasoning, text, usage, steps } = await runTextGeneration({
|
||||
messages,
|
||||
system: generateSubagentSystemPrompt(),
|
||||
...(providerOptions && { providerOptions }),
|
||||
stopWhen: stepCountIs(Infinity),
|
||||
systemPrompt: generateSubagentSystemPrompt(),
|
||||
prepareStep: () => {
|
||||
return {
|
||||
system: generateSubagentSystemPrompt(),
|
||||
}
|
||||
},
|
||||
onFinish: async () => {
|
||||
await close()
|
||||
},
|
||||
tools,
|
||||
})
|
||||
const stepUsages = buildStepUsages(steps)
|
||||
return {
|
||||
@@ -301,17 +381,9 @@ export const createAgent = (
|
||||
}
|
||||
const messages = [...params.messages, scheduleMessage]
|
||||
params.skills.forEach((skill) => enableSkill(skill))
|
||||
const { tools, close } = await getAgentTools()
|
||||
const { response, reasoning, text, usage, steps } = await generateText({
|
||||
model,
|
||||
const { response, reasoning, text, usage, steps } = await runTextGeneration({
|
||||
messages,
|
||||
system: await generateSystemPrompt(),
|
||||
...(providerOptions && { providerOptions }),
|
||||
stopWhen: stepCountIs(Infinity),
|
||||
onFinish: async () => {
|
||||
await close()
|
||||
},
|
||||
tools,
|
||||
systemPrompt: await generateSystemPrompt(),
|
||||
})
|
||||
const stepUsages = buildStepUsages(steps)
|
||||
return {
|
||||
@@ -341,17 +413,9 @@ export const createAgent = (
|
||||
}
|
||||
const messages = [...params.messages, heartbeatMessage]
|
||||
params.skills.forEach((skill) => enableSkill(skill))
|
||||
const { tools, close } = await getAgentTools()
|
||||
const { response, reasoning, text, usage, steps } = await generateText({
|
||||
model,
|
||||
const { response, reasoning, text, usage, steps } = await runTextGeneration({
|
||||
messages,
|
||||
system: await generateSystemPrompt(),
|
||||
...(providerOptions && { providerOptions }),
|
||||
stopWhen: stepCountIs(Infinity),
|
||||
onFinish: async () => {
|
||||
await close()
|
||||
},
|
||||
tools,
|
||||
systemPrompt: await generateSystemPrompt(),
|
||||
})
|
||||
const stepUsages = buildStepUsages(steps)
|
||||
return {
|
||||
@@ -392,6 +456,27 @@ export const createAgent = (
|
||||
input.skills.forEach((skill) => enableSkill(skill))
|
||||
const systemPrompt = await generateSystemPrompt()
|
||||
const attachmentsExtractor = new AttachmentsStreamExtractor()
|
||||
const textLoopGuard = loopDetectionEnabled
|
||||
? createTextLoopGuard({
|
||||
consecutiveHitsToAbort: LOOP_DETECTED_STREAK_THRESHOLD,
|
||||
minNewGramsPerChunk: LOOP_DETECTED_MIN_NEW_GRAMS_PER_CHUNK,
|
||||
})
|
||||
: null
|
||||
const guardLoopOutput = (text: string) => {
|
||||
if (!textLoopGuard) {
|
||||
return
|
||||
}
|
||||
const result = textLoopGuard.inspect(text)
|
||||
if (result.abort) {
|
||||
throw new Error(LOOP_DETECTED_ABORT_MESSAGE)
|
||||
}
|
||||
}
|
||||
const textLoopProbeBuffer = textLoopGuard
|
||||
? createTextLoopProbeBuffer(
|
||||
LOOP_DETECTED_PROBE_CHARS,
|
||||
guardLoopOutput,
|
||||
)
|
||||
: null
|
||||
const result: {
|
||||
messages: ModelMessage[];
|
||||
reasoning: string[];
|
||||
@@ -403,7 +488,20 @@ export const createAgent = (
|
||||
usage: null,
|
||||
usages: [],
|
||||
}
|
||||
const toolLoopAbortCallIds = new Set<string>()
|
||||
const { tools, close } = await getAgentTools()
|
||||
// Stream path needs deferred abort to keep tool_call_start/tool_call_end event pairing.
|
||||
const guardedTools = buildGuardedTools(tools, (toolCallId) => {
|
||||
toolLoopAbortCallIds.add(toolCallId)
|
||||
})
|
||||
let closePromise: Promise<void> | null = null
|
||||
const closeTools = async () => {
|
||||
if (!closePromise) {
|
||||
closePromise = Promise.resolve().then(() => close())
|
||||
}
|
||||
await closePromise
|
||||
}
|
||||
let streamError: unknown = null
|
||||
try {
|
||||
const { fullStream } = streamText({
|
||||
model,
|
||||
@@ -416,9 +514,9 @@ export const createAgent = (
|
||||
system: systemPrompt,
|
||||
}
|
||||
},
|
||||
tools,
|
||||
tools: guardedTools,
|
||||
onFinish: async ({ usage, reasoning, response, steps }) => {
|
||||
await close()
|
||||
await closeTools()
|
||||
result.usage = usage as never
|
||||
result.reasoning = reasoning.map((part) => part.text)
|
||||
result.messages = response.messages
|
||||
@@ -464,6 +562,9 @@ export const createAgent = (
|
||||
chunk.text,
|
||||
)
|
||||
if (visibleText) {
|
||||
if (textLoopProbeBuffer) {
|
||||
textLoopProbeBuffer.push(visibleText)
|
||||
}
|
||||
yield {
|
||||
type: 'text_delta',
|
||||
delta: visibleText,
|
||||
@@ -481,11 +582,17 @@ export const createAgent = (
|
||||
// Flush any remaining buffered content before ending the text stream.
|
||||
const remainder = attachmentsExtractor.flushRemainder()
|
||||
if (remainder.visibleText) {
|
||||
if (textLoopProbeBuffer) {
|
||||
textLoopProbeBuffer.push(remainder.visibleText)
|
||||
}
|
||||
yield {
|
||||
type: 'text_delta',
|
||||
delta: remainder.visibleText,
|
||||
}
|
||||
}
|
||||
if (textLoopProbeBuffer) {
|
||||
textLoopProbeBuffer.flush()
|
||||
}
|
||||
if (remainder.attachments.length) {
|
||||
yield {
|
||||
type: 'attachment_delta',
|
||||
@@ -502,11 +609,17 @@ export const createAgent = (
|
||||
// Flush any remaining buffered content before ending the text stream.
|
||||
const remainder = attachmentsExtractor.flushRemainder()
|
||||
if (remainder.visibleText) {
|
||||
if (textLoopProbeBuffer) {
|
||||
textLoopProbeBuffer.push(remainder.visibleText)
|
||||
}
|
||||
yield {
|
||||
type: 'text_delta',
|
||||
delta: remainder.visibleText,
|
||||
}
|
||||
}
|
||||
if (textLoopProbeBuffer) {
|
||||
textLoopProbeBuffer.flush()
|
||||
}
|
||||
if (remainder.attachments.length) {
|
||||
yield {
|
||||
type: 'attachment_delta',
|
||||
@@ -522,6 +635,9 @@ export const createAgent = (
|
||||
}
|
||||
break
|
||||
case 'tool-result':
|
||||
// Always emit the terminal tool event first so downstream reducers
|
||||
// can close the in-flight tool block before the stream aborts.
|
||||
const shouldAbortForToolLoop = toolLoopAbortCallIds.delete(chunk.toolCallId)
|
||||
yield {
|
||||
type: 'tool_call_end',
|
||||
toolName: chunk.toolName,
|
||||
@@ -530,6 +646,9 @@ export const createAgent = (
|
||||
result: chunk.output,
|
||||
metadata: chunk,
|
||||
}
|
||||
if (shouldAbortForToolLoop) {
|
||||
throw new Error(TOOL_LOOP_DETECTED_ABORT_MESSAGE)
|
||||
}
|
||||
break
|
||||
case 'file':
|
||||
yield {
|
||||
@@ -544,6 +663,9 @@ export const createAgent = (
|
||||
}
|
||||
}
|
||||
}
|
||||
if (textLoopProbeBuffer) {
|
||||
textLoopProbeBuffer.flush()
|
||||
}
|
||||
|
||||
const { messages: strippedMessages } = stripAttachmentsFromMessages(
|
||||
result.messages,
|
||||
@@ -560,8 +682,18 @@ export const createAgent = (
|
||||
skills: getEnabledSkills(),
|
||||
}
|
||||
} catch (error) {
|
||||
streamError = error
|
||||
console.error(error)
|
||||
throw error
|
||||
} finally {
|
||||
try {
|
||||
await closeTools()
|
||||
} catch (closeError) {
|
||||
if (streamError == null) {
|
||||
throw closeError
|
||||
}
|
||||
console.error(closeError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,265 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import {
|
||||
createSential,
|
||||
createTextLoopGuard,
|
||||
createToolLoopGuard,
|
||||
} from './sential'
|
||||
|
||||
describe('sential', () => {
|
||||
it('does not hit when overlap stays low', () => {
|
||||
const sential = createSential()
|
||||
sential.inspect('ABCDEFGHIJKLMNO')
|
||||
|
||||
const result = sential.inspect('qrstuvwxyz12345')
|
||||
|
||||
expect(result.hit).toBe(false)
|
||||
expect(result.overlap).toBe(0)
|
||||
})
|
||||
|
||||
it('hits when overlap is above threshold', () => {
|
||||
const sential = createSential()
|
||||
sential.inspect('0123456789abcdefghij0123456789abcdefghij')
|
||||
|
||||
const result = sential.inspect('0123456789abcdefghij')
|
||||
|
||||
expect(result.hit).toBe(true)
|
||||
expect(result.overlap).toBeGreaterThan(0.75)
|
||||
})
|
||||
|
||||
it('does not hit when overlap is exactly threshold', () => {
|
||||
const sential = createSential({
|
||||
ngramSize: 1,
|
||||
overlapThreshold: 0.75,
|
||||
})
|
||||
sential.inspect('aaaaaaaaaa')
|
||||
|
||||
const result = sential.inspect(`${'a'.repeat(15)}bcdef`)
|
||||
|
||||
expect(result.newGrams).toBe(20)
|
||||
expect(result.matchedGrams).toBe(15)
|
||||
expect(result.overlap).toBeCloseTo(0.75, 10)
|
||||
expect(result.hit).toBe(false)
|
||||
})
|
||||
|
||||
it('evicts old grams with sliding window', () => {
|
||||
const sential = createSential({
|
||||
windowSize: 20,
|
||||
})
|
||||
sential.inspect('abcdefghijabcdefghij')
|
||||
sential.inspect('KLMNOPQRST')
|
||||
sential.inspect('UVWXYZ1234')
|
||||
|
||||
const result = sential.inspect('abcdefghij')
|
||||
|
||||
expect(result.hit).toBe(false)
|
||||
expect(result.matchedGrams).toBe(0)
|
||||
})
|
||||
|
||||
it('aborts only after 10 consecutive hits', () => {
|
||||
const guard = createTextLoopGuard({
|
||||
ngramSize: 1,
|
||||
overlapThreshold: 0.5,
|
||||
consecutiveHitsToAbort: 10,
|
||||
})
|
||||
|
||||
const seeded = guard.inspect('aaaaaaaaaa')
|
||||
expect(seeded.hit).toBe(false)
|
||||
expect(seeded.streak).toBe(0)
|
||||
expect(seeded.abort).toBe(false)
|
||||
|
||||
for (let i = 1; i <= 9; i += 1) {
|
||||
const result = guard.inspect('aaaaaaaaaa')
|
||||
expect(result.hit).toBe(true)
|
||||
expect(result.streak).toBe(i)
|
||||
expect(result.abort).toBe(false)
|
||||
}
|
||||
|
||||
const tenth = guard.inspect('aaaaaaaaaa')
|
||||
expect(tenth.hit).toBe(true)
|
||||
expect(tenth.streak).toBe(10)
|
||||
expect(tenth.abort).toBe(true)
|
||||
})
|
||||
|
||||
it('resets streak when a non-hit chunk appears', () => {
|
||||
const guard = createTextLoopGuard({
|
||||
ngramSize: 1,
|
||||
overlapThreshold: 0.5,
|
||||
consecutiveHitsToAbort: 10,
|
||||
})
|
||||
|
||||
guard.inspect('aaaaaaaaaa')
|
||||
for (let i = 0; i < 5; i += 1) {
|
||||
guard.inspect('aaaaaaaaaa')
|
||||
}
|
||||
|
||||
const miss = guard.inspect('bcdefghijk')
|
||||
expect(miss.hit).toBe(false)
|
||||
expect(miss.streak).toBe(0)
|
||||
expect(miss.abort).toBe(false)
|
||||
|
||||
const hitAgain = guard.inspect('aaaaaaaaaa')
|
||||
expect(hitAgain.hit).toBe(true)
|
||||
expect(hitAgain.streak).toBe(1)
|
||||
expect(hitAgain.abort).toBe(false)
|
||||
})
|
||||
|
||||
it('only updates streak when chunk has enough new grams', () => {
|
||||
const guard = createTextLoopGuard({
|
||||
ngramSize: 1,
|
||||
overlapThreshold: 0.5,
|
||||
minNewGramsPerChunk: 5,
|
||||
})
|
||||
|
||||
guard.inspect('aaaaaaaaaa')
|
||||
|
||||
const smallHit = guard.inspect('aaaa')
|
||||
expect(smallHit.hit).toBe(true)
|
||||
expect(smallHit.newGrams).toBe(4)
|
||||
expect(smallHit.streak).toBe(0)
|
||||
expect(smallHit.abort).toBe(false)
|
||||
|
||||
const countedHit = guard.inspect('aaaaa')
|
||||
expect(countedHit.hit).toBe(true)
|
||||
expect(countedHit.newGrams).toBe(5)
|
||||
expect(countedHit.streak).toBe(1)
|
||||
expect(countedHit.abort).toBe(false)
|
||||
|
||||
const smallMiss = guard.inspect('cccc')
|
||||
expect(smallMiss.hit).toBe(false)
|
||||
expect(smallMiss.newGrams).toBe(4)
|
||||
expect(smallMiss.streak).toBe(1)
|
||||
|
||||
const countedMiss = guard.inspect('ddddd')
|
||||
expect(countedMiss.hit).toBe(false)
|
||||
expect(countedMiss.newGrams).toBe(5)
|
||||
expect(countedMiss.streak).toBe(0)
|
||||
})
|
||||
|
||||
it('warns on first tool-loop breach and aborts on second breach', () => {
|
||||
const guard = createToolLoopGuard({
|
||||
repeatThreshold: 5,
|
||||
warningsBeforeAbort: 1,
|
||||
})
|
||||
const payload = {
|
||||
toolName: 'web_fetch',
|
||||
input: { url: 'https://example.com', requestId: 'r-1' },
|
||||
}
|
||||
|
||||
for (let i = 1; i <= 5; i += 1) {
|
||||
const result = guard.inspect(payload)
|
||||
expect(result.warn).toBe(false)
|
||||
expect(result.abort).toBe(false)
|
||||
expect(result.repeatCount).toBe(i)
|
||||
}
|
||||
|
||||
const firstBreach = guard.inspect(payload)
|
||||
expect(firstBreach.warn).toBe(true)
|
||||
expect(firstBreach.abort).toBe(false)
|
||||
expect(firstBreach.breachCount).toBe(1)
|
||||
expect(firstBreach.repeatCount).toBe(0)
|
||||
|
||||
for (let i = 1; i <= 5; i += 1) {
|
||||
const result = guard.inspect(payload)
|
||||
expect(result.warn).toBe(false)
|
||||
expect(result.abort).toBe(false)
|
||||
expect(result.repeatCount).toBe(i)
|
||||
}
|
||||
|
||||
const secondBreach = guard.inspect(payload)
|
||||
expect(secondBreach.warn).toBe(false)
|
||||
expect(secondBreach.abort).toBe(true)
|
||||
expect(secondBreach.breachCount).toBe(2)
|
||||
})
|
||||
|
||||
it('resets tool-loop repeat count when hash changes', () => {
|
||||
const guard = createToolLoopGuard({
|
||||
repeatThreshold: 5,
|
||||
warningsBeforeAbort: 1,
|
||||
})
|
||||
|
||||
const first = guard.inspect({
|
||||
toolName: 'web_fetch',
|
||||
input: { url: 'https://a.example.com' },
|
||||
})
|
||||
expect(first.repeatCount).toBe(1)
|
||||
|
||||
const second = guard.inspect({
|
||||
toolName: 'web_fetch',
|
||||
input: { url: 'https://a.example.com' },
|
||||
})
|
||||
expect(second.repeatCount).toBe(2)
|
||||
|
||||
const changed = guard.inspect({
|
||||
toolName: 'web_fetch',
|
||||
input: { url: 'https://b.example.com' },
|
||||
})
|
||||
expect(changed.repeatCount).toBe(1)
|
||||
expect(changed.warn).toBe(false)
|
||||
expect(changed.abort).toBe(false)
|
||||
})
|
||||
|
||||
it('resets tool-loop breach count when hash changes', () => {
|
||||
const guard = createToolLoopGuard({
|
||||
repeatThreshold: 1,
|
||||
warningsBeforeAbort: 1,
|
||||
})
|
||||
|
||||
// First fingerprint reaches warning phase.
|
||||
guard.inspect({
|
||||
toolName: 'web_fetch',
|
||||
input: { url: 'https://a.example.com' },
|
||||
})
|
||||
const warned = guard.inspect({
|
||||
toolName: 'web_fetch',
|
||||
input: { url: 'https://a.example.com' },
|
||||
})
|
||||
expect(warned.warn).toBe(true)
|
||||
expect(warned.breachCount).toBe(1)
|
||||
|
||||
// Switching fingerprint should restart warning/abort phase.
|
||||
const switched = guard.inspect({
|
||||
toolName: 'web_fetch',
|
||||
input: { url: 'https://b.example.com' },
|
||||
})
|
||||
expect(switched.warn).toBe(false)
|
||||
expect(switched.abort).toBe(false)
|
||||
expect(switched.breachCount).toBe(0)
|
||||
|
||||
const warnedAgain = guard.inspect({
|
||||
toolName: 'web_fetch',
|
||||
input: { url: 'https://b.example.com' },
|
||||
})
|
||||
expect(warnedAgain.warn).toBe(true)
|
||||
expect(warnedAgain.abort).toBe(false)
|
||||
expect(warnedAgain.breachCount).toBe(1)
|
||||
})
|
||||
|
||||
it('ignores volatile keys when computing tool-loop hash', () => {
|
||||
const guard = createToolLoopGuard({
|
||||
repeatThreshold: 1,
|
||||
warningsBeforeAbort: 1,
|
||||
})
|
||||
|
||||
const first = guard.inspect({
|
||||
toolName: 'web_fetch',
|
||||
input: {
|
||||
url: 'https://example.com',
|
||||
request_id: 'req-1',
|
||||
updatedAt: '2026-02-28T00:00:00.000Z',
|
||||
},
|
||||
})
|
||||
expect(first.warn).toBe(false)
|
||||
expect(first.abort).toBe(false)
|
||||
|
||||
const second = guard.inspect({
|
||||
toolName: 'web_fetch',
|
||||
input: {
|
||||
url: 'https://example.com',
|
||||
request_id: 'req-2',
|
||||
updatedAt: '2026-02-28T00:01:00.000Z',
|
||||
},
|
||||
})
|
||||
expect(second.warn).toBe(true)
|
||||
expect(second.abort).toBe(false)
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,506 @@
|
||||
import { createHash } from 'node:crypto'
|
||||
|
||||
export interface SentialOptions {
|
||||
ngramSize?: number;
|
||||
windowSize?: number;
|
||||
overlapThreshold?: number;
|
||||
}
|
||||
|
||||
export interface TextLoopGuardOptions extends SentialOptions {
|
||||
consecutiveHitsToAbort?: number;
|
||||
minNewGramsPerChunk?: number;
|
||||
}
|
||||
|
||||
export interface SentialInspectResult {
|
||||
hit: boolean;
|
||||
overlap: number;
|
||||
matchedGrams: number;
|
||||
newGrams: number;
|
||||
}
|
||||
|
||||
export interface TextLoopGuardInspectResult extends SentialInspectResult {
|
||||
streak: number;
|
||||
abort: boolean;
|
||||
}
|
||||
|
||||
export interface ToolLoopInspectInput {
|
||||
toolName: string;
|
||||
input: unknown;
|
||||
}
|
||||
|
||||
export interface ToolLoopInspectResult {
|
||||
hash: string;
|
||||
repeatCount: number;
|
||||
breachCount: number;
|
||||
warn: boolean;
|
||||
abort: boolean;
|
||||
}
|
||||
|
||||
export interface ToolLoopGuardOptions {
|
||||
repeatThreshold?: number;
|
||||
warningsBeforeAbort?: number;
|
||||
volatileKeys?: string[];
|
||||
}
|
||||
|
||||
export interface Sential {
|
||||
inspect(text: string): SentialInspectResult;
|
||||
reset(): void;
|
||||
}
|
||||
|
||||
export interface TextLoopGuard {
|
||||
inspect(text: string): TextLoopGuardInspectResult;
|
||||
reset(): void;
|
||||
}
|
||||
|
||||
export interface TextLoopProbeBuffer {
|
||||
push(text: string): void;
|
||||
flush(): void;
|
||||
}
|
||||
|
||||
export interface ToolLoopGuard {
|
||||
inspect(input: ToolLoopInspectInput): ToolLoopInspectResult;
|
||||
reset(): void;
|
||||
}
|
||||
|
||||
const DEFAULT_NGRAM_SIZE = 10
|
||||
const DEFAULT_WINDOW_SIZE = 1000
|
||||
const DEFAULT_OVERLAP_THRESHOLD = 0.75
|
||||
const DEFAULT_CONSECUTIVE_HITS_TO_ABORT = 10
|
||||
const DEFAULT_MIN_NEW_GRAMS_PER_CHUNK = 1
|
||||
const DEFAULT_TOOL_LOOP_REPEAT_THRESHOLD = 5
|
||||
const DEFAULT_TOOL_LOOP_WARNINGS_BEFORE_ABORT = 1
|
||||
const DEFAULT_VOLATILE_KEYS = [
|
||||
'toolCallId',
|
||||
'tool_call_id',
|
||||
'requestId',
|
||||
'request_id',
|
||||
'traceId',
|
||||
'trace_id',
|
||||
'spanId',
|
||||
'span_id',
|
||||
'sessionId',
|
||||
'session_id',
|
||||
'timestamp',
|
||||
'createdAt',
|
||||
'created_at',
|
||||
'updatedAt',
|
||||
'updated_at',
|
||||
'expiresAt',
|
||||
'expires_at',
|
||||
'nonce',
|
||||
]
|
||||
const VOLATILE_KEY_SUFFIXES = [
|
||||
'requestid',
|
||||
'traceid',
|
||||
'sessionid',
|
||||
'toolcallid',
|
||||
'timestamp',
|
||||
'createdat',
|
||||
'updatedat',
|
||||
'expiresat',
|
||||
]
|
||||
|
||||
type NormalizedValue =
|
||||
| null
|
||||
| string
|
||||
| number
|
||||
| boolean
|
||||
| NormalizedValue[]
|
||||
| { [key: string]: NormalizedValue };
|
||||
|
||||
function validatePositiveInt(name: string, value: number): number {
|
||||
if (!Number.isFinite(value) || value <= 0 || !Number.isInteger(value)) {
|
||||
throw new Error(`${name} must be a positive integer`)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
function validateThreshold(value: number): number {
|
||||
if (!Number.isFinite(value) || value < 0 || value > 1) {
|
||||
throw new Error('overlapThreshold must be between 0 and 1')
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
function normalizeChars(text: string): string[] {
|
||||
if (!text) return []
|
||||
return Array.from(text.normalize('NFC'))
|
||||
}
|
||||
|
||||
function buildNgram(chars: string[], start: number, size: number): string {
|
||||
return chars.slice(start, start + size).join('')
|
||||
}
|
||||
|
||||
function normalizeKeyName(key: string): string {
|
||||
return key
|
||||
.trim()
|
||||
.toLowerCase()
|
||||
.replace(/[^a-z0-9]/g, '')
|
||||
}
|
||||
|
||||
function isVolatileKey(key: string, volatileKeySet: Set<string>): boolean {
|
||||
const normalized = normalizeKeyName(key)
|
||||
if (!normalized) return false
|
||||
if (volatileKeySet.has(normalized)) return true
|
||||
return VOLATILE_KEY_SUFFIXES.some((suffix) => normalized.endsWith(suffix))
|
||||
}
|
||||
|
||||
function isPlainObject(value: unknown): value is Record<string, unknown> {
|
||||
if (value === null || typeof value !== 'object') return false
|
||||
const prototype = Object.getPrototypeOf(value as object)
|
||||
return prototype === Object.prototype || prototype === null
|
||||
}
|
||||
|
||||
function normalizeToolLoopValue(
|
||||
value: unknown,
|
||||
volatileKeySet: Set<string>,
|
||||
seen: WeakSet<object>,
|
||||
): NormalizedValue | undefined {
|
||||
if (value === null) return null
|
||||
|
||||
if (typeof value === 'string') return value.normalize('NFC')
|
||||
if (typeof value === 'boolean') return value
|
||||
if (typeof value === 'number') {
|
||||
return Number.isFinite(value) ? value : (String(value) as NormalizedValue)
|
||||
}
|
||||
if (typeof value === 'bigint') return value.toString()
|
||||
if (
|
||||
typeof value === 'undefined' ||
|
||||
typeof value === 'function' ||
|
||||
typeof value === 'symbol'
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
if (value instanceof Date) {
|
||||
return value.toISOString()
|
||||
}
|
||||
|
||||
if (Array.isArray(value)) {
|
||||
return value.map(
|
||||
(item) => normalizeToolLoopValue(item, volatileKeySet, seen) ?? null,
|
||||
)
|
||||
}
|
||||
|
||||
if (!isPlainObject(value)) {
|
||||
const maybeRecord = value as { toJSON?: () => unknown }
|
||||
if (typeof maybeRecord.toJSON === 'function') {
|
||||
return normalizeToolLoopValue(maybeRecord.toJSON(), volatileKeySet, seen)
|
||||
}
|
||||
return String(value)
|
||||
}
|
||||
|
||||
if (seen.has(value)) {
|
||||
return '[Circular]'
|
||||
}
|
||||
seen.add(value)
|
||||
|
||||
const normalizedObject: { [key: string]: NormalizedValue } = {}
|
||||
const keys = Object.keys(value).sort()
|
||||
for (const key of keys) {
|
||||
if (isVolatileKey(key, volatileKeySet)) {
|
||||
continue
|
||||
}
|
||||
const normalized = normalizeToolLoopValue(value[key], volatileKeySet, seen)
|
||||
if (normalized !== undefined) {
|
||||
normalizedObject[key] = normalized
|
||||
}
|
||||
}
|
||||
|
||||
seen.delete(value)
|
||||
return normalizedObject
|
||||
}
|
||||
|
||||
function computeToolLoopHash(
|
||||
input: ToolLoopInspectInput,
|
||||
volatileKeySet: Set<string>,
|
||||
): string {
|
||||
const payload = {
|
||||
toolName: input.toolName.trim(),
|
||||
input:
|
||||
normalizeToolLoopValue(input.input, volatileKeySet, new WeakSet()) ??
|
||||
null,
|
||||
}
|
||||
const serialized = JSON.stringify(payload)
|
||||
return createHash('sha256').update(serialized).digest('hex')
|
||||
}
|
||||
|
||||
export function createSential(options: SentialOptions = {}): Sential {
|
||||
const ngramSize = validatePositiveInt(
|
||||
'ngramSize',
|
||||
options.ngramSize ?? DEFAULT_NGRAM_SIZE,
|
||||
)
|
||||
const windowSize = validatePositiveInt(
|
||||
'windowSize',
|
||||
options.windowSize ?? DEFAULT_WINDOW_SIZE,
|
||||
)
|
||||
const overlapThreshold = validateThreshold(
|
||||
options.overlapThreshold ?? DEFAULT_OVERLAP_THRESHOLD,
|
||||
)
|
||||
if (windowSize < ngramSize) {
|
||||
throw new Error('windowSize must be greater than or equal to ngramSize')
|
||||
}
|
||||
|
||||
const windowChars: string[] = []
|
||||
const windowNgramQueue: string[] = []
|
||||
const historySet = new Set<string>()
|
||||
const historyCounts = new Map<string, number>()
|
||||
|
||||
const addHistoryGram = (gram: string) => {
|
||||
const nextCount = (historyCounts.get(gram) ?? 0) + 1
|
||||
historyCounts.set(gram, nextCount)
|
||||
if (nextCount === 1) {
|
||||
historySet.add(gram)
|
||||
}
|
||||
}
|
||||
|
||||
const removeHistoryGram = (gram: string) => {
|
||||
const prevCount = historyCounts.get(gram)
|
||||
if (!prevCount) return
|
||||
if (prevCount <= 1) {
|
||||
historyCounts.delete(gram)
|
||||
historySet.delete(gram)
|
||||
return
|
||||
}
|
||||
historyCounts.set(gram, prevCount - 1)
|
||||
}
|
||||
|
||||
const pushWindowChar = (char: string) => {
|
||||
windowChars.push(char)
|
||||
|
||||
if (windowChars.length >= ngramSize) {
|
||||
const gram = buildNgram(
|
||||
windowChars,
|
||||
windowChars.length - ngramSize,
|
||||
ngramSize,
|
||||
)
|
||||
windowNgramQueue.push(gram)
|
||||
addHistoryGram(gram)
|
||||
}
|
||||
|
||||
if (windowChars.length <= windowSize) {
|
||||
return
|
||||
}
|
||||
|
||||
windowChars.shift()
|
||||
const removedGram = windowNgramQueue.shift()
|
||||
if (removedGram) {
|
||||
removeHistoryGram(removedGram)
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
inspect(text: string): SentialInspectResult {
|
||||
const incomingChars = normalizeChars(text)
|
||||
if (incomingChars.length === 0) {
|
||||
return {
|
||||
hit: false,
|
||||
overlap: 0,
|
||||
matchedGrams: 0,
|
||||
newGrams: 0,
|
||||
}
|
||||
}
|
||||
|
||||
const contextSize = Math.max(ngramSize - 1, 0)
|
||||
const contextChars =
|
||||
contextSize > 0 ? windowChars.slice(-contextSize) : []
|
||||
const candidateChars = [...contextChars, ...incomingChars]
|
||||
|
||||
let matchedGrams = 0
|
||||
let newGrams = 0
|
||||
const contextLength = contextChars.length
|
||||
|
||||
if (candidateChars.length >= ngramSize) {
|
||||
for (let i = 0; i <= candidateChars.length - ngramSize; i += 1) {
|
||||
const gramEndIndex = i + ngramSize - 1
|
||||
if (gramEndIndex < contextLength) {
|
||||
continue
|
||||
}
|
||||
const gram = buildNgram(candidateChars, i, ngramSize)
|
||||
newGrams += 1
|
||||
if (historySet.has(gram)) {
|
||||
matchedGrams += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const overlap = newGrams === 0 ? 0 : matchedGrams / newGrams
|
||||
const hit = overlap > overlapThreshold
|
||||
|
||||
for (const char of incomingChars) {
|
||||
pushWindowChar(char)
|
||||
}
|
||||
|
||||
return {
|
||||
hit,
|
||||
overlap,
|
||||
matchedGrams,
|
||||
newGrams,
|
||||
}
|
||||
},
|
||||
reset(): void {
|
||||
windowChars.length = 0
|
||||
windowNgramQueue.length = 0
|
||||
historySet.clear()
|
||||
historyCounts.clear()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
export function createTextLoopGuard(
|
||||
options: TextLoopGuardOptions = {},
|
||||
): TextLoopGuard {
|
||||
const consecutiveHitsToAbort = validatePositiveInt(
|
||||
'consecutiveHitsToAbort',
|
||||
options.consecutiveHitsToAbort ?? DEFAULT_CONSECUTIVE_HITS_TO_ABORT,
|
||||
)
|
||||
const minNewGramsPerChunk = validatePositiveInt(
|
||||
'minNewGramsPerChunk',
|
||||
options.minNewGramsPerChunk ?? DEFAULT_MIN_NEW_GRAMS_PER_CHUNK,
|
||||
)
|
||||
const sential = createSential(options)
|
||||
let streak = 0
|
||||
|
||||
return {
|
||||
inspect(text: string): TextLoopGuardInspectResult {
|
||||
const result = sential.inspect(text)
|
||||
if (result.newGrams >= minNewGramsPerChunk) {
|
||||
if (result.hit) {
|
||||
streak += 1
|
||||
} else {
|
||||
streak = 0
|
||||
}
|
||||
}
|
||||
return {
|
||||
...result,
|
||||
streak,
|
||||
abort: streak >= consecutiveHitsToAbort,
|
||||
}
|
||||
},
|
||||
reset(): void {
|
||||
sential.reset()
|
||||
streak = 0
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
export function createTextLoopProbeBuffer(
|
||||
chunkSize: number,
|
||||
inspect: (text: string) => void,
|
||||
): TextLoopProbeBuffer {
|
||||
validatePositiveInt('chunkSize', chunkSize)
|
||||
let chars: string[] = []
|
||||
let offset = 0
|
||||
|
||||
const compact = () => {
|
||||
if (offset > 0) {
|
||||
chars = chars.slice(offset)
|
||||
offset = 0
|
||||
}
|
||||
}
|
||||
|
||||
const inspectChunk = (text: string) => {
|
||||
if (text.length > 0) {
|
||||
inspect(text)
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
push(text: string): void {
|
||||
if (!text) return
|
||||
chars.push(...normalizeChars(text))
|
||||
|
||||
while (chars.length - offset >= chunkSize) {
|
||||
const chunk = chars.slice(offset, offset + chunkSize).join('')
|
||||
offset += chunkSize
|
||||
inspectChunk(chunk)
|
||||
}
|
||||
|
||||
// Prevent unbounded front-gaps after many chunks.
|
||||
if (offset >= chunkSize) {
|
||||
compact()
|
||||
}
|
||||
},
|
||||
flush(): void {
|
||||
if (chars.length - offset > 0) {
|
||||
const remainder = chars.slice(offset).join('')
|
||||
inspectChunk(remainder)
|
||||
}
|
||||
chars = []
|
||||
offset = 0
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
export function createToolLoopGuard(
|
||||
options: ToolLoopGuardOptions = {},
|
||||
): ToolLoopGuard {
|
||||
const repeatThreshold = validatePositiveInt(
|
||||
'repeatThreshold',
|
||||
options.repeatThreshold ?? DEFAULT_TOOL_LOOP_REPEAT_THRESHOLD,
|
||||
)
|
||||
const warningsBeforeAbort = validatePositiveInt(
|
||||
'warningsBeforeAbort',
|
||||
options.warningsBeforeAbort ?? DEFAULT_TOOL_LOOP_WARNINGS_BEFORE_ABORT,
|
||||
)
|
||||
const volatileKeySet = new Set(DEFAULT_VOLATILE_KEYS.map(normalizeKeyName))
|
||||
for (const key of options.volatileKeys ?? []) {
|
||||
const normalizedKey = normalizeKeyName(key)
|
||||
if (normalizedKey) {
|
||||
volatileKeySet.add(normalizedKey)
|
||||
}
|
||||
}
|
||||
|
||||
let lastHash = ''
|
||||
let repeatCount = 0
|
||||
let breachCount = 0
|
||||
let breachHash = ''
|
||||
|
||||
return {
|
||||
inspect(input: ToolLoopInspectInput): ToolLoopInspectResult {
|
||||
const hash = computeToolLoopHash(input, volatileKeySet)
|
||||
|
||||
if (hash === lastHash) {
|
||||
repeatCount += 1
|
||||
} else {
|
||||
lastHash = hash
|
||||
repeatCount = 1
|
||||
}
|
||||
|
||||
// Breach phase is fingerprint-specific: switching tool signature restarts it.
|
||||
if (breachHash !== hash) {
|
||||
breachHash = hash
|
||||
breachCount = 0
|
||||
}
|
||||
|
||||
let warn = false
|
||||
let abort = false
|
||||
if (repeatCount > repeatThreshold) {
|
||||
if (breachCount < warningsBeforeAbort) {
|
||||
breachCount += 1
|
||||
warn = true
|
||||
// Reset consecutive accumulation after first warning.
|
||||
lastHash = ''
|
||||
repeatCount = 0
|
||||
} else {
|
||||
breachCount += 1
|
||||
abort = true
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
hash,
|
||||
repeatCount,
|
||||
breachCount,
|
||||
warn,
|
||||
abort,
|
||||
}
|
||||
},
|
||||
reset(): void {
|
||||
lastHash = ''
|
||||
repeatCount = 0
|
||||
breachCount = 0
|
||||
breachHash = ''
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import type { ToolSet } from 'ai'
|
||||
import { createToolLoopGuardedTools } from './tool-loop'
|
||||
|
||||
describe('tool loop guarded tools', () => {
|
||||
it('preserves promised async-iterable tool outputs', async () => {
|
||||
const onAbortToolCall = vi.fn()
|
||||
const streamedChunks = ['chunk-1', 'chunk-2']
|
||||
const stream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of streamedChunks) {
|
||||
yield chunk
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const baseTools = {
|
||||
streamy: {
|
||||
execute: async () => stream,
|
||||
},
|
||||
} as unknown as ToolSet
|
||||
|
||||
const tools = createToolLoopGuardedTools(baseTools, {
|
||||
repeatThreshold: 1,
|
||||
warningsBeforeAbort: 1,
|
||||
onAbortToolCall,
|
||||
warningKey: '__warn',
|
||||
warningText: 'loop warning',
|
||||
})
|
||||
|
||||
const output = await tools.streamy.execute?.({ value: 'same' } as never, { toolCallId: 't-stream' } as never)
|
||||
|
||||
expect(output).toBe(stream)
|
||||
const received: string[] = []
|
||||
for await (const chunk of output as AsyncIterable<string>) {
|
||||
received.push(chunk)
|
||||
}
|
||||
expect(received).toEqual(streamedChunks)
|
||||
expect(onAbortToolCall).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('defers abort to stream layer when onAbortToolCall is provided', async () => {
|
||||
const onAbortToolCall = vi.fn()
|
||||
const baseTools = {
|
||||
echo: {
|
||||
execute: async (input: unknown) => ({ result: input }),
|
||||
},
|
||||
} as unknown as ToolSet
|
||||
const tools = createToolLoopGuardedTools(baseTools, {
|
||||
repeatThreshold: 1,
|
||||
warningsBeforeAbort: 1,
|
||||
onAbortToolCall,
|
||||
warningKey: '__warn',
|
||||
warningText: 'loop warning',
|
||||
})
|
||||
|
||||
await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-1' } as never)
|
||||
const warned = await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-1' } as never)
|
||||
expect(warned).toMatchObject({
|
||||
__warn: {
|
||||
marker: 'MEMOH_TOOL_LOOP_WARNING',
|
||||
},
|
||||
})
|
||||
await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-1' } as never)
|
||||
const abortedOutput = await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-1' } as never)
|
||||
|
||||
expect(onAbortToolCall).toHaveBeenCalledWith('t-1')
|
||||
expect(abortedOutput).toEqual({ result: { value: 'same' } })
|
||||
})
|
||||
|
||||
it('reports abort via callback without throwing inside tool execution', async () => {
|
||||
const onAbortToolCall = vi.fn()
|
||||
const baseTools = {
|
||||
echo: {
|
||||
execute: async (input: unknown) => ({ result: input }),
|
||||
},
|
||||
} as unknown as ToolSet
|
||||
const tools = createToolLoopGuardedTools(baseTools, {
|
||||
repeatThreshold: 1,
|
||||
warningsBeforeAbort: 1,
|
||||
onAbortToolCall,
|
||||
warningKey: '__warn',
|
||||
warningText: 'loop warning',
|
||||
})
|
||||
|
||||
await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-2' } as never)
|
||||
await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-2' } as never)
|
||||
await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-2' } as never)
|
||||
const abortedOutput = await tools.echo.execute?.({ value: 'same' } as never, { toolCallId: 't-2' } as never)
|
||||
expect(onAbortToolCall).toHaveBeenCalledWith('t-2')
|
||||
expect(abortedOutput).toEqual({ result: { value: 'same' } })
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,122 @@
|
||||
import { ToolExecutionOptions, ToolSet } from 'ai'
|
||||
import { createToolLoopGuard, type ToolLoopInspectResult } from './sential'
|
||||
|
||||
export interface CreateToolLoopGuardedToolsOptions {
|
||||
repeatThreshold: number
|
||||
warningsBeforeAbort: number
|
||||
onAbortToolCall: (toolCallId: string) => void
|
||||
warningKey: string
|
||||
warningText: string
|
||||
}
|
||||
|
||||
const isRecord = (value: unknown): value is Record<string, unknown> => {
|
||||
return value !== null && typeof value === 'object' && !Array.isArray(value)
|
||||
}
|
||||
|
||||
const isAsyncIterable = (value: unknown): value is AsyncIterable<unknown> => {
|
||||
return (
|
||||
value !== null &&
|
||||
typeof value === 'object' &&
|
||||
Symbol.asyncIterator in value
|
||||
)
|
||||
}
|
||||
|
||||
const injectToolLoopWarning = (
|
||||
output: unknown,
|
||||
inspectResult: ToolLoopInspectResult,
|
||||
warningKey: string,
|
||||
warningText: string,
|
||||
): unknown => {
|
||||
// Keep warning payload structured so UI/consumers can detect and render it.
|
||||
const warningPayload = {
|
||||
marker: 'MEMOH_TOOL_LOOP_WARNING',
|
||||
message: warningText,
|
||||
fingerprint: inspectResult.hash,
|
||||
breachCount: inspectResult.breachCount,
|
||||
}
|
||||
if (isRecord(output)) {
|
||||
return {
|
||||
...output,
|
||||
[warningKey]: warningPayload,
|
||||
}
|
||||
}
|
||||
return {
|
||||
[warningKey]: warningPayload,
|
||||
result: output,
|
||||
}
|
||||
}
|
||||
|
||||
export function createToolLoopGuardedTools(
|
||||
tools: ToolSet,
|
||||
{
|
||||
repeatThreshold,
|
||||
warningsBeforeAbort,
|
||||
onAbortToolCall,
|
||||
warningKey,
|
||||
warningText,
|
||||
}: CreateToolLoopGuardedToolsOptions,
|
||||
): ToolSet {
|
||||
const guard = createToolLoopGuard({
|
||||
repeatThreshold,
|
||||
warningsBeforeAbort,
|
||||
})
|
||||
|
||||
// Wrap each executable tool to inspect (toolName + input) after execution.
|
||||
// First breach injects a warning into this tool result; second breach signals abort.
|
||||
return Object.fromEntries(
|
||||
Object.entries(tools).map(([toolName, toolDefinition]) => {
|
||||
const execute = toolDefinition.execute
|
||||
if (typeof execute !== 'function') {
|
||||
return [toolName, toolDefinition]
|
||||
}
|
||||
|
||||
const wrappedTool = {
|
||||
...toolDefinition,
|
||||
execute: (
|
||||
toolInput: unknown,
|
||||
options: ToolExecutionOptions,
|
||||
) => {
|
||||
const directOutput = execute(
|
||||
toolInput as never,
|
||||
options as never,
|
||||
) as unknown
|
||||
|
||||
// Streamed tool outputs are passed through unchanged to preserve streaming semantics.
|
||||
if (isAsyncIterable(directOutput)) {
|
||||
return directOutput as never
|
||||
}
|
||||
|
||||
return (async () => {
|
||||
const resolvedOutput = await directOutput
|
||||
|
||||
// Tools may return Promise<AsyncIterable>; keep that stream untouched too.
|
||||
if (isAsyncIterable(resolvedOutput)) {
|
||||
return resolvedOutput as never
|
||||
}
|
||||
|
||||
const inspectResult = guard.inspect({
|
||||
toolName,
|
||||
input: toolInput,
|
||||
})
|
||||
if (inspectResult.abort) {
|
||||
// Report loop abort to generation layer; it decides when/how to stop.
|
||||
onAbortToolCall(options.toolCallId)
|
||||
return resolvedOutput as never
|
||||
}
|
||||
if (inspectResult.warn) {
|
||||
return injectToolLoopWarning(
|
||||
resolvedOutput,
|
||||
inspectResult,
|
||||
warningKey,
|
||||
warningText,
|
||||
) as never
|
||||
}
|
||||
return resolvedOutput as never
|
||||
})()
|
||||
},
|
||||
}
|
||||
|
||||
return [toolName, wrappedTool]
|
||||
}),
|
||||
) as ToolSet
|
||||
}
|
||||
@@ -38,6 +38,10 @@ export interface InboxItem {
|
||||
createdAt: string
|
||||
}
|
||||
|
||||
export interface LoopDetectionConfig {
|
||||
enabled?: boolean
|
||||
}
|
||||
|
||||
export interface AgentParams {
|
||||
model: ModelConfig
|
||||
language?: string
|
||||
@@ -50,6 +54,7 @@ export interface AgentParams {
|
||||
auth: AgentAuthContext
|
||||
skills?: AgentSkill[]
|
||||
inbox?: InboxItem[]
|
||||
loopDetection?: LoopDetectionConfig
|
||||
}
|
||||
|
||||
export interface AgentInput {
|
||||
|
||||
@@ -487,6 +487,7 @@
|
||||
"heartbeatModelPlaceholder": "Use chat model (default)",
|
||||
"allowGuest": "Allow Guest Access",
|
||||
"allowGuestPersonalHint": "Personal bots do not support guest access. Use a public bot instead.",
|
||||
"loopDetectionTitle": "Detect and auto-block output loops",
|
||||
"searchModel": "Search models…",
|
||||
"noModel": "No models available",
|
||||
"saveSuccess": "Settings saved",
|
||||
|
||||
@@ -483,6 +483,7 @@
|
||||
"heartbeatModelPlaceholder": "使用聊天模型(默认)",
|
||||
"allowGuest": "允许游客访问",
|
||||
"allowGuestPersonalHint": "个人 Bot 不支持游客访问,请使用公开 Bot。",
|
||||
"loopDetectionTitle": "自动检测并阻止模型循环输出",
|
||||
"searchModel": "搜索模型…",
|
||||
"noModel": "暂无可选模型",
|
||||
"saveSuccess": "设置已保存",
|
||||
|
||||
@@ -134,6 +134,17 @@
|
||||
<Separator />
|
||||
</template>
|
||||
|
||||
|
||||
<!-- Loop Detection -->
|
||||
<div class="flex items-center justify-between">
|
||||
<Label>{{ $t('bots.settings.loopDetectionTitle') }}</Label>
|
||||
<Switch
|
||||
:model-value="loopDetectionEnabled"
|
||||
:disabled="isLoopDetectionToggleDisabled"
|
||||
@update:model-value="handleLoopDetectionToggle"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Save -->
|
||||
<div class="flex justify-end">
|
||||
<Button
|
||||
@@ -195,7 +206,7 @@ import {
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from '@memoh/ui'
|
||||
import { reactive, computed, watch } from 'vue'
|
||||
import { reactive, computed, watch, ref } from 'vue'
|
||||
import { useRouter } from 'vue-router'
|
||||
import { toast } from 'vue-sonner'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
@@ -203,7 +214,7 @@ import ConfirmPopover from '@/components/confirm-popover/index.vue'
|
||||
import ModelSelect from './model-select.vue'
|
||||
import SearchProviderSelect from './search-provider-select.vue'
|
||||
import { useQuery, useMutation, useQueryCache } from '@pinia/colada'
|
||||
import { getBotsByBotIdSettings, putBotsByBotIdSettings, deleteBotsById, getModels, getProviders, getSearchProviders } from '@memoh/sdk'
|
||||
import { getBotsByBotIdSettings, putBotsByBotIdSettings, deleteBotsById, getModels, getProviders, getSearchProviders, getBotsById, putBotsById } from '@memoh/sdk'
|
||||
import type { SettingsSettings } from '@memoh/sdk'
|
||||
import type { Ref } from 'vue'
|
||||
import { resolveApiErrorMessage } from '@/utils/api-error'
|
||||
@@ -232,6 +243,15 @@ const { data: settings } = useQuery({
|
||||
enabled: () => !!botIdRef.value,
|
||||
})
|
||||
|
||||
const { data: botProfile, isLoading: isBotProfileLoading } = useQuery({
|
||||
key: () => ['bot', botIdRef.value],
|
||||
query: async () => {
|
||||
const { data } = await getBotsById({ path: { id: botIdRef.value }, throwOnError: true })
|
||||
return data
|
||||
},
|
||||
enabled: () => !!botIdRef.value,
|
||||
})
|
||||
|
||||
const { data: modelData } = useQuery({
|
||||
key: ['all-models'],
|
||||
query: async () => {
|
||||
@@ -278,6 +298,21 @@ const { mutateAsync: deleteBot, isLoading: deleteLoading } = useMutation({
|
||||
},
|
||||
})
|
||||
|
||||
const { mutateAsync: updateBotProfile, isLoading: updateBotProfileLoading } = useMutation({
|
||||
mutation: async (metadata: Record<string, unknown>) => {
|
||||
const { data } = await putBotsById({
|
||||
path: { id: botIdRef.value },
|
||||
body: { metadata },
|
||||
throwOnError: true,
|
||||
})
|
||||
return data
|
||||
},
|
||||
onSettled: () => {
|
||||
queryCache.invalidateQueries({ key: ['bots'] })
|
||||
queryCache.invalidateQueries({ key: ['bot', botIdRef.value] })
|
||||
},
|
||||
})
|
||||
|
||||
const models = computed(() => modelData.value ?? [])
|
||||
const providers = computed(() => providerData.value ?? [])
|
||||
const searchProviders = computed(() => searchProviderData.value ?? [])
|
||||
@@ -302,6 +337,23 @@ const form = reactive({
|
||||
reasoning_effort: 'medium',
|
||||
})
|
||||
|
||||
const loopDetectionEnabled = ref(false)
|
||||
const loopDetectionMetadata = ref<Record<string, unknown> | null>(null)
|
||||
|
||||
const asRecord = (value: unknown): Record<string, unknown> => {
|
||||
if (value && typeof value === 'object' && !Array.isArray(value)) {
|
||||
return value as Record<string, unknown>
|
||||
}
|
||||
return {}
|
||||
}
|
||||
|
||||
const readLoopDetectionEnabled = (metadata: unknown): boolean => {
|
||||
const metadataRecord = asRecord(metadata)
|
||||
const featuresRecord = asRecord(metadataRecord.features)
|
||||
const loopDetectionRecord = asRecord(featuresRecord.loop_detection)
|
||||
return loopDetectionRecord.enabled === true
|
||||
}
|
||||
|
||||
watch(settings, (val) => {
|
||||
if (val) {
|
||||
form.chat_model_id = val.chat_model_id ?? ''
|
||||
@@ -317,6 +369,13 @@ watch(settings, (val) => {
|
||||
}
|
||||
}, { immediate: true })
|
||||
|
||||
watch(botProfile, (val) => {
|
||||
if (val === undefined) return
|
||||
const metadata = asRecord(val?.metadata)
|
||||
loopDetectionMetadata.value = metadata
|
||||
loopDetectionEnabled.value = readLoopDetectionEnabled(metadata)
|
||||
}, { immediate: true })
|
||||
|
||||
const hasChanges = computed(() => {
|
||||
if (!settings.value) return true
|
||||
const s = settings.value
|
||||
@@ -336,6 +395,40 @@ const hasChanges = computed(() => {
|
||||
return changed
|
||||
})
|
||||
|
||||
const isLoopDetectionToggleDisabled = computed(() =>
|
||||
updateBotProfileLoading.value || isBotProfileLoading.value || loopDetectionMetadata.value === null,
|
||||
)
|
||||
|
||||
async function handleLoopDetectionToggle(value: boolean) {
|
||||
if (isLoopDetectionToggleDisabled.value || loopDetectionMetadata.value === null) return
|
||||
const nextEnabled = value === true
|
||||
const prevEnabled = loopDetectionEnabled.value
|
||||
if (nextEnabled === prevEnabled) return
|
||||
|
||||
loopDetectionEnabled.value = nextEnabled
|
||||
const currentMetadata = loopDetectionMetadata.value
|
||||
const currentFeatures = asRecord(currentMetadata.features)
|
||||
const currentLoopDetection = asRecord(currentFeatures.loop_detection)
|
||||
const nextMetadata = {
|
||||
...currentMetadata,
|
||||
features: {
|
||||
...currentFeatures,
|
||||
loop_detection: {
|
||||
...currentLoopDetection,
|
||||
enabled: nextEnabled,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
try {
|
||||
await updateBotProfile(nextMetadata)
|
||||
loopDetectionMetadata.value = nextMetadata
|
||||
} catch (error) {
|
||||
loopDetectionEnabled.value = prevEnabled
|
||||
toast.error(resolveApiErrorMessage(error, t('common.saveFailed')))
|
||||
}
|
||||
}
|
||||
|
||||
async function handleSave() {
|
||||
try {
|
||||
await updateSettings({ ...form })
|
||||
|
||||
Reference in New Issue
Block a user