refactor(memory): fix naming, dead code, error handling and AgentID support

Normalize JSON tags to snake_case, remove dead fusion branches and unused
struct fields, add proper error logging for BM25 operations, wire AgentID
through filters/store/retrieval, and replace goto with structured control flow.
This commit is contained in:
BBQ
2026-02-12 23:43:00 +08:00
parent c53d35740e
commit 57dd75ff52
6 changed files with 84 additions and 116 deletions
+1 -1
View File
@@ -247,6 +247,6 @@ func sparseWeightsToVector(weights map[uint32]float32) ([]uint32, []float32) {
func termHash(term string) uint32 {
hasher := fnv.New32a()
_, _ = hasher.Write([]byte(term))
hasher.Write([]byte(term)) //nolint:errcheck // hash.Write never returns error
return hasher.Sum32() & sparseDimMask
}
+2 -10
View File
@@ -51,7 +51,7 @@ func (c *LLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractRes
if len(req.Messages) == 0 {
return ExtractResponse{}, fmt.Errorf("messages is required")
}
parsedMessages := parseMessages(formatMessages(req.Messages))
parsedMessages := strings.Join(formatMessages(req.Messages), "\n")
systemPrompt, userPrompt := getFactRetrievalMessages(parsedMessages)
content, err := c.callChat(ctx, []chatMessage{
{Role: "system", Content: systemPrompt},
@@ -122,7 +122,7 @@ func (c *LLMClient) Decide(ctx context.Context, req DecideRequest) (DecideRespon
actions = append(actions, DecisionAction{
Event: event,
ID: normalizeID(item["id"]),
ID: asString(item["id"]),
Text: text,
OldMemory: asString(item["old_memory"]),
})
@@ -242,14 +242,6 @@ func asString(value any) string {
}
}
func normalizeID(value any) string {
id := asString(value)
if id == "" {
return ""
}
return id
}
func normalizeMemoryItems(value any) []map[string]any {
switch typed := value.(type) {
case []any:
-4
View File
@@ -120,10 +120,6 @@ Before finalizing, verify the value is one of the allowed codes.`
return systemPrompt, userPrompt
}
func parseMessages(messages []string) string {
return strings.Join(messages, "\n")
}
func removeCodeBlocks(text string) string {
return strings.ReplaceAll(strings.ReplaceAll(text, "```json", ""), "```", "")
}
+9 -12
View File
@@ -55,7 +55,7 @@ func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimens
collection = "memory"
}
if dimension <= 0 && strings.TrimSpace(sparseVectorName) == "" {
dimension = 1536
return nil, fmt.Errorf("embedding dimension is required")
}
cfg := &qdrant.Config{
@@ -455,22 +455,19 @@ func (s *QdrantStore) refreshCollectionSchema(ctx context.Context, vectors map[s
s.vectorNames[name] = int(vec.GetSize())
}
}
if len(vectors) == 0 {
goto sparseCheck
}
for name, dim := range vectors {
if existing, ok := s.vectorNames[name]; ok && existing == dim {
continue
if len(vectors) > 0 {
for name, dim := range vectors {
if existing, ok := s.vectorNames[name]; ok && existing == dim {
continue
}
return fmt.Errorf("collection missing vector %s (dim %d); migration required", name, dim)
}
return fmt.Errorf("collection missing vector %s (dim %d); migration required", name, dim)
}
}
if vectorsConfig == nil || vectorsConfig.GetParamsMap() == nil {
} else {
s.usesNamedVectors = false
s.vectorNames = nil
}
sparseCheck:
sparseConfig := params.GetSparseVectorsConfig()
if s.sparseVectorName != "" {
needsUpdate := false
@@ -506,7 +503,7 @@ func (s *QdrantStore) ensurePayloadIndexes(ctx context.Context) error {
if s.client == nil {
return nil
}
fields := []string{"botId", "runId"}
fields := []string{"bot_id", "run_id"}
wait := true
for _, field := range fields {
_, err := s.client.CreateFieldIndex(ctx, &qdrant.CreateFieldIndexCollection{
+67 -84
View File
@@ -6,7 +6,6 @@ import (
"encoding/hex"
"fmt"
"log/slog"
"math"
"sort"
"strings"
"time"
@@ -287,7 +286,7 @@ func (s *Service) EmbedUpsert(ctx context.Context, req EmbedUpsertRequest) (Embe
}
vectorName := ""
if s.store != nil && s.store.usesNamedVectors {
if s.store.usesNamedVectors {
vectorName = result.Model
}
@@ -336,19 +335,22 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er
if existing == nil {
return MemoryItem{}, fmt.Errorf("memory not found")
}
if s.bm25 == nil {
return MemoryItem{}, fmt.Errorf("bm25 indexer not configured")
}
payload := existing.Payload
oldText := fmt.Sprint(payload["data"])
oldLang := fmt.Sprint(payload["lang"])
if oldLang == "" && strings.TrimSpace(oldText) != "" {
oldLang, _ = s.detectLanguage(ctx, oldText)
var detectErr error
oldLang, detectErr = s.detectLanguage(ctx, oldText)
if detectErr != nil {
s.logger.Warn("detect language failed for old text", slog.Any("error", detectErr))
}
}
if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" {
oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText)
if err == nil {
if err != nil {
s.logger.Warn("bm25 term frequencies failed", slog.String("lang", oldLang), slog.Any("error", err))
} else {
s.bm25.RemoveDocument(oldLang, oldFreq, oldLen)
}
}
@@ -365,7 +367,7 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er
payload["data"] = req.Memory
payload["hash"] = hashMemory(req.Memory)
payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339)
payload["updated_at"] = time.Now().UTC().Format(time.RFC3339)
payload["lang"] = newLang
embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled
@@ -413,10 +415,13 @@ func (s *Service) GetAll(ctx context.Context, req GetAllRequest) (SearchResponse
filters[k] = v
}
if req.BotID != "" {
filters["botId"] = req.BotID
filters["bot_id"] = req.BotID
}
if req.AgentID != "" {
filters["agent_id"] = req.AgentID
}
if req.RunID != "" {
filters["runId"] = req.RunID
filters["run_id"] = req.RunID
}
if len(filters) == 0 {
return SearchResponse{}, fmt.Errorf("bot_id, agent_id or run_id is required")
@@ -449,10 +454,13 @@ func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteRe
filters[k] = v
}
if req.BotID != "" {
filters["botId"] = req.BotID
filters["bot_id"] = req.BotID
}
if req.AgentID != "" {
filters["agent_id"] = req.AgentID
}
if req.RunID != "" {
filters["runId"] = req.RunID
filters["run_id"] = req.RunID
}
if len(filters) == 0 {
return DeleteResponse{}, fmt.Errorf("bot_id, agent_id or run_id is required")
@@ -487,6 +495,7 @@ func (s *Service) WarmupBM25(ctx context.Context, batchSize int) error {
}
termFreq, docLen, err := s.bm25.TermFrequencies(lang, text)
if err != nil {
s.logger.Warn("bm25 warmup: term frequencies failed", slog.String("id", point.ID), slog.Any("error", err))
continue
}
s.bm25.AddDocument(lang, termFreq, docLen)
@@ -609,11 +618,17 @@ func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[
oldText := fmt.Sprint(payload["data"])
oldLang := fmt.Sprint(payload["lang"])
if oldLang == "" && strings.TrimSpace(oldText) != "" {
oldLang, _ = s.detectLanguage(ctx, oldText)
var detectErr error
oldLang, detectErr = s.detectLanguage(ctx, oldText)
if detectErr != nil {
s.logger.Warn("detect language failed for old text", slog.Any("error", detectErr))
}
}
if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" {
oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText)
if err == nil {
if err != nil {
s.logger.Warn("bm25 term frequencies failed", slog.String("lang", oldLang), slog.Any("error", err))
} else {
s.bm25.RemoveDocument(oldLang, oldFreq, oldLen)
}
}
@@ -628,7 +643,7 @@ func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[
sparseIndices, sparseValues := s.bm25.AddDocument(newLang, newFreq, newLen)
payload["data"] = text
payload["hash"] = hashMemory(text)
payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339)
payload["updated_at"] = time.Now().UTC().Format(time.RFC3339)
payload["lang"] = newLang
if metadata != nil {
payload["metadata"] = mergeMetadata(payload["metadata"], metadata)
@@ -676,11 +691,17 @@ func (s *Service) applyDelete(ctx context.Context, id string) (MemoryItem, error
oldText := fmt.Sprint(existing.Payload["data"])
oldLang := fmt.Sprint(existing.Payload["lang"])
if oldLang == "" && strings.TrimSpace(oldText) != "" {
oldLang, _ = s.detectLanguage(ctx, oldText)
var detectErr error
oldLang, detectErr = s.detectLanguage(ctx, oldText)
if detectErr != nil {
s.logger.Warn("detect language failed for old text", slog.Any("error", detectErr))
}
}
if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" {
oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText)
if err == nil {
if err != nil {
s.logger.Warn("bm25 term frequencies failed", slog.String("lang", oldLang), slog.Any("error", err))
} else {
s.bm25.RemoveDocument(oldLang, oldFreq, oldLen)
}
}
@@ -754,10 +775,13 @@ func buildFilters(req AddRequest) map[string]any {
filters[key] = value
}
if req.BotID != "" {
filters["botId"] = req.BotID
filters["bot_id"] = req.BotID
}
if req.AgentID != "" {
filters["agent_id"] = req.AgentID
}
if req.RunID != "" {
filters["runId"] = req.RunID
filters["run_id"] = req.RunID
}
return filters
}
@@ -768,10 +792,13 @@ func buildSearchFilters(req SearchRequest) map[string]any {
filters[key] = value
}
if req.BotID != "" {
filters["botId"] = req.BotID
filters["bot_id"] = req.BotID
}
if req.AgentID != "" {
filters["agent_id"] = req.AgentID
}
if req.RunID != "" {
filters["runId"] = req.RunID
filters["run_id"] = req.RunID
}
return filters
}
@@ -782,10 +809,13 @@ func buildEmbedFilters(req EmbedUpsertRequest) map[string]any {
filters[key] = value
}
if req.BotID != "" {
filters["botId"] = req.BotID
filters["bot_id"] = req.BotID
}
if req.AgentID != "" {
filters["agent_id"] = req.AgentID
}
if req.RunID != "" {
filters["runId"] = req.RunID
filters["run_id"] = req.RunID
}
return filters
}
@@ -840,9 +870,9 @@ func buildPayload(text string, filters map[string]any, metadata map[string]any,
createdAt = time.Now().UTC().Format(time.RFC3339)
}
payload := map[string]any{
"data": text,
"hash": hashMemory(text),
"createdAt": createdAt,
"data": text,
"hash": hashMemory(text),
"created_at": createdAt,
}
if metadata != nil {
payload["metadata"] = metadata
@@ -865,16 +895,19 @@ func payloadToMemoryItem(id string, payload map[string]any) MemoryItem {
if v, ok := payload["hash"].(string); ok {
item.Hash = v
}
if v, ok := payload["createdAt"].(string); ok {
if v, ok := payload["created_at"].(string); ok {
item.CreatedAt = v
}
if v, ok := payload["updatedAt"].(string); ok {
if v, ok := payload["updated_at"].(string); ok {
item.UpdatedAt = v
}
if v, ok := payload["botId"].(string); ok {
if v, ok := payload["bot_id"].(string); ok {
item.BotID = v
}
if v, ok := payload["runId"].(string); ok {
if v, ok := payload["agent_id"].(string); ok {
item.AgentID = v
}
if v, ok := payload["run_id"].(string); ok {
item.RunID = v
}
if meta, ok := payload["metadata"].(map[string]any); ok {
@@ -924,76 +957,33 @@ func mergeMetadata(base any, extra map[string]any) map[string]any {
type rerankCandidate struct {
ID string
Payload map[string]any
Score float64
Source string
Rank int
}
const (
fusionModeRRF = "rrf"
fusionModeCombMNZ = "combmnz"
fusionMode = fusionModeRRF
rrfK = 60.0
rrfK = 60.0
)
func fuseByRankFusion(pointsBySource map[string][]qdrantPoint, scoresBySource map[string][]float64) []MemoryItem {
func fuseByRankFusion(pointsBySource map[string][]qdrantPoint, _ map[string][]float64) []MemoryItem {
candidates := map[string]*rerankCandidate{}
rrfScores := map[string]float64{}
combScores := map[string]float64{}
combCounts := map[string]int{}
for source, points := range pointsBySource {
scores := scoresBySource[source]
minScore := math.MaxFloat64
maxScore := -math.MaxFloat64
for _, points := range pointsBySource {
for idx, point := range points {
if idx >= len(scores) {
continue
}
score := scores[idx]
if score < minScore {
minScore = score
}
if score > maxScore {
maxScore = score
}
if _, ok := candidates[point.ID]; !ok {
candidates[point.ID] = &rerankCandidate{
ID: point.ID,
Payload: point.Payload,
}
}
}
if minScore == math.MaxFloat64 {
minScore = 0
}
if maxScore == -math.MaxFloat64 {
maxScore = minScore
}
for idx, point := range points {
if idx >= len(scores) {
continue
}
score := scores[idx]
rank := float64(idx + 1)
rrfScores[point.ID] += 1.0 / (rrfK + rank)
scoreNorm := normalizeScore(score, minScore, maxScore)
combScores[point.ID] += scoreNorm
combCounts[point.ID]++
}
}
items := make([]MemoryItem, 0, len(candidates))
for id, candidate := range candidates {
item := payloadToMemoryItem(candidate.ID, candidate.Payload)
switch fusionMode {
case fusionModeCombMNZ:
item.Score = combScores[id] * float64(combCounts[id])
default:
item.Score = rrfScores[id]
}
item.Score = rrfScores[id]
items = append(items, item)
}
@@ -1002,10 +992,3 @@ func fuseByRankFusion(pointsBySource map[string][]qdrantPoint, scoresBySource ma
})
return items
}
func normalizeScore(score, minScore, maxScore float64) float64 {
if maxScore <= minScore {
return 1
}
return (score - minScore) / (maxScore - minScore)
}
+5 -5
View File
@@ -88,13 +88,13 @@ type MemoryItem struct {
ID string `json:"id"`
Memory string `json:"memory"`
Hash string `json:"hash,omitempty"`
CreatedAt string `json:"createdAt,omitempty"`
UpdatedAt string `json:"updatedAt,omitempty"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
Score float64 `json:"score,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
BotID string `json:"botId,omitempty"`
AgentID string `json:"agentId,omitempty"`
RunID string `json:"runId,omitempty"`
BotID string `json:"bot_id,omitempty"`
AgentID string `json:"agent_id,omitempty"`
RunID string `json:"run_id,omitempty"`
}
type SearchResponse struct {