From 9dd71358208c2a131f693a59600c3aeba92aea58 Mon Sep 17 00:00:00 2001 From: BBQ Date: Fri, 13 Feb 2026 00:25:42 +0800 Subject: [PATCH] refactor(mcp): standard mcpServers input format with type inference - Accept standard mcpServers item format (command/args/env/url/headers) - Auto-infer connection type: command -> stdio, url -> http/sse - Add PUT /bots/:bot_id/mcp/import for batch import from mcpServers dict - Add GET /bots/:bot_id/mcp/export for standard format export - Add UpsertMCPConnectionByName SQL for import upsert by name - Preserve is_active state on import upsert --- db/queries/mcp.sql | 9 ++ internal/db/sqlc/mcp.sql.go | 38 ++++++ internal/handlers/mcp.go | 63 +++++++++ internal/mcp/connections.go | 219 +++++++++++++++++++++++++++---- internal/mcp/connections_test.go | 173 ++++++++++++++++++++++++ 5 files changed, 476 insertions(+), 26 deletions(-) create mode 100644 internal/mcp/connections_test.go diff --git a/db/queries/mcp.sql b/db/queries/mcp.sql index 9ccc3ee1..de6621d3 100644 --- a/db/queries/mcp.sql +++ b/db/queries/mcp.sql @@ -28,3 +28,12 @@ RETURNING id, bot_id, name, type, config, is_active, created_at, updated_at; -- name: DeleteMCPConnection :exec DELETE FROM mcp_connections WHERE bot_id = $1 AND id = $2; + +-- name: UpsertMCPConnectionByName :one +INSERT INTO mcp_connections (bot_id, name, type, config) +VALUES ($1, $2, $3, $4) +ON CONFLICT (bot_id, name) +DO UPDATE SET type = EXCLUDED.type, + config = EXCLUDED.config, + updated_at = now() +RETURNING id, bot_id, name, type, config, is_active, created_at, updated_at; diff --git a/internal/db/sqlc/mcp.sql.go b/internal/db/sqlc/mcp.sql.go index 3b738171..5a96c150 100644 --- a/internal/db/sqlc/mcp.sql.go +++ b/internal/db/sqlc/mcp.sql.go @@ -168,3 +168,41 @@ func (q *Queries) UpdateMCPConnection(ctx context.Context, arg UpdateMCPConnecti ) return i, err } + +const upsertMCPConnectionByName = `-- name: UpsertMCPConnectionByName :one +INSERT INTO mcp_connections (bot_id, name, type, config) +VALUES ($1, $2, $3, $4) +ON CONFLICT (bot_id, name) +DO UPDATE SET type = EXCLUDED.type, + config = EXCLUDED.config, + updated_at = now() +RETURNING id, bot_id, name, type, config, is_active, created_at, updated_at +` + +type UpsertMCPConnectionByNameParams struct { + BotID pgtype.UUID `json:"bot_id"` + Name string `json:"name"` + Type string `json:"type"` + Config []byte `json:"config"` +} + +func (q *Queries) UpsertMCPConnectionByName(ctx context.Context, arg UpsertMCPConnectionByNameParams) (McpConnection, error) { + row := q.db.QueryRow(ctx, upsertMCPConnectionByName, + arg.BotID, + arg.Name, + arg.Type, + arg.Config, + ) + var i McpConnection + err := row.Scan( + &i.ID, + &i.BotID, + &i.Name, + &i.Type, + &i.Config, + &i.IsActive, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/internal/handlers/mcp.go b/internal/handlers/mcp.go index 0d1f0b68..f69d863b 100644 --- a/internal/handlers/mcp.go +++ b/internal/handlers/mcp.go @@ -35,6 +35,8 @@ func (h *MCPHandler) Register(e *echo.Echo) { group := e.Group("/bots/:bot_id/mcp") group.GET("", h.List) group.POST("", h.Create) + group.PUT("/import", h.Import) + group.GET("/export", h.Export) group.GET("/:id", h.Get) group.PUT("/:id", h.Update) group.DELETE("/:id", h.Delete) @@ -215,6 +217,67 @@ func (h *MCPHandler) Delete(c echo.Context) error { return c.NoContent(http.StatusNoContent) } +// Import godoc +// @Summary Import MCP connections +// @Description Batch import MCP connections from standard mcpServers format. Existing connections (matched by name) get config updated with is_active preserved. New connections are created as active. +// @Tags mcp +// @Param payload body mcp.ImportRequest true "mcpServers dict" +// @Success 200 {object} mcp.ListResponse +// @Failure 400 {object} ErrorResponse +// @Failure 403 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /bots/{bot_id}/mcp/import [put] +func (h *MCPHandler) Import(c echo.Context) error { + userID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + return err + } + var req mcp.ImportRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + items, err := h.service.Import(c.Request().Context(), botID, req) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + return c.JSON(http.StatusOK, mcp.ListResponse{Items: items}) +} + +// Export godoc +// @Summary Export MCP connections +// @Description Export all MCP connections for a bot in standard mcpServers format. +// @Tags mcp +// @Success 200 {object} mcp.ExportResponse +// @Failure 400 {object} ErrorResponse +// @Failure 403 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /bots/{bot_id}/mcp/export [get] +func (h *MCPHandler) Export(c echo.Context) error { + userID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + return err + } + resp, err := h.service.ExportByBot(c.Request().Context(), botID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, resp) +} + func (h *MCPHandler) requireChannelIdentityID(c echo.Context) (string, error) { return RequireChannelIdentityID(c) } diff --git a/internal/mcp/connections.go b/internal/mcp/connections.go index fff2af65..01c9f82b 100644 --- a/internal/mcp/connections.go +++ b/internal/mcp/connections.go @@ -24,12 +24,34 @@ type Connection struct { UpdatedAt time.Time `json:"updated_at"` } -// UpsertRequest is the payload for creating or updating MCP connections. +// UpsertRequest accepts standard mcpServers item format. +// Type is auto-inferred: command present -> stdio, url present -> http (default) or sse (if transport:"sse"). type UpsertRequest struct { - Name string `json:"name"` - Type string `json:"type,omitempty"` - Config map[string]any `json:"config"` - Active *bool `json:"is_active,omitempty"` + Name string `json:"name"` + Command string `json:"command,omitempty"` + Args []string `json:"args,omitempty"` + Env map[string]string `json:"env,omitempty"` + Cwd string `json:"cwd,omitempty"` + URL string `json:"url,omitempty"` + Headers map[string]string `json:"headers,omitempty"` + Transport string `json:"transport,omitempty"` + Active *bool `json:"is_active,omitempty"` +} + +// ImportRequest accepts a standard mcpServers dict for batch import. +type ImportRequest struct { + MCPServers map[string]MCPServerEntry `json:"mcpServers"` +} + +// MCPServerEntry is one entry in the standard mcpServers dict. +type MCPServerEntry struct { + Command string `json:"command,omitempty"` + Args []string `json:"args,omitempty"` + Env map[string]string `json:"env,omitempty"` + Cwd string `json:"cwd,omitempty"` + URL string `json:"url,omitempty"` + Headers map[string]string `json:"headers,omitempty"` + Transport string `json:"transport,omitempty"` } // ListResponse wraps MCP connection list responses. @@ -37,6 +59,11 @@ type ListResponse struct { Items []Connection `json:"items"` } +// ExportResponse returns connections in standard mcpServers format. +type ExportResponse struct { + MCPServers map[string]MCPServerEntry `json:"mcpServers"` +} + // ConnectionService handles CRUD operations for MCP connections. type ConnectionService struct { queries *sqlc.Queries @@ -129,7 +156,7 @@ func (s *ConnectionService) Create(ctx context.Context, botID string, req Upsert if name == "" { return Connection{}, fmt.Errorf("name is required") } - mcpType, config, err := normalizeMCPType(req) + mcpType, config, err := inferTypeAndConfig(req) if err != nil { return Connection{}, err } @@ -171,7 +198,7 @@ func (s *ConnectionService) Update(ctx context.Context, botID, id string, req Up if name == "" { return Connection{}, fmt.Errorf("name is required") } - mcpType, config, err := normalizeMCPType(req) + mcpType, config, err := inferTypeAndConfig(req) if err != nil { return Connection{}, err } @@ -197,6 +224,67 @@ func (s *ConnectionService) Update(ctx context.Context, botID, id string, req Up return normalizeMCPConnection(row) } +// Import performs a declarative sync from a standard mcpServers dict. +// Existing connections (matched by name) get config updated but is_active preserved. +// New connections are created with is_active=true. +// Connections not in the input are left untouched. +func (s *ConnectionService) Import(ctx context.Context, botID string, req ImportRequest) ([]Connection, error) { + if s.queries == nil { + return nil, fmt.Errorf("mcp queries not configured") + } + botUUID, err := db.ParseUUID(botID) + if err != nil { + return nil, err + } + if len(req.MCPServers) == 0 { + return []Connection{}, nil + } + results := make([]Connection, 0, len(req.MCPServers)) + for name, entry := range req.MCPServers { + name = strings.TrimSpace(name) + if name == "" { + continue + } + upsert := entryToUpsertRequest(name, entry) + mcpType, config, err := inferTypeAndConfig(upsert) + if err != nil { + return nil, fmt.Errorf("server %q: %w", name, err) + } + configPayload, err := json.Marshal(config) + if err != nil { + return nil, err + } + row, err := s.queries.UpsertMCPConnectionByName(ctx, sqlc.UpsertMCPConnectionByNameParams{ + BotID: botUUID, + Name: name, + Type: mcpType, + Config: configPayload, + }) + if err != nil { + return nil, fmt.Errorf("server %q: %w", name, err) + } + conn, err := normalizeMCPConnection(row) + if err != nil { + return nil, err + } + results = append(results, conn) + } + return results, nil +} + +// ExportByBot returns all connections for a bot in standard mcpServers format. +func (s *ConnectionService) ExportByBot(ctx context.Context, botID string) (ExportResponse, error) { + items, err := s.ListByBot(ctx, botID) + if err != nil { + return ExportResponse{}, err + } + servers := make(map[string]MCPServerEntry, len(items)) + for _, conn := range items { + servers[conn.Name] = connectionToExportEntry(conn) + } + return ExportResponse{MCPServers: servers}, nil +} + // Delete removes an MCP connection. func (s *ConnectionService) Delete(ctx context.Context, botID, id string) error { if s.queries == nil { @@ -247,26 +335,105 @@ func decodeMCPConfig(raw []byte) (map[string]any, error) { return payload, nil } -func normalizeMCPType(req UpsertRequest) (string, map[string]any, error) { - config := req.Config - if config == nil { - config = map[string]any{} +// inferTypeAndConfig builds internal type + config from a standard mcpServers item. +func inferTypeAndConfig(req UpsertRequest) (string, map[string]any, error) { + hasCommand := strings.TrimSpace(req.Command) != "" + hasURL := strings.TrimSpace(req.URL) != "" + + if !hasCommand && !hasURL { + return "", nil, fmt.Errorf("command or url is required") } - mcpType := strings.TrimSpace(req.Type) - if mcpType == "" { - if raw, ok := config["type"].(string); ok { - mcpType = strings.TrimSpace(raw) + if hasCommand && hasURL { + return "", nil, fmt.Errorf("command and url are mutually exclusive") + } + + config := map[string]any{} + + if hasCommand { + config["command"] = strings.TrimSpace(req.Command) + if len(req.Args) > 0 { + config["args"] = req.Args + } + if len(req.Env) > 0 { + config["env"] = req.Env + } + if strings.TrimSpace(req.Cwd) != "" { + config["cwd"] = strings.TrimSpace(req.Cwd) + } + return "stdio", config, nil + } + + config["url"] = strings.TrimSpace(req.URL) + if len(req.Headers) > 0 { + config["headers"] = req.Headers + } + transport := strings.ToLower(strings.TrimSpace(req.Transport)) + if transport == "sse" { + return "sse", config, nil + } + return "http", config, nil +} + +// entryToUpsertRequest converts a named MCPServerEntry to an UpsertRequest. +func entryToUpsertRequest(name string, entry MCPServerEntry) UpsertRequest { + return UpsertRequest{ + Name: name, + Command: entry.Command, + Args: entry.Args, + Env: entry.Env, + Cwd: entry.Cwd, + URL: entry.URL, + Headers: entry.Headers, + Transport: entry.Transport, + } +} + +// connectionToExportEntry converts a stored connection to standard mcpServers entry. +func connectionToExportEntry(conn Connection) MCPServerEntry { + entry := MCPServerEntry{} + switch conn.Type { + case "stdio": + entry.Command, _ = conn.Config["command"].(string) + if rawArgs, ok := conn.Config["args"]; ok { + switch v := rawArgs.(type) { + case []any: + for _, a := range v { + if s, ok := a.(string); ok { + entry.Args = append(entry.Args, s) + } + } + case []string: + entry.Args = v + } + } + if rawEnv, ok := conn.Config["env"]; ok { + if m, ok := rawEnv.(map[string]any); ok { + entry.Env = make(map[string]string, len(m)) + for k, v := range m { + if s, ok := v.(string); ok { + entry.Env[k] = s + } + } + } + } + if cwd, ok := conn.Config["cwd"].(string); ok && cwd != "" { + entry.Cwd = cwd + } + case "http", "sse": + entry.URL, _ = conn.Config["url"].(string) + if rawHeaders, ok := conn.Config["headers"]; ok { + if m, ok := rawHeaders.(map[string]any); ok { + entry.Headers = make(map[string]string, len(m)) + for k, v := range m { + if s, ok := v.(string); ok { + entry.Headers[k] = s + } + } + } + } + if conn.Type == "sse" { + entry.Transport = "sse" } } - mcpType = strings.ToLower(strings.TrimSpace(mcpType)) - if mcpType == "" { - return "", nil, fmt.Errorf("type is required") - } - switch mcpType { - case "stdio", "http", "sse": - default: - return "", nil, fmt.Errorf("unsupported mcp type: %s", mcpType) - } - config["type"] = mcpType - return mcpType, config, nil + return entry } diff --git a/internal/mcp/connections_test.go b/internal/mcp/connections_test.go new file mode 100644 index 00000000..01590ee5 --- /dev/null +++ b/internal/mcp/connections_test.go @@ -0,0 +1,173 @@ +package mcp + +import ( + "testing" +) + +func TestInferTypeAndConfig_Stdio(t *testing.T) { + req := UpsertRequest{ + Name: "filesystem", + Command: "npx", + Args: []string{"-y", "@modelcontextprotocol/server-filesystem", "/path"}, + Env: map[string]string{"TOKEN": "abc"}, + Cwd: "/workspace", + } + typ, config, err := inferTypeAndConfig(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if typ != "stdio" { + t.Fatalf("expected type stdio, got %s", typ) + } + if config["command"] != "npx" { + t.Fatalf("expected command npx, got %v", config["command"]) + } + args, ok := config["args"].([]string) + if !ok || len(args) != 3 { + t.Fatalf("expected 3 args, got %v", config["args"]) + } + env, ok := config["env"].(map[string]string) + if !ok || env["TOKEN"] != "abc" { + t.Fatalf("expected env TOKEN=abc, got %v", config["env"]) + } + if config["cwd"] != "/workspace" { + t.Fatalf("expected cwd /workspace, got %v", config["cwd"]) + } +} + +func TestInferTypeAndConfig_HTTP(t *testing.T) { + req := UpsertRequest{ + Name: "remote", + URL: "https://example.com/mcp", + Headers: map[string]string{"Authorization": "Bearer sk-xxx"}, + } + typ, config, err := inferTypeAndConfig(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if typ != "http" { + t.Fatalf("expected type http, got %s", typ) + } + if config["url"] != "https://example.com/mcp" { + t.Fatalf("expected url, got %v", config["url"]) + } + headers, ok := config["headers"].(map[string]string) + if !ok || headers["Authorization"] != "Bearer sk-xxx" { + t.Fatalf("expected headers, got %v", config["headers"]) + } +} + +func TestInferTypeAndConfig_SSE(t *testing.T) { + req := UpsertRequest{ + Name: "sse-server", + URL: "https://example.com/sse", + Transport: "sse", + } + typ, _, err := inferTypeAndConfig(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if typ != "sse" { + t.Fatalf("expected type sse, got %s", typ) + } +} + +func TestInferTypeAndConfig_NoCommandNoURL(t *testing.T) { + req := UpsertRequest{Name: "empty"} + _, _, err := inferTypeAndConfig(req) + if err == nil { + t.Fatal("expected error for missing command and url") + } +} + +func TestInferTypeAndConfig_BothCommandAndURL(t *testing.T) { + req := UpsertRequest{ + Name: "conflict", + Command: "npx", + URL: "https://example.com", + } + _, _, err := inferTypeAndConfig(req) + if err == nil { + t.Fatal("expected error for both command and url") + } +} + +func TestConnectionToExportEntry_Stdio(t *testing.T) { + conn := Connection{ + Name: "fs", + Type: "stdio", + Config: map[string]any{ + "command": "npx", + "args": []any{"-y", "server"}, + "env": map[string]any{"KEY": "val"}, + "cwd": "/work", + }, + } + entry := connectionToExportEntry(conn) + if entry.Command != "npx" { + t.Fatalf("expected command npx, got %s", entry.Command) + } + if len(entry.Args) != 2 { + t.Fatalf("expected 2 args, got %v", entry.Args) + } + if entry.Env["KEY"] != "val" { + t.Fatalf("expected env KEY=val, got %v", entry.Env) + } + if entry.Cwd != "/work" { + t.Fatalf("expected cwd /work, got %s", entry.Cwd) + } + if entry.URL != "" { + t.Fatalf("expected empty url, got %s", entry.URL) + } +} + +func TestConnectionToExportEntry_HTTP(t *testing.T) { + conn := Connection{ + Name: "remote", + Type: "http", + Config: map[string]any{ + "url": "https://example.com/mcp", + "headers": map[string]any{"Authorization": "Bearer xxx"}, + }, + } + entry := connectionToExportEntry(conn) + if entry.URL != "https://example.com/mcp" { + t.Fatalf("expected url, got %s", entry.URL) + } + if entry.Headers["Authorization"] != "Bearer xxx" { + t.Fatalf("expected headers, got %v", entry.Headers) + } + if entry.Transport != "" { + t.Fatalf("expected empty transport for http, got %s", entry.Transport) + } +} + +func TestConnectionToExportEntry_SSE(t *testing.T) { + conn := Connection{ + Name: "sse", + Type: "sse", + Config: map[string]any{"url": "https://example.com/sse"}, + } + entry := connectionToExportEntry(conn) + if entry.Transport != "sse" { + t.Fatalf("expected transport sse, got %s", entry.Transport) + } +} + +func TestEntryToUpsertRequest(t *testing.T) { + entry := MCPServerEntry{ + Command: "npx", + Args: []string{"-y", "server"}, + Env: map[string]string{"KEY": "val"}, + } + req := entryToUpsertRequest("test-server", entry) + if req.Name != "test-server" { + t.Fatalf("expected name test-server, got %s", req.Name) + } + if req.Command != "npx" { + t.Fatalf("expected command npx, got %s", req.Command) + } + if len(req.Args) != 2 { + t.Fatalf("expected 2 args, got %v", req.Args) + } +}