Files
Memoh/cmd/agent/modules/memory.go
T
2026-02-19 23:39:56 +08:00

192 lines
6.2 KiB
Go

package modules
import (
"context"
"fmt"
"log/slog"
"strings"
"time"
"github.com/memohai/memoh/internal/config"
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/embeddings"
"github.com/memohai/memoh/internal/memory"
"github.com/memohai/memoh/internal/models"
"go.uber.org/fx"
)
var MemoryModule = fx.Module(
"memory",
fx.Provide(
provideMemoryLLM,
provideEmbeddingSetup,
provideEmbeddingsResolver,
provideTextEmbedderForMemory,
provideQdrantStore,
memory.NewBM25Indexer,
provideMemoryService,
),
fx.Invoke(startMemoryWarmup),
)
// ---------------------------------------------------------------------------
// memory providers
// ---------------------------------------------------------------------------
func provideMemoryLLM(modelsService *models.Service, queries *dbsqlc.Queries, log *slog.Logger) memory.LLM {
return &lazyLLMClient{
modelsService: modelsService,
queries: queries,
timeout: 30 * time.Second,
logger: log,
}
}
func provideEmbeddingsResolver(log *slog.Logger, modelsService *models.Service, queries *dbsqlc.Queries) *embeddings.Resolver {
return embeddings.NewResolver(log, modelsService, queries, 10*time.Second)
}
type embeddingSetup struct {
Vectors map[string]int
TextModel models.GetResponse
MultimodalModel models.GetResponse
HasEmbeddingModels bool
}
func provideEmbeddingSetup(log *slog.Logger, modelsService *models.Service) (embeddingSetup, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
vectors, textModel, multimodalModel, hasEmbeddingModels, err := embeddings.CollectEmbeddingVectors(ctx, modelsService)
if err != nil {
return embeddingSetup{}, fmt.Errorf("embedding models: %w", err)
}
if hasEmbeddingModels && multimodalModel.ModelID == "" {
log.Warn("No multimodal embedding model configured. Multimodal embedding features will be limited.")
}
return embeddingSetup{
Vectors: vectors,
TextModel: textModel,
MultimodalModel: multimodalModel,
HasEmbeddingModels: hasEmbeddingModels,
}, nil
}
func provideTextEmbedderForMemory(resolver *embeddings.Resolver, setup embeddingSetup, log *slog.Logger) embeddings.Embedder {
return buildTextEmbedder(resolver, setup.TextModel, setup.HasEmbeddingModels, log)
}
func provideQdrantStore(log *slog.Logger, cfg config.Config, setup embeddingSetup) (*memory.QdrantStore, error) {
qcfg := cfg.Qdrant
timeout := time.Duration(qcfg.TimeoutSeconds) * time.Second
if setup.HasEmbeddingModels && len(setup.Vectors) > 0 {
store, err := memory.NewQdrantStoreWithVectors(log, qcfg.BaseURL, qcfg.APIKey, qcfg.Collection, setup.Vectors, "sparse_hash", timeout)
if err != nil {
return nil, fmt.Errorf("qdrant named vectors init: %w", err)
}
return store, nil
}
store, err := memory.NewQdrantStore(log, qcfg.BaseURL, qcfg.APIKey, qcfg.Collection, setup.TextModel.Dimensions, "sparse_hash", timeout)
if err != nil {
return nil, fmt.Errorf("qdrant init: %w", err)
}
return store, nil
}
func provideMemoryService(log *slog.Logger, llm memory.LLM, embedder embeddings.Embedder, store *memory.QdrantStore, resolver *embeddings.Resolver, bm25 *memory.BM25Indexer, setup embeddingSetup) *memory.Service {
return memory.NewService(log, llm, embedder, store, resolver, bm25, setup.TextModel.ModelID, setup.MultimodalModel.ModelID)
}
func buildTextEmbedder(resolver *embeddings.Resolver, textModel models.GetResponse, hasModels bool, log *slog.Logger) embeddings.Embedder {
if !hasModels {
return nil
}
if textModel.ModelID == "" || textModel.Dimensions <= 0 {
log.Warn("No text embedding model configured. Text embedding features will be limited.")
return nil
}
return &embeddings.ResolverTextEmbedder{
Resolver: resolver,
ModelID: textModel.ModelID,
Dims: textModel.Dimensions,
}
}
func startMemoryWarmup(lc fx.Lifecycle, memoryService *memory.Service, logger *slog.Logger) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
go func() {
if err := memoryService.WarmupBM25(context.Background(), 200); err != nil {
logger.Warn("bm25 warmup failed", slog.Any("error", err))
}
}()
return nil
},
})
}
// ---------------------------------------------------------------------------
// lazy LLM client
// ---------------------------------------------------------------------------
type lazyLLMClient struct {
modelsService *models.Service
queries *dbsqlc.Queries
timeout time.Duration
logger *slog.Logger
}
func (c *lazyLLMClient) Extract(ctx context.Context, req memory.ExtractRequest) (memory.ExtractResponse, error) {
client, err := c.resolve(ctx)
if err != nil {
return memory.ExtractResponse{}, err
}
return client.Extract(ctx, req)
}
func (c *lazyLLMClient) Decide(ctx context.Context, req memory.DecideRequest) (memory.DecideResponse, error) {
client, err := c.resolve(ctx)
if err != nil {
return memory.DecideResponse{}, err
}
return client.Decide(ctx, req)
}
func (c *lazyLLMClient) Compact(ctx context.Context, req memory.CompactRequest) (memory.CompactResponse, error) {
client, err := c.resolve(ctx)
if err != nil {
return memory.CompactResponse{}, err
}
return client.Compact(ctx, req)
}
func (c *lazyLLMClient) DetectLanguage(ctx context.Context, text string) (string, error) {
client, err := c.resolve(ctx)
if err != nil {
return "", err
}
return client.DetectLanguage(ctx, text)
}
func (c *lazyLLMClient) resolve(ctx context.Context) (memory.LLM, error) {
if c.modelsService == nil || c.queries == nil {
return nil, fmt.Errorf("models service not configured")
}
memoryModel, memoryProvider, err := models.SelectMemoryModel(ctx, c.modelsService, c.queries)
if err != nil {
return nil, err
}
clientType := strings.ToLower(strings.TrimSpace(memoryProvider.ClientType))
switch clientType {
case "openai", "openai-compat", "azure", "mistral", "xai", "ollama", "dashscope":
// These providers support OpenAI-compatible /chat/completions endpoint
default:
return nil, fmt.Errorf("memory provider client type not supported: %s", memoryProvider.ClientType)
}
return memory.NewLLMClient(c.logger, memoryProvider.BaseUrl, memoryProvider.ApiKey, memoryModel.ModelID, c.timeout)
}