Files
Memoh/internal/agent/read_media_test.go
T
Acbox fc2b603018 fix(agent): skip tools for models without tool-call capability and parse image output
- Add SupportsToolCall to RunConfig; only inject tools into SDK when set
- Update twilight-ai to 497ad09 which adds SSE scanner 10MB buffer
  (fixes token-too-long on large image payloads) and parses the images
  array from OpenAI-compatible chat completions into StreamFilePart
2026-04-03 00:01:14 +08:00

398 lines
12 KiB
Go

package agent
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"net"
"strings"
"testing"
sdk "github.com/memohai/twilight-ai/sdk"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"
agenttools "github.com/memohai/memoh/internal/agent/tools"
"github.com/memohai/memoh/internal/workspace/bridge"
pb "github.com/memohai/memoh/internal/workspace/bridgepb"
)
const agentReadMediaTestBufSize = 1 << 20
type agentReadMediaContainerService struct {
pb.UnimplementedContainerServiceServer
files map[string][]byte
}
func (s *agentReadMediaContainerService) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error {
data, ok := s.files[req.GetPath()]
if !ok {
return status.Error(codes.NotFound, "not found")
}
if len(data) == 0 {
return nil
}
return stream.Send(&pb.DataChunk{Data: data})
}
type agentReadMediaBridgeProvider struct {
client *bridge.Client
}
func (p *agentReadMediaBridgeProvider) MCPClient(_ context.Context, _ string) (*bridge.Client, error) {
return p.client, nil
}
func newAgentReadMediaBridgeProvider(t *testing.T, files map[string][]byte) bridge.Provider {
t.Helper()
lis := bufconn.Listen(agentReadMediaTestBufSize)
srv := grpc.NewServer()
pb.RegisterContainerServiceServer(srv, &agentReadMediaContainerService{files: files})
done := make(chan struct{})
go func() {
defer close(done)
_ = srv.Serve(lis)
}()
t.Cleanup(func() {
srv.Stop()
<-done
})
dialer := func(ctx context.Context, _ string) (net.Conn, error) {
return lis.DialContext(ctx)
}
conn, err := grpc.NewClient(
"passthrough://bufnet",
grpc.WithContextDialer(dialer),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
t.Fatalf("grpc.NewClient: %v", err)
}
t.Cleanup(func() { _ = conn.Close() })
return &agentReadMediaBridgeProvider{client: bridge.NewClientFromConn(conn)}
}
type agentReadMediaMockProvider struct {
name string
calls int
handler func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error)
}
func (m *agentReadMediaMockProvider) Name() string {
if m.name != "" {
return m.name
}
return "mock"
}
func (*agentReadMediaMockProvider) ListModels(context.Context) ([]sdk.Model, error) {
return nil, nil
}
func (*agentReadMediaMockProvider) Test(context.Context) *sdk.ProviderTestResult {
return &sdk.ProviderTestResult{Status: sdk.ProviderStatusOK, Message: "ok"}
}
func (*agentReadMediaMockProvider) TestModel(context.Context, string) (*sdk.ModelTestResult, error) {
return &sdk.ModelTestResult{Supported: true, Message: "supported"}, nil
}
func (m *agentReadMediaMockProvider) DoGenerate(_ context.Context, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
m.calls++
return m.handler(m.calls, params)
}
func (m *agentReadMediaMockProvider) DoStream(ctx context.Context, params sdk.GenerateParams) (*sdk.StreamResult, error) {
result, err := m.DoGenerate(ctx, params)
if err != nil {
return nil, err
}
ch := make(chan sdk.StreamPart, 8)
go func() {
defer close(ch)
ch <- &sdk.StartPart{}
ch <- &sdk.StartStepPart{}
if result.Text != "" {
ch <- &sdk.TextStartPart{ID: "mock"}
ch <- &sdk.TextDeltaPart{ID: "mock", Text: result.Text}
ch <- &sdk.TextEndPart{ID: "mock"}
}
for _, tc := range result.ToolCalls {
ch <- &sdk.StreamToolCallPart{
ToolCallID: tc.ToolCallID,
ToolName: tc.ToolName,
Input: tc.Input,
}
}
ch <- &sdk.FinishStepPart{FinishReason: result.FinishReason, Usage: result.Usage, Response: result.Response}
ch <- &sdk.FinishPart{FinishReason: result.FinishReason, TotalUsage: result.Usage}
}()
return &sdk.StreamResult{Stream: ch}, nil
}
func assertInjectedReadMediaMessage(t *testing.T, msg sdk.Message, expectedImage, expectedMediaType string) {
t.Helper()
if msg.Role != sdk.MessageRoleUser {
t.Fatalf("expected injected read_media message role %q, got %q", sdk.MessageRoleUser, msg.Role)
}
if len(msg.Content) != 1 {
t.Fatalf("expected one injected content part, got %d", len(msg.Content))
}
image, ok := msg.Content[0].(sdk.ImagePart)
if !ok {
t.Fatalf("expected sdk.ImagePart, got %T", msg.Content[0])
}
if image.Image != expectedImage {
t.Fatalf("unexpected injected image payload: %q", image.Image)
}
if image.MediaType != expectedMediaType {
t.Fatalf("unexpected injected media type: %q", image.MediaType)
}
}
func TestAgentGenerateReadMediaInjectsImageIntoNextStep(t *testing.T) {
t.Parallel()
pngBytes := []byte("\x89PNG\r\n\x1a\npayload")
expectedDataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(pngBytes)
modelProvider := &agentReadMediaMockProvider{
handler: func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
if call == 1 {
return &sdk.GenerateResult{
FinishReason: sdk.FinishReasonToolCalls,
ToolCalls: []sdk.ToolCall{{
ToolCallID: "call-1",
ToolName: "read_media",
Input: map[string]any{"path": "/data/images/demo.png"},
}},
}, nil
}
if len(params.Messages) < 4 {
t.Fatalf("expected prior tool and injected messages, got %d", len(params.Messages))
}
last := params.Messages[len(params.Messages)-1]
if last.Role != sdk.MessageRoleUser {
t.Fatalf("expected last message to be injected user image, got %s", last.Role)
}
if len(last.Content) != 1 {
t.Fatalf("expected one injected content part, got %d", len(last.Content))
}
image, ok := last.Content[0].(sdk.ImagePart)
if !ok {
t.Fatalf("expected sdk.ImagePart, got %T", last.Content[0])
}
if image.Image != expectedDataURL {
t.Fatalf("unexpected injected image payload: %q", image.Image)
}
if image.MediaType != "image/png" {
t.Fatalf("unexpected injected media type: %q", image.MediaType)
}
var toolResult sdk.ToolResultPart
foundToolMessage := false
for _, msg := range params.Messages {
if msg.Role != sdk.MessageRoleTool || len(msg.Content) == 0 {
continue
}
part, ok := msg.Content[0].(sdk.ToolResultPart)
if !ok {
continue
}
toolResult = part
foundToolMessage = true
break
}
if !foundToolMessage {
t.Fatal("expected tool result message before second step")
}
raw, err := json.Marshal(toolResult.Result)
if err != nil {
t.Fatalf("marshal tool result: %v", err)
}
if !bytes.Contains(raw, []byte(`"ok":true`)) {
t.Fatalf("expected compact success metadata, got %s", raw)
}
if bytes.Contains(raw, []byte(expectedDataURL)) || bytes.Contains(raw, []byte("payload")) {
t.Fatalf("tool result leaked image bytes: %s", raw)
}
return &sdk.GenerateResult{
Text: "done",
FinishReason: sdk.FinishReasonStop,
}, nil
},
}
a := New(Deps{})
a.SetToolProviders([]agenttools.ToolProvider{
agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{
"/data/images/demo.png": pngBytes,
}), "/data"),
})
result, err := a.Generate(context.Background(), RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("look at the image")},
SupportsImageInput: true,
SupportsToolCall: true,
Identity: SessionContext{
BotID: "bot-1",
},
})
if err != nil {
t.Fatalf("Generate returned error: %v", err)
}
if result.Text != "done" {
t.Fatalf("unexpected result text: %q", result.Text)
}
if len(result.Messages) != 4 {
t.Fatalf("expected persisted step + injected history, got %d messages", len(result.Messages))
}
assertInjectedReadMediaMessage(t, result.Messages[2], expectedDataURL, "image/png")
if result.Messages[3].Role != sdk.MessageRoleAssistant {
t.Fatalf("expected final persisted message to be assistant, got %s", result.Messages[3].Role)
}
if modelProvider.calls != 2 {
t.Fatalf("expected 2 model calls, got %d", modelProvider.calls)
}
}
func TestAgentGenerateReadMediaInjectsAnthropicSafeImageIntoNextStep(t *testing.T) {
t.Parallel()
pngBytes := []byte("\x89PNG\r\n\x1a\npayload")
expectedBase64 := base64.StdEncoding.EncodeToString(pngBytes)
modelProvider := &agentReadMediaMockProvider{
name: "anthropic-messages",
handler: func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
if call == 1 {
return &sdk.GenerateResult{
FinishReason: sdk.FinishReasonToolCalls,
ToolCalls: []sdk.ToolCall{{
ToolCallID: "call-1",
ToolName: "read_media",
Input: map[string]any{"path": "/data/images/demo.png"},
}},
}, nil
}
last := params.Messages[len(params.Messages)-1]
image, ok := last.Content[0].(sdk.ImagePart)
if !ok {
t.Fatalf("expected sdk.ImagePart, got %T", last.Content[0])
}
if image.Image != expectedBase64 {
t.Fatalf("expected raw base64 for anthropic, got %q", image.Image)
}
if image.MediaType != "image/png" {
t.Fatalf("unexpected injected media type: %q", image.MediaType)
}
if strings.HasPrefix(image.Image, "data:") {
t.Fatalf("anthropic image payload must not be a data URL: %q", image.Image)
}
return &sdk.GenerateResult{
Text: "done",
FinishReason: sdk.FinishReasonStop,
}, nil
},
}
a := New(Deps{})
a.SetToolProviders([]agenttools.ToolProvider{
agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{
"/data/images/demo.png": pngBytes,
}), "/data"),
})
_, err := a.Generate(context.Background(), RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("look at the image")},
SupportsImageInput: true,
SupportsToolCall: true,
Identity: SessionContext{
BotID: "bot-1",
},
})
if err != nil {
t.Fatalf("Generate returned error: %v", err)
}
}
func TestAgentStreamReadMediaPersistsInjectedImageInTerminalMessages(t *testing.T) {
t.Parallel()
pngBytes := []byte("\x89PNG\r\n\x1a\npayload")
expectedDataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(pngBytes)
modelProvider := &agentReadMediaMockProvider{
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
if call == 1 {
return &sdk.GenerateResult{
FinishReason: sdk.FinishReasonToolCalls,
ToolCalls: []sdk.ToolCall{{
ToolCallID: "call-1",
ToolName: "read_media",
Input: map[string]any{"path": "/data/images/demo.png"},
}},
}, nil
}
return &sdk.GenerateResult{
Text: "done",
FinishReason: sdk.FinishReasonStop,
}, nil
},
}
a := New(Deps{})
a.SetToolProviders([]agenttools.ToolProvider{
agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{
"/data/images/demo.png": pngBytes,
}), "/data"),
})
var terminal StreamEvent
for event := range a.Stream(context.Background(), RunConfig{
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
Messages: []sdk.Message{sdk.UserMessage("look at the image")},
SupportsImageInput: true,
SupportsToolCall: true,
Identity: SessionContext{
BotID: "bot-1",
},
}) {
if event.IsTerminal() {
terminal = event
}
}
if terminal.Type != EventAgentEnd {
t.Fatalf("expected terminal event %q, got %q", EventAgentEnd, terminal.Type)
}
var messages []sdk.Message
if err := json.Unmarshal(terminal.Messages, &messages); err != nil {
t.Fatalf("unmarshal terminal messages: %v", err)
}
if len(messages) != 4 {
t.Fatalf("expected persisted step + injected history, got %d messages", len(messages))
}
assertInjectedReadMediaMessage(t, messages[2], expectedDataURL, "image/png")
if messages[3].Role != sdk.MessageRoleAssistant {
t.Fatalf("expected final persisted message to be assistant, got %s", messages[3].Role)
}
}