feat(agent): restore read_media in pure Go (#257)

This commit is contained in:
Ringo.Typowriter
2026-03-21 14:28:50 +08:00
committed by GitHub
parent e379450702
commit ad08f335eb
11 changed files with 1203 additions and 43 deletions
+1
View File
@@ -572,6 +572,7 @@ func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *c
agenttools.NewMemoryProvider(log, memoryRegistry, settingsService), agenttools.NewMemoryProvider(log, memoryRegistry, settingsService),
agenttools.NewWebProvider(log, settingsService, searchProviderService), agenttools.NewWebProvider(log, settingsService, searchProviderService),
agenttools.NewContainerProvider(log, manager, config.DefaultDataMount), agenttools.NewContainerProvider(log, manager, config.DefaultDataMount),
agenttools.NewReadMediaProvider(log, manager, config.DefaultDataMount),
agenttools.NewInboxProvider(log, inboxService), agenttools.NewInboxProvider(log, inboxService),
agenttools.NewEmailProvider(log, emailService, emailManager), agenttools.NewEmailProvider(log, emailService, emailManager),
agenttools.NewWebFetchProvider(log), agenttools.NewWebFetchProvider(log),
+1
View File
@@ -434,6 +434,7 @@ func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *c
agenttools.NewMemoryProvider(log, memoryRegistry, settingsService), agenttools.NewMemoryProvider(log, memoryRegistry, settingsService),
agenttools.NewWebProvider(log, settingsService, searchProviderService), agenttools.NewWebProvider(log, settingsService, searchProviderService),
agenttools.NewContainerProvider(log, manager, config.DefaultDataMount), agenttools.NewContainerProvider(log, manager, config.DefaultDataMount),
agenttools.NewReadMediaProvider(log, manager, config.DefaultDataMount),
agenttools.NewInboxProvider(log, inboxService), agenttools.NewInboxProvider(log, inboxService),
agenttools.NewEmailProvider(log, emailService, emailManager), agenttools.NewEmailProvider(log, emailService, emailManager),
agenttools.NewWebFetchProvider(log), agenttools.NewWebFetchProvider(log),
+34 -12
View File
@@ -62,6 +62,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
ch <- StreamEvent{Type: EventError, Error: fmt.Sprintf("assemble tools: %v", err)} ch <- StreamEvent{Type: EventError, Error: fmt.Sprintf("assemble tools: %v", err)}
return return
} }
tools, readMediaState := decorateReadMediaTools(cfg.Model, tools)
enabledSkills := make([]string, 0, len(cfg.EnabledSkillNames)) enabledSkills := make([]string, 0, len(cfg.EnabledSkillNames))
copy(enabledSkills, cfg.EnabledSkillNames) copy(enabledSkills, cfg.EnabledSkillNames)
@@ -103,7 +104,11 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
tagResolvers := DefaultTagResolvers() tagResolvers := DefaultTagResolvers()
tagExtractor := NewStreamTagExtractor(tagResolvers) tagExtractor := NewStreamTagExtractor(tagResolvers)
opts := a.buildGenerateOptions(cfg, tools) var prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams
if readMediaState != nil {
prepareStep = readMediaState.prepareStep
}
opts := a.buildGenerateOptions(cfg, tools, prepareStep)
streamResult, err := a.client.StreamText(ctx, opts...) streamResult, err := a.client.StreamText(ctx, opts...)
if err != nil { if err != nil {
@@ -251,7 +256,11 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
textLoopProbeBuffer.Flush() textLoopProbeBuffer.Flush()
} }
finalMessages := StripTagsFromMessages(streamResult.Messages) finalMessages := streamResult.Messages
if readMediaState != nil {
finalMessages = readMediaState.mergeMessages(streamResult.Steps, finalMessages)
}
finalMessages = StripTagsFromMessages(finalMessages)
var totalUsage sdk.Usage var totalUsage sdk.Usage
perStepUsages := make([]json.RawMessage, 0, len(streamResult.Steps)) perStepUsages := make([]json.RawMessage, 0, len(streamResult.Steps))
@@ -286,6 +295,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
if err != nil { if err != nil {
return nil, fmt.Errorf("assemble tools: %w", err) return nil, fmt.Errorf("assemble tools: %w", err)
} }
tools, readMediaState := decorateReadMediaTools(cfg.Model, tools)
enabledSkills := make([]string, 0, len(cfg.EnabledSkillNames)) enabledSkills := make([]string, 0, len(cfg.EnabledSkillNames))
copy(enabledSkills, cfg.EnabledSkillNames) copy(enabledSkills, cfg.EnabledSkillNames)
@@ -315,7 +325,11 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
tools = wrapToolsWithLoopGuard(tools, toolLoopGuard, toolLoopAbortCallIDs) tools = wrapToolsWithLoopGuard(tools, toolLoopGuard, toolLoopAbortCallIDs)
} }
opts := a.buildGenerateOptions(cfg, tools) var prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams
if readMediaState != nil {
prepareStep = readMediaState.prepareStep
}
opts := a.buildGenerateOptions(cfg, tools, prepareStep)
opts = append(opts, opts = append(opts,
sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams { sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams {
if cfg.LoopDetection.Enabled { if cfg.LoopDetection.Enabled {
@@ -376,7 +390,11 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
} }
} }
finalMessages := StripTagsFromMessages(genResult.Messages) finalMessages := genResult.Messages
if readMediaState != nil {
finalMessages = readMediaState.mergeMessages(genResult.Steps, finalMessages)
}
finalMessages = StripTagsFromMessages(finalMessages)
return &GenerateResult{ return &GenerateResult{
Messages: finalMessages, Messages: finalMessages,
@@ -389,7 +407,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
}, nil }, nil
} }
func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool) []sdk.GenerateOption { func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool, prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams) []sdk.GenerateOption {
opts := []sdk.GenerateOption{ opts := []sdk.GenerateOption{
sdk.WithModel(cfg.Model), sdk.WithModel(cfg.Model),
sdk.WithMessages(cfg.Messages), sdk.WithMessages(cfg.Messages),
@@ -399,6 +417,9 @@ func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool) []sdk.Genera
if len(tools) > 0 { if len(tools) > 0 {
opts = append(opts, sdk.WithTools(tools)) opts = append(opts, sdk.WithTools(tools))
} }
if prepareStep != nil {
opts = append(opts, sdk.WithPrepareStep(prepareStep))
}
opts = append(opts, BuildReasoningOptions(ModelConfig{ opts = append(opts, BuildReasoningOptions(ModelConfig{
ClientType: resolveClientType(cfg.Model), ClientType: resolveClientType(cfg.Model),
ReasoningConfig: &ReasoningConfig{ ReasoningConfig: &ReasoningConfig{
@@ -432,13 +453,14 @@ func (a *Agent) assembleTools(ctx context.Context, cfg RunConfig) ([]sdk.Tool, e
return nil, nil return nil, nil
} }
session := tools.SessionContext{ session := tools.SessionContext{
BotID: cfg.Identity.BotID, BotID: cfg.Identity.BotID,
ChatID: cfg.Identity.ChatID, ChatID: cfg.Identity.ChatID,
ChannelIdentityID: cfg.Identity.ChannelIdentityID, ChannelIdentityID: cfg.Identity.ChannelIdentityID,
SessionToken: cfg.Identity.SessionToken, SessionToken: cfg.Identity.SessionToken,
CurrentPlatform: cfg.Identity.CurrentPlatform, CurrentPlatform: cfg.Identity.CurrentPlatform,
ReplyTarget: cfg.Identity.ReplyTarget, ReplyTarget: cfg.Identity.ReplyTarget,
IsSubagent: cfg.Identity.IsSubagent, SupportsImageInput: cfg.SupportsImageInput,
IsSubagent: cfg.Identity.IsSubagent,
} }
var allTools []sdk.Tool var allTools []sdk.Tool
+174
View File
@@ -0,0 +1,174 @@
package agent
import (
"fmt"
"strings"
sdk "github.com/memohai/twilight-ai/sdk"
agenttools "github.com/memohai/memoh/internal/agent/tools"
)
func decorateReadMediaTools(model *sdk.Model, tools []sdk.Tool) ([]sdk.Tool, *readMediaDecorationState) {
if len(tools) == 0 {
return tools, nil
}
clientType := resolveClientType(model)
state := &readMediaDecorationState{
pendingImages: make(map[string]sdk.ImagePart),
}
wrapped := make([]sdk.Tool, 0, len(tools))
found := false
for _, tool := range tools {
if tool.Name != agenttools.ReadMediaToolName || tool.Execute == nil {
wrapped = append(wrapped, tool)
continue
}
found = true
originalExecute := tool.Execute
toolCopy := tool
toolCopy.Execute = func(ctx *sdk.ToolExecContext, input any) (any, error) {
output, err := originalExecute(ctx, input)
if err != nil {
return output, err
}
publicResult, image, ok := normalizeReadMediaOutput(output, clientType)
if !ok {
return output, nil
}
if ctx != nil && strings.TrimSpace(ctx.ToolCallID) != "" && strings.TrimSpace(image.Image) != "" {
if _, exists := state.pendingImages[ctx.ToolCallID]; !exists {
state.pendingOrder = append(state.pendingOrder, ctx.ToolCallID)
}
state.pendingImages[ctx.ToolCallID] = image
}
return publicResult, nil
}
wrapped = append(wrapped, toolCopy)
}
if !found {
return tools, nil
}
return wrapped, state
}
type readMediaDecorationState struct {
pendingOrder []string
pendingImages map[string]sdk.ImagePart
prepareCalls int
injections []readMediaInjection
}
type readMediaInjection struct {
afterStep int
message sdk.Message
}
func (s *readMediaDecorationState) prepareStep(params *sdk.GenerateParams) *sdk.GenerateParams {
if s == nil || params == nil {
return nil
}
afterStep := s.prepareCalls
s.prepareCalls++
if len(s.pendingOrder) == 0 {
return nil
}
parts := make([]sdk.MessagePart, 0, len(s.pendingOrder))
for _, toolCallID := range s.pendingOrder {
image, ok := s.pendingImages[toolCallID]
delete(s.pendingImages, toolCallID)
if !ok || strings.TrimSpace(image.Image) == "" {
continue
}
parts = append(parts, image)
}
s.pendingOrder = s.pendingOrder[:0]
if len(parts) == 0 {
return nil
}
message := sdk.Message{
Role: sdk.MessageRoleUser,
Content: parts,
}
s.injections = append(s.injections, readMediaInjection{
afterStep: afterStep,
message: message,
})
next := *params
next.Messages = append(append([]sdk.Message(nil), params.Messages...), message)
return &next
}
func (s *readMediaDecorationState) mergeMessages(steps []sdk.StepResult, fallback []sdk.Message) []sdk.Message {
if s == nil || len(s.injections) == 0 {
return fallback
}
if len(steps) == 0 {
merged := append([]sdk.Message(nil), fallback...)
for _, injection := range s.injections {
merged = append(merged, injection.message)
}
return merged
}
merged := make([]sdk.Message, 0, len(fallback)+len(s.injections))
injectionIndex := 0
for stepIndex, step := range steps {
merged = append(merged, step.Messages...)
for injectionIndex < len(s.injections) && s.injections[injectionIndex].afterStep == stepIndex {
merged = append(merged, s.injections[injectionIndex].message)
injectionIndex++
}
}
for injectionIndex < len(s.injections) {
merged = append(merged, s.injections[injectionIndex].message)
injectionIndex++
}
return merged
}
func normalizeReadMediaOutput(output any, clientType string) (any, sdk.ImagePart, bool) {
switch value := output.(type) {
case agenttools.ReadMediaToolOutput:
return value.Public, buildReadMediaImagePart(clientType, value.ImageBase64, value.ImageMediaType), true
case *agenttools.ReadMediaToolOutput:
if value == nil {
return nil, sdk.ImagePart{}, false
}
return value.Public, buildReadMediaImagePart(clientType, value.ImageBase64, value.ImageMediaType), true
default:
return nil, sdk.ImagePart{}, false
}
}
func buildReadMediaImagePart(clientType, imageBase64, mediaType string) sdk.ImagePart {
imageBase64 = strings.TrimSpace(imageBase64)
mediaType = strings.TrimSpace(mediaType)
if imageBase64 == "" {
return sdk.ImagePart{}
}
if mediaType == "" {
mediaType = "image/png"
}
image := imageBase64
if clientType != ClientTypeAnthropicMessages {
image = fmt.Sprintf("data:%s;base64,%s", mediaType, imageBase64)
}
return sdk.ImagePart{
Image: image,
MediaType: mediaType,
}
}
+394
View File
@@ -0,0 +1,394 @@
package agent
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"net"
"strings"
"testing"
sdk "github.com/memohai/twilight-ai/sdk"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"
agenttools "github.com/memohai/memoh/internal/agent/tools"
"github.com/memohai/memoh/internal/workspace/bridge"
pb "github.com/memohai/memoh/internal/workspace/bridgepb"
)
const agentReadMediaTestBufSize = 1 << 20
type agentReadMediaContainerService struct {
pb.UnimplementedContainerServiceServer
files map[string][]byte
}
func (s *agentReadMediaContainerService) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error {
data, ok := s.files[req.GetPath()]
if !ok {
return status.Error(codes.NotFound, "not found")
}
if len(data) == 0 {
return nil
}
return stream.Send(&pb.DataChunk{Data: data})
}
type agentReadMediaBridgeProvider struct {
client *bridge.Client
}
func (p *agentReadMediaBridgeProvider) MCPClient(_ context.Context, _ string) (*bridge.Client, error) {
return p.client, nil
}
func newAgentReadMediaBridgeProvider(t *testing.T, files map[string][]byte) bridge.Provider {
t.Helper()
lis := bufconn.Listen(agentReadMediaTestBufSize)
srv := grpc.NewServer()
pb.RegisterContainerServiceServer(srv, &agentReadMediaContainerService{files: files})
done := make(chan struct{})
go func() {
defer close(done)
_ = srv.Serve(lis)
}()
t.Cleanup(func() {
srv.Stop()
<-done
})
dialer := func(ctx context.Context, _ string) (net.Conn, error) {
return lis.DialContext(ctx)
}
conn, err := grpc.NewClient(
"passthrough://bufnet",
grpc.WithContextDialer(dialer),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
t.Fatalf("grpc.NewClient: %v", err)
}
t.Cleanup(func() { _ = conn.Close() })
return &agentReadMediaBridgeProvider{client: bridge.NewClientFromConn(conn)}
}
type agentReadMediaMockProvider struct {
name string
calls int
handler func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error)
}
func (m *agentReadMediaMockProvider) Name() string {
if m.name != "" {
return m.name
}
return "mock"
}
func (*agentReadMediaMockProvider) ListModels(context.Context) ([]sdk.Model, error) {
return nil, nil
}
func (*agentReadMediaMockProvider) Test(context.Context) *sdk.ProviderTestResult {
return &sdk.ProviderTestResult{Status: sdk.ProviderStatusOK, Message: "ok"}
}
func (*agentReadMediaMockProvider) TestModel(context.Context, string) (*sdk.ModelTestResult, error) {
return &sdk.ModelTestResult{Supported: true, Message: "supported"}, nil
}
func (m *agentReadMediaMockProvider) DoGenerate(_ context.Context, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
m.calls++
return m.handler(m.calls, params)
}
func (m *agentReadMediaMockProvider) DoStream(ctx context.Context, params sdk.GenerateParams) (*sdk.StreamResult, error) {
result, err := m.DoGenerate(ctx, params)
if err != nil {
return nil, err
}
ch := make(chan sdk.StreamPart, 8)
go func() {
defer close(ch)
ch <- &sdk.StartPart{}
ch <- &sdk.StartStepPart{}
if result.Text != "" {
ch <- &sdk.TextStartPart{ID: "mock"}
ch <- &sdk.TextDeltaPart{ID: "mock", Text: result.Text}
ch <- &sdk.TextEndPart{ID: "mock"}
}
for _, tc := range result.ToolCalls {
ch <- &sdk.StreamToolCallPart{
ToolCallID: tc.ToolCallID,
ToolName: tc.ToolName,
Input: tc.Input,
}
}
ch <- &sdk.FinishStepPart{FinishReason: result.FinishReason, Usage: result.Usage, Response: result.Response}
ch <- &sdk.FinishPart{FinishReason: result.FinishReason, TotalUsage: result.Usage}
}()
return &sdk.StreamResult{Stream: ch}, nil
}
func assertInjectedReadMediaMessage(t *testing.T, msg sdk.Message, expectedImage, expectedMediaType string) {
t.Helper()
if msg.Role != sdk.MessageRoleUser {
t.Fatalf("expected injected read_media message role %q, got %q", sdk.MessageRoleUser, msg.Role)
}
if len(msg.Content) != 1 {
t.Fatalf("expected one injected content part, got %d", len(msg.Content))
}
image, ok := msg.Content[0].(sdk.ImagePart)
if !ok {
t.Fatalf("expected sdk.ImagePart, got %T", msg.Content[0])
}
if image.Image != expectedImage {
t.Fatalf("unexpected injected image payload: %q", image.Image)
}
if image.MediaType != expectedMediaType {
t.Fatalf("unexpected injected media type: %q", image.MediaType)
}
}
func TestAgentGenerateReadMediaInjectsImageIntoNextStep(t *testing.T) {
t.Parallel()
pngBytes := []byte("\x89PNG\r\n\x1a\npayload")
expectedDataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(pngBytes)
modelProvider := &agentReadMediaMockProvider{
handler: func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
if call == 1 {
return &sdk.GenerateResult{
FinishReason: sdk.FinishReasonToolCalls,
ToolCalls: []sdk.ToolCall{{
ToolCallID: "call-1",
ToolName: "read_media",
Input: map[string]any{"path": "/data/images/demo.png"},
}},
}, nil
}
if len(params.Messages) < 4 {
t.Fatalf("expected prior tool and injected messages, got %d", len(params.Messages))
}
last := params.Messages[len(params.Messages)-1]
if last.Role != sdk.MessageRoleUser {
t.Fatalf("expected last message to be injected user image, got %s", last.Role)
}
if len(last.Content) != 1 {
t.Fatalf("expected one injected content part, got %d", len(last.Content))
}
image, ok := last.Content[0].(sdk.ImagePart)
if !ok {
t.Fatalf("expected sdk.ImagePart, got %T", last.Content[0])
}
if image.Image != expectedDataURL {
t.Fatalf("unexpected injected image payload: %q", image.Image)
}
if image.MediaType != "image/png" {
t.Fatalf("unexpected injected media type: %q", image.MediaType)
}
var toolResult sdk.ToolResultPart
foundToolMessage := false
for _, msg := range params.Messages {
if msg.Role != sdk.MessageRoleTool || len(msg.Content) == 0 {
continue
}
part, ok := msg.Content[0].(sdk.ToolResultPart)
if !ok {
continue
}
toolResult = part
foundToolMessage = true
break
}
if !foundToolMessage {
t.Fatal("expected tool result message before second step")
}
raw, err := json.Marshal(toolResult.Result)
if err != nil {
t.Fatalf("marshal tool result: %v", err)
}
if !bytes.Contains(raw, []byte(`"ok":true`)) {
t.Fatalf("expected compact success metadata, got %s", raw)
}
if bytes.Contains(raw, []byte(expectedDataURL)) || bytes.Contains(raw, []byte("payload")) {
t.Fatalf("tool result leaked image bytes: %s", raw)
}
return &sdk.GenerateResult{
Text: "done",
FinishReason: sdk.FinishReasonStop,
}, nil
},
}
a := New(Deps{})
a.SetToolProviders([]agenttools.ToolProvider{
agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{
"/data/images/demo.png": pngBytes,
}), "/data"),
})
result, err := a.Generate(context.Background(), RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("look at the image")},
SupportsImageInput: true,
Identity: SessionContext{
BotID: "bot-1",
},
})
if err != nil {
t.Fatalf("Generate returned error: %v", err)
}
if result.Text != "done" {
t.Fatalf("unexpected result text: %q", result.Text)
}
if len(result.Messages) != 4 {
t.Fatalf("expected persisted step + injected history, got %d messages", len(result.Messages))
}
assertInjectedReadMediaMessage(t, result.Messages[2], expectedDataURL, "image/png")
if result.Messages[3].Role != sdk.MessageRoleAssistant {
t.Fatalf("expected final persisted message to be assistant, got %s", result.Messages[3].Role)
}
if modelProvider.calls != 2 {
t.Fatalf("expected 2 model calls, got %d", modelProvider.calls)
}
}
func TestAgentGenerateReadMediaInjectsAnthropicSafeImageIntoNextStep(t *testing.T) {
t.Parallel()
pngBytes := []byte("\x89PNG\r\n\x1a\npayload")
expectedBase64 := base64.StdEncoding.EncodeToString(pngBytes)
modelProvider := &agentReadMediaMockProvider{
name: "anthropic-messages",
handler: func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
if call == 1 {
return &sdk.GenerateResult{
FinishReason: sdk.FinishReasonToolCalls,
ToolCalls: []sdk.ToolCall{{
ToolCallID: "call-1",
ToolName: "read_media",
Input: map[string]any{"path": "/data/images/demo.png"},
}},
}, nil
}
last := params.Messages[len(params.Messages)-1]
image, ok := last.Content[0].(sdk.ImagePart)
if !ok {
t.Fatalf("expected sdk.ImagePart, got %T", last.Content[0])
}
if image.Image != expectedBase64 {
t.Fatalf("expected raw base64 for anthropic, got %q", image.Image)
}
if image.MediaType != "image/png" {
t.Fatalf("unexpected injected media type: %q", image.MediaType)
}
if strings.HasPrefix(image.Image, "data:") {
t.Fatalf("anthropic image payload must not be a data URL: %q", image.Image)
}
return &sdk.GenerateResult{
Text: "done",
FinishReason: sdk.FinishReasonStop,
}, nil
},
}
a := New(Deps{})
a.SetToolProviders([]agenttools.ToolProvider{
agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{
"/data/images/demo.png": pngBytes,
}), "/data"),
})
_, err := a.Generate(context.Background(), RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("look at the image")},
SupportsImageInput: true,
Identity: SessionContext{
BotID: "bot-1",
},
})
if err != nil {
t.Fatalf("Generate returned error: %v", err)
}
}
func TestAgentStreamReadMediaPersistsInjectedImageInTerminalMessages(t *testing.T) {
t.Parallel()
pngBytes := []byte("\x89PNG\r\n\x1a\npayload")
expectedDataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(pngBytes)
modelProvider := &agentReadMediaMockProvider{
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
if call == 1 {
return &sdk.GenerateResult{
FinishReason: sdk.FinishReasonToolCalls,
ToolCalls: []sdk.ToolCall{{
ToolCallID: "call-1",
ToolName: "read_media",
Input: map[string]any{"path": "/data/images/demo.png"},
}},
}, nil
}
return &sdk.GenerateResult{
Text: "done",
FinishReason: sdk.FinishReasonStop,
}, nil
},
}
a := New(Deps{})
a.SetToolProviders([]agenttools.ToolProvider{
agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{
"/data/images/demo.png": pngBytes,
}), "/data"),
})
var terminal StreamEvent
for event := range a.Stream(context.Background(), RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("look at the image")},
SupportsImageInput: true,
Identity: SessionContext{
BotID: "bot-1",
},
}) {
if event.IsTerminal() {
terminal = event
}
}
if terminal.Type != EventAgentEnd {
t.Fatalf("expected terminal event %q, got %q", EventAgentEnd, terminal.Type)
}
var messages []sdk.Message
if err := json.Unmarshal(terminal.Messages, &messages); err != nil {
t.Fatalf("unmarshal terminal messages: %v", err)
}
if len(messages) != 4 {
t.Fatalf("expected persisted step + injected history, got %d messages", len(messages))
}
assertInjectedReadMediaMessage(t, messages[2], expectedDataURL, "image/png")
if messages[3].Role != sdk.MessageRoleAssistant {
t.Fatalf("expected final persisted message to be assistant, got %s", messages[3].Role)
}
}
+217
View File
@@ -0,0 +1,217 @@
package tools
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"path"
"strings"
sdk "github.com/memohai/twilight-ai/sdk"
"github.com/memohai/memoh/internal/workspace/bridge"
)
const (
ReadMediaToolName = "read_media"
toolReadMedia = ReadMediaToolName
defaultReadMediaRoot = "/data"
defaultReadMediaMaxBytes = 20 * 1024 * 1024
)
var readMediaSupportedMimeTypes = map[string]struct{}{
"image/gif": {},
"image/jpeg": {},
"image/png": {},
"image/webp": {},
}
// ReadMediaToolResult is the public result returned to the model.
type ReadMediaToolResult struct {
OK bool `json:"ok"`
Path string `json:"path,omitempty"`
Mime string `json:"mime,omitempty"`
Size int `json:"size,omitempty"`
Error string `json:"error,omitempty"`
}
// ReadMediaToolOutput is the internal execution result used by the agent to
// inject the image into the next Twilight AI step while keeping the visible
// tool result lightweight.
type ReadMediaToolOutput struct {
Public ReadMediaToolResult
ImageBase64 string
ImageMediaType string
}
type readMediaToolOutput = ReadMediaToolOutput
type ReadMediaProvider struct {
clients bridge.Provider
rootDir string
maxBytes int64
logger *slog.Logger
}
func NewReadMediaProvider(log *slog.Logger, clients bridge.Provider, rootDir string) *ReadMediaProvider {
if log == nil {
log = slog.Default()
}
root := strings.TrimSpace(rootDir)
if root == "" {
root = defaultReadMediaRoot
}
return &ReadMediaProvider{
clients: clients,
rootDir: path.Clean(root),
maxBytes: defaultReadMediaMaxBytes,
logger: log.With(slog.String("tool", "read_media")),
}
}
func (p *ReadMediaProvider) Tools(_ context.Context, session SessionContext) ([]sdk.Tool, error) {
if p == nil || p.clients == nil || !session.SupportsImageInput {
return nil, nil
}
root := p.rootDir
if root == "" {
root = defaultReadMediaRoot
}
sess := session
return []sdk.Tool{
{
Name: toolReadMedia,
Description: fmt.Sprintf("Load an image file from %s into model context so you can inspect it. Relative paths are resolved under %s.", root, root),
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": fmt.Sprintf("Image file path under %s. Absolute paths must stay under %s; relative paths are resolved under %s.", root, root, root),
},
},
"required": []string{"path"},
},
Execute: func(ctx *sdk.ToolExecContext, input any) (any, error) {
return p.execReadMedia(ctx.Context, sess, inputAsMap(input))
},
},
}, nil
}
func (p *ReadMediaProvider) execReadMedia(ctx context.Context, session SessionContext, args map[string]any) (any, error) {
client, err := p.getClient(ctx, session.BotID)
if err != nil {
return readMediaErrorResult(err.Error()), nil
}
resolvedPath, err := p.resolveImagePath(StringArg(args, "path"))
if err != nil {
return readMediaErrorResult(err.Error()), nil
}
reader, err := client.ReadRaw(ctx, resolvedPath)
if err != nil {
return readMediaErrorResult(err.Error()), nil
}
defer func() { _ = reader.Close() }()
data, err := io.ReadAll(io.LimitReader(reader, p.maxBytes+1))
if err != nil {
return readMediaErrorResult("read_media failed to load image: " + err.Error()), nil
}
if int64(len(data)) > p.maxBytes {
return readMediaErrorResult(fmt.Sprintf("read_media failed to load image: file exceeds %d bytes", p.maxBytes)), nil
}
mimeType, err := detectReadMediaMime(data)
if err != nil {
return readMediaErrorResult(err.Error()), nil
}
encoded := base64.StdEncoding.EncodeToString(data)
return ReadMediaToolOutput{
Public: ReadMediaToolResult{
OK: true,
Path: resolvedPath,
Mime: mimeType,
Size: len(data),
},
ImageBase64: encoded,
ImageMediaType: mimeType,
}, nil
}
func (p *ReadMediaProvider) getClient(ctx context.Context, botID string) (*bridge.Client, error) {
botID = strings.TrimSpace(botID)
if botID == "" {
return nil, errors.New("bot_id is required")
}
client, err := p.clients.MCPClient(ctx, botID)
if err != nil {
return nil, fmt.Errorf("container not reachable: %w", err)
}
return client, nil
}
func (p *ReadMediaProvider) resolveImagePath(raw string) (string, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", errors.New("path is required")
}
root := p.rootDir
if root == "" {
root = defaultReadMediaRoot
}
root = path.Clean(root)
resolved := trimmed
if !strings.HasPrefix(resolved, "/") {
resolved = path.Join(root, resolved)
}
resolved = path.Clean(resolved)
if resolved == root || !strings.HasPrefix(resolved, root+"/") {
return "", fmt.Errorf("path must be under %s", root)
}
return resolved, nil
}
func readMediaErrorResult(message string) ReadMediaToolOutput {
msg := strings.TrimSpace(message)
if msg == "" {
msg = "read_media failed"
}
return ReadMediaToolOutput{
Public: ReadMediaToolResult{
OK: false,
Error: msg,
},
}
}
func detectReadMediaMime(data []byte) (string, error) {
sniffedMime := ""
if len(data) > 0 {
sniffedMime = strings.ToLower(strings.TrimSpace(http.DetectContentType(data)))
}
switch {
case sniffedMime == "":
return "", errors.New("read_media only supports PNG, JPEG, GIF, or WebP image bytes")
case isSupportedReadMediaMime(sniffedMime):
return sniffedMime, nil
default:
return "", errors.New("read_media only supports PNG, JPEG, GIF, or WebP image bytes")
}
}
func isSupportedReadMediaMime(mimeType string) bool {
_, ok := readMediaSupportedMimeTypes[strings.ToLower(strings.TrimSpace(mimeType))]
return ok
}
+306
View File
@@ -0,0 +1,306 @@
package tools
import (
"context"
"encoding/base64"
"net"
"strings"
"testing"
sdk "github.com/memohai/twilight-ai/sdk"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"
"github.com/memohai/memoh/internal/workspace/bridge"
pb "github.com/memohai/memoh/internal/workspace/bridgepb"
)
const readMediaTestBufSize = 1 << 20
type readMediaTestContainerService struct {
pb.UnimplementedContainerServiceServer
files map[string][]byte
}
func (s *readMediaTestContainerService) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error {
data, ok := s.files[req.GetPath()]
if !ok {
return status.Error(codes.NotFound, "not found")
}
if len(data) == 0 {
return nil
}
return stream.Send(&pb.DataChunk{Data: data})
}
type readMediaStaticProvider struct {
client *bridge.Client
}
func (p *readMediaStaticProvider) MCPClient(_ context.Context, _ string) (*bridge.Client, error) {
return p.client, nil
}
func newReadMediaBridgeProvider(t *testing.T, files map[string][]byte) bridge.Provider {
t.Helper()
lis := bufconn.Listen(readMediaTestBufSize)
srv := grpc.NewServer()
pb.RegisterContainerServiceServer(srv, &readMediaTestContainerService{files: files})
done := make(chan struct{})
go func() {
defer close(done)
_ = srv.Serve(lis)
}()
t.Cleanup(func() {
srv.Stop()
<-done
})
dialer := func(ctx context.Context, _ string) (net.Conn, error) {
return lis.DialContext(ctx)
}
conn, err := grpc.NewClient(
"passthrough://bufnet",
grpc.WithContextDialer(dialer),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
t.Fatalf("grpc.NewClient: %v", err)
}
t.Cleanup(func() { _ = conn.Close() })
return &readMediaStaticProvider{client: bridge.NewClientFromConn(conn)}
}
func findToolByName(tools []sdk.Tool, name string) (sdk.Tool, bool) {
for _, tool := range tools {
if tool.Name == name {
return tool, true
}
}
return sdk.Tool{}, false
}
func TestReadMediaProviderToolsOnlyWhenImageInputIsSupported(t *testing.T) {
t.Parallel()
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, nil), "/data")
toolsWithoutImage, err := provider.Tools(context.Background(), SessionContext{
BotID: "bot-1",
SupportsImageInput: false,
})
if err != nil {
t.Fatalf("Tools without image input returned error: %v", err)
}
if len(toolsWithoutImage) != 0 {
t.Fatalf("expected no tools without image input support, got %d", len(toolsWithoutImage))
}
toolsWithImage, err := provider.Tools(context.Background(), SessionContext{
BotID: "bot-1",
SupportsImageInput: true,
})
if err != nil {
t.Fatalf("Tools with image input returned error: %v", err)
}
tool, ok := findToolByName(toolsWithImage, toolReadMedia)
if !ok {
t.Fatalf("expected %q tool to be exposed", toolReadMedia)
}
if tool.Execute == nil {
t.Fatal("expected read_media tool to be executable")
}
}
func TestReadMediaProviderExecuteReadsImageUnderData(t *testing.T) {
t.Parallel()
pngBytes := []byte("\x89PNG\r\n\x1a\npayload")
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, map[string][]byte{
"/data/images/demo.png": pngBytes,
}), "/data")
tools, err := provider.Tools(context.Background(), SessionContext{
BotID: "bot-1",
SupportsImageInput: true,
})
if err != nil {
t.Fatalf("Tools returned error: %v", err)
}
tool, ok := findToolByName(tools, toolReadMedia)
if !ok {
t.Fatalf("expected %q tool", toolReadMedia)
}
output, err := tool.Execute(&sdk.ToolExecContext{
Context: context.Background(),
ToolCallID: "call-1",
ToolName: toolReadMedia,
}, map[string]any{"path": "images/demo.png"})
if err != nil {
t.Fatalf("Execute returned error: %v", err)
}
result, ok := output.(readMediaToolOutput)
if !ok {
t.Fatalf("expected readMediaToolOutput, got %T", output)
}
if !result.Public.OK {
t.Fatalf("expected success result, got %+v", result.Public)
}
if result.Public.Path != "/data/images/demo.png" {
t.Fatalf("unexpected path: %q", result.Public.Path)
}
if result.Public.Mime != "image/png" {
t.Fatalf("unexpected mime: %q", result.Public.Mime)
}
if result.Public.Size != len(pngBytes) {
t.Fatalf("unexpected size: %d", result.Public.Size)
}
expectedBase64 := base64.StdEncoding.EncodeToString(pngBytes)
if result.ImageBase64 != expectedBase64 {
t.Fatalf("unexpected image payload: %q", result.ImageBase64)
}
if result.ImageMediaType != "image/png" {
t.Fatalf("unexpected image media type: %q", result.ImageMediaType)
}
}
func TestReadMediaProviderExecuteRejectsPathOutsideData(t *testing.T) {
t.Parallel()
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, nil), "/data")
tools, err := provider.Tools(context.Background(), SessionContext{
BotID: "bot-1",
SupportsImageInput: true,
})
if err != nil {
t.Fatalf("Tools returned error: %v", err)
}
tool, ok := findToolByName(tools, toolReadMedia)
if !ok {
t.Fatalf("expected %q tool", toolReadMedia)
}
output, err := tool.Execute(&sdk.ToolExecContext{
Context: context.Background(),
ToolCallID: "call-2",
ToolName: toolReadMedia,
}, map[string]any{"path": "/tmp/demo.png"})
if err != nil {
t.Fatalf("Execute returned error: %v", err)
}
result, ok := output.(readMediaToolOutput)
if !ok {
t.Fatalf("expected readMediaToolOutput, got %T", output)
}
if result.Public.OK {
t.Fatalf("expected error result, got %+v", result.Public)
}
if !strings.Contains(result.Public.Error, "path must be under /data") {
t.Fatalf("unexpected error: %q", result.Public.Error)
}
if result.ImageBase64 != "" {
t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64)
}
}
func TestReadMediaProviderExecuteRejectsExtensionOnlySVG(t *testing.T) {
t.Parallel()
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, map[string][]byte{
"/data/images/demo.svg": []byte(`<svg xmlns="http://www.w3.org/2000/svg"></svg>`),
}), "/data")
tools, err := provider.Tools(context.Background(), SessionContext{
BotID: "bot-1",
SupportsImageInput: true,
})
if err != nil {
t.Fatalf("Tools returned error: %v", err)
}
tool, ok := findToolByName(tools, toolReadMedia)
if !ok {
t.Fatalf("expected %q tool", toolReadMedia)
}
output, err := tool.Execute(&sdk.ToolExecContext{
Context: context.Background(),
ToolCallID: "call-3",
ToolName: toolReadMedia,
}, map[string]any{"path": "images/demo.svg"})
if err != nil {
t.Fatalf("Execute returned error: %v", err)
}
result, ok := output.(readMediaToolOutput)
if !ok {
t.Fatalf("expected readMediaToolOutput, got %T", output)
}
if result.Public.OK {
t.Fatalf("expected error result, got %+v", result.Public)
}
if !strings.Contains(result.Public.Error, "PNG, JPEG, GIF, or WebP") {
t.Fatalf("unexpected error: %q", result.Public.Error)
}
if result.ImageBase64 != "" {
t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64)
}
}
func TestReadMediaProviderExecuteRejectsCorruptedRasterBytes(t *testing.T) {
t.Parallel()
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, map[string][]byte{
"/data/images/demo.png": []byte("definitely not a png"),
}), "/data")
tools, err := provider.Tools(context.Background(), SessionContext{
BotID: "bot-1",
SupportsImageInput: true,
})
if err != nil {
t.Fatalf("Tools returned error: %v", err)
}
tool, ok := findToolByName(tools, toolReadMedia)
if !ok {
t.Fatalf("expected %q tool", toolReadMedia)
}
output, err := tool.Execute(&sdk.ToolExecContext{
Context: context.Background(),
ToolCallID: "call-4",
ToolName: toolReadMedia,
}, map[string]any{"path": "images/demo.png"})
if err != nil {
t.Fatalf("Execute returned error: %v", err)
}
result, ok := output.(readMediaToolOutput)
if !ok {
t.Fatalf("expected readMediaToolOutput, got %T", output)
}
if result.Public.OK {
t.Fatalf("expected error result, got %+v", result.Public)
}
if !strings.Contains(result.Public.Error, "PNG, JPEG, GIF, or WebP") {
t.Fatalf("unexpected error: %q", result.Public.Error)
}
if result.ImageBase64 != "" {
t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64)
}
}
+8 -7
View File
@@ -12,13 +12,14 @@ import (
// SessionContext carries request-scoped identity for tool execution. // SessionContext carries request-scoped identity for tool execution.
type SessionContext struct { type SessionContext struct {
BotID string BotID string
ChatID string ChatID string
ChannelIdentityID string ChannelIdentityID string
SessionToken string //nolint:gosec // carries session credential material at runtime SessionToken string //nolint:gosec // carries session credential material at runtime
CurrentPlatform string CurrentPlatform string
ReplyTarget string ReplyTarget string
IsSubagent bool SupportsImageInput bool
IsSubagent bool
} }
// ToolProvider supplies a set of tools for the agent. // ToolProvider supplies a set of tools for the agent.
+15 -14
View File
@@ -54,20 +54,21 @@ type LoopDetectionConfig struct {
// RunConfig holds everything needed for a single agent invocation. // RunConfig holds everything needed for a single agent invocation.
type RunConfig struct { type RunConfig struct {
Model *sdk.Model Model *sdk.Model
ReasoningEffort string ReasoningEffort string
Messages []sdk.Message Messages []sdk.Message
Query string Query string
System string System string
Tools []sdk.Tool Tools []sdk.Tool
Channels []string SupportsImageInput bool
CurrentChannel string Channels []string
Identity SessionContext CurrentChannel string
Skills []SkillEntry Identity SessionContext
EnabledSkillNames []string Skills []SkillEntry
Inbox []InboxItem EnabledSkillNames []string
LoopDetection LoopDetectionConfig Inbox []InboxItem
ActiveContextTime int LoopDetection LoopDetectionConfig
ActiveContextTime int
} }
// GenerateResult holds the result of a non-streaming agent invocation. // GenerateResult holds the result of a non-streaming agent invocation.
@@ -0,0 +1,45 @@
package flow
import (
"context"
"strings"
"testing"
agentpkg "github.com/memohai/memoh/internal/agent"
)
func TestPrepareRunConfigIncludesReadMediaWhenImageInputIsSupported(t *testing.T) {
t.Parallel()
resolver := &Resolver{}
cfg := agentpkg.RunConfig{
Query: "describe this image",
SupportsImageInput: true,
Identity: agentpkg.SessionContext{
BotID: "bot-1",
},
}
prepared := resolver.prepareRunConfig(context.Background(), cfg)
if !strings.Contains(prepared.System, "`read_media`") {
t.Fatalf("expected system prompt to include read_media tool, got:\n%s", prepared.System)
}
}
func TestPrepareRunConfigOmitsReadMediaWhenImageInputIsUnsupported(t *testing.T) {
t.Parallel()
resolver := &Resolver{}
cfg := agentpkg.RunConfig{
Query: "describe this image",
SupportsImageInput: false,
Identity: agentpkg.SessionContext{
BotID: "bot-1",
},
}
prepared := resolver.prepareRunConfig(context.Background(), cfg)
if strings.Contains(prepared.System, "`read_media`") {
t.Fatalf("expected system prompt to omit read_media tool, got:\n%s", prepared.System)
}
}
+8 -10
View File
@@ -309,12 +309,13 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
sdkMessages := modelMessagesToSDKMessages(nonNilModelMessages(messages)) sdkMessages := modelMessagesToSDKMessages(nonNilModelMessages(messages))
runCfg := agentpkg.RunConfig{ runCfg := agentpkg.RunConfig{
Model: sdkModel, Model: sdkModel,
ReasoningEffort: reasoningEffort, ReasoningEffort: reasoningEffort,
Messages: sdkMessages, Messages: sdkMessages,
Query: headerifiedQuery, Query: headerifiedQuery,
Channels: nonNilStrings(req.Channels), SupportsImageInput: chatModel.HasInputModality(models.ModelInputImage),
CurrentChannel: req.CurrentChannel, Channels: nonNilStrings(req.Channels),
CurrentChannel: req.CurrentChannel,
Identity: agentpkg.SessionContext{ Identity: agentpkg.SessionContext{
BotID: req.BotID, BotID: req.BotID,
ChatID: req.ChatID, ChatID: req.ChatID,
@@ -368,10 +369,7 @@ func (r *Resolver) Chat(ctx context.Context, req conversation.ChatRequest) (conv
// prepareRunConfig generates the system prompt and appends the user message. // prepareRunConfig generates the system prompt and appends the user message.
func (r *Resolver) prepareRunConfig(ctx context.Context, cfg agentpkg.RunConfig) agentpkg.RunConfig { func (r *Resolver) prepareRunConfig(ctx context.Context, cfg agentpkg.RunConfig) agentpkg.RunConfig {
supportsImageInput := false supportsImageInput := cfg.SupportsImageInput
for _, m := range cfg.Identity.CurrentPlatform {
_ = m
}
// Build system prompt // Build system prompt
var files []agentpkg.SystemFile var files []agentpkg.SystemFile
if r.agent != nil { if r.agent != nil {