mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
627b673a5c
* refactor: restructure memory into multi-provider adapters, remove manifest.json dependency - Rename internal/memory/provider to internal/memory/adapters with per-provider subdirectories (builtin, mem0, openviking) - Replace manifest.json-based delete/update with scan-based index from daily files - Add mem0 and openviking provider adapters with HTTP client, chat hooks, MCP tools, and CRUD - Wire provider lifecycle into registry (auto-instantiate on create, evict on update/delete) - Split docker-compose into base stack + optional overlays (qdrant, browser, mem0, openviking) - Update admin UI to support dynamic provider config schema rendering * chore(lint): fix all golangci-lint issues for clean CI * refactor(docker): replace compose overlay files with profiles * feat(memory): add built-in memory multi modes * fix(ci): golangci lint * feat(memory): edit built-in memory sparse design
414 lines
12 KiB
Go
414 lines
12 KiB
Go
package builtin
|
|
|
|
import (
|
|
"context"
|
|
"log/slog"
|
|
"sort"
|
|
"strings"
|
|
"testing"
|
|
|
|
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
|
qdrantclient "github.com/memohai/memoh/internal/memory/qdrant"
|
|
"github.com/memohai/memoh/internal/memory/sparse"
|
|
storefs "github.com/memohai/memoh/internal/memory/storefs"
|
|
)
|
|
|
|
type fakeSparseStore struct {
|
|
items map[string]storefs.MemoryItem
|
|
}
|
|
|
|
func newFakeSparseStore(items ...storefs.MemoryItem) *fakeSparseStore {
|
|
store := &fakeSparseStore{items: map[string]storefs.MemoryItem{}}
|
|
for _, item := range items {
|
|
store.items[item.ID] = item
|
|
}
|
|
return store
|
|
}
|
|
|
|
func (s *fakeSparseStore) PersistMemories(_ context.Context, _ string, items []storefs.MemoryItem, _ map[string]any) error {
|
|
for _, item := range items {
|
|
s.items[item.ID] = item
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *fakeSparseStore) ReadAllMemoryFiles(_ context.Context, _ string) ([]storefs.MemoryItem, error) {
|
|
out := make([]storefs.MemoryItem, 0, len(s.items))
|
|
for _, item := range s.items {
|
|
out = append(out, item)
|
|
}
|
|
sort.Slice(out, func(i, j int) bool { return out[i].ID < out[j].ID })
|
|
return out, nil
|
|
}
|
|
|
|
func (s *fakeSparseStore) RemoveMemories(_ context.Context, _ string, ids []string) error {
|
|
for _, id := range ids {
|
|
delete(s.items, strings.TrimSpace(id))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *fakeSparseStore) RemoveAllMemories(_ context.Context, _ string) error {
|
|
s.items = map[string]storefs.MemoryItem{}
|
|
return nil
|
|
}
|
|
|
|
func (s *fakeSparseStore) RebuildFiles(_ context.Context, _ string, items []storefs.MemoryItem, _ map[string]any) error {
|
|
s.items = map[string]storefs.MemoryItem{}
|
|
for _, item := range items {
|
|
s.items[item.ID] = item
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (*fakeSparseStore) SyncOverview(context.Context, string) error { return nil }
|
|
|
|
func (s *fakeSparseStore) CountMemoryFiles(_ context.Context, _ string) (int, error) {
|
|
if len(s.items) == 0 {
|
|
return 0, nil
|
|
}
|
|
return 1, nil
|
|
}
|
|
|
|
type fakeSparseEncoder struct {
|
|
lastQuery string
|
|
}
|
|
|
|
func (*fakeSparseEncoder) EncodeDocument(_ context.Context, _ string) (*sparse.SparseVector, error) {
|
|
return &sparse.SparseVector{Indices: []uint32{1, 2, 3}, Values: []float32{1, 3, 2}}, nil
|
|
}
|
|
|
|
func (*fakeSparseEncoder) EncodeDocuments(_ context.Context, texts []string) ([]sparse.SparseVector, error) {
|
|
out := make([]sparse.SparseVector, 0, len(texts))
|
|
for _, text := range texts {
|
|
_ = text
|
|
out = append(out, sparse.SparseVector{Indices: []uint32{1, 2, 3}, Values: []float32{1, 3, 2}})
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (e *fakeSparseEncoder) EncodeQuery(_ context.Context, text string) (*sparse.SparseVector, error) {
|
|
e.lastQuery = text
|
|
return &sparse.SparseVector{Indices: []uint32{9}, Values: []float32{1}}, nil
|
|
}
|
|
|
|
func (*fakeSparseEncoder) Health(context.Context) error { return nil }
|
|
|
|
type fakeSparseIndex struct {
|
|
encoder *fakeSparseEncoder
|
|
collection string
|
|
exists bool
|
|
points map[string]qdrantclient.SearchResult
|
|
}
|
|
|
|
func newFakeSparseIndex(encoder *fakeSparseEncoder) *fakeSparseIndex {
|
|
return &fakeSparseIndex{
|
|
encoder: encoder,
|
|
collection: "memory_sparse_test",
|
|
points: map[string]qdrantclient.SearchResult{},
|
|
}
|
|
}
|
|
|
|
func (i *fakeSparseIndex) CollectionName() string { return i.collection }
|
|
|
|
func (i *fakeSparseIndex) CollectionExists(context.Context) (bool, error) { return i.exists, nil }
|
|
|
|
func (i *fakeSparseIndex) EnsureCollection(context.Context) error {
|
|
i.exists = true
|
|
return nil
|
|
}
|
|
|
|
func (i *fakeSparseIndex) Upsert(_ context.Context, id string, _ qdrantclient.SparseVector, payload map[string]string) error {
|
|
i.exists = true
|
|
i.points[id] = qdrantclient.SearchResult{
|
|
ID: id,
|
|
Score: 1,
|
|
Payload: payload,
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (i *fakeSparseIndex) Search(_ context.Context, _ qdrantclient.SparseVector, botID string, limit int) ([]qdrantclient.SearchResult, error) {
|
|
query := strings.ToLower(strings.TrimSpace(i.encoder.lastQuery))
|
|
results := make([]qdrantclient.SearchResult, 0, len(i.points))
|
|
for _, point := range i.points {
|
|
if strings.TrimSpace(point.Payload["bot_id"]) != strings.TrimSpace(botID) {
|
|
continue
|
|
}
|
|
text := strings.ToLower(point.Payload["memory"])
|
|
if query != "" && !strings.Contains(text, query) {
|
|
continue
|
|
}
|
|
point.Score = 1
|
|
results = append(results, point)
|
|
}
|
|
sort.Slice(results, func(a, b int) bool { return results[a].ID < results[b].ID })
|
|
if limit > 0 && len(results) > limit {
|
|
results = results[:limit]
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func (i *fakeSparseIndex) Scroll(_ context.Context, botID string, limit int) ([]qdrantclient.SearchResult, error) {
|
|
results := make([]qdrantclient.SearchResult, 0, len(i.points))
|
|
for _, point := range i.points {
|
|
if strings.TrimSpace(point.Payload["bot_id"]) != strings.TrimSpace(botID) {
|
|
continue
|
|
}
|
|
results = append(results, point)
|
|
}
|
|
sort.Slice(results, func(a, b int) bool { return results[a].ID < results[b].ID })
|
|
if limit > 0 && len(results) > limit {
|
|
results = results[:limit]
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func (i *fakeSparseIndex) Count(_ context.Context, botID string) (int, error) {
|
|
count := 0
|
|
for _, point := range i.points {
|
|
if strings.TrimSpace(point.Payload["bot_id"]) == strings.TrimSpace(botID) {
|
|
count++
|
|
}
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
func (i *fakeSparseIndex) DeleteByIDs(_ context.Context, ids []string) error {
|
|
for _, id := range ids {
|
|
delete(i.points, strings.TrimSpace(id))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (i *fakeSparseIndex) DeleteByBotID(_ context.Context, botID string) error {
|
|
for id, point := range i.points {
|
|
if strings.TrimSpace(point.Payload["bot_id"]) == strings.TrimSpace(botID) {
|
|
delete(i.points, id)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func TestSparseRuntimeAddWritesSourceAndSupportsRecall(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
encoder := &fakeSparseEncoder{}
|
|
index := newFakeSparseIndex(encoder)
|
|
store := newFakeSparseStore()
|
|
runtime := &sparseRuntime{
|
|
qdrant: index,
|
|
encoder: encoder,
|
|
store: store,
|
|
}
|
|
|
|
resp, err := runtime.Add(context.Background(), adapters.AddRequest{
|
|
BotID: "bot-1",
|
|
Message: "Ran likes oolong tea",
|
|
Filters: map[string]any{"scopeId": "bot-1"},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Add() error = %v", err)
|
|
}
|
|
if len(resp.Results) != 1 {
|
|
t.Fatalf("expected 1 add result, got %d", len(resp.Results))
|
|
}
|
|
item := resp.Results[0]
|
|
if item.ID == "" {
|
|
t.Fatal("expected source memory id to be populated")
|
|
}
|
|
if _, ok := store.items[item.ID]; !ok {
|
|
t.Fatalf("expected memory %q to be written to markdown source", item.ID)
|
|
}
|
|
point, ok := index.points[sparsePointID("bot-1", item.ID)]
|
|
if !ok {
|
|
t.Fatalf("expected qdrant point for source memory %q", item.ID)
|
|
}
|
|
if point.Payload["source_entry_id"] != item.ID {
|
|
t.Fatalf("expected source_entry_id payload %q, got %q", item.ID, point.Payload["source_entry_id"])
|
|
}
|
|
if len(item.TopKBuckets) != 0 || len(item.CDFCurve) != 0 {
|
|
t.Fatalf("expected add response to skip explain stats, got %#v", item)
|
|
}
|
|
|
|
searchResp, err := runtime.Search(context.Background(), adapters.SearchRequest{
|
|
BotID: "bot-1",
|
|
Query: "oolong tea",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Search() error = %v", err)
|
|
}
|
|
if len(searchResp.Results) != 1 {
|
|
t.Fatalf("expected 1 search result, got %d", len(searchResp.Results))
|
|
}
|
|
if searchResp.Results[0].ID != item.ID {
|
|
t.Fatalf("expected search result id %q, got %q", item.ID, searchResp.Results[0].ID)
|
|
}
|
|
if len(searchResp.Results[0].TopKBuckets) != 0 || len(searchResp.Results[0].CDFCurve) != 0 {
|
|
t.Fatalf("expected search result to skip explain stats, got %#v", searchResp.Results[0])
|
|
}
|
|
}
|
|
|
|
func TestSparseRuntimeRebuildSyncsSourceAndRemovesStalePoints(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
encoder := &fakeSparseEncoder{}
|
|
index := newFakeSparseIndex(encoder)
|
|
store := newFakeSparseStore(
|
|
storefs.MemoryItem{
|
|
ID: "bot-1:mem_1",
|
|
Memory: "Ran likes tea",
|
|
Hash: sparseRuntimeHash("Ran likes tea"),
|
|
CreatedAt: "2026-03-13T09:00:00Z",
|
|
UpdatedAt: "2026-03-13T09:00:00Z",
|
|
},
|
|
storefs.MemoryItem{
|
|
ID: "bot-1:mem_2",
|
|
Memory: "Ran works in Berlin",
|
|
Hash: sparseRuntimeHash("Ran works in Berlin"),
|
|
CreatedAt: "2026-03-13T10:00:00Z",
|
|
UpdatedAt: "2026-03-13T10:00:00Z",
|
|
},
|
|
)
|
|
runtime := &sparseRuntime{
|
|
qdrant: index,
|
|
encoder: encoder,
|
|
store: store,
|
|
}
|
|
|
|
index.points[sparsePointID("bot-1", "bot-1:mem_1")] = qdrantclient.SearchResult{
|
|
ID: sparsePointID("bot-1", "bot-1:mem_1"),
|
|
Payload: map[string]string{
|
|
"bot_id": "bot-1",
|
|
"memory": "Ran likes tea",
|
|
"source_entry_id": "bot-1:mem_1",
|
|
"hash": "outdated",
|
|
"created_at": "2026-03-13T09:00:00Z",
|
|
"updated_at": "2026-03-13T09:00:00Z",
|
|
},
|
|
}
|
|
index.points[sparsePointID("bot-1", "bot-1:stale")] = qdrantclient.SearchResult{
|
|
ID: sparsePointID("bot-1", "bot-1:stale"),
|
|
Payload: map[string]string{
|
|
"bot_id": "bot-1",
|
|
"memory": "stale memory",
|
|
"source_entry_id": "bot-1:stale",
|
|
},
|
|
}
|
|
|
|
result, err := runtime.Rebuild(context.Background(), "bot-1")
|
|
if err != nil {
|
|
t.Fatalf("Rebuild() error = %v", err)
|
|
}
|
|
if result.FsCount != 2 {
|
|
t.Fatalf("expected fs_count=2, got %d", result.FsCount)
|
|
}
|
|
if result.StorageCount != 2 {
|
|
t.Fatalf("expected storage_count=2, got %d", result.StorageCount)
|
|
}
|
|
if result.MissingCount != 1 {
|
|
t.Fatalf("expected missing_count=1, got %d", result.MissingCount)
|
|
}
|
|
if result.RestoredCount != 2 {
|
|
t.Fatalf("expected restored_count=2, got %d", result.RestoredCount)
|
|
}
|
|
if _, ok := index.points[sparsePointID("bot-1", "bot-1:stale")]; ok {
|
|
t.Fatal("expected stale qdrant point to be removed")
|
|
}
|
|
}
|
|
|
|
func TestSparseRuntimeGetAllIncludesExplainStats(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
encoder := &fakeSparseEncoder{}
|
|
index := newFakeSparseIndex(encoder)
|
|
store := newFakeSparseStore(
|
|
storefs.MemoryItem{
|
|
ID: "bot-1:mem_1",
|
|
Memory: "Ran likes tea",
|
|
Hash: sparseRuntimeHash("Ran likes tea"),
|
|
CreatedAt: "2026-03-13T09:00:00Z",
|
|
UpdatedAt: "2026-03-13T09:00:00Z",
|
|
},
|
|
)
|
|
runtime := &sparseRuntime{
|
|
qdrant: index,
|
|
encoder: encoder,
|
|
store: store,
|
|
}
|
|
|
|
resp, err := runtime.GetAll(context.Background(), adapters.GetAllRequest{BotID: "bot-1"})
|
|
if err != nil {
|
|
t.Fatalf("GetAll() error = %v", err)
|
|
}
|
|
if len(resp.Results) != 1 {
|
|
t.Fatalf("expected 1 result, got %d", len(resp.Results))
|
|
}
|
|
if len(resp.Results[0].TopKBuckets) == 0 || len(resp.Results[0].CDFCurve) == 0 {
|
|
t.Fatalf("expected get all result to include explain stats, got %#v", resp.Results[0])
|
|
}
|
|
if resp.Results[0].TopKBuckets[0].Index != 2 {
|
|
t.Fatalf("expected top bucket index 2, got %d", resp.Results[0].TopKBuckets[0].Index)
|
|
}
|
|
if got := resp.Results[0].CDFCurve[len(resp.Results[0].CDFCurve)-1].Cumulative; got != 1 {
|
|
t.Fatalf("expected final CDF cumulative to be 1, got %v", got)
|
|
}
|
|
}
|
|
|
|
func TestBuiltinProviderMultiTurnRecallUsesSparseSourceRuntime(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
encoder := &fakeSparseEncoder{}
|
|
index := newFakeSparseIndex(encoder)
|
|
store := newFakeSparseStore()
|
|
runtime := &sparseRuntime{
|
|
qdrant: index,
|
|
encoder: encoder,
|
|
store: store,
|
|
}
|
|
provider := NewBuiltinProvider(slog.Default(), runtime, nil, nil)
|
|
|
|
err := provider.OnAfterChat(context.Background(), adapters.AfterChatRequest{
|
|
BotID: "bot-1",
|
|
Messages: []adapters.Message{
|
|
{Role: "user", Content: "I like oolong tea."},
|
|
{Role: "assistant", Content: "Noted, you like oolong tea."},
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("OnAfterChat round 1 error = %v", err)
|
|
}
|
|
err = provider.OnAfterChat(context.Background(), adapters.AfterChatRequest{
|
|
BotID: "bot-1",
|
|
Messages: []adapters.Message{
|
|
{Role: "user", Content: "I am based in Berlin."},
|
|
{Role: "assistant", Content: "Understood, you are based in Berlin."},
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("OnAfterChat round 2 error = %v", err)
|
|
}
|
|
|
|
before, err := provider.OnBeforeChat(context.Background(), adapters.BeforeChatRequest{
|
|
BotID: "bot-1",
|
|
Query: "berlin",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("OnBeforeChat() error = %v", err)
|
|
}
|
|
if before == nil || !strings.Contains(strings.ToLower(before.ContextText), "berlin") {
|
|
t.Fatalf("expected recalled context to mention berlin, got %#v", before)
|
|
}
|
|
|
|
before, err = provider.OnBeforeChat(context.Background(), adapters.BeforeChatRequest{
|
|
BotID: "bot-1",
|
|
Query: "oolong tea",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("OnBeforeChat() tea error = %v", err)
|
|
}
|
|
if before == nil || !strings.Contains(strings.ToLower(before.ContextText), "oolong tea") {
|
|
t.Fatalf("expected recalled context to mention oolong tea, got %#v", before)
|
|
}
|
|
}
|