diff --git a/internal/memory/indexer.go b/internal/memory/indexer.go index 86180a3f..534f4ef1 100644 --- a/internal/memory/indexer.go +++ b/internal/memory/indexer.go @@ -247,6 +247,6 @@ func sparseWeightsToVector(weights map[uint32]float32) ([]uint32, []float32) { func termHash(term string) uint32 { hasher := fnv.New32a() - _, _ = hasher.Write([]byte(term)) + hasher.Write([]byte(term)) //nolint:errcheck // hash.Write never returns error return hasher.Sum32() & sparseDimMask } diff --git a/internal/memory/llm_client.go b/internal/memory/llm_client.go index 3b91c65a..2c2e9526 100644 --- a/internal/memory/llm_client.go +++ b/internal/memory/llm_client.go @@ -51,7 +51,7 @@ func (c *LLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractRes if len(req.Messages) == 0 { return ExtractResponse{}, fmt.Errorf("messages is required") } - parsedMessages := parseMessages(formatMessages(req.Messages)) + parsedMessages := strings.Join(formatMessages(req.Messages), "\n") systemPrompt, userPrompt := getFactRetrievalMessages(parsedMessages) content, err := c.callChat(ctx, []chatMessage{ {Role: "system", Content: systemPrompt}, @@ -122,7 +122,7 @@ func (c *LLMClient) Decide(ctx context.Context, req DecideRequest) (DecideRespon actions = append(actions, DecisionAction{ Event: event, - ID: normalizeID(item["id"]), + ID: asString(item["id"]), Text: text, OldMemory: asString(item["old_memory"]), }) @@ -242,14 +242,6 @@ func asString(value any) string { } } -func normalizeID(value any) string { - id := asString(value) - if id == "" { - return "" - } - return id -} - func normalizeMemoryItems(value any) []map[string]any { switch typed := value.(type) { case []any: diff --git a/internal/memory/prompts.go b/internal/memory/prompts.go index d3240ddd..3ad461c5 100644 --- a/internal/memory/prompts.go +++ b/internal/memory/prompts.go @@ -120,10 +120,6 @@ Before finalizing, verify the value is one of the allowed codes.` return systemPrompt, userPrompt } -func parseMessages(messages []string) string { - return strings.Join(messages, "\n") -} - func removeCodeBlocks(text string) string { return strings.ReplaceAll(strings.ReplaceAll(text, "```json", ""), "```", "") } diff --git a/internal/memory/qdrant_store.go b/internal/memory/qdrant_store.go index 2aada92c..1a8f3241 100644 --- a/internal/memory/qdrant_store.go +++ b/internal/memory/qdrant_store.go @@ -55,7 +55,7 @@ func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimens collection = "memory" } if dimension <= 0 && strings.TrimSpace(sparseVectorName) == "" { - dimension = 1536 + return nil, fmt.Errorf("embedding dimension is required") } cfg := &qdrant.Config{ @@ -455,22 +455,19 @@ func (s *QdrantStore) refreshCollectionSchema(ctx context.Context, vectors map[s s.vectorNames[name] = int(vec.GetSize()) } } - if len(vectors) == 0 { - goto sparseCheck - } - for name, dim := range vectors { - if existing, ok := s.vectorNames[name]; ok && existing == dim { - continue + if len(vectors) > 0 { + for name, dim := range vectors { + if existing, ok := s.vectorNames[name]; ok && existing == dim { + continue + } + return fmt.Errorf("collection missing vector %s (dim %d); migration required", name, dim) } - return fmt.Errorf("collection missing vector %s (dim %d); migration required", name, dim) } - } - if vectorsConfig == nil || vectorsConfig.GetParamsMap() == nil { + } else { s.usesNamedVectors = false s.vectorNames = nil } -sparseCheck: sparseConfig := params.GetSparseVectorsConfig() if s.sparseVectorName != "" { needsUpdate := false @@ -506,7 +503,7 @@ func (s *QdrantStore) ensurePayloadIndexes(ctx context.Context) error { if s.client == nil { return nil } - fields := []string{"botId", "runId"} + fields := []string{"bot_id", "run_id"} wait := true for _, field := range fields { _, err := s.client.CreateFieldIndex(ctx, &qdrant.CreateFieldIndexCollection{ diff --git a/internal/memory/service.go b/internal/memory/service.go index 3ff076e2..4e46dfaf 100644 --- a/internal/memory/service.go +++ b/internal/memory/service.go @@ -6,7 +6,6 @@ import ( "encoding/hex" "fmt" "log/slog" - "math" "sort" "strings" "time" @@ -287,7 +286,7 @@ func (s *Service) EmbedUpsert(ctx context.Context, req EmbedUpsertRequest) (Embe } vectorName := "" - if s.store != nil && s.store.usesNamedVectors { + if s.store.usesNamedVectors { vectorName = result.Model } @@ -336,19 +335,22 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er if existing == nil { return MemoryItem{}, fmt.Errorf("memory not found") } - if s.bm25 == nil { - return MemoryItem{}, fmt.Errorf("bm25 indexer not configured") - } payload := existing.Payload oldText := fmt.Sprint(payload["data"]) oldLang := fmt.Sprint(payload["lang"]) if oldLang == "" && strings.TrimSpace(oldText) != "" { - oldLang, _ = s.detectLanguage(ctx, oldText) + var detectErr error + oldLang, detectErr = s.detectLanguage(ctx, oldText) + if detectErr != nil { + s.logger.Warn("detect language failed for old text", slog.Any("error", detectErr)) + } } if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" { oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText) - if err == nil { + if err != nil { + s.logger.Warn("bm25 term frequencies failed", slog.String("lang", oldLang), slog.Any("error", err)) + } else { s.bm25.RemoveDocument(oldLang, oldFreq, oldLen) } } @@ -365,7 +367,7 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er payload["data"] = req.Memory payload["hash"] = hashMemory(req.Memory) - payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339) + payload["updated_at"] = time.Now().UTC().Format(time.RFC3339) payload["lang"] = newLang embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled @@ -413,10 +415,13 @@ func (s *Service) GetAll(ctx context.Context, req GetAllRequest) (SearchResponse filters[k] = v } if req.BotID != "" { - filters["botId"] = req.BotID + filters["bot_id"] = req.BotID + } + if req.AgentID != "" { + filters["agent_id"] = req.AgentID } if req.RunID != "" { - filters["runId"] = req.RunID + filters["run_id"] = req.RunID } if len(filters) == 0 { return SearchResponse{}, fmt.Errorf("bot_id, agent_id or run_id is required") @@ -449,10 +454,13 @@ func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteRe filters[k] = v } if req.BotID != "" { - filters["botId"] = req.BotID + filters["bot_id"] = req.BotID + } + if req.AgentID != "" { + filters["agent_id"] = req.AgentID } if req.RunID != "" { - filters["runId"] = req.RunID + filters["run_id"] = req.RunID } if len(filters) == 0 { return DeleteResponse{}, fmt.Errorf("bot_id, agent_id or run_id is required") @@ -487,6 +495,7 @@ func (s *Service) WarmupBM25(ctx context.Context, batchSize int) error { } termFreq, docLen, err := s.bm25.TermFrequencies(lang, text) if err != nil { + s.logger.Warn("bm25 warmup: term frequencies failed", slog.String("id", point.ID), slog.Any("error", err)) continue } s.bm25.AddDocument(lang, termFreq, docLen) @@ -609,11 +618,17 @@ func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[ oldText := fmt.Sprint(payload["data"]) oldLang := fmt.Sprint(payload["lang"]) if oldLang == "" && strings.TrimSpace(oldText) != "" { - oldLang, _ = s.detectLanguage(ctx, oldText) + var detectErr error + oldLang, detectErr = s.detectLanguage(ctx, oldText) + if detectErr != nil { + s.logger.Warn("detect language failed for old text", slog.Any("error", detectErr)) + } } if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" { oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText) - if err == nil { + if err != nil { + s.logger.Warn("bm25 term frequencies failed", slog.String("lang", oldLang), slog.Any("error", err)) + } else { s.bm25.RemoveDocument(oldLang, oldFreq, oldLen) } } @@ -628,7 +643,7 @@ func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[ sparseIndices, sparseValues := s.bm25.AddDocument(newLang, newFreq, newLen) payload["data"] = text payload["hash"] = hashMemory(text) - payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339) + payload["updated_at"] = time.Now().UTC().Format(time.RFC3339) payload["lang"] = newLang if metadata != nil { payload["metadata"] = mergeMetadata(payload["metadata"], metadata) @@ -676,11 +691,17 @@ func (s *Service) applyDelete(ctx context.Context, id string) (MemoryItem, error oldText := fmt.Sprint(existing.Payload["data"]) oldLang := fmt.Sprint(existing.Payload["lang"]) if oldLang == "" && strings.TrimSpace(oldText) != "" { - oldLang, _ = s.detectLanguage(ctx, oldText) + var detectErr error + oldLang, detectErr = s.detectLanguage(ctx, oldText) + if detectErr != nil { + s.logger.Warn("detect language failed for old text", slog.Any("error", detectErr)) + } } if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" { oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText) - if err == nil { + if err != nil { + s.logger.Warn("bm25 term frequencies failed", slog.String("lang", oldLang), slog.Any("error", err)) + } else { s.bm25.RemoveDocument(oldLang, oldFreq, oldLen) } } @@ -754,10 +775,13 @@ func buildFilters(req AddRequest) map[string]any { filters[key] = value } if req.BotID != "" { - filters["botId"] = req.BotID + filters["bot_id"] = req.BotID + } + if req.AgentID != "" { + filters["agent_id"] = req.AgentID } if req.RunID != "" { - filters["runId"] = req.RunID + filters["run_id"] = req.RunID } return filters } @@ -768,10 +792,13 @@ func buildSearchFilters(req SearchRequest) map[string]any { filters[key] = value } if req.BotID != "" { - filters["botId"] = req.BotID + filters["bot_id"] = req.BotID + } + if req.AgentID != "" { + filters["agent_id"] = req.AgentID } if req.RunID != "" { - filters["runId"] = req.RunID + filters["run_id"] = req.RunID } return filters } @@ -782,10 +809,13 @@ func buildEmbedFilters(req EmbedUpsertRequest) map[string]any { filters[key] = value } if req.BotID != "" { - filters["botId"] = req.BotID + filters["bot_id"] = req.BotID + } + if req.AgentID != "" { + filters["agent_id"] = req.AgentID } if req.RunID != "" { - filters["runId"] = req.RunID + filters["run_id"] = req.RunID } return filters } @@ -840,9 +870,9 @@ func buildPayload(text string, filters map[string]any, metadata map[string]any, createdAt = time.Now().UTC().Format(time.RFC3339) } payload := map[string]any{ - "data": text, - "hash": hashMemory(text), - "createdAt": createdAt, + "data": text, + "hash": hashMemory(text), + "created_at": createdAt, } if metadata != nil { payload["metadata"] = metadata @@ -865,16 +895,19 @@ func payloadToMemoryItem(id string, payload map[string]any) MemoryItem { if v, ok := payload["hash"].(string); ok { item.Hash = v } - if v, ok := payload["createdAt"].(string); ok { + if v, ok := payload["created_at"].(string); ok { item.CreatedAt = v } - if v, ok := payload["updatedAt"].(string); ok { + if v, ok := payload["updated_at"].(string); ok { item.UpdatedAt = v } - if v, ok := payload["botId"].(string); ok { + if v, ok := payload["bot_id"].(string); ok { item.BotID = v } - if v, ok := payload["runId"].(string); ok { + if v, ok := payload["agent_id"].(string); ok { + item.AgentID = v + } + if v, ok := payload["run_id"].(string); ok { item.RunID = v } if meta, ok := payload["metadata"].(map[string]any); ok { @@ -924,76 +957,33 @@ func mergeMetadata(base any, extra map[string]any) map[string]any { type rerankCandidate struct { ID string Payload map[string]any - Score float64 - Source string - Rank int } const ( - fusionModeRRF = "rrf" - fusionModeCombMNZ = "combmnz" - fusionMode = fusionModeRRF - rrfK = 60.0 + rrfK = 60.0 ) -func fuseByRankFusion(pointsBySource map[string][]qdrantPoint, scoresBySource map[string][]float64) []MemoryItem { +func fuseByRankFusion(pointsBySource map[string][]qdrantPoint, _ map[string][]float64) []MemoryItem { candidates := map[string]*rerankCandidate{} rrfScores := map[string]float64{} - combScores := map[string]float64{} - combCounts := map[string]int{} - for source, points := range pointsBySource { - scores := scoresBySource[source] - minScore := math.MaxFloat64 - maxScore := -math.MaxFloat64 + for _, points := range pointsBySource { for idx, point := range points { - if idx >= len(scores) { - continue - } - score := scores[idx] - if score < minScore { - minScore = score - } - if score > maxScore { - maxScore = score - } if _, ok := candidates[point.ID]; !ok { candidates[point.ID] = &rerankCandidate{ ID: point.ID, Payload: point.Payload, } } - } - if minScore == math.MaxFloat64 { - minScore = 0 - } - if maxScore == -math.MaxFloat64 { - maxScore = minScore - } - - for idx, point := range points { - if idx >= len(scores) { - continue - } - score := scores[idx] rank := float64(idx + 1) rrfScores[point.ID] += 1.0 / (rrfK + rank) - - scoreNorm := normalizeScore(score, minScore, maxScore) - combScores[point.ID] += scoreNorm - combCounts[point.ID]++ } } items := make([]MemoryItem, 0, len(candidates)) for id, candidate := range candidates { item := payloadToMemoryItem(candidate.ID, candidate.Payload) - switch fusionMode { - case fusionModeCombMNZ: - item.Score = combScores[id] * float64(combCounts[id]) - default: - item.Score = rrfScores[id] - } + item.Score = rrfScores[id] items = append(items, item) } @@ -1002,10 +992,3 @@ func fuseByRankFusion(pointsBySource map[string][]qdrantPoint, scoresBySource ma }) return items } - -func normalizeScore(score, minScore, maxScore float64) float64 { - if maxScore <= minScore { - return 1 - } - return (score - minScore) / (maxScore - minScore) -} diff --git a/internal/memory/types.go b/internal/memory/types.go index 606f102a..dc262726 100644 --- a/internal/memory/types.go +++ b/internal/memory/types.go @@ -88,13 +88,13 @@ type MemoryItem struct { ID string `json:"id"` Memory string `json:"memory"` Hash string `json:"hash,omitempty"` - CreatedAt string `json:"createdAt,omitempty"` - UpdatedAt string `json:"updatedAt,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` Score float64 `json:"score,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` - BotID string `json:"botId,omitempty"` - AgentID string `json:"agentId,omitempty"` - RunID string `json:"runId,omitempty"` + BotID string `json:"bot_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + RunID string `json:"run_id,omitempty"` } type SearchResponse struct {