mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat(agent): restore read_media in pure Go (#257)
This commit is contained in:
@@ -572,6 +572,7 @@ func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *c
|
||||
agenttools.NewMemoryProvider(log, memoryRegistry, settingsService),
|
||||
agenttools.NewWebProvider(log, settingsService, searchProviderService),
|
||||
agenttools.NewContainerProvider(log, manager, config.DefaultDataMount),
|
||||
agenttools.NewReadMediaProvider(log, manager, config.DefaultDataMount),
|
||||
agenttools.NewInboxProvider(log, inboxService),
|
||||
agenttools.NewEmailProvider(log, emailService, emailManager),
|
||||
agenttools.NewWebFetchProvider(log),
|
||||
|
||||
@@ -434,6 +434,7 @@ func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *c
|
||||
agenttools.NewMemoryProvider(log, memoryRegistry, settingsService),
|
||||
agenttools.NewWebProvider(log, settingsService, searchProviderService),
|
||||
agenttools.NewContainerProvider(log, manager, config.DefaultDataMount),
|
||||
agenttools.NewReadMediaProvider(log, manager, config.DefaultDataMount),
|
||||
agenttools.NewInboxProvider(log, inboxService),
|
||||
agenttools.NewEmailProvider(log, emailService, emailManager),
|
||||
agenttools.NewWebFetchProvider(log),
|
||||
|
||||
+34
-12
@@ -62,6 +62,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
ch <- StreamEvent{Type: EventError, Error: fmt.Sprintf("assemble tools: %v", err)}
|
||||
return
|
||||
}
|
||||
tools, readMediaState := decorateReadMediaTools(cfg.Model, tools)
|
||||
|
||||
enabledSkills := make([]string, 0, len(cfg.EnabledSkillNames))
|
||||
copy(enabledSkills, cfg.EnabledSkillNames)
|
||||
@@ -103,7 +104,11 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
tagResolvers := DefaultTagResolvers()
|
||||
tagExtractor := NewStreamTagExtractor(tagResolvers)
|
||||
|
||||
opts := a.buildGenerateOptions(cfg, tools)
|
||||
var prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams
|
||||
if readMediaState != nil {
|
||||
prepareStep = readMediaState.prepareStep
|
||||
}
|
||||
opts := a.buildGenerateOptions(cfg, tools, prepareStep)
|
||||
|
||||
streamResult, err := a.client.StreamText(ctx, opts...)
|
||||
if err != nil {
|
||||
@@ -251,7 +256,11 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
textLoopProbeBuffer.Flush()
|
||||
}
|
||||
|
||||
finalMessages := StripTagsFromMessages(streamResult.Messages)
|
||||
finalMessages := streamResult.Messages
|
||||
if readMediaState != nil {
|
||||
finalMessages = readMediaState.mergeMessages(streamResult.Steps, finalMessages)
|
||||
}
|
||||
finalMessages = StripTagsFromMessages(finalMessages)
|
||||
|
||||
var totalUsage sdk.Usage
|
||||
perStepUsages := make([]json.RawMessage, 0, len(streamResult.Steps))
|
||||
@@ -286,6 +295,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("assemble tools: %w", err)
|
||||
}
|
||||
tools, readMediaState := decorateReadMediaTools(cfg.Model, tools)
|
||||
|
||||
enabledSkills := make([]string, 0, len(cfg.EnabledSkillNames))
|
||||
copy(enabledSkills, cfg.EnabledSkillNames)
|
||||
@@ -315,7 +325,11 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
||||
tools = wrapToolsWithLoopGuard(tools, toolLoopGuard, toolLoopAbortCallIDs)
|
||||
}
|
||||
|
||||
opts := a.buildGenerateOptions(cfg, tools)
|
||||
var prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams
|
||||
if readMediaState != nil {
|
||||
prepareStep = readMediaState.prepareStep
|
||||
}
|
||||
opts := a.buildGenerateOptions(cfg, tools, prepareStep)
|
||||
opts = append(opts,
|
||||
sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams {
|
||||
if cfg.LoopDetection.Enabled {
|
||||
@@ -376,7 +390,11 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
||||
}
|
||||
}
|
||||
|
||||
finalMessages := StripTagsFromMessages(genResult.Messages)
|
||||
finalMessages := genResult.Messages
|
||||
if readMediaState != nil {
|
||||
finalMessages = readMediaState.mergeMessages(genResult.Steps, finalMessages)
|
||||
}
|
||||
finalMessages = StripTagsFromMessages(finalMessages)
|
||||
|
||||
return &GenerateResult{
|
||||
Messages: finalMessages,
|
||||
@@ -389,7 +407,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool) []sdk.GenerateOption {
|
||||
func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool, prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams) []sdk.GenerateOption {
|
||||
opts := []sdk.GenerateOption{
|
||||
sdk.WithModel(cfg.Model),
|
||||
sdk.WithMessages(cfg.Messages),
|
||||
@@ -399,6 +417,9 @@ func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool) []sdk.Genera
|
||||
if len(tools) > 0 {
|
||||
opts = append(opts, sdk.WithTools(tools))
|
||||
}
|
||||
if prepareStep != nil {
|
||||
opts = append(opts, sdk.WithPrepareStep(prepareStep))
|
||||
}
|
||||
opts = append(opts, BuildReasoningOptions(ModelConfig{
|
||||
ClientType: resolveClientType(cfg.Model),
|
||||
ReasoningConfig: &ReasoningConfig{
|
||||
@@ -432,13 +453,14 @@ func (a *Agent) assembleTools(ctx context.Context, cfg RunConfig) ([]sdk.Tool, e
|
||||
return nil, nil
|
||||
}
|
||||
session := tools.SessionContext{
|
||||
BotID: cfg.Identity.BotID,
|
||||
ChatID: cfg.Identity.ChatID,
|
||||
ChannelIdentityID: cfg.Identity.ChannelIdentityID,
|
||||
SessionToken: cfg.Identity.SessionToken,
|
||||
CurrentPlatform: cfg.Identity.CurrentPlatform,
|
||||
ReplyTarget: cfg.Identity.ReplyTarget,
|
||||
IsSubagent: cfg.Identity.IsSubagent,
|
||||
BotID: cfg.Identity.BotID,
|
||||
ChatID: cfg.Identity.ChatID,
|
||||
ChannelIdentityID: cfg.Identity.ChannelIdentityID,
|
||||
SessionToken: cfg.Identity.SessionToken,
|
||||
CurrentPlatform: cfg.Identity.CurrentPlatform,
|
||||
ReplyTarget: cfg.Identity.ReplyTarget,
|
||||
SupportsImageInput: cfg.SupportsImageInput,
|
||||
IsSubagent: cfg.Identity.IsSubagent,
|
||||
}
|
||||
|
||||
var allTools []sdk.Tool
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
agenttools "github.com/memohai/memoh/internal/agent/tools"
|
||||
)
|
||||
|
||||
func decorateReadMediaTools(model *sdk.Model, tools []sdk.Tool) ([]sdk.Tool, *readMediaDecorationState) {
|
||||
if len(tools) == 0 {
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
clientType := resolveClientType(model)
|
||||
state := &readMediaDecorationState{
|
||||
pendingImages: make(map[string]sdk.ImagePart),
|
||||
}
|
||||
wrapped := make([]sdk.Tool, 0, len(tools))
|
||||
found := false
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Name != agenttools.ReadMediaToolName || tool.Execute == nil {
|
||||
wrapped = append(wrapped, tool)
|
||||
continue
|
||||
}
|
||||
|
||||
found = true
|
||||
originalExecute := tool.Execute
|
||||
toolCopy := tool
|
||||
toolCopy.Execute = func(ctx *sdk.ToolExecContext, input any) (any, error) {
|
||||
output, err := originalExecute(ctx, input)
|
||||
if err != nil {
|
||||
return output, err
|
||||
}
|
||||
|
||||
publicResult, image, ok := normalizeReadMediaOutput(output, clientType)
|
||||
if !ok {
|
||||
return output, nil
|
||||
}
|
||||
if ctx != nil && strings.TrimSpace(ctx.ToolCallID) != "" && strings.TrimSpace(image.Image) != "" {
|
||||
if _, exists := state.pendingImages[ctx.ToolCallID]; !exists {
|
||||
state.pendingOrder = append(state.pendingOrder, ctx.ToolCallID)
|
||||
}
|
||||
state.pendingImages[ctx.ToolCallID] = image
|
||||
}
|
||||
return publicResult, nil
|
||||
}
|
||||
wrapped = append(wrapped, toolCopy)
|
||||
}
|
||||
|
||||
if !found {
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
return wrapped, state
|
||||
}
|
||||
|
||||
type readMediaDecorationState struct {
|
||||
pendingOrder []string
|
||||
pendingImages map[string]sdk.ImagePart
|
||||
prepareCalls int
|
||||
injections []readMediaInjection
|
||||
}
|
||||
|
||||
type readMediaInjection struct {
|
||||
afterStep int
|
||||
message sdk.Message
|
||||
}
|
||||
|
||||
func (s *readMediaDecorationState) prepareStep(params *sdk.GenerateParams) *sdk.GenerateParams {
|
||||
if s == nil || params == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
afterStep := s.prepareCalls
|
||||
s.prepareCalls++
|
||||
|
||||
if len(s.pendingOrder) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := make([]sdk.MessagePart, 0, len(s.pendingOrder))
|
||||
for _, toolCallID := range s.pendingOrder {
|
||||
image, ok := s.pendingImages[toolCallID]
|
||||
delete(s.pendingImages, toolCallID)
|
||||
if !ok || strings.TrimSpace(image.Image) == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, image)
|
||||
}
|
||||
s.pendingOrder = s.pendingOrder[:0]
|
||||
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
message := sdk.Message{
|
||||
Role: sdk.MessageRoleUser,
|
||||
Content: parts,
|
||||
}
|
||||
s.injections = append(s.injections, readMediaInjection{
|
||||
afterStep: afterStep,
|
||||
message: message,
|
||||
})
|
||||
|
||||
next := *params
|
||||
next.Messages = append(append([]sdk.Message(nil), params.Messages...), message)
|
||||
return &next
|
||||
}
|
||||
|
||||
func (s *readMediaDecorationState) mergeMessages(steps []sdk.StepResult, fallback []sdk.Message) []sdk.Message {
|
||||
if s == nil || len(s.injections) == 0 {
|
||||
return fallback
|
||||
}
|
||||
if len(steps) == 0 {
|
||||
merged := append([]sdk.Message(nil), fallback...)
|
||||
for _, injection := range s.injections {
|
||||
merged = append(merged, injection.message)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
merged := make([]sdk.Message, 0, len(fallback)+len(s.injections))
|
||||
injectionIndex := 0
|
||||
for stepIndex, step := range steps {
|
||||
merged = append(merged, step.Messages...)
|
||||
for injectionIndex < len(s.injections) && s.injections[injectionIndex].afterStep == stepIndex {
|
||||
merged = append(merged, s.injections[injectionIndex].message)
|
||||
injectionIndex++
|
||||
}
|
||||
}
|
||||
for injectionIndex < len(s.injections) {
|
||||
merged = append(merged, s.injections[injectionIndex].message)
|
||||
injectionIndex++
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func normalizeReadMediaOutput(output any, clientType string) (any, sdk.ImagePart, bool) {
|
||||
switch value := output.(type) {
|
||||
case agenttools.ReadMediaToolOutput:
|
||||
return value.Public, buildReadMediaImagePart(clientType, value.ImageBase64, value.ImageMediaType), true
|
||||
case *agenttools.ReadMediaToolOutput:
|
||||
if value == nil {
|
||||
return nil, sdk.ImagePart{}, false
|
||||
}
|
||||
return value.Public, buildReadMediaImagePart(clientType, value.ImageBase64, value.ImageMediaType), true
|
||||
default:
|
||||
return nil, sdk.ImagePart{}, false
|
||||
}
|
||||
}
|
||||
|
||||
func buildReadMediaImagePart(clientType, imageBase64, mediaType string) sdk.ImagePart {
|
||||
imageBase64 = strings.TrimSpace(imageBase64)
|
||||
mediaType = strings.TrimSpace(mediaType)
|
||||
if imageBase64 == "" {
|
||||
return sdk.ImagePart{}
|
||||
}
|
||||
if mediaType == "" {
|
||||
mediaType = "image/png"
|
||||
}
|
||||
|
||||
image := imageBase64
|
||||
if clientType != ClientTypeAnthropicMessages {
|
||||
image = fmt.Sprintf("data:%s;base64,%s", mediaType, imageBase64)
|
||||
}
|
||||
return sdk.ImagePart{
|
||||
Image: image,
|
||||
MediaType: mediaType,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,394 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
|
||||
agenttools "github.com/memohai/memoh/internal/agent/tools"
|
||||
"github.com/memohai/memoh/internal/workspace/bridge"
|
||||
pb "github.com/memohai/memoh/internal/workspace/bridgepb"
|
||||
)
|
||||
|
||||
const agentReadMediaTestBufSize = 1 << 20
|
||||
|
||||
type agentReadMediaContainerService struct {
|
||||
pb.UnimplementedContainerServiceServer
|
||||
files map[string][]byte
|
||||
}
|
||||
|
||||
func (s *agentReadMediaContainerService) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error {
|
||||
data, ok := s.files[req.GetPath()]
|
||||
if !ok {
|
||||
return status.Error(codes.NotFound, "not found")
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return stream.Send(&pb.DataChunk{Data: data})
|
||||
}
|
||||
|
||||
type agentReadMediaBridgeProvider struct {
|
||||
client *bridge.Client
|
||||
}
|
||||
|
||||
func (p *agentReadMediaBridgeProvider) MCPClient(_ context.Context, _ string) (*bridge.Client, error) {
|
||||
return p.client, nil
|
||||
}
|
||||
|
||||
func newAgentReadMediaBridgeProvider(t *testing.T, files map[string][]byte) bridge.Provider {
|
||||
t.Helper()
|
||||
|
||||
lis := bufconn.Listen(agentReadMediaTestBufSize)
|
||||
srv := grpc.NewServer()
|
||||
pb.RegisterContainerServiceServer(srv, &agentReadMediaContainerService{files: files})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
_ = srv.Serve(lis)
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
srv.Stop()
|
||||
<-done
|
||||
})
|
||||
|
||||
dialer := func(ctx context.Context, _ string) (net.Conn, error) {
|
||||
return lis.DialContext(ctx)
|
||||
}
|
||||
conn, err := grpc.NewClient(
|
||||
"passthrough://bufnet",
|
||||
grpc.WithContextDialer(dialer),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("grpc.NewClient: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
|
||||
return &agentReadMediaBridgeProvider{client: bridge.NewClientFromConn(conn)}
|
||||
}
|
||||
|
||||
type agentReadMediaMockProvider struct {
|
||||
name string
|
||||
calls int
|
||||
handler func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error)
|
||||
}
|
||||
|
||||
func (m *agentReadMediaMockProvider) Name() string {
|
||||
if m.name != "" {
|
||||
return m.name
|
||||
}
|
||||
return "mock"
|
||||
}
|
||||
|
||||
func (*agentReadMediaMockProvider) ListModels(context.Context) ([]sdk.Model, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (*agentReadMediaMockProvider) Test(context.Context) *sdk.ProviderTestResult {
|
||||
return &sdk.ProviderTestResult{Status: sdk.ProviderStatusOK, Message: "ok"}
|
||||
}
|
||||
|
||||
func (*agentReadMediaMockProvider) TestModel(context.Context, string) (*sdk.ModelTestResult, error) {
|
||||
return &sdk.ModelTestResult{Supported: true, Message: "supported"}, nil
|
||||
}
|
||||
|
||||
func (m *agentReadMediaMockProvider) DoGenerate(_ context.Context, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||
m.calls++
|
||||
return m.handler(m.calls, params)
|
||||
}
|
||||
|
||||
func (m *agentReadMediaMockProvider) DoStream(ctx context.Context, params sdk.GenerateParams) (*sdk.StreamResult, error) {
|
||||
result, err := m.DoGenerate(ctx, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch := make(chan sdk.StreamPart, 8)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ch <- &sdk.StartPart{}
|
||||
ch <- &sdk.StartStepPart{}
|
||||
if result.Text != "" {
|
||||
ch <- &sdk.TextStartPart{ID: "mock"}
|
||||
ch <- &sdk.TextDeltaPart{ID: "mock", Text: result.Text}
|
||||
ch <- &sdk.TextEndPart{ID: "mock"}
|
||||
}
|
||||
for _, tc := range result.ToolCalls {
|
||||
ch <- &sdk.StreamToolCallPart{
|
||||
ToolCallID: tc.ToolCallID,
|
||||
ToolName: tc.ToolName,
|
||||
Input: tc.Input,
|
||||
}
|
||||
}
|
||||
ch <- &sdk.FinishStepPart{FinishReason: result.FinishReason, Usage: result.Usage, Response: result.Response}
|
||||
ch <- &sdk.FinishPart{FinishReason: result.FinishReason, TotalUsage: result.Usage}
|
||||
}()
|
||||
return &sdk.StreamResult{Stream: ch}, nil
|
||||
}
|
||||
|
||||
func assertInjectedReadMediaMessage(t *testing.T, msg sdk.Message, expectedImage, expectedMediaType string) {
|
||||
t.Helper()
|
||||
|
||||
if msg.Role != sdk.MessageRoleUser {
|
||||
t.Fatalf("expected injected read_media message role %q, got %q", sdk.MessageRoleUser, msg.Role)
|
||||
}
|
||||
if len(msg.Content) != 1 {
|
||||
t.Fatalf("expected one injected content part, got %d", len(msg.Content))
|
||||
}
|
||||
image, ok := msg.Content[0].(sdk.ImagePart)
|
||||
if !ok {
|
||||
t.Fatalf("expected sdk.ImagePart, got %T", msg.Content[0])
|
||||
}
|
||||
if image.Image != expectedImage {
|
||||
t.Fatalf("unexpected injected image payload: %q", image.Image)
|
||||
}
|
||||
if image.MediaType != expectedMediaType {
|
||||
t.Fatalf("unexpected injected media type: %q", image.MediaType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentGenerateReadMediaInjectsImageIntoNextStep(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pngBytes := []byte("\x89PNG\r\n\x1a\npayload")
|
||||
expectedDataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(pngBytes)
|
||||
|
||||
modelProvider := &agentReadMediaMockProvider{
|
||||
handler: func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||
if call == 1 {
|
||||
return &sdk.GenerateResult{
|
||||
FinishReason: sdk.FinishReasonToolCalls,
|
||||
ToolCalls: []sdk.ToolCall{{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "read_media",
|
||||
Input: map[string]any{"path": "/data/images/demo.png"},
|
||||
}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if len(params.Messages) < 4 {
|
||||
t.Fatalf("expected prior tool and injected messages, got %d", len(params.Messages))
|
||||
}
|
||||
|
||||
last := params.Messages[len(params.Messages)-1]
|
||||
if last.Role != sdk.MessageRoleUser {
|
||||
t.Fatalf("expected last message to be injected user image, got %s", last.Role)
|
||||
}
|
||||
if len(last.Content) != 1 {
|
||||
t.Fatalf("expected one injected content part, got %d", len(last.Content))
|
||||
}
|
||||
image, ok := last.Content[0].(sdk.ImagePart)
|
||||
if !ok {
|
||||
t.Fatalf("expected sdk.ImagePart, got %T", last.Content[0])
|
||||
}
|
||||
if image.Image != expectedDataURL {
|
||||
t.Fatalf("unexpected injected image payload: %q", image.Image)
|
||||
}
|
||||
if image.MediaType != "image/png" {
|
||||
t.Fatalf("unexpected injected media type: %q", image.MediaType)
|
||||
}
|
||||
|
||||
var toolResult sdk.ToolResultPart
|
||||
foundToolMessage := false
|
||||
for _, msg := range params.Messages {
|
||||
if msg.Role != sdk.MessageRoleTool || len(msg.Content) == 0 {
|
||||
continue
|
||||
}
|
||||
part, ok := msg.Content[0].(sdk.ToolResultPart)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
toolResult = part
|
||||
foundToolMessage = true
|
||||
break
|
||||
}
|
||||
if !foundToolMessage {
|
||||
t.Fatal("expected tool result message before second step")
|
||||
}
|
||||
raw, err := json.Marshal(toolResult.Result)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal tool result: %v", err)
|
||||
}
|
||||
if !bytes.Contains(raw, []byte(`"ok":true`)) {
|
||||
t.Fatalf("expected compact success metadata, got %s", raw)
|
||||
}
|
||||
if bytes.Contains(raw, []byte(expectedDataURL)) || bytes.Contains(raw, []byte("payload")) {
|
||||
t.Fatalf("tool result leaked image bytes: %s", raw)
|
||||
}
|
||||
|
||||
return &sdk.GenerateResult{
|
||||
Text: "done",
|
||||
FinishReason: sdk.FinishReasonStop,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
a := New(Deps{})
|
||||
a.SetToolProviders([]agenttools.ToolProvider{
|
||||
agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{
|
||||
"/data/images/demo.png": pngBytes,
|
||||
}), "/data"),
|
||||
})
|
||||
|
||||
result, err := a.Generate(context.Background(), RunConfig{
|
||||
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
||||
Messages: []sdk.Message{sdk.UserMessage("look at the image")},
|
||||
SupportsImageInput: true,
|
||||
Identity: SessionContext{
|
||||
BotID: "bot-1",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate returned error: %v", err)
|
||||
}
|
||||
if result.Text != "done" {
|
||||
t.Fatalf("unexpected result text: %q", result.Text)
|
||||
}
|
||||
if len(result.Messages) != 4 {
|
||||
t.Fatalf("expected persisted step + injected history, got %d messages", len(result.Messages))
|
||||
}
|
||||
assertInjectedReadMediaMessage(t, result.Messages[2], expectedDataURL, "image/png")
|
||||
if result.Messages[3].Role != sdk.MessageRoleAssistant {
|
||||
t.Fatalf("expected final persisted message to be assistant, got %s", result.Messages[3].Role)
|
||||
}
|
||||
if modelProvider.calls != 2 {
|
||||
t.Fatalf("expected 2 model calls, got %d", modelProvider.calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentGenerateReadMediaInjectsAnthropicSafeImageIntoNextStep(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pngBytes := []byte("\x89PNG\r\n\x1a\npayload")
|
||||
expectedBase64 := base64.StdEncoding.EncodeToString(pngBytes)
|
||||
|
||||
modelProvider := &agentReadMediaMockProvider{
|
||||
name: "anthropic-messages",
|
||||
handler: func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||
if call == 1 {
|
||||
return &sdk.GenerateResult{
|
||||
FinishReason: sdk.FinishReasonToolCalls,
|
||||
ToolCalls: []sdk.ToolCall{{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "read_media",
|
||||
Input: map[string]any{"path": "/data/images/demo.png"},
|
||||
}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
last := params.Messages[len(params.Messages)-1]
|
||||
image, ok := last.Content[0].(sdk.ImagePart)
|
||||
if !ok {
|
||||
t.Fatalf("expected sdk.ImagePart, got %T", last.Content[0])
|
||||
}
|
||||
if image.Image != expectedBase64 {
|
||||
t.Fatalf("expected raw base64 for anthropic, got %q", image.Image)
|
||||
}
|
||||
if image.MediaType != "image/png" {
|
||||
t.Fatalf("unexpected injected media type: %q", image.MediaType)
|
||||
}
|
||||
if strings.HasPrefix(image.Image, "data:") {
|
||||
t.Fatalf("anthropic image payload must not be a data URL: %q", image.Image)
|
||||
}
|
||||
|
||||
return &sdk.GenerateResult{
|
||||
Text: "done",
|
||||
FinishReason: sdk.FinishReasonStop,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
a := New(Deps{})
|
||||
a.SetToolProviders([]agenttools.ToolProvider{
|
||||
agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{
|
||||
"/data/images/demo.png": pngBytes,
|
||||
}), "/data"),
|
||||
})
|
||||
|
||||
_, err := a.Generate(context.Background(), RunConfig{
|
||||
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
||||
Messages: []sdk.Message{sdk.UserMessage("look at the image")},
|
||||
SupportsImageInput: true,
|
||||
Identity: SessionContext{
|
||||
BotID: "bot-1",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentStreamReadMediaPersistsInjectedImageInTerminalMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pngBytes := []byte("\x89PNG\r\n\x1a\npayload")
|
||||
expectedDataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(pngBytes)
|
||||
|
||||
modelProvider := &agentReadMediaMockProvider{
|
||||
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||
if call == 1 {
|
||||
return &sdk.GenerateResult{
|
||||
FinishReason: sdk.FinishReasonToolCalls,
|
||||
ToolCalls: []sdk.ToolCall{{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "read_media",
|
||||
Input: map[string]any{"path": "/data/images/demo.png"},
|
||||
}},
|
||||
}, nil
|
||||
}
|
||||
return &sdk.GenerateResult{
|
||||
Text: "done",
|
||||
FinishReason: sdk.FinishReasonStop,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
a := New(Deps{})
|
||||
a.SetToolProviders([]agenttools.ToolProvider{
|
||||
agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{
|
||||
"/data/images/demo.png": pngBytes,
|
||||
}), "/data"),
|
||||
})
|
||||
|
||||
var terminal StreamEvent
|
||||
for event := range a.Stream(context.Background(), RunConfig{
|
||||
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
||||
Messages: []sdk.Message{sdk.UserMessage("look at the image")},
|
||||
SupportsImageInput: true,
|
||||
Identity: SessionContext{
|
||||
BotID: "bot-1",
|
||||
},
|
||||
}) {
|
||||
if event.IsTerminal() {
|
||||
terminal = event
|
||||
}
|
||||
}
|
||||
|
||||
if terminal.Type != EventAgentEnd {
|
||||
t.Fatalf("expected terminal event %q, got %q", EventAgentEnd, terminal.Type)
|
||||
}
|
||||
|
||||
var messages []sdk.Message
|
||||
if err := json.Unmarshal(terminal.Messages, &messages); err != nil {
|
||||
t.Fatalf("unmarshal terminal messages: %v", err)
|
||||
}
|
||||
if len(messages) != 4 {
|
||||
t.Fatalf("expected persisted step + injected history, got %d messages", len(messages))
|
||||
}
|
||||
assertInjectedReadMediaMessage(t, messages[2], expectedDataURL, "image/png")
|
||||
if messages[3].Role != sdk.MessageRoleAssistant {
|
||||
t.Fatalf("expected final persisted message to be assistant, got %s", messages[3].Role)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,217 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
"github.com/memohai/memoh/internal/workspace/bridge"
|
||||
)
|
||||
|
||||
const (
|
||||
ReadMediaToolName = "read_media"
|
||||
toolReadMedia = ReadMediaToolName
|
||||
defaultReadMediaRoot = "/data"
|
||||
defaultReadMediaMaxBytes = 20 * 1024 * 1024
|
||||
)
|
||||
|
||||
var readMediaSupportedMimeTypes = map[string]struct{}{
|
||||
"image/gif": {},
|
||||
"image/jpeg": {},
|
||||
"image/png": {},
|
||||
"image/webp": {},
|
||||
}
|
||||
|
||||
// ReadMediaToolResult is the public result returned to the model.
|
||||
type ReadMediaToolResult struct {
|
||||
OK bool `json:"ok"`
|
||||
Path string `json:"path,omitempty"`
|
||||
Mime string `json:"mime,omitempty"`
|
||||
Size int `json:"size,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ReadMediaToolOutput is the internal execution result used by the agent to
|
||||
// inject the image into the next Twilight AI step while keeping the visible
|
||||
// tool result lightweight.
|
||||
type ReadMediaToolOutput struct {
|
||||
Public ReadMediaToolResult
|
||||
ImageBase64 string
|
||||
ImageMediaType string
|
||||
}
|
||||
|
||||
type readMediaToolOutput = ReadMediaToolOutput
|
||||
|
||||
type ReadMediaProvider struct {
|
||||
clients bridge.Provider
|
||||
rootDir string
|
||||
maxBytes int64
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewReadMediaProvider(log *slog.Logger, clients bridge.Provider, rootDir string) *ReadMediaProvider {
|
||||
if log == nil {
|
||||
log = slog.Default()
|
||||
}
|
||||
root := strings.TrimSpace(rootDir)
|
||||
if root == "" {
|
||||
root = defaultReadMediaRoot
|
||||
}
|
||||
return &ReadMediaProvider{
|
||||
clients: clients,
|
||||
rootDir: path.Clean(root),
|
||||
maxBytes: defaultReadMediaMaxBytes,
|
||||
logger: log.With(slog.String("tool", "read_media")),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ReadMediaProvider) Tools(_ context.Context, session SessionContext) ([]sdk.Tool, error) {
|
||||
if p == nil || p.clients == nil || !session.SupportsImageInput {
|
||||
return nil, nil
|
||||
}
|
||||
root := p.rootDir
|
||||
if root == "" {
|
||||
root = defaultReadMediaRoot
|
||||
}
|
||||
sess := session
|
||||
return []sdk.Tool{
|
||||
{
|
||||
Name: toolReadMedia,
|
||||
Description: fmt.Sprintf("Load an image file from %s into model context so you can inspect it. Relative paths are resolved under %s.", root, root),
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": fmt.Sprintf("Image file path under %s. Absolute paths must stay under %s; relative paths are resolved under %s.", root, root, root),
|
||||
},
|
||||
},
|
||||
"required": []string{"path"},
|
||||
},
|
||||
Execute: func(ctx *sdk.ToolExecContext, input any) (any, error) {
|
||||
return p.execReadMedia(ctx.Context, sess, inputAsMap(input))
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *ReadMediaProvider) execReadMedia(ctx context.Context, session SessionContext, args map[string]any) (any, error) {
|
||||
client, err := p.getClient(ctx, session.BotID)
|
||||
if err != nil {
|
||||
return readMediaErrorResult(err.Error()), nil
|
||||
}
|
||||
|
||||
resolvedPath, err := p.resolveImagePath(StringArg(args, "path"))
|
||||
if err != nil {
|
||||
return readMediaErrorResult(err.Error()), nil
|
||||
}
|
||||
|
||||
reader, err := client.ReadRaw(ctx, resolvedPath)
|
||||
if err != nil {
|
||||
return readMediaErrorResult(err.Error()), nil
|
||||
}
|
||||
defer func() { _ = reader.Close() }()
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(reader, p.maxBytes+1))
|
||||
if err != nil {
|
||||
return readMediaErrorResult("read_media failed to load image: " + err.Error()), nil
|
||||
}
|
||||
if int64(len(data)) > p.maxBytes {
|
||||
return readMediaErrorResult(fmt.Sprintf("read_media failed to load image: file exceeds %d bytes", p.maxBytes)), nil
|
||||
}
|
||||
|
||||
mimeType, err := detectReadMediaMime(data)
|
||||
if err != nil {
|
||||
return readMediaErrorResult(err.Error()), nil
|
||||
}
|
||||
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
return ReadMediaToolOutput{
|
||||
Public: ReadMediaToolResult{
|
||||
OK: true,
|
||||
Path: resolvedPath,
|
||||
Mime: mimeType,
|
||||
Size: len(data),
|
||||
},
|
||||
ImageBase64: encoded,
|
||||
ImageMediaType: mimeType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *ReadMediaProvider) getClient(ctx context.Context, botID string) (*bridge.Client, error) {
|
||||
botID = strings.TrimSpace(botID)
|
||||
if botID == "" {
|
||||
return nil, errors.New("bot_id is required")
|
||||
}
|
||||
client, err := p.clients.MCPClient(ctx, botID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("container not reachable: %w", err)
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (p *ReadMediaProvider) resolveImagePath(raw string) (string, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "", errors.New("path is required")
|
||||
}
|
||||
|
||||
root := p.rootDir
|
||||
if root == "" {
|
||||
root = defaultReadMediaRoot
|
||||
}
|
||||
root = path.Clean(root)
|
||||
|
||||
resolved := trimmed
|
||||
if !strings.HasPrefix(resolved, "/") {
|
||||
resolved = path.Join(root, resolved)
|
||||
}
|
||||
resolved = path.Clean(resolved)
|
||||
|
||||
if resolved == root || !strings.HasPrefix(resolved, root+"/") {
|
||||
return "", fmt.Errorf("path must be under %s", root)
|
||||
}
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
func readMediaErrorResult(message string) ReadMediaToolOutput {
|
||||
msg := strings.TrimSpace(message)
|
||||
if msg == "" {
|
||||
msg = "read_media failed"
|
||||
}
|
||||
return ReadMediaToolOutput{
|
||||
Public: ReadMediaToolResult{
|
||||
OK: false,
|
||||
Error: msg,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func detectReadMediaMime(data []byte) (string, error) {
|
||||
sniffedMime := ""
|
||||
if len(data) > 0 {
|
||||
sniffedMime = strings.ToLower(strings.TrimSpace(http.DetectContentType(data)))
|
||||
}
|
||||
|
||||
switch {
|
||||
case sniffedMime == "":
|
||||
return "", errors.New("read_media only supports PNG, JPEG, GIF, or WebP image bytes")
|
||||
case isSupportedReadMediaMime(sniffedMime):
|
||||
return sniffedMime, nil
|
||||
default:
|
||||
return "", errors.New("read_media only supports PNG, JPEG, GIF, or WebP image bytes")
|
||||
}
|
||||
}
|
||||
|
||||
func isSupportedReadMediaMime(mimeType string) bool {
|
||||
_, ok := readMediaSupportedMimeTypes[strings.ToLower(strings.TrimSpace(mimeType))]
|
||||
return ok
|
||||
}
|
||||
@@ -0,0 +1,306 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
|
||||
"github.com/memohai/memoh/internal/workspace/bridge"
|
||||
pb "github.com/memohai/memoh/internal/workspace/bridgepb"
|
||||
)
|
||||
|
||||
const readMediaTestBufSize = 1 << 20
|
||||
|
||||
type readMediaTestContainerService struct {
|
||||
pb.UnimplementedContainerServiceServer
|
||||
files map[string][]byte
|
||||
}
|
||||
|
||||
func (s *readMediaTestContainerService) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error {
|
||||
data, ok := s.files[req.GetPath()]
|
||||
if !ok {
|
||||
return status.Error(codes.NotFound, "not found")
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return stream.Send(&pb.DataChunk{Data: data})
|
||||
}
|
||||
|
||||
type readMediaStaticProvider struct {
|
||||
client *bridge.Client
|
||||
}
|
||||
|
||||
func (p *readMediaStaticProvider) MCPClient(_ context.Context, _ string) (*bridge.Client, error) {
|
||||
return p.client, nil
|
||||
}
|
||||
|
||||
func newReadMediaBridgeProvider(t *testing.T, files map[string][]byte) bridge.Provider {
|
||||
t.Helper()
|
||||
|
||||
lis := bufconn.Listen(readMediaTestBufSize)
|
||||
srv := grpc.NewServer()
|
||||
pb.RegisterContainerServiceServer(srv, &readMediaTestContainerService{files: files})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
_ = srv.Serve(lis)
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
srv.Stop()
|
||||
<-done
|
||||
})
|
||||
|
||||
dialer := func(ctx context.Context, _ string) (net.Conn, error) {
|
||||
return lis.DialContext(ctx)
|
||||
}
|
||||
conn, err := grpc.NewClient(
|
||||
"passthrough://bufnet",
|
||||
grpc.WithContextDialer(dialer),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("grpc.NewClient: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
|
||||
return &readMediaStaticProvider{client: bridge.NewClientFromConn(conn)}
|
||||
}
|
||||
|
||||
func findToolByName(tools []sdk.Tool, name string) (sdk.Tool, bool) {
|
||||
for _, tool := range tools {
|
||||
if tool.Name == name {
|
||||
return tool, true
|
||||
}
|
||||
}
|
||||
return sdk.Tool{}, false
|
||||
}
|
||||
|
||||
func TestReadMediaProviderToolsOnlyWhenImageInputIsSupported(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, nil), "/data")
|
||||
|
||||
toolsWithoutImage, err := provider.Tools(context.Background(), SessionContext{
|
||||
BotID: "bot-1",
|
||||
SupportsImageInput: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Tools without image input returned error: %v", err)
|
||||
}
|
||||
if len(toolsWithoutImage) != 0 {
|
||||
t.Fatalf("expected no tools without image input support, got %d", len(toolsWithoutImage))
|
||||
}
|
||||
|
||||
toolsWithImage, err := provider.Tools(context.Background(), SessionContext{
|
||||
BotID: "bot-1",
|
||||
SupportsImageInput: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Tools with image input returned error: %v", err)
|
||||
}
|
||||
|
||||
tool, ok := findToolByName(toolsWithImage, toolReadMedia)
|
||||
if !ok {
|
||||
t.Fatalf("expected %q tool to be exposed", toolReadMedia)
|
||||
}
|
||||
if tool.Execute == nil {
|
||||
t.Fatal("expected read_media tool to be executable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadMediaProviderExecuteReadsImageUnderData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pngBytes := []byte("\x89PNG\r\n\x1a\npayload")
|
||||
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, map[string][]byte{
|
||||
"/data/images/demo.png": pngBytes,
|
||||
}), "/data")
|
||||
|
||||
tools, err := provider.Tools(context.Background(), SessionContext{
|
||||
BotID: "bot-1",
|
||||
SupportsImageInput: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Tools returned error: %v", err)
|
||||
}
|
||||
|
||||
tool, ok := findToolByName(tools, toolReadMedia)
|
||||
if !ok {
|
||||
t.Fatalf("expected %q tool", toolReadMedia)
|
||||
}
|
||||
|
||||
output, err := tool.Execute(&sdk.ToolExecContext{
|
||||
Context: context.Background(),
|
||||
ToolCallID: "call-1",
|
||||
ToolName: toolReadMedia,
|
||||
}, map[string]any{"path": "images/demo.png"})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute returned error: %v", err)
|
||||
}
|
||||
|
||||
result, ok := output.(readMediaToolOutput)
|
||||
if !ok {
|
||||
t.Fatalf("expected readMediaToolOutput, got %T", output)
|
||||
}
|
||||
if !result.Public.OK {
|
||||
t.Fatalf("expected success result, got %+v", result.Public)
|
||||
}
|
||||
if result.Public.Path != "/data/images/demo.png" {
|
||||
t.Fatalf("unexpected path: %q", result.Public.Path)
|
||||
}
|
||||
if result.Public.Mime != "image/png" {
|
||||
t.Fatalf("unexpected mime: %q", result.Public.Mime)
|
||||
}
|
||||
if result.Public.Size != len(pngBytes) {
|
||||
t.Fatalf("unexpected size: %d", result.Public.Size)
|
||||
}
|
||||
|
||||
expectedBase64 := base64.StdEncoding.EncodeToString(pngBytes)
|
||||
if result.ImageBase64 != expectedBase64 {
|
||||
t.Fatalf("unexpected image payload: %q", result.ImageBase64)
|
||||
}
|
||||
if result.ImageMediaType != "image/png" {
|
||||
t.Fatalf("unexpected image media type: %q", result.ImageMediaType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadMediaProviderExecuteRejectsPathOutsideData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, nil), "/data")
|
||||
|
||||
tools, err := provider.Tools(context.Background(), SessionContext{
|
||||
BotID: "bot-1",
|
||||
SupportsImageInput: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Tools returned error: %v", err)
|
||||
}
|
||||
|
||||
tool, ok := findToolByName(tools, toolReadMedia)
|
||||
if !ok {
|
||||
t.Fatalf("expected %q tool", toolReadMedia)
|
||||
}
|
||||
|
||||
output, err := tool.Execute(&sdk.ToolExecContext{
|
||||
Context: context.Background(),
|
||||
ToolCallID: "call-2",
|
||||
ToolName: toolReadMedia,
|
||||
}, map[string]any{"path": "/tmp/demo.png"})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute returned error: %v", err)
|
||||
}
|
||||
|
||||
result, ok := output.(readMediaToolOutput)
|
||||
if !ok {
|
||||
t.Fatalf("expected readMediaToolOutput, got %T", output)
|
||||
}
|
||||
if result.Public.OK {
|
||||
t.Fatalf("expected error result, got %+v", result.Public)
|
||||
}
|
||||
if !strings.Contains(result.Public.Error, "path must be under /data") {
|
||||
t.Fatalf("unexpected error: %q", result.Public.Error)
|
||||
}
|
||||
if result.ImageBase64 != "" {
|
||||
t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadMediaProviderExecuteRejectsExtensionOnlySVG(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, map[string][]byte{
|
||||
"/data/images/demo.svg": []byte(`<svg xmlns="http://www.w3.org/2000/svg"></svg>`),
|
||||
}), "/data")
|
||||
|
||||
tools, err := provider.Tools(context.Background(), SessionContext{
|
||||
BotID: "bot-1",
|
||||
SupportsImageInput: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Tools returned error: %v", err)
|
||||
}
|
||||
|
||||
tool, ok := findToolByName(tools, toolReadMedia)
|
||||
if !ok {
|
||||
t.Fatalf("expected %q tool", toolReadMedia)
|
||||
}
|
||||
|
||||
output, err := tool.Execute(&sdk.ToolExecContext{
|
||||
Context: context.Background(),
|
||||
ToolCallID: "call-3",
|
||||
ToolName: toolReadMedia,
|
||||
}, map[string]any{"path": "images/demo.svg"})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute returned error: %v", err)
|
||||
}
|
||||
|
||||
result, ok := output.(readMediaToolOutput)
|
||||
if !ok {
|
||||
t.Fatalf("expected readMediaToolOutput, got %T", output)
|
||||
}
|
||||
if result.Public.OK {
|
||||
t.Fatalf("expected error result, got %+v", result.Public)
|
||||
}
|
||||
if !strings.Contains(result.Public.Error, "PNG, JPEG, GIF, or WebP") {
|
||||
t.Fatalf("unexpected error: %q", result.Public.Error)
|
||||
}
|
||||
if result.ImageBase64 != "" {
|
||||
t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadMediaProviderExecuteRejectsCorruptedRasterBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, map[string][]byte{
|
||||
"/data/images/demo.png": []byte("definitely not a png"),
|
||||
}), "/data")
|
||||
|
||||
tools, err := provider.Tools(context.Background(), SessionContext{
|
||||
BotID: "bot-1",
|
||||
SupportsImageInput: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Tools returned error: %v", err)
|
||||
}
|
||||
|
||||
tool, ok := findToolByName(tools, toolReadMedia)
|
||||
if !ok {
|
||||
t.Fatalf("expected %q tool", toolReadMedia)
|
||||
}
|
||||
|
||||
output, err := tool.Execute(&sdk.ToolExecContext{
|
||||
Context: context.Background(),
|
||||
ToolCallID: "call-4",
|
||||
ToolName: toolReadMedia,
|
||||
}, map[string]any{"path": "images/demo.png"})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute returned error: %v", err)
|
||||
}
|
||||
|
||||
result, ok := output.(readMediaToolOutput)
|
||||
if !ok {
|
||||
t.Fatalf("expected readMediaToolOutput, got %T", output)
|
||||
}
|
||||
if result.Public.OK {
|
||||
t.Fatalf("expected error result, got %+v", result.Public)
|
||||
}
|
||||
if !strings.Contains(result.Public.Error, "PNG, JPEG, GIF, or WebP") {
|
||||
t.Fatalf("unexpected error: %q", result.Public.Error)
|
||||
}
|
||||
if result.ImageBase64 != "" {
|
||||
t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64)
|
||||
}
|
||||
}
|
||||
@@ -12,13 +12,14 @@ import (
|
||||
|
||||
// SessionContext carries request-scoped identity for tool execution.
|
||||
type SessionContext struct {
|
||||
BotID string
|
||||
ChatID string
|
||||
ChannelIdentityID string
|
||||
SessionToken string //nolint:gosec // carries session credential material at runtime
|
||||
CurrentPlatform string
|
||||
ReplyTarget string
|
||||
IsSubagent bool
|
||||
BotID string
|
||||
ChatID string
|
||||
ChannelIdentityID string
|
||||
SessionToken string //nolint:gosec // carries session credential material at runtime
|
||||
CurrentPlatform string
|
||||
ReplyTarget string
|
||||
SupportsImageInput bool
|
||||
IsSubagent bool
|
||||
}
|
||||
|
||||
// ToolProvider supplies a set of tools for the agent.
|
||||
|
||||
+15
-14
@@ -54,20 +54,21 @@ type LoopDetectionConfig struct {
|
||||
|
||||
// RunConfig holds everything needed for a single agent invocation.
|
||||
type RunConfig struct {
|
||||
Model *sdk.Model
|
||||
ReasoningEffort string
|
||||
Messages []sdk.Message
|
||||
Query string
|
||||
System string
|
||||
Tools []sdk.Tool
|
||||
Channels []string
|
||||
CurrentChannel string
|
||||
Identity SessionContext
|
||||
Skills []SkillEntry
|
||||
EnabledSkillNames []string
|
||||
Inbox []InboxItem
|
||||
LoopDetection LoopDetectionConfig
|
||||
ActiveContextTime int
|
||||
Model *sdk.Model
|
||||
ReasoningEffort string
|
||||
Messages []sdk.Message
|
||||
Query string
|
||||
System string
|
||||
Tools []sdk.Tool
|
||||
SupportsImageInput bool
|
||||
Channels []string
|
||||
CurrentChannel string
|
||||
Identity SessionContext
|
||||
Skills []SkillEntry
|
||||
EnabledSkillNames []string
|
||||
Inbox []InboxItem
|
||||
LoopDetection LoopDetectionConfig
|
||||
ActiveContextTime int
|
||||
}
|
||||
|
||||
// GenerateResult holds the result of a non-streaming agent invocation.
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
package flow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
agentpkg "github.com/memohai/memoh/internal/agent"
|
||||
)
|
||||
|
||||
func TestPrepareRunConfigIncludesReadMediaWhenImageInputIsSupported(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
resolver := &Resolver{}
|
||||
cfg := agentpkg.RunConfig{
|
||||
Query: "describe this image",
|
||||
SupportsImageInput: true,
|
||||
Identity: agentpkg.SessionContext{
|
||||
BotID: "bot-1",
|
||||
},
|
||||
}
|
||||
|
||||
prepared := resolver.prepareRunConfig(context.Background(), cfg)
|
||||
if !strings.Contains(prepared.System, "`read_media`") {
|
||||
t.Fatalf("expected system prompt to include read_media tool, got:\n%s", prepared.System)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareRunConfigOmitsReadMediaWhenImageInputIsUnsupported(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
resolver := &Resolver{}
|
||||
cfg := agentpkg.RunConfig{
|
||||
Query: "describe this image",
|
||||
SupportsImageInput: false,
|
||||
Identity: agentpkg.SessionContext{
|
||||
BotID: "bot-1",
|
||||
},
|
||||
}
|
||||
|
||||
prepared := resolver.prepareRunConfig(context.Background(), cfg)
|
||||
if strings.Contains(prepared.System, "`read_media`") {
|
||||
t.Fatalf("expected system prompt to omit read_media tool, got:\n%s", prepared.System)
|
||||
}
|
||||
}
|
||||
@@ -309,12 +309,13 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
|
||||
sdkMessages := modelMessagesToSDKMessages(nonNilModelMessages(messages))
|
||||
|
||||
runCfg := agentpkg.RunConfig{
|
||||
Model: sdkModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Messages: sdkMessages,
|
||||
Query: headerifiedQuery,
|
||||
Channels: nonNilStrings(req.Channels),
|
||||
CurrentChannel: req.CurrentChannel,
|
||||
Model: sdkModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Messages: sdkMessages,
|
||||
Query: headerifiedQuery,
|
||||
SupportsImageInput: chatModel.HasInputModality(models.ModelInputImage),
|
||||
Channels: nonNilStrings(req.Channels),
|
||||
CurrentChannel: req.CurrentChannel,
|
||||
Identity: agentpkg.SessionContext{
|
||||
BotID: req.BotID,
|
||||
ChatID: req.ChatID,
|
||||
@@ -368,10 +369,7 @@ func (r *Resolver) Chat(ctx context.Context, req conversation.ChatRequest) (conv
|
||||
|
||||
// prepareRunConfig generates the system prompt and appends the user message.
|
||||
func (r *Resolver) prepareRunConfig(ctx context.Context, cfg agentpkg.RunConfig) agentpkg.RunConfig {
|
||||
supportsImageInput := false
|
||||
for _, m := range cfg.Identity.CurrentPlatform {
|
||||
_ = m
|
||||
}
|
||||
supportsImageInput := cfg.SupportsImageInput
|
||||
// Build system prompt
|
||||
var files []agentpkg.SystemFile
|
||||
if r.agent != nil {
|
||||
|
||||
Reference in New Issue
Block a user