mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
280 lines
8.0 KiB
Go
280 lines
8.0 KiB
Go
package federation
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
mcpgw "github.com/memohai/memoh/internal/mcp"
|
|
)
|
|
|
|
const cacheTTL = 5 * time.Second
|
|
|
|
type ConnectionLister interface {
|
|
ListActiveByBot(ctx context.Context, botID string) ([]mcpgw.Connection, error)
|
|
}
|
|
|
|
type Gateway interface {
|
|
ListHTTPConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error)
|
|
CallHTTPConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error)
|
|
|
|
ListSSEConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error)
|
|
CallSSEConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error)
|
|
|
|
ListStdioConnectionTools(ctx context.Context, botID string, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error)
|
|
CallStdioConnectionTool(ctx context.Context, botID string, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error)
|
|
}
|
|
|
|
type toolRoute struct {
|
|
sourceType string
|
|
originalName string
|
|
connection mcpgw.Connection
|
|
}
|
|
|
|
type cacheEntry struct {
|
|
expiresAt time.Time
|
|
routes map[string]toolRoute
|
|
tools []mcpgw.ToolDescriptor
|
|
}
|
|
|
|
type Source struct {
|
|
logger *slog.Logger
|
|
gateway Gateway
|
|
connections ConnectionLister
|
|
|
|
mu sync.Mutex
|
|
cache map[string]cacheEntry
|
|
}
|
|
|
|
func NewSource(log *slog.Logger, gateway Gateway, connections ConnectionLister) *Source {
|
|
if log == nil {
|
|
log = slog.Default()
|
|
}
|
|
return &Source{
|
|
logger: log.With(slog.String("source", "federated_mcp_tool")),
|
|
gateway: gateway,
|
|
connections: connections,
|
|
cache: map[string]cacheEntry{},
|
|
}
|
|
}
|
|
|
|
func (s *Source) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) {
|
|
botID := strings.TrimSpace(session.BotID)
|
|
if botID == "" || s.gateway == nil {
|
|
return []mcpgw.ToolDescriptor{}, nil
|
|
}
|
|
if cached, ok := s.getCache(botID); ok {
|
|
return cloneTools(cached.tools), nil
|
|
}
|
|
tools, routes := s.buildToolsAndRoutes(ctx, botID)
|
|
s.setCache(botID, cacheEntry{
|
|
expiresAt: time.Now().Add(cacheTTL),
|
|
routes: routes,
|
|
tools: tools,
|
|
})
|
|
return cloneTools(tools), nil
|
|
}
|
|
|
|
func (s *Source) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) {
|
|
if s.gateway == nil {
|
|
return mcpgw.BuildToolErrorResult("federation gateway not available"), nil
|
|
}
|
|
botID := strings.TrimSpace(session.BotID)
|
|
if botID == "" {
|
|
return mcpgw.BuildToolErrorResult("bot_id is required"), nil
|
|
}
|
|
|
|
route, ok := s.getRoute(botID, toolName)
|
|
if !ok {
|
|
// Refresh route cache; result intentionally discarded.
|
|
if _, err := s.ListTools(ctx, session); err != nil {
|
|
s.logger.Warn("federation: refresh tools cache failed", slog.Any("error", err))
|
|
}
|
|
route, ok = s.getRoute(botID, toolName)
|
|
if !ok {
|
|
return nil, mcpgw.ErrToolNotFound
|
|
}
|
|
}
|
|
if arguments == nil {
|
|
arguments = map[string]any{}
|
|
}
|
|
|
|
var (
|
|
payload map[string]any
|
|
err error
|
|
)
|
|
switch route.sourceType {
|
|
case "http":
|
|
payload, err = s.gateway.CallHTTPConnectionTool(ctx, route.connection, route.originalName, arguments)
|
|
case "sse":
|
|
payload, err = s.gateway.CallSSEConnectionTool(ctx, route.connection, route.originalName, arguments)
|
|
case "stdio":
|
|
payload, err = s.gateway.CallStdioConnectionTool(ctx, botID, route.connection, route.originalName, arguments)
|
|
default:
|
|
return mcpgw.BuildToolErrorResult("unsupported federated source"), nil
|
|
}
|
|
if err != nil {
|
|
return mcpgw.BuildToolErrorResult(err.Error()), nil
|
|
}
|
|
if err := mcpgw.PayloadError(payload); err != nil {
|
|
return mcpgw.BuildToolErrorResult(err.Error()), nil
|
|
}
|
|
if result, ok := payload["result"].(map[string]any); ok {
|
|
return result, nil
|
|
}
|
|
return mcpgw.BuildToolSuccessResult(payload), nil
|
|
}
|
|
|
|
func (s *Source) buildToolsAndRoutes(ctx context.Context, botID string) ([]mcpgw.ToolDescriptor, map[string]toolRoute) {
|
|
routes := map[string]toolRoute{}
|
|
tools := make([]mcpgw.ToolDescriptor, 0, 16)
|
|
|
|
addTool := func(descriptor mcpgw.ToolDescriptor, route toolRoute) {
|
|
name := strings.TrimSpace(descriptor.Name)
|
|
if name == "" {
|
|
return
|
|
}
|
|
finalName := name
|
|
if _, exists := routes[finalName]; exists {
|
|
seed := strings.ReplaceAll(finalName, ".", "_")
|
|
if seed == "" {
|
|
seed = "tool"
|
|
}
|
|
for i := 2; ; i++ {
|
|
candidate := seed + "_" + strconv.Itoa(i)
|
|
if _, ok := routes[candidate]; ok {
|
|
continue
|
|
}
|
|
finalName = candidate
|
|
break
|
|
}
|
|
}
|
|
descriptor.Name = finalName
|
|
routes[finalName] = route
|
|
tools = append(tools, descriptor)
|
|
}
|
|
|
|
if s.connections != nil {
|
|
items, err := s.connections.ListActiveByBot(ctx, botID)
|
|
if err != nil {
|
|
s.logger.Warn("list mcp connections failed", slog.String("bot_id", botID), slog.Any("error", err))
|
|
} else {
|
|
sort.Slice(items, func(i, j int) bool {
|
|
if items[i].Name == items[j].Name {
|
|
return items[i].ID < items[j].ID
|
|
}
|
|
return items[i].Name < items[j].Name
|
|
})
|
|
for _, connection := range items {
|
|
var connTools []mcpgw.ToolDescriptor
|
|
switch strings.ToLower(strings.TrimSpace(connection.Type)) {
|
|
case "http":
|
|
connTools, err = s.gateway.ListHTTPConnectionTools(ctx, connection)
|
|
case "sse":
|
|
connTools, err = s.gateway.ListSSEConnectionTools(ctx, connection)
|
|
case "stdio":
|
|
connTools, err = s.gateway.ListStdioConnectionTools(ctx, botID, connection)
|
|
default:
|
|
s.logger.Warn("unsupported mcp connection type", slog.String("connection_id", connection.ID), slog.String("type", connection.Type))
|
|
continue
|
|
}
|
|
if err != nil {
|
|
s.logger.Warn("list tools from connection failed", slog.String("connection_id", connection.ID), slog.String("name", connection.Name), slog.Any("error", err))
|
|
continue
|
|
}
|
|
prefix := sanitizePrefix(connection.Name)
|
|
for _, tool := range connTools {
|
|
origin := strings.TrimSpace(tool.Name)
|
|
alias := origin
|
|
if prefix != "" {
|
|
alias = prefix + "_" + origin
|
|
}
|
|
tool.Name = alias
|
|
if strings.TrimSpace(tool.Description) != "" {
|
|
tool.Description = "[" + strings.TrimSpace(connection.Name) + "] " + tool.Description
|
|
} else {
|
|
tool.Description = "[" + strings.TrimSpace(connection.Name) + "] " + origin
|
|
}
|
|
addTool(tool, toolRoute{
|
|
sourceType: strings.ToLower(strings.TrimSpace(connection.Type)),
|
|
originalName: origin,
|
|
connection: connection,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return tools, routes
|
|
}
|
|
|
|
func sanitizePrefix(raw string) string {
|
|
raw = strings.TrimSpace(strings.ToLower(raw))
|
|
if raw == "" {
|
|
return "mcp"
|
|
}
|
|
builder := strings.Builder{}
|
|
for _, ch := range raw {
|
|
if (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' {
|
|
builder.WriteRune(ch)
|
|
continue
|
|
}
|
|
builder.WriteRune('_')
|
|
}
|
|
normalized := strings.Trim(builder.String(), "._-")
|
|
if normalized == "" {
|
|
return "mcp"
|
|
}
|
|
return normalized
|
|
}
|
|
|
|
func cloneTools(items []mcpgw.ToolDescriptor) []mcpgw.ToolDescriptor {
|
|
if len(items) == 0 {
|
|
return []mcpgw.ToolDescriptor{}
|
|
}
|
|
out := make([]mcpgw.ToolDescriptor, 0, len(items))
|
|
for _, item := range items {
|
|
out = append(out, mcpgw.ToolDescriptor{
|
|
Name: item.Name,
|
|
Description: item.Description,
|
|
InputSchema: item.InputSchema,
|
|
})
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (s *Source) getCache(botID string) (cacheEntry, bool) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
cached, ok := s.cache[botID]
|
|
if !ok || time.Now().After(cached.expiresAt) {
|
|
return cacheEntry{}, false
|
|
}
|
|
return cached, true
|
|
}
|
|
|
|
func (s *Source) setCache(botID string, entry cacheEntry) {
|
|
s.mu.Lock()
|
|
s.cache[botID] = entry
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *Source) getRoute(botID, toolName string) (toolRoute, bool) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
cached, ok := s.cache[botID]
|
|
if !ok || time.Now().After(cached.expiresAt) {
|
|
return toolRoute{}, false
|
|
}
|
|
route, exists := cached.routes[strings.TrimSpace(toolName)]
|
|
return route, exists
|
|
}
|
|
|
|
func (s *Source) String() string {
|
|
return fmt.Sprintf("FederationSource(%p)", s)
|
|
}
|