fix(agent): stream loop abort, mid-stream retry parity, collector cleanup (#376)

* fix(agent): align stream retry abort and event collection

* fix(agent): cancel stream on loop detect, harden retry and tool events

* fix(agent): drain previous stream before retry

* fix(lint): ctx ci lint

---------

Co-authored-by: 晨苒 <16112591+chen-ran@users.noreply.github.com>
This commit is contained in:
Fodesu
2026-04-18 03:19:58 +08:00
committed by GitHub
parent b534248e19
commit db777b98ac
5 changed files with 587 additions and 47 deletions
+84 -27
View File
@@ -78,6 +78,9 @@ func sendEvent(ctx context.Context, ch chan<- StreamEvent, evt StreamEvent) bool
}
func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEvent) {
streamCtx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
// Stream emitter: tools targeting the current conversation push
// side-effect events (attachments, reactions, speech) directly here.
// Uses sendEvent to avoid goroutine leaks when the consumer stops reading.
@@ -88,7 +91,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
var sdkTools []sdk.Tool
if cfg.SupportsToolCall {
var err error
sdkTools, err = a.assembleTools(ctx, cfg, streamEmitter)
sdkTools, err = a.assembleTools(streamCtx, cfg, streamEmitter)
if err != nil {
sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: fmt.Sprintf("assemble tools: %v", err)})
return
@@ -96,6 +99,8 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
}
sdkTools, readMediaState := decorateReadMediaTools(cfg.Model, sdkTools)
aborted := false
// Loop detection setup
var textLoopGuard *TextLoopGuard
var textLoopProbeBuffer *TextLoopProbeBuffer
@@ -107,6 +112,8 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
result := textLoopGuard.Inspect(text)
if result.Abort {
a.logger.Warn("text loop detected, will abort")
aborted = true
cancel(ErrTextLoopDetected)
}
})
toolLoopGuard = NewToolLoopGuard(ToolLoopRepeatThreshold, ToolLoopWarningsBeforeAbort)
@@ -198,7 +205,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
var streamResult *sdk.StreamResult
for attempt := 0; attempt < retryCfg.MaxAttempts; attempt++ {
var err error
streamResult, err = a.client.StreamText(ctx, opts...)
streamResult, err = a.client.StreamText(streamCtx, opts...)
if err == nil {
break
}
@@ -225,7 +232,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
}
delay := retryDelay(attempt, retryCfg)
if delay > 0 {
if err := sleepWithContext(ctx, delay); err != nil {
if err := sleepWithContext(streamCtx, delay); err != nil {
sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: fmt.Sprintf("stream start: context cancelled during retry: %v", err)})
return
}
@@ -235,11 +242,10 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
sendEvent(ctx, ch, StreamEvent{Type: EventAgentStart})
var allText strings.Builder
aborted := false
stepNumber := 0
for part := range streamResult.Stream {
if ctx.Err() != nil {
if streamCtx.Err() != nil {
aborted = true
break
}
@@ -319,11 +325,13 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
}
case *sdk.ToolProgressPart:
ch <- StreamEvent{
if !sendEvent(ctx, ch, StreamEvent{
Type: EventToolCallProgress,
ToolName: p.ToolName,
ToolCallID: p.ToolCallID,
Progress: p.Content,
}) {
aborted = true
}
case *sdk.StreamToolResultPart:
@@ -345,10 +353,14 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
}
if shouldAbort {
a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", p.ToolCallID))
cancel(ErrToolLoopDetected)
aborted = true
}
case *sdk.StreamToolErrorPart:
// Take before errors.Is so registry IDs from the loop guard are always cleared.
tookLoopAbort := toolLoopAbortCallIDs.Take(p.ToolCallID)
shouldAbort := errors.Is(p.Error, ErrToolLoopDetected) || tookLoopAbort
if !sendEvent(ctx, ch, StreamEvent{
Type: EventToolCallEnd,
ToolName: p.ToolName,
@@ -357,6 +369,11 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
}) {
aborted = true
}
if shouldAbort {
a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", p.ToolCallID))
cancel(ErrToolLoopDetected)
aborted = true
}
case *sdk.StreamFilePart:
mediaType := p.File.MediaType
@@ -384,7 +401,8 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
// no work has been completed yet and retrying from the start is safe.
if isRetryableStreamError(p.Error) {
streamResult, aborted = a.runMidStreamRetry(
ctx, ch, cfg, sdkTools, prepareStep, streamResult,
ctx, streamCtx, cancel, toolLoopAbortCallIDs,
ch, cfg, sdkTools, prepareStep, streamResult,
stepNumber, errMsg, &allText, textLoopProbeBuffer,
)
} else {
@@ -403,6 +421,11 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
}
}
if aborted {
for range streamResult.Stream {
}
}
if textLoopProbeBuffer != nil {
textLoopProbeBuffer.Flush()
}
@@ -452,12 +475,10 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
loopAbort := newLoopAbortState()
// Collecting emitter: tools push side-effect events here during generation.
var collected []tools.ToolStreamEvent
var collectedMu sync.Mutex
collected := newToolEventCollector()
defer collected.Close()
collectEmitter := tools.StreamEmitter(func(evt tools.ToolStreamEvent) {
collectedMu.Lock()
defer collectedMu.Unlock()
collected = append(collected, evt)
collected.Add(evt)
})
var sdkTools []sdk.Tool
@@ -536,10 +557,11 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
}
// Drain collected tool-emitted side effects into the result.
collectedEvents := collected.CloseAndSnapshot()
var attachments []FileAttachment
var reactions []ReactionItem
var speeches []SpeechItem
for _, evt := range collected {
for _, evt := range collectedEvents {
switch evt.Type {
case tools.StreamEventAttachment:
for _, a := range evt.Attachments {
@@ -860,8 +882,15 @@ func pruneOldToolResults(p *sdk.GenerateParams, keepSteps, threshold int) *sdk.G
// runMidStreamRetry attempts to continue the agent stream after a retryable
// mid-stream error. It re-invokes StreamText with the accumulated messages
// and drains the new stream into the same output channel.
//
// sendCtx is used for sendEvent so consumer disconnect (parent ctx) still
// controls channel back-pressure; streamCtx is passed to the SDK for the same
// cancellation semantics as the main stream (including loop-detect cancel).
func (a *Agent) runMidStreamRetry(
ctx context.Context,
sendCtx context.Context,
streamCtx context.Context,
cancel context.CancelCauseFunc,
toolLoopAbortCallIDs *toolAbortRegistry,
ch chan<- StreamEvent,
cfg RunConfig,
sdkTools []sdk.Tool,
@@ -872,6 +901,13 @@ func (a *Agent) runMidStreamRetry(
allText *strings.Builder,
textLoopProbeBuffer *TextLoopProbeBuffer,
) (*sdk.StreamResult, bool) {
// Drain the previous stream before reading prevResult.Messages.
// This avoids racing with the SDK's final StreamResult write.
if prevResult.Stream != nil {
for range prevResult.Stream {
}
}
retryCfg := DefaultRetryConfig()
for attempt := 0; attempt < retryCfg.MaxAttempts; attempt++ {
a.logger.Warn("mid-stream error, retrying",
@@ -880,7 +916,7 @@ func (a *Agent) runMidStreamRetry(
slog.Int("max_attempts", retryCfg.MaxAttempts),
slog.String("error", errMsg),
)
if !sendEvent(ctx, ch, StreamEvent{
if !sendEvent(sendCtx, ch, StreamEvent{
Type: EventRetry,
Attempt: attempt + 1,
MaxAttempt: retryCfg.MaxAttempts,
@@ -891,7 +927,7 @@ func (a *Agent) runMidStreamRetry(
delay := retryDelay(attempt, retryCfg)
if delay > 0 {
if err := sleepWithContext(ctx, delay); err != nil {
if err := sleepWithContext(streamCtx, delay); err != nil {
return prevResult, true // aborted
}
}
@@ -903,7 +939,7 @@ func (a *Agent) runMidStreamRetry(
retryCfgCopy.Messages = prevResult.Messages
retryOpts := a.buildGenerateOptions(retryCfgCopy, sdkTools, prepareStep)
retryResult, retryErr := a.client.StreamText(ctx, retryOpts...)
retryResult, retryErr := a.client.StreamText(streamCtx, retryOpts...)
if retryErr != nil {
a.logger.Warn("mid-stream retry failed to start",
slog.Int("attempt", attempt+1),
@@ -917,9 +953,13 @@ func (a *Agent) runMidStreamRetry(
// Drain the retry stream into the main event loop
aborted := false
for retryPart := range retryResult.Stream {
if streamCtx.Err() != nil {
aborted = true
break
}
switch rp := retryPart.(type) {
case *sdk.TextStartPart:
if !sendEvent(ctx, ch, StreamEvent{Type: EventTextStart}) {
if !sendEvent(sendCtx, ch, StreamEvent{Type: EventTextStart}) {
aborted = true
}
case *sdk.TextDeltaPart:
@@ -927,7 +967,7 @@ func (a *Agent) runMidStreamRetry(
if textLoopProbeBuffer != nil {
textLoopProbeBuffer.Push(rp.Text)
}
if !sendEvent(ctx, ch, StreamEvent{Type: EventTextDelta, Delta: rp.Text}) {
if !sendEvent(sendCtx, ch, StreamEvent{Type: EventTextDelta, Delta: rp.Text}) {
aborted = true
}
allText.WriteString(rp.Text)
@@ -937,14 +977,14 @@ func (a *Agent) runMidStreamRetry(
textLoopProbeBuffer.Flush()
}
stepNumber++
if !sendEvent(ctx, ch, StreamEvent{Type: EventTextEnd}) {
if !sendEvent(sendCtx, ch, StreamEvent{Type: EventTextEnd}) {
aborted = true
}
case *sdk.ToolInputStartPart:
if textLoopProbeBuffer != nil {
textLoopProbeBuffer.Flush()
}
if !sendEvent(ctx, ch, StreamEvent{
if !sendEvent(sendCtx, ch, StreamEvent{
Type: EventToolCallStart,
ToolName: rp.ToolName,
ToolCallID: rp.ID,
@@ -955,7 +995,7 @@ func (a *Agent) runMidStreamRetry(
if textLoopProbeBuffer != nil {
textLoopProbeBuffer.Flush()
}
if !sendEvent(ctx, ch, StreamEvent{
if !sendEvent(sendCtx, ch, StreamEvent{
Type: EventToolCallStart,
ToolName: rp.ToolName,
ToolCallID: rp.ToolCallID,
@@ -964,14 +1004,15 @@ func (a *Agent) runMidStreamRetry(
aborted = true
}
case *sdk.StreamToolResultPart:
shouldAbort := toolLoopAbortCallIDs.Take(rp.ToolCallID)
stepNumber++
if !sendEvent(ctx, ch, StreamEvent{
if !sendEvent(sendCtx, ch, StreamEvent{
Type: EventToolCallEnd,
ToolName: rp.ToolName,
ToolCallID: rp.ToolCallID,
Input: rp.Input,
Result: rp.Output,
}) || !sendEvent(ctx, ch, StreamEvent{
}) || !sendEvent(sendCtx, ch, StreamEvent{
Type: EventProgress,
StepNumber: stepNumber,
ToolName: rp.ToolName,
@@ -979,8 +1020,15 @@ func (a *Agent) runMidStreamRetry(
}) {
aborted = true
}
if shouldAbort {
a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", rp.ToolCallID))
cancel(ErrToolLoopDetected)
aborted = true
}
case *sdk.StreamToolErrorPart:
if !sendEvent(ctx, ch, StreamEvent{
tookLoopAbort := toolLoopAbortCallIDs.Take(rp.ToolCallID)
shouldAbort := errors.Is(rp.Error, ErrToolLoopDetected) || tookLoopAbort
if !sendEvent(sendCtx, ch, StreamEvent{
Type: EventToolCallEnd,
ToolName: rp.ToolName,
ToolCallID: rp.ToolCallID,
@@ -988,8 +1036,13 @@ func (a *Agent) runMidStreamRetry(
}) {
aborted = true
}
if shouldAbort {
a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", rp.ToolCallID))
cancel(ErrToolLoopDetected)
aborted = true
}
case *sdk.ErrorPart:
sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: rp.Error.Error()})
sendEvent(sendCtx, ch, StreamEvent{Type: EventError, Error: rp.Error.Error()})
aborted = true
case *sdk.AbortPart:
aborted = true
@@ -1000,7 +1053,11 @@ func (a *Agent) runMidStreamRetry(
break
}
}
return retryResult, aborted
if aborted {
for range retryResult.Stream {
}
}
return retryResult, aborted || detectGenerateLoopAbort(streamCtx, streamCtx.Err()) != nil
}
// All retry attempts failed
return prevResult, true
+18 -10
View File
@@ -488,17 +488,25 @@ func (m *Manager) RunningTasksSummary(botID, sessionID string) string {
defer m.mu.Unlock()
var lines []string
for _, t := range m.tasks {
if t.BotID == botID && t.SessionID == sessionID && t.Status == TaskRunning {
desc := t.Description
if desc == "" {
desc = truncate(t.Command, 80)
}
lines = append(lines, fmt.Sprintf("- [%s] %s (started %s ago, output: %s)",
t.ID, desc,
time.Since(t.StartedAt).Round(time.Second),
t.OutputFile,
))
t.mu.Lock()
matches := t.BotID == botID && t.SessionID == sessionID && t.Status == TaskRunning
id := t.ID
desc := t.Description
command := t.Command
startedAt := t.StartedAt
outputFile := t.OutputFile
t.mu.Unlock()
if !matches {
continue
}
if desc == "" {
desc = truncate(command, 80)
}
lines = append(lines, fmt.Sprintf("- [%s] %s (started %s ago, output: %s)",
id, desc,
time.Since(startedAt).Round(time.Second),
outputFile,
))
}
if len(lines) == 0 {
return ""
+385 -9
View File
@@ -3,7 +3,10 @@ package agent
import (
"context"
"errors"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/google/jsonschema-go/jsonschema"
sdk "github.com/memohai/twilight-ai/sdk"
@@ -19,10 +22,76 @@ func (p staticToolProvider) Tools(context.Context, agenttools.SessionContext) ([
return p.tools, nil
}
type atomicMockProvider struct {
calls atomic.Int32
handler func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error)
stream func(ctx context.Context, params sdk.GenerateParams) (*sdk.StreamResult, error)
}
func (*atomicMockProvider) Name() string {
return "mock"
}
func (*atomicMockProvider) ListModels(context.Context) ([]sdk.Model, error) {
return nil, nil
}
func (*atomicMockProvider) Test(context.Context) *sdk.ProviderTestResult {
return &sdk.ProviderTestResult{Status: sdk.ProviderStatusOK, Message: "ok"}
}
func (*atomicMockProvider) TestModel(context.Context, string) (*sdk.ModelTestResult, error) {
return &sdk.ModelTestResult{Supported: true, Message: "supported"}, nil
}
func (m *atomicMockProvider) DoGenerate(_ context.Context, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
call := int(m.calls.Add(1))
return m.handler(call, params)
}
func (m *atomicMockProvider) DoStream(ctx context.Context, params sdk.GenerateParams) (*sdk.StreamResult, error) {
if m.stream != nil {
return m.stream(ctx, params)
}
result, err := m.DoGenerate(ctx, params)
if err != nil {
return nil, err
}
ch := make(chan sdk.StreamPart, 8)
go func() {
defer close(ch)
ch <- &sdk.StartPart{}
ch <- &sdk.StartStepPart{}
if result.Text != "" {
ch <- &sdk.TextStartPart{ID: "mock"}
ch <- &sdk.TextDeltaPart{ID: "mock", Text: result.Text}
ch <- &sdk.TextEndPart{ID: "mock"}
}
for _, tc := range result.ToolCalls {
ch <- &sdk.StreamToolCallPart{
ToolCallID: tc.ToolCallID,
ToolName: tc.ToolName,
Input: tc.Input,
}
}
ch <- &sdk.FinishStepPart{
FinishReason: result.FinishReason,
Usage: result.Usage,
Response: result.Response,
}
ch <- &sdk.FinishPart{
FinishReason: result.FinishReason,
TotalUsage: result.Usage,
}
}()
return &sdk.StreamResult{Stream: ch}, nil
}
func TestAgentGenerateStopsOnToolLoopAbort(t *testing.T) {
t.Parallel()
modelProvider := &agentReadMediaMockProvider{
modelProvider := &atomicMockProvider{
handler: func(_ int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
return &sdk.GenerateResult{
FinishReason: sdk.FinishReasonToolCalls,
@@ -58,8 +127,8 @@ func TestAgentGenerateStopsOnToolLoopAbort(t *testing.T) {
if !errors.Is(err, ErrToolLoopDetected) {
t.Fatalf("expected ErrToolLoopDetected, got %v", err)
}
if modelProvider.calls >= 20 {
t.Fatalf("expected tool loop to stop generation, got %d provider calls", modelProvider.calls)
if modelProvider.calls.Load() >= 20 {
t.Fatalf("expected tool loop to stop generation, got %d provider calls", modelProvider.calls.Load())
}
}
@@ -67,7 +136,7 @@ func TestAgentGenerateStopsOnTextLoopAbort(t *testing.T) {
t.Parallel()
repeatedText := "abcdefghijklmnopqrstuvwxyz0123456789 repeated text chunk for loop detection"
modelProvider := &agentReadMediaMockProvider{
modelProvider := &atomicMockProvider{
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
return &sdk.GenerateResult{
Text: repeatedText,
@@ -104,8 +173,8 @@ func TestAgentGenerateStopsOnTextLoopAbort(t *testing.T) {
if !errors.Is(err, ErrTextLoopDetected) {
t.Fatalf("expected ErrTextLoopDetected, got %v", err)
}
if modelProvider.calls >= 10 {
t.Fatalf("expected text loop to stop generation, got %d provider calls", modelProvider.calls)
if modelProvider.calls.Load() >= 10 {
t.Fatalf("expected text loop to stop generation, got %d provider calls", modelProvider.calls.Load())
}
}
@@ -113,7 +182,7 @@ func TestAgentGenerateStopsOnTerminalTextLoopAbort(t *testing.T) {
t.Parallel()
repeatedText := "abcdefghijklmnopqrstuvwxyz0123456789 repeated text chunk for loop detection"
modelProvider := &agentReadMediaMockProvider{
modelProvider := &atomicMockProvider{
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
finishReason := sdk.FinishReasonToolCalls
var toolCalls []sdk.ToolCall
@@ -157,7 +226,314 @@ func TestAgentGenerateStopsOnTerminalTextLoopAbort(t *testing.T) {
if !errors.Is(err, ErrTextLoopDetected) {
t.Fatalf("expected ErrTextLoopDetected, got %v", err)
}
if modelProvider.calls != 4 {
t.Fatalf("expected terminal text loop to abort on final step, got %d provider calls", modelProvider.calls)
if modelProvider.calls.Load() != 4 {
t.Fatalf("expected terminal text loop to abort on final step, got %d provider calls", modelProvider.calls.Load())
}
}
func TestAgentStreamStopsOnToolLoopAbort(t *testing.T) {
t.Parallel()
modelProvider := &atomicMockProvider{
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
if call >= 20 {
return &sdk.GenerateResult{
Text: "unexpected-final-step",
FinishReason: sdk.FinishReasonStop,
}, nil
}
return &sdk.GenerateResult{
FinishReason: sdk.FinishReasonToolCalls,
ToolCalls: []sdk.ToolCall{{
ToolCallID: "call-stream",
ToolName: "loop_tool",
Input: map[string]any{"query": "same"},
}},
}, nil
},
}
a := New(Deps{})
a.SetToolProviders([]agenttools.ToolProvider{
staticToolProvider{
tools: []sdk.Tool{{
Name: "loop_tool",
Parameters: &jsonschema.Schema{Type: "object"},
Execute: func(_ *sdk.ToolExecContext, _ any) (any, error) {
return map[string]any{"ok": true}, nil
},
}},
},
})
var terminal StreamEvent
for event := range a.Stream(context.Background(), RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("loop stream")},
SupportsToolCall: true,
Identity: SessionContext{BotID: "bot-1"},
LoopDetection: LoopDetectionConfig{Enabled: true},
}) {
if event.IsTerminal() {
terminal = event
}
}
if terminal.Type != EventAgentAbort {
t.Fatalf("expected EventAgentAbort, got %q", terminal.Type)
}
}
func TestAgentStreamMarksTerminalTextLoopAsAbort(t *testing.T) {
t.Parallel()
repeatedChunk := strings.Repeat("abcd", 64)
var observedCancel atomic.Bool
modelProvider := &atomicMockProvider{
stream: func(ctx context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) {
ch := make(chan sdk.StreamPart, 16)
go func() {
defer close(ch)
send := func(part sdk.StreamPart) bool {
select {
case <-ctx.Done():
observedCancel.Store(true)
return false
case ch <- part:
return true
}
}
if !send(&sdk.StartPart{}) {
return
}
if !send(&sdk.StartStepPart{}) {
return
}
if !send(&sdk.TextStartPart{ID: "mock"}) {
return
}
for i := 0; i < 4; i++ {
if !send(&sdk.TextDeltaPart{ID: "mock", Text: repeatedChunk}) {
return
}
}
select {
case <-ctx.Done():
observedCancel.Store(true)
return
case <-time.After(50 * time.Millisecond):
}
if !send(&sdk.TextEndPart{ID: "mock"}) {
return
}
if !send(&sdk.FinishStepPart{FinishReason: sdk.FinishReasonStop}) {
return
}
_ = send(&sdk.FinishPart{FinishReason: sdk.FinishReasonStop})
}()
return &sdk.StreamResult{Stream: ch}, nil
},
}
a := New(Deps{})
var terminal StreamEvent
for event := range a.Stream(context.Background(), RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("loop stream text")},
SupportsToolCall: true,
Identity: SessionContext{BotID: "bot-1"},
LoopDetection: LoopDetectionConfig{Enabled: true},
}) {
if event.IsTerminal() {
terminal = event
}
}
if !observedCancel.Load() {
t.Fatal("expected stream provider to observe context cancellation from text-loop abort")
}
if terminal.Type != EventAgentAbort {
t.Fatalf("expected EventAgentAbort, got %q", terminal.Type)
}
}
func TestAgentStreamMarksRetryTextLoopAsAbort(t *testing.T) {
t.Parallel()
repeatedChunk := strings.Repeat("abcd", 64)
var streamCalls atomic.Int32
var observedCancel atomic.Bool
modelProvider := &atomicMockProvider{
stream: func(ctx context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) {
call := streamCalls.Add(1)
ch := make(chan sdk.StreamPart, 16)
go func() {
defer close(ch)
send := func(part sdk.StreamPart) bool {
select {
case <-ctx.Done():
observedCancel.Store(true)
return false
case ch <- part:
return true
}
}
if !send(&sdk.StartPart{}) {
return
}
if !send(&sdk.StartStepPart{}) {
return
}
if call == 1 {
_ = send(&sdk.ErrorPart{Error: errors.New("api error 500")})
return
}
if !send(&sdk.TextStartPart{ID: "mock-retry"}) {
return
}
for i := 0; i < 4; i++ {
if !send(&sdk.TextDeltaPart{ID: "mock-retry", Text: repeatedChunk}) {
return
}
}
select {
case <-ctx.Done():
observedCancel.Store(true)
return
case <-time.After(50 * time.Millisecond):
}
if !send(&sdk.TextEndPart{ID: "mock-retry"}) {
return
}
if !send(&sdk.FinishStepPart{FinishReason: sdk.FinishReasonStop}) {
return
}
_ = send(&sdk.FinishPart{FinishReason: sdk.FinishReasonStop})
}()
return &sdk.StreamResult{Stream: ch}, nil
},
}
a := New(Deps{})
var terminal StreamEvent
for event := range a.Stream(context.Background(), RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("loop stream retry text")},
SupportsToolCall: true,
Identity: SessionContext{BotID: "bot-1"},
LoopDetection: LoopDetectionConfig{Enabled: true},
}) {
if event.IsTerminal() {
terminal = event
}
}
if streamCalls.Load() != 2 {
t.Fatalf("expected one retry stream attempt, got %d stream calls", streamCalls.Load())
}
if !observedCancel.Load() {
t.Fatal("expected retry stream provider to observe context cancellation from text-loop abort")
}
if terminal.Type != EventAgentAbort {
t.Fatalf("expected EventAgentAbort after retry text-loop abort, got %q", terminal.Type)
}
}
func TestRunMidStreamRetryMarksTextLoopCancellationAsAborted(t *testing.T) {
t.Parallel()
repeatedChunk := strings.Repeat("abcd", 64)
var observedCancel atomic.Bool
modelProvider := &atomicMockProvider{
stream: func(ctx context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) {
ch := make(chan sdk.StreamPart)
go func() {
defer close(ch)
send := func(part sdk.StreamPart) bool {
select {
case <-ctx.Done():
observedCancel.Store(true)
return false
case ch <- part:
return true
}
}
if !send(&sdk.StartPart{}) {
return
}
if !send(&sdk.StartStepPart{}) {
return
}
if !send(&sdk.TextStartPart{ID: "mock-retry-only"}) {
return
}
for i := 0; i < 4; i++ {
if !send(&sdk.TextDeltaPart{ID: "mock-retry-only", Text: repeatedChunk}) {
return
}
}
select {
case <-ctx.Done():
observedCancel.Store(true)
return
case <-time.After(200 * time.Millisecond):
t.Error("expected text-loop detection to cancel retry stream before any extra part was sent")
return
}
}()
return &sdk.StreamResult{Stream: ch}, nil
},
}
a := New(Deps{})
streamCtx, cancel := context.WithCancelCause(context.Background())
defer cancel(nil)
textLoopGuard := NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{})
textLoopProbeBuffer := NewTextLoopProbeBuffer(LoopDetectedProbeChars, func(text string) {
result := textLoopGuard.Inspect(text)
if result.Abort {
cancel(ErrTextLoopDetected)
}
})
retryResult, aborted := a.runMidStreamRetry(
context.Background(),
streamCtx,
cancel,
newToolAbortRegistry(),
make(chan StreamEvent, 32),
RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("retry text loop")},
Identity: SessionContext{BotID: "bot-1"},
LoopDetection: LoopDetectionConfig{Enabled: true},
},
nil,
nil,
&sdk.StreamResult{Messages: []sdk.Message{sdk.UserMessage("previous step")}},
0,
"api error 500",
&strings.Builder{},
textLoopProbeBuffer,
)
if retryResult == nil {
t.Fatal("expected retry result")
}
if !observedCancel.Load() {
t.Fatal("expected retry stream provider to observe context cancellation from text-loop abort")
}
if !errors.Is(context.Cause(streamCtx), ErrTextLoopDetected) {
t.Fatalf("expected stream context cause ErrTextLoopDetected, got %v", context.Cause(streamCtx))
}
if !aborted {
t.Fatal("expected runMidStreamRetry to report aborted when retry stream hit text-loop cancellation")
}
}
+67 -1
View File
@@ -1,6 +1,10 @@
package agent
import "sync"
import (
"sync"
"github.com/memohai/memoh/internal/agent/tools"
)
type toolAbortRegistry struct {
mu sync.Mutex
@@ -46,3 +50,65 @@ func (r *toolAbortRegistry) Any() bool {
defer r.mu.Unlock()
return len(r.ids) > 0
}
type toolEventCollector struct {
mu sync.Mutex
closed bool
events []tools.ToolStreamEvent
}
func newToolEventCollector() *toolEventCollector {
return &toolEventCollector{}
}
func (c *toolEventCollector) Add(evt tools.ToolStreamEvent) bool {
if c == nil {
return false
}
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return false
}
c.events = append(c.events, evt)
return true
}
func (c *toolEventCollector) Close() {
if c == nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
c.closed = true
}
func (c *toolEventCollector) CloseAndSnapshot() []tools.ToolStreamEvent {
if c == nil {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
c.closed = true
snapshot := make([]tools.ToolStreamEvent, len(c.events))
copy(snapshot, c.events)
return snapshot
}
// Snapshot returns a copy of collected events without closing the collector.
// Callers that own the collector lifetime should still invoke Close (or
// CloseAndSnapshot) so late emits are rejected.
func (c *toolEventCollector) Snapshot() []tools.ToolStreamEvent {
if c == nil {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
out := make([]tools.ToolStreamEvent, len(c.events))
copy(out, c.events)
return out
}
+33
View File
@@ -3,7 +3,10 @@ package agent
import (
"fmt"
"sync"
"sync/atomic"
"testing"
"github.com/memohai/memoh/internal/agent/tools"
)
func TestToolAbortRegistryConcurrentAccess(t *testing.T) {
@@ -30,3 +33,33 @@ func TestToolAbortRegistryConcurrentAccess(t *testing.T) {
t.Fatal("expected registry to be empty")
}
}
func TestToolEventCollectorCloseIgnoresLateAdds(t *testing.T) {
collector := newToolEventCollector()
for i := 0; i < 5; i++ {
if !collector.Add(tools.ToolStreamEvent{Type: tools.StreamEventSpawnHeartbeat}) {
t.Fatalf("unexpected add failure before close, i=%d", i)
}
}
snapshot := collector.CloseAndSnapshot()
if len(snapshot) != 5 {
t.Fatalf("expected snapshot len 5, got %d", len(snapshot))
}
var wg sync.WaitGroup
var postCloseAdds atomic.Int32
for i := 0; i < 16; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if collector.Add(tools.ToolStreamEvent{Type: tools.StreamEventSpawnHeartbeat}) {
postCloseAdds.Add(1)
}
}()
}
wg.Wait()
if postCloseAdds.Load() != 0 {
t.Fatalf("expected 0 successful adds after close, got %d", postCloseAdds.Load())
}
}