reactor(cli): move memoh cli to tui

1. Split the oversized `cmd/agent` entrypoint into a multi-file package and update dev/build scripts to use the package path instead of compiling `main.go` directly.
2. Add a new `memoh` terminal UI for local bot chat, with Bubble Tea
This commit is contained in:
晨苒
2026-04-14 00:39:34 +08:00
parent 8c9f222783
commit d50eeea114
32 changed files with 2140 additions and 1962 deletions
+265 -439
View File
@@ -15,10 +15,7 @@ import (
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"go.uber.org/fx"
"go.uber.org/fx/fxevent"
"golang.org/x/crypto/bcrypt"
"github.com/memohai/memoh/internal/accounts"
@@ -26,7 +23,6 @@ import (
agentpkg "github.com/memohai/memoh/internal/agent"
"github.com/memohai/memoh/internal/agent/background"
agenttools "github.com/memohai/memoh/internal/agent/tools"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/bind"
"github.com/memohai/memoh/internal/boot"
"github.com/memohai/memoh/internal/bots"
@@ -96,129 +92,6 @@ import (
"github.com/memohai/memoh/internal/workspace"
)
func runServe() {
fx.New(
fx.Provide(
provideConfig,
boot.ProvideRuntimeConfig,
provideLogger,
provideContainerService,
provideDBConn,
provideDBQueries,
provideWorkspaceManager,
provideMemoryLLM,
memprovider.NewService,
provideMemoryProviderRegistry,
models.NewService,
bots.NewService,
accounts.NewService,
acl.NewService,
settings.NewService,
provideProvidersService,
searchproviders.NewService,
policy.NewService,
mcp.NewConnectionService,
conversation.NewService,
identities.NewService,
bind.NewService,
event.NewHub,
provideTtsRegistry,
ttspkg.NewService,
provideTtsTempStore,
provideEmailRegistry,
emailpkg.NewService,
emailpkg.NewOutboxService,
provideEmailChatGateway,
provideEmailTrigger,
emailpkg.NewManager,
providePipeline,
provideEventStore,
provideDiscussDriver,
provideRouteService,
provideSessionService,
provideMessageService,
provideMediaService,
local.NewRouteHub,
provideChannelRegistry,
channel.NewStore,
provideChannelRouter,
provideChannelManager,
provideChannelLifecycleService,
provideAgent,
provideChatResolver,
browsercontexts.NewService,
provideScheduleTriggerer,
provideHeartbeatSessionCreator,
provideScheduleSessionCreator,
schedule.NewService,
provideHeartbeatTriggerer,
heartbeat.NewService,
compaction.NewService,
provideContainerdHandler,
provideFederationGateway,
provideToolGatewayService,
provideBackgroundManager,
provideToolProviders,
provideServerHandler(handlers.NewPingHandler),
provideServerHandler(provideMemohAuthHandler),
provideServerHandler(provideMemoryHandler),
provideServerHandler(provideMessageHandler),
provideServerHandler(provideSessionHandler),
provideServerHandler(handlers.NewSwaggerHandler),
provideServerHandler(handlers.NewProvidersHandler),
provideServerHandler(handlers.NewProviderOAuthHandler),
provideServerHandler(handlers.NewSearchProvidersHandler),
provideServerHandler(handlers.NewModelsHandler),
provideServerHandler(handlers.NewSettingsHandler),
provideServerHandler(handlers.NewACLHandler),
provideServerHandler(handlers.NewBindHandler),
provideServerHandler(handlers.NewScheduleHandler),
provideServerHandler(handlers.NewHeartbeatHandler),
provideServerHandler(handlers.NewCompactionHandler),
provideServerHandler(handlers.NewChannelHandler),
provideServerHandler(channel.NewWebhookServerHandler),
provideServerHandler(weixin.NewQRServerHandler),
provideServerHandler(provideUsersHandler),
provideServerHandler(handlers.NewMemoryProvidersHandler),
provideServerHandler(handlers.NewSpeechHandler),
provideServerHandler(handlers.NewBotTtsHandler),
provideServerHandler(handlers.NewEmailProvidersHandler),
provideServerHandler(handlers.NewEmailBindingsHandler),
provideServerHandler(handlers.NewEmailOutboxHandler),
provideServerHandler(handlers.NewEmailWebhookHandler),
provideServerHandler(provideEmailOAuthHandler),
emailpkg.NewDBOAuthTokenStore,
provideServerHandler(handlers.NewMCPHandler),
provideServerHandler(handlers.NewMCPOAuthHandler),
provideOAuthService,
provideServerHandler(handlers.NewTokenUsageHandler),
provideServerHandler(handlers.NewSessionInfoHandler),
provideServerHandler(handlers.NewBrowserContextsHandler),
provideServerHandler(provideWebHandler),
provideServerHandler(handlers.NewEmbeddedWebHandler),
provideServer,
),
fx.Invoke(
injectToolProviders,
startRegistrySync,
startMemoryProviderBootstrap,
startSearchProviderBootstrap,
startScheduleService,
startHeartbeatService,
startChannelManager,
startEmailManager,
startContainerReconciliation,
startBackgroundTaskCleanup,
startTtsTempStoreCleanup,
startServer,
),
fx.WithLogger(func(logger *slog.Logger) fxevent.Logger {
return &fxevent.SlogLogger{Logger: logger.With(slog.String("component", "fx"))}
}),
).Run()
}
func provideServerHandler(fn any) any {
return fx.Annotate(
fn,
@@ -227,15 +100,6 @@ func provideServerHandler(fn any) any {
)
}
func provideConfig() (config.Config, error) {
cfgPath := os.Getenv("CONFIG_PATH")
cfg, err := config.Load(cfgPath)
if err != nil {
return config.Config{}, fmt.Errorf("load config: %w", err)
}
return cfg, nil
}
func provideLogger(cfg config.Config) *slog.Logger {
logger.Init(cfg.Log.Level, cfg.Log.Format)
return logger.L
@@ -246,7 +110,12 @@ func provideContainerService(lc fx.Lifecycle, log *slog.Logger, cfg config.Confi
if err != nil {
return nil, err
}
lc.Append(fx.Hook{OnStop: func(_ context.Context) error { cleanup(); return nil }})
lc.Append(fx.Hook{
OnStop: func(_ context.Context) error {
cleanup()
return nil
},
})
return svc, nil
}
@@ -255,25 +124,39 @@ func provideDBConn(lc fx.Lifecycle, cfg config.Config) (*pgxpool.Pool, error) {
if err != nil {
return nil, fmt.Errorf("db connect: %w", err)
}
lc.Append(fx.Hook{OnStop: func(_ context.Context) error { conn.Close(); return nil }})
lc.Append(fx.Hook{
OnStop: func(_ context.Context) error {
conn.Close()
return nil
},
})
return conn, nil
}
func provideDBQueries(conn *pgxpool.Pool) *dbsqlc.Queries { return dbsqlc.New(conn) }
func provideDBQueries(conn *pgxpool.Pool) *dbsqlc.Queries {
return dbsqlc.New(conn)
}
func provideWorkspaceManager(log *slog.Logger, service ctr.Service, cfg config.Config, conn *pgxpool.Pool) *workspace.Manager {
return workspace.NewManager(log, service, cfg.Workspace, cfg.Containerd.Namespace, conn)
}
func provideMemoryLLM(modelsService *models.Service, settingsService *settings.Service, queries *dbsqlc.Queries, log *slog.Logger) memprovider.LLM {
return &lazyLLMClient{modelsService: modelsService, settingsService: settingsService, queries: queries, timeout: 30 * time.Second, logger: log}
return &lazyLLMClient{
modelsService: modelsService,
settingsService: settingsService,
queries: queries,
timeout: 30 * time.Second,
logger: log,
}
}
func provideMemoryProviderRegistry(log *slog.Logger, llm memprovider.LLM, chatService *conversation.Service, accountService *accounts.Service, manager *workspace.Manager, queries *dbsqlc.Queries, cfg config.Config) *memprovider.Registry {
registry := memprovider.NewRegistry(log)
builtinRuntime := handlers.NewBuiltinMemoryRuntime(manager)
fileRuntime := handlers.NewBuiltinMemoryRuntime(manager)
fileStore := storefs.New(log, manager)
registry.RegisterFactory(string(memprovider.ProviderBuiltin), func(_ string, providerConfig map[string]any) (memprovider.Provider, error) {
runtime, err := membuiltin.NewBuiltinRuntimeFromConfig(log, providerConfig, builtinRuntime, fileStore, queries, cfg)
runtime, err := membuiltin.NewBuiltinRuntimeFromConfig(log, providerConfig, fileRuntime, fileStore, queries, cfg)
if err != nil {
return nil, err
}
@@ -282,64 +165,18 @@ func provideMemoryProviderRegistry(log *slog.Logger, llm memprovider.LLM, chatSe
p.ApplyProviderConfig(providerConfig)
return p, nil
})
registry.RegisterFactory(string(memprovider.ProviderMem0), func(_ string, config map[string]any) (memprovider.Provider, error) {
return memmem0.NewMem0Provider(log, config, fileStore)
registry.RegisterFactory(string(memprovider.ProviderMem0), func(_ string, providerConfig map[string]any) (memprovider.Provider, error) {
return memmem0.NewMem0Provider(log, providerConfig, fileStore)
})
registry.RegisterFactory(string(memprovider.ProviderOpenViking), func(_ string, config map[string]any) (memprovider.Provider, error) {
return memopenviking.NewOpenVikingProvider(log, config)
registry.RegisterFactory(string(memprovider.ProviderOpenViking), func(_ string, providerConfig map[string]any) (memprovider.Provider, error) {
return memopenviking.NewOpenVikingProvider(log, providerConfig)
})
defaultProvider := membuiltin.NewBuiltinProvider(log, builtinRuntime, chatService, accountService)
defaultProvider := membuiltin.NewBuiltinProvider(log, fileRuntime, chatService, accountService)
defaultProvider.SetLLM(llm)
registry.Register("__builtin_default__", defaultProvider)
return registry
}
func startRegistrySync(lc fx.Lifecycle, log *slog.Logger, cfg config.Config, queries *dbsqlc.Queries) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
defs, err := registry.Load(log, cfg.Registry.ProvidersPath())
if err != nil {
log.Warn("registry: failed to load provider definitions", slog.Any("error", err))
return nil
}
if len(defs) == 0 {
return nil
}
return registry.Sync(ctx, log, queries, defs)
},
})
}
func startMemoryProviderBootstrap(lc fx.Lifecycle, log *slog.Logger, mpService *memprovider.Service, registry *memprovider.Registry) {
mpService.SetRegistry(registry)
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
resp, err := mpService.EnsureDefault(ctx)
if err != nil {
log.Warn("failed to ensure default memory provider", slog.Any("error", err))
return nil
}
if _, regErr := registry.Instantiate(resp.ID, resp.Provider, resp.Config); regErr != nil {
log.Warn("failed to instantiate default memory provider", slog.Any("error", regErr))
} else {
log.Info("default memory provider ready", slog.String("id", resp.ID), slog.String("provider", resp.Provider))
}
return nil
},
})
}
func startSearchProviderBootstrap(lc fx.Lifecycle, log *slog.Logger, spService *searchproviders.Service) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
if err := spService.EnsureDefaults(ctx); err != nil {
log.Warn("failed to ensure default search providers", slog.Any("error", err))
}
return nil
},
})
}
func providePipeline() *pipelinepkg.Pipeline {
return pipelinepkg.NewPipeline(pipelinepkg.RenderParams{})
}
@@ -439,37 +276,75 @@ func provideChatResolver(log *slog.Logger, a *agentpkg.Agent, modelsService *mod
func provideChannelRegistry(log *slog.Logger, hub *local.RouteHub, mediaService *media.Service) *channel.Registry {
registry := channel.NewRegistry()
tgAdapter := telegram.NewTelegramAdapter(log)
tgAdapter.SetAssetOpener(mediaService)
registry.MustRegister(tgAdapter)
discordAdapter := discord.NewDiscordAdapter(log)
discordAdapter.SetAssetOpener(mediaService)
registry.MustRegister(discordAdapter)
qqAdapter := qq.NewQQAdapter(log)
qqAdapter.SetAssetOpener(mediaService)
registry.MustRegister(qqAdapter)
matrixAdapter := matrix.NewMatrixAdapter(log)
matrixAdapter.SetAssetOpener(mediaService)
registry.MustRegister(matrixAdapter)
feishuAdapter := feishu.NewFeishuAdapter(log)
feishuAdapter.SetAssetOpener(mediaService)
registry.MustRegister(feishuAdapter)
registry.MustRegister(wecom.NewWeComAdapter(log))
dingTalkAdapter := dingtalk.NewDingTalkAdapter(log)
registry.MustRegister(dingTalkAdapter)
registry.MustRegister(wechatoa.NewWeChatOAAdapter(log))
weixinAdapter := weixin.NewWeixinAdapter(log)
weixinAdapter.SetAssetOpener(mediaService)
registry.MustRegister(weixinAdapter)
registry.MustRegister(local.NewWebAdapter(hub))
// Misskey
registry.MustRegister(misskey.NewMisskeyAdapter(log))
return registry
}
func provideChannelRouter(log *slog.Logger, registry *channel.Registry, hub *local.RouteHub, routeService *route.DBService, sessionService *sessionpkg.Service, msgService *message.DBService, resolver *flow.Resolver, identityService *identities.Service, botService *bots.Service, aclService *acl.Service, policyService *policy.Service, bindService *bind.Service, mediaService *media.Service, ttsService *ttspkg.Service, settingsService *settings.Service, scheduleService *schedule.Service, mcpConnService *mcp.ConnectionService, modelsService *models.Service, providersService *providers.Service, memProvService *memprovider.Service, searchProvService *searchproviders.Service, browserCtxService *browsercontexts.Service, emailService *emailpkg.Service, emailOutboxService *emailpkg.OutboxService, heartbeatService *heartbeat.Service, queries *dbsqlc.Queries, containerdHandler *handlers.ContainerdHandler, manager *workspace.Manager, pipeline *pipelinepkg.Pipeline, eventStore *pipelinepkg.EventStore, discussDriver *pipelinepkg.DiscussDriver, rc *boot.RuntimeConfig) *inbound.ChannelInboundProcessor {
func provideChannelRouter(
log *slog.Logger,
registry *channel.Registry,
hub *local.RouteHub,
routeService *route.DBService,
sessionService *sessionpkg.Service,
msgService *message.DBService,
resolver *flow.Resolver,
identityService *identities.Service,
botService *bots.Service,
aclService *acl.Service,
policyService *policy.Service,
bindService *bind.Service,
mediaService *media.Service,
ttsService *ttspkg.Service,
settingsService *settings.Service,
scheduleService *schedule.Service,
mcpConnService *mcp.ConnectionService,
modelsService *models.Service,
providersService *providers.Service,
memProvService *memprovider.Service,
searchProvService *searchproviders.Service,
browserCtxService *browsercontexts.Service,
emailService *emailpkg.Service,
emailOutboxService *emailpkg.OutboxService,
heartbeatService *heartbeat.Service,
queries *dbsqlc.Queries,
containerdHandler *handlers.ContainerdHandler,
manager *workspace.Manager,
pipeline *pipelinepkg.Pipeline,
eventStore *pipelinepkg.EventStore,
discussDriver *pipelinepkg.DiscussDriver,
rc *boot.RuntimeConfig,
) *inbound.ChannelInboundProcessor {
adapter, ok := registry.Get(qq.Type)
if !ok {
panic("qq adapter not registered")
@@ -480,6 +355,7 @@ func provideChannelRouter(log *slog.Logger, registry *channel.Registry, hub *loc
}
qqAdapter.SetChannelIdentityResolver(identityService)
qqAdapter.SetRouteResolver(routeService)
processor := inbound.NewChannelInboundProcessor(log, registry, routeService, msgService, resolver, identityService, policyService, bindService, rc.JwtSecret, 5*time.Minute)
processor.SetSessionEnsurer(&sessionEnsurerAdapter{svc: sessionService})
processor.SetPipeline(pipeline, eventStore, discussDriver)
@@ -548,7 +424,7 @@ func provideOAuthService(log *slog.Logger, queries *dbsqlc.Queries, cfg config.C
if strings.HasPrefix(host, ":") {
host = "localhost" + host
}
callbackURL := "http://" + host + "/oauth/mcp/callback"
callbackURL := "http://" + host + "/api/oauth/mcp/callback"
return mcp.NewOAuthService(log, queries, callbackURL)
}
@@ -597,8 +473,8 @@ func provideMemoryHandler(log *slog.Logger, botService *bots.Service, accountSer
return h
}
func provideMemohAuthHandler(log *slog.Logger, accountService *accounts.Service, rc *boot.RuntimeConfig) *memohAuthHandler {
return &memohAuthHandler{inner: handlers.NewAuthHandler(log, accountService, rc.JwtSecret, rc.JwtExpiresIn)}
func provideAuthHandler(log *slog.Logger, accountService *accounts.Service, rc *boot.RuntimeConfig) *handlers.AuthHandler {
return handlers.NewAuthHandler(log, accountService, rc.JwtSecret, rc.JwtExpiresIn)
}
func provideMessageHandler(log *slog.Logger, chatService *conversation.Service, msgService *message.DBService, mediaService *media.Service, botService *bots.Service, accountService *accounts.Service, hub *event.Hub) *handlers.MessageHandler {
@@ -611,13 +487,6 @@ func provideSessionHandler(log *slog.Logger, sessionService *sessionpkg.Service,
return handlers.NewSessionHandler(log, sessionService, botService, accountService)
}
type memohAuthHandler struct{ inner *handlers.AuthHandler }
func (h *memohAuthHandler) Register(e *echo.Echo) {
e.POST("/api/auth/login", h.inner.Login)
e.POST("/api/auth/refresh", h.inner.Refresh)
}
func provideMediaService(log *slog.Logger, manager *workspace.Manager, cfg config.Config) *media.Service {
primary := containerfs.New(manager)
dataRoot := cfg.Workspace.DataRoot
@@ -641,222 +510,6 @@ func provideWebHandler(channelManager *channel.Manager, channelStore *channel.St
return h
}
type serverParams struct {
fx.In
Logger *slog.Logger
RuntimeConfig *boot.RuntimeConfig
Config config.Config
ServerHandlers []server.Handler `group:"server_handlers"`
ContainerdHandler *handlers.ContainerdHandler
}
type memohServer struct {
echo *echo.Echo
addr string
}
var (
memohJWTExactSkipPaths = map[string]struct{}{
"/": {},
"/ping": {},
"/health": {},
"/api/swagger.json": {},
"/api/auth/login": {},
"/logo.png": {},
"/channels/telegram.webp": {},
"/channels/feishu.png": {},
}
memohJWTPrefixSkipPaths = []string{
"/assets/",
"/api/docs",
"/channels/feishu/webhook/",
"/email/mailgun/webhook/",
"/email/oauth/callback",
}
memohSPABackendPrefixes = []string{
"/api",
"/auth",
"/channels",
"/containers",
"/users",
"/bots",
"/models",
"/providers",
"/search_providers",
"/email-providers",
"/email",
"/settings",
"/memory",
"/message",
"/mcp",
"/schedule",
"/bind",
"/preauth",
"/ping",
"/health",
}
memohAPIRewriteBypassExact = map[string]struct{}{
"/api/swagger.json": {},
}
memohAPIRewriteBypassPrefixes = []string{
"/api/docs",
"/api/auth/",
}
)
func (s *memohServer) Start() error { return s.echo.Start(s.addr) }
func (s *memohServer) Stop(ctx context.Context) error { return s.echo.Shutdown(ctx) }
func provideServer(params serverParams) *memohServer {
allHandlers := make([]server.Handler, 0, len(params.ServerHandlers)+1)
allHandlers = append(allHandlers, params.ServerHandlers...)
allHandlers = append(allHandlers, params.ContainerdHandler)
addr := params.RuntimeConfig.ServerAddr
if addr == "" {
addr = ":8080"
}
e := echo.New()
e.HideBanner = true
e.Pre(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
rewriteAPIPathForMemoh(c.Request())
return next(c)
}
})
e.Use(middleware.Recover())
e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
LogStatus: true,
LogURI: true,
LogMethod: true,
LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
params.Logger.Info("request",
slog.String("method", v.Method),
slog.String("uri", v.URI),
slog.Int("status", v.Status),
slog.Duration("latency", v.Latency),
slog.String("remote_ip", c.RealIP()),
)
return nil
},
}))
e.Use(auth.JWTMiddleware(params.Config.Auth.JWTSecret, func(c echo.Context) bool {
return shouldSkipJWTForMemoh(c.Request().URL.Path)
}))
for _, h := range allHandlers {
if h != nil {
h.Register(e)
}
}
return &memohServer{echo: e, addr: addr}
}
func startScheduleService(lc fx.Lifecycle, scheduleService *schedule.Service) {
lc.Append(fx.Hook{OnStart: func(ctx context.Context) error { return scheduleService.Bootstrap(ctx) }})
}
func startHeartbeatService(lc fx.Lifecycle, heartbeatService *heartbeat.Service) {
lc.Append(fx.Hook{OnStart: func(ctx context.Context) error { return heartbeatService.Bootstrap(ctx) }})
}
func startChannelManager(lc fx.Lifecycle, channelManager *channel.Manager) {
ctx, cancel := context.WithCancel(context.Background())
lc.Append(fx.Hook{
OnStart: func(_ context.Context) error { channelManager.Start(ctx); return nil },
OnStop: func(stopCtx context.Context) error { cancel(); return channelManager.Shutdown(stopCtx) },
})
}
func startContainerReconciliation(lc fx.Lifecycle, manager *workspace.Manager, _ *handlers.ContainerdHandler, _ *mcp.ToolGatewayService) {
lc.Append(fx.Hook{OnStart: func(ctx context.Context) error { go manager.ReconcileContainers(ctx); return nil }})
}
func startServer(lc fx.Lifecycle, logger *slog.Logger, srv *memohServer, shutdowner fx.Shutdowner, cfg config.Config, queries *dbsqlc.Queries, botService *bots.Service, _ *handlers.ContainerdHandler, manager *workspace.Manager, mcpConnService *mcp.ConnectionService, toolGateway *mcp.ToolGatewayService, channelManager *channel.Manager, modelsService *models.Service) {
fmt.Printf("Starting Memoh Agent %s\n", version.GetInfo())
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
if err := ensureAdminUser(ctx, logger, queries, cfg); err != nil {
return err
}
botService.SetContainerLifecycle(manager)
botService.SetContainerReachability(func(ctx context.Context, botID string) error {
_, err := manager.MCPClient(ctx, botID)
return err
})
botService.AddRuntimeChecker(healthcheck.NewRuntimeCheckerAdapter(mcpchecker.NewChecker(logger, mcpConnService, toolGateway)))
botService.AddRuntimeChecker(healthcheck.NewRuntimeCheckerAdapter(channelchecker.NewChecker(logger, channelManager)))
botService.AddRuntimeChecker(healthcheck.NewRuntimeCheckerAdapter(modelchecker.NewChecker(logger, modelchecker.NewQueriesLookup(queries), modelsService)))
go func() {
if err := srv.Start(); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Error("server failed", slog.Any("error", err))
_ = shutdowner.Shutdown()
}
}()
return nil
},
OnStop: func(ctx context.Context) error {
if err := srv.Stop(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("server stop: %w", err)
}
return nil
},
})
}
func shouldSkipJWTForMemoh(path string) bool {
if _, ok := memohJWTExactSkipPaths[path]; ok {
return true
}
if hasAnyPrefix(path, memohJWTPrefixSkipPaths) {
return true
}
// Treat non-backend, extension-less paths as SPA routes (e.g. /chat, /settings/profile).
return shouldServeSPARouteForMemoh(path)
}
func shouldServeSPARouteForMemoh(path string) bool {
if path == "" || path == "/" {
return true
}
if strings.Contains(path, ".") {
return false
}
if hasAnyPrefix(path, memohSPABackendPrefixes) {
return false
}
return true
}
func rewriteAPIPathForMemoh(r *http.Request) {
if r == nil || r.URL == nil {
return
}
path := r.URL.Path
if !strings.HasPrefix(path, "/api/") {
return
}
if _, ok := memohAPIRewriteBypassExact[path]; ok {
return
}
if hasAnyPrefix(path, memohAPIRewriteBypassPrefixes) {
return
}
rewritten := strings.TrimPrefix(path, "/api")
if rewritten == "" {
rewritten = "/"
}
r.URL.Path = rewritten
}
func hasAnyPrefix(path string, prefixes []string) bool {
for _, prefix := range prefixes {
if strings.HasPrefix(path, prefix) {
return true
}
}
return false
}
func provideTtsRegistry(log *slog.Logger) *ttspkg.Registry {
reg := ttspkg.NewRegistry()
reg.Register(ttsedge.NewEdgeAdapter(log))
@@ -895,8 +548,6 @@ func startBackgroundTaskCleanup(lc fx.Lifecycle, mgr *background.Manager) {
})
}
// settingsTtsModelResolver adapts settings.Service to the ttsModelResolver interface
// expected by ChannelInboundProcessor and LocalChannelHandler.
type sessionEnsurerAdapter struct {
svc *sessionpkg.Service
}
@@ -945,8 +596,7 @@ func provideEmailRegistry(log *slog.Logger, tokenStore *emailpkg.DBOAuthTokenSto
return reg
}
func provideProvidersService(log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) *providers.Service {
_ = cfg
func provideProvidersService(log *slog.Logger, queries *dbsqlc.Queries, _ config.Config) *providers.Service {
return providers.NewService(log, queries, defaultProviderOAuthCallbackURL())
}
@@ -986,7 +636,162 @@ func startEmailManager(lc fx.Lifecycle, emailManager *emailpkg.Manager) {
}()
return nil
},
OnStop: func(stopCtx context.Context) error { cancel(); emailManager.Stop(stopCtx); return nil },
OnStop: func(stopCtx context.Context) error {
cancel()
emailManager.Stop(stopCtx)
return nil
},
})
}
type serverParams struct {
fx.In
Logger *slog.Logger
RuntimeConfig *boot.RuntimeConfig
Config config.Config
ServerHandlers []server.Handler `group:"server_handlers"`
ContainerdHandler *handlers.ContainerdHandler
}
func provideServer(params serverParams) *server.Server {
allHandlers := make([]server.Handler, 0, len(params.ServerHandlers)+1)
allHandlers = append(allHandlers, params.ServerHandlers...)
allHandlers = append(allHandlers, params.ContainerdHandler)
return server.NewServer(params.Logger, params.RuntimeConfig.ServerAddr, params.Config.Auth.JWTSecret, allHandlers...)
}
func startRegistrySync(lc fx.Lifecycle, log *slog.Logger, cfg config.Config, queries *dbsqlc.Queries) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
defs, err := registry.Load(log, cfg.Registry.ProvidersPath())
if err != nil {
log.Warn("registry: failed to load provider definitions", slog.Any("error", err))
return nil
}
if len(defs) == 0 {
return nil
}
return registry.Sync(ctx, log, queries, defs)
},
})
}
func startMemoryProviderBootstrap(lc fx.Lifecycle, log *slog.Logger, mpService *memprovider.Service, registry *memprovider.Registry) {
mpService.SetRegistry(registry)
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
resp, err := mpService.EnsureDefault(ctx)
if err != nil {
log.Warn("failed to ensure default memory provider", slog.Any("error", err))
return nil
}
if _, regErr := registry.Instantiate(resp.ID, resp.Provider, resp.Config); regErr != nil {
log.Warn("failed to instantiate default memory provider", slog.Any("error", regErr))
} else {
log.Info("default memory provider ready", slog.String("id", resp.ID), slog.String("provider", resp.Provider))
}
return nil
},
})
}
func startSearchProviderBootstrap(lc fx.Lifecycle, log *slog.Logger, spService *searchproviders.Service) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
if err := spService.EnsureDefaults(ctx); err != nil {
log.Warn("failed to ensure default search providers", slog.Any("error", err))
}
return nil
},
})
}
func startScheduleService(lc fx.Lifecycle, scheduleService *schedule.Service) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
return scheduleService.Bootstrap(ctx)
},
})
}
func startHeartbeatService(lc fx.Lifecycle, heartbeatService *heartbeat.Service) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
return heartbeatService.Bootstrap(ctx)
},
})
}
func wireResolverOutbound(resolver *flow.Resolver, channelManager *channel.Manager) {
resolver.SetOutboundFn(func(ctx context.Context, botID, channelType, target, text string) error {
return channelManager.Send(ctx, botID, channel.ChannelType(channelType), channel.SendRequest{
Target: target,
Message: channel.Message{Text: text},
})
})
}
func startChannelManager(lc fx.Lifecycle, channelManager *channel.Manager) {
ctx, cancel := context.WithCancel(context.Background())
lc.Append(fx.Hook{
OnStart: func(_ context.Context) error {
channelManager.Start(ctx)
return nil
},
OnStop: func(stopCtx context.Context) error {
cancel()
return channelManager.Shutdown(stopCtx)
},
})
}
func startContainerReconciliation(lc fx.Lifecycle, manager *workspace.Manager, _ *handlers.ContainerdHandler, _ *mcp.ToolGatewayService) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
go manager.ReconcileContainers(ctx)
return nil
},
})
}
func startServer(lc fx.Lifecycle, logger *slog.Logger, srv *server.Server, shutdowner fx.Shutdowner, cfg config.Config, queries *dbsqlc.Queries, botService *bots.Service, _ *handlers.ContainerdHandler, manager *workspace.Manager, mcpConnService *mcp.ConnectionService, toolGateway *mcp.ToolGatewayService, channelManager *channel.Manager, modelsService *models.Service) {
fmt.Printf("Starting Memoh Agent %s\n", version.GetInfo())
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
if err := ensureAdminUser(ctx, logger, queries, cfg); err != nil {
return err
}
botService.SetContainerLifecycle(manager)
botService.SetContainerReachability(func(ctx context.Context, botID string) error {
_, err := manager.MCPClient(ctx, botID)
return err
})
botService.AddRuntimeChecker(healthcheck.NewRuntimeCheckerAdapter(
mcpchecker.NewChecker(logger, mcpConnService, toolGateway),
))
botService.AddRuntimeChecker(healthcheck.NewRuntimeCheckerAdapter(
channelchecker.NewChecker(logger, channelManager),
))
botService.AddRuntimeChecker(healthcheck.NewRuntimeCheckerAdapter(
modelchecker.NewChecker(logger, modelchecker.NewQueriesLookup(queries), modelsService),
))
go func() {
if err := srv.Start(); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Error("server failed", slog.Any("error", err))
_ = shutdowner.Shutdown()
}
}()
return nil
},
OnStop: func(ctx context.Context) error {
if err := srv.Stop(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("server stop: %w", err)
}
return nil
},
})
}
@@ -1001,6 +806,7 @@ func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Quer
if count > 0 {
return nil
}
username := strings.TrimSpace(cfg.Admin.Username)
password := strings.TrimSpace(cfg.Admin.Password)
email := strings.TrimSpace(cfg.Admin.Email)
@@ -1010,24 +816,37 @@ func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Quer
if password == "change-your-password-here" {
log.Warn("admin password uses default placeholder; please update config.toml")
}
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
user, err := queries.CreateUser(ctx, dbsqlc.CreateUserParams{IsActive: true, Metadata: []byte("{}")})
user, err := queries.CreateUser(ctx, dbsqlc.CreateUserParams{
IsActive: true,
Metadata: []byte("{}"),
})
if err != nil {
return fmt.Errorf("create admin user: %w", err)
}
emailValue := pgtype.Text{Valid: false}
if email != "" {
emailValue = pgtype.Text{String: email, Valid: true}
}
displayName := pgtype.Text{String: username, Valid: true}
dataRoot := pgtype.Text{String: cfg.Workspace.DataRoot, Valid: cfg.Workspace.DataRoot != ""}
_, err = queries.CreateAccount(ctx, dbsqlc.CreateAccountParams{
UserID: user.ID, Username: pgtype.Text{String: username, Valid: true}, Email: emailValue,
PasswordHash: pgtype.Text{String: string(hashed), Valid: true}, Role: "admin",
DisplayName: displayName, AvatarUrl: pgtype.Text{Valid: false}, IsActive: true, DataRoot: dataRoot,
UserID: user.ID,
Username: pgtype.Text{String: username, Valid: true},
Email: emailValue,
PasswordHash: pgtype.Text{String: string(hashed), Valid: true},
Role: "admin",
DisplayName: displayName,
AvatarUrl: pgtype.Text{Valid: false},
IsActive: true,
DataRoot: dataRoot,
})
if err != nil {
return err
@@ -1081,11 +900,9 @@ func (c *lazyLLMClient) resolve(ctx context.Context, botID string) (memprovider.
return nil, errors.New("models service not configured")
}
// Try to use the bot's configured chat model for memory operations.
chatModelID := ""
if c.settingsService != nil && strings.TrimSpace(botID) != "" {
if botSettings, err := c.settingsService.GetBot(ctx, botID); err == nil {
// Prefer compaction model (smaller/cheaper), then chat model.
if id := strings.TrimSpace(botSettings.CompactionModelID); id != "" {
chatModelID = id
} else if id := strings.TrimSpace(botSettings.ChatModelID); id != "" {
@@ -1107,7 +924,9 @@ func (c *lazyLLMClient) resolve(ctx context.Context, botID string) (memprovider.
}), nil
}
type skillLoaderAdapter struct{ handler *handlers.ContainerdHandler }
type skillLoaderAdapter struct {
handler *handlers.ContainerdHandler
}
func (a *skillLoaderAdapter) LoadSkills(ctx context.Context, botID string) ([]flow.SkillEntry, error) {
items, err := a.handler.LoadSkills(ctx, botID)
@@ -1116,12 +935,19 @@ func (a *skillLoaderAdapter) LoadSkills(ctx context.Context, botID string) ([]fl
}
entries := make([]flow.SkillEntry, len(items))
for i, item := range items {
entries[i] = flow.SkillEntry{Name: item.Name, Description: item.Description, Content: item.Content, Metadata: item.Metadata}
entries[i] = flow.SkillEntry{
Name: item.Name,
Description: item.Description,
Content: item.Content,
Metadata: item.Metadata,
}
}
return entries, nil
}
type mediaAssetResolverAdapter struct{ media *media.Service }
type mediaAssetResolverAdapter struct {
media *media.Service
}
func (a *mediaAssetResolverAdapter) Stat(ctx context.Context, botID, contentHash string) (media.Asset, error) {
if a == nil || a.media == nil {
@@ -1165,7 +991,9 @@ func (a *mediaAssetResolverAdapter) IngestContainerFile(ctx context.Context, bot
return a.media.IngestContainerFile(ctx, botID, containerPath)
}
type gatewayAssetLoaderAdapter struct{ media *media.Service }
type gatewayAssetLoaderAdapter struct {
media *media.Service
}
func (a *gatewayAssetLoaderAdapter) OpenForGateway(ctx context.Context, botID, contentHash string) (io.ReadCloser, string, error) {
if a == nil || a.media == nil {
@@ -1178,7 +1006,6 @@ func (a *gatewayAssetLoaderAdapter) OpenForGateway(ctx context.Context, botID, c
return reader, strings.TrimSpace(asset.Mime), nil
}
// commandSkillLoaderAdapter bridges handlers.ContainerdHandler to command.SkillLoader.
type commandSkillLoaderAdapter struct {
handler *handlers.ContainerdHandler
}
@@ -1195,7 +1022,6 @@ func (a *commandSkillLoaderAdapter) LoadSkills(ctx context.Context, botID string
return skills, nil
}
// commandContainerFSAdapter bridges workspace.Manager to command.ContainerFS.
type commandContainerFSAdapter struct {
manager *workspace.Manager
}
+5 -1300
View File
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -4,5 +4,5 @@ dir = "{{cwd}}"
[tasks.start]
alias = "dev"
description = "Start server"
run = "go run cmd/agent/main.go"
run = "go run cmd/agent"
depends = ["//:go-install"]
+160
View File
@@ -0,0 +1,160 @@
package main
import (
"log/slog"
"go.uber.org/fx"
"go.uber.org/fx/fxevent"
"github.com/memohai/memoh/internal/accounts"
"github.com/memohai/memoh/internal/acl"
"github.com/memohai/memoh/internal/bind"
"github.com/memohai/memoh/internal/boot"
"github.com/memohai/memoh/internal/bots"
"github.com/memohai/memoh/internal/browsercontexts"
"github.com/memohai/memoh/internal/channel"
"github.com/memohai/memoh/internal/channel/adapters/local"
"github.com/memohai/memoh/internal/channel/adapters/weixin"
"github.com/memohai/memoh/internal/channel/identities"
"github.com/memohai/memoh/internal/compaction"
"github.com/memohai/memoh/internal/conversation"
emailpkg "github.com/memohai/memoh/internal/email"
"github.com/memohai/memoh/internal/handlers"
"github.com/memohai/memoh/internal/heartbeat"
"github.com/memohai/memoh/internal/mcp"
memprovider "github.com/memohai/memoh/internal/memory/adapters"
"github.com/memohai/memoh/internal/message/event"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/policy"
"github.com/memohai/memoh/internal/schedule"
"github.com/memohai/memoh/internal/searchproviders"
"github.com/memohai/memoh/internal/settings"
ttspkg "github.com/memohai/memoh/internal/tts"
)
func runServe() {
fx.New(options()).Run()
}
func options() fx.Option {
return fx.Options(
fx.Provide(
provideConfig,
boot.ProvideRuntimeConfig,
provideLogger,
provideContainerService,
provideDBConn,
provideDBQueries,
provideWorkspaceManager,
provideMemoryLLM,
memprovider.NewService,
provideMemoryProviderRegistry,
models.NewService,
bots.NewService,
accounts.NewService,
acl.NewService,
settings.NewService,
provideProvidersService,
searchproviders.NewService,
browsercontexts.NewService,
policy.NewService,
mcp.NewConnectionService,
conversation.NewService,
identities.NewService,
bind.NewService,
event.NewHub,
provideTtsRegistry,
ttspkg.NewService,
provideTtsTempStore,
emailpkg.NewDBOAuthTokenStore,
provideEmailRegistry,
emailpkg.NewService,
emailpkg.NewOutboxService,
provideEmailChatGateway,
provideEmailTrigger,
emailpkg.NewManager,
provideRouteService,
provideSessionService,
provideMessageService,
provideMediaService,
providePipeline,
provideEventStore,
provideDiscussDriver,
local.NewRouteHub,
provideChannelRegistry,
channel.NewStore,
provideChannelRouter,
provideChannelManager,
provideChannelLifecycleService,
provideAgent,
provideChatResolver,
provideScheduleTriggerer,
provideHeartbeatSessionCreator,
provideScheduleSessionCreator,
schedule.NewService,
provideHeartbeatTriggerer,
heartbeat.NewService,
compaction.NewService,
provideContainerdHandler,
provideFederationGateway,
provideToolGatewayService,
provideBackgroundManager,
provideToolProviders,
provideServerHandler(handlers.NewPingHandler),
provideServerHandler(provideAuthHandler),
provideServerHandler(provideMemoryHandler),
provideServerHandler(provideMessageHandler),
provideServerHandler(provideSessionHandler),
provideServerHandler(handlers.NewSwaggerHandler),
provideServerHandler(handlers.NewProvidersHandler),
provideServerHandler(handlers.NewProviderOAuthHandler),
provideServerHandler(handlers.NewSearchProvidersHandler),
provideServerHandler(handlers.NewModelsHandler),
provideServerHandler(handlers.NewSettingsHandler),
provideServerHandler(handlers.NewACLHandler),
provideServerHandler(handlers.NewBindHandler),
provideServerHandler(handlers.NewScheduleHandler),
provideServerHandler(handlers.NewHeartbeatHandler),
provideServerHandler(handlers.NewCompactionHandler),
provideServerHandler(handlers.NewChannelHandler),
provideServerHandler(channel.NewWebhookServerHandler),
provideServerHandler(weixin.NewQRServerHandler),
provideServerHandler(provideUsersHandler),
provideServerHandler(handlers.NewMemoryProvidersHandler),
provideServerHandler(handlers.NewSpeechHandler),
provideServerHandler(handlers.NewBotTtsHandler),
provideServerHandler(handlers.NewEmailProvidersHandler),
provideServerHandler(handlers.NewEmailBindingsHandler),
provideServerHandler(handlers.NewEmailOutboxHandler),
provideServerHandler(handlers.NewEmailWebhookHandler),
provideServerHandler(provideEmailOAuthHandler),
provideServerHandler(handlers.NewMCPHandler),
provideServerHandler(handlers.NewMCPOAuthHandler),
provideOAuthService,
provideServerHandler(handlers.NewTokenUsageHandler),
provideServerHandler(handlers.NewSessionInfoHandler),
provideServerHandler(handlers.NewBrowserContextsHandler),
provideServerHandler(handlers.NewSupermarketHandler),
provideServerHandler(provideWebHandler),
provideServer,
),
fx.Invoke(
injectToolProviders,
startRegistrySync,
startMemoryProviderBootstrap,
startSearchProviderBootstrap,
startScheduleService,
startHeartbeatService,
wireResolverOutbound,
startChannelManager,
startEmailManager,
startContainerReconciliation,
startBackgroundTaskCleanup,
startTtsTempStoreCleanup,
startServer,
),
fx.WithLogger(func(logger *slog.Logger) fxevent.Logger {
return &fxevent.SlogLogger{Logger: logger.With(slog.String("component", "fx"))}
}),
)
}
+58
View File
@@ -0,0 +1,58 @@
package main
import (
"fmt"
"io/fs"
"log/slog"
"os"
dbembed "github.com/memohai/memoh/db"
"github.com/memohai/memoh/internal/config"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/logger"
"github.com/memohai/memoh/internal/version"
)
func provideConfig() (config.Config, error) {
cfgPath := os.Getenv("CONFIG_PATH")
cfg, err := config.Load(cfgPath)
if err != nil {
return config.Config{}, fmt.Errorf("load config: %w", err)
}
return cfg, nil
}
func migrationsFS() fs.FS {
sub, err := fs.Sub(dbembed.MigrationsFS, "migrations")
if err != nil {
panic(fmt.Sprintf("embedded migrations: %v", err))
}
return sub
}
func runMigrateCommand(args []string) error {
cfg, err := provideConfig()
if err != nil {
return fmt.Errorf("config: %w", err)
}
logger.Init(cfg.Log.Level, cfg.Log.Format)
log := logger.L
migrateCmd := args[0]
var migrateArgs []string
if len(args) > 1 {
migrateArgs = args[1:]
}
if err := db.RunMigrate(log, cfg.Postgres, migrationsFS(), migrateCmd, migrateArgs); err != nil {
log.Error("migration failed", slog.Any("error", err))
return err
}
return nil
}
func runVersion() error {
fmt.Printf("memoh-server %s\n", version.GetInfo())
return nil
}
+62
View File
@@ -0,0 +1,62 @@
package main
import (
"context"
"fmt"
"time"
"github.com/spf13/cobra"
"github.com/memohai/memoh/internal/tui"
)
func newChatCommand(ctx *cliContext) *cobra.Command {
var botID string
var sessionID string
var message string
cmd := &cobra.Command{
Use: "chat",
Short: "Send one chat message and stream the reply",
RunE: func(_ *cobra.Command, _ []string) error {
client := tui.NewClient(ctx.state.ServerURL, ctx.state.Token)
if sessionID == "" {
requestCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
sess, err := client.CreateSession(requestCtx, botID, message)
if err != nil {
return err
}
sessionID = sess.ID
fmt.Printf("session: %s\n", sessionID)
}
streamCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
return client.StreamChat(streamCtx, tui.ChatRequest{
BotID: botID,
SessionID: sessionID,
Text: message,
}, func(event tui.ChatEvent) error {
switch event.Type {
case "start":
fmt.Println("[start]")
case "message":
fmt.Println(tui.RenderUIMessage(event.Data))
case "error":
fmt.Println("[error]", event.Message)
case "end":
fmt.Println("[end]")
}
return nil
})
},
}
cmd.Flags().StringVar(&botID, "bot", "", "Target bot ID")
cmd.Flags().StringVar(&sessionID, "session", "", "Existing session ID")
cmd.Flags().StringVar(&message, "message", "", "User message text")
_ = cmd.MarkFlagRequired("bot")
_ = cmd.MarkFlagRequired("message")
return cmd
}
+51
View File
@@ -0,0 +1,51 @@
package main
import (
"context"
"fmt"
"time"
"github.com/spf13/cobra"
"github.com/memohai/memoh/internal/tui"
)
func newLoginCommand(ctx *cliContext) *cobra.Command {
var username string
var password string
cmd := &cobra.Command{
Use: "login",
Short: "Authenticate and persist a local access token",
RunE: func(_ *cobra.Command, _ []string) error {
client := tui.NewClient(ctx.state.ServerURL, "")
requestCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
resp, err := client.Login(requestCtx, username, password)
if err != nil {
return err
}
next := ctx.state
next.ServerURL = client.BaseURL
next.Token = resp.AccessToken
next.Username = resp.Username
if parsed, err := time.Parse(time.RFC3339, resp.ExpiresAt); err == nil {
next.ExpiresAt = parsed
}
if err := tui.SaveState(next); err != nil {
return err
}
fmt.Printf("Logged in as %s against %s\n", resp.Username, client.BaseURL)
return nil
},
}
cmd.Flags().StringVar(&username, "username", "", "Account username")
cmd.Flags().StringVar(&password, "password", "", "Account password")
_ = cmd.MarkFlagRequired("username")
_ = cmd.MarkFlagRequired("password")
return cmd
}
+1 -38
View File
@@ -2,47 +2,10 @@ package main
import (
"os"
"github.com/spf13/cobra"
)
func main() {
rootCmd := &cobra.Command{
Use: "memoh",
Short: "Memoh unified binary",
RunE: func(_ *cobra.Command, _ []string) error {
runServe()
return nil
},
}
rootCmd.AddCommand(&cobra.Command{
Use: "serve",
Short: "Start the server",
RunE: func(_ *cobra.Command, _ []string) error {
runServe()
return nil
},
})
rootCmd.AddCommand(&cobra.Command{
Use: "migrate <up|down|version|force N>",
Short: "Run database migrations",
Args: cobra.MinimumNArgs(1),
RunE: func(_ *cobra.Command, args []string) error {
return runMigrate(args)
},
})
rootCmd.AddCommand(&cobra.Command{
Use: "version",
Short: "Print version information",
RunE: func(_ *cobra.Command, _ []string) error {
return runVersion()
},
})
if err := rootCmd.Execute(); err != nil {
if err := newRootCommand().Execute(); err != nil {
os.Exit(1)
}
}
+9 -36
View File
@@ -1,41 +1,14 @@
package main
import (
"fmt"
"io/fs"
"log/slog"
import "github.com/spf13/cobra"
dbembed "github.com/memohai/memoh/db"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/logger"
)
func migrationsFS() fs.FS {
sub, err := fs.Sub(dbembed.MigrationsFS, "migrations")
if err != nil {
panic(fmt.Sprintf("embedded migrations: %v", err))
func newMigrateCommand() *cobra.Command {
return &cobra.Command{
Use: "migrate <up|down|version|force N>",
Short: "Run database migrations",
Args: cobra.MinimumNArgs(1),
RunE: func(_ *cobra.Command, args []string) error {
return runMigrate(args)
},
}
return sub
}
func runMigrate(args []string) error {
cfg, err := provideConfig()
if err != nil {
return fmt.Errorf("config: %w", err)
}
logger.Init(cfg.Log.Level, cfg.Log.Format)
log := logger.L
migrateCmd := args[0]
var migrateArgs []string
if len(args) > 1 {
migrateArgs = args[1:]
}
if err := db.RunMigrate(log, cfg.Postgres, migrationsFS(), migrateCmd, migrateArgs); err != nil {
log.Error("migration failed", slog.Any("error", err))
return err
}
return nil
}
+69
View File
@@ -0,0 +1,69 @@
package main
import (
"fmt"
tea "github.com/charmbracelet/bubbletea"
"github.com/spf13/cobra"
"github.com/memohai/memoh/internal/tui"
)
type cliContext struct {
state tui.State
server string
}
func newRootCommand() *cobra.Command {
ctx := &cliContext{}
rootCmd := &cobra.Command{
Use: "memoh",
Short: "Memoh terminal operator CLI",
RunE: func(_ *cobra.Command, _ []string) error {
return runTUI(ctx)
},
PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
state, err := tui.LoadState()
if err != nil {
return err
}
ctx.state = state
if ctx.server != "" {
ctx.state.ServerURL = tui.NormalizeServerURL(ctx.server)
}
return nil
},
}
rootCmd.PersistentFlags().StringVar(&ctx.server, "server", "", "Memoh server URL")
rootCmd.AddCommand(newMigrateCommand())
rootCmd.AddCommand(newLoginCommand(ctx))
rootCmd.AddCommand(newChatCommand(ctx))
rootCmd.AddCommand(&cobra.Command{
Use: "tui",
Short: "Open the terminal UI",
RunE: func(_ *cobra.Command, _ []string) error {
return runTUI(ctx)
},
})
rootCmd.AddCommand(&cobra.Command{
Use: "version",
Short: "Print version information",
RunE: func(_ *cobra.Command, _ []string) error {
return runVersion()
},
})
return rootCmd
}
func runTUI(ctx *cliContext) error {
model := tui.NewTUIModel(ctx.state)
program := tea.NewProgram(model, tea.WithAltScreen())
if _, err := program.Run(); err != nil {
return fmt.Errorf("run tui: %w", err)
}
return nil
}
+63
View File
@@ -0,0 +1,63 @@
package main
import (
"errors"
"fmt"
"io/fs"
"log/slog"
"os"
dbembed "github.com/memohai/memoh/db"
"github.com/memohai/memoh/internal/config"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/logger"
"github.com/memohai/memoh/internal/version"
)
func provideConfig() (config.Config, error) {
cfgPath := os.Getenv("CONFIG_PATH")
cfg, err := config.Load(cfgPath)
if err != nil {
return config.Config{}, fmt.Errorf("load config: %w", err)
}
return cfg, nil
}
func migrationsFS() fs.FS {
sub, err := fs.Sub(dbembed.MigrationsFS, "migrations")
if err != nil {
panic(fmt.Sprintf("embedded migrations: %v", err))
}
return sub
}
func runMigrate(args []string) error {
if len(args) == 0 {
return errors.New("usage: migrate <up|down|version|force N>")
}
cfg, err := provideConfig()
if err != nil {
return fmt.Errorf("config: %w", err)
}
logger.Init(cfg.Log.Level, cfg.Log.Format)
log := logger.L
migrateCmd := args[0]
var migrateArgs []string
if len(args) > 1 {
migrateArgs = args[1:]
}
if err := db.RunMigrate(log, cfg.Postgres, migrationsFS(), migrateCmd, migrateArgs); err != nil {
log.Error("migration failed", slog.Any("error", err))
return err
}
return nil
}
func runVersion() error {
fmt.Printf("memoh %s\n", version.GetInfo())
return nil
}
-12
View File
@@ -1,12 +0,0 @@
package main
import (
"fmt"
"github.com/memohai/memoh/internal/version"
)
func runVersion() error {
fmt.Printf("memoh %s\n", version.GetInfo())
return nil
}