mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
refactor(memory): fix naming, dead code, error handling and AgentID support
Normalize JSON tags to snake_case, remove dead fusion branches and unused struct fields, add proper error logging for BM25 operations, wire AgentID through filters/store/retrieval, and replace goto with structured control flow.
This commit is contained in:
@@ -247,6 +247,6 @@ func sparseWeightsToVector(weights map[uint32]float32) ([]uint32, []float32) {
|
||||
|
||||
func termHash(term string) uint32 {
|
||||
hasher := fnv.New32a()
|
||||
_, _ = hasher.Write([]byte(term))
|
||||
hasher.Write([]byte(term)) //nolint:errcheck // hash.Write never returns error
|
||||
return hasher.Sum32() & sparseDimMask
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func (c *LLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractRes
|
||||
if len(req.Messages) == 0 {
|
||||
return ExtractResponse{}, fmt.Errorf("messages is required")
|
||||
}
|
||||
parsedMessages := parseMessages(formatMessages(req.Messages))
|
||||
parsedMessages := strings.Join(formatMessages(req.Messages), "\n")
|
||||
systemPrompt, userPrompt := getFactRetrievalMessages(parsedMessages)
|
||||
content, err := c.callChat(ctx, []chatMessage{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
@@ -122,7 +122,7 @@ func (c *LLMClient) Decide(ctx context.Context, req DecideRequest) (DecideRespon
|
||||
|
||||
actions = append(actions, DecisionAction{
|
||||
Event: event,
|
||||
ID: normalizeID(item["id"]),
|
||||
ID: asString(item["id"]),
|
||||
Text: text,
|
||||
OldMemory: asString(item["old_memory"]),
|
||||
})
|
||||
@@ -242,14 +242,6 @@ func asString(value any) string {
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeID(value any) string {
|
||||
id := asString(value)
|
||||
if id == "" {
|
||||
return ""
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func normalizeMemoryItems(value any) []map[string]any {
|
||||
switch typed := value.(type) {
|
||||
case []any:
|
||||
|
||||
@@ -120,10 +120,6 @@ Before finalizing, verify the value is one of the allowed codes.`
|
||||
return systemPrompt, userPrompt
|
||||
}
|
||||
|
||||
func parseMessages(messages []string) string {
|
||||
return strings.Join(messages, "\n")
|
||||
}
|
||||
|
||||
func removeCodeBlocks(text string) string {
|
||||
return strings.ReplaceAll(strings.ReplaceAll(text, "```json", ""), "```", "")
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimens
|
||||
collection = "memory"
|
||||
}
|
||||
if dimension <= 0 && strings.TrimSpace(sparseVectorName) == "" {
|
||||
dimension = 1536
|
||||
return nil, fmt.Errorf("embedding dimension is required")
|
||||
}
|
||||
|
||||
cfg := &qdrant.Config{
|
||||
@@ -455,9 +455,7 @@ func (s *QdrantStore) refreshCollectionSchema(ctx context.Context, vectors map[s
|
||||
s.vectorNames[name] = int(vec.GetSize())
|
||||
}
|
||||
}
|
||||
if len(vectors) == 0 {
|
||||
goto sparseCheck
|
||||
}
|
||||
if len(vectors) > 0 {
|
||||
for name, dim := range vectors {
|
||||
if existing, ok := s.vectorNames[name]; ok && existing == dim {
|
||||
continue
|
||||
@@ -465,12 +463,11 @@ func (s *QdrantStore) refreshCollectionSchema(ctx context.Context, vectors map[s
|
||||
return fmt.Errorf("collection missing vector %s (dim %d); migration required", name, dim)
|
||||
}
|
||||
}
|
||||
if vectorsConfig == nil || vectorsConfig.GetParamsMap() == nil {
|
||||
} else {
|
||||
s.usesNamedVectors = false
|
||||
s.vectorNames = nil
|
||||
}
|
||||
|
||||
sparseCheck:
|
||||
sparseConfig := params.GetSparseVectorsConfig()
|
||||
if s.sparseVectorName != "" {
|
||||
needsUpdate := false
|
||||
@@ -506,7 +503,7 @@ func (s *QdrantStore) ensurePayloadIndexes(ctx context.Context) error {
|
||||
if s.client == nil {
|
||||
return nil
|
||||
}
|
||||
fields := []string{"botId", "runId"}
|
||||
fields := []string{"bot_id", "run_id"}
|
||||
wait := true
|
||||
for _, field := range fields {
|
||||
_, err := s.client.CreateFieldIndex(ctx, &qdrant.CreateFieldIndexCollection{
|
||||
|
||||
+63
-80
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -287,7 +286,7 @@ func (s *Service) EmbedUpsert(ctx context.Context, req EmbedUpsertRequest) (Embe
|
||||
}
|
||||
|
||||
vectorName := ""
|
||||
if s.store != nil && s.store.usesNamedVectors {
|
||||
if s.store.usesNamedVectors {
|
||||
vectorName = result.Model
|
||||
}
|
||||
|
||||
@@ -336,19 +335,22 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er
|
||||
if existing == nil {
|
||||
return MemoryItem{}, fmt.Errorf("memory not found")
|
||||
}
|
||||
if s.bm25 == nil {
|
||||
return MemoryItem{}, fmt.Errorf("bm25 indexer not configured")
|
||||
}
|
||||
|
||||
payload := existing.Payload
|
||||
oldText := fmt.Sprint(payload["data"])
|
||||
oldLang := fmt.Sprint(payload["lang"])
|
||||
if oldLang == "" && strings.TrimSpace(oldText) != "" {
|
||||
oldLang, _ = s.detectLanguage(ctx, oldText)
|
||||
var detectErr error
|
||||
oldLang, detectErr = s.detectLanguage(ctx, oldText)
|
||||
if detectErr != nil {
|
||||
s.logger.Warn("detect language failed for old text", slog.Any("error", detectErr))
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" {
|
||||
oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText)
|
||||
if err == nil {
|
||||
if err != nil {
|
||||
s.logger.Warn("bm25 term frequencies failed", slog.String("lang", oldLang), slog.Any("error", err))
|
||||
} else {
|
||||
s.bm25.RemoveDocument(oldLang, oldFreq, oldLen)
|
||||
}
|
||||
}
|
||||
@@ -365,7 +367,7 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er
|
||||
|
||||
payload["data"] = req.Memory
|
||||
payload["hash"] = hashMemory(req.Memory)
|
||||
payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339)
|
||||
payload["updated_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||
payload["lang"] = newLang
|
||||
|
||||
embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled
|
||||
@@ -413,10 +415,13 @@ func (s *Service) GetAll(ctx context.Context, req GetAllRequest) (SearchResponse
|
||||
filters[k] = v
|
||||
}
|
||||
if req.BotID != "" {
|
||||
filters["botId"] = req.BotID
|
||||
filters["bot_id"] = req.BotID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agent_id"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
filters["run_id"] = req.RunID
|
||||
}
|
||||
if len(filters) == 0 {
|
||||
return SearchResponse{}, fmt.Errorf("bot_id, agent_id or run_id is required")
|
||||
@@ -449,10 +454,13 @@ func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteRe
|
||||
filters[k] = v
|
||||
}
|
||||
if req.BotID != "" {
|
||||
filters["botId"] = req.BotID
|
||||
filters["bot_id"] = req.BotID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agent_id"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
filters["run_id"] = req.RunID
|
||||
}
|
||||
if len(filters) == 0 {
|
||||
return DeleteResponse{}, fmt.Errorf("bot_id, agent_id or run_id is required")
|
||||
@@ -487,6 +495,7 @@ func (s *Service) WarmupBM25(ctx context.Context, batchSize int) error {
|
||||
}
|
||||
termFreq, docLen, err := s.bm25.TermFrequencies(lang, text)
|
||||
if err != nil {
|
||||
s.logger.Warn("bm25 warmup: term frequencies failed", slog.String("id", point.ID), slog.Any("error", err))
|
||||
continue
|
||||
}
|
||||
s.bm25.AddDocument(lang, termFreq, docLen)
|
||||
@@ -609,11 +618,17 @@ func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[
|
||||
oldText := fmt.Sprint(payload["data"])
|
||||
oldLang := fmt.Sprint(payload["lang"])
|
||||
if oldLang == "" && strings.TrimSpace(oldText) != "" {
|
||||
oldLang, _ = s.detectLanguage(ctx, oldText)
|
||||
var detectErr error
|
||||
oldLang, detectErr = s.detectLanguage(ctx, oldText)
|
||||
if detectErr != nil {
|
||||
s.logger.Warn("detect language failed for old text", slog.Any("error", detectErr))
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" {
|
||||
oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText)
|
||||
if err == nil {
|
||||
if err != nil {
|
||||
s.logger.Warn("bm25 term frequencies failed", slog.String("lang", oldLang), slog.Any("error", err))
|
||||
} else {
|
||||
s.bm25.RemoveDocument(oldLang, oldFreq, oldLen)
|
||||
}
|
||||
}
|
||||
@@ -628,7 +643,7 @@ func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[
|
||||
sparseIndices, sparseValues := s.bm25.AddDocument(newLang, newFreq, newLen)
|
||||
payload["data"] = text
|
||||
payload["hash"] = hashMemory(text)
|
||||
payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339)
|
||||
payload["updated_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||
payload["lang"] = newLang
|
||||
if metadata != nil {
|
||||
payload["metadata"] = mergeMetadata(payload["metadata"], metadata)
|
||||
@@ -676,11 +691,17 @@ func (s *Service) applyDelete(ctx context.Context, id string) (MemoryItem, error
|
||||
oldText := fmt.Sprint(existing.Payload["data"])
|
||||
oldLang := fmt.Sprint(existing.Payload["lang"])
|
||||
if oldLang == "" && strings.TrimSpace(oldText) != "" {
|
||||
oldLang, _ = s.detectLanguage(ctx, oldText)
|
||||
var detectErr error
|
||||
oldLang, detectErr = s.detectLanguage(ctx, oldText)
|
||||
if detectErr != nil {
|
||||
s.logger.Warn("detect language failed for old text", slog.Any("error", detectErr))
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" {
|
||||
oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText)
|
||||
if err == nil {
|
||||
if err != nil {
|
||||
s.logger.Warn("bm25 term frequencies failed", slog.String("lang", oldLang), slog.Any("error", err))
|
||||
} else {
|
||||
s.bm25.RemoveDocument(oldLang, oldFreq, oldLen)
|
||||
}
|
||||
}
|
||||
@@ -754,10 +775,13 @@ func buildFilters(req AddRequest) map[string]any {
|
||||
filters[key] = value
|
||||
}
|
||||
if req.BotID != "" {
|
||||
filters["botId"] = req.BotID
|
||||
filters["bot_id"] = req.BotID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agent_id"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
filters["run_id"] = req.RunID
|
||||
}
|
||||
return filters
|
||||
}
|
||||
@@ -768,10 +792,13 @@ func buildSearchFilters(req SearchRequest) map[string]any {
|
||||
filters[key] = value
|
||||
}
|
||||
if req.BotID != "" {
|
||||
filters["botId"] = req.BotID
|
||||
filters["bot_id"] = req.BotID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agent_id"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
filters["run_id"] = req.RunID
|
||||
}
|
||||
return filters
|
||||
}
|
||||
@@ -782,10 +809,13 @@ func buildEmbedFilters(req EmbedUpsertRequest) map[string]any {
|
||||
filters[key] = value
|
||||
}
|
||||
if req.BotID != "" {
|
||||
filters["botId"] = req.BotID
|
||||
filters["bot_id"] = req.BotID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agent_id"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
filters["run_id"] = req.RunID
|
||||
}
|
||||
return filters
|
||||
}
|
||||
@@ -842,7 +872,7 @@ func buildPayload(text string, filters map[string]any, metadata map[string]any,
|
||||
payload := map[string]any{
|
||||
"data": text,
|
||||
"hash": hashMemory(text),
|
||||
"createdAt": createdAt,
|
||||
"created_at": createdAt,
|
||||
}
|
||||
if metadata != nil {
|
||||
payload["metadata"] = metadata
|
||||
@@ -865,16 +895,19 @@ func payloadToMemoryItem(id string, payload map[string]any) MemoryItem {
|
||||
if v, ok := payload["hash"].(string); ok {
|
||||
item.Hash = v
|
||||
}
|
||||
if v, ok := payload["createdAt"].(string); ok {
|
||||
if v, ok := payload["created_at"].(string); ok {
|
||||
item.CreatedAt = v
|
||||
}
|
||||
if v, ok := payload["updatedAt"].(string); ok {
|
||||
if v, ok := payload["updated_at"].(string); ok {
|
||||
item.UpdatedAt = v
|
||||
}
|
||||
if v, ok := payload["botId"].(string); ok {
|
||||
if v, ok := payload["bot_id"].(string); ok {
|
||||
item.BotID = v
|
||||
}
|
||||
if v, ok := payload["runId"].(string); ok {
|
||||
if v, ok := payload["agent_id"].(string); ok {
|
||||
item.AgentID = v
|
||||
}
|
||||
if v, ok := payload["run_id"].(string); ok {
|
||||
item.RunID = v
|
||||
}
|
||||
if meta, ok := payload["metadata"].(map[string]any); ok {
|
||||
@@ -924,76 +957,33 @@ func mergeMetadata(base any, extra map[string]any) map[string]any {
|
||||
type rerankCandidate struct {
|
||||
ID string
|
||||
Payload map[string]any
|
||||
Score float64
|
||||
Source string
|
||||
Rank int
|
||||
}
|
||||
|
||||
const (
|
||||
fusionModeRRF = "rrf"
|
||||
fusionModeCombMNZ = "combmnz"
|
||||
fusionMode = fusionModeRRF
|
||||
rrfK = 60.0
|
||||
)
|
||||
|
||||
func fuseByRankFusion(pointsBySource map[string][]qdrantPoint, scoresBySource map[string][]float64) []MemoryItem {
|
||||
func fuseByRankFusion(pointsBySource map[string][]qdrantPoint, _ map[string][]float64) []MemoryItem {
|
||||
candidates := map[string]*rerankCandidate{}
|
||||
rrfScores := map[string]float64{}
|
||||
combScores := map[string]float64{}
|
||||
combCounts := map[string]int{}
|
||||
|
||||
for source, points := range pointsBySource {
|
||||
scores := scoresBySource[source]
|
||||
minScore := math.MaxFloat64
|
||||
maxScore := -math.MaxFloat64
|
||||
for _, points := range pointsBySource {
|
||||
for idx, point := range points {
|
||||
if idx >= len(scores) {
|
||||
continue
|
||||
}
|
||||
score := scores[idx]
|
||||
if score < minScore {
|
||||
minScore = score
|
||||
}
|
||||
if score > maxScore {
|
||||
maxScore = score
|
||||
}
|
||||
if _, ok := candidates[point.ID]; !ok {
|
||||
candidates[point.ID] = &rerankCandidate{
|
||||
ID: point.ID,
|
||||
Payload: point.Payload,
|
||||
}
|
||||
}
|
||||
}
|
||||
if minScore == math.MaxFloat64 {
|
||||
minScore = 0
|
||||
}
|
||||
if maxScore == -math.MaxFloat64 {
|
||||
maxScore = minScore
|
||||
}
|
||||
|
||||
for idx, point := range points {
|
||||
if idx >= len(scores) {
|
||||
continue
|
||||
}
|
||||
score := scores[idx]
|
||||
rank := float64(idx + 1)
|
||||
rrfScores[point.ID] += 1.0 / (rrfK + rank)
|
||||
|
||||
scoreNorm := normalizeScore(score, minScore, maxScore)
|
||||
combScores[point.ID] += scoreNorm
|
||||
combCounts[point.ID]++
|
||||
}
|
||||
}
|
||||
|
||||
items := make([]MemoryItem, 0, len(candidates))
|
||||
for id, candidate := range candidates {
|
||||
item := payloadToMemoryItem(candidate.ID, candidate.Payload)
|
||||
switch fusionMode {
|
||||
case fusionModeCombMNZ:
|
||||
item.Score = combScores[id] * float64(combCounts[id])
|
||||
default:
|
||||
item.Score = rrfScores[id]
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
@@ -1002,10 +992,3 @@ func fuseByRankFusion(pointsBySource map[string][]qdrantPoint, scoresBySource ma
|
||||
})
|
||||
return items
|
||||
}
|
||||
|
||||
func normalizeScore(score, minScore, maxScore float64) float64 {
|
||||
if maxScore <= minScore {
|
||||
return 1
|
||||
}
|
||||
return (score - minScore) / (maxScore - minScore)
|
||||
}
|
||||
|
||||
@@ -88,13 +88,13 @@ type MemoryItem struct {
|
||||
ID string `json:"id"`
|
||||
Memory string `json:"memory"`
|
||||
Hash string `json:"hash,omitempty"`
|
||||
CreatedAt string `json:"createdAt,omitempty"`
|
||||
UpdatedAt string `json:"updatedAt,omitempty"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
Score float64 `json:"score,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
BotID string `json:"botId,omitempty"`
|
||||
AgentID string `json:"agentId,omitempty"`
|
||||
RunID string `json:"runId,omitempty"`
|
||||
BotID string `json:"bot_id,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
}
|
||||
|
||||
type SearchResponse struct {
|
||||
|
||||
Reference in New Issue
Block a user