mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
547 lines
15 KiB
Go
547 lines
15 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"strings"
|
|
|
|
sdk "github.com/memohai/twilight-ai/sdk"
|
|
|
|
"github.com/memohai/memoh/internal/agent/tools"
|
|
"github.com/memohai/memoh/internal/workspace/bridge"
|
|
)
|
|
|
|
// Agent is the core agent that handles LLM interactions.
|
|
type Agent struct {
|
|
client *sdk.Client
|
|
toolProviders []tools.ToolProvider
|
|
bridgeProvider bridge.Provider
|
|
logger *slog.Logger
|
|
}
|
|
|
|
// New creates a new Agent with the given dependencies.
|
|
func New(deps Deps) *Agent {
|
|
logger := deps.Logger
|
|
if logger == nil {
|
|
logger = slog.Default()
|
|
}
|
|
return &Agent{
|
|
client: sdk.NewClient(),
|
|
bridgeProvider: deps.BridgeProvider,
|
|
logger: logger.With(slog.String("service", "agent")),
|
|
}
|
|
}
|
|
|
|
// SetToolProviders sets the tool providers after construction.
|
|
// This allows breaking dependency cycles in the DI graph.
|
|
func (a *Agent) SetToolProviders(providers []tools.ToolProvider) {
|
|
a.toolProviders = providers
|
|
}
|
|
|
|
// Stream runs the agent in streaming mode, emitting events to the returned channel.
|
|
func (a *Agent) Stream(ctx context.Context, cfg RunConfig) <-chan StreamEvent {
|
|
ch := make(chan StreamEvent)
|
|
go func() {
|
|
defer close(ch)
|
|
a.runStream(ctx, cfg, ch)
|
|
}()
|
|
return ch
|
|
}
|
|
|
|
// Generate runs the agent in non-streaming mode, returning the complete result.
|
|
func (a *Agent) Generate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) {
|
|
return a.runGenerate(ctx, cfg)
|
|
}
|
|
|
|
func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEvent) {
|
|
tools, err := a.assembleTools(ctx, cfg)
|
|
if err != nil {
|
|
ch <- StreamEvent{Type: EventError, Error: fmt.Sprintf("assemble tools: %v", err)}
|
|
return
|
|
}
|
|
tools, readMediaState := decorateReadMediaTools(cfg.Model, tools)
|
|
|
|
enabledSkills := make([]string, 0, len(cfg.EnabledSkillNames))
|
|
copy(enabledSkills, cfg.EnabledSkillNames)
|
|
enableSkill := func(name string) {
|
|
for _, s := range cfg.Skills {
|
|
if s.Name == name {
|
|
for _, existing := range enabledSkills {
|
|
if existing == name {
|
|
return
|
|
}
|
|
}
|
|
enabledSkills = append(enabledSkills, name)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Loop detection setup
|
|
var textLoopGuard *TextLoopGuard
|
|
var textLoopProbeBuffer *TextLoopProbeBuffer
|
|
var toolLoopGuard *ToolLoopGuard
|
|
toolLoopAbortCallIDs := make(map[string]struct{})
|
|
if cfg.LoopDetection.Enabled {
|
|
textLoopGuard = NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{})
|
|
textLoopProbeBuffer = NewTextLoopProbeBuffer(LoopDetectedProbeChars, func(text string) {
|
|
result := textLoopGuard.Inspect(text)
|
|
if result.Abort {
|
|
a.logger.Warn("text loop detected, will abort")
|
|
}
|
|
})
|
|
toolLoopGuard = NewToolLoopGuard(ToolLoopRepeatThreshold, ToolLoopWarningsBeforeAbort)
|
|
}
|
|
|
|
// Wrap tools with loop detection
|
|
if toolLoopGuard != nil {
|
|
tools = wrapToolsWithLoopGuard(tools, toolLoopGuard, toolLoopAbortCallIDs)
|
|
}
|
|
|
|
tagResolvers := DefaultTagResolvers()
|
|
tagExtractor := NewStreamTagExtractor(tagResolvers)
|
|
|
|
var prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams
|
|
if readMediaState != nil {
|
|
prepareStep = readMediaState.prepareStep
|
|
}
|
|
opts := a.buildGenerateOptions(cfg, tools, prepareStep)
|
|
|
|
streamResult, err := a.client.StreamText(ctx, opts...)
|
|
if err != nil {
|
|
ch <- StreamEvent{Type: EventError, Error: fmt.Sprintf("stream start: %v", err)}
|
|
return
|
|
}
|
|
|
|
ch <- StreamEvent{Type: EventAgentStart}
|
|
|
|
var allText strings.Builder
|
|
aborted := false
|
|
|
|
for part := range streamResult.Stream {
|
|
if ctx.Err() != nil {
|
|
aborted = true
|
|
break
|
|
}
|
|
|
|
switch p := part.(type) {
|
|
case *sdk.StartPart:
|
|
_ = p // stream start already emitted
|
|
|
|
case *sdk.TextStartPart:
|
|
ch <- StreamEvent{Type: EventTextStart}
|
|
|
|
case *sdk.TextDeltaPart:
|
|
result := tagExtractor.Push(p.Text)
|
|
if result.VisibleText != "" {
|
|
if textLoopProbeBuffer != nil {
|
|
textLoopProbeBuffer.Push(result.VisibleText)
|
|
}
|
|
ch <- StreamEvent{Type: EventTextDelta, Delta: result.VisibleText}
|
|
allText.WriteString(result.VisibleText)
|
|
}
|
|
emitTagEvents(ch, result.Events)
|
|
|
|
case *sdk.TextEndPart:
|
|
remainder := tagExtractor.FlushRemainder()
|
|
if remainder.VisibleText != "" {
|
|
if textLoopProbeBuffer != nil {
|
|
textLoopProbeBuffer.Push(remainder.VisibleText)
|
|
}
|
|
ch <- StreamEvent{Type: EventTextDelta, Delta: remainder.VisibleText}
|
|
allText.WriteString(remainder.VisibleText)
|
|
}
|
|
if textLoopProbeBuffer != nil {
|
|
textLoopProbeBuffer.Flush()
|
|
}
|
|
emitTagEvents(ch, remainder.Events)
|
|
ch <- StreamEvent{Type: EventTextEnd}
|
|
|
|
case *sdk.ReasoningStartPart:
|
|
ch <- StreamEvent{Type: EventReasoningStart}
|
|
|
|
case *sdk.ReasoningDeltaPart:
|
|
ch <- StreamEvent{Type: EventReasoningDelta, Delta: p.Text}
|
|
|
|
case *sdk.ReasoningEndPart:
|
|
ch <- StreamEvent{Type: EventReasoningEnd}
|
|
|
|
case *sdk.StreamToolCallPart:
|
|
remainder := tagExtractor.FlushRemainder()
|
|
if remainder.VisibleText != "" {
|
|
if textLoopProbeBuffer != nil {
|
|
textLoopProbeBuffer.Push(remainder.VisibleText)
|
|
}
|
|
ch <- StreamEvent{Type: EventTextDelta, Delta: remainder.VisibleText}
|
|
allText.WriteString(remainder.VisibleText)
|
|
}
|
|
if textLoopProbeBuffer != nil {
|
|
textLoopProbeBuffer.Flush()
|
|
}
|
|
emitTagEvents(ch, remainder.Events)
|
|
ch <- StreamEvent{
|
|
Type: EventToolCallStart,
|
|
ToolName: p.ToolName,
|
|
ToolCallID: p.ToolCallID,
|
|
Input: p.Input,
|
|
}
|
|
|
|
case *sdk.StreamToolResultPart:
|
|
shouldAbort := false
|
|
if _, ok := toolLoopAbortCallIDs[p.ToolCallID]; ok {
|
|
delete(toolLoopAbortCallIDs, p.ToolCallID)
|
|
shouldAbort = true
|
|
}
|
|
ch <- StreamEvent{
|
|
Type: EventToolCallEnd,
|
|
ToolName: p.ToolName,
|
|
ToolCallID: p.ToolCallID,
|
|
Input: p.Input,
|
|
Result: p.Output,
|
|
}
|
|
if p.ToolName == "use_skill" {
|
|
if resultMap, ok := p.Output.(map[string]any); ok {
|
|
if skillName, ok := resultMap["skillName"].(string); ok && skillName != "" {
|
|
enableSkill(skillName)
|
|
}
|
|
}
|
|
}
|
|
if shouldAbort {
|
|
a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", p.ToolCallID))
|
|
aborted = true
|
|
}
|
|
|
|
case *sdk.StreamToolErrorPart:
|
|
ch <- StreamEvent{
|
|
Type: EventToolCallEnd,
|
|
ToolName: p.ToolName,
|
|
ToolCallID: p.ToolCallID,
|
|
Error: p.Error.Error(),
|
|
}
|
|
|
|
case *sdk.StreamFilePart:
|
|
mediaType := p.File.MediaType
|
|
if mediaType == "" {
|
|
mediaType = "image/png"
|
|
}
|
|
ch <- StreamEvent{
|
|
Type: EventAttachment,
|
|
Attachments: []FileAttachment{{
|
|
Type: "image",
|
|
URL: fmt.Sprintf("data:%s;base64,%s", mediaType, p.File.Data),
|
|
Mime: mediaType,
|
|
}},
|
|
}
|
|
|
|
case *sdk.ErrorPart:
|
|
ch <- StreamEvent{Type: EventError, Error: p.Error.Error()}
|
|
aborted = true
|
|
|
|
case *sdk.AbortPart:
|
|
aborted = true
|
|
|
|
case *sdk.FinishPart:
|
|
// handled after loop
|
|
}
|
|
|
|
if aborted {
|
|
break
|
|
}
|
|
}
|
|
|
|
if textLoopProbeBuffer != nil {
|
|
textLoopProbeBuffer.Flush()
|
|
}
|
|
|
|
finalMessages := streamResult.Messages
|
|
if readMediaState != nil {
|
|
finalMessages = readMediaState.mergeMessages(streamResult.Steps, finalMessages)
|
|
}
|
|
finalMessages = StripTagsFromMessages(finalMessages)
|
|
|
|
var totalUsage sdk.Usage
|
|
perStepUsages := make([]json.RawMessage, 0, len(streamResult.Steps))
|
|
for _, step := range streamResult.Steps {
|
|
totalUsage.InputTokens += step.Usage.InputTokens
|
|
totalUsage.OutputTokens += step.Usage.OutputTokens
|
|
totalUsage.TotalTokens += step.Usage.TotalTokens
|
|
totalUsage.ReasoningTokens += step.Usage.ReasoningTokens
|
|
totalUsage.CachedInputTokens += step.Usage.CachedInputTokens
|
|
stepJSON, _ := json.Marshal(step.Usage)
|
|
perStepUsages = append(perStepUsages, stepJSON)
|
|
}
|
|
usageJSON, _ := json.Marshal(totalUsage)
|
|
usagesJSON, _ := json.Marshal(perStepUsages)
|
|
|
|
termEvent := StreamEvent{
|
|
Messages: mustMarshal(finalMessages),
|
|
Usage: usageJSON,
|
|
Usages: usagesJSON,
|
|
Skills: enabledSkills,
|
|
}
|
|
if aborted {
|
|
termEvent.Type = EventAgentAbort
|
|
} else {
|
|
termEvent.Type = EventAgentEnd
|
|
}
|
|
ch <- termEvent
|
|
}
|
|
|
|
func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) {
|
|
tools, err := a.assembleTools(ctx, cfg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("assemble tools: %w", err)
|
|
}
|
|
tools, readMediaState := decorateReadMediaTools(cfg.Model, tools)
|
|
|
|
enabledSkills := make([]string, 0, len(cfg.EnabledSkillNames))
|
|
copy(enabledSkills, cfg.EnabledSkillNames)
|
|
enableSkill := func(name string) {
|
|
for _, s := range cfg.Skills {
|
|
if s.Name == name {
|
|
for _, existing := range enabledSkills {
|
|
if existing == name {
|
|
return
|
|
}
|
|
}
|
|
enabledSkills = append(enabledSkills, name)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
var toolLoopGuard *ToolLoopGuard
|
|
var textLoopGuard *TextLoopGuard
|
|
toolLoopAbortCallIDs := make(map[string]struct{})
|
|
if cfg.LoopDetection.Enabled {
|
|
toolLoopGuard = NewToolLoopGuard(ToolLoopRepeatThreshold, ToolLoopWarningsBeforeAbort)
|
|
textLoopGuard = NewTextLoopGuard(LoopDetectedStreakThreshold, LoopDetectedMinNewGramsPerChunk, SentialOptions{})
|
|
}
|
|
|
|
if toolLoopGuard != nil {
|
|
tools = wrapToolsWithLoopGuard(tools, toolLoopGuard, toolLoopAbortCallIDs)
|
|
}
|
|
|
|
var prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams
|
|
if readMediaState != nil {
|
|
prepareStep = readMediaState.prepareStep
|
|
}
|
|
opts := a.buildGenerateOptions(cfg, tools, prepareStep)
|
|
opts = append(opts,
|
|
sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams {
|
|
if cfg.LoopDetection.Enabled {
|
|
if len(toolLoopAbortCallIDs) > 0 {
|
|
return nil // stop
|
|
}
|
|
if textLoopGuard != nil && isNonEmptyString(step.Text) {
|
|
result := textLoopGuard.Inspect(step.Text)
|
|
if result.Abort {
|
|
return nil // stop
|
|
}
|
|
}
|
|
}
|
|
for _, tr := range step.ToolResults {
|
|
if tr.ToolName == "use_skill" {
|
|
if resultMap, ok := tr.Output.(map[string]any); ok {
|
|
if skillName, ok := resultMap["skillName"].(string); ok && skillName != "" {
|
|
enableSkill(skillName)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}),
|
|
)
|
|
|
|
genResult, err := a.client.GenerateTextResult(ctx, opts...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generate: %w", err)
|
|
}
|
|
|
|
resolvers := DefaultTagResolvers()
|
|
cleanedText, events := ExtractTagsFromText(genResult.Text, resolvers)
|
|
|
|
var attachments []FileAttachment
|
|
var reactions []ReactionItem
|
|
var speeches []SpeechItem
|
|
for _, ev := range events {
|
|
switch ev.Tag {
|
|
case "attachments":
|
|
for _, d := range ev.Data {
|
|
if att, ok := d.(FileAttachment); ok {
|
|
attachments = append(attachments, att)
|
|
}
|
|
}
|
|
case "reactions":
|
|
for _, d := range ev.Data {
|
|
if r, ok := d.(ReactionItem); ok {
|
|
reactions = append(reactions, r)
|
|
}
|
|
}
|
|
case "speech":
|
|
for _, d := range ev.Data {
|
|
if s, ok := d.(SpeechItem); ok {
|
|
speeches = append(speeches, s)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
finalMessages := genResult.Messages
|
|
if readMediaState != nil {
|
|
finalMessages = readMediaState.mergeMessages(genResult.Steps, finalMessages)
|
|
}
|
|
finalMessages = StripTagsFromMessages(finalMessages)
|
|
|
|
return &GenerateResult{
|
|
Messages: finalMessages,
|
|
Text: cleanedText,
|
|
Attachments: attachments,
|
|
Reactions: reactions,
|
|
Speeches: speeches,
|
|
Skills: enabledSkills,
|
|
Usage: &genResult.Usage,
|
|
}, nil
|
|
}
|
|
|
|
func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool, prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams) []sdk.GenerateOption {
|
|
opts := []sdk.GenerateOption{
|
|
sdk.WithModel(cfg.Model),
|
|
sdk.WithMessages(cfg.Messages),
|
|
sdk.WithSystem(cfg.System),
|
|
sdk.WithMaxSteps(-1),
|
|
}
|
|
if len(tools) > 0 {
|
|
opts = append(opts, sdk.WithTools(tools))
|
|
}
|
|
if prepareStep != nil {
|
|
opts = append(opts, sdk.WithPrepareStep(prepareStep))
|
|
}
|
|
opts = append(opts, BuildReasoningOptions(ModelConfig{
|
|
ClientType: resolveClientType(cfg.Model),
|
|
ReasoningConfig: &ReasoningConfig{
|
|
Enabled: cfg.ReasoningEffort != "",
|
|
Effort: cfg.ReasoningEffort,
|
|
},
|
|
})...)
|
|
return opts
|
|
}
|
|
|
|
func resolveClientType(model *sdk.Model) string {
|
|
if model == nil || model.Provider == nil {
|
|
return ClientTypeOpenAICompletions
|
|
}
|
|
name := model.Provider.Name()
|
|
switch {
|
|
case strings.Contains(name, "anthropic"):
|
|
return ClientTypeAnthropicMessages
|
|
case strings.Contains(name, "google"):
|
|
return ClientTypeGoogleGenerativeAI
|
|
case strings.Contains(name, "responses"):
|
|
return ClientTypeOpenAIResponses
|
|
default:
|
|
return ClientTypeOpenAICompletions
|
|
}
|
|
}
|
|
|
|
// assembleTools collects tools from all registered ToolProviders.
|
|
func (a *Agent) assembleTools(ctx context.Context, cfg RunConfig) ([]sdk.Tool, error) {
|
|
if len(a.toolProviders) == 0 {
|
|
return nil, nil
|
|
}
|
|
session := tools.SessionContext{
|
|
BotID: cfg.Identity.BotID,
|
|
ChatID: cfg.Identity.ChatID,
|
|
ChannelIdentityID: cfg.Identity.ChannelIdentityID,
|
|
SessionToken: cfg.Identity.SessionToken,
|
|
CurrentPlatform: cfg.Identity.CurrentPlatform,
|
|
ReplyTarget: cfg.Identity.ReplyTarget,
|
|
SupportsImageInput: cfg.SupportsImageInput,
|
|
IsSubagent: cfg.Identity.IsSubagent,
|
|
}
|
|
|
|
var allTools []sdk.Tool
|
|
for _, provider := range a.toolProviders {
|
|
providerTools, err := provider.Tools(ctx, session)
|
|
if err != nil {
|
|
a.logger.Warn("tool provider failed", slog.Any("error", err))
|
|
continue
|
|
}
|
|
allTools = append(allTools, providerTools...)
|
|
}
|
|
return allTools, nil
|
|
}
|
|
|
|
func emitTagEvents(ch chan<- StreamEvent, events []TagEvent) {
|
|
for _, ev := range events {
|
|
switch ev.Tag {
|
|
case "attachments":
|
|
var atts []FileAttachment
|
|
for _, d := range ev.Data {
|
|
if att, ok := d.(FileAttachment); ok {
|
|
atts = append(atts, att)
|
|
}
|
|
}
|
|
if len(atts) > 0 {
|
|
ch <- StreamEvent{Type: EventAttachment, Attachments: atts}
|
|
}
|
|
case "reactions":
|
|
var reactions []ReactionItem
|
|
for _, d := range ev.Data {
|
|
if r, ok := d.(ReactionItem); ok {
|
|
reactions = append(reactions, r)
|
|
}
|
|
}
|
|
if len(reactions) > 0 {
|
|
ch <- StreamEvent{Type: EventReaction, Reactions: reactions}
|
|
}
|
|
case "speech":
|
|
var speeches []SpeechItem
|
|
for _, d := range ev.Data {
|
|
if s, ok := d.(SpeechItem); ok {
|
|
speeches = append(speeches, s)
|
|
}
|
|
}
|
|
if len(speeches) > 0 {
|
|
ch <- StreamEvent{Type: EventSpeech, Speeches: speeches}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs map[string]struct{}) []sdk.Tool {
|
|
wrapped := make([]sdk.Tool, len(tools))
|
|
for i, tool := range tools {
|
|
originalExecute := tool.Execute
|
|
toolName := tool.Name
|
|
wrapped[i] = tool
|
|
wrapped[i].Execute = func(ctx *sdk.ToolExecContext, input any) (any, error) {
|
|
warn, abort := guard.Guard(toolName, input)
|
|
if abort {
|
|
abortCallIDs[ctx.ToolCallID] = struct{}{}
|
|
return map[string]any{
|
|
"isError": true,
|
|
"content": []map[string]any{{
|
|
"type": "text",
|
|
"text": ToolLoopDetectedAbortMessage,
|
|
}},
|
|
}, errors.New(ToolLoopDetectedAbortMessage)
|
|
}
|
|
if warn {
|
|
return map[string]any{
|
|
ToolLoopWarningKey: true,
|
|
"content": []map[string]any{{
|
|
"type": "text",
|
|
"text": ToolLoopWarningText,
|
|
}},
|
|
}, nil
|
|
}
|
|
return originalExecute(ctx, input)
|
|
}
|
|
}
|
|
return wrapped
|
|
}
|