refactor(agent): merge read_media tool into read tool (#326)

This commit is contained in:
Ringo.Typowriter
2026-04-04 20:56:00 +08:00
committed by GitHub
parent 5cfbaa40e2
commit 09c523f0b8
8 changed files with 127 additions and 323 deletions
-1
View File
@@ -640,7 +640,6 @@ func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *c
agenttools.NewMemoryProvider(log, memoryRegistry, settingsService), agenttools.NewMemoryProvider(log, memoryRegistry, settingsService),
agenttools.NewWebProvider(log, settingsService, searchProviderService), agenttools.NewWebProvider(log, settingsService, searchProviderService),
agenttools.NewContainerProvider(log, manager, config.DefaultDataMount), agenttools.NewContainerProvider(log, manager, config.DefaultDataMount),
agenttools.NewReadMediaProvider(log, manager, config.DefaultDataMount),
agenttools.NewEmailProvider(log, emailService, emailManager), agenttools.NewEmailProvider(log, emailService, emailManager),
agenttools.NewWebFetchProvider(log), agenttools.NewWebFetchProvider(log),
agenttools.NewSpawnProvider(log, settingsService, modelsService, queries, sessionService), agenttools.NewSpawnProvider(log, settingsService, modelsService, queries, sessionService),
-1
View File
@@ -540,7 +540,6 @@ func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *c
agenttools.NewMemoryProvider(log, memoryRegistry, settingsService), agenttools.NewMemoryProvider(log, memoryRegistry, settingsService),
agenttools.NewWebProvider(log, settingsService, searchProviderService), agenttools.NewWebProvider(log, settingsService, searchProviderService),
agenttools.NewContainerProvider(log, manager, config.DefaultDataMount), agenttools.NewContainerProvider(log, manager, config.DefaultDataMount),
agenttools.NewReadMediaProvider(log, manager, config.DefaultDataMount),
agenttools.NewEmailProvider(log, emailService, emailManager), agenttools.NewEmailProvider(log, emailService, emailManager),
agenttools.NewWebFetchProvider(log), agenttools.NewWebFetchProvider(log),
agenttools.NewSpawnProvider(log, settingsService, modelsService, queries, sessionService), agenttools.NewSpawnProvider(log, settingsService, modelsService, queries, sessionService),
+3 -4
View File
@@ -110,12 +110,11 @@ func GenerateSystemPrompt(params SystemPromptParams) string {
timezoneName = "UTC" timezoneName = "UTC"
} }
basicTools := []string{ readToolDesc := "- `read`: read file content"
"- `read`: read file content",
}
if params.SupportsImageInput { if params.SupportsImageInput {
basicTools = append(basicTools, "- `read_media`: view the media") readToolDesc += " (also supports images: PNG, JPEG, GIF, WebP)"
} }
basicTools := []string{readToolDesc}
basicTools = append(basicTools, basicTools = append(basicTools,
"- `write`: write file content", "- `write`: write file content",
"- `list`: list directory entries", "- `list`: list directory entries",
+37 -12
View File
@@ -23,11 +23,24 @@ import (
const agentReadMediaTestBufSize = 1 << 20 const agentReadMediaTestBufSize = 1 << 20
// agentReadMediaContainerService implements both ReadFile and ReadRaw so
// that the merged read tool (ContainerProvider) can detect binary files
// and then delegate to ReadImageFromContainer.
type agentReadMediaContainerService struct { type agentReadMediaContainerService struct {
pb.UnimplementedContainerServiceServer pb.UnimplementedContainerServiceServer
files map[string][]byte files map[string][]byte
} }
func (s *agentReadMediaContainerService) ReadFile(_ context.Context, req *pb.ReadFileRequest) (*pb.ReadFileResponse, error) {
data, ok := s.files[req.GetPath()]
if !ok {
return nil, status.Error(codes.NotFound, "not found")
}
_ = data
// All files in this test fixture are images → binary.
return &pb.ReadFileResponse{Binary: true}, nil
}
func (s *agentReadMediaContainerService) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error { func (s *agentReadMediaContainerService) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error {
data, ok := s.files[req.GetPath()] data, ok := s.files[req.GetPath()]
if !ok { if !ok {
@@ -172,7 +185,7 @@ func TestAgentGenerateReadMediaInjectsImageIntoNextStep(t *testing.T) {
FinishReason: sdk.FinishReasonToolCalls, FinishReason: sdk.FinishReasonToolCalls,
ToolCalls: []sdk.ToolCall{{ ToolCalls: []sdk.ToolCall{{
ToolCallID: "call-1", ToolCallID: "call-1",
ToolName: "read_media", ToolName: "read",
Input: map[string]any{"path": "/data/images/demo.png"}, Input: map[string]any{"path": "/data/images/demo.png"},
}}, }},
}, nil }, nil
@@ -235,11 +248,15 @@ func TestAgentGenerateReadMediaInjectsImageIntoNextStep(t *testing.T) {
}, },
} }
// ContainerProvider normalizes paths by stripping the workdir prefix,
// so the mock files map must use the normalized (relative) path.
bp := newAgentReadMediaBridgeProvider(t, map[string][]byte{
"images/demo.png": pngBytes,
})
a := New(Deps{}) a := New(Deps{})
a.SetToolProviders([]agenttools.ToolProvider{ a.SetToolProviders([]agenttools.ToolProvider{
agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{ agenttools.NewContainerProvider(nil, bp, "/data"),
"/data/images/demo.png": pngBytes,
}), "/data"),
}) })
result, err := a.Generate(context.Background(), RunConfig{ result, err := a.Generate(context.Background(), RunConfig{
@@ -283,7 +300,7 @@ func TestAgentGenerateReadMediaInjectsAnthropicSafeImageIntoNextStep(t *testing.
FinishReason: sdk.FinishReasonToolCalls, FinishReason: sdk.FinishReasonToolCalls,
ToolCalls: []sdk.ToolCall{{ ToolCalls: []sdk.ToolCall{{
ToolCallID: "call-1", ToolCallID: "call-1",
ToolName: "read_media", ToolName: "read",
Input: map[string]any{"path": "/data/images/demo.png"}, Input: map[string]any{"path": "/data/images/demo.png"},
}}, }},
}, nil }, nil
@@ -311,11 +328,15 @@ func TestAgentGenerateReadMediaInjectsAnthropicSafeImageIntoNextStep(t *testing.
}, },
} }
// ContainerProvider normalizes paths by stripping the workdir prefix,
// so the mock files map must use the normalized (relative) path.
bp := newAgentReadMediaBridgeProvider(t, map[string][]byte{
"images/demo.png": pngBytes,
})
a := New(Deps{}) a := New(Deps{})
a.SetToolProviders([]agenttools.ToolProvider{ a.SetToolProviders([]agenttools.ToolProvider{
agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{ agenttools.NewContainerProvider(nil, bp, "/data"),
"/data/images/demo.png": pngBytes,
}), "/data"),
}) })
_, err := a.Generate(context.Background(), RunConfig{ _, err := a.Generate(context.Background(), RunConfig{
@@ -345,7 +366,7 @@ func TestAgentStreamReadMediaPersistsInjectedImageInTerminalMessages(t *testing.
FinishReason: sdk.FinishReasonToolCalls, FinishReason: sdk.FinishReasonToolCalls,
ToolCalls: []sdk.ToolCall{{ ToolCalls: []sdk.ToolCall{{
ToolCallID: "call-1", ToolCallID: "call-1",
ToolName: "read_media", ToolName: "read",
Input: map[string]any{"path": "/data/images/demo.png"}, Input: map[string]any{"path": "/data/images/demo.png"},
}}, }},
}, nil }, nil
@@ -357,11 +378,15 @@ func TestAgentStreamReadMediaPersistsInjectedImageInTerminalMessages(t *testing.
}, },
} }
// ContainerProvider normalizes paths by stripping the workdir prefix,
// so the mock files map must use the normalized (relative) path.
bp := newAgentReadMediaBridgeProvider(t, map[string][]byte{
"images/demo.png": pngBytes,
})
a := New(Deps{}) a := New(Deps{})
a.SetToolProviders([]agenttools.ToolProvider{ a.SetToolProviders([]agenttools.ToolProvider{
agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{ agenttools.NewContainerProvider(nil, bp, "/data"),
"/data/images/demo.png": pngBytes,
}), "/data"),
}) })
var terminal StreamEvent var terminal StreamEvent
+11 -2
View File
@@ -36,10 +36,16 @@ func NewContainerProvider(log *slog.Logger, clients bridge.Provider, execWorkDir
func (p *ContainerProvider) Tools(_ context.Context, session SessionContext) ([]sdk.Tool, error) { func (p *ContainerProvider) Tools(_ context.Context, session SessionContext) ([]sdk.Tool, error) {
wd := p.execWorkDir wd := p.execWorkDir
sess := session sess := session
readDesc := fmt.Sprintf("Read file content inside the bot container. Supports pagination for large files. Max %d lines / %d bytes per call.", readMaxLines, readMaxBytes)
if sess.SupportsImageInput {
readDesc += " Also supports reading image files (PNG, JPEG, GIF, WebP) — binary images are loaded into model context automatically."
}
return []sdk.Tool{ return []sdk.Tool{
{ {
Name: "read", Name: "read",
Description: fmt.Sprintf("Read file content inside the bot container. Supports pagination for large files. Max %d lines / %d bytes per call.", readMaxLines, readMaxBytes), Description: readDesc,
Parameters: map[string]any{ Parameters: map[string]any{
"type": "object", "type": "object",
"properties": map[string]any{ "properties": map[string]any{
@@ -187,7 +193,10 @@ func (p *ContainerProvider) execRead(ctx context.Context, session SessionContext
return nil, err return nil, err
} }
if resp.GetBinary() { if resp.GetBinary() {
return nil, errors.New("file appears to be binary. Read tool only supports text files") if !session.SupportsImageInput {
return nil, errors.New("file appears to be binary. Read tool only supports text files (image reading not available for this model)")
}
return ReadImageFromContainer(ctx, client, filePath, defaultReadMediaMaxBytes), nil
} }
content := addLineNumbers(resp.GetContent(), lineOffset) content := addLineNumbers(resp.GetContent(), lineOffset)
return map[string]any{"content": content, "total_lines": resp.GetTotalLines()}, nil return map[string]any{"content": content, "total_lines": resp.GetTotalLines()}, nil
+40 -120
View File
@@ -6,20 +6,16 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/http" "net/http"
"path"
"strings" "strings"
sdk "github.com/memohai/twilight-ai/sdk"
"github.com/memohai/memoh/internal/workspace/bridge" "github.com/memohai/memoh/internal/workspace/bridge"
) )
const ( const (
ReadMediaToolName = "read_media" // ReadMediaToolName is the tool name that the agent decoration layer
toolReadMedia = ReadMediaToolName // matches on to intercept image payloads. After the merge this is "read".
defaultReadMediaRoot = "/data" ReadMediaToolName = "read"
defaultReadMediaMaxBytes = 20 * 1024 * 1024 defaultReadMediaMaxBytes = 20 * 1024 * 1024
) )
@@ -48,144 +44,68 @@ type ReadMediaToolOutput struct {
ImageMediaType string ImageMediaType string
} }
type readMediaToolOutput = ReadMediaToolOutput // mimeSniffSize is the number of bytes http.DetectContentType needs.
const mimeSniffSize = 512
type ReadMediaProvider struct { // ReadImageFromContainer reads a binary file through the bridge client,
clients bridge.Provider // validates that it is a supported image format, and returns a
rootDir string // ReadMediaToolOutput ready for the agent decoration pipeline.
maxBytes int64 //
logger *slog.Logger // It reads only a small header first to sniff the MIME type, avoiding
} // buffering large non-image binaries just to reject them.
func ReadImageFromContainer(ctx context.Context, client *bridge.Client, path string, maxBytes int64) ReadMediaToolOutput {
if maxBytes <= 0 {
maxBytes = defaultReadMediaMaxBytes
}
func NewReadMediaProvider(log *slog.Logger, clients bridge.Provider, rootDir string) *ReadMediaProvider { reader, err := client.ReadRaw(ctx, path)
if log == nil {
log = slog.Default()
}
root := strings.TrimSpace(rootDir)
if root == "" {
root = defaultReadMediaRoot
}
return &ReadMediaProvider{
clients: clients,
rootDir: path.Clean(root),
maxBytes: defaultReadMediaMaxBytes,
logger: log.With(slog.String("tool", "read_media")),
}
}
func (p *ReadMediaProvider) Tools(_ context.Context, session SessionContext) ([]sdk.Tool, error) {
if p == nil || p.clients == nil || !session.SupportsImageInput {
return nil, nil
}
root := p.rootDir
if root == "" {
root = defaultReadMediaRoot
}
sess := session
return []sdk.Tool{
{
Name: toolReadMedia,
Description: fmt.Sprintf("Load an image file from %s into model context so you can inspect it. Relative paths are resolved under %s.", root, root),
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": fmt.Sprintf("Image file path under %s. Absolute paths must stay under %s; relative paths are resolved under %s.", root, root, root),
},
},
"required": []string{"path"},
},
Execute: func(ctx *sdk.ToolExecContext, input any) (any, error) {
return p.execReadMedia(ctx.Context, sess, inputAsMap(input))
},
},
}, nil
}
func (p *ReadMediaProvider) execReadMedia(ctx context.Context, session SessionContext, args map[string]any) (any, error) {
client, err := p.getClient(ctx, session.BotID)
if err != nil { if err != nil {
return readMediaErrorResult(err.Error()), nil return readMediaErrorResult(err.Error())
}
resolvedPath, err := p.resolveImagePath(StringArg(args, "path"))
if err != nil {
return readMediaErrorResult(err.Error()), nil
}
reader, err := client.ReadRaw(ctx, resolvedPath)
if err != nil {
return readMediaErrorResult(err.Error()), nil
} }
defer func() { _ = reader.Close() }() defer func() { _ = reader.Close() }()
data, err := io.ReadAll(io.LimitReader(reader, p.maxBytes+1)) // Read only the sniff header first so non-image binaries fail fast.
if err != nil { header := make([]byte, mimeSniffSize)
return readMediaErrorResult("read_media failed to load image: " + err.Error()), nil n, err := io.ReadAtLeast(reader, header, 1)
if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) {
return readMediaErrorResult("failed to load image: " + err.Error())
} }
if int64(len(data)) > p.maxBytes { header = header[:n]
return readMediaErrorResult(fmt.Sprintf("read_media failed to load image: file exceeds %d bytes", p.maxBytes)), nil
mimeType, err := detectReadMediaMime(header)
if err != nil {
return readMediaErrorResult(err.Error())
} }
mimeType, err := detectReadMediaMime(data) // MIME looks good — read the remainder up to the size limit.
rest, err := io.ReadAll(io.LimitReader(reader, maxBytes-int64(n)+1))
if err != nil { if err != nil {
return readMediaErrorResult(err.Error()), nil return readMediaErrorResult("failed to load image: " + err.Error())
}
data := make([]byte, 0, len(header)+len(rest))
data = append(data, header...)
data = append(data, rest...)
if int64(len(data)) > maxBytes {
return readMediaErrorResult(fmt.Sprintf("failed to load image: file exceeds %d bytes", maxBytes))
} }
encoded := base64.StdEncoding.EncodeToString(data) encoded := base64.StdEncoding.EncodeToString(data)
return ReadMediaToolOutput{ return ReadMediaToolOutput{
Public: ReadMediaToolResult{ Public: ReadMediaToolResult{
OK: true, OK: true,
Path: resolvedPath, Path: path,
Mime: mimeType, Mime: mimeType,
Size: len(data), Size: len(data),
}, },
ImageBase64: encoded, ImageBase64: encoded,
ImageMediaType: mimeType, ImageMediaType: mimeType,
}, nil
}
func (p *ReadMediaProvider) getClient(ctx context.Context, botID string) (*bridge.Client, error) {
botID = strings.TrimSpace(botID)
if botID == "" {
return nil, errors.New("bot_id is required")
} }
client, err := p.clients.MCPClient(ctx, botID)
if err != nil {
return nil, fmt.Errorf("container not reachable: %w", err)
}
return client, nil
}
func (p *ReadMediaProvider) resolveImagePath(raw string) (string, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", errors.New("path is required")
}
root := p.rootDir
if root == "" {
root = defaultReadMediaRoot
}
root = path.Clean(root)
resolved := trimmed
if !strings.HasPrefix(resolved, "/") {
resolved = path.Join(root, resolved)
}
resolved = path.Clean(resolved)
if resolved == root || !strings.HasPrefix(resolved, root+"/") {
return "", fmt.Errorf("path must be under %s", root)
}
return resolved, nil
} }
func readMediaErrorResult(message string) ReadMediaToolOutput { func readMediaErrorResult(message string) ReadMediaToolOutput {
msg := strings.TrimSpace(message) msg := strings.TrimSpace(message)
if msg == "" { if msg == "" {
msg = "read_media failed" msg = "read failed"
} }
return ReadMediaToolOutput{ return ReadMediaToolOutput{
Public: ReadMediaToolResult{ Public: ReadMediaToolResult{
@@ -203,11 +123,11 @@ func detectReadMediaMime(data []byte) (string, error) {
switch { switch {
case sniffedMime == "": case sniffedMime == "":
return "", errors.New("read_media only supports PNG, JPEG, GIF, or WebP image bytes") return "", errors.New("only supports PNG, JPEG, GIF, or WebP image bytes")
case isSupportedReadMediaMime(sniffedMime): case isSupportedReadMediaMime(sniffedMime):
return sniffedMime, nil return sniffedMime, nil
default: default:
return "", errors.New("read_media only supports PNG, JPEG, GIF, or WebP image bytes") return "", errors.New("only supports PNG, JPEG, GIF, or WebP image bytes")
} }
} }
+28 -177
View File
@@ -7,7 +7,6 @@ import (
"strings" "strings"
"testing" "testing"
sdk "github.com/memohai/twilight-ai/sdk"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
@@ -36,15 +35,7 @@ func (s *readMediaTestContainerService) ReadRaw(req *pb.ReadRawRequest, stream p
return stream.Send(&pb.DataChunk{Data: data}) return stream.Send(&pb.DataChunk{Data: data})
} }
type readMediaStaticProvider struct { func newReadMediaTestClient(t *testing.T, files map[string][]byte) *bridge.Client {
client *bridge.Client
}
func (p *readMediaStaticProvider) MCPClient(_ context.Context, _ string) (*bridge.Client, error) {
return p.client, nil
}
func newReadMediaBridgeProvider(t *testing.T, files map[string][]byte) bridge.Provider {
t.Helper() t.Helper()
lis := bufconn.Listen(readMediaTestBufSize) lis := bufconn.Listen(readMediaTestBufSize)
@@ -74,85 +65,19 @@ func newReadMediaBridgeProvider(t *testing.T, files map[string][]byte) bridge.Pr
} }
t.Cleanup(func() { _ = conn.Close() }) t.Cleanup(func() { _ = conn.Close() })
return &readMediaStaticProvider{client: bridge.NewClientFromConn(conn)} return bridge.NewClientFromConn(conn)
} }
func findToolByName(tools []sdk.Tool, name string) (sdk.Tool, bool) { func TestReadImageFromContainerSuccess(t *testing.T) {
for _, tool := range tools {
if tool.Name == name {
return tool, true
}
}
return sdk.Tool{}, false
}
func TestReadMediaProviderToolsOnlyWhenImageInputIsSupported(t *testing.T) {
t.Parallel()
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, nil), "/data")
toolsWithoutImage, err := provider.Tools(context.Background(), SessionContext{
BotID: "bot-1",
SupportsImageInput: false,
})
if err != nil {
t.Fatalf("Tools without image input returned error: %v", err)
}
if len(toolsWithoutImage) != 0 {
t.Fatalf("expected no tools without image input support, got %d", len(toolsWithoutImage))
}
toolsWithImage, err := provider.Tools(context.Background(), SessionContext{
BotID: "bot-1",
SupportsImageInput: true,
})
if err != nil {
t.Fatalf("Tools with image input returned error: %v", err)
}
tool, ok := findToolByName(toolsWithImage, toolReadMedia)
if !ok {
t.Fatalf("expected %q tool to be exposed", toolReadMedia)
}
if tool.Execute == nil {
t.Fatal("expected read_media tool to be executable")
}
}
func TestReadMediaProviderExecuteReadsImageUnderData(t *testing.T) {
t.Parallel() t.Parallel()
pngBytes := []byte("\x89PNG\r\n\x1a\npayload") pngBytes := []byte("\x89PNG\r\n\x1a\npayload")
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, map[string][]byte{ client := newReadMediaTestClient(t, map[string][]byte{
"/data/images/demo.png": pngBytes, "/data/images/demo.png": pngBytes,
}), "/data")
tools, err := provider.Tools(context.Background(), SessionContext{
BotID: "bot-1",
SupportsImageInput: true,
}) })
if err != nil {
t.Fatalf("Tools returned error: %v", err)
}
tool, ok := findToolByName(tools, toolReadMedia) result := ReadImageFromContainer(context.Background(), client, "/data/images/demo.png", 0)
if !ok {
t.Fatalf("expected %q tool", toolReadMedia)
}
output, err := tool.Execute(&sdk.ToolExecContext{
Context: context.Background(),
ToolCallID: "call-1",
ToolName: toolReadMedia,
}, map[string]any{"path": "images/demo.png"})
if err != nil {
t.Fatalf("Execute returned error: %v", err)
}
result, ok := output.(readMediaToolOutput)
if !ok {
t.Fatalf("expected readMediaToolOutput, got %T", output)
}
if !result.Public.OK { if !result.Public.OK {
t.Fatalf("expected success result, got %+v", result.Public) t.Fatalf("expected success result, got %+v", result.Public)
} }
@@ -175,81 +100,16 @@ func TestReadMediaProviderExecuteReadsImageUnderData(t *testing.T) {
} }
} }
func TestReadMediaProviderExecuteRejectsPathOutsideData(t *testing.T) { func TestReadImageFromContainerRejectsUnsupportedMime(t *testing.T) {
t.Parallel() t.Parallel()
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, nil), "/data") svgBytes := []byte(`<svg xmlns="http://www.w3.org/2000/svg"></svg>`)
client := newReadMediaTestClient(t, map[string][]byte{
tools, err := provider.Tools(context.Background(), SessionContext{ "/data/images/demo.svg": svgBytes,
BotID: "bot-1",
SupportsImageInput: true,
}) })
if err != nil {
t.Fatalf("Tools returned error: %v", err)
}
tool, ok := findToolByName(tools, toolReadMedia) result := ReadImageFromContainer(context.Background(), client, "/data/images/demo.svg", 0)
if !ok {
t.Fatalf("expected %q tool", toolReadMedia)
}
output, err := tool.Execute(&sdk.ToolExecContext{
Context: context.Background(),
ToolCallID: "call-2",
ToolName: toolReadMedia,
}, map[string]any{"path": "/tmp/demo.png"})
if err != nil {
t.Fatalf("Execute returned error: %v", err)
}
result, ok := output.(readMediaToolOutput)
if !ok {
t.Fatalf("expected readMediaToolOutput, got %T", output)
}
if result.Public.OK {
t.Fatalf("expected error result, got %+v", result.Public)
}
if !strings.Contains(result.Public.Error, "path must be under /data") {
t.Fatalf("unexpected error: %q", result.Public.Error)
}
if result.ImageBase64 != "" {
t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64)
}
}
func TestReadMediaProviderExecuteRejectsExtensionOnlySVG(t *testing.T) {
t.Parallel()
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, map[string][]byte{
"/data/images/demo.svg": []byte(`<svg xmlns="http://www.w3.org/2000/svg"></svg>`),
}), "/data")
tools, err := provider.Tools(context.Background(), SessionContext{
BotID: "bot-1",
SupportsImageInput: true,
})
if err != nil {
t.Fatalf("Tools returned error: %v", err)
}
tool, ok := findToolByName(tools, toolReadMedia)
if !ok {
t.Fatalf("expected %q tool", toolReadMedia)
}
output, err := tool.Execute(&sdk.ToolExecContext{
Context: context.Background(),
ToolCallID: "call-3",
ToolName: toolReadMedia,
}, map[string]any{"path": "images/demo.svg"})
if err != nil {
t.Fatalf("Execute returned error: %v", err)
}
result, ok := output.(readMediaToolOutput)
if !ok {
t.Fatalf("expected readMediaToolOutput, got %T", output)
}
if result.Public.OK { if result.Public.OK {
t.Fatalf("expected error result, got %+v", result.Public) t.Fatalf("expected error result, got %+v", result.Public)
} }
@@ -261,39 +121,15 @@ func TestReadMediaProviderExecuteRejectsExtensionOnlySVG(t *testing.T) {
} }
} }
func TestReadMediaProviderExecuteRejectsCorruptedRasterBytes(t *testing.T) { func TestReadImageFromContainerRejectsCorruptedBytes(t *testing.T) {
t.Parallel() t.Parallel()
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, map[string][]byte{ client := newReadMediaTestClient(t, map[string][]byte{
"/data/images/demo.png": []byte("definitely not a png"), "/data/images/demo.png": []byte("definitely not a png"),
}), "/data")
tools, err := provider.Tools(context.Background(), SessionContext{
BotID: "bot-1",
SupportsImageInput: true,
}) })
if err != nil {
t.Fatalf("Tools returned error: %v", err)
}
tool, ok := findToolByName(tools, toolReadMedia) result := ReadImageFromContainer(context.Background(), client, "/data/images/demo.png", 0)
if !ok {
t.Fatalf("expected %q tool", toolReadMedia)
}
output, err := tool.Execute(&sdk.ToolExecContext{
Context: context.Background(),
ToolCallID: "call-4",
ToolName: toolReadMedia,
}, map[string]any{"path": "images/demo.png"})
if err != nil {
t.Fatalf("Execute returned error: %v", err)
}
result, ok := output.(readMediaToolOutput)
if !ok {
t.Fatalf("expected readMediaToolOutput, got %T", output)
}
if result.Public.OK { if result.Public.OK {
t.Fatalf("expected error result, got %+v", result.Public) t.Fatalf("expected error result, got %+v", result.Public)
} }
@@ -304,3 +140,18 @@ func TestReadMediaProviderExecuteRejectsCorruptedRasterBytes(t *testing.T) {
t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64) t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64)
} }
} }
func TestReadImageFromContainerNotFound(t *testing.T) {
t.Parallel()
client := newReadMediaTestClient(t, map[string][]byte{})
result := ReadImageFromContainer(context.Background(), client, "/data/images/missing.png", 0)
if result.Public.OK {
t.Fatalf("expected error result, got %+v", result.Public)
}
if result.ImageBase64 != "" {
t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64)
}
}
@@ -8,7 +8,9 @@ import (
agentpkg "github.com/memohai/memoh/internal/agent" agentpkg "github.com/memohai/memoh/internal/agent"
) )
func TestPrepareRunConfigIncludesReadMediaWhenImageInputIsSupported(t *testing.T) { const imageReadHint = "also supports images: PNG, JPEG, GIF, WebP"
func TestPrepareRunConfigIncludesImageReadHintWhenImageInputIsSupported(t *testing.T) {
t.Parallel() t.Parallel()
resolver := &Resolver{} resolver := &Resolver{}
@@ -21,12 +23,12 @@ func TestPrepareRunConfigIncludesReadMediaWhenImageInputIsSupported(t *testing.T
} }
prepared := resolver.prepareRunConfig(context.Background(), cfg) prepared := resolver.prepareRunConfig(context.Background(), cfg)
if !strings.Contains(prepared.System, "`read_media`") { if !strings.Contains(prepared.System, imageReadHint) {
t.Fatalf("expected system prompt to include read_media tool, got:\n%s", prepared.System) t.Fatalf("expected system prompt to contain %q, got:\n%s", imageReadHint, prepared.System)
} }
} }
func TestPrepareRunConfigOmitsReadMediaWhenImageInputIsUnsupported(t *testing.T) { func TestPrepareRunConfigOmitsImageReadHintWhenImageInputIsUnsupported(t *testing.T) {
t.Parallel() t.Parallel()
resolver := &Resolver{} resolver := &Resolver{}
@@ -39,7 +41,7 @@ func TestPrepareRunConfigOmitsReadMediaWhenImageInputIsUnsupported(t *testing.T)
} }
prepared := resolver.prepareRunConfig(context.Background(), cfg) prepared := resolver.prepareRunConfig(context.Background(), cfg)
if strings.Contains(prepared.System, "`read_media`") { if strings.Contains(prepared.System, imageReadHint) {
t.Fatalf("expected system prompt to omit read_media tool, got:\n%s", prepared.System) t.Fatalf("expected system prompt to NOT contain %q, got:\n%s", imageReadHint, prepared.System)
} }
} }