diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 625e44aa..1159117d 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -460,7 +460,7 @@ func startContainerReconciliation(lc fx.Lifecycle, containerdHandler *handlers.C }) } -func startServer(lc fx.Lifecycle, logger *slog.Logger, srv *server.Server, shutdowner fx.Shutdowner, cfg config.Config, queries *dbsqlc.Queries, botService *bots.Service, containerdHandler *handlers.ContainerdHandler) { +func startServer(lc fx.Lifecycle, logger *slog.Logger, srv *server.Server, shutdowner fx.Shutdowner, cfg config.Config, queries *dbsqlc.Queries, botService *bots.Service, containerdHandler *handlers.ContainerdHandler, mcpConnService *mcp.ConnectionService, toolGateway *mcp.ToolGatewayService) { fmt.Printf("Starting Memoh Agent %s\n", version.GetInfo()) lc.Append(fx.Hook{ @@ -469,6 +469,7 @@ func startServer(lc fx.Lifecycle, logger *slog.Logger, srv *server.Server, shutd return err } botService.SetContainerLifecycle(containerdHandler) + botService.AddRuntimeChecker(mcp.NewConnectionChecker(logger, mcpConnService, toolGateway)) go func() { if err := srv.Start(); err != nil && !errors.Is(err, http.ErrServerClosed) { diff --git a/internal/bots/service.go b/internal/bots/service.go index 751e94d2..82702a93 100644 --- a/internal/bots/service.go +++ b/internal/bots/service.go @@ -23,6 +23,7 @@ type Service struct { queries *sqlc.Queries logger *slog.Logger containerLifecycle ContainerLifecycle + checkers []RuntimeChecker } const ( @@ -56,6 +57,13 @@ func (s *Service) SetContainerLifecycle(lc ContainerLifecycle) { s.containerLifecycle = lc } +// AddRuntimeChecker registers an additional runtime checker. +func (s *Service) AddRuntimeChecker(c RuntimeChecker) { + if c != nil { + s.checkers = append(s.checkers, c) + } +} + // AuthorizeAccess checks whether userID may access the given bot. func (s *Service) AuthorizeAccess(ctx context.Context, userID, botID string, isAdmin bool, policy AccessPolicy) (Bot, error) { if s.queries == nil { @@ -836,6 +844,11 @@ func (s *Service) buildRuntimeChecks(ctx context.Context, row sqlc.Bot) ([]BotCh } checks = append(checks, dataCheck) + botID := uuid.UUID(row.ID.Bytes).String() + for _, checker := range s.checkers { + checks = append(checks, checker.CheckBot(ctx, botID)...) + } + return checks, nil } diff --git a/internal/bots/types.go b/internal/bots/types.go index e002524c..3f1024a2 100644 --- a/internal/bots/types.go +++ b/internal/bots/types.go @@ -87,6 +87,11 @@ type ContainerLifecycle interface { CleanupBotContainer(ctx context.Context, botID string) error } +// RuntimeChecker produces runtime check items for a bot. +type RuntimeChecker interface { + CheckBot(ctx context.Context, botID string) []BotCheck +} + const ( BotTypePersonal = "personal" BotTypePublic = "public" diff --git a/internal/mcp/checker.go b/internal/mcp/checker.go new file mode 100644 index 00000000..28d5eafc --- /dev/null +++ b/internal/mcp/checker.go @@ -0,0 +1,126 @@ +package mcp + +import ( + "context" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/memohai/memoh/internal/bots" +) + +const ( + mcpCheckTimeout = 8 * time.Second +) + +// ConnectionChecker implements bots.RuntimeChecker for MCP connections. +type ConnectionChecker struct { + logger *slog.Logger + connections *ConnectionService + gateway *ToolGatewayService +} + +// NewConnectionChecker creates an MCP runtime checker. +func NewConnectionChecker(log *slog.Logger, connections *ConnectionService, gateway *ToolGatewayService) *ConnectionChecker { + if log == nil { + log = slog.Default() + } + return &ConnectionChecker{ + logger: log.With(slog.String("checker", "mcp")), + connections: connections, + gateway: gateway, + } +} + +// CheckBot probes each active MCP connection for a bot and returns check results. +func (c *ConnectionChecker) CheckBot(ctx context.Context, botID string) []bots.BotCheck { + if c.connections == nil { + return nil + } + items, err := c.connections.ListActiveByBot(ctx, botID) + if err != nil { + c.logger.Warn("mcp checker: list connections failed", + slog.String("bot_id", botID), slog.Any("error", err)) + return nil + } + if len(items) == 0 { + return nil + } + + checks := make([]bots.BotCheck, 0, len(items)) + for _, conn := range items { + check := c.probeConnection(ctx, botID, conn) + checks = append(checks, check) + } + return checks +} + +func (c *ConnectionChecker) probeConnection(ctx context.Context, botID string, conn Connection) bots.BotCheck { + checkKey := "mcp." + sanitizeCheckKey(conn.Name) + check := bots.BotCheck{ + CheckKey: checkKey, + Status: bots.BotCheckStatusUnknown, + Summary: fmt.Sprintf("MCP server %q is being checked.", conn.Name), + Metadata: map[string]any{ + "connection_id": conn.ID, + "name": conn.Name, + "type": conn.Type, + }, + } + + if c.gateway == nil { + check.Status = bots.BotCheckStatusWarn + check.Summary = fmt.Sprintf("MCP server %q cannot be checked.", conn.Name) + check.Detail = "tool gateway not available" + return check + } + + probeCtx, cancel := context.WithTimeout(ctx, mcpCheckTimeout) + defer cancel() + + session := ToolSessionContext{BotID: botID} + tools, err := c.gateway.ListTools(probeCtx, session) + if err != nil { + check.Status = bots.BotCheckStatusError + check.Summary = fmt.Sprintf("MCP server %q is not reachable.", conn.Name) + check.Detail = err.Error() + return check + } + + // Count tools belonging to this connection (prefixed with connection name). + prefix := sanitizeCheckKey(conn.Name) + "." + toolCount := 0 + for _, t := range tools { + if strings.HasPrefix(t.Name, prefix) || t.Name == conn.Name { + toolCount++ + } + } + + if toolCount > 0 { + check.Status = bots.BotCheckStatusOK + check.Summary = fmt.Sprintf("MCP server %q is healthy (%d tools).", conn.Name, toolCount) + check.Metadata["tool_count"] = toolCount + } else { + check.Status = bots.BotCheckStatusWarn + check.Summary = fmt.Sprintf("MCP server %q is reachable but no tools found.", conn.Name) + check.Detail = "The server responded but exposed no tools." + } + return check +} + +func sanitizeCheckKey(raw string) string { + raw = strings.TrimSpace(strings.ToLower(raw)) + if raw == "" { + return "unknown" + } + b := strings.Builder{} + for _, ch := range raw { + if (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' { + b.WriteRune(ch) + } else { + b.WriteRune('_') + } + } + return strings.Trim(b.String(), "_-") +}