Files
Memoh/internal/handlers/mcp_federation_gateway.go
T
Ran 6acdd191c7 Squashed commit of the following:
commit bcdb026ae43e4f95d0b2c4f9bd440a2df9d6b514
Author: Ran <16112591+chen-ran@users.noreply.github.com>
Date:   Thu Feb 12 17:10:32 2026 +0800

    chore: update DEVELOPMENT.md

commit 30281742ef
Merge: ca5c6a1 5b05f13
Author: BBQ <bbq@BBQdeMacBook-Air.local>
Date:   Thu Feb 12 15:49:17 2026 +0800

    merge(github/main): integrate fx dependency injection framework

    Merge upstream fx refactor and adapt all services to use go.uber.org/fx
    for dependency injection. Resolve conflicts in main.go, server.go,
    and service constructors while preserving our domain model changes.

    - Fix telegram adapter panic on shutdown (double close channel)
    - Fix feishu adapter processing messages after stop
    - Increase directory lookup timeout from 2s to 5s

commit ca5c6a1866
Author: BBQ <bbq@BBQdeMacBook-Air.local>
Date:   Thu Feb 12 15:33:09 2026 +0800

    refactor(core): restructure conversation, channel and message domains

    - Rename chat module to conversation with flow-based architecture
    - Move channelidentities into channel/identities subpackage
    - Add channel/route for routing logic
    - Add message service with event hub
    - Add MCP providers: container, directory, schedule
    - Refactor Feishu/Telegram adapters with directory and stream support
    - Add platform management page and channel badges in web UI
    - Update database schema for conversations, messages and channel routes
    - Add @memoh/shared package for cross-package type definitions

commit 75e2ef0467
Merge: d99ba38 01cb6c8
Author: BBQ <bbq@BBQdeMacBook-Air.local>
Date:   Thu Feb 12 14:45:49 2026 +0800

    merge(github): merge github/main, resolve index.ts URL conflict

    Keep our defensive absolute-URL check in createAuthFetcher.

commit d99ba38b7d
Merge: 860e20f 35ce7d1
Author: BBQ <bbq@BBQdeMacBook-Air.local>
Date:   Thu Feb 12 05:20:18 2026 +0800

    merge(github): merge github/main, keep our code and docs/spec

commit 860e20fe70
Author: BBQ <bbq@BBQdeMacBook-Air.local>
Date:   Wed Feb 11 22:13:27 2026 +0800

    docs(docs): add concepts and style guides for VitePress site

    - Add concepts: identity-and-binding, index (en/zh)
    - Add style: terminology (en/zh)
    - Update index and zh/index
    - Update .vitepress/config.ts

commit a75fdb8040
Author: BBQ <bbq@BBQdeMacBook-Air.local>
Date:   Wed Feb 11 17:37:16 2026 +0800

    refactor(mcp): standardize unified tool gateway on go-sdk

    Split business executors from federation sources and migrate unified tool/federation transports to the official go-sdk for stricter MCP compliance and safer session lifecycle handling. Add targeted regression tests for accept compatibility, initialization retries, pending cleanup, and include updated swagger artifacts.

commit 02b33c8e85
Author: BBQ <bbq@BBQdeMacBook-Air.local>
Date:   Wed Feb 11 15:42:21 2026 +0800

    refactor(core): finalize user-centric identity and policy cleanup

    Unify auth and chat identity semantics around user_id, enforce personal-bot owner-only authorization, and remove legacy compatibility branches in integration tests.

commit 06e8619a37
Author: BBQ <bbq@BBQdeMacBook-Air.local>
Date:   Wed Feb 11 14:47:03 2026 +0800

    refactor(core): migrate channel identity and binding across app

    Align channel identity and bind flow across backend and app-facing layers, including generated swagger artifacts and package lock updates while excluding docs content changes.
2026-02-12 17:13:03 +08:00

481 lines
12 KiB
Go

