diff --git a/agent/src/agent.ts b/agent/src/agent.ts index a536ce51..6cc3cf7d 100644 --- a/agent/src/agent.ts +++ b/agent/src/agent.ts @@ -1,5 +1,5 @@ import { generateText, ImagePart, LanguageModelUsage, ModelMessage, stepCountIs, streamText, UserModelMessage } from 'ai' -import { AgentInput, AgentParams, AgentSkill, allActions, HTTPMCPConnection, MCPConnection, Schedule } from './types' +import { AgentInput, AgentParams, AgentSkill, allActions, HTTPMCPConnection, MCPConnection, Schedule, StdioMCPConnection } from './types' import { system, schedule, user, subagentSystem } from './prompts' import { AuthFetcher } from './index' import { createModel } from './model' @@ -56,7 +56,14 @@ export const createAgent = ({ 'Authorization': `Bearer ${auth.bearer}`, }, } - return [fs] + const mcpFetch: StdioMCPConnection = { + type: 'stdio', + name: 'mcp-fetch', + command: 'npx', + args: ['fetch-mcp'], + env: {}, + } + return [fs, mcpFetch] } const generateSystemPrompt = () => { @@ -82,7 +89,11 @@ export const createAgent = ({ const { tools: mcpTools, close: closeMCP } = await getMCPTools([ ...defaultMCPConnections, ...mcpConnections, - ]) + ], { + botId: identity.botId, + auth, + fetch, + }) Object.assign(tools, mcpTools) return { tools, diff --git a/agent/src/tools/mcp.ts b/agent/src/tools/mcp.ts index b7bcbac4..1bb4466f 100644 --- a/agent/src/tools/mcp.ts +++ b/agent/src/tools/mcp.ts @@ -1,7 +1,15 @@ import { HTTPMCPConnection, MCPConnection, SSEMCPConnection, StdioMCPConnection } from '../types' import { createMCPClient } from '@ai-sdk/mcp' +import { AuthFetcher } from '../index' +import type { AgentAuthContext } from '../types/agent' -export const getMCPTools = async (connections: MCPConnection[]) => { +type MCPToolOptions = { + botId?: string + auth?: AgentAuthContext + fetch?: AuthFetcher +} + +export const getMCPTools = async (connections: MCPConnection[], options: MCPToolOptions = {}) => { const closeCallbacks: Array<() => Promise> = [] const getHTTPTools = async (connection: HTTPMCPConnection) => { @@ -13,7 +21,8 @@ export const getMCPTools = async (connections: MCPConnection[]) => { } }) closeCallbacks.push(() => client.close()) - return await client.tools() + const tools = await client.tools() + return tools } const getSSETools = async (connection: SSEMCPConnection) => { @@ -25,15 +34,51 @@ export const getMCPTools = async (connections: MCPConnection[]) => { } }) closeCallbacks.push(() => client.close()) - return await client.tools() + const tools = await client.tools() + return tools } const getStdioTools = async (connection: StdioMCPConnection) => { - // TODO: Implement stdio tools - return {} + if (!options.fetch || !options.botId || !options.auth) { + throw new Error('stdio mcp requires auth fetcher and bot id') + } + const response = await options.fetch(`/bots/${options.botId}/mcp-stdio`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + name: connection.name, + command: connection.command, + args: connection.args ?? [], + env: connection.env ?? {}, + cwd: connection.cwd ?? '' + }) + }) + if (!response.ok) { + const text = await response.text().catch(() => '') + throw new Error(`mcp-stdio failed: ${response.status} ${text}`) + } + const data = await response.json().catch(() => ({} as { url?: string })) + const rawUrl = typeof data?.url === 'string' ? data.url : '' + if (!rawUrl) { + throw new Error('mcp-stdio response missing url') + } + const baseUrl = options.auth.baseUrl ?? '' + const url = rawUrl.startsWith('http') + ? rawUrl + : `${baseUrl.replace(/\/$/, '')}/${rawUrl.replace(/^\//, '')}` + return await getHTTPTools({ + type: 'http', + name: connection.name, + url, + headers: { + 'Authorization': `Bearer ${options.auth.bearer}` + } + }) } - const toolSets = await Promise.all(connections.map(connection => { + const toolSets = await Promise.all(connections.map(async (connection) => { switch (connection.type) { case 'http': return getHTTPTools(connection) @@ -41,6 +86,9 @@ export const getMCPTools = async (connections: MCPConnection[]) => { return getSSETools(connection) case 'stdio': return getStdioTools(connection) + default: + console.warn('unknown mcp connection type', connection) + return {} } })) @@ -50,4 +98,4 @@ export const getMCPTools = async (connections: MCPConnection[]) => { await Promise.all(closeCallbacks.map(callback => callback())) } } -} \ No newline at end of file +} diff --git a/cmd/agent/main.go b/cmd/agent/main.go index ef18c7d0..bee71fee 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -154,8 +154,10 @@ func main() { contactsHandler := handlers.NewContactsHandler(contactsService, botService, usersService) preauthService := preauth.NewService(queries) preauthHandler := handlers.NewPreauthHandler(preauthService, botService, usersService) + mcpConnectionsService := mcp.NewConnectionService(logger.L, queries) + mcpHandler := handlers.NewMCPHandler(logger.L, mcpConnectionsService, botService, usersService) - chatResolver = chat.NewResolver(logger.L, modelsService, queries, memoryService, historyService, settingsService, cfg.AgentGateway.BaseURL(), 120*time.Second) + chatResolver = chat.NewResolver(logger.L, modelsService, queries, memoryService, historyService, settingsService, mcpConnectionsService, cfg.AgentGateway.BaseURL(), 120*time.Second) chatResolver.SetSkillLoader(&skillLoaderAdapter{handler: containerdHandler}) embeddingsHandler := handlers.NewEmbeddingsHandler(logger.L, modelsService, queries) swaggerHandler := handlers.NewSwaggerHandler(logger.L) @@ -186,7 +188,7 @@ func main() { scheduleHandler := handlers.NewScheduleHandler(logger.L, scheduleService, botService, usersService) subagentService := subagent.NewService(logger.L, queries) subagentHandler := handlers.NewSubagentHandler(logger.L, subagentService, botService, usersService) - srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, historyHandler, contactsHandler, preauthHandler, scheduleHandler, subagentHandler, containerdHandler, channelHandler, usersHandler, cliHandler, webHandler) + srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, historyHandler, contactsHandler, preauthHandler, scheduleHandler, subagentHandler, containerdHandler, channelHandler, usersHandler, mcpHandler, cliHandler, webHandler) if err := srv.Start(); err != nil { logger.Error("server failed", slog.Any("error", err)) diff --git a/db/migrations/0001_init.down.sql b/db/migrations/0001_init.down.sql index 687bea5b..28f5128d 100644 --- a/db/migrations/0001_init.down.sql +++ b/db/migrations/0001_init.down.sql @@ -12,6 +12,7 @@ DROP TABLE IF EXISTS bot_channel_configs; DROP TABLE IF EXISTS user_channel_bindings; DROP TABLE IF EXISTS history; DROP TABLE IF EXISTS conversations; +DROP TABLE IF EXISTS mcp_connections; DROP TABLE IF EXISTS bot_model_configs; DROP TABLE IF EXISTS bot_settings; DROP TABLE IF EXISTS bot_members; diff --git a/db/migrations/0001_init.up.sql b/db/migrations/0001_init.up.sql index 8e0b7b08..b9facd8b 100644 --- a/db/migrations/0001_init.up.sql +++ b/db/migrations/0001_init.up.sql @@ -106,6 +106,21 @@ CREATE TABLE IF NOT EXISTS bot_model_configs ( memory_model_id UUID REFERENCES models(id) ON DELETE SET NULL ); +CREATE TABLE IF NOT EXISTS mcp_connections ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, + name TEXT NOT NULL, + type TEXT NOT NULL, + config JSONB NOT NULL DEFAULT '{}'::jsonb, + is_active BOOLEAN NOT NULL DEFAULT true, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT mcp_connections_type_check CHECK (type IN ('stdio', 'http', 'sse')), + CONSTRAINT mcp_connections_unique UNIQUE (bot_id, name) +); + +CREATE INDEX IF NOT EXISTS idx_mcp_connections_bot_id ON mcp_connections(bot_id); + CREATE TABLE IF NOT EXISTS conversations ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, diff --git a/db/queries/mcp.sql b/db/queries/mcp.sql new file mode 100644 index 00000000..9ccc3ee1 --- /dev/null +++ b/db/queries/mcp.sql @@ -0,0 +1,30 @@ +-- name: GetMCPConnectionByID :one +SELECT id, bot_id, name, type, config, is_active, created_at, updated_at +FROM mcp_connections +WHERE bot_id = $1 AND id = $2 +LIMIT 1; + +-- name: ListMCPConnectionsByBotID :many +SELECT id, bot_id, name, type, config, is_active, created_at, updated_at +FROM mcp_connections +WHERE bot_id = $1 +ORDER BY created_at DESC; + +-- name: CreateMCPConnection :one +INSERT INTO mcp_connections (bot_id, name, type, config, is_active) +VALUES ($1, $2, $3, $4, $5) +RETURNING id, bot_id, name, type, config, is_active, created_at, updated_at; + +-- name: UpdateMCPConnection :one +UPDATE mcp_connections +SET name = $3, + type = $4, + config = $5, + is_active = $6, + updated_at = now() +WHERE bot_id = $1 AND id = $2 +RETURNING id, bot_id, name, type, config, is_active, created_at, updated_at; + +-- name: DeleteMCPConnection :exec +DELETE FROM mcp_connections +WHERE bot_id = $1 AND id = $2; diff --git a/docs/docs.go b/docs/docs.go index f1df0f4e..063aa69f 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -887,6 +887,362 @@ const docTemplate = `{ } } }, + "/bots/{bot_id}/mcp": { + "get": { + "description": "List MCP connections for a bot", + "tags": [ + "mcp" + ], + "summary": "List MCP connections", + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/mcp.ListResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "post": { + "description": "Create a MCP connection for a bot", + "tags": [ + "mcp" + ], + "summary": "Create MCP connection", + "parameters": [ + { + "description": "MCP payload", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/mcp.UpsertRequest" + } + } + ], + "responses": { + "201": { + "description": "Created", + "schema": { + "$ref": "#/definitions/github_com_memohai_memoh_internal_mcp.Connection" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/bots/{bot_id}/mcp-stdio": { + "post": { + "description": "Start a stdio MCP process in the bot container and expose it as MCP HTTP endpoint.", + "tags": [ + "containerd" + ], + "summary": "Create MCP stdio proxy", + "parameters": [ + { + "type": "string", + "description": "Bot ID", + "name": "bot_id", + "in": "path", + "required": true + }, + { + "description": "Stdio MCP payload", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/handlers.MCPStdioRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/handlers.MCPStdioResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/bots/{bot_id}/mcp-stdio/{session_id}": { + "post": { + "description": "Proxies MCP JSON-RPC requests to a stdio MCP process in the container.", + "tags": [ + "containerd" + ], + "summary": "MCP stdio proxy (JSON-RPC)", + "parameters": [ + { + "type": "string", + "description": "Bot ID", + "name": "bot_id", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "Session ID", + "name": "session_id", + "in": "path", + "required": true + }, + { + "description": "JSON-RPC request", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "type": "object" + } + } + ], + "responses": { + "200": { + "description": "JSON-RPC response: {jsonrpc,id,result|error}", + "schema": { + "type": "object" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/bots/{bot_id}/mcp/{id}": { + "get": { + "description": "Get a MCP connection by ID", + "tags": [ + "mcp" + ], + "summary": "Get MCP connection", + "parameters": [ + { + "type": "string", + "description": "MCP ID", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/github_com_memohai_memoh_internal_mcp.Connection" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "put": { + "description": "Update a MCP connection by ID", + "tags": [ + "mcp" + ], + "summary": "Update MCP connection", + "parameters": [ + { + "type": "string", + "description": "MCP ID", + "name": "id", + "in": "path", + "required": true + }, + { + "description": "MCP payload", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/mcp.UpsertRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/github_com_memohai_memoh_internal_mcp.Connection" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "delete": { + "description": "Delete a MCP connection by ID", + "tags": [ + "mcp" + ], + "summary": "Delete MCP connection", + "parameters": [ + { + "type": "string", + "description": "MCP ID", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "204": { + "description": "No Content" + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, "/bots/{bot_id}/memory/add": { "post": { "description": "Add memory for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).", @@ -4678,6 +5034,36 @@ const docTemplate = `{ } } }, + "github_com_memohai_memoh_internal_mcp.Connection": { + "type": "object", + "properties": { + "active": { + "type": "boolean" + }, + "bot_id": { + "type": "string" + }, + "config": { + "type": "object", + "additionalProperties": {} + }, + "created_at": { + "type": "string" + }, + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "type": { + "type": "string" + }, + "updated_at": { + "type": "string" + } + } + }, "handlers.ChannelMeta": { "type": "object", "properties": { @@ -4933,6 +5319,49 @@ const docTemplate = `{ } } }, + "handlers.MCPStdioRequest": { + "type": "object", + "properties": { + "args": { + "type": "array", + "items": { + "type": "string" + } + }, + "command": { + "type": "string" + }, + "cwd": { + "type": "string" + }, + "env": { + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "name": { + "type": "string" + } + } + }, + "handlers.MCPStdioResponse": { + "type": "object", + "properties": { + "session_id": { + "type": "string" + }, + "tools": { + "type": "array", + "items": { + "type": "string" + } + }, + "url": { + "type": "string" + } + } + }, "handlers.SkillItem": { "type": "object", "properties": { @@ -5185,6 +5614,35 @@ const docTemplate = `{ } } }, + "mcp.ListResponse": { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "$ref": "#/definitions/github_com_memohai_memoh_internal_mcp.Connection" + } + } + } + }, + "mcp.UpsertRequest": { + "type": "object", + "properties": { + "active": { + "type": "boolean" + }, + "config": { + "type": "object", + "additionalProperties": {} + }, + "name": { + "type": "string" + }, + "type": { + "type": "string" + } + } + }, "memory.DeleteResponse": { "type": "object", "properties": { diff --git a/docs/swagger.json b/docs/swagger.json index 81af9224..7b795f7c 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -878,6 +878,362 @@ } } }, + "/bots/{bot_id}/mcp": { + "get": { + "description": "List MCP connections for a bot", + "tags": [ + "mcp" + ], + "summary": "List MCP connections", + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/mcp.ListResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "post": { + "description": "Create a MCP connection for a bot", + "tags": [ + "mcp" + ], + "summary": "Create MCP connection", + "parameters": [ + { + "description": "MCP payload", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/mcp.UpsertRequest" + } + } + ], + "responses": { + "201": { + "description": "Created", + "schema": { + "$ref": "#/definitions/github_com_memohai_memoh_internal_mcp.Connection" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/bots/{bot_id}/mcp-stdio": { + "post": { + "description": "Start a stdio MCP process in the bot container and expose it as MCP HTTP endpoint.", + "tags": [ + "containerd" + ], + "summary": "Create MCP stdio proxy", + "parameters": [ + { + "type": "string", + "description": "Bot ID", + "name": "bot_id", + "in": "path", + "required": true + }, + { + "description": "Stdio MCP payload", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/handlers.MCPStdioRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/handlers.MCPStdioResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/bots/{bot_id}/mcp-stdio/{session_id}": { + "post": { + "description": "Proxies MCP JSON-RPC requests to a stdio MCP process in the container.", + "tags": [ + "containerd" + ], + "summary": "MCP stdio proxy (JSON-RPC)", + "parameters": [ + { + "type": "string", + "description": "Bot ID", + "name": "bot_id", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "Session ID", + "name": "session_id", + "in": "path", + "required": true + }, + { + "description": "JSON-RPC request", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "type": "object" + } + } + ], + "responses": { + "200": { + "description": "JSON-RPC response: {jsonrpc,id,result|error}", + "schema": { + "type": "object" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/bots/{bot_id}/mcp/{id}": { + "get": { + "description": "Get a MCP connection by ID", + "tags": [ + "mcp" + ], + "summary": "Get MCP connection", + "parameters": [ + { + "type": "string", + "description": "MCP ID", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/github_com_memohai_memoh_internal_mcp.Connection" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "put": { + "description": "Update a MCP connection by ID", + "tags": [ + "mcp" + ], + "summary": "Update MCP connection", + "parameters": [ + { + "type": "string", + "description": "MCP ID", + "name": "id", + "in": "path", + "required": true + }, + { + "description": "MCP payload", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/mcp.UpsertRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/github_com_memohai_memoh_internal_mcp.Connection" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "delete": { + "description": "Delete a MCP connection by ID", + "tags": [ + "mcp" + ], + "summary": "Delete MCP connection", + "parameters": [ + { + "type": "string", + "description": "MCP ID", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "204": { + "description": "No Content" + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, "/bots/{bot_id}/memory/add": { "post": { "description": "Add memory for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).", @@ -4669,6 +5025,36 @@ } } }, + "github_com_memohai_memoh_internal_mcp.Connection": { + "type": "object", + "properties": { + "active": { + "type": "boolean" + }, + "bot_id": { + "type": "string" + }, + "config": { + "type": "object", + "additionalProperties": {} + }, + "created_at": { + "type": "string" + }, + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "type": { + "type": "string" + }, + "updated_at": { + "type": "string" + } + } + }, "handlers.ChannelMeta": { "type": "object", "properties": { @@ -4924,6 +5310,49 @@ } } }, + "handlers.MCPStdioRequest": { + "type": "object", + "properties": { + "args": { + "type": "array", + "items": { + "type": "string" + } + }, + "command": { + "type": "string" + }, + "cwd": { + "type": "string" + }, + "env": { + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "name": { + "type": "string" + } + } + }, + "handlers.MCPStdioResponse": { + "type": "object", + "properties": { + "session_id": { + "type": "string" + }, + "tools": { + "type": "array", + "items": { + "type": "string" + } + }, + "url": { + "type": "string" + } + } + }, "handlers.SkillItem": { "type": "object", "properties": { @@ -5176,6 +5605,35 @@ } } }, + "mcp.ListResponse": { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "$ref": "#/definitions/github_com_memohai_memoh_internal_mcp.Connection" + } + } + } + }, + "mcp.UpsertRequest": { + "type": "object", + "properties": { + "active": { + "type": "boolean" + }, + "config": { + "type": "object", + "additionalProperties": {} + }, + "name": { + "type": "string" + }, + "type": { + "type": "string" + } + } + }, "memory.DeleteResponse": { "type": "object", "properties": { diff --git a/docs/swagger.yaml b/docs/swagger.yaml index ff7551bb..5bf23b95 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -483,6 +483,26 @@ definitions: name: type: string type: object + github_com_memohai_memoh_internal_mcp.Connection: + properties: + active: + type: boolean + bot_id: + type: string + config: + additionalProperties: {} + type: object + created_at: + type: string + id: + type: string + name: + type: string + type: + type: string + updated_at: + type: string + type: object handlers.ChannelMeta: properties: capabilities: @@ -648,6 +668,34 @@ definitions: username: type: string type: object + handlers.MCPStdioRequest: + properties: + args: + items: + type: string + type: array + command: + type: string + cwd: + type: string + env: + additionalProperties: + type: string + type: object + name: + type: string + type: object + handlers.MCPStdioResponse: + properties: + session_id: + type: string + tools: + items: + type: string + type: array + url: + type: string + type: object handlers.SkillItem: properties: content: @@ -815,6 +863,25 @@ definitions: timestamp: type: string type: object + mcp.ListResponse: + properties: + items: + items: + $ref: '#/definitions/github_com_memohai_memoh_internal_mcp.Connection' + type: array + type: object + mcp.UpsertRequest: + properties: + active: + type: boolean + config: + additionalProperties: {} + type: object + name: + type: string + type: + type: string + type: object memory.DeleteResponse: properties: message: @@ -1913,6 +1980,243 @@ paths: summary: Get history record tags: - history + /bots/{bot_id}/mcp: + get: + description: List MCP connections for a bot + responses: + "200": + description: OK + schema: + $ref: '#/definitions/mcp.ListResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "403": + description: Forbidden + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: List MCP connections + tags: + - mcp + post: + description: Create a MCP connection for a bot + parameters: + - description: MCP payload + in: body + name: payload + required: true + schema: + $ref: '#/definitions/mcp.UpsertRequest' + responses: + "201": + description: Created + schema: + $ref: '#/definitions/github_com_memohai_memoh_internal_mcp.Connection' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "403": + description: Forbidden + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Create MCP connection + tags: + - mcp + /bots/{bot_id}/mcp-stdio: + post: + description: Start a stdio MCP process in the bot container and expose it as + MCP HTTP endpoint. + parameters: + - description: Bot ID + in: path + name: bot_id + required: true + type: string + - description: Stdio MCP payload + in: body + name: payload + required: true + schema: + $ref: '#/definitions/handlers.MCPStdioRequest' + responses: + "200": + description: OK + schema: + $ref: '#/definitions/handlers.MCPStdioResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Create MCP stdio proxy + tags: + - containerd + /bots/{bot_id}/mcp-stdio/{session_id}: + post: + description: Proxies MCP JSON-RPC requests to a stdio MCP process in the container. + parameters: + - description: Bot ID + in: path + name: bot_id + required: true + type: string + - description: Session ID + in: path + name: session_id + required: true + type: string + - description: JSON-RPC request + in: body + name: payload + required: true + schema: + type: object + responses: + "200": + description: 'JSON-RPC response: {jsonrpc,id,result|error}' + schema: + type: object + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: MCP stdio proxy (JSON-RPC) + tags: + - containerd + /bots/{bot_id}/mcp/{id}: + delete: + description: Delete a MCP connection by ID + parameters: + - description: MCP ID + in: path + name: id + required: true + type: string + responses: + "204": + description: No Content + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "403": + description: Forbidden + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Delete MCP connection + tags: + - mcp + get: + description: Get a MCP connection by ID + parameters: + - description: MCP ID + in: path + name: id + required: true + type: string + responses: + "200": + description: OK + schema: + $ref: '#/definitions/github_com_memohai_memoh_internal_mcp.Connection' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "403": + description: Forbidden + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Get MCP connection + tags: + - mcp + put: + description: Update a MCP connection by ID + parameters: + - description: MCP ID + in: path + name: id + required: true + type: string + - description: MCP payload + in: body + name: payload + required: true + schema: + $ref: '#/definitions/mcp.UpsertRequest' + responses: + "200": + description: OK + schema: + $ref: '#/definitions/github_com_memohai_memoh_internal_mcp.Connection' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "403": + description: Forbidden + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Update MCP connection + tags: + - mcp /bots/{bot_id}/memory/add: post: description: 'Add memory for a user via memory. Auth: Bearer JWT determines diff --git a/internal/chat/resolver.go b/internal/chat/resolver.go index 44cbb0c5..70a1b412 100644 --- a/internal/chat/resolver.go +++ b/internal/chat/resolver.go @@ -16,6 +16,7 @@ import ( "github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/history" + "github.com/memohai/memoh/internal/mcp" "github.com/memohai/memoh/internal/memory" "github.com/memohai/memoh/internal/models" "github.com/memohai/memoh/internal/schedule" @@ -44,6 +45,7 @@ type Resolver struct { memoryService *memory.Service historyService *history.Service settingsService *settings.Service + mcpService *mcp.ConnectionService skillLoader SkillLoader gatewayBaseURL string timeout time.Duration @@ -60,6 +62,7 @@ func NewResolver( memoryService *memory.Service, historyService *history.Service, settingsService *settings.Service, + mcpService *mcp.ConnectionService, gatewayBaseURL string, timeout time.Duration, ) *Resolver { @@ -76,6 +79,7 @@ func NewResolver( memoryService: memoryService, historyService: historyService, settingsService: settingsService, + mcpService: mcpService, gatewayBaseURL: gatewayBaseURL, timeout: timeout, logger: log.With(slog.String("service", "chat")), @@ -125,6 +129,7 @@ type gatewayRequest struct { Channels []string `json:"channels"` CurrentChannel string `json:"currentChannel"` AllowedActions []string `json:"allowedActions,omitempty"` + MCPConnections []map[string]any `json:"mcpConnections"` Messages []ModelMessage `json:"messages"` Skills []string `json:"skills"` UsableSkills []gatewaySkill `json:"usableSkills"` @@ -155,6 +160,7 @@ type triggerScheduleRequest struct { Channels []string `json:"channels"` CurrentChannel string `json:"currentChannel"` AllowedActions []string `json:"allowedActions,omitempty"` + MCPConnections []map[string]any `json:"mcpConnections"` Messages []ModelMessage `json:"messages"` Skills []string `json:"skills"` UsableSkills []gatewaySkill `json:"usableSkills"` @@ -166,8 +172,8 @@ type triggerScheduleRequest struct { // --- resolved context (shared by Chat / StreamChat / TriggerSchedule) --- type resolvedContext struct { - payload gatewayRequest - model models.GetResponse + payload gatewayRequest + model models.GetResponse provider sqlc.LlmProvider } @@ -240,6 +246,24 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex usableSkills = []gatewaySkill{} } + mcpConnections := []map[string]any{} + if r.mcpService != nil { + items, err := r.mcpService.ListActiveByBot(ctx, req.BotID) + if err != nil { + r.logger.Warn("failed to load mcp connections", slog.String("bot_id", req.BotID), slog.Any("error", err)) + } else { + for _, item := range items { + payload := map[string]any{} + for k, v := range item.Config { + payload[k] = v + } + payload["name"] = item.Name + payload["type"] = item.Type + mcpConnections = append(mcpConnections, payload) + } + } + } + payload := gatewayRequest{ Model: gatewayModelConfig{ ModelID: chatModel.ModelID, @@ -252,6 +276,7 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex Channels: nonNilStrings(req.Channels), CurrentChannel: req.CurrentChannel, AllowedActions: req.AllowedActions, + MCPConnections: mcpConnections, Messages: nonNilMessages(messages), Skills: nonNilStrings(skills), UsableSkills: usableSkills, @@ -327,6 +352,7 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc Channels: rc.payload.Channels, CurrentChannel: rc.payload.CurrentChannel, AllowedActions: rc.payload.AllowedActions, + MCPConnections: rc.payload.MCPConnections, Messages: rc.payload.Messages, Skills: rc.payload.Skills, UsableSkills: rc.payload.UsableSkills, diff --git a/internal/db/sqlc/mcp.sql.go b/internal/db/sqlc/mcp.sql.go new file mode 100644 index 00000000..3b738171 --- /dev/null +++ b/internal/db/sqlc/mcp.sql.go @@ -0,0 +1,170 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: mcp.sql + +package sqlc + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createMCPConnection = `-- name: CreateMCPConnection :one +INSERT INTO mcp_connections (bot_id, name, type, config, is_active) +VALUES ($1, $2, $3, $4, $5) +RETURNING id, bot_id, name, type, config, is_active, created_at, updated_at +` + +type CreateMCPConnectionParams struct { + BotID pgtype.UUID `json:"bot_id"` + Name string `json:"name"` + Type string `json:"type"` + Config []byte `json:"config"` + IsActive bool `json:"is_active"` +} + +func (q *Queries) CreateMCPConnection(ctx context.Context, arg CreateMCPConnectionParams) (McpConnection, error) { + row := q.db.QueryRow(ctx, createMCPConnection, + arg.BotID, + arg.Name, + arg.Type, + arg.Config, + arg.IsActive, + ) + var i McpConnection + err := row.Scan( + &i.ID, + &i.BotID, + &i.Name, + &i.Type, + &i.Config, + &i.IsActive, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const deleteMCPConnection = `-- name: DeleteMCPConnection :exec +DELETE FROM mcp_connections +WHERE bot_id = $1 AND id = $2 +` + +type DeleteMCPConnectionParams struct { + BotID pgtype.UUID `json:"bot_id"` + ID pgtype.UUID `json:"id"` +} + +func (q *Queries) DeleteMCPConnection(ctx context.Context, arg DeleteMCPConnectionParams) error { + _, err := q.db.Exec(ctx, deleteMCPConnection, arg.BotID, arg.ID) + return err +} + +const getMCPConnectionByID = `-- name: GetMCPConnectionByID :one +SELECT id, bot_id, name, type, config, is_active, created_at, updated_at +FROM mcp_connections +WHERE bot_id = $1 AND id = $2 +LIMIT 1 +` + +type GetMCPConnectionByIDParams struct { + BotID pgtype.UUID `json:"bot_id"` + ID pgtype.UUID `json:"id"` +} + +func (q *Queries) GetMCPConnectionByID(ctx context.Context, arg GetMCPConnectionByIDParams) (McpConnection, error) { + row := q.db.QueryRow(ctx, getMCPConnectionByID, arg.BotID, arg.ID) + var i McpConnection + err := row.Scan( + &i.ID, + &i.BotID, + &i.Name, + &i.Type, + &i.Config, + &i.IsActive, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const listMCPConnectionsByBotID = `-- name: ListMCPConnectionsByBotID :many +SELECT id, bot_id, name, type, config, is_active, created_at, updated_at +FROM mcp_connections +WHERE bot_id = $1 +ORDER BY created_at DESC +` + +func (q *Queries) ListMCPConnectionsByBotID(ctx context.Context, botID pgtype.UUID) ([]McpConnection, error) { + rows, err := q.db.Query(ctx, listMCPConnectionsByBotID, botID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []McpConnection + for rows.Next() { + var i McpConnection + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.Name, + &i.Type, + &i.Config, + &i.IsActive, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateMCPConnection = `-- name: UpdateMCPConnection :one +UPDATE mcp_connections +SET name = $3, + type = $4, + config = $5, + is_active = $6, + updated_at = now() +WHERE bot_id = $1 AND id = $2 +RETURNING id, bot_id, name, type, config, is_active, created_at, updated_at +` + +type UpdateMCPConnectionParams struct { + BotID pgtype.UUID `json:"bot_id"` + ID pgtype.UUID `json:"id"` + Name string `json:"name"` + Type string `json:"type"` + Config []byte `json:"config"` + IsActive bool `json:"is_active"` +} + +func (q *Queries) UpdateMCPConnection(ctx context.Context, arg UpdateMCPConnectionParams) (McpConnection, error) { + row := q.db.QueryRow(ctx, updateMCPConnection, + arg.BotID, + arg.ID, + arg.Name, + arg.Type, + arg.Config, + arg.IsActive, + ) + var i McpConnection + err := row.Scan( + &i.ID, + &i.BotID, + &i.Name, + &i.Type, + &i.Config, + &i.IsActive, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/internal/db/sqlc/models.go b/internal/db/sqlc/models.go index 44b30a0a..0341bf03 100644 --- a/internal/db/sqlc/models.go +++ b/internal/db/sqlc/models.go @@ -169,6 +169,17 @@ type LlmProvider struct { UpdatedAt pgtype.Timestamptz `json:"updated_at"` } +type McpConnection struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + Name string `json:"name"` + Type string `json:"type"` + Config []byte `json:"config"` + IsActive bool `json:"is_active"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + type Model struct { ID pgtype.UUID `json:"id"` ModelID string `json:"model_id"` diff --git a/internal/handlers/containerd.go b/internal/handlers/containerd.go index 1e194ae7..417b9789 100644 --- a/internal/handlers/containerd.go +++ b/internal/handlers/containerd.go @@ -32,15 +32,17 @@ import ( ) type ContainerdHandler struct { - service ctr.Service - cfg config.MCPConfig - namespace string - logger *slog.Logger - mcpMu sync.Mutex - mcpSess map[string]*mcpSession - botService *bots.Service - userService *users.Service - queries *dbsqlc.Queries + service ctr.Service + cfg config.MCPConfig + namespace string + logger *slog.Logger + mcpMu sync.Mutex + mcpSess map[string]*mcpSession + mcpStdioMu sync.Mutex + mcpStdioSess map[string]*mcpStdioSession + botService *bots.Service + userService *users.Service + queries *dbsqlc.Queries } type CreateContainerRequest struct { @@ -94,14 +96,15 @@ type ListSnapshotsResponse struct { func NewContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string, botService *bots.Service, userService *users.Service, queries *dbsqlc.Queries) *ContainerdHandler { return &ContainerdHandler{ - service: service, - cfg: cfg, - namespace: namespace, - logger: log.With(slog.String("handler", "containerd")), - mcpSess: make(map[string]*mcpSession), - botService: botService, - userService: userService, - queries: queries, + service: service, + cfg: cfg, + namespace: namespace, + logger: log.With(slog.String("handler", "containerd")), + mcpSess: make(map[string]*mcpSession), + mcpStdioSess: make(map[string]*mcpStdioSession), + botService: botService, + userService: userService, + queries: queries, } } @@ -118,6 +121,10 @@ func (h *ContainerdHandler) Register(e *echo.Echo) { group.POST("/skills", h.UpsertSkills) group.DELETE("/skills", h.DeleteSkills) group.POST("/fs", h.HandleMCPFS) + + root := e.Group("/bots/:bot_id") + root.POST("/mcp-stdio", h.CreateMCPStdio) + root.POST("/mcp-stdio/:session_id", h.HandleMCPStdio) } // CreateContainer godoc diff --git a/internal/handlers/mcp.go b/internal/handlers/mcp.go new file mode 100644 index 00000000..bfc5d11d --- /dev/null +++ b/internal/handlers/mcp.go @@ -0,0 +1,250 @@ +package handlers + +import ( + "context" + "errors" + "log/slog" + "net/http" + "strings" + + "github.com/jackc/pgx/v5" + "github.com/labstack/echo/v4" + + "github.com/memohai/memoh/internal/auth" + "github.com/memohai/memoh/internal/bots" + "github.com/memohai/memoh/internal/identity" + "github.com/memohai/memoh/internal/mcp" + "github.com/memohai/memoh/internal/users" +) + +type MCPHandler struct { + service *mcp.ConnectionService + botService *bots.Service + userService *users.Service + logger *slog.Logger +} + +func NewMCPHandler(log *slog.Logger, service *mcp.ConnectionService, botService *bots.Service, userService *users.Service) *MCPHandler { + return &MCPHandler{ + service: service, + botService: botService, + userService: userService, + logger: log.With(slog.String("handler", "mcp")), + } +} + +func (h *MCPHandler) Register(e *echo.Echo) { + group := e.Group("/bots/:bot_id/mcp") + group.GET("", h.List) + group.POST("", h.Create) + group.GET("/:id", h.Get) + group.PUT("/:id", h.Update) + group.DELETE("/:id", h.Delete) +} + +// List godoc +// @Summary List MCP connections +// @Description List MCP connections for a bot +// @Tags mcp +// @Success 200 {object} mcp.ListResponse +// @Failure 400 {object} ErrorResponse +// @Failure 403 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /bots/{bot_id}/mcp [get] +func (h *MCPHandler) List(c echo.Context) error { + userID, err := h.requireUserID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + return err + } + items, err := h.service.ListByBot(c.Request().Context(), botID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, mcp.ListResponse{Items: items}) +} + +// Create godoc +// @Summary Create MCP connection +// @Description Create a MCP connection for a bot +// @Tags mcp +// @Param payload body mcp.UpsertRequest true "MCP payload" +// @Success 201 {object} mcp.Connection +// @Failure 400 {object} ErrorResponse +// @Failure 403 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /bots/{bot_id}/mcp [post] +func (h *MCPHandler) Create(c echo.Context) error { + userID, err := h.requireUserID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + return err + } + var req mcp.UpsertRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + resp, err := h.service.Create(c.Request().Context(), botID, req) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + return c.JSON(http.StatusCreated, resp) +} + +// Get godoc +// @Summary Get MCP connection +// @Description Get a MCP connection by ID +// @Tags mcp +// @Param id path string true "MCP ID" +// @Success 200 {object} mcp.Connection +// @Failure 400 {object} ErrorResponse +// @Failure 403 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /bots/{bot_id}/mcp/{id} [get] +func (h *MCPHandler) Get(c echo.Context) error { + userID, err := h.requireUserID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + return err + } + id := strings.TrimSpace(c.Param("id")) + if id == "" { + return echo.NewHTTPError(http.StatusBadRequest, "id is required") + } + resp, err := h.service.Get(c.Request().Context(), botID, id) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return echo.NewHTTPError(http.StatusNotFound, "mcp connection not found") + } + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, resp) +} + +// Update godoc +// @Summary Update MCP connection +// @Description Update a MCP connection by ID +// @Tags mcp +// @Param id path string true "MCP ID" +// @Param payload body mcp.UpsertRequest true "MCP payload" +// @Success 200 {object} mcp.Connection +// @Failure 400 {object} ErrorResponse +// @Failure 403 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /bots/{bot_id}/mcp/{id} [put] +func (h *MCPHandler) Update(c echo.Context) error { + userID, err := h.requireUserID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + return err + } + id := strings.TrimSpace(c.Param("id")) + if id == "" { + return echo.NewHTTPError(http.StatusBadRequest, "id is required") + } + var req mcp.UpsertRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + resp, err := h.service.Update(c.Request().Context(), botID, id, req) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return echo.NewHTTPError(http.StatusNotFound, "mcp connection not found") + } + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + return c.JSON(http.StatusOK, resp) +} + +// Delete godoc +// @Summary Delete MCP connection +// @Description Delete a MCP connection by ID +// @Tags mcp +// @Param id path string true "MCP ID" +// @Success 204 "No Content" +// @Failure 400 {object} ErrorResponse +// @Failure 403 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /bots/{bot_id}/mcp/{id} [delete] +func (h *MCPHandler) Delete(c echo.Context) error { + userID, err := h.requireUserID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + return err + } + id := strings.TrimSpace(c.Param("id")) + if id == "" { + return echo.NewHTTPError(http.StatusBadRequest, "id is required") + } + if err := h.service.Delete(c.Request().Context(), botID, id); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.NoContent(http.StatusNoContent) +} + +func (h *MCPHandler) requireUserID(c echo.Context) (string, error) { + userID, err := auth.UserIDFromContext(c) + if err != nil { + return "", err + } + if err := identity.ValidateUserID(userID); err != nil { + return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + return userID, nil +} + +func (h *MCPHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { + if h.botService == nil || h.userService == nil { + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") + } + isAdmin, err := h.userService.IsAdmin(ctx, actorID) + if err != nil { + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + if err != nil { + if errors.Is(err, bots.ErrBotNotFound) { + return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") + } + if errors.Is(err, bots.ErrBotAccessDenied) { + return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied") + } + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return bot, nil +} diff --git a/internal/handlers/mcp_stdio.go b/internal/handlers/mcp_stdio.go new file mode 100644 index 00000000..a262adaf --- /dev/null +++ b/internal/handlers/mcp_stdio.go @@ -0,0 +1,388 @@ +package handlers + +import ( + "context" + "fmt" + "io" + "log/slog" + "net/http" + "os/exec" + "runtime" + "sort" + "strings" + "time" + + "github.com/google/uuid" + "github.com/labstack/echo/v4" + + ctr "github.com/memohai/memoh/internal/containerd" + mcptools "github.com/memohai/memoh/internal/mcp" +) + +type MCPStdioRequest struct { + Name string `json:"name"` + Command string `json:"command"` + Args []string `json:"args"` + Env map[string]string `json:"env"` + Cwd string `json:"cwd"` +} + +type MCPStdioResponse struct { + SessionID string `json:"session_id"` + URL string `json:"url"` + Tools []string `json:"tools,omitempty"` +} + +type mcpStdioSession struct { + id string + botID string + containerID string + name string + createdAt time.Time + lastUsedAt time.Time + session *mcpSession +} + +// CreateMCPStdio godoc +// @Summary Create MCP stdio proxy +// @Description Start a stdio MCP process in the bot container and expose it as MCP HTTP endpoint. +// @Tags containerd +// @Param bot_id path string true "Bot ID" +// @Param payload body MCPStdioRequest true "Stdio MCP payload" +// @Success 200 {object} MCPStdioResponse +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /bots/{bot_id}/mcp-stdio [post] +func (h *ContainerdHandler) CreateMCPStdio(c echo.Context) error { + botID, err := h.requireBotAccess(c) + if err != nil { + return err + } + var req MCPStdioRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + if strings.TrimSpace(req.Command) == "" { + return echo.NewHTTPError(http.StatusBadRequest, "command is required") + } + ctx := c.Request().Context() + containerID, err := h.botContainerID(ctx, botID) + if err != nil { + return echo.NewHTTPError(http.StatusNotFound, "container not found for bot") + } + if err := h.validateMCPContainer(ctx, containerID, botID); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + if err := h.ensureTaskRunning(ctx, containerID); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + + sess, err := h.startContainerdMCPCommandSession(ctx, containerID, req) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + tools := h.probeMCPTools(ctx, sess, botID, strings.TrimSpace(req.Name)) + sessionID := uuid.NewString() + record := &mcpStdioSession{ + id: sessionID, + botID: botID, + containerID: containerID, + name: strings.TrimSpace(req.Name), + createdAt: time.Now().UTC(), + lastUsedAt: time.Now().UTC(), + session: sess, + } + sess.onClose = func() { + h.mcpStdioMu.Lock() + if current, ok := h.mcpStdioSess[sessionID]; ok && current == record { + delete(h.mcpStdioSess, sessionID) + } + h.mcpStdioMu.Unlock() + } + h.mcpStdioMu.Lock() + h.mcpStdioSess[sessionID] = record + h.mcpStdioMu.Unlock() + + return c.JSON(http.StatusOK, MCPStdioResponse{ + SessionID: sessionID, + URL: fmt.Sprintf("/bots/%s/mcp-stdio/%s", botID, sessionID), + Tools: tools, + }) +} + +// HandleMCPStdio godoc +// @Summary MCP stdio proxy (JSON-RPC) +// @Description Proxies MCP JSON-RPC requests to a stdio MCP process in the container. +// @Tags containerd +// @Param bot_id path string true "Bot ID" +// @Param session_id path string true "Session ID" +// @Param payload body object true "JSON-RPC request" +// @Success 200 {object} object "JSON-RPC response: {jsonrpc,id,result|error}" +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /bots/{bot_id}/mcp-stdio/{session_id} [post] +func (h *ContainerdHandler) HandleMCPStdio(c echo.Context) error { + botID, err := h.requireBotAccess(c) + if err != nil { + return err + } + sessionID := strings.TrimSpace(c.Param("session_id")) + if sessionID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") + } + h.mcpStdioMu.Lock() + session := h.mcpStdioSess[sessionID] + h.mcpStdioMu.Unlock() + if session == nil || session.session == nil || session.botID != botID { + return echo.NewHTTPError(http.StatusNotFound, "mcp session not found") + } + select { + case <-session.session.closed: + return echo.NewHTTPError(http.StatusNotFound, "mcp session closed") + default: + } + + var req mcptools.JSONRPCRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + if req.JSONRPC != "" && req.JSONRPC != "2.0" { + return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32600, "invalid jsonrpc version")) + } + if strings.TrimSpace(req.Method) == "" { + return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32601, "method not found")) + } + session.lastUsedAt = time.Now().UTC() + if mcptools.IsNotification(req) { + if err := session.session.notify(c.Request().Context(), req); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.NoContent(http.StatusAccepted) + } + payload, err := session.session.call(c.Request().Context(), req) + if err != nil { + return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32603, err.Error())) + } + return c.JSON(http.StatusOK, payload) +} + +func (h *ContainerdHandler) startContainerdMCPCommandSession(ctx context.Context, containerID string, req MCPStdioRequest) (*mcpSession, error) { + if runtime.GOOS == "darwin" { + return h.startLimaMCPCommandSession(containerID, req) + } + args := append([]string{strings.TrimSpace(req.Command)}, req.Args...) + env := buildEnvPairs(req.Env) + execSession, err := h.service.ExecTaskStreaming(ctx, containerID, ctr.ExecTaskRequest{ + Args: args, + Env: env, + WorkDir: strings.TrimSpace(req.Cwd), + }) + if err != nil { + return nil, err + } + + sess := &mcpSession{ + stdin: execSession.Stdin, + stdout: execSession.Stdout, + stderr: execSession.Stderr, + pending: make(map[string]chan mcptools.JSONRPCResponse), + closed: make(chan struct{}), + } + h.startMCPStderrLogger(execSession.Stderr, containerID) + go sess.readLoop() + go func() { + _, err := execSession.Wait() + if err != nil { + h.logger.Error("mcp stdio session exited", slog.Any("error", err), slog.String("container_id", containerID)) + sess.closeWithError(err) + } else { + sess.closeWithError(io.EOF) + } + }() + return sess, nil +} + +func buildEnvPairs(env map[string]string) []string { + if len(env) == 0 { + return nil + } + keys := make([]string, 0, len(env)) + for k := range env { + if strings.TrimSpace(k) != "" { + keys = append(keys, k) + } + } + sort.Strings(keys) + out := make([]string, 0, len(keys)) + for _, k := range keys { + out = append(out, fmt.Sprintf("%s=%s", k, env[k])) + } + return out +} + +func (h *ContainerdHandler) probeMCPTools(ctx context.Context, sess *mcpSession, botID, name string) []string { + if sess == nil { + return nil + } + probeCtx, cancel := context.WithTimeout(ctx, 8*time.Second) + defer cancel() + payload, err := sess.call(probeCtx, mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("probe-tools"), + Method: "tools/list", + }) + if err != nil { + h.logger.Warn("mcp stdio tools probe failed", + slog.String("bot_id", botID), + slog.String("name", name), + slog.Any("error", err), + ) + return nil + } + tools := extractToolNames(payload) + if len(tools) == 0 { + h.logger.Warn("mcp stdio tools empty", + slog.String("bot_id", botID), + slog.String("name", name), + ) + } else { + h.logger.Info("mcp stdio tools loaded", + slog.String("bot_id", botID), + slog.String("name", name), + slog.Int("count", len(tools)), + ) + } + return tools +} + +func extractToolNames(payload map[string]any) []string { + result, ok := payload["result"].(map[string]any) + if !ok { + return nil + } + rawTools, ok := result["tools"].([]any) + if !ok { + return nil + } + names := make([]string, 0, len(rawTools)) + for _, raw := range rawTools { + item, ok := raw.(map[string]any) + if !ok { + continue + } + name, _ := item["name"].(string) + name = strings.TrimSpace(name) + if name == "" { + continue + } + names = append(names, name) + } + sort.Strings(names) + return names +} + +func (h *ContainerdHandler) startLimaMCPCommandSession(containerID string, req MCPStdioRequest) (*mcpSession, error) { + execID := fmt.Sprintf("mcp-stdio-%d", time.Now().UnixNano()) + cmdline := buildShellCommand(req) + cmd := exec.Command( + "limactl", + "shell", + "--tty=false", + "default", + "--", + "sudo", + "-n", + "ctr", + "-n", + "default", + "tasks", + "exec", + "--exec-id", + execID, + containerID, + "/bin/sh", + "-lc", + cmdline, + ) + + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, err + } + stdout, err := cmd.StdoutPipe() + if err != nil { + _ = stdin.Close() + return nil, err + } + stderr, err := cmd.StderrPipe() + if err != nil { + _ = stdin.Close() + _ = stdout.Close() + return nil, err + } + if err := cmd.Start(); err != nil { + _ = stdin.Close() + _ = stdout.Close() + _ = stderr.Close() + return nil, err + } + + sess := &mcpSession{ + stdin: stdin, + stdout: stdout, + stderr: stderr, + cmd: cmd, + pending: make(map[string]chan mcptools.JSONRPCResponse), + closed: make(chan struct{}), + } + + h.startMCPStderrLogger(stderr, containerID) + go sess.readLoop() + go func() { + if err := cmd.Wait(); err != nil { + h.logger.Error("mcp stdio session exited", slog.Any("error", err), slog.String("container_id", containerID)) + sess.closeWithError(err) + } else { + sess.closeWithError(io.EOF) + } + }() + + return sess, nil +} + +func buildShellCommand(req MCPStdioRequest) string { + cmd := strings.TrimSpace(req.Command) + if cmd == "" { + return "" + } + parts := make([]string, 0, len(req.Args)+1) + parts = append(parts, escapeShellArg(cmd)) + for _, arg := range req.Args { + parts = append(parts, escapeShellArg(arg)) + } + command := strings.Join(parts, " ") + + assignments := []string{} + for _, pair := range buildEnvPairs(req.Env) { + assignments = append(assignments, escapeShellArg(pair)) + } + if len(assignments) > 0 { + command = strings.Join(assignments, " ") + " " + command + } + if strings.TrimSpace(req.Cwd) != "" { + command = "cd " + escapeShellArg(req.Cwd) + " && " + command + } + return command +} + +func escapeShellArg(value string) string { + if value == "" { + return "''" + } + if !strings.ContainsAny(value, " \t\n'\"\\$&;|<>*?()[]{}!`") { + return value + } + return "'" + strings.ReplaceAll(value, "'", `'\''`) + "'" +} diff --git a/internal/mcp/connections.go b/internal/mcp/connections.go new file mode 100644 index 00000000..f08ba91d --- /dev/null +++ b/internal/mcp/connections.go @@ -0,0 +1,272 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" +) + +// Connection represents a stored MCP connection for a bot. +type Connection struct { + ID string `json:"id"` + BotID string `json:"bot_id"` + Name string `json:"name"` + Type string `json:"type"` + Config map[string]any `json:"config"` + Active bool `json:"active"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// UpsertRequest is the payload for creating or updating MCP connections. +type UpsertRequest struct { + Name string `json:"name"` + Type string `json:"type,omitempty"` + Config map[string]any `json:"config"` + Active *bool `json:"active,omitempty"` +} + +// ListResponse wraps MCP connection list responses. +type ListResponse struct { + Items []Connection `json:"items"` +} + +// ConnectionService handles CRUD operations for MCP connections. +type ConnectionService struct { + queries *sqlc.Queries + logger *slog.Logger +} + +// NewConnectionService creates a ConnectionService backed by sqlc queries. +func NewConnectionService(log *slog.Logger, queries *sqlc.Queries) *ConnectionService { + if log == nil { + log = slog.Default() + } + return &ConnectionService{ + queries: queries, + logger: log.With(slog.String("service", "mcp_connections")), + } +} + +// ListByBot returns all MCP connections for a bot. +func (s *ConnectionService) ListByBot(ctx context.Context, botID string) ([]Connection, error) { + if s.queries == nil { + return nil, fmt.Errorf("mcp queries not configured") + } + pgBotID, err := db.ParseUUID(botID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListMCPConnectionsByBotID(ctx, pgBotID) + if err != nil { + return nil, err + } + items := make([]Connection, 0, len(rows)) + for _, row := range rows { + item, err := normalizeMCPConnection(row) + if err != nil { + return nil, err + } + items = append(items, item) + } + return items, nil +} + +// ListActiveByBot returns active MCP connections for a bot. +func (s *ConnectionService) ListActiveByBot(ctx context.Context, botID string) ([]Connection, error) { + items, err := s.ListByBot(ctx, botID) + if err != nil { + return nil, err + } + active := make([]Connection, 0, len(items)) + for _, item := range items { + if item.Active { + active = append(active, item) + } + } + return active, nil +} + +// Get returns a specific MCP connection for a bot. +func (s *ConnectionService) Get(ctx context.Context, botID, id string) (Connection, error) { + if s.queries == nil { + return Connection{}, fmt.Errorf("mcp queries not configured") + } + pgBotID, err := db.ParseUUID(botID) + if err != nil { + return Connection{}, err + } + pgID, err := db.ParseUUID(id) + if err != nil { + return Connection{}, err + } + row, err := s.queries.GetMCPConnectionByID(ctx, sqlc.GetMCPConnectionByIDParams{ + BotID: pgBotID, + ID: pgID, + }) + if err != nil { + return Connection{}, err + } + return normalizeMCPConnection(row) +} + +// Create inserts a new MCP connection for a bot. +func (s *ConnectionService) Create(ctx context.Context, botID string, req UpsertRequest) (Connection, error) { + if s.queries == nil { + return Connection{}, fmt.Errorf("mcp queries not configured") + } + botUUID, err := db.ParseUUID(botID) + if err != nil { + return Connection{}, err + } + name := strings.TrimSpace(req.Name) + if name == "" { + return Connection{}, fmt.Errorf("name is required") + } + mcpType, config, err := normalizeMCPType(req) + if err != nil { + return Connection{}, err + } + configPayload, err := json.Marshal(config) + if err != nil { + return Connection{}, err + } + active := true + if req.Active != nil { + active = *req.Active + } + row, err := s.queries.CreateMCPConnection(ctx, sqlc.CreateMCPConnectionParams{ + BotID: botUUID, + Name: name, + Type: mcpType, + Config: configPayload, + IsActive: active, + }) + if err != nil { + return Connection{}, err + } + return normalizeMCPConnection(row) +} + +// Update modifies an existing MCP connection. +func (s *ConnectionService) Update(ctx context.Context, botID, id string, req UpsertRequest) (Connection, error) { + if s.queries == nil { + return Connection{}, fmt.Errorf("mcp queries not configured") + } + botUUID, err := db.ParseUUID(botID) + if err != nil { + return Connection{}, err + } + connUUID, err := db.ParseUUID(id) + if err != nil { + return Connection{}, err + } + name := strings.TrimSpace(req.Name) + if name == "" { + return Connection{}, fmt.Errorf("name is required") + } + mcpType, config, err := normalizeMCPType(req) + if err != nil { + return Connection{}, err + } + active := true + if req.Active != nil { + active = *req.Active + } + configPayload, err := json.Marshal(config) + if err != nil { + return Connection{}, err + } + row, err := s.queries.UpdateMCPConnection(ctx, sqlc.UpdateMCPConnectionParams{ + BotID: botUUID, + ID: connUUID, + Name: name, + Type: mcpType, + Config: configPayload, + IsActive: active, + }) + if err != nil { + return Connection{}, err + } + return normalizeMCPConnection(row) +} + +// Delete removes an MCP connection. +func (s *ConnectionService) Delete(ctx context.Context, botID, id string) error { + if s.queries == nil { + return fmt.Errorf("mcp queries not configured") + } + botUUID, err := db.ParseUUID(botID) + if err != nil { + return err + } + connUUID, err := db.ParseUUID(id) + if err != nil { + return err + } + return s.queries.DeleteMCPConnection(ctx, sqlc.DeleteMCPConnectionParams{ + BotID: botUUID, + ID: connUUID, + }) +} + +func normalizeMCPConnection(row sqlc.McpConnection) (Connection, error) { + config, err := decodeMCPConfig(row.Config) + if err != nil { + return Connection{}, err + } + return Connection{ + ID: db.UUIDToString(row.ID), + BotID: db.UUIDToString(row.BotID), + Name: strings.TrimSpace(row.Name), + Type: strings.TrimSpace(row.Type), + Config: config, + Active: row.IsActive, + CreatedAt: db.TimeFromPg(row.CreatedAt), + UpdatedAt: db.TimeFromPg(row.UpdatedAt), + }, nil +} + +func decodeMCPConfig(raw []byte) (map[string]any, error) { + if len(raw) == 0 { + return map[string]any{}, nil + } + var payload map[string]any + if err := json.Unmarshal(raw, &payload); err != nil { + return nil, err + } + if payload == nil { + payload = map[string]any{} + } + return payload, nil +} + +func normalizeMCPType(req UpsertRequest) (string, map[string]any, error) { + config := req.Config + if config == nil { + config = map[string]any{} + } + mcpType := strings.TrimSpace(req.Type) + if mcpType == "" { + if raw, ok := config["type"].(string); ok { + mcpType = strings.TrimSpace(raw) + } + } + mcpType = strings.ToLower(strings.TrimSpace(mcpType)) + if mcpType == "" { + return "", nil, fmt.Errorf("type is required") + } + switch mcpType { + case "stdio", "http", "sse": + default: + return "", nil, fmt.Errorf("unsupported mcp type: %s", mcpType) + } + config["type"] = mcpType + return mcpType, config, nil +} diff --git a/internal/server/server.go b/internal/server/server.go index 639cc97d..b451fa90 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -17,7 +17,7 @@ type Server struct { logger *slog.Logger } -func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, chatHandler *handlers.ChatHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, settingsHandler *handlers.SettingsHandler, historyHandler *handlers.HistoryHandler, contactsHandler *handlers.ContactsHandler, preauthHandler *handlers.PreauthHandler, scheduleHandler *handlers.ScheduleHandler, subagentHandler *handlers.SubagentHandler, containerdHandler *handlers.ContainerdHandler, channelHandler *handlers.ChannelHandler, usersHandler *handlers.UsersHandler, cliHandler *handlers.LocalChannelHandler, webHandler *handlers.LocalChannelHandler) *Server { +func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, chatHandler *handlers.ChatHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, settingsHandler *handlers.SettingsHandler, historyHandler *handlers.HistoryHandler, contactsHandler *handlers.ContactsHandler, preauthHandler *handlers.PreauthHandler, scheduleHandler *handlers.ScheduleHandler, subagentHandler *handlers.SubagentHandler, containerdHandler *handlers.ContainerdHandler, channelHandler *handlers.ChannelHandler, usersHandler *handlers.UsersHandler, mcpHandler *handlers.MCPHandler, cliHandler *handlers.LocalChannelHandler, webHandler *handlers.LocalChannelHandler) *Server { if addr == "" { addr = ":8080" } @@ -102,6 +102,9 @@ func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *han if usersHandler != nil { usersHandler.Register(e) } + if mcpHandler != nil { + mcpHandler.Register(e) + } if cliHandler != nil { cliHandler.Register(e) }