Files

280 lines
8.0 KiB
Go

package federation
import (
"context"
"fmt"
"log/slog"
"sort"
"strconv"
"strings"
"sync"
"time"
mcpgw "github.com/memohai/memoh/internal/mcp"
)
const cacheTTL = 5 * time.Second
type ConnectionLister interface {
ListActiveByBot(ctx context.Context, botID string) ([]mcpgw.Connection, error)
}
type Gateway interface {
ListHTTPConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error)
CallHTTPConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error)
ListSSEConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error)
CallSSEConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error)
ListStdioConnectionTools(ctx context.Context, botID string, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error)
CallStdioConnectionTool(ctx context.Context, botID string, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error)
}
type toolRoute struct {
sourceType string
originalName string
connection mcpgw.Connection
}
type cacheEntry struct {
expiresAt time.Time
routes map[string]toolRoute
tools []mcpgw.ToolDescriptor
}
type Source struct {
logger *slog.Logger
gateway Gateway
connections ConnectionLister
mu sync.Mutex
cache map[string]cacheEntry
}
func NewSource(log *slog.Logger, gateway Gateway, connections ConnectionLister) *Source {
if log == nil {
log = slog.Default()
}
return &Source{
logger: log.With(slog.String("tool_source", "federated_mcp_tool")),
gateway: gateway,
connections: connections,
cache: map[string]cacheEntry{},
}
}
func (s *Source) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) {
botID := strings.TrimSpace(session.BotID)
if botID == "" || s.gateway == nil {
return []mcpgw.ToolDescriptor{}, nil
}
if cached, ok := s.getCache(botID); ok {
return cloneTools(cached.tools), nil
}
tools, routes := s.buildToolsAndRoutes(ctx, botID)
s.setCache(botID, cacheEntry{
expiresAt: time.Now().Add(cacheTTL),
routes: routes,
tools: tools,
})
return cloneTools(tools), nil
}
func (s *Source) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) {
if s.gateway == nil {
return mcpgw.BuildToolErrorResult("federation gateway not available"), nil
}
botID := strings.TrimSpace(session.BotID)
if botID == "" {
return mcpgw.BuildToolErrorResult("bot_id is required"), nil
}
route, ok := s.getRoute(botID, toolName)
if !ok {
// Refresh route cache; result intentionally discarded.
if _, err := s.ListTools(ctx, session); err != nil {
s.logger.Warn("federation: refresh tools cache failed", slog.Any("error", err))
}
route, ok = s.getRoute(botID, toolName)
if !ok {
return nil, mcpgw.ErrToolNotFound
}
}
if arguments == nil {
arguments = map[string]any{}
}
var (
payload map[string]any
err error
)
switch route.sourceType {
case "http":
payload, err = s.gateway.CallHTTPConnectionTool(ctx, route.connection, route.originalName, arguments)
case "sse":
payload, err = s.gateway.CallSSEConnectionTool(ctx, route.connection, route.originalName, arguments)
case "stdio":
payload, err = s.gateway.CallStdioConnectionTool(ctx, botID, route.connection, route.originalName, arguments)
default:
return mcpgw.BuildToolErrorResult("unsupported federated source"), nil
}
if err != nil {
return mcpgw.BuildToolErrorResult(err.Error()), nil
}
if err := mcpgw.PayloadError(payload); err != nil {
return mcpgw.BuildToolErrorResult(err.Error()), nil
}
if result, ok := payload["result"].(map[string]any); ok {
return result, nil
}
return mcpgw.BuildToolSuccessResult(payload), nil
}
func (s *Source) buildToolsAndRoutes(ctx context.Context, botID string) ([]mcpgw.ToolDescriptor, map[string]toolRoute) {
routes := map[string]toolRoute{}
tools := make([]mcpgw.ToolDescriptor, 0, 16)
addTool := func(descriptor mcpgw.ToolDescriptor, route toolRoute) {
name := strings.TrimSpace(descriptor.Name)
if name == "" {
return
}
finalName := name
if _, exists := routes[finalName]; exists {
seed := strings.ReplaceAll(finalName, ".", "_")
if seed == "" {
seed = "tool"
}
for i := 2; ; i++ {
candidate := seed + "_" + strconv.Itoa(i)
if _, ok := routes[candidate]; ok {
continue
}
finalName = candidate
break
}
}
descriptor.Name = finalName
routes[finalName] = route
tools = append(tools, descriptor)
}
if s.connections != nil {
items, err := s.connections.ListActiveByBot(ctx, botID)
if err != nil {
s.logger.Warn("list mcp connections failed", slog.String("bot_id", botID), slog.Any("error", err))
} else {
sort.Slice(items, func(i, j int) bool {
if items[i].Name == items[j].Name {
return items[i].ID < items[j].ID
}
return items[i].Name < items[j].Name
})
for _, connection := range items {
var connTools []mcpgw.ToolDescriptor
switch strings.ToLower(strings.TrimSpace(connection.Type)) {
case "http":
connTools, err = s.gateway.ListHTTPConnectionTools(ctx, connection)
case "sse":
connTools, err = s.gateway.ListSSEConnectionTools(ctx, connection)
case "stdio":
connTools, err = s.gateway.ListStdioConnectionTools(ctx, botID, connection)
default:
s.logger.Warn("unsupported mcp connection type", slog.String("connection_id", connection.ID), slog.String("type", connection.Type))
continue
}
if err != nil {
s.logger.Warn("list tools from connection failed", slog.String("connection_id", connection.ID), slog.String("name", connection.Name), slog.Any("error", err))
continue
}
prefix := sanitizePrefix(connection.Name)
for _, tool := range connTools {
origin := strings.TrimSpace(tool.Name)
alias := origin
if prefix != "" {
alias = prefix + "_" + origin
}
tool.Name = alias
if strings.TrimSpace(tool.Description) != "" {
tool.Description = "[" + strings.TrimSpace(connection.Name) + "] " + tool.Description
} else {
tool.Description = "[" + strings.TrimSpace(connection.Name) + "] " + origin
}
addTool(tool, toolRoute{
sourceType: strings.ToLower(strings.TrimSpace(connection.Type)),
originalName: origin,
connection: connection,
})
}
}
}
}
return tools, routes
}
func sanitizePrefix(raw string) string {
raw = strings.TrimSpace(strings.ToLower(raw))
if raw == "" {
return "mcp"
}
builder := strings.Builder{}
for _, ch := range raw {
if (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' {
builder.WriteRune(ch)
continue
}
builder.WriteRune('_')
}
normalized := strings.Trim(builder.String(), "._-")
if normalized == "" {
return "mcp"
}
return normalized
}
func cloneTools(items []mcpgw.ToolDescriptor) []mcpgw.ToolDescriptor {
if len(items) == 0 {
return []mcpgw.ToolDescriptor{}
}
out := make([]mcpgw.ToolDescriptor, 0, len(items))
for _, item := range items {
out = append(out, mcpgw.ToolDescriptor{
Name: item.Name,
Description: item.Description,
InputSchema: item.InputSchema,
})
}
return out
}
func (s *Source) getCache(botID string) (cacheEntry, bool) {
s.mu.Lock()
defer s.mu.Unlock()
cached, ok := s.cache[botID]
if !ok || time.Now().After(cached.expiresAt) {
return cacheEntry{}, false
}
return cached, true
}
func (s *Source) setCache(botID string, entry cacheEntry) {
s.mu.Lock()
s.cache[botID] = entry
s.mu.Unlock()
}
func (s *Source) getRoute(botID, toolName string) (toolRoute, bool) {
s.mu.Lock()
defer s.mu.Unlock()
cached, ok := s.cache[botID]
if !ok || time.Now().After(cached.expiresAt) {
return toolRoute{}, false
}
route, exists := cached.routes[strings.TrimSpace(toolName)]
return route, exists
}
func (s *Source) String() string {
return fmt.Sprintf("FederationSource(%p)", s)
}