diff --git a/agent/src/agent.ts b/agent/src/agent.ts index e66d1c13..992511ed 100644 --- a/agent/src/agent.ts +++ b/agent/src/agent.ts @@ -1,10 +1,9 @@ import { generateText, ImagePart, LanguageModelUsage, ModelMessage, stepCountIs, streamText, UserModelMessage } from 'ai' -import { AgentInput, AgentParams, AgentSkill, allActions, HTTPMCPConnection, MCPConnection, Schedule, StdioMCPConnection } from './types' +import { AgentInput, AgentParams, AgentSkill, allActions, Schedule } from './types' import { system, schedule, user, subagentSystem } from './prompts' import { AuthFetcher } from './index' import { createModel } from './model' import { AgentAction } from './types/action' -import { getTools } from './tools' import { extractAttachmentsFromText, stripAttachmentsFromMessages, @@ -21,7 +20,6 @@ export const createAgent = ({ language = 'Same as the user input', allowedActions = allActions, channels = [], - mcpConnections = [], skills = [], currentChannel = 'Unknown Channel', identity = { @@ -47,18 +45,6 @@ export const createAgent = ({ return enabledSkills.map(skill => skill.name) } - const getDefaultMCPConnections = (): MCPConnection[] => { - const fs: HTTPMCPConnection = { - type: 'http', - name: 'fs', - url: `${auth.baseUrl}/bots/${identity.botId}/container/fs-mcp`, - headers: { - 'Authorization': `Bearer ${auth.bearer}`, - }, - } - return [fs] - } - const loadSystemFiles = async () => { if (!auth?.bearer || !identity.botId) { return { @@ -103,26 +89,38 @@ export const createAgent = ({ } const getAgentTools = async () => { - const tools = getTools(allowedActions, { - fetch, - model: modelConfig, - brave, - identity, - auth, - enableSkill, - }) - const defaultMCPConnections = getDefaultMCPConnections() - const { tools: mcpTools, close: closeMCP } = await getMCPTools([ - ...defaultMCPConnections, - ...mcpConnections, - ], { - botId: identity.botId, - auth, - fetch, - }) - Object.assign(tools, mcpTools) + const baseUrl = auth.baseUrl.replace(/\/$/, '') + const botId = identity.botId.trim() + if (!baseUrl || !botId) { + return { + tools: {}, + close: async () => {}, + } + } + const headers: Record = { + 'Authorization': `Bearer ${auth.bearer}`, + } + if (identity.sessionId) { + headers['X-Memoh-Chat-Id'] = identity.sessionId + } + if (identity.channelIdentityId) { + headers['X-Memoh-Channel-Identity-Id'] = identity.channelIdentityId + } + if (identity.sessionToken) { + headers['X-Memoh-Session-Token'] = identity.sessionToken + } + if (identity.currentPlatform) { + headers['X-Memoh-Current-Platform'] = identity.currentPlatform + } + if (identity.replyTarget) { + headers['X-Memoh-Reply-Target'] = identity.replyTarget + } + if (identity.displayName) { + headers['X-Memoh-Display-Name'] = identity.displayName + } + const { tools: mcpTools, close: closeMCP } = await getMCPTools(`${baseUrl}/bots/${botId}/tools`, headers) return { - tools, + tools: mcpTools, close: closeMCP, } } diff --git a/agent/src/modules/chat.ts b/agent/src/modules/chat.ts index a67fac09..db95a559 100644 --- a/agent/src/modules/chat.ts +++ b/agent/src/modules/chat.ts @@ -4,7 +4,7 @@ import { createAgent } from '../agent' import { createAuthFetcher, getBaseUrl, getBraveConfig } from '../index' import { ModelConfig } from '../types' import { bearerMiddleware } from '../middlewares/bearer' -import { AgentSkillModel, AllowedActionModel, AttachmentModel, IdentityContextModel, MCPConnectionModel, ModelConfigModel, ScheduleModel } from '../models' +import { AgentSkillModel, AllowedActionModel, AttachmentModel, IdentityContextModel, ModelConfigModel, ScheduleModel } from '../models' import { allActions } from '../types' const AgentModel = z.object({ @@ -18,7 +18,6 @@ const AgentModel = z.object({ skills: z.array(z.string()), identity: IdentityContextModel, attachments: z.array(AttachmentModel).optional().default([]), - mcpConnections: z.array(MCPConnectionModel).optional().default([]), }) export const chatModule = new Elysia({ prefix: '/chat' }) @@ -33,7 +32,6 @@ export const chatModule = new Elysia({ prefix: '/chat' }) currentChannel: body.currentChannel, allowedActions: body.allowedActions, identity: body.identity, - mcpConnections: body.mcpConnections, auth: { bearer: bearer!, baseUrl: getBaseUrl(), @@ -63,7 +61,6 @@ export const chatModule = new Elysia({ prefix: '/chat' }) currentChannel: body.currentChannel, allowedActions: body.allowedActions, identity: body.identity, - mcpConnections: body.mcpConnections, auth: { bearer: bearer!, baseUrl: getBaseUrl(), @@ -99,7 +96,6 @@ export const chatModule = new Elysia({ prefix: '/chat' }) channels: body.channels, currentChannel: body.currentChannel, identity: body.identity, - mcpConnections: body.mcpConnections, auth: { bearer: bearer!, baseUrl: getBaseUrl(), diff --git a/agent/src/test/unified_mcp_tools.test.ts b/agent/src/test/unified_mcp_tools.test.ts new file mode 100644 index 00000000..6da02ada --- /dev/null +++ b/agent/src/test/unified_mcp_tools.test.ts @@ -0,0 +1,90 @@ +import { describe, expect, test } from 'bun:test' +import { getMCPTools } from '../tools/mcp' + +describe('getMCPTools (unified endpoint)', () => { + test('loads tools from unified MCP HTTP endpoint', async () => { + const seenMethods: string[] = [] + const seenAuthHeaders: string[] = [] + + const server = Bun.serve({ + port: 0, + async fetch(request) { + seenAuthHeaders.push(request.headers.get('authorization') ?? '') + const body = await request.json().catch(() => ({} as Record)) + const method = typeof body?.method === 'string' ? body.method : '' + seenMethods.push(method) + + if (method === 'initialize') { + return Response.json({ + jsonrpc: '2.0', + id: body.id ?? null, + result: { + protocolVersion: '2025-06-18', + capabilities: { + tools: { + listChanged: false, + }, + }, + serverInfo: { + name: 'test-mcp', + version: '1.0.0', + }, + }, + }) + } + + if (method === 'notifications/initialized') { + return new Response(null, { status: 202 }) + } + + if (method === 'tools/list') { + return Response.json({ + jsonrpc: '2.0', + id: body.id ?? null, + result: { + tools: [ + { + name: 'search_memory', + description: 'Search memory', + inputSchema: { + type: 'object', + properties: { + query: { type: 'string' }, + }, + required: ['query'], + }, + }, + ], + }, + }) + } + + return Response.json({ + jsonrpc: '2.0', + id: body.id ?? null, + error: { + code: -32601, + message: 'method not found', + }, + }) + }, + }) + + try { + const endpoint = `http://127.0.0.1:${server.port}/bots/bot-1/tools` + const { tools, close } = await getMCPTools(endpoint, { + Authorization: 'Bearer test-token', + 'X-Memoh-Chat-Id': 'chat-1', + }) + + expect(Object.keys(tools)).toContain('search_memory') + expect(seenMethods).toContain('initialize') + expect(seenMethods).toContain('tools/list') + expect(seenAuthHeaders.some(value => value === 'Bearer test-token')).toBe(true) + + await close() + } finally { + server.stop(true) + } + }) +}) diff --git a/agent/src/tools/mcp.ts b/agent/src/tools/mcp.ts index d10a5387..a6a88447 100644 --- a/agent/src/tools/mcp.ts +++ b/agent/src/tools/mcp.ts @@ -1,110 +1,17 @@ -import { HTTPMCPConnection, MCPConnection, SSEMCPConnection, StdioMCPConnection } from '../types' import { createMCPClient } from '@ai-sdk/mcp' -import { AuthFetcher } from '../index' -import type { AgentAuthContext } from '../types/agent' - -type MCPToolOptions = { - botId?: string - auth?: AgentAuthContext - fetch?: AuthFetcher -} - -export const getMCPTools = async (connections: MCPConnection[], options: MCPToolOptions = {}) => { - const closeCallbacks: Array<() => Promise> = [] - - const getHTTPTools = async (connection: HTTPMCPConnection) => { - const client = await createMCPClient({ - transport: { - type: 'http', - url: connection.url, - headers: connection.headers, - } - }) - closeCallbacks.push(() => client.close()) - const tools = await client.tools() - return tools - } - - const getSSETools = async (connection: SSEMCPConnection) => { - const client = await createMCPClient({ - transport: { - type: 'sse', - url: connection.url, - headers: connection.headers, - } - }) - closeCallbacks.push(() => client.close()) - const tools = await client.tools() - return tools - } - - const getStdioTools = async (connection: StdioMCPConnection) => { - if (!options.fetch || !options.botId || !options.auth) { - throw new Error('stdio mcp requires auth fetcher and bot id') - } - const response = await options.fetch(`/bots/${options.botId}/mcp-stdio`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json' - }, - body: JSON.stringify({ - name: connection.name, - command: connection.command, - args: connection.args ?? [], - env: connection.env ?? {}, - cwd: connection.cwd ?? '' - }) - }) - if (!response.ok) { - const text = await response.text().catch(() => '') - throw new Error(`mcp-stdio failed: ${response.status} ${text}`) - } - const data = await response.json().catch(() => ({} as { url?: string })) - const rawUrl = typeof data?.url === 'string' ? data.url : '' - if (!rawUrl) { - throw new Error('mcp-stdio response missing url') - } - const baseUrl = options.auth.baseUrl ?? '' - const url = rawUrl.startsWith('http') - ? rawUrl - : `${baseUrl.replace(/\/$/, '')}/${rawUrl.replace(/^\//, '')}` - return await getHTTPTools({ +export const getMCPTools = async (url: string, headers: Record = {}) => { + const client = await createMCPClient({ + transport: { type: 'http', - name: connection.name, url, - headers: { - 'Authorization': `Bearer ${options.auth.bearer}` - } - }) - } - - const toolSets = await Promise.all(connections.map(async (connection) => { - try { - switch (connection.type) { - case 'http': - return await getHTTPTools(connection) - case 'sse': - return await getSSETools(connection) - case 'stdio': - return await getStdioTools(connection) - default: - console.warn('unknown mcp connection type', connection) - return {} - } - } catch (error) { - console.warn('skip mcp connection due to initialization error', { - name: connection.name, - type: connection.type, - error: error instanceof Error ? error.message : String(error), - }) - return {} + headers, } - })) - + }) + const tools = await client.tools() return { - tools: Object.assign({}, ...toolSets), + tools, close: async () => { - await Promise.all(closeCallbacks.map(callback => callback())) + await client.close() } } } diff --git a/agent/src/types/agent.ts b/agent/src/types/agent.ts index eac7b312..0dc9a6c3 100644 --- a/agent/src/types/agent.ts +++ b/agent/src/types/agent.ts @@ -1,7 +1,6 @@ import { ModelMessage } from 'ai' import { ModelConfig } from './model' import { AgentAttachment } from './attachment' -import { MCPConnection } from './mcp' export interface IdentityContext { botId: string @@ -52,7 +51,6 @@ export interface AgentParams { brave?: BraveConfig channels?: string[] currentChannel?: string - mcpConnections?: MCPConnection[] identity?: IdentityContext auth: AgentAuthContext skills?: AgentSkill[] diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 6b2c4050..1a8f9db2 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -25,6 +25,10 @@ import ( "github.com/memohai/memoh/internal/handlers" "github.com/memohai/memoh/internal/logger" "github.com/memohai/memoh/internal/mcp" + mcpmemory "github.com/memohai/memoh/internal/mcp/providers/memory" + mcpmessage "github.com/memohai/memoh/internal/mcp/providers/message" + mcpschedule "github.com/memohai/memoh/internal/mcp/providers/schedule" + mcpfederation "github.com/memohai/memoh/internal/mcp/sources/federation" "github.com/memohai/memoh/internal/memory" "github.com/memohai/memoh/internal/models" "github.com/memohai/memoh/internal/policy" @@ -187,6 +191,23 @@ func main() { scheduleHandler := handlers.NewScheduleHandler(logger.L, scheduleService, botService, accountService) subagentService := subagent.NewService(logger.L, queries) subagentHandler := handlers.NewSubagentHandler(logger.L, subagentService, botService, accountService) + messageToolExecutor := mcpmessage.NewExecutor(logger.L, channelManager, channelRegistry) + scheduleToolExecutor := mcpschedule.NewExecutor(logger.L, scheduleService) + memoryToolExecutor := mcpmemory.NewExecutor(logger.L, memoryService, chatService, accountService) + federationGateway := handlers.NewMCPFederationGateway(logger.L, containerdHandler) + federatedToolSource := mcpfederation.NewSource(logger.L, federationGateway, mcpConnectionsService) + toolGatewayService := mcp.NewToolGatewayService( + logger.L, + []mcp.ToolExecutor{ + messageToolExecutor, + scheduleToolExecutor, + memoryToolExecutor, + }, + []mcp.ToolSource{ + federatedToolSource, + }, + ) + containerdHandler.SetToolGatewayService(toolGatewayService) srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, preauthHandler, bindHandler, scheduleHandler, subagentHandler, containerdHandler, channelHandler, usersHandler, mcpHandler, cliHandler, webHandler) if err := srv.Start(); err != nil { diff --git a/docs/docs.go b/docs/docs.go index b8bd4e5f..2a98fd12 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -2152,6 +2152,59 @@ const docTemplate = `{ } } }, + "/bots/{bot_id}/tools": { + "post": { + "description": "MCP endpoint for tool discovery and invocation.", + "tags": [ + "containerd" + ], + "summary": "Unified MCP tools gateway", + "parameters": [ + { + "type": "string", + "description": "Bot ID", + "name": "bot_id", + "in": "path", + "required": true + }, + { + "description": "JSON-RPC request", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "type": "object" + } + } + ], + "responses": { + "200": { + "description": "JSON-RPC response: {jsonrpc,id,result|error}", + "schema": { + "type": "object" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, "/bots/{id}": { "get": { "description": "Get a bot by ID (owner/admin only)", diff --git a/docs/swagger.json b/docs/swagger.json index 465dd098..9240174e 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -2143,6 +2143,59 @@ } } }, + "/bots/{bot_id}/tools": { + "post": { + "description": "MCP endpoint for tool discovery and invocation.", + "tags": [ + "containerd" + ], + "summary": "Unified MCP tools gateway", + "parameters": [ + { + "type": "string", + "description": "Bot ID", + "name": "bot_id", + "in": "path", + "required": true + }, + { + "description": "JSON-RPC request", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "type": "object" + } + } + ], + "responses": { + "200": { + "description": "JSON-RPC response: {jsonrpc,id,result|error}", + "schema": { + "type": "object" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, "/bots/{id}": { "get": { "description": "Get a bot by ID (owner/admin only)", diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 1e52852c..26eea9bc 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -2665,6 +2665,41 @@ paths: summary: Update subagent skills tags: - subagent + /bots/{bot_id}/tools: + post: + description: MCP endpoint for tool discovery and invocation. + parameters: + - description: Bot ID + in: path + name: bot_id + required: true + type: string + - description: JSON-RPC request + in: body + name: payload + required: true + schema: + type: object + responses: + "200": + description: 'JSON-RPC response: {jsonrpc,id,result|error}' + schema: + type: object + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Unified MCP tools gateway + tags: + - containerd /bots/{id}: delete: description: Delete a bot user (owner/admin only) diff --git a/internal/chat/resolver.go b/internal/chat/resolver.go index 12452e23..5929d599 100644 --- a/internal/chat/resolver.go +++ b/internal/chat/resolver.go @@ -132,7 +132,6 @@ type gatewayRequest struct { Channels []string `json:"channels"` CurrentChannel string `json:"currentChannel"` AllowedActions []string `json:"allowedActions,omitempty"` - MCPConnections []map[string]any `json:"mcpConnections"` Messages []ModelMessage `json:"messages"` Skills []string `json:"skills"` UsableSkills []gatewaySkill `json:"usableSkills"` @@ -163,7 +162,6 @@ type triggerScheduleRequest struct { Channels []string `json:"channels"` CurrentChannel string `json:"currentChannel"` AllowedActions []string `json:"allowedActions,omitempty"` - MCPConnections []map[string]any `json:"mcpConnections"` Messages []ModelMessage `json:"messages"` Skills []string `json:"skills"` UsableSkills []gatewaySkill `json:"usableSkills"` @@ -258,24 +256,6 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex usableSkills = []gatewaySkill{} } - mcpConnections := []map[string]any{} - if r.mcpService != nil { - items, err := r.mcpService.ListActiveByBot(ctx, req.BotID) - if err != nil { - r.logger.Warn("failed to load mcp connections", slog.String("bot_id", req.BotID), slog.Any("error", err)) - } else { - for _, item := range items { - payload := map[string]any{} - for k, v := range item.Config { - payload[k] = v - } - payload["name"] = item.Name - payload["type"] = item.Type - mcpConnections = append(mcpConnections, payload) - } - } - } - payload := gatewayRequest{ Model: gatewayModelConfig{ ModelID: chatModel.ModelID, @@ -288,7 +268,6 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex Channels: nonNilStrings(req.Channels), CurrentChannel: req.CurrentChannel, AllowedActions: req.AllowedActions, - MCPConnections: mcpConnections, Messages: nonNilModelMessages(messages), Skills: nonNilStrings(skills), UsableSkills: usableSkills, @@ -365,7 +344,6 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc Channels: rc.payload.Channels, CurrentChannel: rc.payload.CurrentChannel, AllowedActions: rc.payload.AllowedActions, - MCPConnections: rc.payload.MCPConnections, Messages: rc.payload.Messages, Skills: rc.payload.Skills, UsableSkills: rc.payload.UsableSkills, diff --git a/internal/chat/resolver_memory_context_test.go b/internal/chat/resolver_memory_context_test.go new file mode 100644 index 00000000..082bad6d --- /dev/null +++ b/internal/chat/resolver_memory_context_test.go @@ -0,0 +1,61 @@ +package chat + +import ( + "context" + "log/slog" + "strings" + "testing" + + "github.com/memohai/memoh/internal/memory" +) + +func TestLoadMemoryContextMessage_NoMemoryService(t *testing.T) { + resolver := &Resolver{ + memoryService: nil, + logger: slog.Default(), + } + msg := resolver.loadMemoryContextMessage(context.Background(), ChatRequest{ + Query: "hello", + BotID: "bot-1", + ChatID: "chat-1", + }, Settings{ + EnableChatMemory: true, + }) + if msg != nil { + t.Fatalf("expected nil message when memory service is nil") + } +} + +func TestLoadMemoryContextMessage_SearchFailureFallback(t *testing.T) { + resolver := &Resolver{ + memoryService: &memory.Service{}, + logger: slog.Default(), + } + msg := resolver.loadMemoryContextMessage(context.Background(), ChatRequest{ + Query: "hello", + BotID: "bot-1", + ChatID: "chat-1", + UserID: "user-1", + }, Settings{ + EnableChatMemory: true, + EnablePrivateMemory: true, + EnablePublicMemory: true, + }) + if msg != nil { + t.Fatalf("expected nil message when memory search cannot return results") + } +} + +func TestTruncateMemorySnippet(t *testing.T) { + longText := strings.Repeat("a", 20) + " " + got := truncateMemorySnippet(longText, 10) + if got != strings.Repeat("a", 10)+"..." { + t.Fatalf("unexpected truncated value: %q", got) + } + + shortText := " short " + got = truncateMemorySnippet(shortText, 10) + if got != "short" { + t.Fatalf("unexpected trimmed short value: %q", got) + } +} diff --git a/internal/handlers/containerd.go b/internal/handlers/containerd.go index 6abdbdcb..e65bfef2 100644 --- a/internal/handlers/containerd.go +++ b/internal/handlers/containerd.go @@ -37,6 +37,7 @@ type ContainerdHandler struct { cfg config.MCPConfig namespace string logger *slog.Logger + toolGateway *mcp.ToolGatewayService mcpMu sync.Mutex mcpSess map[string]*mcpSession mcpStdioMu sync.Mutex @@ -135,6 +136,7 @@ func (h *ContainerdHandler) Register(e *echo.Echo) { fs.DELETE("", h.DeleteFS) root.POST("/mcp-stdio", h.CreateMCPStdio) root.POST("/mcp-stdio/:session_id", h.HandleMCPStdio) + root.POST("/tools", h.HandleMCPTools) } // CreateContainer godoc diff --git a/internal/handlers/fs.go b/internal/handlers/fs.go index d10d5b35..815dbc08 100644 --- a/internal/handlers/fs.go +++ b/internal/handlers/fs.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -17,6 +18,8 @@ import ( "github.com/containerd/containerd/v2/pkg/namespaces" "github.com/containerd/errdefs" "github.com/labstack/echo/v4" + sdkjsonrpc "github.com/modelcontextprotocol/go-sdk/jsonrpc" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" ctr "github.com/memohai/memoh/internal/containerd" mcptools "github.com/memohai/memoh/internal/mcp" @@ -141,16 +144,27 @@ type mcpSession struct { stdout io.ReadCloser stderr io.ReadCloser cmd *exec.Cmd - initOnce sync.Once - writeMu sync.Mutex + initMu sync.Mutex + initState mcpSessionInitState + initWait chan struct{} pendingMu sync.Mutex - pending map[string]chan mcptools.JSONRPCResponse + pending map[string]chan *sdkjsonrpc.Response + conn sdkmcp.Connection closed chan struct{} closeOnce sync.Once closeErr error onClose func() } +type mcpSessionInitState uint8 + +const ( + mcpSessionInitStateNone mcpSessionInitState = iota + mcpSessionInitStateInitializing + mcpSessionInitStateInitialized + mcpSessionInitStateReady +) + func (h *ContainerdHandler) getMCPSession(ctx context.Context, containerID string) (*mcpSession, error) { h.mcpMu.Lock() if sess, ok := h.mcpSess[containerID]; ok { @@ -198,9 +212,19 @@ func (h *ContainerdHandler) startContainerdMCPSession(ctx context.Context, conta stdin: execSession.Stdin, stdout: execSession.Stdout, stderr: execSession.Stderr, - pending: make(map[string]chan mcptools.JSONRPCResponse), + pending: make(map[string]chan *sdkjsonrpc.Response), closed: make(chan struct{}), } + transport := &sdkmcp.IOTransport{ + Reader: sess.stdout, + Writer: sess.stdin, + } + conn, err := transport.Connect(ctx) + if err != nil { + sess.closeWithError(err) + return nil, err + } + sess.conn = conn h.startMCPStderrLogger(execSession.Stderr, containerID) go sess.readLoop() @@ -265,9 +289,19 @@ func (h *ContainerdHandler) startLimaMCPSession(containerID string) (*mcpSession stdout: stdout, stderr: stderr, cmd: cmd, - pending: make(map[string]chan mcptools.JSONRPCResponse), + pending: make(map[string]chan *sdkjsonrpc.Response), closed: make(chan struct{}), } + transport := &sdkmcp.IOTransport{ + Reader: sess.stdout, + Writer: sess.stdin, + } + conn, err := transport.Connect(context.Background()) + if err != nil { + sess.closeWithError(err) + return nil, err + } + sess.conn = conn h.startMCPStderrLogger(stderr, containerID) go sess.readLoop() @@ -291,11 +325,20 @@ func (s *mcpSession) closeWithError(err error) { for _, ch := range s.pending { close(ch) } - s.pending = map[string]chan mcptools.JSONRPCResponse{} + s.pending = map[string]chan *sdkjsonrpc.Response{} s.pendingMu.Unlock() - _ = s.stdin.Close() - _ = s.stdout.Close() - _ = s.stderr.Close() + if s.conn != nil { + _ = s.conn.Close() + } + if s.stdin != nil { + _ = s.stdin.Close() + } + if s.stdout != nil { + _ = s.stdout.Close() + } + if s.stderr != nil { + _ = s.stderr.Close() + } if s.cmd != nil && s.cmd.Process != nil { _ = s.cmd.Process.Kill() } @@ -326,18 +369,25 @@ func (h *ContainerdHandler) startMCPStderrLogger(stderr io.ReadCloser, container } func (s *mcpSession) readLoop() { - scanner := bufio.NewScanner(s.stdout) - scanner.Buffer(make([]byte, 0, 64*1024), 8*1024*1024) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" { + if s.conn == nil { + s.closeWithError(io.EOF) + return + } + for { + msg, err := s.conn.Read(context.Background()) + if err != nil { + if errors.Is(err, io.EOF) { + s.closeWithError(io.EOF) + return + } + s.closeWithError(err) + return + } + resp, ok := msg.(*sdkjsonrpc.Response) + if !ok || !resp.ID.IsValid() { continue } - var resp mcptools.JSONRPCResponse - if err := json.Unmarshal([]byte(line), &resp); err != nil { - continue - } - id := strings.TrimSpace(string(resp.ID)) + id := sdkIDKey(resp.ID) if id == "" { continue } @@ -352,29 +402,43 @@ func (s *mcpSession) readLoop() { close(ch) } } - if err := scanner.Err(); err != nil { - s.closeWithError(err) - } else { - s.closeWithError(io.EOF) - } } func (s *mcpSession) call(ctx context.Context, req mcptools.JSONRPCRequest) (map[string]any, error) { - payloads, targetID, err := mcptools.BuildPayloads(req, &s.initOnce) + method := strings.TrimSpace(req.Method) + if method == "initialize" { + return s.callInitialize(ctx, req) + } + if method != "notifications/initialized" { + if err := s.ensureInitialized(ctx); err != nil { + return nil, err + } + } + + targetID, err := parseRawJSONRPCID(req.ID) if err != nil { return nil, err } - target := strings.TrimSpace(string(targetID)) + target := sdkIDKey(targetID) if target == "" { return nil, fmt.Errorf("missing request id") } + if s.conn == nil { + return nil, io.EOF + } - respCh := make(chan mcptools.JSONRPCResponse, 1) + respCh := make(chan *sdkjsonrpc.Response, 1) s.pendingMu.Lock() s.pending[target] = respCh s.pendingMu.Unlock() - if err := s.writePayloads(payloads); err != nil { + callReq := &sdkjsonrpc.Request{ + ID: targetID, + Method: method, + Params: req.Params, + } + if err := s.conn.Write(ctx, callReq); err != nil { + s.removePending(target) return nil, err } @@ -386,46 +450,347 @@ func (s *mcpSession) call(ctx context.Context, req mcptools.JSONRPCRequest) (map } return nil, io.EOF } - if resp.Error != nil { - return map[string]any{ - "jsonrpc": "2.0", - "id": resp.ID, - "error": map[string]any{ - "code": resp.Error.Code, - "message": resp.Error.Message, - }, - }, nil + if method == "notifications/initialized" { + s.setInitStateAtLeast(mcpSessionInitStateReady) } - return map[string]any{ - "jsonrpc": "2.0", - "id": resp.ID, - "result": resp.Result, - }, nil + return sdkResponsePayload(resp) case <-s.closed: if s.closeErr != nil { return nil, s.closeErr } return nil, io.EOF case <-ctx.Done(): + s.removePending(target) + return nil, ctx.Err() + } +} + +func (s *mcpSession) callInitialize(ctx context.Context, req mcptools.JSONRPCRequest) (map[string]any, error) { + payload, err := s.callRaw(ctx, req) + if err != nil { + return nil, err + } + if err := mcptools.PayloadError(payload); err != nil { + return payload, nil + } + s.setInitStateAtLeast(mcpSessionInitStateInitialized) + return payload, nil +} + +func (s *mcpSession) callRaw(ctx context.Context, req mcptools.JSONRPCRequest) (map[string]any, error) { + method := strings.TrimSpace(req.Method) + targetID, err := parseRawJSONRPCID(req.ID) + if err != nil { + return nil, err + } + target := sdkIDKey(targetID) + if target == "" { + return nil, fmt.Errorf("missing request id") + } + if s.conn == nil { + return nil, io.EOF + } + + respCh := make(chan *sdkjsonrpc.Response, 1) + s.pendingMu.Lock() + s.pending[target] = respCh + s.pendingMu.Unlock() + + callReq := &sdkjsonrpc.Request{ + ID: targetID, + Method: method, + Params: req.Params, + } + if err := s.conn.Write(ctx, callReq); err != nil { + s.removePending(target) + return nil, err + } + + select { + case resp, ok := <-respCh: + if !ok { + if s.closeErr != nil { + return nil, s.closeErr + } + return nil, io.EOF + } + return sdkResponsePayload(resp) + case <-s.closed: + if s.closeErr != nil { + return nil, s.closeErr + } + return nil, io.EOF + case <-ctx.Done(): + s.removePending(target) return nil, ctx.Err() } } func (s *mcpSession) notify(ctx context.Context, req mcptools.JSONRPCRequest) error { - payloads, err := mcptools.BuildNotificationPayloads(req) - if err != nil { + if s.conn == nil { + return io.EOF + } + method := strings.TrimSpace(req.Method) + notification := &sdkjsonrpc.Request{ + Method: method, + Params: req.Params, + } + if err := s.conn.Write(ctx, notification); err != nil { return err } - return s.writePayloads(payloads) -} - -func (s *mcpSession) writePayloads(payloads []string) error { - s.writeMu.Lock() - defer s.writeMu.Unlock() - for _, payload := range payloads { - if _, err := s.stdin.Write([]byte(payload + "\n")); err != nil { - return err - } + if method == "notifications/initialized" { + s.setInitStateAtLeast(mcpSessionInitStateReady) } return nil } + +func (s *mcpSession) ensureInitialized(ctx context.Context) error { + for { + s.initMu.Lock() + switch s.initState { + case mcpSessionInitStateReady: + s.initMu.Unlock() + return nil + case mcpSessionInitStateInitializing: + waitCh := s.initWait + s.initMu.Unlock() + if waitCh == nil { + continue + } + select { + case <-waitCh: + continue + case <-ctx.Done(): + return ctx.Err() + case <-s.closed: + if s.closeErr != nil { + return s.closeErr + } + return io.EOF + } + case mcpSessionInitStateInitialized: + waitCh := make(chan struct{}) + s.initState = mcpSessionInitStateInitializing + s.initWait = waitCh + s.initMu.Unlock() + + err := s.sendInitializedNotification(ctx) + + s.initMu.Lock() + if err == nil { + s.initState = mcpSessionInitStateReady + } else { + s.initState = mcpSessionInitStateInitialized + } + s.initWait = nil + close(waitCh) + s.initMu.Unlock() + + if err != nil { + return err + } + return nil + default: + waitCh := make(chan struct{}) + s.initState = mcpSessionInitStateInitializing + s.initWait = waitCh + s.initMu.Unlock() + + nextState, err := s.initializeHandshake(ctx) + + s.initMu.Lock() + s.initState = nextState + s.initWait = nil + close(waitCh) + s.initMu.Unlock() + + if err != nil { + return err + } + if nextState == mcpSessionInitStateReady { + return nil + } + } + } +} + +func (s *mcpSession) initializeHandshake(ctx context.Context) (mcpSessionInitState, error) { + params, err := json.Marshal(map[string]any{ + "protocolVersion": "2025-06-18", + "capabilities": map[string]any{ + "roots": map[string]any{ + "listChanged": false, + }, + }, + "clientInfo": map[string]any{ + "name": "memoh-http-proxy", + "version": "v0", + }, + }) + if err != nil { + return mcpSessionInitStateNone, err + } + initID, err := sdkjsonrpc.MakeID("init-1") + if err != nil { + return mcpSessionInitStateNone, err + } + initResp, err := s.invokeCall(ctx, &sdkjsonrpc.Request{ + ID: initID, + Method: "initialize", + Params: params, + }) + if err != nil { + return mcpSessionInitStateNone, err + } + if initResp.Error != nil { + return mcpSessionInitStateNone, initResp.Error + } + if err := s.sendInitializedNotification(ctx); err != nil { + return mcpSessionInitStateInitialized, err + } + return mcpSessionInitStateReady, nil +} + +func (s *mcpSession) sendInitializedNotification(ctx context.Context) error { + if s.conn == nil { + return io.EOF + } + return s.conn.Write(ctx, &sdkjsonrpc.Request{ + Method: "notifications/initialized", + }) +} + +func (s *mcpSession) invokeCall(ctx context.Context, req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { + if s.conn == nil { + return nil, io.EOF + } + if req == nil || !req.ID.IsValid() { + return nil, fmt.Errorf("missing request id") + } + key := sdkIDKey(req.ID) + if key == "" { + return nil, fmt.Errorf("invalid request id") + } + + respCh := make(chan *sdkjsonrpc.Response, 1) + s.pendingMu.Lock() + s.pending[key] = respCh + s.pendingMu.Unlock() + + if err := s.conn.Write(ctx, req); err != nil { + s.removePending(key) + return nil, err + } + + select { + case resp, ok := <-respCh: + if !ok { + if s.closeErr != nil { + return nil, s.closeErr + } + return nil, io.EOF + } + return resp, nil + case <-s.closed: + if s.closeErr != nil { + return nil, s.closeErr + } + return nil, io.EOF + case <-ctx.Done(): + s.removePending(key) + return nil, ctx.Err() + } +} + +func (s *mcpSession) removePending(key string) { + if strings.TrimSpace(key) == "" { + return + } + s.pendingMu.Lock() + delete(s.pending, key) + s.pendingMu.Unlock() +} + +func (s *mcpSession) setInitStateAtLeast(next mcpSessionInitState) { + s.initMu.Lock() + if s.initState != mcpSessionInitStateInitializing && s.initState < next { + s.initState = next + } + s.initMu.Unlock() +} + +func parseRawJSONRPCID(raw json.RawMessage) (sdkjsonrpc.ID, error) { + if len(raw) == 0 { + return sdkjsonrpc.ID{}, fmt.Errorf("missing request id") + } + var idValue any + if err := json.Unmarshal(raw, &idValue); err != nil { + return sdkjsonrpc.ID{}, err + } + id, err := sdkjsonrpc.MakeID(idValue) + if err != nil { + return sdkjsonrpc.ID{}, err + } + if !id.IsValid() { + return sdkjsonrpc.ID{}, fmt.Errorf("missing request id") + } + return id, nil +} + +func sdkIDKey(id sdkjsonrpc.ID) string { + if !id.IsValid() { + return "" + } + raw, err := json.Marshal(id.Raw()) + if err != nil { + return "" + } + return string(raw) +} + +func sdkIDRaw(id sdkjsonrpc.ID) json.RawMessage { + if !id.IsValid() { + return nil + } + raw, err := json.Marshal(id.Raw()) + if err != nil { + return nil + } + return json.RawMessage(raw) +} + +func sdkResponsePayload(resp *sdkjsonrpc.Response) (map[string]any, error) { + if resp == nil { + return nil, io.EOF + } + if resp.Error != nil { + code := int64(-32603) + message := strings.TrimSpace(resp.Error.Error()) + if wireErr, ok := resp.Error.(*sdkjsonrpc.Error); ok { + code = wireErr.Code + message = strings.TrimSpace(wireErr.Message) + } + if message == "" { + message = "internal error" + } + return map[string]any{ + "jsonrpc": "2.0", + "id": sdkIDRaw(resp.ID), + "error": map[string]any{ + "code": code, + "message": message, + }, + }, nil + } + var result any + if len(resp.Result) > 0 { + if err := json.Unmarshal(resp.Result, &result); err != nil { + return nil, err + } + } + return map[string]any{ + "jsonrpc": "2.0", + "id": sdkIDRaw(resp.ID), + "result": result, + }, nil +} diff --git a/internal/handlers/fs_mcp_session_test.go b/internal/handlers/fs_mcp_session_test.go new file mode 100644 index 00000000..3ef000ca --- /dev/null +++ b/internal/handlers/fs_mcp_session_test.go @@ -0,0 +1,255 @@ +package handlers + +import ( + "context" + "encoding/json" + "fmt" + "io" + "sync" + "testing" + "time" + + sdkjsonrpc "github.com/modelcontextprotocol/go-sdk/jsonrpc" + + mcptools "github.com/memohai/memoh/internal/mcp" +) + +type fakeMCPConnection struct { + mu sync.Mutex + writes []*sdkjsonrpc.Request + readCh chan sdkjsonrpc.Message + closed chan struct{} + closeMu sync.Once + onWrite func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) +} + +func newFakeMCPConnection(onWrite func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error)) *fakeMCPConnection { + return &fakeMCPConnection{ + writes: make([]*sdkjsonrpc.Request, 0, 16), + readCh: make(chan sdkjsonrpc.Message, 32), + closed: make(chan struct{}), + onWrite: onWrite, + } +} + +func (c *fakeMCPConnection) Read(ctx context.Context) (sdkjsonrpc.Message, error) { + select { + case <-c.closed: + return nil, io.EOF + case <-ctx.Done(): + return nil, ctx.Err() + case msg, ok := <-c.readCh: + if !ok { + return nil, io.EOF + } + return msg, nil + } +} + +func (c *fakeMCPConnection) Write(ctx context.Context, msg sdkjsonrpc.Message) error { + req, ok := msg.(*sdkjsonrpc.Request) + if !ok { + return fmt.Errorf("unsupported message type: %T", msg) + } + cloned := cloneJSONRPCRequest(req) + c.mu.Lock() + c.writes = append(c.writes, cloned) + c.mu.Unlock() + + if c.onWrite == nil { + return nil + } + resp, err := c.onWrite(cloned) + if err != nil { + return err + } + if resp == nil { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.closed: + return io.EOF + case c.readCh <- resp: + return nil + } +} + +func (c *fakeMCPConnection) Close() error { + c.closeMu.Do(func() { + close(c.closed) + close(c.readCh) + }) + return nil +} + +func (c *fakeMCPConnection) SessionID() string { + return "test-session" +} + +func cloneJSONRPCRequest(req *sdkjsonrpc.Request) *sdkjsonrpc.Request { + if req == nil { + return nil + } + params := append([]byte(nil), req.Params...) + return &sdkjsonrpc.Request{ + ID: req.ID, + Method: req.Method, + Params: params, + Extra: req.Extra, + } +} + +func jsonRPCSuccessResponse(id sdkjsonrpc.ID, payload map[string]any) *sdkjsonrpc.Response { + body, _ := json.Marshal(payload) + return &sdkjsonrpc.Response{ + ID: id, + Result: body, + } +} + +func newTestMCPSession(conn *fakeMCPConnection) *mcpSession { + return &mcpSession{ + pending: map[string]chan *sdkjsonrpc.Response{}, + conn: conn, + closed: make(chan struct{}), + } +} + +func TestMCPSessionRetriesInitializeAfterFailure(t *testing.T) { + initCalls := 0 + conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { + switch req.Method { + case "initialize": + initCalls++ + if initCalls == 1 { + return &sdkjsonrpc.Response{ + ID: req.ID, + Error: &sdkjsonrpc.Error{ + Code: -32603, + Message: "temporary init failure", + }, + }, nil + } + return jsonRPCSuccessResponse(req.ID, map[string]any{ + "protocolVersion": "2025-06-18", + }), nil + case "tools/list": + return jsonRPCSuccessResponse(req.ID, map[string]any{ + "tools": []any{}, + }), nil + default: + return nil, nil + } + }) + session := newTestMCPSession(conn) + go session.readLoop() + defer session.closeWithError(io.EOF) + + _, firstErr := session.call(context.Background(), mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("1"), + Method: "tools/list", + }) + if firstErr == nil { + t.Fatalf("first call should fail when initialize fails") + } + + secondPayload, secondErr := session.call(context.Background(), mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("2"), + Method: "tools/list", + }) + if secondErr != nil { + t.Fatalf("second call should recover by retrying initialize: %v", secondErr) + } + if initCalls != 2 { + t.Fatalf("initialize should be retried once, got calls: %d", initCalls) + } + result, ok := secondPayload["result"].(map[string]any) + if !ok { + t.Fatalf("missing tools/list result: %#v", secondPayload) + } + if _, ok := result["tools"].([]any); !ok { + t.Fatalf("missing tools field: %#v", result) + } +} + +func TestMCPSessionExplicitInitializeDoesNotDuplicateInitialize(t *testing.T) { + initializeCalls := 0 + initializedNotifications := 0 + conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { + switch req.Method { + case "initialize": + initializeCalls++ + return jsonRPCSuccessResponse(req.ID, map[string]any{ + "protocolVersion": "2025-06-18", + }), nil + case "notifications/initialized": + initializedNotifications++ + return nil, nil + case "tools/list": + return jsonRPCSuccessResponse(req.ID, map[string]any{ + "tools": []any{}, + }), nil + default: + return nil, nil + } + }) + session := newTestMCPSession(conn) + go session.readLoop() + defer session.closeWithError(io.EOF) + + _, initErr := session.call(context.Background(), mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("100"), + Method: "initialize", + Params: json.RawMessage(`{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"v1"}}`), + }) + if initErr != nil { + t.Fatalf("explicit initialize should succeed: %v", initErr) + } + + _, listErr := session.call(context.Background(), mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("101"), + Method: "tools/list", + }) + if listErr != nil { + t.Fatalf("tools/list after initialize should succeed: %v", listErr) + } + if initializeCalls != 1 { + t.Fatalf("initialize should not be duplicated, got: %d", initializeCalls) + } + if initializedNotifications != 1 { + t.Fatalf("should send exactly one notifications/initialized, got: %d", initializedNotifications) + } +} + +func TestMCPSessionRemovesPendingOnContextCancel(t *testing.T) { + conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { + // Intentionally do not reply; caller should timeout. + return nil, nil + }) + session := newTestMCPSession(conn) + session.initState = mcpSessionInitStateReady + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) + defer cancel() + _, err := session.call(ctx, mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("200"), + Method: "tools/list", + }) + if err == nil { + t.Fatalf("call should fail on context timeout") + } + + session.pendingMu.Lock() + pendingCount := len(session.pending) + session.pendingMu.Unlock() + if pendingCount != 0 { + t.Fatalf("pending map should be empty after cancellation, got: %d", pendingCount) + } +} diff --git a/internal/handlers/mcp_federation_gateway.go b/internal/handlers/mcp_federation_gateway.go new file mode 100644 index 00000000..4872c9f1 --- /dev/null +++ b/internal/handlers/mcp_federation_gateway.go @@ -0,0 +1,534 @@ +package handlers + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "strings" + "time" + + mcpgw "github.com/memohai/memoh/internal/mcp" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type MCPFederationGateway struct { + handler *ContainerdHandler + logger *slog.Logger + client *http.Client +} + +func NewMCPFederationGateway(log *slog.Logger, handler *ContainerdHandler) *MCPFederationGateway { + if log == nil { + log = slog.Default() + } + return &MCPFederationGateway{ + handler: handler, + logger: log.With(slog.String("gateway", "mcp_federation")), + client: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +func (g *MCPFederationGateway) ListFSMCPTools(ctx context.Context, botID string) ([]mcpgw.ToolDescriptor, error) { + if g.handler == nil { + return nil, fmt.Errorf("containerd handler not configured") + } + containerID, err := g.handler.botContainerID(ctx, botID) + if err != nil { + return nil, err + } + if err := g.handler.validateMCPContainer(ctx, containerID, botID); err != nil { + return nil, err + } + if err := g.handler.ensureTaskRunning(ctx, containerID); err != nil { + return nil, err + } + payload, err := g.handler.callMCPServer(ctx, containerID, mcpgw.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcpgw.RawStringID("federated-fs-tools-list"), + Method: "tools/list", + }) + if err != nil { + return nil, err + } + return parseGatewayToolsListPayload(payload) +} + +func (g *MCPFederationGateway) CallFSMCPTool(ctx context.Context, botID, toolName string, args map[string]any) (map[string]any, error) { + if g.handler == nil { + return nil, fmt.Errorf("containerd handler not configured") + } + containerID, err := g.handler.botContainerID(ctx, botID) + if err != nil { + return nil, err + } + if err := g.handler.validateMCPContainer(ctx, containerID, botID); err != nil { + return nil, err + } + if err := g.handler.ensureTaskRunning(ctx, containerID); err != nil { + return nil, err + } + params, err := json.Marshal(map[string]any{ + "name": toolName, + "arguments": args, + }) + if err != nil { + return nil, err + } + return g.handler.callMCPServer(ctx, containerID, mcpgw.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcpgw.RawStringID("federated-fs-tools-call"), + Method: "tools/call", + Params: params, + }) +} + +func (g *MCPFederationGateway) ListHTTPConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { + session, err := g.connectStreamableSession(ctx, connection) + if err != nil { + return nil, err + } + defer func() { _ = session.Close() }() + result, err := session.ListTools(ctx, &sdkmcp.ListToolsParams{}) + if err != nil { + return nil, err + } + return convertSDKTools(result.Tools), nil +} + +func (g *MCPFederationGateway) CallHTTPConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { + session, err := g.connectStreamableSession(ctx, connection) + if err != nil { + return nil, err + } + defer func() { _ = session.Close() }() + result, err := session.CallTool(ctx, &sdkmcp.CallToolParams{ + Name: strings.TrimSpace(toolName), + Arguments: args, + }) + if err != nil { + return nil, err + } + return wrapSDKToolResult(result) +} + +func (g *MCPFederationGateway) ListSSEConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { + session, err := g.connectSSESession(ctx, connection) + if err != nil { + return nil, err + } + defer func() { _ = session.Close() }() + result, err := session.ListTools(ctx, &sdkmcp.ListToolsParams{}) + if err != nil { + return nil, err + } + return convertSDKTools(result.Tools), nil +} + +func (g *MCPFederationGateway) CallSSEConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { + session, err := g.connectSSESession(ctx, connection) + if err != nil { + return nil, err + } + defer func() { _ = session.Close() }() + result, err := session.CallTool(ctx, &sdkmcp.CallToolParams{ + Name: strings.TrimSpace(toolName), + Arguments: args, + }) + if err != nil { + return nil, err + } + return wrapSDKToolResult(result) +} + +func (g *MCPFederationGateway) connectStreamableSession(ctx context.Context, connection mcpgw.Connection) (*sdkmcp.ClientSession, error) { + url := strings.TrimSpace(anyToString(connection.Config["url"])) + if url == "" { + return nil, fmt.Errorf("http mcp url is required") + } + client := sdkmcp.NewClient(&sdkmcp.Implementation{ + Name: "memoh-federation-client", + Version: "v1", + }, nil) + transport := &sdkmcp.StreamableClientTransport{ + Endpoint: url, + HTTPClient: g.connectionHTTPClient(connection), + MaxRetries: -1, + } + return client.Connect(ctx, transport, nil) +} + +func (g *MCPFederationGateway) connectSSESession(ctx context.Context, connection mcpgw.Connection) (*sdkmcp.ClientSession, error) { + endpoints := resolveSSEEndpointCandidates(connection.Config) + if len(endpoints) == 0 { + return nil, fmt.Errorf("sse mcp url is required") + } + var lastErr error + for _, endpoint := range endpoints { + client := sdkmcp.NewClient(&sdkmcp.Implementation{ + Name: "memoh-federation-client", + Version: "v1", + }, nil) + transport := &sdkmcp.SSEClientTransport{ + Endpoint: endpoint, + HTTPClient: g.connectionHTTPClient(connection), + } + session, err := client.Connect(ctx, transport, nil) + if err == nil { + return session, nil + } + lastErr = err + } + if lastErr == nil { + lastErr = fmt.Errorf("no sse endpoint candidate available") + } + return nil, fmt.Errorf("connect sse mcp failed: %w", lastErr) +} + +func resolveSSEEndpointCandidates(config map[string]any) []string { + if config == nil { + return []string{} + } + + seen := map[string]struct{}{} + out := make([]string, 0, 4) + appendEndpoint := func(value string) { + value = strings.TrimSpace(value) + if value == "" { + return + } + if _, ok := seen[value]; ok { + return + } + seen[value] = struct{}{} + out = append(out, value) + } + + for _, key := range []string{"sse_url", "sseUrl"} { + appendEndpoint(anyToString(config[key])) + } + + baseURL := strings.TrimSpace(anyToString(config["url"])) + appendEndpoint(baseURL) + + var messageURL string + for _, key := range []string{"message_url", "messageUrl"} { + if value := strings.TrimSpace(anyToString(config[key])); value != "" { + messageURL = value + break + } + } + if messageURL != "" { + normalized := strings.TrimSuffix(messageURL, "/") + if strings.HasSuffix(normalized, "/message") { + appendEndpoint(strings.TrimSuffix(normalized, "/message") + "/sse") + } + appendEndpoint(messageURL) + } + + if baseURL != "" { + normalized := strings.TrimSuffix(baseURL, "/") + if strings.HasSuffix(normalized, "/message") { + appendEndpoint(strings.TrimSuffix(normalized, "/message") + "/sse") + } + } + + return out +} + +func (g *MCPFederationGateway) connectionHTTPClient(connection mcpgw.Connection) *http.Client { + base := g.client + if base == nil { + base = &http.Client{Timeout: 30 * time.Second} + } + headers := normalizeHeaderMap(connection.Config["headers"]) + if len(headers) == 0 { + return base + } + transport := base.Transport + if transport == nil { + transport = http.DefaultTransport + } + return &http.Client{ + Timeout: base.Timeout, + CheckRedirect: base.CheckRedirect, + Jar: base.Jar, + Transport: &staticHeaderRoundTripper{ + next: transport, + headers: headers, + }, + } +} + +func (g *MCPFederationGateway) ListStdioConnectionTools(ctx context.Context, botID string, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { + sess, err := g.startStdioConnectionSession(ctx, botID, connection) + if err != nil { + return nil, err + } + defer sess.closeWithError(io.EOF) + + payload, err := sess.call(ctx, mcpgw.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcpgw.RawStringID("federated-stdio-tools-list"), + Method: "tools/list", + }) + if err != nil { + return nil, err + } + return parseGatewayToolsListPayload(payload) +} + +func (g *MCPFederationGateway) CallStdioConnectionTool(ctx context.Context, botID string, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { + sess, err := g.startStdioConnectionSession(ctx, botID, connection) + if err != nil { + return nil, err + } + defer sess.closeWithError(io.EOF) + + params, err := json.Marshal(map[string]any{ + "name": toolName, + "arguments": args, + }) + if err != nil { + return nil, err + } + return sess.call(ctx, mcpgw.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcpgw.RawStringID("federated-stdio-tools-call"), + Method: "tools/call", + Params: params, + }) +} + +func (g *MCPFederationGateway) startStdioConnectionSession(ctx context.Context, botID string, connection mcpgw.Connection) (*mcpSession, error) { + if g.handler == nil { + return nil, fmt.Errorf("containerd handler not configured") + } + containerID, err := g.handler.botContainerID(ctx, botID) + if err != nil { + return nil, err + } + if err := g.handler.validateMCPContainer(ctx, containerID, botID); err != nil { + return nil, err + } + if err := g.handler.ensureTaskRunning(ctx, containerID); err != nil { + return nil, err + } + + command := strings.TrimSpace(anyToString(connection.Config["command"])) + if command == "" { + return nil, fmt.Errorf("stdio mcp command is required") + } + request := MCPStdioRequest{ + Name: strings.TrimSpace(connection.Name), + Command: command, + Args: normalizeStringSlice(connection.Config["args"]), + Env: normalizeStringMap(connection.Config["env"]), + Cwd: strings.TrimSpace(anyToString(connection.Config["cwd"])), + } + return g.handler.startContainerdMCPCommandSession(ctx, containerID, request) +} + +func parseGatewayToolsListPayload(payload map[string]any) ([]mcpgw.ToolDescriptor, error) { + if err := mcpgw.PayloadError(payload); err != nil { + return nil, err + } + result, ok := payload["result"].(map[string]any) + if !ok { + return nil, fmt.Errorf("invalid tools/list result") + } + rawTools, ok := result["tools"].([]any) + if !ok { + return nil, fmt.Errorf("invalid tools/list tools field") + } + tools := make([]mcpgw.ToolDescriptor, 0, len(rawTools)) + for _, rawTool := range rawTools { + item, ok := rawTool.(map[string]any) + if !ok { + continue + } + name := strings.TrimSpace(anyToString(item["name"])) + if name == "" { + continue + } + description := strings.TrimSpace(anyToString(item["description"])) + inputSchema, _ := item["inputSchema"].(map[string]any) + if inputSchema == nil { + inputSchema = map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + } + tools = append(tools, mcpgw.ToolDescriptor{ + Name: name, + Description: description, + InputSchema: inputSchema, + }) + } + return tools, nil +} + +func convertSDKTools(items []*sdkmcp.Tool) []mcpgw.ToolDescriptor { + if len(items) == 0 { + return []mcpgw.ToolDescriptor{} + } + tools := make([]mcpgw.ToolDescriptor, 0, len(items)) + for _, item := range items { + if item == nil { + continue + } + name := strings.TrimSpace(item.Name) + if name == "" { + continue + } + tools = append(tools, mcpgw.ToolDescriptor{ + Name: name, + Description: strings.TrimSpace(item.Description), + InputSchema: normalizeToolInputSchema(item.InputSchema), + }) + } + return tools +} + +func normalizeToolInputSchema(raw any) map[string]any { + if schema, ok := raw.(map[string]any); ok && schema != nil { + return schema + } + if raw != nil { + payload, err := json.Marshal(raw) + if err == nil { + var schema map[string]any + if err := json.Unmarshal(payload, &schema); err == nil && schema != nil { + return schema + } + } + } + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +func wrapSDKToolResult(result *sdkmcp.CallToolResult) (map[string]any, error) { + if result == nil { + return map[string]any{ + "result": mcpgw.BuildToolSuccessResult(map[string]any{"ok": true}), + }, nil + } + payload, err := json.Marshal(result) + if err != nil { + return nil, err + } + var parsed map[string]any + if err := json.Unmarshal(payload, &parsed); err != nil { + return nil, err + } + if parsed == nil { + parsed = map[string]any{} + } + return map[string]any{"result": parsed}, nil +} + +func normalizeHeaderMap(raw any) map[string]string { + switch value := raw.(type) { + case map[string]string: + return value + case map[string]any: + out := make(map[string]string, len(value)) + for k, v := range value { + key := strings.TrimSpace(k) + val := strings.TrimSpace(anyToString(v)) + if key == "" || val == "" { + continue + } + out[key] = val + } + return out + default: + return map[string]string{} + } +} + +func normalizeStringSlice(raw any) []string { + switch value := raw.(type) { + case []string: + out := make([]string, 0, len(value)) + for _, item := range value { + item = strings.TrimSpace(item) + if item != "" { + out = append(out, item) + } + } + return out + case []any: + out := make([]string, 0, len(value)) + for _, item := range value { + val := strings.TrimSpace(anyToString(item)) + if val != "" { + out = append(out, val) + } + } + return out + default: + return []string{} + } +} + +func normalizeStringMap(raw any) map[string]string { + switch value := raw.(type) { + case map[string]string: + return value + case map[string]any: + out := make(map[string]string, len(value)) + for k, v := range value { + key := strings.TrimSpace(k) + val := strings.TrimSpace(anyToString(v)) + if key == "" { + continue + } + out[key] = val + } + return out + default: + return map[string]string{} + } +} + +func anyToString(v any) string { + if v == nil { + return "" + } + switch value := v.(type) { + case string: + return value + default: + return fmt.Sprintf("%v", v) + } +} + +type staticHeaderRoundTripper struct { + next http.RoundTripper + headers map[string]string +} + +func (t *staticHeaderRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + next := t.next + if next == nil { + next = http.DefaultTransport + } + clone := req.Clone(req.Context()) + clone.Header = req.Header.Clone() + for key, value := range t.headers { + headerKey := strings.TrimSpace(key) + headerVal := strings.TrimSpace(value) + if headerKey == "" || headerVal == "" { + continue + } + clone.Header.Set(headerKey, headerVal) + } + return next.RoundTrip(clone) +} diff --git a/internal/handlers/mcp_federation_gateway_test.go b/internal/handlers/mcp_federation_gateway_test.go new file mode 100644 index 00000000..ff453626 --- /dev/null +++ b/internal/handlers/mcp_federation_gateway_test.go @@ -0,0 +1,188 @@ +package handlers + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + mcpgw "github.com/memohai/memoh/internal/mcp" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type testToolInput struct { + Query string `json:"query"` +} + +type testToolOutput struct { + Echo string `json:"echo"` +} + +func newTestMCPServer() *sdkmcp.Server { + server := sdkmcp.NewServer(&sdkmcp.Implementation{ + Name: "test-federation-server", + Version: "v1", + }, nil) + sdkmcp.AddTool(server, &sdkmcp.Tool{ + Name: "echo", + Description: "Echo query", + }, func(ctx context.Context, request *sdkmcp.CallToolRequest, input testToolInput) (*sdkmcp.CallToolResult, testToolOutput, error) { + return nil, testToolOutput{Echo: input.Query}, nil + }) + return server +} + +func withAuthHeader(next http.Handler, token string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != token { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) +} + +func TestFederationGatewayHTTPConnectionViaSDK(t *testing.T) { + server := newTestMCPServer() + handler := sdkmcp.NewStreamableHTTPHandler(func(*http.Request) *sdkmcp.Server { + return server + }, nil) + httpServer := httptest.NewServer(withAuthHeader(handler, "Bearer test-token")) + defer httpServer.Close() + + gateway := &MCPFederationGateway{ + client: httpServer.Client(), + } + connection := mcpgw.Connection{ + Config: map[string]any{ + "url": httpServer.URL, + "headers": map[string]any{ + "Authorization": "Bearer test-token", + }, + }, + } + + tools, err := gateway.ListHTTPConnectionTools(context.Background(), connection) + if err != nil { + t.Fatalf("list http tools failed: %v", err) + } + if len(tools) != 1 || tools[0].Name != "echo" { + t.Fatalf("unexpected tool list: %#v", tools) + } + + payload, err := gateway.CallHTTPConnectionTool(context.Background(), connection, "echo", map[string]any{ + "query": "hello-http", + }) + if err != nil { + t.Fatalf("call http tool failed: %v", err) + } + assertEchoResult(t, payload, "hello-http") +} + +func TestFederationGatewaySSEConnectionViaSDK(t *testing.T) { + server := newTestMCPServer() + handler := sdkmcp.NewSSEHandler(func(*http.Request) *sdkmcp.Server { + return server + }, nil) + httpServer := httptest.NewServer(withAuthHeader(handler, "Bearer test-token")) + defer httpServer.Close() + + gateway := &MCPFederationGateway{ + client: httpServer.Client(), + } + connection := mcpgw.Connection{ + Config: map[string]any{ + "url": httpServer.URL, + "headers": map[string]any{ + "Authorization": "Bearer test-token", + }, + }, + } + + tools, err := gateway.ListSSEConnectionTools(context.Background(), connection) + if err != nil { + t.Fatalf("list sse tools failed: %v", err) + } + if len(tools) != 1 || tools[0].Name != "echo" { + t.Fatalf("unexpected tool list: %#v", tools) + } + + payload, err := gateway.CallSSEConnectionTool(context.Background(), connection, "echo", map[string]any{ + "query": "hello-sse", + }) + if err != nil { + t.Fatalf("call sse tool failed: %v", err) + } + assertEchoResult(t, payload, "hello-sse") +} + +func TestResolveSSEEndpointCandidatesCompatibility(t *testing.T) { + tests := []struct { + name string + config map[string]any + contains string + firstWant string + }{ + { + name: "prefer explicit sse_url", + config: map[string]any{"sse_url": "http://example.com/custom-sse", "url": "http://example.com/sse"}, + firstWant: "http://example.com/custom-sse", + contains: "http://example.com/sse", + }, + { + name: "fallback to url as endpoint", + config: map[string]any{"url": "http://example.com/sse"}, + firstWant: "http://example.com/sse", + contains: "http://example.com/sse", + }, + { + name: "derive endpoint from message url", + config: map[string]any{"message_url": "http://example.com/message"}, + firstWant: "http://example.com/sse", + contains: "http://example.com/message", + }, + { + name: "derive endpoint from url message suffix", + config: map[string]any{"url": "http://example.com/message"}, + firstWant: "http://example.com/message", + contains: "http://example.com/sse", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := resolveSSEEndpointCandidates(tt.config) + if len(got) == 0 { + t.Fatalf("resolve sse endpoints should not be empty") + } + if got[0] != tt.firstWant { + t.Fatalf("unexpected first endpoint: got=%s want=%s", got[0], tt.firstWant) + } + found := false + for _, item := range got { + if item == tt.contains { + found = true + break + } + } + if !found { + t.Fatalf("endpoint candidates missing expected value: %s in %#v", tt.contains, got) + } + }) + } +} + +func assertEchoResult(t *testing.T, payload map[string]any, expected string) { + t.Helper() + result, ok := payload["result"].(map[string]any) + if !ok { + t.Fatalf("missing result payload: %#v", payload) + } + structured, ok := result["structuredContent"].(map[string]any) + if !ok { + t.Fatalf("missing structured content: %#v", result) + } + if got := anyToString(structured["echo"]); got != expected { + t.Fatalf("unexpected echo result: got=%s want=%s", got, expected) + } +} diff --git a/internal/handlers/mcp_stdio.go b/internal/handlers/mcp_stdio.go index a262adaf..c6af9ca1 100644 --- a/internal/handlers/mcp_stdio.go +++ b/internal/handlers/mcp_stdio.go @@ -14,6 +14,8 @@ import ( "github.com/google/uuid" "github.com/labstack/echo/v4" + sdkjsonrpc "github.com/modelcontextprotocol/go-sdk/jsonrpc" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" ctr "github.com/memohai/memoh/internal/containerd" mcptools "github.com/memohai/memoh/internal/mcp" @@ -187,9 +189,19 @@ func (h *ContainerdHandler) startContainerdMCPCommandSession(ctx context.Context stdin: execSession.Stdin, stdout: execSession.Stdout, stderr: execSession.Stderr, - pending: make(map[string]chan mcptools.JSONRPCResponse), + pending: make(map[string]chan *sdkjsonrpc.Response), closed: make(chan struct{}), } + transport := &sdkmcp.IOTransport{ + Reader: sess.stdout, + Writer: sess.stdin, + } + conn, err := transport.Connect(ctx) + if err != nil { + sess.closeWithError(err) + return nil, err + } + sess.conn = conn h.startMCPStderrLogger(execSession.Stderr, containerID) go sess.readLoop() go func() { @@ -334,9 +346,19 @@ func (h *ContainerdHandler) startLimaMCPCommandSession(containerID string, req M stdout: stdout, stderr: stderr, cmd: cmd, - pending: make(map[string]chan mcptools.JSONRPCResponse), + pending: make(map[string]chan *sdkjsonrpc.Response), closed: make(chan struct{}), } + transport := &sdkmcp.IOTransport{ + Reader: sess.stdout, + Writer: sess.stdin, + } + conn, err := transport.Connect(context.Background()) + if err != nil { + sess.closeWithError(err) + return nil, err + } + sess.conn = conn h.startMCPStderrLogger(stderr, containerID) go sess.readLoop() diff --git a/internal/handlers/mcp_tools.go b/internal/handlers/mcp_tools.go new file mode 100644 index 00000000..d5b34e06 --- /dev/null +++ b/internal/handlers/mcp_tools.go @@ -0,0 +1,244 @@ +package handlers + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/labstack/echo/v4" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/memohai/memoh/internal/auth" + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +const ( + headerChatID = "X-Memoh-Chat-Id" + headerChannelIdentityID = "X-Memoh-Channel-Identity-Id" + headerSessionToken = "X-Memoh-Session-Token" + headerCurrentPlatform = "X-Memoh-Current-Platform" + headerReplyTarget = "X-Memoh-Reply-Target" + headerDisplayName = "X-Memoh-Display-Name" +) + +func (h *ContainerdHandler) SetToolGatewayService(service *mcpgw.ToolGatewayService) { + h.toolGateway = service +} + +// HandleMCPTools godoc +// @Summary Unified MCP tools gateway +// @Description MCP endpoint for tool discovery and invocation. +// @Tags containerd +// @Param bot_id path string true "Bot ID" +// @Param payload body object true "JSON-RPC request" +// @Success 200 {object} object "JSON-RPC response: {jsonrpc,id,result|error}" +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /bots/{bot_id}/tools [post] +func (h *ContainerdHandler) HandleMCPTools(c echo.Context) error { + if h.toolGateway == nil { + return echo.NewHTTPError(http.StatusServiceUnavailable, "tool gateway not configured") + } + botID, err := h.requireBotAccess(c) + if err != nil { + return err + } + return h.handleMCPToolsWithBotID(c, botID) +} + +func (h *ContainerdHandler) handleMCPToolsWithBotID(c echo.Context, botID string) error { + session := h.buildToolSessionContext(c, botID) + + req := c.Request() + ensureStreamableAcceptHeader(req) + ctx := context.WithValue(req.Context(), toolSessionContextKey{}, session) + req = req.WithContext(ctx) + + handler := sdkmcp.NewStreamableHTTPHandler( + func(r *http.Request) *sdkmcp.Server { + return h.buildToolMCPServer(r.Context()) + }, + &sdkmcp.StreamableHTTPOptions{ + Stateless: true, + JSONResponse: true, + Logger: h.logger, + }, + ) + handler.ServeHTTP(c.Response().Writer, req) + return nil +} + +func ensureStreamableAcceptHeader(req *http.Request) { + if req == nil { + return + } + acceptValues := req.Header.Values("Accept") + joined := strings.ToLower(strings.Join(acceptValues, ",")) + hasJSON := strings.Contains(joined, "application/json") || strings.Contains(joined, "application/*") || strings.Contains(joined, "*/*") + hasStream := strings.Contains(joined, "text/event-stream") || strings.Contains(joined, "text/*") || strings.Contains(joined, "*/*") + if hasJSON && hasStream { + return + } + + base := strings.TrimSpace(strings.Join(acceptValues, ",")) + parts := make([]string, 0, 3) + if base != "" { + parts = append(parts, base) + } + if !hasJSON { + parts = append(parts, "application/json") + } + if !hasStream { + parts = append(parts, "text/event-stream") + } + if len(parts) == 0 { + parts = append(parts, "application/json", "text/event-stream") + } + req.Header.Set("Accept", strings.Join(parts, ", ")) +} + +type toolSessionContextKey struct{} + +func (h *ContainerdHandler) buildToolMCPServer(ctx context.Context) *sdkmcp.Server { + if h.toolGateway == nil { + return nil + } + session, ok := ctx.Value(toolSessionContextKey{}).(mcpgw.ToolSessionContext) + if !ok { + return nil + } + + server := sdkmcp.NewServer( + &sdkmcp.Implementation{ + Name: "memoh-tools-gateway", + Version: "1.0.0", + }, + &sdkmcp.ServerOptions{ + Capabilities: &sdkmcp.ServerCapabilities{ + Tools: &sdkmcp.ToolCapabilities{ + ListChanged: false, + }, + }, + }, + ) + server.AddReceivingMiddleware(h.toolGatewayMiddleware(session)) + return server +} + +func (h *ContainerdHandler) toolGatewayMiddleware(session mcpgw.ToolSessionContext) sdkmcp.Middleware { + return func(next sdkmcp.MethodHandler) sdkmcp.MethodHandler { + return func(ctx context.Context, method string, req sdkmcp.Request) (sdkmcp.Result, error) { + switch strings.TrimSpace(method) { + case "tools/list": + tools, err := h.toolGateway.ListTools(ctx, session) + if err != nil { + return nil, err + } + return &sdkmcp.ListToolsResult{ + Tools: convertGatewayToolsToSDK(tools), + }, nil + case "tools/call": + callReq, ok := req.(*sdkmcp.ServerRequest[*sdkmcp.CallToolParamsRaw]) + if !ok || callReq == nil || callReq.Params == nil { + return nil, fmt.Errorf("tools/call params is required") + } + payload, err := buildToolCallPayloadFromRaw(callReq.Params) + if err != nil { + return nil, err + } + result, err := h.toolGateway.CallTool(ctx, session, payload) + if err != nil { + return nil, err + } + return convertGatewayCallResultToSDK(result) + default: + return next(ctx, method, req) + } + } + } +} + +func buildToolCallPayloadFromRaw(params *sdkmcp.CallToolParamsRaw) (mcpgw.ToolCallPayload, error) { + if params == nil { + return mcpgw.ToolCallPayload{}, fmt.Errorf("tools/call params is required") + } + name := strings.TrimSpace(params.Name) + if name == "" { + return mcpgw.ToolCallPayload{}, fmt.Errorf("tools/call name is required") + } + arguments := map[string]any{} + if len(params.Arguments) > 0 { + if err := json.Unmarshal(params.Arguments, &arguments); err != nil { + return mcpgw.ToolCallPayload{}, err + } + } + if arguments == nil { + arguments = map[string]any{} + } + return mcpgw.ToolCallPayload{ + Name: name, + Arguments: arguments, + }, nil +} + +func convertGatewayToolsToSDK(items []mcpgw.ToolDescriptor) []*sdkmcp.Tool { + if len(items) == 0 { + return []*sdkmcp.Tool{} + } + tools := make([]*sdkmcp.Tool, 0, len(items)) + for _, item := range items { + name := strings.TrimSpace(item.Name) + if name == "" { + continue + } + inputSchema := item.InputSchema + if inputSchema == nil { + inputSchema = map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + } + tools = append(tools, &sdkmcp.Tool{ + Name: name, + Description: strings.TrimSpace(item.Description), + InputSchema: inputSchema, + }) + } + return tools +} + +func convertGatewayCallResultToSDK(result map[string]any) (*sdkmcp.CallToolResult, error) { + if result == nil { + result = mcpgw.BuildToolSuccessResult(map[string]any{"ok": true}) + } + payload, err := json.Marshal(result) + if err != nil { + return nil, err + } + var out sdkmcp.CallToolResult + if err := json.Unmarshal(payload, &out); err != nil { + return nil, err + } + return &out, nil +} + +func (h *ContainerdHandler) buildToolSessionContext(c echo.Context, botID string) mcpgw.ToolSessionContext { + channelIdentityID := strings.TrimSpace(c.Request().Header.Get(headerChannelIdentityID)) + if channelIdentityID == "" { + if ctxIdentityID, err := auth.UserIDFromContext(c); err == nil { + channelIdentityID = strings.TrimSpace(ctxIdentityID) + } + } + return mcpgw.ToolSessionContext{ + BotID: strings.TrimSpace(botID), + ChatID: strings.TrimSpace(c.Request().Header.Get(headerChatID)), + ChannelIdentityID: channelIdentityID, + SessionToken: strings.TrimSpace(c.Request().Header.Get(headerSessionToken)), + CurrentPlatform: strings.TrimSpace(c.Request().Header.Get(headerCurrentPlatform)), + ReplyTarget: strings.TrimSpace(c.Request().Header.Get(headerReplyTarget)), + DisplayName: strings.TrimSpace(c.Request().Header.Get(headerDisplayName)), + } +} diff --git a/internal/handlers/mcp_tools_test.go b/internal/handlers/mcp_tools_test.go new file mode 100644 index 00000000..6ac0ce43 --- /dev/null +++ b/internal/handlers/mcp_tools_test.go @@ -0,0 +1,167 @@ +package handlers + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/labstack/echo/v4" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +func TestBuildToolCallPayloadFromRaw(t *testing.T) { + params := &sdkmcp.CallToolParamsRaw{ + Name: " tool_a ", + Arguments: json.RawMessage(`{"x":1}`), + } + payload, err := buildToolCallPayloadFromRaw(params) + if err != nil { + t.Fatalf("valid payload should parse: %v", err) + } + if payload.Name != "tool_a" { + t.Fatalf("unexpected tool name: %s", payload.Name) + } + if _, ok := payload.Arguments["x"]; !ok { + t.Fatalf("expected argument x") + } + + invalid := &sdkmcp.CallToolParamsRaw{ + Name: "", + Arguments: json.RawMessage(`{}`), + } + if _, err := buildToolCallPayloadFromRaw(invalid); err == nil { + t.Fatalf("empty tool name should fail") + } +} + +func TestHandleMCPToolsWithoutGateway(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/bots/bot-1/tools", strings.NewReader(`{"jsonrpc":"2.0","id":"1","method":"tools/list"}`)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.SetPath("/bots/:bot_id/tools") + c.SetParamNames("bot_id") + c.SetParamValues("bot-1") + + handler := &ContainerdHandler{} + err := handler.HandleMCPTools(c) + if err == nil { + t.Fatalf("expected service unavailable error") + } + httpErr, ok := err.(*echo.HTTPError) + if !ok { + t.Fatalf("expected echo HTTP error, got %T", err) + } + if httpErr.Code != http.StatusServiceUnavailable { + t.Fatalf("unexpected status code: %d", httpErr.Code) + } +} + +type mcpToolsTestExecutor struct { + lastSession mcpgw.ToolSessionContext +} + +func (e *mcpToolsTestExecutor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { + e.lastSession = session + return []mcpgw.ToolDescriptor{ + { + Name: "echo_tool", + Description: "echo input", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "input": map[string]any{"type": "string"}, + }, + }, + }, + }, nil +} + +func (e *mcpToolsTestExecutor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { + e.lastSession = session + if strings.TrimSpace(toolName) != "echo_tool" { + return nil, mcpgw.ErrToolNotFound + } + return mcpgw.BuildToolSuccessResult(map[string]any{ + "ok": true, + "echo": mcpgw.StringArg(arguments, "input"), + "chat_id": session.ChatID, + "channel_identity_id": session.ChannelIdentityID, + }), nil +} + +func TestHandleMCPToolsWithGatewayAcceptCompatibility(t *testing.T) { + e := echo.New() + executor := &mcpToolsTestExecutor{} + toolGateway := mcpgw.NewToolGatewayService(slog.Default(), []mcpgw.ToolExecutor{executor}, nil) + handler := &ContainerdHandler{ + logger: slog.Default(), + toolGateway: toolGateway, + } + + listReq := httptest.NewRequest(http.MethodPost, "/bots/bot-1/tools", strings.NewReader(`{"jsonrpc":"2.0","id":"1","method":"tools/list"}`)) + listReq.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + listReq.Header.Set("Accept", "application/json") + listReq.Header.Set("X-Memoh-Chat-Id", "chat-1") + listReq.Header.Set("X-Memoh-Channel-Identity-Id", "user-1") + listRec := httptest.NewRecorder() + listCtx := e.NewContext(listReq, listRec) + + if err := handler.handleMCPToolsWithBotID(listCtx, "bot-1"); err != nil { + t.Fatalf("list tools should succeed: %v", err) + } + if listRec.Code != http.StatusOK { + t.Fatalf("unexpected list status: %d body=%s", listRec.Code, listRec.Body.String()) + } + if !strings.Contains(strings.ToLower(listReq.Header.Get("Accept")), "text/event-stream") { + t.Fatalf("accept header should include text/event-stream: %s", listReq.Header.Get("Accept")) + } + + var listPayload map[string]any + if err := json.Unmarshal(listRec.Body.Bytes(), &listPayload); err != nil { + t.Fatalf("decode list payload failed: %v", err) + } + result, _ := listPayload["result"].(map[string]any) + tools, _ := result["tools"].([]any) + if len(tools) != 1 { + t.Fatalf("expected one tool, got: %#v", result["tools"]) + } + + callReq := httptest.NewRequest(http.MethodPost, "/bots/bot-1/tools", strings.NewReader(`{"jsonrpc":"2.0","id":"2","method":"tools/call","params":{"name":"echo_tool","arguments":{"input":"hello"}}}`)) + callReq.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + callReq.Header.Set("Accept", "application/json") + callReq.Header.Set("X-Memoh-Chat-Id", "chat-1") + callReq.Header.Set("X-Memoh-Channel-Identity-Id", "user-1") + callRec := httptest.NewRecorder() + callCtx := e.NewContext(callReq, callRec) + + if err := handler.handleMCPToolsWithBotID(callCtx, "bot-1"); err != nil { + t.Fatalf("call tool should succeed: %v", err) + } + if callRec.Code != http.StatusOK { + t.Fatalf("unexpected call status: %d body=%s", callRec.Code, callRec.Body.String()) + } + + var callPayload map[string]any + if err := json.Unmarshal(callRec.Body.Bytes(), &callPayload); err != nil { + t.Fatalf("decode call payload failed: %v", err) + } + callResult, _ := callPayload["result"].(map[string]any) + structured, _ := callResult["structuredContent"].(map[string]any) + if echoValue := strings.TrimSpace(mcpgw.StringArg(structured, "echo")); echoValue != "hello" { + t.Fatalf("unexpected echo value: %#v", structured["echo"]) + } + if strings.TrimSpace(mcpgw.StringArg(structured, "chat_id")) != "chat-1" { + t.Fatalf("unexpected chat id: %#v", structured["chat_id"]) + } + if strings.TrimSpace(mcpgw.StringArg(structured, "channel_identity_id")) != "user-1" { + t.Fatalf("unexpected channel identity id: %#v", structured["channel_identity_id"]) + } +} diff --git a/internal/handlers/memory.go b/internal/handlers/memory.go index fff1ee0d..360e8641 100644 --- a/internal/handlers/memory.go +++ b/internal/handlers/memory.go @@ -155,13 +155,22 @@ func (h *MemoryHandler) ChatSearch(c echo.Context) error { if err != nil { return err } + chatObj, err := h.chatService.Get(c.Request().Context(), chatID) + if err != nil { + return echo.NewHTTPError(http.StatusNotFound, "chat not found") + } + botID := strings.TrimSpace(chatObj.BotID) // Search across all enabled namespaces and merge results. var allResults []memory.MemoryItem for _, scope := range scopes { filters := buildNamespaceFilters(scope.Namespace, scope.ScopeID, payload.Filters) + if botID != "" { + filters["botId"] = botID + } req := memory.SearchRequest{ Query: payload.Query, + BotID: botID, RunID: payload.RunID, Limit: payload.Limit, Filters: filters, diff --git a/internal/mcp/jsonrpc.go b/internal/mcp/jsonrpc.go index d6d4933e..18912243 100644 --- a/internal/mcp/jsonrpc.go +++ b/internal/mcp/jsonrpc.go @@ -2,9 +2,7 @@ package mcp import ( "encoding/json" - "fmt" "strings" - "sync" ) func IsNotification(req JSONRPCRequest) bool { @@ -18,78 +16,3 @@ func JSONRPCErrorResponse(id json.RawMessage, code int, message string) JSONRPCR Error: &JSONRPCError{Code: code, Message: message}, } } - -func BuildPayloads(req JSONRPCRequest, initOnce *sync.Once) ([]string, json.RawMessage, error) { - if req.JSONRPC == "" { - req.JSONRPC = "2.0" - } - targetID := req.ID - payloads := []string{} - shouldInit := req.Method != "initialize" && req.Method != "notifications/initialized" - if initOnce != nil { - ran := false - initOnce.Do(func() { - ran = true - }) - if ran { - // This is the first call on the session. - } else { - shouldInit = false - } - } - if shouldInit { - initReq := map[string]any{ - "jsonrpc": "2.0", - "id": "init-1", - "method": "initialize", - "params": map[string]any{ - "protocolVersion": "2025-06-18", - "capabilities": map[string]any{ - "roots": map[string]any{ - "listChanged": false, - }, - }, - "clientInfo": map[string]any{ - "name": "memoh-http-proxy", - "version": "v0", - }, - }, - } - initBytes, err := json.Marshal(initReq) - if err != nil { - return nil, nil, err - } - payloads = append(payloads, string(initBytes)) - - initialized := map[string]any{ - "jsonrpc": "2.0", - "method": "notifications/initialized", - } - initializedBytes, err := json.Marshal(initialized) - if err != nil { - return nil, nil, err - } - payloads = append(payloads, string(initializedBytes)) - } - - reqBytes, err := json.Marshal(req) - if err != nil { - return nil, nil, err - } - payloads = append(payloads, string(reqBytes)) - return payloads, targetID, nil -} - -func BuildNotificationPayloads(req JSONRPCRequest) ([]string, error) { - if req.JSONRPC == "" { - req.JSONRPC = "2.0" - } - if strings.TrimSpace(req.Method) == "" { - return nil, fmt.Errorf("missing method") - } - reqBytes, err := json.Marshal(req) - if err != nil { - return nil, err - } - return []string{string(reqBytes)}, nil -} diff --git a/internal/mcp/providers/memory/provider.go b/internal/mcp/providers/memory/provider.go new file mode 100644 index 00000000..137e8b19 --- /dev/null +++ b/internal/mcp/providers/memory/provider.go @@ -0,0 +1,228 @@ +package memory + +import ( + "context" + "log/slog" + "sort" + "strings" + + "github.com/memohai/memoh/internal/chat" + mcpgw "github.com/memohai/memoh/internal/mcp" + mem "github.com/memohai/memoh/internal/memory" +) + +const ( + toolSearchMemory = "search_memory" + defaultMemoryToolLimit = 8 + maxMemoryToolLimit = 50 +) + +type MemorySearcher interface { + Search(ctx context.Context, req mem.SearchRequest) (mem.SearchResponse, error) +} + +type ChatAccessor interface { + Get(ctx context.Context, chatID string) (chat.Chat, error) + GetSettings(ctx context.Context, chatID string) (chat.Settings, error) + IsParticipant(ctx context.Context, chatID, channelIdentityID string) (bool, error) +} + +type AdminChecker interface { + IsAdmin(ctx context.Context, channelIdentityID string) (bool, error) +} + +type Executor struct { + searcher MemorySearcher + chatAccessor ChatAccessor + adminChecker AdminChecker + logger *slog.Logger +} + +func NewExecutor(log *slog.Logger, searcher MemorySearcher, chatAccessor ChatAccessor, adminChecker AdminChecker) *Executor { + if log == nil { + log = slog.Default() + } + return &Executor{ + searcher: searcher, + chatAccessor: chatAccessor, + adminChecker: adminChecker, + logger: log.With(slog.String("provider", "memory_tool")), + } +} + +func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { + if p.searcher == nil || p.chatAccessor == nil { + return []mcpgw.ToolDescriptor{}, nil + } + return []mcpgw.ToolDescriptor{ + { + Name: toolSearchMemory, + Description: "Search for memories relevant to the current chat", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{ + "type": "string", + "description": "The query to search memories", + }, + "limit": map[string]any{ + "type": "integer", + "description": "Maximum number of memory results", + }, + }, + "required": []string{"query"}, + }, + }, + }, nil +} + +func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { + if toolName != toolSearchMemory { + return nil, mcpgw.ErrToolNotFound + } + if p.searcher == nil || p.chatAccessor == nil { + return mcpgw.BuildToolErrorResult("memory service not available"), nil + } + + query := mcpgw.StringArg(arguments, "query") + if query == "" { + return mcpgw.BuildToolErrorResult("query is required"), nil + } + botID := strings.TrimSpace(session.BotID) + chatID := strings.TrimSpace(session.ChatID) + channelIdentityID := strings.TrimSpace(session.ChannelIdentityID) + if botID == "" || chatID == "" { + return mcpgw.BuildToolErrorResult("bot_id and chat_id are required"), nil + } + + limit := defaultMemoryToolLimit + if value, ok, err := mcpgw.IntArg(arguments, "limit"); err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } else if ok { + limit = value + } + if limit <= 0 { + limit = defaultMemoryToolLimit + } + if limit > maxMemoryToolLimit { + limit = maxMemoryToolLimit + } + + chatObj, err := p.chatAccessor.Get(ctx, chatID) + if err != nil { + return mcpgw.BuildToolErrorResult("chat not found"), nil + } + if strings.TrimSpace(chatObj.BotID) != botID { + return mcpgw.BuildToolErrorResult("bot mismatch"), nil + } + if channelIdentityID != "" { + allowed, err := p.canAccessChat(ctx, chatID, channelIdentityID) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + if !allowed { + return mcpgw.BuildToolErrorResult("not a chat participant"), nil + } + } + + settings, err := p.chatAccessor.GetSettings(ctx, chatID) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + + type memoryScope struct { + namespace string + scopeID string + } + scopes := make([]memoryScope, 0, 3) + if settings.EnableChatMemory { + scopes = append(scopes, memoryScope{namespace: "chat", scopeID: chatID}) + } + if settings.EnablePrivateMemory && channelIdentityID != "" { + scopes = append(scopes, memoryScope{namespace: "private", scopeID: channelIdentityID}) + } + if settings.EnablePublicMemory { + scopes = append(scopes, memoryScope{namespace: "public", scopeID: botID}) + } + if len(scopes) == 0 { + scopes = append(scopes, memoryScope{namespace: "chat", scopeID: chatID}) + } + + allResults := make([]mem.MemoryItem, 0, len(scopes)*limit) + for _, scope := range scopes { + resp, err := p.searcher.Search(ctx, mem.SearchRequest{ + Query: query, + BotID: botID, + Limit: limit, + Filters: map[string]any{ + "namespace": scope.namespace, + "scopeId": scope.scopeID, + "botId": botID, + }, + }) + if err != nil { + p.logger.Warn("memory search namespace failed", slog.String("namespace", scope.namespace), slog.Any("error", err)) + continue + } + allResults = append(allResults, resp.Results...) + } + + allResults = deduplicateMemoryItems(allResults) + sort.Slice(allResults, func(i, j int) bool { + return allResults[i].Score > allResults[j].Score + }) + if len(allResults) > limit { + allResults = allResults[:limit] + } + + results := make([]map[string]any, 0, len(allResults)) + for _, item := range allResults { + results = append(results, map[string]any{ + "id": item.ID, + "memory": item.Memory, + "score": item.Score, + }) + } + + return mcpgw.BuildToolSuccessResult(map[string]any{ + "query": query, + "total": len(results), + "results": results, + }), nil +} + +func (p *Executor) canAccessChat(ctx context.Context, chatID, channelIdentityID string) (bool, error) { + if p.adminChecker != nil { + isAdmin, err := p.adminChecker.IsAdmin(ctx, channelIdentityID) + if err != nil { + return false, err + } + if isAdmin { + return true, nil + } + } + return p.chatAccessor.IsParticipant(ctx, chatID, channelIdentityID) +} + +func deduplicateMemoryItems(items []mem.MemoryItem) []mem.MemoryItem { + if len(items) == 0 { + return items + } + seen := make(map[string]struct{}, len(items)) + result := make([]mem.MemoryItem, 0, len(items)) + for _, item := range items { + id := strings.TrimSpace(item.ID) + if id == "" { + id = strings.TrimSpace(item.Memory) + } + if id == "" { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + result = append(result, item) + } + return result +} diff --git a/internal/mcp/providers/message/provider.go b/internal/mcp/providers/message/provider.go new file mode 100644 index 00000000..ad964a6c --- /dev/null +++ b/internal/mcp/providers/message/provider.go @@ -0,0 +1,147 @@ +package message + +import ( + "context" + "log/slog" + "strings" + + "github.com/memohai/memoh/internal/channel" + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +const toolSendMessage = "send_message" + +type Sender interface { + Send(ctx context.Context, botID string, channelType channel.ChannelType, req channel.SendRequest) error +} + +type ChannelTypeResolver interface { + ParseChannelType(raw string) (channel.ChannelType, error) +} + +type Executor struct { + sender Sender + resolver ChannelTypeResolver + logger *slog.Logger +} + +func NewExecutor(log *slog.Logger, sender Sender, resolver ChannelTypeResolver) *Executor { + if log == nil { + log = slog.Default() + } + return &Executor{ + sender: sender, + resolver: resolver, + logger: log.With(slog.String("provider", "message_tool")), + } +} + +func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { + if p.sender == nil || p.resolver == nil { + return []mcpgw.ToolDescriptor{}, nil + } + return []mcpgw.ToolDescriptor{ + { + Name: toolSendMessage, + Description: "Send a message to a channel or session", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "bot_id": map[string]any{ + "type": "string", + "description": "Bot ID, optional and defaults to current bot", + }, + "platform": map[string]any{ + "type": "string", + "description": "Channel platform name", + }, + "target": map[string]any{ + "type": "string", + "description": "Channel target (chat/group/thread ID)", + }, + "channel_identity_id": map[string]any{ + "type": "string", + "description": "Target identity ID when direct target is absent", + }, + "to_user_id": map[string]any{ + "type": "string", + "description": "Alias for channel_identity_id", + }, + "message": map[string]any{ + "type": "string", + "description": "Message text content", + }, + }, + "required": []string{"message"}, + }, + }, + }, nil +} + +func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { + if toolName != toolSendMessage { + return nil, mcpgw.ErrToolNotFound + } + if p.sender == nil || p.resolver == nil { + return mcpgw.BuildToolErrorResult("message service not available"), nil + } + + botID := mcpgw.FirstStringArg(arguments, "bot_id") + if botID == "" { + botID = strings.TrimSpace(session.BotID) + } + if botID == "" { + return mcpgw.BuildToolErrorResult("bot_id is required"), nil + } + if strings.TrimSpace(session.BotID) != "" && botID != strings.TrimSpace(session.BotID) { + return mcpgw.BuildToolErrorResult("bot_id mismatch"), nil + } + + platform := mcpgw.FirstStringArg(arguments, "platform") + if platform == "" { + platform = strings.TrimSpace(session.CurrentPlatform) + } + if platform == "" { + return mcpgw.BuildToolErrorResult("platform is required"), nil + } + channelType, err := p.resolver.ParseChannelType(platform) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + + messageText := mcpgw.FirstStringArg(arguments, "message") + if messageText == "" { + return mcpgw.BuildToolErrorResult("message is required"), nil + } + + target := mcpgw.FirstStringArg(arguments, "target") + if target == "" { + target = strings.TrimSpace(session.ReplyTarget) + } + channelIdentityID := mcpgw.FirstStringArg(arguments, "channel_identity_id", "to_user_id") + if target == "" && channelIdentityID == "" { + return mcpgw.BuildToolErrorResult("target or channel_identity_id is required"), nil + } + + sendReq := channel.SendRequest{ + Target: target, + ChannelIdentityID: channelIdentityID, + Message: channel.Message{ + Text: messageText, + }, + } + if err := p.sender.Send(ctx, botID, channelType, sendReq); err != nil { + p.logger.Warn("send message failed", slog.Any("error", err), slog.String("bot_id", botID), slog.String("platform", platform)) + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + + payload := map[string]any{ + "ok": true, + "bot_id": botID, + "platform": channelType.String(), + "target": target, + "channel_identity_id": channelIdentityID, + "instruction": "Message delivered successfully. You have completed your response. Please STOP now and do not call any more tools.", + } + return mcpgw.BuildToolSuccessResult(payload), nil +} diff --git a/internal/mcp/providers/schedule/provider.go b/internal/mcp/providers/schedule/provider.go new file mode 100644 index 00000000..71e37b3e --- /dev/null +++ b/internal/mcp/providers/schedule/provider.go @@ -0,0 +1,259 @@ +package schedule + +import ( + "context" + "log/slog" + "strings" + + mcpgw "github.com/memohai/memoh/internal/mcp" + sched "github.com/memohai/memoh/internal/schedule" +) + +const ( + toolScheduleList = "schedule_list" + toolScheduleGet = "schedule_get" + toolScheduleCreate = "schedule_create" + toolScheduleUpdate = "schedule_update" + toolScheduleDelete = "schedule_delete" +) + +type Scheduler interface { + List(ctx context.Context, botID string) ([]sched.Schedule, error) + Get(ctx context.Context, id string) (sched.Schedule, error) + Create(ctx context.Context, botID string, req sched.CreateRequest) (sched.Schedule, error) + Update(ctx context.Context, id string, req sched.UpdateRequest) (sched.Schedule, error) + Delete(ctx context.Context, id string) error +} + +type Executor struct { + service Scheduler + logger *slog.Logger +} + +func NewExecutor(log *slog.Logger, service Scheduler) *Executor { + if log == nil { + log = slog.Default() + } + return &Executor{ + service: service, + logger: log.With(slog.String("provider", "schedule_tool")), + } +} + +func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { + if p.service == nil { + return []mcpgw.ToolDescriptor{}, nil + } + return []mcpgw.ToolDescriptor{ + { + Name: toolScheduleList, + Description: "List schedules for current bot", + InputSchema: emptyObjectSchema(), + }, + { + Name: toolScheduleGet, + Description: "Get a schedule by id", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "string", "description": "Schedule ID"}, + }, + "required": []string{"id"}, + }, + }, + { + Name: toolScheduleCreate, + Description: "Create a new schedule", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "description": map[string]any{"type": "string"}, + "pattern": map[string]any{"type": "string"}, + "max_calls": map[string]any{ + "type": []string{"integer", "null"}, + "description": "Optional max calls, null means unlimited", + }, + "enabled": map[string]any{"type": "boolean"}, + "command": map[string]any{"type": "string"}, + }, + "required": []string{"name", "description", "pattern", "command"}, + }, + }, + { + Name: toolScheduleUpdate, + Description: "Update an existing schedule", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "string"}, + "name": map[string]any{"type": "string"}, + "description": map[string]any{"type": "string"}, + "pattern": map[string]any{"type": "string"}, + "max_calls": map[string]any{"type": []string{"integer", "null"}}, + "enabled": map[string]any{"type": "boolean"}, + "command": map[string]any{"type": "string"}, + }, + "required": []string{"id"}, + }, + }, + { + Name: toolScheduleDelete, + Description: "Delete a schedule by id", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "string", "description": "Schedule ID"}, + }, + "required": []string{"id"}, + }, + }, + }, nil +} + +func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { + if p.service == nil { + return mcpgw.BuildToolErrorResult("schedule service not available"), nil + } + botID := strings.TrimSpace(session.BotID) + if botID == "" { + return mcpgw.BuildToolErrorResult("bot_id is required"), nil + } + + switch toolName { + case toolScheduleList: + items, err := p.service.List(ctx, botID) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + return mcpgw.BuildToolSuccessResult(map[string]any{ + "items": items, + }), nil + case toolScheduleGet: + id := mcpgw.StringArg(arguments, "id") + if id == "" { + return mcpgw.BuildToolErrorResult("id is required"), nil + } + item, err := p.service.Get(ctx, id) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + if item.BotID != botID { + return mcpgw.BuildToolErrorResult("bot mismatch"), nil + } + return mcpgw.BuildToolSuccessResult(item), nil + case toolScheduleCreate: + name := mcpgw.StringArg(arguments, "name") + description := mcpgw.StringArg(arguments, "description") + pattern := mcpgw.StringArg(arguments, "pattern") + command := mcpgw.StringArg(arguments, "command") + if name == "" || description == "" || pattern == "" || command == "" { + return mcpgw.BuildToolErrorResult("name, description, pattern, command are required"), nil + } + + req := sched.CreateRequest{ + Name: name, + Description: description, + Pattern: pattern, + Command: command, + } + maxCalls, err := parseNullableIntArg(arguments, "max_calls") + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + req.MaxCalls = maxCalls + if enabled, ok, err := mcpgw.BoolArg(arguments, "enabled"); err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } else if ok { + req.Enabled = &enabled + } + item, err := p.service.Create(ctx, botID, req) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + return mcpgw.BuildToolSuccessResult(item), nil + case toolScheduleUpdate: + id := mcpgw.StringArg(arguments, "id") + if id == "" { + return mcpgw.BuildToolErrorResult("id is required"), nil + } + req := sched.UpdateRequest{} + maxCalls, err := parseNullableIntArg(arguments, "max_calls") + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + req.MaxCalls = maxCalls + if value := mcpgw.StringArg(arguments, "name"); value != "" { + req.Name = &value + } + if value := mcpgw.StringArg(arguments, "description"); value != "" { + req.Description = &value + } + if value := mcpgw.StringArg(arguments, "pattern"); value != "" { + req.Pattern = &value + } + if value := mcpgw.StringArg(arguments, "command"); value != "" { + req.Command = &value + } + if enabled, ok, err := mcpgw.BoolArg(arguments, "enabled"); err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } else if ok { + req.Enabled = &enabled + } + item, err := p.service.Update(ctx, id, req) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + if item.BotID != botID { + return mcpgw.BuildToolErrorResult("bot mismatch"), nil + } + return mcpgw.BuildToolSuccessResult(item), nil + case toolScheduleDelete: + id := mcpgw.StringArg(arguments, "id") + if id == "" { + return mcpgw.BuildToolErrorResult("id is required"), nil + } + item, err := p.service.Get(ctx, id) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + if item.BotID != botID { + return mcpgw.BuildToolErrorResult("bot mismatch"), nil + } + if err := p.service.Delete(ctx, id); err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + return mcpgw.BuildToolSuccessResult(map[string]any{"success": true}), nil + default: + return nil, mcpgw.ErrToolNotFound + } +} + +func parseNullableIntArg(arguments map[string]any, key string) (sched.NullableInt, error) { + req := sched.NullableInt{} + if arguments == nil { + return req, nil + } + raw, exists := arguments[key] + if !exists { + return req, nil + } + req.Set = true + if raw == nil { + req.Value = nil + return req, nil + } + value, _, err := mcpgw.IntArg(arguments, key) + if err != nil { + return sched.NullableInt{}, err + } + req.Value = &value + return req, nil +} + +func emptyObjectSchema() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} diff --git a/internal/mcp/sources/federation/source.go b/internal/mcp/sources/federation/source.go new file mode 100644 index 00000000..a444cd48 --- /dev/null +++ b/internal/mcp/sources/federation/source.go @@ -0,0 +1,293 @@ +package federation + +import ( + "context" + "fmt" + "log/slog" + "sort" + "strconv" + "strings" + "sync" + "time" + + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +const cacheTTL = 5 * time.Second + +type ConnectionLister interface { + ListActiveByBot(ctx context.Context, botID string) ([]mcpgw.Connection, error) +} + +type Gateway interface { + ListFSMCPTools(ctx context.Context, botID string) ([]mcpgw.ToolDescriptor, error) + CallFSMCPTool(ctx context.Context, botID, toolName string, args map[string]any) (map[string]any, error) + + ListHTTPConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) + CallHTTPConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) + + ListSSEConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) + CallSSEConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) + + ListStdioConnectionTools(ctx context.Context, botID string, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) + CallStdioConnectionTool(ctx context.Context, botID string, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) +} + +type toolRoute struct { + sourceType string + originalName string + connection mcpgw.Connection +} + +type cacheEntry struct { + expiresAt time.Time + routes map[string]toolRoute + tools []mcpgw.ToolDescriptor +} + +type Source struct { + logger *slog.Logger + gateway Gateway + connections ConnectionLister + + mu sync.Mutex + cache map[string]cacheEntry +} + +func NewSource(log *slog.Logger, gateway Gateway, connections ConnectionLister) *Source { + if log == nil { + log = slog.Default() + } + return &Source{ + logger: log.With(slog.String("source", "federated_mcp_tool")), + gateway: gateway, + connections: connections, + cache: map[string]cacheEntry{}, + } +} + +func (s *Source) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { + botID := strings.TrimSpace(session.BotID) + if botID == "" || s.gateway == nil { + return []mcpgw.ToolDescriptor{}, nil + } + if cached, ok := s.getCache(botID); ok { + return cloneTools(cached.tools), nil + } + tools, routes := s.buildToolsAndRoutes(ctx, botID) + s.setCache(botID, cacheEntry{ + expiresAt: time.Now().Add(cacheTTL), + routes: routes, + tools: tools, + }) + return cloneTools(tools), nil +} + +func (s *Source) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { + if s.gateway == nil { + return mcpgw.BuildToolErrorResult("federation gateway not available"), nil + } + botID := strings.TrimSpace(session.BotID) + if botID == "" { + return mcpgw.BuildToolErrorResult("bot_id is required"), nil + } + + route, ok := s.getRoute(botID, toolName) + if !ok { + _, _ = s.ListTools(ctx, session) + route, ok = s.getRoute(botID, toolName) + if !ok { + return nil, mcpgw.ErrToolNotFound + } + } + if arguments == nil { + arguments = map[string]any{} + } + + var ( + payload map[string]any + err error + ) + switch route.sourceType { + case "fs": + payload, err = s.gateway.CallFSMCPTool(ctx, botID, route.originalName, arguments) + case "http": + payload, err = s.gateway.CallHTTPConnectionTool(ctx, route.connection, route.originalName, arguments) + case "sse": + payload, err = s.gateway.CallSSEConnectionTool(ctx, route.connection, route.originalName, arguments) + case "stdio": + payload, err = s.gateway.CallStdioConnectionTool(ctx, botID, route.connection, route.originalName, arguments) + default: + return mcpgw.BuildToolErrorResult("unsupported federated source"), nil + } + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + if err := mcpgw.PayloadError(payload); err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + if result, ok := payload["result"].(map[string]any); ok { + return result, nil + } + return mcpgw.BuildToolSuccessResult(payload), nil +} + +func (s *Source) buildToolsAndRoutes(ctx context.Context, botID string) ([]mcpgw.ToolDescriptor, map[string]toolRoute) { + routes := map[string]toolRoute{} + tools := make([]mcpgw.ToolDescriptor, 0, 16) + + addTool := func(descriptor mcpgw.ToolDescriptor, route toolRoute) { + name := strings.TrimSpace(descriptor.Name) + if name == "" { + return + } + finalName := name + if _, exists := routes[finalName]; exists { + seed := strings.ReplaceAll(finalName, ".", "_") + if seed == "" { + seed = "tool" + } + for i := 2; ; i++ { + candidate := seed + "_" + strconv.Itoa(i) + if _, ok := routes[candidate]; ok { + continue + } + finalName = candidate + break + } + } + descriptor.Name = finalName + routes[finalName] = route + tools = append(tools, descriptor) + } + + fsTools, err := s.gateway.ListFSMCPTools(ctx, botID) + if err != nil { + s.logger.Warn("list fs mcp tools failed", slog.String("bot_id", botID), slog.Any("error", err)) + } else { + for _, tool := range fsTools { + addTool(tool, toolRoute{ + sourceType: "fs", + originalName: tool.Name, + }) + } + } + + if s.connections != nil { + items, err := s.connections.ListActiveByBot(ctx, botID) + if err != nil { + s.logger.Warn("list mcp connections failed", slog.String("bot_id", botID), slog.Any("error", err)) + } else { + sort.Slice(items, func(i, j int) bool { + if items[i].Name == items[j].Name { + return items[i].ID < items[j].ID + } + return items[i].Name < items[j].Name + }) + for _, connection := range items { + var connTools []mcpgw.ToolDescriptor + switch strings.ToLower(strings.TrimSpace(connection.Type)) { + case "http": + connTools, err = s.gateway.ListHTTPConnectionTools(ctx, connection) + case "sse": + connTools, err = s.gateway.ListSSEConnectionTools(ctx, connection) + case "stdio": + connTools, err = s.gateway.ListStdioConnectionTools(ctx, botID, connection) + default: + s.logger.Warn("unsupported mcp connection type", slog.String("connection_id", connection.ID), slog.String("type", connection.Type)) + continue + } + if err != nil { + s.logger.Warn("list tools from connection failed", slog.String("connection_id", connection.ID), slog.String("name", connection.Name), slog.Any("error", err)) + continue + } + prefix := sanitizePrefix(connection.Name) + for _, tool := range connTools { + origin := strings.TrimSpace(tool.Name) + alias := origin + if prefix != "" { + alias = prefix + "." + origin + } + tool.Name = alias + if strings.TrimSpace(tool.Description) != "" { + tool.Description = "[" + strings.TrimSpace(connection.Name) + "] " + tool.Description + } else { + tool.Description = "[" + strings.TrimSpace(connection.Name) + "] " + origin + } + addTool(tool, toolRoute{ + sourceType: strings.ToLower(strings.TrimSpace(connection.Type)), + originalName: origin, + connection: connection, + }) + } + } + } + } + return tools, routes +} + +func sanitizePrefix(raw string) string { + raw = strings.TrimSpace(strings.ToLower(raw)) + if raw == "" { + return "mcp" + } + builder := strings.Builder{} + for _, ch := range raw { + if (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' { + builder.WriteRune(ch) + continue + } + builder.WriteRune('_') + } + normalized := strings.Trim(builder.String(), "._-") + if normalized == "" { + return "mcp" + } + return normalized +} + +func cloneTools(items []mcpgw.ToolDescriptor) []mcpgw.ToolDescriptor { + if len(items) == 0 { + return []mcpgw.ToolDescriptor{} + } + out := make([]mcpgw.ToolDescriptor, 0, len(items)) + for _, item := range items { + out = append(out, mcpgw.ToolDescriptor{ + Name: item.Name, + Description: item.Description, + InputSchema: item.InputSchema, + }) + } + return out +} + +func (s *Source) getCache(botID string) (cacheEntry, bool) { + s.mu.Lock() + defer s.mu.Unlock() + cached, ok := s.cache[botID] + if !ok || time.Now().After(cached.expiresAt) { + return cacheEntry{}, false + } + return cached, true +} + +func (s *Source) setCache(botID string, entry cacheEntry) { + s.mu.Lock() + s.cache[botID] = entry + s.mu.Unlock() +} + +func (s *Source) getRoute(botID, toolName string) (toolRoute, bool) { + s.mu.Lock() + defer s.mu.Unlock() + cached, ok := s.cache[botID] + if !ok || time.Now().After(cached.expiresAt) { + return toolRoute{}, false + } + route, exists := cached.routes[strings.TrimSpace(toolName)] + return route, exists +} + +func (s *Source) String() string { + return fmt.Sprintf("FederationSource(%p)", s) +} diff --git a/internal/mcp/sources/federation/source_test.go b/internal/mcp/sources/federation/source_test.go new file mode 100644 index 00000000..ab4902f9 --- /dev/null +++ b/internal/mcp/sources/federation/source_test.go @@ -0,0 +1,136 @@ +package federation + +import ( + "context" + "log/slog" + "testing" + + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +type testConnectionLister struct { + items []mcpgw.Connection + err error +} + +func (l *testConnectionLister) ListActiveByBot(ctx context.Context, botID string) ([]mcpgw.Connection, error) { + if l.err != nil { + return nil, l.err + } + return l.items, nil +} + +type testGateway struct { + listFS []mcpgw.ToolDescriptor + listHTTP []mcpgw.ToolDescriptor + listSSE []mcpgw.ToolDescriptor + listStdio []mcpgw.ToolDescriptor + + lastCallType string +} + +func (g *testGateway) ListFSMCPTools(ctx context.Context, botID string) ([]mcpgw.ToolDescriptor, error) { + return g.listFS, nil +} + +func (g *testGateway) CallFSMCPTool(ctx context.Context, botID, toolName string, args map[string]any) (map[string]any, error) { + g.lastCallType = "fs" + return map[string]any{"result": map[string]any{"ok": true, "route": "fs"}}, nil +} + +func (g *testGateway) ListHTTPConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { + return g.listHTTP, nil +} + +func (g *testGateway) CallHTTPConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { + g.lastCallType = "http" + return map[string]any{"result": map[string]any{"ok": true, "route": "http"}}, nil +} + +func (g *testGateway) ListSSEConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { + return g.listSSE, nil +} + +func (g *testGateway) CallSSEConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { + g.lastCallType = "sse" + return map[string]any{"result": map[string]any{"ok": true, "route": "sse"}}, nil +} + +func (g *testGateway) ListStdioConnectionTools(ctx context.Context, botID string, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { + return g.listStdio, nil +} + +func (g *testGateway) CallStdioConnectionTool(ctx context.Context, botID string, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { + g.lastCallType = "stdio" + return map[string]any{"result": map[string]any{"ok": true, "route": "stdio"}}, nil +} + +func TestSourceListToolsIncludesSSETools(t *testing.T) { + gateway := &testGateway{ + listSSE: []mcpgw.ToolDescriptor{ + { + Name: "search", + Description: "search remote data", + InputSchema: map[string]any{"type": "object"}, + }, + }, + } + lister := &testConnectionLister{ + items: []mcpgw.Connection{ + { + ID: "conn-1", + Name: "Remote SSE", + Type: "sse", + Active: true, + Config: map[string]any{"url": "http://example.com/sse"}, + }, + }, + } + + source := NewSource(slog.Default(), gateway, lister) + tools, err := source.ListTools(context.Background(), mcpgw.ToolSessionContext{BotID: "bot-1"}) + if err != nil { + t.Fatalf("list tools failed: %v", err) + } + if len(tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(tools)) + } + if tools[0].Name != "remote_sse.search" { + t.Fatalf("unexpected tool alias: %s", tools[0].Name) + } +} + +func TestSourceCallToolRoutesToSSEConnection(t *testing.T) { + gateway := &testGateway{ + listSSE: []mcpgw.ToolDescriptor{ + { + Name: "search", + Description: "search remote data", + InputSchema: map[string]any{"type": "object"}, + }, + }, + } + lister := &testConnectionLister{ + items: []mcpgw.Connection{ + { + ID: "conn-1", + Name: "Remote SSE", + Type: "sse", + Active: true, + Config: map[string]any{"url": "http://example.com/sse"}, + }, + }, + } + source := NewSource(slog.Default(), gateway, lister) + + result, err := source.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot-1"}, "remote_sse.search", map[string]any{"query": "hello"}) + if err != nil { + t.Fatalf("call tool failed: %v", err) + } + if gateway.lastCallType != "sse" { + t.Fatalf("expected sse route, got: %s", gateway.lastCallType) + } + if ok, _ := result["ok"].(bool); !ok { + t.Fatalf("expected ok=true in result") + } +} diff --git a/internal/mcp/tool_gateway_service.go b/internal/mcp/tool_gateway_service.go new file mode 100644 index 00000000..f7a921d5 --- /dev/null +++ b/internal/mcp/tool_gateway_service.go @@ -0,0 +1,168 @@ +package mcp + +import ( + "context" + "fmt" + "log/slog" + "strings" + "sync" + "time" +) + +const ( + defaultToolRegistryCacheTTL = 5 * time.Second +) + +type cachedToolRegistry struct { + expiresAt time.Time + registry *ToolRegistry +} + +// ToolGatewayService federates tools from executors and sources. +type ToolGatewayService struct { + logger *slog.Logger + executors []ToolExecutor + sources []ToolSource + cacheTTL time.Duration + + mu sync.Mutex + cache map[string]cachedToolRegistry +} + +func NewToolGatewayService(log *slog.Logger, executors []ToolExecutor, sources []ToolSource) *ToolGatewayService { + if log == nil { + log = slog.Default() + } + filteredExecutors := make([]ToolExecutor, 0, len(executors)) + for _, executor := range executors { + if executor != nil { + filteredExecutors = append(filteredExecutors, executor) + } + } + filteredSources := make([]ToolSource, 0, len(sources)) + for _, source := range sources { + if source != nil { + filteredSources = append(filteredSources, source) + } + } + return &ToolGatewayService{ + logger: log.With(slog.String("service", "tool_gateway")), + executors: filteredExecutors, + sources: filteredSources, + cacheTTL: defaultToolRegistryCacheTTL, + cache: map[string]cachedToolRegistry{}, + } +} + +func (s *ToolGatewayService) InitializeResult() map[string]any { + return map[string]any{ + "protocolVersion": "2025-06-18", + "capabilities": map[string]any{ + "tools": map[string]any{ + "listChanged": false, + }, + }, + "serverInfo": map[string]any{ + "name": "memoh-tools-gateway", + "version": "1.0.0", + }, + } +} + +func (s *ToolGatewayService) ListTools(ctx context.Context, session ToolSessionContext) ([]ToolDescriptor, error) { + registry, err := s.getRegistry(ctx, session, false) + if err != nil { + return nil, err + } + return registry.List(), nil +} + +func (s *ToolGatewayService) CallTool(ctx context.Context, session ToolSessionContext, payload ToolCallPayload) (map[string]any, error) { + toolName := strings.TrimSpace(payload.Name) + if toolName == "" { + return nil, fmt.Errorf("tool name is required") + } + + registry, err := s.getRegistry(ctx, session, false) + if err != nil { + return nil, err + } + executor, _, ok := registry.Lookup(toolName) + if !ok { + // Refresh once for dynamic executors/sources. + registry, err = s.getRegistry(ctx, session, true) + if err != nil { + return nil, err + } + executor, _, ok = registry.Lookup(toolName) + if !ok { + return BuildToolErrorResult("tool not found: " + toolName), nil + } + } + + arguments := payload.Arguments + if arguments == nil { + arguments = map[string]any{} + } + result, err := executor.CallTool(ctx, session, toolName, arguments) + if err != nil { + if err == ErrToolNotFound { + return BuildToolErrorResult("tool not found: " + toolName), nil + } + return BuildToolErrorResult(err.Error()), nil + } + if result == nil { + return BuildToolSuccessResult(map[string]any{"ok": true}), nil + } + return result, nil +} + +func (s *ToolGatewayService) getRegistry(ctx context.Context, session ToolSessionContext, force bool) (*ToolRegistry, error) { + botID := strings.TrimSpace(session.BotID) + if botID == "" { + return nil, fmt.Errorf("bot id is required") + } + if !force { + s.mu.Lock() + cached, ok := s.cache[botID] + if ok && time.Now().Before(cached.expiresAt) && cached.registry != nil { + s.mu.Unlock() + return cached.registry, nil + } + s.mu.Unlock() + } + + registry := NewToolRegistry() + for _, executor := range s.executors { + tools, err := executor.ListTools(ctx, session) + if err != nil { + s.logger.Warn("list tools from executor failed", slog.Any("error", err)) + continue + } + for _, tool := range tools { + if err := registry.Register(executor, tool); err != nil { + s.logger.Warn("skip duplicated/invalid tool", slog.String("tool", tool.Name), slog.Any("error", err)) + } + } + } + for _, source := range s.sources { + tools, err := source.ListTools(ctx, session) + if err != nil { + s.logger.Warn("list tools from source failed", slog.Any("error", err)) + continue + } + for _, tool := range tools { + if err := registry.Register(source, tool); err != nil { + s.logger.Warn("skip duplicated/invalid tool", slog.String("tool", tool.Name), slog.Any("error", err)) + } + } + } + + s.mu.Lock() + s.cache[botID] = cachedToolRegistry{ + expiresAt: time.Now().Add(s.cacheTTL), + registry: registry, + } + s.mu.Unlock() + return registry, nil +} diff --git a/internal/mcp/tool_gateway_service_test.go b/internal/mcp/tool_gateway_service_test.go new file mode 100644 index 00000000..3509f7ef --- /dev/null +++ b/internal/mcp/tool_gateway_service_test.go @@ -0,0 +1,126 @@ +package mcp + +import ( + "context" + "errors" + "log/slog" + "testing" +) + +type gatewayTestProvider struct { + tools []ToolDescriptor + callResult map[string]map[string]any + callErr map[string]error +} + +func (p *gatewayTestProvider) ListTools(ctx context.Context, session ToolSessionContext) ([]ToolDescriptor, error) { + return p.tools, nil +} + +func (p *gatewayTestProvider) CallTool(ctx context.Context, session ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { + if err, ok := p.callErr[toolName]; ok { + return nil, err + } + if result, ok := p.callResult[toolName]; ok { + return result, nil + } + return nil, ErrToolNotFound +} + +func TestToolGatewayServiceListTools(t *testing.T) { + providerA := &gatewayTestProvider{ + tools: []ToolDescriptor{ + {Name: "tool_a", InputSchema: map[string]any{"type": "object"}}, + {Name: "dup_tool", InputSchema: map[string]any{"type": "object"}}, + }, + } + providerB := &gatewayTestProvider{ + tools: []ToolDescriptor{ + {Name: "tool_b", InputSchema: map[string]any{"type": "object"}}, + {Name: "dup_tool", InputSchema: map[string]any{"type": "object"}}, + }, + } + service := NewToolGatewayService(slog.Default(), []ToolExecutor{providerA, providerB}, nil) + + tools, err := service.ListTools(context.Background(), ToolSessionContext{BotID: "bot-1"}) + if err != nil { + t.Fatalf("list tools failed: %v", err) + } + if len(tools) != 3 { + t.Fatalf("expected 3 tools after dedupe, got %d", len(tools)) + } +} + +func TestToolGatewayServiceCallToolSuccess(t *testing.T) { + provider := &gatewayTestProvider{ + tools: []ToolDescriptor{ + {Name: "echo_tool", InputSchema: map[string]any{"type": "object"}}, + }, + callResult: map[string]map[string]any{ + "echo_tool": { + "content": []map[string]any{ + {"type": "text", "text": "ok"}, + }, + }, + }, + callErr: map[string]error{}, + } + service := NewToolGatewayService(slog.Default(), []ToolExecutor{provider}, nil) + + result, err := service.CallTool(context.Background(), ToolSessionContext{BotID: "bot-1"}, ToolCallPayload{ + Name: "echo_tool", + Arguments: map[string]any{"value": "hello"}, + }) + if err != nil { + t.Fatalf("call tool should not fail: %v", err) + } + if _, ok := result["content"]; !ok { + t.Fatalf("expected content in tool result") + } +} + +func TestToolGatewayServiceCallToolNotFound(t *testing.T) { + provider := &gatewayTestProvider{ + tools: []ToolDescriptor{}, + callResult: map[string]map[string]any{}, + callErr: map[string]error{}, + } + service := NewToolGatewayService(slog.Default(), []ToolExecutor{provider}, nil) + + result, err := service.CallTool(context.Background(), ToolSessionContext{BotID: "bot-1"}, ToolCallPayload{ + Name: "missing_tool", + Arguments: map[string]any{}, + }) + if err != nil { + t.Fatalf("call should return mcp error result instead of failing: %v", err) + } + isErr, _ := result["isError"].(bool) + if !isErr { + t.Fatalf("expected isError=true for missing tool") + } +} + +func TestToolGatewayServiceCallToolProviderError(t *testing.T) { + provider := &gatewayTestProvider{ + tools: []ToolDescriptor{ + {Name: "broken_tool", InputSchema: map[string]any{"type": "object"}}, + }, + callResult: map[string]map[string]any{}, + callErr: map[string]error{ + "broken_tool": errors.New("boom"), + }, + } + service := NewToolGatewayService(slog.Default(), []ToolExecutor{provider}, nil) + + result, err := service.CallTool(context.Background(), ToolSessionContext{BotID: "bot-1"}, ToolCallPayload{ + Name: "broken_tool", + Arguments: map[string]any{}, + }) + if err != nil { + t.Fatalf("call should not return hard error: %v", err) + } + isErr, _ := result["isError"].(bool) + if !isErr { + t.Fatalf("expected isError=true for provider failure") + } +} diff --git a/internal/mcp/tool_registry.go b/internal/mcp/tool_registry.go new file mode 100644 index 00000000..edd7552c --- /dev/null +++ b/internal/mcp/tool_registry.go @@ -0,0 +1,72 @@ +package mcp + +import ( + "fmt" + "sort" + "strings" +) + +type registryItem struct { + executor ToolExecutor + tool ToolDescriptor +} + +// ToolRegistry stores provider ownership and descriptor metadata. +type ToolRegistry struct { + items map[string]registryItem +} + +func NewToolRegistry() *ToolRegistry { + return &ToolRegistry{ + items: map[string]registryItem{}, + } +} + +func (r *ToolRegistry) Register(executor ToolExecutor, tool ToolDescriptor) error { + if executor == nil { + return fmt.Errorf("tool executor is required") + } + name := strings.TrimSpace(tool.Name) + if name == "" { + return fmt.Errorf("tool name is required") + } + if tool.InputSchema == nil { + tool.InputSchema = map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + } + if _, exists := r.items[name]; exists { + return fmt.Errorf("tool already registered: %s", name) + } + tool.Name = name + r.items[name] = registryItem{ + executor: executor, + tool: tool, + } + return nil +} + +func (r *ToolRegistry) Lookup(name string) (ToolExecutor, ToolDescriptor, bool) { + item, ok := r.items[strings.TrimSpace(name)] + if !ok { + return nil, ToolDescriptor{}, false + } + return item.executor, item.tool, true +} + +func (r *ToolRegistry) List() []ToolDescriptor { + if len(r.items) == 0 { + return []ToolDescriptor{} + } + names := make([]string, 0, len(r.items)) + for name := range r.items { + names = append(names, name) + } + sort.Strings(names) + tools := make([]ToolDescriptor, 0, len(names)) + for _, name := range names { + tools = append(tools, r.items[name].tool) + } + return tools +} diff --git a/internal/mcp/tool_registry_test.go b/internal/mcp/tool_registry_test.go new file mode 100644 index 00000000..f5001d9d --- /dev/null +++ b/internal/mcp/tool_registry_test.go @@ -0,0 +1,83 @@ +package mcp + +import ( + "context" + "testing" +) + +type registryTestProvider struct{} + +func (p *registryTestProvider) ListTools(ctx context.Context, session ToolSessionContext) ([]ToolDescriptor, error) { + return nil, nil +} + +func (p *registryTestProvider) CallTool(ctx context.Context, session ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { + return nil, nil +} + +func TestToolRegistryRegisterAndLookup(t *testing.T) { + registry := NewToolRegistry() + provider := ®istryTestProvider{} + if err := registry.Register(provider, ToolDescriptor{ + Name: "tool_a", + Description: "test", + InputSchema: map[string]any{"type": "object"}, + }); err != nil { + t.Fatalf("register should succeed: %v", err) + } + + gotProvider, descriptor, ok := registry.Lookup("tool_a") + if !ok { + t.Fatalf("lookup should find registered tool") + } + if gotProvider != provider { + t.Fatalf("lookup provider mismatch") + } + if descriptor.Name != "tool_a" { + t.Fatalf("lookup descriptor mismatch, got: %s", descriptor.Name) + } +} + +func TestToolRegistryRegisterDuplicate(t *testing.T) { + registry := NewToolRegistry() + provider := ®istryTestProvider{} + first := ToolDescriptor{ + Name: "dup_tool", + Description: "first", + InputSchema: map[string]any{"type": "object"}, + } + second := ToolDescriptor{ + Name: "dup_tool", + Description: "second", + InputSchema: map[string]any{"type": "object"}, + } + if err := registry.Register(provider, first); err != nil { + t.Fatalf("first register should succeed: %v", err) + } + if err := registry.Register(provider, second); err == nil { + t.Fatalf("duplicate register should fail") + } +} + +func TestToolRegistryListStableOrder(t *testing.T) { + registry := NewToolRegistry() + provider := ®istryTestProvider{} + tools := []ToolDescriptor{ + {Name: "b_tool", InputSchema: map[string]any{"type": "object"}}, + {Name: "a_tool", InputSchema: map[string]any{"type": "object"}}, + {Name: "c_tool", InputSchema: map[string]any{"type": "object"}}, + } + for _, tool := range tools { + if err := registry.Register(provider, tool); err != nil { + t.Fatalf("register %s failed: %v", tool.Name, err) + } + } + + list := registry.List() + if len(list) != 3 { + t.Fatalf("expected 3 tools, got %d", len(list)) + } + if list[0].Name != "a_tool" || list[1].Name != "b_tool" || list[2].Name != "c_tool" { + t.Fatalf("unexpected order: %#v", list) + } +} diff --git a/internal/mcp/tool_types.go b/internal/mcp/tool_types.go new file mode 100644 index 00000000..89e59c78 --- /dev/null +++ b/internal/mcp/tool_types.go @@ -0,0 +1,198 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "math" + "strings" +) + +// ToolSessionContext carries request-scoped identity for tool execution. +type ToolSessionContext struct { + BotID string + ChatID string + ChannelIdentityID string + SessionToken string + CurrentPlatform string + ReplyTarget string + DisplayName string +} + +// ToolDescriptor is the MCP tools/list item shape used by the gateway. +type ToolDescriptor struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema map[string]any `json:"inputSchema"` +} + +// ToolExecutor represents business-facing tools (message/schedule/memory). +type ToolExecutor interface { + ListTools(ctx context.Context, session ToolSessionContext) ([]ToolDescriptor, error) + CallTool(ctx context.Context, session ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) +} + +// ToolSource represents infrastructure-level tool sources (federation/connectors). +// A source is not a business tool itself; it supplies and routes downstream tools. +type ToolSource interface { + ListTools(ctx context.Context, session ToolSessionContext) ([]ToolDescriptor, error) + CallTool(ctx context.Context, session ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) +} + +// ToolCallPayload is the MCP tools/call params payload. +type ToolCallPayload struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` +} + +// ErrToolNotFound indicates the provider does not own the requested tool. +var ErrToolNotFound = fmt.Errorf("tool not found") + +// BuildToolSuccessResult builds a standard MCP tool success result object. +func BuildToolSuccessResult(structured any) map[string]any { + result := map[string]any{} + if structured != nil { + result["structuredContent"] = structured + if text := stringifyStructuredContent(structured); text != "" { + result["content"] = []map[string]any{ + { + "type": "text", + "text": text, + }, + } + } + } + if len(result) == 0 { + result["content"] = []map[string]any{ + { + "type": "text", + "text": "ok", + }, + } + } + return result +} + +// BuildToolErrorResult builds a standard MCP tool error result object. +func BuildToolErrorResult(message string) map[string]any { + msg := strings.TrimSpace(message) + if msg == "" { + msg = "tool execution failed" + } + return map[string]any{ + "isError": true, + "content": []map[string]any{ + { + "type": "text", + "text": msg, + }, + }, + } +} + +func stringifyStructuredContent(v any) string { + if v == nil { + return "" + } + switch value := v.(type) { + case string: + return strings.TrimSpace(value) + default: + payload, err := json.Marshal(value) + if err != nil { + return "" + } + return string(payload) + } +} + +func StringArg(arguments map[string]any, key string) string { + if arguments == nil { + return "" + } + raw, ok := arguments[key] + if !ok { + return "" + } + switch value := raw.(type) { + case string: + return strings.TrimSpace(value) + default: + return strings.TrimSpace(fmt.Sprintf("%v", raw)) + } +} + +func FirstStringArg(arguments map[string]any, keys ...string) string { + for _, key := range keys { + if value := StringArg(arguments, key); value != "" { + return value + } + } + return "" +} + +func IntArg(arguments map[string]any, key string) (int, bool, error) { + if arguments == nil { + return 0, false, nil + } + raw, ok := arguments[key] + if !ok || raw == nil { + return 0, false, nil + } + switch value := raw.(type) { + case int: + return value, true, nil + case int8: + return int(value), true, nil + case int16: + return int(value), true, nil + case int32: + return int(value), true, nil + case int64: + return int(value), true, nil + case uint: + return int(value), true, nil + case uint8: + return int(value), true, nil + case uint16: + return int(value), true, nil + case uint32: + return int(value), true, nil + case uint64: + return int(value), true, nil + case float32: + f := float64(value) + if math.IsNaN(f) || math.IsInf(f, 0) { + return 0, true, fmt.Errorf("%s must be a valid number", key) + } + return int(f), true, nil + case float64: + if math.IsNaN(value) || math.IsInf(value, 0) { + return 0, true, fmt.Errorf("%s must be a valid number", key) + } + return int(value), true, nil + case json.Number: + i, err := value.Int64() + if err != nil { + return 0, true, fmt.Errorf("%s must be an integer", key) + } + return int(i), true, nil + default: + return 0, true, fmt.Errorf("%s must be a number", key) + } +} + +func BoolArg(arguments map[string]any, key string) (bool, bool, error) { + if arguments == nil { + return false, false, nil + } + raw, ok := arguments[key] + if !ok || raw == nil { + return false, false, nil + } + value, ok := raw.(bool) + if !ok { + return false, true, fmt.Errorf("%s must be a boolean", key) + } + return value, true, nil +}