mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
fix(agent): guard tool loop state against concurrent tool execution
This commit is contained in:
+10
-10
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
sdk "github.com/memohai/twilight-ai/sdk"
|
sdk "github.com/memohai/twilight-ai/sdk"
|
||||||
@@ -99,7 +100,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
|||||||
var textLoopGuard *TextLoopGuard
|
var textLoopGuard *TextLoopGuard
|
||||||
var textLoopProbeBuffer *TextLoopProbeBuffer
|
var textLoopProbeBuffer *TextLoopProbeBuffer
|
||||||
var toolLoopGuard *ToolLoopGuard
|
var toolLoopGuard *ToolLoopGuard
|
||||||
toolLoopAbortCallIDs := make(map[string]struct{})
|
toolLoopAbortCallIDs := newToolAbortRegistry()
|
||||||
if cfg.LoopDetection.Enabled {
|
if cfg.LoopDetection.Enabled {
|
||||||
textLoopGuard = NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{})
|
textLoopGuard = NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{})
|
||||||
textLoopProbeBuffer = NewTextLoopProbeBuffer(LoopDetectedProbeChars, func(text string) {
|
textLoopProbeBuffer = NewTextLoopProbeBuffer(LoopDetectedProbeChars, func(text string) {
|
||||||
@@ -314,11 +315,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
|||||||
}
|
}
|
||||||
|
|
||||||
case *sdk.StreamToolResultPart:
|
case *sdk.StreamToolResultPart:
|
||||||
shouldAbort := false
|
shouldAbort := toolLoopAbortCallIDs.Take(p.ToolCallID)
|
||||||
if _, ok := toolLoopAbortCallIDs[p.ToolCallID]; ok {
|
|
||||||
delete(toolLoopAbortCallIDs, p.ToolCallID)
|
|
||||||
shouldAbort = true
|
|
||||||
}
|
|
||||||
stepNumber++
|
stepNumber++
|
||||||
if !sendEvent(ctx, ch, StreamEvent{
|
if !sendEvent(ctx, ch, StreamEvent{
|
||||||
Type: EventToolCallEnd,
|
Type: EventToolCallEnd,
|
||||||
@@ -440,7 +437,10 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
|||||||
func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) {
|
func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) {
|
||||||
// Collecting emitter: tools push side-effect events here during generation.
|
// Collecting emitter: tools push side-effect events here during generation.
|
||||||
var collected []tools.ToolStreamEvent
|
var collected []tools.ToolStreamEvent
|
||||||
|
var collectedMu sync.Mutex
|
||||||
collectEmitter := tools.StreamEmitter(func(evt tools.ToolStreamEvent) {
|
collectEmitter := tools.StreamEmitter(func(evt tools.ToolStreamEvent) {
|
||||||
|
collectedMu.Lock()
|
||||||
|
defer collectedMu.Unlock()
|
||||||
collected = append(collected, evt)
|
collected = append(collected, evt)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -456,7 +456,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
|||||||
|
|
||||||
var toolLoopGuard *ToolLoopGuard
|
var toolLoopGuard *ToolLoopGuard
|
||||||
var textLoopGuard *TextLoopGuard
|
var textLoopGuard *TextLoopGuard
|
||||||
toolLoopAbortCallIDs := make(map[string]struct{})
|
toolLoopAbortCallIDs := newToolAbortRegistry()
|
||||||
if cfg.LoopDetection.Enabled {
|
if cfg.LoopDetection.Enabled {
|
||||||
toolLoopGuard = NewToolLoopGuard(ToolLoopRepeatThreshold, ToolLoopWarningsBeforeAbort)
|
toolLoopGuard = NewToolLoopGuard(ToolLoopRepeatThreshold, ToolLoopWarningsBeforeAbort)
|
||||||
textLoopGuard = NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{})
|
textLoopGuard = NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{})
|
||||||
@@ -490,7 +490,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
|||||||
opts = append(opts,
|
opts = append(opts,
|
||||||
sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams {
|
sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams {
|
||||||
if cfg.LoopDetection.Enabled {
|
if cfg.LoopDetection.Enabled {
|
||||||
if len(toolLoopAbortCallIDs) > 0 {
|
if toolLoopAbortCallIDs.Any() {
|
||||||
return nil // stop
|
return nil // stop
|
||||||
}
|
}
|
||||||
if textLoopGuard != nil && isNonEmptyString(step.Text) {
|
if textLoopGuard != nil && isNonEmptyString(step.Text) {
|
||||||
@@ -701,7 +701,7 @@ func drainBackgroundNotifications(
|
|||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs map[string]struct{}) []sdk.Tool {
|
func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs *toolAbortRegistry) []sdk.Tool {
|
||||||
wrapped := make([]sdk.Tool, len(tools))
|
wrapped := make([]sdk.Tool, len(tools))
|
||||||
for i, tool := range tools {
|
for i, tool := range tools {
|
||||||
originalExecute := tool.Execute
|
originalExecute := tool.Execute
|
||||||
@@ -710,7 +710,7 @@ func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs
|
|||||||
wrapped[i].Execute = func(ctx *sdk.ToolExecContext, input any) (any, error) {
|
wrapped[i].Execute = func(ctx *sdk.ToolExecContext, input any) (any, error) {
|
||||||
warn, abort := guard.Guard(toolName, input)
|
warn, abort := guard.Guard(toolName, input)
|
||||||
if abort {
|
if abort {
|
||||||
abortCallIDs[ctx.ToolCallID] = struct{}{}
|
abortCallIDs.Add(ctx.ToolCallID)
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"isError": true,
|
"isError": true,
|
||||||
"content": []map[string]any{{
|
"content": []map[string]any{{
|
||||||
|
|||||||
@@ -0,0 +1,48 @@
|
|||||||
|
package agent
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
type toolAbortRegistry struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
ids map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newToolAbortRegistry() *toolAbortRegistry {
|
||||||
|
return &toolAbortRegistry{
|
||||||
|
ids: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *toolAbortRegistry) Add(toolCallID string) {
|
||||||
|
if r == nil || toolCallID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.ids[toolCallID] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *toolAbortRegistry) Take(toolCallID string) bool {
|
||||||
|
if r == nil || toolCallID == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
if _, ok := r.ids[toolCallID]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
delete(r.ids, toolCallID)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *toolAbortRegistry) Any() bool {
|
||||||
|
if r == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
return len(r.ids) > 0
|
||||||
|
}
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestToolAbortRegistryConcurrentAccess(t *testing.T) {
|
||||||
|
registry := newToolAbortRegistry()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 32; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
id := fmt.Sprintf("call-%d", i)
|
||||||
|
registry.Add(id)
|
||||||
|
if !registry.Any() {
|
||||||
|
t.Error("expected registry to report pending aborts")
|
||||||
|
}
|
||||||
|
if !registry.Take(id) {
|
||||||
|
t.Errorf("expected to take %s", id)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if registry.Any() {
|
||||||
|
t.Fatal("expected registry to be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"unicode"
|
"unicode"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
@@ -322,6 +323,7 @@ type ToolLoopResult struct {
|
|||||||
|
|
||||||
// ToolLoopGuard detects repeated identical tool calls.
|
// ToolLoopGuard detects repeated identical tool calls.
|
||||||
type ToolLoopGuard struct {
|
type ToolLoopGuard struct {
|
||||||
|
mu sync.Mutex
|
||||||
repeatThreshold int
|
repeatThreshold int
|
||||||
warningsBeforeAbort int
|
warningsBeforeAbort int
|
||||||
volatileKeySet map[string]struct{}
|
volatileKeySet map[string]struct{}
|
||||||
@@ -352,8 +354,17 @@ func NewToolLoopGuard(repeatThreshold, warningsBeforeAbort int) *ToolLoopGuard {
|
|||||||
|
|
||||||
// Inspect checks a tool call for repetition.
|
// Inspect checks a tool call for repetition.
|
||||||
func (g *ToolLoopGuard) Inspect(input ToolLoopInput) ToolLoopResult {
|
func (g *ToolLoopGuard) Inspect(input ToolLoopInput) ToolLoopResult {
|
||||||
|
if g == nil {
|
||||||
|
return ToolLoopResult{
|
||||||
|
Hash: computeToolLoopHash(input, nil),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
hash := computeToolLoopHash(input, g.volatileKeySet)
|
hash := computeToolLoopHash(input, g.volatileKeySet)
|
||||||
|
|
||||||
|
g.mu.Lock()
|
||||||
|
defer g.mu.Unlock()
|
||||||
|
|
||||||
if hash == g.lastHash {
|
if hash == g.lastHash {
|
||||||
g.repeatCount++
|
g.repeatCount++
|
||||||
} else {
|
} else {
|
||||||
@@ -391,6 +402,13 @@ func (g *ToolLoopGuard) Inspect(input ToolLoopInput) ToolLoopResult {
|
|||||||
|
|
||||||
// Reset clears the guard state.
|
// Reset clears the guard state.
|
||||||
func (g *ToolLoopGuard) Reset() {
|
func (g *ToolLoopGuard) Reset() {
|
||||||
|
if g == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
g.mu.Lock()
|
||||||
|
defer g.mu.Unlock()
|
||||||
|
|
||||||
g.lastHash = ""
|
g.lastHash = ""
|
||||||
g.repeatCount = 0
|
g.repeatCount = 0
|
||||||
g.breachCount = 0
|
g.breachCount = 0
|
||||||
|
|||||||
@@ -0,0 +1,57 @@
|
|||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestToolLoopGuardNilReceiver(t *testing.T) {
|
||||||
|
var guard *ToolLoopGuard
|
||||||
|
|
||||||
|
result := guard.Inspect(ToolLoopInput{
|
||||||
|
ToolName: "search",
|
||||||
|
Input: map[string]any{
|
||||||
|
"query": "memoh",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if result.Hash == "" {
|
||||||
|
t.Fatal("expected hash for nil receiver")
|
||||||
|
}
|
||||||
|
if result.Warn {
|
||||||
|
t.Fatal("did not expect warning for nil receiver")
|
||||||
|
}
|
||||||
|
if result.Abort {
|
||||||
|
t.Fatal("did not expect abort for nil receiver")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolLoopGuardConcurrentInspectAndReset(t *testing.T) {
|
||||||
|
guard := NewToolLoopGuard(2, 1)
|
||||||
|
input := ToolLoopInput{
|
||||||
|
ToolName: "web_search",
|
||||||
|
Input: map[string]any{
|
||||||
|
"query": "memoh logs",
|
||||||
|
"requestId": "volatile",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 16; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < 200; j++ {
|
||||||
|
result := guard.Inspect(input)
|
||||||
|
if result.Hash == "" {
|
||||||
|
t.Error("expected non-empty hash")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (i+j)%25 == 0 {
|
||||||
|
guard.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user