From 0a2a17ecc8dfa5ba12f9af03657e53521e577f5a Mon Sep 17 00:00:00 2001 From: "Ringo.Typowriter" Date: Wed, 4 Mar 2026 11:24:01 +0800 Subject: [PATCH] feat(agent): add readMedia tool for model to view the image (#165) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(agent): add readMedia tool for loading local images into model context * feat(channel/inbound): include container attachment refs in inbound query * fix(agent): preserve ImagePart literal typing in buildNativeImageParts * chore: rename tool --------- Co-authored-by: 晨苒 <16112591+chen-ran@users.noreply.github.com> --- internal/channel/inbound/channel.go | 47 ++++++- internal/channel/inbound/channel_test.go | 27 +++- packages/agent/src/agent.ts | 56 ++++---- packages/agent/src/prompts/system.ts | 18 ++- packages/agent/src/utils/attachments.ts | 6 +- .../src/utils/read-media-injector.test.ts | 109 +++++++++++++++ .../agent/src/utils/read-media-injector.ts | 126 ++++++++++++++++++ 7 files changed, 350 insertions(+), 39 deletions(-) create mode 100644 packages/agent/src/utils/read-media-injector.test.ts create mode 100644 packages/agent/src/utils/read-media-injector.ts diff --git a/internal/channel/inbound/channel.go b/internal/channel/inbound/channel.go index b5afec46..a9e98e91 100644 --- a/internal/channel/inbound/channel.go +++ b/internal/channel/inbound/channel.go @@ -142,7 +142,7 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel if sender == nil { return fmt.Errorf("reply sender not configured") } - text := buildInboundQuery(msg.Message) + text := buildInboundQuery(msg.Message, nil) if p.logger != nil { p.logger.Debug("inbound handle start", slog.String("channel", msg.Channel.String()), @@ -153,7 +153,7 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel slog.String("conversation_id", strings.TrimSpace(msg.Conversation.ID)), ) } - if strings.TrimSpace(text) == "" && len(msg.Message.Attachments) == 0 { + if strings.TrimSpace(msg.Message.PlainText()) == "" && len(msg.Message.Attachments) == 0 { if p.logger != nil { p.logger.Debug("inbound dropped empty", slog.String("channel", msg.Channel.String())) } @@ -185,6 +185,7 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel identity := state.Identity resolvedAttachments := p.ingestInboundAttachments(ctx, cfg, msg, strings.TrimSpace(identity.BotID), msg.Message.Attachments) attachments := mapChannelToChatAttachments(resolvedAttachments) + text = buildInboundQuery(msg.Message, attachments) // Resolve or create the route via channel_routes. if p.routeResolver == nil { @@ -1052,7 +1053,7 @@ func mapStreamChunkToChannelEvents(chunk conversation.StreamChunk) ([]channel.St } } -func buildInboundQuery(message channel.Message) string { +func buildInboundQuery(message channel.Message, attachments []conversation.ChatAttachment) string { text := strings.TrimSpace(message.PlainText()) if text != "" { return text @@ -1061,10 +1062,46 @@ func buildInboundQuery(message channel.Message) string { return "" } count := len(message.Attachments) + fallback := fmt.Sprintf("[User sent %d attachments]", count) if count == 1 { - return "[User sent 1 attachment]" + fallback = "[User sent 1 attachment]" } - return fmt.Sprintf("[User sent %d attachments]", count) + refs := collectContainerAttachmentRefs(attachments) + if len(refs) == 0 { + return fallback + } + var sb strings.Builder + sb.WriteString(fallback) + sb.WriteString("\n[Attachment refs: container paths]\n") + for _, ref := range refs { + sb.WriteString("- ") + sb.WriteString(ref) + sb.WriteByte('\n') + } + return strings.TrimSpace(sb.String()) +} + +func collectContainerAttachmentRefs(attachments []conversation.ChatAttachment) []string { + if len(attachments) == 0 { + return nil + } + seen := make(map[string]struct{}, len(attachments)) + refs := make([]string, 0, len(attachments)) + for _, att := range attachments { + ref := strings.TrimSpace(att.Path) + if ref == "" { + continue + } + if _, exists := seen[ref]; exists { + continue + } + seen[ref] = struct{}{} + refs = append(refs, ref) + } + if len(refs) == 0 { + return nil + } + return refs } func normalizeContentPartType(raw string) channel.MessagePartType { diff --git a/internal/channel/inbound/channel_test.go b/internal/channel/inbound/channel_test.go index 9ec09a65..75ccc0fc 100644 --- a/internal/channel/inbound/channel_test.go +++ b/internal/channel/inbound/channel_test.go @@ -387,7 +387,7 @@ func TestBuildInboundQueryAttachmentFallback(t *testing.T) { {Type: channel.AttachmentImage}, }, } - if got := buildInboundQuery(one); got != "[User sent 1 attachment]" { + if got := buildInboundQuery(one, nil); got != "[User sent 1 attachment]" { t.Fatalf("unexpected single attachment fallback: %q", got) } @@ -397,11 +397,34 @@ func TestBuildInboundQueryAttachmentFallback(t *testing.T) { {Type: channel.AttachmentImage}, }, } - if got := buildInboundQuery(two); got != "[User sent 2 attachments]" { + if got := buildInboundQuery(two, nil); got != "[User sent 2 attachments]" { t.Fatalf("unexpected multiple attachment fallback: %q", got) } } +func TestBuildInboundQueryAttachmentFallbackWithContainerRefs(t *testing.T) { + t.Parallel() + + msg := channel.Message{ + Attachments: []channel.Attachment{ + {Type: channel.AttachmentImage}, + {Type: channel.AttachmentImage}, + }, + } + atts := []conversation.ChatAttachment{ + {Path: "/data/media/ab/first.png"}, + {Path: "/data/media/cd/second.png"}, + {Path: "/data/media/ab/first.png"}, + } + want := "[User sent 2 attachments]\n" + + "[Attachment refs: container paths]\n" + + "- /data/media/ab/first.png\n" + + "- /data/media/cd/second.png" + if got := buildInboundQuery(msg, atts); got != want { + t.Fatalf("unexpected attachment refs fallback:\nwant:\n%s\n\ngot:\n%s", want, got) + } +} + func TestChannelInboundProcessorAttachmentOnlyUsesFallbackQuery(t *testing.T) { channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-fallback"}} memberSvc := &fakeMemberService{isMember: true} diff --git a/packages/agent/src/agent.ts b/packages/agent/src/agent.ts index 3539b978..ba767586 100644 --- a/packages/agent/src/agent.ts +++ b/packages/agent/src/agent.ts @@ -1,12 +1,13 @@ import { generateText, - ImagePart, + type ImagePart, LanguageModelUsage, ModelMessage, stepCountIs, streamText, ToolSet, UserModelMessage, + type PrepareStepFunction, } from 'ai' import { AgentInput, @@ -36,6 +37,7 @@ import { buildIdentityHeaders } from './utils/headers' import { createFS } from './utils' import { createTextLoopGuard, createTextLoopProbeBuffer } from './sential' import { createToolLoopGuardedTools } from './tool-loop' +import { createPrepareStepWithReadMedia } from './utils/read-media-injector' const ANTHROPIC_BUDGET: Record = { low: 5000, medium: 16000, high: 50000 } const GOOGLE_BUDGET: Record = { low: 5000, medium: 16000, high: 50000 } @@ -84,7 +86,7 @@ export const buildNativeImageParts = (attachments: GatewayInputAttachment[]): Im (attachment.transport === 'inline_data_url' || attachment.transport === 'public_url') && Boolean(attachment.payload), ) - .map((attachment) => ({ type: 'image', image: attachment.payload } as ImagePart)) + .map((attachment): ImagePart => ({ type: 'image', image: attachment.payload })) } export const createAgent = ( @@ -110,6 +112,7 @@ export const createAgent = ( fetch: AuthFetcher, ) => { const model = createModel(modelConfig) + const supportsImageInput = hasInputModality(modelConfig, ModelInput.Image) // eslint-disable-next-line @typescript-eslint/no-explicit-any const providerOptions = buildProviderOptions(modelConfig) as any const loopDetectionEnabled = loopDetection?.enabled === true @@ -161,6 +164,7 @@ export const createAgent = ( skills, enabledSkills, inbox, + supportsImageInput, files, }) } @@ -205,8 +209,7 @@ export const createAgent = ( } const generateUserPrompt = (input: AgentInput) => { - const supportsImage = hasInputModality(modelConfig, ModelInput.Image) - const imageParts = supportsImage ? buildNativeImageParts(input.attachments) : [] + const imageParts = supportsImageInput ? buildNativeImageParts(input.attachments) : [] const userMessage: UserModelMessage = { role: 'user', @@ -250,13 +253,20 @@ export const createAgent = ( const runTextGeneration = async ({ messages, systemPrompt, - prepareStep, + basePrepareStep, }: { messages: ModelMessage[] systemPrompt: string - prepareStep?: () => { system: string } + basePrepareStep?: PrepareStepFunction }) => { - const { tools, close } = await getAgentTools() + const { tools: baseTools, close } = await getAgentTools() + const { prepareStep, tools: readMediaTools } = createPrepareStepWithReadMedia({ + modelConfig, + fs, + systemPrompt, + basePrepareStep, + }) + const tools = { ...baseTools, ...readMediaTools } let shouldAbortForToolLoop = false const guardedTools = buildGuardedTools(tools, () => { shouldAbortForToolLoop = true @@ -270,7 +280,7 @@ export const createAgent = ( system: systemPrompt, ...(providerOptions && { providerOptions }), stopWhen: stepCountIs(Infinity), - ...(prepareStep && { prepareStep }), + prepareStep, ...(loopDetectionEnabled && { onStepFinish: ({ text }: { text: string }) => { if (shouldAbortForToolLoop) { @@ -306,11 +316,7 @@ export const createAgent = ( const { response, reasoning, text, usage, steps } = await runTextGeneration({ messages, systemPrompt, - prepareStep: () => { - return { - system: systemPrompt, - } - }, + basePrepareStep: () => ({ system: systemPrompt }), }) const stepUsages = buildStepUsages(steps) const { cleanedText, attachments: textAttachments } = @@ -352,15 +358,12 @@ export const createAgent = ( description: params.description, }) } + const systemPrompt = generateSubagentSystemPrompt() const messages = [...params.messages, userPrompt] const { response, reasoning, text, usage, steps } = await runTextGeneration({ messages, - systemPrompt: generateSubagentSystemPrompt(), - prepareStep: () => { - return { - system: generateSubagentSystemPrompt(), - } - }, + systemPrompt, + basePrepareStep: () => ({ system: generateSubagentSystemPrompt() }), }) const stepUsages = buildStepUsages(steps) return { @@ -497,7 +500,14 @@ export const createAgent = ( usages: [], } const toolLoopAbortCallIds = new Set() - const { tools, close } = await getAgentTools() + const { tools: baseTools, close } = await getAgentTools() + const { prepareStep, tools: readMediaTools } = createPrepareStepWithReadMedia({ + modelConfig, + fs, + systemPrompt, + basePrepareStep: () => ({ system: systemPrompt }), + }) + const tools = { ...baseTools, ...readMediaTools } // Stream path needs deferred abort to keep tool_call_start/tool_call_end event pairing. const guardedTools = buildGuardedTools(tools, (toolCallId) => { toolLoopAbortCallIds.add(toolCallId) @@ -517,11 +527,7 @@ export const createAgent = ( system: systemPrompt, ...(providerOptions && { providerOptions }), stopWhen: stepCountIs(Infinity), - prepareStep: () => { - return { - system: systemPrompt, - } - }, + prepareStep, tools: guardedTools, onFinish: async ({ usage, reasoning, response, steps }) => { await closeTools() diff --git a/packages/agent/src/prompts/system.ts b/packages/agent/src/prompts/system.ts index 6d86a61c..d123e06c 100644 --- a/packages/agent/src/prompts/system.ts +++ b/packages/agent/src/prompts/system.ts @@ -14,6 +14,7 @@ export interface SystemParams { files: SystemFile[] attachments?: string[] inbox?: InboxItem[] + supportsImageInput?: boolean } export const skillPrompt = (skill: AgentSkill) => { @@ -65,6 +66,7 @@ export const system = ({ enabledSkills, files, inbox = [], + supportsImageInput = true, }: SystemParams) => { const home = '/data' // ── Static section (stable prefix for LLM prompt caching) ────────── @@ -80,6 +82,16 @@ export const system = ({ 'time-now': date.toISOString(), } + const basicTools = [ + `- ${quote('read')}: read file content`, + supportsImageInput ? `- ${quote('read_media')}: view the media` : null, + `- ${quote('write')}: write file content`, + `- ${quote('list')}: list directory entries`, + `- ${quote('edit')}: replace exact text in a file`, + `- ${quote('exec')}: execute command`, + ] + .filter((line): line is string => Boolean(line)) + .join('\n') console.log('inbox', inbox) return ` @@ -93,11 +105,7 @@ You are just woke up. ${quote(home)} is your HOME — you can read and write files there freely. ## Basic Tools -- ${quote('read')}: read file content -- ${quote('write')}: write file content -- ${quote('list')}: list directory entries -- ${quote('edit')}: replace exact text in a file -- ${quote('exec')}: execute command +${basicTools} ## Safety - Keep private data private diff --git a/packages/agent/src/utils/attachments.ts b/packages/agent/src/utils/attachments.ts index 69876f4a..87e767f2 100644 --- a/packages/agent/src/utils/attachments.ts +++ b/packages/agent/src/utils/attachments.ts @@ -1,5 +1,8 @@ import type { AssistantModelMessage, ModelMessage, TextPart } from 'ai' -import type { AgentAttachment, ContainerFileAttachment } from '../types/attachment' +import type { + AgentAttachment, + ContainerFileAttachment, +} from '../types/attachment' const ATTACHMENTS_START = '' const ATTACHMENTS_END = '' @@ -196,4 +199,3 @@ export class AttachmentsStreamExtractor { return { visibleText: out, attachments: [] } } } - diff --git a/packages/agent/src/utils/read-media-injector.test.ts b/packages/agent/src/utils/read-media-injector.test.ts new file mode 100644 index 00000000..f053565d --- /dev/null +++ b/packages/agent/src/utils/read-media-injector.test.ts @@ -0,0 +1,109 @@ +import { describe, expect, it } from 'vitest' +import { createPrepareStepWithReadMedia } from './read-media-injector' +import { ClientType, ModelInput, ModelConfig } from '../types/model' + +const baseModelConfig: ModelConfig = { + apiKey: 'test', + baseUrl: 'http://example.com', + modelId: 'model', + clientType: ClientType.OpenAIResponses, + input: [ModelInput.Image], +} + +describe('read_media runtime', () => { + it('caches image and injects it into messages', async () => { + const fs = { + download: async () => + new Response(new Uint8Array([1, 2, 3]), { + headers: { 'content-type': 'image/png' }, + }), + } + const { prepareStep, tools } = createPrepareStepWithReadMedia({ + modelConfig: baseModelConfig, + fs, + systemPrompt: 'sys', + }) + const readMedia = tools.read_media + const output = await readMedia.execute( + { path: '/data/media/a.png' }, + { toolCallId: 'call-1' }, + ) + expect((output as { ok?: boolean }).ok).toBe(true) + const prepared = await prepareStep({ + messages: [{ role: 'user', content: 'hi' }], + steps: [], + stepNumber: 0, + model: {} as never, + experimental_context: undefined, + }) + const injected = prepared.messages?.[1] + expect(injected?.role).toBe('user') + const content = injected?.content as Array<{ type?: string; image?: string }> + expect(content?.[0]?.type).toBe('image') + expect(content?.[0]?.image?.startsWith('data:image/png;base64,')).toBe(true) + }) + + it('returns error result on download failure', async () => { + const fs = { + download: async () => { + throw new Error('boom') + }, + } + const { prepareStep, tools } = createPrepareStepWithReadMedia({ + modelConfig: baseModelConfig, + fs, + systemPrompt: 'sys', + }) + const readMedia = tools.read_media + const output = await readMedia.execute( + { path: '/data/media/a.png' }, + { toolCallId: 'call-2' }, + ) + expect((output as { isError?: boolean }).isError).toBe(true) + const prepared = await prepareStep({ + messages: [{ role: 'user', content: 'hi' }], + steps: [], + stepNumber: 0, + model: {} as never, + experimental_context: undefined, + }) + expect(prepared.messages).toBeUndefined() + }) + + it('preserves tool call order when downloads finish out of order', async () => { + const fs = { + download: async (path: string) => { + const delay = path.includes('a.png') ? 20 : 0 + await new Promise((resolve) => setTimeout(resolve, delay)) + const payload = path.includes('a.png') ? new Uint8Array([1]) : new Uint8Array([2]) + return new Response(payload, { headers: { 'content-type': 'image/png' } }) + }, + } + const { prepareStep, tools } = createPrepareStepWithReadMedia({ + modelConfig: baseModelConfig, + fs, + systemPrompt: 'sys', + }) + const readMedia = tools.read_media + const first = readMedia.execute( + { path: '/data/media/a.png' }, + { toolCallId: 'call-1' }, + ) + const second = readMedia.execute( + { path: '/data/media/b.png' }, + { toolCallId: 'call-2' }, + ) + await Promise.all([first, second]) + const prepared = await prepareStep({ + messages: [{ role: 'user', content: 'hi' }], + steps: [], + stepNumber: 0, + model: {} as never, + experimental_context: undefined, + }) + const injected = prepared.messages?.[1] + const content = injected?.content as Array<{ type?: string; image?: string }> + expect(content?.[0]?.image?.includes('AQ==')).toBe(true) + expect(content?.[1]?.image?.includes('Ag==')).toBe(true) + }) +}) diff --git a/packages/agent/src/utils/read-media-injector.ts b/packages/agent/src/utils/read-media-injector.ts new file mode 100644 index 00000000..da227712 --- /dev/null +++ b/packages/agent/src/utils/read-media-injector.ts @@ -0,0 +1,126 @@ +import { ImagePart, PrepareStepFunction, ToolSet, UserModelMessage, tool } from 'ai' +import { z } from 'zod' +import { ModelConfig, ModelInput, hasInputModality } from '../types/model' + +const READ_MEDIA_TOOL_NAME = 'read_media' + +const isImageMime = (mime: string): boolean => { + return mime.trim().toLowerCase().startsWith('image/') +} + +const toImagePart = (payload: string): ImagePart => { + return { type: 'image', image: payload } as ImagePart +} + +type ReadMediaFS = { + download: (path: string) => Promise +} + +const buildReadMediaToolError = (message: string) => ({ + isError: true, + content: [{ type: 'text', text: message }], + structuredContent: { ok: false, error: message }, +}) + +const loadImageAsDataUrl = async ( + fs: ReadMediaFS, + path: string, +): Promise<{ ok: true; dataUrl: string; mime: string } | { ok: false; error: string }> => { + try { + const response = await fs.download(path) + const arrayBuffer = await response.arrayBuffer() + const base64 = Buffer.from(arrayBuffer).toString('base64') + const header = response.headers.get('content-type') ?? '' + const mime = header.split(';')[0]?.trim() ?? '' + if (!mime || !isImageMime(mime)) { + return { ok: false, error: 'read_media only supports image files' } + } + return { ok: true, dataUrl: `data:${mime};base64,${base64}`, mime } + } catch (error) { + console.error(error) + const message = error instanceof Error ? error.message : String(error) + return { ok: false, error: `read_media failed to load image: ${message}` } + } +} + +export const createPrepareStepWithReadMedia = (params: { + modelConfig: ModelConfig + fs: ReadMediaFS + systemPrompt: string + basePrepareStep?: PrepareStepFunction +}) => { + const supportsImage = hasInputModality(params.modelConfig, ModelInput.Image) + if (!supportsImage) { + const prepareStep = async (options: Parameters[0]) => { + return (params.basePrepareStep ? await params.basePrepareStep(options) : {}) ?? {} + } + return { prepareStep, tools: {} as ToolSet } + } + const cachedImages = new Map() + const callOrder: string[] = [] + + const readMediaTool = tool({ + description: 'Load an image file into context so the model can view it.', + inputSchema: z.object({ + path: z.string().describe('Image file path inside the container.'), + }), + execute: async ({ path }, options) => { + const trimmedPath = typeof path === 'string' ? path.trim() : '' + if (!trimmedPath) { + return buildReadMediaToolError('path is required') + } + const toolCallId = typeof options?.toolCallId === 'string' ? options.toolCallId : '' + if (!toolCallId) { + return buildReadMediaToolError('read_media missing toolCallId') + } + if (!cachedImages.has(toolCallId)) { + cachedImages.set(toolCallId, null) + callOrder.push(toolCallId) + } + const loaded = await loadImageAsDataUrl(params.fs, trimmedPath) + if (!loaded.ok) { + return buildReadMediaToolError(loaded.error) + } + cachedImages.set(toolCallId, toImagePart(loaded.dataUrl) as ImagePart) + return { ok: true, path: trimmedPath, mime: loaded.mime } + }, + }) + + const prepareStep = async (options: Parameters[0]) => { + const base = (params.basePrepareStep ? await params.basePrepareStep(options) : {}) ?? {} + const baseMessages = base.messages ?? options.messages + if (cachedImages.size === 0) { + if (!base.system) { + base.system = params.systemPrompt + } + return base + } + const imageParts = callOrder + .map((toolCallId) => cachedImages.get(toolCallId)) + .filter((part): part is ImagePart => Boolean(part)) + if (imageParts.length === 0) { + if (!base.system) { + base.system = params.systemPrompt + } + return base + } + const injectedMessage: UserModelMessage = { + role: 'user', + content: imageParts, + } + const merged = { + ...base, + messages: [...baseMessages, injectedMessage], + } + if (!merged.system) { + merged.system = params.systemPrompt + } + return merged + } + + const readMediaTools: ToolSet = { + [READ_MEDIA_TOOL_NAME]: readMediaTool, + } + + return { prepareStep, tools: readMediaTools } +}