mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat: add mcp support
This commit is contained in:
+32
-7
@@ -12,6 +12,7 @@ import {
|
||||
AttachmentsStreamExtractor,
|
||||
} from './utils/attachments'
|
||||
import type { ContainerFileAttachment } from './types/attachment'
|
||||
import { getMCPTools } from './tools/mcp'
|
||||
|
||||
export const createAgent = ({
|
||||
model: modelConfig,
|
||||
@@ -21,6 +22,7 @@ export const createAgent = ({
|
||||
allowedActions = allActions,
|
||||
identity,
|
||||
channels = [],
|
||||
mcpConnections = [],
|
||||
currentChannel = 'Unknown Channel',
|
||||
}: AgentParams, fetch: AuthFetcher) => {
|
||||
const model = createModel(modelConfig)
|
||||
@@ -36,12 +38,20 @@ export const createAgent = ({
|
||||
})
|
||||
}
|
||||
|
||||
const tools = getTools(allowedActions, {
|
||||
fetch,
|
||||
model: modelConfig,
|
||||
brave,
|
||||
identity,
|
||||
})
|
||||
const getAgentTools = async () => {
|
||||
const tools = getTools(allowedActions, {
|
||||
fetch,
|
||||
model: modelConfig,
|
||||
brave,
|
||||
identity,
|
||||
})
|
||||
const { tools: mcpTools, close: closeMCP } = await getMCPTools(mcpConnections)
|
||||
Object.assign(tools, mcpTools)
|
||||
return {
|
||||
tools,
|
||||
close: closeMCP,
|
||||
}
|
||||
}
|
||||
|
||||
const generateUserPrompt = (input: AgentInput) => {
|
||||
const images = input.attachments.filter(attachment => attachment.type === 'image')
|
||||
@@ -67,6 +77,7 @@ export const createAgent = ({
|
||||
const userPrompt = generateUserPrompt(input)
|
||||
const messages = [...input.messages, userPrompt]
|
||||
const systemPrompt = generateSystemPrompt()
|
||||
const { tools, close } = await getAgentTools()
|
||||
const { response, reasoning, text, usage } = await generateText({
|
||||
model,
|
||||
messages,
|
||||
@@ -77,6 +88,9 @@ export const createAgent = ({
|
||||
system: systemPrompt,
|
||||
}
|
||||
},
|
||||
onFinish: async () => {
|
||||
await close()
|
||||
},
|
||||
tools,
|
||||
})
|
||||
const { cleanedText, attachments: textAttachments } = extractAttachmentsFromText(text)
|
||||
@@ -111,6 +125,7 @@ export const createAgent = ({
|
||||
})
|
||||
}
|
||||
const messages = [...params.messages, userPrompt]
|
||||
const { tools, close } = await getAgentTools()
|
||||
const { response, reasoning, text, usage } = await generateText({
|
||||
model,
|
||||
messages,
|
||||
@@ -121,6 +136,9 @@ export const createAgent = ({
|
||||
system: generateSubagentSystemPrompt(),
|
||||
}
|
||||
},
|
||||
onFinish: async () => {
|
||||
await close()
|
||||
},
|
||||
tools,
|
||||
})
|
||||
return {
|
||||
@@ -142,11 +160,16 @@ export const createAgent = ({
|
||||
]
|
||||
}
|
||||
const messages = [...params.messages, scheduleMessage]
|
||||
const { tools, close } = await getAgentTools()
|
||||
const { response, reasoning, text, usage } = await generateText({
|
||||
model,
|
||||
messages,
|
||||
system: generateSystemPrompt(),
|
||||
stopWhen: stepCountIs(Infinity),
|
||||
onFinish: async () => {
|
||||
await close()
|
||||
},
|
||||
tools,
|
||||
})
|
||||
return {
|
||||
messages: [scheduleMessage, ...response.messages],
|
||||
@@ -170,6 +193,7 @@ export const createAgent = ({
|
||||
reasoning: [],
|
||||
usage: null
|
||||
}
|
||||
const { tools, close } = await getAgentTools()
|
||||
const { fullStream } = streamText({
|
||||
model,
|
||||
messages,
|
||||
@@ -181,7 +205,8 @@ export const createAgent = ({
|
||||
}
|
||||
},
|
||||
tools,
|
||||
onFinish: ({ usage, reasoning, response }) => {
|
||||
onFinish: async ({ usage, reasoning, response }) => {
|
||||
await close()
|
||||
result.usage = usage as never
|
||||
result.reasoning = reasoning.map(part => part.text)
|
||||
result.messages = response.messages
|
||||
|
||||
+26
-1
@@ -54,4 +54,29 @@ export const FileAttachmentModel = z.object({
|
||||
metadata: z.record(z.string(), z.any()).optional(),
|
||||
})
|
||||
|
||||
export const AttachmentModel = z.union([ImageAttachmentModel, FileAttachmentModel])
|
||||
export const AttachmentModel = z.union([ImageAttachmentModel, FileAttachmentModel])
|
||||
|
||||
export const HTTPMCPConnectionModel = z.object({
|
||||
name: z.string().min(1, 'Name is required'),
|
||||
type: z.literal('http'),
|
||||
url: z.string().min(1, 'URL is required'),
|
||||
headers: z.record(z.string(), z.string()).optional(),
|
||||
})
|
||||
|
||||
export const SSEMCPConnectionModel = z.object({
|
||||
name: z.string().min(1, 'Name is required'),
|
||||
type: z.literal('sse'),
|
||||
url: z.string().min(1, 'URL is required'),
|
||||
headers: z.record(z.string(), z.string()).optional(),
|
||||
})
|
||||
|
||||
export const StdioMCPConnectionModel = z.object({
|
||||
name: z.string().min(1, 'Name is required'),
|
||||
type: z.literal('stdio'),
|
||||
command: z.string().min(1, 'Command is required'),
|
||||
args: z.array(z.string()),
|
||||
env: z.record(z.string(), z.string()).optional(),
|
||||
cwd: z.string().optional(),
|
||||
})
|
||||
|
||||
export const MCPConnectionModel = z.union([HTTPMCPConnectionModel, SSEMCPConnectionModel, StdioMCPConnectionModel])
|
||||
+20
-18
@@ -4,7 +4,7 @@ import { createAgent } from '../agent'
|
||||
import { createAuthFetcher } from '../index'
|
||||
import { ModelConfig } from '../types'
|
||||
import { bearerMiddleware } from '../middlewares/bearer'
|
||||
import { AllowedActionModel, AttachmentModel, IdentityContextModel, ModelConfigModel } from '../models'
|
||||
import { AllowedActionModel, AttachmentModel, IdentityContextModel, MCPConnectionModel, ModelConfigModel } from '../models'
|
||||
import { allActions } from '../types'
|
||||
|
||||
const AgentModel = z.object({
|
||||
@@ -18,6 +18,7 @@ const AgentModel = z.object({
|
||||
query: z.string(),
|
||||
identity: IdentityContextModel,
|
||||
attachments: z.array(AttachmentModel).optional().default([]),
|
||||
mcpConnections: z.array(MCPConnectionModel).optional().default([]),
|
||||
})
|
||||
|
||||
export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
@@ -31,6 +32,7 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
currentChannel: body.currentChannel,
|
||||
allowedActions: body.allowedActions,
|
||||
identity: body.identity,
|
||||
mcpConnections: body.mcpConnections,
|
||||
}, authFetcher)
|
||||
return ask({
|
||||
query: body.query,
|
||||
@@ -44,22 +46,23 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
.post('/stream', async function* ({ body, bearer }) {
|
||||
try {
|
||||
const authFetcher = createAuthFetcher(bearer)
|
||||
const { stream } = createAgent({
|
||||
model: body.model as ModelConfig,
|
||||
activeContextTime: body.activeContextTime,
|
||||
channels: body.channels,
|
||||
currentChannel: body.currentChannel,
|
||||
allowedActions: body.allowedActions,
|
||||
identity: body.identity,
|
||||
}, authFetcher)
|
||||
for await (const action of stream({
|
||||
query: body.query,
|
||||
messages: body.messages,
|
||||
skills: body.skills,
|
||||
attachments: body.attachments,
|
||||
})) {
|
||||
yield sse(JSON.stringify(action))
|
||||
}
|
||||
const { stream } = createAgent({
|
||||
model: body.model as ModelConfig,
|
||||
activeContextTime: body.activeContextTime,
|
||||
channels: body.channels,
|
||||
currentChannel: body.currentChannel,
|
||||
allowedActions: body.allowedActions,
|
||||
identity: body.identity,
|
||||
mcpConnections: body.mcpConnections,
|
||||
}, authFetcher)
|
||||
for await (const action of stream({
|
||||
query: body.query,
|
||||
messages: body.messages,
|
||||
skills: body.skills,
|
||||
attachments: body.attachments,
|
||||
})) {
|
||||
yield sse(JSON.stringify(action))
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(error)
|
||||
yield sse(JSON.stringify({
|
||||
@@ -70,4 +73,3 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
}, {
|
||||
body: AgentModel,
|
||||
})
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
import { HTTPMCPConnection, MCPConnection, SSEMCPConnection, StdioMCPConnection } from '../types'
|
||||
import { createMCPClient } from '@ai-sdk/mcp'
|
||||
|
||||
export const getMCPTools = async (connections: MCPConnection[]) => {
|
||||
const closeCallbacks: Array<() => Promise<void>> = []
|
||||
|
||||
const getHTTPTools = async (connection: HTTPMCPConnection) => {
|
||||
const client = await createMCPClient({
|
||||
transport: {
|
||||
type: 'http',
|
||||
url: connection.url,
|
||||
headers: connection.headers,
|
||||
}
|
||||
})
|
||||
closeCallbacks.push(client.close)
|
||||
return await client.tools()
|
||||
}
|
||||
|
||||
const getSSETools = async (connection: SSEMCPConnection) => {
|
||||
const client = await createMCPClient({
|
||||
transport: {
|
||||
type: 'sse',
|
||||
url: connection.url,
|
||||
headers: connection.headers,
|
||||
}
|
||||
})
|
||||
closeCallbacks.push(client.close)
|
||||
return await client.tools()
|
||||
}
|
||||
|
||||
const getStdioTools = async (connection: StdioMCPConnection) => {
|
||||
// TODO: Implement stdio tools
|
||||
return []
|
||||
}
|
||||
|
||||
return {
|
||||
tools: await Promise.all(connections.map(connection => {
|
||||
switch (connection.type) {
|
||||
case 'http':
|
||||
return getHTTPTools(connection)
|
||||
case 'sse':
|
||||
return getSSETools(connection)
|
||||
case 'stdio':
|
||||
return getStdioTools(connection)
|
||||
}
|
||||
})),
|
||||
close: async () => {
|
||||
await Promise.all(closeCallbacks.map(callback => callback()))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
import { ModelMessage } from 'ai'
|
||||
import { ModelConfig } from './model'
|
||||
import { AgentAttachment } from './attachment'
|
||||
import { MCPConnection } from './mcp'
|
||||
|
||||
export interface IdentityContext {
|
||||
botId: string
|
||||
@@ -43,6 +44,7 @@ export interface AgentParams {
|
||||
identity: IdentityContext
|
||||
channels?: string[]
|
||||
currentChannel?: string
|
||||
mcpConnections?: MCPConnection[]
|
||||
}
|
||||
|
||||
export interface AgentInput {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
export * from './agent'
|
||||
export * from './model'
|
||||
export * from './schedule'
|
||||
export * from './attachment'
|
||||
export * from './attachment'
|
||||
export * from './mcp'
|
||||
@@ -0,0 +1,29 @@
|
||||
export interface BaseMCPConnection {
|
||||
type: string
|
||||
name: string
|
||||
}
|
||||
|
||||
export interface StdioMCPConnection extends BaseMCPConnection {
|
||||
type: 'stdio'
|
||||
command: string
|
||||
args: string[]
|
||||
env?: Record<string, string>
|
||||
cwd?: string
|
||||
}
|
||||
|
||||
export interface HTTPMCPConnection extends BaseMCPConnection {
|
||||
type: 'http'
|
||||
url: string
|
||||
headers?: Record<string, string>
|
||||
}
|
||||
|
||||
export interface SSEMCPConnection extends BaseMCPConnection {
|
||||
type: 'sse'
|
||||
url: string
|
||||
headers?: Record<string, string>
|
||||
}
|
||||
|
||||
export type MCPConnection =
|
||||
| StdioMCPConnection
|
||||
| HTTPMCPConnection
|
||||
| SSEMCPConnection
|
||||
Reference in New Issue
Block a user