Files
Memoh/internal/handlers/mcp_stdio.go
T
2026-02-12 08:23:25 +08:00

398 lines
10 KiB
Go

package handlers
import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
"os/exec"
"runtime"
"sort"
"strings"
"time"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
ctr "github.com/memohai/memoh/internal/containerd"
mcptools "github.com/memohai/memoh/internal/mcp"
)
type MCPStdioRequest struct {
Name string `json:"name"`
Command string `json:"command"`
Args []string `json:"args"`
Env map[string]string `json:"env"`
Cwd string `json:"cwd"`
}
type MCPStdioResponse struct {
SessionID string `json:"session_id"`
URL string `json:"url"`
Tools []string `json:"tools,omitempty"`
}
type mcpStdioSession struct {
id string
botID string
containerID string
name string
createdAt time.Time
lastUsedAt time.Time
session *mcpSession
}
// CreateMCPStdio godoc
// @Summary Create MCP stdio proxy
// @Description Start a stdio MCP process in the bot container and expose it as MCP HTTP endpoint.
// @Tags containerd
// @Param bot_id path string true "Bot ID"
// @Param payload body MCPStdioRequest true "Stdio MCP payload"
// @Success 200 {object} MCPStdioResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/mcp-stdio [post]
func (h *ContainerdHandler) CreateMCPStdio(c echo.Context) error {
botID, err := h.requireBotAccess(c)
if err != nil {
return err
}
var req MCPStdioRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if strings.TrimSpace(req.Command) == "" {
return echo.NewHTTPError(http.StatusBadRequest, "command is required")
}
ctx := c.Request().Context()
containerID, err := h.botContainerID(ctx, botID)
if err != nil {
return echo.NewHTTPError(http.StatusNotFound, "container not found for bot")
}
if err := h.validateMCPContainer(ctx, containerID, botID); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if err := h.ensureTaskRunning(ctx, containerID); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
sess, err := h.startContainerdMCPCommandSession(ctx, containerID, req)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
tools := h.probeMCPTools(ctx, sess, botID, strings.TrimSpace(req.Name))
sessionID := uuid.NewString()
record := &mcpStdioSession{
id: sessionID,
botID: botID,
containerID: containerID,
name: strings.TrimSpace(req.Name),
createdAt: time.Now().UTC(),
lastUsedAt: time.Now().UTC(),
session: sess,
}
sess.onClose = func() {
h.mcpStdioMu.Lock()
if current, ok := h.mcpStdioSess[sessionID]; ok && current == record {
delete(h.mcpStdioSess, sessionID)
}
h.mcpStdioMu.Unlock()
}
h.mcpStdioMu.Lock()
h.mcpStdioSess[sessionID] = record
h.mcpStdioMu.Unlock()
return c.JSON(http.StatusOK, MCPStdioResponse{
SessionID: sessionID,
URL: fmt.Sprintf("/bots/%s/mcp-stdio/%s", botID, sessionID),
Tools: tools,
})
}
// HandleMCPStdio godoc
// @Summary MCP stdio proxy (JSON-RPC)
// @Description Proxies MCP JSON-RPC requests to a stdio MCP process in the container.
// @Tags containerd
// @Param bot_id path string true "Bot ID"
// @Param session_id path string true "Session ID"
// @Param payload body object true "JSON-RPC request"
// @Success 200 {object} object "JSON-RPC response: {jsonrpc,id,result|error}"
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/mcp-stdio/{session_id} [post]
func (h *ContainerdHandler) HandleMCPStdio(c echo.Context) error {
botID, err := h.requireBotAccess(c)
if err != nil {
return err
}
sessionID := strings.TrimSpace(c.Param("session_id"))
if sessionID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "session_id is required")
}
h.mcpStdioMu.Lock()
session := h.mcpStdioSess[sessionID]
h.mcpStdioMu.Unlock()
if session == nil || session.session == nil || session.botID != botID {
return echo.NewHTTPError(http.StatusNotFound, "mcp session not found")
}
select {
case <-session.session.closed:
return echo.NewHTTPError(http.StatusNotFound, "mcp session closed")
default:
}
var req mcptools.JSONRPCRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if req.JSONRPC != "" && req.JSONRPC != "2.0" {
return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32600, "invalid jsonrpc version"))
}
if strings.TrimSpace(req.Method) == "" {
return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32601, "method not found"))
}
session.lastUsedAt = time.Now().UTC()
if mcptools.IsNotification(req) {
if err := session.session.notify(c.Request().Context(), req); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.NoContent(http.StatusAccepted)
}
payload, err := session.session.call(c.Request().Context(), req)
if err != nil {
return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32603, err.Error()))
}
return c.JSON(http.StatusOK, payload)
}
func (h *ContainerdHandler) startContainerdMCPCommandSession(ctx context.Context, containerID string, req MCPStdioRequest) (*mcpSession, error) {
if runtime.GOOS == "darwin" {
return h.startLimaMCPCommandSession(containerID, req)
}
args := append([]string{strings.TrimSpace(req.Command)}, req.Args...)
env := buildEnvPairs(req.Env)
execSession, err := h.service.ExecTaskStreaming(ctx, containerID, ctr.ExecTaskRequest{
Args: args,
Env: env,
WorkDir: strings.TrimSpace(req.Cwd),
FIFODir: h.mcpFIFODir(),
})
if err != nil {
return nil, err
}
sess := &mcpSession{
stdin: execSession.Stdin,
stdout: execSession.Stdout,
stderr: execSession.Stderr,
pending: make(map[string]chan mcptools.JSONRPCResponse),
closed: make(chan struct{}),
}
h.startMCPStderrLogger(execSession.Stderr, containerID)
go sess.readLoop()
go func() {
_, err := execSession.Wait()
if err != nil {
if isBenignMCPSessionExit(err) {
sess.closeWithError(io.EOF)
return
}
h.logger.Error("mcp stdio session exited", slog.Any("error", err), slog.String("container_id", containerID))
sess.closeWithError(err)
return
}
sess.closeWithError(io.EOF)
}()
return sess, nil
}
func buildEnvPairs(env map[string]string) []string {
if len(env) == 0 {
return nil
}
keys := make([]string, 0, len(env))
for k := range env {
if strings.TrimSpace(k) != "" {
keys = append(keys, k)
}
}
sort.Strings(keys)
out := make([]string, 0, len(keys))
for _, k := range keys {
out = append(out, fmt.Sprintf("%s=%s", k, env[k]))
}
return out
}
func (h *ContainerdHandler) probeMCPTools(ctx context.Context, sess *mcpSession, botID, name string) []string {
if sess == nil {
return nil
}
probeCtx, cancel := context.WithTimeout(ctx, 8*time.Second)
defer cancel()
payload, err := sess.call(probeCtx, mcptools.JSONRPCRequest{
JSONRPC: "2.0",
ID: mcptools.RawStringID("probe-tools"),
Method: "tools/list",
})
if err != nil {
h.logger.Warn("mcp stdio tools probe failed",
slog.String("bot_id", botID),
slog.String("name", name),
slog.Any("error", err),
)
return nil
}
tools := extractToolNames(payload)
if len(tools) == 0 {
h.logger.Warn("mcp stdio tools empty",
slog.String("bot_id", botID),
slog.String("name", name),
)
} else {
h.logger.Info("mcp stdio tools loaded",
slog.String("bot_id", botID),
slog.String("name", name),
slog.Int("count", len(tools)),
)
}
return tools
}
func extractToolNames(payload map[string]any) []string {
result, ok := payload["result"].(map[string]any)
if !ok {
return nil
}
rawTools, ok := result["tools"].([]any)
if !ok {
return nil
}
names := make([]string, 0, len(rawTools))
for _, raw := range rawTools {
item, ok := raw.(map[string]any)
if !ok {
continue
}
name, _ := item["name"].(string)
name = strings.TrimSpace(name)
if name == "" {
continue
}
names = append(names, name)
}
sort.Strings(names)
return names
}
func (h *ContainerdHandler) startLimaMCPCommandSession(containerID string, req MCPStdioRequest) (*mcpSession, error) {
execID := fmt.Sprintf("mcp-stdio-%d", time.Now().UnixNano())
cmdline := buildShellCommand(req)
cmd := exec.Command(
"limactl",
"shell",
"--tty=false",
"default",
"--",
"sudo",
"-n",
"ctr",
"-n",
"default",
"tasks",
"exec",
"--exec-id",
execID,
containerID,
"/bin/sh",
"-lc",
cmdline,
)
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, err
}
stdout, err := cmd.StdoutPipe()
if err != nil {
_ = stdin.Close()
return nil, err
}
stderr, err := cmd.StderrPipe()
if err != nil {
_ = stdin.Close()
_ = stdout.Close()
return nil, err
}
if err := cmd.Start(); err != nil {
_ = stdin.Close()
_ = stdout.Close()
_ = stderr.Close()
return nil, err
}
sess := &mcpSession{
stdin: stdin,
stdout: stdout,
stderr: stderr,
cmd: cmd,
pending: make(map[string]chan mcptools.JSONRPCResponse),
closed: make(chan struct{}),
}
h.startMCPStderrLogger(stderr, containerID)
go sess.readLoop()
go func() {
if err := cmd.Wait(); err != nil {
if isBenignMCPSessionExit(err) {
sess.closeWithError(io.EOF)
return
}
h.logger.Error("mcp stdio session exited", slog.Any("error", err), slog.String("container_id", containerID))
sess.closeWithError(err)
return
}
sess.closeWithError(io.EOF)
}()
return sess, nil
}
func buildShellCommand(req MCPStdioRequest) string {
cmd := strings.TrimSpace(req.Command)
if cmd == "" {
return ""
}
parts := make([]string, 0, len(req.Args)+1)
parts = append(parts, escapeShellArg(cmd))
for _, arg := range req.Args {
parts = append(parts, escapeShellArg(arg))
}
command := strings.Join(parts, " ")
assignments := []string{}
for _, pair := range buildEnvPairs(req.Env) {
assignments = append(assignments, escapeShellArg(pair))
}
if len(assignments) > 0 {
command = strings.Join(assignments, " ") + " " + command
}
if strings.TrimSpace(req.Cwd) != "" {
command = "cd " + escapeShellArg(req.Cwd) + " && " + command
}
return command
}
func escapeShellArg(value string) string {
if value == "" {
return "''"
}
if !strings.ContainsAny(value, " \t\n'\"\\$&;|<>*?()[]{}!`") {
return value
}
return "'" + strings.ReplaceAll(value, "'", `'\''`) + "'"
}