Files
Memoh/internal/handlers/mcp_federation_gateway.go
T

501 lines
13 KiB
Go

package handlers
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"time"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
mcpgw "github.com/memohai/memoh/internal/mcp"
)
type MCPFederationGateway struct {
handler *ContainerdHandler
logger *slog.Logger
client *http.Client
oauthService *mcpgw.OAuthService
}
func NewMCPFederationGateway(log *slog.Logger, handler *ContainerdHandler) *MCPFederationGateway {
if log == nil {
log = slog.Default()
}
return &MCPFederationGateway{
handler: handler,
logger: log.With(slog.String("gateway", "mcp_federation")),
client: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// SetOAuthService injects the OAuth service for token-based authentication.
func (g *MCPFederationGateway) SetOAuthService(svc *mcpgw.OAuthService) {
g.oauthService = svc
}
func (g *MCPFederationGateway) ListHTTPConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) {
session, err := g.connectStreamableSession(ctx, connection)
if err != nil {
return nil, err
}
defer func() { _ = session.Close() }()
result, err := session.ListTools(ctx, &sdkmcp.ListToolsParams{})
if err != nil {
return nil, err
}
return convertSDKTools(result.Tools), nil
}
func (g *MCPFederationGateway) CallHTTPConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) {
session, err := g.connectStreamableSession(ctx, connection)
if err != nil {
return nil, err
}
defer func() { _ = session.Close() }()
result, err := session.CallTool(ctx, &sdkmcp.CallToolParams{
Name: strings.TrimSpace(toolName),
Arguments: args,
})
if err != nil {
return nil, err
}
return wrapSDKToolResult(result)
}
func (g *MCPFederationGateway) ListSSEConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) {
session, err := g.connectSSESession(ctx, connection)
if err != nil {
return nil, err
}
defer func() { _ = session.Close() }()
result, err := session.ListTools(ctx, &sdkmcp.ListToolsParams{})
if err != nil {
return nil, err
}
return convertSDKTools(result.Tools), nil
}
func (g *MCPFederationGateway) CallSSEConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) {
session, err := g.connectSSESession(ctx, connection)
if err != nil {
return nil, err
}
defer func() { _ = session.Close() }()
result, err := session.CallTool(ctx, &sdkmcp.CallToolParams{
Name: strings.TrimSpace(toolName),
Arguments: args,
})
if err != nil {
return nil, err
}
return wrapSDKToolResult(result)
}
func (g *MCPFederationGateway) connectStreamableSession(ctx context.Context, connection mcpgw.Connection) (*sdkmcp.ClientSession, error) {
url := strings.TrimSpace(anyToString(connection.Config["url"]))
if url == "" {
return nil, errors.New("http mcp url is required")
}
client := sdkmcp.NewClient(&sdkmcp.Implementation{
Name: "memoh-federation-client",
Version: "v1",
}, nil)
transport := &sdkmcp.StreamableClientTransport{
Endpoint: url,
HTTPClient: g.connectionHTTPClient(ctx, connection),
MaxRetries: -1,
}
return client.Connect(ctx, transport, nil)
}
func (g *MCPFederationGateway) connectSSESession(ctx context.Context, connection mcpgw.Connection) (*sdkmcp.ClientSession, error) {
endpoints := resolveSSEEndpointCandidates(connection.Config)
if len(endpoints) == 0 {
return nil, errors.New("sse mcp url is required")
}
var lastErr error
for _, endpoint := range endpoints {
client := sdkmcp.NewClient(&sdkmcp.Implementation{
Name: "memoh-federation-client",
Version: "v1",
}, nil)
transport := &sdkmcp.SSEClientTransport{
Endpoint: endpoint,
HTTPClient: g.connectionHTTPClient(ctx, connection),
}
session, err := client.Connect(ctx, transport, nil)
if err == nil {
return session, nil
}
lastErr = err
}
if lastErr == nil {
lastErr = errors.New("no sse endpoint candidate available")
}
return nil, fmt.Errorf("connect sse mcp failed: %w", lastErr)
}
func resolveSSEEndpointCandidates(config map[string]any) []string {
if config == nil {
return []string{}
}
seen := map[string]struct{}{}
out := make([]string, 0, 4)
appendEndpoint := func(value string) {
value = strings.TrimSpace(value)
if value == "" {
return
}
if _, ok := seen[value]; ok {
return
}
seen[value] = struct{}{}
out = append(out, value)
}
for _, key := range []string{"sse_url", "sseUrl"} {
appendEndpoint(anyToString(config[key]))
}
baseURL := strings.TrimSpace(anyToString(config["url"]))
appendEndpoint(baseURL)
var messageURL string
for _, key := range []string{"message_url", "messageUrl"} {
if value := strings.TrimSpace(anyToString(config[key])); value != "" {
messageURL = value
break
}
}
if messageURL != "" {
normalized := strings.TrimSuffix(messageURL, "/")
if strings.HasSuffix(normalized, "/message") {
appendEndpoint(strings.TrimSuffix(normalized, "/message") + "/sse")
}
appendEndpoint(messageURL)
}
if baseURL != "" {
normalized := strings.TrimSuffix(baseURL, "/")
if strings.HasSuffix(normalized, "/message") {
appendEndpoint(strings.TrimSuffix(normalized, "/message") + "/sse")
}
}
return out
}
func (g *MCPFederationGateway) connectionHTTPClient(ctx context.Context, connection mcpgw.Connection) *http.Client {
base := g.client
if base == nil {
base = &http.Client{Timeout: 30 * time.Second}
}
headers := normalizeHeaderMap(connection.Config["headers"])
if strings.TrimSpace(connection.AuthType) == "oauth" && g.oauthService != nil {
token, err := g.oauthService.GetValidToken(ctx, connection.ID)
if err != nil {
g.logger.Warn("failed to get OAuth token for connection",
slog.String("connection_id", connection.ID),
slog.Any("error", err))
} else if token != "" {
if headers == nil {
headers = map[string]string{}
}
headers["Authorization"] = "Bearer " + token
}
}
if len(headers) == 0 {
return base
}
transport := base.Transport
if transport == nil {
transport = http.DefaultTransport
}
return &http.Client{
Timeout: base.Timeout,
CheckRedirect: base.CheckRedirect,
Jar: base.Jar,
Transport: &staticHeaderRoundTripper{
next: transport,
headers: headers,
},
}
}
func (g *MCPFederationGateway) ListStdioConnectionTools(ctx context.Context, botID string, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) {
sess, err := g.startStdioConnectionSession(ctx, botID, connection)
if err != nil {
return nil, err
}
defer sess.closeWithError(io.EOF)
payload, err := sess.call(ctx, mcpgw.JSONRPCRequest{
JSONRPC: "2.0",
ID: mcpgw.RawStringID("federated-stdio-tools-list"),
Method: "tools/list",
})
if err != nil {
return nil, err
}
return parseGatewayToolsListPayload(payload)
}
func (g *MCPFederationGateway) CallStdioConnectionTool(ctx context.Context, botID string, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) {
sess, err := g.startStdioConnectionSession(ctx, botID, connection)
if err != nil {
return nil, err
}
defer sess.closeWithError(io.EOF)
params, err := json.Marshal(map[string]any{
"name": toolName,
"arguments": args,
})
if err != nil {
return nil, err
}
return sess.call(ctx, mcpgw.JSONRPCRequest{
JSONRPC: "2.0",
ID: mcpgw.RawStringID("federated-stdio-tools-call"),
Method: "tools/call",
Params: params,
})
}
func (g *MCPFederationGateway) startStdioConnectionSession(ctx context.Context, botID string, connection mcpgw.Connection) (*mcpSession, error) {
if g.handler == nil {
return nil, errors.New("containerd handler not configured")
}
containerID, err := g.handler.botContainerID(ctx, botID)
if err != nil {
return nil, err
}
if err := g.handler.ensureContainerAndTask(ctx, containerID, botID); err != nil {
return nil, err
}
command := strings.TrimSpace(anyToString(connection.Config["command"]))
if command == "" {
return nil, errors.New("stdio mcp command is required")
}
request := MCPStdioRequest{
Name: strings.TrimSpace(connection.Name),
Command: command,
Args: normalizeStringSlice(connection.Config["args"]),
Env: normalizeStringMap(connection.Config["env"]),
Cwd: strings.TrimSpace(anyToString(connection.Config["cwd"])),
}
return g.handler.startContainerdMCPCommandSession(ctx, containerID, request)
}
func parseGatewayToolsListPayload(payload map[string]any) ([]mcpgw.ToolDescriptor, error) {
if err := mcpgw.PayloadError(payload); err != nil {
return nil, err
}
result, ok := payload["result"].(map[string]any)
if !ok {
return nil, errors.New("invalid tools/list result")
}
rawTools, ok := result["tools"].([]any)
if !ok {
return nil, errors.New("invalid tools/list tools field")
}
tools := make([]mcpgw.ToolDescriptor, 0, len(rawTools))
for _, rawTool := range rawTools {
item, ok := rawTool.(map[string]any)
if !ok {
continue
}
name := strings.TrimSpace(anyToString(item["name"]))
if name == "" {
continue
}
description := strings.TrimSpace(anyToString(item["description"]))
inputSchema, _ := item["inputSchema"].(map[string]any)
if inputSchema == nil {
inputSchema = map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
tools = append(tools, mcpgw.ToolDescriptor{
Name: name,
Description: description,
InputSchema: inputSchema,
})
}
return tools, nil
}
func convertSDKTools(items []*sdkmcp.Tool) []mcpgw.ToolDescriptor {
if len(items) == 0 {
return []mcpgw.ToolDescriptor{}
}
tools := make([]mcpgw.ToolDescriptor, 0, len(items))
for _, item := range items {
if item == nil {
continue
}
name := strings.TrimSpace(item.Name)
if name == "" {
continue
}
tools = append(tools, mcpgw.ToolDescriptor{
Name: name,
Description: strings.TrimSpace(item.Description),
InputSchema: normalizeToolInputSchema(item.InputSchema),
})
}
return tools
}
func normalizeToolInputSchema(raw any) map[string]any {
if schema, ok := raw.(map[string]any); ok && schema != nil {
return schema
}
if raw != nil {
payload, err := json.Marshal(raw)
if err == nil {
var schema map[string]any
if err := json.Unmarshal(payload, &schema); err == nil && schema != nil {
return schema
}
}
}
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
func wrapSDKToolResult(result *sdkmcp.CallToolResult) (map[string]any, error) {
if result == nil {
return map[string]any{
"result": mcpgw.BuildToolSuccessResult(map[string]any{"ok": true}),
}, nil
}
payload, err := json.Marshal(result)
if err != nil {
return nil, err
}
var parsed map[string]any
if err := json.Unmarshal(payload, &parsed); err != nil {
return nil, err
}
if parsed == nil {
parsed = map[string]any{}
}
return map[string]any{"result": parsed}, nil
}
func normalizeHeaderMap(raw any) map[string]string {
switch value := raw.(type) {
case map[string]string:
return value
case map[string]any:
out := make(map[string]string, len(value))
for k, v := range value {
key := strings.TrimSpace(k)
val := strings.TrimSpace(anyToString(v))
if key == "" || val == "" {
continue
}
out[key] = val
}
return out
default:
return map[string]string{}
}
}
func normalizeStringSlice(raw any) []string {
switch value := raw.(type) {
case []string:
out := make([]string, 0, len(value))
for _, item := range value {
item = strings.TrimSpace(item)
if item != "" {
out = append(out, item)
}
}
return out
case []any:
out := make([]string, 0, len(value))
for _, item := range value {
val := strings.TrimSpace(anyToString(item))
if val != "" {
out = append(out, val)
}
}
return out
default:
return []string{}
}
}
func normalizeStringMap(raw any) map[string]string {
switch value := raw.(type) {
case map[string]string:
return value
case map[string]any:
out := make(map[string]string, len(value))
for k, v := range value {
key := strings.TrimSpace(k)
val := strings.TrimSpace(anyToString(v))
if key == "" {
continue
}
out[key] = val
}
return out
default:
return map[string]string{}
}
}
func anyToString(v any) string {
if v == nil {
return ""
}
switch value := v.(type) {
case string:
return value
default:
return fmt.Sprintf("%v", v)
}
}
type staticHeaderRoundTripper struct {
next http.RoundTripper
headers map[string]string
}
func (t *staticHeaderRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
next := t.next
if next == nil {
next = http.DefaultTransport
}
clone := req.Clone(req.Context())
clone.Header = req.Header.Clone()
for key, value := range t.headers {
headerKey := strings.TrimSpace(key)
headerVal := strings.TrimSpace(value)
if headerKey == "" || headerVal == "" {
continue
}
clone.Header.Set(headerKey, headerVal)
}
return next.RoundTrip(clone)
}