mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
fix(agent): guard tool loop state against concurrent tool execution
This commit is contained in:
+10
-10
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
@@ -99,7 +100,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
var textLoopGuard *TextLoopGuard
|
||||
var textLoopProbeBuffer *TextLoopProbeBuffer
|
||||
var toolLoopGuard *ToolLoopGuard
|
||||
toolLoopAbortCallIDs := make(map[string]struct{})
|
||||
toolLoopAbortCallIDs := newToolAbortRegistry()
|
||||
if cfg.LoopDetection.Enabled {
|
||||
textLoopGuard = NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{})
|
||||
textLoopProbeBuffer = NewTextLoopProbeBuffer(LoopDetectedProbeChars, func(text string) {
|
||||
@@ -314,11 +315,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
}
|
||||
|
||||
case *sdk.StreamToolResultPart:
|
||||
shouldAbort := false
|
||||
if _, ok := toolLoopAbortCallIDs[p.ToolCallID]; ok {
|
||||
delete(toolLoopAbortCallIDs, p.ToolCallID)
|
||||
shouldAbort = true
|
||||
}
|
||||
shouldAbort := toolLoopAbortCallIDs.Take(p.ToolCallID)
|
||||
stepNumber++
|
||||
if !sendEvent(ctx, ch, StreamEvent{
|
||||
Type: EventToolCallEnd,
|
||||
@@ -440,7 +437,10 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) {
|
||||
// Collecting emitter: tools push side-effect events here during generation.
|
||||
var collected []tools.ToolStreamEvent
|
||||
var collectedMu sync.Mutex
|
||||
collectEmitter := tools.StreamEmitter(func(evt tools.ToolStreamEvent) {
|
||||
collectedMu.Lock()
|
||||
defer collectedMu.Unlock()
|
||||
collected = append(collected, evt)
|
||||
})
|
||||
|
||||
@@ -456,7 +456,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
||||
|
||||
var toolLoopGuard *ToolLoopGuard
|
||||
var textLoopGuard *TextLoopGuard
|
||||
toolLoopAbortCallIDs := make(map[string]struct{})
|
||||
toolLoopAbortCallIDs := newToolAbortRegistry()
|
||||
if cfg.LoopDetection.Enabled {
|
||||
toolLoopGuard = NewToolLoopGuard(ToolLoopRepeatThreshold, ToolLoopWarningsBeforeAbort)
|
||||
textLoopGuard = NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{})
|
||||
@@ -490,7 +490,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
||||
opts = append(opts,
|
||||
sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams {
|
||||
if cfg.LoopDetection.Enabled {
|
||||
if len(toolLoopAbortCallIDs) > 0 {
|
||||
if toolLoopAbortCallIDs.Any() {
|
||||
return nil // stop
|
||||
}
|
||||
if textLoopGuard != nil && isNonEmptyString(step.Text) {
|
||||
@@ -701,7 +701,7 @@ func drainBackgroundNotifications(
|
||||
return p
|
||||
}
|
||||
|
||||
func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs map[string]struct{}) []sdk.Tool {
|
||||
func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs *toolAbortRegistry) []sdk.Tool {
|
||||
wrapped := make([]sdk.Tool, len(tools))
|
||||
for i, tool := range tools {
|
||||
originalExecute := tool.Execute
|
||||
@@ -710,7 +710,7 @@ func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs
|
||||
wrapped[i].Execute = func(ctx *sdk.ToolExecContext, input any) (any, error) {
|
||||
warn, abort := guard.Guard(toolName, input)
|
||||
if abort {
|
||||
abortCallIDs[ctx.ToolCallID] = struct{}{}
|
||||
abortCallIDs.Add(ctx.ToolCallID)
|
||||
return map[string]any{
|
||||
"isError": true,
|
||||
"content": []map[string]any{{
|
||||
|
||||
Reference in New Issue
Block a user