mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
refactor(agent): merge read_media tool into read tool (#326)
This commit is contained in:
@@ -640,7 +640,6 @@ func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *c
|
|||||||
agenttools.NewMemoryProvider(log, memoryRegistry, settingsService),
|
agenttools.NewMemoryProvider(log, memoryRegistry, settingsService),
|
||||||
agenttools.NewWebProvider(log, settingsService, searchProviderService),
|
agenttools.NewWebProvider(log, settingsService, searchProviderService),
|
||||||
agenttools.NewContainerProvider(log, manager, config.DefaultDataMount),
|
agenttools.NewContainerProvider(log, manager, config.DefaultDataMount),
|
||||||
agenttools.NewReadMediaProvider(log, manager, config.DefaultDataMount),
|
|
||||||
agenttools.NewEmailProvider(log, emailService, emailManager),
|
agenttools.NewEmailProvider(log, emailService, emailManager),
|
||||||
agenttools.NewWebFetchProvider(log),
|
agenttools.NewWebFetchProvider(log),
|
||||||
agenttools.NewSpawnProvider(log, settingsService, modelsService, queries, sessionService),
|
agenttools.NewSpawnProvider(log, settingsService, modelsService, queries, sessionService),
|
||||||
|
|||||||
@@ -540,7 +540,6 @@ func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *c
|
|||||||
agenttools.NewMemoryProvider(log, memoryRegistry, settingsService),
|
agenttools.NewMemoryProvider(log, memoryRegistry, settingsService),
|
||||||
agenttools.NewWebProvider(log, settingsService, searchProviderService),
|
agenttools.NewWebProvider(log, settingsService, searchProviderService),
|
||||||
agenttools.NewContainerProvider(log, manager, config.DefaultDataMount),
|
agenttools.NewContainerProvider(log, manager, config.DefaultDataMount),
|
||||||
agenttools.NewReadMediaProvider(log, manager, config.DefaultDataMount),
|
|
||||||
agenttools.NewEmailProvider(log, emailService, emailManager),
|
agenttools.NewEmailProvider(log, emailService, emailManager),
|
||||||
agenttools.NewWebFetchProvider(log),
|
agenttools.NewWebFetchProvider(log),
|
||||||
agenttools.NewSpawnProvider(log, settingsService, modelsService, queries, sessionService),
|
agenttools.NewSpawnProvider(log, settingsService, modelsService, queries, sessionService),
|
||||||
|
|||||||
@@ -110,12 +110,11 @@ func GenerateSystemPrompt(params SystemPromptParams) string {
|
|||||||
timezoneName = "UTC"
|
timezoneName = "UTC"
|
||||||
}
|
}
|
||||||
|
|
||||||
basicTools := []string{
|
readToolDesc := "- `read`: read file content"
|
||||||
"- `read`: read file content",
|
|
||||||
}
|
|
||||||
if params.SupportsImageInput {
|
if params.SupportsImageInput {
|
||||||
basicTools = append(basicTools, "- `read_media`: view the media")
|
readToolDesc += " (also supports images: PNG, JPEG, GIF, WebP)"
|
||||||
}
|
}
|
||||||
|
basicTools := []string{readToolDesc}
|
||||||
basicTools = append(basicTools,
|
basicTools = append(basicTools,
|
||||||
"- `write`: write file content",
|
"- `write`: write file content",
|
||||||
"- `list`: list directory entries",
|
"- `list`: list directory entries",
|
||||||
|
|||||||
@@ -23,11 +23,24 @@ import (
|
|||||||
|
|
||||||
const agentReadMediaTestBufSize = 1 << 20
|
const agentReadMediaTestBufSize = 1 << 20
|
||||||
|
|
||||||
|
// agentReadMediaContainerService implements both ReadFile and ReadRaw so
|
||||||
|
// that the merged read tool (ContainerProvider) can detect binary files
|
||||||
|
// and then delegate to ReadImageFromContainer.
|
||||||
type agentReadMediaContainerService struct {
|
type agentReadMediaContainerService struct {
|
||||||
pb.UnimplementedContainerServiceServer
|
pb.UnimplementedContainerServiceServer
|
||||||
files map[string][]byte
|
files map[string][]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *agentReadMediaContainerService) ReadFile(_ context.Context, req *pb.ReadFileRequest) (*pb.ReadFileResponse, error) {
|
||||||
|
data, ok := s.files[req.GetPath()]
|
||||||
|
if !ok {
|
||||||
|
return nil, status.Error(codes.NotFound, "not found")
|
||||||
|
}
|
||||||
|
_ = data
|
||||||
|
// All files in this test fixture are images → binary.
|
||||||
|
return &pb.ReadFileResponse{Binary: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *agentReadMediaContainerService) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error {
|
func (s *agentReadMediaContainerService) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error {
|
||||||
data, ok := s.files[req.GetPath()]
|
data, ok := s.files[req.GetPath()]
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -172,7 +185,7 @@ func TestAgentGenerateReadMediaInjectsImageIntoNextStep(t *testing.T) {
|
|||||||
FinishReason: sdk.FinishReasonToolCalls,
|
FinishReason: sdk.FinishReasonToolCalls,
|
||||||
ToolCalls: []sdk.ToolCall{{
|
ToolCalls: []sdk.ToolCall{{
|
||||||
ToolCallID: "call-1",
|
ToolCallID: "call-1",
|
||||||
ToolName: "read_media",
|
ToolName: "read",
|
||||||
Input: map[string]any{"path": "/data/images/demo.png"},
|
Input: map[string]any{"path": "/data/images/demo.png"},
|
||||||
}},
|
}},
|
||||||
}, nil
|
}, nil
|
||||||
@@ -235,11 +248,15 @@ func TestAgentGenerateReadMediaInjectsImageIntoNextStep(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ContainerProvider normalizes paths by stripping the workdir prefix,
|
||||||
|
// so the mock files map must use the normalized (relative) path.
|
||||||
|
bp := newAgentReadMediaBridgeProvider(t, map[string][]byte{
|
||||||
|
"images/demo.png": pngBytes,
|
||||||
|
})
|
||||||
|
|
||||||
a := New(Deps{})
|
a := New(Deps{})
|
||||||
a.SetToolProviders([]agenttools.ToolProvider{
|
a.SetToolProviders([]agenttools.ToolProvider{
|
||||||
agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{
|
agenttools.NewContainerProvider(nil, bp, "/data"),
|
||||||
"/data/images/demo.png": pngBytes,
|
|
||||||
}), "/data"),
|
|
||||||
})
|
})
|
||||||
|
|
||||||
result, err := a.Generate(context.Background(), RunConfig{
|
result, err := a.Generate(context.Background(), RunConfig{
|
||||||
@@ -283,7 +300,7 @@ func TestAgentGenerateReadMediaInjectsAnthropicSafeImageIntoNextStep(t *testing.
|
|||||||
FinishReason: sdk.FinishReasonToolCalls,
|
FinishReason: sdk.FinishReasonToolCalls,
|
||||||
ToolCalls: []sdk.ToolCall{{
|
ToolCalls: []sdk.ToolCall{{
|
||||||
ToolCallID: "call-1",
|
ToolCallID: "call-1",
|
||||||
ToolName: "read_media",
|
ToolName: "read",
|
||||||
Input: map[string]any{"path": "/data/images/demo.png"},
|
Input: map[string]any{"path": "/data/images/demo.png"},
|
||||||
}},
|
}},
|
||||||
}, nil
|
}, nil
|
||||||
@@ -311,11 +328,15 @@ func TestAgentGenerateReadMediaInjectsAnthropicSafeImageIntoNextStep(t *testing.
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ContainerProvider normalizes paths by stripping the workdir prefix,
|
||||||
|
// so the mock files map must use the normalized (relative) path.
|
||||||
|
bp := newAgentReadMediaBridgeProvider(t, map[string][]byte{
|
||||||
|
"images/demo.png": pngBytes,
|
||||||
|
})
|
||||||
|
|
||||||
a := New(Deps{})
|
a := New(Deps{})
|
||||||
a.SetToolProviders([]agenttools.ToolProvider{
|
a.SetToolProviders([]agenttools.ToolProvider{
|
||||||
agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{
|
agenttools.NewContainerProvider(nil, bp, "/data"),
|
||||||
"/data/images/demo.png": pngBytes,
|
|
||||||
}), "/data"),
|
|
||||||
})
|
})
|
||||||
|
|
||||||
_, err := a.Generate(context.Background(), RunConfig{
|
_, err := a.Generate(context.Background(), RunConfig{
|
||||||
@@ -345,7 +366,7 @@ func TestAgentStreamReadMediaPersistsInjectedImageInTerminalMessages(t *testing.
|
|||||||
FinishReason: sdk.FinishReasonToolCalls,
|
FinishReason: sdk.FinishReasonToolCalls,
|
||||||
ToolCalls: []sdk.ToolCall{{
|
ToolCalls: []sdk.ToolCall{{
|
||||||
ToolCallID: "call-1",
|
ToolCallID: "call-1",
|
||||||
ToolName: "read_media",
|
ToolName: "read",
|
||||||
Input: map[string]any{"path": "/data/images/demo.png"},
|
Input: map[string]any{"path": "/data/images/demo.png"},
|
||||||
}},
|
}},
|
||||||
}, nil
|
}, nil
|
||||||
@@ -357,11 +378,15 @@ func TestAgentStreamReadMediaPersistsInjectedImageInTerminalMessages(t *testing.
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ContainerProvider normalizes paths by stripping the workdir prefix,
|
||||||
|
// so the mock files map must use the normalized (relative) path.
|
||||||
|
bp := newAgentReadMediaBridgeProvider(t, map[string][]byte{
|
||||||
|
"images/demo.png": pngBytes,
|
||||||
|
})
|
||||||
|
|
||||||
a := New(Deps{})
|
a := New(Deps{})
|
||||||
a.SetToolProviders([]agenttools.ToolProvider{
|
a.SetToolProviders([]agenttools.ToolProvider{
|
||||||
agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{
|
agenttools.NewContainerProvider(nil, bp, "/data"),
|
||||||
"/data/images/demo.png": pngBytes,
|
|
||||||
}), "/data"),
|
|
||||||
})
|
})
|
||||||
|
|
||||||
var terminal StreamEvent
|
var terminal StreamEvent
|
||||||
|
|||||||
@@ -36,10 +36,16 @@ func NewContainerProvider(log *slog.Logger, clients bridge.Provider, execWorkDir
|
|||||||
func (p *ContainerProvider) Tools(_ context.Context, session SessionContext) ([]sdk.Tool, error) {
|
func (p *ContainerProvider) Tools(_ context.Context, session SessionContext) ([]sdk.Tool, error) {
|
||||||
wd := p.execWorkDir
|
wd := p.execWorkDir
|
||||||
sess := session
|
sess := session
|
||||||
|
|
||||||
|
readDesc := fmt.Sprintf("Read file content inside the bot container. Supports pagination for large files. Max %d lines / %d bytes per call.", readMaxLines, readMaxBytes)
|
||||||
|
if sess.SupportsImageInput {
|
||||||
|
readDesc += " Also supports reading image files (PNG, JPEG, GIF, WebP) — binary images are loaded into model context automatically."
|
||||||
|
}
|
||||||
|
|
||||||
return []sdk.Tool{
|
return []sdk.Tool{
|
||||||
{
|
{
|
||||||
Name: "read",
|
Name: "read",
|
||||||
Description: fmt.Sprintf("Read file content inside the bot container. Supports pagination for large files. Max %d lines / %d bytes per call.", readMaxLines, readMaxBytes),
|
Description: readDesc,
|
||||||
Parameters: map[string]any{
|
Parameters: map[string]any{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": map[string]any{
|
"properties": map[string]any{
|
||||||
@@ -187,7 +193,10 @@ func (p *ContainerProvider) execRead(ctx context.Context, session SessionContext
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if resp.GetBinary() {
|
if resp.GetBinary() {
|
||||||
return nil, errors.New("file appears to be binary. Read tool only supports text files")
|
if !session.SupportsImageInput {
|
||||||
|
return nil, errors.New("file appears to be binary. Read tool only supports text files (image reading not available for this model)")
|
||||||
|
}
|
||||||
|
return ReadImageFromContainer(ctx, client, filePath, defaultReadMediaMaxBytes), nil
|
||||||
}
|
}
|
||||||
content := addLineNumbers(resp.GetContent(), lineOffset)
|
content := addLineNumbers(resp.GetContent(), lineOffset)
|
||||||
return map[string]any{"content": content, "total_lines": resp.GetTotalLines()}, nil
|
return map[string]any{"content": content, "total_lines": resp.GetTotalLines()}, nil
|
||||||
|
|||||||
@@ -6,20 +6,16 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
sdk "github.com/memohai/twilight-ai/sdk"
|
|
||||||
|
|
||||||
"github.com/memohai/memoh/internal/workspace/bridge"
|
"github.com/memohai/memoh/internal/workspace/bridge"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ReadMediaToolName = "read_media"
|
// ReadMediaToolName is the tool name that the agent decoration layer
|
||||||
toolReadMedia = ReadMediaToolName
|
// matches on to intercept image payloads. After the merge this is "read".
|
||||||
defaultReadMediaRoot = "/data"
|
ReadMediaToolName = "read"
|
||||||
defaultReadMediaMaxBytes = 20 * 1024 * 1024
|
defaultReadMediaMaxBytes = 20 * 1024 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -48,144 +44,68 @@ type ReadMediaToolOutput struct {
|
|||||||
ImageMediaType string
|
ImageMediaType string
|
||||||
}
|
}
|
||||||
|
|
||||||
type readMediaToolOutput = ReadMediaToolOutput
|
// mimeSniffSize is the number of bytes http.DetectContentType needs.
|
||||||
|
const mimeSniffSize = 512
|
||||||
|
|
||||||
type ReadMediaProvider struct {
|
// ReadImageFromContainer reads a binary file through the bridge client,
|
||||||
clients bridge.Provider
|
// validates that it is a supported image format, and returns a
|
||||||
rootDir string
|
// ReadMediaToolOutput ready for the agent decoration pipeline.
|
||||||
maxBytes int64
|
//
|
||||||
logger *slog.Logger
|
// It reads only a small header first to sniff the MIME type, avoiding
|
||||||
}
|
// buffering large non-image binaries just to reject them.
|
||||||
|
func ReadImageFromContainer(ctx context.Context, client *bridge.Client, path string, maxBytes int64) ReadMediaToolOutput {
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = defaultReadMediaMaxBytes
|
||||||
|
}
|
||||||
|
|
||||||
func NewReadMediaProvider(log *slog.Logger, clients bridge.Provider, rootDir string) *ReadMediaProvider {
|
reader, err := client.ReadRaw(ctx, path)
|
||||||
if log == nil {
|
|
||||||
log = slog.Default()
|
|
||||||
}
|
|
||||||
root := strings.TrimSpace(rootDir)
|
|
||||||
if root == "" {
|
|
||||||
root = defaultReadMediaRoot
|
|
||||||
}
|
|
||||||
return &ReadMediaProvider{
|
|
||||||
clients: clients,
|
|
||||||
rootDir: path.Clean(root),
|
|
||||||
maxBytes: defaultReadMediaMaxBytes,
|
|
||||||
logger: log.With(slog.String("tool", "read_media")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ReadMediaProvider) Tools(_ context.Context, session SessionContext) ([]sdk.Tool, error) {
|
|
||||||
if p == nil || p.clients == nil || !session.SupportsImageInput {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
root := p.rootDir
|
|
||||||
if root == "" {
|
|
||||||
root = defaultReadMediaRoot
|
|
||||||
}
|
|
||||||
sess := session
|
|
||||||
return []sdk.Tool{
|
|
||||||
{
|
|
||||||
Name: toolReadMedia,
|
|
||||||
Description: fmt.Sprintf("Load an image file from %s into model context so you can inspect it. Relative paths are resolved under %s.", root, root),
|
|
||||||
Parameters: map[string]any{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]any{
|
|
||||||
"path": map[string]any{
|
|
||||||
"type": "string",
|
|
||||||
"description": fmt.Sprintf("Image file path under %s. Absolute paths must stay under %s; relative paths are resolved under %s.", root, root, root),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": []string{"path"},
|
|
||||||
},
|
|
||||||
Execute: func(ctx *sdk.ToolExecContext, input any) (any, error) {
|
|
||||||
return p.execReadMedia(ctx.Context, sess, inputAsMap(input))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ReadMediaProvider) execReadMedia(ctx context.Context, session SessionContext, args map[string]any) (any, error) {
|
|
||||||
client, err := p.getClient(ctx, session.BotID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return readMediaErrorResult(err.Error()), nil
|
return readMediaErrorResult(err.Error())
|
||||||
}
|
|
||||||
|
|
||||||
resolvedPath, err := p.resolveImagePath(StringArg(args, "path"))
|
|
||||||
if err != nil {
|
|
||||||
return readMediaErrorResult(err.Error()), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
reader, err := client.ReadRaw(ctx, resolvedPath)
|
|
||||||
if err != nil {
|
|
||||||
return readMediaErrorResult(err.Error()), nil
|
|
||||||
}
|
}
|
||||||
defer func() { _ = reader.Close() }()
|
defer func() { _ = reader.Close() }()
|
||||||
|
|
||||||
data, err := io.ReadAll(io.LimitReader(reader, p.maxBytes+1))
|
// Read only the sniff header first so non-image binaries fail fast.
|
||||||
if err != nil {
|
header := make([]byte, mimeSniffSize)
|
||||||
return readMediaErrorResult("read_media failed to load image: " + err.Error()), nil
|
n, err := io.ReadAtLeast(reader, header, 1)
|
||||||
|
if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
|
return readMediaErrorResult("failed to load image: " + err.Error())
|
||||||
}
|
}
|
||||||
if int64(len(data)) > p.maxBytes {
|
header = header[:n]
|
||||||
return readMediaErrorResult(fmt.Sprintf("read_media failed to load image: file exceeds %d bytes", p.maxBytes)), nil
|
|
||||||
|
mimeType, err := detectReadMediaMime(header)
|
||||||
|
if err != nil {
|
||||||
|
return readMediaErrorResult(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
mimeType, err := detectReadMediaMime(data)
|
// MIME looks good — read the remainder up to the size limit.
|
||||||
|
rest, err := io.ReadAll(io.LimitReader(reader, maxBytes-int64(n)+1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return readMediaErrorResult(err.Error()), nil
|
return readMediaErrorResult("failed to load image: " + err.Error())
|
||||||
|
}
|
||||||
|
data := make([]byte, 0, len(header)+len(rest))
|
||||||
|
data = append(data, header...)
|
||||||
|
data = append(data, rest...)
|
||||||
|
if int64(len(data)) > maxBytes {
|
||||||
|
return readMediaErrorResult(fmt.Sprintf("failed to load image: file exceeds %d bytes", maxBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
encoded := base64.StdEncoding.EncodeToString(data)
|
encoded := base64.StdEncoding.EncodeToString(data)
|
||||||
return ReadMediaToolOutput{
|
return ReadMediaToolOutput{
|
||||||
Public: ReadMediaToolResult{
|
Public: ReadMediaToolResult{
|
||||||
OK: true,
|
OK: true,
|
||||||
Path: resolvedPath,
|
Path: path,
|
||||||
Mime: mimeType,
|
Mime: mimeType,
|
||||||
Size: len(data),
|
Size: len(data),
|
||||||
},
|
},
|
||||||
ImageBase64: encoded,
|
ImageBase64: encoded,
|
||||||
ImageMediaType: mimeType,
|
ImageMediaType: mimeType,
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ReadMediaProvider) getClient(ctx context.Context, botID string) (*bridge.Client, error) {
|
|
||||||
botID = strings.TrimSpace(botID)
|
|
||||||
if botID == "" {
|
|
||||||
return nil, errors.New("bot_id is required")
|
|
||||||
}
|
}
|
||||||
client, err := p.clients.MCPClient(ctx, botID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("container not reachable: %w", err)
|
|
||||||
}
|
|
||||||
return client, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ReadMediaProvider) resolveImagePath(raw string) (string, error) {
|
|
||||||
trimmed := strings.TrimSpace(raw)
|
|
||||||
if trimmed == "" {
|
|
||||||
return "", errors.New("path is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
root := p.rootDir
|
|
||||||
if root == "" {
|
|
||||||
root = defaultReadMediaRoot
|
|
||||||
}
|
|
||||||
root = path.Clean(root)
|
|
||||||
|
|
||||||
resolved := trimmed
|
|
||||||
if !strings.HasPrefix(resolved, "/") {
|
|
||||||
resolved = path.Join(root, resolved)
|
|
||||||
}
|
|
||||||
resolved = path.Clean(resolved)
|
|
||||||
|
|
||||||
if resolved == root || !strings.HasPrefix(resolved, root+"/") {
|
|
||||||
return "", fmt.Errorf("path must be under %s", root)
|
|
||||||
}
|
|
||||||
return resolved, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func readMediaErrorResult(message string) ReadMediaToolOutput {
|
func readMediaErrorResult(message string) ReadMediaToolOutput {
|
||||||
msg := strings.TrimSpace(message)
|
msg := strings.TrimSpace(message)
|
||||||
if msg == "" {
|
if msg == "" {
|
||||||
msg = "read_media failed"
|
msg = "read failed"
|
||||||
}
|
}
|
||||||
return ReadMediaToolOutput{
|
return ReadMediaToolOutput{
|
||||||
Public: ReadMediaToolResult{
|
Public: ReadMediaToolResult{
|
||||||
@@ -203,11 +123,11 @@ func detectReadMediaMime(data []byte) (string, error) {
|
|||||||
|
|
||||||
switch {
|
switch {
|
||||||
case sniffedMime == "":
|
case sniffedMime == "":
|
||||||
return "", errors.New("read_media only supports PNG, JPEG, GIF, or WebP image bytes")
|
return "", errors.New("only supports PNG, JPEG, GIF, or WebP image bytes")
|
||||||
case isSupportedReadMediaMime(sniffedMime):
|
case isSupportedReadMediaMime(sniffedMime):
|
||||||
return sniffedMime, nil
|
return sniffedMime, nil
|
||||||
default:
|
default:
|
||||||
return "", errors.New("read_media only supports PNG, JPEG, GIF, or WebP image bytes")
|
return "", errors.New("only supports PNG, JPEG, GIF, or WebP image bytes")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
sdk "github.com/memohai/twilight-ai/sdk"
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
@@ -36,15 +35,7 @@ func (s *readMediaTestContainerService) ReadRaw(req *pb.ReadRawRequest, stream p
|
|||||||
return stream.Send(&pb.DataChunk{Data: data})
|
return stream.Send(&pb.DataChunk{Data: data})
|
||||||
}
|
}
|
||||||
|
|
||||||
type readMediaStaticProvider struct {
|
func newReadMediaTestClient(t *testing.T, files map[string][]byte) *bridge.Client {
|
||||||
client *bridge.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *readMediaStaticProvider) MCPClient(_ context.Context, _ string) (*bridge.Client, error) {
|
|
||||||
return p.client, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newReadMediaBridgeProvider(t *testing.T, files map[string][]byte) bridge.Provider {
|
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
lis := bufconn.Listen(readMediaTestBufSize)
|
lis := bufconn.Listen(readMediaTestBufSize)
|
||||||
@@ -74,85 +65,19 @@ func newReadMediaBridgeProvider(t *testing.T, files map[string][]byte) bridge.Pr
|
|||||||
}
|
}
|
||||||
t.Cleanup(func() { _ = conn.Close() })
|
t.Cleanup(func() { _ = conn.Close() })
|
||||||
|
|
||||||
return &readMediaStaticProvider{client: bridge.NewClientFromConn(conn)}
|
return bridge.NewClientFromConn(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func findToolByName(tools []sdk.Tool, name string) (sdk.Tool, bool) {
|
func TestReadImageFromContainerSuccess(t *testing.T) {
|
||||||
for _, tool := range tools {
|
|
||||||
if tool.Name == name {
|
|
||||||
return tool, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return sdk.Tool{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadMediaProviderToolsOnlyWhenImageInputIsSupported(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, nil), "/data")
|
|
||||||
|
|
||||||
toolsWithoutImage, err := provider.Tools(context.Background(), SessionContext{
|
|
||||||
BotID: "bot-1",
|
|
||||||
SupportsImageInput: false,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Tools without image input returned error: %v", err)
|
|
||||||
}
|
|
||||||
if len(toolsWithoutImage) != 0 {
|
|
||||||
t.Fatalf("expected no tools without image input support, got %d", len(toolsWithoutImage))
|
|
||||||
}
|
|
||||||
|
|
||||||
toolsWithImage, err := provider.Tools(context.Background(), SessionContext{
|
|
||||||
BotID: "bot-1",
|
|
||||||
SupportsImageInput: true,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Tools with image input returned error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tool, ok := findToolByName(toolsWithImage, toolReadMedia)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected %q tool to be exposed", toolReadMedia)
|
|
||||||
}
|
|
||||||
if tool.Execute == nil {
|
|
||||||
t.Fatal("expected read_media tool to be executable")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadMediaProviderExecuteReadsImageUnderData(t *testing.T) {
|
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
pngBytes := []byte("\x89PNG\r\n\x1a\npayload")
|
pngBytes := []byte("\x89PNG\r\n\x1a\npayload")
|
||||||
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, map[string][]byte{
|
client := newReadMediaTestClient(t, map[string][]byte{
|
||||||
"/data/images/demo.png": pngBytes,
|
"/data/images/demo.png": pngBytes,
|
||||||
}), "/data")
|
|
||||||
|
|
||||||
tools, err := provider.Tools(context.Background(), SessionContext{
|
|
||||||
BotID: "bot-1",
|
|
||||||
SupportsImageInput: true,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Tools returned error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tool, ok := findToolByName(tools, toolReadMedia)
|
result := ReadImageFromContainer(context.Background(), client, "/data/images/demo.png", 0)
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected %q tool", toolReadMedia)
|
|
||||||
}
|
|
||||||
|
|
||||||
output, err := tool.Execute(&sdk.ToolExecContext{
|
|
||||||
Context: context.Background(),
|
|
||||||
ToolCallID: "call-1",
|
|
||||||
ToolName: toolReadMedia,
|
|
||||||
}, map[string]any{"path": "images/demo.png"})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Execute returned error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result, ok := output.(readMediaToolOutput)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected readMediaToolOutput, got %T", output)
|
|
||||||
}
|
|
||||||
if !result.Public.OK {
|
if !result.Public.OK {
|
||||||
t.Fatalf("expected success result, got %+v", result.Public)
|
t.Fatalf("expected success result, got %+v", result.Public)
|
||||||
}
|
}
|
||||||
@@ -175,81 +100,16 @@ func TestReadMediaProviderExecuteReadsImageUnderData(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReadMediaProviderExecuteRejectsPathOutsideData(t *testing.T) {
|
func TestReadImageFromContainerRejectsUnsupportedMime(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, nil), "/data")
|
svgBytes := []byte(`<svg xmlns="http://www.w3.org/2000/svg"></svg>`)
|
||||||
|
client := newReadMediaTestClient(t, map[string][]byte{
|
||||||
tools, err := provider.Tools(context.Background(), SessionContext{
|
"/data/images/demo.svg": svgBytes,
|
||||||
BotID: "bot-1",
|
|
||||||
SupportsImageInput: true,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Tools returned error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tool, ok := findToolByName(tools, toolReadMedia)
|
result := ReadImageFromContainer(context.Background(), client, "/data/images/demo.svg", 0)
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected %q tool", toolReadMedia)
|
|
||||||
}
|
|
||||||
|
|
||||||
output, err := tool.Execute(&sdk.ToolExecContext{
|
|
||||||
Context: context.Background(),
|
|
||||||
ToolCallID: "call-2",
|
|
||||||
ToolName: toolReadMedia,
|
|
||||||
}, map[string]any{"path": "/tmp/demo.png"})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Execute returned error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result, ok := output.(readMediaToolOutput)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected readMediaToolOutput, got %T", output)
|
|
||||||
}
|
|
||||||
if result.Public.OK {
|
|
||||||
t.Fatalf("expected error result, got %+v", result.Public)
|
|
||||||
}
|
|
||||||
if !strings.Contains(result.Public.Error, "path must be under /data") {
|
|
||||||
t.Fatalf("unexpected error: %q", result.Public.Error)
|
|
||||||
}
|
|
||||||
if result.ImageBase64 != "" {
|
|
||||||
t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadMediaProviderExecuteRejectsExtensionOnlySVG(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, map[string][]byte{
|
|
||||||
"/data/images/demo.svg": []byte(`<svg xmlns="http://www.w3.org/2000/svg"></svg>`),
|
|
||||||
}), "/data")
|
|
||||||
|
|
||||||
tools, err := provider.Tools(context.Background(), SessionContext{
|
|
||||||
BotID: "bot-1",
|
|
||||||
SupportsImageInput: true,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Tools returned error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tool, ok := findToolByName(tools, toolReadMedia)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected %q tool", toolReadMedia)
|
|
||||||
}
|
|
||||||
|
|
||||||
output, err := tool.Execute(&sdk.ToolExecContext{
|
|
||||||
Context: context.Background(),
|
|
||||||
ToolCallID: "call-3",
|
|
||||||
ToolName: toolReadMedia,
|
|
||||||
}, map[string]any{"path": "images/demo.svg"})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Execute returned error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result, ok := output.(readMediaToolOutput)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected readMediaToolOutput, got %T", output)
|
|
||||||
}
|
|
||||||
if result.Public.OK {
|
if result.Public.OK {
|
||||||
t.Fatalf("expected error result, got %+v", result.Public)
|
t.Fatalf("expected error result, got %+v", result.Public)
|
||||||
}
|
}
|
||||||
@@ -261,39 +121,15 @@ func TestReadMediaProviderExecuteRejectsExtensionOnlySVG(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReadMediaProviderExecuteRejectsCorruptedRasterBytes(t *testing.T) {
|
func TestReadImageFromContainerRejectsCorruptedBytes(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, map[string][]byte{
|
client := newReadMediaTestClient(t, map[string][]byte{
|
||||||
"/data/images/demo.png": []byte("definitely not a png"),
|
"/data/images/demo.png": []byte("definitely not a png"),
|
||||||
}), "/data")
|
|
||||||
|
|
||||||
tools, err := provider.Tools(context.Background(), SessionContext{
|
|
||||||
BotID: "bot-1",
|
|
||||||
SupportsImageInput: true,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Tools returned error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tool, ok := findToolByName(tools, toolReadMedia)
|
result := ReadImageFromContainer(context.Background(), client, "/data/images/demo.png", 0)
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected %q tool", toolReadMedia)
|
|
||||||
}
|
|
||||||
|
|
||||||
output, err := tool.Execute(&sdk.ToolExecContext{
|
|
||||||
Context: context.Background(),
|
|
||||||
ToolCallID: "call-4",
|
|
||||||
ToolName: toolReadMedia,
|
|
||||||
}, map[string]any{"path": "images/demo.png"})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Execute returned error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result, ok := output.(readMediaToolOutput)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected readMediaToolOutput, got %T", output)
|
|
||||||
}
|
|
||||||
if result.Public.OK {
|
if result.Public.OK {
|
||||||
t.Fatalf("expected error result, got %+v", result.Public)
|
t.Fatalf("expected error result, got %+v", result.Public)
|
||||||
}
|
}
|
||||||
@@ -304,3 +140,18 @@ func TestReadMediaProviderExecuteRejectsCorruptedRasterBytes(t *testing.T) {
|
|||||||
t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64)
|
t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReadImageFromContainerNotFound(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
client := newReadMediaTestClient(t, map[string][]byte{})
|
||||||
|
|
||||||
|
result := ReadImageFromContainer(context.Background(), client, "/data/images/missing.png", 0)
|
||||||
|
|
||||||
|
if result.Public.OK {
|
||||||
|
t.Fatalf("expected error result, got %+v", result.Public)
|
||||||
|
}
|
||||||
|
if result.ImageBase64 != "" {
|
||||||
|
t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,7 +8,9 @@ import (
|
|||||||
agentpkg "github.com/memohai/memoh/internal/agent"
|
agentpkg "github.com/memohai/memoh/internal/agent"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPrepareRunConfigIncludesReadMediaWhenImageInputIsSupported(t *testing.T) {
|
const imageReadHint = "also supports images: PNG, JPEG, GIF, WebP"
|
||||||
|
|
||||||
|
func TestPrepareRunConfigIncludesImageReadHintWhenImageInputIsSupported(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
resolver := &Resolver{}
|
resolver := &Resolver{}
|
||||||
@@ -21,12 +23,12 @@ func TestPrepareRunConfigIncludesReadMediaWhenImageInputIsSupported(t *testing.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
prepared := resolver.prepareRunConfig(context.Background(), cfg)
|
prepared := resolver.prepareRunConfig(context.Background(), cfg)
|
||||||
if !strings.Contains(prepared.System, "`read_media`") {
|
if !strings.Contains(prepared.System, imageReadHint) {
|
||||||
t.Fatalf("expected system prompt to include read_media tool, got:\n%s", prepared.System)
|
t.Fatalf("expected system prompt to contain %q, got:\n%s", imageReadHint, prepared.System)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPrepareRunConfigOmitsReadMediaWhenImageInputIsUnsupported(t *testing.T) {
|
func TestPrepareRunConfigOmitsImageReadHintWhenImageInputIsUnsupported(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
resolver := &Resolver{}
|
resolver := &Resolver{}
|
||||||
@@ -39,7 +41,7 @@ func TestPrepareRunConfigOmitsReadMediaWhenImageInputIsUnsupported(t *testing.T)
|
|||||||
}
|
}
|
||||||
|
|
||||||
prepared := resolver.prepareRunConfig(context.Background(), cfg)
|
prepared := resolver.prepareRunConfig(context.Background(), cfg)
|
||||||
if strings.Contains(prepared.System, "`read_media`") {
|
if strings.Contains(prepared.System, imageReadHint) {
|
||||||
t.Fatalf("expected system prompt to omit read_media tool, got:\n%s", prepared.System)
|
t.Fatalf("expected system prompt to NOT contain %q, got:\n%s", imageReadHint, prepared.System)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user