Files
Memoh/internal/handlers/mcp_tools.go

245 lines
7.1 KiB
Go

package handlers
import (
"context"
"encoding/json"
"errors"
"net/http"
"strings"
"github.com/labstack/echo/v4"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/memohai/memoh/internal/auth"
mcpgw "github.com/memohai/memoh/internal/mcp"
)
const (
headerChannelIdentityID = "X-Memoh-Channel-Identity-Id"
headerSessionToken = "X-Memoh-Session-Token" //nolint:gosec // G101: this is an HTTP header name, not a hardcoded credential
headerCurrentPlatform = "X-Memoh-Current-Platform"
headerReplyTarget = "X-Memoh-Reply-Target"
headerIsSubagent = "X-Memoh-Is-Subagent"
)
func (h *ContainerdHandler) SetToolGatewayService(service *mcpgw.ToolGatewayService) {
h.toolGateway = service
}
// HandleMCPTools godoc
// @Summary Unified MCP tools gateway
// @Description MCP endpoint for tool discovery and invocation.
// @Tags containerd
// @Param bot_id path string true "Bot ID"
// @Param payload body object true "JSON-RPC request"
// @Success 200 {object} object "JSON-RPC response: {jsonrpc,id,result|error}"
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/tools [post].
func (h *ContainerdHandler) HandleMCPTools(c echo.Context) error {
if h.toolGateway == nil {
return echo.NewHTTPError(http.StatusServiceUnavailable, "tool gateway not configured")
}
botID, err := h.requireBotAccessWithGuest(c)
if err != nil {
return err
}
return h.handleMCPToolsWithBotID(c, botID)
}
func (h *ContainerdHandler) handleMCPToolsWithBotID(c echo.Context, botID string) error {
session := h.buildToolSessionContext(c, botID)
req := c.Request()
ensureStreamableAcceptHeader(req)
ctx := context.WithValue(req.Context(), toolSessionContextKey{}, session)
req = req.WithContext(ctx)
handler := sdkmcp.NewStreamableHTTPHandler(
func(r *http.Request) *sdkmcp.Server {
return h.buildToolMCPServer(r.Context())
},
&sdkmcp.StreamableHTTPOptions{
Stateless: true,
JSONResponse: true,
Logger: h.logger,
},
)
handler.ServeHTTP(c.Response().Writer, req)
return nil
}
func ensureStreamableAcceptHeader(req *http.Request) {
if req == nil {
return
}
acceptValues := req.Header.Values("Accept")
joined := strings.ToLower(strings.Join(acceptValues, ","))
hasJSON := strings.Contains(joined, "application/json") || strings.Contains(joined, "application/*") || strings.Contains(joined, "*/*")
hasStream := strings.Contains(joined, "text/event-stream") || strings.Contains(joined, "text/*") || strings.Contains(joined, "*/*")
if hasJSON && hasStream {
return
}
base := strings.TrimSpace(strings.Join(acceptValues, ","))
parts := make([]string, 0, 3)
if base != "" {
parts = append(parts, base)
}
if !hasJSON {
parts = append(parts, "application/json")
}
if !hasStream {
parts = append(parts, "text/event-stream")
}
if len(parts) == 0 {
parts = append(parts, "application/json", "text/event-stream")
}
req.Header.Set("Accept", strings.Join(parts, ", "))
}
type toolSessionContextKey struct{}
func (h *ContainerdHandler) buildToolMCPServer(ctx context.Context) *sdkmcp.Server {
if h.toolGateway == nil {
return nil
}
session, ok := ctx.Value(toolSessionContextKey{}).(mcpgw.ToolSessionContext)
if !ok {
return nil
}
server := sdkmcp.NewServer(
&sdkmcp.Implementation{
Name: "memoh-tools-gateway",
Version: "1.0.0",
},
&sdkmcp.ServerOptions{
Capabilities: &sdkmcp.ServerCapabilities{
Tools: &sdkmcp.ToolCapabilities{
ListChanged: false,
},
},
},
)
server.AddReceivingMiddleware(h.toolGatewayMiddleware(session))
return server
}
func (h *ContainerdHandler) toolGatewayMiddleware(session mcpgw.ToolSessionContext) sdkmcp.Middleware {
return func(next sdkmcp.MethodHandler) sdkmcp.MethodHandler {
return func(ctx context.Context, method string, req sdkmcp.Request) (sdkmcp.Result, error) {
switch strings.TrimSpace(method) {
case "tools/list":
tools, err := h.toolGateway.ListTools(ctx, session)
if err != nil {
return nil, err
}
return &sdkmcp.ListToolsResult{
Tools: convertGatewayToolsToSDK(tools),
}, nil
case "tools/call":
callReq, ok := req.(*sdkmcp.ServerRequest[*sdkmcp.CallToolParamsRaw])
if !ok || callReq == nil || callReq.Params == nil {
return nil, errors.New("tools/call params is required")
}
payload, err := buildToolCallPayloadFromRaw(callReq.Params)
if err != nil {
return nil, err
}
result, err := h.toolGateway.CallTool(ctx, session, payload)
if err != nil {
return nil, err
}
return convertGatewayCallResultToSDK(result)
default:
return next(ctx, method, req)
}
}
}
}
func buildToolCallPayloadFromRaw(params *sdkmcp.CallToolParamsRaw) (mcpgw.ToolCallPayload, error) {
if params == nil {
return mcpgw.ToolCallPayload{}, errors.New("tools/call params is required")
}
name := strings.TrimSpace(params.Name)
if name == "" {
return mcpgw.ToolCallPayload{}, errors.New("tools/call name is required")
}
arguments := map[string]any{}
if len(params.Arguments) > 0 {
if err := json.Unmarshal(params.Arguments, &arguments); err != nil {
return mcpgw.ToolCallPayload{}, err
}
}
if arguments == nil {
arguments = map[string]any{}
}
return mcpgw.ToolCallPayload{
Name: name,
Arguments: arguments,
}, nil
}
func convertGatewayToolsToSDK(items []mcpgw.ToolDescriptor) []*sdkmcp.Tool {
if len(items) == 0 {
return []*sdkmcp.Tool{}
}
tools := make([]*sdkmcp.Tool, 0, len(items))
for _, item := range items {
name := strings.TrimSpace(item.Name)
if name == "" {
continue
}
inputSchema := item.InputSchema
if inputSchema == nil {
inputSchema = map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
tools = append(tools, &sdkmcp.Tool{
Name: name,
Description: strings.TrimSpace(item.Description),
InputSchema: inputSchema,
})
}
return tools
}
func convertGatewayCallResultToSDK(result map[string]any) (*sdkmcp.CallToolResult, error) {
if result == nil {
result = mcpgw.BuildToolSuccessResult(map[string]any{"ok": true})
}
payload, err := json.Marshal(result)
if err != nil {
return nil, err
}
var out sdkmcp.CallToolResult
if err := json.Unmarshal(payload, &out); err != nil {
return nil, err
}
return &out, nil
}
func (*ContainerdHandler) buildToolSessionContext(c echo.Context, botID string) mcpgw.ToolSessionContext {
channelIdentityID := strings.TrimSpace(c.Request().Header.Get(headerChannelIdentityID))
if channelIdentityID == "" {
if ctxIdentityID, err := auth.UserIDFromContext(c); err == nil {
channelIdentityID = strings.TrimSpace(ctxIdentityID)
}
}
isSubagent := strings.EqualFold(strings.TrimSpace(c.Request().Header.Get(headerIsSubagent)), "true")
return mcpgw.ToolSessionContext{
BotID: strings.TrimSpace(botID),
ChatID: strings.TrimSpace(botID),
ChannelIdentityID: channelIdentityID,
SessionToken: strings.TrimSpace(c.Request().Header.Get(headerSessionToken)),
CurrentPlatform: strings.TrimSpace(c.Request().Header.Get(headerCurrentPlatform)),
ReplyTarget: strings.TrimSpace(c.Request().Header.Get(headerReplyTarget)),
IsSubagent: isSubagent,
}
}