mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat(agent): add readMedia tool for model to view the image (#165)
* feat(agent): add readMedia tool for loading local images into model context * feat(channel/inbound): include container attachment refs in inbound query * fix(agent): preserve ImagePart literal typing in buildNativeImageParts * chore: rename tool --------- Co-authored-by: 晨苒 <16112591+chen-ran@users.noreply.github.com>
This commit is contained in:
+31
-25
@@ -1,12 +1,13 @@
|
||||
import {
|
||||
generateText,
|
||||
ImagePart,
|
||||
type ImagePart,
|
||||
LanguageModelUsage,
|
||||
ModelMessage,
|
||||
stepCountIs,
|
||||
streamText,
|
||||
ToolSet,
|
||||
UserModelMessage,
|
||||
type PrepareStepFunction,
|
||||
} from 'ai'
|
||||
import {
|
||||
AgentInput,
|
||||
@@ -36,6 +37,7 @@ import { buildIdentityHeaders } from './utils/headers'
|
||||
import { createFS } from './utils'
|
||||
import { createTextLoopGuard, createTextLoopProbeBuffer } from './sential'
|
||||
import { createToolLoopGuardedTools } from './tool-loop'
|
||||
import { createPrepareStepWithReadMedia } from './utils/read-media-injector'
|
||||
|
||||
const ANTHROPIC_BUDGET: Record<string, number> = { low: 5000, medium: 16000, high: 50000 }
|
||||
const GOOGLE_BUDGET: Record<string, number> = { low: 5000, medium: 16000, high: 50000 }
|
||||
@@ -84,7 +86,7 @@ export const buildNativeImageParts = (attachments: GatewayInputAttachment[]): Im
|
||||
(attachment.transport === 'inline_data_url' || attachment.transport === 'public_url') &&
|
||||
Boolean(attachment.payload),
|
||||
)
|
||||
.map((attachment) => ({ type: 'image', image: attachment.payload } as ImagePart))
|
||||
.map((attachment): ImagePart => ({ type: 'image', image: attachment.payload }))
|
||||
}
|
||||
|
||||
export const createAgent = (
|
||||
@@ -110,6 +112,7 @@ export const createAgent = (
|
||||
fetch: AuthFetcher,
|
||||
) => {
|
||||
const model = createModel(modelConfig)
|
||||
const supportsImageInput = hasInputModality(modelConfig, ModelInput.Image)
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const providerOptions = buildProviderOptions(modelConfig) as any
|
||||
const loopDetectionEnabled = loopDetection?.enabled === true
|
||||
@@ -161,6 +164,7 @@ export const createAgent = (
|
||||
skills,
|
||||
enabledSkills,
|
||||
inbox,
|
||||
supportsImageInput,
|
||||
files,
|
||||
})
|
||||
}
|
||||
@@ -205,8 +209,7 @@ export const createAgent = (
|
||||
}
|
||||
|
||||
const generateUserPrompt = (input: AgentInput) => {
|
||||
const supportsImage = hasInputModality(modelConfig, ModelInput.Image)
|
||||
const imageParts = supportsImage ? buildNativeImageParts(input.attachments) : []
|
||||
const imageParts = supportsImageInput ? buildNativeImageParts(input.attachments) : []
|
||||
|
||||
const userMessage: UserModelMessage = {
|
||||
role: 'user',
|
||||
@@ -250,13 +253,20 @@ export const createAgent = (
|
||||
const runTextGeneration = async ({
|
||||
messages,
|
||||
systemPrompt,
|
||||
prepareStep,
|
||||
basePrepareStep,
|
||||
}: {
|
||||
messages: ModelMessage[]
|
||||
systemPrompt: string
|
||||
prepareStep?: () => { system: string }
|
||||
basePrepareStep?: PrepareStepFunction
|
||||
}) => {
|
||||
const { tools, close } = await getAgentTools()
|
||||
const { tools: baseTools, close } = await getAgentTools()
|
||||
const { prepareStep, tools: readMediaTools } = createPrepareStepWithReadMedia({
|
||||
modelConfig,
|
||||
fs,
|
||||
systemPrompt,
|
||||
basePrepareStep,
|
||||
})
|
||||
const tools = { ...baseTools, ...readMediaTools }
|
||||
let shouldAbortForToolLoop = false
|
||||
const guardedTools = buildGuardedTools(tools, () => {
|
||||
shouldAbortForToolLoop = true
|
||||
@@ -270,7 +280,7 @@ export const createAgent = (
|
||||
system: systemPrompt,
|
||||
...(providerOptions && { providerOptions }),
|
||||
stopWhen: stepCountIs(Infinity),
|
||||
...(prepareStep && { prepareStep }),
|
||||
prepareStep,
|
||||
...(loopDetectionEnabled && {
|
||||
onStepFinish: ({ text }: { text: string }) => {
|
||||
if (shouldAbortForToolLoop) {
|
||||
@@ -306,11 +316,7 @@ export const createAgent = (
|
||||
const { response, reasoning, text, usage, steps } = await runTextGeneration({
|
||||
messages,
|
||||
systemPrompt,
|
||||
prepareStep: () => {
|
||||
return {
|
||||
system: systemPrompt,
|
||||
}
|
||||
},
|
||||
basePrepareStep: () => ({ system: systemPrompt }),
|
||||
})
|
||||
const stepUsages = buildStepUsages(steps)
|
||||
const { cleanedText, attachments: textAttachments } =
|
||||
@@ -352,15 +358,12 @@ export const createAgent = (
|
||||
description: params.description,
|
||||
})
|
||||
}
|
||||
const systemPrompt = generateSubagentSystemPrompt()
|
||||
const messages = [...params.messages, userPrompt]
|
||||
const { response, reasoning, text, usage, steps } = await runTextGeneration({
|
||||
messages,
|
||||
systemPrompt: generateSubagentSystemPrompt(),
|
||||
prepareStep: () => {
|
||||
return {
|
||||
system: generateSubagentSystemPrompt(),
|
||||
}
|
||||
},
|
||||
systemPrompt,
|
||||
basePrepareStep: () => ({ system: generateSubagentSystemPrompt() }),
|
||||
})
|
||||
const stepUsages = buildStepUsages(steps)
|
||||
return {
|
||||
@@ -497,7 +500,14 @@ export const createAgent = (
|
||||
usages: [],
|
||||
}
|
||||
const toolLoopAbortCallIds = new Set<string>()
|
||||
const { tools, close } = await getAgentTools()
|
||||
const { tools: baseTools, close } = await getAgentTools()
|
||||
const { prepareStep, tools: readMediaTools } = createPrepareStepWithReadMedia({
|
||||
modelConfig,
|
||||
fs,
|
||||
systemPrompt,
|
||||
basePrepareStep: () => ({ system: systemPrompt }),
|
||||
})
|
||||
const tools = { ...baseTools, ...readMediaTools }
|
||||
// Stream path needs deferred abort to keep tool_call_start/tool_call_end event pairing.
|
||||
const guardedTools = buildGuardedTools(tools, (toolCallId) => {
|
||||
toolLoopAbortCallIds.add(toolCallId)
|
||||
@@ -517,11 +527,7 @@ export const createAgent = (
|
||||
system: systemPrompt,
|
||||
...(providerOptions && { providerOptions }),
|
||||
stopWhen: stepCountIs(Infinity),
|
||||
prepareStep: () => {
|
||||
return {
|
||||
system: systemPrompt,
|
||||
}
|
||||
},
|
||||
prepareStep,
|
||||
tools: guardedTools,
|
||||
onFinish: async ({ usage, reasoning, response, steps }) => {
|
||||
await closeTools()
|
||||
|
||||
@@ -14,6 +14,7 @@ export interface SystemParams {
|
||||
files: SystemFile[]
|
||||
attachments?: string[]
|
||||
inbox?: InboxItem[]
|
||||
supportsImageInput?: boolean
|
||||
}
|
||||
|
||||
export const skillPrompt = (skill: AgentSkill) => {
|
||||
@@ -65,6 +66,7 @@ export const system = ({
|
||||
enabledSkills,
|
||||
files,
|
||||
inbox = [],
|
||||
supportsImageInput = true,
|
||||
}: SystemParams) => {
|
||||
const home = '/data'
|
||||
// ── Static section (stable prefix for LLM prompt caching) ──────────
|
||||
@@ -80,6 +82,16 @@ export const system = ({
|
||||
'time-now': date.toISOString(),
|
||||
}
|
||||
|
||||
const basicTools = [
|
||||
`- ${quote('read')}: read file content`,
|
||||
supportsImageInput ? `- ${quote('read_media')}: view the media` : null,
|
||||
`- ${quote('write')}: write file content`,
|
||||
`- ${quote('list')}: list directory entries`,
|
||||
`- ${quote('edit')}: replace exact text in a file`,
|
||||
`- ${quote('exec')}: execute command`,
|
||||
]
|
||||
.filter((line): line is string => Boolean(line))
|
||||
.join('\n')
|
||||
console.log('inbox', inbox)
|
||||
|
||||
return `
|
||||
@@ -93,11 +105,7 @@ You are just woke up.
|
||||
${quote(home)} is your HOME — you can read and write files there freely.
|
||||
|
||||
## Basic Tools
|
||||
- ${quote('read')}: read file content
|
||||
- ${quote('write')}: write file content
|
||||
- ${quote('list')}: list directory entries
|
||||
- ${quote('edit')}: replace exact text in a file
|
||||
- ${quote('exec')}: execute command
|
||||
${basicTools}
|
||||
|
||||
## Safety
|
||||
- Keep private data private
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import type { AssistantModelMessage, ModelMessage, TextPart } from 'ai'
|
||||
import type { AgentAttachment, ContainerFileAttachment } from '../types/attachment'
|
||||
import type {
|
||||
AgentAttachment,
|
||||
ContainerFileAttachment,
|
||||
} from '../types/attachment'
|
||||
|
||||
const ATTACHMENTS_START = '<attachments>'
|
||||
const ATTACHMENTS_END = '</attachments>'
|
||||
@@ -196,4 +199,3 @@ export class AttachmentsStreamExtractor {
|
||||
return { visibleText: out, attachments: [] }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { createPrepareStepWithReadMedia } from './read-media-injector'
|
||||
import { ClientType, ModelInput, ModelConfig } from '../types/model'
|
||||
|
||||
const baseModelConfig: ModelConfig = {
|
||||
apiKey: 'test',
|
||||
baseUrl: 'http://example.com',
|
||||
modelId: 'model',
|
||||
clientType: ClientType.OpenAIResponses,
|
||||
input: [ModelInput.Image],
|
||||
}
|
||||
|
||||
describe('read_media runtime', () => {
|
||||
it('caches image and injects it into messages', async () => {
|
||||
const fs = {
|
||||
download: async () =>
|
||||
new Response(new Uint8Array([1, 2, 3]), {
|
||||
headers: { 'content-type': 'image/png' },
|
||||
}),
|
||||
}
|
||||
const { prepareStep, tools } = createPrepareStepWithReadMedia({
|
||||
modelConfig: baseModelConfig,
|
||||
fs,
|
||||
systemPrompt: 'sys',
|
||||
})
|
||||
const readMedia = tools.read_media
|
||||
const output = await readMedia.execute(
|
||||
{ path: '/data/media/a.png' },
|
||||
{ toolCallId: 'call-1' },
|
||||
)
|
||||
expect((output as { ok?: boolean }).ok).toBe(true)
|
||||
const prepared = await prepareStep({
|
||||
messages: [{ role: 'user', content: 'hi' }],
|
||||
steps: [],
|
||||
stepNumber: 0,
|
||||
model: {} as never,
|
||||
experimental_context: undefined,
|
||||
})
|
||||
const injected = prepared.messages?.[1]
|
||||
expect(injected?.role).toBe('user')
|
||||
const content = injected?.content as Array<{ type?: string; image?: string }>
|
||||
expect(content?.[0]?.type).toBe('image')
|
||||
expect(content?.[0]?.image?.startsWith('data:image/png;base64,')).toBe(true)
|
||||
})
|
||||
|
||||
it('returns error result on download failure', async () => {
|
||||
const fs = {
|
||||
download: async () => {
|
||||
throw new Error('boom')
|
||||
},
|
||||
}
|
||||
const { prepareStep, tools } = createPrepareStepWithReadMedia({
|
||||
modelConfig: baseModelConfig,
|
||||
fs,
|
||||
systemPrompt: 'sys',
|
||||
})
|
||||
const readMedia = tools.read_media
|
||||
const output = await readMedia.execute(
|
||||
{ path: '/data/media/a.png' },
|
||||
{ toolCallId: 'call-2' },
|
||||
)
|
||||
expect((output as { isError?: boolean }).isError).toBe(true)
|
||||
const prepared = await prepareStep({
|
||||
messages: [{ role: 'user', content: 'hi' }],
|
||||
steps: [],
|
||||
stepNumber: 0,
|
||||
model: {} as never,
|
||||
experimental_context: undefined,
|
||||
})
|
||||
expect(prepared.messages).toBeUndefined()
|
||||
})
|
||||
|
||||
it('preserves tool call order when downloads finish out of order', async () => {
|
||||
const fs = {
|
||||
download: async (path: string) => {
|
||||
const delay = path.includes('a.png') ? 20 : 0
|
||||
await new Promise((resolve) => setTimeout(resolve, delay))
|
||||
const payload = path.includes('a.png') ? new Uint8Array([1]) : new Uint8Array([2])
|
||||
return new Response(payload, { headers: { 'content-type': 'image/png' } })
|
||||
},
|
||||
}
|
||||
const { prepareStep, tools } = createPrepareStepWithReadMedia({
|
||||
modelConfig: baseModelConfig,
|
||||
fs,
|
||||
systemPrompt: 'sys',
|
||||
})
|
||||
const readMedia = tools.read_media
|
||||
const first = readMedia.execute(
|
||||
{ path: '/data/media/a.png' },
|
||||
{ toolCallId: 'call-1' },
|
||||
)
|
||||
const second = readMedia.execute(
|
||||
{ path: '/data/media/b.png' },
|
||||
{ toolCallId: 'call-2' },
|
||||
)
|
||||
await Promise.all([first, second])
|
||||
const prepared = await prepareStep({
|
||||
messages: [{ role: 'user', content: 'hi' }],
|
||||
steps: [],
|
||||
stepNumber: 0,
|
||||
model: {} as never,
|
||||
experimental_context: undefined,
|
||||
})
|
||||
const injected = prepared.messages?.[1]
|
||||
const content = injected?.content as Array<{ type?: string; image?: string }>
|
||||
expect(content?.[0]?.image?.includes('AQ==')).toBe(true)
|
||||
expect(content?.[1]?.image?.includes('Ag==')).toBe(true)
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,126 @@
|
||||
import { ImagePart, PrepareStepFunction, ToolSet, UserModelMessage, tool } from 'ai'
|
||||
import { z } from 'zod'
|
||||
import { ModelConfig, ModelInput, hasInputModality } from '../types/model'
|
||||
|
||||
const READ_MEDIA_TOOL_NAME = 'read_media'
|
||||
|
||||
const isImageMime = (mime: string): boolean => {
|
||||
return mime.trim().toLowerCase().startsWith('image/')
|
||||
}
|
||||
|
||||
const toImagePart = (payload: string): ImagePart => {
|
||||
return { type: 'image', image: payload } as ImagePart
|
||||
}
|
||||
|
||||
type ReadMediaFS = {
|
||||
download: (path: string) => Promise<Response>
|
||||
}
|
||||
|
||||
const buildReadMediaToolError = (message: string) => ({
|
||||
isError: true,
|
||||
content: [{ type: 'text', text: message }],
|
||||
structuredContent: { ok: false, error: message },
|
||||
})
|
||||
|
||||
const loadImageAsDataUrl = async (
|
||||
fs: ReadMediaFS,
|
||||
path: string,
|
||||
): Promise<{ ok: true; dataUrl: string; mime: string } | { ok: false; error: string }> => {
|
||||
try {
|
||||
const response = await fs.download(path)
|
||||
const arrayBuffer = await response.arrayBuffer()
|
||||
const base64 = Buffer.from(arrayBuffer).toString('base64')
|
||||
const header = response.headers.get('content-type') ?? ''
|
||||
const mime = header.split(';')[0]?.trim() ?? ''
|
||||
if (!mime || !isImageMime(mime)) {
|
||||
return { ok: false, error: 'read_media only supports image files' }
|
||||
}
|
||||
return { ok: true, dataUrl: `data:${mime};base64,${base64}`, mime }
|
||||
} catch (error) {
|
||||
console.error(error)
|
||||
const message = error instanceof Error ? error.message : String(error)
|
||||
return { ok: false, error: `read_media failed to load image: ${message}` }
|
||||
}
|
||||
}
|
||||
|
||||
export const createPrepareStepWithReadMedia = (params: {
|
||||
modelConfig: ModelConfig
|
||||
fs: ReadMediaFS
|
||||
systemPrompt: string
|
||||
basePrepareStep?: PrepareStepFunction
|
||||
}) => {
|
||||
const supportsImage = hasInputModality(params.modelConfig, ModelInput.Image)
|
||||
if (!supportsImage) {
|
||||
const prepareStep = async (options: Parameters<PrepareStepFunction>[0]) => {
|
||||
return (params.basePrepareStep ? await params.basePrepareStep(options) : {}) ?? {}
|
||||
}
|
||||
return { prepareStep, tools: {} as ToolSet }
|
||||
}
|
||||
const cachedImages = new Map<string, ImagePart | null>()
|
||||
const callOrder: string[] = []
|
||||
|
||||
const readMediaTool = tool({
|
||||
description: 'Load an image file into context so the model can view it.',
|
||||
inputSchema: z.object({
|
||||
path: z.string().describe('Image file path inside the container.'),
|
||||
}),
|
||||
execute: async ({ path }, options) => {
|
||||
const trimmedPath = typeof path === 'string' ? path.trim() : ''
|
||||
if (!trimmedPath) {
|
||||
return buildReadMediaToolError('path is required')
|
||||
}
|
||||
const toolCallId = typeof options?.toolCallId === 'string' ? options.toolCallId : ''
|
||||
if (!toolCallId) {
|
||||
return buildReadMediaToolError('read_media missing toolCallId')
|
||||
}
|
||||
if (!cachedImages.has(toolCallId)) {
|
||||
cachedImages.set(toolCallId, null)
|
||||
callOrder.push(toolCallId)
|
||||
}
|
||||
const loaded = await loadImageAsDataUrl(params.fs, trimmedPath)
|
||||
if (!loaded.ok) {
|
||||
return buildReadMediaToolError(loaded.error)
|
||||
}
|
||||
cachedImages.set(toolCallId, toImagePart(loaded.dataUrl) as ImagePart)
|
||||
return { ok: true, path: trimmedPath, mime: loaded.mime }
|
||||
},
|
||||
})
|
||||
|
||||
const prepareStep = async (options: Parameters<PrepareStepFunction>[0]) => {
|
||||
const base = (params.basePrepareStep ? await params.basePrepareStep(options) : {}) ?? {}
|
||||
const baseMessages = base.messages ?? options.messages
|
||||
if (cachedImages.size === 0) {
|
||||
if (!base.system) {
|
||||
base.system = params.systemPrompt
|
||||
}
|
||||
return base
|
||||
}
|
||||
const imageParts = callOrder
|
||||
.map((toolCallId) => cachedImages.get(toolCallId))
|
||||
.filter((part): part is ImagePart => Boolean(part))
|
||||
if (imageParts.length === 0) {
|
||||
if (!base.system) {
|
||||
base.system = params.systemPrompt
|
||||
}
|
||||
return base
|
||||
}
|
||||
const injectedMessage: UserModelMessage = {
|
||||
role: 'user',
|
||||
content: imageParts,
|
||||
}
|
||||
const merged = {
|
||||
...base,
|
||||
messages: [...baseMessages, injectedMessage],
|
||||
}
|
||||
if (!merged.system) {
|
||||
merged.system = params.systemPrompt
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
const readMediaTools: ToolSet = {
|
||||
[READ_MEDIA_TOOL_NAME]: readMediaTool,
|
||||
}
|
||||
|
||||
return { prepareStep, tools: readMediaTools }
|
||||
}
|
||||
Reference in New Issue
Block a user