Files
Memoh/internal/memory/adapters/builtin/sparse_runtime_test.go
T
2026-03-24 06:18:16 +08:00

414 lines
12 KiB
Go

package builtin
import (
"context"
"log/slog"
"sort"
"strings"
"testing"
adapters "github.com/memohai/memoh/internal/memory/adapters"
qdrantclient "github.com/memohai/memoh/internal/memory/qdrant"
"github.com/memohai/memoh/internal/memory/sparse"
storefs "github.com/memohai/memoh/internal/memory/storefs"
)
type fakeSparseStore struct {
items map[string]storefs.MemoryItem
}
func newFakeSparseStore(items ...storefs.MemoryItem) *fakeSparseStore {
store := &fakeSparseStore{items: map[string]storefs.MemoryItem{}}
for _, item := range items {
store.items[item.ID] = item
}
return store
}
func (s *fakeSparseStore) PersistMemories(_ context.Context, _ string, items []storefs.MemoryItem, _ map[string]any) error {
for _, item := range items {
s.items[item.ID] = item
}
return nil
}
func (s *fakeSparseStore) ReadAllMemoryFiles(_ context.Context, _ string) ([]storefs.MemoryItem, error) {
out := make([]storefs.MemoryItem, 0, len(s.items))
for _, item := range s.items {
out = append(out, item)
}
sort.Slice(out, func(i, j int) bool { return out[i].ID < out[j].ID })
return out, nil
}
func (s *fakeSparseStore) RemoveMemories(_ context.Context, _ string, ids []string) error {
for _, id := range ids {
delete(s.items, strings.TrimSpace(id))
}
return nil
}
func (s *fakeSparseStore) RemoveAllMemories(_ context.Context, _ string) error {
s.items = map[string]storefs.MemoryItem{}
return nil
}
func (s *fakeSparseStore) RebuildFiles(_ context.Context, _ string, items []storefs.MemoryItem, _ map[string]any) error {
s.items = map[string]storefs.MemoryItem{}
for _, item := range items {
s.items[item.ID] = item
}
return nil
}
func (*fakeSparseStore) SyncOverview(context.Context, string) error { return nil }
func (s *fakeSparseStore) CountMemoryFiles(_ context.Context, _ string) (int, error) {
if len(s.items) == 0 {
return 0, nil
}
return 1, nil
}
type fakeSparseEncoder struct {
lastQuery string
}
func (*fakeSparseEncoder) EncodeDocument(_ context.Context, _ string) (*sparse.SparseVector, error) {
return &sparse.SparseVector{Indices: []uint32{1, 2, 3}, Values: []float32{1, 3, 2}}, nil
}
func (*fakeSparseEncoder) EncodeDocuments(_ context.Context, texts []string) ([]sparse.SparseVector, error) {
out := make([]sparse.SparseVector, 0, len(texts))
for _, text := range texts {
_ = text
out = append(out, sparse.SparseVector{Indices: []uint32{1, 2, 3}, Values: []float32{1, 3, 2}})
}
return out, nil
}
func (e *fakeSparseEncoder) EncodeQuery(_ context.Context, text string) (*sparse.SparseVector, error) {
e.lastQuery = text
return &sparse.SparseVector{Indices: []uint32{9}, Values: []float32{1}}, nil
}
func (*fakeSparseEncoder) Health(context.Context) error { return nil }
type fakeSparseIndex struct {
encoder *fakeSparseEncoder
collection string
exists bool
points map[string]qdrantclient.SearchResult
}
func newFakeSparseIndex(encoder *fakeSparseEncoder) *fakeSparseIndex {
return &fakeSparseIndex{
encoder: encoder,
collection: "memory_sparse_test",
points: map[string]qdrantclient.SearchResult{},
}
}
func (i *fakeSparseIndex) CollectionName() string { return i.collection }
func (i *fakeSparseIndex) CollectionExists(context.Context) (bool, error) { return i.exists, nil }
func (i *fakeSparseIndex) EnsureCollection(context.Context) error {
i.exists = true
return nil
}
func (i *fakeSparseIndex) Upsert(_ context.Context, id string, _ qdrantclient.SparseVector, payload map[string]string) error {
i.exists = true
i.points[id] = qdrantclient.SearchResult{
ID: id,
Score: 1,
Payload: payload,
}
return nil
}
func (i *fakeSparseIndex) Search(_ context.Context, _ qdrantclient.SparseVector, botID string, limit int) ([]qdrantclient.SearchResult, error) {
query := strings.ToLower(strings.TrimSpace(i.encoder.lastQuery))
results := make([]qdrantclient.SearchResult, 0, len(i.points))
for _, point := range i.points {
if strings.TrimSpace(point.Payload["bot_id"]) != strings.TrimSpace(botID) {
continue
}
text := strings.ToLower(point.Payload["memory"])
if query != "" && !strings.Contains(text, query) {
continue
}
point.Score = 1
results = append(results, point)
}
sort.Slice(results, func(a, b int) bool { return results[a].ID < results[b].ID })
if limit > 0 && len(results) > limit {
results = results[:limit]
}
return results, nil
}
func (i *fakeSparseIndex) Scroll(_ context.Context, botID string, limit int) ([]qdrantclient.SearchResult, error) {
results := make([]qdrantclient.SearchResult, 0, len(i.points))
for _, point := range i.points {
if strings.TrimSpace(point.Payload["bot_id"]) != strings.TrimSpace(botID) {
continue
}
results = append(results, point)
}
sort.Slice(results, func(a, b int) bool { return results[a].ID < results[b].ID })
if limit > 0 && len(results) > limit {
results = results[:limit]
}
return results, nil
}
func (i *fakeSparseIndex) Count(_ context.Context, botID string) (int, error) {
count := 0
for _, point := range i.points {
if strings.TrimSpace(point.Payload["bot_id"]) == strings.TrimSpace(botID) {
count++
}
}
return count, nil
}
func (i *fakeSparseIndex) DeleteByIDs(_ context.Context, ids []string) error {
for _, id := range ids {
delete(i.points, strings.TrimSpace(id))
}
return nil
}
func (i *fakeSparseIndex) DeleteByBotID(_ context.Context, botID string) error {
for id, point := range i.points {
if strings.TrimSpace(point.Payload["bot_id"]) == strings.TrimSpace(botID) {
delete(i.points, id)
}
}
return nil
}
func TestSparseRuntimeAddWritesSourceAndSupportsRecall(t *testing.T) {
t.Parallel()
encoder := &fakeSparseEncoder{}
index := newFakeSparseIndex(encoder)
store := newFakeSparseStore()
runtime := &sparseRuntime{
qdrant: index,
encoder: encoder,
store: store,
}
resp, err := runtime.Add(context.Background(), adapters.AddRequest{
BotID: "bot-1",
Message: "Ran likes oolong tea",
Filters: map[string]any{"scopeId": "bot-1"},
})
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if len(resp.Results) != 1 {
t.Fatalf("expected 1 add result, got %d", len(resp.Results))
}
item := resp.Results[0]
if item.ID == "" {
t.Fatal("expected source memory id to be populated")
}
if _, ok := store.items[item.ID]; !ok {
t.Fatalf("expected memory %q to be written to markdown source", item.ID)
}
point, ok := index.points[runtimePointID("bot-1", item.ID)]
if !ok {
t.Fatalf("expected qdrant point for source memory %q", item.ID)
}
if point.Payload["source_entry_id"] != item.ID {
t.Fatalf("expected source_entry_id payload %q, got %q", item.ID, point.Payload["source_entry_id"])
}
if len(item.TopKBuckets) != 0 || len(item.CDFCurve) != 0 {
t.Fatalf("expected add response to skip explain stats, got %#v", item)
}
searchResp, err := runtime.Search(context.Background(), adapters.SearchRequest{
BotID: "bot-1",
Query: "oolong tea",
})
if err != nil {
t.Fatalf("Search() error = %v", err)
}
if len(searchResp.Results) != 1 {
t.Fatalf("expected 1 search result, got %d", len(searchResp.Results))
}
if searchResp.Results[0].ID != item.ID {
t.Fatalf("expected search result id %q, got %q", item.ID, searchResp.Results[0].ID)
}
if len(searchResp.Results[0].TopKBuckets) != 0 || len(searchResp.Results[0].CDFCurve) != 0 {
t.Fatalf("expected search result to skip explain stats, got %#v", searchResp.Results[0])
}
}
func TestSparseRuntimeRebuildSyncsSourceAndRemovesStalePoints(t *testing.T) {
t.Parallel()
encoder := &fakeSparseEncoder{}
index := newFakeSparseIndex(encoder)
store := newFakeSparseStore(
storefs.MemoryItem{
ID: "bot-1:mem_1",
Memory: "Ran likes tea",
Hash: runtimeHash("Ran likes tea"),
CreatedAt: "2026-03-13T09:00:00Z",
UpdatedAt: "2026-03-13T09:00:00Z",
},
storefs.MemoryItem{
ID: "bot-1:mem_2",
Memory: "Ran works in Berlin",
Hash: runtimeHash("Ran works in Berlin"),
CreatedAt: "2026-03-13T10:00:00Z",
UpdatedAt: "2026-03-13T10:00:00Z",
},
)
runtime := &sparseRuntime{
qdrant: index,
encoder: encoder,
store: store,
}
index.points[runtimePointID("bot-1", "bot-1:mem_1")] = qdrantclient.SearchResult{
ID: runtimePointID("bot-1", "bot-1:mem_1"),
Payload: map[string]string{
"bot_id": "bot-1",
"memory": "Ran likes tea",
"source_entry_id": "bot-1:mem_1",
"hash": "outdated",
"created_at": "2026-03-13T09:00:00Z",
"updated_at": "2026-03-13T09:00:00Z",
},
}
index.points[runtimePointID("bot-1", "bot-1:stale")] = qdrantclient.SearchResult{
ID: runtimePointID("bot-1", "bot-1:stale"),
Payload: map[string]string{
"bot_id": "bot-1",
"memory": "stale memory",
"source_entry_id": "bot-1:stale",
},
}
result, err := runtime.Rebuild(context.Background(), "bot-1")
if err != nil {
t.Fatalf("Rebuild() error = %v", err)
}
if result.FsCount != 2 {
t.Fatalf("expected fs_count=2, got %d", result.FsCount)
}
if result.StorageCount != 2 {
t.Fatalf("expected storage_count=2, got %d", result.StorageCount)
}
if result.MissingCount != 1 {
t.Fatalf("expected missing_count=1, got %d", result.MissingCount)
}
if result.RestoredCount != 2 {
t.Fatalf("expected restored_count=2, got %d", result.RestoredCount)
}
if _, ok := index.points[runtimePointID("bot-1", "bot-1:stale")]; ok {
t.Fatal("expected stale qdrant point to be removed")
}
}
func TestSparseRuntimeGetAllIncludesExplainStats(t *testing.T) {
t.Parallel()
encoder := &fakeSparseEncoder{}
index := newFakeSparseIndex(encoder)
store := newFakeSparseStore(
storefs.MemoryItem{
ID: "bot-1:mem_1",
Memory: "Ran likes tea",
Hash: runtimeHash("Ran likes tea"),
CreatedAt: "2026-03-13T09:00:00Z",
UpdatedAt: "2026-03-13T09:00:00Z",
},
)
runtime := &sparseRuntime{
qdrant: index,
encoder: encoder,
store: store,
}
resp, err := runtime.GetAll(context.Background(), adapters.GetAllRequest{BotID: "bot-1"})
if err != nil {
t.Fatalf("GetAll() error = %v", err)
}
if len(resp.Results) != 1 {
t.Fatalf("expected 1 result, got %d", len(resp.Results))
}
if len(resp.Results[0].TopKBuckets) == 0 || len(resp.Results[0].CDFCurve) == 0 {
t.Fatalf("expected get all result to include explain stats, got %#v", resp.Results[0])
}
if resp.Results[0].TopKBuckets[0].Index != 2 {
t.Fatalf("expected top bucket index 2, got %d", resp.Results[0].TopKBuckets[0].Index)
}
if got := resp.Results[0].CDFCurve[len(resp.Results[0].CDFCurve)-1].Cumulative; got != 1 {
t.Fatalf("expected final CDF cumulative to be 1, got %v", got)
}
}
func TestBuiltinProviderMultiTurnRecallUsesSparseSourceRuntime(t *testing.T) {
t.Parallel()
encoder := &fakeSparseEncoder{}
index := newFakeSparseIndex(encoder)
store := newFakeSparseStore()
runtime := &sparseRuntime{
qdrant: index,
encoder: encoder,
store: store,
}
provider := NewBuiltinProvider(slog.Default(), runtime, nil, nil)
err := provider.OnAfterChat(context.Background(), adapters.AfterChatRequest{
BotID: "bot-1",
Messages: []adapters.Message{
{Role: "user", Content: "I like oolong tea."},
{Role: "assistant", Content: "Noted, you like oolong tea."},
},
})
if err != nil {
t.Fatalf("OnAfterChat round 1 error = %v", err)
}
err = provider.OnAfterChat(context.Background(), adapters.AfterChatRequest{
BotID: "bot-1",
Messages: []adapters.Message{
{Role: "user", Content: "I am based in Berlin."},
{Role: "assistant", Content: "Understood, you are based in Berlin."},
},
})
if err != nil {
t.Fatalf("OnAfterChat round 2 error = %v", err)
}
before, err := provider.OnBeforeChat(context.Background(), adapters.BeforeChatRequest{
BotID: "bot-1",
Query: "berlin",
})
if err != nil {
t.Fatalf("OnBeforeChat() error = %v", err)
}
if before == nil || !strings.Contains(strings.ToLower(before.ContextText), "berlin") {
t.Fatalf("expected recalled context to mention berlin, got %#v", before)
}
before, err = provider.OnBeforeChat(context.Background(), adapters.BeforeChatRequest{
BotID: "bot-1",
Query: "oolong tea",
})
if err != nil {
t.Fatalf("OnBeforeChat() tea error = %v", err)
}
if before == nil || !strings.Contains(strings.ToLower(before.ContextText), "oolong tea") {
t.Fatalf("expected recalled context to mention oolong tea, got %#v", before)
}
}