diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 20d9b577..9cf84a9f 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -572,6 +572,7 @@ func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *c agenttools.NewMemoryProvider(log, memoryRegistry, settingsService), agenttools.NewWebProvider(log, settingsService, searchProviderService), agenttools.NewContainerProvider(log, manager, config.DefaultDataMount), + agenttools.NewReadMediaProvider(log, manager, config.DefaultDataMount), agenttools.NewInboxProvider(log, inboxService), agenttools.NewEmailProvider(log, emailService, emailManager), agenttools.NewWebFetchProvider(log), diff --git a/cmd/memoh/serve.go b/cmd/memoh/serve.go index 9e65186d..5a238a1b 100644 --- a/cmd/memoh/serve.go +++ b/cmd/memoh/serve.go @@ -434,6 +434,7 @@ func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *c agenttools.NewMemoryProvider(log, memoryRegistry, settingsService), agenttools.NewWebProvider(log, settingsService, searchProviderService), agenttools.NewContainerProvider(log, manager, config.DefaultDataMount), + agenttools.NewReadMediaProvider(log, manager, config.DefaultDataMount), agenttools.NewInboxProvider(log, inboxService), agenttools.NewEmailProvider(log, emailService, emailManager), agenttools.NewWebFetchProvider(log), diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 0acdabb6..d8fad39e 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -62,6 +62,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv ch <- StreamEvent{Type: EventError, Error: fmt.Sprintf("assemble tools: %v", err)} return } + tools, readMediaState := decorateReadMediaTools(cfg.Model, tools) enabledSkills := make([]string, 0, len(cfg.EnabledSkillNames)) copy(enabledSkills, cfg.EnabledSkillNames) @@ -103,7 +104,11 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv tagResolvers := DefaultTagResolvers() tagExtractor := NewStreamTagExtractor(tagResolvers) - opts := a.buildGenerateOptions(cfg, tools) + var prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams + if readMediaState != nil { + prepareStep = readMediaState.prepareStep + } + opts := a.buildGenerateOptions(cfg, tools, prepareStep) streamResult, err := a.client.StreamText(ctx, opts...) if err != nil { @@ -251,7 +256,11 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv textLoopProbeBuffer.Flush() } - finalMessages := StripTagsFromMessages(streamResult.Messages) + finalMessages := streamResult.Messages + if readMediaState != nil { + finalMessages = readMediaState.mergeMessages(streamResult.Steps, finalMessages) + } + finalMessages = StripTagsFromMessages(finalMessages) var totalUsage sdk.Usage perStepUsages := make([]json.RawMessage, 0, len(streamResult.Steps)) @@ -286,6 +295,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult if err != nil { return nil, fmt.Errorf("assemble tools: %w", err) } + tools, readMediaState := decorateReadMediaTools(cfg.Model, tools) enabledSkills := make([]string, 0, len(cfg.EnabledSkillNames)) copy(enabledSkills, cfg.EnabledSkillNames) @@ -315,7 +325,11 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult tools = wrapToolsWithLoopGuard(tools, toolLoopGuard, toolLoopAbortCallIDs) } - opts := a.buildGenerateOptions(cfg, tools) + var prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams + if readMediaState != nil { + prepareStep = readMediaState.prepareStep + } + opts := a.buildGenerateOptions(cfg, tools, prepareStep) opts = append(opts, sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams { if cfg.LoopDetection.Enabled { @@ -376,7 +390,11 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult } } - finalMessages := StripTagsFromMessages(genResult.Messages) + finalMessages := genResult.Messages + if readMediaState != nil { + finalMessages = readMediaState.mergeMessages(genResult.Steps, finalMessages) + } + finalMessages = StripTagsFromMessages(finalMessages) return &GenerateResult{ Messages: finalMessages, @@ -389,7 +407,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult }, nil } -func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool) []sdk.GenerateOption { +func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool, prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams) []sdk.GenerateOption { opts := []sdk.GenerateOption{ sdk.WithModel(cfg.Model), sdk.WithMessages(cfg.Messages), @@ -399,6 +417,9 @@ func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool) []sdk.Genera if len(tools) > 0 { opts = append(opts, sdk.WithTools(tools)) } + if prepareStep != nil { + opts = append(opts, sdk.WithPrepareStep(prepareStep)) + } opts = append(opts, BuildReasoningOptions(ModelConfig{ ClientType: resolveClientType(cfg.Model), ReasoningConfig: &ReasoningConfig{ @@ -432,13 +453,14 @@ func (a *Agent) assembleTools(ctx context.Context, cfg RunConfig) ([]sdk.Tool, e return nil, nil } session := tools.SessionContext{ - BotID: cfg.Identity.BotID, - ChatID: cfg.Identity.ChatID, - ChannelIdentityID: cfg.Identity.ChannelIdentityID, - SessionToken: cfg.Identity.SessionToken, - CurrentPlatform: cfg.Identity.CurrentPlatform, - ReplyTarget: cfg.Identity.ReplyTarget, - IsSubagent: cfg.Identity.IsSubagent, + BotID: cfg.Identity.BotID, + ChatID: cfg.Identity.ChatID, + ChannelIdentityID: cfg.Identity.ChannelIdentityID, + SessionToken: cfg.Identity.SessionToken, + CurrentPlatform: cfg.Identity.CurrentPlatform, + ReplyTarget: cfg.Identity.ReplyTarget, + SupportsImageInput: cfg.SupportsImageInput, + IsSubagent: cfg.Identity.IsSubagent, } var allTools []sdk.Tool diff --git a/internal/agent/read_media.go b/internal/agent/read_media.go new file mode 100644 index 00000000..4a5269a6 --- /dev/null +++ b/internal/agent/read_media.go @@ -0,0 +1,174 @@ +package agent + +import ( + "fmt" + "strings" + + sdk "github.com/memohai/twilight-ai/sdk" + + agenttools "github.com/memohai/memoh/internal/agent/tools" +) + +func decorateReadMediaTools(model *sdk.Model, tools []sdk.Tool) ([]sdk.Tool, *readMediaDecorationState) { + if len(tools) == 0 { + return tools, nil + } + + clientType := resolveClientType(model) + state := &readMediaDecorationState{ + pendingImages: make(map[string]sdk.ImagePart), + } + wrapped := make([]sdk.Tool, 0, len(tools)) + found := false + + for _, tool := range tools { + if tool.Name != agenttools.ReadMediaToolName || tool.Execute == nil { + wrapped = append(wrapped, tool) + continue + } + + found = true + originalExecute := tool.Execute + toolCopy := tool + toolCopy.Execute = func(ctx *sdk.ToolExecContext, input any) (any, error) { + output, err := originalExecute(ctx, input) + if err != nil { + return output, err + } + + publicResult, image, ok := normalizeReadMediaOutput(output, clientType) + if !ok { + return output, nil + } + if ctx != nil && strings.TrimSpace(ctx.ToolCallID) != "" && strings.TrimSpace(image.Image) != "" { + if _, exists := state.pendingImages[ctx.ToolCallID]; !exists { + state.pendingOrder = append(state.pendingOrder, ctx.ToolCallID) + } + state.pendingImages[ctx.ToolCallID] = image + } + return publicResult, nil + } + wrapped = append(wrapped, toolCopy) + } + + if !found { + return tools, nil + } + + return wrapped, state +} + +type readMediaDecorationState struct { + pendingOrder []string + pendingImages map[string]sdk.ImagePart + prepareCalls int + injections []readMediaInjection +} + +type readMediaInjection struct { + afterStep int + message sdk.Message +} + +func (s *readMediaDecorationState) prepareStep(params *sdk.GenerateParams) *sdk.GenerateParams { + if s == nil || params == nil { + return nil + } + + afterStep := s.prepareCalls + s.prepareCalls++ + + if len(s.pendingOrder) == 0 { + return nil + } + + parts := make([]sdk.MessagePart, 0, len(s.pendingOrder)) + for _, toolCallID := range s.pendingOrder { + image, ok := s.pendingImages[toolCallID] + delete(s.pendingImages, toolCallID) + if !ok || strings.TrimSpace(image.Image) == "" { + continue + } + parts = append(parts, image) + } + s.pendingOrder = s.pendingOrder[:0] + + if len(parts) == 0 { + return nil + } + + message := sdk.Message{ + Role: sdk.MessageRoleUser, + Content: parts, + } + s.injections = append(s.injections, readMediaInjection{ + afterStep: afterStep, + message: message, + }) + + next := *params + next.Messages = append(append([]sdk.Message(nil), params.Messages...), message) + return &next +} + +func (s *readMediaDecorationState) mergeMessages(steps []sdk.StepResult, fallback []sdk.Message) []sdk.Message { + if s == nil || len(s.injections) == 0 { + return fallback + } + if len(steps) == 0 { + merged := append([]sdk.Message(nil), fallback...) + for _, injection := range s.injections { + merged = append(merged, injection.message) + } + return merged + } + + merged := make([]sdk.Message, 0, len(fallback)+len(s.injections)) + injectionIndex := 0 + for stepIndex, step := range steps { + merged = append(merged, step.Messages...) + for injectionIndex < len(s.injections) && s.injections[injectionIndex].afterStep == stepIndex { + merged = append(merged, s.injections[injectionIndex].message) + injectionIndex++ + } + } + for injectionIndex < len(s.injections) { + merged = append(merged, s.injections[injectionIndex].message) + injectionIndex++ + } + return merged +} + +func normalizeReadMediaOutput(output any, clientType string) (any, sdk.ImagePart, bool) { + switch value := output.(type) { + case agenttools.ReadMediaToolOutput: + return value.Public, buildReadMediaImagePart(clientType, value.ImageBase64, value.ImageMediaType), true + case *agenttools.ReadMediaToolOutput: + if value == nil { + return nil, sdk.ImagePart{}, false + } + return value.Public, buildReadMediaImagePart(clientType, value.ImageBase64, value.ImageMediaType), true + default: + return nil, sdk.ImagePart{}, false + } +} + +func buildReadMediaImagePart(clientType, imageBase64, mediaType string) sdk.ImagePart { + imageBase64 = strings.TrimSpace(imageBase64) + mediaType = strings.TrimSpace(mediaType) + if imageBase64 == "" { + return sdk.ImagePart{} + } + if mediaType == "" { + mediaType = "image/png" + } + + image := imageBase64 + if clientType != ClientTypeAnthropicMessages { + image = fmt.Sprintf("data:%s;base64,%s", mediaType, imageBase64) + } + return sdk.ImagePart{ + Image: image, + MediaType: mediaType, + } +} diff --git a/internal/agent/read_media_test.go b/internal/agent/read_media_test.go new file mode 100644 index 00000000..7c5a7b31 --- /dev/null +++ b/internal/agent/read_media_test.go @@ -0,0 +1,394 @@ +package agent + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "net" + "strings" + "testing" + + sdk "github.com/memohai/twilight-ai/sdk" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" + "google.golang.org/grpc/test/bufconn" + + agenttools "github.com/memohai/memoh/internal/agent/tools" + "github.com/memohai/memoh/internal/workspace/bridge" + pb "github.com/memohai/memoh/internal/workspace/bridgepb" +) + +const agentReadMediaTestBufSize = 1 << 20 + +type agentReadMediaContainerService struct { + pb.UnimplementedContainerServiceServer + files map[string][]byte +} + +func (s *agentReadMediaContainerService) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error { + data, ok := s.files[req.GetPath()] + if !ok { + return status.Error(codes.NotFound, "not found") + } + if len(data) == 0 { + return nil + } + return stream.Send(&pb.DataChunk{Data: data}) +} + +type agentReadMediaBridgeProvider struct { + client *bridge.Client +} + +func (p *agentReadMediaBridgeProvider) MCPClient(_ context.Context, _ string) (*bridge.Client, error) { + return p.client, nil +} + +func newAgentReadMediaBridgeProvider(t *testing.T, files map[string][]byte) bridge.Provider { + t.Helper() + + lis := bufconn.Listen(agentReadMediaTestBufSize) + srv := grpc.NewServer() + pb.RegisterContainerServiceServer(srv, &agentReadMediaContainerService{files: files}) + + done := make(chan struct{}) + go func() { + defer close(done) + _ = srv.Serve(lis) + }() + t.Cleanup(func() { + srv.Stop() + <-done + }) + + dialer := func(ctx context.Context, _ string) (net.Conn, error) { + return lis.DialContext(ctx) + } + conn, err := grpc.NewClient( + "passthrough://bufnet", + grpc.WithContextDialer(dialer), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("grpc.NewClient: %v", err) + } + t.Cleanup(func() { _ = conn.Close() }) + + return &agentReadMediaBridgeProvider{client: bridge.NewClientFromConn(conn)} +} + +type agentReadMediaMockProvider struct { + name string + calls int + handler func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error) +} + +func (m *agentReadMediaMockProvider) Name() string { + if m.name != "" { + return m.name + } + return "mock" +} + +func (*agentReadMediaMockProvider) ListModels(context.Context) ([]sdk.Model, error) { + return nil, nil +} + +func (*agentReadMediaMockProvider) Test(context.Context) *sdk.ProviderTestResult { + return &sdk.ProviderTestResult{Status: sdk.ProviderStatusOK, Message: "ok"} +} + +func (*agentReadMediaMockProvider) TestModel(context.Context, string) (*sdk.ModelTestResult, error) { + return &sdk.ModelTestResult{Supported: true, Message: "supported"}, nil +} + +func (m *agentReadMediaMockProvider) DoGenerate(_ context.Context, params sdk.GenerateParams) (*sdk.GenerateResult, error) { + m.calls++ + return m.handler(m.calls, params) +} + +func (m *agentReadMediaMockProvider) DoStream(ctx context.Context, params sdk.GenerateParams) (*sdk.StreamResult, error) { + result, err := m.DoGenerate(ctx, params) + if err != nil { + return nil, err + } + ch := make(chan sdk.StreamPart, 8) + go func() { + defer close(ch) + ch <- &sdk.StartPart{} + ch <- &sdk.StartStepPart{} + if result.Text != "" { + ch <- &sdk.TextStartPart{ID: "mock"} + ch <- &sdk.TextDeltaPart{ID: "mock", Text: result.Text} + ch <- &sdk.TextEndPart{ID: "mock"} + } + for _, tc := range result.ToolCalls { + ch <- &sdk.StreamToolCallPart{ + ToolCallID: tc.ToolCallID, + ToolName: tc.ToolName, + Input: tc.Input, + } + } + ch <- &sdk.FinishStepPart{FinishReason: result.FinishReason, Usage: result.Usage, Response: result.Response} + ch <- &sdk.FinishPart{FinishReason: result.FinishReason, TotalUsage: result.Usage} + }() + return &sdk.StreamResult{Stream: ch}, nil +} + +func assertInjectedReadMediaMessage(t *testing.T, msg sdk.Message, expectedImage, expectedMediaType string) { + t.Helper() + + if msg.Role != sdk.MessageRoleUser { + t.Fatalf("expected injected read_media message role %q, got %q", sdk.MessageRoleUser, msg.Role) + } + if len(msg.Content) != 1 { + t.Fatalf("expected one injected content part, got %d", len(msg.Content)) + } + image, ok := msg.Content[0].(sdk.ImagePart) + if !ok { + t.Fatalf("expected sdk.ImagePart, got %T", msg.Content[0]) + } + if image.Image != expectedImage { + t.Fatalf("unexpected injected image payload: %q", image.Image) + } + if image.MediaType != expectedMediaType { + t.Fatalf("unexpected injected media type: %q", image.MediaType) + } +} + +func TestAgentGenerateReadMediaInjectsImageIntoNextStep(t *testing.T) { + t.Parallel() + + pngBytes := []byte("\x89PNG\r\n\x1a\npayload") + expectedDataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(pngBytes) + + modelProvider := &agentReadMediaMockProvider{ + handler: func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error) { + if call == 1 { + return &sdk.GenerateResult{ + FinishReason: sdk.FinishReasonToolCalls, + ToolCalls: []sdk.ToolCall{{ + ToolCallID: "call-1", + ToolName: "read_media", + Input: map[string]any{"path": "/data/images/demo.png"}, + }}, + }, nil + } + + if len(params.Messages) < 4 { + t.Fatalf("expected prior tool and injected messages, got %d", len(params.Messages)) + } + + last := params.Messages[len(params.Messages)-1] + if last.Role != sdk.MessageRoleUser { + t.Fatalf("expected last message to be injected user image, got %s", last.Role) + } + if len(last.Content) != 1 { + t.Fatalf("expected one injected content part, got %d", len(last.Content)) + } + image, ok := last.Content[0].(sdk.ImagePart) + if !ok { + t.Fatalf("expected sdk.ImagePart, got %T", last.Content[0]) + } + if image.Image != expectedDataURL { + t.Fatalf("unexpected injected image payload: %q", image.Image) + } + if image.MediaType != "image/png" { + t.Fatalf("unexpected injected media type: %q", image.MediaType) + } + + var toolResult sdk.ToolResultPart + foundToolMessage := false + for _, msg := range params.Messages { + if msg.Role != sdk.MessageRoleTool || len(msg.Content) == 0 { + continue + } + part, ok := msg.Content[0].(sdk.ToolResultPart) + if !ok { + continue + } + toolResult = part + foundToolMessage = true + break + } + if !foundToolMessage { + t.Fatal("expected tool result message before second step") + } + raw, err := json.Marshal(toolResult.Result) + if err != nil { + t.Fatalf("marshal tool result: %v", err) + } + if !bytes.Contains(raw, []byte(`"ok":true`)) { + t.Fatalf("expected compact success metadata, got %s", raw) + } + if bytes.Contains(raw, []byte(expectedDataURL)) || bytes.Contains(raw, []byte("payload")) { + t.Fatalf("tool result leaked image bytes: %s", raw) + } + + return &sdk.GenerateResult{ + Text: "done", + FinishReason: sdk.FinishReasonStop, + }, nil + }, + } + + a := New(Deps{}) + a.SetToolProviders([]agenttools.ToolProvider{ + agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{ + "/data/images/demo.png": pngBytes, + }), "/data"), + }) + + result, err := a.Generate(context.Background(), RunConfig{ + Model: &sdk.Model{ID: "mock-model", Provider: modelProvider}, + Messages: []sdk.Message{sdk.UserMessage("look at the image")}, + SupportsImageInput: true, + Identity: SessionContext{ + BotID: "bot-1", + }, + }) + if err != nil { + t.Fatalf("Generate returned error: %v", err) + } + if result.Text != "done" { + t.Fatalf("unexpected result text: %q", result.Text) + } + if len(result.Messages) != 4 { + t.Fatalf("expected persisted step + injected history, got %d messages", len(result.Messages)) + } + assertInjectedReadMediaMessage(t, result.Messages[2], expectedDataURL, "image/png") + if result.Messages[3].Role != sdk.MessageRoleAssistant { + t.Fatalf("expected final persisted message to be assistant, got %s", result.Messages[3].Role) + } + if modelProvider.calls != 2 { + t.Fatalf("expected 2 model calls, got %d", modelProvider.calls) + } +} + +func TestAgentGenerateReadMediaInjectsAnthropicSafeImageIntoNextStep(t *testing.T) { + t.Parallel() + + pngBytes := []byte("\x89PNG\r\n\x1a\npayload") + expectedBase64 := base64.StdEncoding.EncodeToString(pngBytes) + + modelProvider := &agentReadMediaMockProvider{ + name: "anthropic-messages", + handler: func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error) { + if call == 1 { + return &sdk.GenerateResult{ + FinishReason: sdk.FinishReasonToolCalls, + ToolCalls: []sdk.ToolCall{{ + ToolCallID: "call-1", + ToolName: "read_media", + Input: map[string]any{"path": "/data/images/demo.png"}, + }}, + }, nil + } + + last := params.Messages[len(params.Messages)-1] + image, ok := last.Content[0].(sdk.ImagePart) + if !ok { + t.Fatalf("expected sdk.ImagePart, got %T", last.Content[0]) + } + if image.Image != expectedBase64 { + t.Fatalf("expected raw base64 for anthropic, got %q", image.Image) + } + if image.MediaType != "image/png" { + t.Fatalf("unexpected injected media type: %q", image.MediaType) + } + if strings.HasPrefix(image.Image, "data:") { + t.Fatalf("anthropic image payload must not be a data URL: %q", image.Image) + } + + return &sdk.GenerateResult{ + Text: "done", + FinishReason: sdk.FinishReasonStop, + }, nil + }, + } + + a := New(Deps{}) + a.SetToolProviders([]agenttools.ToolProvider{ + agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{ + "/data/images/demo.png": pngBytes, + }), "/data"), + }) + + _, err := a.Generate(context.Background(), RunConfig{ + Model: &sdk.Model{ID: "mock-model", Provider: modelProvider}, + Messages: []sdk.Message{sdk.UserMessage("look at the image")}, + SupportsImageInput: true, + Identity: SessionContext{ + BotID: "bot-1", + }, + }) + if err != nil { + t.Fatalf("Generate returned error: %v", err) + } +} + +func TestAgentStreamReadMediaPersistsInjectedImageInTerminalMessages(t *testing.T) { + t.Parallel() + + pngBytes := []byte("\x89PNG\r\n\x1a\npayload") + expectedDataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(pngBytes) + + modelProvider := &agentReadMediaMockProvider{ + handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) { + if call == 1 { + return &sdk.GenerateResult{ + FinishReason: sdk.FinishReasonToolCalls, + ToolCalls: []sdk.ToolCall{{ + ToolCallID: "call-1", + ToolName: "read_media", + Input: map[string]any{"path": "/data/images/demo.png"}, + }}, + }, nil + } + return &sdk.GenerateResult{ + Text: "done", + FinishReason: sdk.FinishReasonStop, + }, nil + }, + } + + a := New(Deps{}) + a.SetToolProviders([]agenttools.ToolProvider{ + agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{ + "/data/images/demo.png": pngBytes, + }), "/data"), + }) + + var terminal StreamEvent + for event := range a.Stream(context.Background(), RunConfig{ + Model: &sdk.Model{ID: "mock-model", Provider: modelProvider}, + Messages: []sdk.Message{sdk.UserMessage("look at the image")}, + SupportsImageInput: true, + Identity: SessionContext{ + BotID: "bot-1", + }, + }) { + if event.IsTerminal() { + terminal = event + } + } + + if terminal.Type != EventAgentEnd { + t.Fatalf("expected terminal event %q, got %q", EventAgentEnd, terminal.Type) + } + + var messages []sdk.Message + if err := json.Unmarshal(terminal.Messages, &messages); err != nil { + t.Fatalf("unmarshal terminal messages: %v", err) + } + if len(messages) != 4 { + t.Fatalf("expected persisted step + injected history, got %d messages", len(messages)) + } + assertInjectedReadMediaMessage(t, messages[2], expectedDataURL, "image/png") + if messages[3].Role != sdk.MessageRoleAssistant { + t.Fatalf("expected final persisted message to be assistant, got %s", messages[3].Role) + } +} diff --git a/internal/agent/tools/read_media.go b/internal/agent/tools/read_media.go new file mode 100644 index 00000000..8d5b959c --- /dev/null +++ b/internal/agent/tools/read_media.go @@ -0,0 +1,217 @@ +package tools + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "path" + "strings" + + sdk "github.com/memohai/twilight-ai/sdk" + + "github.com/memohai/memoh/internal/workspace/bridge" +) + +const ( + ReadMediaToolName = "read_media" + toolReadMedia = ReadMediaToolName + defaultReadMediaRoot = "/data" + defaultReadMediaMaxBytes = 20 * 1024 * 1024 +) + +var readMediaSupportedMimeTypes = map[string]struct{}{ + "image/gif": {}, + "image/jpeg": {}, + "image/png": {}, + "image/webp": {}, +} + +// ReadMediaToolResult is the public result returned to the model. +type ReadMediaToolResult struct { + OK bool `json:"ok"` + Path string `json:"path,omitempty"` + Mime string `json:"mime,omitempty"` + Size int `json:"size,omitempty"` + Error string `json:"error,omitempty"` +} + +// ReadMediaToolOutput is the internal execution result used by the agent to +// inject the image into the next Twilight AI step while keeping the visible +// tool result lightweight. +type ReadMediaToolOutput struct { + Public ReadMediaToolResult + ImageBase64 string + ImageMediaType string +} + +type readMediaToolOutput = ReadMediaToolOutput + +type ReadMediaProvider struct { + clients bridge.Provider + rootDir string + maxBytes int64 + logger *slog.Logger +} + +func NewReadMediaProvider(log *slog.Logger, clients bridge.Provider, rootDir string) *ReadMediaProvider { + if log == nil { + log = slog.Default() + } + root := strings.TrimSpace(rootDir) + if root == "" { + root = defaultReadMediaRoot + } + return &ReadMediaProvider{ + clients: clients, + rootDir: path.Clean(root), + maxBytes: defaultReadMediaMaxBytes, + logger: log.With(slog.String("tool", "read_media")), + } +} + +func (p *ReadMediaProvider) Tools(_ context.Context, session SessionContext) ([]sdk.Tool, error) { + if p == nil || p.clients == nil || !session.SupportsImageInput { + return nil, nil + } + root := p.rootDir + if root == "" { + root = defaultReadMediaRoot + } + sess := session + return []sdk.Tool{ + { + Name: toolReadMedia, + Description: fmt.Sprintf("Load an image file from %s into model context so you can inspect it. Relative paths are resolved under %s.", root, root), + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": fmt.Sprintf("Image file path under %s. Absolute paths must stay under %s; relative paths are resolved under %s.", root, root, root), + }, + }, + "required": []string{"path"}, + }, + Execute: func(ctx *sdk.ToolExecContext, input any) (any, error) { + return p.execReadMedia(ctx.Context, sess, inputAsMap(input)) + }, + }, + }, nil +} + +func (p *ReadMediaProvider) execReadMedia(ctx context.Context, session SessionContext, args map[string]any) (any, error) { + client, err := p.getClient(ctx, session.BotID) + if err != nil { + return readMediaErrorResult(err.Error()), nil + } + + resolvedPath, err := p.resolveImagePath(StringArg(args, "path")) + if err != nil { + return readMediaErrorResult(err.Error()), nil + } + + reader, err := client.ReadRaw(ctx, resolvedPath) + if err != nil { + return readMediaErrorResult(err.Error()), nil + } + defer func() { _ = reader.Close() }() + + data, err := io.ReadAll(io.LimitReader(reader, p.maxBytes+1)) + if err != nil { + return readMediaErrorResult("read_media failed to load image: " + err.Error()), nil + } + if int64(len(data)) > p.maxBytes { + return readMediaErrorResult(fmt.Sprintf("read_media failed to load image: file exceeds %d bytes", p.maxBytes)), nil + } + + mimeType, err := detectReadMediaMime(data) + if err != nil { + return readMediaErrorResult(err.Error()), nil + } + + encoded := base64.StdEncoding.EncodeToString(data) + return ReadMediaToolOutput{ + Public: ReadMediaToolResult{ + OK: true, + Path: resolvedPath, + Mime: mimeType, + Size: len(data), + }, + ImageBase64: encoded, + ImageMediaType: mimeType, + }, nil +} + +func (p *ReadMediaProvider) getClient(ctx context.Context, botID string) (*bridge.Client, error) { + botID = strings.TrimSpace(botID) + if botID == "" { + return nil, errors.New("bot_id is required") + } + client, err := p.clients.MCPClient(ctx, botID) + if err != nil { + return nil, fmt.Errorf("container not reachable: %w", err) + } + return client, nil +} + +func (p *ReadMediaProvider) resolveImagePath(raw string) (string, error) { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "", errors.New("path is required") + } + + root := p.rootDir + if root == "" { + root = defaultReadMediaRoot + } + root = path.Clean(root) + + resolved := trimmed + if !strings.HasPrefix(resolved, "/") { + resolved = path.Join(root, resolved) + } + resolved = path.Clean(resolved) + + if resolved == root || !strings.HasPrefix(resolved, root+"/") { + return "", fmt.Errorf("path must be under %s", root) + } + return resolved, nil +} + +func readMediaErrorResult(message string) ReadMediaToolOutput { + msg := strings.TrimSpace(message) + if msg == "" { + msg = "read_media failed" + } + return ReadMediaToolOutput{ + Public: ReadMediaToolResult{ + OK: false, + Error: msg, + }, + } +} + +func detectReadMediaMime(data []byte) (string, error) { + sniffedMime := "" + if len(data) > 0 { + sniffedMime = strings.ToLower(strings.TrimSpace(http.DetectContentType(data))) + } + + switch { + case sniffedMime == "": + return "", errors.New("read_media only supports PNG, JPEG, GIF, or WebP image bytes") + case isSupportedReadMediaMime(sniffedMime): + return sniffedMime, nil + default: + return "", errors.New("read_media only supports PNG, JPEG, GIF, or WebP image bytes") + } +} + +func isSupportedReadMediaMime(mimeType string) bool { + _, ok := readMediaSupportedMimeTypes[strings.ToLower(strings.TrimSpace(mimeType))] + return ok +} diff --git a/internal/agent/tools/read_media_test.go b/internal/agent/tools/read_media_test.go new file mode 100644 index 00000000..7ab897cc --- /dev/null +++ b/internal/agent/tools/read_media_test.go @@ -0,0 +1,306 @@ +package tools + +import ( + "context" + "encoding/base64" + "net" + "strings" + "testing" + + sdk "github.com/memohai/twilight-ai/sdk" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" + "google.golang.org/grpc/test/bufconn" + + "github.com/memohai/memoh/internal/workspace/bridge" + pb "github.com/memohai/memoh/internal/workspace/bridgepb" +) + +const readMediaTestBufSize = 1 << 20 + +type readMediaTestContainerService struct { + pb.UnimplementedContainerServiceServer + files map[string][]byte +} + +func (s *readMediaTestContainerService) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error { + data, ok := s.files[req.GetPath()] + if !ok { + return status.Error(codes.NotFound, "not found") + } + if len(data) == 0 { + return nil + } + return stream.Send(&pb.DataChunk{Data: data}) +} + +type readMediaStaticProvider struct { + client *bridge.Client +} + +func (p *readMediaStaticProvider) MCPClient(_ context.Context, _ string) (*bridge.Client, error) { + return p.client, nil +} + +func newReadMediaBridgeProvider(t *testing.T, files map[string][]byte) bridge.Provider { + t.Helper() + + lis := bufconn.Listen(readMediaTestBufSize) + srv := grpc.NewServer() + pb.RegisterContainerServiceServer(srv, &readMediaTestContainerService{files: files}) + + done := make(chan struct{}) + go func() { + defer close(done) + _ = srv.Serve(lis) + }() + t.Cleanup(func() { + srv.Stop() + <-done + }) + + dialer := func(ctx context.Context, _ string) (net.Conn, error) { + return lis.DialContext(ctx) + } + conn, err := grpc.NewClient( + "passthrough://bufnet", + grpc.WithContextDialer(dialer), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("grpc.NewClient: %v", err) + } + t.Cleanup(func() { _ = conn.Close() }) + + return &readMediaStaticProvider{client: bridge.NewClientFromConn(conn)} +} + +func findToolByName(tools []sdk.Tool, name string) (sdk.Tool, bool) { + for _, tool := range tools { + if tool.Name == name { + return tool, true + } + } + return sdk.Tool{}, false +} + +func TestReadMediaProviderToolsOnlyWhenImageInputIsSupported(t *testing.T) { + t.Parallel() + + provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, nil), "/data") + + toolsWithoutImage, err := provider.Tools(context.Background(), SessionContext{ + BotID: "bot-1", + SupportsImageInput: false, + }) + if err != nil { + t.Fatalf("Tools without image input returned error: %v", err) + } + if len(toolsWithoutImage) != 0 { + t.Fatalf("expected no tools without image input support, got %d", len(toolsWithoutImage)) + } + + toolsWithImage, err := provider.Tools(context.Background(), SessionContext{ + BotID: "bot-1", + SupportsImageInput: true, + }) + if err != nil { + t.Fatalf("Tools with image input returned error: %v", err) + } + + tool, ok := findToolByName(toolsWithImage, toolReadMedia) + if !ok { + t.Fatalf("expected %q tool to be exposed", toolReadMedia) + } + if tool.Execute == nil { + t.Fatal("expected read_media tool to be executable") + } +} + +func TestReadMediaProviderExecuteReadsImageUnderData(t *testing.T) { + t.Parallel() + + pngBytes := []byte("\x89PNG\r\n\x1a\npayload") + provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, map[string][]byte{ + "/data/images/demo.png": pngBytes, + }), "/data") + + tools, err := provider.Tools(context.Background(), SessionContext{ + BotID: "bot-1", + SupportsImageInput: true, + }) + if err != nil { + t.Fatalf("Tools returned error: %v", err) + } + + tool, ok := findToolByName(tools, toolReadMedia) + if !ok { + t.Fatalf("expected %q tool", toolReadMedia) + } + + output, err := tool.Execute(&sdk.ToolExecContext{ + Context: context.Background(), + ToolCallID: "call-1", + ToolName: toolReadMedia, + }, map[string]any{"path": "images/demo.png"}) + if err != nil { + t.Fatalf("Execute returned error: %v", err) + } + + result, ok := output.(readMediaToolOutput) + if !ok { + t.Fatalf("expected readMediaToolOutput, got %T", output) + } + if !result.Public.OK { + t.Fatalf("expected success result, got %+v", result.Public) + } + if result.Public.Path != "/data/images/demo.png" { + t.Fatalf("unexpected path: %q", result.Public.Path) + } + if result.Public.Mime != "image/png" { + t.Fatalf("unexpected mime: %q", result.Public.Mime) + } + if result.Public.Size != len(pngBytes) { + t.Fatalf("unexpected size: %d", result.Public.Size) + } + + expectedBase64 := base64.StdEncoding.EncodeToString(pngBytes) + if result.ImageBase64 != expectedBase64 { + t.Fatalf("unexpected image payload: %q", result.ImageBase64) + } + if result.ImageMediaType != "image/png" { + t.Fatalf("unexpected image media type: %q", result.ImageMediaType) + } +} + +func TestReadMediaProviderExecuteRejectsPathOutsideData(t *testing.T) { + t.Parallel() + + provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, nil), "/data") + + tools, err := provider.Tools(context.Background(), SessionContext{ + BotID: "bot-1", + SupportsImageInput: true, + }) + if err != nil { + t.Fatalf("Tools returned error: %v", err) + } + + tool, ok := findToolByName(tools, toolReadMedia) + if !ok { + t.Fatalf("expected %q tool", toolReadMedia) + } + + output, err := tool.Execute(&sdk.ToolExecContext{ + Context: context.Background(), + ToolCallID: "call-2", + ToolName: toolReadMedia, + }, map[string]any{"path": "/tmp/demo.png"}) + if err != nil { + t.Fatalf("Execute returned error: %v", err) + } + + result, ok := output.(readMediaToolOutput) + if !ok { + t.Fatalf("expected readMediaToolOutput, got %T", output) + } + if result.Public.OK { + t.Fatalf("expected error result, got %+v", result.Public) + } + if !strings.Contains(result.Public.Error, "path must be under /data") { + t.Fatalf("unexpected error: %q", result.Public.Error) + } + if result.ImageBase64 != "" { + t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64) + } +} + +func TestReadMediaProviderExecuteRejectsExtensionOnlySVG(t *testing.T) { + t.Parallel() + + provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, map[string][]byte{ + "/data/images/demo.svg": []byte(``), + }), "/data") + + tools, err := provider.Tools(context.Background(), SessionContext{ + BotID: "bot-1", + SupportsImageInput: true, + }) + if err != nil { + t.Fatalf("Tools returned error: %v", err) + } + + tool, ok := findToolByName(tools, toolReadMedia) + if !ok { + t.Fatalf("expected %q tool", toolReadMedia) + } + + output, err := tool.Execute(&sdk.ToolExecContext{ + Context: context.Background(), + ToolCallID: "call-3", + ToolName: toolReadMedia, + }, map[string]any{"path": "images/demo.svg"}) + if err != nil { + t.Fatalf("Execute returned error: %v", err) + } + + result, ok := output.(readMediaToolOutput) + if !ok { + t.Fatalf("expected readMediaToolOutput, got %T", output) + } + if result.Public.OK { + t.Fatalf("expected error result, got %+v", result.Public) + } + if !strings.Contains(result.Public.Error, "PNG, JPEG, GIF, or WebP") { + t.Fatalf("unexpected error: %q", result.Public.Error) + } + if result.ImageBase64 != "" { + t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64) + } +} + +func TestReadMediaProviderExecuteRejectsCorruptedRasterBytes(t *testing.T) { + t.Parallel() + + provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, map[string][]byte{ + "/data/images/demo.png": []byte("definitely not a png"), + }), "/data") + + tools, err := provider.Tools(context.Background(), SessionContext{ + BotID: "bot-1", + SupportsImageInput: true, + }) + if err != nil { + t.Fatalf("Tools returned error: %v", err) + } + + tool, ok := findToolByName(tools, toolReadMedia) + if !ok { + t.Fatalf("expected %q tool", toolReadMedia) + } + + output, err := tool.Execute(&sdk.ToolExecContext{ + Context: context.Background(), + ToolCallID: "call-4", + ToolName: toolReadMedia, + }, map[string]any{"path": "images/demo.png"}) + if err != nil { + t.Fatalf("Execute returned error: %v", err) + } + + result, ok := output.(readMediaToolOutput) + if !ok { + t.Fatalf("expected readMediaToolOutput, got %T", output) + } + if result.Public.OK { + t.Fatalf("expected error result, got %+v", result.Public) + } + if !strings.Contains(result.Public.Error, "PNG, JPEG, GIF, or WebP") { + t.Fatalf("unexpected error: %q", result.Public.Error) + } + if result.ImageBase64 != "" { + t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64) + } +} diff --git a/internal/agent/tools/types.go b/internal/agent/tools/types.go index 28134583..9aee08c1 100644 --- a/internal/agent/tools/types.go +++ b/internal/agent/tools/types.go @@ -12,13 +12,14 @@ import ( // SessionContext carries request-scoped identity for tool execution. type SessionContext struct { - BotID string - ChatID string - ChannelIdentityID string - SessionToken string //nolint:gosec // carries session credential material at runtime - CurrentPlatform string - ReplyTarget string - IsSubagent bool + BotID string + ChatID string + ChannelIdentityID string + SessionToken string //nolint:gosec // carries session credential material at runtime + CurrentPlatform string + ReplyTarget string + SupportsImageInput bool + IsSubagent bool } // ToolProvider supplies a set of tools for the agent. diff --git a/internal/agent/types.go b/internal/agent/types.go index b2c81574..bf2bc57c 100644 --- a/internal/agent/types.go +++ b/internal/agent/types.go @@ -54,20 +54,21 @@ type LoopDetectionConfig struct { // RunConfig holds everything needed for a single agent invocation. type RunConfig struct { - Model *sdk.Model - ReasoningEffort string - Messages []sdk.Message - Query string - System string - Tools []sdk.Tool - Channels []string - CurrentChannel string - Identity SessionContext - Skills []SkillEntry - EnabledSkillNames []string - Inbox []InboxItem - LoopDetection LoopDetectionConfig - ActiveContextTime int + Model *sdk.Model + ReasoningEffort string + Messages []sdk.Message + Query string + System string + Tools []sdk.Tool + SupportsImageInput bool + Channels []string + CurrentChannel string + Identity SessionContext + Skills []SkillEntry + EnabledSkillNames []string + Inbox []InboxItem + LoopDetection LoopDetectionConfig + ActiveContextTime int } // GenerateResult holds the result of a non-streaming agent invocation. diff --git a/internal/conversation/flow/read_media_prompt_test.go b/internal/conversation/flow/read_media_prompt_test.go new file mode 100644 index 00000000..2c39efad --- /dev/null +++ b/internal/conversation/flow/read_media_prompt_test.go @@ -0,0 +1,45 @@ +package flow + +import ( + "context" + "strings" + "testing" + + agentpkg "github.com/memohai/memoh/internal/agent" +) + +func TestPrepareRunConfigIncludesReadMediaWhenImageInputIsSupported(t *testing.T) { + t.Parallel() + + resolver := &Resolver{} + cfg := agentpkg.RunConfig{ + Query: "describe this image", + SupportsImageInput: true, + Identity: agentpkg.SessionContext{ + BotID: "bot-1", + }, + } + + prepared := resolver.prepareRunConfig(context.Background(), cfg) + if !strings.Contains(prepared.System, "`read_media`") { + t.Fatalf("expected system prompt to include read_media tool, got:\n%s", prepared.System) + } +} + +func TestPrepareRunConfigOmitsReadMediaWhenImageInputIsUnsupported(t *testing.T) { + t.Parallel() + + resolver := &Resolver{} + cfg := agentpkg.RunConfig{ + Query: "describe this image", + SupportsImageInput: false, + Identity: agentpkg.SessionContext{ + BotID: "bot-1", + }, + } + + prepared := resolver.prepareRunConfig(context.Background(), cfg) + if strings.Contains(prepared.System, "`read_media`") { + t.Fatalf("expected system prompt to omit read_media tool, got:\n%s", prepared.System) + } +} diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go index 574d0f96..4595a9e1 100644 --- a/internal/conversation/flow/resolver.go +++ b/internal/conversation/flow/resolver.go @@ -309,12 +309,13 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r sdkMessages := modelMessagesToSDKMessages(nonNilModelMessages(messages)) runCfg := agentpkg.RunConfig{ - Model: sdkModel, - ReasoningEffort: reasoningEffort, - Messages: sdkMessages, - Query: headerifiedQuery, - Channels: nonNilStrings(req.Channels), - CurrentChannel: req.CurrentChannel, + Model: sdkModel, + ReasoningEffort: reasoningEffort, + Messages: sdkMessages, + Query: headerifiedQuery, + SupportsImageInput: chatModel.HasInputModality(models.ModelInputImage), + Channels: nonNilStrings(req.Channels), + CurrentChannel: req.CurrentChannel, Identity: agentpkg.SessionContext{ BotID: req.BotID, ChatID: req.ChatID, @@ -368,10 +369,7 @@ func (r *Resolver) Chat(ctx context.Context, req conversation.ChatRequest) (conv // prepareRunConfig generates the system prompt and appends the user message. func (r *Resolver) prepareRunConfig(ctx context.Context, cfg agentpkg.RunConfig) agentpkg.RunConfig { - supportsImageInput := false - for _, m := range cfg.Identity.CurrentPlatform { - _ = m - } + supportsImageInput := cfg.SupportsImageInput // Build system prompt var files []agentpkg.SystemFile if r.agent != nil {