diff --git a/cmd/agent/main.go b/cmd/agent/main.go index bee71fee..705daca4 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -2,12 +2,16 @@ package main import ( "context" + "errors" "fmt" "log/slog" + "net/http" "os" "strings" "time" + containerd "github.com/containerd/containerd/v2/client" + "github.com/memohai/memoh/internal/boot" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/channel" "github.com/memohai/memoh/internal/channel/adapters/feishu" @@ -36,164 +40,317 @@ import ( "github.com/memohai/memoh/internal/subagent" "github.com/memohai/memoh/internal/users" "github.com/memohai/memoh/internal/version" + "go.uber.org/fx" + "go.uber.org/fx/fxevent" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" "golang.org/x/crypto/bcrypt" ) -func main() { - fmt.Printf("Starting Memoh Agent %s\n", version.GetInfo()) - ctx := context.Background() +func provideConfig() (config.Config, error) { cfgPath := os.Getenv("CONFIG_PATH") cfg, err := config.Load(cfgPath) if err != nil { - fmt.Fprintf(os.Stderr, "load config: %v\n", err) - os.Exit(1) + return config.Config{}, fmt.Errorf("load config: %v\n", err) } + return cfg, nil +} +func provideLogger(cfg config.Config) *slog.Logger { logger.Init(cfg.Log.Level, cfg.Log.Format) + return logger.L +} - if strings.TrimSpace(cfg.Auth.JWTSecret) == "" { - logger.Error("jwt secret is required") - os.Exit(1) - } - jwtExpiresIn, err := time.ParseDuration(cfg.Auth.JWTExpiresIn) +func provideContainerdClient(lc fx.Lifecycle, runtimeConfig *boot.RuntimeConfig) (*containerd.Client, error) { + factory := ctr.DefaultClientFactory{SocketPath: runtimeConfig.ContainerdSocketPath} + client, err := factory.New(context.Background()) if err != nil { - logger.Error("invalid jwt expires in", slog.Any("error", err)) - os.Exit(1) + return nil, fmt.Errorf("connect containerd: %w", err) } - addr := cfg.Server.Addr - if value := os.Getenv("HTTP_ADDR"); value != "" { - addr = value - } + lc.Append(fx.Hook{ + OnStop: func(ctx context.Context) error { + if err := client.Close(); err != nil { + return fmt.Errorf("close containerd client: %w", err) + } + return nil + }, + }) + return client, nil +} - socketPath := cfg.Containerd.SocketPath - if value := os.Getenv("CONTAINERD_SOCKET"); value != "" { - socketPath = value - } - factory := ctr.DefaultClientFactory{SocketPath: socketPath} - client, err := factory.New(ctx) - if err != nil { - logger.Error("connect containerd", slog.Any("error", err)) - os.Exit(1) - } - defer client.Close() +func main() { + fx.New( + fx.Provide( + provideConfig, + boot.ProvideRuntimeConfig, + provideLogger, - service := ctr.NewDefaultService(logger.L, client, cfg.Containerd.Namespace) - manager := mcp.NewManager(logger.L, service, cfg.MCP) + // misc + provideContainerdClient, + provideDBConn, + provideDBQueries, - pingHandler := handlers.NewPingHandler(logger.L) - // containerdHandler is created later after DB services are initialized + fx.Annotate(ctr.NewDefaultService, fx.As(new(ctr.Service))), + mcp.NewManager, + + provideMemoryLLM, + provideEmbeddingsResolver, + provideEmbeddingSetup, + provideTextEmbedderForMemory, + provideQdrantStore, + memory.NewBM25Indexer, + provideChatResolver, + local.NewSessionHub, + provideChannelRegistry, + + provideChannelRouter, + provideChannelManager, + + chat.NewScheduleGateway, + fx.Annotate(func(scheduleGateway *chat.ScheduleGateway) schedule.Triggerer { + return scheduleGateway + }, fx.As(new(schedule.Triggerer))), + + models.NewService, + bots.NewService, + users.NewService, + providers.NewService, + settings.NewService, + history.NewService, + contacts.NewService, + preauth.NewService, + mcp.NewConnectionService, + subagent.NewService, + schedule.NewService, + channel.NewService, + policy.NewService, + provideMemoryService, + + provideServerHandler(handlers.NewPingHandler), + provideServerHandler(handlers.NewAuthHandler), + provideServerHandler(handlers.NewMemoryHandler), + provideServerHandler(handlers.NewEmbeddingsHandler), + provideServerHandler(handlers.NewChatHandler), + provideServerHandler(handlers.NewSwaggerHandler), + provideServerHandler(handlers.NewProvidersHandler), + provideServerHandler(handlers.NewModelsHandler), + provideServerHandler(handlers.NewSettingsHandler), + provideServerHandler(handlers.NewHistoryHandler), + provideServerHandler(handlers.NewContactsHandler), + provideServerHandler(handlers.NewPreauthHandler), + provideServerHandler(handlers.NewScheduleHandler), + provideServerHandler(handlers.NewSubagentHandler), + handlers.NewContainerdHandler, + provideServerHandler(handlers.NewContainerdHandler), + provideServerHandler(handlers.NewChannelHandler), + provideServerHandler(handlers.NewUsersHandler), + provideServerHandler(handlers.NewMCPHandler), + provideServerHandler(provideCLIHandler), + provideServerHandler(provideWebHandler), + + provideServer, + ), + fx.Invoke( + startMemoryWarmup, + startScheduleService, + startChannelManager, + startServer, + ), + fx.WithLogger(func(logger *slog.Logger) fxevent.Logger { + l := &fxevent.SlogLogger{Logger: logger.With(slog.String("component", "fx"))} + // l.UseLogLevel(slog.LevelInfo) + return l + }), + ).Run() +} + +func provideServerHandler(fn any) any { + return fx.Annotate( + fn, + fx.As(new(server.Handler)), + fx.ResultTags(`group:"server_handlers"`), + ) +} + +func provideDBConn(lc fx.Lifecycle, cfg config.Config) (*pgxpool.Pool, error) { + ctx := context.Background() // TODO: use timeout context conn, err := db.Open(ctx, cfg.Postgres) if err != nil { - logger.Error("db connect", slog.Any("error", err)) - os.Exit(1) + return nil, fmt.Errorf("db connect: %w", err) } - defer conn.Close() - manager.WithDB(conn) - queries := dbsqlc.New(conn) - modelsService := models.NewService(logger.L, queries) - botService := bots.NewService(logger.L, queries) - usersService := users.NewService(logger.L, queries) + lc.Append(fx.Hook{ + OnStop: func(ctx context.Context) error { + conn.Close() + return nil + }, + }) + return conn, nil +} - containerdHandler := handlers.NewContainerdHandler(logger.L, service, cfg.MCP, cfg.Containerd.Namespace, botService, usersService, queries) - botService.SetContainerLifecycle(containerdHandler) +func provideDBQueries(conn *pgxpool.Pool) *dbsqlc.Queries { + return dbsqlc.New(conn) +} - if err := ensureAdminUser(ctx, logger.L, queries, cfg); err != nil { - logger.Error("ensure admin user", slog.Any("error", err)) - os.Exit(1) - } +func provideEmbeddingsResolver(log *slog.Logger, modelsService *models.Service, queries *dbsqlc.Queries) *embeddings.Resolver { + return embeddings.NewResolver(log, modelsService, queries, 10*time.Second) +} - authHandler := handlers.NewAuthHandler(logger.L, usersService, cfg.Auth.JWTSecret, jwtExpiresIn) +type embeddingSetup struct { + Vectors map[string]int + TextModel models.GetResponse + MultimodalModel models.GetResponse + HasEmbeddingModels bool +} - // Initialize chat resolver after memory service is configured. - var chatResolver *chat.Resolver +func provideEmbeddingSetup(log *slog.Logger, modelsService *models.Service) (embeddingSetup, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() - // Create LLM client for memory operations (deferred model/provider selection). - var llmClient memory.LLM = &lazyLLMClient{ - modelsService: modelsService, - queries: queries, - timeout: 30 * time.Second, - logger: logger.L, - } - - resolver := embeddings.NewResolver(logger.L, modelsService, queries, 10*time.Second) vectors, textModel, multimodalModel, hasEmbeddingModels, err := embeddings.CollectEmbeddingVectors(ctx, modelsService) if err != nil { - logger.Error("embedding models", slog.Any("error", err)) - os.Exit(1) + return embeddingSetup{}, fmt.Errorf("embedding models: %w", err) } - - textEmbedder := buildTextEmbedder(resolver, textModel, hasEmbeddingModels, logger.L) if hasEmbeddingModels && multimodalModel.ModelID == "" { - logger.Warn("No multimodal embedding model configured. Multimodal embedding features will be limited.") + log.Warn("No multimodal embedding model configured. Multimodal embedding features will be limited.") } - store := buildQdrantStore(logger.L, cfg.Qdrant, vectors, hasEmbeddingModels, textModel.Dimensions) + return embeddingSetup{ + Vectors: vectors, + TextModel: textModel, + MultimodalModel: multimodalModel, + HasEmbeddingModels: hasEmbeddingModels, + }, nil +} - bm25Indexer := memory.NewBM25Indexer(logger.L) - memoryService := memory.NewService(logger.L, llmClient, textEmbedder, store, resolver, bm25Indexer, textModel.ModelID, multimodalModel.ModelID) - memoryHandler := handlers.NewMemoryHandler(logger.L, memoryService, botService, usersService) - go func() { - if err := memoryService.WarmupBM25(ctx, 200); err != nil { - logger.Warn("bm25 warmup failed", slog.Any("error", err)) - } - }() +func provideTextEmbedderForMemory(resolver *embeddings.Resolver, setup embeddingSetup, log *slog.Logger) embeddings.Embedder { + return buildTextEmbedder(resolver, setup.TextModel, setup.HasEmbeddingModels, log) +} - // Initialize providers and models handlers - providersService := providers.NewService(logger.L, queries) - providersHandler := handlers.NewProvidersHandler(logger.L, providersService, modelsService) - settingsService := settings.NewService(logger.L, queries) - settingsHandler := handlers.NewSettingsHandler(logger.L, settingsService, botService, usersService) - modelsHandler := handlers.NewModelsHandler(logger.L, modelsService, settingsService) - policyService := policy.NewService(logger.L, botService, settingsService) - historyService := history.NewService(logger.L, queries) - historyHandler := handlers.NewHistoryHandler(logger.L, historyService, botService, usersService) - contactsService := contacts.NewService(queries) - contactsHandler := handlers.NewContactsHandler(contactsService, botService, usersService) - preauthService := preauth.NewService(queries) - preauthHandler := handlers.NewPreauthHandler(preauthService, botService, usersService) - mcpConnectionsService := mcp.NewConnectionService(logger.L, queries) - mcpHandler := handlers.NewMCPHandler(logger.L, mcpConnectionsService, botService, usersService) +func provideMemoryService(log *slog.Logger, llm memory.LLM, embedder embeddings.Embedder, store *memory.QdrantStore, resolver *embeddings.Resolver, bm25Indexer *memory.BM25Indexer, setup embeddingSetup) *memory.Service { + return memory.NewService(log, llm, embedder, store, resolver, bm25Indexer, setup.TextModel.ModelID, setup.MultimodalModel.ModelID) +} - chatResolver = chat.NewResolver(logger.L, modelsService, queries, memoryService, historyService, settingsService, mcpConnectionsService, cfg.AgentGateway.BaseURL(), 120*time.Second) +func provideChatResolver(log *slog.Logger, cfg config.Config, modelsService *models.Service, queries *dbsqlc.Queries, memoryService *memory.Service, historyService *history.Service, settingsService *settings.Service, mcpConnectionsService *mcp.ConnectionService, containerdHandler *handlers.ContainerdHandler) *chat.Resolver { + chatResolver := chat.NewResolver(log, modelsService, queries, memoryService, historyService, settingsService, mcpConnectionsService, cfg.AgentGateway.BaseURL(), 120*time.Second) chatResolver.SetSkillLoader(&skillLoaderAdapter{handler: containerdHandler}) - embeddingsHandler := handlers.NewEmbeddingsHandler(logger.L, modelsService, queries) - swaggerHandler := handlers.NewSwaggerHandler(logger.L) - chatHandler := handlers.NewChatHandler(logger.L, chatResolver, botService, usersService) - channelRegistry := channel.NewRegistry() - sessionHub := local.NewSessionHub() - channelRegistry.MustRegister(telegram.NewTelegramAdapter(logger.L)) - channelRegistry.MustRegister(feishu.NewFeishuAdapter(logger.L)) - channelRegistry.MustRegister(local.NewCLIAdapter(sessionHub)) - channelRegistry.MustRegister(local.NewWebAdapter(sessionHub)) - channelService := channel.NewService(queries, channelRegistry) - channelRouter := router.NewChannelInboundProcessor(logger.L, channelRegistry, channelService, chatResolver, contactsService, policyService, preauthService, cfg.Auth.JWTSecret, 5*time.Minute) - channelManager := channel.NewManager(logger.L, channelRegistry, channelService, channelRouter) + return chatResolver +} + +func provideChannelRegistry(log *slog.Logger, sessionHub *local.SessionHub) *channel.Registry { + registry := channel.NewRegistry() + registry.MustRegister(telegram.NewTelegramAdapter(log)) + registry.MustRegister(feishu.NewFeishuAdapter(log)) + registry.MustRegister(local.NewCLIAdapter(sessionHub)) + registry.MustRegister(local.NewWebAdapter(sessionHub)) + return registry +} + +func provideChannelRouter(log *slog.Logger, registry *channel.Registry, channelService *channel.Service, chatResolver *chat.Resolver, contactsService *contacts.Service, policyService *policy.Service, preauthService *preauth.Service, cfg config.Config) *router.ChannelInboundProcessor { + return router.NewChannelInboundProcessor(log, registry, channelService, chatResolver, contactsService, policyService, preauthService, cfg.Auth.JWTSecret, 5*time.Minute) +} + +func provideChannelManager(log *slog.Logger, registry *channel.Registry, channelService *channel.Service, channelRouter *router.ChannelInboundProcessor) *channel.Manager { + channelManager := channel.NewManager(log, registry, channelService, channelRouter) if mw := channelRouter.IdentityMiddleware(); mw != nil { channelManager.Use(mw) } - channelManager.Start(ctx) - channelHandler := handlers.NewChannelHandler(channelService, channelRegistry) - usersHandler := handlers.NewUsersHandler(logger.L, usersService, botService, channelService, channelManager, channelRegistry) - cliHandler := handlers.NewLocalChannelHandler(local.CLIType, channelManager, channelService, sessionHub, botService, usersService) - webHandler := handlers.NewLocalChannelHandler(local.WebType, channelManager, channelService, sessionHub, botService, usersService) - scheduleGateway := chat.NewScheduleGateway(chatResolver) - scheduleService := schedule.NewService(logger.L, queries, scheduleGateway, cfg.Auth.JWTSecret) - if err := scheduleService.Bootstrap(ctx); err != nil { - logger.Error("schedule bootstrap", slog.Any("error", err)) - os.Exit(1) - } - scheduleHandler := handlers.NewScheduleHandler(logger.L, scheduleService, botService, usersService) - subagentService := subagent.NewService(logger.L, queries) - subagentHandler := handlers.NewSubagentHandler(logger.L, subagentService, botService, usersService) - srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, historyHandler, contactsHandler, preauthHandler, scheduleHandler, subagentHandler, containerdHandler, channelHandler, usersHandler, mcpHandler, cliHandler, webHandler) + return channelManager +} - if err := srv.Start(); err != nil { - logger.Error("server failed", slog.Any("error", err)) - os.Exit(1) - } +func provideCLIHandler(channelManager *channel.Manager, channelService *channel.Service, sessionHub *local.SessionHub, botService *bots.Service, usersService *users.Service) *handlers.LocalChannelHandler { + return handlers.NewLocalChannelHandler(local.CLIType, channelManager, channelService, sessionHub, botService, usersService) +} + +func provideWebHandler(channelManager *channel.Manager, channelService *channel.Service, sessionHub *local.SessionHub, botService *bots.Service, usersService *users.Service) *handlers.LocalChannelHandler { + return handlers.NewLocalChannelHandler(local.WebType, channelManager, channelService, sessionHub, botService, usersService) +} + +type serverParams struct { + fx.In + + Logger *slog.Logger + RuntimeConfig *boot.RuntimeConfig + Config config.Config + ServerHandlers []server.Handler `group:"server_handlers"` +} + +func provideServer(params serverParams) *server.Server { + return server.NewServer(params.Logger, params.RuntimeConfig.ServerAddr, params.Config.Auth.JWTSecret, params.ServerHandlers...) +} + +func startMemoryWarmup(lc fx.Lifecycle, memoryService *memory.Service, logger *slog.Logger) { + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + go func() { + if err := memoryService.WarmupBM25(context.Background(), 200); err != nil { + logger.Warn("bm25 warmup failed", slog.Any("error", err)) + } + }() + return nil + }, + }) + +} + +func startChannelManager(lc fx.Lifecycle, channelManager *channel.Manager, logger *slog.Logger) { + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + channelManager.Start(ctx) + return nil + }, + }) +} + +func startScheduleService(lc fx.Lifecycle, scheduleService *schedule.Service, logger *slog.Logger) { + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + return scheduleService.Bootstrap(ctx) + }, + }) +} + +func startServer( + lc fx.Lifecycle, + logger *slog.Logger, + srv *server.Server, + shutdowner fx.Shutdowner, + cfg config.Config, + queries *dbsqlc.Queries, + scheduleService *schedule.Service, + channelManager *channel.Manager, + botService *bots.Service, + containerdHandler *handlers.ContainerdHandler, +) { + fmt.Printf("Starting Memoh Agent %s\n", version.GetInfo()) + + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + + if err := ensureAdminUser(ctx, logger, queries, cfg); err != nil { + return err + } + + botService.SetContainerLifecycle(containerdHandler) + + go func() { + if err := srv.Start(); err != nil { // block until server is stopped + logger.Error("server failed", slog.Any("error", err)) + _ = shutdowner.Shutdown() // shutdown the application if the server fails to start + } + }() + + return nil + }, + OnStop: func(ctx context.Context) error { + // graceful shutdown + if err := srv.Stop(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("server stop: %w", err) + } + return nil + }, + }) } func buildTextEmbedder(resolver *embeddings.Resolver, textModel models.GetResponse, hasModels bool, log *slog.Logger) embeddings.Embedder { @@ -211,38 +368,37 @@ func buildTextEmbedder(resolver *embeddings.Resolver, textModel models.GetRespon } } -func buildQdrantStore(log *slog.Logger, cfg config.QdrantConfig, vectors map[string]int, hasModels bool, textDims int) *memory.QdrantStore { +func provideQdrantStore(log *slog.Logger, cfgAll config.Config, setup embeddingSetup) (*memory.QdrantStore, error) { + cfg := cfgAll.Qdrant timeout := time.Duration(cfg.TimeoutSeconds) * time.Second - if hasModels && len(vectors) > 0 { + if setup.HasEmbeddingModels && len(setup.Vectors) > 0 { store, err := memory.NewQdrantStoreWithVectors( log, cfg.BaseURL, cfg.APIKey, cfg.Collection, - vectors, + setup.Vectors, "sparse_hash", timeout, ) if err != nil { - log.Error("qdrant named vectors init", slog.Any("error", err)) - os.Exit(1) + return nil, fmt.Errorf("qdrant named vectors init: %w", err) } - return store + return store, nil } store, err := memory.NewQdrantStore( log, cfg.BaseURL, cfg.APIKey, cfg.Collection, - textDims, + setup.TextModel.Dimensions, "sparse_hash", timeout, ) if err != nil { - log.Error("qdrant init", slog.Any("error", err)) - os.Exit(1) + return nil, fmt.Errorf("qdrant init: %w", err) } - return store + return store, nil } func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) error { @@ -296,6 +452,15 @@ func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Quer return nil } +func provideMemoryLLM(modelsService *models.Service, queries *dbsqlc.Queries, log *slog.Logger) memory.LLM { + return &lazyLLMClient{ + modelsService: modelsService, + queries: queries, + timeout: 30 * time.Second, + logger: log, + } +} + type lazyLLMClient struct { modelsService *models.Service queries *dbsqlc.Queries diff --git a/go.mod b/go.mod index 2005a394..a3f6bbe3 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/containerd/containerd/api v1.10.0 github.com/containerd/containerd/v2 v2.2.1 github.com/containerd/errdefs v1.0.0 + github.com/containerd/go-cni v1.1.13 github.com/containerd/platforms v1.0.0-rc.2 github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 github.com/golang-jwt/jwt/v5 v5.3.1 @@ -24,7 +25,10 @@ require ( github.com/robfig/cron/v3 v3.0.1 github.com/stretchr/testify v1.11.1 github.com/swaggo/swag v1.16.6 + go.uber.org/fx v1.24.0 golang.org/x/crypto v0.47.0 + google.golang.org/grpc v1.78.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -45,7 +49,6 @@ require ( github.com/containerd/continuity v0.4.5 // indirect github.com/containerd/errdefs/pkg v0.3.0 // indirect github.com/containerd/fifo v1.1.0 // indirect - github.com/containerd/go-cni v1.1.13 // indirect github.com/containerd/log v0.1.0 // indirect github.com/containerd/plugin v1.0.0 // indirect github.com/containerd/ttrpc v1.2.7 // indirect @@ -103,6 +106,9 @@ require ( go.opentelemetry.io/otel v1.39.0 // indirect go.opentelemetry.io/otel/metric v1.39.0 // indirect go.opentelemetry.io/otel/trace v1.39.0 // indirect + go.uber.org/dig v1.19.0 // indirect + go.uber.org/multierr v1.10.0 // indirect + go.uber.org/zap v1.26.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/mod v0.32.0 // indirect golang.org/x/net v0.49.0 // indirect @@ -113,7 +119,5 @@ require ( golang.org/x/time v0.14.0 // indirect golang.org/x/tools v0.41.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 // indirect - google.golang.org/grpc v1.78.0 // indirect google.golang.org/protobuf v1.36.11 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 8864560c..6283effb 100644 --- a/go.sum +++ b/go.sum @@ -109,6 +109,8 @@ github.com/go-openapi/testify/enable/yaml/v2 v2.0.2 h1:0+Y41Pz1NkbTHz8NngxTuAXxE github.com/go-openapi/testify/enable/yaml/v2 v2.0.2/go.mod h1:kme83333GCtJQHXQ8UKX3IBZu6z8T5Dvy5+CW3NLUUg= github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6Ub6wls= github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= +github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= +github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 h1:wG8n/XJQ07TmjbITcGiUaOtXxdrINDz1b0J1w0SzqDc= github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= @@ -142,6 +144,8 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 h1:FKHo8hFI3A+7w0aUQuYXQ+6EN5stWmeY/AZqtM8xk9k= +github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -198,6 +202,10 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee h1:W5t00kpgFdJifH4BDsTlE89Zl93FEloxaWZfGcifgq8= github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/onsi/ginkgo/v2 v2.20.1 h1:YlVIbqct+ZmnEph770q9Q7NVAz4wwIiVNahee6JyUzo= +github.com/onsi/ginkgo/v2 v2.20.1/go.mod h1:lG9ey2Z29hR41WMVthyJBGUBcBhGOtoPF2VFMvBXFCI= +github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= +github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -229,6 +237,8 @@ github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC4 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -242,6 +252,8 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= +github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -262,6 +274,16 @@ go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2W go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= +go.uber.org/dig v1.19.0 h1:BACLhebsYdpQ7IROQ1AGPjrXcP5dF80U3gKoFzbaq/4= +go.uber.org/dig v1.19.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE= +go.uber.org/fx v1.24.0 h1:wE8mruvpg2kiiL1Vqd0CC+tr0/24XIB10Iwp2lLWzkg= +go.uber.org/fx v1.24.0/go.mod h1:AmDeGyS+ZARGKM4tlH4FY2Jr63VjbEDJHtqXTGP5hbo= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= +go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -270,6 +292,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f h1:XdNn9LlyWAhLVp6P/i8QYBW+hlyhrhei9uErw2B5GJo= +golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f/go.mod h1:D5SMRVC3C2/4+F/DB1wZsLRnSNimn2Sp/NPsCrsv8ak= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= diff --git a/internal/boot/runtime.go b/internal/boot/runtime.go new file mode 100644 index 00000000..0ef1acbc --- /dev/null +++ b/internal/boot/runtime.go @@ -0,0 +1,45 @@ +package boot + +import ( + "errors" + "fmt" + "os" + "strings" + "time" + + "github.com/memohai/memoh/internal/config" +) + +type RuntimeConfig struct { + JwtSecret string + JwtExpiresIn time.Duration + ServerAddr string + ContainerdSocketPath string +} + +func ProvideRuntimeConfig(cfg config.Config) (*RuntimeConfig, error) { + if strings.TrimSpace(cfg.Auth.JWTSecret) == "" { + return nil, errors.New("jwt secret is required") + } + + jwtExpiresIn, err := time.ParseDuration(cfg.Auth.JWTExpiresIn) + if err != nil { + return nil, fmt.Errorf("invalid jwt expires in: %w", err) + } + + ret := &RuntimeConfig{ + JwtSecret: cfg.Auth.JWTSecret, + JwtExpiresIn: jwtExpiresIn, + ServerAddr: cfg.Server.Addr, + ContainerdSocketPath: cfg.Containerd.SocketPath, + } + + if value := os.Getenv("HTTP_ADDR"); value != "" { + ret.ServerAddr = value + } + + if value := os.Getenv("CONTAINERD_SOCKET"); value != "" { + ret.ContainerdSocketPath = value + } + return ret, nil +} diff --git a/internal/containerd/service.go b/internal/containerd/service.go index 44ee8633..c825905a 100644 --- a/internal/containerd/service.go +++ b/internal/containerd/service.go @@ -26,6 +26,7 @@ import ( "github.com/containerd/containerd/v2/pkg/oci" "github.com/containerd/errdefs" "github.com/containerd/platforms" + "github.com/memohai/memoh/internal/config" "github.com/opencontainers/go-digest" "github.com/opencontainers/image-spec/identity" "github.com/opencontainers/runtime-spec/specs-go" @@ -147,7 +148,8 @@ type DefaultService struct { logger *slog.Logger } -func NewDefaultService(log *slog.Logger, client *containerd.Client, namespace string) *DefaultService { +func NewDefaultService(log *slog.Logger, client *containerd.Client, cfg config.Config) *DefaultService { + namespace := cfg.Containerd.Namespace if namespace == "" { namespace = DefaultNamespace } diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go index 4f1e8ad1..fb908109 100644 --- a/internal/handlers/auth.go +++ b/internal/handlers/auth.go @@ -10,6 +10,7 @@ import ( "github.com/labstack/echo/v4" "github.com/memohai/memoh/internal/auth" + "github.com/memohai/memoh/internal/boot" "github.com/memohai/memoh/internal/users" ) @@ -35,11 +36,11 @@ type LoginResponse struct { Username string `json:"username" validate:"required"` } -func NewAuthHandler(log *slog.Logger, userService *users.Service, jwtSecret string, expiresIn time.Duration) *AuthHandler { +func NewAuthHandler(log *slog.Logger, userService *users.Service, runtimeConfig *boot.RuntimeConfig) *AuthHandler { return &AuthHandler{ userService: userService, - jwtSecret: jwtSecret, - expiresIn: expiresIn, + jwtSecret: runtimeConfig.JwtSecret, + expiresIn: runtimeConfig.JwtExpiresIn, logger: log.With(slog.String("handler", "auth")), } } diff --git a/internal/handlers/containerd.go b/internal/handlers/containerd.go index bd2aaef7..7c736e42 100644 --- a/internal/handlers/containerd.go +++ b/internal/handlers/containerd.go @@ -95,11 +95,11 @@ type ListSnapshotsResponse struct { Snapshots []SnapshotInfo `json:"snapshots"` } -func NewContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string, botService *bots.Service, userService *users.Service, queries *dbsqlc.Queries) *ContainerdHandler { +func NewContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.Config, botService *bots.Service, userService *users.Service, queries *dbsqlc.Queries) *ContainerdHandler { return &ContainerdHandler{ service: service, - cfg: cfg, - namespace: namespace, + cfg: cfg.MCP, + namespace: cfg.Containerd.Namespace, logger: log.With(slog.String("handler", "containerd")), mcpSess: make(map[string]*mcpSession), mcpStdioSess: make(map[string]*mcpStdioSession), diff --git a/internal/mcp/manager.go b/internal/mcp/manager.go index e2a46218..d492db1d 100644 --- a/internal/mcp/manager.go +++ b/internal/mcp/manager.go @@ -47,10 +47,12 @@ type Manager struct { logger *slog.Logger } -func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig) *Manager { +func NewManager(log *slog.Logger, service ctr.Service, cfg config.Config, db *pgxpool.Pool) *Manager { return &Manager{ + db: db, + queries: dbsqlc.New(db), service: service, - cfg: cfg, + cfg: cfg.MCP, logger: log.With(slog.String("component", "mcp")), containerID: func(botID string) string { return ContainerPrefix + botID @@ -58,12 +60,6 @@ func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig) *Ma } } -func (m *Manager) WithDB(db *pgxpool.Pool) *Manager { - m.db = db - m.queries = dbsqlc.New(db) - return m -} - func (m *Manager) Init(ctx context.Context) error { image := m.cfg.BusyboxImage if image == "" { diff --git a/internal/schedule/service.go b/internal/schedule/service.go index 5af044b2..42aff905 100644 --- a/internal/schedule/service.go +++ b/internal/schedule/service.go @@ -15,6 +15,7 @@ import ( "github.com/robfig/cron/v3" "github.com/memohai/memoh/internal/auth" + "github.com/memohai/memoh/internal/boot" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -29,7 +30,7 @@ type Service struct { jobs map[string]cron.EntryID } -func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, jwtSecret string) *Service { +func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, runtimeConfig *boot.RuntimeConfig) *Service { parser := cron.NewParser(cron.SecondOptional | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor) c := cron.New(cron.WithParser(parser)) service := &Service{ @@ -37,7 +38,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, jw cron: c, parser: parser, triggerer: triggerer, - jwtSecret: jwtSecret, + jwtSecret: runtimeConfig.JwtSecret, logger: log.With(slog.String("service", "schedule")), jobs: map[string]cron.EntryID{}, } diff --git a/internal/server/server.go b/internal/server/server.go index cb8113d2..f2e64937 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,6 +1,7 @@ package server import ( + "context" "log/slog" "strings" @@ -8,7 +9,6 @@ import ( "github.com/labstack/echo/v4/middleware" "github.com/memohai/memoh/internal/auth" - "github.com/memohai/memoh/internal/handlers" ) type Server struct { @@ -17,7 +17,13 @@ type Server struct { logger *slog.Logger } -func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, chatHandler *handlers.ChatHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, settingsHandler *handlers.SettingsHandler, historyHandler *handlers.HistoryHandler, contactsHandler *handlers.ContactsHandler, preauthHandler *handlers.PreauthHandler, scheduleHandler *handlers.ScheduleHandler, subagentHandler *handlers.SubagentHandler, containerdHandler *handlers.ContainerdHandler, channelHandler *handlers.ChannelHandler, usersHandler *handlers.UsersHandler, mcpHandler *handlers.MCPHandler, cliHandler *handlers.LocalChannelHandler, webHandler *handlers.LocalChannelHandler) *Server { +type Handler interface { + Register(e *echo.Echo) +} + +func NewServer(log *slog.Logger, addr string, jwtSecret string, + handlers ...Handler, +) *Server { if addr == "" { addr = ":8080" } @@ -51,65 +57,10 @@ func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *han return false })) - if pingHandler != nil { - pingHandler.Register(e) - } - if authHandler != nil { - authHandler.Register(e) - } - if memoryHandler != nil { - memoryHandler.Register(e) - } - if embeddingsHandler != nil { - embeddingsHandler.Register(e) - } - if chatHandler != nil { - chatHandler.Register(e) - } - if swaggerHandler != nil { - swaggerHandler.Register(e) - } - if settingsHandler != nil { - settingsHandler.Register(e) - } - if historyHandler != nil { - historyHandler.Register(e) - } - if contactsHandler != nil { - contactsHandler.Register(e) - } - if preauthHandler != nil { - preauthHandler.Register(e) - } - if scheduleHandler != nil { - scheduleHandler.Register(e) - } - if subagentHandler != nil { - subagentHandler.Register(e) - } - if providersHandler != nil { - providersHandler.Register(e) - } - if modelsHandler != nil { - modelsHandler.Register(e) - } - if containerdHandler != nil { - containerdHandler.Register(e) - } - if channelHandler != nil { - channelHandler.Register(e) - } - if usersHandler != nil { - usersHandler.Register(e) - } - if mcpHandler != nil { - mcpHandler.Register(e) - } - if cliHandler != nil { - cliHandler.Register(e) - } - if webHandler != nil { - webHandler.Register(e) + for _, h := range handlers { + if h != nil { + h.Register(e) + } } return &Server{ @@ -122,3 +73,7 @@ func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *han func (s *Server) Start() error { return s.echo.Start(s.addr) } + +func (s *Server) Stop(ctx context.Context) error { + return s.echo.Shutdown(ctx) +}