mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
fix(agent): stream loop abort, mid-stream retry parity, collector cleanup (#376)
* fix(agent): align stream retry abort and event collection * fix(agent): cancel stream on loop detect, harden retry and tool events * fix(agent): drain previous stream before retry * fix(lint): ctx ci lint --------- Co-authored-by: 晨苒 <16112591+chen-ran@users.noreply.github.com>
This commit is contained in:
+84
-27
@@ -78,6 +78,9 @@ func sendEvent(ctx context.Context, ch chan<- StreamEvent, evt StreamEvent) bool
|
||||
}
|
||||
|
||||
func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEvent) {
|
||||
streamCtx, cancel := context.WithCancelCause(ctx)
|
||||
defer cancel(nil)
|
||||
|
||||
// Stream emitter: tools targeting the current conversation push
|
||||
// side-effect events (attachments, reactions, speech) directly here.
|
||||
// Uses sendEvent to avoid goroutine leaks when the consumer stops reading.
|
||||
@@ -88,7 +91,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
var sdkTools []sdk.Tool
|
||||
if cfg.SupportsToolCall {
|
||||
var err error
|
||||
sdkTools, err = a.assembleTools(ctx, cfg, streamEmitter)
|
||||
sdkTools, err = a.assembleTools(streamCtx, cfg, streamEmitter)
|
||||
if err != nil {
|
||||
sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: fmt.Sprintf("assemble tools: %v", err)})
|
||||
return
|
||||
@@ -96,6 +99,8 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
}
|
||||
sdkTools, readMediaState := decorateReadMediaTools(cfg.Model, sdkTools)
|
||||
|
||||
aborted := false
|
||||
|
||||
// Loop detection setup
|
||||
var textLoopGuard *TextLoopGuard
|
||||
var textLoopProbeBuffer *TextLoopProbeBuffer
|
||||
@@ -107,6 +112,8 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
result := textLoopGuard.Inspect(text)
|
||||
if result.Abort {
|
||||
a.logger.Warn("text loop detected, will abort")
|
||||
aborted = true
|
||||
cancel(ErrTextLoopDetected)
|
||||
}
|
||||
})
|
||||
toolLoopGuard = NewToolLoopGuard(ToolLoopRepeatThreshold, ToolLoopWarningsBeforeAbort)
|
||||
@@ -198,7 +205,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
var streamResult *sdk.StreamResult
|
||||
for attempt := 0; attempt < retryCfg.MaxAttempts; attempt++ {
|
||||
var err error
|
||||
streamResult, err = a.client.StreamText(ctx, opts...)
|
||||
streamResult, err = a.client.StreamText(streamCtx, opts...)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
@@ -225,7 +232,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
}
|
||||
delay := retryDelay(attempt, retryCfg)
|
||||
if delay > 0 {
|
||||
if err := sleepWithContext(ctx, delay); err != nil {
|
||||
if err := sleepWithContext(streamCtx, delay); err != nil {
|
||||
sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: fmt.Sprintf("stream start: context cancelled during retry: %v", err)})
|
||||
return
|
||||
}
|
||||
@@ -235,11 +242,10 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
sendEvent(ctx, ch, StreamEvent{Type: EventAgentStart})
|
||||
|
||||
var allText strings.Builder
|
||||
aborted := false
|
||||
stepNumber := 0
|
||||
|
||||
for part := range streamResult.Stream {
|
||||
if ctx.Err() != nil {
|
||||
if streamCtx.Err() != nil {
|
||||
aborted = true
|
||||
break
|
||||
}
|
||||
@@ -319,11 +325,13 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
}
|
||||
|
||||
case *sdk.ToolProgressPart:
|
||||
ch <- StreamEvent{
|
||||
if !sendEvent(ctx, ch, StreamEvent{
|
||||
Type: EventToolCallProgress,
|
||||
ToolName: p.ToolName,
|
||||
ToolCallID: p.ToolCallID,
|
||||
Progress: p.Content,
|
||||
}) {
|
||||
aborted = true
|
||||
}
|
||||
|
||||
case *sdk.StreamToolResultPart:
|
||||
@@ -345,10 +353,14 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
}
|
||||
if shouldAbort {
|
||||
a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", p.ToolCallID))
|
||||
cancel(ErrToolLoopDetected)
|
||||
aborted = true
|
||||
}
|
||||
|
||||
case *sdk.StreamToolErrorPart:
|
||||
// Take before errors.Is so registry IDs from the loop guard are always cleared.
|
||||
tookLoopAbort := toolLoopAbortCallIDs.Take(p.ToolCallID)
|
||||
shouldAbort := errors.Is(p.Error, ErrToolLoopDetected) || tookLoopAbort
|
||||
if !sendEvent(ctx, ch, StreamEvent{
|
||||
Type: EventToolCallEnd,
|
||||
ToolName: p.ToolName,
|
||||
@@ -357,6 +369,11 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
}) {
|
||||
aborted = true
|
||||
}
|
||||
if shouldAbort {
|
||||
a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", p.ToolCallID))
|
||||
cancel(ErrToolLoopDetected)
|
||||
aborted = true
|
||||
}
|
||||
|
||||
case *sdk.StreamFilePart:
|
||||
mediaType := p.File.MediaType
|
||||
@@ -384,7 +401,8 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
// no work has been completed yet and retrying from the start is safe.
|
||||
if isRetryableStreamError(p.Error) {
|
||||
streamResult, aborted = a.runMidStreamRetry(
|
||||
ctx, ch, cfg, sdkTools, prepareStep, streamResult,
|
||||
ctx, streamCtx, cancel, toolLoopAbortCallIDs,
|
||||
ch, cfg, sdkTools, prepareStep, streamResult,
|
||||
stepNumber, errMsg, &allText, textLoopProbeBuffer,
|
||||
)
|
||||
} else {
|
||||
@@ -403,6 +421,11 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
}
|
||||
}
|
||||
|
||||
if aborted {
|
||||
for range streamResult.Stream {
|
||||
}
|
||||
}
|
||||
|
||||
if textLoopProbeBuffer != nil {
|
||||
textLoopProbeBuffer.Flush()
|
||||
}
|
||||
@@ -452,12 +475,10 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
||||
loopAbort := newLoopAbortState()
|
||||
|
||||
// Collecting emitter: tools push side-effect events here during generation.
|
||||
var collected []tools.ToolStreamEvent
|
||||
var collectedMu sync.Mutex
|
||||
collected := newToolEventCollector()
|
||||
defer collected.Close()
|
||||
collectEmitter := tools.StreamEmitter(func(evt tools.ToolStreamEvent) {
|
||||
collectedMu.Lock()
|
||||
defer collectedMu.Unlock()
|
||||
collected = append(collected, evt)
|
||||
collected.Add(evt)
|
||||
})
|
||||
|
||||
var sdkTools []sdk.Tool
|
||||
@@ -536,10 +557,11 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
||||
}
|
||||
|
||||
// Drain collected tool-emitted side effects into the result.
|
||||
collectedEvents := collected.CloseAndSnapshot()
|
||||
var attachments []FileAttachment
|
||||
var reactions []ReactionItem
|
||||
var speeches []SpeechItem
|
||||
for _, evt := range collected {
|
||||
for _, evt := range collectedEvents {
|
||||
switch evt.Type {
|
||||
case tools.StreamEventAttachment:
|
||||
for _, a := range evt.Attachments {
|
||||
@@ -860,8 +882,15 @@ func pruneOldToolResults(p *sdk.GenerateParams, keepSteps, threshold int) *sdk.G
|
||||
// runMidStreamRetry attempts to continue the agent stream after a retryable
|
||||
// mid-stream error. It re-invokes StreamText with the accumulated messages
|
||||
// and drains the new stream into the same output channel.
|
||||
//
|
||||
// sendCtx is used for sendEvent so consumer disconnect (parent ctx) still
|
||||
// controls channel back-pressure; streamCtx is passed to the SDK for the same
|
||||
// cancellation semantics as the main stream (including loop-detect cancel).
|
||||
func (a *Agent) runMidStreamRetry(
|
||||
ctx context.Context,
|
||||
sendCtx context.Context,
|
||||
streamCtx context.Context,
|
||||
cancel context.CancelCauseFunc,
|
||||
toolLoopAbortCallIDs *toolAbortRegistry,
|
||||
ch chan<- StreamEvent,
|
||||
cfg RunConfig,
|
||||
sdkTools []sdk.Tool,
|
||||
@@ -872,6 +901,13 @@ func (a *Agent) runMidStreamRetry(
|
||||
allText *strings.Builder,
|
||||
textLoopProbeBuffer *TextLoopProbeBuffer,
|
||||
) (*sdk.StreamResult, bool) {
|
||||
// Drain the previous stream before reading prevResult.Messages.
|
||||
// This avoids racing with the SDK's final StreamResult write.
|
||||
if prevResult.Stream != nil {
|
||||
for range prevResult.Stream {
|
||||
}
|
||||
}
|
||||
|
||||
retryCfg := DefaultRetryConfig()
|
||||
for attempt := 0; attempt < retryCfg.MaxAttempts; attempt++ {
|
||||
a.logger.Warn("mid-stream error, retrying",
|
||||
@@ -880,7 +916,7 @@ func (a *Agent) runMidStreamRetry(
|
||||
slog.Int("max_attempts", retryCfg.MaxAttempts),
|
||||
slog.String("error", errMsg),
|
||||
)
|
||||
if !sendEvent(ctx, ch, StreamEvent{
|
||||
if !sendEvent(sendCtx, ch, StreamEvent{
|
||||
Type: EventRetry,
|
||||
Attempt: attempt + 1,
|
||||
MaxAttempt: retryCfg.MaxAttempts,
|
||||
@@ -891,7 +927,7 @@ func (a *Agent) runMidStreamRetry(
|
||||
|
||||
delay := retryDelay(attempt, retryCfg)
|
||||
if delay > 0 {
|
||||
if err := sleepWithContext(ctx, delay); err != nil {
|
||||
if err := sleepWithContext(streamCtx, delay); err != nil {
|
||||
return prevResult, true // aborted
|
||||
}
|
||||
}
|
||||
@@ -903,7 +939,7 @@ func (a *Agent) runMidStreamRetry(
|
||||
retryCfgCopy.Messages = prevResult.Messages
|
||||
retryOpts := a.buildGenerateOptions(retryCfgCopy, sdkTools, prepareStep)
|
||||
|
||||
retryResult, retryErr := a.client.StreamText(ctx, retryOpts...)
|
||||
retryResult, retryErr := a.client.StreamText(streamCtx, retryOpts...)
|
||||
if retryErr != nil {
|
||||
a.logger.Warn("mid-stream retry failed to start",
|
||||
slog.Int("attempt", attempt+1),
|
||||
@@ -917,9 +953,13 @@ func (a *Agent) runMidStreamRetry(
|
||||
// Drain the retry stream into the main event loop
|
||||
aborted := false
|
||||
for retryPart := range retryResult.Stream {
|
||||
if streamCtx.Err() != nil {
|
||||
aborted = true
|
||||
break
|
||||
}
|
||||
switch rp := retryPart.(type) {
|
||||
case *sdk.TextStartPart:
|
||||
if !sendEvent(ctx, ch, StreamEvent{Type: EventTextStart}) {
|
||||
if !sendEvent(sendCtx, ch, StreamEvent{Type: EventTextStart}) {
|
||||
aborted = true
|
||||
}
|
||||
case *sdk.TextDeltaPart:
|
||||
@@ -927,7 +967,7 @@ func (a *Agent) runMidStreamRetry(
|
||||
if textLoopProbeBuffer != nil {
|
||||
textLoopProbeBuffer.Push(rp.Text)
|
||||
}
|
||||
if !sendEvent(ctx, ch, StreamEvent{Type: EventTextDelta, Delta: rp.Text}) {
|
||||
if !sendEvent(sendCtx, ch, StreamEvent{Type: EventTextDelta, Delta: rp.Text}) {
|
||||
aborted = true
|
||||
}
|
||||
allText.WriteString(rp.Text)
|
||||
@@ -937,14 +977,14 @@ func (a *Agent) runMidStreamRetry(
|
||||
textLoopProbeBuffer.Flush()
|
||||
}
|
||||
stepNumber++
|
||||
if !sendEvent(ctx, ch, StreamEvent{Type: EventTextEnd}) {
|
||||
if !sendEvent(sendCtx, ch, StreamEvent{Type: EventTextEnd}) {
|
||||
aborted = true
|
||||
}
|
||||
case *sdk.ToolInputStartPart:
|
||||
if textLoopProbeBuffer != nil {
|
||||
textLoopProbeBuffer.Flush()
|
||||
}
|
||||
if !sendEvent(ctx, ch, StreamEvent{
|
||||
if !sendEvent(sendCtx, ch, StreamEvent{
|
||||
Type: EventToolCallStart,
|
||||
ToolName: rp.ToolName,
|
||||
ToolCallID: rp.ID,
|
||||
@@ -955,7 +995,7 @@ func (a *Agent) runMidStreamRetry(
|
||||
if textLoopProbeBuffer != nil {
|
||||
textLoopProbeBuffer.Flush()
|
||||
}
|
||||
if !sendEvent(ctx, ch, StreamEvent{
|
||||
if !sendEvent(sendCtx, ch, StreamEvent{
|
||||
Type: EventToolCallStart,
|
||||
ToolName: rp.ToolName,
|
||||
ToolCallID: rp.ToolCallID,
|
||||
@@ -964,14 +1004,15 @@ func (a *Agent) runMidStreamRetry(
|
||||
aborted = true
|
||||
}
|
||||
case *sdk.StreamToolResultPart:
|
||||
shouldAbort := toolLoopAbortCallIDs.Take(rp.ToolCallID)
|
||||
stepNumber++
|
||||
if !sendEvent(ctx, ch, StreamEvent{
|
||||
if !sendEvent(sendCtx, ch, StreamEvent{
|
||||
Type: EventToolCallEnd,
|
||||
ToolName: rp.ToolName,
|
||||
ToolCallID: rp.ToolCallID,
|
||||
Input: rp.Input,
|
||||
Result: rp.Output,
|
||||
}) || !sendEvent(ctx, ch, StreamEvent{
|
||||
}) || !sendEvent(sendCtx, ch, StreamEvent{
|
||||
Type: EventProgress,
|
||||
StepNumber: stepNumber,
|
||||
ToolName: rp.ToolName,
|
||||
@@ -979,8 +1020,15 @@ func (a *Agent) runMidStreamRetry(
|
||||
}) {
|
||||
aborted = true
|
||||
}
|
||||
if shouldAbort {
|
||||
a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", rp.ToolCallID))
|
||||
cancel(ErrToolLoopDetected)
|
||||
aborted = true
|
||||
}
|
||||
case *sdk.StreamToolErrorPart:
|
||||
if !sendEvent(ctx, ch, StreamEvent{
|
||||
tookLoopAbort := toolLoopAbortCallIDs.Take(rp.ToolCallID)
|
||||
shouldAbort := errors.Is(rp.Error, ErrToolLoopDetected) || tookLoopAbort
|
||||
if !sendEvent(sendCtx, ch, StreamEvent{
|
||||
Type: EventToolCallEnd,
|
||||
ToolName: rp.ToolName,
|
||||
ToolCallID: rp.ToolCallID,
|
||||
@@ -988,8 +1036,13 @@ func (a *Agent) runMidStreamRetry(
|
||||
}) {
|
||||
aborted = true
|
||||
}
|
||||
if shouldAbort {
|
||||
a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", rp.ToolCallID))
|
||||
cancel(ErrToolLoopDetected)
|
||||
aborted = true
|
||||
}
|
||||
case *sdk.ErrorPart:
|
||||
sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: rp.Error.Error()})
|
||||
sendEvent(sendCtx, ch, StreamEvent{Type: EventError, Error: rp.Error.Error()})
|
||||
aborted = true
|
||||
case *sdk.AbortPart:
|
||||
aborted = true
|
||||
@@ -1000,7 +1053,11 @@ func (a *Agent) runMidStreamRetry(
|
||||
break
|
||||
}
|
||||
}
|
||||
return retryResult, aborted
|
||||
if aborted {
|
||||
for range retryResult.Stream {
|
||||
}
|
||||
}
|
||||
return retryResult, aborted || detectGenerateLoopAbort(streamCtx, streamCtx.Err()) != nil
|
||||
}
|
||||
// All retry attempts failed
|
||||
return prevResult, true
|
||||
|
||||
@@ -488,17 +488,25 @@ func (m *Manager) RunningTasksSummary(botID, sessionID string) string {
|
||||
defer m.mu.Unlock()
|
||||
var lines []string
|
||||
for _, t := range m.tasks {
|
||||
if t.BotID == botID && t.SessionID == sessionID && t.Status == TaskRunning {
|
||||
desc := t.Description
|
||||
if desc == "" {
|
||||
desc = truncate(t.Command, 80)
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("- [%s] %s (started %s ago, output: %s)",
|
||||
t.ID, desc,
|
||||
time.Since(t.StartedAt).Round(time.Second),
|
||||
t.OutputFile,
|
||||
))
|
||||
t.mu.Lock()
|
||||
matches := t.BotID == botID && t.SessionID == sessionID && t.Status == TaskRunning
|
||||
id := t.ID
|
||||
desc := t.Description
|
||||
command := t.Command
|
||||
startedAt := t.StartedAt
|
||||
outputFile := t.OutputFile
|
||||
t.mu.Unlock()
|
||||
if !matches {
|
||||
continue
|
||||
}
|
||||
if desc == "" {
|
||||
desc = truncate(command, 80)
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("- [%s] %s (started %s ago, output: %s)",
|
||||
id, desc,
|
||||
time.Since(startedAt).Round(time.Second),
|
||||
outputFile,
|
||||
))
|
||||
}
|
||||
if len(lines) == 0 {
|
||||
return ""
|
||||
|
||||
@@ -3,7 +3,10 @@ package agent
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/jsonschema-go/jsonschema"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
@@ -19,10 +22,76 @@ func (p staticToolProvider) Tools(context.Context, agenttools.SessionContext) ([
|
||||
return p.tools, nil
|
||||
}
|
||||
|
||||
type atomicMockProvider struct {
|
||||
calls atomic.Int32
|
||||
handler func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error)
|
||||
stream func(ctx context.Context, params sdk.GenerateParams) (*sdk.StreamResult, error)
|
||||
}
|
||||
|
||||
func (*atomicMockProvider) Name() string {
|
||||
return "mock"
|
||||
}
|
||||
|
||||
func (*atomicMockProvider) ListModels(context.Context) ([]sdk.Model, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (*atomicMockProvider) Test(context.Context) *sdk.ProviderTestResult {
|
||||
return &sdk.ProviderTestResult{Status: sdk.ProviderStatusOK, Message: "ok"}
|
||||
}
|
||||
|
||||
func (*atomicMockProvider) TestModel(context.Context, string) (*sdk.ModelTestResult, error) {
|
||||
return &sdk.ModelTestResult{Supported: true, Message: "supported"}, nil
|
||||
}
|
||||
|
||||
func (m *atomicMockProvider) DoGenerate(_ context.Context, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||
call := int(m.calls.Add(1))
|
||||
return m.handler(call, params)
|
||||
}
|
||||
|
||||
func (m *atomicMockProvider) DoStream(ctx context.Context, params sdk.GenerateParams) (*sdk.StreamResult, error) {
|
||||
if m.stream != nil {
|
||||
return m.stream(ctx, params)
|
||||
}
|
||||
|
||||
result, err := m.DoGenerate(ctx, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch := make(chan sdk.StreamPart, 8)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ch <- &sdk.StartPart{}
|
||||
ch <- &sdk.StartStepPart{}
|
||||
if result.Text != "" {
|
||||
ch <- &sdk.TextStartPart{ID: "mock"}
|
||||
ch <- &sdk.TextDeltaPart{ID: "mock", Text: result.Text}
|
||||
ch <- &sdk.TextEndPart{ID: "mock"}
|
||||
}
|
||||
for _, tc := range result.ToolCalls {
|
||||
ch <- &sdk.StreamToolCallPart{
|
||||
ToolCallID: tc.ToolCallID,
|
||||
ToolName: tc.ToolName,
|
||||
Input: tc.Input,
|
||||
}
|
||||
}
|
||||
ch <- &sdk.FinishStepPart{
|
||||
FinishReason: result.FinishReason,
|
||||
Usage: result.Usage,
|
||||
Response: result.Response,
|
||||
}
|
||||
ch <- &sdk.FinishPart{
|
||||
FinishReason: result.FinishReason,
|
||||
TotalUsage: result.Usage,
|
||||
}
|
||||
}()
|
||||
return &sdk.StreamResult{Stream: ch}, nil
|
||||
}
|
||||
|
||||
func TestAgentGenerateStopsOnToolLoopAbort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
modelProvider := &agentReadMediaMockProvider{
|
||||
modelProvider := &atomicMockProvider{
|
||||
handler: func(_ int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||
return &sdk.GenerateResult{
|
||||
FinishReason: sdk.FinishReasonToolCalls,
|
||||
@@ -58,8 +127,8 @@ func TestAgentGenerateStopsOnToolLoopAbort(t *testing.T) {
|
||||
if !errors.Is(err, ErrToolLoopDetected) {
|
||||
t.Fatalf("expected ErrToolLoopDetected, got %v", err)
|
||||
}
|
||||
if modelProvider.calls >= 20 {
|
||||
t.Fatalf("expected tool loop to stop generation, got %d provider calls", modelProvider.calls)
|
||||
if modelProvider.calls.Load() >= 20 {
|
||||
t.Fatalf("expected tool loop to stop generation, got %d provider calls", modelProvider.calls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,7 +136,7 @@ func TestAgentGenerateStopsOnTextLoopAbort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
repeatedText := "abcdefghijklmnopqrstuvwxyz0123456789 repeated text chunk for loop detection"
|
||||
modelProvider := &agentReadMediaMockProvider{
|
||||
modelProvider := &atomicMockProvider{
|
||||
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||
return &sdk.GenerateResult{
|
||||
Text: repeatedText,
|
||||
@@ -104,8 +173,8 @@ func TestAgentGenerateStopsOnTextLoopAbort(t *testing.T) {
|
||||
if !errors.Is(err, ErrTextLoopDetected) {
|
||||
t.Fatalf("expected ErrTextLoopDetected, got %v", err)
|
||||
}
|
||||
if modelProvider.calls >= 10 {
|
||||
t.Fatalf("expected text loop to stop generation, got %d provider calls", modelProvider.calls)
|
||||
if modelProvider.calls.Load() >= 10 {
|
||||
t.Fatalf("expected text loop to stop generation, got %d provider calls", modelProvider.calls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,7 +182,7 @@ func TestAgentGenerateStopsOnTerminalTextLoopAbort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
repeatedText := "abcdefghijklmnopqrstuvwxyz0123456789 repeated text chunk for loop detection"
|
||||
modelProvider := &agentReadMediaMockProvider{
|
||||
modelProvider := &atomicMockProvider{
|
||||
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||
finishReason := sdk.FinishReasonToolCalls
|
||||
var toolCalls []sdk.ToolCall
|
||||
@@ -157,7 +226,314 @@ func TestAgentGenerateStopsOnTerminalTextLoopAbort(t *testing.T) {
|
||||
if !errors.Is(err, ErrTextLoopDetected) {
|
||||
t.Fatalf("expected ErrTextLoopDetected, got %v", err)
|
||||
}
|
||||
if modelProvider.calls != 4 {
|
||||
t.Fatalf("expected terminal text loop to abort on final step, got %d provider calls", modelProvider.calls)
|
||||
if modelProvider.calls.Load() != 4 {
|
||||
t.Fatalf("expected terminal text loop to abort on final step, got %d provider calls", modelProvider.calls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentStreamStopsOnToolLoopAbort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
modelProvider := &atomicMockProvider{
|
||||
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||
if call >= 20 {
|
||||
return &sdk.GenerateResult{
|
||||
Text: "unexpected-final-step",
|
||||
FinishReason: sdk.FinishReasonStop,
|
||||
}, nil
|
||||
}
|
||||
return &sdk.GenerateResult{
|
||||
FinishReason: sdk.FinishReasonToolCalls,
|
||||
ToolCalls: []sdk.ToolCall{{
|
||||
ToolCallID: "call-stream",
|
||||
ToolName: "loop_tool",
|
||||
Input: map[string]any{"query": "same"},
|
||||
}},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
a := New(Deps{})
|
||||
a.SetToolProviders([]agenttools.ToolProvider{
|
||||
staticToolProvider{
|
||||
tools: []sdk.Tool{{
|
||||
Name: "loop_tool",
|
||||
Parameters: &jsonschema.Schema{Type: "object"},
|
||||
Execute: func(_ *sdk.ToolExecContext, _ any) (any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
},
|
||||
}},
|
||||
},
|
||||
})
|
||||
|
||||
var terminal StreamEvent
|
||||
for event := range a.Stream(context.Background(), RunConfig{
|
||||
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
||||
Messages: []sdk.Message{sdk.UserMessage("loop stream")},
|
||||
SupportsToolCall: true,
|
||||
Identity: SessionContext{BotID: "bot-1"},
|
||||
LoopDetection: LoopDetectionConfig{Enabled: true},
|
||||
}) {
|
||||
if event.IsTerminal() {
|
||||
terminal = event
|
||||
}
|
||||
}
|
||||
|
||||
if terminal.Type != EventAgentAbort {
|
||||
t.Fatalf("expected EventAgentAbort, got %q", terminal.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentStreamMarksTerminalTextLoopAsAbort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
repeatedChunk := strings.Repeat("abcd", 64)
|
||||
var observedCancel atomic.Bool
|
||||
modelProvider := &atomicMockProvider{
|
||||
stream: func(ctx context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) {
|
||||
ch := make(chan sdk.StreamPart, 16)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
send := func(part sdk.StreamPart) bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
observedCancel.Store(true)
|
||||
return false
|
||||
case ch <- part:
|
||||
return true
|
||||
}
|
||||
}
|
||||
if !send(&sdk.StartPart{}) {
|
||||
return
|
||||
}
|
||||
if !send(&sdk.StartStepPart{}) {
|
||||
return
|
||||
}
|
||||
if !send(&sdk.TextStartPart{ID: "mock"}) {
|
||||
return
|
||||
}
|
||||
for i := 0; i < 4; i++ {
|
||||
if !send(&sdk.TextDeltaPart{ID: "mock", Text: repeatedChunk}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
observedCancel.Store(true)
|
||||
return
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
if !send(&sdk.TextEndPart{ID: "mock"}) {
|
||||
return
|
||||
}
|
||||
if !send(&sdk.FinishStepPart{FinishReason: sdk.FinishReasonStop}) {
|
||||
return
|
||||
}
|
||||
_ = send(&sdk.FinishPart{FinishReason: sdk.FinishReasonStop})
|
||||
}()
|
||||
return &sdk.StreamResult{Stream: ch}, nil
|
||||
},
|
||||
}
|
||||
|
||||
a := New(Deps{})
|
||||
|
||||
var terminal StreamEvent
|
||||
for event := range a.Stream(context.Background(), RunConfig{
|
||||
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
||||
Messages: []sdk.Message{sdk.UserMessage("loop stream text")},
|
||||
SupportsToolCall: true,
|
||||
Identity: SessionContext{BotID: "bot-1"},
|
||||
LoopDetection: LoopDetectionConfig{Enabled: true},
|
||||
}) {
|
||||
if event.IsTerminal() {
|
||||
terminal = event
|
||||
}
|
||||
}
|
||||
|
||||
if !observedCancel.Load() {
|
||||
t.Fatal("expected stream provider to observe context cancellation from text-loop abort")
|
||||
}
|
||||
if terminal.Type != EventAgentAbort {
|
||||
t.Fatalf("expected EventAgentAbort, got %q", terminal.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentStreamMarksRetryTextLoopAsAbort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
repeatedChunk := strings.Repeat("abcd", 64)
|
||||
var streamCalls atomic.Int32
|
||||
var observedCancel atomic.Bool
|
||||
modelProvider := &atomicMockProvider{
|
||||
stream: func(ctx context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) {
|
||||
call := streamCalls.Add(1)
|
||||
ch := make(chan sdk.StreamPart, 16)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
send := func(part sdk.StreamPart) bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
observedCancel.Store(true)
|
||||
return false
|
||||
case ch <- part:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if !send(&sdk.StartPart{}) {
|
||||
return
|
||||
}
|
||||
if !send(&sdk.StartStepPart{}) {
|
||||
return
|
||||
}
|
||||
|
||||
if call == 1 {
|
||||
_ = send(&sdk.ErrorPart{Error: errors.New("api error 500")})
|
||||
return
|
||||
}
|
||||
|
||||
if !send(&sdk.TextStartPart{ID: "mock-retry"}) {
|
||||
return
|
||||
}
|
||||
for i := 0; i < 4; i++ {
|
||||
if !send(&sdk.TextDeltaPart{ID: "mock-retry", Text: repeatedChunk}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
observedCancel.Store(true)
|
||||
return
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
if !send(&sdk.TextEndPart{ID: "mock-retry"}) {
|
||||
return
|
||||
}
|
||||
if !send(&sdk.FinishStepPart{FinishReason: sdk.FinishReasonStop}) {
|
||||
return
|
||||
}
|
||||
_ = send(&sdk.FinishPart{FinishReason: sdk.FinishReasonStop})
|
||||
}()
|
||||
return &sdk.StreamResult{Stream: ch}, nil
|
||||
},
|
||||
}
|
||||
|
||||
a := New(Deps{})
|
||||
|
||||
var terminal StreamEvent
|
||||
for event := range a.Stream(context.Background(), RunConfig{
|
||||
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
||||
Messages: []sdk.Message{sdk.UserMessage("loop stream retry text")},
|
||||
SupportsToolCall: true,
|
||||
Identity: SessionContext{BotID: "bot-1"},
|
||||
LoopDetection: LoopDetectionConfig{Enabled: true},
|
||||
}) {
|
||||
if event.IsTerminal() {
|
||||
terminal = event
|
||||
}
|
||||
}
|
||||
|
||||
if streamCalls.Load() != 2 {
|
||||
t.Fatalf("expected one retry stream attempt, got %d stream calls", streamCalls.Load())
|
||||
}
|
||||
if !observedCancel.Load() {
|
||||
t.Fatal("expected retry stream provider to observe context cancellation from text-loop abort")
|
||||
}
|
||||
if terminal.Type != EventAgentAbort {
|
||||
t.Fatalf("expected EventAgentAbort after retry text-loop abort, got %q", terminal.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunMidStreamRetryMarksTextLoopCancellationAsAborted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
repeatedChunk := strings.Repeat("abcd", 64)
|
||||
var observedCancel atomic.Bool
|
||||
modelProvider := &atomicMockProvider{
|
||||
stream: func(ctx context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) {
|
||||
ch := make(chan sdk.StreamPart)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
send := func(part sdk.StreamPart) bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
observedCancel.Store(true)
|
||||
return false
|
||||
case ch <- part:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if !send(&sdk.StartPart{}) {
|
||||
return
|
||||
}
|
||||
if !send(&sdk.StartStepPart{}) {
|
||||
return
|
||||
}
|
||||
if !send(&sdk.TextStartPart{ID: "mock-retry-only"}) {
|
||||
return
|
||||
}
|
||||
for i := 0; i < 4; i++ {
|
||||
if !send(&sdk.TextDeltaPart{ID: "mock-retry-only", Text: repeatedChunk}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
observedCancel.Store(true)
|
||||
return
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Error("expected text-loop detection to cancel retry stream before any extra part was sent")
|
||||
return
|
||||
}
|
||||
}()
|
||||
return &sdk.StreamResult{Stream: ch}, nil
|
||||
},
|
||||
}
|
||||
|
||||
a := New(Deps{})
|
||||
streamCtx, cancel := context.WithCancelCause(context.Background())
|
||||
defer cancel(nil)
|
||||
|
||||
textLoopGuard := NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{})
|
||||
textLoopProbeBuffer := NewTextLoopProbeBuffer(LoopDetectedProbeChars, func(text string) {
|
||||
result := textLoopGuard.Inspect(text)
|
||||
if result.Abort {
|
||||
cancel(ErrTextLoopDetected)
|
||||
}
|
||||
})
|
||||
|
||||
retryResult, aborted := a.runMidStreamRetry(
|
||||
context.Background(),
|
||||
streamCtx,
|
||||
cancel,
|
||||
newToolAbortRegistry(),
|
||||
make(chan StreamEvent, 32),
|
||||
RunConfig{
|
||||
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
||||
Messages: []sdk.Message{sdk.UserMessage("retry text loop")},
|
||||
Identity: SessionContext{BotID: "bot-1"},
|
||||
LoopDetection: LoopDetectionConfig{Enabled: true},
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
&sdk.StreamResult{Messages: []sdk.Message{sdk.UserMessage("previous step")}},
|
||||
0,
|
||||
"api error 500",
|
||||
&strings.Builder{},
|
||||
textLoopProbeBuffer,
|
||||
)
|
||||
|
||||
if retryResult == nil {
|
||||
t.Fatal("expected retry result")
|
||||
}
|
||||
if !observedCancel.Load() {
|
||||
t.Fatal("expected retry stream provider to observe context cancellation from text-loop abort")
|
||||
}
|
||||
if !errors.Is(context.Cause(streamCtx), ErrTextLoopDetected) {
|
||||
t.Fatalf("expected stream context cause ErrTextLoopDetected, got %v", context.Cause(streamCtx))
|
||||
}
|
||||
if !aborted {
|
||||
t.Fatal("expected runMidStreamRetry to report aborted when retry stream hit text-loop cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package agent
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/memohai/memoh/internal/agent/tools"
|
||||
)
|
||||
|
||||
type toolAbortRegistry struct {
|
||||
mu sync.Mutex
|
||||
@@ -46,3 +50,65 @@ func (r *toolAbortRegistry) Any() bool {
|
||||
defer r.mu.Unlock()
|
||||
return len(r.ids) > 0
|
||||
}
|
||||
|
||||
type toolEventCollector struct {
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
events []tools.ToolStreamEvent
|
||||
}
|
||||
|
||||
func newToolEventCollector() *toolEventCollector {
|
||||
return &toolEventCollector{}
|
||||
}
|
||||
|
||||
func (c *toolEventCollector) Add(evt tools.ToolStreamEvent) bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closed {
|
||||
return false
|
||||
}
|
||||
c.events = append(c.events, evt)
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *toolEventCollector) Close() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.closed = true
|
||||
}
|
||||
|
||||
func (c *toolEventCollector) CloseAndSnapshot() []tools.ToolStreamEvent {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.closed = true
|
||||
snapshot := make([]tools.ToolStreamEvent, len(c.events))
|
||||
copy(snapshot, c.events)
|
||||
return snapshot
|
||||
}
|
||||
|
||||
// Snapshot returns a copy of collected events without closing the collector.
|
||||
// Callers that own the collector lifetime should still invoke Close (or
|
||||
// CloseAndSnapshot) so late emits are rejected.
|
||||
func (c *toolEventCollector) Snapshot() []tools.ToolStreamEvent {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
out := make([]tools.ToolStreamEvent, len(c.events))
|
||||
copy(out, c.events)
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -3,7 +3,10 @@ package agent
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/memohai/memoh/internal/agent/tools"
|
||||
)
|
||||
|
||||
func TestToolAbortRegistryConcurrentAccess(t *testing.T) {
|
||||
@@ -30,3 +33,33 @@ func TestToolAbortRegistryConcurrentAccess(t *testing.T) {
|
||||
t.Fatal("expected registry to be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolEventCollectorCloseIgnoresLateAdds(t *testing.T) {
|
||||
collector := newToolEventCollector()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
if !collector.Add(tools.ToolStreamEvent{Type: tools.StreamEventSpawnHeartbeat}) {
|
||||
t.Fatalf("unexpected add failure before close, i=%d", i)
|
||||
}
|
||||
}
|
||||
snapshot := collector.CloseAndSnapshot()
|
||||
if len(snapshot) != 5 {
|
||||
t.Fatalf("expected snapshot len 5, got %d", len(snapshot))
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var postCloseAdds atomic.Int32
|
||||
for i := 0; i < 16; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if collector.Add(tools.ToolStreamEvent{Type: tools.StreamEventSpawnHeartbeat}) {
|
||||
postCloseAdds.Add(1)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if postCloseAdds.Load() != 0 {
|
||||
t.Fatalf("expected 0 successful adds after close, got %d", postCloseAdds.Load())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user