mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
483 lines
12 KiB
Go
483 lines
12 KiB
Go
package handlers
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"os/exec"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/containerd/containerd/v2/pkg/namespaces"
|
|
"github.com/containerd/errdefs"
|
|
"github.com/labstack/echo/v4"
|
|
|
|
ctr "github.com/memohai/memoh/internal/containerd"
|
|
mcptools "github.com/memohai/memoh/internal/mcp"
|
|
)
|
|
|
|
// HandleMCPFS godoc
|
|
// @Summary MCP filesystem tools (JSON-RPC)
|
|
// @Description Forwards MCP JSON-RPC requests to the MCP server inside the container.
|
|
// @Description Required:
|
|
// @Description - container task is running
|
|
// @Description - container has data mount (default /data) bound to <data_root>/users/<user_id>
|
|
// @Description - container image contains the "mcp" binary
|
|
// @Description Auth: Bearer JWT is used to determine user_id (sub or user_id).
|
|
// @Description Paths must be relative (no leading slash) and must not contain "..".
|
|
// @Description
|
|
// @Description Example: tools/list
|
|
// @Description {"jsonrpc":"2.0","id":1,"method":"tools/list"}
|
|
// @Description
|
|
// @Description Example: tools/call (fs.read)
|
|
// @Description {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"fs.read","arguments":{"path":"notes.txt"}}}
|
|
// @Tags containerd
|
|
// @Param Authorization header string true "Bearer <token>"
|
|
// @Param bot_id path string true "Bot ID"
|
|
// @Param payload body object true "JSON-RPC request"
|
|
// @Success 200 {object} object "JSON-RPC response: {jsonrpc,id,result|error}"
|
|
// @Failure 400 {object} ErrorResponse
|
|
// @Failure 404 {object} ErrorResponse
|
|
// @Failure 500 {object} ErrorResponse
|
|
// @Router /bots/{bot_id}/container/fs [post]
|
|
func (h *ContainerdHandler) HandleMCPFS(c echo.Context) error {
|
|
botID, err := h.requireBotAccess(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ctx := c.Request().Context()
|
|
containerID, err := h.botContainerID(ctx, botID)
|
|
if err != nil {
|
|
return echo.NewHTTPError(http.StatusNotFound, "container not found for bot")
|
|
}
|
|
|
|
var req mcptools.JSONRPCRequest
|
|
if err := c.Bind(&req); err != nil {
|
|
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
|
}
|
|
if req.JSONRPC != "" && req.JSONRPC != "2.0" {
|
|
return c.JSON(http.StatusOK, mcptools.JSONRPCResponse{
|
|
JSONRPC: "2.0",
|
|
ID: req.ID,
|
|
Error: &mcptools.JSONRPCError{Code: -32600, Message: "invalid jsonrpc version"},
|
|
})
|
|
}
|
|
|
|
if err := h.validateMCPContainer(c.Request().Context(), containerID, botID); err != nil {
|
|
h.logger.Error("mcp fs validate failed", slog.Any("error", err), slog.String("bot_id", botID), slog.String("container_id", containerID))
|
|
return c.JSON(http.StatusOK, mcptools.JSONRPCResponse{
|
|
JSONRPC: "2.0",
|
|
ID: req.ID,
|
|
Error: &mcptools.JSONRPCError{Code: -32603, Message: err.Error()},
|
|
})
|
|
}
|
|
if err := h.ensureTaskRunning(c.Request().Context(), containerID); err != nil {
|
|
h.logger.Error("mcp fs ensure task failed", slog.Any("error", err), slog.String("bot_id", botID), slog.String("container_id", containerID))
|
|
return c.JSON(http.StatusOK, mcptools.JSONRPCResponse{
|
|
JSONRPC: "2.0",
|
|
ID: req.ID,
|
|
Error: &mcptools.JSONRPCError{Code: -32603, Message: err.Error()},
|
|
})
|
|
}
|
|
|
|
if strings.TrimSpace(req.Method) == "" {
|
|
return c.JSON(http.StatusOK, mcptools.JSONRPCResponse{
|
|
JSONRPC: "2.0",
|
|
ID: req.ID,
|
|
Error: &mcptools.JSONRPCError{Code: -32601, Message: "method not found"},
|
|
})
|
|
}
|
|
payload, err := h.callMCPServer(c.Request().Context(), containerID, req)
|
|
if err != nil {
|
|
h.logger.Error("mcp fs call failed", slog.Any("error", err), slog.String("method", req.Method), slog.String("bot_id", botID), slog.String("container_id", containerID))
|
|
return c.JSON(http.StatusOK, mcptools.JSONRPCResponse{
|
|
JSONRPC: "2.0",
|
|
ID: req.ID,
|
|
Error: &mcptools.JSONRPCError{Code: -32603, Message: err.Error()},
|
|
})
|
|
}
|
|
return c.JSON(http.StatusOK, payload)
|
|
}
|
|
|
|
func (h *ContainerdHandler) validateMCPContainer(ctx context.Context, containerID, botID string) error {
|
|
if strings.TrimSpace(botID) == "" {
|
|
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
|
}
|
|
container, err := h.service.GetContainer(ctx, containerID)
|
|
if err != nil {
|
|
if errdefs.IsNotFound(err) {
|
|
return echo.NewHTTPError(http.StatusNotFound, "container not found")
|
|
}
|
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
|
}
|
|
|
|
infoCtx := ctx
|
|
if strings.TrimSpace(h.namespace) != "" {
|
|
infoCtx = namespaces.WithNamespace(ctx, h.namespace)
|
|
}
|
|
info, err := container.Info(infoCtx)
|
|
if err != nil {
|
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
|
}
|
|
labelBotID := strings.TrimSpace(info.Labels[mcptools.BotLabelKey])
|
|
if labelBotID != "" && labelBotID != botID {
|
|
return echo.NewHTTPError(http.StatusForbidden, "bot mismatch")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *ContainerdHandler) callMCPServer(ctx context.Context, containerID string, req mcptools.JSONRPCRequest) (map[string]any, error) {
|
|
session, err := h.getMCPSession(ctx, containerID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return session.call(ctx, req)
|
|
}
|
|
|
|
type mcpSession struct {
|
|
stdin io.WriteCloser
|
|
stdout io.ReadCloser
|
|
stderr io.ReadCloser
|
|
cmd *exec.Cmd
|
|
initOnce sync.Once
|
|
writeMu sync.Mutex
|
|
pendingMu sync.Mutex
|
|
pending map[string]chan mcptools.JSONRPCResponse
|
|
closed chan struct{}
|
|
closeOnce sync.Once
|
|
closeErr error
|
|
onClose func()
|
|
}
|
|
|
|
func (h *ContainerdHandler) getMCPSession(ctx context.Context, containerID string) (*mcpSession, error) {
|
|
h.mcpMu.Lock()
|
|
if sess, ok := h.mcpSess[containerID]; ok {
|
|
h.mcpMu.Unlock()
|
|
return sess, nil
|
|
}
|
|
h.mcpMu.Unlock()
|
|
|
|
var sess *mcpSession
|
|
var err error
|
|
if runtime.GOOS == "darwin" {
|
|
sess, err = h.startLimaMCPSession(containerID)
|
|
}
|
|
if err != nil || sess == nil {
|
|
sess, err = h.startContainerdMCPSession(ctx, containerID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
h.mcpMu.Lock()
|
|
h.mcpSess[containerID] = sess
|
|
h.mcpMu.Unlock()
|
|
|
|
sess.onClose = func() {
|
|
h.mcpMu.Lock()
|
|
if current, ok := h.mcpSess[containerID]; ok && current == sess {
|
|
delete(h.mcpSess, containerID)
|
|
}
|
|
h.mcpMu.Unlock()
|
|
}
|
|
|
|
return sess, nil
|
|
}
|
|
|
|
func (h *ContainerdHandler) startContainerdMCPSession(ctx context.Context, containerID string) (*mcpSession, error) {
|
|
execSession, err := h.service.ExecTaskStreaming(ctx, containerID, ctr.ExecTaskRequest{
|
|
Args: []string{"/mcp"},
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sess := &mcpSession{
|
|
stdin: execSession.Stdin,
|
|
stdout: execSession.Stdout,
|
|
stderr: execSession.Stderr,
|
|
pending: make(map[string]chan mcptools.JSONRPCResponse),
|
|
closed: make(chan struct{}),
|
|
}
|
|
|
|
h.startMCPStderrLogger(execSession.Stderr, containerID)
|
|
go sess.readLoop()
|
|
go func() {
|
|
_, err := execSession.Wait()
|
|
if err != nil {
|
|
h.logger.Error("mcp session exited", slog.Any("error", err), slog.String("container_id", containerID))
|
|
sess.closeWithError(err)
|
|
} else {
|
|
sess.closeWithError(io.EOF)
|
|
}
|
|
}()
|
|
|
|
return sess, nil
|
|
}
|
|
|
|
func (h *ContainerdHandler) startLimaMCPSession(containerID string) (*mcpSession, error) {
|
|
execID := fmt.Sprintf("mcp-%d", time.Now().UnixNano())
|
|
cmd := exec.Command(
|
|
"limactl",
|
|
"shell",
|
|
"--tty=false",
|
|
"default",
|
|
"--",
|
|
"sudo",
|
|
"-n",
|
|
"ctr",
|
|
"-n",
|
|
"default",
|
|
"tasks",
|
|
"exec",
|
|
"--exec-id",
|
|
execID,
|
|
containerID,
|
|
"/mcp",
|
|
)
|
|
|
|
stdin, err := cmd.StdinPipe()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
stdout, err := cmd.StdoutPipe()
|
|
if err != nil {
|
|
_ = stdin.Close()
|
|
return nil, err
|
|
}
|
|
stderr, err := cmd.StderrPipe()
|
|
if err != nil {
|
|
_ = stdin.Close()
|
|
_ = stdout.Close()
|
|
return nil, err
|
|
}
|
|
if err := cmd.Start(); err != nil {
|
|
_ = stdin.Close()
|
|
_ = stdout.Close()
|
|
_ = stderr.Close()
|
|
return nil, err
|
|
}
|
|
|
|
sess := &mcpSession{
|
|
stdin: stdin,
|
|
stdout: stdout,
|
|
stderr: stderr,
|
|
cmd: cmd,
|
|
pending: make(map[string]chan mcptools.JSONRPCResponse),
|
|
closed: make(chan struct{}),
|
|
}
|
|
|
|
h.startMCPStderrLogger(stderr, containerID)
|
|
go sess.readLoop()
|
|
go func() {
|
|
if err := cmd.Wait(); err != nil {
|
|
h.logger.Error("mcp session exited", slog.Any("error", err), slog.String("container_id", containerID))
|
|
sess.closeWithError(err)
|
|
} else {
|
|
sess.closeWithError(io.EOF)
|
|
}
|
|
}()
|
|
|
|
return sess, nil
|
|
}
|
|
|
|
func (s *mcpSession) closeWithError(err error) {
|
|
s.closeOnce.Do(func() {
|
|
s.closeErr = err
|
|
close(s.closed)
|
|
s.pendingMu.Lock()
|
|
for _, ch := range s.pending {
|
|
close(ch)
|
|
}
|
|
s.pending = map[string]chan mcptools.JSONRPCResponse{}
|
|
s.pendingMu.Unlock()
|
|
_ = s.stdin.Close()
|
|
_ = s.stdout.Close()
|
|
_ = s.stderr.Close()
|
|
if s.cmd != nil && s.cmd.Process != nil {
|
|
_ = s.cmd.Process.Kill()
|
|
}
|
|
if s.onClose != nil {
|
|
s.onClose()
|
|
}
|
|
})
|
|
}
|
|
|
|
func (h *ContainerdHandler) startMCPStderrLogger(stderr io.ReadCloser, containerID string) {
|
|
if stderr == nil {
|
|
return
|
|
}
|
|
go func() {
|
|
scanner := bufio.NewScanner(stderr)
|
|
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
|
for scanner.Scan() {
|
|
line := strings.TrimSpace(scanner.Text())
|
|
if line == "" {
|
|
continue
|
|
}
|
|
h.logger.Warn("mcp stderr", slog.String("container_id", containerID), slog.String("message", line))
|
|
}
|
|
if err := scanner.Err(); err != nil {
|
|
h.logger.Error("mcp stderr read failed", slog.Any("error", err), slog.String("container_id", containerID))
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (s *mcpSession) readLoop() {
|
|
scanner := bufio.NewScanner(s.stdout)
|
|
scanner.Buffer(make([]byte, 0, 64*1024), 8*1024*1024)
|
|
for scanner.Scan() {
|
|
line := strings.TrimSpace(scanner.Text())
|
|
if line == "" {
|
|
continue
|
|
}
|
|
var resp mcptools.JSONRPCResponse
|
|
if err := json.Unmarshal([]byte(line), &resp); err != nil {
|
|
continue
|
|
}
|
|
id := strings.TrimSpace(string(resp.ID))
|
|
if id == "" {
|
|
continue
|
|
}
|
|
s.pendingMu.Lock()
|
|
ch, ok := s.pending[id]
|
|
if ok {
|
|
delete(s.pending, id)
|
|
}
|
|
s.pendingMu.Unlock()
|
|
if ok {
|
|
ch <- resp
|
|
close(ch)
|
|
}
|
|
}
|
|
if err := scanner.Err(); err != nil {
|
|
s.closeWithError(err)
|
|
} else {
|
|
s.closeWithError(io.EOF)
|
|
}
|
|
}
|
|
|
|
func (s *mcpSession) call(ctx context.Context, req mcptools.JSONRPCRequest) (map[string]any, error) {
|
|
payloads, targetID, err := buildMCPPayloads(req, &s.initOnce)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
target := strings.TrimSpace(string(targetID))
|
|
if target == "" {
|
|
return nil, fmt.Errorf("missing request id")
|
|
}
|
|
|
|
respCh := make(chan mcptools.JSONRPCResponse, 1)
|
|
s.pendingMu.Lock()
|
|
s.pending[target] = respCh
|
|
s.pendingMu.Unlock()
|
|
|
|
s.writeMu.Lock()
|
|
for _, payload := range payloads {
|
|
if _, err := s.stdin.Write([]byte(payload + "\n")); err != nil {
|
|
s.writeMu.Unlock()
|
|
return nil, err
|
|
}
|
|
}
|
|
s.writeMu.Unlock()
|
|
|
|
select {
|
|
case resp, ok := <-respCh:
|
|
if !ok {
|
|
if s.closeErr != nil {
|
|
return nil, s.closeErr
|
|
}
|
|
return nil, io.EOF
|
|
}
|
|
if resp.Error != nil {
|
|
return map[string]any{
|
|
"jsonrpc": "2.0",
|
|
"id": resp.ID,
|
|
"error": map[string]any{
|
|
"code": resp.Error.Code,
|
|
"message": resp.Error.Message,
|
|
},
|
|
}, nil
|
|
}
|
|
return map[string]any{
|
|
"jsonrpc": "2.0",
|
|
"id": resp.ID,
|
|
"result": resp.Result,
|
|
}, nil
|
|
case <-s.closed:
|
|
if s.closeErr != nil {
|
|
return nil, s.closeErr
|
|
}
|
|
return nil, io.EOF
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
func buildMCPPayloads(req mcptools.JSONRPCRequest, initOnce *sync.Once) ([]string, json.RawMessage, error) {
|
|
if req.JSONRPC == "" {
|
|
req.JSONRPC = "2.0"
|
|
}
|
|
targetID := req.ID
|
|
payloads := []string{}
|
|
shouldInit := req.Method != "initialize" && req.Method != "notifications/initialized"
|
|
if initOnce != nil {
|
|
ran := false
|
|
initOnce.Do(func() {
|
|
ran = true
|
|
})
|
|
if ran {
|
|
// This is the first call on the session.
|
|
} else {
|
|
shouldInit = false
|
|
}
|
|
}
|
|
if shouldInit {
|
|
initReq := map[string]any{
|
|
"jsonrpc": "2.0",
|
|
"id": "init-1",
|
|
"method": "initialize",
|
|
"params": map[string]any{
|
|
"protocolVersion": "2025-06-18",
|
|
"capabilities": map[string]any{
|
|
"roots": map[string]any{
|
|
"listChanged": false,
|
|
},
|
|
},
|
|
"clientInfo": map[string]any{
|
|
"name": "memoh-http-proxy",
|
|
"version": "v0",
|
|
},
|
|
},
|
|
}
|
|
initBytes, err := json.Marshal(initReq)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
payloads = append(payloads, string(initBytes))
|
|
|
|
initialized := map[string]any{
|
|
"jsonrpc": "2.0",
|
|
"method": "notifications/initialized",
|
|
}
|
|
initializedBytes, err := json.Marshal(initialized)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
payloads = append(payloads, string(initializedBytes))
|
|
}
|
|
|
|
reqBytes, err := json.Marshal(req)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
payloads = append(payloads, string(reqBytes))
|
|
return payloads, targetID, nil
|
|
}
|