diff --git a/internal/agent/agent.go b/internal/agent/agent.go index cd570255..5bd888af 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -435,6 +435,10 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv } func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) { + genCtx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + loopAbort := newLoopAbortState() + // Collecting emitter: tools push side-effect events here during generation. var collected []tools.ToolStreamEvent var collectedMu sync.Mutex @@ -447,7 +451,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult var sdkTools []sdk.Tool if cfg.SupportsToolCall { var err error - sdkTools, err = a.assembleTools(ctx, cfg, collectEmitter) + sdkTools, err = a.assembleTools(genCtx, cfg, collectEmitter) if err != nil { return nil, fmt.Errorf("assemble tools: %w", err) } @@ -491,12 +495,16 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams { if cfg.LoopDetection.Enabled { if toolLoopAbortCallIDs.Any() { - return nil // stop + loopAbort.Set(ErrToolLoopDetected) + cancel(ErrToolLoopDetected) + return nil } if textLoopGuard != nil && isNonEmptyString(step.Text) { result := textLoopGuard.Inspect(step.Text) if result.Abort { - return nil // stop + loopAbort.Set(ErrTextLoopDetected) + cancel(ErrTextLoopDetected) + return nil } } } @@ -504,10 +512,16 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult }), ) - genResult, err := a.client.GenerateTextResult(ctx, opts...) + genResult, err := a.client.GenerateTextResult(genCtx, opts...) if err != nil { + if loopErr := detectGenerateLoopAbort(genCtx, err); loopErr != nil { + return nil, loopErr + } return nil, fmt.Errorf("generate: %w", err) } + if loopErr := loopAbort.Err(); loopErr != nil { + return nil, loopErr + } // Drain collected tool-emitted side effects into the result. var attachments []FileAttachment @@ -717,7 +731,7 @@ func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs "type": "text", "text": ToolLoopDetectedAbortMessage, }}, - }, errors.New(ToolLoopDetectedAbortMessage) + }, ErrToolLoopDetected } if warn { return map[string]any{ @@ -979,3 +993,50 @@ func sleepWithContext(ctx context.Context, d time.Duration) error { return ctx.Err() } } + +func detectGenerateLoopAbort(ctx context.Context, err error) error { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + return nil + } + + cause := context.Cause(ctx) + switch { + case errors.Is(cause, ErrToolLoopDetected): + return ErrToolLoopDetected + case errors.Is(cause, ErrTextLoopDetected): + return ErrTextLoopDetected + default: + return nil + } +} + +type loopAbortState struct { + mu sync.Mutex + err error +} + +func newLoopAbortState() *loopAbortState { + return &loopAbortState{} +} + +func (s *loopAbortState) Set(err error) { + if s == nil || err == nil { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + if s.err == nil { + s.err = err + } +} + +func (s *loopAbortState) Err() error { + if s == nil { + return nil + } + + s.mu.Lock() + defer s.mu.Unlock() + return s.err +} diff --git a/internal/agent/generate_loop_test.go b/internal/agent/generate_loop_test.go new file mode 100644 index 00000000..3afa8804 --- /dev/null +++ b/internal/agent/generate_loop_test.go @@ -0,0 +1,163 @@ +package agent + +import ( + "context" + "errors" + "testing" + + "github.com/google/jsonschema-go/jsonschema" + sdk "github.com/memohai/twilight-ai/sdk" + + agenttools "github.com/memohai/memoh/internal/agent/tools" +) + +type staticToolProvider struct { + tools []sdk.Tool +} + +func (p staticToolProvider) Tools(context.Context, agenttools.SessionContext) ([]sdk.Tool, error) { + return p.tools, nil +} + +func TestAgentGenerateStopsOnToolLoopAbort(t *testing.T) { + t.Parallel() + + modelProvider := &agentReadMediaMockProvider{ + handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) { + return &sdk.GenerateResult{ + FinishReason: sdk.FinishReasonToolCalls, + ToolCalls: []sdk.ToolCall{{ + ToolCallID: "call-same", + ToolName: "loop_tool", + Input: map[string]any{"query": "same"}, + }}, + }, nil + }, + } + + a := New(Deps{}) + a.SetToolProviders([]agenttools.ToolProvider{ + staticToolProvider{ + tools: []sdk.Tool{{ + Name: "loop_tool", + Parameters: &jsonschema.Schema{Type: "object"}, + Execute: func(ctx *sdk.ToolExecContext, input any) (any, error) { + return map[string]any{"ok": true}, nil + }, + }}, + }, + }) + + _, err := a.Generate(context.Background(), RunConfig{ + Model: &sdk.Model{ID: "mock-model", Provider: modelProvider}, + Messages: []sdk.Message{sdk.UserMessage("loop")}, + SupportsToolCall: true, + Identity: SessionContext{BotID: "bot-1"}, + LoopDetection: LoopDetectionConfig{Enabled: true}, + }) + if !errors.Is(err, ErrToolLoopDetected) { + t.Fatalf("expected ErrToolLoopDetected, got %v", err) + } + if modelProvider.calls >= 20 { + t.Fatalf("expected tool loop to stop generation, got %d provider calls", modelProvider.calls) + } +} + +func TestAgentGenerateStopsOnTextLoopAbort(t *testing.T) { + t.Parallel() + + repeatedText := "abcdefghijklmnopqrstuvwxyz0123456789 repeated text chunk for loop detection" + modelProvider := &agentReadMediaMockProvider{ + handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) { + return &sdk.GenerateResult{ + Text: repeatedText, + FinishReason: sdk.FinishReasonToolCalls, + ToolCalls: []sdk.ToolCall{{ + ToolCallID: "call-text", + ToolName: "noop_tool", + Input: map[string]any{"step": call}, + }}, + }, nil + }, + } + + a := New(Deps{}) + a.SetToolProviders([]agenttools.ToolProvider{ + staticToolProvider{ + tools: []sdk.Tool{{ + Name: "noop_tool", + Parameters: &jsonschema.Schema{Type: "object"}, + Execute: func(ctx *sdk.ToolExecContext, input any) (any, error) { + return map[string]any{"ok": true}, nil + }, + }}, + }, + }) + + _, err := a.Generate(context.Background(), RunConfig{ + Model: &sdk.Model{ID: "mock-model", Provider: modelProvider}, + Messages: []sdk.Message{sdk.UserMessage("loop text")}, + SupportsToolCall: true, + Identity: SessionContext{BotID: "bot-1"}, + LoopDetection: LoopDetectionConfig{Enabled: true}, + }) + if !errors.Is(err, ErrTextLoopDetected) { + t.Fatalf("expected ErrTextLoopDetected, got %v", err) + } + if modelProvider.calls >= 10 { + t.Fatalf("expected text loop to stop generation, got %d provider calls", modelProvider.calls) + } +} + +func TestAgentGenerateStopsOnTerminalTextLoopAbort(t *testing.T) { + t.Parallel() + + repeatedText := "abcdefghijklmnopqrstuvwxyz0123456789 repeated text chunk for loop detection" + modelProvider := &agentReadMediaMockProvider{ + handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) { + finishReason := sdk.FinishReasonToolCalls + var toolCalls []sdk.ToolCall + if call < 4 { + toolCalls = []sdk.ToolCall{{ + ToolCallID: "call-terminal", + ToolName: "noop_tool", + Input: map[string]any{"step": call}, + }} + } else { + finishReason = sdk.FinishReasonStop + } + return &sdk.GenerateResult{ + Text: repeatedText, + FinishReason: finishReason, + ToolCalls: toolCalls, + }, nil + }, + } + + a := New(Deps{}) + a.SetToolProviders([]agenttools.ToolProvider{ + staticToolProvider{ + tools: []sdk.Tool{{ + Name: "noop_tool", + Parameters: &jsonschema.Schema{Type: "object"}, + Execute: func(ctx *sdk.ToolExecContext, input any) (any, error) { + return map[string]any{"ok": true}, nil + }, + }}, + }, + }) + + _, err := a.Generate(context.Background(), RunConfig{ + Model: &sdk.Model{ID: "mock-model", Provider: modelProvider}, + Messages: []sdk.Message{sdk.UserMessage("loop text terminal")}, + SupportsToolCall: true, + Identity: SessionContext{BotID: "bot-1"}, + LoopDetection: LoopDetectionConfig{Enabled: true}, + }) + if !errors.Is(err, ErrTextLoopDetected) { + t.Fatalf("expected ErrTextLoopDetected, got %v", err) + } + if modelProvider.calls != 4 { + t.Fatalf("expected terminal text loop to abort on final step, got %d provider calls", modelProvider.calls) + } +} diff --git a/internal/agent/sential.go b/internal/agent/sential.go index 23ec52f7..8ac8c408 100644 --- a/internal/agent/sential.go +++ b/internal/agent/sential.go @@ -4,6 +4,7 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "fmt" "sort" "strings" @@ -31,6 +32,11 @@ const ( defaultMinNewGramsPerChunk = 1 ) +var ( + ErrTextLoopDetected = errors.New(LoopDetectedAbortMessage) + ErrToolLoopDetected = errors.New(ToolLoopDetectedAbortMessage) +) + // --- Sential: n-gram overlap detector --- // SentialOptions configures the n-gram overlap detector.