mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat(agent): restore read_media in pure Go (#257)
This commit is contained in:
@@ -572,6 +572,7 @@ func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *c
|
|||||||
agenttools.NewMemoryProvider(log, memoryRegistry, settingsService),
|
agenttools.NewMemoryProvider(log, memoryRegistry, settingsService),
|
||||||
agenttools.NewWebProvider(log, settingsService, searchProviderService),
|
agenttools.NewWebProvider(log, settingsService, searchProviderService),
|
||||||
agenttools.NewContainerProvider(log, manager, config.DefaultDataMount),
|
agenttools.NewContainerProvider(log, manager, config.DefaultDataMount),
|
||||||
|
agenttools.NewReadMediaProvider(log, manager, config.DefaultDataMount),
|
||||||
agenttools.NewInboxProvider(log, inboxService),
|
agenttools.NewInboxProvider(log, inboxService),
|
||||||
agenttools.NewEmailProvider(log, emailService, emailManager),
|
agenttools.NewEmailProvider(log, emailService, emailManager),
|
||||||
agenttools.NewWebFetchProvider(log),
|
agenttools.NewWebFetchProvider(log),
|
||||||
|
|||||||
@@ -434,6 +434,7 @@ func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *c
|
|||||||
agenttools.NewMemoryProvider(log, memoryRegistry, settingsService),
|
agenttools.NewMemoryProvider(log, memoryRegistry, settingsService),
|
||||||
agenttools.NewWebProvider(log, settingsService, searchProviderService),
|
agenttools.NewWebProvider(log, settingsService, searchProviderService),
|
||||||
agenttools.NewContainerProvider(log, manager, config.DefaultDataMount),
|
agenttools.NewContainerProvider(log, manager, config.DefaultDataMount),
|
||||||
|
agenttools.NewReadMediaProvider(log, manager, config.DefaultDataMount),
|
||||||
agenttools.NewInboxProvider(log, inboxService),
|
agenttools.NewInboxProvider(log, inboxService),
|
||||||
agenttools.NewEmailProvider(log, emailService, emailManager),
|
agenttools.NewEmailProvider(log, emailService, emailManager),
|
||||||
agenttools.NewWebFetchProvider(log),
|
agenttools.NewWebFetchProvider(log),
|
||||||
|
|||||||
+34
-12
@@ -62,6 +62,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
|||||||
ch <- StreamEvent{Type: EventError, Error: fmt.Sprintf("assemble tools: %v", err)}
|
ch <- StreamEvent{Type: EventError, Error: fmt.Sprintf("assemble tools: %v", err)}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
tools, readMediaState := decorateReadMediaTools(cfg.Model, tools)
|
||||||
|
|
||||||
enabledSkills := make([]string, 0, len(cfg.EnabledSkillNames))
|
enabledSkills := make([]string, 0, len(cfg.EnabledSkillNames))
|
||||||
copy(enabledSkills, cfg.EnabledSkillNames)
|
copy(enabledSkills, cfg.EnabledSkillNames)
|
||||||
@@ -103,7 +104,11 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
|||||||
tagResolvers := DefaultTagResolvers()
|
tagResolvers := DefaultTagResolvers()
|
||||||
tagExtractor := NewStreamTagExtractor(tagResolvers)
|
tagExtractor := NewStreamTagExtractor(tagResolvers)
|
||||||
|
|
||||||
opts := a.buildGenerateOptions(cfg, tools)
|
var prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams
|
||||||
|
if readMediaState != nil {
|
||||||
|
prepareStep = readMediaState.prepareStep
|
||||||
|
}
|
||||||
|
opts := a.buildGenerateOptions(cfg, tools, prepareStep)
|
||||||
|
|
||||||
streamResult, err := a.client.StreamText(ctx, opts...)
|
streamResult, err := a.client.StreamText(ctx, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -251,7 +256,11 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
|||||||
textLoopProbeBuffer.Flush()
|
textLoopProbeBuffer.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
finalMessages := StripTagsFromMessages(streamResult.Messages)
|
finalMessages := streamResult.Messages
|
||||||
|
if readMediaState != nil {
|
||||||
|
finalMessages = readMediaState.mergeMessages(streamResult.Steps, finalMessages)
|
||||||
|
}
|
||||||
|
finalMessages = StripTagsFromMessages(finalMessages)
|
||||||
|
|
||||||
var totalUsage sdk.Usage
|
var totalUsage sdk.Usage
|
||||||
perStepUsages := make([]json.RawMessage, 0, len(streamResult.Steps))
|
perStepUsages := make([]json.RawMessage, 0, len(streamResult.Steps))
|
||||||
@@ -286,6 +295,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("assemble tools: %w", err)
|
return nil, fmt.Errorf("assemble tools: %w", err)
|
||||||
}
|
}
|
||||||
|
tools, readMediaState := decorateReadMediaTools(cfg.Model, tools)
|
||||||
|
|
||||||
enabledSkills := make([]string, 0, len(cfg.EnabledSkillNames))
|
enabledSkills := make([]string, 0, len(cfg.EnabledSkillNames))
|
||||||
copy(enabledSkills, cfg.EnabledSkillNames)
|
copy(enabledSkills, cfg.EnabledSkillNames)
|
||||||
@@ -315,7 +325,11 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
|||||||
tools = wrapToolsWithLoopGuard(tools, toolLoopGuard, toolLoopAbortCallIDs)
|
tools = wrapToolsWithLoopGuard(tools, toolLoopGuard, toolLoopAbortCallIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
opts := a.buildGenerateOptions(cfg, tools)
|
var prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams
|
||||||
|
if readMediaState != nil {
|
||||||
|
prepareStep = readMediaState.prepareStep
|
||||||
|
}
|
||||||
|
opts := a.buildGenerateOptions(cfg, tools, prepareStep)
|
||||||
opts = append(opts,
|
opts = append(opts,
|
||||||
sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams {
|
sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams {
|
||||||
if cfg.LoopDetection.Enabled {
|
if cfg.LoopDetection.Enabled {
|
||||||
@@ -376,7 +390,11 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
finalMessages := StripTagsFromMessages(genResult.Messages)
|
finalMessages := genResult.Messages
|
||||||
|
if readMediaState != nil {
|
||||||
|
finalMessages = readMediaState.mergeMessages(genResult.Steps, finalMessages)
|
||||||
|
}
|
||||||
|
finalMessages = StripTagsFromMessages(finalMessages)
|
||||||
|
|
||||||
return &GenerateResult{
|
return &GenerateResult{
|
||||||
Messages: finalMessages,
|
Messages: finalMessages,
|
||||||
@@ -389,7 +407,7 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool) []sdk.GenerateOption {
|
func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool, prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams) []sdk.GenerateOption {
|
||||||
opts := []sdk.GenerateOption{
|
opts := []sdk.GenerateOption{
|
||||||
sdk.WithModel(cfg.Model),
|
sdk.WithModel(cfg.Model),
|
||||||
sdk.WithMessages(cfg.Messages),
|
sdk.WithMessages(cfg.Messages),
|
||||||
@@ -399,6 +417,9 @@ func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool) []sdk.Genera
|
|||||||
if len(tools) > 0 {
|
if len(tools) > 0 {
|
||||||
opts = append(opts, sdk.WithTools(tools))
|
opts = append(opts, sdk.WithTools(tools))
|
||||||
}
|
}
|
||||||
|
if prepareStep != nil {
|
||||||
|
opts = append(opts, sdk.WithPrepareStep(prepareStep))
|
||||||
|
}
|
||||||
opts = append(opts, BuildReasoningOptions(ModelConfig{
|
opts = append(opts, BuildReasoningOptions(ModelConfig{
|
||||||
ClientType: resolveClientType(cfg.Model),
|
ClientType: resolveClientType(cfg.Model),
|
||||||
ReasoningConfig: &ReasoningConfig{
|
ReasoningConfig: &ReasoningConfig{
|
||||||
@@ -432,13 +453,14 @@ func (a *Agent) assembleTools(ctx context.Context, cfg RunConfig) ([]sdk.Tool, e
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
session := tools.SessionContext{
|
session := tools.SessionContext{
|
||||||
BotID: cfg.Identity.BotID,
|
BotID: cfg.Identity.BotID,
|
||||||
ChatID: cfg.Identity.ChatID,
|
ChatID: cfg.Identity.ChatID,
|
||||||
ChannelIdentityID: cfg.Identity.ChannelIdentityID,
|
ChannelIdentityID: cfg.Identity.ChannelIdentityID,
|
||||||
SessionToken: cfg.Identity.SessionToken,
|
SessionToken: cfg.Identity.SessionToken,
|
||||||
CurrentPlatform: cfg.Identity.CurrentPlatform,
|
CurrentPlatform: cfg.Identity.CurrentPlatform,
|
||||||
ReplyTarget: cfg.Identity.ReplyTarget,
|
ReplyTarget: cfg.Identity.ReplyTarget,
|
||||||
IsSubagent: cfg.Identity.IsSubagent,
|
SupportsImageInput: cfg.SupportsImageInput,
|
||||||
|
IsSubagent: cfg.Identity.IsSubagent,
|
||||||
}
|
}
|
||||||
|
|
||||||
var allTools []sdk.Tool
|
var allTools []sdk.Tool
|
||||||
|
|||||||
@@ -0,0 +1,174 @@
|
|||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
sdk "github.com/memohai/twilight-ai/sdk"
|
||||||
|
|
||||||
|
agenttools "github.com/memohai/memoh/internal/agent/tools"
|
||||||
|
)
|
||||||
|
|
||||||
|
func decorateReadMediaTools(model *sdk.Model, tools []sdk.Tool) ([]sdk.Tool, *readMediaDecorationState) {
|
||||||
|
if len(tools) == 0 {
|
||||||
|
return tools, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
clientType := resolveClientType(model)
|
||||||
|
state := &readMediaDecorationState{
|
||||||
|
pendingImages: make(map[string]sdk.ImagePart),
|
||||||
|
}
|
||||||
|
wrapped := make([]sdk.Tool, 0, len(tools))
|
||||||
|
found := false
|
||||||
|
|
||||||
|
for _, tool := range tools {
|
||||||
|
if tool.Name != agenttools.ReadMediaToolName || tool.Execute == nil {
|
||||||
|
wrapped = append(wrapped, tool)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
found = true
|
||||||
|
originalExecute := tool.Execute
|
||||||
|
toolCopy := tool
|
||||||
|
toolCopy.Execute = func(ctx *sdk.ToolExecContext, input any) (any, error) {
|
||||||
|
output, err := originalExecute(ctx, input)
|
||||||
|
if err != nil {
|
||||||
|
return output, err
|
||||||
|
}
|
||||||
|
|
||||||
|
publicResult, image, ok := normalizeReadMediaOutput(output, clientType)
|
||||||
|
if !ok {
|
||||||
|
return output, nil
|
||||||
|
}
|
||||||
|
if ctx != nil && strings.TrimSpace(ctx.ToolCallID) != "" && strings.TrimSpace(image.Image) != "" {
|
||||||
|
if _, exists := state.pendingImages[ctx.ToolCallID]; !exists {
|
||||||
|
state.pendingOrder = append(state.pendingOrder, ctx.ToolCallID)
|
||||||
|
}
|
||||||
|
state.pendingImages[ctx.ToolCallID] = image
|
||||||
|
}
|
||||||
|
return publicResult, nil
|
||||||
|
}
|
||||||
|
wrapped = append(wrapped, toolCopy)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
return tools, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return wrapped, state
|
||||||
|
}
|
||||||
|
|
||||||
|
type readMediaDecorationState struct {
|
||||||
|
pendingOrder []string
|
||||||
|
pendingImages map[string]sdk.ImagePart
|
||||||
|
prepareCalls int
|
||||||
|
injections []readMediaInjection
|
||||||
|
}
|
||||||
|
|
||||||
|
type readMediaInjection struct {
|
||||||
|
afterStep int
|
||||||
|
message sdk.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *readMediaDecorationState) prepareStep(params *sdk.GenerateParams) *sdk.GenerateParams {
|
||||||
|
if s == nil || params == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
afterStep := s.prepareCalls
|
||||||
|
s.prepareCalls++
|
||||||
|
|
||||||
|
if len(s.pendingOrder) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := make([]sdk.MessagePart, 0, len(s.pendingOrder))
|
||||||
|
for _, toolCallID := range s.pendingOrder {
|
||||||
|
image, ok := s.pendingImages[toolCallID]
|
||||||
|
delete(s.pendingImages, toolCallID)
|
||||||
|
if !ok || strings.TrimSpace(image.Image) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parts = append(parts, image)
|
||||||
|
}
|
||||||
|
s.pendingOrder = s.pendingOrder[:0]
|
||||||
|
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
message := sdk.Message{
|
||||||
|
Role: sdk.MessageRoleUser,
|
||||||
|
Content: parts,
|
||||||
|
}
|
||||||
|
s.injections = append(s.injections, readMediaInjection{
|
||||||
|
afterStep: afterStep,
|
||||||
|
message: message,
|
||||||
|
})
|
||||||
|
|
||||||
|
next := *params
|
||||||
|
next.Messages = append(append([]sdk.Message(nil), params.Messages...), message)
|
||||||
|
return &next
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *readMediaDecorationState) mergeMessages(steps []sdk.StepResult, fallback []sdk.Message) []sdk.Message {
|
||||||
|
if s == nil || len(s.injections) == 0 {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
if len(steps) == 0 {
|
||||||
|
merged := append([]sdk.Message(nil), fallback...)
|
||||||
|
for _, injection := range s.injections {
|
||||||
|
merged = append(merged, injection.message)
|
||||||
|
}
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
|
merged := make([]sdk.Message, 0, len(fallback)+len(s.injections))
|
||||||
|
injectionIndex := 0
|
||||||
|
for stepIndex, step := range steps {
|
||||||
|
merged = append(merged, step.Messages...)
|
||||||
|
for injectionIndex < len(s.injections) && s.injections[injectionIndex].afterStep == stepIndex {
|
||||||
|
merged = append(merged, s.injections[injectionIndex].message)
|
||||||
|
injectionIndex++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for injectionIndex < len(s.injections) {
|
||||||
|
merged = append(merged, s.injections[injectionIndex].message)
|
||||||
|
injectionIndex++
|
||||||
|
}
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeReadMediaOutput(output any, clientType string) (any, sdk.ImagePart, bool) {
|
||||||
|
switch value := output.(type) {
|
||||||
|
case agenttools.ReadMediaToolOutput:
|
||||||
|
return value.Public, buildReadMediaImagePart(clientType, value.ImageBase64, value.ImageMediaType), true
|
||||||
|
case *agenttools.ReadMediaToolOutput:
|
||||||
|
if value == nil {
|
||||||
|
return nil, sdk.ImagePart{}, false
|
||||||
|
}
|
||||||
|
return value.Public, buildReadMediaImagePart(clientType, value.ImageBase64, value.ImageMediaType), true
|
||||||
|
default:
|
||||||
|
return nil, sdk.ImagePart{}, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildReadMediaImagePart(clientType, imageBase64, mediaType string) sdk.ImagePart {
|
||||||
|
imageBase64 = strings.TrimSpace(imageBase64)
|
||||||
|
mediaType = strings.TrimSpace(mediaType)
|
||||||
|
if imageBase64 == "" {
|
||||||
|
return sdk.ImagePart{}
|
||||||
|
}
|
||||||
|
if mediaType == "" {
|
||||||
|
mediaType = "image/png"
|
||||||
|
}
|
||||||
|
|
||||||
|
image := imageBase64
|
||||||
|
if clientType != ClientTypeAnthropicMessages {
|
||||||
|
image = fmt.Sprintf("data:%s;base64,%s", mediaType, imageBase64)
|
||||||
|
}
|
||||||
|
return sdk.ImagePart{
|
||||||
|
Image: image,
|
||||||
|
MediaType: mediaType,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,394 @@
|
|||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
sdk "github.com/memohai/twilight-ai/sdk"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
"google.golang.org/grpc/test/bufconn"
|
||||||
|
|
||||||
|
agenttools "github.com/memohai/memoh/internal/agent/tools"
|
||||||
|
"github.com/memohai/memoh/internal/workspace/bridge"
|
||||||
|
pb "github.com/memohai/memoh/internal/workspace/bridgepb"
|
||||||
|
)
|
||||||
|
|
||||||
|
const agentReadMediaTestBufSize = 1 << 20
|
||||||
|
|
||||||
|
type agentReadMediaContainerService struct {
|
||||||
|
pb.UnimplementedContainerServiceServer
|
||||||
|
files map[string][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *agentReadMediaContainerService) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error {
|
||||||
|
data, ok := s.files[req.GetPath()]
|
||||||
|
if !ok {
|
||||||
|
return status.Error(codes.NotFound, "not found")
|
||||||
|
}
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return stream.Send(&pb.DataChunk{Data: data})
|
||||||
|
}
|
||||||
|
|
||||||
|
type agentReadMediaBridgeProvider struct {
|
||||||
|
client *bridge.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *agentReadMediaBridgeProvider) MCPClient(_ context.Context, _ string) (*bridge.Client, error) {
|
||||||
|
return p.client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAgentReadMediaBridgeProvider(t *testing.T, files map[string][]byte) bridge.Provider {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
lis := bufconn.Listen(agentReadMediaTestBufSize)
|
||||||
|
srv := grpc.NewServer()
|
||||||
|
pb.RegisterContainerServiceServer(srv, &agentReadMediaContainerService{files: files})
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
_ = srv.Serve(lis)
|
||||||
|
}()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
srv.Stop()
|
||||||
|
<-done
|
||||||
|
})
|
||||||
|
|
||||||
|
dialer := func(ctx context.Context, _ string) (net.Conn, error) {
|
||||||
|
return lis.DialContext(ctx)
|
||||||
|
}
|
||||||
|
conn, err := grpc.NewClient(
|
||||||
|
"passthrough://bufnet",
|
||||||
|
grpc.WithContextDialer(dialer),
|
||||||
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("grpc.NewClient: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = conn.Close() })
|
||||||
|
|
||||||
|
return &agentReadMediaBridgeProvider{client: bridge.NewClientFromConn(conn)}
|
||||||
|
}
|
||||||
|
|
||||||
|
type agentReadMediaMockProvider struct {
|
||||||
|
name string
|
||||||
|
calls int
|
||||||
|
handler func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *agentReadMediaMockProvider) Name() string {
|
||||||
|
if m.name != "" {
|
||||||
|
return m.name
|
||||||
|
}
|
||||||
|
return "mock"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*agentReadMediaMockProvider) ListModels(context.Context) ([]sdk.Model, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*agentReadMediaMockProvider) Test(context.Context) *sdk.ProviderTestResult {
|
||||||
|
return &sdk.ProviderTestResult{Status: sdk.ProviderStatusOK, Message: "ok"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*agentReadMediaMockProvider) TestModel(context.Context, string) (*sdk.ModelTestResult, error) {
|
||||||
|
return &sdk.ModelTestResult{Supported: true, Message: "supported"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *agentReadMediaMockProvider) DoGenerate(_ context.Context, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||||
|
m.calls++
|
||||||
|
return m.handler(m.calls, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *agentReadMediaMockProvider) DoStream(ctx context.Context, params sdk.GenerateParams) (*sdk.StreamResult, error) {
|
||||||
|
result, err := m.DoGenerate(ctx, params)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ch := make(chan sdk.StreamPart, 8)
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
ch <- &sdk.StartPart{}
|
||||||
|
ch <- &sdk.StartStepPart{}
|
||||||
|
if result.Text != "" {
|
||||||
|
ch <- &sdk.TextStartPart{ID: "mock"}
|
||||||
|
ch <- &sdk.TextDeltaPart{ID: "mock", Text: result.Text}
|
||||||
|
ch <- &sdk.TextEndPart{ID: "mock"}
|
||||||
|
}
|
||||||
|
for _, tc := range result.ToolCalls {
|
||||||
|
ch <- &sdk.StreamToolCallPart{
|
||||||
|
ToolCallID: tc.ToolCallID,
|
||||||
|
ToolName: tc.ToolName,
|
||||||
|
Input: tc.Input,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ch <- &sdk.FinishStepPart{FinishReason: result.FinishReason, Usage: result.Usage, Response: result.Response}
|
||||||
|
ch <- &sdk.FinishPart{FinishReason: result.FinishReason, TotalUsage: result.Usage}
|
||||||
|
}()
|
||||||
|
return &sdk.StreamResult{Stream: ch}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertInjectedReadMediaMessage(t *testing.T, msg sdk.Message, expectedImage, expectedMediaType string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if msg.Role != sdk.MessageRoleUser {
|
||||||
|
t.Fatalf("expected injected read_media message role %q, got %q", sdk.MessageRoleUser, msg.Role)
|
||||||
|
}
|
||||||
|
if len(msg.Content) != 1 {
|
||||||
|
t.Fatalf("expected one injected content part, got %d", len(msg.Content))
|
||||||
|
}
|
||||||
|
image, ok := msg.Content[0].(sdk.ImagePart)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected sdk.ImagePart, got %T", msg.Content[0])
|
||||||
|
}
|
||||||
|
if image.Image != expectedImage {
|
||||||
|
t.Fatalf("unexpected injected image payload: %q", image.Image)
|
||||||
|
}
|
||||||
|
if image.MediaType != expectedMediaType {
|
||||||
|
t.Fatalf("unexpected injected media type: %q", image.MediaType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgentGenerateReadMediaInjectsImageIntoNextStep(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pngBytes := []byte("\x89PNG\r\n\x1a\npayload")
|
||||||
|
expectedDataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(pngBytes)
|
||||||
|
|
||||||
|
modelProvider := &agentReadMediaMockProvider{
|
||||||
|
handler: func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||||
|
if call == 1 {
|
||||||
|
return &sdk.GenerateResult{
|
||||||
|
FinishReason: sdk.FinishReasonToolCalls,
|
||||||
|
ToolCalls: []sdk.ToolCall{{
|
||||||
|
ToolCallID: "call-1",
|
||||||
|
ToolName: "read_media",
|
||||||
|
Input: map[string]any{"path": "/data/images/demo.png"},
|
||||||
|
}},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(params.Messages) < 4 {
|
||||||
|
t.Fatalf("expected prior tool and injected messages, got %d", len(params.Messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
last := params.Messages[len(params.Messages)-1]
|
||||||
|
if last.Role != sdk.MessageRoleUser {
|
||||||
|
t.Fatalf("expected last message to be injected user image, got %s", last.Role)
|
||||||
|
}
|
||||||
|
if len(last.Content) != 1 {
|
||||||
|
t.Fatalf("expected one injected content part, got %d", len(last.Content))
|
||||||
|
}
|
||||||
|
image, ok := last.Content[0].(sdk.ImagePart)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected sdk.ImagePart, got %T", last.Content[0])
|
||||||
|
}
|
||||||
|
if image.Image != expectedDataURL {
|
||||||
|
t.Fatalf("unexpected injected image payload: %q", image.Image)
|
||||||
|
}
|
||||||
|
if image.MediaType != "image/png" {
|
||||||
|
t.Fatalf("unexpected injected media type: %q", image.MediaType)
|
||||||
|
}
|
||||||
|
|
||||||
|
var toolResult sdk.ToolResultPart
|
||||||
|
foundToolMessage := false
|
||||||
|
for _, msg := range params.Messages {
|
||||||
|
if msg.Role != sdk.MessageRoleTool || len(msg.Content) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
part, ok := msg.Content[0].(sdk.ToolResultPart)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
toolResult = part
|
||||||
|
foundToolMessage = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !foundToolMessage {
|
||||||
|
t.Fatal("expected tool result message before second step")
|
||||||
|
}
|
||||||
|
raw, err := json.Marshal(toolResult.Result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal tool result: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Contains(raw, []byte(`"ok":true`)) {
|
||||||
|
t.Fatalf("expected compact success metadata, got %s", raw)
|
||||||
|
}
|
||||||
|
if bytes.Contains(raw, []byte(expectedDataURL)) || bytes.Contains(raw, []byte("payload")) {
|
||||||
|
t.Fatalf("tool result leaked image bytes: %s", raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &sdk.GenerateResult{
|
||||||
|
Text: "done",
|
||||||
|
FinishReason: sdk.FinishReasonStop,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
a := New(Deps{})
|
||||||
|
a.SetToolProviders([]agenttools.ToolProvider{
|
||||||
|
agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{
|
||||||
|
"/data/images/demo.png": pngBytes,
|
||||||
|
}), "/data"),
|
||||||
|
})
|
||||||
|
|
||||||
|
result, err := a.Generate(context.Background(), RunConfig{
|
||||||
|
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
||||||
|
Messages: []sdk.Message{sdk.UserMessage("look at the image")},
|
||||||
|
SupportsImageInput: true,
|
||||||
|
Identity: SessionContext{
|
||||||
|
BotID: "bot-1",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Generate returned error: %v", err)
|
||||||
|
}
|
||||||
|
if result.Text != "done" {
|
||||||
|
t.Fatalf("unexpected result text: %q", result.Text)
|
||||||
|
}
|
||||||
|
if len(result.Messages) != 4 {
|
||||||
|
t.Fatalf("expected persisted step + injected history, got %d messages", len(result.Messages))
|
||||||
|
}
|
||||||
|
assertInjectedReadMediaMessage(t, result.Messages[2], expectedDataURL, "image/png")
|
||||||
|
if result.Messages[3].Role != sdk.MessageRoleAssistant {
|
||||||
|
t.Fatalf("expected final persisted message to be assistant, got %s", result.Messages[3].Role)
|
||||||
|
}
|
||||||
|
if modelProvider.calls != 2 {
|
||||||
|
t.Fatalf("expected 2 model calls, got %d", modelProvider.calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgentGenerateReadMediaInjectsAnthropicSafeImageIntoNextStep(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pngBytes := []byte("\x89PNG\r\n\x1a\npayload")
|
||||||
|
expectedBase64 := base64.StdEncoding.EncodeToString(pngBytes)
|
||||||
|
|
||||||
|
modelProvider := &agentReadMediaMockProvider{
|
||||||
|
name: "anthropic-messages",
|
||||||
|
handler: func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||||
|
if call == 1 {
|
||||||
|
return &sdk.GenerateResult{
|
||||||
|
FinishReason: sdk.FinishReasonToolCalls,
|
||||||
|
ToolCalls: []sdk.ToolCall{{
|
||||||
|
ToolCallID: "call-1",
|
||||||
|
ToolName: "read_media",
|
||||||
|
Input: map[string]any{"path": "/data/images/demo.png"},
|
||||||
|
}},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
last := params.Messages[len(params.Messages)-1]
|
||||||
|
image, ok := last.Content[0].(sdk.ImagePart)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected sdk.ImagePart, got %T", last.Content[0])
|
||||||
|
}
|
||||||
|
if image.Image != expectedBase64 {
|
||||||
|
t.Fatalf("expected raw base64 for anthropic, got %q", image.Image)
|
||||||
|
}
|
||||||
|
if image.MediaType != "image/png" {
|
||||||
|
t.Fatalf("unexpected injected media type: %q", image.MediaType)
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(image.Image, "data:") {
|
||||||
|
t.Fatalf("anthropic image payload must not be a data URL: %q", image.Image)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &sdk.GenerateResult{
|
||||||
|
Text: "done",
|
||||||
|
FinishReason: sdk.FinishReasonStop,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
a := New(Deps{})
|
||||||
|
a.SetToolProviders([]agenttools.ToolProvider{
|
||||||
|
agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{
|
||||||
|
"/data/images/demo.png": pngBytes,
|
||||||
|
}), "/data"),
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := a.Generate(context.Background(), RunConfig{
|
||||||
|
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
||||||
|
Messages: []sdk.Message{sdk.UserMessage("look at the image")},
|
||||||
|
SupportsImageInput: true,
|
||||||
|
Identity: SessionContext{
|
||||||
|
BotID: "bot-1",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Generate returned error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgentStreamReadMediaPersistsInjectedImageInTerminalMessages(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pngBytes := []byte("\x89PNG\r\n\x1a\npayload")
|
||||||
|
expectedDataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(pngBytes)
|
||||||
|
|
||||||
|
modelProvider := &agentReadMediaMockProvider{
|
||||||
|
handler: func(call int, _ sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||||
|
if call == 1 {
|
||||||
|
return &sdk.GenerateResult{
|
||||||
|
FinishReason: sdk.FinishReasonToolCalls,
|
||||||
|
ToolCalls: []sdk.ToolCall{{
|
||||||
|
ToolCallID: "call-1",
|
||||||
|
ToolName: "read_media",
|
||||||
|
Input: map[string]any{"path": "/data/images/demo.png"},
|
||||||
|
}},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return &sdk.GenerateResult{
|
||||||
|
Text: "done",
|
||||||
|
FinishReason: sdk.FinishReasonStop,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
a := New(Deps{})
|
||||||
|
a.SetToolProviders([]agenttools.ToolProvider{
|
||||||
|
agenttools.NewReadMediaProvider(nil, newAgentReadMediaBridgeProvider(t, map[string][]byte{
|
||||||
|
"/data/images/demo.png": pngBytes,
|
||||||
|
}), "/data"),
|
||||||
|
})
|
||||||
|
|
||||||
|
var terminal StreamEvent
|
||||||
|
for event := range a.Stream(context.Background(), RunConfig{
|
||||||
|
Model: &sdk.Model{ID: "mock-model", Provider: modelProvider},
|
||||||
|
Messages: []sdk.Message{sdk.UserMessage("look at the image")},
|
||||||
|
SupportsImageInput: true,
|
||||||
|
Identity: SessionContext{
|
||||||
|
BotID: "bot-1",
|
||||||
|
},
|
||||||
|
}) {
|
||||||
|
if event.IsTerminal() {
|
||||||
|
terminal = event
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if terminal.Type != EventAgentEnd {
|
||||||
|
t.Fatalf("expected terminal event %q, got %q", EventAgentEnd, terminal.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
var messages []sdk.Message
|
||||||
|
if err := json.Unmarshal(terminal.Messages, &messages); err != nil {
|
||||||
|
t.Fatalf("unmarshal terminal messages: %v", err)
|
||||||
|
}
|
||||||
|
if len(messages) != 4 {
|
||||||
|
t.Fatalf("expected persisted step + injected history, got %d messages", len(messages))
|
||||||
|
}
|
||||||
|
assertInjectedReadMediaMessage(t, messages[2], expectedDataURL, "image/png")
|
||||||
|
if messages[3].Role != sdk.MessageRoleAssistant {
|
||||||
|
t.Fatalf("expected final persisted message to be assistant, got %s", messages[3].Role)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,217 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
sdk "github.com/memohai/twilight-ai/sdk"
|
||||||
|
|
||||||
|
"github.com/memohai/memoh/internal/workspace/bridge"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ReadMediaToolName = "read_media"
|
||||||
|
toolReadMedia = ReadMediaToolName
|
||||||
|
defaultReadMediaRoot = "/data"
|
||||||
|
defaultReadMediaMaxBytes = 20 * 1024 * 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
var readMediaSupportedMimeTypes = map[string]struct{}{
|
||||||
|
"image/gif": {},
|
||||||
|
"image/jpeg": {},
|
||||||
|
"image/png": {},
|
||||||
|
"image/webp": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadMediaToolResult is the public result returned to the model.
|
||||||
|
type ReadMediaToolResult struct {
|
||||||
|
OK bool `json:"ok"`
|
||||||
|
Path string `json:"path,omitempty"`
|
||||||
|
Mime string `json:"mime,omitempty"`
|
||||||
|
Size int `json:"size,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadMediaToolOutput is the internal execution result used by the agent to
|
||||||
|
// inject the image into the next Twilight AI step while keeping the visible
|
||||||
|
// tool result lightweight.
|
||||||
|
type ReadMediaToolOutput struct {
|
||||||
|
Public ReadMediaToolResult
|
||||||
|
ImageBase64 string
|
||||||
|
ImageMediaType string
|
||||||
|
}
|
||||||
|
|
||||||
|
type readMediaToolOutput = ReadMediaToolOutput
|
||||||
|
|
||||||
|
type ReadMediaProvider struct {
|
||||||
|
clients bridge.Provider
|
||||||
|
rootDir string
|
||||||
|
maxBytes int64
|
||||||
|
logger *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewReadMediaProvider(log *slog.Logger, clients bridge.Provider, rootDir string) *ReadMediaProvider {
|
||||||
|
if log == nil {
|
||||||
|
log = slog.Default()
|
||||||
|
}
|
||||||
|
root := strings.TrimSpace(rootDir)
|
||||||
|
if root == "" {
|
||||||
|
root = defaultReadMediaRoot
|
||||||
|
}
|
||||||
|
return &ReadMediaProvider{
|
||||||
|
clients: clients,
|
||||||
|
rootDir: path.Clean(root),
|
||||||
|
maxBytes: defaultReadMediaMaxBytes,
|
||||||
|
logger: log.With(slog.String("tool", "read_media")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ReadMediaProvider) Tools(_ context.Context, session SessionContext) ([]sdk.Tool, error) {
|
||||||
|
if p == nil || p.clients == nil || !session.SupportsImageInput {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
root := p.rootDir
|
||||||
|
if root == "" {
|
||||||
|
root = defaultReadMediaRoot
|
||||||
|
}
|
||||||
|
sess := session
|
||||||
|
return []sdk.Tool{
|
||||||
|
{
|
||||||
|
Name: toolReadMedia,
|
||||||
|
Description: fmt.Sprintf("Load an image file from %s into model context so you can inspect it. Relative paths are resolved under %s.", root, root),
|
||||||
|
Parameters: map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"path": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": fmt.Sprintf("Image file path under %s. Absolute paths must stay under %s; relative paths are resolved under %s.", root, root, root),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"path"},
|
||||||
|
},
|
||||||
|
Execute: func(ctx *sdk.ToolExecContext, input any) (any, error) {
|
||||||
|
return p.execReadMedia(ctx.Context, sess, inputAsMap(input))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ReadMediaProvider) execReadMedia(ctx context.Context, session SessionContext, args map[string]any) (any, error) {
|
||||||
|
client, err := p.getClient(ctx, session.BotID)
|
||||||
|
if err != nil {
|
||||||
|
return readMediaErrorResult(err.Error()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedPath, err := p.resolveImagePath(StringArg(args, "path"))
|
||||||
|
if err != nil {
|
||||||
|
return readMediaErrorResult(err.Error()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
reader, err := client.ReadRaw(ctx, resolvedPath)
|
||||||
|
if err != nil {
|
||||||
|
return readMediaErrorResult(err.Error()), nil
|
||||||
|
}
|
||||||
|
defer func() { _ = reader.Close() }()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(io.LimitReader(reader, p.maxBytes+1))
|
||||||
|
if err != nil {
|
||||||
|
return readMediaErrorResult("read_media failed to load image: " + err.Error()), nil
|
||||||
|
}
|
||||||
|
if int64(len(data)) > p.maxBytes {
|
||||||
|
return readMediaErrorResult(fmt.Sprintf("read_media failed to load image: file exceeds %d bytes", p.maxBytes)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mimeType, err := detectReadMediaMime(data)
|
||||||
|
if err != nil {
|
||||||
|
return readMediaErrorResult(err.Error()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
encoded := base64.StdEncoding.EncodeToString(data)
|
||||||
|
return ReadMediaToolOutput{
|
||||||
|
Public: ReadMediaToolResult{
|
||||||
|
OK: true,
|
||||||
|
Path: resolvedPath,
|
||||||
|
Mime: mimeType,
|
||||||
|
Size: len(data),
|
||||||
|
},
|
||||||
|
ImageBase64: encoded,
|
||||||
|
ImageMediaType: mimeType,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ReadMediaProvider) getClient(ctx context.Context, botID string) (*bridge.Client, error) {
|
||||||
|
botID = strings.TrimSpace(botID)
|
||||||
|
if botID == "" {
|
||||||
|
return nil, errors.New("bot_id is required")
|
||||||
|
}
|
||||||
|
client, err := p.clients.MCPClient(ctx, botID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("container not reachable: %w", err)
|
||||||
|
}
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ReadMediaProvider) resolveImagePath(raw string) (string, error) {
|
||||||
|
trimmed := strings.TrimSpace(raw)
|
||||||
|
if trimmed == "" {
|
||||||
|
return "", errors.New("path is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
root := p.rootDir
|
||||||
|
if root == "" {
|
||||||
|
root = defaultReadMediaRoot
|
||||||
|
}
|
||||||
|
root = path.Clean(root)
|
||||||
|
|
||||||
|
resolved := trimmed
|
||||||
|
if !strings.HasPrefix(resolved, "/") {
|
||||||
|
resolved = path.Join(root, resolved)
|
||||||
|
}
|
||||||
|
resolved = path.Clean(resolved)
|
||||||
|
|
||||||
|
if resolved == root || !strings.HasPrefix(resolved, root+"/") {
|
||||||
|
return "", fmt.Errorf("path must be under %s", root)
|
||||||
|
}
|
||||||
|
return resolved, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readMediaErrorResult(message string) ReadMediaToolOutput {
|
||||||
|
msg := strings.TrimSpace(message)
|
||||||
|
if msg == "" {
|
||||||
|
msg = "read_media failed"
|
||||||
|
}
|
||||||
|
return ReadMediaToolOutput{
|
||||||
|
Public: ReadMediaToolResult{
|
||||||
|
OK: false,
|
||||||
|
Error: msg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func detectReadMediaMime(data []byte) (string, error) {
|
||||||
|
sniffedMime := ""
|
||||||
|
if len(data) > 0 {
|
||||||
|
sniffedMime = strings.ToLower(strings.TrimSpace(http.DetectContentType(data)))
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case sniffedMime == "":
|
||||||
|
return "", errors.New("read_media only supports PNG, JPEG, GIF, or WebP image bytes")
|
||||||
|
case isSupportedReadMediaMime(sniffedMime):
|
||||||
|
return sniffedMime, nil
|
||||||
|
default:
|
||||||
|
return "", errors.New("read_media only supports PNG, JPEG, GIF, or WebP image bytes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSupportedReadMediaMime(mimeType string) bool {
|
||||||
|
_, ok := readMediaSupportedMimeTypes[strings.ToLower(strings.TrimSpace(mimeType))]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
@@ -0,0 +1,306 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
sdk "github.com/memohai/twilight-ai/sdk"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
"google.golang.org/grpc/test/bufconn"
|
||||||
|
|
||||||
|
"github.com/memohai/memoh/internal/workspace/bridge"
|
||||||
|
pb "github.com/memohai/memoh/internal/workspace/bridgepb"
|
||||||
|
)
|
||||||
|
|
||||||
|
const readMediaTestBufSize = 1 << 20
|
||||||
|
|
||||||
|
type readMediaTestContainerService struct {
|
||||||
|
pb.UnimplementedContainerServiceServer
|
||||||
|
files map[string][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *readMediaTestContainerService) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error {
|
||||||
|
data, ok := s.files[req.GetPath()]
|
||||||
|
if !ok {
|
||||||
|
return status.Error(codes.NotFound, "not found")
|
||||||
|
}
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return stream.Send(&pb.DataChunk{Data: data})
|
||||||
|
}
|
||||||
|
|
||||||
|
type readMediaStaticProvider struct {
|
||||||
|
client *bridge.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *readMediaStaticProvider) MCPClient(_ context.Context, _ string) (*bridge.Client, error) {
|
||||||
|
return p.client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newReadMediaBridgeProvider(t *testing.T, files map[string][]byte) bridge.Provider {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
lis := bufconn.Listen(readMediaTestBufSize)
|
||||||
|
srv := grpc.NewServer()
|
||||||
|
pb.RegisterContainerServiceServer(srv, &readMediaTestContainerService{files: files})
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
_ = srv.Serve(lis)
|
||||||
|
}()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
srv.Stop()
|
||||||
|
<-done
|
||||||
|
})
|
||||||
|
|
||||||
|
dialer := func(ctx context.Context, _ string) (net.Conn, error) {
|
||||||
|
return lis.DialContext(ctx)
|
||||||
|
}
|
||||||
|
conn, err := grpc.NewClient(
|
||||||
|
"passthrough://bufnet",
|
||||||
|
grpc.WithContextDialer(dialer),
|
||||||
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("grpc.NewClient: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = conn.Close() })
|
||||||
|
|
||||||
|
return &readMediaStaticProvider{client: bridge.NewClientFromConn(conn)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func findToolByName(tools []sdk.Tool, name string) (sdk.Tool, bool) {
|
||||||
|
for _, tool := range tools {
|
||||||
|
if tool.Name == name {
|
||||||
|
return tool, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sdk.Tool{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadMediaProviderToolsOnlyWhenImageInputIsSupported(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, nil), "/data")
|
||||||
|
|
||||||
|
toolsWithoutImage, err := provider.Tools(context.Background(), SessionContext{
|
||||||
|
BotID: "bot-1",
|
||||||
|
SupportsImageInput: false,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Tools without image input returned error: %v", err)
|
||||||
|
}
|
||||||
|
if len(toolsWithoutImage) != 0 {
|
||||||
|
t.Fatalf("expected no tools without image input support, got %d", len(toolsWithoutImage))
|
||||||
|
}
|
||||||
|
|
||||||
|
toolsWithImage, err := provider.Tools(context.Background(), SessionContext{
|
||||||
|
BotID: "bot-1",
|
||||||
|
SupportsImageInput: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Tools with image input returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tool, ok := findToolByName(toolsWithImage, toolReadMedia)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected %q tool to be exposed", toolReadMedia)
|
||||||
|
}
|
||||||
|
if tool.Execute == nil {
|
||||||
|
t.Fatal("expected read_media tool to be executable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadMediaProviderExecuteReadsImageUnderData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pngBytes := []byte("\x89PNG\r\n\x1a\npayload")
|
||||||
|
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, map[string][]byte{
|
||||||
|
"/data/images/demo.png": pngBytes,
|
||||||
|
}), "/data")
|
||||||
|
|
||||||
|
tools, err := provider.Tools(context.Background(), SessionContext{
|
||||||
|
BotID: "bot-1",
|
||||||
|
SupportsImageInput: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Tools returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tool, ok := findToolByName(tools, toolReadMedia)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected %q tool", toolReadMedia)
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := tool.Execute(&sdk.ToolExecContext{
|
||||||
|
Context: context.Background(),
|
||||||
|
ToolCallID: "call-1",
|
||||||
|
ToolName: toolReadMedia,
|
||||||
|
}, map[string]any{"path": "images/demo.png"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, ok := output.(readMediaToolOutput)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected readMediaToolOutput, got %T", output)
|
||||||
|
}
|
||||||
|
if !result.Public.OK {
|
||||||
|
t.Fatalf("expected success result, got %+v", result.Public)
|
||||||
|
}
|
||||||
|
if result.Public.Path != "/data/images/demo.png" {
|
||||||
|
t.Fatalf("unexpected path: %q", result.Public.Path)
|
||||||
|
}
|
||||||
|
if result.Public.Mime != "image/png" {
|
||||||
|
t.Fatalf("unexpected mime: %q", result.Public.Mime)
|
||||||
|
}
|
||||||
|
if result.Public.Size != len(pngBytes) {
|
||||||
|
t.Fatalf("unexpected size: %d", result.Public.Size)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedBase64 := base64.StdEncoding.EncodeToString(pngBytes)
|
||||||
|
if result.ImageBase64 != expectedBase64 {
|
||||||
|
t.Fatalf("unexpected image payload: %q", result.ImageBase64)
|
||||||
|
}
|
||||||
|
if result.ImageMediaType != "image/png" {
|
||||||
|
t.Fatalf("unexpected image media type: %q", result.ImageMediaType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadMediaProviderExecuteRejectsPathOutsideData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, nil), "/data")
|
||||||
|
|
||||||
|
tools, err := provider.Tools(context.Background(), SessionContext{
|
||||||
|
BotID: "bot-1",
|
||||||
|
SupportsImageInput: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Tools returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tool, ok := findToolByName(tools, toolReadMedia)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected %q tool", toolReadMedia)
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := tool.Execute(&sdk.ToolExecContext{
|
||||||
|
Context: context.Background(),
|
||||||
|
ToolCallID: "call-2",
|
||||||
|
ToolName: toolReadMedia,
|
||||||
|
}, map[string]any{"path": "/tmp/demo.png"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, ok := output.(readMediaToolOutput)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected readMediaToolOutput, got %T", output)
|
||||||
|
}
|
||||||
|
if result.Public.OK {
|
||||||
|
t.Fatalf("expected error result, got %+v", result.Public)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result.Public.Error, "path must be under /data") {
|
||||||
|
t.Fatalf("unexpected error: %q", result.Public.Error)
|
||||||
|
}
|
||||||
|
if result.ImageBase64 != "" {
|
||||||
|
t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadMediaProviderExecuteRejectsExtensionOnlySVG(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, map[string][]byte{
|
||||||
|
"/data/images/demo.svg": []byte(`<svg xmlns="http://www.w3.org/2000/svg"></svg>`),
|
||||||
|
}), "/data")
|
||||||
|
|
||||||
|
tools, err := provider.Tools(context.Background(), SessionContext{
|
||||||
|
BotID: "bot-1",
|
||||||
|
SupportsImageInput: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Tools returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tool, ok := findToolByName(tools, toolReadMedia)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected %q tool", toolReadMedia)
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := tool.Execute(&sdk.ToolExecContext{
|
||||||
|
Context: context.Background(),
|
||||||
|
ToolCallID: "call-3",
|
||||||
|
ToolName: toolReadMedia,
|
||||||
|
}, map[string]any{"path": "images/demo.svg"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, ok := output.(readMediaToolOutput)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected readMediaToolOutput, got %T", output)
|
||||||
|
}
|
||||||
|
if result.Public.OK {
|
||||||
|
t.Fatalf("expected error result, got %+v", result.Public)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result.Public.Error, "PNG, JPEG, GIF, or WebP") {
|
||||||
|
t.Fatalf("unexpected error: %q", result.Public.Error)
|
||||||
|
}
|
||||||
|
if result.ImageBase64 != "" {
|
||||||
|
t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadMediaProviderExecuteRejectsCorruptedRasterBytes(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
provider := NewReadMediaProvider(nil, newReadMediaBridgeProvider(t, map[string][]byte{
|
||||||
|
"/data/images/demo.png": []byte("definitely not a png"),
|
||||||
|
}), "/data")
|
||||||
|
|
||||||
|
tools, err := provider.Tools(context.Background(), SessionContext{
|
||||||
|
BotID: "bot-1",
|
||||||
|
SupportsImageInput: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Tools returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tool, ok := findToolByName(tools, toolReadMedia)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected %q tool", toolReadMedia)
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := tool.Execute(&sdk.ToolExecContext{
|
||||||
|
Context: context.Background(),
|
||||||
|
ToolCallID: "call-4",
|
||||||
|
ToolName: toolReadMedia,
|
||||||
|
}, map[string]any{"path": "images/demo.png"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, ok := output.(readMediaToolOutput)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected readMediaToolOutput, got %T", output)
|
||||||
|
}
|
||||||
|
if result.Public.OK {
|
||||||
|
t.Fatalf("expected error result, got %+v", result.Public)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result.Public.Error, "PNG, JPEG, GIF, or WebP") {
|
||||||
|
t.Fatalf("unexpected error: %q", result.Public.Error)
|
||||||
|
}
|
||||||
|
if result.ImageBase64 != "" {
|
||||||
|
t.Fatalf("expected no injected image for error result, got %q", result.ImageBase64)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,13 +12,14 @@ import (
|
|||||||
|
|
||||||
// SessionContext carries request-scoped identity for tool execution.
|
// SessionContext carries request-scoped identity for tool execution.
|
||||||
type SessionContext struct {
|
type SessionContext struct {
|
||||||
BotID string
|
BotID string
|
||||||
ChatID string
|
ChatID string
|
||||||
ChannelIdentityID string
|
ChannelIdentityID string
|
||||||
SessionToken string //nolint:gosec // carries session credential material at runtime
|
SessionToken string //nolint:gosec // carries session credential material at runtime
|
||||||
CurrentPlatform string
|
CurrentPlatform string
|
||||||
ReplyTarget string
|
ReplyTarget string
|
||||||
IsSubagent bool
|
SupportsImageInput bool
|
||||||
|
IsSubagent bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToolProvider supplies a set of tools for the agent.
|
// ToolProvider supplies a set of tools for the agent.
|
||||||
|
|||||||
+15
-14
@@ -54,20 +54,21 @@ type LoopDetectionConfig struct {
|
|||||||
|
|
||||||
// RunConfig holds everything needed for a single agent invocation.
|
// RunConfig holds everything needed for a single agent invocation.
|
||||||
type RunConfig struct {
|
type RunConfig struct {
|
||||||
Model *sdk.Model
|
Model *sdk.Model
|
||||||
ReasoningEffort string
|
ReasoningEffort string
|
||||||
Messages []sdk.Message
|
Messages []sdk.Message
|
||||||
Query string
|
Query string
|
||||||
System string
|
System string
|
||||||
Tools []sdk.Tool
|
Tools []sdk.Tool
|
||||||
Channels []string
|
SupportsImageInput bool
|
||||||
CurrentChannel string
|
Channels []string
|
||||||
Identity SessionContext
|
CurrentChannel string
|
||||||
Skills []SkillEntry
|
Identity SessionContext
|
||||||
EnabledSkillNames []string
|
Skills []SkillEntry
|
||||||
Inbox []InboxItem
|
EnabledSkillNames []string
|
||||||
LoopDetection LoopDetectionConfig
|
Inbox []InboxItem
|
||||||
ActiveContextTime int
|
LoopDetection LoopDetectionConfig
|
||||||
|
ActiveContextTime int
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateResult holds the result of a non-streaming agent invocation.
|
// GenerateResult holds the result of a non-streaming agent invocation.
|
||||||
|
|||||||
@@ -0,0 +1,45 @@
|
|||||||
|
package flow
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
agentpkg "github.com/memohai/memoh/internal/agent"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPrepareRunConfigIncludesReadMediaWhenImageInputIsSupported(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
resolver := &Resolver{}
|
||||||
|
cfg := agentpkg.RunConfig{
|
||||||
|
Query: "describe this image",
|
||||||
|
SupportsImageInput: true,
|
||||||
|
Identity: agentpkg.SessionContext{
|
||||||
|
BotID: "bot-1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
prepared := resolver.prepareRunConfig(context.Background(), cfg)
|
||||||
|
if !strings.Contains(prepared.System, "`read_media`") {
|
||||||
|
t.Fatalf("expected system prompt to include read_media tool, got:\n%s", prepared.System)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareRunConfigOmitsReadMediaWhenImageInputIsUnsupported(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
resolver := &Resolver{}
|
||||||
|
cfg := agentpkg.RunConfig{
|
||||||
|
Query: "describe this image",
|
||||||
|
SupportsImageInput: false,
|
||||||
|
Identity: agentpkg.SessionContext{
|
||||||
|
BotID: "bot-1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
prepared := resolver.prepareRunConfig(context.Background(), cfg)
|
||||||
|
if strings.Contains(prepared.System, "`read_media`") {
|
||||||
|
t.Fatalf("expected system prompt to omit read_media tool, got:\n%s", prepared.System)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -309,12 +309,13 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
|
|||||||
sdkMessages := modelMessagesToSDKMessages(nonNilModelMessages(messages))
|
sdkMessages := modelMessagesToSDKMessages(nonNilModelMessages(messages))
|
||||||
|
|
||||||
runCfg := agentpkg.RunConfig{
|
runCfg := agentpkg.RunConfig{
|
||||||
Model: sdkModel,
|
Model: sdkModel,
|
||||||
ReasoningEffort: reasoningEffort,
|
ReasoningEffort: reasoningEffort,
|
||||||
Messages: sdkMessages,
|
Messages: sdkMessages,
|
||||||
Query: headerifiedQuery,
|
Query: headerifiedQuery,
|
||||||
Channels: nonNilStrings(req.Channels),
|
SupportsImageInput: chatModel.HasInputModality(models.ModelInputImage),
|
||||||
CurrentChannel: req.CurrentChannel,
|
Channels: nonNilStrings(req.Channels),
|
||||||
|
CurrentChannel: req.CurrentChannel,
|
||||||
Identity: agentpkg.SessionContext{
|
Identity: agentpkg.SessionContext{
|
||||||
BotID: req.BotID,
|
BotID: req.BotID,
|
||||||
ChatID: req.ChatID,
|
ChatID: req.ChatID,
|
||||||
@@ -368,10 +369,7 @@ func (r *Resolver) Chat(ctx context.Context, req conversation.ChatRequest) (conv
|
|||||||
|
|
||||||
// prepareRunConfig generates the system prompt and appends the user message.
|
// prepareRunConfig generates the system prompt and appends the user message.
|
||||||
func (r *Resolver) prepareRunConfig(ctx context.Context, cfg agentpkg.RunConfig) agentpkg.RunConfig {
|
func (r *Resolver) prepareRunConfig(ctx context.Context, cfg agentpkg.RunConfig) agentpkg.RunConfig {
|
||||||
supportsImageInput := false
|
supportsImageInput := cfg.SupportsImageInput
|
||||||
for _, m := range cfg.Identity.CurrentPlatform {
|
|
||||||
_ = m
|
|
||||||
}
|
|
||||||
// Build system prompt
|
// Build system prompt
|
||||||
var files []agentpkg.SystemFile
|
var files []agentpkg.SystemFile
|
||||||
if r.agent != nil {
|
if r.agent != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user