mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
192 lines
6.2 KiB
Go
192 lines
6.2 KiB
Go
package modules
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/memohai/memoh/internal/config"
|
|
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
|
"github.com/memohai/memoh/internal/embeddings"
|
|
"github.com/memohai/memoh/internal/memory"
|
|
"github.com/memohai/memoh/internal/models"
|
|
"go.uber.org/fx"
|
|
)
|
|
|
|
|
|
var MemoryModule = fx.Module(
|
|
"memory",
|
|
fx.Provide(
|
|
provideMemoryLLM,
|
|
provideEmbeddingSetup,
|
|
provideEmbeddingsResolver,
|
|
provideTextEmbedderForMemory,
|
|
provideQdrantStore,
|
|
memory.NewBM25Indexer,
|
|
provideMemoryService,
|
|
),
|
|
fx.Invoke(startMemoryWarmup),
|
|
)
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// memory providers
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func provideMemoryLLM(modelsService *models.Service, queries *dbsqlc.Queries, log *slog.Logger) memory.LLM {
|
|
return &lazyLLMClient{
|
|
modelsService: modelsService,
|
|
queries: queries,
|
|
timeout: 30 * time.Second,
|
|
logger: log,
|
|
}
|
|
}
|
|
|
|
func provideEmbeddingsResolver(log *slog.Logger, modelsService *models.Service, queries *dbsqlc.Queries) *embeddings.Resolver {
|
|
return embeddings.NewResolver(log, modelsService, queries, 10*time.Second)
|
|
}
|
|
|
|
type embeddingSetup struct {
|
|
Vectors map[string]int
|
|
TextModel models.GetResponse
|
|
MultimodalModel models.GetResponse
|
|
HasEmbeddingModels bool
|
|
}
|
|
|
|
func provideEmbeddingSetup(log *slog.Logger, modelsService *models.Service) (embeddingSetup, error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
vectors, textModel, multimodalModel, hasEmbeddingModels, err := embeddings.CollectEmbeddingVectors(ctx, modelsService)
|
|
if err != nil {
|
|
return embeddingSetup{}, fmt.Errorf("embedding models: %w", err)
|
|
}
|
|
if hasEmbeddingModels && multimodalModel.ModelID == "" {
|
|
log.Warn("No multimodal embedding model configured. Multimodal embedding features will be limited.")
|
|
}
|
|
return embeddingSetup{
|
|
Vectors: vectors,
|
|
TextModel: textModel,
|
|
MultimodalModel: multimodalModel,
|
|
HasEmbeddingModels: hasEmbeddingModels,
|
|
}, nil
|
|
}
|
|
|
|
func provideTextEmbedderForMemory(resolver *embeddings.Resolver, setup embeddingSetup, log *slog.Logger) embeddings.Embedder {
|
|
return buildTextEmbedder(resolver, setup.TextModel, setup.HasEmbeddingModels, log)
|
|
}
|
|
|
|
func provideQdrantStore(log *slog.Logger, cfg config.Config, setup embeddingSetup) (*memory.QdrantStore, error) {
|
|
qcfg := cfg.Qdrant
|
|
timeout := time.Duration(qcfg.TimeoutSeconds) * time.Second
|
|
if setup.HasEmbeddingModels && len(setup.Vectors) > 0 {
|
|
store, err := memory.NewQdrantStoreWithVectors(log, qcfg.BaseURL, qcfg.APIKey, qcfg.Collection, setup.Vectors, "sparse_hash", timeout)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("qdrant named vectors init: %w", err)
|
|
}
|
|
return store, nil
|
|
}
|
|
store, err := memory.NewQdrantStore(log, qcfg.BaseURL, qcfg.APIKey, qcfg.Collection, setup.TextModel.Dimensions, "sparse_hash", timeout)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("qdrant init: %w", err)
|
|
}
|
|
return store, nil
|
|
}
|
|
|
|
func provideMemoryService(log *slog.Logger, llm memory.LLM, embedder embeddings.Embedder, store *memory.QdrantStore, resolver *embeddings.Resolver, bm25 *memory.BM25Indexer, setup embeddingSetup) *memory.Service {
|
|
return memory.NewService(log, llm, embedder, store, resolver, bm25, setup.TextModel.ModelID, setup.MultimodalModel.ModelID)
|
|
}
|
|
|
|
func buildTextEmbedder(resolver *embeddings.Resolver, textModel models.GetResponse, hasModels bool, log *slog.Logger) embeddings.Embedder {
|
|
if !hasModels {
|
|
return nil
|
|
}
|
|
if textModel.ModelID == "" || textModel.Dimensions <= 0 {
|
|
log.Warn("No text embedding model configured. Text embedding features will be limited.")
|
|
return nil
|
|
}
|
|
return &embeddings.ResolverTextEmbedder{
|
|
Resolver: resolver,
|
|
ModelID: textModel.ModelID,
|
|
Dims: textModel.Dimensions,
|
|
}
|
|
}
|
|
|
|
func startMemoryWarmup(lc fx.Lifecycle, memoryService *memory.Service, logger *slog.Logger) {
|
|
lc.Append(fx.Hook{
|
|
OnStart: func(ctx context.Context) error {
|
|
go func() {
|
|
if err := memoryService.WarmupBM25(context.Background(), 200); err != nil {
|
|
logger.Warn("bm25 warmup failed", slog.Any("error", err))
|
|
}
|
|
}()
|
|
return nil
|
|
},
|
|
})
|
|
}
|
|
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// lazy LLM client
|
|
// ---------------------------------------------------------------------------
|
|
|
|
type lazyLLMClient struct {
|
|
modelsService *models.Service
|
|
queries *dbsqlc.Queries
|
|
timeout time.Duration
|
|
logger *slog.Logger
|
|
}
|
|
|
|
func (c *lazyLLMClient) Extract(ctx context.Context, req memory.ExtractRequest) (memory.ExtractResponse, error) {
|
|
client, err := c.resolve(ctx)
|
|
if err != nil {
|
|
return memory.ExtractResponse{}, err
|
|
}
|
|
return client.Extract(ctx, req)
|
|
}
|
|
|
|
func (c *lazyLLMClient) Decide(ctx context.Context, req memory.DecideRequest) (memory.DecideResponse, error) {
|
|
client, err := c.resolve(ctx)
|
|
if err != nil {
|
|
return memory.DecideResponse{}, err
|
|
}
|
|
return client.Decide(ctx, req)
|
|
}
|
|
|
|
func (c *lazyLLMClient) Compact(ctx context.Context, req memory.CompactRequest) (memory.CompactResponse, error) {
|
|
client, err := c.resolve(ctx)
|
|
if err != nil {
|
|
return memory.CompactResponse{}, err
|
|
}
|
|
return client.Compact(ctx, req)
|
|
}
|
|
|
|
func (c *lazyLLMClient) DetectLanguage(ctx context.Context, text string) (string, error) {
|
|
client, err := c.resolve(ctx)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return client.DetectLanguage(ctx, text)
|
|
}
|
|
|
|
func (c *lazyLLMClient) resolve(ctx context.Context) (memory.LLM, error) {
|
|
if c.modelsService == nil || c.queries == nil {
|
|
return nil, fmt.Errorf("models service not configured")
|
|
}
|
|
memoryModel, memoryProvider, err := models.SelectMemoryModel(ctx, c.modelsService, c.queries)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
clientType := strings.ToLower(strings.TrimSpace(memoryProvider.ClientType))
|
|
switch clientType {
|
|
case "openai", "openai-compat", "azure", "mistral", "xai", "ollama", "dashscope":
|
|
// These providers support OpenAI-compatible /chat/completions endpoint
|
|
default:
|
|
return nil, fmt.Errorf("memory provider client type not supported: %s", memoryProvider.ClientType)
|
|
}
|
|
return memory.NewLLMClient(c.logger, memoryProvider.BaseUrl, memoryProvider.ApiKey, memoryModel.ModelID, c.timeout)
|
|
}
|
|
|
|
|