mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
refactor: use sparse vector for memory
This commit is contained in:
@@ -0,0 +1,252 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"log/slog"
|
||||
"math"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/blevesearch/bleve/v2/registry"
|
||||
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/analyzer/standard"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/ar"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/bg"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/ca"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/cjk"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/ckb"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/da"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/de"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/el"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/en"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/es"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/eu"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/fa"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/fi"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/fr"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/ga"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/gl"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/hi"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/hr"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/hu"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/hy"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/id"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/it"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/nl"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/no"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/pl"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/pt"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/ro"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/ru"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/sv"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/tr"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBM25K1 = 1.2
|
||||
defaultBM25B = 0.75
|
||||
sparseDimBits = 20
|
||||
sparseDimSize = 1 << sparseDimBits
|
||||
sparseDimMask = sparseDimSize - 1
|
||||
)
|
||||
|
||||
type BM25Indexer struct {
|
||||
cache *registry.Cache
|
||||
logger *slog.Logger
|
||||
k1 float64
|
||||
b float64
|
||||
|
||||
mu sync.RWMutex
|
||||
stats map[string]*bm25Stats
|
||||
}
|
||||
|
||||
type bm25Stats struct {
|
||||
DocCount int
|
||||
AvgDocLen float64
|
||||
DocFreq map[string]int
|
||||
}
|
||||
|
||||
func NewBM25Indexer(log *slog.Logger) *BM25Indexer {
|
||||
if log == nil {
|
||||
log = slog.Default()
|
||||
}
|
||||
return &BM25Indexer{
|
||||
cache: registry.NewCache(),
|
||||
logger: log.With(slog.String("indexer", "bm25")),
|
||||
k1: defaultBM25K1,
|
||||
b: defaultBM25B,
|
||||
stats: map[string]*bm25Stats{},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) TermFrequencies(lang, text string) (map[string]int, int, error) {
|
||||
analyzerName, err := b.normalizeAnalyzer(lang)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
analyzer, err := b.cache.AnalyzerNamed(analyzerName)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("bm25 analyzer %s: %w", analyzerName, err)
|
||||
}
|
||||
tokens := analyzer.Analyze([]byte(text))
|
||||
freq := map[string]int{}
|
||||
docLen := 0
|
||||
for _, token := range tokens {
|
||||
term := strings.TrimSpace(string(token.Term))
|
||||
if term == "" {
|
||||
continue
|
||||
}
|
||||
freq[term]++
|
||||
docLen++
|
||||
}
|
||||
return freq, docLen, nil
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) AddDocument(lang string, termFreq map[string]int, docLen int) (indices []uint32, values []float32) {
|
||||
b.mu.Lock()
|
||||
stats := b.ensureStatsLocked(lang)
|
||||
b.updateStatsAddLocked(stats, termFreq, docLen)
|
||||
indices, values = b.buildDocVectorLocked(stats, termFreq, docLen)
|
||||
b.mu.Unlock()
|
||||
return indices, values
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) RemoveDocument(lang string, termFreq map[string]int, docLen int) {
|
||||
b.mu.Lock()
|
||||
stats := b.ensureStatsLocked(lang)
|
||||
b.updateStatsRemoveLocked(stats, termFreq, docLen)
|
||||
b.mu.Unlock()
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) BuildQueryVector(lang string, termFreq map[string]int) (indices []uint32, values []float32) {
|
||||
b.mu.RLock()
|
||||
stats := b.ensureStatsLocked(lang)
|
||||
indices, values = b.buildQueryVectorLocked(stats, termFreq)
|
||||
b.mu.RUnlock()
|
||||
return indices, values
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) normalizeAnalyzer(lang string) (string, error) {
|
||||
normalized := strings.ToLower(strings.TrimSpace(lang))
|
||||
switch normalized {
|
||||
case "":
|
||||
return "standard", nil
|
||||
case "in":
|
||||
normalized = "id"
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) ensureStatsLocked(lang string) *bm25Stats {
|
||||
name, _ := b.normalizeAnalyzer(lang)
|
||||
stats := b.stats[name]
|
||||
if stats == nil {
|
||||
stats = &bm25Stats{
|
||||
DocFreq: map[string]int{},
|
||||
}
|
||||
b.stats[name] = stats
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) updateStatsAddLocked(stats *bm25Stats, termFreq map[string]int, docLen int) {
|
||||
totalDocs := stats.DocCount
|
||||
stats.DocCount++
|
||||
totalLen := stats.AvgDocLen * float64(totalDocs)
|
||||
stats.AvgDocLen = (totalLen + float64(docLen)) / float64(stats.DocCount)
|
||||
for term := range termFreq {
|
||||
stats.DocFreq[term]++
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) updateStatsRemoveLocked(stats *bm25Stats, termFreq map[string]int, docLen int) {
|
||||
if stats.DocCount <= 0 {
|
||||
return
|
||||
}
|
||||
totalDocs := stats.DocCount
|
||||
totalLen := stats.AvgDocLen * float64(totalDocs)
|
||||
stats.DocCount--
|
||||
if stats.DocCount > 0 {
|
||||
stats.AvgDocLen = (totalLen - float64(docLen)) / float64(stats.DocCount)
|
||||
} else {
|
||||
stats.AvgDocLen = 0
|
||||
}
|
||||
for term := range termFreq {
|
||||
if stats.DocFreq[term] > 1 {
|
||||
stats.DocFreq[term]--
|
||||
} else {
|
||||
delete(stats.DocFreq, term)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) buildDocVectorLocked(stats *bm25Stats, termFreq map[string]int, docLen int) ([]uint32, []float32) {
|
||||
if stats.DocCount == 0 || docLen == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
avgDocLen := stats.AvgDocLen
|
||||
if avgDocLen <= 0 {
|
||||
avgDocLen = 1
|
||||
}
|
||||
weights := map[uint32]float32{}
|
||||
for term, tf := range termFreq {
|
||||
df := stats.DocFreq[term]
|
||||
if df == 0 {
|
||||
continue
|
||||
}
|
||||
idf := math.Log(1 + (float64(stats.DocCount)-float64(df)+0.5)/(float64(df)+0.5))
|
||||
numerator := float64(tf) * (b.k1 + 1)
|
||||
denominator := float64(tf) + b.k1*(1-b.b+b.b*float64(docLen)/avgDocLen)
|
||||
tfNorm := numerator / denominator
|
||||
weight := float32(tfNorm * idf)
|
||||
if weight == 0 {
|
||||
continue
|
||||
}
|
||||
index := termHash(term)
|
||||
weights[index] += weight
|
||||
}
|
||||
return sparseWeightsToVector(weights)
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) buildQueryVectorLocked(stats *bm25Stats, termFreq map[string]int) ([]uint32, []float32) {
|
||||
if stats.DocCount == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
weights := map[uint32]float32{}
|
||||
for term, tf := range termFreq {
|
||||
if stats.DocFreq[term] == 0 {
|
||||
continue
|
||||
}
|
||||
weight := float32(tf)
|
||||
if weight == 0 {
|
||||
continue
|
||||
}
|
||||
index := termHash(term)
|
||||
weights[index] += weight
|
||||
}
|
||||
return sparseWeightsToVector(weights)
|
||||
}
|
||||
|
||||
func sparseWeightsToVector(weights map[uint32]float32) ([]uint32, []float32) {
|
||||
if len(weights) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
indices := make([]uint32, 0, len(weights))
|
||||
for idx := range weights {
|
||||
indices = append(indices, idx)
|
||||
}
|
||||
sort.Slice(indices, func(i, j int) bool { return indices[i] < indices[j] })
|
||||
values := make([]float32, 0, len(indices))
|
||||
for _, idx := range indices {
|
||||
values = append(values, weights[idx])
|
||||
}
|
||||
return indices, values
|
||||
}
|
||||
|
||||
func termHash(term string) uint32 {
|
||||
hasher := fnv.New32a()
|
||||
_, _ = hasher.Write([]byte(term))
|
||||
return hasher.Sum32() & sparseDimMask
|
||||
}
|
||||
@@ -29,7 +29,7 @@ func NewLLMClient(log *slog.Logger, baseURL, apiKey, model string, timeout time.
|
||||
}
|
||||
baseURL = strings.TrimRight(baseURL, "/")
|
||||
if model == "" {
|
||||
model = "gpt-4.1-nano-2025-04-14"
|
||||
model = "gpt-4.1-nano"
|
||||
}
|
||||
if timeout <= 0 {
|
||||
timeout = 10 * time.Second
|
||||
@@ -119,6 +119,31 @@ func (c *LLMClient) Decide(ctx context.Context, req DecideRequest) (DecideRespon
|
||||
return DecideResponse{Actions: actions}, nil
|
||||
}
|
||||
|
||||
func (c *LLMClient) DetectLanguage(ctx context.Context, text string) (string, error) {
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return "", fmt.Errorf("text is required")
|
||||
}
|
||||
systemPrompt, userPrompt := getLanguageDetectionMessages(text)
|
||||
content, err := c.callChat(ctx, []chatMessage{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: userPrompt},
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var parsed struct {
|
||||
Language string `json:"language"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &parsed); err != nil {
|
||||
return "", err
|
||||
}
|
||||
lang := strings.ToLower(strings.TrimSpace(parsed.Language))
|
||||
if !isAllowedLanguageCode(lang) {
|
||||
return "", fmt.Errorf("unsupported language code: %s", lang)
|
||||
}
|
||||
return lang, nil
|
||||
}
|
||||
|
||||
type chatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
@@ -247,3 +272,14 @@ func normalizeMemoryItems(value interface{}) []map[string]interface{} {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func isAllowedLanguageCode(code string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(code)) {
|
||||
case "ar", "bg", "ca", "cjk", "ckb", "da", "de", "el", "en", "es", "eu",
|
||||
"fa", "fi", "fr", "ga", "gl", "hi", "hr", "hu", "hy", "id", "in",
|
||||
"it", "nl", "no", "pl", "pt", "ro", "ru", "sv", "tr":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,6 +106,20 @@ Follow the instruction mentioned below:
|
||||
Do not return anything except the JSON format.`, toJSON(retrievedOldMemory), toJSON(newRetrievedFacts), "```json", "```")
|
||||
}
|
||||
|
||||
func getLanguageDetectionMessages(text string) (string, string) {
|
||||
systemPrompt := `You are a language classifier for the given input text.
|
||||
Return a JSON object with a single key "language" whose value is one of the allowed codes.
|
||||
Allowed codes: ar, bg, ca, cjk, ckb, da, de, el, en, es, eu, fa, fi, fr, ga, gl, hi, hr, hu, hy, id, in, it, nl, no, pl, pt, ro, ru, sv, tr.
|
||||
Use "cjk" for Chinese/Japanese/Korean text, ckb=Kurdish(Sorani), ga=Irish(Gaelic), gl=Galician, eu=Basque, hy=Armenian, fa=Persian, hr=Croatian, hu=Hungarian, ro=Romanian, bg=Bulgarian. If unsure between id/in, use id.
|
||||
If multiple languages appear, choose the dominant language.
|
||||
Do not include any extra keys, comments, or formatting. Output must be valid JSON only.
|
||||
If the text is Chinese, Japanese, or Korean, output exactly {"language":"cjk"}.
|
||||
Never output "zh", "zh-cn", "zh-tw", "ja", "ko", or any code not in the allowed list.
|
||||
Before finalizing, verify the value is one of the allowed codes.`
|
||||
userPrompt := fmt.Sprintf("Text:\n%s", text)
|
||||
return systemPrompt, userPrompt
|
||||
}
|
||||
|
||||
func parseMessages(messages []string) string {
|
||||
return strings.Join(messages, "\n")
|
||||
}
|
||||
|
||||
+226
-52
@@ -12,34 +12,47 @@ import (
|
||||
"github.com/qdrant/go-client/qdrant"
|
||||
)
|
||||
|
||||
const (
|
||||
sparseHashVectorName = "sparse_hash"
|
||||
sparseVocabVectorName = "sparse_vocab"
|
||||
)
|
||||
|
||||
type QdrantStore struct {
|
||||
client *qdrant.Client
|
||||
collection string
|
||||
dimension int
|
||||
baseURL string
|
||||
apiKey string
|
||||
timeout time.Duration
|
||||
logger *slog.Logger
|
||||
vectorNames map[string]int
|
||||
usesNamedVectors bool
|
||||
client *qdrant.Client
|
||||
collection string
|
||||
dimension int
|
||||
baseURL string
|
||||
apiKey string
|
||||
timeout time.Duration
|
||||
logger *slog.Logger
|
||||
vectorNames map[string]int
|
||||
usesNamedVectors bool
|
||||
sparseVectorName string
|
||||
usesSparseVectors bool
|
||||
}
|
||||
|
||||
type qdrantPoint struct {
|
||||
ID string `json:"id"`
|
||||
Vector []float32 `json:"vector"`
|
||||
VectorName string `json:"vector_name,omitempty"`
|
||||
Payload map[string]interface{} `json:"payload,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Vector []float32 `json:"vector"`
|
||||
VectorName string `json:"vector_name,omitempty"`
|
||||
SparseIndices []uint32 `json:"sparse_indices,omitempty"`
|
||||
SparseValues []float32 `json:"sparse_values,omitempty"`
|
||||
SparseVectorName string `json:"sparse_vector_name,omitempty"`
|
||||
Payload map[string]interface{} `json:"payload,omitempty"`
|
||||
}
|
||||
|
||||
func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimension int, timeout time.Duration) (*QdrantStore, error) {
|
||||
func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimension int, sparseVectorName string, timeout time.Duration) (*QdrantStore, error) {
|
||||
host, port, useTLS, err := parseQdrantEndpoint(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(sparseVectorName) == "" {
|
||||
sparseVectorName = sparseHashVectorName
|
||||
}
|
||||
if collection == "" {
|
||||
collection = "memory"
|
||||
}
|
||||
if dimension <= 0 {
|
||||
if dimension <= 0 && strings.TrimSpace(sparseVectorName) == "" {
|
||||
dimension = 1536
|
||||
}
|
||||
|
||||
@@ -55,13 +68,15 @@ func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimens
|
||||
}
|
||||
|
||||
store := &QdrantStore{
|
||||
client: client,
|
||||
collection: collection,
|
||||
dimension: dimension,
|
||||
baseURL: baseURL,
|
||||
apiKey: apiKey,
|
||||
timeout: timeoutOrDefault(timeout),
|
||||
logger: log.With(slog.String("store", "qdrant")),
|
||||
client: client,
|
||||
collection: collection,
|
||||
dimension: dimension,
|
||||
baseURL: baseURL,
|
||||
apiKey: apiKey,
|
||||
timeout: timeoutOrDefault(timeout),
|
||||
logger: log.With(slog.String("store", "qdrant")),
|
||||
sparseVectorName: strings.TrimSpace(sparseVectorName),
|
||||
usesSparseVectors: strings.TrimSpace(sparseVectorName) != "",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeoutOrDefault(timeout))
|
||||
@@ -73,14 +88,17 @@ func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimens
|
||||
}
|
||||
|
||||
func (s *QdrantStore) NewSibling(collection string, dimension int) (*QdrantStore, error) {
|
||||
return NewQdrantStore(s.logger, s.baseURL, s.apiKey, collection, dimension, s.timeout)
|
||||
return NewQdrantStore(s.logger, s.baseURL, s.apiKey, collection, dimension, s.sparseVectorName, s.timeout)
|
||||
}
|
||||
|
||||
func NewQdrantStoreWithVectors(log *slog.Logger, baseURL, apiKey, collection string, vectors map[string]int, timeout time.Duration) (*QdrantStore, error) {
|
||||
func NewQdrantStoreWithVectors(log *slog.Logger, baseURL, apiKey, collection string, vectors map[string]int, sparseVectorName string, timeout time.Duration) (*QdrantStore, error) {
|
||||
host, port, useTLS, err := parseQdrantEndpoint(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(sparseVectorName) == "" {
|
||||
sparseVectorName = sparseHashVectorName
|
||||
}
|
||||
if collection == "" {
|
||||
collection = "memory"
|
||||
}
|
||||
@@ -100,14 +118,16 @@ func NewQdrantStoreWithVectors(log *slog.Logger, baseURL, apiKey, collection str
|
||||
}
|
||||
|
||||
store := &QdrantStore{
|
||||
client: client,
|
||||
collection: collection,
|
||||
baseURL: baseURL,
|
||||
apiKey: apiKey,
|
||||
timeout: timeoutOrDefault(timeout),
|
||||
logger: log.With(slog.String("store", "qdrant")),
|
||||
vectorNames: vectors,
|
||||
usesNamedVectors: true,
|
||||
client: client,
|
||||
collection: collection,
|
||||
baseURL: baseURL,
|
||||
apiKey: apiKey,
|
||||
timeout: timeoutOrDefault(timeout),
|
||||
logger: log.With(slog.String("store", "qdrant")),
|
||||
vectorNames: vectors,
|
||||
usesNamedVectors: true,
|
||||
sparseVectorName: strings.TrimSpace(sparseVectorName),
|
||||
usesSparseVectors: strings.TrimSpace(sparseVectorName) != "",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeoutOrDefault(timeout))
|
||||
@@ -129,12 +149,31 @@ func (s *QdrantStore) Upsert(ctx context.Context, points []qdrantPoint) error {
|
||||
return err
|
||||
}
|
||||
var vectors *qdrant.Vectors
|
||||
if point.VectorName != "" && s.usesNamedVectors {
|
||||
vectors = qdrant.NewVectorsMap(map[string]*qdrant.Vector{
|
||||
point.VectorName: qdrant.NewVectorDense(point.Vector),
|
||||
})
|
||||
} else {
|
||||
vectors = qdrant.NewVectorsDense(point.Vector)
|
||||
vectorMap := map[string]*qdrant.Vector{}
|
||||
if len(point.Vector) > 0 {
|
||||
if point.VectorName != "" && s.usesNamedVectors {
|
||||
vectorMap[point.VectorName] = qdrant.NewVectorDense(point.Vector)
|
||||
} else if !s.usesNamedVectors && len(point.SparseIndices) == 0 {
|
||||
vectors = qdrant.NewVectorsDense(point.Vector)
|
||||
} else if point.VectorName != "" {
|
||||
vectorMap[point.VectorName] = qdrant.NewVectorDense(point.Vector)
|
||||
}
|
||||
}
|
||||
if len(point.SparseIndices) > 0 && len(point.SparseValues) > 0 {
|
||||
sparseName := strings.TrimSpace(point.SparseVectorName)
|
||||
if sparseName == "" {
|
||||
sparseName = s.sparseVectorName
|
||||
}
|
||||
if sparseName == "" {
|
||||
return fmt.Errorf("sparse vector name is required")
|
||||
}
|
||||
vectorMap[sparseName] = qdrant.NewVectorSparse(point.SparseIndices, point.SparseValues)
|
||||
}
|
||||
if vectors == nil {
|
||||
if len(vectorMap) == 0 {
|
||||
return fmt.Errorf("no vector data provided for point %s", point.ID)
|
||||
}
|
||||
vectors = qdrant.NewVectorsMap(vectorMap)
|
||||
}
|
||||
qPoints = append(qPoints, &qdrant.PointStruct{
|
||||
Id: qdrant.NewIDUUID(point.ID),
|
||||
@@ -183,6 +222,41 @@ func (s *QdrantStore) Search(ctx context.Context, vector []float32, limit int, f
|
||||
return points, scores, nil
|
||||
}
|
||||
|
||||
func (s *QdrantStore) SearchSparse(ctx context.Context, indices []uint32, values []float32, limit int, filters map[string]interface{}) ([]qdrantPoint, []float64, error) {
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
if len(indices) == 0 || len(values) == 0 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
if s.sparseVectorName == "" {
|
||||
return nil, nil, fmt.Errorf("sparse vector name not configured")
|
||||
}
|
||||
filter := buildQdrantFilter(filters)
|
||||
using := qdrant.PtrOf(s.sparseVectorName)
|
||||
results, err := s.client.Query(ctx, &qdrant.QueryPoints{
|
||||
CollectionName: s.collection,
|
||||
Query: qdrant.NewQuerySparse(indices, values),
|
||||
Using: using,
|
||||
Limit: qdrant.PtrOf(uint64(limit)),
|
||||
Filter: filter,
|
||||
WithPayload: qdrant.NewWithPayload(true),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
points := make([]qdrantPoint, 0, len(results))
|
||||
scores := make([]float64, 0, len(results))
|
||||
for _, scored := range results {
|
||||
points = append(points, qdrantPoint{
|
||||
ID: pointIDToString(scored.GetId()),
|
||||
Payload: valueMapToInterface(scored.GetPayload()),
|
||||
})
|
||||
scores = append(scores, float64(scored.GetScore()))
|
||||
}
|
||||
return points, scores, nil
|
||||
}
|
||||
|
||||
func (s *QdrantStore) SearchBySources(ctx context.Context, vector []float32, limit int, filters map[string]interface{}, sources []string, vectorName string) (map[string][]qdrantPoint, map[string][]float64, error) {
|
||||
pointsBySource := make(map[string][]qdrantPoint, len(sources))
|
||||
scoresBySource := make(map[string][]float64, len(sources))
|
||||
@@ -204,6 +278,27 @@ func (s *QdrantStore) SearchBySources(ctx context.Context, vector []float32, lim
|
||||
return pointsBySource, scoresBySource, nil
|
||||
}
|
||||
|
||||
func (s *QdrantStore) SearchSparseBySources(ctx context.Context, indices []uint32, values []float32, limit int, filters map[string]interface{}, sources []string) (map[string][]qdrantPoint, map[string][]float64, error) {
|
||||
pointsBySource := make(map[string][]qdrantPoint, len(sources))
|
||||
scoresBySource := make(map[string][]float64, len(sources))
|
||||
if len(sources) == 0 {
|
||||
return pointsBySource, scoresBySource, nil
|
||||
}
|
||||
for _, source := range sources {
|
||||
merged := cloneFilters(filters)
|
||||
if source != "" {
|
||||
merged["source"] = source
|
||||
}
|
||||
points, scores, err := s.SearchSparse(ctx, indices, values, limit, merged)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
pointsBySource[source] = points
|
||||
scoresBySource[source] = scores
|
||||
}
|
||||
return pointsBySource, scoresBySource, nil
|
||||
}
|
||||
|
||||
func (s *QdrantStore) Get(ctx context.Context, id string) (*qdrantPoint, error) {
|
||||
result, err := s.client.Get(ctx, &qdrant.GetPoints{
|
||||
CollectionName: s.collection,
|
||||
@@ -257,6 +352,31 @@ func (s *QdrantStore) List(ctx context.Context, limit int, filters map[string]in
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *QdrantStore) Scroll(ctx context.Context, limit int, filters map[string]interface{}, offset *qdrant.PointId) ([]qdrantPoint, *qdrant.PointId, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
filter := buildQdrantFilter(filters)
|
||||
points, nextOffset, err := s.client.ScrollAndOffset(ctx, &qdrant.ScrollPoints{
|
||||
CollectionName: s.collection,
|
||||
Limit: qdrant.PtrOf(uint32(limit)),
|
||||
Filter: filter,
|
||||
Offset: offset,
|
||||
WithPayload: qdrant.NewWithPayload(true),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
result := make([]qdrantPoint, 0, len(points))
|
||||
for _, point := range points {
|
||||
result = append(result, qdrantPoint{
|
||||
ID: pointIDToString(point.GetId()),
|
||||
Payload: valueMapToInterface(point.GetPayload()),
|
||||
})
|
||||
}
|
||||
return result, nextOffset, nil
|
||||
}
|
||||
|
||||
func (s *QdrantStore) DeleteAll(ctx context.Context, filters map[string]interface{}) error {
|
||||
filter := buildQdrantFilter(filters)
|
||||
if filter == nil {
|
||||
@@ -278,6 +398,7 @@ func (s *QdrantStore) ensureCollection(ctx context.Context, vectors map[string]i
|
||||
if exists {
|
||||
return s.refreshCollectionSchema(ctx, vectors)
|
||||
}
|
||||
var vectorsConfig *qdrant.VectorsConfig
|
||||
if len(vectors) > 0 {
|
||||
params := make(map[string]*qdrant.VectorParams, len(vectors))
|
||||
for name, dim := range vectors {
|
||||
@@ -286,17 +407,24 @@ func (s *QdrantStore) ensureCollection(ctx context.Context, vectors map[string]i
|
||||
Distance: qdrant.Distance_Cosine,
|
||||
}
|
||||
}
|
||||
return s.client.CreateCollection(ctx, &qdrant.CreateCollection{
|
||||
CollectionName: s.collection,
|
||||
VectorsConfig: qdrant.NewVectorsConfigMap(params),
|
||||
vectorsConfig = qdrant.NewVectorsConfigMap(params)
|
||||
} else if s.dimension > 0 {
|
||||
vectorsConfig = qdrant.NewVectorsConfig(&qdrant.VectorParams{
|
||||
Size: uint64(s.dimension),
|
||||
Distance: qdrant.Distance_Cosine,
|
||||
})
|
||||
}
|
||||
var sparseConfig *qdrant.SparseVectorConfig
|
||||
if s.sparseVectorName != "" {
|
||||
sparseConfig = qdrant.NewSparseVectorsConfig(map[string]*qdrant.SparseVectorParams{
|
||||
s.sparseVectorName: {Modifier: qdrant.PtrOf(qdrant.Modifier_None)},
|
||||
sparseVocabVectorName: {Modifier: qdrant.PtrOf(qdrant.Modifier_None)},
|
||||
})
|
||||
}
|
||||
return s.client.CreateCollection(ctx, &qdrant.CreateCollection{
|
||||
CollectionName: s.collection,
|
||||
VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{
|
||||
Size: uint64(s.dimension),
|
||||
Distance: qdrant.Distance_Cosine,
|
||||
}),
|
||||
CollectionName: s.collection,
|
||||
VectorsConfig: vectorsConfig,
|
||||
SparseVectorsConfig: sparseConfig,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -306,11 +434,12 @@ func (s *QdrantStore) refreshCollectionSchema(ctx context.Context, vectors map[s
|
||||
return err
|
||||
}
|
||||
config := info.GetConfig()
|
||||
if config == nil || config.GetParams() == nil || config.GetParams().GetVectorsConfig() == nil {
|
||||
if config == nil || config.GetParams() == nil {
|
||||
return nil
|
||||
}
|
||||
vectorsConfig := config.GetParams().GetVectorsConfig()
|
||||
if vectorsConfig.GetParamsMap() != nil {
|
||||
params := config.GetParams()
|
||||
vectorsConfig := params.GetVectorsConfig()
|
||||
if vectorsConfig != nil && vectorsConfig.GetParamsMap() != nil {
|
||||
s.usesNamedVectors = true
|
||||
s.vectorNames = map[string]int{}
|
||||
for name, vec := range vectorsConfig.GetParamsMap().GetMap() {
|
||||
@@ -319,7 +448,7 @@ func (s *QdrantStore) refreshCollectionSchema(ctx context.Context, vectors map[s
|
||||
}
|
||||
}
|
||||
if len(vectors) == 0 {
|
||||
return nil
|
||||
goto sparseCheck
|
||||
}
|
||||
for name, dim := range vectors {
|
||||
if existing, ok := s.vectorNames[name]; ok && existing == dim {
|
||||
@@ -327,13 +456,58 @@ func (s *QdrantStore) refreshCollectionSchema(ctx context.Context, vectors map[s
|
||||
}
|
||||
return fmt.Errorf("collection missing vector %s (dim %d); migration required", name, dim)
|
||||
}
|
||||
}
|
||||
if vectorsConfig == nil || vectorsConfig.GetParamsMap() == nil {
|
||||
s.usesNamedVectors = false
|
||||
s.vectorNames = nil
|
||||
}
|
||||
|
||||
sparseCheck:
|
||||
sparseConfig := params.GetSparseVectorsConfig()
|
||||
if s.sparseVectorName != "" {
|
||||
needsUpdate := false
|
||||
if sparseConfig == nil || len(sparseConfig.GetMap()) == 0 {
|
||||
needsUpdate = true
|
||||
} else {
|
||||
if _, ok := sparseConfig.GetMap()[s.sparseVectorName]; !ok {
|
||||
needsUpdate = true
|
||||
}
|
||||
if _, ok := sparseConfig.GetMap()[sparseVocabVectorName]; !ok {
|
||||
needsUpdate = true
|
||||
}
|
||||
}
|
||||
if needsUpdate {
|
||||
if err := s.ensureSparseVectors(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
s.usesSparseVectors = true
|
||||
return nil
|
||||
}
|
||||
s.usesNamedVectors = false
|
||||
s.vectorNames = nil
|
||||
if sparseConfig != nil && len(sparseConfig.GetMap()) > 0 {
|
||||
s.usesSparseVectors = true
|
||||
for name := range sparseConfig.GetMap() {
|
||||
s.sparseVectorName = name
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *QdrantStore) ensureSparseVectors(ctx context.Context) error {
|
||||
if s.sparseVectorName == "" {
|
||||
return nil
|
||||
}
|
||||
err := s.client.UpdateCollection(ctx, &qdrant.UpdateCollection{
|
||||
CollectionName: s.collection,
|
||||
SparseVectorsConfig: qdrant.NewSparseVectorsConfig(map[string]*qdrant.SparseVectorParams{
|
||||
s.sparseVectorName: {Modifier: qdrant.PtrOf(qdrant.Modifier_None)},
|
||||
sparseVocabVectorName: {Modifier: qdrant.PtrOf(qdrant.Modifier_None)},
|
||||
}),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func parseQdrantEndpoint(endpoint string) (string, int, bool, error) {
|
||||
if endpoint == "" {
|
||||
return "127.0.0.1", 6334, false, nil
|
||||
|
||||
+322
-78
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/qdrant/go-client/qdrant"
|
||||
|
||||
"github.com/memohai/memoh/internal/embeddings"
|
||||
)
|
||||
@@ -21,17 +22,19 @@ type Service struct {
|
||||
embedder embeddings.Embedder
|
||||
store *QdrantStore
|
||||
resolver *embeddings.Resolver
|
||||
bm25 *BM25Indexer
|
||||
logger *slog.Logger
|
||||
defaultTextModelID string
|
||||
defaultMultimodalModelID string
|
||||
}
|
||||
|
||||
func NewService(log *slog.Logger, llm LLM, embedder embeddings.Embedder, store *QdrantStore, resolver *embeddings.Resolver, defaultTextModelID, defaultMultimodalModelID string) *Service {
|
||||
func NewService(log *slog.Logger, llm LLM, embedder embeddings.Embedder, store *QdrantStore, resolver *embeddings.Resolver, bm25 *BM25Indexer, defaultTextModelID, defaultMultimodalModelID string) *Service {
|
||||
return &Service{
|
||||
llm: llm,
|
||||
embedder: embedder,
|
||||
store: store,
|
||||
resolver: resolver,
|
||||
bm25: bm25,
|
||||
logger: log.With(slog.String("service", "memory")),
|
||||
defaultTextModelID: defaultTextModelID,
|
||||
defaultMultimodalModelID: defaultMultimodalModelID,
|
||||
@@ -42,15 +45,16 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro
|
||||
if req.Message == "" && len(req.Messages) == 0 {
|
||||
return SearchResponse{}, fmt.Errorf("message or messages is required")
|
||||
}
|
||||
if req.UserID == "" && req.AgentID == "" && req.RunID == "" {
|
||||
return SearchResponse{}, fmt.Errorf("user_id, agent_id or run_id is required")
|
||||
if req.UserID == "" {
|
||||
return SearchResponse{}, fmt.Errorf("user_id is required")
|
||||
}
|
||||
|
||||
messages := normalizeMessages(req)
|
||||
filters := buildFilters(req)
|
||||
|
||||
embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled
|
||||
if req.Infer != nil && !*req.Infer {
|
||||
return s.addRawMessages(ctx, messages, filters, req.Metadata)
|
||||
return s.addRawMessages(ctx, messages, filters, req.Metadata, embeddingEnabled)
|
||||
}
|
||||
|
||||
extractResp, err := s.llm.Extract(ctx, ExtractRequest{
|
||||
@@ -95,7 +99,7 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro
|
||||
for _, action := range actions {
|
||||
switch strings.ToUpper(action.Event) {
|
||||
case "ADD":
|
||||
item, err := s.applyAdd(ctx, action.Text, filters, req.Metadata)
|
||||
item, err := s.applyAdd(ctx, action.Text, filters, req.Metadata, embeddingEnabled)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
@@ -104,7 +108,7 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro
|
||||
})
|
||||
results = append(results, item)
|
||||
case "UPDATE":
|
||||
item, err := s.applyUpdate(ctx, action.ID, action.Text, filters, req.Metadata)
|
||||
item, err := s.applyUpdate(ctx, action.ID, action.Text, filters, req.Metadata, embeddingEnabled)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
@@ -134,19 +138,19 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse
|
||||
if strings.TrimSpace(req.Query) == "" {
|
||||
return SearchResponse{}, fmt.Errorf("query is required")
|
||||
}
|
||||
if s.store == nil {
|
||||
return SearchResponse{}, fmt.Errorf("qdrant store not configured")
|
||||
}
|
||||
filters := buildSearchFilters(req)
|
||||
modality := ""
|
||||
if raw, ok := filters["modality"].(string); ok {
|
||||
modality = strings.ToLower(strings.TrimSpace(raw))
|
||||
}
|
||||
|
||||
var (
|
||||
vector []float32
|
||||
store *QdrantStore
|
||||
vectorName string
|
||||
err error
|
||||
)
|
||||
embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled
|
||||
if modality == embeddings.TypeMultimodal {
|
||||
if !embeddingEnabled {
|
||||
return SearchResponse{}, fmt.Errorf("embedding is disabled")
|
||||
}
|
||||
if s.resolver == nil {
|
||||
return SearchResponse{}, fmt.Errorf("embeddings resolver not configured")
|
||||
}
|
||||
@@ -159,24 +163,79 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
vector = result.Embedding
|
||||
store = s.store
|
||||
vectorName = s.vectorNameForMultimodal()
|
||||
} else {
|
||||
vector, err = s.embedder.Embed(ctx, req.Query)
|
||||
vectorName := s.vectorNameForMultimodal()
|
||||
if len(req.Sources) == 0 {
|
||||
points, scores, err := s.store.Search(ctx, result.Embedding, req.Limit, filters, vectorName)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
results := make([]MemoryItem, 0, len(points))
|
||||
for idx, point := range points {
|
||||
item := payloadToMemoryItem(point.ID, point.Payload)
|
||||
if idx < len(scores) {
|
||||
item.Score = scores[idx]
|
||||
}
|
||||
results = append(results, item)
|
||||
}
|
||||
return SearchResponse{Results: results}, nil
|
||||
}
|
||||
pointsBySource, scoresBySource, err := s.store.SearchBySources(ctx, result.Embedding, req.Limit, filters, req.Sources, vectorName)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
store = s.store
|
||||
vectorName = s.vectorNameForText()
|
||||
results := fuseByRankFusion(pointsBySource, scoresBySource)
|
||||
return SearchResponse{Results: results}, nil
|
||||
}
|
||||
|
||||
if len(req.Sources) == 0 {
|
||||
points, scores, err := store.Search(ctx, vector, req.Limit, filters, vectorName)
|
||||
if embeddingEnabled {
|
||||
if s.embedder == nil {
|
||||
return SearchResponse{}, fmt.Errorf("embedder not configured")
|
||||
}
|
||||
vector, err := s.embedder.Embed(ctx, req.Query)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
vectorName := s.vectorNameForText()
|
||||
if len(req.Sources) == 0 {
|
||||
points, scores, err := s.store.Search(ctx, vector, req.Limit, filters, vectorName)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
results := make([]MemoryItem, 0, len(points))
|
||||
for idx, point := range points {
|
||||
item := payloadToMemoryItem(point.ID, point.Payload)
|
||||
if idx < len(scores) {
|
||||
item.Score = scores[idx]
|
||||
}
|
||||
results = append(results, item)
|
||||
}
|
||||
return SearchResponse{Results: results}, nil
|
||||
}
|
||||
pointsBySource, scoresBySource, err := s.store.SearchBySources(ctx, vector, req.Limit, filters, req.Sources, vectorName)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
results := fuseByRankFusion(pointsBySource, scoresBySource)
|
||||
return SearchResponse{Results: results}, nil
|
||||
}
|
||||
|
||||
if s.bm25 == nil {
|
||||
return SearchResponse{}, fmt.Errorf("bm25 indexer not configured")
|
||||
}
|
||||
lang, err := s.detectLanguage(ctx, req.Query)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
termFreq, _, err := s.bm25.TermFrequencies(lang, req.Query)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
indices, values := s.bm25.BuildQueryVector(lang, termFreq)
|
||||
if len(req.Sources) == 0 {
|
||||
points, scores, err := s.store.SearchSparse(ctx, indices, values, req.Limit, filters)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
results := make([]MemoryItem, 0, len(points))
|
||||
for idx, point := range points {
|
||||
item := payloadToMemoryItem(point.ID, point.Payload)
|
||||
@@ -187,8 +246,7 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse
|
||||
}
|
||||
return SearchResponse{Results: results}, nil
|
||||
}
|
||||
|
||||
pointsBySource, scoresBySource, err := store.SearchBySources(ctx, vector, req.Limit, filters, req.Sources, vectorName)
|
||||
pointsBySource, scoresBySource, err := s.store.SearchSparseBySources(ctx, indices, values, req.Limit, filters, req.Sources)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
@@ -200,8 +258,8 @@ func (s *Service) EmbedUpsert(ctx context.Context, req EmbedUpsertRequest) (Embe
|
||||
if s.resolver == nil {
|
||||
return EmbedUpsertResponse{}, fmt.Errorf("embeddings resolver not configured")
|
||||
}
|
||||
if req.UserID == "" && req.AgentID == "" && req.RunID == "" {
|
||||
return EmbedUpsertResponse{}, fmt.Errorf("user_id, agent_id or run_id is required")
|
||||
if req.UserID == "" {
|
||||
return EmbedUpsertResponse{}, fmt.Errorf("user_id is required")
|
||||
}
|
||||
req.Type = strings.TrimSpace(req.Type)
|
||||
req.Provider = strings.TrimSpace(req.Provider)
|
||||
@@ -264,6 +322,12 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er
|
||||
if strings.TrimSpace(req.Memory) == "" {
|
||||
return MemoryItem{}, fmt.Errorf("memory is required")
|
||||
}
|
||||
if s.store == nil {
|
||||
return MemoryItem{}, fmt.Errorf("qdrant store not configured")
|
||||
}
|
||||
if s.bm25 == nil {
|
||||
return MemoryItem{}, fmt.Errorf("bm25 indexer not configured")
|
||||
}
|
||||
|
||||
existing, err := s.store.Get(ctx, req.MemoryID)
|
||||
if err != nil {
|
||||
@@ -272,22 +336,58 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er
|
||||
if existing == nil {
|
||||
return MemoryItem{}, fmt.Errorf("memory not found")
|
||||
}
|
||||
if s.bm25 == nil {
|
||||
return MemoryItem{}, fmt.Errorf("bm25 indexer not configured")
|
||||
}
|
||||
|
||||
payload := existing.Payload
|
||||
payload["data"] = req.Memory
|
||||
payload["hash"] = hashMemory(req.Memory)
|
||||
payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339)
|
||||
oldText := fmt.Sprint(payload["data"])
|
||||
oldLang := fmt.Sprint(payload["lang"])
|
||||
if oldLang == "" && strings.TrimSpace(oldText) != "" {
|
||||
oldLang, _ = s.detectLanguage(ctx, oldText)
|
||||
}
|
||||
if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" {
|
||||
oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText)
|
||||
if err == nil {
|
||||
s.bm25.RemoveDocument(oldLang, oldFreq, oldLen)
|
||||
}
|
||||
}
|
||||
|
||||
vector, err := s.embedder.Embed(ctx, req.Memory)
|
||||
newLang, err := s.detectLanguage(ctx, req.Memory)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{{
|
||||
ID: req.MemoryID,
|
||||
Vector: vector,
|
||||
VectorName: s.vectorNameForText(),
|
||||
Payload: payload,
|
||||
}}); err != nil {
|
||||
newFreq, newLen, err := s.bm25.TermFrequencies(newLang, req.Memory)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
sparseIndices, sparseValues := s.bm25.AddDocument(newLang, newFreq, newLen)
|
||||
|
||||
payload["data"] = req.Memory
|
||||
payload["hash"] = hashMemory(req.Memory)
|
||||
payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339)
|
||||
payload["lang"] = newLang
|
||||
|
||||
embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled
|
||||
point := qdrantPoint{
|
||||
ID: req.MemoryID,
|
||||
SparseIndices: sparseIndices,
|
||||
SparseValues: sparseValues,
|
||||
SparseVectorName: s.store.sparseVectorName,
|
||||
Payload: payload,
|
||||
}
|
||||
if embeddingEnabled {
|
||||
if s.embedder == nil {
|
||||
return MemoryItem{}, fmt.Errorf("embedder not configured")
|
||||
}
|
||||
vector, err := s.embedder.Embed(ctx, req.Memory)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
point.Vector = vector
|
||||
point.VectorName = s.vectorNameForText()
|
||||
}
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
return payloadToMemoryItem(req.MemoryID, payload), nil
|
||||
@@ -312,14 +412,11 @@ func (s *Service) GetAll(ctx context.Context, req GetAllRequest) (SearchResponse
|
||||
if req.UserID != "" {
|
||||
filters["userId"] = req.UserID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agentId"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
}
|
||||
if len(filters) == 0 {
|
||||
return SearchResponse{}, fmt.Errorf("user_id, agent_id or run_id is required")
|
||||
return SearchResponse{}, fmt.Errorf("user_id is required")
|
||||
}
|
||||
|
||||
points, err := s.store.List(ctx, req.Limit, filters)
|
||||
@@ -348,14 +445,11 @@ func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteRe
|
||||
if req.UserID != "" {
|
||||
filters["userId"] = req.UserID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agentId"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
}
|
||||
if len(filters) == 0 {
|
||||
return DeleteResponse{}, fmt.Errorf("user_id, agent_id or run_id is required")
|
||||
return DeleteResponse{}, fmt.Errorf("user_id is required")
|
||||
}
|
||||
if err := s.store.DeleteAll(ctx, filters); err != nil {
|
||||
return DeleteResponse{}, err
|
||||
@@ -363,10 +457,46 @@ func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteRe
|
||||
return DeleteResponse{Message: "Memories deleted successfully!"}, nil
|
||||
}
|
||||
|
||||
func (s *Service) addRawMessages(ctx context.Context, messages []Message, filters map[string]interface{}, metadata map[string]interface{}) (SearchResponse, error) {
|
||||
func (s *Service) WarmupBM25(ctx context.Context, batchSize int) error {
|
||||
if s.bm25 == nil || s.store == nil {
|
||||
return nil
|
||||
}
|
||||
var offset *qdrant.PointId
|
||||
for {
|
||||
points, next, err := s.store.Scroll(ctx, batchSize, nil, offset)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(points) == 0 {
|
||||
break
|
||||
}
|
||||
for _, point := range points {
|
||||
text := fmt.Sprint(point.Payload["data"])
|
||||
if strings.TrimSpace(text) == "" {
|
||||
continue
|
||||
}
|
||||
lang := fmt.Sprint(point.Payload["lang"])
|
||||
if lang == "" {
|
||||
lang = fallbackLanguageCode(text)
|
||||
}
|
||||
termFreq, docLen, err := s.bm25.TermFrequencies(lang, text)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
s.bm25.AddDocument(lang, termFreq, docLen)
|
||||
}
|
||||
if next == nil {
|
||||
break
|
||||
}
|
||||
offset = next
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) addRawMessages(ctx context.Context, messages []Message, filters map[string]interface{}, metadata map[string]interface{}, embeddingEnabled bool) (SearchResponse, error) {
|
||||
results := make([]MemoryItem, 0, len(messages))
|
||||
for _, message := range messages {
|
||||
item, err := s.applyAdd(ctx, message.Content, filters, metadata)
|
||||
item, err := s.applyAdd(ctx, message.Content, filters, metadata, embeddingEnabled)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
@@ -381,11 +511,19 @@ func (s *Service) addRawMessages(ctx context.Context, messages []Message, filter
|
||||
func (s *Service) collectCandidates(ctx context.Context, facts []string, filters map[string]interface{}) ([]CandidateMemory, error) {
|
||||
unique := map[string]CandidateMemory{}
|
||||
for _, fact := range facts {
|
||||
vector, err := s.embedder.Embed(ctx, fact)
|
||||
if s.bm25 == nil {
|
||||
return nil, fmt.Errorf("bm25 indexer not configured")
|
||||
}
|
||||
lang, err := s.detectLanguage(ctx, fact)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
points, _, err := s.store.Search(ctx, vector, 5, filters, s.vectorNameForText())
|
||||
termFreq, _, err := s.bm25.TermFrequencies(lang, fact)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
indices, values := s.bm25.BuildQueryVector(lang, termFreq)
|
||||
points, _, err := s.store.SearchSparse(ctx, indices, values, 5, filters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -406,25 +544,50 @@ func (s *Service) collectCandidates(ctx context.Context, facts []string, filters
|
||||
return candidates, nil
|
||||
}
|
||||
|
||||
func (s *Service) applyAdd(ctx context.Context, text string, filters map[string]interface{}, metadata map[string]interface{}) (MemoryItem, error) {
|
||||
vector, err := s.embedder.Embed(ctx, text)
|
||||
func (s *Service) applyAdd(ctx context.Context, text string, filters map[string]interface{}, metadata map[string]interface{}, embeddingEnabled bool) (MemoryItem, error) {
|
||||
if s.store == nil {
|
||||
return MemoryItem{}, fmt.Errorf("qdrant store not configured")
|
||||
}
|
||||
if s.bm25 == nil {
|
||||
return MemoryItem{}, fmt.Errorf("bm25 indexer not configured")
|
||||
}
|
||||
lang, err := s.detectLanguage(ctx, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
termFreq, docLen, err := s.bm25.TermFrequencies(lang, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
sparseIndices, sparseValues := s.bm25.AddDocument(lang, termFreq, docLen)
|
||||
id := uuid.NewString()
|
||||
payload := buildPayload(text, filters, metadata, "")
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{{
|
||||
ID: id,
|
||||
Vector: vector,
|
||||
VectorName: s.vectorNameForText(),
|
||||
Payload: payload,
|
||||
}}); err != nil {
|
||||
payload["lang"] = lang
|
||||
point := qdrantPoint{
|
||||
ID: id,
|
||||
SparseIndices: sparseIndices,
|
||||
SparseValues: sparseValues,
|
||||
SparseVectorName: s.store.sparseVectorName,
|
||||
Payload: payload,
|
||||
}
|
||||
if embeddingEnabled {
|
||||
if s.embedder == nil {
|
||||
return MemoryItem{}, fmt.Errorf("embedder not configured")
|
||||
}
|
||||
vector, err := s.embedder.Embed(ctx, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
point.Vector = vector
|
||||
point.VectorName = s.vectorNameForText()
|
||||
}
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
return payloadToMemoryItem(id, payload), nil
|
||||
}
|
||||
|
||||
func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[string]interface{}, metadata map[string]interface{}) (MemoryItem, error) {
|
||||
func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[string]interface{}, metadata map[string]interface{}, embeddingEnabled bool) (MemoryItem, error) {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return MemoryItem{}, fmt.Errorf("update action missing id")
|
||||
}
|
||||
@@ -437,25 +600,55 @@ func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[
|
||||
}
|
||||
|
||||
payload := existing.Payload
|
||||
oldText := fmt.Sprint(payload["data"])
|
||||
oldLang := fmt.Sprint(payload["lang"])
|
||||
if oldLang == "" && strings.TrimSpace(oldText) != "" {
|
||||
oldLang, _ = s.detectLanguage(ctx, oldText)
|
||||
}
|
||||
if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" {
|
||||
oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText)
|
||||
if err == nil {
|
||||
s.bm25.RemoveDocument(oldLang, oldFreq, oldLen)
|
||||
}
|
||||
}
|
||||
newLang, err := s.detectLanguage(ctx, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
newFreq, newLen, err := s.bm25.TermFrequencies(newLang, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
sparseIndices, sparseValues := s.bm25.AddDocument(newLang, newFreq, newLen)
|
||||
payload["data"] = text
|
||||
payload["hash"] = hashMemory(text)
|
||||
payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339)
|
||||
payload["lang"] = newLang
|
||||
if metadata != nil {
|
||||
payload["metadata"] = mergeMetadata(payload["metadata"], metadata)
|
||||
}
|
||||
if filters != nil {
|
||||
applyFiltersToPayload(payload, filters)
|
||||
}
|
||||
vector, err := s.embedder.Embed(ctx, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
point := qdrantPoint{
|
||||
ID: id,
|
||||
SparseIndices: sparseIndices,
|
||||
SparseValues: sparseValues,
|
||||
SparseVectorName: s.store.sparseVectorName,
|
||||
Payload: payload,
|
||||
}
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{{
|
||||
ID: id,
|
||||
Vector: vector,
|
||||
VectorName: s.vectorNameForText(),
|
||||
Payload: payload,
|
||||
}}); err != nil {
|
||||
if embeddingEnabled {
|
||||
if s.embedder == nil {
|
||||
return MemoryItem{}, fmt.Errorf("embedder not configured")
|
||||
}
|
||||
vector, err := s.embedder.Embed(ctx, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
point.Vector = vector
|
||||
point.VectorName = s.vectorNameForText()
|
||||
}
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
return payloadToMemoryItem(id, payload), nil
|
||||
@@ -473,6 +666,19 @@ func (s *Service) applyDelete(ctx context.Context, id string) (MemoryItem, error
|
||||
return MemoryItem{}, fmt.Errorf("memory not found")
|
||||
}
|
||||
item := payloadToMemoryItem(id, existing.Payload)
|
||||
if s.bm25 != nil {
|
||||
oldText := fmt.Sprint(existing.Payload["data"])
|
||||
oldLang := fmt.Sprint(existing.Payload["lang"])
|
||||
if oldLang == "" && strings.TrimSpace(oldText) != "" {
|
||||
oldLang, _ = s.detectLanguage(ctx, oldText)
|
||||
}
|
||||
if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" {
|
||||
oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText)
|
||||
if err == nil {
|
||||
s.bm25.RemoveDocument(oldLang, oldFreq, oldLen)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := s.store.Delete(ctx, id); err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
@@ -486,6 +692,56 @@ func normalizeMessages(req AddRequest) []Message {
|
||||
return []Message{{Role: "user", Content: req.Message}}
|
||||
}
|
||||
|
||||
func (s *Service) detectLanguage(ctx context.Context, text string) (string, error) {
|
||||
if s.llm == nil {
|
||||
return "", fmt.Errorf("language detector not configured")
|
||||
}
|
||||
lang, err := s.llm.DetectLanguage(ctx, text)
|
||||
if err == nil && lang != "" {
|
||||
return lang, nil
|
||||
}
|
||||
fallback := fallbackLanguageCode(text)
|
||||
if s.logger != nil {
|
||||
s.logger.Warn("language detection failed; using fallback", slog.Any("error", err), slog.String("fallback", fallback))
|
||||
}
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func fallbackLanguageCode(text string) string {
|
||||
for _, r := range text {
|
||||
if isCJKRune(r) {
|
||||
return "cjk"
|
||||
}
|
||||
}
|
||||
return "en"
|
||||
}
|
||||
|
||||
func isCJKRune(r rune) bool {
|
||||
switch {
|
||||
case r >= 0x4E00 && r <= 0x9FFF: // CJK Unified Ideographs
|
||||
return true
|
||||
case r >= 0x3400 && r <= 0x4DBF: // CJK Unified Ideographs Extension A
|
||||
return true
|
||||
case r >= 0x20000 && r <= 0x2A6DF: // CJK Unified Ideographs Extension B
|
||||
return true
|
||||
case r >= 0x2A700 && r <= 0x2B73F: // CJK Unified Ideographs Extension C
|
||||
return true
|
||||
case r >= 0x2B740 && r <= 0x2B81F: // CJK Unified Ideographs Extension D
|
||||
return true
|
||||
case r >= 0x2B820 && r <= 0x2CEAF: // CJK Unified Ideographs Extension E
|
||||
return true
|
||||
case r >= 0x2CEB0 && r <= 0x2EBEF: // CJK Unified Ideographs Extension F
|
||||
return true
|
||||
case r >= 0x3000 && r <= 0x303F: // CJK Symbols and Punctuation
|
||||
return true
|
||||
case r >= 0x3040 && r <= 0x30FF: // Hiragana/Katakana
|
||||
return true
|
||||
case r >= 0xAC00 && r <= 0xD7AF: // Hangul Syllables
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildFilters(req AddRequest) map[string]interface{} {
|
||||
filters := map[string]interface{}{}
|
||||
for key, value := range req.Filters {
|
||||
@@ -494,9 +750,6 @@ func buildFilters(req AddRequest) map[string]interface{} {
|
||||
if req.UserID != "" {
|
||||
filters["userId"] = req.UserID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agentId"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
}
|
||||
@@ -511,9 +764,6 @@ func buildSearchFilters(req SearchRequest) map[string]interface{} {
|
||||
if req.UserID != "" {
|
||||
filters["userId"] = req.UserID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agentId"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
}
|
||||
@@ -528,9 +778,6 @@ func buildEmbedFilters(req EmbedUpsertRequest) map[string]interface{} {
|
||||
if req.UserID != "" {
|
||||
filters["userId"] = req.UserID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agentId"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
}
|
||||
@@ -621,9 +868,6 @@ func payloadToMemoryItem(id string, payload map[string]interface{}) MemoryItem {
|
||||
if v, ok := payload["userId"].(string); ok {
|
||||
item.UserID = v
|
||||
}
|
||||
if v, ok := payload["agentId"].(string); ok {
|
||||
item.AgentID = v
|
||||
}
|
||||
if v, ok := payload["runId"].(string); ok {
|
||||
item.RunID = v
|
||||
}
|
||||
|
||||
+24
-26
@@ -6,6 +6,7 @@ import "context"
|
||||
type LLM interface {
|
||||
Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error)
|
||||
Decide(ctx context.Context, req DecideRequest) (DecideResponse, error)
|
||||
DetectLanguage(ctx context.Context, text string) (string, error)
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
@@ -14,42 +15,41 @@ type Message struct {
|
||||
}
|
||||
|
||||
type AddRequest struct {
|
||||
Message string `json:"message,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
Infer *bool `json:"infer,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
Infer *bool `json:"infer,omitempty"`
|
||||
EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"`
|
||||
}
|
||||
|
||||
type SearchRequest struct {
|
||||
Query string `json:"query"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
Sources []string `json:"sources,omitempty"`
|
||||
Query string `json:"query"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
Sources []string `json:"sources,omitempty"`
|
||||
EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"`
|
||||
}
|
||||
|
||||
type UpdateRequest struct {
|
||||
MemoryID string `json:"memory_id"`
|
||||
Memory string `json:"memory"`
|
||||
MemoryID string `json:"memory_id"`
|
||||
Memory string `json:"memory"`
|
||||
EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"`
|
||||
}
|
||||
|
||||
type GetAllRequest struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
}
|
||||
|
||||
type DeleteAllRequest struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
}
|
||||
|
||||
type EmbedInput struct {
|
||||
@@ -65,7 +65,6 @@ type EmbedUpsertRequest struct {
|
||||
Input EmbedInput `json:"input"`
|
||||
Source string `json:"source,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
@@ -87,7 +86,6 @@ type MemoryItem struct {
|
||||
Score float64 `json:"score,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
UserID string `json:"userId,omitempty"`
|
||||
AgentID string `json:"agentId,omitempty"`
|
||||
RunID string `json:"runId,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user