mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
501 lines
13 KiB
Go
501 lines
13 KiB
Go
package handlers
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
|
|
|
mcpgw "github.com/memohai/memoh/internal/mcp"
|
|
)
|
|
|
|
type MCPFederationGateway struct {
|
|
handler *ContainerdHandler
|
|
logger *slog.Logger
|
|
client *http.Client
|
|
oauthService *mcpgw.OAuthService
|
|
}
|
|
|
|
func NewMCPFederationGateway(log *slog.Logger, handler *ContainerdHandler) *MCPFederationGateway {
|
|
if log == nil {
|
|
log = slog.Default()
|
|
}
|
|
return &MCPFederationGateway{
|
|
handler: handler,
|
|
logger: log.With(slog.String("gateway", "mcp_federation")),
|
|
client: &http.Client{
|
|
Timeout: 30 * time.Second,
|
|
},
|
|
}
|
|
}
|
|
|
|
// SetOAuthService injects the OAuth service for token-based authentication.
|
|
func (g *MCPFederationGateway) SetOAuthService(svc *mcpgw.OAuthService) {
|
|
g.oauthService = svc
|
|
}
|
|
|
|
func (g *MCPFederationGateway) ListHTTPConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) {
|
|
session, err := g.connectStreamableSession(ctx, connection)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() { _ = session.Close() }()
|
|
result, err := session.ListTools(ctx, &sdkmcp.ListToolsParams{})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return convertSDKTools(result.Tools), nil
|
|
}
|
|
|
|
func (g *MCPFederationGateway) CallHTTPConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) {
|
|
session, err := g.connectStreamableSession(ctx, connection)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() { _ = session.Close() }()
|
|
result, err := session.CallTool(ctx, &sdkmcp.CallToolParams{
|
|
Name: strings.TrimSpace(toolName),
|
|
Arguments: args,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return wrapSDKToolResult(result)
|
|
}
|
|
|
|
func (g *MCPFederationGateway) ListSSEConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) {
|
|
session, err := g.connectSSESession(ctx, connection)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() { _ = session.Close() }()
|
|
result, err := session.ListTools(ctx, &sdkmcp.ListToolsParams{})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return convertSDKTools(result.Tools), nil
|
|
}
|
|
|
|
func (g *MCPFederationGateway) CallSSEConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) {
|
|
session, err := g.connectSSESession(ctx, connection)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() { _ = session.Close() }()
|
|
result, err := session.CallTool(ctx, &sdkmcp.CallToolParams{
|
|
Name: strings.TrimSpace(toolName),
|
|
Arguments: args,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return wrapSDKToolResult(result)
|
|
}
|
|
|
|
func (g *MCPFederationGateway) connectStreamableSession(ctx context.Context, connection mcpgw.Connection) (*sdkmcp.ClientSession, error) {
|
|
url := strings.TrimSpace(anyToString(connection.Config["url"]))
|
|
if url == "" {
|
|
return nil, errors.New("http mcp url is required")
|
|
}
|
|
client := sdkmcp.NewClient(&sdkmcp.Implementation{
|
|
Name: "memoh-federation-client",
|
|
Version: "v1",
|
|
}, nil)
|
|
transport := &sdkmcp.StreamableClientTransport{
|
|
Endpoint: url,
|
|
HTTPClient: g.connectionHTTPClient(ctx, connection),
|
|
MaxRetries: -1,
|
|
}
|
|
return client.Connect(ctx, transport, nil)
|
|
}
|
|
|
|
func (g *MCPFederationGateway) connectSSESession(ctx context.Context, connection mcpgw.Connection) (*sdkmcp.ClientSession, error) {
|
|
endpoints := resolveSSEEndpointCandidates(connection.Config)
|
|
if len(endpoints) == 0 {
|
|
return nil, errors.New("sse mcp url is required")
|
|
}
|
|
var lastErr error
|
|
for _, endpoint := range endpoints {
|
|
client := sdkmcp.NewClient(&sdkmcp.Implementation{
|
|
Name: "memoh-federation-client",
|
|
Version: "v1",
|
|
}, nil)
|
|
transport := &sdkmcp.SSEClientTransport{
|
|
Endpoint: endpoint,
|
|
HTTPClient: g.connectionHTTPClient(ctx, connection),
|
|
}
|
|
session, err := client.Connect(ctx, transport, nil)
|
|
if err == nil {
|
|
return session, nil
|
|
}
|
|
lastErr = err
|
|
}
|
|
if lastErr == nil {
|
|
lastErr = errors.New("no sse endpoint candidate available")
|
|
}
|
|
return nil, fmt.Errorf("connect sse mcp failed: %w", lastErr)
|
|
}
|
|
|
|
func resolveSSEEndpointCandidates(config map[string]any) []string {
|
|
if config == nil {
|
|
return []string{}
|
|
}
|
|
|
|
seen := map[string]struct{}{}
|
|
out := make([]string, 0, 4)
|
|
appendEndpoint := func(value string) {
|
|
value = strings.TrimSpace(value)
|
|
if value == "" {
|
|
return
|
|
}
|
|
if _, ok := seen[value]; ok {
|
|
return
|
|
}
|
|
seen[value] = struct{}{}
|
|
out = append(out, value)
|
|
}
|
|
|
|
for _, key := range []string{"sse_url", "sseUrl"} {
|
|
appendEndpoint(anyToString(config[key]))
|
|
}
|
|
|
|
baseURL := strings.TrimSpace(anyToString(config["url"]))
|
|
appendEndpoint(baseURL)
|
|
|
|
var messageURL string
|
|
for _, key := range []string{"message_url", "messageUrl"} {
|
|
if value := strings.TrimSpace(anyToString(config[key])); value != "" {
|
|
messageURL = value
|
|
break
|
|
}
|
|
}
|
|
if messageURL != "" {
|
|
normalized := strings.TrimSuffix(messageURL, "/")
|
|
if strings.HasSuffix(normalized, "/message") {
|
|
appendEndpoint(strings.TrimSuffix(normalized, "/message") + "/sse")
|
|
}
|
|
appendEndpoint(messageURL)
|
|
}
|
|
|
|
if baseURL != "" {
|
|
normalized := strings.TrimSuffix(baseURL, "/")
|
|
if strings.HasSuffix(normalized, "/message") {
|
|
appendEndpoint(strings.TrimSuffix(normalized, "/message") + "/sse")
|
|
}
|
|
}
|
|
|
|
return out
|
|
}
|
|
|
|
func (g *MCPFederationGateway) connectionHTTPClient(ctx context.Context, connection mcpgw.Connection) *http.Client {
|
|
base := g.client
|
|
if base == nil {
|
|
base = &http.Client{Timeout: 30 * time.Second}
|
|
}
|
|
headers := normalizeHeaderMap(connection.Config["headers"])
|
|
|
|
if strings.TrimSpace(connection.AuthType) == "oauth" && g.oauthService != nil {
|
|
token, err := g.oauthService.GetValidToken(ctx, connection.ID)
|
|
if err != nil {
|
|
g.logger.Warn("failed to get OAuth token for connection",
|
|
slog.String("connection_id", connection.ID),
|
|
slog.Any("error", err))
|
|
} else if token != "" {
|
|
if headers == nil {
|
|
headers = map[string]string{}
|
|
}
|
|
headers["Authorization"] = "Bearer " + token
|
|
}
|
|
}
|
|
|
|
if len(headers) == 0 {
|
|
return base
|
|
}
|
|
transport := base.Transport
|
|
if transport == nil {
|
|
transport = http.DefaultTransport
|
|
}
|
|
return &http.Client{
|
|
Timeout: base.Timeout,
|
|
CheckRedirect: base.CheckRedirect,
|
|
Jar: base.Jar,
|
|
Transport: &staticHeaderRoundTripper{
|
|
next: transport,
|
|
headers: headers,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (g *MCPFederationGateway) ListStdioConnectionTools(ctx context.Context, botID string, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) {
|
|
sess, err := g.startStdioConnectionSession(ctx, botID, connection)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer sess.closeWithError(io.EOF)
|
|
|
|
payload, err := sess.call(ctx, mcpgw.JSONRPCRequest{
|
|
JSONRPC: "2.0",
|
|
ID: mcpgw.RawStringID("federated-stdio-tools-list"),
|
|
Method: "tools/list",
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return parseGatewayToolsListPayload(payload)
|
|
}
|
|
|
|
func (g *MCPFederationGateway) CallStdioConnectionTool(ctx context.Context, botID string, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) {
|
|
sess, err := g.startStdioConnectionSession(ctx, botID, connection)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer sess.closeWithError(io.EOF)
|
|
|
|
params, err := json.Marshal(map[string]any{
|
|
"name": toolName,
|
|
"arguments": args,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return sess.call(ctx, mcpgw.JSONRPCRequest{
|
|
JSONRPC: "2.0",
|
|
ID: mcpgw.RawStringID("federated-stdio-tools-call"),
|
|
Method: "tools/call",
|
|
Params: params,
|
|
})
|
|
}
|
|
|
|
func (g *MCPFederationGateway) startStdioConnectionSession(ctx context.Context, botID string, connection mcpgw.Connection) (*mcpSession, error) {
|
|
if g.handler == nil {
|
|
return nil, errors.New("containerd handler not configured")
|
|
}
|
|
containerID, err := g.handler.botContainerID(ctx, botID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := g.handler.ensureContainerAndTask(ctx, containerID, botID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
command := strings.TrimSpace(anyToString(connection.Config["command"]))
|
|
if command == "" {
|
|
return nil, errors.New("stdio mcp command is required")
|
|
}
|
|
request := MCPStdioRequest{
|
|
Name: strings.TrimSpace(connection.Name),
|
|
Command: command,
|
|
Args: normalizeStringSlice(connection.Config["args"]),
|
|
Env: normalizeStringMap(connection.Config["env"]),
|
|
Cwd: strings.TrimSpace(anyToString(connection.Config["cwd"])),
|
|
}
|
|
return g.handler.startContainerdMCPCommandSession(ctx, containerID, request)
|
|
}
|
|
|
|
func parseGatewayToolsListPayload(payload map[string]any) ([]mcpgw.ToolDescriptor, error) {
|
|
if err := mcpgw.PayloadError(payload); err != nil {
|
|
return nil, err
|
|
}
|
|
result, ok := payload["result"].(map[string]any)
|
|
if !ok {
|
|
return nil, errors.New("invalid tools/list result")
|
|
}
|
|
rawTools, ok := result["tools"].([]any)
|
|
if !ok {
|
|
return nil, errors.New("invalid tools/list tools field")
|
|
}
|
|
tools := make([]mcpgw.ToolDescriptor, 0, len(rawTools))
|
|
for _, rawTool := range rawTools {
|
|
item, ok := rawTool.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
name := strings.TrimSpace(anyToString(item["name"]))
|
|
if name == "" {
|
|
continue
|
|
}
|
|
description := strings.TrimSpace(anyToString(item["description"]))
|
|
inputSchema, _ := item["inputSchema"].(map[string]any)
|
|
if inputSchema == nil {
|
|
inputSchema = map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{},
|
|
}
|
|
}
|
|
tools = append(tools, mcpgw.ToolDescriptor{
|
|
Name: name,
|
|
Description: description,
|
|
InputSchema: inputSchema,
|
|
})
|
|
}
|
|
return tools, nil
|
|
}
|
|
|
|
func convertSDKTools(items []*sdkmcp.Tool) []mcpgw.ToolDescriptor {
|
|
if len(items) == 0 {
|
|
return []mcpgw.ToolDescriptor{}
|
|
}
|
|
tools := make([]mcpgw.ToolDescriptor, 0, len(items))
|
|
for _, item := range items {
|
|
if item == nil {
|
|
continue
|
|
}
|
|
name := strings.TrimSpace(item.Name)
|
|
if name == "" {
|
|
continue
|
|
}
|
|
tools = append(tools, mcpgw.ToolDescriptor{
|
|
Name: name,
|
|
Description: strings.TrimSpace(item.Description),
|
|
InputSchema: normalizeToolInputSchema(item.InputSchema),
|
|
})
|
|
}
|
|
return tools
|
|
}
|
|
|
|
func normalizeToolInputSchema(raw any) map[string]any {
|
|
if schema, ok := raw.(map[string]any); ok && schema != nil {
|
|
return schema
|
|
}
|
|
if raw != nil {
|
|
payload, err := json.Marshal(raw)
|
|
if err == nil {
|
|
var schema map[string]any
|
|
if err := json.Unmarshal(payload, &schema); err == nil && schema != nil {
|
|
return schema
|
|
}
|
|
}
|
|
}
|
|
return map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{},
|
|
}
|
|
}
|
|
|
|
func wrapSDKToolResult(result *sdkmcp.CallToolResult) (map[string]any, error) {
|
|
if result == nil {
|
|
return map[string]any{
|
|
"result": mcpgw.BuildToolSuccessResult(map[string]any{"ok": true}),
|
|
}, nil
|
|
}
|
|
payload, err := json.Marshal(result)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var parsed map[string]any
|
|
if err := json.Unmarshal(payload, &parsed); err != nil {
|
|
return nil, err
|
|
}
|
|
if parsed == nil {
|
|
parsed = map[string]any{}
|
|
}
|
|
return map[string]any{"result": parsed}, nil
|
|
}
|
|
|
|
func normalizeHeaderMap(raw any) map[string]string {
|
|
switch value := raw.(type) {
|
|
case map[string]string:
|
|
return value
|
|
case map[string]any:
|
|
out := make(map[string]string, len(value))
|
|
for k, v := range value {
|
|
key := strings.TrimSpace(k)
|
|
val := strings.TrimSpace(anyToString(v))
|
|
if key == "" || val == "" {
|
|
continue
|
|
}
|
|
out[key] = val
|
|
}
|
|
return out
|
|
default:
|
|
return map[string]string{}
|
|
}
|
|
}
|
|
|
|
func normalizeStringSlice(raw any) []string {
|
|
switch value := raw.(type) {
|
|
case []string:
|
|
out := make([]string, 0, len(value))
|
|
for _, item := range value {
|
|
item = strings.TrimSpace(item)
|
|
if item != "" {
|
|
out = append(out, item)
|
|
}
|
|
}
|
|
return out
|
|
case []any:
|
|
out := make([]string, 0, len(value))
|
|
for _, item := range value {
|
|
val := strings.TrimSpace(anyToString(item))
|
|
if val != "" {
|
|
out = append(out, val)
|
|
}
|
|
}
|
|
return out
|
|
default:
|
|
return []string{}
|
|
}
|
|
}
|
|
|
|
func normalizeStringMap(raw any) map[string]string {
|
|
switch value := raw.(type) {
|
|
case map[string]string:
|
|
return value
|
|
case map[string]any:
|
|
out := make(map[string]string, len(value))
|
|
for k, v := range value {
|
|
key := strings.TrimSpace(k)
|
|
val := strings.TrimSpace(anyToString(v))
|
|
if key == "" {
|
|
continue
|
|
}
|
|
out[key] = val
|
|
}
|
|
return out
|
|
default:
|
|
return map[string]string{}
|
|
}
|
|
}
|
|
|
|
func anyToString(v any) string {
|
|
if v == nil {
|
|
return ""
|
|
}
|
|
switch value := v.(type) {
|
|
case string:
|
|
return value
|
|
default:
|
|
return fmt.Sprintf("%v", v)
|
|
}
|
|
}
|
|
|
|
type staticHeaderRoundTripper struct {
|
|
next http.RoundTripper
|
|
headers map[string]string
|
|
}
|
|
|
|
func (t *staticHeaderRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
next := t.next
|
|
if next == nil {
|
|
next = http.DefaultTransport
|
|
}
|
|
clone := req.Clone(req.Context())
|
|
clone.Header = req.Header.Clone()
|
|
for key, value := range t.headers {
|
|
headerKey := strings.TrimSpace(key)
|
|
headerVal := strings.TrimSpace(value)
|
|
if headerKey == "" || headerVal == "" {
|
|
continue
|
|
}
|
|
clone.Header.Set(headerKey, headerVal)
|
|
}
|
|
return next.RoundTrip(clone)
|
|
}
|