Files
Memoh/internal/agent/generate_loop_test.go
Fodesu db777b98ac fix(agent): stream loop abort, mid-stream retry parity, collector cleanup (#376)
* fix(agent): align stream retry abort and event collection

* fix(agent): cancel stream on loop detect, harden retry and tool events

* fix(agent): drain previous stream before retry

* fix(lint): ctx ci lint

---------

Co-authored-by: 晨苒 <16112591+chen-ran@users.noreply.github.com>
2026-04-18 03:19:58 +08:00

540 lines
15 KiB
Go

package agent
import (
"context"
"errors"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/google/jsonschema-go/jsonschema"
sdk "github.com/memohai/twilight-ai/sdk"
agenttools "github.com/memohai/memoh/internal/agent/tools"
)
type staticToolProvider struct {
tools []sdk.Tool
}
func (p staticToolProvider) Tools(context.Context, agenttools.SessionContext) ([]sdk.Tool, error) {
return p.tools, nil
}
type atomicMockProvider struct {
calls atomic.Int32
handler func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error)
stream func(ctx context.Context, params sdk.GenerateParams) (*sdk.StreamResult, error)
}
func (*atomicMockProvider) Name() string {
return "mock"
}
func (*atomicMockProvider) ListModels(context.Context) ([]sdk.Model, error) {
return nil, nil
}
func (*atomicMockProvider) Test(context.Context) *sdk.ProviderTestResult {
return &sdk.ProviderTestResult{Status: sdk.ProviderStatusOK, Message: "ok"}
}
func (*atomicMockProvider) TestModel(context.Context, string) (*sdk.ModelTestResult, error) {
return &sdk.ModelTestResult{Supported: true, Message: "supported"}, nil
}
func (m *atomicMockProvider) DoGenerate(_ context.Context, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
call := int(m.calls.Add(1))
return m.handler(call, params)
}
func (m *atomicMockProvider) DoStream(ctx context.Context, params sdk.GenerateParams) (*sdk.StreamResult, error) {
if m.stream != nil {
return m.stream(ctx, params)
}
result, err := m.DoGenerate(ctx, params)
if err != nil {
return nil, err
}
ch := make(chan sdk.StreamPart, 8)
go func() {
defer close(ch)
ch <- &sdk.StartPart{}
ch <- &sdk.StartStepPart{}
if result.Text != "" {
ch <- &sdk.TextStartPart{ID: "mock"}
ch <- &sdk.TextDeltaPart{ID: "mock", Text: result.Text}
ch <- &sdk.TextEndPart{ID: "mock"}
}
for _, tc := range result.ToolCalls {
ch <- &sdk.StreamToolCallPart{
ToolCallID: tc.ToolCallID,
ToolName: tc.ToolName,
Input: tc.Input,
}
}
ch <- &sdk.FinishStepPart{
FinishReason: result.FinishReason,
Usage: result.Usage,
Response: result.Response,
}
ch <- &sdk.FinishPart{
FinishReason: result.FinishReason,
TotalUsage: result.Usage,
}
}()
return &sdk.StreamResult{Stream: ch}, nil
}
func TestAgentGenerateStopsOnToolLoopAbort(t *testing.T) {
t.Parallel()
modelProvider := &atomicMockProvider{
handler: func(_ int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
return &sdk.GenerateResult{
FinishReason: sdk.FinishReasonToolCalls,
ToolCalls: []sdk.ToolCall{{
ToolCallID: "call-same",
ToolName: "loop_tool",
Input: map[string]any{"query": "same"},
}},
}, nil
},
}
a := New(Deps{})
a.SetToolProviders([]agenttools.ToolProvider{
staticToolProvider{
tools: []sdk.Tool{{
Name: "loop_tool",
Parameters: &jsonschema.Schema{Type: "object"},
Execute: func(_ *sdk.ToolExecContext, _ any) (any, error) {
return map[string]any{"ok": true}, nil
},
}},
},
})
_, err := a.Generate(context.Background(), RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("loop")},
SupportsToolCall: true,
Identity: SessionContext{BotID: "bot-1"},
LoopDetection: LoopDetectionConfig{Enabled: true},
})
if !errors.Is(err, ErrToolLoopDetected) {
t.Fatalf("expected ErrToolLoopDetected, got %v", err)
}
if modelProvider.calls.Load() >= 20 {
t.Fatalf("expected tool loop to stop generation, got %d provider calls", modelProvider.calls.Load())
}
}
func TestAgentGenerateStopsOnTextLoopAbort(t *testing.T) {
t.Parallel()
repeatedText := "abcdefghijklmnopqrstuvwxyz0123456789 repeated text chunk for loop detection"
modelProvider := &atomicMockProvider{
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
return &sdk.GenerateResult{
Text: repeatedText,
FinishReason: sdk.FinishReasonToolCalls,
ToolCalls: []sdk.ToolCall{{
ToolCallID: "call-text",
ToolName: "noop_tool",
Input: map[string]any{"step": call},
}},
}, nil
},
}
a := New(Deps{})
a.SetToolProviders([]agenttools.ToolProvider{
staticToolProvider{
tools: []sdk.Tool{{
Name: "noop_tool",
Parameters: &jsonschema.Schema{Type: "object"},
Execute: func(_ *sdk.ToolExecContext, _ any) (any, error) {
return map[string]any{"ok": true}, nil
},
}},
},
})
_, err := a.Generate(context.Background(), RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("loop text")},
SupportsToolCall: true,
Identity: SessionContext{BotID: "bot-1"},
LoopDetection: LoopDetectionConfig{Enabled: true},
})
if !errors.Is(err, ErrTextLoopDetected) {
t.Fatalf("expected ErrTextLoopDetected, got %v", err)
}
if modelProvider.calls.Load() >= 10 {
t.Fatalf("expected text loop to stop generation, got %d provider calls", modelProvider.calls.Load())
}
}
func TestAgentGenerateStopsOnTerminalTextLoopAbort(t *testing.T) {
t.Parallel()
repeatedText := "abcdefghijklmnopqrstuvwxyz0123456789 repeated text chunk for loop detection"
modelProvider := &atomicMockProvider{
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
finishReason := sdk.FinishReasonToolCalls
var toolCalls []sdk.ToolCall
if call < 4 {
toolCalls = []sdk.ToolCall{{
ToolCallID: "call-terminal",
ToolName: "noop_tool",
Input: map[string]any{"step": call},
}}
} else {
finishReason = sdk.FinishReasonStop
}
return &sdk.GenerateResult{
Text: repeatedText,
FinishReason: finishReason,
ToolCalls: toolCalls,
}, nil
},
}
a := New(Deps{})
a.SetToolProviders([]agenttools.ToolProvider{
staticToolProvider{
tools: []sdk.Tool{{
Name: "noop_tool",
Parameters: &jsonschema.Schema{Type: "object"},
Execute: func(_ *sdk.ToolExecContext, _ any) (any, error) {
return map[string]any{"ok": true}, nil
},
}},
},
})
_, err := a.Generate(context.Background(), RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("loop text terminal")},
SupportsToolCall: true,
Identity: SessionContext{BotID: "bot-1"},
LoopDetection: LoopDetectionConfig{Enabled: true},
})
if !errors.Is(err, ErrTextLoopDetected) {
t.Fatalf("expected ErrTextLoopDetected, got %v", err)
}
if modelProvider.calls.Load() != 4 {
t.Fatalf("expected terminal text loop to abort on final step, got %d provider calls", modelProvider.calls.Load())
}
}
func TestAgentStreamStopsOnToolLoopAbort(t *testing.T) {
t.Parallel()
modelProvider := &atomicMockProvider{
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
if call >= 20 {
return &sdk.GenerateResult{
Text: "unexpected-final-step",
FinishReason: sdk.FinishReasonStop,
}, nil
}
return &sdk.GenerateResult{
FinishReason: sdk.FinishReasonToolCalls,
ToolCalls: []sdk.ToolCall{{
ToolCallID: "call-stream",
ToolName: "loop_tool",
Input: map[string]any{"query": "same"},
}},
}, nil
},
}
a := New(Deps{})
a.SetToolProviders([]agenttools.ToolProvider{
staticToolProvider{
tools: []sdk.Tool{{
Name: "loop_tool",
Parameters: &jsonschema.Schema{Type: "object"},
Execute: func(_ *sdk.ToolExecContext, _ any) (any, error) {
return map[string]any{"ok": true}, nil
},
}},
},
})
var terminal StreamEvent
for event := range a.Stream(context.Background(), RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("loop stream")},
SupportsToolCall: true,
Identity: SessionContext{BotID: "bot-1"},
LoopDetection: LoopDetectionConfig{Enabled: true},
}) {
if event.IsTerminal() {
terminal = event
}
}
if terminal.Type != EventAgentAbort {
t.Fatalf("expected EventAgentAbort, got %q", terminal.Type)
}
}
func TestAgentStreamMarksTerminalTextLoopAsAbort(t *testing.T) {
t.Parallel()
repeatedChunk := strings.Repeat("abcd", 64)
var observedCancel atomic.Bool
modelProvider := &atomicMockProvider{
stream: func(ctx context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) {
ch := make(chan sdk.StreamPart, 16)
go func() {
defer close(ch)
send := func(part sdk.StreamPart) bool {
select {
case <-ctx.Done():
observedCancel.Store(true)
return false
case ch <- part:
return true
}
}
if !send(&sdk.StartPart{}) {
return
}
if !send(&sdk.StartStepPart{}) {
return
}
if !send(&sdk.TextStartPart{ID: "mock"}) {
return
}
for i := 0; i < 4; i++ {
if !send(&sdk.TextDeltaPart{ID: "mock", Text: repeatedChunk}) {
return
}
}
select {
case <-ctx.Done():
observedCancel.Store(true)
return
case <-time.After(50 * time.Millisecond):
}
if !send(&sdk.TextEndPart{ID: "mock"}) {
return
}
if !send(&sdk.FinishStepPart{FinishReason: sdk.FinishReasonStop}) {
return
}
_ = send(&sdk.FinishPart{FinishReason: sdk.FinishReasonStop})
}()
return &sdk.StreamResult{Stream: ch}, nil
},
}
a := New(Deps{})
var terminal StreamEvent
for event := range a.Stream(context.Background(), RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("loop stream text")},
SupportsToolCall: true,
Identity: SessionContext{BotID: "bot-1"},
LoopDetection: LoopDetectionConfig{Enabled: true},
}) {
if event.IsTerminal() {
terminal = event
}
}
if !observedCancel.Load() {
t.Fatal("expected stream provider to observe context cancellation from text-loop abort")
}
if terminal.Type != EventAgentAbort {
t.Fatalf("expected EventAgentAbort, got %q", terminal.Type)
}
}
func TestAgentStreamMarksRetryTextLoopAsAbort(t *testing.T) {
t.Parallel()
repeatedChunk := strings.Repeat("abcd", 64)
var streamCalls atomic.Int32
var observedCancel atomic.Bool
modelProvider := &atomicMockProvider{
stream: func(ctx context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) {
call := streamCalls.Add(1)
ch := make(chan sdk.StreamPart, 16)
go func() {
defer close(ch)
send := func(part sdk.StreamPart) bool {
select {
case <-ctx.Done():
observedCancel.Store(true)
return false
case ch <- part:
return true
}
}
if !send(&sdk.StartPart{}) {
return
}
if !send(&sdk.StartStepPart{}) {
return
}
if call == 1 {
_ = send(&sdk.ErrorPart{Error: errors.New("api error 500")})
return
}
if !send(&sdk.TextStartPart{ID: "mock-retry"}) {
return
}
for i := 0; i < 4; i++ {
if !send(&sdk.TextDeltaPart{ID: "mock-retry", Text: repeatedChunk}) {
return
}
}
select {
case <-ctx.Done():
observedCancel.Store(true)
return
case <-time.After(50 * time.Millisecond):
}
if !send(&sdk.TextEndPart{ID: "mock-retry"}) {
return
}
if !send(&sdk.FinishStepPart{FinishReason: sdk.FinishReasonStop}) {
return
}
_ = send(&sdk.FinishPart{FinishReason: sdk.FinishReasonStop})
}()
return &sdk.StreamResult{Stream: ch}, nil
},
}
a := New(Deps{})
var terminal StreamEvent
for event := range a.Stream(context.Background(), RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("loop stream retry text")},
SupportsToolCall: true,
Identity: SessionContext{BotID: "bot-1"},
LoopDetection: LoopDetectionConfig{Enabled: true},
}) {
if event.IsTerminal() {
terminal = event
}
}
if streamCalls.Load() != 2 {
t.Fatalf("expected one retry stream attempt, got %d stream calls", streamCalls.Load())
}
if !observedCancel.Load() {
t.Fatal("expected retry stream provider to observe context cancellation from text-loop abort")
}
if terminal.Type != EventAgentAbort {
t.Fatalf("expected EventAgentAbort after retry text-loop abort, got %q", terminal.Type)
}
}
func TestRunMidStreamRetryMarksTextLoopCancellationAsAborted(t *testing.T) {
t.Parallel()
repeatedChunk := strings.Repeat("abcd", 64)
var observedCancel atomic.Bool
modelProvider := &atomicMockProvider{
stream: func(ctx context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) {
ch := make(chan sdk.StreamPart)
go func() {
defer close(ch)
send := func(part sdk.StreamPart) bool {
select {
case <-ctx.Done():
observedCancel.Store(true)
return false
case ch <- part:
return true
}
}
if !send(&sdk.StartPart{}) {
return
}
if !send(&sdk.StartStepPart{}) {
return
}
if !send(&sdk.TextStartPart{ID: "mock-retry-only"}) {
return
}
for i := 0; i < 4; i++ {
if !send(&sdk.TextDeltaPart{ID: "mock-retry-only", Text: repeatedChunk}) {
return
}
}
select {
case <-ctx.Done():
observedCancel.Store(true)
return
case <-time.After(200 * time.Millisecond):
t.Error("expected text-loop detection to cancel retry stream before any extra part was sent")
return
}
}()
return &sdk.StreamResult{Stream: ch}, nil
},
}
a := New(Deps{})
streamCtx, cancel := context.WithCancelCause(context.Background())
defer cancel(nil)
textLoopGuard := NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{})
textLoopProbeBuffer := NewTextLoopProbeBuffer(LoopDetectedProbeChars, func(text string) {
result := textLoopGuard.Inspect(text)
if result.Abort {
cancel(ErrTextLoopDetected)
}
})
retryResult, aborted := a.runMidStreamRetry(
context.Background(),
streamCtx,
cancel,
newToolAbortRegistry(),
make(chan StreamEvent, 32),
RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("retry text loop")},
Identity: SessionContext{BotID: "bot-1"},
LoopDetection: LoopDetectionConfig{Enabled: true},
},
nil,
nil,
&sdk.StreamResult{Messages: []sdk.Message{sdk.UserMessage("previous step")}},
0,
"api error 500",
&strings.Builder{},
textLoopProbeBuffer,
)
if retryResult == nil {
t.Fatal("expected retry result")
}
if !observedCancel.Load() {
t.Fatal("expected retry stream provider to observe context cancellation from text-loop abort")
}
if !errors.Is(context.Cause(streamCtx), ErrTextLoopDetected) {
t.Fatalf("expected stream context cause ErrTextLoopDetected, got %v", context.Cause(streamCtx))
}
if !aborted {
t.Fatal("expected runMidStreamRetry to report aborted when retry stream hit text-loop cancellation")
}
}