fix(agent): guard tool loop state against concurrent tool execution

This commit is contained in:
Fodesu
2026-04-13 23:50:22 +08:00
committed by 晨苒
parent 1a6d12a137
commit 59147b255d
5 changed files with 165 additions and 10 deletions
+10 -10
View File
@@ -7,6 +7,7 @@ import (
"fmt"
"log/slog"
"strings"
"sync"
"time"
sdk "github.com/memohai/twilight-ai/sdk"
@@ -99,7 +100,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
var textLoopGuard *TextLoopGuard
var textLoopProbeBuffer *TextLoopProbeBuffer
var toolLoopGuard *ToolLoopGuard
toolLoopAbortCallIDs := make(map[string]struct{})
toolLoopAbortCallIDs := newToolAbortRegistry()
if cfg.LoopDetection.Enabled {
textLoopGuard = NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{})
textLoopProbeBuffer = NewTextLoopProbeBuffer(LoopDetectedProbeChars, func(text string) {
@@ -314,11 +315,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
}
case *sdk.StreamToolResultPart:
shouldAbort := false
if _, ok := toolLoopAbortCallIDs[p.ToolCallID]; ok {
delete(toolLoopAbortCallIDs, p.ToolCallID)
shouldAbort = true
}
shouldAbort := toolLoopAbortCallIDs.Take(p.ToolCallID)
stepNumber++
if !sendEvent(ctx, ch, StreamEvent{
Type: EventToolCallEnd,
@@ -440,7 +437,10 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) {
// Collecting emitter: tools push side-effect events here during generation.
var collected []tools.ToolStreamEvent
var collectedMu sync.Mutex
collectEmitter := tools.StreamEmitter(func(evt tools.ToolStreamEvent) {
collectedMu.Lock()
defer collectedMu.Unlock()
collected = append(collected, evt)
})
@@ -456,7 +456,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
var toolLoopGuard *ToolLoopGuard
var textLoopGuard *TextLoopGuard
toolLoopAbortCallIDs := make(map[string]struct{})
toolLoopAbortCallIDs := newToolAbortRegistry()
if cfg.LoopDetection.Enabled {
toolLoopGuard = NewToolLoopGuard(ToolLoopRepeatThreshold, ToolLoopWarningsBeforeAbort)
textLoopGuard = NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{})
@@ -490,7 +490,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
opts = append(opts,
sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams {
if cfg.LoopDetection.Enabled {
if len(toolLoopAbortCallIDs) > 0 {
if toolLoopAbortCallIDs.Any() {
return nil // stop
}
if textLoopGuard != nil && isNonEmptyString(step.Text) {
@@ -701,7 +701,7 @@ func drainBackgroundNotifications(
return p
}
func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs map[string]struct{}) []sdk.Tool {
func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs *toolAbortRegistry) []sdk.Tool {
wrapped := make([]sdk.Tool, len(tools))
for i, tool := range tools {
originalExecute := tool.Execute
@@ -710,7 +710,7 @@ func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs
wrapped[i].Execute = func(ctx *sdk.ToolExecContext, input any) (any, error) {
warn, abort := guard.Guard(toolName, input)
if abort {
abortCallIDs[ctx.ToolCallID] = struct{}{}
abortCallIDs.Add(ctx.ToolCallID)
return map[string]any{
"isError": true,
"content": []map[string]any{{
+48
View File
@@ -0,0 +1,48 @@
package agent
import "sync"
type toolAbortRegistry struct {
mu sync.Mutex
ids map[string]struct{}
}
func newToolAbortRegistry() *toolAbortRegistry {
return &toolAbortRegistry{
ids: make(map[string]struct{}),
}
}
func (r *toolAbortRegistry) Add(toolCallID string) {
if r == nil || toolCallID == "" {
return
}
r.mu.Lock()
defer r.mu.Unlock()
r.ids[toolCallID] = struct{}{}
}
func (r *toolAbortRegistry) Take(toolCallID string) bool {
if r == nil || toolCallID == "" {
return false
}
r.mu.Lock()
defer r.mu.Unlock()
if _, ok := r.ids[toolCallID]; !ok {
return false
}
delete(r.ids, toolCallID)
return true
}
func (r *toolAbortRegistry) Any() bool {
if r == nil {
return false
}
r.mu.Lock()
defer r.mu.Unlock()
return len(r.ids) > 0
}
+32
View File
@@ -0,0 +1,32 @@
package agent
import (
"fmt"
"sync"
"testing"
)
func TestToolAbortRegistryConcurrentAccess(t *testing.T) {
registry := newToolAbortRegistry()
var wg sync.WaitGroup
for i := 0; i < 32; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
id := fmt.Sprintf("call-%d", i)
registry.Add(id)
if !registry.Any() {
t.Error("expected registry to report pending aborts")
}
if !registry.Take(id) {
t.Errorf("expected to take %s", id)
}
}(i)
}
wg.Wait()
if registry.Any() {
t.Fatal("expected registry to be empty")
}
}
+18
View File
@@ -7,6 +7,7 @@ import (
"fmt"
"sort"
"strings"
"sync"
"unicode"
"unicode/utf8"
)
@@ -322,6 +323,7 @@ type ToolLoopResult struct {
// ToolLoopGuard detects repeated identical tool calls.
type ToolLoopGuard struct {
mu sync.Mutex
repeatThreshold int
warningsBeforeAbort int
volatileKeySet map[string]struct{}
@@ -352,8 +354,17 @@ func NewToolLoopGuard(repeatThreshold, warningsBeforeAbort int) *ToolLoopGuard {
// Inspect checks a tool call for repetition.
func (g *ToolLoopGuard) Inspect(input ToolLoopInput) ToolLoopResult {
if g == nil {
return ToolLoopResult{
Hash: computeToolLoopHash(input, nil),
}
}
hash := computeToolLoopHash(input, g.volatileKeySet)
g.mu.Lock()
defer g.mu.Unlock()
if hash == g.lastHash {
g.repeatCount++
} else {
@@ -391,6 +402,13 @@ func (g *ToolLoopGuard) Inspect(input ToolLoopInput) ToolLoopResult {
// Reset clears the guard state.
func (g *ToolLoopGuard) Reset() {
if g == nil {
return
}
g.mu.Lock()
defer g.mu.Unlock()
g.lastHash = ""
g.repeatCount = 0
g.breachCount = 0
+57
View File
@@ -0,0 +1,57 @@
package agent
import (
"sync"
"testing"
)
func TestToolLoopGuardNilReceiver(t *testing.T) {
var guard *ToolLoopGuard
result := guard.Inspect(ToolLoopInput{
ToolName: "search",
Input: map[string]any{
"query": "memoh",
},
})
if result.Hash == "" {
t.Fatal("expected hash for nil receiver")
}
if result.Warn {
t.Fatal("did not expect warning for nil receiver")
}
if result.Abort {
t.Fatal("did not expect abort for nil receiver")
}
}
func TestToolLoopGuardConcurrentInspectAndReset(t *testing.T) {
guard := NewToolLoopGuard(2, 1)
input := ToolLoopInput{
ToolName: "web_search",
Input: map[string]any{
"query": "memoh logs",
"requestId": "volatile",
},
}
var wg sync.WaitGroup
for i := 0; i < 16; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
for j := 0; j < 200; j++ {
result := guard.Inspect(input)
if result.Hash == "" {
t.Error("expected non-empty hash")
return
}
if (i+j)%25 == 0 {
guard.Reset()
}
}
}(i)
}
wg.Wait()
}