mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
427 lines
13 KiB
Go
427 lines
13 KiB
Go
package agent
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"net"
|
|
"strings"
|
|
"testing"
|
|
|
|
sdk "github.com/memohai/twilight-ai/sdk"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/grpc/test/bufconn"
|
|
|
|
agenttools "github.com/memohai/memoh/internal/agent/tools"
|
|
"github.com/memohai/memoh/internal/workspace/bridge"
|
|
pb "github.com/memohai/memoh/internal/workspace/bridgepb"
|
|
)
|
|
|
|
const agentReadMediaTestBufSize = 1 << 20
|
|
|
|
// agentReadMediaContainerService implements both ReadFile and ReadRaw so
|
|
// that the merged read tool (ContainerProvider) can detect binary files
|
|
// and then delegate to ReadImageFromContainer.
|
|
type agentReadMediaContainerService struct {
|
|
pb.UnimplementedContainerServiceServer
|
|
files map[string][]byte
|
|
}
|
|
|
|
func (s *agentReadMediaContainerService) ReadFile(_ context.Context, req *pb.ReadFileRequest) (*pb.ReadFileResponse, error) {
|
|
data, ok := s.files[req.GetPath()]
|
|
if !ok {
|
|
return nil, status.Error(codes.NotFound, "not found")
|
|
}
|
|
_ = data
|
|
// All files in this test fixture are images → binary.
|
|
return &pb.ReadFileResponse{Binary: true}, nil
|
|
}
|
|
|
|
func (s *agentReadMediaContainerService) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error {
|
|
data, ok := s.files[req.GetPath()]
|
|
if !ok {
|
|
return status.Error(codes.NotFound, "not found")
|
|
}
|
|
if len(data) == 0 {
|
|
return nil
|
|
}
|
|
return stream.Send(&pb.DataChunk{Data: data})
|
|
}
|
|
|
|
type agentReadMediaBridgeProvider struct {
|
|
client *bridge.Client
|
|
}
|
|
|
|
func (p *agentReadMediaBridgeProvider) MCPClient(_ context.Context, _ string) (*bridge.Client, error) {
|
|
return p.client, nil
|
|
}
|
|
|
|
func newAgentReadMediaBridgeProvider(t *testing.T, files map[string][]byte) bridge.Provider {
|
|
t.Helper()
|
|
|
|
lis := bufconn.Listen(agentReadMediaTestBufSize)
|
|
srv := grpc.NewServer()
|
|
pb.RegisterContainerServiceServer(srv, &agentReadMediaContainerService{files: files})
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
_ = srv.Serve(lis)
|
|
}()
|
|
t.Cleanup(func() {
|
|
srv.Stop()
|
|
<-done
|
|
})
|
|
|
|
dialer := func(ctx context.Context, _ string) (net.Conn, error) {
|
|
return lis.DialContext(ctx)
|
|
}
|
|
conn, err := grpc.NewClient(
|
|
"passthrough://bufnet",
|
|
grpc.WithContextDialer(dialer),
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("grpc.NewClient: %v", err)
|
|
}
|
|
t.Cleanup(func() { _ = conn.Close() })
|
|
|
|
return &agentReadMediaBridgeProvider{client: bridge.NewClientFromConn(conn)}
|
|
}
|
|
|
|
type agentReadMediaMockProvider struct {
|
|
name string
|
|
calls int
|
|
handler func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error)
|
|
}
|
|
|
|
func (m *agentReadMediaMockProvider) Name() string {
|
|
if m.name != "" {
|
|
return m.name
|
|
}
|
|
return "mock"
|
|
}
|
|
|
|
func (*agentReadMediaMockProvider) ListModels(context.Context) ([]sdk.Model, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (*agentReadMediaMockProvider) Test(context.Context) *sdk.ProviderTestResult {
|
|
return &sdk.ProviderTestResult{Status: sdk.ProviderStatusOK, Message: "ok"}
|
|
}
|
|
|
|
func (*agentReadMediaMockProvider) TestModel(context.Context, string) (*sdk.ModelTestResult, error) {
|
|
return &sdk.ModelTestResult{Supported: true, Message: "supported"}, nil
|
|
}
|
|
|
|
func (m *agentReadMediaMockProvider) DoGenerate(_ context.Context, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
|
m.calls++
|
|
return m.handler(m.calls, params)
|
|
}
|
|
|
|
func (m *agentReadMediaMockProvider) DoStream(ctx context.Context, params sdk.GenerateParams) (*sdk.StreamResult, error) {
|
|
result, err := m.DoGenerate(ctx, params)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ch := make(chan sdk.StreamPart, 8)
|
|
go func() {
|
|
defer close(ch)
|
|
ch <- &sdk.StartPart{}
|
|
ch <- &sdk.StartStepPart{}
|
|
if result.Text != "" {
|
|
ch <- &sdk.TextStartPart{ID: "mock"}
|
|
ch <- &sdk.TextDeltaPart{ID: "mock", Text: result.Text}
|
|
ch <- &sdk.TextEndPart{ID: "mock"}
|
|
}
|
|
for _, tc := range result.ToolCalls {
|
|
ch <- &sdk.StreamToolCallPart{
|
|
ToolCallID: tc.ToolCallID,
|
|
ToolName: tc.ToolName,
|
|
Input: tc.Input,
|
|
}
|
|
}
|
|
ch <- &sdk.FinishStepPart{FinishReason: result.FinishReason, Usage: result.Usage, Response: result.Response}
|
|
ch <- &sdk.FinishPart{FinishReason: result.FinishReason, TotalUsage: result.Usage}
|
|
}()
|
|
return &sdk.StreamResult{Stream: ch}, nil
|
|
}
|
|
|
|
func assertInjectedReadMediaMessage(t *testing.T, msg sdk.Message, expectedImage, expectedMediaType string) {
|
|
t.Helper()
|
|
|
|
if msg.Role != sdk.MessageRoleUser {
|
|
t.Fatalf("expected injected read_media message role %q, got %q", sdk.MessageRoleUser, msg.Role)
|
|
}
|
|
if len(msg.Content) != 1 {
|
|
t.Fatalf("expected one injected content part, got %d", len(msg.Content))
|
|
}
|
|
image, ok := msg.Content[0].(sdk.ImagePart)
|
|
if !ok {
|
|
t.Fatalf("expected sdk.ImagePart, got %T", msg.Content[0])
|
|
}
|
|
if image.Image != expectedImage {
|
|
t.Fatalf("unexpected injected image payload: %q", image.Image)
|
|
}
|
|
if image.MediaType != expectedMediaType {
|
|
t.Fatalf("unexpected injected media type: %q", image.MediaType)
|
|
}
|
|
}
|
|
|
|
func TestAgentGenerateReadMediaInjectsImageIntoNextStep(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// The PNG data must contain a null byte (\x00) so that the execRead
|
|
// binary probe (bytes.IndexByte(probe, 0)) detects it as binary and
|
|
// delegates to ReadImageFromContainer. Real PNG files always contain
|
|
// null bytes in their IHDR and other chunks.
|
|
pngBytes := []byte("\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00payload")
|
|
expectedDataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(pngBytes)
|
|
|
|
modelProvider := &agentReadMediaMockProvider{
|
|
handler: func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
|
if call == 1 {
|
|
return &sdk.GenerateResult{
|
|
FinishReason: sdk.FinishReasonToolCalls,
|
|
ToolCalls: []sdk.ToolCall{{
|
|
ToolCallID: "call-1",
|
|
ToolName: "read",
|
|
Input: map[string]any{"path": "/data/images/demo.png"},
|
|
}},
|
|
}, nil
|
|
}
|
|
|
|
if len(params.Messages) < 4 {
|
|
t.Fatalf("expected prior tool and injected messages, got %d", len(params.Messages))
|
|
}
|
|
|
|
last := params.Messages[len(params.Messages)-1]
|
|
if last.Role != sdk.MessageRoleUser {
|
|
t.Fatalf("expected last message to be injected user image, got %s", last.Role)
|
|
}
|
|
if len(last.Content) != 1 {
|
|
t.Fatalf("expected one injected content part, got %d", len(last.Content))
|
|
}
|
|
image, ok := last.Content[0].(sdk.ImagePart)
|
|
if !ok {
|
|
t.Fatalf("expected sdk.ImagePart, got %T", last.Content[0])
|
|
}
|
|
if image.Image != expectedDataURL {
|
|
t.Fatalf("unexpected injected image payload: %q", image.Image)
|
|
}
|
|
if image.MediaType != "image/png" {
|
|
t.Fatalf("unexpected injected media type: %q", image.MediaType)
|
|
}
|
|
|
|
var toolResult sdk.ToolResultPart
|
|
foundToolMessage := false
|
|
for _, msg := range params.Messages {
|
|
if msg.Role != sdk.MessageRoleTool || len(msg.Content) == 0 {
|
|
continue
|
|
}
|
|
part, ok := msg.Content[0].(sdk.ToolResultPart)
|
|
if !ok {
|
|
continue
|
|
}
|
|
toolResult = part
|
|
foundToolMessage = true
|
|
break
|
|
}
|
|
if !foundToolMessage {
|
|
t.Fatal("expected tool result message before second step")
|
|
}
|
|
raw, err := json.Marshal(toolResult.Result)
|
|
if err != nil {
|
|
t.Fatalf("marshal tool result: %v", err)
|
|
}
|
|
if !bytes.Contains(raw, []byte(`"ok":true`)) {
|
|
t.Fatalf("expected compact success metadata, got %s", raw)
|
|
}
|
|
if bytes.Contains(raw, []byte(expectedDataURL)) || bytes.Contains(raw, []byte("payload")) {
|
|
t.Fatalf("tool result leaked image bytes: %s", raw)
|
|
}
|
|
|
|
return &sdk.GenerateResult{
|
|
Text: "done",
|
|
FinishReason: sdk.FinishReasonStop,
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
// ContainerProvider normalizes paths by stripping the workdir prefix,
|
|
// so the mock files map must use the normalized (relative) path.
|
|
bp := newAgentReadMediaBridgeProvider(t, map[string][]byte{
|
|
"images/demo.png": pngBytes,
|
|
})
|
|
|
|
a := New(Deps{})
|
|
a.SetToolProviders([]agenttools.ToolProvider{
|
|
agenttools.NewContainerProvider(nil, bp, nil, "/data"),
|
|
})
|
|
|
|
result, err := a.Generate(context.Background(), RunConfig{
|
|
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
|
Messages: []sdk.Message{sdk.UserMessage("look at the image")},
|
|
SupportsImageInput: true,
|
|
SupportsToolCall: true,
|
|
Identity: SessionContext{
|
|
BotID: "bot-1",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Generate returned error: %v", err)
|
|
}
|
|
if result.Text != "done" {
|
|
t.Fatalf("unexpected result text: %q", result.Text)
|
|
}
|
|
if len(result.Messages) != 4 {
|
|
t.Fatalf("expected persisted step + injected history, got %d messages", len(result.Messages))
|
|
}
|
|
assertInjectedReadMediaMessage(t, result.Messages[2], expectedDataURL, "image/png")
|
|
if result.Messages[3].Role != sdk.MessageRoleAssistant {
|
|
t.Fatalf("expected final persisted message to be assistant, got %s", result.Messages[3].Role)
|
|
}
|
|
if modelProvider.calls != 2 {
|
|
t.Fatalf("expected 2 model calls, got %d", modelProvider.calls)
|
|
}
|
|
}
|
|
|
|
func TestAgentGenerateReadMediaInjectsAnthropicSafeImageIntoNextStep(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
pngBytes := []byte("\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00payload")
|
|
expectedBase64 := base64.StdEncoding.EncodeToString(pngBytes)
|
|
|
|
modelProvider := &agentReadMediaMockProvider{
|
|
name: "anthropic-messages",
|
|
handler: func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
|
if call == 1 {
|
|
return &sdk.GenerateResult{
|
|
FinishReason: sdk.FinishReasonToolCalls,
|
|
ToolCalls: []sdk.ToolCall{{
|
|
ToolCallID: "call-1",
|
|
ToolName: "read",
|
|
Input: map[string]any{"path": "/data/images/demo.png"},
|
|
}},
|
|
}, nil
|
|
}
|
|
|
|
last := params.Messages[len(params.Messages)-1]
|
|
image, ok := last.Content[0].(sdk.ImagePart)
|
|
if !ok {
|
|
t.Fatalf("expected sdk.ImagePart, got %T", last.Content[0])
|
|
}
|
|
if image.Image != expectedBase64 {
|
|
t.Fatalf("expected raw base64 for anthropic, got %q", image.Image)
|
|
}
|
|
if image.MediaType != "image/png" {
|
|
t.Fatalf("unexpected injected media type: %q", image.MediaType)
|
|
}
|
|
if strings.HasPrefix(image.Image, "data:") {
|
|
t.Fatalf("anthropic image payload must not be a data URL: %q", image.Image)
|
|
}
|
|
|
|
return &sdk.GenerateResult{
|
|
Text: "done",
|
|
FinishReason: sdk.FinishReasonStop,
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
// ContainerProvider normalizes paths by stripping the workdir prefix,
|
|
// so the mock files map must use the normalized (relative) path.
|
|
bp := newAgentReadMediaBridgeProvider(t, map[string][]byte{
|
|
"images/demo.png": pngBytes,
|
|
})
|
|
|
|
a := New(Deps{})
|
|
a.SetToolProviders([]agenttools.ToolProvider{
|
|
agenttools.NewContainerProvider(nil, bp, nil, "/data"),
|
|
})
|
|
|
|
_, err := a.Generate(context.Background(), RunConfig{
|
|
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
|
Messages: []sdk.Message{sdk.UserMessage("look at the image")},
|
|
SupportsImageInput: true,
|
|
SupportsToolCall: true,
|
|
Identity: SessionContext{
|
|
BotID: "bot-1",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Generate returned error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAgentStreamReadMediaPersistsInjectedImageInTerminalMessages(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
pngBytes := []byte("\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00payload")
|
|
expectedDataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(pngBytes)
|
|
|
|
modelProvider := &agentReadMediaMockProvider{
|
|
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
|
if call == 1 {
|
|
return &sdk.GenerateResult{
|
|
FinishReason: sdk.FinishReasonToolCalls,
|
|
ToolCalls: []sdk.ToolCall{{
|
|
ToolCallID: "call-1",
|
|
ToolName: "read",
|
|
Input: map[string]any{"path": "/data/images/demo.png"},
|
|
}},
|
|
}, nil
|
|
}
|
|
return &sdk.GenerateResult{
|
|
Text: "done",
|
|
FinishReason: sdk.FinishReasonStop,
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
// ContainerProvider normalizes paths by stripping the workdir prefix,
|
|
// so the mock files map must use the normalized (relative) path.
|
|
bp := newAgentReadMediaBridgeProvider(t, map[string][]byte{
|
|
"images/demo.png": pngBytes,
|
|
})
|
|
|
|
a := New(Deps{})
|
|
a.SetToolProviders([]agenttools.ToolProvider{
|
|
agenttools.NewContainerProvider(nil, bp, nil, "/data"),
|
|
})
|
|
|
|
var terminal StreamEvent
|
|
for event := range a.Stream(context.Background(), RunConfig{
|
|
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
|
Messages: []sdk.Message{sdk.UserMessage("look at the image")},
|
|
SupportsImageInput: true,
|
|
SupportsToolCall: true,
|
|
Identity: SessionContext{
|
|
BotID: "bot-1",
|
|
},
|
|
}) {
|
|
if event.IsTerminal() {
|
|
terminal = event
|
|
}
|
|
}
|
|
|
|
if terminal.Type != EventAgentEnd {
|
|
t.Fatalf("expected terminal event %q, got %q", EventAgentEnd, terminal.Type)
|
|
}
|
|
|
|
var messages []sdk.Message
|
|
if err := json.Unmarshal(terminal.Messages, &messages); err != nil {
|
|
t.Fatalf("unmarshal terminal messages: %v", err)
|
|
}
|
|
if len(messages) != 4 {
|
|
t.Fatalf("expected persisted step + injected history, got %d messages", len(messages))
|
|
}
|
|
assertInjectedReadMediaMessage(t, messages[2], expectedDataURL, "image/png")
|
|
if messages[3].Role != sdk.MessageRoleAssistant {
|
|
t.Fatalf("expected final persisted message to be assistant, got %s", messages[3].Role)
|
|
}
|
|
}
|