mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
fix(agent): stop generate loop aborts correctly
This commit is contained in:
+66
-5
@@ -435,6 +435,10 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
}
|
||||
|
||||
func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) {
|
||||
genCtx, cancel := context.WithCancelCause(ctx)
|
||||
defer cancel(nil)
|
||||
loopAbort := newLoopAbortState()
|
||||
|
||||
// Collecting emitter: tools push side-effect events here during generation.
|
||||
var collected []tools.ToolStreamEvent
|
||||
var collectedMu sync.Mutex
|
||||
@@ -447,7 +451,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
||||
var sdkTools []sdk.Tool
|
||||
if cfg.SupportsToolCall {
|
||||
var err error
|
||||
sdkTools, err = a.assembleTools(ctx, cfg, collectEmitter)
|
||||
sdkTools, err = a.assembleTools(genCtx, cfg, collectEmitter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("assemble tools: %w", err)
|
||||
}
|
||||
@@ -491,12 +495,16 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
||||
sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams {
|
||||
if cfg.LoopDetection.Enabled {
|
||||
if toolLoopAbortCallIDs.Any() {
|
||||
return nil // stop
|
||||
loopAbort.Set(ErrToolLoopDetected)
|
||||
cancel(ErrToolLoopDetected)
|
||||
return nil
|
||||
}
|
||||
if textLoopGuard != nil && isNonEmptyString(step.Text) {
|
||||
result := textLoopGuard.Inspect(step.Text)
|
||||
if result.Abort {
|
||||
return nil // stop
|
||||
loopAbort.Set(ErrTextLoopDetected)
|
||||
cancel(ErrTextLoopDetected)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -504,10 +512,16 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
||||
}),
|
||||
)
|
||||
|
||||
genResult, err := a.client.GenerateTextResult(ctx, opts...)
|
||||
genResult, err := a.client.GenerateTextResult(genCtx, opts...)
|
||||
if err != nil {
|
||||
if loopErr := detectGenerateLoopAbort(genCtx, err); loopErr != nil {
|
||||
return nil, loopErr
|
||||
}
|
||||
return nil, fmt.Errorf("generate: %w", err)
|
||||
}
|
||||
if loopErr := loopAbort.Err(); loopErr != nil {
|
||||
return nil, loopErr
|
||||
}
|
||||
|
||||
// Drain collected tool-emitted side effects into the result.
|
||||
var attachments []FileAttachment
|
||||
@@ -717,7 +731,7 @@ func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs
|
||||
"type": "text",
|
||||
"text": ToolLoopDetectedAbortMessage,
|
||||
}},
|
||||
}, errors.New(ToolLoopDetectedAbortMessage)
|
||||
}, ErrToolLoopDetected
|
||||
}
|
||||
if warn {
|
||||
return map[string]any{
|
||||
@@ -979,3 +993,50 @@ func sleepWithContext(ctx context.Context, d time.Duration) error {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func detectGenerateLoopAbort(ctx context.Context, err error) error {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
|
||||
cause := context.Cause(ctx)
|
||||
switch {
|
||||
case errors.Is(cause, ErrToolLoopDetected):
|
||||
return ErrToolLoopDetected
|
||||
case errors.Is(cause, ErrTextLoopDetected):
|
||||
return ErrTextLoopDetected
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type loopAbortState struct {
|
||||
mu sync.Mutex
|
||||
err error
|
||||
}
|
||||
|
||||
func newLoopAbortState() *loopAbortState {
|
||||
return &loopAbortState{}
|
||||
}
|
||||
|
||||
func (s *loopAbortState) Set(err error) {
|
||||
if s == nil || err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err == nil {
|
||||
s.err = err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *loopAbortState) Err() error {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user