From 3f8cb3292c82c4acfe31dd66fbb24f77bc2d9a79 Mon Sep 17 00:00:00 2001 From: Ran <16112591+chen-ran@users.noreply.github.com> Date: Sun, 8 Feb 2026 01:45:53 +0800 Subject: [PATCH] chore: optimize code structure --- internal/handlers/fs.go | 134 ++++++---------------------------------- internal/mcp/jsonrpc.go | 95 ++++++++++++++++++++++++++++ internal/mcp/manager.go | 44 +++++++------ 3 files changed, 139 insertions(+), 134 deletions(-) create mode 100644 internal/mcp/jsonrpc.go diff --git a/internal/handlers/fs.go b/internal/handlers/fs.go index 60db0dc1..8dbac89e 100644 --- a/internal/handlers/fs.go +++ b/internal/handlers/fs.go @@ -62,53 +62,33 @@ func (h *ContainerdHandler) HandleMCPFS(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } if req.JSONRPC != "" && req.JSONRPC != "2.0" { - return c.JSON(http.StatusOK, mcptools.JSONRPCResponse{ - JSONRPC: "2.0", - ID: req.ID, - Error: &mcptools.JSONRPCError{Code: -32600, Message: "invalid jsonrpc version"}, - }) + return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32600, "invalid jsonrpc version")) } - if err := h.validateMCPContainer(c.Request().Context(), containerID, botID); err != nil { + if err := h.validateMCPContainer(ctx, containerID, botID); err != nil { h.logger.Error("mcp fs validate failed", slog.Any("error", err), slog.String("bot_id", botID), slog.String("container_id", containerID)) - return c.JSON(http.StatusOK, mcptools.JSONRPCResponse{ - JSONRPC: "2.0", - ID: req.ID, - Error: &mcptools.JSONRPCError{Code: -32603, Message: err.Error()}, - }) + return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32603, err.Error())) } - if err := h.ensureTaskRunning(c.Request().Context(), containerID); err != nil { + if err := h.ensureTaskRunning(ctx, containerID); err != nil { h.logger.Error("mcp fs ensure task failed", slog.Any("error", err), slog.String("bot_id", botID), slog.String("container_id", containerID)) - return c.JSON(http.StatusOK, mcptools.JSONRPCResponse{ - JSONRPC: "2.0", - ID: req.ID, - Error: &mcptools.JSONRPCError{Code: -32603, Message: err.Error()}, - }) + return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32603, err.Error())) } if strings.TrimSpace(req.Method) == "" { - return c.JSON(http.StatusOK, mcptools.JSONRPCResponse{ - JSONRPC: "2.0", - ID: req.ID, - Error: &mcptools.JSONRPCError{Code: -32601, Message: "method not found"}, - }) + return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32601, "method not found")) } - if len(req.ID) == 0 && strings.HasPrefix(req.Method, "notifications/") { - if err := h.notifyMCPServer(c.Request().Context(), containerID, req); err != nil { + if mcptools.IsNotification(req) { + if err := h.notifyMCPServer(ctx, containerID, req); err != nil { h.logger.Error("mcp fs notify failed", slog.Any("error", err), slog.String("method", req.Method), slog.String("bot_id", botID), slog.String("container_id", containerID)) return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } // MCP Streamable HTTP spec: notifications must be answered with 202 Accepted and no body. return c.NoContent(http.StatusAccepted) } - payload, err := h.callMCPServer(c.Request().Context(), containerID, req) + payload, err := h.callMCPServer(ctx, containerID, req) if err != nil { h.logger.Error("mcp fs call failed", slog.Any("error", err), slog.String("method", req.Method), slog.String("bot_id", botID), slog.String("container_id", containerID)) - return c.JSON(http.StatusOK, mcptools.JSONRPCResponse{ - JSONRPC: "2.0", - ID: req.ID, - Error: &mcptools.JSONRPCError{Code: -32603, Message: err.Error()}, - }) + return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32603, err.Error())) } return c.JSON(http.StatusOK, payload) } @@ -380,7 +360,7 @@ func (s *mcpSession) readLoop() { } func (s *mcpSession) call(ctx context.Context, req mcptools.JSONRPCRequest) (map[string]any, error) { - payloads, targetID, err := buildMCPPayloads(req, &s.initOnce) + payloads, targetID, err := mcptools.BuildPayloads(req, &s.initOnce) if err != nil { return nil, err } @@ -394,14 +374,9 @@ func (s *mcpSession) call(ctx context.Context, req mcptools.JSONRPCRequest) (map s.pending[target] = respCh s.pendingMu.Unlock() - s.writeMu.Lock() - for _, payload := range payloads { - if _, err := s.stdin.Write([]byte(payload + "\n")); err != nil { - s.writeMu.Unlock() - return nil, err - } + if err := s.writePayloads(payloads); err != nil { + return nil, err } - s.writeMu.Unlock() select { case resp, ok := <-respCh: @@ -437,92 +412,21 @@ func (s *mcpSession) call(ctx context.Context, req mcptools.JSONRPCRequest) (map } func (s *mcpSession) notify(ctx context.Context, req mcptools.JSONRPCRequest) error { - payloads, err := buildMCPNotificationPayloads(req) + payloads, err := mcptools.BuildNotificationPayloads(req) if err != nil { return err } + return s.writePayloads(payloads) +} + +func (s *mcpSession) writePayloads(payloads []string) error { s.writeMu.Lock() + defer s.writeMu.Unlock() for _, payload := range payloads { if _, err := s.stdin.Write([]byte(payload + "\n")); err != nil { - s.writeMu.Unlock() return err } } - s.writeMu.Unlock() return nil } -func buildMCPPayloads(req mcptools.JSONRPCRequest, initOnce *sync.Once) ([]string, json.RawMessage, error) { - if req.JSONRPC == "" { - req.JSONRPC = "2.0" - } - targetID := req.ID - payloads := []string{} - shouldInit := req.Method != "initialize" && req.Method != "notifications/initialized" - if initOnce != nil { - ran := false - initOnce.Do(func() { - ran = true - }) - if ran { - // This is the first call on the session. - } else { - shouldInit = false - } - } - if shouldInit { - initReq := map[string]any{ - "jsonrpc": "2.0", - "id": "init-1", - "method": "initialize", - "params": map[string]any{ - "protocolVersion": "2025-06-18", - "capabilities": map[string]any{ - "roots": map[string]any{ - "listChanged": false, - }, - }, - "clientInfo": map[string]any{ - "name": "memoh-http-proxy", - "version": "v0", - }, - }, - } - initBytes, err := json.Marshal(initReq) - if err != nil { - return nil, nil, err - } - payloads = append(payloads, string(initBytes)) - - initialized := map[string]any{ - "jsonrpc": "2.0", - "method": "notifications/initialized", - } - initializedBytes, err := json.Marshal(initialized) - if err != nil { - return nil, nil, err - } - payloads = append(payloads, string(initializedBytes)) - } - - reqBytes, err := json.Marshal(req) - if err != nil { - return nil, nil, err - } - payloads = append(payloads, string(reqBytes)) - return payloads, targetID, nil -} - -func buildMCPNotificationPayloads(req mcptools.JSONRPCRequest) ([]string, error) { - if req.JSONRPC == "" { - req.JSONRPC = "2.0" - } - if strings.TrimSpace(req.Method) == "" { - return nil, fmt.Errorf("missing method") - } - reqBytes, err := json.Marshal(req) - if err != nil { - return nil, err - } - return []string{string(reqBytes)}, nil -} diff --git a/internal/mcp/jsonrpc.go b/internal/mcp/jsonrpc.go new file mode 100644 index 00000000..d6d4933e --- /dev/null +++ b/internal/mcp/jsonrpc.go @@ -0,0 +1,95 @@ +package mcp + +import ( + "encoding/json" + "fmt" + "strings" + "sync" +) + +func IsNotification(req JSONRPCRequest) bool { + return len(req.ID) == 0 && strings.HasPrefix(req.Method, "notifications/") +} + +func JSONRPCErrorResponse(id json.RawMessage, code int, message string) JSONRPCResponse { + return JSONRPCResponse{ + JSONRPC: "2.0", + ID: id, + Error: &JSONRPCError{Code: code, Message: message}, + } +} + +func BuildPayloads(req JSONRPCRequest, initOnce *sync.Once) ([]string, json.RawMessage, error) { + if req.JSONRPC == "" { + req.JSONRPC = "2.0" + } + targetID := req.ID + payloads := []string{} + shouldInit := req.Method != "initialize" && req.Method != "notifications/initialized" + if initOnce != nil { + ran := false + initOnce.Do(func() { + ran = true + }) + if ran { + // This is the first call on the session. + } else { + shouldInit = false + } + } + if shouldInit { + initReq := map[string]any{ + "jsonrpc": "2.0", + "id": "init-1", + "method": "initialize", + "params": map[string]any{ + "protocolVersion": "2025-06-18", + "capabilities": map[string]any{ + "roots": map[string]any{ + "listChanged": false, + }, + }, + "clientInfo": map[string]any{ + "name": "memoh-http-proxy", + "version": "v0", + }, + }, + } + initBytes, err := json.Marshal(initReq) + if err != nil { + return nil, nil, err + } + payloads = append(payloads, string(initBytes)) + + initialized := map[string]any{ + "jsonrpc": "2.0", + "method": "notifications/initialized", + } + initializedBytes, err := json.Marshal(initialized) + if err != nil { + return nil, nil, err + } + payloads = append(payloads, string(initializedBytes)) + } + + reqBytes, err := json.Marshal(req) + if err != nil { + return nil, nil, err + } + payloads = append(payloads, string(reqBytes)) + return payloads, targetID, nil +} + +func BuildNotificationPayloads(req JSONRPCRequest) ([]string, error) { + if req.JSONRPC == "" { + req.JSONRPC = "2.0" + } + if strings.TrimSpace(req.Method) == "" { + return nil, fmt.Errorf("missing method") + } + reqBytes, err := json.Marshal(req) + if err != nil { + return nil, err + } + return []string{string(reqBytes)}, nil +} diff --git a/internal/mcp/manager.go b/internal/mcp/manager.go index 06e7a925..845b46fe 100644 --- a/internal/mcp/manager.go +++ b/internal/mcp/manager.go @@ -88,15 +88,8 @@ func (m *Manager) EnsureBot(ctx context.Context, botID string) error { return err } - dataMount := m.cfg.DataMount - if dataMount == "" { - dataMount = config.DefaultDataMount - } - - image := m.cfg.BusyboxImage - if image == "" { - image = config.DefaultBusyboxImg - } + dataMount := m.dataMount() + image := m.imageRef() specOpts := []oci.SpecOpts{ oci.WithMounts([]specs.Mount{ @@ -235,25 +228,38 @@ func (m *Manager) DataDir(botID string) (string, error) { return "", err } - root := m.cfg.DataRoot - if root == "" { - root = config.DefaultDataRoot - } - return filepath.Join(root, "bots", botID), nil + return filepath.Join(m.dataRoot(), "bots", botID), nil } func (m *Manager) ensureBotDir(botID string) (string, error) { - root := m.cfg.DataRoot - if root == "" { - root = config.DefaultDataRoot - } - dir := filepath.Join(root, "bots", botID) + dir := filepath.Join(m.dataRoot(), "bots", botID) if err := os.MkdirAll(dir, 0o755); err != nil { return "", err } return dir, nil } +func (m *Manager) dataRoot() string { + if m.cfg.DataRoot == "" { + return config.DefaultDataRoot + } + return m.cfg.DataRoot +} + +func (m *Manager) dataMount() string { + if m.cfg.DataMount == "" { + return config.DefaultDataMount + } + return m.cfg.DataMount +} + +func (m *Manager) imageRef() string { + if m.cfg.BusyboxImage == "" { + return config.DefaultBusyboxImg + } + return m.cfg.BusyboxImage +} + func validateBotID(botID string) error { return identity.ValidateUserID(botID) }