From db777b98ac80b3fb31b5e4042722a0228b9e5348 Mon Sep 17 00:00:00 2001 From: Fodesu <75713465+Fodesu@users.noreply.github.com> Date: Sat, 18 Apr 2026 03:19:58 +0800 Subject: [PATCH] fix(agent): stream loop abort, mid-stream retry parity, collector cleanup (#376) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(agent): align stream retry abort and event collection * fix(agent): cancel stream on loop detect, harden retry and tool events * fix(agent): drain previous stream before retry * fix(lint): ctx ci lint --------- Co-authored-by: 晨苒 <16112591+chen-ran@users.noreply.github.com> --- internal/agent/agent.go | 111 ++++++-- internal/agent/background/manager.go | 28 +- internal/agent/generate_loop_test.go | 394 ++++++++++++++++++++++++++- internal/agent/guard_state.go | 68 ++++- internal/agent/guard_state_test.go | 33 +++ 5 files changed, 587 insertions(+), 47 deletions(-) diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 5c30204c..de78c696 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -78,6 +78,9 @@ func sendEvent(ctx context.Context, ch chan<- StreamEvent, evt StreamEvent) bool } func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEvent) { + streamCtx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + // Stream emitter: tools targeting the current conversation push // side-effect events (attachments, reactions, speech) directly here. // Uses sendEvent to avoid goroutine leaks when the consumer stops reading. @@ -88,7 +91,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv var sdkTools []sdk.Tool if cfg.SupportsToolCall { var err error - sdkTools, err = a.assembleTools(ctx, cfg, streamEmitter) + sdkTools, err = a.assembleTools(streamCtx, cfg, streamEmitter) if err != nil { sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: fmt.Sprintf("assemble tools: %v", err)}) return @@ -96,6 +99,8 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv } sdkTools, readMediaState := decorateReadMediaTools(cfg.Model, sdkTools) + aborted := false + // Loop detection setup var textLoopGuard *TextLoopGuard var textLoopProbeBuffer *TextLoopProbeBuffer @@ -107,6 +112,8 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv result := textLoopGuard.Inspect(text) if result.Abort { a.logger.Warn("text loop detected, will abort") + aborted = true + cancel(ErrTextLoopDetected) } }) toolLoopGuard = NewToolLoopGuard(ToolLoopRepeatThreshold, ToolLoopWarningsBeforeAbort) @@ -198,7 +205,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv var streamResult *sdk.StreamResult for attempt := 0; attempt < retryCfg.MaxAttempts; attempt++ { var err error - streamResult, err = a.client.StreamText(ctx, opts...) + streamResult, err = a.client.StreamText(streamCtx, opts...) if err == nil { break } @@ -225,7 +232,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv } delay := retryDelay(attempt, retryCfg) if delay > 0 { - if err := sleepWithContext(ctx, delay); err != nil { + if err := sleepWithContext(streamCtx, delay); err != nil { sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: fmt.Sprintf("stream start: context cancelled during retry: %v", err)}) return } @@ -235,11 +242,10 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv sendEvent(ctx, ch, StreamEvent{Type: EventAgentStart}) var allText strings.Builder - aborted := false stepNumber := 0 for part := range streamResult.Stream { - if ctx.Err() != nil { + if streamCtx.Err() != nil { aborted = true break } @@ -319,11 +325,13 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv } case *sdk.ToolProgressPart: - ch <- StreamEvent{ + if !sendEvent(ctx, ch, StreamEvent{ Type: EventToolCallProgress, ToolName: p.ToolName, ToolCallID: p.ToolCallID, Progress: p.Content, + }) { + aborted = true } case *sdk.StreamToolResultPart: @@ -345,10 +353,14 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv } if shouldAbort { a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", p.ToolCallID)) + cancel(ErrToolLoopDetected) aborted = true } case *sdk.StreamToolErrorPart: + // Take before errors.Is so registry IDs from the loop guard are always cleared. + tookLoopAbort := toolLoopAbortCallIDs.Take(p.ToolCallID) + shouldAbort := errors.Is(p.Error, ErrToolLoopDetected) || tookLoopAbort if !sendEvent(ctx, ch, StreamEvent{ Type: EventToolCallEnd, ToolName: p.ToolName, @@ -357,6 +369,11 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv }) { aborted = true } + if shouldAbort { + a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", p.ToolCallID)) + cancel(ErrToolLoopDetected) + aborted = true + } case *sdk.StreamFilePart: mediaType := p.File.MediaType @@ -384,7 +401,8 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv // no work has been completed yet and retrying from the start is safe. if isRetryableStreamError(p.Error) { streamResult, aborted = a.runMidStreamRetry( - ctx, ch, cfg, sdkTools, prepareStep, streamResult, + ctx, streamCtx, cancel, toolLoopAbortCallIDs, + ch, cfg, sdkTools, prepareStep, streamResult, stepNumber, errMsg, &allText, textLoopProbeBuffer, ) } else { @@ -403,6 +421,11 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv } } + if aborted { + for range streamResult.Stream { + } + } + if textLoopProbeBuffer != nil { textLoopProbeBuffer.Flush() } @@ -452,12 +475,10 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult loopAbort := newLoopAbortState() // Collecting emitter: tools push side-effect events here during generation. - var collected []tools.ToolStreamEvent - var collectedMu sync.Mutex + collected := newToolEventCollector() + defer collected.Close() collectEmitter := tools.StreamEmitter(func(evt tools.ToolStreamEvent) { - collectedMu.Lock() - defer collectedMu.Unlock() - collected = append(collected, evt) + collected.Add(evt) }) var sdkTools []sdk.Tool @@ -536,10 +557,11 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult } // Drain collected tool-emitted side effects into the result. + collectedEvents := collected.CloseAndSnapshot() var attachments []FileAttachment var reactions []ReactionItem var speeches []SpeechItem - for _, evt := range collected { + for _, evt := range collectedEvents { switch evt.Type { case tools.StreamEventAttachment: for _, a := range evt.Attachments { @@ -860,8 +882,15 @@ func pruneOldToolResults(p *sdk.GenerateParams, keepSteps, threshold int) *sdk.G // runMidStreamRetry attempts to continue the agent stream after a retryable // mid-stream error. It re-invokes StreamText with the accumulated messages // and drains the new stream into the same output channel. +// +// sendCtx is used for sendEvent so consumer disconnect (parent ctx) still +// controls channel back-pressure; streamCtx is passed to the SDK for the same +// cancellation semantics as the main stream (including loop-detect cancel). func (a *Agent) runMidStreamRetry( - ctx context.Context, + sendCtx context.Context, + streamCtx context.Context, + cancel context.CancelCauseFunc, + toolLoopAbortCallIDs *toolAbortRegistry, ch chan<- StreamEvent, cfg RunConfig, sdkTools []sdk.Tool, @@ -872,6 +901,13 @@ func (a *Agent) runMidStreamRetry( allText *strings.Builder, textLoopProbeBuffer *TextLoopProbeBuffer, ) (*sdk.StreamResult, bool) { + // Drain the previous stream before reading prevResult.Messages. + // This avoids racing with the SDK's final StreamResult write. + if prevResult.Stream != nil { + for range prevResult.Stream { + } + } + retryCfg := DefaultRetryConfig() for attempt := 0; attempt < retryCfg.MaxAttempts; attempt++ { a.logger.Warn("mid-stream error, retrying", @@ -880,7 +916,7 @@ func (a *Agent) runMidStreamRetry( slog.Int("max_attempts", retryCfg.MaxAttempts), slog.String("error", errMsg), ) - if !sendEvent(ctx, ch, StreamEvent{ + if !sendEvent(sendCtx, ch, StreamEvent{ Type: EventRetry, Attempt: attempt + 1, MaxAttempt: retryCfg.MaxAttempts, @@ -891,7 +927,7 @@ func (a *Agent) runMidStreamRetry( delay := retryDelay(attempt, retryCfg) if delay > 0 { - if err := sleepWithContext(ctx, delay); err != nil { + if err := sleepWithContext(streamCtx, delay); err != nil { return prevResult, true // aborted } } @@ -903,7 +939,7 @@ func (a *Agent) runMidStreamRetry( retryCfgCopy.Messages = prevResult.Messages retryOpts := a.buildGenerateOptions(retryCfgCopy, sdkTools, prepareStep) - retryResult, retryErr := a.client.StreamText(ctx, retryOpts...) + retryResult, retryErr := a.client.StreamText(streamCtx, retryOpts...) if retryErr != nil { a.logger.Warn("mid-stream retry failed to start", slog.Int("attempt", attempt+1), @@ -917,9 +953,13 @@ func (a *Agent) runMidStreamRetry( // Drain the retry stream into the main event loop aborted := false for retryPart := range retryResult.Stream { + if streamCtx.Err() != nil { + aborted = true + break + } switch rp := retryPart.(type) { case *sdk.TextStartPart: - if !sendEvent(ctx, ch, StreamEvent{Type: EventTextStart}) { + if !sendEvent(sendCtx, ch, StreamEvent{Type: EventTextStart}) { aborted = true } case *sdk.TextDeltaPart: @@ -927,7 +967,7 @@ func (a *Agent) runMidStreamRetry( if textLoopProbeBuffer != nil { textLoopProbeBuffer.Push(rp.Text) } - if !sendEvent(ctx, ch, StreamEvent{Type: EventTextDelta, Delta: rp.Text}) { + if !sendEvent(sendCtx, ch, StreamEvent{Type: EventTextDelta, Delta: rp.Text}) { aborted = true } allText.WriteString(rp.Text) @@ -937,14 +977,14 @@ func (a *Agent) runMidStreamRetry( textLoopProbeBuffer.Flush() } stepNumber++ - if !sendEvent(ctx, ch, StreamEvent{Type: EventTextEnd}) { + if !sendEvent(sendCtx, ch, StreamEvent{Type: EventTextEnd}) { aborted = true } case *sdk.ToolInputStartPart: if textLoopProbeBuffer != nil { textLoopProbeBuffer.Flush() } - if !sendEvent(ctx, ch, StreamEvent{ + if !sendEvent(sendCtx, ch, StreamEvent{ Type: EventToolCallStart, ToolName: rp.ToolName, ToolCallID: rp.ID, @@ -955,7 +995,7 @@ func (a *Agent) runMidStreamRetry( if textLoopProbeBuffer != nil { textLoopProbeBuffer.Flush() } - if !sendEvent(ctx, ch, StreamEvent{ + if !sendEvent(sendCtx, ch, StreamEvent{ Type: EventToolCallStart, ToolName: rp.ToolName, ToolCallID: rp.ToolCallID, @@ -964,14 +1004,15 @@ func (a *Agent) runMidStreamRetry( aborted = true } case *sdk.StreamToolResultPart: + shouldAbort := toolLoopAbortCallIDs.Take(rp.ToolCallID) stepNumber++ - if !sendEvent(ctx, ch, StreamEvent{ + if !sendEvent(sendCtx, ch, StreamEvent{ Type: EventToolCallEnd, ToolName: rp.ToolName, ToolCallID: rp.ToolCallID, Input: rp.Input, Result: rp.Output, - }) || !sendEvent(ctx, ch, StreamEvent{ + }) || !sendEvent(sendCtx, ch, StreamEvent{ Type: EventProgress, StepNumber: stepNumber, ToolName: rp.ToolName, @@ -979,8 +1020,15 @@ func (a *Agent) runMidStreamRetry( }) { aborted = true } + if shouldAbort { + a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", rp.ToolCallID)) + cancel(ErrToolLoopDetected) + aborted = true + } case *sdk.StreamToolErrorPart: - if !sendEvent(ctx, ch, StreamEvent{ + tookLoopAbort := toolLoopAbortCallIDs.Take(rp.ToolCallID) + shouldAbort := errors.Is(rp.Error, ErrToolLoopDetected) || tookLoopAbort + if !sendEvent(sendCtx, ch, StreamEvent{ Type: EventToolCallEnd, ToolName: rp.ToolName, ToolCallID: rp.ToolCallID, @@ -988,8 +1036,13 @@ func (a *Agent) runMidStreamRetry( }) { aborted = true } + if shouldAbort { + a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", rp.ToolCallID)) + cancel(ErrToolLoopDetected) + aborted = true + } case *sdk.ErrorPart: - sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: rp.Error.Error()}) + sendEvent(sendCtx, ch, StreamEvent{Type: EventError, Error: rp.Error.Error()}) aborted = true case *sdk.AbortPart: aborted = true @@ -1000,7 +1053,11 @@ func (a *Agent) runMidStreamRetry( break } } - return retryResult, aborted + if aborted { + for range retryResult.Stream { + } + } + return retryResult, aborted || detectGenerateLoopAbort(streamCtx, streamCtx.Err()) != nil } // All retry attempts failed return prevResult, true diff --git a/internal/agent/background/manager.go b/internal/agent/background/manager.go index e80472cf..2764d0ad 100644 --- a/internal/agent/background/manager.go +++ b/internal/agent/background/manager.go @@ -488,17 +488,25 @@ func (m *Manager) RunningTasksSummary(botID, sessionID string) string { defer m.mu.Unlock() var lines []string for _, t := range m.tasks { - if t.BotID == botID && t.SessionID == sessionID && t.Status == TaskRunning { - desc := t.Description - if desc == "" { - desc = truncate(t.Command, 80) - } - lines = append(lines, fmt.Sprintf("- [%s] %s (started %s ago, output: %s)", - t.ID, desc, - time.Since(t.StartedAt).Round(time.Second), - t.OutputFile, - )) + t.mu.Lock() + matches := t.BotID == botID && t.SessionID == sessionID && t.Status == TaskRunning + id := t.ID + desc := t.Description + command := t.Command + startedAt := t.StartedAt + outputFile := t.OutputFile + t.mu.Unlock() + if !matches { + continue } + if desc == "" { + desc = truncate(command, 80) + } + lines = append(lines, fmt.Sprintf("- [%s] %s (started %s ago, output: %s)", + id, desc, + time.Since(startedAt).Round(time.Second), + outputFile, + )) } if len(lines) == 0 { return "" diff --git a/internal/agent/generate_loop_test.go b/internal/agent/generate_loop_test.go index e3bae6e1..416c1249 100644 --- a/internal/agent/generate_loop_test.go +++ b/internal/agent/generate_loop_test.go @@ -3,7 +3,10 @@ package agent import ( "context" "errors" + "strings" + "sync/atomic" "testing" + "time" "github.com/google/jsonschema-go/jsonschema" sdk "github.com/memohai/twilight-ai/sdk" @@ -19,10 +22,76 @@ func (p staticToolProvider) Tools(context.Context, agenttools.SessionContext) ([ return p.tools, nil } +type atomicMockProvider struct { + calls atomic.Int32 + handler func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error) + stream func(ctx context.Context, params sdk.GenerateParams) (*sdk.StreamResult, error) +} + +func (*atomicMockProvider) Name() string { + return "mock" +} + +func (*atomicMockProvider) ListModels(context.Context) ([]sdk.Model, error) { + return nil, nil +} + +func (*atomicMockProvider) Test(context.Context) *sdk.ProviderTestResult { + return &sdk.ProviderTestResult{Status: sdk.ProviderStatusOK, Message: "ok"} +} + +func (*atomicMockProvider) TestModel(context.Context, string) (*sdk.ModelTestResult, error) { + return &sdk.ModelTestResult{Supported: true, Message: "supported"}, nil +} + +func (m *atomicMockProvider) DoGenerate(_ context.Context, params sdk.GenerateParams) (*sdk.GenerateResult, error) { + call := int(m.calls.Add(1)) + return m.handler(call, params) +} + +func (m *atomicMockProvider) DoStream(ctx context.Context, params sdk.GenerateParams) (*sdk.StreamResult, error) { + if m.stream != nil { + return m.stream(ctx, params) + } + + result, err := m.DoGenerate(ctx, params) + if err != nil { + return nil, err + } + ch := make(chan sdk.StreamPart, 8) + go func() { + defer close(ch) + ch <- &sdk.StartPart{} + ch <- &sdk.StartStepPart{} + if result.Text != "" { + ch <- &sdk.TextStartPart{ID: "mock"} + ch <- &sdk.TextDeltaPart{ID: "mock", Text: result.Text} + ch <- &sdk.TextEndPart{ID: "mock"} + } + for _, tc := range result.ToolCalls { + ch <- &sdk.StreamToolCallPart{ + ToolCallID: tc.ToolCallID, + ToolName: tc.ToolName, + Input: tc.Input, + } + } + ch <- &sdk.FinishStepPart{ + FinishReason: result.FinishReason, + Usage: result.Usage, + Response: result.Response, + } + ch <- &sdk.FinishPart{ + FinishReason: result.FinishReason, + TotalUsage: result.Usage, + } + }() + return &sdk.StreamResult{Stream: ch}, nil +} + func TestAgentGenerateStopsOnToolLoopAbort(t *testing.T) { t.Parallel() - modelProvider := &agentReadMediaMockProvider{ + modelProvider := &atomicMockProvider{ handler: func(_ int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) { return &sdk.GenerateResult{ FinishReason: sdk.FinishReasonToolCalls, @@ -58,8 +127,8 @@ func TestAgentGenerateStopsOnToolLoopAbort(t *testing.T) { if !errors.Is(err, ErrToolLoopDetected) { t.Fatalf("expected ErrToolLoopDetected, got %v", err) } - if modelProvider.calls >= 20 { - t.Fatalf("expected tool loop to stop generation, got %d provider calls", modelProvider.calls) + if modelProvider.calls.Load() >= 20 { + t.Fatalf("expected tool loop to stop generation, got %d provider calls", modelProvider.calls.Load()) } } @@ -67,7 +136,7 @@ func TestAgentGenerateStopsOnTextLoopAbort(t *testing.T) { t.Parallel() repeatedText := "abcdefghijklmnopqrstuvwxyz0123456789 repeated text chunk for loop detection" - modelProvider := &agentReadMediaMockProvider{ + modelProvider := &atomicMockProvider{ handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) { return &sdk.GenerateResult{ Text: repeatedText, @@ -104,8 +173,8 @@ func TestAgentGenerateStopsOnTextLoopAbort(t *testing.T) { if !errors.Is(err, ErrTextLoopDetected) { t.Fatalf("expected ErrTextLoopDetected, got %v", err) } - if modelProvider.calls >= 10 { - t.Fatalf("expected text loop to stop generation, got %d provider calls", modelProvider.calls) + if modelProvider.calls.Load() >= 10 { + t.Fatalf("expected text loop to stop generation, got %d provider calls", modelProvider.calls.Load()) } } @@ -113,7 +182,7 @@ func TestAgentGenerateStopsOnTerminalTextLoopAbort(t *testing.T) { t.Parallel() repeatedText := "abcdefghijklmnopqrstuvwxyz0123456789 repeated text chunk for loop detection" - modelProvider := &agentReadMediaMockProvider{ + modelProvider := &atomicMockProvider{ handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) { finishReason := sdk.FinishReasonToolCalls var toolCalls []sdk.ToolCall @@ -157,7 +226,314 @@ func TestAgentGenerateStopsOnTerminalTextLoopAbort(t *testing.T) { if !errors.Is(err, ErrTextLoopDetected) { t.Fatalf("expected ErrTextLoopDetected, got %v", err) } - if modelProvider.calls != 4 { - t.Fatalf("expected terminal text loop to abort on final step, got %d provider calls", modelProvider.calls) + if modelProvider.calls.Load() != 4 { + t.Fatalf("expected terminal text loop to abort on final step, got %d provider calls", modelProvider.calls.Load()) + } +} + +func TestAgentStreamStopsOnToolLoopAbort(t *testing.T) { + t.Parallel() + + modelProvider := &atomicMockProvider{ + handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) { + if call >= 20 { + return &sdk.GenerateResult{ + Text: "unexpected-final-step", + FinishReason: sdk.FinishReasonStop, + }, nil + } + return &sdk.GenerateResult{ + FinishReason: sdk.FinishReasonToolCalls, + ToolCalls: []sdk.ToolCall{{ + ToolCallID: "call-stream", + ToolName: "loop_tool", + Input: map[string]any{"query": "same"}, + }}, + }, nil + }, + } + + a := New(Deps{}) + a.SetToolProviders([]agenttools.ToolProvider{ + staticToolProvider{ + tools: []sdk.Tool{{ + Name: "loop_tool", + Parameters: &jsonschema.Schema{Type: "object"}, + Execute: func(_ *sdk.ToolExecContext, _ any) (any, error) { + return map[string]any{"ok": true}, nil + }, + }}, + }, + }) + + var terminal StreamEvent + for event := range a.Stream(context.Background(), RunConfig{ + Model: &sdk.Model{ID: "mock-model", Provider: modelProvider}, + Messages: []sdk.Message{sdk.UserMessage("loop stream")}, + SupportsToolCall: true, + Identity: SessionContext{BotID: "bot-1"}, + LoopDetection: LoopDetectionConfig{Enabled: true}, + }) { + if event.IsTerminal() { + terminal = event + } + } + + if terminal.Type != EventAgentAbort { + t.Fatalf("expected EventAgentAbort, got %q", terminal.Type) + } +} + +func TestAgentStreamMarksTerminalTextLoopAsAbort(t *testing.T) { + t.Parallel() + + repeatedChunk := strings.Repeat("abcd", 64) + var observedCancel atomic.Bool + modelProvider := &atomicMockProvider{ + stream: func(ctx context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) { + ch := make(chan sdk.StreamPart, 16) + go func() { + defer close(ch) + send := func(part sdk.StreamPart) bool { + select { + case <-ctx.Done(): + observedCancel.Store(true) + return false + case ch <- part: + return true + } + } + if !send(&sdk.StartPart{}) { + return + } + if !send(&sdk.StartStepPart{}) { + return + } + if !send(&sdk.TextStartPart{ID: "mock"}) { + return + } + for i := 0; i < 4; i++ { + if !send(&sdk.TextDeltaPart{ID: "mock", Text: repeatedChunk}) { + return + } + } + select { + case <-ctx.Done(): + observedCancel.Store(true) + return + case <-time.After(50 * time.Millisecond): + } + if !send(&sdk.TextEndPart{ID: "mock"}) { + return + } + if !send(&sdk.FinishStepPart{FinishReason: sdk.FinishReasonStop}) { + return + } + _ = send(&sdk.FinishPart{FinishReason: sdk.FinishReasonStop}) + }() + return &sdk.StreamResult{Stream: ch}, nil + }, + } + + a := New(Deps{}) + + var terminal StreamEvent + for event := range a.Stream(context.Background(), RunConfig{ + Model: &sdk.Model{ID: "mock-model", Provider: modelProvider}, + Messages: []sdk.Message{sdk.UserMessage("loop stream text")}, + SupportsToolCall: true, + Identity: SessionContext{BotID: "bot-1"}, + LoopDetection: LoopDetectionConfig{Enabled: true}, + }) { + if event.IsTerminal() { + terminal = event + } + } + + if !observedCancel.Load() { + t.Fatal("expected stream provider to observe context cancellation from text-loop abort") + } + if terminal.Type != EventAgentAbort { + t.Fatalf("expected EventAgentAbort, got %q", terminal.Type) + } +} + +func TestAgentStreamMarksRetryTextLoopAsAbort(t *testing.T) { + t.Parallel() + + repeatedChunk := strings.Repeat("abcd", 64) + var streamCalls atomic.Int32 + var observedCancel atomic.Bool + modelProvider := &atomicMockProvider{ + stream: func(ctx context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) { + call := streamCalls.Add(1) + ch := make(chan sdk.StreamPart, 16) + go func() { + defer close(ch) + send := func(part sdk.StreamPart) bool { + select { + case <-ctx.Done(): + observedCancel.Store(true) + return false + case ch <- part: + return true + } + } + + if !send(&sdk.StartPart{}) { + return + } + if !send(&sdk.StartStepPart{}) { + return + } + + if call == 1 { + _ = send(&sdk.ErrorPart{Error: errors.New("api error 500")}) + return + } + + if !send(&sdk.TextStartPart{ID: "mock-retry"}) { + return + } + for i := 0; i < 4; i++ { + if !send(&sdk.TextDeltaPart{ID: "mock-retry", Text: repeatedChunk}) { + return + } + } + select { + case <-ctx.Done(): + observedCancel.Store(true) + return + case <-time.After(50 * time.Millisecond): + } + if !send(&sdk.TextEndPart{ID: "mock-retry"}) { + return + } + if !send(&sdk.FinishStepPart{FinishReason: sdk.FinishReasonStop}) { + return + } + _ = send(&sdk.FinishPart{FinishReason: sdk.FinishReasonStop}) + }() + return &sdk.StreamResult{Stream: ch}, nil + }, + } + + a := New(Deps{}) + + var terminal StreamEvent + for event := range a.Stream(context.Background(), RunConfig{ + Model: &sdk.Model{ID: "mock-model", Provider: modelProvider}, + Messages: []sdk.Message{sdk.UserMessage("loop stream retry text")}, + SupportsToolCall: true, + Identity: SessionContext{BotID: "bot-1"}, + LoopDetection: LoopDetectionConfig{Enabled: true}, + }) { + if event.IsTerminal() { + terminal = event + } + } + + if streamCalls.Load() != 2 { + t.Fatalf("expected one retry stream attempt, got %d stream calls", streamCalls.Load()) + } + if !observedCancel.Load() { + t.Fatal("expected retry stream provider to observe context cancellation from text-loop abort") + } + if terminal.Type != EventAgentAbort { + t.Fatalf("expected EventAgentAbort after retry text-loop abort, got %q", terminal.Type) + } +} + +func TestRunMidStreamRetryMarksTextLoopCancellationAsAborted(t *testing.T) { + t.Parallel() + + repeatedChunk := strings.Repeat("abcd", 64) + var observedCancel atomic.Bool + modelProvider := &atomicMockProvider{ + stream: func(ctx context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) { + ch := make(chan sdk.StreamPart) + go func() { + defer close(ch) + send := func(part sdk.StreamPart) bool { + select { + case <-ctx.Done(): + observedCancel.Store(true) + return false + case ch <- part: + return true + } + } + + if !send(&sdk.StartPart{}) { + return + } + if !send(&sdk.StartStepPart{}) { + return + } + if !send(&sdk.TextStartPart{ID: "mock-retry-only"}) { + return + } + for i := 0; i < 4; i++ { + if !send(&sdk.TextDeltaPart{ID: "mock-retry-only", Text: repeatedChunk}) { + return + } + } + select { + case <-ctx.Done(): + observedCancel.Store(true) + return + case <-time.After(200 * time.Millisecond): + t.Error("expected text-loop detection to cancel retry stream before any extra part was sent") + return + } + }() + return &sdk.StreamResult{Stream: ch}, nil + }, + } + + a := New(Deps{}) + streamCtx, cancel := context.WithCancelCause(context.Background()) + defer cancel(nil) + + textLoopGuard := NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{}) + textLoopProbeBuffer := NewTextLoopProbeBuffer(LoopDetectedProbeChars, func(text string) { + result := textLoopGuard.Inspect(text) + if result.Abort { + cancel(ErrTextLoopDetected) + } + }) + + retryResult, aborted := a.runMidStreamRetry( + context.Background(), + streamCtx, + cancel, + newToolAbortRegistry(), + make(chan StreamEvent, 32), + RunConfig{ + Model: &sdk.Model{ID: "mock-model", Provider: modelProvider}, + Messages: []sdk.Message{sdk.UserMessage("retry text loop")}, + Identity: SessionContext{BotID: "bot-1"}, + LoopDetection: LoopDetectionConfig{Enabled: true}, + }, + nil, + nil, + &sdk.StreamResult{Messages: []sdk.Message{sdk.UserMessage("previous step")}}, + 0, + "api error 500", + &strings.Builder{}, + textLoopProbeBuffer, + ) + + if retryResult == nil { + t.Fatal("expected retry result") + } + if !observedCancel.Load() { + t.Fatal("expected retry stream provider to observe context cancellation from text-loop abort") + } + if !errors.Is(context.Cause(streamCtx), ErrTextLoopDetected) { + t.Fatalf("expected stream context cause ErrTextLoopDetected, got %v", context.Cause(streamCtx)) + } + if !aborted { + t.Fatal("expected runMidStreamRetry to report aborted when retry stream hit text-loop cancellation") } } diff --git a/internal/agent/guard_state.go b/internal/agent/guard_state.go index f54fe902..56e61864 100644 --- a/internal/agent/guard_state.go +++ b/internal/agent/guard_state.go @@ -1,6 +1,10 @@ package agent -import "sync" +import ( + "sync" + + "github.com/memohai/memoh/internal/agent/tools" +) type toolAbortRegistry struct { mu sync.Mutex @@ -46,3 +50,65 @@ func (r *toolAbortRegistry) Any() bool { defer r.mu.Unlock() return len(r.ids) > 0 } + +type toolEventCollector struct { + mu sync.Mutex + closed bool + events []tools.ToolStreamEvent +} + +func newToolEventCollector() *toolEventCollector { + return &toolEventCollector{} +} + +func (c *toolEventCollector) Add(evt tools.ToolStreamEvent) bool { + if c == nil { + return false + } + + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return false + } + c.events = append(c.events, evt) + return true +} + +func (c *toolEventCollector) Close() { + if c == nil { + return + } + + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true +} + +func (c *toolEventCollector) CloseAndSnapshot() []tools.ToolStreamEvent { + if c == nil { + return nil + } + + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + snapshot := make([]tools.ToolStreamEvent, len(c.events)) + copy(snapshot, c.events) + return snapshot +} + +// Snapshot returns a copy of collected events without closing the collector. +// Callers that own the collector lifetime should still invoke Close (or +// CloseAndSnapshot) so late emits are rejected. +func (c *toolEventCollector) Snapshot() []tools.ToolStreamEvent { + if c == nil { + return nil + } + + c.mu.Lock() + defer c.mu.Unlock() + out := make([]tools.ToolStreamEvent, len(c.events)) + copy(out, c.events) + return out +} diff --git a/internal/agent/guard_state_test.go b/internal/agent/guard_state_test.go index f4b6b5f8..36e9317b 100644 --- a/internal/agent/guard_state_test.go +++ b/internal/agent/guard_state_test.go @@ -3,7 +3,10 @@ package agent import ( "fmt" "sync" + "sync/atomic" "testing" + + "github.com/memohai/memoh/internal/agent/tools" ) func TestToolAbortRegistryConcurrentAccess(t *testing.T) { @@ -30,3 +33,33 @@ func TestToolAbortRegistryConcurrentAccess(t *testing.T) { t.Fatal("expected registry to be empty") } } + +func TestToolEventCollectorCloseIgnoresLateAdds(t *testing.T) { + collector := newToolEventCollector() + + for i := 0; i < 5; i++ { + if !collector.Add(tools.ToolStreamEvent{Type: tools.StreamEventSpawnHeartbeat}) { + t.Fatalf("unexpected add failure before close, i=%d", i) + } + } + snapshot := collector.CloseAndSnapshot() + if len(snapshot) != 5 { + t.Fatalf("expected snapshot len 5, got %d", len(snapshot)) + } + + var wg sync.WaitGroup + var postCloseAdds atomic.Int32 + for i := 0; i < 16; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if collector.Add(tools.ToolStreamEvent{Type: tools.StreamEventSpawnHeartbeat}) { + postCloseAdds.Add(1) + } + }() + } + wg.Wait() + if postCloseAdds.Load() != 0 { + t.Fatalf("expected 0 successful adds after close, got %d", postCloseAdds.Load()) + } +}