mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
refactor: use sparse vector for memory
This commit is contained in:
+322
-78
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/qdrant/go-client/qdrant"
|
||||
|
||||
"github.com/memohai/memoh/internal/embeddings"
|
||||
)
|
||||
@@ -21,17 +22,19 @@ type Service struct {
|
||||
embedder embeddings.Embedder
|
||||
store *QdrantStore
|
||||
resolver *embeddings.Resolver
|
||||
bm25 *BM25Indexer
|
||||
logger *slog.Logger
|
||||
defaultTextModelID string
|
||||
defaultMultimodalModelID string
|
||||
}
|
||||
|
||||
func NewService(log *slog.Logger, llm LLM, embedder embeddings.Embedder, store *QdrantStore, resolver *embeddings.Resolver, defaultTextModelID, defaultMultimodalModelID string) *Service {
|
||||
func NewService(log *slog.Logger, llm LLM, embedder embeddings.Embedder, store *QdrantStore, resolver *embeddings.Resolver, bm25 *BM25Indexer, defaultTextModelID, defaultMultimodalModelID string) *Service {
|
||||
return &Service{
|
||||
llm: llm,
|
||||
embedder: embedder,
|
||||
store: store,
|
||||
resolver: resolver,
|
||||
bm25: bm25,
|
||||
logger: log.With(slog.String("service", "memory")),
|
||||
defaultTextModelID: defaultTextModelID,
|
||||
defaultMultimodalModelID: defaultMultimodalModelID,
|
||||
@@ -42,15 +45,16 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro
|
||||
if req.Message == "" && len(req.Messages) == 0 {
|
||||
return SearchResponse{}, fmt.Errorf("message or messages is required")
|
||||
}
|
||||
if req.UserID == "" && req.AgentID == "" && req.RunID == "" {
|
||||
return SearchResponse{}, fmt.Errorf("user_id, agent_id or run_id is required")
|
||||
if req.UserID == "" {
|
||||
return SearchResponse{}, fmt.Errorf("user_id is required")
|
||||
}
|
||||
|
||||
messages := normalizeMessages(req)
|
||||
filters := buildFilters(req)
|
||||
|
||||
embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled
|
||||
if req.Infer != nil && !*req.Infer {
|
||||
return s.addRawMessages(ctx, messages, filters, req.Metadata)
|
||||
return s.addRawMessages(ctx, messages, filters, req.Metadata, embeddingEnabled)
|
||||
}
|
||||
|
||||
extractResp, err := s.llm.Extract(ctx, ExtractRequest{
|
||||
@@ -95,7 +99,7 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro
|
||||
for _, action := range actions {
|
||||
switch strings.ToUpper(action.Event) {
|
||||
case "ADD":
|
||||
item, err := s.applyAdd(ctx, action.Text, filters, req.Metadata)
|
||||
item, err := s.applyAdd(ctx, action.Text, filters, req.Metadata, embeddingEnabled)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
@@ -104,7 +108,7 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro
|
||||
})
|
||||
results = append(results, item)
|
||||
case "UPDATE":
|
||||
item, err := s.applyUpdate(ctx, action.ID, action.Text, filters, req.Metadata)
|
||||
item, err := s.applyUpdate(ctx, action.ID, action.Text, filters, req.Metadata, embeddingEnabled)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
@@ -134,19 +138,19 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse
|
||||
if strings.TrimSpace(req.Query) == "" {
|
||||
return SearchResponse{}, fmt.Errorf("query is required")
|
||||
}
|
||||
if s.store == nil {
|
||||
return SearchResponse{}, fmt.Errorf("qdrant store not configured")
|
||||
}
|
||||
filters := buildSearchFilters(req)
|
||||
modality := ""
|
||||
if raw, ok := filters["modality"].(string); ok {
|
||||
modality = strings.ToLower(strings.TrimSpace(raw))
|
||||
}
|
||||
|
||||
var (
|
||||
vector []float32
|
||||
store *QdrantStore
|
||||
vectorName string
|
||||
err error
|
||||
)
|
||||
embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled
|
||||
if modality == embeddings.TypeMultimodal {
|
||||
if !embeddingEnabled {
|
||||
return SearchResponse{}, fmt.Errorf("embedding is disabled")
|
||||
}
|
||||
if s.resolver == nil {
|
||||
return SearchResponse{}, fmt.Errorf("embeddings resolver not configured")
|
||||
}
|
||||
@@ -159,24 +163,79 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
vector = result.Embedding
|
||||
store = s.store
|
||||
vectorName = s.vectorNameForMultimodal()
|
||||
} else {
|
||||
vector, err = s.embedder.Embed(ctx, req.Query)
|
||||
vectorName := s.vectorNameForMultimodal()
|
||||
if len(req.Sources) == 0 {
|
||||
points, scores, err := s.store.Search(ctx, result.Embedding, req.Limit, filters, vectorName)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
results := make([]MemoryItem, 0, len(points))
|
||||
for idx, point := range points {
|
||||
item := payloadToMemoryItem(point.ID, point.Payload)
|
||||
if idx < len(scores) {
|
||||
item.Score = scores[idx]
|
||||
}
|
||||
results = append(results, item)
|
||||
}
|
||||
return SearchResponse{Results: results}, nil
|
||||
}
|
||||
pointsBySource, scoresBySource, err := s.store.SearchBySources(ctx, result.Embedding, req.Limit, filters, req.Sources, vectorName)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
store = s.store
|
||||
vectorName = s.vectorNameForText()
|
||||
results := fuseByRankFusion(pointsBySource, scoresBySource)
|
||||
return SearchResponse{Results: results}, nil
|
||||
}
|
||||
|
||||
if len(req.Sources) == 0 {
|
||||
points, scores, err := store.Search(ctx, vector, req.Limit, filters, vectorName)
|
||||
if embeddingEnabled {
|
||||
if s.embedder == nil {
|
||||
return SearchResponse{}, fmt.Errorf("embedder not configured")
|
||||
}
|
||||
vector, err := s.embedder.Embed(ctx, req.Query)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
vectorName := s.vectorNameForText()
|
||||
if len(req.Sources) == 0 {
|
||||
points, scores, err := s.store.Search(ctx, vector, req.Limit, filters, vectorName)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
results := make([]MemoryItem, 0, len(points))
|
||||
for idx, point := range points {
|
||||
item := payloadToMemoryItem(point.ID, point.Payload)
|
||||
if idx < len(scores) {
|
||||
item.Score = scores[idx]
|
||||
}
|
||||
results = append(results, item)
|
||||
}
|
||||
return SearchResponse{Results: results}, nil
|
||||
}
|
||||
pointsBySource, scoresBySource, err := s.store.SearchBySources(ctx, vector, req.Limit, filters, req.Sources, vectorName)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
results := fuseByRankFusion(pointsBySource, scoresBySource)
|
||||
return SearchResponse{Results: results}, nil
|
||||
}
|
||||
|
||||
if s.bm25 == nil {
|
||||
return SearchResponse{}, fmt.Errorf("bm25 indexer not configured")
|
||||
}
|
||||
lang, err := s.detectLanguage(ctx, req.Query)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
termFreq, _, err := s.bm25.TermFrequencies(lang, req.Query)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
indices, values := s.bm25.BuildQueryVector(lang, termFreq)
|
||||
if len(req.Sources) == 0 {
|
||||
points, scores, err := s.store.SearchSparse(ctx, indices, values, req.Limit, filters)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
results := make([]MemoryItem, 0, len(points))
|
||||
for idx, point := range points {
|
||||
item := payloadToMemoryItem(point.ID, point.Payload)
|
||||
@@ -187,8 +246,7 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse
|
||||
}
|
||||
return SearchResponse{Results: results}, nil
|
||||
}
|
||||
|
||||
pointsBySource, scoresBySource, err := store.SearchBySources(ctx, vector, req.Limit, filters, req.Sources, vectorName)
|
||||
pointsBySource, scoresBySource, err := s.store.SearchSparseBySources(ctx, indices, values, req.Limit, filters, req.Sources)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
@@ -200,8 +258,8 @@ func (s *Service) EmbedUpsert(ctx context.Context, req EmbedUpsertRequest) (Embe
|
||||
if s.resolver == nil {
|
||||
return EmbedUpsertResponse{}, fmt.Errorf("embeddings resolver not configured")
|
||||
}
|
||||
if req.UserID == "" && req.AgentID == "" && req.RunID == "" {
|
||||
return EmbedUpsertResponse{}, fmt.Errorf("user_id, agent_id or run_id is required")
|
||||
if req.UserID == "" {
|
||||
return EmbedUpsertResponse{}, fmt.Errorf("user_id is required")
|
||||
}
|
||||
req.Type = strings.TrimSpace(req.Type)
|
||||
req.Provider = strings.TrimSpace(req.Provider)
|
||||
@@ -264,6 +322,12 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er
|
||||
if strings.TrimSpace(req.Memory) == "" {
|
||||
return MemoryItem{}, fmt.Errorf("memory is required")
|
||||
}
|
||||
if s.store == nil {
|
||||
return MemoryItem{}, fmt.Errorf("qdrant store not configured")
|
||||
}
|
||||
if s.bm25 == nil {
|
||||
return MemoryItem{}, fmt.Errorf("bm25 indexer not configured")
|
||||
}
|
||||
|
||||
existing, err := s.store.Get(ctx, req.MemoryID)
|
||||
if err != nil {
|
||||
@@ -272,22 +336,58 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er
|
||||
if existing == nil {
|
||||
return MemoryItem{}, fmt.Errorf("memory not found")
|
||||
}
|
||||
if s.bm25 == nil {
|
||||
return MemoryItem{}, fmt.Errorf("bm25 indexer not configured")
|
||||
}
|
||||
|
||||
payload := existing.Payload
|
||||
payload["data"] = req.Memory
|
||||
payload["hash"] = hashMemory(req.Memory)
|
||||
payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339)
|
||||
oldText := fmt.Sprint(payload["data"])
|
||||
oldLang := fmt.Sprint(payload["lang"])
|
||||
if oldLang == "" && strings.TrimSpace(oldText) != "" {
|
||||
oldLang, _ = s.detectLanguage(ctx, oldText)
|
||||
}
|
||||
if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" {
|
||||
oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText)
|
||||
if err == nil {
|
||||
s.bm25.RemoveDocument(oldLang, oldFreq, oldLen)
|
||||
}
|
||||
}
|
||||
|
||||
vector, err := s.embedder.Embed(ctx, req.Memory)
|
||||
newLang, err := s.detectLanguage(ctx, req.Memory)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{{
|
||||
ID: req.MemoryID,
|
||||
Vector: vector,
|
||||
VectorName: s.vectorNameForText(),
|
||||
Payload: payload,
|
||||
}}); err != nil {
|
||||
newFreq, newLen, err := s.bm25.TermFrequencies(newLang, req.Memory)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
sparseIndices, sparseValues := s.bm25.AddDocument(newLang, newFreq, newLen)
|
||||
|
||||
payload["data"] = req.Memory
|
||||
payload["hash"] = hashMemory(req.Memory)
|
||||
payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339)
|
||||
payload["lang"] = newLang
|
||||
|
||||
embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled
|
||||
point := qdrantPoint{
|
||||
ID: req.MemoryID,
|
||||
SparseIndices: sparseIndices,
|
||||
SparseValues: sparseValues,
|
||||
SparseVectorName: s.store.sparseVectorName,
|
||||
Payload: payload,
|
||||
}
|
||||
if embeddingEnabled {
|
||||
if s.embedder == nil {
|
||||
return MemoryItem{}, fmt.Errorf("embedder not configured")
|
||||
}
|
||||
vector, err := s.embedder.Embed(ctx, req.Memory)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
point.Vector = vector
|
||||
point.VectorName = s.vectorNameForText()
|
||||
}
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
return payloadToMemoryItem(req.MemoryID, payload), nil
|
||||
@@ -312,14 +412,11 @@ func (s *Service) GetAll(ctx context.Context, req GetAllRequest) (SearchResponse
|
||||
if req.UserID != "" {
|
||||
filters["userId"] = req.UserID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agentId"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
}
|
||||
if len(filters) == 0 {
|
||||
return SearchResponse{}, fmt.Errorf("user_id, agent_id or run_id is required")
|
||||
return SearchResponse{}, fmt.Errorf("user_id is required")
|
||||
}
|
||||
|
||||
points, err := s.store.List(ctx, req.Limit, filters)
|
||||
@@ -348,14 +445,11 @@ func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteRe
|
||||
if req.UserID != "" {
|
||||
filters["userId"] = req.UserID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agentId"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
}
|
||||
if len(filters) == 0 {
|
||||
return DeleteResponse{}, fmt.Errorf("user_id, agent_id or run_id is required")
|
||||
return DeleteResponse{}, fmt.Errorf("user_id is required")
|
||||
}
|
||||
if err := s.store.DeleteAll(ctx, filters); err != nil {
|
||||
return DeleteResponse{}, err
|
||||
@@ -363,10 +457,46 @@ func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteRe
|
||||
return DeleteResponse{Message: "Memories deleted successfully!"}, nil
|
||||
}
|
||||
|
||||
func (s *Service) addRawMessages(ctx context.Context, messages []Message, filters map[string]interface{}, metadata map[string]interface{}) (SearchResponse, error) {
|
||||
func (s *Service) WarmupBM25(ctx context.Context, batchSize int) error {
|
||||
if s.bm25 == nil || s.store == nil {
|
||||
return nil
|
||||
}
|
||||
var offset *qdrant.PointId
|
||||
for {
|
||||
points, next, err := s.store.Scroll(ctx, batchSize, nil, offset)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(points) == 0 {
|
||||
break
|
||||
}
|
||||
for _, point := range points {
|
||||
text := fmt.Sprint(point.Payload["data"])
|
||||
if strings.TrimSpace(text) == "" {
|
||||
continue
|
||||
}
|
||||
lang := fmt.Sprint(point.Payload["lang"])
|
||||
if lang == "" {
|
||||
lang = fallbackLanguageCode(text)
|
||||
}
|
||||
termFreq, docLen, err := s.bm25.TermFrequencies(lang, text)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
s.bm25.AddDocument(lang, termFreq, docLen)
|
||||
}
|
||||
if next == nil {
|
||||
break
|
||||
}
|
||||
offset = next
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) addRawMessages(ctx context.Context, messages []Message, filters map[string]interface{}, metadata map[string]interface{}, embeddingEnabled bool) (SearchResponse, error) {
|
||||
results := make([]MemoryItem, 0, len(messages))
|
||||
for _, message := range messages {
|
||||
item, err := s.applyAdd(ctx, message.Content, filters, metadata)
|
||||
item, err := s.applyAdd(ctx, message.Content, filters, metadata, embeddingEnabled)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
@@ -381,11 +511,19 @@ func (s *Service) addRawMessages(ctx context.Context, messages []Message, filter
|
||||
func (s *Service) collectCandidates(ctx context.Context, facts []string, filters map[string]interface{}) ([]CandidateMemory, error) {
|
||||
unique := map[string]CandidateMemory{}
|
||||
for _, fact := range facts {
|
||||
vector, err := s.embedder.Embed(ctx, fact)
|
||||
if s.bm25 == nil {
|
||||
return nil, fmt.Errorf("bm25 indexer not configured")
|
||||
}
|
||||
lang, err := s.detectLanguage(ctx, fact)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
points, _, err := s.store.Search(ctx, vector, 5, filters, s.vectorNameForText())
|
||||
termFreq, _, err := s.bm25.TermFrequencies(lang, fact)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
indices, values := s.bm25.BuildQueryVector(lang, termFreq)
|
||||
points, _, err := s.store.SearchSparse(ctx, indices, values, 5, filters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -406,25 +544,50 @@ func (s *Service) collectCandidates(ctx context.Context, facts []string, filters
|
||||
return candidates, nil
|
||||
}
|
||||
|
||||
func (s *Service) applyAdd(ctx context.Context, text string, filters map[string]interface{}, metadata map[string]interface{}) (MemoryItem, error) {
|
||||
vector, err := s.embedder.Embed(ctx, text)
|
||||
func (s *Service) applyAdd(ctx context.Context, text string, filters map[string]interface{}, metadata map[string]interface{}, embeddingEnabled bool) (MemoryItem, error) {
|
||||
if s.store == nil {
|
||||
return MemoryItem{}, fmt.Errorf("qdrant store not configured")
|
||||
}
|
||||
if s.bm25 == nil {
|
||||
return MemoryItem{}, fmt.Errorf("bm25 indexer not configured")
|
||||
}
|
||||
lang, err := s.detectLanguage(ctx, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
termFreq, docLen, err := s.bm25.TermFrequencies(lang, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
sparseIndices, sparseValues := s.bm25.AddDocument(lang, termFreq, docLen)
|
||||
id := uuid.NewString()
|
||||
payload := buildPayload(text, filters, metadata, "")
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{{
|
||||
ID: id,
|
||||
Vector: vector,
|
||||
VectorName: s.vectorNameForText(),
|
||||
Payload: payload,
|
||||
}}); err != nil {
|
||||
payload["lang"] = lang
|
||||
point := qdrantPoint{
|
||||
ID: id,
|
||||
SparseIndices: sparseIndices,
|
||||
SparseValues: sparseValues,
|
||||
SparseVectorName: s.store.sparseVectorName,
|
||||
Payload: payload,
|
||||
}
|
||||
if embeddingEnabled {
|
||||
if s.embedder == nil {
|
||||
return MemoryItem{}, fmt.Errorf("embedder not configured")
|
||||
}
|
||||
vector, err := s.embedder.Embed(ctx, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
point.Vector = vector
|
||||
point.VectorName = s.vectorNameForText()
|
||||
}
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
return payloadToMemoryItem(id, payload), nil
|
||||
}
|
||||
|
||||
func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[string]interface{}, metadata map[string]interface{}) (MemoryItem, error) {
|
||||
func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[string]interface{}, metadata map[string]interface{}, embeddingEnabled bool) (MemoryItem, error) {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return MemoryItem{}, fmt.Errorf("update action missing id")
|
||||
}
|
||||
@@ -437,25 +600,55 @@ func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[
|
||||
}
|
||||
|
||||
payload := existing.Payload
|
||||
oldText := fmt.Sprint(payload["data"])
|
||||
oldLang := fmt.Sprint(payload["lang"])
|
||||
if oldLang == "" && strings.TrimSpace(oldText) != "" {
|
||||
oldLang, _ = s.detectLanguage(ctx, oldText)
|
||||
}
|
||||
if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" {
|
||||
oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText)
|
||||
if err == nil {
|
||||
s.bm25.RemoveDocument(oldLang, oldFreq, oldLen)
|
||||
}
|
||||
}
|
||||
newLang, err := s.detectLanguage(ctx, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
newFreq, newLen, err := s.bm25.TermFrequencies(newLang, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
sparseIndices, sparseValues := s.bm25.AddDocument(newLang, newFreq, newLen)
|
||||
payload["data"] = text
|
||||
payload["hash"] = hashMemory(text)
|
||||
payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339)
|
||||
payload["lang"] = newLang
|
||||
if metadata != nil {
|
||||
payload["metadata"] = mergeMetadata(payload["metadata"], metadata)
|
||||
}
|
||||
if filters != nil {
|
||||
applyFiltersToPayload(payload, filters)
|
||||
}
|
||||
vector, err := s.embedder.Embed(ctx, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
point := qdrantPoint{
|
||||
ID: id,
|
||||
SparseIndices: sparseIndices,
|
||||
SparseValues: sparseValues,
|
||||
SparseVectorName: s.store.sparseVectorName,
|
||||
Payload: payload,
|
||||
}
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{{
|
||||
ID: id,
|
||||
Vector: vector,
|
||||
VectorName: s.vectorNameForText(),
|
||||
Payload: payload,
|
||||
}}); err != nil {
|
||||
if embeddingEnabled {
|
||||
if s.embedder == nil {
|
||||
return MemoryItem{}, fmt.Errorf("embedder not configured")
|
||||
}
|
||||
vector, err := s.embedder.Embed(ctx, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
point.Vector = vector
|
||||
point.VectorName = s.vectorNameForText()
|
||||
}
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
return payloadToMemoryItem(id, payload), nil
|
||||
@@ -473,6 +666,19 @@ func (s *Service) applyDelete(ctx context.Context, id string) (MemoryItem, error
|
||||
return MemoryItem{}, fmt.Errorf("memory not found")
|
||||
}
|
||||
item := payloadToMemoryItem(id, existing.Payload)
|
||||
if s.bm25 != nil {
|
||||
oldText := fmt.Sprint(existing.Payload["data"])
|
||||
oldLang := fmt.Sprint(existing.Payload["lang"])
|
||||
if oldLang == "" && strings.TrimSpace(oldText) != "" {
|
||||
oldLang, _ = s.detectLanguage(ctx, oldText)
|
||||
}
|
||||
if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" {
|
||||
oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText)
|
||||
if err == nil {
|
||||
s.bm25.RemoveDocument(oldLang, oldFreq, oldLen)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := s.store.Delete(ctx, id); err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
@@ -486,6 +692,56 @@ func normalizeMessages(req AddRequest) []Message {
|
||||
return []Message{{Role: "user", Content: req.Message}}
|
||||
}
|
||||
|
||||
func (s *Service) detectLanguage(ctx context.Context, text string) (string, error) {
|
||||
if s.llm == nil {
|
||||
return "", fmt.Errorf("language detector not configured")
|
||||
}
|
||||
lang, err := s.llm.DetectLanguage(ctx, text)
|
||||
if err == nil && lang != "" {
|
||||
return lang, nil
|
||||
}
|
||||
fallback := fallbackLanguageCode(text)
|
||||
if s.logger != nil {
|
||||
s.logger.Warn("language detection failed; using fallback", slog.Any("error", err), slog.String("fallback", fallback))
|
||||
}
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func fallbackLanguageCode(text string) string {
|
||||
for _, r := range text {
|
||||
if isCJKRune(r) {
|
||||
return "cjk"
|
||||
}
|
||||
}
|
||||
return "en"
|
||||
}
|
||||
|
||||
func isCJKRune(r rune) bool {
|
||||
switch {
|
||||
case r >= 0x4E00 && r <= 0x9FFF: // CJK Unified Ideographs
|
||||
return true
|
||||
case r >= 0x3400 && r <= 0x4DBF: // CJK Unified Ideographs Extension A
|
||||
return true
|
||||
case r >= 0x20000 && r <= 0x2A6DF: // CJK Unified Ideographs Extension B
|
||||
return true
|
||||
case r >= 0x2A700 && r <= 0x2B73F: // CJK Unified Ideographs Extension C
|
||||
return true
|
||||
case r >= 0x2B740 && r <= 0x2B81F: // CJK Unified Ideographs Extension D
|
||||
return true
|
||||
case r >= 0x2B820 && r <= 0x2CEAF: // CJK Unified Ideographs Extension E
|
||||
return true
|
||||
case r >= 0x2CEB0 && r <= 0x2EBEF: // CJK Unified Ideographs Extension F
|
||||
return true
|
||||
case r >= 0x3000 && r <= 0x303F: // CJK Symbols and Punctuation
|
||||
return true
|
||||
case r >= 0x3040 && r <= 0x30FF: // Hiragana/Katakana
|
||||
return true
|
||||
case r >= 0xAC00 && r <= 0xD7AF: // Hangul Syllables
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildFilters(req AddRequest) map[string]interface{} {
|
||||
filters := map[string]interface{}{}
|
||||
for key, value := range req.Filters {
|
||||
@@ -494,9 +750,6 @@ func buildFilters(req AddRequest) map[string]interface{} {
|
||||
if req.UserID != "" {
|
||||
filters["userId"] = req.UserID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agentId"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
}
|
||||
@@ -511,9 +764,6 @@ func buildSearchFilters(req SearchRequest) map[string]interface{} {
|
||||
if req.UserID != "" {
|
||||
filters["userId"] = req.UserID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agentId"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
}
|
||||
@@ -528,9 +778,6 @@ func buildEmbedFilters(req EmbedUpsertRequest) map[string]interface{} {
|
||||
if req.UserID != "" {
|
||||
filters["userId"] = req.UserID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agentId"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
}
|
||||
@@ -621,9 +868,6 @@ func payloadToMemoryItem(id string, payload map[string]interface{}) MemoryItem {
|
||||
if v, ok := payload["userId"].(string); ok {
|
||||
item.UserID = v
|
||||
}
|
||||
if v, ok := payload["agentId"].(string); ok {
|
||||
item.AgentID = v
|
||||
}
|
||||
if v, ok := payload["runId"].(string); ok {
|
||||
item.RunID = v
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user