mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat: embedding router
This commit is contained in:
+97
-17
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -10,12 +11,68 @@ import (
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
ctr "github.com/memohai/memoh/internal/containerd"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/embeddings"
|
||||
"github.com/memohai/memoh/internal/handlers"
|
||||
"github.com/memohai/memoh/internal/mcp"
|
||||
"github.com/memohai/memoh/internal/memory"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/server"
|
||||
)
|
||||
|
||||
type resolverTextEmbedder struct {
|
||||
resolver *embeddings.Resolver
|
||||
modelID string
|
||||
dims int
|
||||
}
|
||||
|
||||
func (e *resolverTextEmbedder) Embed(ctx context.Context, input string) ([]float32, error) {
|
||||
result, err := e.resolver.Embed(ctx, embeddings.Request{
|
||||
Type: embeddings.TypeText,
|
||||
Model: e.modelID,
|
||||
Input: embeddings.Input{Text: input},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.Embedding, nil
|
||||
}
|
||||
|
||||
func (e *resolverTextEmbedder) Dimensions() int {
|
||||
return e.dims
|
||||
}
|
||||
|
||||
func collectEmbeddingVectors(ctx context.Context, service *models.Service) (map[string]int, models.GetResponse, models.GetResponse, error) {
|
||||
candidates, err := service.ListByType(ctx, models.ModelTypeEmbedding)
|
||||
if err != nil {
|
||||
return nil, models.GetResponse{}, models.GetResponse{}, err
|
||||
}
|
||||
vectors := map[string]int{}
|
||||
var textModel models.GetResponse
|
||||
var multimodalModel models.GetResponse
|
||||
for _, model := range candidates {
|
||||
if model.Dimensions > 0 && model.ModelID != "" {
|
||||
vectors[model.ModelID] = model.Dimensions
|
||||
}
|
||||
if model.IsMultimodal {
|
||||
if multimodalModel.ModelID == "" {
|
||||
multimodalModel = model
|
||||
}
|
||||
continue
|
||||
}
|
||||
if textModel.ModelID == "" {
|
||||
textModel = model
|
||||
}
|
||||
}
|
||||
if textModel.ModelID == "" {
|
||||
return vectors, textModel, multimodalModel, fmt.Errorf("no text embedding model configured")
|
||||
}
|
||||
if multimodalModel.ModelID == "" {
|
||||
return vectors, textModel, multimodalModel, fmt.Errorf("no multimodal embedding model configured")
|
||||
}
|
||||
return vectors, textModel, multimodalModel, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
cfgPath := os.Getenv("CONFIG_PATH")
|
||||
@@ -53,6 +110,8 @@ func main() {
|
||||
}
|
||||
defer conn.Close()
|
||||
manager.WithDB(conn)
|
||||
queries := dbsqlc.New(conn)
|
||||
modelsService := models.NewService(queries)
|
||||
|
||||
pingHandler := handlers.NewPingHandler()
|
||||
authHandler := handlers.NewAuthHandler(conn, cfg.Auth.JWTSecret, jwtExpiresIn)
|
||||
@@ -62,28 +121,49 @@ func main() {
|
||||
cfg.Memory.Model,
|
||||
time.Duration(cfg.Memory.TimeoutSeconds)*time.Second,
|
||||
)
|
||||
embedder := memory.NewOpenAIEmbedder(
|
||||
cfg.Embeddings.OpenAIAPIKey,
|
||||
cfg.Embeddings.OpenAIBaseURL,
|
||||
cfg.Embeddings.Model,
|
||||
cfg.Embeddings.Dimensions,
|
||||
time.Duration(cfg.Embeddings.TimeoutSeconds)*time.Second,
|
||||
)
|
||||
store, err := memory.NewQdrantStore(
|
||||
cfg.Qdrant.BaseURL,
|
||||
cfg.Qdrant.APIKey,
|
||||
cfg.Qdrant.Collection,
|
||||
cfg.Embeddings.Dimensions,
|
||||
time.Duration(cfg.Qdrant.TimeoutSeconds)*time.Second,
|
||||
)
|
||||
resolver := embeddings.NewResolver(modelsService, queries, 10*time.Second)
|
||||
vectors, textModel, multimodalModel, err := collectEmbeddingVectors(ctx, modelsService)
|
||||
if err != nil {
|
||||
log.Fatalf("qdrant init: %v", err)
|
||||
log.Fatalf("embedding models: %v", err)
|
||||
}
|
||||
memoryService := memory.NewService(llmClient, embedder, store)
|
||||
if textModel.Dimensions <= 0 {
|
||||
log.Fatalf("text embedding dimensions not configured")
|
||||
}
|
||||
textEmbedder := &resolverTextEmbedder{
|
||||
resolver: resolver,
|
||||
modelID: textModel.ModelID,
|
||||
dims: textModel.Dimensions,
|
||||
}
|
||||
var store *memory.QdrantStore
|
||||
if len(vectors) > 0 {
|
||||
store, err = memory.NewQdrantStoreWithVectors(
|
||||
cfg.Qdrant.BaseURL,
|
||||
cfg.Qdrant.APIKey,
|
||||
cfg.Qdrant.Collection,
|
||||
vectors,
|
||||
time.Duration(cfg.Qdrant.TimeoutSeconds)*time.Second,
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatalf("qdrant named vectors init: %v", err)
|
||||
}
|
||||
} else {
|
||||
store, err = memory.NewQdrantStore(
|
||||
cfg.Qdrant.BaseURL,
|
||||
cfg.Qdrant.APIKey,
|
||||
cfg.Qdrant.Collection,
|
||||
textModel.Dimensions,
|
||||
time.Duration(cfg.Qdrant.TimeoutSeconds)*time.Second,
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatalf("qdrant init: %v", err)
|
||||
}
|
||||
}
|
||||
memoryService := memory.NewService(llmClient, textEmbedder, store, resolver, textModel.ModelID, multimodalModel.ModelID)
|
||||
memoryHandler := handlers.NewMemoryHandler(memoryService)
|
||||
embeddingsHandler := handlers.NewEmbeddingsHandler(modelsService, queries)
|
||||
fsHandler := handlers.NewFSHandler(service, manager, cfg.MCP, cfg.Containerd.Namespace)
|
||||
swaggerHandler := handlers.NewSwaggerHandler()
|
||||
srv := server.NewServer(addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, fsHandler, swaggerHandler)
|
||||
srv := server.NewServer(addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, fsHandler, swaggerHandler)
|
||||
|
||||
if err := srv.Start(); err != nil {
|
||||
log.Fatalf("server failed: %v", err)
|
||||
|
||||
Reference in New Issue
Block a user