mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
fix(agent): stream loop abort, mid-stream retry parity, collector cleanup (#376)
* fix(agent): align stream retry abort and event collection * fix(agent): cancel stream on loop detect, harden retry and tool events * fix(agent): drain previous stream before retry * fix(lint): ctx ci lint --------- Co-authored-by: 晨苒 <16112591+chen-ran@users.noreply.github.com>
This commit is contained in:
+84
-27
@@ -78,6 +78,9 @@ func sendEvent(ctx context.Context, ch chan<- StreamEvent, evt StreamEvent) bool
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEvent) {
|
func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEvent) {
|
||||||
|
streamCtx, cancel := context.WithCancelCause(ctx)
|
||||||
|
defer cancel(nil)
|
||||||
|
|
||||||
// Stream emitter: tools targeting the current conversation push
|
// Stream emitter: tools targeting the current conversation push
|
||||||
// side-effect events (attachments, reactions, speech) directly here.
|
// side-effect events (attachments, reactions, speech) directly here.
|
||||||
// Uses sendEvent to avoid goroutine leaks when the consumer stops reading.
|
// Uses sendEvent to avoid goroutine leaks when the consumer stops reading.
|
||||||
@@ -88,7 +91,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
|||||||
var sdkTools []sdk.Tool
|
var sdkTools []sdk.Tool
|
||||||
if cfg.SupportsToolCall {
|
if cfg.SupportsToolCall {
|
||||||
var err error
|
var err error
|
||||||
sdkTools, err = a.assembleTools(ctx, cfg, streamEmitter)
|
sdkTools, err = a.assembleTools(streamCtx, cfg, streamEmitter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: fmt.Sprintf("assemble tools: %v", err)})
|
sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: fmt.Sprintf("assemble tools: %v", err)})
|
||||||
return
|
return
|
||||||
@@ -96,6 +99,8 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
|||||||
}
|
}
|
||||||
sdkTools, readMediaState := decorateReadMediaTools(cfg.Model, sdkTools)
|
sdkTools, readMediaState := decorateReadMediaTools(cfg.Model, sdkTools)
|
||||||
|
|
||||||
|
aborted := false
|
||||||
|
|
||||||
// Loop detection setup
|
// Loop detection setup
|
||||||
var textLoopGuard *TextLoopGuard
|
var textLoopGuard *TextLoopGuard
|
||||||
var textLoopProbeBuffer *TextLoopProbeBuffer
|
var textLoopProbeBuffer *TextLoopProbeBuffer
|
||||||
@@ -107,6 +112,8 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
|||||||
result := textLoopGuard.Inspect(text)
|
result := textLoopGuard.Inspect(text)
|
||||||
if result.Abort {
|
if result.Abort {
|
||||||
a.logger.Warn("text loop detected, will abort")
|
a.logger.Warn("text loop detected, will abort")
|
||||||
|
aborted = true
|
||||||
|
cancel(ErrTextLoopDetected)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
toolLoopGuard = NewToolLoopGuard(ToolLoopRepeatThreshold, ToolLoopWarningsBeforeAbort)
|
toolLoopGuard = NewToolLoopGuard(ToolLoopRepeatThreshold, ToolLoopWarningsBeforeAbort)
|
||||||
@@ -198,7 +205,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
|||||||
var streamResult *sdk.StreamResult
|
var streamResult *sdk.StreamResult
|
||||||
for attempt := 0; attempt < retryCfg.MaxAttempts; attempt++ {
|
for attempt := 0; attempt < retryCfg.MaxAttempts; attempt++ {
|
||||||
var err error
|
var err error
|
||||||
streamResult, err = a.client.StreamText(ctx, opts...)
|
streamResult, err = a.client.StreamText(streamCtx, opts...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -225,7 +232,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
|||||||
}
|
}
|
||||||
delay := retryDelay(attempt, retryCfg)
|
delay := retryDelay(attempt, retryCfg)
|
||||||
if delay > 0 {
|
if delay > 0 {
|
||||||
if err := sleepWithContext(ctx, delay); err != nil {
|
if err := sleepWithContext(streamCtx, delay); err != nil {
|
||||||
sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: fmt.Sprintf("stream start: context cancelled during retry: %v", err)})
|
sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: fmt.Sprintf("stream start: context cancelled during retry: %v", err)})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -235,11 +242,10 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
|||||||
sendEvent(ctx, ch, StreamEvent{Type: EventAgentStart})
|
sendEvent(ctx, ch, StreamEvent{Type: EventAgentStart})
|
||||||
|
|
||||||
var allText strings.Builder
|
var allText strings.Builder
|
||||||
aborted := false
|
|
||||||
stepNumber := 0
|
stepNumber := 0
|
||||||
|
|
||||||
for part := range streamResult.Stream {
|
for part := range streamResult.Stream {
|
||||||
if ctx.Err() != nil {
|
if streamCtx.Err() != nil {
|
||||||
aborted = true
|
aborted = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -319,11 +325,13 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
|||||||
}
|
}
|
||||||
|
|
||||||
case *sdk.ToolProgressPart:
|
case *sdk.ToolProgressPart:
|
||||||
ch <- StreamEvent{
|
if !sendEvent(ctx, ch, StreamEvent{
|
||||||
Type: EventToolCallProgress,
|
Type: EventToolCallProgress,
|
||||||
ToolName: p.ToolName,
|
ToolName: p.ToolName,
|
||||||
ToolCallID: p.ToolCallID,
|
ToolCallID: p.ToolCallID,
|
||||||
Progress: p.Content,
|
Progress: p.Content,
|
||||||
|
}) {
|
||||||
|
aborted = true
|
||||||
}
|
}
|
||||||
|
|
||||||
case *sdk.StreamToolResultPart:
|
case *sdk.StreamToolResultPart:
|
||||||
@@ -345,10 +353,14 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
|||||||
}
|
}
|
||||||
if shouldAbort {
|
if shouldAbort {
|
||||||
a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", p.ToolCallID))
|
a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", p.ToolCallID))
|
||||||
|
cancel(ErrToolLoopDetected)
|
||||||
aborted = true
|
aborted = true
|
||||||
}
|
}
|
||||||
|
|
||||||
case *sdk.StreamToolErrorPart:
|
case *sdk.StreamToolErrorPart:
|
||||||
|
// Take before errors.Is so registry IDs from the loop guard are always cleared.
|
||||||
|
tookLoopAbort := toolLoopAbortCallIDs.Take(p.ToolCallID)
|
||||||
|
shouldAbort := errors.Is(p.Error, ErrToolLoopDetected) || tookLoopAbort
|
||||||
if !sendEvent(ctx, ch, StreamEvent{
|
if !sendEvent(ctx, ch, StreamEvent{
|
||||||
Type: EventToolCallEnd,
|
Type: EventToolCallEnd,
|
||||||
ToolName: p.ToolName,
|
ToolName: p.ToolName,
|
||||||
@@ -357,6 +369,11 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
|||||||
}) {
|
}) {
|
||||||
aborted = true
|
aborted = true
|
||||||
}
|
}
|
||||||
|
if shouldAbort {
|
||||||
|
a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", p.ToolCallID))
|
||||||
|
cancel(ErrToolLoopDetected)
|
||||||
|
aborted = true
|
||||||
|
}
|
||||||
|
|
||||||
case *sdk.StreamFilePart:
|
case *sdk.StreamFilePart:
|
||||||
mediaType := p.File.MediaType
|
mediaType := p.File.MediaType
|
||||||
@@ -384,7 +401,8 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
|||||||
// no work has been completed yet and retrying from the start is safe.
|
// no work has been completed yet and retrying from the start is safe.
|
||||||
if isRetryableStreamError(p.Error) {
|
if isRetryableStreamError(p.Error) {
|
||||||
streamResult, aborted = a.runMidStreamRetry(
|
streamResult, aborted = a.runMidStreamRetry(
|
||||||
ctx, ch, cfg, sdkTools, prepareStep, streamResult,
|
ctx, streamCtx, cancel, toolLoopAbortCallIDs,
|
||||||
|
ch, cfg, sdkTools, prepareStep, streamResult,
|
||||||
stepNumber, errMsg, &allText, textLoopProbeBuffer,
|
stepNumber, errMsg, &allText, textLoopProbeBuffer,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
@@ -403,6 +421,11 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if aborted {
|
||||||
|
for range streamResult.Stream {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if textLoopProbeBuffer != nil {
|
if textLoopProbeBuffer != nil {
|
||||||
textLoopProbeBuffer.Flush()
|
textLoopProbeBuffer.Flush()
|
||||||
}
|
}
|
||||||
@@ -452,12 +475,10 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
|||||||
loopAbort := newLoopAbortState()
|
loopAbort := newLoopAbortState()
|
||||||
|
|
||||||
// Collecting emitter: tools push side-effect events here during generation.
|
// Collecting emitter: tools push side-effect events here during generation.
|
||||||
var collected []tools.ToolStreamEvent
|
collected := newToolEventCollector()
|
||||||
var collectedMu sync.Mutex
|
defer collected.Close()
|
||||||
collectEmitter := tools.StreamEmitter(func(evt tools.ToolStreamEvent) {
|
collectEmitter := tools.StreamEmitter(func(evt tools.ToolStreamEvent) {
|
||||||
collectedMu.Lock()
|
collected.Add(evt)
|
||||||
defer collectedMu.Unlock()
|
|
||||||
collected = append(collected, evt)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
var sdkTools []sdk.Tool
|
var sdkTools []sdk.Tool
|
||||||
@@ -536,10 +557,11 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Drain collected tool-emitted side effects into the result.
|
// Drain collected tool-emitted side effects into the result.
|
||||||
|
collectedEvents := collected.CloseAndSnapshot()
|
||||||
var attachments []FileAttachment
|
var attachments []FileAttachment
|
||||||
var reactions []ReactionItem
|
var reactions []ReactionItem
|
||||||
var speeches []SpeechItem
|
var speeches []SpeechItem
|
||||||
for _, evt := range collected {
|
for _, evt := range collectedEvents {
|
||||||
switch evt.Type {
|
switch evt.Type {
|
||||||
case tools.StreamEventAttachment:
|
case tools.StreamEventAttachment:
|
||||||
for _, a := range evt.Attachments {
|
for _, a := range evt.Attachments {
|
||||||
@@ -860,8 +882,15 @@ func pruneOldToolResults(p *sdk.GenerateParams, keepSteps, threshold int) *sdk.G
|
|||||||
// runMidStreamRetry attempts to continue the agent stream after a retryable
|
// runMidStreamRetry attempts to continue the agent stream after a retryable
|
||||||
// mid-stream error. It re-invokes StreamText with the accumulated messages
|
// mid-stream error. It re-invokes StreamText with the accumulated messages
|
||||||
// and drains the new stream into the same output channel.
|
// and drains the new stream into the same output channel.
|
||||||
|
//
|
||||||
|
// sendCtx is used for sendEvent so consumer disconnect (parent ctx) still
|
||||||
|
// controls channel back-pressure; streamCtx is passed to the SDK for the same
|
||||||
|
// cancellation semantics as the main stream (including loop-detect cancel).
|
||||||
func (a *Agent) runMidStreamRetry(
|
func (a *Agent) runMidStreamRetry(
|
||||||
ctx context.Context,
|
sendCtx context.Context,
|
||||||
|
streamCtx context.Context,
|
||||||
|
cancel context.CancelCauseFunc,
|
||||||
|
toolLoopAbortCallIDs *toolAbortRegistry,
|
||||||
ch chan<- StreamEvent,
|
ch chan<- StreamEvent,
|
||||||
cfg RunConfig,
|
cfg RunConfig,
|
||||||
sdkTools []sdk.Tool,
|
sdkTools []sdk.Tool,
|
||||||
@@ -872,6 +901,13 @@ func (a *Agent) runMidStreamRetry(
|
|||||||
allText *strings.Builder,
|
allText *strings.Builder,
|
||||||
textLoopProbeBuffer *TextLoopProbeBuffer,
|
textLoopProbeBuffer *TextLoopProbeBuffer,
|
||||||
) (*sdk.StreamResult, bool) {
|
) (*sdk.StreamResult, bool) {
|
||||||
|
// Drain the previous stream before reading prevResult.Messages.
|
||||||
|
// This avoids racing with the SDK's final StreamResult write.
|
||||||
|
if prevResult.Stream != nil {
|
||||||
|
for range prevResult.Stream {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
retryCfg := DefaultRetryConfig()
|
retryCfg := DefaultRetryConfig()
|
||||||
for attempt := 0; attempt < retryCfg.MaxAttempts; attempt++ {
|
for attempt := 0; attempt < retryCfg.MaxAttempts; attempt++ {
|
||||||
a.logger.Warn("mid-stream error, retrying",
|
a.logger.Warn("mid-stream error, retrying",
|
||||||
@@ -880,7 +916,7 @@ func (a *Agent) runMidStreamRetry(
|
|||||||
slog.Int("max_attempts", retryCfg.MaxAttempts),
|
slog.Int("max_attempts", retryCfg.MaxAttempts),
|
||||||
slog.String("error", errMsg),
|
slog.String("error", errMsg),
|
||||||
)
|
)
|
||||||
if !sendEvent(ctx, ch, StreamEvent{
|
if !sendEvent(sendCtx, ch, StreamEvent{
|
||||||
Type: EventRetry,
|
Type: EventRetry,
|
||||||
Attempt: attempt + 1,
|
Attempt: attempt + 1,
|
||||||
MaxAttempt: retryCfg.MaxAttempts,
|
MaxAttempt: retryCfg.MaxAttempts,
|
||||||
@@ -891,7 +927,7 @@ func (a *Agent) runMidStreamRetry(
|
|||||||
|
|
||||||
delay := retryDelay(attempt, retryCfg)
|
delay := retryDelay(attempt, retryCfg)
|
||||||
if delay > 0 {
|
if delay > 0 {
|
||||||
if err := sleepWithContext(ctx, delay); err != nil {
|
if err := sleepWithContext(streamCtx, delay); err != nil {
|
||||||
return prevResult, true // aborted
|
return prevResult, true // aborted
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -903,7 +939,7 @@ func (a *Agent) runMidStreamRetry(
|
|||||||
retryCfgCopy.Messages = prevResult.Messages
|
retryCfgCopy.Messages = prevResult.Messages
|
||||||
retryOpts := a.buildGenerateOptions(retryCfgCopy, sdkTools, prepareStep)
|
retryOpts := a.buildGenerateOptions(retryCfgCopy, sdkTools, prepareStep)
|
||||||
|
|
||||||
retryResult, retryErr := a.client.StreamText(ctx, retryOpts...)
|
retryResult, retryErr := a.client.StreamText(streamCtx, retryOpts...)
|
||||||
if retryErr != nil {
|
if retryErr != nil {
|
||||||
a.logger.Warn("mid-stream retry failed to start",
|
a.logger.Warn("mid-stream retry failed to start",
|
||||||
slog.Int("attempt", attempt+1),
|
slog.Int("attempt", attempt+1),
|
||||||
@@ -917,9 +953,13 @@ func (a *Agent) runMidStreamRetry(
|
|||||||
// Drain the retry stream into the main event loop
|
// Drain the retry stream into the main event loop
|
||||||
aborted := false
|
aborted := false
|
||||||
for retryPart := range retryResult.Stream {
|
for retryPart := range retryResult.Stream {
|
||||||
|
if streamCtx.Err() != nil {
|
||||||
|
aborted = true
|
||||||
|
break
|
||||||
|
}
|
||||||
switch rp := retryPart.(type) {
|
switch rp := retryPart.(type) {
|
||||||
case *sdk.TextStartPart:
|
case *sdk.TextStartPart:
|
||||||
if !sendEvent(ctx, ch, StreamEvent{Type: EventTextStart}) {
|
if !sendEvent(sendCtx, ch, StreamEvent{Type: EventTextStart}) {
|
||||||
aborted = true
|
aborted = true
|
||||||
}
|
}
|
||||||
case *sdk.TextDeltaPart:
|
case *sdk.TextDeltaPart:
|
||||||
@@ -927,7 +967,7 @@ func (a *Agent) runMidStreamRetry(
|
|||||||
if textLoopProbeBuffer != nil {
|
if textLoopProbeBuffer != nil {
|
||||||
textLoopProbeBuffer.Push(rp.Text)
|
textLoopProbeBuffer.Push(rp.Text)
|
||||||
}
|
}
|
||||||
if !sendEvent(ctx, ch, StreamEvent{Type: EventTextDelta, Delta: rp.Text}) {
|
if !sendEvent(sendCtx, ch, StreamEvent{Type: EventTextDelta, Delta: rp.Text}) {
|
||||||
aborted = true
|
aborted = true
|
||||||
}
|
}
|
||||||
allText.WriteString(rp.Text)
|
allText.WriteString(rp.Text)
|
||||||
@@ -937,14 +977,14 @@ func (a *Agent) runMidStreamRetry(
|
|||||||
textLoopProbeBuffer.Flush()
|
textLoopProbeBuffer.Flush()
|
||||||
}
|
}
|
||||||
stepNumber++
|
stepNumber++
|
||||||
if !sendEvent(ctx, ch, StreamEvent{Type: EventTextEnd}) {
|
if !sendEvent(sendCtx, ch, StreamEvent{Type: EventTextEnd}) {
|
||||||
aborted = true
|
aborted = true
|
||||||
}
|
}
|
||||||
case *sdk.ToolInputStartPart:
|
case *sdk.ToolInputStartPart:
|
||||||
if textLoopProbeBuffer != nil {
|
if textLoopProbeBuffer != nil {
|
||||||
textLoopProbeBuffer.Flush()
|
textLoopProbeBuffer.Flush()
|
||||||
}
|
}
|
||||||
if !sendEvent(ctx, ch, StreamEvent{
|
if !sendEvent(sendCtx, ch, StreamEvent{
|
||||||
Type: EventToolCallStart,
|
Type: EventToolCallStart,
|
||||||
ToolName: rp.ToolName,
|
ToolName: rp.ToolName,
|
||||||
ToolCallID: rp.ID,
|
ToolCallID: rp.ID,
|
||||||
@@ -955,7 +995,7 @@ func (a *Agent) runMidStreamRetry(
|
|||||||
if textLoopProbeBuffer != nil {
|
if textLoopProbeBuffer != nil {
|
||||||
textLoopProbeBuffer.Flush()
|
textLoopProbeBuffer.Flush()
|
||||||
}
|
}
|
||||||
if !sendEvent(ctx, ch, StreamEvent{
|
if !sendEvent(sendCtx, ch, StreamEvent{
|
||||||
Type: EventToolCallStart,
|
Type: EventToolCallStart,
|
||||||
ToolName: rp.ToolName,
|
ToolName: rp.ToolName,
|
||||||
ToolCallID: rp.ToolCallID,
|
ToolCallID: rp.ToolCallID,
|
||||||
@@ -964,14 +1004,15 @@ func (a *Agent) runMidStreamRetry(
|
|||||||
aborted = true
|
aborted = true
|
||||||
}
|
}
|
||||||
case *sdk.StreamToolResultPart:
|
case *sdk.StreamToolResultPart:
|
||||||
|
shouldAbort := toolLoopAbortCallIDs.Take(rp.ToolCallID)
|
||||||
stepNumber++
|
stepNumber++
|
||||||
if !sendEvent(ctx, ch, StreamEvent{
|
if !sendEvent(sendCtx, ch, StreamEvent{
|
||||||
Type: EventToolCallEnd,
|
Type: EventToolCallEnd,
|
||||||
ToolName: rp.ToolName,
|
ToolName: rp.ToolName,
|
||||||
ToolCallID: rp.ToolCallID,
|
ToolCallID: rp.ToolCallID,
|
||||||
Input: rp.Input,
|
Input: rp.Input,
|
||||||
Result: rp.Output,
|
Result: rp.Output,
|
||||||
}) || !sendEvent(ctx, ch, StreamEvent{
|
}) || !sendEvent(sendCtx, ch, StreamEvent{
|
||||||
Type: EventProgress,
|
Type: EventProgress,
|
||||||
StepNumber: stepNumber,
|
StepNumber: stepNumber,
|
||||||
ToolName: rp.ToolName,
|
ToolName: rp.ToolName,
|
||||||
@@ -979,8 +1020,15 @@ func (a *Agent) runMidStreamRetry(
|
|||||||
}) {
|
}) {
|
||||||
aborted = true
|
aborted = true
|
||||||
}
|
}
|
||||||
|
if shouldAbort {
|
||||||
|
a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", rp.ToolCallID))
|
||||||
|
cancel(ErrToolLoopDetected)
|
||||||
|
aborted = true
|
||||||
|
}
|
||||||
case *sdk.StreamToolErrorPart:
|
case *sdk.StreamToolErrorPart:
|
||||||
if !sendEvent(ctx, ch, StreamEvent{
|
tookLoopAbort := toolLoopAbortCallIDs.Take(rp.ToolCallID)
|
||||||
|
shouldAbort := errors.Is(rp.Error, ErrToolLoopDetected) || tookLoopAbort
|
||||||
|
if !sendEvent(sendCtx, ch, StreamEvent{
|
||||||
Type: EventToolCallEnd,
|
Type: EventToolCallEnd,
|
||||||
ToolName: rp.ToolName,
|
ToolName: rp.ToolName,
|
||||||
ToolCallID: rp.ToolCallID,
|
ToolCallID: rp.ToolCallID,
|
||||||
@@ -988,8 +1036,13 @@ func (a *Agent) runMidStreamRetry(
|
|||||||
}) {
|
}) {
|
||||||
aborted = true
|
aborted = true
|
||||||
}
|
}
|
||||||
|
if shouldAbort {
|
||||||
|
a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", rp.ToolCallID))
|
||||||
|
cancel(ErrToolLoopDetected)
|
||||||
|
aborted = true
|
||||||
|
}
|
||||||
case *sdk.ErrorPart:
|
case *sdk.ErrorPart:
|
||||||
sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: rp.Error.Error()})
|
sendEvent(sendCtx, ch, StreamEvent{Type: EventError, Error: rp.Error.Error()})
|
||||||
aborted = true
|
aborted = true
|
||||||
case *sdk.AbortPart:
|
case *sdk.AbortPart:
|
||||||
aborted = true
|
aborted = true
|
||||||
@@ -1000,7 +1053,11 @@ func (a *Agent) runMidStreamRetry(
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return retryResult, aborted
|
if aborted {
|
||||||
|
for range retryResult.Stream {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return retryResult, aborted || detectGenerateLoopAbort(streamCtx, streamCtx.Err()) != nil
|
||||||
}
|
}
|
||||||
// All retry attempts failed
|
// All retry attempts failed
|
||||||
return prevResult, true
|
return prevResult, true
|
||||||
|
|||||||
@@ -488,17 +488,25 @@ func (m *Manager) RunningTasksSummary(botID, sessionID string) string {
|
|||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
var lines []string
|
var lines []string
|
||||||
for _, t := range m.tasks {
|
for _, t := range m.tasks {
|
||||||
if t.BotID == botID && t.SessionID == sessionID && t.Status == TaskRunning {
|
t.mu.Lock()
|
||||||
desc := t.Description
|
matches := t.BotID == botID && t.SessionID == sessionID && t.Status == TaskRunning
|
||||||
if desc == "" {
|
id := t.ID
|
||||||
desc = truncate(t.Command, 80)
|
desc := t.Description
|
||||||
}
|
command := t.Command
|
||||||
lines = append(lines, fmt.Sprintf("- [%s] %s (started %s ago, output: %s)",
|
startedAt := t.StartedAt
|
||||||
t.ID, desc,
|
outputFile := t.OutputFile
|
||||||
time.Since(t.StartedAt).Round(time.Second),
|
t.mu.Unlock()
|
||||||
t.OutputFile,
|
if !matches {
|
||||||
))
|
continue
|
||||||
}
|
}
|
||||||
|
if desc == "" {
|
||||||
|
desc = truncate(command, 80)
|
||||||
|
}
|
||||||
|
lines = append(lines, fmt.Sprintf("- [%s] %s (started %s ago, output: %s)",
|
||||||
|
id, desc,
|
||||||
|
time.Since(startedAt).Round(time.Second),
|
||||||
|
outputFile,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
if len(lines) == 0 {
|
if len(lines) == 0 {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ package agent
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/jsonschema-go/jsonschema"
|
"github.com/google/jsonschema-go/jsonschema"
|
||||||
sdk "github.com/memohai/twilight-ai/sdk"
|
sdk "github.com/memohai/twilight-ai/sdk"
|
||||||
@@ -19,10 +22,76 @@ func (p staticToolProvider) Tools(context.Context, agenttools.SessionContext) ([
|
|||||||
return p.tools, nil
|
return p.tools, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type atomicMockProvider struct {
|
||||||
|
calls atomic.Int32
|
||||||
|
handler func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error)
|
||||||
|
stream func(ctx context.Context, params sdk.GenerateParams) (*sdk.StreamResult, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*atomicMockProvider) Name() string {
|
||||||
|
return "mock"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*atomicMockProvider) ListModels(context.Context) ([]sdk.Model, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*atomicMockProvider) Test(context.Context) *sdk.ProviderTestResult {
|
||||||
|
return &sdk.ProviderTestResult{Status: sdk.ProviderStatusOK, Message: "ok"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*atomicMockProvider) TestModel(context.Context, string) (*sdk.ModelTestResult, error) {
|
||||||
|
return &sdk.ModelTestResult{Supported: true, Message: "supported"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *atomicMockProvider) DoGenerate(_ context.Context, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||||
|
call := int(m.calls.Add(1))
|
||||||
|
return m.handler(call, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *atomicMockProvider) DoStream(ctx context.Context, params sdk.GenerateParams) (*sdk.StreamResult, error) {
|
||||||
|
if m.stream != nil {
|
||||||
|
return m.stream(ctx, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := m.DoGenerate(ctx, params)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ch := make(chan sdk.StreamPart, 8)
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
ch <- &sdk.StartPart{}
|
||||||
|
ch <- &sdk.StartStepPart{}
|
||||||
|
if result.Text != "" {
|
||||||
|
ch <- &sdk.TextStartPart{ID: "mock"}
|
||||||
|
ch <- &sdk.TextDeltaPart{ID: "mock", Text: result.Text}
|
||||||
|
ch <- &sdk.TextEndPart{ID: "mock"}
|
||||||
|
}
|
||||||
|
for _, tc := range result.ToolCalls {
|
||||||
|
ch <- &sdk.StreamToolCallPart{
|
||||||
|
ToolCallID: tc.ToolCallID,
|
||||||
|
ToolName: tc.ToolName,
|
||||||
|
Input: tc.Input,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ch <- &sdk.FinishStepPart{
|
||||||
|
FinishReason: result.FinishReason,
|
||||||
|
Usage: result.Usage,
|
||||||
|
Response: result.Response,
|
||||||
|
}
|
||||||
|
ch <- &sdk.FinishPart{
|
||||||
|
FinishReason: result.FinishReason,
|
||||||
|
TotalUsage: result.Usage,
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return &sdk.StreamResult{Stream: ch}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestAgentGenerateStopsOnToolLoopAbort(t *testing.T) {
|
func TestAgentGenerateStopsOnToolLoopAbort(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
modelProvider := &agentReadMediaMockProvider{
|
modelProvider := &atomicMockProvider{
|
||||||
handler: func(_ int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
handler: func(_ int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||||
return &sdk.GenerateResult{
|
return &sdk.GenerateResult{
|
||||||
FinishReason: sdk.FinishReasonToolCalls,
|
FinishReason: sdk.FinishReasonToolCalls,
|
||||||
@@ -58,8 +127,8 @@ func TestAgentGenerateStopsOnToolLoopAbort(t *testing.T) {
|
|||||||
if !errors.Is(err, ErrToolLoopDetected) {
|
if !errors.Is(err, ErrToolLoopDetected) {
|
||||||
t.Fatalf("expected ErrToolLoopDetected, got %v", err)
|
t.Fatalf("expected ErrToolLoopDetected, got %v", err)
|
||||||
}
|
}
|
||||||
if modelProvider.calls >= 20 {
|
if modelProvider.calls.Load() >= 20 {
|
||||||
t.Fatalf("expected tool loop to stop generation, got %d provider calls", modelProvider.calls)
|
t.Fatalf("expected tool loop to stop generation, got %d provider calls", modelProvider.calls.Load())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,7 +136,7 @@ func TestAgentGenerateStopsOnTextLoopAbort(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
repeatedText := "abcdefghijklmnopqrstuvwxyz0123456789 repeated text chunk for loop detection"
|
repeatedText := "abcdefghijklmnopqrstuvwxyz0123456789 repeated text chunk for loop detection"
|
||||||
modelProvider := &agentReadMediaMockProvider{
|
modelProvider := &atomicMockProvider{
|
||||||
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||||
return &sdk.GenerateResult{
|
return &sdk.GenerateResult{
|
||||||
Text: repeatedText,
|
Text: repeatedText,
|
||||||
@@ -104,8 +173,8 @@ func TestAgentGenerateStopsOnTextLoopAbort(t *testing.T) {
|
|||||||
if !errors.Is(err, ErrTextLoopDetected) {
|
if !errors.Is(err, ErrTextLoopDetected) {
|
||||||
t.Fatalf("expected ErrTextLoopDetected, got %v", err)
|
t.Fatalf("expected ErrTextLoopDetected, got %v", err)
|
||||||
}
|
}
|
||||||
if modelProvider.calls >= 10 {
|
if modelProvider.calls.Load() >= 10 {
|
||||||
t.Fatalf("expected text loop to stop generation, got %d provider calls", modelProvider.calls)
|
t.Fatalf("expected text loop to stop generation, got %d provider calls", modelProvider.calls.Load())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,7 +182,7 @@ func TestAgentGenerateStopsOnTerminalTextLoopAbort(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
repeatedText := "abcdefghijklmnopqrstuvwxyz0123456789 repeated text chunk for loop detection"
|
repeatedText := "abcdefghijklmnopqrstuvwxyz0123456789 repeated text chunk for loop detection"
|
||||||
modelProvider := &agentReadMediaMockProvider{
|
modelProvider := &atomicMockProvider{
|
||||||
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||||
finishReason := sdk.FinishReasonToolCalls
|
finishReason := sdk.FinishReasonToolCalls
|
||||||
var toolCalls []sdk.ToolCall
|
var toolCalls []sdk.ToolCall
|
||||||
@@ -157,7 +226,314 @@ func TestAgentGenerateStopsOnTerminalTextLoopAbort(t *testing.T) {
|
|||||||
if !errors.Is(err, ErrTextLoopDetected) {
|
if !errors.Is(err, ErrTextLoopDetected) {
|
||||||
t.Fatalf("expected ErrTextLoopDetected, got %v", err)
|
t.Fatalf("expected ErrTextLoopDetected, got %v", err)
|
||||||
}
|
}
|
||||||
if modelProvider.calls != 4 {
|
if modelProvider.calls.Load() != 4 {
|
||||||
t.Fatalf("expected terminal text loop to abort on final step, got %d provider calls", modelProvider.calls)
|
t.Fatalf("expected terminal text loop to abort on final step, got %d provider calls", modelProvider.calls.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgentStreamStopsOnToolLoopAbort(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
modelProvider := &atomicMockProvider{
|
||||||
|
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||||
|
if call >= 20 {
|
||||||
|
return &sdk.GenerateResult{
|
||||||
|
Text: "unexpected-final-step",
|
||||||
|
FinishReason: sdk.FinishReasonStop,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return &sdk.GenerateResult{
|
||||||
|
FinishReason: sdk.FinishReasonToolCalls,
|
||||||
|
ToolCalls: []sdk.ToolCall{{
|
||||||
|
ToolCallID: "call-stream",
|
||||||
|
ToolName: "loop_tool",
|
||||||
|
Input: map[string]any{"query": "same"},
|
||||||
|
}},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
a := New(Deps{})
|
||||||
|
a.SetToolProviders([]agenttools.ToolProvider{
|
||||||
|
staticToolProvider{
|
||||||
|
tools: []sdk.Tool{{
|
||||||
|
Name: "loop_tool",
|
||||||
|
Parameters: &jsonschema.Schema{Type: "object"},
|
||||||
|
Execute: func(_ *sdk.ToolExecContext, _ any) (any, error) {
|
||||||
|
return map[string]any{"ok": true}, nil
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
var terminal StreamEvent
|
||||||
|
for event := range a.Stream(context.Background(), RunConfig{
|
||||||
|
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
||||||
|
Messages: []sdk.Message{sdk.UserMessage("loop stream")},
|
||||||
|
SupportsToolCall: true,
|
||||||
|
Identity: SessionContext{BotID: "bot-1"},
|
||||||
|
LoopDetection: LoopDetectionConfig{Enabled: true},
|
||||||
|
}) {
|
||||||
|
if event.IsTerminal() {
|
||||||
|
terminal = event
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if terminal.Type != EventAgentAbort {
|
||||||
|
t.Fatalf("expected EventAgentAbort, got %q", terminal.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgentStreamMarksTerminalTextLoopAsAbort(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
repeatedChunk := strings.Repeat("abcd", 64)
|
||||||
|
var observedCancel atomic.Bool
|
||||||
|
modelProvider := &atomicMockProvider{
|
||||||
|
stream: func(ctx context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) {
|
||||||
|
ch := make(chan sdk.StreamPart, 16)
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
send := func(part sdk.StreamPart) bool {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
observedCancel.Store(true)
|
||||||
|
return false
|
||||||
|
case ch <- part:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !send(&sdk.StartPart{}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !send(&sdk.StartStepPart{}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !send(&sdk.TextStartPart{ID: "mock"}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
if !send(&sdk.TextDeltaPart{ID: "mock", Text: repeatedChunk}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
observedCancel.Store(true)
|
||||||
|
return
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
}
|
||||||
|
if !send(&sdk.TextEndPart{ID: "mock"}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !send(&sdk.FinishStepPart{FinishReason: sdk.FinishReasonStop}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = send(&sdk.FinishPart{FinishReason: sdk.FinishReasonStop})
|
||||||
|
}()
|
||||||
|
return &sdk.StreamResult{Stream: ch}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
a := New(Deps{})
|
||||||
|
|
||||||
|
var terminal StreamEvent
|
||||||
|
for event := range a.Stream(context.Background(), RunConfig{
|
||||||
|
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
||||||
|
Messages: []sdk.Message{sdk.UserMessage("loop stream text")},
|
||||||
|
SupportsToolCall: true,
|
||||||
|
Identity: SessionContext{BotID: "bot-1"},
|
||||||
|
LoopDetection: LoopDetectionConfig{Enabled: true},
|
||||||
|
}) {
|
||||||
|
if event.IsTerminal() {
|
||||||
|
terminal = event
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !observedCancel.Load() {
|
||||||
|
t.Fatal("expected stream provider to observe context cancellation from text-loop abort")
|
||||||
|
}
|
||||||
|
if terminal.Type != EventAgentAbort {
|
||||||
|
t.Fatalf("expected EventAgentAbort, got %q", terminal.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgentStreamMarksRetryTextLoopAsAbort(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
repeatedChunk := strings.Repeat("abcd", 64)
|
||||||
|
var streamCalls atomic.Int32
|
||||||
|
var observedCancel atomic.Bool
|
||||||
|
modelProvider := &atomicMockProvider{
|
||||||
|
stream: func(ctx context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) {
|
||||||
|
call := streamCalls.Add(1)
|
||||||
|
ch := make(chan sdk.StreamPart, 16)
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
send := func(part sdk.StreamPart) bool {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
observedCancel.Store(true)
|
||||||
|
return false
|
||||||
|
case ch <- part:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !send(&sdk.StartPart{}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !send(&sdk.StartStepPart{}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if call == 1 {
|
||||||
|
_ = send(&sdk.ErrorPart{Error: errors.New("api error 500")})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !send(&sdk.TextStartPart{ID: "mock-retry"}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
if !send(&sdk.TextDeltaPart{ID: "mock-retry", Text: repeatedChunk}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
observedCancel.Store(true)
|
||||||
|
return
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
}
|
||||||
|
if !send(&sdk.TextEndPart{ID: "mock-retry"}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !send(&sdk.FinishStepPart{FinishReason: sdk.FinishReasonStop}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = send(&sdk.FinishPart{FinishReason: sdk.FinishReasonStop})
|
||||||
|
}()
|
||||||
|
return &sdk.StreamResult{Stream: ch}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
a := New(Deps{})
|
||||||
|
|
||||||
|
var terminal StreamEvent
|
||||||
|
for event := range a.Stream(context.Background(), RunConfig{
|
||||||
|
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
||||||
|
Messages: []sdk.Message{sdk.UserMessage("loop stream retry text")},
|
||||||
|
SupportsToolCall: true,
|
||||||
|
Identity: SessionContext{BotID: "bot-1"},
|
||||||
|
LoopDetection: LoopDetectionConfig{Enabled: true},
|
||||||
|
}) {
|
||||||
|
if event.IsTerminal() {
|
||||||
|
terminal = event
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if streamCalls.Load() != 2 {
|
||||||
|
t.Fatalf("expected one retry stream attempt, got %d stream calls", streamCalls.Load())
|
||||||
|
}
|
||||||
|
if !observedCancel.Load() {
|
||||||
|
t.Fatal("expected retry stream provider to observe context cancellation from text-loop abort")
|
||||||
|
}
|
||||||
|
if terminal.Type != EventAgentAbort {
|
||||||
|
t.Fatalf("expected EventAgentAbort after retry text-loop abort, got %q", terminal.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunMidStreamRetryMarksTextLoopCancellationAsAborted(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
repeatedChunk := strings.Repeat("abcd", 64)
|
||||||
|
var observedCancel atomic.Bool
|
||||||
|
modelProvider := &atomicMockProvider{
|
||||||
|
stream: func(ctx context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) {
|
||||||
|
ch := make(chan sdk.StreamPart)
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
send := func(part sdk.StreamPart) bool {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
observedCancel.Store(true)
|
||||||
|
return false
|
||||||
|
case ch <- part:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !send(&sdk.StartPart{}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !send(&sdk.StartStepPart{}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !send(&sdk.TextStartPart{ID: "mock-retry-only"}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
if !send(&sdk.TextDeltaPart{ID: "mock-retry-only", Text: repeatedChunk}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
observedCancel.Store(true)
|
||||||
|
return
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Error("expected text-loop detection to cancel retry stream before any extra part was sent")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return &sdk.StreamResult{Stream: ch}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
a := New(Deps{})
|
||||||
|
streamCtx, cancel := context.WithCancelCause(context.Background())
|
||||||
|
defer cancel(nil)
|
||||||
|
|
||||||
|
textLoopGuard := NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{})
|
||||||
|
textLoopProbeBuffer := NewTextLoopProbeBuffer(LoopDetectedProbeChars, func(text string) {
|
||||||
|
result := textLoopGuard.Inspect(text)
|
||||||
|
if result.Abort {
|
||||||
|
cancel(ErrTextLoopDetected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
retryResult, aborted := a.runMidStreamRetry(
|
||||||
|
context.Background(),
|
||||||
|
streamCtx,
|
||||||
|
cancel,
|
||||||
|
newToolAbortRegistry(),
|
||||||
|
make(chan StreamEvent, 32),
|
||||||
|
RunConfig{
|
||||||
|
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
||||||
|
Messages: []sdk.Message{sdk.UserMessage("retry text loop")},
|
||||||
|
Identity: SessionContext{BotID: "bot-1"},
|
||||||
|
LoopDetection: LoopDetectionConfig{Enabled: true},
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
&sdk.StreamResult{Messages: []sdk.Message{sdk.UserMessage("previous step")}},
|
||||||
|
0,
|
||||||
|
"api error 500",
|
||||||
|
&strings.Builder{},
|
||||||
|
textLoopProbeBuffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
if retryResult == nil {
|
||||||
|
t.Fatal("expected retry result")
|
||||||
|
}
|
||||||
|
if !observedCancel.Load() {
|
||||||
|
t.Fatal("expected retry stream provider to observe context cancellation from text-loop abort")
|
||||||
|
}
|
||||||
|
if !errors.Is(context.Cause(streamCtx), ErrTextLoopDetected) {
|
||||||
|
t.Fatalf("expected stream context cause ErrTextLoopDetected, got %v", context.Cause(streamCtx))
|
||||||
|
}
|
||||||
|
if !aborted {
|
||||||
|
t.Fatal("expected runMidStreamRetry to report aborted when retry stream hit text-loop cancellation")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
package agent
|
package agent
|
||||||
|
|
||||||
import "sync"
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/memohai/memoh/internal/agent/tools"
|
||||||
|
)
|
||||||
|
|
||||||
type toolAbortRegistry struct {
|
type toolAbortRegistry struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -46,3 +50,65 @@ func (r *toolAbortRegistry) Any() bool {
|
|||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
return len(r.ids) > 0
|
return len(r.ids) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type toolEventCollector struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
closed bool
|
||||||
|
events []tools.ToolStreamEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
func newToolEventCollector() *toolEventCollector {
|
||||||
|
return &toolEventCollector{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *toolEventCollector) Add(evt tools.ToolStreamEvent) bool {
|
||||||
|
if c == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if c.closed {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
c.events = append(c.events, evt)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *toolEventCollector) Close() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.closed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *toolEventCollector) CloseAndSnapshot() []tools.ToolStreamEvent {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.closed = true
|
||||||
|
snapshot := make([]tools.ToolStreamEvent, len(c.events))
|
||||||
|
copy(snapshot, c.events)
|
||||||
|
return snapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
// Snapshot returns a copy of collected events without closing the collector.
|
||||||
|
// Callers that own the collector lifetime should still invoke Close (or
|
||||||
|
// CloseAndSnapshot) so late emits are rejected.
|
||||||
|
func (c *toolEventCollector) Snapshot() []tools.ToolStreamEvent {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
out := make([]tools.ToolStreamEvent, len(c.events))
|
||||||
|
copy(out, c.events)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ package agent
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/memohai/memoh/internal/agent/tools"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestToolAbortRegistryConcurrentAccess(t *testing.T) {
|
func TestToolAbortRegistryConcurrentAccess(t *testing.T) {
|
||||||
@@ -30,3 +33,33 @@ func TestToolAbortRegistryConcurrentAccess(t *testing.T) {
|
|||||||
t.Fatal("expected registry to be empty")
|
t.Fatal("expected registry to be empty")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToolEventCollectorCloseIgnoresLateAdds(t *testing.T) {
|
||||||
|
collector := newToolEventCollector()
|
||||||
|
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
if !collector.Add(tools.ToolStreamEvent{Type: tools.StreamEventSpawnHeartbeat}) {
|
||||||
|
t.Fatalf("unexpected add failure before close, i=%d", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
snapshot := collector.CloseAndSnapshot()
|
||||||
|
if len(snapshot) != 5 {
|
||||||
|
t.Fatalf("expected snapshot len 5, got %d", len(snapshot))
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var postCloseAdds atomic.Int32
|
||||||
|
for i := 0; i < 16; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if collector.Add(tools.ToolStreamEvent{Type: tools.StreamEventSpawnHeartbeat}) {
|
||||||
|
postCloseAdds.Add(1)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
if postCloseAdds.Load() != 0 {
|
||||||
|
t.Fatalf("expected 0 successful adds after close, got %d", postCloseAdds.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user