fix(agent): stop generate loop aborts correctly

This commit is contained in:
Fodesu
2026-04-14 00:38:26 +08:00
committed by 晨苒
parent 59147b255d
commit 33461d7ac1
3 changed files with 235 additions and 5 deletions
+66 -5
View File
@@ -435,6 +435,10 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
} }
func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) { func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) {
genCtx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
loopAbort := newLoopAbortState()
// Collecting emitter: tools push side-effect events here during generation. // Collecting emitter: tools push side-effect events here during generation.
var collected []tools.ToolStreamEvent var collected []tools.ToolStreamEvent
var collectedMu sync.Mutex var collectedMu sync.Mutex
@@ -447,7 +451,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
var sdkTools []sdk.Tool var sdkTools []sdk.Tool
if cfg.SupportsToolCall { if cfg.SupportsToolCall {
var err error var err error
sdkTools, err = a.assembleTools(ctx, cfg, collectEmitter) sdkTools, err = a.assembleTools(genCtx, cfg, collectEmitter)
if err != nil { if err != nil {
return nil, fmt.Errorf("assemble tools: %w", err) return nil, fmt.Errorf("assemble tools: %w", err)
} }
@@ -491,12 +495,16 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams { sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams {
if cfg.LoopDetection.Enabled { if cfg.LoopDetection.Enabled {
if toolLoopAbortCallIDs.Any() { if toolLoopAbortCallIDs.Any() {
return nil // stop loopAbort.Set(ErrToolLoopDetected)
cancel(ErrToolLoopDetected)
return nil
} }
if textLoopGuard != nil && isNonEmptyString(step.Text) { if textLoopGuard != nil && isNonEmptyString(step.Text) {
result := textLoopGuard.Inspect(step.Text) result := textLoopGuard.Inspect(step.Text)
if result.Abort { if result.Abort {
return nil // stop loopAbort.Set(ErrTextLoopDetected)
cancel(ErrTextLoopDetected)
return nil
} }
} }
} }
@@ -504,10 +512,16 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
}), }),
) )
genResult, err := a.client.GenerateTextResult(ctx, opts...) genResult, err := a.client.GenerateTextResult(genCtx, opts...)
if err != nil { if err != nil {
if loopErr := detectGenerateLoopAbort(genCtx, err); loopErr != nil {
return nil, loopErr
}
return nil, fmt.Errorf("generate: %w", err) return nil, fmt.Errorf("generate: %w", err)
} }
if loopErr := loopAbort.Err(); loopErr != nil {
return nil, loopErr
}
// Drain collected tool-emitted side effects into the result. // Drain collected tool-emitted side effects into the result.
var attachments []FileAttachment var attachments []FileAttachment
@@ -717,7 +731,7 @@ func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs
"type": "text", "type": "text",
"text": ToolLoopDetectedAbortMessage, "text": ToolLoopDetectedAbortMessage,
}}, }},
}, errors.New(ToolLoopDetectedAbortMessage) }, ErrToolLoopDetected
} }
if warn { if warn {
return map[string]any{ return map[string]any{
@@ -979,3 +993,50 @@ func sleepWithContext(ctx context.Context, d time.Duration) error {
return ctx.Err() return ctx.Err()
} }
} }
func detectGenerateLoopAbort(ctx context.Context, err error) error {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
return nil
}
cause := context.Cause(ctx)
switch {
case errors.Is(cause, ErrToolLoopDetected):
return ErrToolLoopDetected
case errors.Is(cause, ErrTextLoopDetected):
return ErrTextLoopDetected
default:
return nil
}
}
type loopAbortState struct {
mu sync.Mutex
err error
}
func newLoopAbortState() *loopAbortState {
return &loopAbortState{}
}
func (s *loopAbortState) Set(err error) {
if s == nil || err == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
if s.err == nil {
s.err = err
}
}
func (s *loopAbortState) Err() error {
if s == nil {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
return s.err
}
+163
View File
@@ -0,0 +1,163 @@
package agent
import (
"context"
"errors"
"testing"
"github.com/google/jsonschema-go/jsonschema"
sdk "github.com/memohai/twilight-ai/sdk"
agenttools "github.com/memohai/memoh/internal/agent/tools"
)
type staticToolProvider struct {
tools []sdk.Tool
}
func (p staticToolProvider) Tools(context.Context, agenttools.SessionContext) ([]sdk.Tool, error) {
return p.tools, nil
}
func TestAgentGenerateStopsOnToolLoopAbort(t *testing.T) {
t.Parallel()
modelProvider := &agentReadMediaMockProvider{
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
return &sdk.GenerateResult{
FinishReason: sdk.FinishReasonToolCalls,
ToolCalls: []sdk.ToolCall{{
ToolCallID: "call-same",
ToolName: "loop_tool",
Input: map[string]any{"query": "same"},
}},
}, nil
},
}
a := New(Deps{})
a.SetToolProviders([]agenttools.ToolProvider{
staticToolProvider{
tools: []sdk.Tool{{
Name: "loop_tool",
Parameters: &jsonschema.Schema{Type: "object"},
Execute: func(ctx *sdk.ToolExecContext, input any) (any, error) {
return map[string]any{"ok": true}, nil
},
}},
},
})
_, err := a.Generate(context.Background(), RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("loop")},
SupportsToolCall: true,
Identity: SessionContext{BotID: "bot-1"},
LoopDetection: LoopDetectionConfig{Enabled: true},
})
if !errors.Is(err, ErrToolLoopDetected) {
t.Fatalf("expected ErrToolLoopDetected, got %v", err)
}
if modelProvider.calls >= 20 {
t.Fatalf("expected tool loop to stop generation, got %d provider calls", modelProvider.calls)
}
}
func TestAgentGenerateStopsOnTextLoopAbort(t *testing.T) {
t.Parallel()
repeatedText := "abcdefghijklmnopqrstuvwxyz0123456789 repeated text chunk for loop detection"
modelProvider := &agentReadMediaMockProvider{
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
return &sdk.GenerateResult{
Text: repeatedText,
FinishReason: sdk.FinishReasonToolCalls,
ToolCalls: []sdk.ToolCall{{
ToolCallID: "call-text",
ToolName: "noop_tool",
Input: map[string]any{"step": call},
}},
}, nil
},
}
a := New(Deps{})
a.SetToolProviders([]agenttools.ToolProvider{
staticToolProvider{
tools: []sdk.Tool{{
Name: "noop_tool",
Parameters: &jsonschema.Schema{Type: "object"},
Execute: func(ctx *sdk.ToolExecContext, input any) (any, error) {
return map[string]any{"ok": true}, nil
},
}},
},
})
_, err := a.Generate(context.Background(), RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("loop text")},
SupportsToolCall: true,
Identity: SessionContext{BotID: "bot-1"},
LoopDetection: LoopDetectionConfig{Enabled: true},
})
if !errors.Is(err, ErrTextLoopDetected) {
t.Fatalf("expected ErrTextLoopDetected, got %v", err)
}
if modelProvider.calls >= 10 {
t.Fatalf("expected text loop to stop generation, got %d provider calls", modelProvider.calls)
}
}
func TestAgentGenerateStopsOnTerminalTextLoopAbort(t *testing.T) {
t.Parallel()
repeatedText := "abcdefghijklmnopqrstuvwxyz0123456789 repeated text chunk for loop detection"
modelProvider := &agentReadMediaMockProvider{
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
finishReason := sdk.FinishReasonToolCalls
var toolCalls []sdk.ToolCall
if call < 4 {
toolCalls = []sdk.ToolCall{{
ToolCallID: "call-terminal",
ToolName: "noop_tool",
Input: map[string]any{"step": call},
}}
} else {
finishReason = sdk.FinishReasonStop
}
return &sdk.GenerateResult{
Text: repeatedText,
FinishReason: finishReason,
ToolCalls: toolCalls,
}, nil
},
}
a := New(Deps{})
a.SetToolProviders([]agenttools.ToolProvider{
staticToolProvider{
tools: []sdk.Tool{{
Name: "noop_tool",
Parameters: &jsonschema.Schema{Type: "object"},
Execute: func(ctx *sdk.ToolExecContext, input any) (any, error) {
return map[string]any{"ok": true}, nil
},
}},
},
})
_, err := a.Generate(context.Background(), RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("loop text terminal")},
SupportsToolCall: true,
Identity: SessionContext{BotID: "bot-1"},
LoopDetection: LoopDetectionConfig{Enabled: true},
})
if !errors.Is(err, ErrTextLoopDetected) {
t.Fatalf("expected ErrTextLoopDetected, got %v", err)
}
if modelProvider.calls != 4 {
t.Fatalf("expected terminal text loop to abort on final step, got %d provider calls", modelProvider.calls)
}
}
+6
View File
@@ -4,6 +4,7 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"sort" "sort"
"strings" "strings"
@@ -31,6 +32,11 @@ const (
defaultMinNewGramsPerChunk = 1 defaultMinNewGramsPerChunk = 1
) )
var (
ErrTextLoopDetected = errors.New(LoopDetectedAbortMessage)
ErrToolLoopDetected = errors.New(ToolLoopDetectedAbortMessage)
)
// --- Sential: n-gram overlap detector --- // --- Sential: n-gram overlap detector ---
// SentialOptions configures the n-gram overlap detector. // SentialOptions configures the n-gram overlap detector.