fix(agent): guard tool loop state against concurrent tool execution

This commit is contained in:
Fodesu
2026-04-13 23:50:22 +08:00
committed by 晨苒
parent 1a6d12a137
commit 59147b255d
5 changed files with 165 additions and 10 deletions
+10 -10
View File
@@ -7,6 +7,7 @@ import (
"fmt"
"log/slog"
"strings"
"sync"
"time"
sdk "github.com/memohai/twilight-ai/sdk"
@@ -99,7 +100,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
var textLoopGuard *TextLoopGuard
var textLoopProbeBuffer *TextLoopProbeBuffer
var toolLoopGuard *ToolLoopGuard
toolLoopAbortCallIDs := make(map[string]struct{})
toolLoopAbortCallIDs := newToolAbortRegistry()
if cfg.LoopDetection.Enabled {
textLoopGuard = NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{})
textLoopProbeBuffer = NewTextLoopProbeBuffer(LoopDetectedProbeChars, func(text string) {
@@ -314,11 +315,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
}
case *sdk.StreamToolResultPart:
shouldAbort := false
if _, ok := toolLoopAbortCallIDs[p.ToolCallID]; ok {
delete(toolLoopAbortCallIDs, p.ToolCallID)
shouldAbort = true
}
shouldAbort := toolLoopAbortCallIDs.Take(p.ToolCallID)
stepNumber++
if !sendEvent(ctx, ch, StreamEvent{
Type: EventToolCallEnd,
@@ -440,7 +437,10 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) {
// Collecting emitter: tools push side-effect events here during generation.
var collected []tools.ToolStreamEvent
var collectedMu sync.Mutex
collectEmitter := tools.StreamEmitter(func(evt tools.ToolStreamEvent) {
collectedMu.Lock()
defer collectedMu.Unlock()
collected = append(collected, evt)
})
@@ -456,7 +456,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
var toolLoopGuard *ToolLoopGuard
var textLoopGuard *TextLoopGuard
toolLoopAbortCallIDs := make(map[string]struct{})
toolLoopAbortCallIDs := newToolAbortRegistry()
if cfg.LoopDetection.Enabled {
toolLoopGuard = NewToolLoopGuard(ToolLoopRepeatThreshold, ToolLoopWarningsBeforeAbort)
textLoopGuard = NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{})
@@ -490,7 +490,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
opts = append(opts,
sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams {
if cfg.LoopDetection.Enabled {
if len(toolLoopAbortCallIDs) > 0 {
if toolLoopAbortCallIDs.Any() {
return nil // stop
}
if textLoopGuard != nil && isNonEmptyString(step.Text) {
@@ -701,7 +701,7 @@ func drainBackgroundNotifications(
return p
}
func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs map[string]struct{}) []sdk.Tool {
func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs *toolAbortRegistry) []sdk.Tool {
wrapped := make([]sdk.Tool, len(tools))
for i, tool := range tools {
originalExecute := tool.Execute
@@ -710,7 +710,7 @@ func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs
wrapped[i].Execute = func(ctx *sdk.ToolExecContext, input any) (any, error) {
warn, abort := guard.Guard(toolName, input)
if abort {
abortCallIDs[ctx.ToolCallID] = struct{}{}
abortCallIDs.Add(ctx.ToolCallID)
return map[string]any{
"isError": true,
"content": []map[string]any{{