package agent import ( "context" "encoding/json" "errors" "fmt" "log/slog" "strings" sdk "github.com/memohai/twilight-ai/sdk" "github.com/memohai/memoh/internal/agent/tools" "github.com/memohai/memoh/internal/models" "github.com/memohai/memoh/internal/workspace/bridge" ) // Agent is the core agent that handles LLM interactions. type Agent struct { client *sdk.Client toolProviders []tools.ToolProvider bridgeProvider bridge.Provider logger *slog.Logger } // New creates a new Agent with the given dependencies. func New(deps Deps) *Agent { logger := deps.Logger if logger == nil { logger = slog.Default() } return &Agent{ client: sdk.NewClient(), bridgeProvider: deps.BridgeProvider, logger: logger.With(slog.String("service", "agent")), } } // BridgeProvider returns the underlying bridge provider (workspace manager). func (a *Agent) BridgeProvider() bridge.Provider { return a.bridgeProvider } // SetToolProviders sets the tool providers after construction. // This allows breaking dependency cycles in the DI graph. func (a *Agent) SetToolProviders(providers []tools.ToolProvider) { a.toolProviders = providers } // Stream runs the agent in streaming mode, emitting events to the returned channel. func (a *Agent) Stream(ctx context.Context, cfg RunConfig) <-chan StreamEvent { ch := make(chan StreamEvent) go func() { defer close(ch) a.runStream(ctx, cfg, ch) }() return ch } // Generate runs the agent in non-streaming mode, returning the complete result. func (a *Agent) Generate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) { return a.runGenerate(ctx, cfg) } func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEvent) { // Stream emitter: tools targeting the current conversation push // side-effect events (attachments, reactions, speech) directly here. streamEmitter := tools.StreamEmitter(func(evt tools.ToolStreamEvent) { ch <- toolStreamEventToAgentEvent(evt) }) var sdkTools []sdk.Tool if cfg.SupportsToolCall { var err error sdkTools, err = a.assembleTools(ctx, cfg, streamEmitter) if err != nil { ch <- StreamEvent{Type: EventError, Error: fmt.Sprintf("assemble tools: %v", err)} return } } sdkTools, readMediaState := decorateReadMediaTools(cfg.Model, sdkTools) // Loop detection setup var textLoopGuard *TextLoopGuard var textLoopProbeBuffer *TextLoopProbeBuffer var toolLoopGuard *ToolLoopGuard toolLoopAbortCallIDs := make(map[string]struct{}) if cfg.LoopDetection.Enabled { textLoopGuard = NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{}) textLoopProbeBuffer = NewTextLoopProbeBuffer(LoopDetectedProbeChars, func(text string) { result := textLoopGuard.Inspect(text) if result.Abort { a.logger.Warn("text loop detected, will abort") } }) toolLoopGuard = NewToolLoopGuard(ToolLoopRepeatThreshold, ToolLoopWarningsBeforeAbort) } // Wrap tools with loop detection if toolLoopGuard != nil { sdkTools = wrapToolsWithLoopGuard(sdkTools, toolLoopGuard, toolLoopAbortCallIDs) } var prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams if readMediaState != nil { prepareStep = readMediaState.prepareStep } initialMsgCount := len(cfg.Messages) if cfg.InjectCh != nil { basePrepare := prepareStep prepareStep = func(p *sdk.GenerateParams) *sdk.GenerateParams { if basePrepare != nil { if override := basePrepare(p); override != nil { p = override } } for { select { case injected, ok := <-cfg.InjectCh: if !ok { break } text := strings.TrimSpace(injected.HeaderifiedText) if text == "" { text = strings.TrimSpace(injected.Text) } if text != "" || (cfg.SupportsImageInput && len(injected.ImageParts) > 0) { insertAfter := len(p.Messages) - initialMsgCount var extra []sdk.MessagePart if cfg.SupportsImageInput { for _, img := range injected.ImageParts { if strings.TrimSpace(img.Image) != "" { extra = append(extra, img) } } } p.Messages = append(p.Messages, sdk.UserMessage(text, extra...)) if cfg.InjectedRecorder != nil { cfg.InjectedRecorder(text, insertAfter) } a.logger.Info("injected user message into agent stream", slog.String("bot_id", cfg.Identity.BotID), slog.Int("insert_after", insertAfter), slog.Int("image_parts", len(extra)), ) } continue default: } break } return p } } opts := a.buildGenerateOptions(cfg, sdkTools, prepareStep) streamResult, err := a.client.StreamText(ctx, opts...) if err != nil { ch <- StreamEvent{Type: EventError, Error: fmt.Sprintf("stream start: %v", err)} return } ch <- StreamEvent{Type: EventAgentStart} var allText strings.Builder aborted := false for part := range streamResult.Stream { if ctx.Err() != nil { aborted = true break } switch p := part.(type) { case *sdk.StartPart: _ = p // stream start already emitted case *sdk.TextStartPart: ch <- StreamEvent{Type: EventTextStart} case *sdk.TextDeltaPart: if p.Text != "" { if textLoopProbeBuffer != nil { textLoopProbeBuffer.Push(p.Text) } ch <- StreamEvent{Type: EventTextDelta, Delta: p.Text} allText.WriteString(p.Text) } case *sdk.TextEndPart: if textLoopProbeBuffer != nil { textLoopProbeBuffer.Flush() } ch <- StreamEvent{Type: EventTextEnd} case *sdk.ReasoningStartPart: ch <- StreamEvent{Type: EventReasoningStart} case *sdk.ReasoningDeltaPart: ch <- StreamEvent{Type: EventReasoningDelta, Delta: p.Text} case *sdk.ReasoningEndPart: ch <- StreamEvent{Type: EventReasoningEnd} case *sdk.StreamToolCallPart: if textLoopProbeBuffer != nil { textLoopProbeBuffer.Flush() } ch <- StreamEvent{ Type: EventToolCallStart, ToolName: p.ToolName, ToolCallID: p.ToolCallID, Input: p.Input, } case *sdk.ToolProgressPart: ch <- StreamEvent{ Type: EventToolCallProgress, ToolName: p.ToolName, ToolCallID: p.ToolCallID, Progress: p.Content, } case *sdk.StreamToolResultPart: shouldAbort := false if _, ok := toolLoopAbortCallIDs[p.ToolCallID]; ok { delete(toolLoopAbortCallIDs, p.ToolCallID) shouldAbort = true } ch <- StreamEvent{ Type: EventToolCallEnd, ToolName: p.ToolName, ToolCallID: p.ToolCallID, Input: p.Input, Result: p.Output, } if shouldAbort { a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", p.ToolCallID)) aborted = true } case *sdk.StreamToolErrorPart: ch <- StreamEvent{ Type: EventToolCallEnd, ToolName: p.ToolName, ToolCallID: p.ToolCallID, Error: p.Error.Error(), } case *sdk.StreamFilePart: mediaType := p.File.MediaType if mediaType == "" { mediaType = "image/png" } ch <- StreamEvent{ Type: EventAttachment, Attachments: []FileAttachment{{ Type: "image", URL: fmt.Sprintf("data:%s;base64,%s", mediaType, p.File.Data), Mime: mediaType, }}, } case *sdk.ErrorPart: ch <- StreamEvent{Type: EventError, Error: p.Error.Error()} aborted = true case *sdk.AbortPart: aborted = true case *sdk.FinishPart: // handled after loop } if aborted { break } } if textLoopProbeBuffer != nil { textLoopProbeBuffer.Flush() } finalMessages := streamResult.Messages if readMediaState != nil { finalMessages = readMediaState.mergeMessages(streamResult.Steps, finalMessages) } var totalUsage sdk.Usage for _, step := range streamResult.Steps { totalUsage.InputTokens += step.Usage.InputTokens totalUsage.OutputTokens += step.Usage.OutputTokens totalUsage.TotalTokens += step.Usage.TotalTokens totalUsage.ReasoningTokens += step.Usage.ReasoningTokens totalUsage.CachedInputTokens += step.Usage.CachedInputTokens totalUsage.InputTokenDetails.NoCacheTokens += step.Usage.InputTokenDetails.NoCacheTokens totalUsage.InputTokenDetails.CacheReadTokens += step.Usage.InputTokenDetails.CacheReadTokens totalUsage.InputTokenDetails.CacheWriteTokens += step.Usage.InputTokenDetails.CacheWriteTokens totalUsage.OutputTokenDetails.TextTokens += step.Usage.OutputTokenDetails.TextTokens totalUsage.OutputTokenDetails.ReasoningTokens += step.Usage.OutputTokenDetails.ReasoningTokens } usageJSON, _ := json.Marshal(totalUsage) termEvent := StreamEvent{ Messages: mustMarshal(finalMessages), Usage: usageJSON, } if aborted { termEvent.Type = EventAgentAbort } else { termEvent.Type = EventAgentEnd } ch <- termEvent } func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) { // Collecting emitter: tools push side-effect events here during generation. var collected []tools.ToolStreamEvent collectEmitter := tools.StreamEmitter(func(evt tools.ToolStreamEvent) { collected = append(collected, evt) }) var sdkTools []sdk.Tool if cfg.SupportsToolCall { var err error sdkTools, err = a.assembleTools(ctx, cfg, collectEmitter) if err != nil { return nil, fmt.Errorf("assemble tools: %w", err) } } sdkTools, readMediaState := decorateReadMediaTools(cfg.Model, sdkTools) var toolLoopGuard *ToolLoopGuard var textLoopGuard *TextLoopGuard toolLoopAbortCallIDs := make(map[string]struct{}) if cfg.LoopDetection.Enabled { toolLoopGuard = NewToolLoopGuard(ToolLoopRepeatThreshold, ToolLoopWarningsBeforeAbort) textLoopGuard = NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{}) } if toolLoopGuard != nil { sdkTools = wrapToolsWithLoopGuard(sdkTools, toolLoopGuard, toolLoopAbortCallIDs) } var prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams if readMediaState != nil { prepareStep = readMediaState.prepareStep } opts := a.buildGenerateOptions(cfg, sdkTools, prepareStep) opts = append(opts, sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams { if cfg.LoopDetection.Enabled { if len(toolLoopAbortCallIDs) > 0 { return nil // stop } if textLoopGuard != nil && isNonEmptyString(step.Text) { result := textLoopGuard.Inspect(step.Text) if result.Abort { return nil // stop } } } return nil }), ) genResult, err := a.client.GenerateTextResult(ctx, opts...) if err != nil { return nil, fmt.Errorf("generate: %w", err) } // Drain collected tool-emitted side effects into the result. var attachments []FileAttachment var reactions []ReactionItem var speeches []SpeechItem for _, evt := range collected { switch evt.Type { case tools.StreamEventAttachment: for _, a := range evt.Attachments { attachments = append(attachments, FileAttachment{ Type: a.Type, Path: a.Path, URL: a.URL, Mime: a.Mime, Name: a.Name, ContentHash: a.ContentHash, Size: a.Size, Metadata: a.Metadata, }) } case tools.StreamEventReaction: for _, r := range evt.Reactions { reactions = append(reactions, ReactionItem{Emoji: r.Emoji}) } case tools.StreamEventSpeech: for _, s := range evt.Speeches { speeches = append(speeches, SpeechItem{Text: s.Text}) } } } finalMessages := genResult.Messages if readMediaState != nil { finalMessages = readMediaState.mergeMessages(genResult.Steps, finalMessages) } return &GenerateResult{ Messages: finalMessages, Text: genResult.Text, Attachments: attachments, Reactions: reactions, Speeches: speeches, Usage: &genResult.Usage, }, nil } func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool, prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams) []sdk.GenerateOption { opts := []sdk.GenerateOption{ sdk.WithModel(cfg.Model), sdk.WithMessages(cfg.Messages), sdk.WithSystem(cfg.System), sdk.WithMaxSteps(-1), } if len(tools) > 0 && cfg.SupportsToolCall { opts = append(opts, sdk.WithTools(tools)) } if prepareStep != nil { opts = append(opts, sdk.WithPrepareStep(prepareStep)) } opts = append(opts, models.BuildReasoningOptions(models.SDKModelConfig{ ClientType: models.ResolveClientType(cfg.Model), ReasoningConfig: &models.ReasoningConfig{ Enabled: cfg.ReasoningEffort != "", Effort: cfg.ReasoningEffort, }, })...) return opts } // assembleTools collects tools from all registered ToolProviders. // emitter is injected into the session context so that tools targeting the // current conversation can push side-effect events (attachments, reactions, // speech) directly into the agent stream. func (a *Agent) assembleTools(ctx context.Context, cfg RunConfig, emitter tools.StreamEmitter) ([]sdk.Tool, error) { if len(a.toolProviders) == 0 { return nil, nil } skillsMap := make(map[string]tools.SkillDetail, len(cfg.Skills)) for _, s := range cfg.Skills { skillsMap[s.Name] = tools.SkillDetail{ Description: s.Description, Content: s.Content, } } session := tools.SessionContext{ BotID: cfg.Identity.BotID, ChatID: cfg.Identity.ChatID, SessionID: cfg.Identity.SessionID, SessionType: cfg.SessionType, ChannelIdentityID: cfg.Identity.ChannelIdentityID, SessionToken: cfg.Identity.SessionToken, CurrentPlatform: cfg.Identity.CurrentPlatform, ReplyTarget: cfg.Identity.ReplyTarget, SupportsImageInput: cfg.SupportsImageInput, IsSubagent: cfg.Identity.IsSubagent, Skills: skillsMap, TimezoneLocation: cfg.Identity.TimezoneLocation, Emitter: emitter, } var allTools []sdk.Tool for _, provider := range a.toolProviders { providerTools, err := provider.Tools(ctx, session) if err != nil { a.logger.Warn("tool provider failed", slog.Any("error", err)) continue } allTools = append(allTools, providerTools...) } return allTools, nil } // toolStreamEventToAgentEvent converts a tool-layer ToolStreamEvent into an // agent-layer StreamEvent suitable for the output channel. func toolStreamEventToAgentEvent(evt tools.ToolStreamEvent) StreamEvent { switch evt.Type { case tools.StreamEventAttachment: atts := make([]FileAttachment, 0, len(evt.Attachments)) for _, a := range evt.Attachments { atts = append(atts, FileAttachment{ Type: a.Type, Path: a.Path, URL: a.URL, Mime: a.Mime, Name: a.Name, ContentHash: a.ContentHash, Size: a.Size, Metadata: a.Metadata, }) } return StreamEvent{Type: EventAttachment, Attachments: atts} case tools.StreamEventReaction: rs := make([]ReactionItem, 0, len(evt.Reactions)) for _, r := range evt.Reactions { rs = append(rs, ReactionItem{Emoji: r.Emoji}) } return StreamEvent{Type: EventReaction, Reactions: rs} case tools.StreamEventSpeech: ss := make([]SpeechItem, 0, len(evt.Speeches)) for _, s := range evt.Speeches { ss = append(ss, SpeechItem{Text: s.Text}) } return StreamEvent{Type: EventSpeech, Speeches: ss} default: return StreamEvent{} } } func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs map[string]struct{}) []sdk.Tool { wrapped := make([]sdk.Tool, len(tools)) for i, tool := range tools { originalExecute := tool.Execute toolName := tool.Name wrapped[i] = tool wrapped[i].Execute = func(ctx *sdk.ToolExecContext, input any) (any, error) { warn, abort := guard.Guard(toolName, input) if abort { abortCallIDs[ctx.ToolCallID] = struct{}{} return map[string]any{ "isError": true, "content": []map[string]any{{ "type": "text", "text": ToolLoopDetectedAbortMessage, }}, }, errors.New(ToolLoopDetectedAbortMessage) } if warn { return map[string]any{ ToolLoopWarningKey: true, "content": []map[string]any{{ "type": "text", "text": ToolLoopWarningText, }}, }, nil } return originalExecute(ctx, input) } } return wrapped }