mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
a75fdb8040
Split business executors from federation sources and migrate unified tool/federation transports to the official go-sdk for stricter MCP compliance and safer session lifecycle handling. Add targeted regression tests for accept compatibility, initialization retries, pending cleanup, and include updated swagger artifacts.
395 lines
15 KiB
Go
395 lines
15 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/memohai/memoh/internal/accounts"
|
|
"github.com/memohai/memoh/internal/bind"
|
|
"github.com/memohai/memoh/internal/bots"
|
|
"github.com/memohai/memoh/internal/channel"
|
|
"github.com/memohai/memoh/internal/channel/adapters/feishu"
|
|
"github.com/memohai/memoh/internal/channel/adapters/local"
|
|
"github.com/memohai/memoh/internal/channel/adapters/telegram"
|
|
"github.com/memohai/memoh/internal/channelidentities"
|
|
"github.com/memohai/memoh/internal/chat"
|
|
"github.com/memohai/memoh/internal/config"
|
|
ctr "github.com/memohai/memoh/internal/containerd"
|
|
"github.com/memohai/memoh/internal/db"
|
|
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
|
"github.com/memohai/memoh/internal/embeddings"
|
|
"github.com/memohai/memoh/internal/handlers"
|
|
"github.com/memohai/memoh/internal/logger"
|
|
"github.com/memohai/memoh/internal/mcp"
|
|
mcpmemory "github.com/memohai/memoh/internal/mcp/providers/memory"
|
|
mcpmessage "github.com/memohai/memoh/internal/mcp/providers/message"
|
|
mcpschedule "github.com/memohai/memoh/internal/mcp/providers/schedule"
|
|
mcpfederation "github.com/memohai/memoh/internal/mcp/sources/federation"
|
|
"github.com/memohai/memoh/internal/memory"
|
|
"github.com/memohai/memoh/internal/models"
|
|
"github.com/memohai/memoh/internal/policy"
|
|
"github.com/memohai/memoh/internal/preauth"
|
|
"github.com/memohai/memoh/internal/providers"
|
|
"github.com/memohai/memoh/internal/router"
|
|
"github.com/memohai/memoh/internal/schedule"
|
|
"github.com/memohai/memoh/internal/server"
|
|
"github.com/memohai/memoh/internal/settings"
|
|
"github.com/memohai/memoh/internal/subagent"
|
|
"github.com/memohai/memoh/internal/version"
|
|
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
func main() {
|
|
fmt.Printf("Starting Memoh Agent %s\n", version.GetInfo())
|
|
ctx := context.Background()
|
|
cfgPath := os.Getenv("CONFIG_PATH")
|
|
cfg, err := config.Load(cfgPath)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "load config: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
logger.Init(cfg.Log.Level, cfg.Log.Format)
|
|
|
|
if strings.TrimSpace(cfg.Auth.JWTSecret) == "" {
|
|
logger.Error("jwt secret is required")
|
|
os.Exit(1)
|
|
}
|
|
jwtExpiresIn, err := time.ParseDuration(cfg.Auth.JWTExpiresIn)
|
|
if err != nil {
|
|
logger.Error("invalid jwt expires in", slog.Any("error", err))
|
|
os.Exit(1)
|
|
}
|
|
|
|
addr := cfg.Server.Addr
|
|
if value := os.Getenv("HTTP_ADDR"); value != "" {
|
|
addr = value
|
|
}
|
|
|
|
socketPath := cfg.Containerd.SocketPath
|
|
if value := os.Getenv("CONTAINERD_SOCKET"); value != "" {
|
|
socketPath = value
|
|
}
|
|
factory := ctr.DefaultClientFactory{SocketPath: socketPath}
|
|
client, err := factory.New(ctx)
|
|
if err != nil {
|
|
logger.Error("connect containerd", slog.Any("error", err))
|
|
os.Exit(1)
|
|
}
|
|
defer client.Close()
|
|
|
|
service := ctr.NewDefaultService(logger.L, client, cfg.Containerd.Namespace)
|
|
manager := mcp.NewManager(logger.L, service, cfg.MCP)
|
|
|
|
pingHandler := handlers.NewPingHandler(logger.L)
|
|
// containerdHandler is created later after DB services are initialized
|
|
|
|
conn, err := db.Open(ctx, cfg.Postgres)
|
|
if err != nil {
|
|
logger.Error("db connect", slog.Any("error", err))
|
|
os.Exit(1)
|
|
}
|
|
defer conn.Close()
|
|
manager.WithDB(conn)
|
|
queries := dbsqlc.New(conn)
|
|
modelsService := models.NewService(logger.L, queries)
|
|
botService := bots.NewService(logger.L, queries)
|
|
accountService := accounts.NewService(logger.L, queries)
|
|
|
|
containerdHandler := handlers.NewContainerdHandler(logger.L, service, cfg.MCP, cfg.Containerd.Namespace, botService, accountService, queries)
|
|
botService.SetContainerLifecycle(containerdHandler)
|
|
|
|
if err := ensureAdminUser(ctx, logger.L, queries, cfg); err != nil {
|
|
logger.Error("ensure admin user", slog.Any("error", err))
|
|
os.Exit(1)
|
|
}
|
|
|
|
authHandler := handlers.NewAuthHandler(logger.L, accountService, cfg.Auth.JWTSecret, jwtExpiresIn)
|
|
|
|
// Initialize chat resolver after memory service is configured.
|
|
var chatResolver *chat.Resolver
|
|
|
|
// Create LLM client for memory operations (deferred model/provider selection).
|
|
var llmClient memory.LLM = &lazyLLMClient{
|
|
modelsService: modelsService,
|
|
queries: queries,
|
|
timeout: 30 * time.Second,
|
|
logger: logger.L,
|
|
}
|
|
|
|
resolver := embeddings.NewResolver(logger.L, modelsService, queries, 10*time.Second)
|
|
vectors, textModel, multimodalModel, hasEmbeddingModels, err := embeddings.CollectEmbeddingVectors(ctx, modelsService)
|
|
if err != nil {
|
|
logger.Error("embedding models", slog.Any("error", err))
|
|
os.Exit(1)
|
|
}
|
|
|
|
textEmbedder := buildTextEmbedder(resolver, textModel, hasEmbeddingModels, logger.L)
|
|
if hasEmbeddingModels && multimodalModel.ModelID == "" {
|
|
logger.Warn("No multimodal embedding model configured. Multimodal embedding features will be limited.")
|
|
}
|
|
store := buildQdrantStore(logger.L, cfg.Qdrant, vectors, hasEmbeddingModels, textModel.Dimensions)
|
|
|
|
bm25Indexer := memory.NewBM25Indexer(logger.L)
|
|
memoryService := memory.NewService(logger.L, llmClient, textEmbedder, store, resolver, bm25Indexer, textModel.ModelID, multimodalModel.ModelID)
|
|
go func() {
|
|
if err := memoryService.WarmupBM25(ctx, 200); err != nil {
|
|
logger.Warn("bm25 warmup failed", slog.Any("error", err))
|
|
}
|
|
}()
|
|
|
|
// Initialize providers and models handlers
|
|
providersService := providers.NewService(logger.L, queries)
|
|
providersHandler := handlers.NewProvidersHandler(logger.L, providersService, modelsService)
|
|
settingsService := settings.NewService(logger.L, queries)
|
|
settingsHandler := handlers.NewSettingsHandler(logger.L, settingsService, botService, accountService)
|
|
modelsHandler := handlers.NewModelsHandler(logger.L, modelsService, settingsService)
|
|
policyService := policy.NewService(logger.L, botService, settingsService)
|
|
chatService := chat.NewService(logger.L, queries)
|
|
memoryHandler := handlers.NewMemoryHandler(logger.L, memoryService, chatService, accountService)
|
|
actorService := channelidentities.NewService(logger.L, queries)
|
|
preauthService := preauth.NewService(queries)
|
|
preauthHandler := handlers.NewPreauthHandler(preauthService, botService, accountService)
|
|
bindService := bind.NewService(logger.L, conn, queries)
|
|
bindHandler := handlers.NewBindHandler(logger.L, bindService)
|
|
mcpConnectionsService := mcp.NewConnectionService(logger.L, queries)
|
|
mcpHandler := handlers.NewMCPHandler(logger.L, mcpConnectionsService, botService, accountService)
|
|
chatResolver = chat.NewResolver(logger.L, modelsService, queries, memoryService, chatService, settingsService, mcpConnectionsService, cfg.AgentGateway.BaseURL(), 120*time.Second)
|
|
chatResolver.SetSkillLoader(&skillLoaderAdapter{handler: containerdHandler})
|
|
embeddingsHandler := handlers.NewEmbeddingsHandler(logger.L, modelsService, queries)
|
|
swaggerHandler := handlers.NewSwaggerHandler(logger.L)
|
|
chatHandler := handlers.NewChatHandler(logger.L, chatResolver, chatService, botService, accountService)
|
|
channelRegistry := channel.NewRegistry()
|
|
sessionHub := local.NewSessionHub()
|
|
channelRegistry.MustRegister(telegram.NewTelegramAdapter(logger.L))
|
|
channelRegistry.MustRegister(feishu.NewFeishuAdapter(logger.L))
|
|
channelRegistry.MustRegister(local.NewCLIAdapter(sessionHub))
|
|
channelRegistry.MustRegister(local.NewWebAdapter(sessionHub))
|
|
channelService := channel.NewService(queries, channelRegistry)
|
|
channelRouter := router.NewChannelInboundProcessor(logger.L, channelRegistry, chatService, chatResolver, actorService, botService, policyService, preauthService, bindService, cfg.Auth.JWTSecret, 5*time.Minute)
|
|
channelManager := channel.NewManager(logger.L, channelRegistry, channelService, channelRouter)
|
|
if mw := channelRouter.IdentityMiddleware(); mw != nil {
|
|
channelManager.Use(mw)
|
|
}
|
|
channelManager.Start(ctx)
|
|
channelHandler := handlers.NewChannelHandler(channelService, channelRegistry)
|
|
usersHandler := handlers.NewUsersHandler(logger.L, accountService, actorService, botService, chatService, channelService, channelManager, channelRegistry)
|
|
cliHandler := handlers.NewLocalChannelHandler(local.CLIType, channelManager, channelService, chatService, sessionHub, botService, accountService)
|
|
webHandler := handlers.NewLocalChannelHandler(local.WebType, channelManager, channelService, chatService, sessionHub, botService, accountService)
|
|
scheduleGateway := chat.NewScheduleGateway(chatResolver)
|
|
scheduleService := schedule.NewService(logger.L, queries, scheduleGateway, cfg.Auth.JWTSecret)
|
|
if err := scheduleService.Bootstrap(ctx); err != nil {
|
|
logger.Error("schedule bootstrap", slog.Any("error", err))
|
|
os.Exit(1)
|
|
}
|
|
scheduleHandler := handlers.NewScheduleHandler(logger.L, scheduleService, botService, accountService)
|
|
subagentService := subagent.NewService(logger.L, queries)
|
|
subagentHandler := handlers.NewSubagentHandler(logger.L, subagentService, botService, accountService)
|
|
messageToolExecutor := mcpmessage.NewExecutor(logger.L, channelManager, channelRegistry)
|
|
scheduleToolExecutor := mcpschedule.NewExecutor(logger.L, scheduleService)
|
|
memoryToolExecutor := mcpmemory.NewExecutor(logger.L, memoryService, chatService, accountService)
|
|
federationGateway := handlers.NewMCPFederationGateway(logger.L, containerdHandler)
|
|
federatedToolSource := mcpfederation.NewSource(logger.L, federationGateway, mcpConnectionsService)
|
|
toolGatewayService := mcp.NewToolGatewayService(
|
|
logger.L,
|
|
[]mcp.ToolExecutor{
|
|
messageToolExecutor,
|
|
scheduleToolExecutor,
|
|
memoryToolExecutor,
|
|
},
|
|
[]mcp.ToolSource{
|
|
federatedToolSource,
|
|
},
|
|
)
|
|
containerdHandler.SetToolGatewayService(toolGatewayService)
|
|
srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, preauthHandler, bindHandler, scheduleHandler, subagentHandler, containerdHandler, channelHandler, usersHandler, mcpHandler, cliHandler, webHandler)
|
|
|
|
if err := srv.Start(); err != nil {
|
|
logger.Error("server failed", slog.Any("error", err))
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func buildTextEmbedder(resolver *embeddings.Resolver, textModel models.GetResponse, hasModels bool, log *slog.Logger) embeddings.Embedder {
|
|
if !hasModels {
|
|
return nil
|
|
}
|
|
if textModel.ModelID == "" || textModel.Dimensions <= 0 {
|
|
log.Warn("No text embedding model configured. Text embedding features will be limited.")
|
|
return nil
|
|
}
|
|
return &embeddings.ResolverTextEmbedder{
|
|
Resolver: resolver,
|
|
ModelID: textModel.ModelID,
|
|
Dims: textModel.Dimensions,
|
|
}
|
|
}
|
|
|
|
func buildQdrantStore(log *slog.Logger, cfg config.QdrantConfig, vectors map[string]int, hasModels bool, textDims int) *memory.QdrantStore {
|
|
timeout := time.Duration(cfg.TimeoutSeconds) * time.Second
|
|
if hasModels && len(vectors) > 0 {
|
|
store, err := memory.NewQdrantStoreWithVectors(
|
|
log,
|
|
cfg.BaseURL,
|
|
cfg.APIKey,
|
|
cfg.Collection,
|
|
vectors,
|
|
"sparse_hash",
|
|
timeout,
|
|
)
|
|
if err != nil {
|
|
log.Error("qdrant named vectors init", slog.Any("error", err))
|
|
os.Exit(1)
|
|
}
|
|
return store
|
|
}
|
|
store, err := memory.NewQdrantStore(
|
|
log,
|
|
cfg.BaseURL,
|
|
cfg.APIKey,
|
|
cfg.Collection,
|
|
textDims,
|
|
"sparse_hash",
|
|
timeout,
|
|
)
|
|
if err != nil {
|
|
log.Error("qdrant init", slog.Any("error", err))
|
|
os.Exit(1)
|
|
}
|
|
return store
|
|
}
|
|
|
|
func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) error {
|
|
if queries == nil {
|
|
return fmt.Errorf("db queries not configured")
|
|
}
|
|
count, err := queries.CountAccounts(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if count > 0 {
|
|
return nil
|
|
}
|
|
|
|
username := strings.TrimSpace(cfg.Admin.Username)
|
|
password := strings.TrimSpace(cfg.Admin.Password)
|
|
email := strings.TrimSpace(cfg.Admin.Email)
|
|
if username == "" || password == "" {
|
|
return fmt.Errorf("admin username/password required in config.toml")
|
|
}
|
|
if password == "change-your-password-here" {
|
|
log.Warn("admin password uses default placeholder; please update config.toml")
|
|
}
|
|
|
|
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
user, err := queries.CreateUser(ctx, dbsqlc.CreateUserParams{
|
|
IsActive: true,
|
|
Metadata: []byte("{}"),
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("create admin user: %w", err)
|
|
}
|
|
|
|
emailValue := pgtype.Text{Valid: false}
|
|
if email != "" {
|
|
emailValue = pgtype.Text{String: email, Valid: true}
|
|
}
|
|
displayName := pgtype.Text{String: username, Valid: true}
|
|
dataRoot := pgtype.Text{String: cfg.MCP.DataRoot, Valid: cfg.MCP.DataRoot != ""}
|
|
|
|
_, err = queries.CreateAccount(ctx, dbsqlc.CreateAccountParams{
|
|
UserID: user.ID,
|
|
Username: pgtype.Text{String: username, Valid: true},
|
|
Email: emailValue,
|
|
PasswordHash: pgtype.Text{String: string(hashed), Valid: true},
|
|
Role: "admin",
|
|
DisplayName: displayName,
|
|
AvatarUrl: pgtype.Text{Valid: false},
|
|
IsActive: true,
|
|
DataRoot: dataRoot,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.Info("Admin user created", slog.String("username", username))
|
|
return nil
|
|
}
|
|
|
|
type lazyLLMClient struct {
|
|
modelsService *models.Service
|
|
queries *dbsqlc.Queries
|
|
timeout time.Duration
|
|
logger *slog.Logger
|
|
}
|
|
|
|
func (c *lazyLLMClient) Extract(ctx context.Context, req memory.ExtractRequest) (memory.ExtractResponse, error) {
|
|
client, err := c.resolve(ctx)
|
|
if err != nil {
|
|
return memory.ExtractResponse{}, err
|
|
}
|
|
return client.Extract(ctx, req)
|
|
}
|
|
|
|
func (c *lazyLLMClient) Decide(ctx context.Context, req memory.DecideRequest) (memory.DecideResponse, error) {
|
|
client, err := c.resolve(ctx)
|
|
if err != nil {
|
|
return memory.DecideResponse{}, err
|
|
}
|
|
return client.Decide(ctx, req)
|
|
}
|
|
|
|
func (c *lazyLLMClient) DetectLanguage(ctx context.Context, text string) (string, error) {
|
|
client, err := c.resolve(ctx)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return client.DetectLanguage(ctx, text)
|
|
}
|
|
|
|
func (c *lazyLLMClient) resolve(ctx context.Context) (memory.LLM, error) {
|
|
if c.modelsService == nil || c.queries == nil {
|
|
return nil, fmt.Errorf("models service not configured")
|
|
}
|
|
memoryModel, memoryProvider, err := models.SelectMemoryModel(ctx, c.modelsService, c.queries)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
clientType := strings.ToLower(strings.TrimSpace(memoryProvider.ClientType))
|
|
if clientType != "openai" && clientType != "openai-compat" {
|
|
return nil, fmt.Errorf("memory provider client type not supported: %s", memoryProvider.ClientType)
|
|
}
|
|
return memory.NewLLMClient(c.logger, memoryProvider.BaseUrl, memoryProvider.ApiKey, memoryModel.ModelID, c.timeout)
|
|
}
|
|
|
|
// skillLoaderAdapter bridges handlers.ContainerdHandler to chat.SkillLoader.
|
|
type skillLoaderAdapter struct {
|
|
handler *handlers.ContainerdHandler
|
|
}
|
|
|
|
func (a *skillLoaderAdapter) LoadSkills(ctx context.Context, botID string) ([]chat.SkillEntry, error) {
|
|
items, err := a.handler.LoadSkills(ctx, botID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
entries := make([]chat.SkillEntry, len(items))
|
|
for i, item := range items {
|
|
entries[i] = chat.SkillEntry{
|
|
Name: item.Name,
|
|
Description: item.Description,
|
|
Content: item.Content,
|
|
Metadata: item.Metadata,
|
|
}
|
|
}
|
|
return entries, nil
|
|
}
|