fix(agent): stop generate loop aborts correctly

This commit is contained in:
Fodesu
2026-04-14 00:38:26 +08:00
committed by 晨苒
parent 59147b255d
commit 33461d7ac1
3 changed files with 235 additions and 5 deletions
+66 -5
View File
@@ -435,6 +435,10 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
}
func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) {
genCtx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
loopAbort := newLoopAbortState()
// Collecting emitter: tools push side-effect events here during generation.
var collected []tools.ToolStreamEvent
var collectedMu sync.Mutex
@@ -447,7 +451,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
var sdkTools []sdk.Tool
if cfg.SupportsToolCall {
var err error
sdkTools, err = a.assembleTools(ctx, cfg, collectEmitter)
sdkTools, err = a.assembleTools(genCtx, cfg, collectEmitter)
if err != nil {
return nil, fmt.Errorf("assemble tools: %w", err)
}
@@ -491,12 +495,16 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams {
if cfg.LoopDetection.Enabled {
if toolLoopAbortCallIDs.Any() {
return nil // stop
loopAbort.Set(ErrToolLoopDetected)
cancel(ErrToolLoopDetected)
return nil
}
if textLoopGuard != nil && isNonEmptyString(step.Text) {
result := textLoopGuard.Inspect(step.Text)
if result.Abort {
return nil // stop
loopAbort.Set(ErrTextLoopDetected)
cancel(ErrTextLoopDetected)
return nil
}
}
}
@@ -504,10 +512,16 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
}),
)
genResult, err := a.client.GenerateTextResult(ctx, opts...)
genResult, err := a.client.GenerateTextResult(genCtx, opts...)
if err != nil {
if loopErr := detectGenerateLoopAbort(genCtx, err); loopErr != nil {
return nil, loopErr
}
return nil, fmt.Errorf("generate: %w", err)
}
if loopErr := loopAbort.Err(); loopErr != nil {
return nil, loopErr
}
// Drain collected tool-emitted side effects into the result.
var attachments []FileAttachment
@@ -717,7 +731,7 @@ func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs
"type": "text",
"text": ToolLoopDetectedAbortMessage,
}},
}, errors.New(ToolLoopDetectedAbortMessage)
}, ErrToolLoopDetected
}
if warn {
return map[string]any{
@@ -979,3 +993,50 @@ func sleepWithContext(ctx context.Context, d time.Duration) error {
return ctx.Err()
}
}
func detectGenerateLoopAbort(ctx context.Context, err error) error {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
return nil
}
cause := context.Cause(ctx)
switch {
case errors.Is(cause, ErrToolLoopDetected):
return ErrToolLoopDetected
case errors.Is(cause, ErrTextLoopDetected):
return ErrTextLoopDetected
default:
return nil
}
}
type loopAbortState struct {
mu sync.Mutex
err error
}
func newLoopAbortState() *loopAbortState {
return &loopAbortState{}
}
func (s *loopAbortState) Set(err error) {
if s == nil || err == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
if s.err == nil {
s.err = err
}
}
func (s *loopAbortState) Err() error {
if s == nil {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
return s.err
}