diff --git a/internal/agent/agent.go b/internal/agent/agent.go index efdc4808..cd570255 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -7,6 +7,7 @@ import ( "fmt" "log/slog" "strings" + "sync" "time" sdk "github.com/memohai/twilight-ai/sdk" @@ -99,7 +100,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv var textLoopGuard *TextLoopGuard var textLoopProbeBuffer *TextLoopProbeBuffer var toolLoopGuard *ToolLoopGuard - toolLoopAbortCallIDs := make(map[string]struct{}) + toolLoopAbortCallIDs := newToolAbortRegistry() if cfg.LoopDetection.Enabled { textLoopGuard = NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{}) textLoopProbeBuffer = NewTextLoopProbeBuffer(LoopDetectedProbeChars, func(text string) { @@ -314,11 +315,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv } case *sdk.StreamToolResultPart: - shouldAbort := false - if _, ok := toolLoopAbortCallIDs[p.ToolCallID]; ok { - delete(toolLoopAbortCallIDs, p.ToolCallID) - shouldAbort = true - } + shouldAbort := toolLoopAbortCallIDs.Take(p.ToolCallID) stepNumber++ if !sendEvent(ctx, ch, StreamEvent{ Type: EventToolCallEnd, @@ -440,7 +437,10 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) { // Collecting emitter: tools push side-effect events here during generation. var collected []tools.ToolStreamEvent + var collectedMu sync.Mutex collectEmitter := tools.StreamEmitter(func(evt tools.ToolStreamEvent) { + collectedMu.Lock() + defer collectedMu.Unlock() collected = append(collected, evt) }) @@ -456,7 +456,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult var toolLoopGuard *ToolLoopGuard var textLoopGuard *TextLoopGuard - toolLoopAbortCallIDs := make(map[string]struct{}) + toolLoopAbortCallIDs := newToolAbortRegistry() if cfg.LoopDetection.Enabled { toolLoopGuard = NewToolLoopGuard(ToolLoopRepeatThreshold, ToolLoopWarningsBeforeAbort) textLoopGuard = NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{}) @@ -490,7 +490,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult opts = append(opts, sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams { if cfg.LoopDetection.Enabled { - if len(toolLoopAbortCallIDs) > 0 { + if toolLoopAbortCallIDs.Any() { return nil // stop } if textLoopGuard != nil && isNonEmptyString(step.Text) { @@ -701,7 +701,7 @@ func drainBackgroundNotifications( return p } -func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs map[string]struct{}) []sdk.Tool { +func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs *toolAbortRegistry) []sdk.Tool { wrapped := make([]sdk.Tool, len(tools)) for i, tool := range tools { originalExecute := tool.Execute @@ -710,7 +710,7 @@ func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs wrapped[i].Execute = func(ctx *sdk.ToolExecContext, input any) (any, error) { warn, abort := guard.Guard(toolName, input) if abort { - abortCallIDs[ctx.ToolCallID] = struct{}{} + abortCallIDs.Add(ctx.ToolCallID) return map[string]any{ "isError": true, "content": []map[string]any{{ diff --git a/internal/agent/guard_state.go b/internal/agent/guard_state.go new file mode 100644 index 00000000..f54fe902 --- /dev/null +++ b/internal/agent/guard_state.go @@ -0,0 +1,48 @@ +package agent + +import "sync" + +type toolAbortRegistry struct { + mu sync.Mutex + ids map[string]struct{} +} + +func newToolAbortRegistry() *toolAbortRegistry { + return &toolAbortRegistry{ + ids: make(map[string]struct{}), + } +} + +func (r *toolAbortRegistry) Add(toolCallID string) { + if r == nil || toolCallID == "" { + return + } + + r.mu.Lock() + defer r.mu.Unlock() + r.ids[toolCallID] = struct{}{} +} + +func (r *toolAbortRegistry) Take(toolCallID string) bool { + if r == nil || toolCallID == "" { + return false + } + + r.mu.Lock() + defer r.mu.Unlock() + if _, ok := r.ids[toolCallID]; !ok { + return false + } + delete(r.ids, toolCallID) + return true +} + +func (r *toolAbortRegistry) Any() bool { + if r == nil { + return false + } + + r.mu.Lock() + defer r.mu.Unlock() + return len(r.ids) > 0 +} diff --git a/internal/agent/guard_state_test.go b/internal/agent/guard_state_test.go new file mode 100644 index 00000000..f4b6b5f8 --- /dev/null +++ b/internal/agent/guard_state_test.go @@ -0,0 +1,32 @@ +package agent + +import ( + "fmt" + "sync" + "testing" +) + +func TestToolAbortRegistryConcurrentAccess(t *testing.T) { + registry := newToolAbortRegistry() + + var wg sync.WaitGroup + for i := 0; i < 32; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + id := fmt.Sprintf("call-%d", i) + registry.Add(id) + if !registry.Any() { + t.Error("expected registry to report pending aborts") + } + if !registry.Take(id) { + t.Errorf("expected to take %s", id) + } + }(i) + } + wg.Wait() + + if registry.Any() { + t.Fatal("expected registry to be empty") + } +} diff --git a/internal/agent/sential.go b/internal/agent/sential.go index 51d5784d..23ec52f7 100644 --- a/internal/agent/sential.go +++ b/internal/agent/sential.go @@ -7,6 +7,7 @@ import ( "fmt" "sort" "strings" + "sync" "unicode" "unicode/utf8" ) @@ -322,6 +323,7 @@ type ToolLoopResult struct { // ToolLoopGuard detects repeated identical tool calls. type ToolLoopGuard struct { + mu sync.Mutex repeatThreshold int warningsBeforeAbort int volatileKeySet map[string]struct{} @@ -352,8 +354,17 @@ func NewToolLoopGuard(repeatThreshold, warningsBeforeAbort int) *ToolLoopGuard { // Inspect checks a tool call for repetition. func (g *ToolLoopGuard) Inspect(input ToolLoopInput) ToolLoopResult { + if g == nil { + return ToolLoopResult{ + Hash: computeToolLoopHash(input, nil), + } + } + hash := computeToolLoopHash(input, g.volatileKeySet) + g.mu.Lock() + defer g.mu.Unlock() + if hash == g.lastHash { g.repeatCount++ } else { @@ -391,6 +402,13 @@ func (g *ToolLoopGuard) Inspect(input ToolLoopInput) ToolLoopResult { // Reset clears the guard state. func (g *ToolLoopGuard) Reset() { + if g == nil { + return + } + + g.mu.Lock() + defer g.mu.Unlock() + g.lastHash = "" g.repeatCount = 0 g.breachCount = 0 diff --git a/internal/agent/sential_test.go b/internal/agent/sential_test.go new file mode 100644 index 00000000..acef5ec8 --- /dev/null +++ b/internal/agent/sential_test.go @@ -0,0 +1,57 @@ +package agent + +import ( + "sync" + "testing" +) + +func TestToolLoopGuardNilReceiver(t *testing.T) { + var guard *ToolLoopGuard + + result := guard.Inspect(ToolLoopInput{ + ToolName: "search", + Input: map[string]any{ + "query": "memoh", + }, + }) + + if result.Hash == "" { + t.Fatal("expected hash for nil receiver") + } + if result.Warn { + t.Fatal("did not expect warning for nil receiver") + } + if result.Abort { + t.Fatal("did not expect abort for nil receiver") + } +} + +func TestToolLoopGuardConcurrentInspectAndReset(t *testing.T) { + guard := NewToolLoopGuard(2, 1) + input := ToolLoopInput{ + ToolName: "web_search", + Input: map[string]any{ + "query": "memoh logs", + "requestId": "volatile", + }, + } + + var wg sync.WaitGroup + for i := 0; i < 16; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + for j := 0; j < 200; j++ { + result := guard.Inspect(input) + if result.Hash == "" { + t.Error("expected non-empty hash") + return + } + if (i+j)%25 == 0 { + guard.Reset() + } + } + }(i) + } + wg.Wait() +}