mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
fix(agent): stop generate loop aborts correctly
This commit is contained in:
+66
-5
@@ -435,6 +435,10 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) {
|
func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) {
|
||||||
|
genCtx, cancel := context.WithCancelCause(ctx)
|
||||||
|
defer cancel(nil)
|
||||||
|
loopAbort := newLoopAbortState()
|
||||||
|
|
||||||
// Collecting emitter: tools push side-effect events here during generation.
|
// Collecting emitter: tools push side-effect events here during generation.
|
||||||
var collected []tools.ToolStreamEvent
|
var collected []tools.ToolStreamEvent
|
||||||
var collectedMu sync.Mutex
|
var collectedMu sync.Mutex
|
||||||
@@ -447,7 +451,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
|||||||
var sdkTools []sdk.Tool
|
var sdkTools []sdk.Tool
|
||||||
if cfg.SupportsToolCall {
|
if cfg.SupportsToolCall {
|
||||||
var err error
|
var err error
|
||||||
sdkTools, err = a.assembleTools(ctx, cfg, collectEmitter)
|
sdkTools, err = a.assembleTools(genCtx, cfg, collectEmitter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("assemble tools: %w", err)
|
return nil, fmt.Errorf("assemble tools: %w", err)
|
||||||
}
|
}
|
||||||
@@ -491,12 +495,16 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
|||||||
sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams {
|
sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams {
|
||||||
if cfg.LoopDetection.Enabled {
|
if cfg.LoopDetection.Enabled {
|
||||||
if toolLoopAbortCallIDs.Any() {
|
if toolLoopAbortCallIDs.Any() {
|
||||||
return nil // stop
|
loopAbort.Set(ErrToolLoopDetected)
|
||||||
|
cancel(ErrToolLoopDetected)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
if textLoopGuard != nil && isNonEmptyString(step.Text) {
|
if textLoopGuard != nil && isNonEmptyString(step.Text) {
|
||||||
result := textLoopGuard.Inspect(step.Text)
|
result := textLoopGuard.Inspect(step.Text)
|
||||||
if result.Abort {
|
if result.Abort {
|
||||||
return nil // stop
|
loopAbort.Set(ErrTextLoopDetected)
|
||||||
|
cancel(ErrTextLoopDetected)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -504,10 +512,16 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
genResult, err := a.client.GenerateTextResult(ctx, opts...)
|
genResult, err := a.client.GenerateTextResult(genCtx, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if loopErr := detectGenerateLoopAbort(genCtx, err); loopErr != nil {
|
||||||
|
return nil, loopErr
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("generate: %w", err)
|
return nil, fmt.Errorf("generate: %w", err)
|
||||||
}
|
}
|
||||||
|
if loopErr := loopAbort.Err(); loopErr != nil {
|
||||||
|
return nil, loopErr
|
||||||
|
}
|
||||||
|
|
||||||
// Drain collected tool-emitted side effects into the result.
|
// Drain collected tool-emitted side effects into the result.
|
||||||
var attachments []FileAttachment
|
var attachments []FileAttachment
|
||||||
@@ -717,7 +731,7 @@ func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs
|
|||||||
"type": "text",
|
"type": "text",
|
||||||
"text": ToolLoopDetectedAbortMessage,
|
"text": ToolLoopDetectedAbortMessage,
|
||||||
}},
|
}},
|
||||||
}, errors.New(ToolLoopDetectedAbortMessage)
|
}, ErrToolLoopDetected
|
||||||
}
|
}
|
||||||
if warn {
|
if warn {
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
@@ -979,3 +993,50 @@ func sleepWithContext(ctx context.Context, d time.Duration) error {
|
|||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func detectGenerateLoopAbort(ctx context.Context, err error) error {
|
||||||
|
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cause := context.Cause(ctx)
|
||||||
|
switch {
|
||||||
|
case errors.Is(cause, ErrToolLoopDetected):
|
||||||
|
return ErrToolLoopDetected
|
||||||
|
case errors.Is(cause, ErrTextLoopDetected):
|
||||||
|
return ErrTextLoopDetected
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type loopAbortState struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLoopAbortState() *loopAbortState {
|
||||||
|
return &loopAbortState{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *loopAbortState) Set(err error) {
|
||||||
|
if s == nil || err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if s.err == nil {
|
||||||
|
s.err = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *loopAbortState) Err() error {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return s.err
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,163 @@
|
|||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/jsonschema-go/jsonschema"
|
||||||
|
sdk "github.com/memohai/twilight-ai/sdk"
|
||||||
|
|
||||||
|
agenttools "github.com/memohai/memoh/internal/agent/tools"
|
||||||
|
)
|
||||||
|
|
||||||
|
type staticToolProvider struct {
|
||||||
|
tools []sdk.Tool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p staticToolProvider) Tools(context.Context, agenttools.SessionContext) ([]sdk.Tool, error) {
|
||||||
|
return p.tools, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgentGenerateStopsOnToolLoopAbort(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
modelProvider := &agentReadMediaMockProvider{
|
||||||
|
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||||
|
return &sdk.GenerateResult{
|
||||||
|
FinishReason: sdk.FinishReasonToolCalls,
|
||||||
|
ToolCalls: []sdk.ToolCall{{
|
||||||
|
ToolCallID: "call-same",
|
||||||
|
ToolName: "loop_tool",
|
||||||
|
Input: map[string]any{"query": "same"},
|
||||||
|
}},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
a := New(Deps{})
|
||||||
|
a.SetToolProviders([]agenttools.ToolProvider{
|
||||||
|
staticToolProvider{
|
||||||
|
tools: []sdk.Tool{{
|
||||||
|
Name: "loop_tool",
|
||||||
|
Parameters: &jsonschema.Schema{Type: "object"},
|
||||||
|
Execute: func(ctx *sdk.ToolExecContext, input any) (any, error) {
|
||||||
|
return map[string]any{"ok": true}, nil
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := a.Generate(context.Background(), RunConfig{
|
||||||
|
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
||||||
|
Messages: []sdk.Message{sdk.UserMessage("loop")},
|
||||||
|
SupportsToolCall: true,
|
||||||
|
Identity: SessionContext{BotID: "bot-1"},
|
||||||
|
LoopDetection: LoopDetectionConfig{Enabled: true},
|
||||||
|
})
|
||||||
|
if !errors.Is(err, ErrToolLoopDetected) {
|
||||||
|
t.Fatalf("expected ErrToolLoopDetected, got %v", err)
|
||||||
|
}
|
||||||
|
if modelProvider.calls >= 20 {
|
||||||
|
t.Fatalf("expected tool loop to stop generation, got %d provider calls", modelProvider.calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgentGenerateStopsOnTextLoopAbort(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
repeatedText := "abcdefghijklmnopqrstuvwxyz0123456789 repeated text chunk for loop detection"
|
||||||
|
modelProvider := &agentReadMediaMockProvider{
|
||||||
|
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||||
|
return &sdk.GenerateResult{
|
||||||
|
Text: repeatedText,
|
||||||
|
FinishReason: sdk.FinishReasonToolCalls,
|
||||||
|
ToolCalls: []sdk.ToolCall{{
|
||||||
|
ToolCallID: "call-text",
|
||||||
|
ToolName: "noop_tool",
|
||||||
|
Input: map[string]any{"step": call},
|
||||||
|
}},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
a := New(Deps{})
|
||||||
|
a.SetToolProviders([]agenttools.ToolProvider{
|
||||||
|
staticToolProvider{
|
||||||
|
tools: []sdk.Tool{{
|
||||||
|
Name: "noop_tool",
|
||||||
|
Parameters: &jsonschema.Schema{Type: "object"},
|
||||||
|
Execute: func(ctx *sdk.ToolExecContext, input any) (any, error) {
|
||||||
|
return map[string]any{"ok": true}, nil
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := a.Generate(context.Background(), RunConfig{
|
||||||
|
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
||||||
|
Messages: []sdk.Message{sdk.UserMessage("loop text")},
|
||||||
|
SupportsToolCall: true,
|
||||||
|
Identity: SessionContext{BotID: "bot-1"},
|
||||||
|
LoopDetection: LoopDetectionConfig{Enabled: true},
|
||||||
|
})
|
||||||
|
if !errors.Is(err, ErrTextLoopDetected) {
|
||||||
|
t.Fatalf("expected ErrTextLoopDetected, got %v", err)
|
||||||
|
}
|
||||||
|
if modelProvider.calls >= 10 {
|
||||||
|
t.Fatalf("expected text loop to stop generation, got %d provider calls", modelProvider.calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgentGenerateStopsOnTerminalTextLoopAbort(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
repeatedText := "abcdefghijklmnopqrstuvwxyz0123456789 repeated text chunk for loop detection"
|
||||||
|
modelProvider := &agentReadMediaMockProvider{
|
||||||
|
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||||
|
finishReason := sdk.FinishReasonToolCalls
|
||||||
|
var toolCalls []sdk.ToolCall
|
||||||
|
if call < 4 {
|
||||||
|
toolCalls = []sdk.ToolCall{{
|
||||||
|
ToolCallID: "call-terminal",
|
||||||
|
ToolName: "noop_tool",
|
||||||
|
Input: map[string]any{"step": call},
|
||||||
|
}}
|
||||||
|
} else {
|
||||||
|
finishReason = sdk.FinishReasonStop
|
||||||
|
}
|
||||||
|
return &sdk.GenerateResult{
|
||||||
|
Text: repeatedText,
|
||||||
|
FinishReason: finishReason,
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
a := New(Deps{})
|
||||||
|
a.SetToolProviders([]agenttools.ToolProvider{
|
||||||
|
staticToolProvider{
|
||||||
|
tools: []sdk.Tool{{
|
||||||
|
Name: "noop_tool",
|
||||||
|
Parameters: &jsonschema.Schema{Type: "object"},
|
||||||
|
Execute: func(ctx *sdk.ToolExecContext, input any) (any, error) {
|
||||||
|
return map[string]any{"ok": true}, nil
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := a.Generate(context.Background(), RunConfig{
|
||||||
|
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
||||||
|
Messages: []sdk.Message{sdk.UserMessage("loop text terminal")},
|
||||||
|
SupportsToolCall: true,
|
||||||
|
Identity: SessionContext{BotID: "bot-1"},
|
||||||
|
LoopDetection: LoopDetectionConfig{Enabled: true},
|
||||||
|
})
|
||||||
|
if !errors.Is(err, ErrTextLoopDetected) {
|
||||||
|
t.Fatalf("expected ErrTextLoopDetected, got %v", err)
|
||||||
|
}
|
||||||
|
if modelProvider.calls != 4 {
|
||||||
|
t.Fatalf("expected terminal text loop to abort on final step, got %d provider calls", modelProvider.calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -31,6 +32,11 @@ const (
|
|||||||
defaultMinNewGramsPerChunk = 1
|
defaultMinNewGramsPerChunk = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrTextLoopDetected = errors.New(LoopDetectedAbortMessage)
|
||||||
|
ErrToolLoopDetected = errors.New(ToolLoopDetectedAbortMessage)
|
||||||
|
)
|
||||||
|
|
||||||
// --- Sential: n-gram overlap detector ---
|
// --- Sential: n-gram overlap detector ---
|
||||||
|
|
||||||
// SentialOptions configures the n-gram overlap detector.
|
// SentialOptions configures the n-gram overlap detector.
|
||||||
|
|||||||
Reference in New Issue
Block a user