refactor: using fx

This commit is contained in:
MengYX
2026-02-11 02:26:51 +08:00
parent 155c70685f
commit 6548c31597
10 changed files with 383 additions and 206 deletions
+293 -128
View File
@@ -2,12 +2,16 @@ package main
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"os"
"strings"
"time"
containerd "github.com/containerd/containerd/v2/client"
"github.com/memohai/memoh/internal/boot"
"github.com/memohai/memoh/internal/bots"
"github.com/memohai/memoh/internal/channel"
"github.com/memohai/memoh/internal/channel/adapters/feishu"
@@ -36,164 +40,317 @@ import (
"github.com/memohai/memoh/internal/subagent"
"github.com/memohai/memoh/internal/users"
"github.com/memohai/memoh/internal/version"
"go.uber.org/fx"
"go.uber.org/fx/fxevent"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"golang.org/x/crypto/bcrypt"
)
func main() {
fmt.Printf("Starting Memoh Agent %s\n", version.GetInfo())
ctx := context.Background()
func provideConfig() (config.Config, error) {
cfgPath := os.Getenv("CONFIG_PATH")
cfg, err := config.Load(cfgPath)
if err != nil {
fmt.Fprintf(os.Stderr, "load config: %v\n", err)
os.Exit(1)
return config.Config{}, fmt.Errorf("load config: %v\n", err)
}
return cfg, nil
}
func provideLogger(cfg config.Config) *slog.Logger {
logger.Init(cfg.Log.Level, cfg.Log.Format)
return logger.L
}
if strings.TrimSpace(cfg.Auth.JWTSecret) == "" {
logger.Error("jwt secret is required")
os.Exit(1)
}
jwtExpiresIn, err := time.ParseDuration(cfg.Auth.JWTExpiresIn)
func provideContainerdClient(lc fx.Lifecycle, runtimeConfig *boot.RuntimeConfig) (*containerd.Client, error) {
factory := ctr.DefaultClientFactory{SocketPath: runtimeConfig.ContainerdSocketPath}
client, err := factory.New(context.Background())
if err != nil {
logger.Error("invalid jwt expires in", slog.Any("error", err))
os.Exit(1)
return nil, fmt.Errorf("connect containerd: %w", err)
}
addr := cfg.Server.Addr
if value := os.Getenv("HTTP_ADDR"); value != "" {
addr = value
}
lc.Append(fx.Hook{
OnStop: func(ctx context.Context) error {
if err := client.Close(); err != nil {
return fmt.Errorf("close containerd client: %w", err)
}
return nil
},
})
return client, nil
}
socketPath := cfg.Containerd.SocketPath
if value := os.Getenv("CONTAINERD_SOCKET"); value != "" {
socketPath = value
}
factory := ctr.DefaultClientFactory{SocketPath: socketPath}
client, err := factory.New(ctx)
if err != nil {
logger.Error("connect containerd", slog.Any("error", err))
os.Exit(1)
}
defer client.Close()
func main() {
fx.New(
fx.Provide(
provideConfig,
boot.ProvideRuntimeConfig,
provideLogger,
service := ctr.NewDefaultService(logger.L, client, cfg.Containerd.Namespace)
manager := mcp.NewManager(logger.L, service, cfg.MCP)
// misc
provideContainerdClient,
provideDBConn,
provideDBQueries,
pingHandler := handlers.NewPingHandler(logger.L)
// containerdHandler is created later after DB services are initialized
fx.Annotate(ctr.NewDefaultService, fx.As(new(ctr.Service))),
mcp.NewManager,
provideMemoryLLM,
provideEmbeddingsResolver,
provideEmbeddingSetup,
provideTextEmbedderForMemory,
provideQdrantStore,
memory.NewBM25Indexer,
provideChatResolver,
local.NewSessionHub,
provideChannelRegistry,
provideChannelRouter,
provideChannelManager,
chat.NewScheduleGateway,
fx.Annotate(func(scheduleGateway *chat.ScheduleGateway) schedule.Triggerer {
return scheduleGateway
}, fx.As(new(schedule.Triggerer))),
models.NewService,
bots.NewService,
users.NewService,
providers.NewService,
settings.NewService,
history.NewService,
contacts.NewService,
preauth.NewService,
mcp.NewConnectionService,
subagent.NewService,
schedule.NewService,
channel.NewService,
policy.NewService,
provideMemoryService,
provideServerHandler(handlers.NewPingHandler),
provideServerHandler(handlers.NewAuthHandler),
provideServerHandler(handlers.NewMemoryHandler),
provideServerHandler(handlers.NewEmbeddingsHandler),
provideServerHandler(handlers.NewChatHandler),
provideServerHandler(handlers.NewSwaggerHandler),
provideServerHandler(handlers.NewProvidersHandler),
provideServerHandler(handlers.NewModelsHandler),
provideServerHandler(handlers.NewSettingsHandler),
provideServerHandler(handlers.NewHistoryHandler),
provideServerHandler(handlers.NewContactsHandler),
provideServerHandler(handlers.NewPreauthHandler),
provideServerHandler(handlers.NewScheduleHandler),
provideServerHandler(handlers.NewSubagentHandler),
handlers.NewContainerdHandler,
provideServerHandler(handlers.NewContainerdHandler),
provideServerHandler(handlers.NewChannelHandler),
provideServerHandler(handlers.NewUsersHandler),
provideServerHandler(handlers.NewMCPHandler),
provideServerHandler(provideCLIHandler),
provideServerHandler(provideWebHandler),
provideServer,
),
fx.Invoke(
startMemoryWarmup,
startScheduleService,
startChannelManager,
startServer,
),
fx.WithLogger(func(logger *slog.Logger) fxevent.Logger {
l := &fxevent.SlogLogger{Logger: logger.With(slog.String("component", "fx"))}
// l.UseLogLevel(slog.LevelInfo)
return l
}),
).Run()
}
func provideServerHandler(fn any) any {
return fx.Annotate(
fn,
fx.As(new(server.Handler)),
fx.ResultTags(`group:"server_handlers"`),
)
}
func provideDBConn(lc fx.Lifecycle, cfg config.Config) (*pgxpool.Pool, error) {
ctx := context.Background() // TODO: use timeout context
conn, err := db.Open(ctx, cfg.Postgres)
if err != nil {
logger.Error("db connect", slog.Any("error", err))
os.Exit(1)
return nil, fmt.Errorf("db connect: %w", err)
}
defer conn.Close()
manager.WithDB(conn)
queries := dbsqlc.New(conn)
modelsService := models.NewService(logger.L, queries)
botService := bots.NewService(logger.L, queries)
usersService := users.NewService(logger.L, queries)
lc.Append(fx.Hook{
OnStop: func(ctx context.Context) error {
conn.Close()
return nil
},
})
return conn, nil
}
containerdHandler := handlers.NewContainerdHandler(logger.L, service, cfg.MCP, cfg.Containerd.Namespace, botService, usersService, queries)
botService.SetContainerLifecycle(containerdHandler)
func provideDBQueries(conn *pgxpool.Pool) *dbsqlc.Queries {
return dbsqlc.New(conn)
}
if err := ensureAdminUser(ctx, logger.L, queries, cfg); err != nil {
logger.Error("ensure admin user", slog.Any("error", err))
os.Exit(1)
}
func provideEmbeddingsResolver(log *slog.Logger, modelsService *models.Service, queries *dbsqlc.Queries) *embeddings.Resolver {
return embeddings.NewResolver(log, modelsService, queries, 10*time.Second)
}
authHandler := handlers.NewAuthHandler(logger.L, usersService, cfg.Auth.JWTSecret, jwtExpiresIn)
type embeddingSetup struct {
Vectors map[string]int
TextModel models.GetResponse
MultimodalModel models.GetResponse
HasEmbeddingModels bool
}
// Initialize chat resolver after memory service is configured.
var chatResolver *chat.Resolver
func provideEmbeddingSetup(log *slog.Logger, modelsService *models.Service) (embeddingSetup, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Create LLM client for memory operations (deferred model/provider selection).
var llmClient memory.LLM = &lazyLLMClient{
modelsService: modelsService,
queries: queries,
timeout: 30 * time.Second,
logger: logger.L,
}
resolver := embeddings.NewResolver(logger.L, modelsService, queries, 10*time.Second)
vectors, textModel, multimodalModel, hasEmbeddingModels, err := embeddings.CollectEmbeddingVectors(ctx, modelsService)
if err != nil {
logger.Error("embedding models", slog.Any("error", err))
os.Exit(1)
return embeddingSetup{}, fmt.Errorf("embedding models: %w", err)
}
textEmbedder := buildTextEmbedder(resolver, textModel, hasEmbeddingModels, logger.L)
if hasEmbeddingModels && multimodalModel.ModelID == "" {
logger.Warn("No multimodal embedding model configured. Multimodal embedding features will be limited.")
log.Warn("No multimodal embedding model configured. Multimodal embedding features will be limited.")
}
store := buildQdrantStore(logger.L, cfg.Qdrant, vectors, hasEmbeddingModels, textModel.Dimensions)
return embeddingSetup{
Vectors: vectors,
TextModel: textModel,
MultimodalModel: multimodalModel,
HasEmbeddingModels: hasEmbeddingModels,
}, nil
}
bm25Indexer := memory.NewBM25Indexer(logger.L)
memoryService := memory.NewService(logger.L, llmClient, textEmbedder, store, resolver, bm25Indexer, textModel.ModelID, multimodalModel.ModelID)
memoryHandler := handlers.NewMemoryHandler(logger.L, memoryService, botService, usersService)
go func() {
if err := memoryService.WarmupBM25(ctx, 200); err != nil {
logger.Warn("bm25 warmup failed", slog.Any("error", err))
}
}()
func provideTextEmbedderForMemory(resolver *embeddings.Resolver, setup embeddingSetup, log *slog.Logger) embeddings.Embedder {
return buildTextEmbedder(resolver, setup.TextModel, setup.HasEmbeddingModels, log)
}
// Initialize providers and models handlers
providersService := providers.NewService(logger.L, queries)
providersHandler := handlers.NewProvidersHandler(logger.L, providersService, modelsService)
settingsService := settings.NewService(logger.L, queries)
settingsHandler := handlers.NewSettingsHandler(logger.L, settingsService, botService, usersService)
modelsHandler := handlers.NewModelsHandler(logger.L, modelsService, settingsService)
policyService := policy.NewService(logger.L, botService, settingsService)
historyService := history.NewService(logger.L, queries)
historyHandler := handlers.NewHistoryHandler(logger.L, historyService, botService, usersService)
contactsService := contacts.NewService(queries)
contactsHandler := handlers.NewContactsHandler(contactsService, botService, usersService)
preauthService := preauth.NewService(queries)
preauthHandler := handlers.NewPreauthHandler(preauthService, botService, usersService)
mcpConnectionsService := mcp.NewConnectionService(logger.L, queries)
mcpHandler := handlers.NewMCPHandler(logger.L, mcpConnectionsService, botService, usersService)
func provideMemoryService(log *slog.Logger, llm memory.LLM, embedder embeddings.Embedder, store *memory.QdrantStore, resolver *embeddings.Resolver, bm25Indexer *memory.BM25Indexer, setup embeddingSetup) *memory.Service {
return memory.NewService(log, llm, embedder, store, resolver, bm25Indexer, setup.TextModel.ModelID, setup.MultimodalModel.ModelID)
}
chatResolver = chat.NewResolver(logger.L, modelsService, queries, memoryService, historyService, settingsService, mcpConnectionsService, cfg.AgentGateway.BaseURL(), 120*time.Second)
func provideChatResolver(log *slog.Logger, cfg config.Config, modelsService *models.Service, queries *dbsqlc.Queries, memoryService *memory.Service, historyService *history.Service, settingsService *settings.Service, mcpConnectionsService *mcp.ConnectionService, containerdHandler *handlers.ContainerdHandler) *chat.Resolver {
chatResolver := chat.NewResolver(log, modelsService, queries, memoryService, historyService, settingsService, mcpConnectionsService, cfg.AgentGateway.BaseURL(), 120*time.Second)
chatResolver.SetSkillLoader(&skillLoaderAdapter{handler: containerdHandler})
embeddingsHandler := handlers.NewEmbeddingsHandler(logger.L, modelsService, queries)
swaggerHandler := handlers.NewSwaggerHandler(logger.L)
chatHandler := handlers.NewChatHandler(logger.L, chatResolver, botService, usersService)
channelRegistry := channel.NewRegistry()
sessionHub := local.NewSessionHub()
channelRegistry.MustRegister(telegram.NewTelegramAdapter(logger.L))
channelRegistry.MustRegister(feishu.NewFeishuAdapter(logger.L))
channelRegistry.MustRegister(local.NewCLIAdapter(sessionHub))
channelRegistry.MustRegister(local.NewWebAdapter(sessionHub))
channelService := channel.NewService(queries, channelRegistry)
channelRouter := router.NewChannelInboundProcessor(logger.L, channelRegistry, channelService, chatResolver, contactsService, policyService, preauthService, cfg.Auth.JWTSecret, 5*time.Minute)
channelManager := channel.NewManager(logger.L, channelRegistry, channelService, channelRouter)
return chatResolver
}
func provideChannelRegistry(log *slog.Logger, sessionHub *local.SessionHub) *channel.Registry {
registry := channel.NewRegistry()
registry.MustRegister(telegram.NewTelegramAdapter(log))
registry.MustRegister(feishu.NewFeishuAdapter(log))
registry.MustRegister(local.NewCLIAdapter(sessionHub))
registry.MustRegister(local.NewWebAdapter(sessionHub))
return registry
}
func provideChannelRouter(log *slog.Logger, registry *channel.Registry, channelService *channel.Service, chatResolver *chat.Resolver, contactsService *contacts.Service, policyService *policy.Service, preauthService *preauth.Service, cfg config.Config) *router.ChannelInboundProcessor {
return router.NewChannelInboundProcessor(log, registry, channelService, chatResolver, contactsService, policyService, preauthService, cfg.Auth.JWTSecret, 5*time.Minute)
}
func provideChannelManager(log *slog.Logger, registry *channel.Registry, channelService *channel.Service, channelRouter *router.ChannelInboundProcessor) *channel.Manager {
channelManager := channel.NewManager(log, registry, channelService, channelRouter)
if mw := channelRouter.IdentityMiddleware(); mw != nil {
channelManager.Use(mw)
}
channelManager.Start(ctx)
channelHandler := handlers.NewChannelHandler(channelService, channelRegistry)
usersHandler := handlers.NewUsersHandler(logger.L, usersService, botService, channelService, channelManager, channelRegistry)
cliHandler := handlers.NewLocalChannelHandler(local.CLIType, channelManager, channelService, sessionHub, botService, usersService)
webHandler := handlers.NewLocalChannelHandler(local.WebType, channelManager, channelService, sessionHub, botService, usersService)
scheduleGateway := chat.NewScheduleGateway(chatResolver)
scheduleService := schedule.NewService(logger.L, queries, scheduleGateway, cfg.Auth.JWTSecret)
if err := scheduleService.Bootstrap(ctx); err != nil {
logger.Error("schedule bootstrap", slog.Any("error", err))
os.Exit(1)
}
scheduleHandler := handlers.NewScheduleHandler(logger.L, scheduleService, botService, usersService)
subagentService := subagent.NewService(logger.L, queries)
subagentHandler := handlers.NewSubagentHandler(logger.L, subagentService, botService, usersService)
srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, historyHandler, contactsHandler, preauthHandler, scheduleHandler, subagentHandler, containerdHandler, channelHandler, usersHandler, mcpHandler, cliHandler, webHandler)
return channelManager
}
if err := srv.Start(); err != nil {
logger.Error("server failed", slog.Any("error", err))
os.Exit(1)
}
func provideCLIHandler(channelManager *channel.Manager, channelService *channel.Service, sessionHub *local.SessionHub, botService *bots.Service, usersService *users.Service) *handlers.LocalChannelHandler {
return handlers.NewLocalChannelHandler(local.CLIType, channelManager, channelService, sessionHub, botService, usersService)
}
func provideWebHandler(channelManager *channel.Manager, channelService *channel.Service, sessionHub *local.SessionHub, botService *bots.Service, usersService *users.Service) *handlers.LocalChannelHandler {
return handlers.NewLocalChannelHandler(local.WebType, channelManager, channelService, sessionHub, botService, usersService)
}
type serverParams struct {
fx.In
Logger *slog.Logger
RuntimeConfig *boot.RuntimeConfig
Config config.Config
ServerHandlers []server.Handler `group:"server_handlers"`
}
func provideServer(params serverParams) *server.Server {
return server.NewServer(params.Logger, params.RuntimeConfig.ServerAddr, params.Config.Auth.JWTSecret, params.ServerHandlers...)
}
func startMemoryWarmup(lc fx.Lifecycle, memoryService *memory.Service, logger *slog.Logger) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
go func() {
if err := memoryService.WarmupBM25(context.Background(), 200); err != nil {
logger.Warn("bm25 warmup failed", slog.Any("error", err))
}
}()
return nil
},
})
}
func startChannelManager(lc fx.Lifecycle, channelManager *channel.Manager, logger *slog.Logger) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
channelManager.Start(ctx)
return nil
},
})
}
func startScheduleService(lc fx.Lifecycle, scheduleService *schedule.Service, logger *slog.Logger) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
return scheduleService.Bootstrap(ctx)
},
})
}
func startServer(
lc fx.Lifecycle,
logger *slog.Logger,
srv *server.Server,
shutdowner fx.Shutdowner,
cfg config.Config,
queries *dbsqlc.Queries,
scheduleService *schedule.Service,
channelManager *channel.Manager,
botService *bots.Service,
containerdHandler *handlers.ContainerdHandler,
) {
fmt.Printf("Starting Memoh Agent %s\n", version.GetInfo())
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
if err := ensureAdminUser(ctx, logger, queries, cfg); err != nil {
return err
}
botService.SetContainerLifecycle(containerdHandler)
go func() {
if err := srv.Start(); err != nil { // block until server is stopped
logger.Error("server failed", slog.Any("error", err))
_ = shutdowner.Shutdown() // shutdown the application if the server fails to start
}
}()
return nil
},
OnStop: func(ctx context.Context) error {
// graceful shutdown
if err := srv.Stop(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("server stop: %w", err)
}
return nil
},
})
}
func buildTextEmbedder(resolver *embeddings.Resolver, textModel models.GetResponse, hasModels bool, log *slog.Logger) embeddings.Embedder {
@@ -211,38 +368,37 @@ func buildTextEmbedder(resolver *embeddings.Resolver, textModel models.GetRespon
}
}
func buildQdrantStore(log *slog.Logger, cfg config.QdrantConfig, vectors map[string]int, hasModels bool, textDims int) *memory.QdrantStore {
func provideQdrantStore(log *slog.Logger, cfgAll config.Config, setup embeddingSetup) (*memory.QdrantStore, error) {
cfg := cfgAll.Qdrant
timeout := time.Duration(cfg.TimeoutSeconds) * time.Second
if hasModels && len(vectors) > 0 {
if setup.HasEmbeddingModels && len(setup.Vectors) > 0 {
store, err := memory.NewQdrantStoreWithVectors(
log,
cfg.BaseURL,
cfg.APIKey,
cfg.Collection,
vectors,
setup.Vectors,
"sparse_hash",
timeout,
)
if err != nil {
log.Error("qdrant named vectors init", slog.Any("error", err))
os.Exit(1)
return nil, fmt.Errorf("qdrant named vectors init: %w", err)
}
return store
return store, nil
}
store, err := memory.NewQdrantStore(
log,
cfg.BaseURL,
cfg.APIKey,
cfg.Collection,
textDims,
setup.TextModel.Dimensions,
"sparse_hash",
timeout,
)
if err != nil {
log.Error("qdrant init", slog.Any("error", err))
os.Exit(1)
return nil, fmt.Errorf("qdrant init: %w", err)
}
return store
return store, nil
}
func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) error {
@@ -296,6 +452,15 @@ func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Quer
return nil
}
func provideMemoryLLM(modelsService *models.Service, queries *dbsqlc.Queries, log *slog.Logger) memory.LLM {
return &lazyLLMClient{
modelsService: modelsService,
queries: queries,
timeout: 30 * time.Second,
logger: log,
}
}
type lazyLLMClient struct {
modelsService *models.Service
queries *dbsqlc.Queries
+4
View File
@@ -103,6 +103,10 @@ require (
go.opentelemetry.io/otel v1.39.0 // indirect
go.opentelemetry.io/otel/metric v1.39.0 // indirect
go.opentelemetry.io/otel/trace v1.39.0 // indirect
go.uber.org/dig v1.19.0 // indirect
go.uber.org/fx v1.24.0 // indirect
go.uber.org/multierr v1.10.0 // indirect
go.uber.org/zap v1.26.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/mod v0.32.0 // indirect
golang.org/x/net v0.49.0 // indirect
+8
View File
@@ -262,6 +262,14 @@ go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2W
go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew=
go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI=
go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA=
go.uber.org/dig v1.19.0 h1:BACLhebsYdpQ7IROQ1AGPjrXcP5dF80U3gKoFzbaq/4=
go.uber.org/dig v1.19.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE=
go.uber.org/fx v1.24.0 h1:wE8mruvpg2kiiL1Vqd0CC+tr0/24XIB10Iwp2lLWzkg=
go.uber.org/fx v1.24.0/go.mod h1:AmDeGyS+ZARGKM4tlH4FY2Jr63VjbEDJHtqXTGP5hbo=
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo=
go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+45
View File
@@ -0,0 +1,45 @@
package boot
import (
"errors"
"fmt"
"os"
"strings"
"time"
"github.com/memohai/memoh/internal/config"
)
type RuntimeConfig struct {
JwtSecret string
JwtExpiresIn time.Duration
ServerAddr string
ContainerdSocketPath string
}
func ProvideRuntimeConfig(cfg config.Config) (*RuntimeConfig, error) {
if strings.TrimSpace(cfg.Auth.JWTSecret) == "" {
return nil, errors.New("jwt secret is required")
}
jwtExpiresIn, err := time.ParseDuration(cfg.Auth.JWTExpiresIn)
if err != nil {
return nil, fmt.Errorf("invalid jwt expires in: %w", err)
}
ret := &RuntimeConfig{
JwtSecret: cfg.Auth.JWTSecret,
JwtExpiresIn: jwtExpiresIn,
ServerAddr: cfg.Server.Addr,
ContainerdSocketPath: cfg.Containerd.SocketPath,
}
if value := os.Getenv("HTTP_ADDR"); value != "" {
ret.ServerAddr = value
}
if value := os.Getenv("CONTAINERD_SOCKET"); value != "" {
ret.ContainerdSocketPath = value
}
return ret, nil
}
+3 -1
View File
@@ -27,6 +27,7 @@ import (
"github.com/containerd/containerd/v2/pkg/oci"
"github.com/containerd/errdefs"
"github.com/containerd/platforms"
"github.com/memohai/memoh/internal/config"
"github.com/opencontainers/go-digest"
"github.com/opencontainers/image-spec/identity"
"github.com/opencontainers/runtime-spec/specs-go"
@@ -148,7 +149,8 @@ type DefaultService struct {
logger *slog.Logger
}
func NewDefaultService(log *slog.Logger, client *containerd.Client, namespace string) *DefaultService {
func NewDefaultService(log *slog.Logger, client *containerd.Client, cfg config.Config) *DefaultService {
namespace := cfg.Containerd.Namespace
if namespace == "" {
namespace = DefaultNamespace
}
+4 -3
View File
@@ -10,6 +10,7 @@ import (
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/boot"
"github.com/memohai/memoh/internal/users"
)
@@ -35,11 +36,11 @@ type LoginResponse struct {
Username string `json:"username"`
}
func NewAuthHandler(log *slog.Logger, userService *users.Service, jwtSecret string, expiresIn time.Duration) *AuthHandler {
func NewAuthHandler(log *slog.Logger, userService *users.Service, runtimeConfig *boot.RuntimeConfig) *AuthHandler {
return &AuthHandler{
userService: userService,
jwtSecret: jwtSecret,
expiresIn: expiresIn,
jwtSecret: runtimeConfig.JwtSecret,
expiresIn: runtimeConfig.JwtExpiresIn,
logger: log.With(slog.String("handler", "auth")),
}
}
+3 -3
View File
@@ -95,11 +95,11 @@ type ListSnapshotsResponse struct {
Snapshots []SnapshotInfo `json:"snapshots"`
}
func NewContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string, botService *bots.Service, userService *users.Service, queries *dbsqlc.Queries) *ContainerdHandler {
func NewContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.Config, botService *bots.Service, userService *users.Service, queries *dbsqlc.Queries) *ContainerdHandler {
return &ContainerdHandler{
service: service,
cfg: cfg,
namespace: namespace,
cfg: cfg.MCP,
namespace: cfg.Containerd.Namespace,
logger: log.With(slog.String("handler", "containerd")),
mcpSess: make(map[string]*mcpSession),
mcpStdioSess: make(map[string]*mcpStdioSession),
+4 -8
View File
@@ -47,10 +47,12 @@ type Manager struct {
logger *slog.Logger
}
func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig) *Manager {
func NewManager(log *slog.Logger, service ctr.Service, cfg config.Config, db *pgxpool.Pool) *Manager {
return &Manager{
db: db,
queries: dbsqlc.New(db),
service: service,
cfg: cfg,
cfg: cfg.MCP,
logger: log.With(slog.String("component", "mcp")),
containerID: func(botID string) string {
return ContainerPrefix + botID
@@ -58,12 +60,6 @@ func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig) *Ma
}
}
func (m *Manager) WithDB(db *pgxpool.Pool) *Manager {
m.db = db
m.queries = dbsqlc.New(db)
return m
}
func (m *Manager) Init(ctx context.Context) error {
image := m.cfg.BusyboxImage
if image == "" {
+3 -2
View File
@@ -15,6 +15,7 @@ import (
"github.com/robfig/cron/v3"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/boot"
"github.com/memohai/memoh/internal/db/sqlc"
)
@@ -29,7 +30,7 @@ type Service struct {
jobs map[string]cron.EntryID
}
func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, jwtSecret string) *Service {
func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, runtimeConfig *boot.RuntimeConfig) *Service {
parser := cron.NewParser(cron.SecondOptional | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor)
c := cron.New(cron.WithParser(parser))
service := &Service{
@@ -37,7 +38,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, jw
cron: c,
parser: parser,
triggerer: triggerer,
jwtSecret: jwtSecret,
jwtSecret: runtimeConfig.JwtSecret,
logger: log.With(slog.String("service", "schedule")),
jobs: map[string]cron.EntryID{},
}
+16 -61
View File
@@ -1,6 +1,7 @@
package server
import (
"context"
"log/slog"
"strings"
@@ -8,7 +9,6 @@ import (
"github.com/labstack/echo/v4/middleware"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/handlers"
)
type Server struct {
@@ -17,7 +17,13 @@ type Server struct {
logger *slog.Logger
}
func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, chatHandler *handlers.ChatHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, settingsHandler *handlers.SettingsHandler, historyHandler *handlers.HistoryHandler, contactsHandler *handlers.ContactsHandler, preauthHandler *handlers.PreauthHandler, scheduleHandler *handlers.ScheduleHandler, subagentHandler *handlers.SubagentHandler, containerdHandler *handlers.ContainerdHandler, channelHandler *handlers.ChannelHandler, usersHandler *handlers.UsersHandler, mcpHandler *handlers.MCPHandler, cliHandler *handlers.LocalChannelHandler, webHandler *handlers.LocalChannelHandler) *Server {
type Handler interface {
Register(e *echo.Echo)
}
func NewServer(log *slog.Logger, addr string, jwtSecret string,
handlers ...Handler,
) *Server {
if addr == "" {
addr = ":8080"
}
@@ -51,65 +57,10 @@ func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *han
return false
}))
if pingHandler != nil {
pingHandler.Register(e)
}
if authHandler != nil {
authHandler.Register(e)
}
if memoryHandler != nil {
memoryHandler.Register(e)
}
if embeddingsHandler != nil {
embeddingsHandler.Register(e)
}
if chatHandler != nil {
chatHandler.Register(e)
}
if swaggerHandler != nil {
swaggerHandler.Register(e)
}
if settingsHandler != nil {
settingsHandler.Register(e)
}
if historyHandler != nil {
historyHandler.Register(e)
}
if contactsHandler != nil {
contactsHandler.Register(e)
}
if preauthHandler != nil {
preauthHandler.Register(e)
}
if scheduleHandler != nil {
scheduleHandler.Register(e)
}
if subagentHandler != nil {
subagentHandler.Register(e)
}
if providersHandler != nil {
providersHandler.Register(e)
}
if modelsHandler != nil {
modelsHandler.Register(e)
}
if containerdHandler != nil {
containerdHandler.Register(e)
}
if channelHandler != nil {
channelHandler.Register(e)
}
if usersHandler != nil {
usersHandler.Register(e)
}
if mcpHandler != nil {
mcpHandler.Register(e)
}
if cliHandler != nil {
cliHandler.Register(e)
}
if webHandler != nil {
webHandler.Register(e)
for _, h := range handlers {
if h != nil {
h.Register(e)
}
}
return &Server{
@@ -122,3 +73,7 @@ func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *han
func (s *Server) Start() error {
return s.echo.Start(s.addr)
}
func (s *Server) Stop(ctx context.Context) error {
return s.echo.Shutdown(ctx)
}