Files
Memoh/internal/handlers/mcp_stdio.go
T

813 lines
20 KiB
Go

package handlers
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"sort"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
sdkjsonrpc "github.com/modelcontextprotocol/go-sdk/jsonrpc"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
mcptools "github.com/memohai/memoh/internal/mcp"
pb "github.com/memohai/memoh/internal/workspace/bridgepb"
)
// MCPStdioRequest represents a request to create an MCP stdio session.
type MCPStdioRequest struct {
Name string `json:"name"`
Command string `json:"command"`
Args []string `json:"args"`
Env map[string]string `json:"env"`
Cwd string `json:"cwd"`
}
// MCPStdioResponse represents the response from creating an MCP stdio session.
type MCPStdioResponse struct {
ConnectionID string `json:"connection_id"`
URL string `json:"url"`
Tools []string `json:"tools,omitempty"`
}
// mcpSession represents an MCP session over stdio.
type mcpSession struct {
stdin io.WriteCloser
stdout io.ReadCloser
stderr io.ReadCloser
readCtx context.Context
cancelRead context.CancelFunc
initMu sync.Mutex
initState mcpSessionInitState
initWait chan struct{}
pendingMu sync.Mutex
pending map[string]chan *sdkjsonrpc.Response
conn sdkmcp.Connection
closed chan struct{}
closeOnce sync.Once
closeErr error
onClose func()
}
type mcpSessionInitState uint8
const (
mcpSessionInitStateNone mcpSessionInitState = iota
mcpSessionInitStateInitializing
mcpSessionInitStateInitialized
mcpSessionInitStateReady
)
func (s *mcpSession) closeWithError(err error) {
s.closeOnce.Do(func() {
s.closeErr = err
close(s.closed)
if s.cancelRead != nil {
s.cancelRead()
}
s.pendingMu.Lock()
for _, ch := range s.pending {
close(ch)
}
s.pending = map[string]chan *sdkjsonrpc.Response{}
s.pendingMu.Unlock()
if s.conn != nil {
_ = s.conn.Close()
}
if s.stdin != nil {
_ = s.stdin.Close()
}
if s.stdout != nil {
_ = s.stdout.Close()
}
if s.stderr != nil {
_ = s.stderr.Close()
}
if s.onClose != nil {
s.onClose()
}
})
}
func (s *mcpSession) readLoop() {
if s.conn == nil {
s.closeWithError(io.EOF)
return
}
for {
msg, err := s.conn.Read(s.readCtx)
if err != nil {
if errors.Is(err, io.EOF) {
s.closeWithError(io.EOF)
return
}
s.closeWithError(err)
return
}
resp, ok := msg.(*sdkjsonrpc.Response)
if !ok || !resp.ID.IsValid() {
continue
}
id := sdkIDKey(resp.ID)
if id == "" {
continue
}
s.pendingMu.Lock()
ch, ok := s.pending[id]
if ok {
delete(s.pending, id)
}
s.pendingMu.Unlock()
if ok {
ch <- resp
close(ch)
}
}
}
func (s *mcpSession) call(ctx context.Context, req mcptools.JSONRPCRequest) (map[string]any, error) {
method := strings.TrimSpace(req.Method)
if method == "initialize" {
payload, err := s.callRaw(ctx, req)
if err != nil {
return nil, err
}
// If the server accepted our initialize, advance state so
// ensureInitialized will only send notifications/initialized next time.
if _, hasError := payload["error"]; !hasError {
s.initMu.Lock()
if s.initState < mcpSessionInitStateInitialized {
s.initState = mcpSessionInitStateInitialized
}
s.initMu.Unlock()
}
return payload, nil
}
if method != "notifications/initialized" {
if err := s.ensureInitialized(ctx); err != nil {
return nil, err
}
}
return s.callRaw(ctx, req)
}
func (s *mcpSession) callRaw(ctx context.Context, req mcptools.JSONRPCRequest) (map[string]any, error) {
targetID, err := parseRawJSONRPCID(req.ID)
if err != nil {
return nil, err
}
target := sdkIDKey(targetID)
if target == "" {
return nil, errors.New("missing request id")
}
respCh := make(chan *sdkjsonrpc.Response, 1)
s.pendingMu.Lock()
s.pending[target] = respCh
s.pendingMu.Unlock()
callReq := &sdkjsonrpc.Request{
ID: targetID,
Method: req.Method,
Params: req.Params,
}
if err := s.conn.Write(ctx, callReq); err != nil {
s.pendingMu.Lock()
delete(s.pending, target)
s.pendingMu.Unlock()
return nil, err
}
select {
case resp, ok := <-respCh:
if !ok {
if s.closeErr != nil {
return nil, s.closeErr
}
return nil, io.EOF
}
return sdkResponsePayload(resp)
case <-s.closed:
if s.closeErr != nil {
return nil, s.closeErr
}
return nil, io.EOF
case <-ctx.Done():
s.pendingMu.Lock()
delete(s.pending, target)
s.pendingMu.Unlock()
return nil, ctx.Err()
}
}
// sdkResponsePayload wraps an SDK JSON-RPC response into a standard JSON-RPC
// envelope ({"jsonrpc":"2.0","id":...,"result":...} or "error":...).
func sdkResponsePayload(resp *sdkjsonrpc.Response) (map[string]any, error) {
if resp == nil {
return nil, io.EOF
}
if resp.Error != nil {
code := int64(-32603)
message := strings.TrimSpace(resp.Error.Error())
wireErr := &sdkjsonrpc.Error{}
if errors.As(resp.Error, &wireErr) {
code = wireErr.Code
message = strings.TrimSpace(wireErr.Message)
}
if message == "" {
message = "internal error"
}
return map[string]any{
"jsonrpc": "2.0",
"id": sdkIDRaw(resp.ID),
"error": map[string]any{
"code": code,
"message": message,
},
}, nil
}
var result any
if len(resp.Result) > 0 {
if err := json.Unmarshal(resp.Result, &result); err != nil {
return nil, err
}
}
return map[string]any{
"jsonrpc": "2.0",
"id": sdkIDRaw(resp.ID),
"result": result,
}, nil
}
func sdkIDRaw(id sdkjsonrpc.ID) any {
if !id.IsValid() {
return nil
}
return id.Raw()
}
func (s *mcpSession) notify(ctx context.Context, req mcptools.JSONRPCRequest) error {
if s.conn == nil {
return io.EOF
}
return s.conn.Write(ctx, &sdkjsonrpc.Request{
Method: req.Method,
Params: req.Params,
})
}
func (s *mcpSession) ensureInitialized(ctx context.Context) error {
for {
s.initMu.Lock()
state := s.initState
switch state {
case mcpSessionInitStateReady:
s.initMu.Unlock()
return nil
case mcpSessionInitStateInitializing:
waitCh := s.initWait
s.initMu.Unlock()
if waitCh == nil {
continue
}
select {
case <-waitCh:
continue
case <-ctx.Done():
return ctx.Err()
case <-s.closed:
if s.closeErr != nil {
return s.closeErr
}
return io.EOF
}
case mcpSessionInitStateInitialized:
waitCh := make(chan struct{})
s.initState = mcpSessionInitStateInitializing
s.initWait = waitCh
s.initMu.Unlock()
err := s.sendInitializedNotification(ctx)
s.initMu.Lock()
if err == nil {
s.initState = mcpSessionInitStateReady
} else {
s.initState = mcpSessionInitStateInitialized
}
s.initWait = nil
close(waitCh)
s.initMu.Unlock()
if err != nil {
return err
}
return nil
default:
waitCh := make(chan struct{})
s.initState = mcpSessionInitStateInitializing
s.initWait = waitCh
s.initMu.Unlock()
nextState, err := s.initializeHandshake(ctx)
s.initMu.Lock()
s.initState = nextState
s.initWait = nil
close(waitCh)
s.initMu.Unlock()
if err != nil {
return err
}
if nextState == mcpSessionInitStateReady {
return nil
}
}
}
}
func (s *mcpSession) initializeHandshake(ctx context.Context) (mcpSessionInitState, error) {
initID, _ := sdkjsonrpc.MakeID("init")
params, _ := json.Marshal(map[string]any{
"protocolVersion": "2025-06-18",
"capabilities": map[string]any{},
"clientInfo": map[string]any{
"name": "memoh",
"version": "1.0.0",
},
})
initResp, err := s.invokeCall(ctx, &sdkjsonrpc.Request{
ID: initID,
Method: "initialize",
Params: params,
})
if err != nil {
return mcpSessionInitStateNone, err
}
if initResp.Error != nil {
return mcpSessionInitStateNone, initResp.Error
}
if err := s.sendInitializedNotification(ctx); err != nil {
return mcpSessionInitStateInitialized, err
}
return mcpSessionInitStateReady, nil
}
func (s *mcpSession) sendInitializedNotification(ctx context.Context) error {
if s.conn == nil {
return io.EOF
}
return s.conn.Write(ctx, &sdkjsonrpc.Request{
Method: "notifications/initialized",
})
}
func (s *mcpSession) invokeCall(ctx context.Context, req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) {
if s.conn == nil {
return nil, io.EOF
}
if req == nil || !req.ID.IsValid() {
return nil, errors.New("missing request id")
}
key := sdkIDKey(req.ID)
if key == "" {
return nil, errors.New("invalid request id")
}
respCh := make(chan *sdkjsonrpc.Response, 1)
s.pendingMu.Lock()
s.pending[key] = respCh
s.pendingMu.Unlock()
if err := s.conn.Write(ctx, req); err != nil {
s.removePending(key)
return nil, err
}
select {
case resp, ok := <-respCh:
if !ok {
if s.closeErr != nil {
return nil, s.closeErr
}
return nil, io.EOF
}
return resp, nil
case <-s.closed:
if s.closeErr != nil {
return nil, s.closeErr
}
return nil, io.EOF
case <-ctx.Done():
s.removePending(key)
return nil, ctx.Err()
}
}
func (s *mcpSession) removePending(key string) {
if strings.TrimSpace(key) == "" {
return
}
s.pendingMu.Lock()
delete(s.pending, key)
s.pendingMu.Unlock()
}
func parseRawJSONRPCID(raw json.RawMessage) (sdkjsonrpc.ID, error) {
if len(raw) == 0 {
return sdkjsonrpc.ID{}, errors.New("missing request id")
}
var idValue any
if err := json.Unmarshal(raw, &idValue); err != nil {
return sdkjsonrpc.ID{}, err
}
id, err := sdkjsonrpc.MakeID(idValue)
if err != nil {
return sdkjsonrpc.ID{}, err
}
if !id.IsValid() {
return sdkjsonrpc.ID{}, errors.New("missing request id")
}
return id, nil
}
func sdkIDKey(id sdkjsonrpc.ID) string {
if !id.IsValid() {
return ""
}
raw, _ := json.Marshal(id.Raw())
return string(raw)
}
func startMCPStderrLogger(stderr io.ReadCloser, containerID string, logger *slog.Logger) {
if stderr == nil {
return
}
go func() {
scanner := bufio.NewScanner(stderr)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
logger.Warn("mcp stderr", slog.String("container_id", containerID), slog.String("message", line))
}
if err := scanner.Err(); err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || strings.Contains(err.Error(), "closed pipe") {
return
}
logger.Error("mcp stderr read failed", slog.Any("error", err), slog.String("container_id", containerID))
}
}()
}
func extractToolNames(payload map[string]any) []string {
result, ok := payload["result"].(map[string]any)
if !ok {
return nil
}
rawTools, ok := result["tools"].([]any)
if !ok {
return nil
}
names := make([]string, 0, len(rawTools))
for _, raw := range rawTools {
item, ok := raw.(map[string]any)
if !ok {
continue
}
name, _ := item["name"].(string)
name = strings.TrimSpace(name)
if name == "" {
continue
}
names = append(names, name)
}
sort.Strings(names)
return names
}
func buildShellCommand(req MCPStdioRequest) string {
cmd := strings.TrimSpace(req.Command)
if cmd == "" {
return ""
}
parts := make([]string, 0, len(req.Args)+1)
parts = append(parts, escapeShellArg(cmd))
for _, arg := range req.Args {
parts = append(parts, escapeShellArg(arg))
}
command := strings.Join(parts, " ")
assignments := []string{}
for _, pair := range buildEnvPairs(req.Env) {
assignments = append(assignments, escapeShellArg(pair))
}
if len(assignments) > 0 {
command = strings.Join(assignments, " ") + " " + command
}
if strings.TrimSpace(req.Cwd) != "" {
command = "cd " + escapeShellArg(req.Cwd) + " && " + command
}
return command
}
func escapeShellArg(value string) string {
if value == "" {
return "''"
}
if !strings.ContainsAny(value, " \t\n'\"\\$&;|<>*?()[]{}!`") {
return value
}
return "'" + strings.ReplaceAll(value, "'", `'\''`) + "'"
}
func buildEnvPairs(env map[string]string) []string {
if len(env) == 0 {
return nil
}
keys := make([]string, 0, len(env))
for k := range env {
if strings.TrimSpace(k) != "" {
keys = append(keys, k)
}
}
sort.Strings(keys)
out := make([]string, 0, len(keys))
for _, k := range keys {
out = append(out, fmt.Sprintf("%s=%s", k, env[k]))
}
return out
}
// ---------- MCP Stdio Handlers ----------
type mcpStdioSession struct {
id string
botID string
containerID string
name string
createdAt time.Time
lastUsedAt time.Time
session *mcpSession
}
// CreateMCPStdio godoc
// @Summary Create MCP stdio proxy
// @Description Start a stdio MCP process in the bot container and expose it as MCP HTTP endpoint.
// @Tags containerd
// @Param bot_id path string true "Bot ID"
// @Param payload body MCPStdioRequest true "Stdio MCP payload"
// @Success 200 {object} MCPStdioResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/mcp-stdio [post].
func (h *ContainerdHandler) CreateMCPStdio(c echo.Context) error {
botID, err := h.requireBotAccess(c)
if err != nil {
return err
}
var req MCPStdioRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if strings.TrimSpace(req.Command) == "" {
return echo.NewHTTPError(http.StatusBadRequest, "command is required")
}
ctx := c.Request().Context()
if err := h.manager.EnsureRunning(ctx, botID); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
containerID, err := h.manager.ContainerID(ctx, botID)
if err != nil {
return echo.NewHTTPError(http.StatusNotFound, "container not found for bot")
}
sess, err := h.startContainerdMCPCommandSession(ctx, botID, containerID, req)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
tools := h.probeMCPTools(ctx, sess, botID, strings.TrimSpace(req.Name))
connectionID := uuid.NewString()
record := &mcpStdioSession{
id: connectionID,
botID: botID,
containerID: containerID,
name: strings.TrimSpace(req.Name),
createdAt: time.Now().UTC(),
lastUsedAt: time.Now().UTC(),
session: sess,
}
sess.onClose = func() {
h.mcpStdioMu.Lock()
if current, ok := h.mcpStdioSess[connectionID]; ok && current == record {
delete(h.mcpStdioSess, connectionID)
}
h.mcpStdioMu.Unlock()
}
h.mcpStdioMu.Lock()
h.mcpStdioSess[connectionID] = record
h.mcpStdioMu.Unlock()
return c.JSON(http.StatusOK, MCPStdioResponse{
ConnectionID: connectionID,
URL: fmt.Sprintf("/bots/%s/mcp-stdio/%s", botID, connectionID),
Tools: tools,
})
}
// HandleMCPStdio godoc
// @Summary MCP stdio proxy (JSON-RPC)
// @Description Proxies MCP JSON-RPC requests to a stdio MCP process in the container.
// @Tags containerd
// @Param bot_id path string true "Bot ID"
// @Param connection_id path string true "Connection ID"
// @Param payload body object true "JSON-RPC request"
// @Success 200 {object} object "JSON-RPC response: {jsonrpc,id,result|error}"
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/mcp-stdio/{connection_id} [post].
func (h *ContainerdHandler) HandleMCPStdio(c echo.Context) error {
botID, err := h.requireBotAccess(c)
if err != nil {
return err
}
connectionID := strings.TrimSpace(c.Param("connection_id"))
if connectionID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "connection_id is required")
}
h.mcpStdioMu.Lock()
session := h.mcpStdioSess[connectionID]
h.mcpStdioMu.Unlock()
if session == nil || session.session == nil || session.botID != botID {
return echo.NewHTTPError(http.StatusNotFound, "mcp connection not found")
}
select {
case <-session.session.closed:
return echo.NewHTTPError(http.StatusNotFound, "mcp connection closed")
default:
}
var req mcptools.JSONRPCRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if req.JSONRPC != "" && req.JSONRPC != "2.0" {
return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32600, "invalid jsonrpc version"))
}
if strings.TrimSpace(req.Method) == "" {
return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32601, "method not found"))
}
session.lastUsedAt = time.Now().UTC()
if mcptools.IsNotification(req) {
if err := session.session.notify(c.Request().Context(), req); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.NoContent(http.StatusAccepted)
}
payload, err := session.session.call(c.Request().Context(), req)
if err != nil {
return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32603, err.Error()))
}
return c.JSON(http.StatusOK, payload)
}
func (h *ContainerdHandler) startContainerdMCPCommandSession(ctx context.Context, botID, containerID string, req MCPStdioRequest) (*mcpSession, error) {
// Get gRPC client for the bot container via manager
client, err := h.manager.MCPClient(ctx, botID)
if err != nil {
return nil, fmt.Errorf("get container client: %w", err)
}
command := buildShellCommand(req)
// Create bidirectional exec stream
execStream, err := client.ExecStream(ctx, command, strings.TrimSpace(req.Cwd), 0)
if err != nil {
return nil, err
}
// Create pipes for stdin/stdout/stderr
stdinR, stdinW := io.Pipe()
stdoutR, stdoutW := io.Pipe()
stderrR, stderrW := io.Pipe()
readCtx, cancelRead := context.WithCancel(context.Background()) //nolint:gosec // G118: cancelRead is stored in sess.cancelRead
sess := &mcpSession{
stdin: stdinW,
stdout: stdoutR,
stderr: stderrR,
readCtx: readCtx,
cancelRead: cancelRead,
pending: make(map[string]chan *sdkjsonrpc.Response),
closed: make(chan struct{}),
}
// Forward stdin to gRPC stream
go func() {
buf := make([]byte, 32*1024)
for {
n, err := stdinR.Read(buf)
if n > 0 {
_ = execStream.SendStdin(buf[:n])
}
if err != nil {
break
}
}
_ = stdinR.Close()
}()
// Forward gRPC stdout/stderr to pipes
go func() {
for {
output, err := execStream.Recv()
if err != nil {
if !errors.Is(err, io.EOF) {
h.logger.Debug("exec stream recv done", slog.Any("error", err))
}
_ = stdoutW.Close()
_ = stderrW.Close()
break
}
switch output.GetStream() {
case pb.ExecOutput_STDOUT:
_, _ = stdoutW.Write(output.GetData())
case pb.ExecOutput_STDERR:
_, _ = stderrW.Write(output.GetData())
case pb.ExecOutput_EXIT:
_ = stdoutW.Close()
_ = stderrW.Close()
return
}
}
}()
transport := &sdkmcp.IOTransport{
Reader: sess.stdout,
Writer: sess.stdin,
}
conn, err := transport.Connect(ctx)
if err != nil {
sess.closeWithError(err)
return nil, err
}
sess.conn = conn
startMCPStderrLogger(sess.stderr, containerID, h.logger)
go sess.readLoop()
go func() {
<-sess.closed
_ = execStream.Close()
}()
return sess, nil
}
func (h *ContainerdHandler) probeMCPTools(ctx context.Context, sess *mcpSession, botID, name string) []string {
if sess == nil {
return nil
}
probeCtx, cancel := context.WithTimeout(ctx, 8*time.Second)
defer cancel()
payload, err := sess.call(probeCtx, mcptools.JSONRPCRequest{
JSONRPC: "2.0",
ID: mcptools.RawStringID("probe-tools"),
Method: "tools/list",
})
if err != nil {
h.logger.Warn("mcp stdio tools probe failed",
slog.String("bot_id", botID),
slog.String("name", name),
slog.Any("error", err),
)
return nil
}
tools := extractToolNames(payload)
if len(tools) == 0 {
h.logger.Warn("mcp stdio tools empty",
slog.String("bot_id", botID),
slog.String("name", name),
)
} else {
h.logger.Info("mcp stdio tools loaded",
slog.String("bot_id", botID),
slog.String("name", name),
slog.Int("count", len(tools)),
)
}
return tools
}