package handlers
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"time"
mcpgw "github.com/memohai/memoh/internal/mcp"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
)
type MCPFederationGateway struct {
handler *ContainerdHandler
logger *slog.Logger
client *http.Client
}
func NewMCPFederationGateway(log *slog.Logger, handler *ContainerdHandler) *MCPFederationGateway {
if log == nil {
log = slog.Default()
}
return &MCPFederationGateway{
handler: handler,
logger: log.With(slog.String("gateway", "mcp_federation")),
client: &http.Client{
Timeout: 30 * time.Second,
},
}
}
func (g *MCPFederationGateway) ListHTTPConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) {
session, err := g.connectStreamableSession(ctx, connection)
if err != nil {
return nil, err
}
defer func() { _ = session.Close() }()
result, err := session.ListTools(ctx, &sdkmcp.ListToolsParams{})
if err != nil {
return nil, err
}
return convertSDKTools(result.Tools), nil
}
func (g *MCPFederationGateway) CallHTTPConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) {
session, err := g.connectStreamableSession(ctx, connection)
if err != nil {
return nil, err
}
defer func() { _ = session.Close() }()
result, err := session.CallTool(ctx, &sdkmcp.CallToolParams{
Name: strings.TrimSpace(toolName),
Arguments: args,
})
if err != nil {
return nil, err
}
return wrapSDKToolResult(result)
}
func (g *MCPFederationGateway) ListSSEConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) {
session, err := g.connectSSESession(ctx, connection)
if err != nil {
return nil, err
}
defer func() { _ = session.Close() }()
result, err := session.ListTools(ctx, &sdkmcp.ListToolsParams{})
if err != nil {
return nil, err
}
return convertSDKTools(result.Tools), nil
}
func (g *MCPFederationGateway) CallSSEConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) {
session, err := g.connectSSESession(ctx, connection)
if err != nil {
return nil, err
}
defer func() { _ = session.Close() }()
result, err := session.CallTool(ctx, &sdkmcp.CallToolParams{
Name: strings.TrimSpace(toolName),
Arguments: args,
})
if err != nil {
return nil, err
}
return wrapSDKToolResult(result)
}
func (g *MCPFederationGateway) connectStreamableSession(ctx context.Context, connection mcpgw.Connection) (*sdkmcp.ClientSession, error) {
url := strings.TrimSpace(anyToString(connection.Config["url"]))
if url == "" {
return nil, fmt.Errorf("http mcp url is required")
}
client := sdkmcp.NewClient(&sdkmcp.Implementation{
Name: "memoh-federation-client",
Version: "v1",
}, nil)
transport := &sdkmcp.StreamableClientTransport{
Endpoint: url,
HTTPClient: g.connectionHTTPClient(connection),
MaxRetries: -1,
}
return client.Connect(ctx, transport, nil)
}
func (g *MCPFederationGateway) connectSSESession(ctx context.Context, connection mcpgw.Connection) (*sdkmcp.ClientSession, error) {
endpoints := resolveSSEEndpointCandidates(connection.Config)
if len(endpoints) == 0 {
return nil, fmt.Errorf("sse mcp url is required")
}
var lastErr error
for _, endpoint := range endpoints {
client := sdkmcp.NewClient(&sdkmcp.Implementation{
Name: "memoh-federation-client",
Version: "v1",
}, nil)
transport := &sdkmcp.SSEClientTransport{
Endpoint: endpoint,
HTTPClient: g.connectionHTTPClient(connection),
}
session, err := client.Connect(ctx, transport, nil)
if err == nil {
return session, nil
}
lastErr = err
}
if lastErr == nil {
lastErr = fmt.Errorf("no sse endpoint candidate available")
}
return nil, fmt.Errorf("connect sse mcp failed: %w", lastErr)
}
func resolveSSEEndpointCandidates(config map[string]any) []string {
if config == nil {
return []string{}
}
seen := map[string]struct{}{}
out := make([]string, 0, 4)
appendEndpoint := func(value string) {
value = strings.TrimSpace(value)
if value == "" {
return
}
if _, ok := seen[value]; ok {
return
}
seen[value] = struct{}{}
out = append(out, value)
}
for _, key := range []string{"sse_url", "sseUrl"} {
appendEndpoint(anyToString(config[key]))
}
baseURL := strings.TrimSpace(anyToString(config["url"]))
appendEndpoint(baseURL)
var messageURL string
for _, key := range []string{"message_url", "messageUrl"} {
if value := strings.TrimSpace(anyToString(config[key])); value != "" {
messageURL = value
break
}
}
if messageURL != "" {
normalized := strings.TrimSuffix(messageURL, "/")
if strings.HasSuffix(normalized, "/message") {
appendEndpoint(strings.TrimSuffix(normalized, "/message") + "/sse")
}
appendEndpoint(messageURL)
}
if baseURL != "" {
normalized := strings.TrimSuffix(baseURL, "/")
if strings.HasSuffix(normalized, "/message") {
appendEndpoint(strings.TrimSuffix(normalized, "/message") + "/sse")
}
}
return out
}
func (g *MCPFederationGateway) connectionHTTPClient(connection mcpgw.Connection) *http.Client {
base := g.client
if base == nil {
base = &http.Client{Timeout: 30 * time.Second}
}
headers := normalizeHeaderMap(connection.Config["headers"])
if len(headers) == 0 {
return base
}
transport := base.Transport
if transport == nil {
transport = http.DefaultTransport
}
return &http.Client{
Timeout: base.Timeout,
CheckRedirect: base.CheckRedirect,
Jar: base.Jar,
Transport: &staticHeaderRoundTripper{
next: transport,
headers: headers,
},
}
}
func (g *MCPFederationGateway) ListStdioConnectionTools(ctx context.Context, botID string, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) {
sess, err := g.startStdioConnectionSession(ctx, botID, connection)
if err != nil {
return nil, err
}
defer sess.closeWithError(io.EOF)
payload, err := sess.call(ctx, mcpgw.JSONRPCRequest{
JSONRPC: "2.0",
ID: mcpgw.RawStringID("federated-stdio-tools-list"),
Method: "tools/list",
})
if err != nil {
return nil, err
}
return parseGatewayToolsListPayload(payload)
}
func (g *MCPFederationGateway) CallStdioConnectionTool(ctx context.Context, botID string, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) {
sess, err := g.startStdioConnectionSession(ctx, botID, connection)
if err != nil {
return nil, err
}
defer sess.closeWithError(io.EOF)
params, err := json.Marshal(map[string]any{
"name": toolName,
"arguments": args,
})
if err != nil {
return nil, err
}
return sess.call(ctx, mcpgw.JSONRPCRequest{
JSONRPC: "2.0",
ID: mcpgw.RawStringID("federated-stdio-tools-call"),
Method: "tools/call",
Params: params,
})
}
func (g *MCPFederationGateway) startStdioConnectionSession(ctx context.Context, botID string, connection mcpgw.Connection) (*mcpSession, error) {
if g.handler == nil {
return nil, fmt.Errorf("containerd handler not configured")
}
containerID, err := g.handler.botContainerID(ctx, botID)
if err != nil {
return nil, err
}
if err := g.handler.validateMCPContainer(ctx, containerID, botID); err != nil {
return nil, err
}
if err := g.handler.ensureContainerAndTask(ctx, containerID, botID); err != nil {
return nil, err
}
command := strings.TrimSpace(anyToString(connection.Config["command"]))
if command == "" {
return nil, fmt.Errorf("stdio mcp command is required")
}
request := MCPStdioRequest{
Name: strings.TrimSpace(connection.Name),
Command: command,
Args: normalizeStringSlice(connection.Config["args"]),
Env: normalizeStringMap(connection.Config["env"]),
Cwd: strings.TrimSpace(anyToString(connection.Config["cwd"])),
}
return g.handler.startContainerdMCPCommandSession(ctx, containerID, request)
}
func parseGatewayToolsListPayload(payload map[string]any) ([]mcpgw.ToolDescriptor, error) {
if err := mcpgw.PayloadError(payload); err != nil {
return nil, err
}
result, ok := payload["result"].(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid tools/list result")
}
rawTools, ok := result["tools"].([]any)
if !ok {
return nil, fmt.Errorf("invalid tools/list tools field")
}
tools := make([]mcpgw.ToolDescriptor, 0, len(rawTools))
for _, rawTool := range rawTools {
item, ok := rawTool.(map[string]any)
if !ok {
continue
}
name := strings.TrimSpace(anyToString(item["name"]))
if name == "" {
continue
}
description := strings.TrimSpace(anyToString(item["description"]))
inputSchema, _ := item["inputSchema"].(map[string]any)
if inputSchema == nil {
inputSchema = map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
tools = append(tools, mcpgw.ToolDescriptor{
Name: name,
Description: description,
InputSchema: inputSchema,
})
}
return tools, nil
}
func convertSDKTools(items []*sdkmcp.Tool) []mcpgw.ToolDescriptor {
if len(items) == 0 {
return []mcpgw.ToolDescriptor{}
}
tools := make([]mcpgw.ToolDescriptor, 0, len(items))
for _, item := range items {
if item == nil {
continue
}
name := strings.TrimSpace(item.Name)
if name == "" {
continue
}
tools = append(tools, mcpgw.ToolDescriptor{
Name: name,
Description: strings.TrimSpace(item.Description),
InputSchema: normalizeToolInputSchema(item.InputSchema),
})
}
return tools
}
func normalizeToolInputSchema(raw any) map[string]any {
if schema, ok := raw.(map[string]any); ok && schema != nil {
return schema
}
if raw != nil {
payload, err := json.Marshal(raw)
if err == nil {
var schema map[string]any
if err := json.Unmarshal(payload, &schema); err == nil && schema != nil {
return schema
}
}
}
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
func wrapSDKToolResult(result *sdkmcp.CallToolResult) (map[string]any, error) {
if result == nil {
return map[string]any{
"result": mcpgw.BuildToolSuccessResult(map[string]any{"ok": true}),
}, nil
}
payload, err := json.Marshal(result)
if err != nil {
return nil, err
}
var parsed map[string]any
if err := json.Unmarshal(payload, &parsed); err != nil {
return nil, err
}
if parsed == nil {
parsed = map[string]any{}
}
return map[string]any{"result": parsed}, nil
}
func normalizeHeaderMap(raw any) map[string]string {
switch value := raw.(type) {
case map[string]string:
return value
case map[string]any:
out := make(map[string]string, len(value))
for k, v := range value {
key := strings.TrimSpace(k)
val := strings.TrimSpace(anyToString(v))
if key == "" || val == "" {
continue
}
out[key] = val
}
return out
default:
return map[string]string{}
}
}
func normalizeStringSlice(raw any) []string {
switch value := raw.(type) {
case []string:
out := make([]string, 0, len(value))
for _, item := range value {
item = strings.TrimSpace(item)
if item != "" {
out = append(out, item)
}
}
return out
case []any:
out := make([]string, 0, len(value))
for _, item := range value {
val := strings.TrimSpace(anyToString(item))
if val != "" {
out = append(out, val)
}
}
return out
default:
return []string{}
}
}
func normalizeStringMap(raw any) map[string]string {
switch value := raw.(type) {
case map[string]string:
return value
case map[string]any:
out := make(map[string]string, len(value))
for k, v := range value {
key := strings.TrimSpace(k)
val := strings.TrimSpace(anyToString(v))
if key == "" {
continue
}
out[key] = val
}
return out
default:
return map[string]string{}
}
}
func anyToString(v any) string {
if v == nil {
return ""
}
switch value := v.(type) {
case string:
return value
default:
return fmt.Sprintf("%v", v)
}
}
type staticHeaderRoundTripper struct {
next http.RoundTripper
headers map[string]string
}
func (t *staticHeaderRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
next := t.next
if next == nil {
next = http.DefaultTransport
}
clone := req.Clone(req.Context())
clone.Header = req.Header.Clone()
for key, value := range t.headers {
headerKey := strings.TrimSpace(key)
headerVal := strings.TrimSpace(value)
if headerKey == "" || headerVal == "" {
continue
}
clone.Header.Set(headerKey, headerVal)
}
return next.RoundTrip(clone)
}