Files
Memoh/internal/agent/agent.go
T
2026-04-11 13:29:41 +08:00

544 lines
16 KiB
Go

package agent
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"strings"
sdk "github.com/memohai/twilight-ai/sdk"
"github.com/memohai/memoh/internal/agent/tools"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/workspace/bridge"
)
// Agent is the core agent that handles LLM interactions.
type Agent struct {
client *sdk.Client
toolProviders []tools.ToolProvider
bridgeProvider bridge.Provider
logger *slog.Logger
}
// New creates a new Agent with the given dependencies.
func New(deps Deps) *Agent {
logger := deps.Logger
if logger == nil {
logger = slog.Default()
}
return &Agent{
client: sdk.NewClient(),
bridgeProvider: deps.BridgeProvider,
logger: logger.With(slog.String("service", "agent")),
}
}
// BridgeProvider returns the underlying bridge provider (workspace manager).
func (a *Agent) BridgeProvider() bridge.Provider {
return a.bridgeProvider
}
// SetToolProviders sets the tool providers after construction.
// This allows breaking dependency cycles in the DI graph.
func (a *Agent) SetToolProviders(providers []tools.ToolProvider) {
a.toolProviders = providers
}
// Stream runs the agent in streaming mode, emitting events to the returned channel.
func (a *Agent) Stream(ctx context.Context, cfg RunConfig) <-chan StreamEvent {
ch := make(chan StreamEvent)
go func() {
defer close(ch)
a.runStream(ctx, cfg, ch)
}()
return ch
}
// Generate runs the agent in non-streaming mode, returning the complete result.
func (a *Agent) Generate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) {
return a.runGenerate(ctx, cfg)
}
func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEvent) {
// Stream emitter: tools targeting the current conversation push
// side-effect events (attachments, reactions, speech) directly here.
streamEmitter := tools.StreamEmitter(func(evt tools.ToolStreamEvent) {
ch <- toolStreamEventToAgentEvent(evt)
})
var sdkTools []sdk.Tool
if cfg.SupportsToolCall {
var err error
sdkTools, err = a.assembleTools(ctx, cfg, streamEmitter)
if err != nil {
ch <- StreamEvent{Type: EventError, Error: fmt.Sprintf("assemble tools: %v", err)}
return
}
}
sdkTools, readMediaState := decorateReadMediaTools(cfg.Model, sdkTools)
// Loop detection setup
var textLoopGuard *TextLoopGuard
var textLoopProbeBuffer *TextLoopProbeBuffer
var toolLoopGuard *ToolLoopGuard
toolLoopAbortCallIDs := make(map[string]struct{})
if cfg.LoopDetection.Enabled {
textLoopGuard = NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{})
textLoopProbeBuffer = NewTextLoopProbeBuffer(LoopDetectedProbeChars, func(text string) {
result := textLoopGuard.Inspect(text)
if result.Abort {
a.logger.Warn("text loop detected, will abort")
}
})
toolLoopGuard = NewToolLoopGuard(ToolLoopRepeatThreshold, ToolLoopWarningsBeforeAbort)
}
// Wrap tools with loop detection
if toolLoopGuard != nil {
sdkTools = wrapToolsWithLoopGuard(sdkTools, toolLoopGuard, toolLoopAbortCallIDs)
}
var prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams
if readMediaState != nil {
prepareStep = readMediaState.prepareStep
}
initialMsgCount := len(cfg.Messages)
if cfg.InjectCh != nil {
basePrepare := prepareStep
prepareStep = func(p *sdk.GenerateParams) *sdk.GenerateParams {
if basePrepare != nil {
if override := basePrepare(p); override != nil {
p = override
}
}
for {
select {
case injected, ok := <-cfg.InjectCh:
if !ok {
break
}
text := strings.TrimSpace(injected.HeaderifiedText)
if text == "" {
text = strings.TrimSpace(injected.Text)
}
if text != "" || (cfg.SupportsImageInput && len(injected.ImageParts) > 0) {
insertAfter := len(p.Messages) - initialMsgCount
var extra []sdk.MessagePart
if cfg.SupportsImageInput {
for _, img := range injected.ImageParts {
if strings.TrimSpace(img.Image) != "" {
extra = append(extra, img)
}
}
}
p.Messages = append(p.Messages, sdk.UserMessage(text, extra...))
if cfg.InjectedRecorder != nil {
cfg.InjectedRecorder(text, insertAfter)
}
a.logger.Info("injected user message into agent stream",
slog.String("bot_id", cfg.Identity.BotID),
slog.Int("insert_after", insertAfter),
slog.Int("image_parts", len(extra)),
)
}
continue
default:
}
break
}
return p
}
}
opts := a.buildGenerateOptions(cfg, sdkTools, prepareStep)
streamResult, err := a.client.StreamText(ctx, opts...)
if err != nil {
ch <- StreamEvent{Type: EventError, Error: fmt.Sprintf("stream start: %v", err)}
return
}
ch <- StreamEvent{Type: EventAgentStart}
var allText strings.Builder
aborted := false
for part := range streamResult.Stream {
if ctx.Err() != nil {
aborted = true
break
}
switch p := part.(type) {
case *sdk.StartPart:
_ = p // stream start already emitted
case *sdk.TextStartPart:
ch <- StreamEvent{Type: EventTextStart}
case *sdk.TextDeltaPart:
if p.Text != "" {
if textLoopProbeBuffer != nil {
textLoopProbeBuffer.Push(p.Text)
}
ch <- StreamEvent{Type: EventTextDelta, Delta: p.Text}
allText.WriteString(p.Text)
}
case *sdk.TextEndPart:
if textLoopProbeBuffer != nil {
textLoopProbeBuffer.Flush()
}
ch <- StreamEvent{Type: EventTextEnd}
case *sdk.ReasoningStartPart:
ch <- StreamEvent{Type: EventReasoningStart}
case *sdk.ReasoningDeltaPart:
ch <- StreamEvent{Type: EventReasoningDelta, Delta: p.Text}
case *sdk.ReasoningEndPart:
ch <- StreamEvent{Type: EventReasoningEnd}
case *sdk.StreamToolCallPart:
if textLoopProbeBuffer != nil {
textLoopProbeBuffer.Flush()
}
ch <- StreamEvent{
Type: EventToolCallStart,
ToolName: p.ToolName,
ToolCallID: p.ToolCallID,
Input: p.Input,
}
case *sdk.ToolProgressPart:
ch <- StreamEvent{
Type: EventToolCallProgress,
ToolName: p.ToolName,
ToolCallID: p.ToolCallID,
Progress: p.Content,
}
case *sdk.StreamToolResultPart:
shouldAbort := false
if _, ok := toolLoopAbortCallIDs[p.ToolCallID]; ok {
delete(toolLoopAbortCallIDs, p.ToolCallID)
shouldAbort = true
}
ch <- StreamEvent{
Type: EventToolCallEnd,
ToolName: p.ToolName,
ToolCallID: p.ToolCallID,
Input: p.Input,
Result: p.Output,
}
if shouldAbort {
a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", p.ToolCallID))
aborted = true
}
case *sdk.StreamToolErrorPart:
ch <- StreamEvent{
Type: EventToolCallEnd,
ToolName: p.ToolName,
ToolCallID: p.ToolCallID,
Error: p.Error.Error(),
}
case *sdk.StreamFilePart:
mediaType := p.File.MediaType
if mediaType == "" {
mediaType = "image/png"
}
ch <- StreamEvent{
Type: EventAttachment,
Attachments: []FileAttachment{{
Type: "image",
URL: fmt.Sprintf("data:%s;base64,%s", mediaType, p.File.Data),
Mime: mediaType,
}},
}
case *sdk.ErrorPart:
ch <- StreamEvent{Type: EventError, Error: p.Error.Error()}
aborted = true
case *sdk.AbortPart:
aborted = true
case *sdk.FinishPart:
// handled after loop
}
if aborted {
break
}
}
if textLoopProbeBuffer != nil {
textLoopProbeBuffer.Flush()
}
finalMessages := streamResult.Messages
if readMediaState != nil {
finalMessages = readMediaState.mergeMessages(streamResult.Steps, finalMessages)
}
var totalUsage sdk.Usage
for _, step := range streamResult.Steps {
totalUsage.InputTokens += step.Usage.InputTokens
totalUsage.OutputTokens += step.Usage.OutputTokens
totalUsage.TotalTokens += step.Usage.TotalTokens
totalUsage.ReasoningTokens += step.Usage.ReasoningTokens
totalUsage.CachedInputTokens += step.Usage.CachedInputTokens
totalUsage.InputTokenDetails.NoCacheTokens += step.Usage.InputTokenDetails.NoCacheTokens
totalUsage.InputTokenDetails.CacheReadTokens += step.Usage.InputTokenDetails.CacheReadTokens
totalUsage.InputTokenDetails.CacheWriteTokens += step.Usage.InputTokenDetails.CacheWriteTokens
totalUsage.OutputTokenDetails.TextTokens += step.Usage.OutputTokenDetails.TextTokens
totalUsage.OutputTokenDetails.ReasoningTokens += step.Usage.OutputTokenDetails.ReasoningTokens
}
usageJSON, _ := json.Marshal(totalUsage)
termEvent := StreamEvent{
Messages: mustMarshal(finalMessages),
Usage: usageJSON,
}
if aborted {
termEvent.Type = EventAgentAbort
} else {
termEvent.Type = EventAgentEnd
}
ch <- termEvent
}
func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) {
// Collecting emitter: tools push side-effect events here during generation.
var collected []tools.ToolStreamEvent
collectEmitter := tools.StreamEmitter(func(evt tools.ToolStreamEvent) {
collected = append(collected, evt)
})
var sdkTools []sdk.Tool
if cfg.SupportsToolCall {
var err error
sdkTools, err = a.assembleTools(ctx, cfg, collectEmitter)
if err != nil {
return nil, fmt.Errorf("assemble tools: %w", err)
}
}
sdkTools, readMediaState := decorateReadMediaTools(cfg.Model, sdkTools)
var toolLoopGuard *ToolLoopGuard
var textLoopGuard *TextLoopGuard
toolLoopAbortCallIDs := make(map[string]struct{})
if cfg.LoopDetection.Enabled {
toolLoopGuard = NewToolLoopGuard(ToolLoopRepeatThreshold, ToolLoopWarningsBeforeAbort)
textLoopGuard = NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{})
}
if toolLoopGuard != nil {
sdkTools = wrapToolsWithLoopGuard(sdkTools, toolLoopGuard, toolLoopAbortCallIDs)
}
var prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams
if readMediaState != nil {
prepareStep = readMediaState.prepareStep
}
opts := a.buildGenerateOptions(cfg, sdkTools, prepareStep)
opts = append(opts,
sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams {
if cfg.LoopDetection.Enabled {
if len(toolLoopAbortCallIDs) > 0 {
return nil // stop
}
if textLoopGuard != nil && isNonEmptyString(step.Text) {
result := textLoopGuard.Inspect(step.Text)
if result.Abort {
return nil // stop
}
}
}
return nil
}),
)
genResult, err := a.client.GenerateTextResult(ctx, opts...)
if err != nil {
return nil, fmt.Errorf("generate: %w", err)
}
// Drain collected tool-emitted side effects into the result.
var attachments []FileAttachment
var reactions []ReactionItem
var speeches []SpeechItem
for _, evt := range collected {
switch evt.Type {
case tools.StreamEventAttachment:
for _, a := range evt.Attachments {
attachments = append(attachments, FileAttachment{
Type: a.Type, Path: a.Path, URL: a.URL,
Mime: a.Mime, Name: a.Name,
ContentHash: a.ContentHash, Size: a.Size,
Metadata: a.Metadata,
})
}
case tools.StreamEventReaction:
for _, r := range evt.Reactions {
reactions = append(reactions, ReactionItem{Emoji: r.Emoji})
}
case tools.StreamEventSpeech:
for _, s := range evt.Speeches {
speeches = append(speeches, SpeechItem{Text: s.Text})
}
}
}
finalMessages := genResult.Messages
if readMediaState != nil {
finalMessages = readMediaState.mergeMessages(genResult.Steps, finalMessages)
}
return &GenerateResult{
Messages: finalMessages,
Text: genResult.Text,
Attachments: attachments,
Reactions: reactions,
Speeches: speeches,
Usage: &genResult.Usage,
}, nil
}
func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool, prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams) []sdk.GenerateOption {
opts := []sdk.GenerateOption{
sdk.WithModel(cfg.Model),
sdk.WithMessages(cfg.Messages),
sdk.WithSystem(cfg.System),
sdk.WithMaxSteps(-1),
}
if len(tools) > 0 && cfg.SupportsToolCall {
opts = append(opts, sdk.WithTools(tools))
}
if prepareStep != nil {
opts = append(opts, sdk.WithPrepareStep(prepareStep))
}
opts = append(opts, models.BuildReasoningOptions(models.SDKModelConfig{
ClientType: models.ResolveClientType(cfg.Model),
ReasoningConfig: &models.ReasoningConfig{
Enabled: cfg.ReasoningEffort != "",
Effort: cfg.ReasoningEffort,
},
})...)
return opts
}
// assembleTools collects tools from all registered ToolProviders.
// emitter is injected into the session context so that tools targeting the
// current conversation can push side-effect events (attachments, reactions,
// speech) directly into the agent stream.
func (a *Agent) assembleTools(ctx context.Context, cfg RunConfig, emitter tools.StreamEmitter) ([]sdk.Tool, error) {
if len(a.toolProviders) == 0 {
return nil, nil
}
skillsMap := make(map[string]tools.SkillDetail, len(cfg.Skills))
for _, s := range cfg.Skills {
skillsMap[s.Name] = tools.SkillDetail{
Description: s.Description,
Content: s.Content,
}
}
session := tools.SessionContext{
BotID: cfg.Identity.BotID,
ChatID: cfg.Identity.ChatID,
SessionID: cfg.Identity.SessionID,
SessionType: cfg.SessionType,
ChannelIdentityID: cfg.Identity.ChannelIdentityID,
SessionToken: cfg.Identity.SessionToken,
CurrentPlatform: cfg.Identity.CurrentPlatform,
ReplyTarget: cfg.Identity.ReplyTarget,
SupportsImageInput: cfg.SupportsImageInput,
IsSubagent: cfg.Identity.IsSubagent,
Skills: skillsMap,
TimezoneLocation: cfg.Identity.TimezoneLocation,
Emitter: emitter,
}
var allTools []sdk.Tool
for _, provider := range a.toolProviders {
providerTools, err := provider.Tools(ctx, session)
if err != nil {
a.logger.Warn("tool provider failed", slog.Any("error", err))
continue
}
allTools = append(allTools, providerTools...)
}
return allTools, nil
}
// toolStreamEventToAgentEvent converts a tool-layer ToolStreamEvent into an
// agent-layer StreamEvent suitable for the output channel.
func toolStreamEventToAgentEvent(evt tools.ToolStreamEvent) StreamEvent {
switch evt.Type {
case tools.StreamEventAttachment:
atts := make([]FileAttachment, 0, len(evt.Attachments))
for _, a := range evt.Attachments {
atts = append(atts, FileAttachment{
Type: a.Type, Path: a.Path, URL: a.URL,
Mime: a.Mime, Name: a.Name,
ContentHash: a.ContentHash, Size: a.Size,
Metadata: a.Metadata,
})
}
return StreamEvent{Type: EventAttachment, Attachments: atts}
case tools.StreamEventReaction:
rs := make([]ReactionItem, 0, len(evt.Reactions))
for _, r := range evt.Reactions {
rs = append(rs, ReactionItem{Emoji: r.Emoji})
}
return StreamEvent{Type: EventReaction, Reactions: rs}
case tools.StreamEventSpeech:
ss := make([]SpeechItem, 0, len(evt.Speeches))
for _, s := range evt.Speeches {
ss = append(ss, SpeechItem{Text: s.Text})
}
return StreamEvent{Type: EventSpeech, Speeches: ss}
default:
return StreamEvent{}
}
}
func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs map[string]struct{}) []sdk.Tool {
wrapped := make([]sdk.Tool, len(tools))
for i, tool := range tools {
originalExecute := tool.Execute
toolName := tool.Name
wrapped[i] = tool
wrapped[i].Execute = func(ctx *sdk.ToolExecContext, input any) (any, error) {
warn, abort := guard.Guard(toolName, input)
if abort {
abortCallIDs[ctx.ToolCallID] = struct{}{}
return map[string]any{
"isError": true,
"content": []map[string]any{{
"type": "text",
"text": ToolLoopDetectedAbortMessage,
}},
}, errors.New(ToolLoopDetectedAbortMessage)
}
if warn {
return map[string]any{
ToolLoopWarningKey: true,
"content": []map[string]any{{
"type": "text",
"text": ToolLoopWarningText,
}},
}, nil
}
return originalExecute(ctx, input)
}
}
return wrapped
}