feat: mcp (#31)

* feat: add mcp connections table and related crud api

* feat: mcp-stdio api
This commit is contained in:
Acbox Liu
2026-02-09 20:07:40 +08:00
committed by GitHub
parent 92838ef8da
commit 8ea779779e
17 changed files with 2486 additions and 32 deletions
+24 -17
View File
@@ -32,15 +32,17 @@ import (
)
type ContainerdHandler struct {
service ctr.Service
cfg config.MCPConfig
namespace string
logger *slog.Logger
mcpMu sync.Mutex
mcpSess map[string]*mcpSession
botService *bots.Service
userService *users.Service
queries *dbsqlc.Queries
service ctr.Service
cfg config.MCPConfig
namespace string
logger *slog.Logger
mcpMu sync.Mutex
mcpSess map[string]*mcpSession
mcpStdioMu sync.Mutex
mcpStdioSess map[string]*mcpStdioSession
botService *bots.Service
userService *users.Service
queries *dbsqlc.Queries
}
type CreateContainerRequest struct {
@@ -94,14 +96,15 @@ type ListSnapshotsResponse struct {
func NewContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string, botService *bots.Service, userService *users.Service, queries *dbsqlc.Queries) *ContainerdHandler {
return &ContainerdHandler{
service: service,
cfg: cfg,
namespace: namespace,
logger: log.With(slog.String("handler", "containerd")),
mcpSess: make(map[string]*mcpSession),
botService: botService,
userService: userService,
queries: queries,
service: service,
cfg: cfg,
namespace: namespace,
logger: log.With(slog.String("handler", "containerd")),
mcpSess: make(map[string]*mcpSession),
mcpStdioSess: make(map[string]*mcpStdioSession),
botService: botService,
userService: userService,
queries: queries,
}
}
@@ -118,6 +121,10 @@ func (h *ContainerdHandler) Register(e *echo.Echo) {
group.POST("/skills", h.UpsertSkills)
group.DELETE("/skills", h.DeleteSkills)
group.POST("/fs", h.HandleMCPFS)
root := e.Group("/bots/:bot_id")
root.POST("/mcp-stdio", h.CreateMCPStdio)
root.POST("/mcp-stdio/:session_id", h.HandleMCPStdio)
}
// CreateContainer godoc
+250
View File
@@ -0,0 +1,250 @@
package handlers
import (
"context"
"errors"
"log/slog"
"net/http"
"strings"
"github.com/jackc/pgx/v5"
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/bots"
"github.com/memohai/memoh/internal/identity"
"github.com/memohai/memoh/internal/mcp"
"github.com/memohai/memoh/internal/users"
)
type MCPHandler struct {
service *mcp.ConnectionService
botService *bots.Service
userService *users.Service
logger *slog.Logger
}
func NewMCPHandler(log *slog.Logger, service *mcp.ConnectionService, botService *bots.Service, userService *users.Service) *MCPHandler {
return &MCPHandler{
service: service,
botService: botService,
userService: userService,
logger: log.With(slog.String("handler", "mcp")),
}
}
func (h *MCPHandler) Register(e *echo.Echo) {
group := e.Group("/bots/:bot_id/mcp")
group.GET("", h.List)
group.POST("", h.Create)
group.GET("/:id", h.Get)
group.PUT("/:id", h.Update)
group.DELETE("/:id", h.Delete)
}
// List godoc
// @Summary List MCP connections
// @Description List MCP connections for a bot
// @Tags mcp
// @Success 200 {object} mcp.ListResponse
// @Failure 400 {object} ErrorResponse
// @Failure 403 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/mcp [get]
func (h *MCPHandler) List(c echo.Context) error {
userID, err := h.requireUserID(c)
if err != nil {
return err
}
botID := strings.TrimSpace(c.Param("bot_id"))
if botID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
}
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
return err
}
items, err := h.service.ListByBot(c.Request().Context(), botID)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.JSON(http.StatusOK, mcp.ListResponse{Items: items})
}
// Create godoc
// @Summary Create MCP connection
// @Description Create a MCP connection for a bot
// @Tags mcp
// @Param payload body mcp.UpsertRequest true "MCP payload"
// @Success 201 {object} mcp.Connection
// @Failure 400 {object} ErrorResponse
// @Failure 403 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/mcp [post]
func (h *MCPHandler) Create(c echo.Context) error {
userID, err := h.requireUserID(c)
if err != nil {
return err
}
botID := strings.TrimSpace(c.Param("bot_id"))
if botID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
}
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
return err
}
var req mcp.UpsertRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
resp, err := h.service.Create(c.Request().Context(), botID, req)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return c.JSON(http.StatusCreated, resp)
}
// Get godoc
// @Summary Get MCP connection
// @Description Get a MCP connection by ID
// @Tags mcp
// @Param id path string true "MCP ID"
// @Success 200 {object} mcp.Connection
// @Failure 400 {object} ErrorResponse
// @Failure 403 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/mcp/{id} [get]
func (h *MCPHandler) Get(c echo.Context) error {
userID, err := h.requireUserID(c)
if err != nil {
return err
}
botID := strings.TrimSpace(c.Param("bot_id"))
if botID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
}
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
return err
}
id := strings.TrimSpace(c.Param("id"))
if id == "" {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
resp, err := h.service.Get(c.Request().Context(), botID, id)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return echo.NewHTTPError(http.StatusNotFound, "mcp connection not found")
}
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.JSON(http.StatusOK, resp)
}
// Update godoc
// @Summary Update MCP connection
// @Description Update a MCP connection by ID
// @Tags mcp
// @Param id path string true "MCP ID"
// @Param payload body mcp.UpsertRequest true "MCP payload"
// @Success 200 {object} mcp.Connection
// @Failure 400 {object} ErrorResponse
// @Failure 403 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/mcp/{id} [put]
func (h *MCPHandler) Update(c echo.Context) error {
userID, err := h.requireUserID(c)
if err != nil {
return err
}
botID := strings.TrimSpace(c.Param("bot_id"))
if botID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
}
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
return err
}
id := strings.TrimSpace(c.Param("id"))
if id == "" {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
var req mcp.UpsertRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
resp, err := h.service.Update(c.Request().Context(), botID, id, req)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return echo.NewHTTPError(http.StatusNotFound, "mcp connection not found")
}
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return c.JSON(http.StatusOK, resp)
}
// Delete godoc
// @Summary Delete MCP connection
// @Description Delete a MCP connection by ID
// @Tags mcp
// @Param id path string true "MCP ID"
// @Success 204 "No Content"
// @Failure 400 {object} ErrorResponse
// @Failure 403 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/mcp/{id} [delete]
func (h *MCPHandler) Delete(c echo.Context) error {
userID, err := h.requireUserID(c)
if err != nil {
return err
}
botID := strings.TrimSpace(c.Param("bot_id"))
if botID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
}
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
return err
}
id := strings.TrimSpace(c.Param("id"))
if id == "" {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
if err := h.service.Delete(c.Request().Context(), botID, id); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.NoContent(http.StatusNoContent)
}
func (h *MCPHandler) requireUserID(c echo.Context) (string, error) {
userID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateUserID(userID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return userID, nil
}
func (h *MCPHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) {
if h.botService == nil || h.userService == nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
}
isAdmin, err := h.userService.IsAdmin(ctx, actorID)
if err != nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
if err != nil {
if errors.Is(err, bots.ErrBotNotFound) {
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
}
if errors.Is(err, bots.ErrBotAccessDenied) {
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
}
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return bot, nil
}
+388
View File
@@ -0,0 +1,388 @@
package handlers
import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
"os/exec"
"runtime"
"sort"
"strings"
"time"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
ctr "github.com/memohai/memoh/internal/containerd"
mcptools "github.com/memohai/memoh/internal/mcp"
)
type MCPStdioRequest struct {
Name string `json:"name"`
Command string `json:"command"`
Args []string `json:"args"`
Env map[string]string `json:"env"`
Cwd string `json:"cwd"`
}
type MCPStdioResponse struct {
SessionID string `json:"session_id"`
URL string `json:"url"`
Tools []string `json:"tools,omitempty"`
}
type mcpStdioSession struct {
id string
botID string
containerID string
name string
createdAt time.Time
lastUsedAt time.Time
session *mcpSession
}
// CreateMCPStdio godoc
// @Summary Create MCP stdio proxy
// @Description Start a stdio MCP process in the bot container and expose it as MCP HTTP endpoint.
// @Tags containerd
// @Param bot_id path string true "Bot ID"
// @Param payload body MCPStdioRequest true "Stdio MCP payload"
// @Success 200 {object} MCPStdioResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/mcp-stdio [post]
func (h *ContainerdHandler) CreateMCPStdio(c echo.Context) error {
botID, err := h.requireBotAccess(c)
if err != nil {
return err
}
var req MCPStdioRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if strings.TrimSpace(req.Command) == "" {
return echo.NewHTTPError(http.StatusBadRequest, "command is required")
}
ctx := c.Request().Context()
containerID, err := h.botContainerID(ctx, botID)
if err != nil {
return echo.NewHTTPError(http.StatusNotFound, "container not found for bot")
}
if err := h.validateMCPContainer(ctx, containerID, botID); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if err := h.ensureTaskRunning(ctx, containerID); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
sess, err := h.startContainerdMCPCommandSession(ctx, containerID, req)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
tools := h.probeMCPTools(ctx, sess, botID, strings.TrimSpace(req.Name))
sessionID := uuid.NewString()
record := &mcpStdioSession{
id: sessionID,
botID: botID,
containerID: containerID,
name: strings.TrimSpace(req.Name),
createdAt: time.Now().UTC(),
lastUsedAt: time.Now().UTC(),
session: sess,
}
sess.onClose = func() {
h.mcpStdioMu.Lock()
if current, ok := h.mcpStdioSess[sessionID]; ok && current == record {
delete(h.mcpStdioSess, sessionID)
}
h.mcpStdioMu.Unlock()
}
h.mcpStdioMu.Lock()
h.mcpStdioSess[sessionID] = record
h.mcpStdioMu.Unlock()
return c.JSON(http.StatusOK, MCPStdioResponse{
SessionID: sessionID,
URL: fmt.Sprintf("/bots/%s/mcp-stdio/%s", botID, sessionID),
Tools: tools,
})
}
// HandleMCPStdio godoc
// @Summary MCP stdio proxy (JSON-RPC)
// @Description Proxies MCP JSON-RPC requests to a stdio MCP process in the container.
// @Tags containerd
// @Param bot_id path string true "Bot ID"
// @Param session_id path string true "Session ID"
// @Param payload body object true "JSON-RPC request"
// @Success 200 {object} object "JSON-RPC response: {jsonrpc,id,result|error}"
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/mcp-stdio/{session_id} [post]
func (h *ContainerdHandler) HandleMCPStdio(c echo.Context) error {
botID, err := h.requireBotAccess(c)
if err != nil {
return err
}
sessionID := strings.TrimSpace(c.Param("session_id"))
if sessionID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "session_id is required")
}
h.mcpStdioMu.Lock()
session := h.mcpStdioSess[sessionID]
h.mcpStdioMu.Unlock()
if session == nil || session.session == nil || session.botID != botID {
return echo.NewHTTPError(http.StatusNotFound, "mcp session not found")
}
select {
case <-session.session.closed:
return echo.NewHTTPError(http.StatusNotFound, "mcp session closed")
default:
}
var req mcptools.JSONRPCRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if req.JSONRPC != "" && req.JSONRPC != "2.0" {
return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32600, "invalid jsonrpc version"))
}
if strings.TrimSpace(req.Method) == "" {
return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32601, "method not found"))
}
session.lastUsedAt = time.Now().UTC()
if mcptools.IsNotification(req) {
if err := session.session.notify(c.Request().Context(), req); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.NoContent(http.StatusAccepted)
}
payload, err := session.session.call(c.Request().Context(), req)
if err != nil {
return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32603, err.Error()))
}
return c.JSON(http.StatusOK, payload)
}
func (h *ContainerdHandler) startContainerdMCPCommandSession(ctx context.Context, containerID string, req MCPStdioRequest) (*mcpSession, error) {
if runtime.GOOS == "darwin" {
return h.startLimaMCPCommandSession(containerID, req)
}
args := append([]string{strings.TrimSpace(req.Command)}, req.Args...)
env := buildEnvPairs(req.Env)
execSession, err := h.service.ExecTaskStreaming(ctx, containerID, ctr.ExecTaskRequest{
Args: args,
Env: env,
WorkDir: strings.TrimSpace(req.Cwd),
})
if err != nil {
return nil, err
}
sess := &mcpSession{
stdin: execSession.Stdin,
stdout: execSession.Stdout,
stderr: execSession.Stderr,
pending: make(map[string]chan mcptools.JSONRPCResponse),
closed: make(chan struct{}),
}
h.startMCPStderrLogger(execSession.Stderr, containerID)
go sess.readLoop()
go func() {
_, err := execSession.Wait()
if err != nil {
h.logger.Error("mcp stdio session exited", slog.Any("error", err), slog.String("container_id", containerID))
sess.closeWithError(err)
} else {
sess.closeWithError(io.EOF)
}
}()
return sess, nil
}
func buildEnvPairs(env map[string]string) []string {
if len(env) == 0 {
return nil
}
keys := make([]string, 0, len(env))
for k := range env {
if strings.TrimSpace(k) != "" {
keys = append(keys, k)
}
}
sort.Strings(keys)
out := make([]string, 0, len(keys))
for _, k := range keys {
out = append(out, fmt.Sprintf("%s=%s", k, env[k]))
}
return out
}
func (h *ContainerdHandler) probeMCPTools(ctx context.Context, sess *mcpSession, botID, name string) []string {
if sess == nil {
return nil
}
probeCtx, cancel := context.WithTimeout(ctx, 8*time.Second)
defer cancel()
payload, err := sess.call(probeCtx, mcptools.JSONRPCRequest{
JSONRPC: "2.0",
ID: mcptools.RawStringID("probe-tools"),
Method: "tools/list",
})
if err != nil {
h.logger.Warn("mcp stdio tools probe failed",
slog.String("bot_id", botID),
slog.String("name", name),
slog.Any("error", err),
)
return nil
}
tools := extractToolNames(payload)
if len(tools) == 0 {
h.logger.Warn("mcp stdio tools empty",
slog.String("bot_id", botID),
slog.String("name", name),
)
} else {
h.logger.Info("mcp stdio tools loaded",
slog.String("bot_id", botID),
slog.String("name", name),
slog.Int("count", len(tools)),
)
}
return tools
}
func extractToolNames(payload map[string]any) []string {
result, ok := payload["result"].(map[string]any)
if !ok {
return nil
}
rawTools, ok := result["tools"].([]any)
if !ok {
return nil
}
names := make([]string, 0, len(rawTools))
for _, raw := range rawTools {
item, ok := raw.(map[string]any)
if !ok {
continue
}
name, _ := item["name"].(string)
name = strings.TrimSpace(name)
if name == "" {
continue
}
names = append(names, name)
}
sort.Strings(names)
return names
}
func (h *ContainerdHandler) startLimaMCPCommandSession(containerID string, req MCPStdioRequest) (*mcpSession, error) {
execID := fmt.Sprintf("mcp-stdio-%d", time.Now().UnixNano())
cmdline := buildShellCommand(req)
cmd := exec.Command(
"limactl",
"shell",
"--tty=false",
"default",
"--",
"sudo",
"-n",
"ctr",
"-n",
"default",
"tasks",
"exec",
"--exec-id",
execID,
containerID,
"/bin/sh",
"-lc",
cmdline,
)
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, err
}
stdout, err := cmd.StdoutPipe()
if err != nil {
_ = stdin.Close()
return nil, err
}
stderr, err := cmd.StderrPipe()
if err != nil {
_ = stdin.Close()
_ = stdout.Close()
return nil, err
}
if err := cmd.Start(); err != nil {
_ = stdin.Close()
_ = stdout.Close()
_ = stderr.Close()
return nil, err
}
sess := &mcpSession{
stdin: stdin,
stdout: stdout,
stderr: stderr,
cmd: cmd,
pending: make(map[string]chan mcptools.JSONRPCResponse),
closed: make(chan struct{}),
}
h.startMCPStderrLogger(stderr, containerID)
go sess.readLoop()
go func() {
if err := cmd.Wait(); err != nil {
h.logger.Error("mcp stdio session exited", slog.Any("error", err), slog.String("container_id", containerID))
sess.closeWithError(err)
} else {
sess.closeWithError(io.EOF)
}
}()
return sess, nil
}
func buildShellCommand(req MCPStdioRequest) string {
cmd := strings.TrimSpace(req.Command)
if cmd == "" {
return ""
}
parts := make([]string, 0, len(req.Args)+1)
parts = append(parts, escapeShellArg(cmd))
for _, arg := range req.Args {
parts = append(parts, escapeShellArg(arg))
}
command := strings.Join(parts, " ")
assignments := []string{}
for _, pair := range buildEnvPairs(req.Env) {
assignments = append(assignments, escapeShellArg(pair))
}
if len(assignments) > 0 {
command = strings.Join(assignments, " ") + " " + command
}
if strings.TrimSpace(req.Cwd) != "" {
command = "cd " + escapeShellArg(req.Cwd) + " && " + command
}
return command
}
func escapeShellArg(value string) string {
if value == "" {
return "''"
}
if !strings.ContainsAny(value, " \t\n'\"\\$&;|<>*?()[]{}!`") {
return value
}
return "'" + strings.ReplaceAll(value, "'", `'\''`) + "'"
}