refactor: multi-provider memory adapters with scan-based builtin (#227)

* refactor: restructure memory into multi-provider adapters, remove manifest.json dependency

- Rename internal/memory/provider to internal/memory/adapters with per-provider subdirectories (builtin, mem0, openviking)
- Replace manifest.json-based delete/update with scan-based index from daily files
- Add mem0 and openviking provider adapters with HTTP client, chat hooks, MCP tools, and CRUD
- Wire provider lifecycle into registry (auto-instantiate on create, evict on update/delete)
- Split docker-compose into base stack + optional overlays (qdrant, browser, mem0, openviking)
- Update admin UI to support dynamic provider config schema rendering

* chore(lint): fix all golangci-lint issues for clean CI

* refactor(docker): replace compose overlay files with profiles

* feat(memory): add built-in memory multi modes

* fix(ci): golangci lint

* feat(memory): edit built-in memory sparse design
This commit is contained in:
晨苒
2026-03-14 06:04:13 +08:00
committed by GitHub
parent 27607d582d
commit 627b673a5c
75 changed files with 8253 additions and 2107 deletions
@@ -1,4 +1,4 @@
package provider
package builtin
import (
"context"
@@ -9,6 +9,7 @@ import (
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/mcp"
adapters "github.com/memohai/memoh/internal/memory/adapters"
)
const (
@@ -36,15 +37,18 @@ type BuiltinProvider struct {
// It is intentionally defined as an interface to decouple provider wiring from
// concrete service structs in the memory package.
type memoryRuntime interface {
Add(ctx context.Context, req AddRequest) (SearchResponse, error)
Search(ctx context.Context, req SearchRequest) (SearchResponse, error)
GetAll(ctx context.Context, req GetAllRequest) (SearchResponse, error)
Update(ctx context.Context, req UpdateRequest) (MemoryItem, error)
Delete(ctx context.Context, memoryID string) (DeleteResponse, error)
DeleteBatch(ctx context.Context, memoryIDs []string) (DeleteResponse, error)
DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteResponse, error)
Compact(ctx context.Context, filters map[string]any, ratio float64, decayDays int) (CompactResult, error)
Usage(ctx context.Context, filters map[string]any) (UsageResponse, error)
Add(ctx context.Context, req adapters.AddRequest) (adapters.SearchResponse, error)
Search(ctx context.Context, req adapters.SearchRequest) (adapters.SearchResponse, error)
GetAll(ctx context.Context, req adapters.GetAllRequest) (adapters.SearchResponse, error)
Update(ctx context.Context, req adapters.UpdateRequest) (adapters.MemoryItem, error)
Delete(ctx context.Context, memoryID string) (adapters.DeleteResponse, error)
DeleteBatch(ctx context.Context, memoryIDs []string) (adapters.DeleteResponse, error)
DeleteAll(ctx context.Context, req adapters.DeleteAllRequest) (adapters.DeleteResponse, error)
Compact(ctx context.Context, filters map[string]any, ratio float64, decayDays int) (adapters.CompactResult, error)
Usage(ctx context.Context, filters map[string]any) (adapters.UsageResponse, error)
Mode() string
Status(ctx context.Context, botID string) (adapters.MemoryStatusResponse, error)
Rebuild(ctx context.Context, botID string) (adapters.RebuildResult, error)
}
// AdminChecker checks whether a channel identity has admin privileges.
@@ -69,7 +73,7 @@ func (*BuiltinProvider) Type() string { return BuiltinType }
// --- Conversation Hooks ---
func (p *BuiltinProvider) OnBeforeChat(ctx context.Context, req BeforeChatRequest) (*BeforeChatResult, error) {
func (p *BuiltinProvider) OnBeforeChat(ctx context.Context, req adapters.BeforeChatRequest) (*adapters.BeforeChatResult, error) {
if p.service == nil {
return nil, nil
}
@@ -77,7 +81,7 @@ func (p *BuiltinProvider) OnBeforeChat(ctx context.Context, req BeforeChatReques
return nil, nil
}
resp, err := p.service.Search(ctx, SearchRequest{
resp, err := p.service.Search(ctx, adapters.SearchRequest{
Query: req.Query,
BotID: req.BotID,
Limit: memoryContextLimitPerScope,
@@ -96,7 +100,7 @@ func (p *BuiltinProvider) OnBeforeChat(ctx context.Context, req BeforeChatReques
seen := map[string]struct{}{}
type contextItem struct {
namespace string
item MemoryItem
item adapters.MemoryItem
}
results := make([]contextItem, 0, memoryContextLimitPerScope)
for _, item := range resp.Results {
@@ -134,7 +138,7 @@ func (p *BuiltinProvider) OnBeforeChat(ctx context.Context, req BeforeChatReques
sb.WriteString("- [")
sb.WriteString(entry.namespace)
sb.WriteString("] ")
sb.WriteString(truncateSnippet(text, memoryContextItemMaxChars))
sb.WriteString(adapters.TruncateSnippet(text, memoryContextItemMaxChars))
sb.WriteString("\n")
}
sb.WriteString("</memory-context>")
@@ -142,10 +146,10 @@ func (p *BuiltinProvider) OnBeforeChat(ctx context.Context, req BeforeChatReques
if payload == "" {
return nil, nil
}
return &BeforeChatResult{ContextText: payload}, nil
return &adapters.BeforeChatResult{ContextText: payload}, nil
}
func (p *BuiltinProvider) OnAfterChat(ctx context.Context, req AfterChatRequest) error {
func (p *BuiltinProvider) OnAfterChat(ctx context.Context, req adapters.AfterChatRequest) error {
if p.service == nil {
return nil
}
@@ -161,9 +165,11 @@ func (p *BuiltinProvider) OnAfterChat(ctx context.Context, req AfterChatRequest)
"scopeId": botID,
"bot_id": botID,
}
if _, err := p.service.Add(ctx, AddRequest{
metadata := adapters.BuildProfileMetadata(req.UserID, req.ChannelIdentityID, req.DisplayName)
if _, err := p.service.Add(ctx, adapters.AddRequest{
Messages: req.Messages,
BotID: botID,
Metadata: metadata,
Filters: filters,
}); err != nil {
p.logger.Warn("store memory failed", slog.String("bot_id", botID), slog.Any("error", err))
@@ -256,7 +262,7 @@ func (p *BuiltinProvider) CallTool(ctx context.Context, session mcp.ToolSessionC
}
}
resp, err := p.service.Search(ctx, SearchRequest{
resp, err := p.service.Search(ctx, adapters.SearchRequest{
Query: query,
BotID: botID,
Limit: limit,
@@ -271,7 +277,7 @@ func (p *BuiltinProvider) CallTool(ctx context.Context, session mcp.ToolSessionC
return mcp.BuildToolErrorResult("memory search failed"), nil
}
allResults := deduplicateItems(resp.Results)
allResults := adapters.DeduplicateItems(resp.Results)
sort.Slice(allResults, func(i, j int) bool {
return allResults[i].Score > allResults[j].Score
})
@@ -313,99 +319,79 @@ func (p *BuiltinProvider) canAccessChat(ctx context.Context, chatID, channelIden
// --- CRUD ---
func (p *BuiltinProvider) Add(ctx context.Context, req AddRequest) (SearchResponse, error) {
func (p *BuiltinProvider) Add(ctx context.Context, req adapters.AddRequest) (adapters.SearchResponse, error) {
if p.service == nil {
return SearchResponse{}, errors.New("memory runtime not configured")
return adapters.SearchResponse{}, errors.New("memory runtime not configured")
}
return p.service.Add(ctx, req)
}
func (p *BuiltinProvider) Search(ctx context.Context, req SearchRequest) (SearchResponse, error) {
func (p *BuiltinProvider) Search(ctx context.Context, req adapters.SearchRequest) (adapters.SearchResponse, error) {
if p.service == nil {
return SearchResponse{}, errors.New("memory runtime not configured")
return adapters.SearchResponse{}, errors.New("memory runtime not configured")
}
return p.service.Search(ctx, req)
}
func (p *BuiltinProvider) GetAll(ctx context.Context, req GetAllRequest) (SearchResponse, error) {
func (p *BuiltinProvider) GetAll(ctx context.Context, req adapters.GetAllRequest) (adapters.SearchResponse, error) {
if p.service == nil {
return SearchResponse{}, errors.New("memory runtime not configured")
return adapters.SearchResponse{}, errors.New("memory runtime not configured")
}
return p.service.GetAll(ctx, req)
}
func (p *BuiltinProvider) Update(ctx context.Context, req UpdateRequest) (MemoryItem, error) {
func (p *BuiltinProvider) Update(ctx context.Context, req adapters.UpdateRequest) (adapters.MemoryItem, error) {
if p.service == nil {
return MemoryItem{}, errors.New("memory runtime not configured")
return adapters.MemoryItem{}, errors.New("memory runtime not configured")
}
return p.service.Update(ctx, req)
}
func (p *BuiltinProvider) Delete(ctx context.Context, memoryID string) (DeleteResponse, error) {
func (p *BuiltinProvider) Delete(ctx context.Context, memoryID string) (adapters.DeleteResponse, error) {
if p.service == nil {
return DeleteResponse{}, errors.New("memory runtime not configured")
return adapters.DeleteResponse{}, errors.New("memory runtime not configured")
}
return p.service.Delete(ctx, memoryID)
}
func (p *BuiltinProvider) DeleteBatch(ctx context.Context, memoryIDs []string) (DeleteResponse, error) {
func (p *BuiltinProvider) DeleteBatch(ctx context.Context, memoryIDs []string) (adapters.DeleteResponse, error) {
if p.service == nil {
return DeleteResponse{}, errors.New("memory runtime not configured")
return adapters.DeleteResponse{}, errors.New("memory runtime not configured")
}
return p.service.DeleteBatch(ctx, memoryIDs)
}
func (p *BuiltinProvider) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteResponse, error) {
func (p *BuiltinProvider) DeleteAll(ctx context.Context, req adapters.DeleteAllRequest) (adapters.DeleteResponse, error) {
if p.service == nil {
return DeleteResponse{}, errors.New("memory runtime not configured")
return adapters.DeleteResponse{}, errors.New("memory runtime not configured")
}
return p.service.DeleteAll(ctx, req)
}
func (p *BuiltinProvider) Compact(ctx context.Context, filters map[string]any, ratio float64, decayDays int) (CompactResult, error) {
func (p *BuiltinProvider) Compact(ctx context.Context, filters map[string]any, ratio float64, decayDays int) (adapters.CompactResult, error) {
if p.service == nil {
return CompactResult{}, errors.New("memory runtime not configured")
return adapters.CompactResult{}, errors.New("memory runtime not configured")
}
return p.service.Compact(ctx, filters, ratio, decayDays)
}
func (p *BuiltinProvider) Usage(ctx context.Context, filters map[string]any) (UsageResponse, error) {
func (p *BuiltinProvider) Usage(ctx context.Context, filters map[string]any) (adapters.UsageResponse, error) {
if p.service == nil {
return UsageResponse{}, errors.New("memory runtime not configured")
return adapters.UsageResponse{}, errors.New("memory runtime not configured")
}
return p.service.Usage(ctx, filters)
}
// --- helpers ---
func truncateSnippet(s string, n int) string {
trimmed := strings.TrimSpace(s)
runes := []rune(trimmed)
if len(runes) <= n {
return trimmed
func (p *BuiltinProvider) Status(ctx context.Context, botID string) (adapters.MemoryStatusResponse, error) {
if p.service == nil {
return adapters.MemoryStatusResponse{}, errors.New("memory runtime not configured")
}
return strings.TrimSpace(string(runes[:n])) + "..."
return p.service.Status(ctx, botID)
}
func deduplicateItems(items []MemoryItem) []MemoryItem {
if len(items) == 0 {
return items
func (p *BuiltinProvider) Rebuild(ctx context.Context, botID string) (adapters.RebuildResult, error) {
if p.service == nil {
return adapters.RebuildResult{}, errors.New("memory runtime not configured")
}
seen := make(map[string]struct{}, len(items))
result := make([]MemoryItem, 0, len(items))
for _, item := range items {
id := strings.TrimSpace(item.ID)
if id == "" {
id = strings.TrimSpace(item.Memory)
}
if id == "" {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
result = append(result, item)
}
return result
return p.service.Rebuild(ctx, botID)
}
@@ -1,13 +1,15 @@
package provider
package builtin
import (
"testing"
"unicode/utf8"
adapters "github.com/memohai/memoh/internal/memory/adapters"
)
func TestTruncateSnippet_ASCII(t *testing.T) {
t.Parallel()
got := truncateSnippet("hello world", 5)
got := adapters.TruncateSnippet("hello world", 5)
if got != "hello..." {
t.Fatalf("expected %q, got %q", "hello...", got)
}
@@ -15,7 +17,7 @@ func TestTruncateSnippet_ASCII(t *testing.T) {
func TestTruncateSnippet_NoTruncation(t *testing.T) {
t.Parallel()
got := truncateSnippet("short", 100)
got := adapters.TruncateSnippet("short", 100)
if got != "short" {
t.Fatalf("expected %q, got %q", "short", got)
}
@@ -24,7 +26,7 @@ func TestTruncateSnippet_NoTruncation(t *testing.T) {
func TestTruncateSnippet_CJK(t *testing.T) {
t.Parallel()
// 5 CJK characters (15 bytes in UTF-8), truncate to 3 runes.
got := truncateSnippet("你好世界啊", 3)
got := adapters.TruncateSnippet("你好世界啊", 3)
if !utf8.ValidString(got) {
t.Fatalf("result is not valid UTF-8: %q", got)
}
@@ -36,7 +38,7 @@ func TestTruncateSnippet_CJK(t *testing.T) {
func TestTruncateSnippet_Emoji(t *testing.T) {
t.Parallel()
// Emoji are 4 bytes each in UTF-8.
got := truncateSnippet("😀😁😂🤣😃", 2)
got := adapters.TruncateSnippet("😀😁😂🤣😃", 2)
if !utf8.ValidString(got) {
t.Fatalf("result is not valid UTF-8: %q", got)
}
@@ -47,7 +49,7 @@ func TestTruncateSnippet_Emoji(t *testing.T) {
func TestTruncateSnippet_TrimWhitespace(t *testing.T) {
t.Parallel()
got := truncateSnippet(" hello ", 100)
got := adapters.TruncateSnippet(" hello ", 100)
if got != "hello" {
t.Fatalf("expected %q, got %q", "hello", got)
}
@@ -0,0 +1,736 @@
package builtin
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"path"
"sort"
"strings"
"time"
"github.com/memohai/memoh/internal/config"
"github.com/memohai/memoh/internal/db"
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
adapters "github.com/memohai/memoh/internal/memory/adapters"
qdrantclient "github.com/memohai/memoh/internal/memory/qdrant"
storefs "github.com/memohai/memoh/internal/memory/storefs"
)
type denseRuntime struct {
qdrant *qdrantclient.Client
store *storefs.Service
embedder *denseEmbeddingClient
collection string
}
type denseEmbeddingClient struct {
baseURL string
apiKey string
modelID string
dimensions int
httpClient *http.Client
}
type denseEmbeddingResponse struct {
Data []struct {
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
} `json:"data"`
}
type denseModelSpec struct {
modelID string
baseURL string
apiKey string
dimensions int
}
func newDenseRuntime(providerConfig map[string]any, queries *dbsqlc.Queries, cfg config.Config, store *storefs.Service) (*denseRuntime, error) {
if queries == nil {
return nil, errors.New("dense runtime: queries are required")
}
if store == nil {
return nil, errors.New("dense runtime: memory store is required")
}
modelRef := strings.TrimSpace(adapters.StringFromConfig(providerConfig, "embedding_model_id"))
if modelRef == "" {
return nil, errors.New("dense runtime: embedding_model_id is required")
}
modelSpec, err := resolveDenseEmbeddingModel(context.Background(), queries, modelRef)
if err != nil {
return nil, err
}
host, port := parseQdrantHostPort(cfg.Qdrant.BaseURL)
if host == "" {
host = "localhost"
}
if port == 0 {
port = 6334
}
collection := adapters.StringFromConfig(providerConfig, "qdrant_collection")
if strings.TrimSpace(collection) == "" {
collection = "memory_dense"
}
qClient, err := qdrantclient.NewClient(host, port, cfg.Qdrant.APIKey, collection)
if err != nil {
return nil, fmt.Errorf("dense runtime: %w", err)
}
return &denseRuntime{
qdrant: qClient,
store: store,
embedder: &denseEmbeddingClient{
baseURL: strings.TrimRight(modelSpec.baseURL, "/"),
apiKey: modelSpec.apiKey,
modelID: modelSpec.modelID,
dimensions: modelSpec.dimensions,
httpClient: &http.Client{Timeout: 30 * time.Second},
},
collection: collection,
}, nil
}
func (r *denseRuntime) Add(ctx context.Context, req adapters.AddRequest) (adapters.SearchResponse, error) {
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
if err != nil {
return adapters.SearchResponse{}, err
}
text := sparseRuntimeText(req.Message, req.Messages)
if text == "" {
return adapters.SearchResponse{}, errors.New("dense runtime: message is required")
}
now := time.Now().UTC().Format(time.RFC3339)
item := adapters.MemoryItem{
ID: sparseRuntimeMemoryID(botID, time.Now().UTC()),
Memory: text,
Hash: denseRuntimeHash(text),
Metadata: req.Metadata,
BotID: botID,
CreatedAt: now,
UpdatedAt: now,
}
if err := r.store.PersistMemories(ctx, botID, []storefs.MemoryItem{denseStoreItemFromMemoryItem(item)}, req.Filters); err != nil {
return adapters.SearchResponse{}, err
}
if err := r.upsertSourceItems(ctx, botID, []storefs.MemoryItem{denseStoreItemFromMemoryItem(item)}); err != nil {
return adapters.SearchResponse{}, err
}
return adapters.SearchResponse{Results: []adapters.MemoryItem{item}}, nil
}
func (r *denseRuntime) Search(ctx context.Context, req adapters.SearchRequest) (adapters.SearchResponse, error) {
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
if err != nil {
return adapters.SearchResponse{}, err
}
if err := r.qdrant.EnsureDenseCollection(ctx, r.embedder.dimensions); err != nil {
return adapters.SearchResponse{}, err
}
limit := req.Limit
if limit <= 0 {
limit = 10
}
vec, err := r.embedder.EmbedQuery(ctx, req.Query)
if err != nil {
return adapters.SearchResponse{}, fmt.Errorf("dense embed query: %w", err)
}
results, err := r.qdrant.SearchDense(ctx, qdrantclient.DenseVector{Values: vec}, botID, limit)
if err != nil {
return adapters.SearchResponse{}, err
}
items := make([]adapters.MemoryItem, 0, len(results))
for _, result := range results {
items = append(items, denseResultToItem(result))
}
return adapters.SearchResponse{Results: items}, nil
}
func (r *denseRuntime) GetAll(ctx context.Context, req adapters.GetAllRequest) (adapters.SearchResponse, error) {
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
if err != nil {
return adapters.SearchResponse{}, err
}
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
if err != nil {
return adapters.SearchResponse{}, err
}
result := make([]adapters.MemoryItem, 0, len(items))
for _, item := range items {
mem := denseMemoryItemFromStore(item)
mem.BotID = botID
result = append(result, mem)
}
sort.Slice(result, func(i, j int) bool { return result[i].UpdatedAt > result[j].UpdatedAt })
if req.Limit > 0 && len(result) > req.Limit {
result = result[:req.Limit]
}
return adapters.SearchResponse{Results: result}, nil
}
func (r *denseRuntime) Update(ctx context.Context, req adapters.UpdateRequest) (adapters.MemoryItem, error) {
memoryID := strings.TrimSpace(req.MemoryID)
if memoryID == "" {
return adapters.MemoryItem{}, errors.New("dense runtime: memory_id is required")
}
text := strings.TrimSpace(req.Memory)
if text == "" {
return adapters.MemoryItem{}, errors.New("dense runtime: memory is required")
}
botID := sparseRuntimeBotIDFromMemoryID(memoryID)
if botID == "" {
return adapters.MemoryItem{}, errors.New("dense runtime: invalid memory_id")
}
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
if err != nil {
return adapters.MemoryItem{}, err
}
var existing *storefs.MemoryItem
for i := range items {
if strings.TrimSpace(items[i].ID) == memoryID {
item := items[i]
existing = &item
break
}
}
if existing == nil {
return adapters.MemoryItem{}, errors.New("dense runtime: memory not found")
}
existing.Memory = text
existing.Hash = denseRuntimeHash(text)
existing.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
if err := r.store.PersistMemories(ctx, botID, []storefs.MemoryItem{*existing}, nil); err != nil {
return adapters.MemoryItem{}, err
}
if err := r.upsertSourceItems(ctx, botID, []storefs.MemoryItem{*existing}); err != nil {
return adapters.MemoryItem{}, err
}
item := denseMemoryItemFromStore(*existing)
item.BotID = botID
return item, nil
}
func (r *denseRuntime) Delete(ctx context.Context, memoryID string) (adapters.DeleteResponse, error) {
return r.DeleteBatch(ctx, []string{memoryID})
}
func (r *denseRuntime) DeleteBatch(ctx context.Context, memoryIDs []string) (adapters.DeleteResponse, error) {
grouped := map[string][]string{}
pointIDs := make([]string, 0, len(memoryIDs))
for _, rawID := range memoryIDs {
memoryID := strings.TrimSpace(rawID)
if memoryID == "" {
continue
}
botID := sparseRuntimeBotIDFromMemoryID(memoryID)
if botID == "" {
continue
}
grouped[botID] = append(grouped[botID], memoryID)
pointIDs = append(pointIDs, sparsePointID(botID, memoryID))
}
for botID, ids := range grouped {
if err := r.store.RemoveMemories(ctx, botID, ids); err != nil {
return adapters.DeleteResponse{}, err
}
}
if err := r.qdrant.DeleteByIDs(ctx, pointIDs); err != nil {
return adapters.DeleteResponse{}, err
}
return adapters.DeleteResponse{Message: "Memories deleted successfully!"}, nil
}
func (r *denseRuntime) DeleteAll(ctx context.Context, req adapters.DeleteAllRequest) (adapters.DeleteResponse, error) {
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
if err != nil {
return adapters.DeleteResponse{}, err
}
if err := r.store.RemoveAllMemories(ctx, botID); err != nil {
return adapters.DeleteResponse{}, err
}
if err := r.qdrant.DeleteByBotID(ctx, botID); err != nil {
return adapters.DeleteResponse{}, err
}
return adapters.DeleteResponse{Message: "All memories deleted successfully!"}, nil
}
func (r *denseRuntime) Compact(ctx context.Context, filters map[string]any, ratio float64, _ int) (adapters.CompactResult, error) {
botID, err := sparseRuntimeBotID("", filters)
if err != nil {
return adapters.CompactResult{}, err
}
if ratio <= 0 || ratio > 1 {
return adapters.CompactResult{}, errors.New("ratio must be in range (0, 1]")
}
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
if err != nil {
return adapters.CompactResult{}, err
}
before := len(items)
if before == 0 {
return adapters.CompactResult{BeforeCount: 0, AfterCount: 0, Ratio: ratio, Results: []adapters.MemoryItem{}}, nil
}
sort.Slice(items, func(i, j int) bool { return items[i].UpdatedAt > items[j].UpdatedAt })
target := int(float64(before) * ratio)
if target < 1 {
target = 1
}
if target > before {
target = before
}
keptStore := append([]storefs.MemoryItem(nil), items[:target]...)
if err := r.store.RebuildFiles(ctx, botID, keptStore, filters); err != nil {
return adapters.CompactResult{}, err
}
if _, err := r.Rebuild(ctx, botID); err != nil {
return adapters.CompactResult{}, err
}
kept := make([]adapters.MemoryItem, 0, len(keptStore))
for _, item := range keptStore {
kept = append(kept, denseMemoryItemFromStore(item))
}
return adapters.CompactResult{
BeforeCount: before,
AfterCount: len(kept),
Ratio: ratio,
Results: kept,
}, nil
}
func (r *denseRuntime) Usage(ctx context.Context, filters map[string]any) (adapters.UsageResponse, error) {
botID, err := sparseRuntimeBotID("", filters)
if err != nil {
return adapters.UsageResponse{}, err
}
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
if err != nil {
return adapters.UsageResponse{}, err
}
var usage adapters.UsageResponse
usage.Count = len(items)
for _, item := range items {
usage.TotalTextBytes += int64(len(item.Memory))
}
if usage.Count > 0 {
usage.AvgTextBytes = usage.TotalTextBytes / int64(usage.Count)
}
usage.EstimatedStorageBytes = usage.TotalTextBytes
return usage, nil
}
func (*denseRuntime) Mode() string {
return string(ModeDense)
}
func (r *denseRuntime) Status(ctx context.Context, botID string) (adapters.MemoryStatusResponse, error) {
fileCount, err := r.store.CountMemoryFiles(ctx, botID)
if err != nil {
return adapters.MemoryStatusResponse{}, err
}
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
if err != nil {
return adapters.MemoryStatusResponse{}, err
}
status := adapters.MemoryStatusResponse{
ProviderType: BuiltinType,
MemoryMode: string(ModeDense),
CanManualSync: true,
SourceDir: path.Join(config.DefaultDataMount, "memory"),
OverviewPath: path.Join(config.DefaultDataMount, "MEMORY.md"),
MarkdownFileCount: fileCount,
SourceCount: len(items),
QdrantCollection: r.collection,
}
if err := r.embedder.Health(ctx); err != nil {
status.Encoder.Error = err.Error()
} else {
status.Encoder.OK = true
}
exists, err := r.qdrant.CollectionExists(ctx)
if err != nil {
status.Qdrant.Error = err.Error()
return status, nil
}
status.Qdrant.OK = true
if exists {
count, err := r.qdrant.Count(ctx, botID)
if err != nil {
status.Qdrant.OK = false
status.Qdrant.Error = err.Error()
return status, nil
}
status.IndexedCount = count
}
return status, nil
}
func (r *denseRuntime) Rebuild(ctx context.Context, botID string) (adapters.RebuildResult, error) {
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
if err != nil {
return adapters.RebuildResult{}, err
}
if err := r.store.SyncOverview(ctx, botID); err != nil {
return adapters.RebuildResult{}, err
}
return r.syncSourceItems(ctx, botID, items)
}
func (r *denseRuntime) syncSourceItems(ctx context.Context, botID string, items []storefs.MemoryItem) (adapters.RebuildResult, error) {
if err := r.qdrant.EnsureDenseCollection(ctx, r.embedder.dimensions); err != nil {
return adapters.RebuildResult{}, err
}
existing, err := r.qdrant.Scroll(ctx, botID, 10000)
if err != nil {
return adapters.RebuildResult{}, err
}
existingBySource := make(map[string]qdrantclient.SearchResult, len(existing))
for _, item := range existing {
sourceID := strings.TrimSpace(item.Payload["source_entry_id"])
if sourceID == "" {
sourceID = strings.TrimSpace(item.ID)
}
if sourceID != "" {
existingBySource[sourceID] = item
}
}
sourceIDs := make(map[string]struct{}, len(items))
toUpsert := make([]storefs.MemoryItem, 0, len(items))
missingCount := 0
restoredCount := 0
for _, item := range items {
item = denseCanonicalStoreItem(item)
if item.ID == "" || item.Memory == "" {
continue
}
sourceIDs[item.ID] = struct{}{}
payload := densePayload(botID, item)
existingItem, ok := existingBySource[item.ID]
if !ok {
missingCount++
restoredCount++
toUpsert = append(toUpsert, item)
continue
}
if !densePayloadMatches(existingItem.Payload, payload) {
restoredCount++
toUpsert = append(toUpsert, item)
}
}
stale := make([]string, 0)
for _, item := range existing {
sourceID := strings.TrimSpace(item.Payload["source_entry_id"])
if sourceID == "" {
sourceID = strings.TrimSpace(item.ID)
}
if _, ok := sourceIDs[sourceID]; ok {
continue
}
if strings.TrimSpace(item.ID) != "" {
stale = append(stale, item.ID)
}
}
if len(stale) > 0 {
if err := r.qdrant.DeleteByIDs(ctx, stale); err != nil {
return adapters.RebuildResult{}, err
}
}
if err := r.upsertSourceItems(ctx, botID, toUpsert); err != nil {
return adapters.RebuildResult{}, err
}
count, err := r.qdrant.Count(ctx, botID)
if err != nil {
return adapters.RebuildResult{}, err
}
return adapters.RebuildResult{
FsCount: len(items),
StorageCount: count,
MissingCount: missingCount,
RestoredCount: restoredCount,
}, nil
}
func (r *denseRuntime) upsertSourceItems(ctx context.Context, botID string, items []storefs.MemoryItem) error {
if len(items) == 0 {
return nil
}
if err := r.qdrant.EnsureDenseCollection(ctx, r.embedder.dimensions); err != nil {
return err
}
canonical := make([]storefs.MemoryItem, 0, len(items))
texts := make([]string, 0, len(items))
for _, item := range items {
item = denseCanonicalStoreItem(item)
if item.ID == "" || item.Memory == "" {
continue
}
canonical = append(canonical, item)
texts = append(texts, item.Memory)
}
if len(canonical) == 0 {
return nil
}
vectors, err := r.embedder.EmbedDocuments(ctx, texts)
if err != nil {
return fmt.Errorf("dense embed documents: %w", err)
}
if len(vectors) != len(canonical) {
return fmt.Errorf("dense embed documents: expected %d vectors, got %d", len(canonical), len(vectors))
}
for i, item := range canonical {
if err := r.qdrant.UpsertDense(ctx, sparsePointID(botID, item.ID), qdrantclient.DenseVector{
Values: vectors[i],
}, densePayload(botID, item)); err != nil {
return err
}
}
return nil
}
func resolveDenseEmbeddingModel(ctx context.Context, queries *dbsqlc.Queries, modelRef string) (denseModelSpec, error) {
modelRef = strings.TrimSpace(modelRef)
if modelRef == "" {
return denseModelSpec{}, errors.New("dense runtime: embedding_model_id is required")
}
var row dbsqlc.Model
if parsed, err := db.ParseUUID(modelRef); err == nil {
dbModel, err := queries.GetModelByID(ctx, parsed)
if err == nil {
row = dbModel
}
}
if !row.ID.Valid {
rows, err := queries.ListModelsByModelID(ctx, modelRef)
if err != nil || len(rows) == 0 {
return denseModelSpec{}, fmt.Errorf("dense runtime: embedding model not found: %s", modelRef)
}
row = rows[0]
}
if row.Type != "embedding" {
return denseModelSpec{}, fmt.Errorf("dense runtime: model %s is not an embedding model", modelRef)
}
if !row.LlmProviderID.Valid {
return denseModelSpec{}, fmt.Errorf("dense runtime: model %s has no provider", modelRef)
}
provider, err := queries.GetLlmProviderByID(ctx, row.LlmProviderID)
if err != nil {
return denseModelSpec{}, fmt.Errorf("dense runtime: get embedding provider: %w", err)
}
if !row.Dimensions.Valid || row.Dimensions.Int32 <= 0 {
return denseModelSpec{}, fmt.Errorf("dense runtime: embedding model %s missing dimensions", modelRef)
}
return denseModelSpec{
modelID: strings.TrimSpace(row.ModelID),
baseURL: strings.TrimSpace(provider.BaseUrl),
apiKey: strings.TrimSpace(provider.ApiKey),
dimensions: int(row.Dimensions.Int32),
}, nil
}
func joinDenseEmbeddingEndpointURL(baseURL, endpointPath string) (string, error) {
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return "", errors.New("dense embedding base URL is required")
}
base, err := url.Parse(baseURL)
if err != nil {
return "", fmt.Errorf("invalid dense embedding base URL: %w", err)
}
if base.Scheme != "http" && base.Scheme != "https" {
return "", fmt.Errorf("invalid dense embedding base URL scheme: %q", base.Scheme)
}
if base.Host == "" {
return "", errors.New("invalid dense embedding base URL: host is required")
}
ref, err := url.Parse(endpointPath)
if err != nil {
return "", fmt.Errorf("invalid dense embedding path: %w", err)
}
return base.ResolveReference(ref).String(), nil
}
func (c *denseEmbeddingClient) Health(ctx context.Context) error {
endpoint, err := joinDenseEmbeddingEndpointURL(c.baseURL, "/models")
if err != nil {
return err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return err
}
if c.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+c.apiKey)
}
resp, err := c.httpClient.Do(req) //nolint:gosec // G704: URL is validated and derived from operator-configured embedding provider base URL
if err != nil {
return fmt.Errorf("dense embedding health check failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("dense embedding health error (status %d): %s", resp.StatusCode, string(body))
}
return nil
}
func (c *denseEmbeddingClient) EmbedQuery(ctx context.Context, text string) ([]float32, error) {
vectors, err := c.EmbedDocuments(ctx, []string{text})
if err != nil {
return nil, err
}
if len(vectors) == 0 {
return nil, errors.New("dense embed query: empty embedding response")
}
return vectors[0], nil
}
func (c *denseEmbeddingClient) EmbedDocuments(ctx context.Context, texts []string) ([][]float32, error) {
body, err := json.Marshal(map[string]any{
"model": c.modelID,
"input": texts,
})
if err != nil {
return nil, err
}
endpoint, err := joinDenseEmbeddingEndpointURL(c.baseURL, "/embeddings")
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
if c.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+c.apiKey)
}
resp, err := c.httpClient.Do(req) //nolint:gosec // G704: URL is validated and derived from operator-configured embedding provider base URL
if err != nil {
return nil, fmt.Errorf("dense embed request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("dense embed read response: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("dense embed api error %d: %s", resp.StatusCode, string(respBody))
}
var parsed denseEmbeddingResponse
if err := json.Unmarshal(respBody, &parsed); err != nil {
return nil, fmt.Errorf("dense embed decode response: %w", err)
}
vectors := make([][]float32, len(parsed.Data))
for _, item := range parsed.Data {
if item.Index >= 0 && item.Index < len(vectors) {
vectors[item.Index] = item.Embedding
}
}
out := make([][]float32, 0, len(vectors))
for _, vector := range vectors {
if len(vector) > 0 {
out = append(out, vector)
}
}
return out, nil
}
func denseCanonicalStoreItem(item storefs.MemoryItem) storefs.MemoryItem {
item.ID = strings.TrimSpace(item.ID)
item.Memory = strings.TrimSpace(item.Memory)
if item.Memory != "" && strings.TrimSpace(item.Hash) == "" {
item.Hash = denseRuntimeHash(item.Memory)
}
return item
}
func densePayload(botID string, item storefs.MemoryItem) map[string]string {
item = denseCanonicalStoreItem(item)
payload := map[string]string{
"memory": item.Memory,
"bot_id": strings.TrimSpace(botID),
"source_entry_id": item.ID,
"hash": item.Hash,
}
if item.CreatedAt != "" {
payload["created_at"] = item.CreatedAt
}
if item.UpdatedAt != "" {
payload["updated_at"] = item.UpdatedAt
}
return payload
}
func densePayloadMatches(existing, expected map[string]string) bool {
for key, value := range expected {
if strings.TrimSpace(existing[key]) != strings.TrimSpace(value) {
return false
}
}
return true
}
func denseStoreItemFromMemoryItem(item adapters.MemoryItem) storefs.MemoryItem {
return denseCanonicalStoreItem(storefs.MemoryItem{
ID: item.ID,
Memory: item.Memory,
Hash: item.Hash,
CreatedAt: item.CreatedAt,
UpdatedAt: item.UpdatedAt,
Score: item.Score,
Metadata: item.Metadata,
BotID: item.BotID,
AgentID: item.AgentID,
RunID: item.RunID,
})
}
func denseMemoryItemFromStore(item storefs.MemoryItem) adapters.MemoryItem {
item = denseCanonicalStoreItem(item)
return adapters.MemoryItem{
ID: item.ID,
Memory: item.Memory,
Hash: item.Hash,
CreatedAt: item.CreatedAt,
UpdatedAt: item.UpdatedAt,
Score: item.Score,
Metadata: item.Metadata,
BotID: item.BotID,
AgentID: item.AgentID,
RunID: item.RunID,
}
}
func denseResultToItem(r qdrantclient.SearchResult) adapters.MemoryItem {
item := adapters.MemoryItem{
ID: r.ID,
Score: r.Score,
}
if r.Payload != nil {
if sourceID := strings.TrimSpace(r.Payload["source_entry_id"]); sourceID != "" {
item.ID = sourceID
}
item.Memory = r.Payload["memory"]
item.Hash = r.Payload["hash"]
item.BotID = r.Payload["bot_id"]
item.CreatedAt = r.Payload["created_at"]
item.UpdatedAt = r.Payload["updated_at"]
}
return item
}
func denseRuntimeHash(text string) string {
sum := sha256.Sum256([]byte(strings.TrimSpace(text)))
return hex.EncodeToString(sum[:])
}
@@ -0,0 +1,98 @@
package builtin
import (
"log/slog"
"strconv"
"strings"
"github.com/memohai/memoh/internal/config"
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
adapters "github.com/memohai/memoh/internal/memory/adapters"
storefs "github.com/memohai/memoh/internal/memory/storefs"
)
// BuiltinMemoryMode represents the operating mode of the built-in memory provider.
type BuiltinMemoryMode string
const (
ModeOff BuiltinMemoryMode = "off"
ModeSparse BuiltinMemoryMode = "sparse"
ModeDense BuiltinMemoryMode = "dense"
)
// NewBuiltinRuntimeFromConfig returns the appropriate memoryRuntime based on the
// provider's persisted config (memory_mode field). Falls back to the file runtime for "off" or unknown.
func NewBuiltinRuntimeFromConfig(log *slog.Logger, providerConfig map[string]any, fileRuntime any, store *storefs.Service, queries *dbsqlc.Queries, cfg config.Config) (any, error) {
mode := BuiltinMemoryMode(strings.TrimSpace(adapters.StringFromConfig(providerConfig, "memory_mode")))
switch mode {
case ModeSparse:
host, port := parseQdrantHostPort(cfg.Qdrant.BaseURL)
if host == "" {
host = "localhost"
}
if port == 0 {
port = 6334
}
collection := adapters.StringFromConfig(providerConfig, "qdrant_collection")
if collection == "" {
collection = "memory_sparse"
}
rt, err := newSparseRuntime(
host,
port,
cfg.Qdrant.APIKey,
collection,
strings.TrimSpace(cfg.Sparse.BaseURL),
store,
)
if err != nil {
if log != nil {
log.Warn("sparse runtime init failed, falling back to file runtime", slog.Any("error", err))
}
return fileRuntime, nil
}
return rt, nil
case ModeDense:
rt, err := newDenseRuntime(providerConfig, queries, cfg, store)
if err != nil {
if log != nil {
log.Warn("dense runtime init failed, falling back to file runtime", slog.Any("error", err))
}
return fileRuntime, nil
}
return rt, nil
default:
return fileRuntime, nil
}
}
// parseQdrantHostPort extracts host and gRPC port from a Qdrant base URL.
// Qdrant base URLs are typically HTTP (port 6333), but the gRPC port is 6334.
func parseQdrantHostPort(baseURL string) (string, int) {
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return "", 0
}
baseURL = strings.TrimPrefix(baseURL, "http://")
baseURL = strings.TrimPrefix(baseURL, "https://")
parts := strings.SplitN(baseURL, ":", 2)
host := parts[0]
if len(parts) == 2 {
httpPort, err := strconv.Atoi(strings.TrimRight(parts[1], "/"))
if err == nil {
switch httpPort {
case 6333:
return host, 6334
case 6334:
return host, 6334
default:
// Common case: operator already configured the intended gRPC port.
return host, httpPort
}
}
}
return host, 6334
}
@@ -0,0 +1,742 @@
package builtin
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"path"
"sort"
"strconv"
"strings"
"time"
"github.com/google/uuid"
"github.com/memohai/memoh/internal/config"
adapters "github.com/memohai/memoh/internal/memory/adapters"
qdrantclient "github.com/memohai/memoh/internal/memory/qdrant"
"github.com/memohai/memoh/internal/memory/sparse"
storefs "github.com/memohai/memoh/internal/memory/storefs"
)
type sparseEncoder interface {
EncodeDocument(ctx context.Context, text string) (*sparse.SparseVector, error)
EncodeDocuments(ctx context.Context, texts []string) ([]sparse.SparseVector, error)
EncodeQuery(ctx context.Context, text string) (*sparse.SparseVector, error)
Health(ctx context.Context) error
}
type sparseIndex interface {
CollectionName() string
CollectionExists(ctx context.Context) (bool, error)
EnsureCollection(ctx context.Context) error
Upsert(ctx context.Context, id string, vec qdrantclient.SparseVector, payload map[string]string) error
Search(ctx context.Context, vec qdrantclient.SparseVector, botID string, limit int) ([]qdrantclient.SearchResult, error)
Scroll(ctx context.Context, botID string, limit int) ([]qdrantclient.SearchResult, error)
Count(ctx context.Context, botID string) (int, error)
DeleteByIDs(ctx context.Context, ids []string) error
DeleteByBotID(ctx context.Context, botID string) error
}
type sparseMemoryStore interface {
PersistMemories(ctx context.Context, botID string, items []storefs.MemoryItem, filters map[string]any) error
ReadAllMemoryFiles(ctx context.Context, botID string) ([]storefs.MemoryItem, error)
RemoveMemories(ctx context.Context, botID string, ids []string) error
RemoveAllMemories(ctx context.Context, botID string) error
RebuildFiles(ctx context.Context, botID string, items []storefs.MemoryItem, filters map[string]any) error
SyncOverview(ctx context.Context, botID string) error
CountMemoryFiles(ctx context.Context, botID string) (int, error)
}
// sparseRuntime implements memoryRuntime with markdown files as the source of
// truth and Qdrant as a derived sparse index used for retrieval.
type sparseRuntime struct {
qdrant sparseIndex
encoder sparseEncoder
store sparseMemoryStore
}
const (
sparseExplainTopKLimit = 24
)
func newSparseRuntime(qdrantHost string, qdrantPort int, qdrantAPIKey, collection, encoderBaseURL string, store *storefs.Service) (*sparseRuntime, error) {
if strings.TrimSpace(qdrantHost) == "" {
return nil, errors.New("sparse runtime: qdrant host is required")
}
if strings.TrimSpace(encoderBaseURL) == "" {
return nil, errors.New("sparse runtime: sparse.base_url is required")
}
if store == nil {
return nil, errors.New("sparse runtime: memory store is required")
}
qClient, err := qdrantclient.NewClient(qdrantHost, qdrantPort, qdrantAPIKey, collection)
if err != nil {
return nil, fmt.Errorf("sparse runtime: %w", err)
}
return &sparseRuntime{
qdrant: qClient,
encoder: sparse.NewClient(encoderBaseURL),
store: store,
}, nil
}
func (r *sparseRuntime) ensureCollection(ctx context.Context) error {
return r.qdrant.EnsureCollection(ctx)
}
func (*sparseRuntime) Mode() string {
return string(ModeSparse)
}
func (r *sparseRuntime) Add(ctx context.Context, req adapters.AddRequest) (adapters.SearchResponse, error) {
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
if err != nil {
return adapters.SearchResponse{}, err
}
text := sparseRuntimeText(req.Message, req.Messages)
if text == "" {
return adapters.SearchResponse{}, errors.New("sparse runtime: message is required")
}
now := time.Now().UTC().Format(time.RFC3339)
item := adapters.MemoryItem{
ID: sparseRuntimeMemoryID(botID, time.Now().UTC()),
Memory: text,
Hash: sparseRuntimeHash(text),
Metadata: req.Metadata,
BotID: botID,
CreatedAt: now,
UpdatedAt: now,
}
if err := r.store.PersistMemories(ctx, botID, []storefs.MemoryItem{sparseStoreItemFromMemoryItem(item)}, req.Filters); err != nil {
return adapters.SearchResponse{}, err
}
if err := r.upsertSourceItems(ctx, botID, []storefs.MemoryItem{sparseStoreItemFromMemoryItem(item)}); err != nil {
return adapters.SearchResponse{}, err
}
return adapters.SearchResponse{Results: []adapters.MemoryItem{item}}, nil
}
func (r *sparseRuntime) Search(ctx context.Context, req adapters.SearchRequest) (adapters.SearchResponse, error) {
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
if err != nil {
return adapters.SearchResponse{}, err
}
if err := r.ensureCollection(ctx); err != nil {
return adapters.SearchResponse{}, err
}
limit := req.Limit
if limit <= 0 {
limit = 10
}
vec, err := r.encoder.EncodeQuery(ctx, req.Query)
if err != nil {
return adapters.SearchResponse{}, fmt.Errorf("sparse encode query: %w", err)
}
results, err := r.qdrant.Search(ctx, qdrantclient.SparseVector{
Indices: vec.Indices,
Values: vec.Values,
}, botID, limit)
if err != nil {
return adapters.SearchResponse{}, err
}
items := make([]adapters.MemoryItem, 0, len(results))
for _, r := range results {
items = append(items, sparseResultToItem(r))
}
return adapters.SearchResponse{Results: items}, nil
}
func (r *sparseRuntime) GetAll(ctx context.Context, req adapters.GetAllRequest) (adapters.SearchResponse, error) {
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
if err != nil {
return adapters.SearchResponse{}, err
}
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
if err != nil {
return adapters.SearchResponse{}, err
}
result := make([]adapters.MemoryItem, 0, len(items))
for _, item := range items {
mem := sparseMemoryItemFromStore(item)
mem.BotID = botID
result = append(result, mem)
}
r.populateExplainStats(ctx, sparseMemoryItemPointers(result))
sort.Slice(result, func(i, j int) bool { return result[i].UpdatedAt > result[j].UpdatedAt })
if req.Limit > 0 && len(result) > req.Limit {
result = result[:req.Limit]
}
return adapters.SearchResponse{Results: result}, nil
}
func (r *sparseRuntime) Update(ctx context.Context, req adapters.UpdateRequest) (adapters.MemoryItem, error) {
memoryID := strings.TrimSpace(req.MemoryID)
if memoryID == "" {
return adapters.MemoryItem{}, errors.New("sparse runtime: memory_id is required")
}
text := strings.TrimSpace(req.Memory)
if text == "" {
return adapters.MemoryItem{}, errors.New("sparse runtime: memory is required")
}
botID := sparseRuntimeBotIDFromMemoryID(memoryID)
if botID == "" {
return adapters.MemoryItem{}, errors.New("sparse runtime: invalid memory_id")
}
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
if err != nil {
return adapters.MemoryItem{}, err
}
var existing *storefs.MemoryItem
for i := range items {
if strings.TrimSpace(items[i].ID) == memoryID {
item := items[i]
existing = &item
break
}
}
if existing == nil {
return adapters.MemoryItem{}, errors.New("sparse runtime: memory not found")
}
existing.Memory = text
existing.Hash = sparseRuntimeHash(text)
existing.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
if err := r.store.PersistMemories(ctx, botID, []storefs.MemoryItem{*existing}, nil); err != nil {
return adapters.MemoryItem{}, err
}
if err := r.upsertSourceItems(ctx, botID, []storefs.MemoryItem{*existing}); err != nil {
return adapters.MemoryItem{}, err
}
item := sparseMemoryItemFromStore(*existing)
item.BotID = botID
return item, nil
}
func (r *sparseRuntime) Delete(ctx context.Context, memoryID string) (adapters.DeleteResponse, error) {
return r.DeleteBatch(ctx, []string{memoryID})
}
func (r *sparseRuntime) DeleteBatch(ctx context.Context, memoryIDs []string) (adapters.DeleteResponse, error) {
grouped := map[string][]string{}
pointIDs := make([]string, 0, len(memoryIDs))
for _, rawID := range memoryIDs {
memoryID := strings.TrimSpace(rawID)
if memoryID == "" {
continue
}
botID := sparseRuntimeBotIDFromMemoryID(memoryID)
if botID == "" {
continue
}
grouped[botID] = append(grouped[botID], memoryID)
pointIDs = append(pointIDs, sparsePointID(botID, memoryID))
}
for botID, ids := range grouped {
if err := r.store.RemoveMemories(ctx, botID, ids); err != nil {
return adapters.DeleteResponse{}, err
}
}
if err := r.ensureCollection(ctx); err != nil {
return adapters.DeleteResponse{}, err
}
if err := r.qdrant.DeleteByIDs(ctx, pointIDs); err != nil {
return adapters.DeleteResponse{}, err
}
return adapters.DeleteResponse{Message: "Memories deleted successfully!"}, nil
}
func (r *sparseRuntime) DeleteAll(ctx context.Context, req adapters.DeleteAllRequest) (adapters.DeleteResponse, error) {
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
if err != nil {
return adapters.DeleteResponse{}, err
}
if err := r.store.RemoveAllMemories(ctx, botID); err != nil {
return adapters.DeleteResponse{}, err
}
if err := r.ensureCollection(ctx); err != nil {
return adapters.DeleteResponse{}, err
}
if err := r.qdrant.DeleteByBotID(ctx, botID); err != nil {
return adapters.DeleteResponse{}, err
}
return adapters.DeleteResponse{Message: "All memories deleted successfully!"}, nil
}
func (r *sparseRuntime) Compact(ctx context.Context, filters map[string]any, ratio float64, _ int) (adapters.CompactResult, error) {
botID, err := sparseRuntimeBotID("", filters)
if err != nil {
return adapters.CompactResult{}, err
}
all, err := r.store.ReadAllMemoryFiles(ctx, botID)
if err != nil {
return adapters.CompactResult{}, err
}
before := len(all)
if before == 0 {
return adapters.CompactResult{Ratio: ratio}, nil
}
sort.Slice(all, func(i, j int) bool {
return all[i].UpdatedAt > all[j].UpdatedAt
})
target := int(float64(before) * ratio)
if target < 1 {
target = 1
}
if target > before {
target = before
}
keptStore := append([]storefs.MemoryItem(nil), all[:target]...)
if err := r.store.RebuildFiles(ctx, botID, keptStore, filters); err != nil {
return adapters.CompactResult{}, err
}
if _, err := r.Rebuild(ctx, botID); err != nil {
return adapters.CompactResult{}, err
}
kept := make([]adapters.MemoryItem, 0, len(keptStore))
for _, item := range keptStore {
kept = append(kept, sparseMemoryItemFromStore(item))
}
return adapters.CompactResult{
BeforeCount: before,
AfterCount: len(kept),
Ratio: ratio,
Results: kept,
}, nil
}
func (r *sparseRuntime) Usage(ctx context.Context, filters map[string]any) (adapters.UsageResponse, error) {
botID, err := sparseRuntimeBotID("", filters)
if err != nil {
return adapters.UsageResponse{}, err
}
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
if err != nil {
return adapters.UsageResponse{}, err
}
var usage adapters.UsageResponse
usage.Count = len(items)
for _, item := range items {
usage.TotalTextBytes += int64(len(item.Memory))
}
if usage.Count > 0 {
usage.AvgTextBytes = usage.TotalTextBytes / int64(usage.Count)
}
usage.EstimatedStorageBytes = usage.TotalTextBytes
return usage, nil
}
func (r *sparseRuntime) Status(ctx context.Context, botID string) (adapters.MemoryStatusResponse, error) {
fileCount, err := r.store.CountMemoryFiles(ctx, botID)
if err != nil {
return adapters.MemoryStatusResponse{}, err
}
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
if err != nil {
return adapters.MemoryStatusResponse{}, err
}
status := adapters.MemoryStatusResponse{
ProviderType: BuiltinType,
MemoryMode: string(ModeSparse),
CanManualSync: true,
SourceDir: path.Join(config.DefaultDataMount, "memory"),
OverviewPath: path.Join(config.DefaultDataMount, "MEMORY.md"),
MarkdownFileCount: fileCount,
SourceCount: len(items),
QdrantCollection: r.qdrant.CollectionName(),
}
if err := r.encoder.Health(ctx); err != nil {
status.Encoder.Error = err.Error()
} else {
status.Encoder.OK = true
}
exists, err := r.qdrant.CollectionExists(ctx)
if err != nil {
status.Qdrant.Error = err.Error()
return status, nil
}
status.Qdrant.OK = true
if exists {
count, err := r.qdrant.Count(ctx, botID)
if err != nil {
status.Qdrant.OK = false
status.Qdrant.Error = err.Error()
return status, nil
}
status.IndexedCount = count
}
return status, nil
}
func (r *sparseRuntime) Rebuild(ctx context.Context, botID string) (adapters.RebuildResult, error) {
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
if err != nil {
return adapters.RebuildResult{}, err
}
if err := r.store.SyncOverview(ctx, botID); err != nil {
return adapters.RebuildResult{}, err
}
return r.syncSourceItems(ctx, botID, items)
}
// --- helpers ---
func (r *sparseRuntime) syncSourceItems(ctx context.Context, botID string, items []storefs.MemoryItem) (adapters.RebuildResult, error) {
if err := r.ensureCollection(ctx); err != nil {
return adapters.RebuildResult{}, err
}
existing, err := r.qdrant.Scroll(ctx, botID, 10000)
if err != nil {
return adapters.RebuildResult{}, err
}
existingBySource := make(map[string]qdrantclient.SearchResult, len(existing))
for _, item := range existing {
sourceID := strings.TrimSpace(item.Payload["source_entry_id"])
if sourceID == "" {
sourceID = strings.TrimSpace(item.ID)
}
if sourceID == "" {
continue
}
existingBySource[sourceID] = item
}
canonical := make([]storefs.MemoryItem, 0, len(items))
sourceIDs := make(map[string]struct{}, len(items))
toUpsert := make([]storefs.MemoryItem, 0, len(items))
missingCount := 0
restoredCount := 0
for _, item := range items {
item = sparseCanonicalStoreItem(item)
if item.ID == "" || item.Memory == "" {
continue
}
canonical = append(canonical, item)
sourceIDs[item.ID] = struct{}{}
payload := sparsePayload(botID, item)
existingItem, ok := existingBySource[item.ID]
if !ok {
missingCount++
restoredCount++
toUpsert = append(toUpsert, item)
continue
}
if !sparsePayloadMatches(existingItem.Payload, payload) {
restoredCount++
toUpsert = append(toUpsert, item)
}
}
stalePointIDs := make([]string, 0)
for _, item := range existing {
sourceID := strings.TrimSpace(item.Payload["source_entry_id"])
if sourceID == "" {
sourceID = strings.TrimSpace(item.ID)
}
if _, ok := sourceIDs[sourceID]; ok {
continue
}
if strings.TrimSpace(item.ID) != "" {
stalePointIDs = append(stalePointIDs, item.ID)
}
}
if len(stalePointIDs) > 0 {
if err := r.qdrant.DeleteByIDs(ctx, stalePointIDs); err != nil {
return adapters.RebuildResult{}, err
}
}
if err := r.upsertSourceItems(ctx, botID, toUpsert); err != nil {
return adapters.RebuildResult{}, err
}
count, err := r.qdrant.Count(ctx, botID)
if err != nil {
return adapters.RebuildResult{}, err
}
return adapters.RebuildResult{
FsCount: len(canonical),
StorageCount: count,
MissingCount: missingCount,
RestoredCount: restoredCount,
}, nil
}
func (r *sparseRuntime) upsertSourceItems(ctx context.Context, botID string, items []storefs.MemoryItem) error {
if len(items) == 0 {
return nil
}
if err := r.ensureCollection(ctx); err != nil {
return err
}
texts := make([]string, 0, len(items))
canonical := make([]storefs.MemoryItem, 0, len(items))
for _, item := range items {
item = sparseCanonicalStoreItem(item)
if item.ID == "" || item.Memory == "" {
continue
}
canonical = append(canonical, item)
texts = append(texts, item.Memory)
}
if len(canonical) == 0 {
return nil
}
vectors, err := r.encoder.EncodeDocuments(ctx, texts)
if err != nil {
return fmt.Errorf("sparse encode documents: %w", err)
}
if len(vectors) != len(canonical) {
return fmt.Errorf("sparse encode documents: expected %d vectors, got %d", len(canonical), len(vectors))
}
for i, item := range canonical {
vec := vectors[i]
if err := r.qdrant.Upsert(ctx, sparsePointID(botID, item.ID), qdrantclient.SparseVector{
Indices: vec.Indices,
Values: vec.Values,
}, sparsePayload(botID, item)); err != nil {
return err
}
}
return nil
}
func sparseResultToItem(r qdrantclient.SearchResult) adapters.MemoryItem {
item := adapters.MemoryItem{
ID: r.ID,
Score: r.Score,
}
if r.Payload != nil {
if sourceID := strings.TrimSpace(r.Payload["source_entry_id"]); sourceID != "" {
item.ID = sourceID
}
item.Memory = r.Payload["memory"]
item.Hash = r.Payload["hash"]
item.BotID = r.Payload["bot_id"]
item.CreatedAt = r.Payload["created_at"]
item.UpdatedAt = r.Payload["updated_at"]
}
return item
}
func (r *sparseRuntime) populateExplainStats(ctx context.Context, items []*adapters.MemoryItem) {
if len(items) == 0 {
return
}
texts := make([]string, 0, len(items))
targets := make([]*adapters.MemoryItem, 0, len(items))
for _, item := range items {
if item == nil || strings.TrimSpace(item.Memory) == "" {
continue
}
texts = append(texts, item.Memory)
targets = append(targets, item)
}
if len(texts) == 0 {
return
}
vectors, err := r.encoder.EncodeDocuments(ctx, texts)
if err != nil || len(vectors) != len(targets) {
return
}
for i := range targets {
topK, cdf := sparseExplainStats(vectors[i])
targets[i].TopKBuckets = topK
targets[i].CDFCurve = cdf
}
}
func sparseExplainStats(vec sparse.SparseVector) ([]adapters.TopKBucket, []adapters.CDFPoint) {
type pair struct {
index uint32
value float32
}
pairs := make([]pair, 0, len(vec.Values))
for i, value := range vec.Values {
if i >= len(vec.Indices) || value <= 0 {
continue
}
pairs = append(pairs, pair{index: vec.Indices[i], value: value})
}
if len(pairs) == 0 {
return nil, nil
}
sort.Slice(pairs, func(i, j int) bool {
if pairs[i].value == pairs[j].value {
return pairs[i].index < pairs[j].index
}
return pairs[i].value > pairs[j].value
})
topN := len(pairs)
if topN > sparseExplainTopKLimit {
topN = sparseExplainTopKLimit
}
topK := make([]adapters.TopKBucket, 0, topN)
total := 0.0
for _, pair := range pairs {
total += float64(pair.value)
}
for _, pair := range pairs[:topN] {
topK = append(topK, adapters.TopKBucket{
Index: pair.index,
Value: pair.value,
})
}
cdf := make([]adapters.CDFPoint, 0, len(pairs))
if total <= 0 {
return topK, cdf
}
running := 0.0
for i, pair := range pairs {
running += float64(pair.value)
cdf = append(cdf, adapters.CDFPoint{
K: i + 1,
Cumulative: running / total,
})
}
return topK, cdf
}
func sparseMemoryItemPointers(items []adapters.MemoryItem) []*adapters.MemoryItem {
if len(items) == 0 {
return nil
}
pointers := make([]*adapters.MemoryItem, 0, len(items))
for i := range items {
pointers = append(pointers, &items[i])
}
return pointers
}
func sparseCanonicalStoreItem(item storefs.MemoryItem) storefs.MemoryItem {
item.ID = strings.TrimSpace(item.ID)
item.Memory = strings.TrimSpace(item.Memory)
if item.Memory != "" && strings.TrimSpace(item.Hash) == "" {
item.Hash = sparseRuntimeHash(item.Memory)
}
return item
}
func sparsePayload(botID string, item storefs.MemoryItem) map[string]string {
item = sparseCanonicalStoreItem(item)
payload := map[string]string{
"memory": item.Memory,
"bot_id": strings.TrimSpace(botID),
"source_entry_id": item.ID,
"hash": item.Hash,
}
if item.CreatedAt != "" {
payload["created_at"] = item.CreatedAt
}
if item.UpdatedAt != "" {
payload["updated_at"] = item.UpdatedAt
}
return payload
}
func sparsePayloadMatches(existing, expected map[string]string) bool {
for key, value := range expected {
if strings.TrimSpace(existing[key]) != strings.TrimSpace(value) {
return false
}
}
return true
}
func sparseMemoryItemFromStore(item storefs.MemoryItem) adapters.MemoryItem {
item = sparseCanonicalStoreItem(item)
return adapters.MemoryItem{
ID: item.ID,
Memory: item.Memory,
Hash: item.Hash,
CreatedAt: item.CreatedAt,
UpdatedAt: item.UpdatedAt,
Score: item.Score,
Metadata: item.Metadata,
BotID: item.BotID,
AgentID: item.AgentID,
RunID: item.RunID,
}
}
func sparseStoreItemFromMemoryItem(item adapters.MemoryItem) storefs.MemoryItem {
return sparseCanonicalStoreItem(storefs.MemoryItem{
ID: item.ID,
Memory: item.Memory,
Hash: item.Hash,
CreatedAt: item.CreatedAt,
UpdatedAt: item.UpdatedAt,
Score: item.Score,
Metadata: item.Metadata,
BotID: item.BotID,
AgentID: item.AgentID,
RunID: item.RunID,
})
}
func sparseRuntimeText(message string, messages []adapters.Message) string {
text := strings.TrimSpace(message)
if text == "" && len(messages) > 0 {
parts := make([]string, 0, len(messages))
for _, m := range messages {
content := strings.TrimSpace(m.Content)
if content == "" {
continue
}
role := strings.ToUpper(strings.TrimSpace(m.Role))
if role == "" {
role = "MESSAGE"
}
parts = append(parts, "["+role+"] "+content)
}
text = strings.Join(parts, "\n")
}
return strings.TrimSpace(text)
}
func sparseRuntimeMemoryID(botID string, now time.Time) string {
return botID + ":" + "mem_" + strconv.FormatInt(now.UnixNano(), 10)
}
func sparseRuntimeHash(text string) string {
sum := sha256.Sum256([]byte(strings.TrimSpace(text)))
return hex.EncodeToString(sum[:])
}
func sparseRuntimeBotID(botID string, filters map[string]any) (string, error) {
botID = strings.TrimSpace(botID)
if botID == "" {
botID = strings.TrimSpace(sparseRuntimeAny(filters, "bot_id"))
}
if botID == "" {
botID = strings.TrimSpace(sparseRuntimeAny(filters, "scopeId"))
}
if botID == "" {
return "", errors.New("bot_id is required")
}
return botID, nil
}
func sparseRuntimeBotIDFromMemoryID(memoryID string) string {
parts := strings.SplitN(strings.TrimSpace(memoryID), ":", 2)
if len(parts) != 2 {
return ""
}
return strings.TrimSpace(parts[0])
}
func sparseRuntimeAny(m map[string]any, key string) string {
if m == nil {
return ""
}
v, ok := m[key]
if !ok || v == nil {
return ""
}
return strings.TrimSpace(fmt.Sprint(v))
}
func sparsePointID(botID, sourceID string) string {
return uuid.NewSHA1(uuid.NameSpaceURL, []byte(strings.TrimSpace(botID)+"\n"+strings.TrimSpace(sourceID))).String()
}
@@ -0,0 +1,413 @@
package builtin
import (
"context"
"log/slog"
"sort"
"strings"
"testing"
adapters "github.com/memohai/memoh/internal/memory/adapters"
qdrantclient "github.com/memohai/memoh/internal/memory/qdrant"
"github.com/memohai/memoh/internal/memory/sparse"
storefs "github.com/memohai/memoh/internal/memory/storefs"
)
type fakeSparseStore struct {
items map[string]storefs.MemoryItem
}
func newFakeSparseStore(items ...storefs.MemoryItem) *fakeSparseStore {
store := &fakeSparseStore{items: map[string]storefs.MemoryItem{}}
for _, item := range items {
store.items[item.ID] = item
}
return store
}
func (s *fakeSparseStore) PersistMemories(_ context.Context, _ string, items []storefs.MemoryItem, _ map[string]any) error {
for _, item := range items {
s.items[item.ID] = item
}
return nil
}
func (s *fakeSparseStore) ReadAllMemoryFiles(_ context.Context, _ string) ([]storefs.MemoryItem, error) {
out := make([]storefs.MemoryItem, 0, len(s.items))
for _, item := range s.items {
out = append(out, item)
}
sort.Slice(out, func(i, j int) bool { return out[i].ID < out[j].ID })
return out, nil
}
func (s *fakeSparseStore) RemoveMemories(_ context.Context, _ string, ids []string) error {
for _, id := range ids {
delete(s.items, strings.TrimSpace(id))
}
return nil
}
func (s *fakeSparseStore) RemoveAllMemories(_ context.Context, _ string) error {
s.items = map[string]storefs.MemoryItem{}
return nil
}
func (s *fakeSparseStore) RebuildFiles(_ context.Context, _ string, items []storefs.MemoryItem, _ map[string]any) error {
s.items = map[string]storefs.MemoryItem{}
for _, item := range items {
s.items[item.ID] = item
}
return nil
}
func (*fakeSparseStore) SyncOverview(context.Context, string) error { return nil }
func (s *fakeSparseStore) CountMemoryFiles(_ context.Context, _ string) (int, error) {
if len(s.items) == 0 {
return 0, nil
}
return 1, nil
}
type fakeSparseEncoder struct {
lastQuery string
}
func (*fakeSparseEncoder) EncodeDocument(_ context.Context, _ string) (*sparse.SparseVector, error) {
return &sparse.SparseVector{Indices: []uint32{1, 2, 3}, Values: []float32{1, 3, 2}}, nil
}
func (*fakeSparseEncoder) EncodeDocuments(_ context.Context, texts []string) ([]sparse.SparseVector, error) {
out := make([]sparse.SparseVector, 0, len(texts))
for _, text := range texts {
_ = text
out = append(out, sparse.SparseVector{Indices: []uint32{1, 2, 3}, Values: []float32{1, 3, 2}})
}
return out, nil
}
func (e *fakeSparseEncoder) EncodeQuery(_ context.Context, text string) (*sparse.SparseVector, error) {
e.lastQuery = text
return &sparse.SparseVector{Indices: []uint32{9}, Values: []float32{1}}, nil
}
func (*fakeSparseEncoder) Health(context.Context) error { return nil }
type fakeSparseIndex struct {
encoder *fakeSparseEncoder
collection string
exists bool
points map[string]qdrantclient.SearchResult
}
func newFakeSparseIndex(encoder *fakeSparseEncoder) *fakeSparseIndex {
return &fakeSparseIndex{
encoder: encoder,
collection: "memory_sparse_test",
points: map[string]qdrantclient.SearchResult{},
}
}
func (i *fakeSparseIndex) CollectionName() string { return i.collection }
func (i *fakeSparseIndex) CollectionExists(context.Context) (bool, error) { return i.exists, nil }
func (i *fakeSparseIndex) EnsureCollection(context.Context) error {
i.exists = true
return nil
}
func (i *fakeSparseIndex) Upsert(_ context.Context, id string, _ qdrantclient.SparseVector, payload map[string]string) error {
i.exists = true
i.points[id] = qdrantclient.SearchResult{
ID: id,
Score: 1,
Payload: payload,
}
return nil
}
func (i *fakeSparseIndex) Search(_ context.Context, _ qdrantclient.SparseVector, botID string, limit int) ([]qdrantclient.SearchResult, error) {
query := strings.ToLower(strings.TrimSpace(i.encoder.lastQuery))
results := make([]qdrantclient.SearchResult, 0, len(i.points))
for _, point := range i.points {
if strings.TrimSpace(point.Payload["bot_id"]) != strings.TrimSpace(botID) {
continue
}
text := strings.ToLower(point.Payload["memory"])
if query != "" && !strings.Contains(text, query) {
continue
}
point.Score = 1
results = append(results, point)
}
sort.Slice(results, func(a, b int) bool { return results[a].ID < results[b].ID })
if limit > 0 && len(results) > limit {
results = results[:limit]
}
return results, nil
}
func (i *fakeSparseIndex) Scroll(_ context.Context, botID string, limit int) ([]qdrantclient.SearchResult, error) {
results := make([]qdrantclient.SearchResult, 0, len(i.points))
for _, point := range i.points {
if strings.TrimSpace(point.Payload["bot_id"]) != strings.TrimSpace(botID) {
continue
}
results = append(results, point)
}
sort.Slice(results, func(a, b int) bool { return results[a].ID < results[b].ID })
if limit > 0 && len(results) > limit {
results = results[:limit]
}
return results, nil
}
func (i *fakeSparseIndex) Count(_ context.Context, botID string) (int, error) {
count := 0
for _, point := range i.points {
if strings.TrimSpace(point.Payload["bot_id"]) == strings.TrimSpace(botID) {
count++
}
}
return count, nil
}
func (i *fakeSparseIndex) DeleteByIDs(_ context.Context, ids []string) error {
for _, id := range ids {
delete(i.points, strings.TrimSpace(id))
}
return nil
}
func (i *fakeSparseIndex) DeleteByBotID(_ context.Context, botID string) error {
for id, point := range i.points {
if strings.TrimSpace(point.Payload["bot_id"]) == strings.TrimSpace(botID) {
delete(i.points, id)
}
}
return nil
}
func TestSparseRuntimeAddWritesSourceAndSupportsRecall(t *testing.T) {
t.Parallel()
encoder := &fakeSparseEncoder{}
index := newFakeSparseIndex(encoder)
store := newFakeSparseStore()
runtime := &sparseRuntime{
qdrant: index,
encoder: encoder,
store: store,
}
resp, err := runtime.Add(context.Background(), adapters.AddRequest{
BotID: "bot-1",
Message: "Ran likes oolong tea",
Filters: map[string]any{"scopeId": "bot-1"},
})
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if len(resp.Results) != 1 {
t.Fatalf("expected 1 add result, got %d", len(resp.Results))
}
item := resp.Results[0]
if item.ID == "" {
t.Fatal("expected source memory id to be populated")
}
if _, ok := store.items[item.ID]; !ok {
t.Fatalf("expected memory %q to be written to markdown source", item.ID)
}
point, ok := index.points[sparsePointID("bot-1", item.ID)]
if !ok {
t.Fatalf("expected qdrant point for source memory %q", item.ID)
}
if point.Payload["source_entry_id"] != item.ID {
t.Fatalf("expected source_entry_id payload %q, got %q", item.ID, point.Payload["source_entry_id"])
}
if len(item.TopKBuckets) != 0 || len(item.CDFCurve) != 0 {
t.Fatalf("expected add response to skip explain stats, got %#v", item)
}
searchResp, err := runtime.Search(context.Background(), adapters.SearchRequest{
BotID: "bot-1",
Query: "oolong tea",
})
if err != nil {
t.Fatalf("Search() error = %v", err)
}
if len(searchResp.Results) != 1 {
t.Fatalf("expected 1 search result, got %d", len(searchResp.Results))
}
if searchResp.Results[0].ID != item.ID {
t.Fatalf("expected search result id %q, got %q", item.ID, searchResp.Results[0].ID)
}
if len(searchResp.Results[0].TopKBuckets) != 0 || len(searchResp.Results[0].CDFCurve) != 0 {
t.Fatalf("expected search result to skip explain stats, got %#v", searchResp.Results[0])
}
}
func TestSparseRuntimeRebuildSyncsSourceAndRemovesStalePoints(t *testing.T) {
t.Parallel()
encoder := &fakeSparseEncoder{}
index := newFakeSparseIndex(encoder)
store := newFakeSparseStore(
storefs.MemoryItem{
ID: "bot-1:mem_1",
Memory: "Ran likes tea",
Hash: sparseRuntimeHash("Ran likes tea"),
CreatedAt: "2026-03-13T09:00:00Z",
UpdatedAt: "2026-03-13T09:00:00Z",
},
storefs.MemoryItem{
ID: "bot-1:mem_2",
Memory: "Ran works in Berlin",
Hash: sparseRuntimeHash("Ran works in Berlin"),
CreatedAt: "2026-03-13T10:00:00Z",
UpdatedAt: "2026-03-13T10:00:00Z",
},
)
runtime := &sparseRuntime{
qdrant: index,
encoder: encoder,
store: store,
}
index.points[sparsePointID("bot-1", "bot-1:mem_1")] = qdrantclient.SearchResult{
ID: sparsePointID("bot-1", "bot-1:mem_1"),
Payload: map[string]string{
"bot_id": "bot-1",
"memory": "Ran likes tea",
"source_entry_id": "bot-1:mem_1",
"hash": "outdated",
"created_at": "2026-03-13T09:00:00Z",
"updated_at": "2026-03-13T09:00:00Z",
},
}
index.points[sparsePointID("bot-1", "bot-1:stale")] = qdrantclient.SearchResult{
ID: sparsePointID("bot-1", "bot-1:stale"),
Payload: map[string]string{
"bot_id": "bot-1",
"memory": "stale memory",
"source_entry_id": "bot-1:stale",
},
}
result, err := runtime.Rebuild(context.Background(), "bot-1")
if err != nil {
t.Fatalf("Rebuild() error = %v", err)
}
if result.FsCount != 2 {
t.Fatalf("expected fs_count=2, got %d", result.FsCount)
}
if result.StorageCount != 2 {
t.Fatalf("expected storage_count=2, got %d", result.StorageCount)
}
if result.MissingCount != 1 {
t.Fatalf("expected missing_count=1, got %d", result.MissingCount)
}
if result.RestoredCount != 2 {
t.Fatalf("expected restored_count=2, got %d", result.RestoredCount)
}
if _, ok := index.points[sparsePointID("bot-1", "bot-1:stale")]; ok {
t.Fatal("expected stale qdrant point to be removed")
}
}
func TestSparseRuntimeGetAllIncludesExplainStats(t *testing.T) {
t.Parallel()
encoder := &fakeSparseEncoder{}
index := newFakeSparseIndex(encoder)
store := newFakeSparseStore(
storefs.MemoryItem{
ID: "bot-1:mem_1",
Memory: "Ran likes tea",
Hash: sparseRuntimeHash("Ran likes tea"),
CreatedAt: "2026-03-13T09:00:00Z",
UpdatedAt: "2026-03-13T09:00:00Z",
},
)
runtime := &sparseRuntime{
qdrant: index,
encoder: encoder,
store: store,
}
resp, err := runtime.GetAll(context.Background(), adapters.GetAllRequest{BotID: "bot-1"})
if err != nil {
t.Fatalf("GetAll() error = %v", err)
}
if len(resp.Results) != 1 {
t.Fatalf("expected 1 result, got %d", len(resp.Results))
}
if len(resp.Results[0].TopKBuckets) == 0 || len(resp.Results[0].CDFCurve) == 0 {
t.Fatalf("expected get all result to include explain stats, got %#v", resp.Results[0])
}
if resp.Results[0].TopKBuckets[0].Index != 2 {
t.Fatalf("expected top bucket index 2, got %d", resp.Results[0].TopKBuckets[0].Index)
}
if got := resp.Results[0].CDFCurve[len(resp.Results[0].CDFCurve)-1].Cumulative; got != 1 {
t.Fatalf("expected final CDF cumulative to be 1, got %v", got)
}
}
func TestBuiltinProviderMultiTurnRecallUsesSparseSourceRuntime(t *testing.T) {
t.Parallel()
encoder := &fakeSparseEncoder{}
index := newFakeSparseIndex(encoder)
store := newFakeSparseStore()
runtime := &sparseRuntime{
qdrant: index,
encoder: encoder,
store: store,
}
provider := NewBuiltinProvider(slog.Default(), runtime, nil, nil)
err := provider.OnAfterChat(context.Background(), adapters.AfterChatRequest{
BotID: "bot-1",
Messages: []adapters.Message{
{Role: "user", Content: "I like oolong tea."},
{Role: "assistant", Content: "Noted, you like oolong tea."},
},
})
if err != nil {
t.Fatalf("OnAfterChat round 1 error = %v", err)
}
err = provider.OnAfterChat(context.Background(), adapters.AfterChatRequest{
BotID: "bot-1",
Messages: []adapters.Message{
{Role: "user", Content: "I am based in Berlin."},
{Role: "assistant", Content: "Understood, you are based in Berlin."},
},
})
if err != nil {
t.Fatalf("OnAfterChat round 2 error = %v", err)
}
before, err := provider.OnBeforeChat(context.Background(), adapters.BeforeChatRequest{
BotID: "bot-1",
Query: "berlin",
})
if err != nil {
t.Fatalf("OnBeforeChat() error = %v", err)
}
if before == nil || !strings.Contains(strings.ToLower(before.ContextText), "berlin") {
t.Fatalf("expected recalled context to mention berlin, got %#v", before)
}
before, err = provider.OnBeforeChat(context.Background(), adapters.BeforeChatRequest{
BotID: "bot-1",
Query: "oolong tea",
})
if err != nil {
t.Fatalf("OnBeforeChat() tea error = %v", err)
}
if before == nil || !strings.Contains(strings.ToLower(before.ContextText), "oolong tea") {
t.Fatalf("expected recalled context to mention oolong tea, got %#v", before)
}
}
+93
View File
@@ -0,0 +1,93 @@
package adapters
import (
"fmt"
"strings"
)
// TruncateSnippet truncates a string to n runes, appending "..." if truncated.
func TruncateSnippet(s string, n int) string {
trimmed := strings.TrimSpace(s)
runes := []rune(trimmed)
if len(runes) <= n {
return trimmed
}
return strings.TrimSpace(string(runes[:n])) + "..."
}
// DeduplicateItems removes duplicate MemoryItems by ID.
func DeduplicateItems(items []MemoryItem) []MemoryItem {
if len(items) == 0 {
return items
}
seen := make(map[string]struct{}, len(items))
result := make([]MemoryItem, 0, len(items))
for _, item := range items {
id := strings.TrimSpace(item.ID)
if id == "" {
id = strings.TrimSpace(item.Memory)
}
if id == "" {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
result = append(result, item)
}
return result
}
// StringFromConfig extracts a trimmed string value from a config map.
func StringFromConfig(config map[string]any, key string) string {
if config == nil {
return ""
}
v, ok := config[key]
if !ok {
return ""
}
s, ok := v.(string)
if !ok {
return ""
}
return strings.TrimSpace(s)
}
func MergeMetadata(base map[string]any, extra map[string]any) map[string]any {
if len(base) == 0 && len(extra) == 0 {
return nil
}
out := make(map[string]any, len(base)+len(extra))
for k, v := range base {
out[k] = v
}
for k, v := range extra {
out[k] = v
}
return out
}
func BuildProfileMetadata(userID, channelIdentityID, displayName string) map[string]any {
userID = strings.TrimSpace(userID)
channelIdentityID = strings.TrimSpace(channelIdentityID)
displayName = strings.TrimSpace(displayName)
if userID == "" && channelIdentityID == "" && displayName == "" {
return nil
}
out := map[string]any{}
if userID != "" {
out["profile_user_id"] = userID
out["profile_ref"] = fmt.Sprintf("user:%s", userID)
} else if channelIdentityID != "" {
out["profile_ref"] = fmt.Sprintf("channel_identity:%s", channelIdentityID)
}
if channelIdentityID != "" {
out["profile_channel_identity_id"] = channelIdentityID
}
if displayName != "" {
out["profile_display_name"] = displayName
}
return out
}
+418
View File
@@ -0,0 +1,418 @@
package mem0
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
adapters "github.com/memohai/memoh/internal/memory/adapters"
)
const (
mem0DefaultBaseURL = "https://api.mem0.ai"
mem0OutputFormatV11 = "v1.1"
mem0VersionV2 = "v2"
mem0BatchDeleteMaxSize = 1000
mem0ListPageSize = 1000
)
type mem0Client struct {
baseURL string
apiKey string
orgID string
projectID string
httpClient *http.Client
}
func newMem0Client(config map[string]any) (*mem0Client, error) {
baseURL := adapters.StringFromConfig(config, "base_url")
if baseURL == "" {
baseURL = mem0DefaultBaseURL
}
apiKey := adapters.StringFromConfig(config, "api_key")
if apiKey == "" {
return nil, errors.New("mem0: api_key is required for SaaS")
}
baseURL = strings.TrimRight(baseURL, "/")
return &mem0Client{
baseURL: baseURL,
apiKey: apiKey,
orgID: adapters.StringFromConfig(config, "org_id"),
projectID: adapters.StringFromConfig(config, "project_id"),
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}, nil
}
type mem0AddRequest struct {
Messages []adapters.Message `json:"messages,omitempty"`
UserID string `json:"user_id,omitempty"`
AgentID string `json:"agent_id,omitempty"`
RunID string `json:"run_id,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Infer *bool `json:"infer,omitempty"`
AsyncMode *bool `json:"async_mode,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
Version string `json:"version,omitempty"`
OrgID string `json:"org_id,omitempty"`
ProjectID string `json:"project_id,omitempty"`
}
type mem0AddEvent struct {
ID string `json:"id"`
Event string `json:"event,omitempty"`
Data struct {
Memory string `json:"memory,omitempty"`
Text string `json:"text,omitempty"`
} `json:"data"`
}
type mem0Memory struct {
ID string `json:"id"`
Memory string `json:"memory"`
Text string `json:"text,omitempty"`
Hash string `json:"hash,omitempty"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
Score float64 `json:"score,omitempty"`
UserID string `json:"user_id,omitempty"`
AgentID string `json:"agent_id,omitempty"`
RunID string `json:"run_id,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
}
type mem0SearchRequest struct {
Query string `json:"query"`
Version string `json:"version,omitempty"`
Filters map[string]any `json:"filters,omitempty"`
TopK int `json:"top_k,omitempty"`
OrgID string `json:"org_id,omitempty"`
ProjectID string `json:"project_id,omitempty"`
}
type mem0GetAllRequest struct {
Filters map[string]any `json:"filters"`
Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
OrgID string `json:"org_id,omitempty"`
ProjectID string `json:"project_id,omitempty"`
}
type mem0UpdateRequest struct {
Text string `json:"text"`
Metadata map[string]any `json:"metadata,omitempty"`
}
type mem0MemoriesResponse struct {
Results []mem0Memory `json:"results"`
Memories []mem0Memory `json:"memories"`
Total int `json:"total,omitempty"`
Relations []any `json:"relations,omitempty"`
}
type mem0AddEventsResponse struct {
Results []mem0AddEvent `json:"results"`
}
func (c *mem0Client) Add(ctx context.Context, req mem0AddRequest) ([]mem0Memory, error) {
if req.OutputFormat == "" {
req.OutputFormat = mem0OutputFormatV11
}
if req.Version == "" {
req.Version = mem0VersionV2
}
req.OrgID = c.orgID
req.ProjectID = c.projectID
body, err := c.doJSONRaw(ctx, http.MethodPost, "/v1/memories/", req)
if err != nil {
return nil, fmt.Errorf("mem0 add: %w", err)
}
memories, err := parseMem0AddMemories(body)
if err != nil {
return nil, fmt.Errorf("mem0 add: %w", err)
}
return memories, nil
}
func (c *mem0Client) Search(ctx context.Context, req mem0SearchRequest) ([]mem0Memory, error) {
if req.Version == "" {
req.Version = mem0VersionV2
}
req.OrgID = c.orgID
req.ProjectID = c.projectID
body, err := c.doJSONRaw(ctx, http.MethodPost, "/v2/memories/search/", req)
if err != nil {
return nil, fmt.Errorf("mem0 search: %w", err)
}
results, err := parseMem0Memories(body)
if err != nil {
return nil, fmt.Errorf("mem0 search: %w", err)
}
return results, nil
}
func (c *mem0Client) GetAll(ctx context.Context, req mem0GetAllRequest) ([]mem0Memory, error) {
req.OrgID = c.orgID
req.ProjectID = c.projectID
if req.OutputFormat == "" {
req.OutputFormat = mem0OutputFormatV11
}
body, err := c.doJSONRaw(ctx, http.MethodPost, "/v2/memories/", req)
if err != nil {
return nil, fmt.Errorf("mem0 get all: %w", err)
}
results, err := parseMem0Memories(body)
if err != nil {
return nil, fmt.Errorf("mem0 get all: %w", err)
}
return results, nil
}
func (c *mem0Client) ListAllByAgent(ctx context.Context, agentID string) ([]mem0Memory, error) {
agentID = strings.TrimSpace(agentID)
if agentID == "" {
return nil, errors.New("agent_id is required")
}
all := make([]mem0Memory, 0)
seen := map[string]struct{}{}
for page := 1; ; page++ {
results, err := c.GetAll(ctx, mem0GetAllRequest{
Filters: map[string]any{
"agent_id": agentID,
},
Page: page,
PageSize: mem0ListPageSize,
})
if err != nil {
return nil, err
}
for _, item := range results {
id := strings.TrimSpace(item.ID)
if id == "" {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
all = append(all, item)
}
if len(results) < mem0ListPageSize {
break
}
}
return all, nil
}
func (c *mem0Client) Update(ctx context.Context, memoryID string, text string, metadata map[string]any) (*mem0Memory, error) {
var result mem0Memory
if err := c.doJSON(ctx, http.MethodPut, "/v1/memories/"+memoryID+"/", mem0UpdateRequest{
Text: text,
Metadata: metadata,
}, &result); err != nil {
return nil, fmt.Errorf("mem0 update: %w", err)
}
result = normalizeMem0Memory(result)
return &result, nil
}
func (c *mem0Client) Delete(ctx context.Context, memoryID string) error {
if err := c.doJSON(ctx, http.MethodDelete, "/v1/memories/"+memoryID+"/", nil, nil); err != nil {
return fmt.Errorf("mem0 delete: %w", err)
}
return nil
}
func (c *mem0Client) DeleteAll(ctx context.Context, agentID string) error {
q := url.Values{}
q.Set("agent_id", agentID)
if c.orgID != "" {
q.Set("org_id", c.orgID)
}
if c.projectID != "" {
q.Set("project_id", c.projectID)
}
u := "/v1/memories/?" + q.Encode()
if err := c.doJSON(ctx, http.MethodDelete, u, nil, nil); err != nil {
return fmt.Errorf("mem0 delete all: %w", err)
}
return nil
}
func (c *mem0Client) BatchDelete(ctx context.Context, memoryIDs []string) error {
if len(memoryIDs) == 0 {
return nil
}
if len(memoryIDs) > mem0BatchDeleteMaxSize {
return fmt.Errorf("mem0 batch delete: maximum %d memories allowed", mem0BatchDeleteMaxSize)
}
memories := make([]map[string]string, 0, len(memoryIDs))
for _, id := range memoryIDs {
id = strings.TrimSpace(id)
if id == "" {
continue
}
memories = append(memories, map[string]string{"memory_id": id})
}
if len(memories) == 0 {
return nil
}
ids := make([]string, 0, len(memories))
for _, memory := range memories {
ids = append(ids, memory["memory_id"])
}
if err := c.doJSON(ctx, http.MethodDelete, "/v1/batch/", map[string]any{
"memory_ids": ids,
"memories": memories,
}, nil); err != nil {
return fmt.Errorf("mem0 batch delete: %w", err)
}
return nil
}
func (c *mem0Client) doJSON(ctx context.Context, method, urlPath string, body any, result any) error {
respBody, err := c.doJSONRaw(ctx, method, urlPath, body)
if err != nil {
return err
}
if result != nil && len(respBody) > 0 {
if err := json.Unmarshal(respBody, result); err != nil {
return fmt.Errorf("unmarshal response: %w", err)
}
}
return nil
}
func (c *mem0Client) doJSONRaw(ctx context.Context, method, urlPath string, body any) ([]byte, error) {
var bodyReader io.Reader
if body != nil {
data, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
bodyReader = bytes.NewReader(data)
}
req, err := http.NewRequestWithContext(ctx, method, c.baseURL+urlPath, bodyReader)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Token "+c.apiKey)
if c.orgID != "" {
req.Header.Set("X-Org-Id", c.orgID)
}
if c.projectID != "" {
req.Header.Set("X-Project-Id", c.projectID)
}
resp, err := c.httpClient.Do(req) //nolint:gosec // URL is from admin-configured base_url
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("mem0 API error %d: %s", resp.StatusCode, truncateBody(respBody))
}
return respBody, nil
}
func parseMem0AddMemories(body []byte) ([]mem0Memory, error) {
memories, err := parseMem0Memories(body)
if err == nil && hasConcreteMem0Memories(memories) {
return memories, nil
}
var envelope mem0AddEventsResponse
if err := json.Unmarshal(body, &envelope); err == nil && len(envelope.Results) > 0 {
return mem0EventsToMemories(envelope.Results), nil
}
var events []mem0AddEvent
if err := json.Unmarshal(body, &events); err == nil && len(events) > 0 {
return mem0EventsToMemories(events), nil
}
if err == nil {
return memories, nil
}
return nil, err
}
func parseMem0Memories(body []byte) ([]mem0Memory, error) {
var list []mem0Memory
if err := json.Unmarshal(body, &list); err == nil {
return normalizeMem0Memories(list), nil
}
var envelope mem0MemoriesResponse
if err := json.Unmarshal(body, &envelope); err == nil {
switch {
case len(envelope.Results) > 0:
return normalizeMem0Memories(envelope.Results), nil
case len(envelope.Memories) > 0:
return normalizeMem0Memories(envelope.Memories), nil
default:
return []mem0Memory{}, nil
}
}
return nil, errors.New("unsupported mem0 response shape")
}
func mem0EventsToMemories(events []mem0AddEvent) []mem0Memory {
memories := make([]mem0Memory, 0, len(events))
for _, event := range events {
memory := strings.TrimSpace(event.Data.Memory)
if memory == "" {
memory = strings.TrimSpace(event.Data.Text)
}
memories = append(memories, mem0Memory{
ID: strings.TrimSpace(event.ID),
Memory: memory,
})
}
return memories
}
func normalizeMem0Memories(memories []mem0Memory) []mem0Memory {
for i := range memories {
memories[i] = normalizeMem0Memory(memories[i])
}
return memories
}
func normalizeMem0Memory(memory mem0Memory) mem0Memory {
if strings.TrimSpace(memory.Memory) == "" {
memory.Memory = strings.TrimSpace(memory.Text)
}
return memory
}
func hasConcreteMem0Memories(memories []mem0Memory) bool {
for _, memory := range memories {
if strings.TrimSpace(memory.Memory) != "" {
return true
}
}
return false
}
func truncateBody(b []byte) string {
if len(b) > 300 {
return string(b[:300]) + "..."
}
return string(b)
}
@@ -0,0 +1,78 @@
package mem0
import "testing"
func TestNewMem0ClientDefaultsToSaaS(t *testing.T) {
t.Parallel()
client, err := newMem0Client(map[string]any{
"api_key": "test-key",
})
if err != nil {
t.Fatalf("newMem0Client() error = %v", err)
}
if client.baseURL != mem0DefaultBaseURL {
t.Fatalf("baseURL = %q, want %q", client.baseURL, mem0DefaultBaseURL)
}
}
func TestParseMem0AddMemoriesSupportsEventResponses(t *testing.T) {
t.Parallel()
body := []byte(`[
{
"id": "mem_123",
"event": "ADD",
"data": {
"memory": "The user likes oolong tea."
}
}
]`)
memories, err := parseMem0AddMemories(body)
if err != nil {
t.Fatalf("parseMem0AddMemories() error = %v", err)
}
if len(memories) != 1 {
t.Fatalf("len(memories) = %d, want 1", len(memories))
}
if memories[0].ID != "mem_123" {
t.Fatalf("memory id = %q, want %q", memories[0].ID, "mem_123")
}
if memories[0].Memory != "The user likes oolong tea." {
t.Fatalf("memory text = %q", memories[0].Memory)
}
}
func TestParseMem0MemoriesSupportsEnvelopeResponses(t *testing.T) {
t.Parallel()
body := []byte(`{
"results": [
{
"id": "mem_456",
"memory": "The user lives in Shanghai.",
"score": 0.92,
"agent_id": "bot-1",
"run_id": "run-1",
"created_at": "2026-01-01T00:00:00Z",
"updated_at": "2026-01-02T00:00:00Z"
}
],
"total": 1
}`)
memories, err := parseMem0Memories(body)
if err != nil {
t.Fatalf("parseMem0Memories() error = %v", err)
}
if len(memories) != 1 {
t.Fatalf("len(memories) = %d, want 1", len(memories))
}
if memories[0].Score != 0.92 {
t.Fatalf("score = %v, want 0.92", memories[0].Score)
}
if memories[0].AgentID != "bot-1" {
t.Fatalf("agent_id = %q, want %q", memories[0].AgentID, "bot-1")
}
}
+546
View File
@@ -0,0 +1,546 @@
package mem0
import (
"context"
"errors"
"fmt"
"log/slog"
"path"
"sort"
"strings"
"github.com/memohai/memoh/internal/config"
"github.com/memohai/memoh/internal/mcp"
adapters "github.com/memohai/memoh/internal/memory/adapters"
storefs "github.com/memohai/memoh/internal/memory/storefs"
)
const (
Mem0Type = "mem0"
mem0ToolSearchMemory = "search_memory"
mem0DefaultLimit = 8
mem0MaxLimit = 50
mem0ContextMaxItems = 8
mem0ContextMaxChars = 220
mem0SyncMetadataKeySourceEntryID = "source_entry_id"
mem0SyncMetadataKeySourceHash = "source_hash"
mem0SyncMetadataKeySourceBotID = "source_bot_id"
mem0SyncMetadataKeySourceManaged = "source_managed"
)
// Mem0Provider implements adapters.Provider by delegating to the Mem0 SaaS API.
type Mem0Provider struct {
client *mem0Client
logger *slog.Logger
store *storefs.Service
}
func NewMem0Provider(log *slog.Logger, config map[string]any, store *storefs.Service) (*Mem0Provider, error) {
if log == nil {
log = slog.Default()
}
c, err := newMem0Client(config)
if err != nil {
return nil, err
}
return &Mem0Provider{
client: c,
logger: log.With(slog.String("provider", Mem0Type)),
store: store,
}, nil
}
func (*Mem0Provider) Type() string { return Mem0Type }
// --- Conversation Hooks ---
func (p *Mem0Provider) OnBeforeChat(ctx context.Context, req adapters.BeforeChatRequest) (*adapters.BeforeChatResult, error) {
query := strings.TrimSpace(req.Query)
botID := strings.TrimSpace(req.BotID)
if query == "" || botID == "" {
return nil, nil
}
memories, err := p.searchMemories(ctx, query, botID, mem0ContextMaxItems)
if err != nil {
p.logger.Warn("mem0 search for context failed", slog.Any("error", err))
return nil, nil
}
if len(memories) == 0 {
return nil, nil
}
var sb strings.Builder
sb.WriteString("<memory-context>\nRelevant memory context (use when helpful):\n")
for i, mem := range memories {
if i >= mem0ContextMaxItems {
break
}
text := strings.TrimSpace(mem.Memory)
if text == "" {
continue
}
sb.WriteString("- ")
sb.WriteString(adapters.TruncateSnippet(text, mem0ContextMaxChars))
sb.WriteString("\n")
}
sb.WriteString("</memory-context>")
return &adapters.BeforeChatResult{ContextText: sb.String()}, nil
}
func (p *Mem0Provider) OnAfterChat(ctx context.Context, req adapters.AfterChatRequest) error {
botID := strings.TrimSpace(req.BotID)
if botID == "" || len(req.Messages) == 0 {
return nil
}
_, err := p.client.Add(ctx, mem0AddRequest{
Messages: req.Messages,
AgentID: botID,
})
if err != nil {
p.logger.Warn("mem0 store memory failed", slog.String("bot_id", botID), slog.Any("error", err))
}
return nil
}
// --- MCP Tools ---
func (*Mem0Provider) ListTools(_ context.Context, _ mcp.ToolSessionContext) ([]mcp.ToolDescriptor, error) {
return []mcp.ToolDescriptor{
{
Name: mem0ToolSearchMemory,
Description: "Search for memories relevant to the current chat",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{
"type": "string",
"description": "The query to search memories",
},
"limit": map[string]any{
"type": "integer",
"description": "Maximum number of memory results",
},
},
"required": []string{"query"},
},
},
}, nil
}
func (p *Mem0Provider) CallTool(ctx context.Context, session mcp.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) {
if toolName != mem0ToolSearchMemory {
return nil, mcp.ErrToolNotFound
}
query := mcp.StringArg(arguments, "query")
if query == "" {
return mcp.BuildToolErrorResult("query is required"), nil
}
botID := strings.TrimSpace(session.BotID)
if botID == "" {
return mcp.BuildToolErrorResult("bot_id is required"), nil
}
limit := mem0DefaultLimit
if value, ok, err := mcp.IntArg(arguments, "limit"); err != nil {
return mcp.BuildToolErrorResult(err.Error()), nil
} else if ok {
limit = value
}
if limit <= 0 {
limit = mem0DefaultLimit
}
if limit > mem0MaxLimit {
limit = mem0MaxLimit
}
memories, err := p.searchMemories(ctx, query, botID, limit)
if err != nil {
return mcp.BuildToolErrorResult("memory search failed"), nil
}
results := make([]map[string]any, 0, len(memories))
for _, mem := range memories {
results = append(results, map[string]any{
"id": mem.ID,
"memory": mem.Memory,
"score": mem.Score,
})
}
return mcp.BuildToolSuccessResult(map[string]any{
"query": query,
"total": len(results),
"results": results,
}), nil
}
// --- CRUD ---
func (p *Mem0Provider) Add(ctx context.Context, req adapters.AddRequest) (adapters.SearchResponse, error) {
agentID := mem0ScopeID(req.BotID, req.AgentID)
if agentID == "" {
return adapters.SearchResponse{}, errors.New("bot_id or agent_id is required")
}
addReq := mem0AddRequest{
AgentID: agentID,
RunID: req.RunID,
Infer: req.Infer,
}
if req.Message != "" {
addReq.Messages = []adapters.Message{{Role: "user", Content: req.Message}}
} else {
addReq.Messages = req.Messages
}
if req.Metadata != nil {
addReq.Metadata = req.Metadata
}
memories, err := p.client.Add(ctx, addReq)
if err != nil {
return adapters.SearchResponse{}, err
}
return adapters.SearchResponse{Results: mem0ToItems(memories)}, nil
}
func (p *Mem0Provider) Search(ctx context.Context, req adapters.SearchRequest) (adapters.SearchResponse, error) {
agentID := mem0ScopeID(req.BotID, req.AgentID)
if agentID == "" {
return adapters.SearchResponse{}, errors.New("bot_id or agent_id is required")
}
limit := req.Limit
if limit <= 0 {
limit = mem0DefaultLimit
} else if limit > mem0MaxLimit {
limit = mem0MaxLimit
}
memories, err := p.client.Search(ctx, mem0SearchRequest{
Query: req.Query,
TopK: limit,
Filters: mem0AgentFilter(agentID, req.RunID),
})
if err != nil {
return adapters.SearchResponse{}, err
}
items := mem0ToItems(memories)
sort.Slice(items, func(i, j int) bool { return items[i].Score > items[j].Score })
return adapters.SearchResponse{Results: adapters.DeduplicateItems(items)}, nil
}
func (p *Mem0Provider) GetAll(ctx context.Context, req adapters.GetAllRequest) (adapters.SearchResponse, error) {
agentID := mem0ScopeID(req.BotID, req.AgentID)
if agentID == "" {
return adapters.SearchResponse{}, errors.New("bot_id or agent_id is required")
}
getReq := mem0GetAllRequest{
Filters: mem0AgentFilter(agentID, req.RunID),
}
if req.Limit > 0 {
getReq.PageSize = req.Limit
}
memories, err := p.client.GetAll(ctx, getReq)
if err != nil {
return adapters.SearchResponse{}, err
}
items := mem0ToItems(memories)
sort.Slice(items, func(i, j int) bool { return items[i].UpdatedAt > items[j].UpdatedAt })
return adapters.SearchResponse{Results: items}, nil
}
func (p *Mem0Provider) Update(ctx context.Context, req adapters.UpdateRequest) (adapters.MemoryItem, error) {
memoryID := strings.TrimSpace(req.MemoryID)
if memoryID == "" {
return adapters.MemoryItem{}, errors.New("memory_id is required")
}
mem, err := p.client.Update(ctx, memoryID, req.Memory, nil)
if err != nil {
return adapters.MemoryItem{}, err
}
return mem0ToItem(*mem), nil
}
func (p *Mem0Provider) Delete(ctx context.Context, memoryID string) (adapters.DeleteResponse, error) {
if err := p.client.Delete(ctx, strings.TrimSpace(memoryID)); err != nil {
return adapters.DeleteResponse{}, err
}
return adapters.DeleteResponse{Message: "Memory deleted successfully"}, nil
}
func (p *Mem0Provider) DeleteBatch(ctx context.Context, memoryIDs []string) (adapters.DeleteResponse, error) {
if err := p.client.BatchDelete(ctx, memoryIDs); err != nil {
return adapters.DeleteResponse{}, err
}
return adapters.DeleteResponse{Message: "Memories deleted successfully"}, nil
}
func (p *Mem0Provider) DeleteAll(ctx context.Context, req adapters.DeleteAllRequest) (adapters.DeleteResponse, error) {
agentID := mem0ScopeID(req.BotID, req.AgentID)
if agentID == "" {
return adapters.DeleteResponse{}, errors.New("bot_id or agent_id is required")
}
if err := p.client.DeleteAll(ctx, agentID); err != nil {
return adapters.DeleteResponse{}, err
}
return adapters.DeleteResponse{Message: "All memories deleted"}, nil
}
// --- Lifecycle ---
func (*Mem0Provider) Compact(_ context.Context, _ map[string]any, _ float64, _ int) (adapters.CompactResult, error) {
return adapters.CompactResult{}, errors.New("compact is not supported by mem0 provider")
}
func (*Mem0Provider) Usage(_ context.Context, _ map[string]any) (adapters.UsageResponse, error) {
return adapters.UsageResponse{}, errors.New("usage is not supported by mem0 provider")
}
func (p *Mem0Provider) Status(ctx context.Context, botID string) (adapters.MemoryStatusResponse, error) {
status := adapters.MemoryStatusResponse{
ProviderType: Mem0Type,
CanManualSync: p.store != nil,
SourceDir: path.Join(config.DefaultDataMount, "memory"),
OverviewPath: path.Join(config.DefaultDataMount, "MEMORY.md"),
}
if p.store == nil {
return status, nil
}
fileCount, err := p.store.CountMemoryFiles(ctx, botID)
if err != nil {
return adapters.MemoryStatusResponse{}, err
}
items, err := p.store.ReadAllMemoryFiles(ctx, botID)
if err != nil {
return adapters.MemoryStatusResponse{}, err
}
status.MarkdownFileCount = fileCount
status.SourceCount = len(mem0CanonicalStoreItems(items))
remote, err := p.client.ListAllByAgent(ctx, strings.TrimSpace(botID))
if err != nil {
return adapters.MemoryStatusResponse{}, err
}
status.IndexedCount = len(remote)
return status, nil
}
func (p *Mem0Provider) Rebuild(ctx context.Context, botID string) (adapters.RebuildResult, error) {
if p.store == nil {
return adapters.RebuildResult{}, errors.New("memory filesystem not configured")
}
items, err := p.store.ReadAllMemoryFiles(ctx, botID)
if err != nil {
return adapters.RebuildResult{}, err
}
if err := p.store.SyncOverview(ctx, botID); err != nil {
return adapters.RebuildResult{}, err
}
return p.syncSourceItems(ctx, strings.TrimSpace(botID), items)
}
// --- helpers ---
func mem0ToItems(memories []mem0Memory) []adapters.MemoryItem {
items := make([]adapters.MemoryItem, 0, len(memories))
for _, m := range memories {
items = append(items, mem0ToItem(m))
}
return items
}
func mem0ToItem(m mem0Memory) adapters.MemoryItem {
return adapters.MemoryItem{
ID: m.ID,
Memory: m.Memory,
Hash: m.Hash,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
Score: m.Score,
Metadata: m.Metadata,
BotID: m.AgentID,
AgentID: m.AgentID,
RunID: m.RunID,
}
}
func (p *Mem0Provider) syncSourceItems(ctx context.Context, botID string, items []storefs.MemoryItem) (adapters.RebuildResult, error) {
canonical := mem0CanonicalStoreItems(items)
existing, err := p.client.ListAllByAgent(ctx, botID)
if err != nil {
return adapters.RebuildResult{}, err
}
existingBySource := make(map[string]mem0Memory, len(existing))
for _, memory := range existing {
sourceID := mem0SourceEntryID(memory.Metadata)
if sourceID == "" {
continue
}
existingBySource[sourceID] = memory
}
sourceIDs := make(map[string]struct{}, len(canonical))
missingCount := 0
restoredCount := 0
for _, item := range canonical {
sourceIDs[item.ID] = struct{}{}
metadata := mem0SourceMetadata(botID, item)
existingMemory, ok := existingBySource[item.ID]
if !ok {
missingCount++
restoredCount++
if _, err := p.client.Add(ctx, mem0AddRequest{
Messages: []adapters.Message{{Role: "system", Content: item.Memory}},
AgentID: botID,
RunID: item.RunID,
Metadata: metadata,
Infer: mem0BoolPtr(false),
AsyncMode: mem0BoolPtr(false),
}); err != nil {
return adapters.RebuildResult{}, err
}
continue
}
if mem0SourceMemoryMatches(existingMemory, item, botID) {
continue
}
restoredCount++
if _, err := p.client.Update(ctx, existingMemory.ID, item.Memory, metadata); err != nil {
return adapters.RebuildResult{}, err
}
}
staleIDs := make([]string, 0)
for _, memory := range existing {
sourceID := mem0SourceEntryID(memory.Metadata)
if sourceID == "" {
continue
}
if _, ok := sourceIDs[sourceID]; ok {
continue
}
if strings.TrimSpace(memory.ID) != "" {
staleIDs = append(staleIDs, memory.ID)
}
}
for len(staleIDs) > 0 {
chunkSize := mem0BatchDeleteMaxSize
if len(staleIDs) < chunkSize {
chunkSize = len(staleIDs)
}
if err := p.client.BatchDelete(ctx, staleIDs[:chunkSize]); err != nil {
return adapters.RebuildResult{}, err
}
staleIDs = staleIDs[chunkSize:]
}
remote, err := p.client.ListAllByAgent(ctx, botID)
if err != nil {
return adapters.RebuildResult{}, err
}
return adapters.RebuildResult{
FsCount: len(canonical),
StorageCount: len(remote),
MissingCount: missingCount,
RestoredCount: restoredCount,
}, nil
}
func (p *Mem0Provider) searchMemories(ctx context.Context, query, agentID string, limit int) ([]adapters.MemoryItem, error) {
memories, err := p.client.Search(ctx, mem0SearchRequest{
Query: query,
TopK: limit,
Filters: mem0AgentFilter(agentID, ""),
})
if err != nil {
return nil, err
}
items := mem0ToItems(memories)
sort.Slice(items, func(i, j int) bool { return items[i].Score > items[j].Score })
return adapters.DeduplicateItems(items), nil
}
func mem0ScopeID(botID, agentID string) string {
if value := strings.TrimSpace(botID); value != "" {
return value
}
return strings.TrimSpace(agentID)
}
func mem0AgentFilter(agentID, runID string) map[string]any {
filter := map[string]any{
"agent_id": strings.TrimSpace(agentID),
}
if strings.TrimSpace(runID) != "" {
filter["run_id"] = strings.TrimSpace(runID)
}
return filter
}
func mem0CanonicalStoreItems(items []storefs.MemoryItem) []storefs.MemoryItem {
result := make([]storefs.MemoryItem, 0, len(items))
for _, item := range items {
item.ID = strings.TrimSpace(item.ID)
item.Memory = strings.TrimSpace(item.Memory)
if item.ID == "" || item.Memory == "" {
continue
}
result = append(result, item)
}
return result
}
func mem0SourceMetadata(botID string, item storefs.MemoryItem) map[string]any {
metadata := make(map[string]any, len(item.Metadata)+4)
for key, value := range item.Metadata {
metadata[key] = value
}
metadata[mem0SyncMetadataKeySourceEntryID] = item.ID
metadata[mem0SyncMetadataKeySourceHash] = strings.TrimSpace(item.Hash)
metadata[mem0SyncMetadataKeySourceBotID] = strings.TrimSpace(botID)
metadata[mem0SyncMetadataKeySourceManaged] = true
return metadata
}
func mem0SourceEntryID(metadata map[string]any) string {
return strings.TrimSpace(mem0MetadataString(metadata, mem0SyncMetadataKeySourceEntryID))
}
func mem0SourceMemoryMatches(memory mem0Memory, item storefs.MemoryItem, botID string) bool {
if strings.TrimSpace(memory.Memory) != strings.TrimSpace(item.Memory) {
return false
}
metadata := memory.Metadata
if mem0SourceEntryID(metadata) != strings.TrimSpace(item.ID) {
return false
}
if mem0MetadataString(metadata, mem0SyncMetadataKeySourceHash) != strings.TrimSpace(item.Hash) {
return false
}
if mem0MetadataString(metadata, mem0SyncMetadataKeySourceBotID) != strings.TrimSpace(botID) {
return false
}
return mem0MetadataBool(metadata, mem0SyncMetadataKeySourceManaged)
}
func mem0MetadataString(metadata map[string]any, key string) string {
if metadata == nil {
return ""
}
raw, ok := metadata[key]
if !ok || raw == nil {
return ""
}
if value, ok := raw.(string); ok {
return strings.TrimSpace(value)
}
return strings.TrimSpace(strings.Trim(fmt.Sprintf("%v", raw), "\""))
}
func mem0MetadataBool(metadata map[string]any, key string) bool {
if metadata == nil {
return false
}
raw, ok := metadata[key]
if !ok || raw == nil {
return false
}
value, ok := raw.(bool)
return ok && value
}
func mem0BoolPtr(v bool) *bool {
return &v
}
@@ -0,0 +1,163 @@
package openviking
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
adapters "github.com/memohai/memoh/internal/memory/adapters"
)
type openVikingClient struct {
baseURL string
apiKey string
httpClient *http.Client
}
func newOpenVikingClient(config map[string]any) (*openVikingClient, error) {
baseURL := adapters.StringFromConfig(config, "base_url")
if baseURL == "" {
return nil, errors.New("openviking: base_url is required")
}
baseURL = strings.TrimRight(baseURL, "/")
return &openVikingClient{
baseURL: baseURL,
apiKey: adapters.StringFromConfig(config, "api_key"),
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}, nil
}
type ovMemory struct {
ID string `json:"id"`
Content string `json:"content"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Score float64 `json:"score,omitempty"`
}
type ovAddRequest struct {
AgentID string `json:"agent_id"`
Content string `json:"content"`
}
type ovSearchRequest struct {
Query string `json:"query"`
AgentID string `json:"agent_id"`
Limit int `json:"limit,omitempty"`
}
type ovUpdateRequest struct {
Content string `json:"content"`
}
func (c *openVikingClient) Add(ctx context.Context, agentID, content string) (*ovMemory, error) {
var result ovMemory
if err := c.doJSON(ctx, http.MethodPost, "/memories", ovAddRequest{
AgentID: agentID,
Content: content,
}, &result); err != nil {
return nil, fmt.Errorf("openviking add: %w", err)
}
return &result, nil
}
func (c *openVikingClient) Search(ctx context.Context, agentID, query string, limit int) ([]ovMemory, error) {
var results []ovMemory
if err := c.doJSON(ctx, http.MethodPost, "/memories/search", ovSearchRequest{
Query: query,
AgentID: agentID,
Limit: limit,
}, &results); err != nil {
return nil, fmt.Errorf("openviking search: %w", err)
}
return results, nil
}
func (c *openVikingClient) GetAll(ctx context.Context, agentID string, limit int) ([]ovMemory, error) {
u := "/memories?agent_id=" + url.QueryEscape(agentID)
if limit > 0 {
u += fmt.Sprintf("&limit=%d", limit)
}
var results []ovMemory
if err := c.doJSON(ctx, http.MethodGet, u, nil, &results); err != nil {
return nil, fmt.Errorf("openviking get all: %w", err)
}
return results, nil
}
func (c *openVikingClient) Update(ctx context.Context, memoryID, content string) (*ovMemory, error) {
var result ovMemory
if err := c.doJSON(ctx, http.MethodPut, "/memories/"+memoryID, ovUpdateRequest{Content: content}, &result); err != nil {
return nil, fmt.Errorf("openviking update: %w", err)
}
return &result, nil
}
func (c *openVikingClient) Delete(ctx context.Context, memoryID string) error {
if err := c.doJSON(ctx, http.MethodDelete, "/memories/"+memoryID, nil, nil); err != nil {
return fmt.Errorf("openviking delete: %w", err)
}
return nil
}
func (c *openVikingClient) DeleteAll(ctx context.Context, agentID string) error {
u := "/memories?agent_id=" + url.QueryEscape(agentID)
if err := c.doJSON(ctx, http.MethodDelete, u, nil, nil); err != nil {
return fmt.Errorf("openviking delete all: %w", err)
}
return nil
}
func (c *openVikingClient) doJSON(ctx context.Context, method, urlPath string, body any, result any) error {
var bodyReader io.Reader
if body != nil {
data, err := json.Marshal(body)
if err != nil {
return fmt.Errorf("marshal request: %w", err)
}
bodyReader = bytes.NewReader(data)
}
req, err := http.NewRequestWithContext(ctx, method, c.baseURL+urlPath, bodyReader)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
if c.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+c.apiKey)
}
resp, err := c.httpClient.Do(req) //nolint:gosec // URL is from admin-configured base_url
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read response: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("openviking API error %d: %s", resp.StatusCode, truncateBody(respBody))
}
if result != nil && len(respBody) > 0 {
if err := json.Unmarshal(respBody, result); err != nil {
return fmt.Errorf("unmarshal response: %w", err)
}
}
return nil
}
func truncateBody(b []byte) string {
if len(b) > 300 {
return string(b[:300]) + "..."
}
return string(b)
}
@@ -0,0 +1,309 @@
package openviking
import (
"context"
"errors"
"log/slog"
"sort"
"strings"
"github.com/memohai/memoh/internal/mcp"
adapters "github.com/memohai/memoh/internal/memory/adapters"
)
const (
OpenVikingType = "openviking"
ovToolSearchMemory = "search_memory"
ovDefaultLimit = 10
ovMaxLimit = 50
ovContextMaxItems = 8
ovContextMaxChars = 220
)
// OpenVikingProvider implements adapters.Provider by delegating to an OpenViking API (self-hosted or SaaS).
type OpenVikingProvider struct {
client *openVikingClient
logger *slog.Logger
}
func NewOpenVikingProvider(log *slog.Logger, config map[string]any) (*OpenVikingProvider, error) {
if log == nil {
log = slog.Default()
}
c, err := newOpenVikingClient(config)
if err != nil {
return nil, err
}
return &OpenVikingProvider{
client: c,
logger: log.With(slog.String("provider", OpenVikingType)),
}, nil
}
func (*OpenVikingProvider) Type() string { return OpenVikingType }
// --- Conversation Hooks ---
func (p *OpenVikingProvider) OnBeforeChat(ctx context.Context, req adapters.BeforeChatRequest) (*adapters.BeforeChatResult, error) {
query := strings.TrimSpace(req.Query)
botID := strings.TrimSpace(req.BotID)
if query == "" || botID == "" {
return nil, nil
}
memories, err := p.client.Search(ctx, botID, query, ovContextMaxItems)
if err != nil {
p.logger.Warn("openviking search for context failed", slog.Any("error", err))
return nil, nil
}
if len(memories) == 0 {
return nil, nil
}
var sb strings.Builder
sb.WriteString("<memory-context>\nRelevant memory context (use when helpful):\n")
for i, mem := range memories {
if i >= ovContextMaxItems {
break
}
text := strings.TrimSpace(mem.Content)
if text == "" {
continue
}
sb.WriteString("- ")
sb.WriteString(adapters.TruncateSnippet(text, ovContextMaxChars))
sb.WriteString("\n")
}
sb.WriteString("</memory-context>")
return &adapters.BeforeChatResult{ContextText: sb.String()}, nil
}
func (p *OpenVikingProvider) OnAfterChat(ctx context.Context, req adapters.AfterChatRequest) error {
botID := strings.TrimSpace(req.BotID)
if botID == "" || len(req.Messages) == 0 {
return nil
}
var parts []string
for _, msg := range req.Messages {
content := strings.TrimSpace(msg.Content)
if content == "" {
continue
}
role := strings.ToUpper(strings.TrimSpace(msg.Role))
if role == "" {
role = "MESSAGE"
}
parts = append(parts, "["+role+"] "+content)
}
if len(parts) == 0 {
return nil
}
_, err := p.client.Add(ctx, botID, strings.Join(parts, "\n"))
if err != nil {
p.logger.Warn("openviking store memory failed", slog.String("bot_id", botID), slog.Any("error", err))
}
return nil
}
// --- MCP Tools ---
func (*OpenVikingProvider) ListTools(_ context.Context, _ mcp.ToolSessionContext) ([]mcp.ToolDescriptor, error) {
return []mcp.ToolDescriptor{
{
Name: ovToolSearchMemory,
Description: "Search for memories relevant to the current chat",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{
"type": "string",
"description": "The query to search memories",
},
"limit": map[string]any{
"type": "integer",
"description": "Maximum number of memory results",
},
},
"required": []string{"query"},
},
},
}, nil
}
func (p *OpenVikingProvider) CallTool(ctx context.Context, session mcp.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) {
if toolName != ovToolSearchMemory {
return nil, mcp.ErrToolNotFound
}
query := mcp.StringArg(arguments, "query")
if query == "" {
return mcp.BuildToolErrorResult("query is required"), nil
}
botID := strings.TrimSpace(session.BotID)
if botID == "" {
return mcp.BuildToolErrorResult("bot_id is required"), nil
}
limit := ovDefaultLimit
if value, ok, err := mcp.IntArg(arguments, "limit"); err != nil {
return mcp.BuildToolErrorResult(err.Error()), nil
} else if ok {
limit = value
}
if limit <= 0 {
limit = ovDefaultLimit
}
if limit > ovMaxLimit {
limit = ovMaxLimit
}
memories, err := p.client.Search(ctx, botID, query, limit)
if err != nil {
return mcp.BuildToolErrorResult("memory search failed"), nil
}
results := make([]map[string]any, 0, len(memories))
for _, mem := range memories {
results = append(results, map[string]any{
"id": mem.ID,
"memory": mem.Content,
"score": mem.Score,
})
}
return mcp.BuildToolSuccessResult(map[string]any{
"query": query,
"total": len(results),
"results": results,
}), nil
}
// --- CRUD ---
func (p *OpenVikingProvider) Add(ctx context.Context, req adapters.AddRequest) (adapters.SearchResponse, error) {
botID := strings.TrimSpace(req.BotID)
if botID == "" {
return adapters.SearchResponse{}, errors.New("bot_id is required")
}
text := strings.TrimSpace(req.Message)
if text == "" && len(req.Messages) > 0 {
parts := make([]string, 0, len(req.Messages))
for _, m := range req.Messages {
content := strings.TrimSpace(m.Content)
if content == "" {
continue
}
role := strings.ToUpper(strings.TrimSpace(m.Role))
if role == "" {
role = "MESSAGE"
}
parts = append(parts, "["+role+"] "+content)
}
text = strings.Join(parts, "\n")
}
if text == "" {
return adapters.SearchResponse{}, errors.New("message is required")
}
mem, err := p.client.Add(ctx, botID, text)
if err != nil {
return adapters.SearchResponse{}, err
}
return adapters.SearchResponse{Results: []adapters.MemoryItem{ovToItem(*mem)}}, nil
}
func (p *OpenVikingProvider) Search(ctx context.Context, req adapters.SearchRequest) (adapters.SearchResponse, error) {
botID := strings.TrimSpace(req.BotID)
if botID == "" {
return adapters.SearchResponse{}, errors.New("bot_id is required")
}
limit := req.Limit
if limit <= 0 {
limit = ovDefaultLimit
} else if limit > ovMaxLimit {
limit = ovMaxLimit
}
memories, err := p.client.Search(ctx, botID, req.Query, limit)
if err != nil {
return adapters.SearchResponse{}, err
}
return adapters.SearchResponse{Results: ovToItems(memories)}, nil
}
func (p *OpenVikingProvider) GetAll(ctx context.Context, req adapters.GetAllRequest) (adapters.SearchResponse, error) {
botID := strings.TrimSpace(req.BotID)
if botID == "" {
return adapters.SearchResponse{}, errors.New("bot_id is required")
}
memories, err := p.client.GetAll(ctx, botID, req.Limit)
if err != nil {
return adapters.SearchResponse{}, err
}
items := ovToItems(memories)
sort.Slice(items, func(i, j int) bool { return items[i].UpdatedAt > items[j].UpdatedAt })
return adapters.SearchResponse{Results: items}, nil
}
func (p *OpenVikingProvider) Update(ctx context.Context, req adapters.UpdateRequest) (adapters.MemoryItem, error) {
memoryID := strings.TrimSpace(req.MemoryID)
if memoryID == "" {
return adapters.MemoryItem{}, errors.New("memory_id is required")
}
mem, err := p.client.Update(ctx, memoryID, req.Memory)
if err != nil {
return adapters.MemoryItem{}, err
}
return ovToItem(*mem), nil
}
func (p *OpenVikingProvider) Delete(ctx context.Context, memoryID string) (adapters.DeleteResponse, error) {
if err := p.client.Delete(ctx, strings.TrimSpace(memoryID)); err != nil {
return adapters.DeleteResponse{}, err
}
return adapters.DeleteResponse{Message: "Memory deleted successfully"}, nil
}
func (p *OpenVikingProvider) DeleteBatch(ctx context.Context, memoryIDs []string) (adapters.DeleteResponse, error) {
for _, id := range memoryIDs {
if err := p.client.Delete(ctx, strings.TrimSpace(id)); err != nil {
return adapters.DeleteResponse{}, err
}
}
return adapters.DeleteResponse{Message: "Memories deleted successfully"}, nil
}
func (p *OpenVikingProvider) DeleteAll(ctx context.Context, req adapters.DeleteAllRequest) (adapters.DeleteResponse, error) {
botID := strings.TrimSpace(req.BotID)
if botID == "" {
return adapters.DeleteResponse{}, errors.New("bot_id is required")
}
if err := p.client.DeleteAll(ctx, botID); err != nil {
return adapters.DeleteResponse{}, err
}
return adapters.DeleteResponse{Message: "All memories deleted"}, nil
}
// --- Lifecycle ---
func (*OpenVikingProvider) Compact(_ context.Context, _ map[string]any, _ float64, _ int) (adapters.CompactResult, error) {
return adapters.CompactResult{}, errors.New("compact is not supported by openviking provider")
}
func (*OpenVikingProvider) Usage(_ context.Context, _ map[string]any) (adapters.UsageResponse, error) {
return adapters.UsageResponse{}, errors.New("usage is not supported by openviking provider")
}
// --- helpers ---
func ovToItems(memories []ovMemory) []adapters.MemoryItem {
items := make([]adapters.MemoryItem, 0, len(memories))
for _, m := range memories {
items = append(items, ovToItem(m))
}
return items
}
func ovToItem(m ovMemory) adapters.MemoryItem {
return adapters.MemoryItem{
ID: m.ID,
Memory: m.Content,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
Metadata: m.Metadata,
Score: m.Score,
}
}
@@ -1,4 +1,4 @@
package provider
package adapters
import (
"context"
@@ -7,7 +7,7 @@ import (
)
// Provider is the unified interface for memory systems. Each provider type
// (builtin, mem0, openmemory, etc.) implements this independently with its
// (builtin, mem0, openviking, etc.) implements this independently with its
// own storage, retrieval, and tool logic.
type Provider interface {
// Type returns the provider type identifier (e.g. "builtin", "mem0").
@@ -15,20 +15,12 @@ type Provider interface {
// --- Conversation Hooks ---
// OnBeforeChat is called before sending to the agent gateway.
// It returns memory context to inject into the conversation, or nil if none.
OnBeforeChat(ctx context.Context, req BeforeChatRequest) (*BeforeChatResult, error)
// OnAfterChat is called after receiving the gateway response.
// It extracts facts from the conversation and stores them.
OnAfterChat(ctx context.Context, req AfterChatRequest) error
// --- MCP Tools ---
// ListTools returns MCP tool descriptors provided by this memory provider.
ListTools(ctx context.Context, session mcp.ToolSessionContext) ([]mcp.ToolDescriptor, error)
// CallTool executes an MCP tool owned by this memory provider.
CallTool(ctx context.Context, session mcp.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error)
// --- CRUD ---
@@ -46,3 +38,10 @@ type Provider interface {
Compact(ctx context.Context, filters map[string]any, ratio float64, decayDays int) (CompactResult, error)
Usage(ctx context.Context, filters map[string]any) (UsageResponse, error)
}
// SourceSyncProvider is implemented by providers that can report runtime status
// and rebuild derived storage from a canonical source of truth.
type SourceSyncProvider interface {
Status(ctx context.Context, botID string) (MemoryStatusResponse, error)
Rebuild(ctx context.Context, botID string) (RebuildResult, error)
}
@@ -1,4 +1,4 @@
package provider
package adapters
import (
"errors"
+348
View File
@@ -0,0 +1,348 @@
package adapters
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strconv"
"strings"
"github.com/memohai/memoh/internal/config"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
qdrantclient "github.com/memohai/memoh/internal/memory/qdrant"
)
type Service struct {
queries *sqlc.Queries
registry *Registry
logger *slog.Logger
cfg config.Config
}
func NewService(log *slog.Logger, queries *sqlc.Queries, cfg config.Config) *Service {
return &Service{
queries: queries,
logger: log.With(slog.String("service", "memory_providers")),
cfg: cfg,
}
}
// SetRegistry configures the runtime registry so that CRUD operations
// can instantiate/evict provider instances automatically.
func (s *Service) SetRegistry(registry *Registry) {
s.registry = registry
}
func (*Service) ListMeta(_ context.Context) []ProviderMeta {
return []ProviderMeta{
{
Provider: string(ProviderBuiltin),
DisplayName: "Built-in",
ConfigSchema: ProviderConfigSchema{
Fields: map[string]ProviderFieldSchema{
"memory_mode": {
Type: "select",
Title: "Memory Mode",
Description: "off = file-based, sparse = Qdrant sparse vectors, dense = embedding API + Qdrant dense vectors",
Required: false,
},
"embedding_model_id": {
Type: "model_select",
Title: "Embedding Model",
Description: "Embedding model for dense vector search (dense mode only)",
Required: false,
},
"qdrant_collection": {
Type: "string",
Title: "Qdrant Collection",
Description: "Qdrant collection name for sparse mode. Defaults to memory_sparse.",
Required: false,
Example: "memory_sparse",
},
},
},
},
{
Provider: string(ProviderMem0),
DisplayName: "Mem0",
ConfigSchema: ProviderConfigSchema{
Fields: map[string]ProviderFieldSchema{
"base_url": {
Type: "string",
Title: "Base URL",
Description: "Mem0 SaaS API base URL. Defaults to https://api.mem0.ai when empty.",
Required: false,
Example: "https://api.mem0.ai",
},
"api_key": {
Type: "string",
Title: "API Key",
Description: "API key for Mem0 SaaS authentication",
Required: true,
Secret: true,
},
"org_id": {
Type: "string",
Title: "Organization ID",
Description: "Organization ID for Mem0 SaaS workspace context",
},
"project_id": {
Type: "string",
Title: "Project ID",
Description: "Project ID for Mem0 SaaS workspace context",
},
},
},
},
{
Provider: string(ProviderOpenViking),
DisplayName: "OpenViking",
ConfigSchema: ProviderConfigSchema{
Fields: map[string]ProviderFieldSchema{
"base_url": {
Type: "string",
Title: "Base URL",
Description: "OpenViking API base URL (self-hosted or SaaS)",
Required: true,
Example: "http://openviking:8088",
},
"api_key": {
Type: "string",
Title: "API Key",
Description: "API key for OpenViking authentication",
Secret: true,
},
},
},
},
}
}
func (s *Service) Create(ctx context.Context, req ProviderCreateRequest) (ProviderGetResponse, error) {
if !isValidProviderType(req.Provider) {
return ProviderGetResponse{}, fmt.Errorf("invalid provider type: %s", req.Provider)
}
configJSON, err := json.Marshal(req.Config)
if err != nil {
return ProviderGetResponse{}, fmt.Errorf("marshal config: %w", err)
}
row, err := s.queries.CreateMemoryProvider(ctx, sqlc.CreateMemoryProviderParams{
Name: strings.TrimSpace(req.Name),
Provider: string(req.Provider),
Config: configJSON,
IsDefault: false,
})
if err != nil {
return ProviderGetResponse{}, fmt.Errorf("create memory provider: %w", err)
}
resp := s.toGetResponse(row)
s.tryInstantiate(resp.ID, resp.Provider, resp.Config)
return resp, nil
}
func (s *Service) Get(ctx context.Context, id string) (ProviderGetResponse, error) {
pgID, err := db.ParseUUID(id)
if err != nil {
return ProviderGetResponse{}, err
}
row, err := s.queries.GetMemoryProviderByID(ctx, pgID)
if err != nil {
return ProviderGetResponse{}, fmt.Errorf("get memory provider: %w", err)
}
return s.toGetResponse(row), nil
}
func (s *Service) Status(ctx context.Context, id string) (ProviderStatusResponse, error) {
resp, err := s.Get(ctx, id)
if err != nil {
return ProviderStatusResponse{}, err
}
status := ProviderStatusResponse{
ProviderType: resp.Provider,
}
if resp.Provider != string(ProviderBuiltin) {
return status, nil
}
status.MemoryMode = StringFromConfig(resp.Config, "memory_mode")
status.EmbeddingModelID = StringFromConfig(resp.Config, "embedding_model_id")
collections := []string{"memory_sparse", "memory_dense"}
status.Collections = make([]ProviderCollectionStatus, 0, len(collections))
for _, collection := range collections {
collStatus := ProviderCollectionStatus{Name: collection}
host, port := parseQdrantHostPort(s.cfg.Qdrant.BaseURL)
client, err := qdrantclient.NewClient(host, port, s.cfg.Qdrant.APIKey, collection)
if err != nil {
collStatus.Qdrant.Error = err.Error()
status.Collections = append(status.Collections, collStatus)
continue
}
exists, err := client.CollectionExists(ctx)
if err != nil {
collStatus.Qdrant.Error = err.Error()
status.Collections = append(status.Collections, collStatus)
continue
}
collStatus.Qdrant.OK = true
collStatus.Exists = exists
if exists {
points, err := client.CountAll(ctx)
if err != nil {
collStatus.Qdrant.OK = false
collStatus.Qdrant.Error = err.Error()
} else {
collStatus.Points = points
}
}
status.Collections = append(status.Collections, collStatus)
}
return status, nil
}
func (s *Service) List(ctx context.Context) ([]ProviderGetResponse, error) {
rows, err := s.queries.ListMemoryProviders(ctx)
if err != nil {
return nil, fmt.Errorf("list memory providers: %w", err)
}
items := make([]ProviderGetResponse, 0, len(rows))
for _, row := range rows {
items = append(items, s.toGetResponse(row))
}
return items, nil
}
func (s *Service) Update(ctx context.Context, id string, req ProviderUpdateRequest) (ProviderGetResponse, error) {
pgID, err := db.ParseUUID(id)
if err != nil {
return ProviderGetResponse{}, err
}
current, err := s.queries.GetMemoryProviderByID(ctx, pgID)
if err != nil {
return ProviderGetResponse{}, fmt.Errorf("get memory provider: %w", err)
}
name := current.Name
if req.Name != nil {
name = strings.TrimSpace(*req.Name)
}
config := current.Config
if req.Config != nil {
configJSON, marshalErr := json.Marshal(req.Config)
if marshalErr != nil {
return ProviderGetResponse{}, fmt.Errorf("marshal config: %w", marshalErr)
}
config = configJSON
}
updated, err := s.queries.UpdateMemoryProvider(ctx, sqlc.UpdateMemoryProviderParams{
ID: pgID,
Name: name,
Config: config,
})
if err != nil {
return ProviderGetResponse{}, fmt.Errorf("update memory provider: %w", err)
}
resp := s.toGetResponse(updated)
s.tryEvictAndReinstantiate(resp.ID, resp.Provider, resp.Config)
return resp, nil
}
func (s *Service) Delete(ctx context.Context, id string) error {
pgID, err := db.ParseUUID(id)
if err != nil {
return err
}
if err := s.queries.DeleteMemoryProvider(ctx, pgID); err != nil {
return err
}
if s.registry != nil {
s.registry.Remove(id)
}
return nil
}
// EnsureDefault creates a default builtin provider if none exists.
func (s *Service) EnsureDefault(ctx context.Context) (ProviderGetResponse, error) {
row, err := s.queries.GetDefaultMemoryProvider(ctx)
if err == nil {
return s.toGetResponse(row), nil
}
configJSON, _ := json.Marshal(map[string]any{})
created, err := s.queries.CreateMemoryProvider(ctx, sqlc.CreateMemoryProviderParams{
Name: "Built-in Memory",
Provider: string(ProviderBuiltin),
Config: configJSON,
IsDefault: true,
})
if err != nil {
return ProviderGetResponse{}, fmt.Errorf("create default memory provider: %w", err)
}
return s.toGetResponse(created), nil
}
func (s *Service) toGetResponse(row sqlc.MemoryProvider) ProviderGetResponse {
var cfg map[string]any
if len(row.Config) > 0 {
if err := json.Unmarshal(row.Config, &cfg); err != nil {
s.logger.Warn("memory provider config unmarshal failed", slog.String("id", row.ID.String()), slog.Any("error", err))
}
}
return ProviderGetResponse{
ID: row.ID.String(),
Name: row.Name,
Provider: row.Provider,
Config: cfg,
IsDefault: row.IsDefault,
CreatedAt: row.CreatedAt.Time,
UpdatedAt: row.UpdatedAt.Time,
}
}
func (s *Service) tryInstantiate(id, providerType string, config map[string]any) {
if s.registry == nil {
return
}
if _, err := s.registry.Instantiate(id, providerType, config); err != nil {
s.logger.Warn("auto-instantiate memory provider failed",
slog.String("id", id), slog.String("provider", providerType), slog.Any("error", err))
}
}
func (s *Service) tryEvictAndReinstantiate(id, providerType string, config map[string]any) {
if s.registry == nil {
return
}
s.registry.Remove(id)
s.tryInstantiate(id, providerType, config)
}
func isValidProviderType(t ProviderType) bool {
switch t {
case ProviderBuiltin, ProviderMem0, ProviderOpenViking:
return true
default:
return false
}
}
func parseQdrantHostPort(baseURL string) (string, int) {
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return "", 0
}
baseURL = strings.TrimPrefix(baseURL, "http://")
baseURL = strings.TrimPrefix(baseURL, "https://")
parts := strings.SplitN(baseURL, ":", 2)
host := parts[0]
if len(parts) == 2 {
httpPort, err := strconv.Atoi(strings.TrimRight(parts[1], "/"))
if err == nil {
switch httpPort {
case 6333, 6334:
return host, 6334
default:
return host, httpPort
}
}
}
return host, 6334
}
@@ -1,4 +1,4 @@
package provider
package adapters
import (
"context"
@@ -19,8 +19,11 @@ type BeforeChatResult struct {
// AfterChatRequest is passed to OnAfterChat after receiving the gateway response.
type AfterChatRequest struct {
BotID string
Messages []Message
BotID string
Messages []Message
UserID string
ChannelIdentityID string
DisplayName string
}
// LLM is the interface for LLM operations needed by memory service.
@@ -205,16 +208,37 @@ type UsageResponse struct {
type RebuildResult struct {
FsCount int `json:"fs_count"`
QdrantCount int `json:"qdrant_count"`
StorageCount int `json:"storage_count"`
MissingCount int `json:"missing_count"`
RestoredCount int `json:"restored_count"`
}
type HealthStatus struct {
OK bool `json:"ok"`
Error string `json:"error,omitempty"`
}
type MemoryStatusResponse struct {
ProviderType string `json:"provider_type,omitempty"`
MemoryMode string `json:"memory_mode,omitempty"`
CanManualSync bool `json:"can_manual_sync"`
SourceDir string `json:"source_dir,omitempty"`
OverviewPath string `json:"overview_path,omitempty"`
MarkdownFileCount int `json:"markdown_file_count,omitempty"`
SourceCount int `json:"source_count,omitempty"`
IndexedCount int `json:"indexed_count,omitempty"`
QdrantCollection string `json:"qdrant_collection,omitempty"`
Encoder HealthStatus `json:"encoder"`
Qdrant HealthStatus `json:"qdrant"`
}
// Memory provider admin types.
type ProviderType string
const (
ProviderBuiltin ProviderType = "builtin"
ProviderBuiltin ProviderType = "builtin"
ProviderMem0 ProviderType = "mem0"
ProviderOpenViking ProviderType = "openviking"
)
type ProviderCreateRequest struct {
@@ -247,6 +271,7 @@ type ProviderFieldSchema struct {
Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"`
Required bool `json:"required,omitempty"`
Secret bool `json:"secret,omitempty"`
Example any `json:"example,omitempty"`
}
@@ -255,3 +280,17 @@ type ProviderMeta struct {
DisplayName string `json:"display_name"`
ConfigSchema ProviderConfigSchema `json:"config_schema"`
}
type ProviderCollectionStatus struct {
Name string `json:"name"`
Exists bool `json:"exists"`
Points int `json:"points"`
Qdrant HealthStatus `json:"qdrant"`
}
type ProviderStatusResponse struct {
ProviderType string `json:"provider_type"`
MemoryMode string `json:"memory_mode,omitempty"`
EmbeddingModelID string `json:"embedding_model_id,omitempty"`
Collections []ProviderCollectionStatus `json:"collections,omitempty"`
}
-179
View File
@@ -1,179 +0,0 @@
package provider
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
)
type Service struct {
queries *sqlc.Queries
logger *slog.Logger
}
func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
return &Service{
queries: queries,
logger: log.With(slog.String("service", "memory_providers")),
}
}
func (*Service) ListMeta(_ context.Context) []ProviderMeta {
return []ProviderMeta{
{
Provider: string(ProviderBuiltin),
DisplayName: "Built-in",
ConfigSchema: ProviderConfigSchema{
Fields: map[string]ProviderFieldSchema{
"memory_model_id": {
Type: "model_select",
Title: "Memory Model",
Description: "LLM model used for memory extraction and decision",
Required: false,
},
"embedding_model_id": {
Type: "model_select",
Title: "Embedding Model",
Description: "Embedding model for dense vector search",
Required: false,
},
},
},
},
}
}
func (s *Service) Create(ctx context.Context, req ProviderCreateRequest) (ProviderGetResponse, error) {
if !isValidProviderType(req.Provider) {
return ProviderGetResponse{}, fmt.Errorf("invalid provider type: %s", req.Provider)
}
configJSON, err := json.Marshal(req.Config)
if err != nil {
return ProviderGetResponse{}, fmt.Errorf("marshal config: %w", err)
}
row, err := s.queries.CreateMemoryProvider(ctx, sqlc.CreateMemoryProviderParams{
Name: strings.TrimSpace(req.Name),
Provider: string(req.Provider),
Config: configJSON,
IsDefault: false,
})
if err != nil {
return ProviderGetResponse{}, fmt.Errorf("create memory provider: %w", err)
}
return s.toGetResponse(row), nil
}
func (s *Service) Get(ctx context.Context, id string) (ProviderGetResponse, error) {
pgID, err := db.ParseUUID(id)
if err != nil {
return ProviderGetResponse{}, err
}
row, err := s.queries.GetMemoryProviderByID(ctx, pgID)
if err != nil {
return ProviderGetResponse{}, fmt.Errorf("get memory provider: %w", err)
}
return s.toGetResponse(row), nil
}
func (s *Service) List(ctx context.Context) ([]ProviderGetResponse, error) {
rows, err := s.queries.ListMemoryProviders(ctx)
if err != nil {
return nil, fmt.Errorf("list memory providers: %w", err)
}
items := make([]ProviderGetResponse, 0, len(rows))
for _, row := range rows {
items = append(items, s.toGetResponse(row))
}
return items, nil
}
func (s *Service) Update(ctx context.Context, id string, req ProviderUpdateRequest) (ProviderGetResponse, error) {
pgID, err := db.ParseUUID(id)
if err != nil {
return ProviderGetResponse{}, err
}
current, err := s.queries.GetMemoryProviderByID(ctx, pgID)
if err != nil {
return ProviderGetResponse{}, fmt.Errorf("get memory provider: %w", err)
}
name := current.Name
if req.Name != nil {
name = strings.TrimSpace(*req.Name)
}
config := current.Config
if req.Config != nil {
configJSON, marshalErr := json.Marshal(req.Config)
if marshalErr != nil {
return ProviderGetResponse{}, fmt.Errorf("marshal config: %w", marshalErr)
}
config = configJSON
}
updated, err := s.queries.UpdateMemoryProvider(ctx, sqlc.UpdateMemoryProviderParams{
ID: pgID,
Name: name,
Config: config,
})
if err != nil {
return ProviderGetResponse{}, fmt.Errorf("update memory provider: %w", err)
}
return s.toGetResponse(updated), nil
}
func (s *Service) Delete(ctx context.Context, id string) error {
pgID, err := db.ParseUUID(id)
if err != nil {
return err
}
return s.queries.DeleteMemoryProvider(ctx, pgID)
}
// EnsureDefault creates a default builtin provider if none exists.
func (s *Service) EnsureDefault(ctx context.Context) (ProviderGetResponse, error) {
row, err := s.queries.GetDefaultMemoryProvider(ctx)
if err == nil {
return s.toGetResponse(row), nil
}
configJSON, _ := json.Marshal(map[string]any{})
created, err := s.queries.CreateMemoryProvider(ctx, sqlc.CreateMemoryProviderParams{
Name: "Built-in Memory",
Provider: string(ProviderBuiltin),
Config: configJSON,
IsDefault: true,
})
if err != nil {
return ProviderGetResponse{}, fmt.Errorf("create default memory provider: %w", err)
}
return s.toGetResponse(created), nil
}
func (s *Service) toGetResponse(row sqlc.MemoryProvider) ProviderGetResponse {
var cfg map[string]any
if len(row.Config) > 0 {
if err := json.Unmarshal(row.Config, &cfg); err != nil {
s.logger.Warn("memory provider config unmarshal failed", slog.String("id", row.ID.String()), slog.Any("error", err))
}
}
return ProviderGetResponse{
ID: row.ID.String(),
Name: row.Name,
Provider: row.Provider,
Config: cfg,
IsDefault: row.IsDefault,
CreatedAt: row.CreatedAt.Time,
UpdatedAt: row.UpdatedAt.Time,
}
}
func isValidProviderType(t ProviderType) bool {
switch t {
case ProviderBuiltin:
return true
default:
return false
}
}
+423
View File
@@ -0,0 +1,423 @@
// Package qdrant wraps the official github.com/qdrant/go-client SDK,
// providing a thin facade for sparse-vector memory operations.
package qdrant
import (
"context"
"fmt"
"math"
"strconv"
"strings"
pb "github.com/qdrant/go-client/qdrant"
)
const (
sparseVectorName = "sparse"
)
// Client wraps the official Qdrant gRPC client with sparse-memory-specific helpers.
type Client struct {
inner *pb.Client
collection string
}
// NewClient creates a Qdrant client connected via gRPC.
// host should be a bare hostname/IP; port is the gRPC port (default 6334).
func NewClient(host string, port int, apiKey, collection string) (*Client, error) {
if host == "" {
host = "localhost"
}
if port == 0 {
port = 6334
}
if collection == "" {
collection = "memory"
}
cfg := &pb.Config{
Host: host,
Port: port,
}
if apiKey != "" {
cfg.APIKey = apiKey
}
inner, err := pb.NewClient(cfg)
if err != nil {
return nil, fmt.Errorf("qdrant: connect: %w", err)
}
return &Client{inner: inner, collection: collection}, nil
}
// Close closes the underlying gRPC connection.
func (c *Client) Close() error {
return c.inner.Close()
}
func (c *Client) CollectionName() string {
return c.collection
}
func (c *Client) CollectionExists(ctx context.Context) (bool, error) {
exists, err := c.inner.CollectionExists(ctx, c.collection)
if err != nil {
return false, fmt.Errorf("qdrant: check collection: %w", err)
}
return exists, nil
}
// EnsureCollection creates the collection with a named sparse vector config if it does not exist.
func (c *Client) EnsureCollection(ctx context.Context) error {
exists, err := c.CollectionExists(ctx)
if err != nil {
return err
}
if exists {
return nil
}
err = c.inner.CreateCollection(ctx, &pb.CreateCollection{
CollectionName: c.collection,
SparseVectorsConfig: pb.NewSparseVectorsConfig(map[string]*pb.SparseVectorParams{
sparseVectorName: {},
}),
})
if err != nil {
return fmt.Errorf("qdrant: create collection: %w", err)
}
return nil
}
// EnsureDenseCollection creates the collection with dense vector config if it
// does not exist.
func (c *Client) EnsureDenseCollection(ctx context.Context, dimensions int) error {
if dimensions <= 0 {
return fmt.Errorf("qdrant: dense dimensions must be positive, got %d", dimensions)
}
exists, err := c.CollectionExists(ctx)
if err != nil {
return err
}
if exists {
return nil
}
err = c.inner.CreateCollection(ctx, &pb.CreateCollection{
CollectionName: c.collection,
VectorsConfig: pb.NewVectorsConfig(&pb.VectorParams{
Size: uint64(dimensions),
Distance: pb.Distance_Cosine,
}),
})
if err != nil {
return fmt.Errorf("qdrant: create dense collection: %w", err)
}
return nil
}
// SparseVector holds the non-zero components of a sparse text encoding.
type SparseVector struct {
Indices []uint32
Values []float32
}
type DenseVector struct {
Values []float32
}
// SearchResult is one result from a sparse search or scroll.
type SearchResult struct {
ID string
Score float64
Payload map[string]string
}
// Upsert inserts or updates points with named sparse vectors.
func (c *Client) Upsert(ctx context.Context, id string, vec SparseVector, payload map[string]string) error {
wait := true
_, err := c.inner.Upsert(ctx, &pb.UpsertPoints{
CollectionName: c.collection,
Wait: &wait,
Points: []*pb.PointStruct{
{
Id: pb.NewID(id),
Vectors: pb.NewVectorsMap(map[string]*pb.Vector{
sparseVectorName: {
Data: vec.Values,
Indices: &pb.SparseIndices{Data: vec.Indices},
},
}),
Payload: stringPayloadToValueMap(payload),
},
},
})
if err != nil {
return fmt.Errorf("qdrant: upsert: %w", err)
}
return nil
}
// UpsertDense inserts or updates points with dense vectors.
func (c *Client) UpsertDense(ctx context.Context, id string, vec DenseVector, payload map[string]string) error {
wait := true
_, err := c.inner.Upsert(ctx, &pb.UpsertPoints{
CollectionName: c.collection,
Wait: &wait,
Points: []*pb.PointStruct{
{
Id: pb.NewID(id),
Vectors: pb.NewVectorsDense(vec.Values),
Payload: stringPayloadToValueMap(payload),
},
},
})
if err != nil {
return fmt.Errorf("qdrant: upsert dense: %w", err)
}
return nil
}
// Search performs a sparse-vector query against the collection, filtered by bot_id.
func (c *Client) Search(ctx context.Context, vec SparseVector, botID string, limit int) ([]SearchResult, error) {
if limit <= 0 {
limit = 10
}
queryLimit, err := intToUint64(limit)
if err != nil {
return nil, fmt.Errorf("qdrant: invalid search limit: %w", err)
}
scored, err := c.inner.Query(ctx, &pb.QueryPoints{
CollectionName: c.collection,
Query: pb.NewQuerySparse(vec.Indices, vec.Values),
Using: strPtr(sparseVectorName),
Filter: botFilter(botID),
Limit: uint64Ptr(queryLimit),
WithPayload: pb.NewWithPayload(true),
})
if err != nil {
return nil, fmt.Errorf("qdrant: search: %w", err)
}
return scoredPointsToResults(scored), nil
}
// SearchDense performs a dense-vector query against the collection, filtered by bot_id.
func (c *Client) SearchDense(ctx context.Context, vec DenseVector, botID string, limit int) ([]SearchResult, error) {
if limit <= 0 {
limit = 10
}
queryLimit, err := intToUint64(limit)
if err != nil {
return nil, fmt.Errorf("qdrant: invalid dense search limit: %w", err)
}
scored, err := c.inner.Query(ctx, &pb.QueryPoints{
CollectionName: c.collection,
Query: pb.NewQueryDense(vec.Values),
Filter: botFilter(botID),
Limit: uint64Ptr(queryLimit),
WithPayload: pb.NewWithPayload(true),
})
if err != nil {
return nil, fmt.Errorf("qdrant: dense search: %w", err)
}
return scoredPointsToResults(scored), nil
}
// GetByID fetches a single point by UUID.
func (c *Client) GetByID(ctx context.Context, id string) (*SearchResult, error) {
points, err := c.inner.Get(ctx, &pb.GetPoints{
CollectionName: c.collection,
Ids: []*pb.PointId{pb.NewID(id)},
WithPayload: pb.NewWithPayload(true),
})
if err != nil {
return nil, fmt.Errorf("qdrant: get: %w", err)
}
if len(points) == 0 {
return nil, nil
}
r := retrievedPointToResult(points[0])
return &r, nil
}
// Scroll returns all points matching bot_id, up to limit.
func (c *Client) Scroll(ctx context.Context, botID string, limit int) ([]SearchResult, error) {
if limit <= 0 {
limit = 1000
}
if limit > math.MaxUint32 {
limit = math.MaxUint32
}
l, err := intToUint32(limit)
if err != nil {
return nil, fmt.Errorf("qdrant: invalid scroll limit: %w", err)
}
points, err := c.inner.Scroll(ctx, &pb.ScrollPoints{
CollectionName: c.collection,
Filter: botFilter(botID),
Limit: &l,
WithPayload: pb.NewWithPayload(true),
})
if err != nil {
return nil, fmt.Errorf("qdrant: scroll: %w", err)
}
results := make([]SearchResult, 0, len(points))
for _, p := range points {
results = append(results, retrievedPointToResult(p))
}
return results, nil
}
// Count returns the number of points matching bot_id.
func (c *Client) Count(ctx context.Context, botID string) (int, error) {
exact := true
n, err := c.inner.Count(ctx, &pb.CountPoints{
CollectionName: c.collection,
Filter: botFilter(botID),
Exact: &exact,
})
if err != nil {
return 0, fmt.Errorf("qdrant: count: %w", err)
}
if n > uint64(math.MaxInt) {
return 0, fmt.Errorf("qdrant: count overflow: %d", n)
}
return int(n), nil
}
// CountAll returns the total number of points in the collection.
func (c *Client) CountAll(ctx context.Context) (int, error) {
exact := true
n, err := c.inner.Count(ctx, &pb.CountPoints{
CollectionName: c.collection,
Exact: &exact,
})
if err != nil {
return 0, fmt.Errorf("qdrant: count all: %w", err)
}
if n > uint64(math.MaxInt) {
return 0, fmt.Errorf("qdrant: count overflow: %d", n)
}
return int(n), nil
}
// DeleteByIDs removes specific points by their UUID strings.
func (c *Client) DeleteByIDs(ctx context.Context, ids []string) error {
if len(ids) == 0 {
return nil
}
pointIDs := make([]*pb.PointId, 0, len(ids))
for _, id := range ids {
if strings.TrimSpace(id) != "" {
pointIDs = append(pointIDs, pb.NewID(strings.TrimSpace(id)))
}
}
wait := true
_, err := c.inner.Delete(ctx, &pb.DeletePoints{
CollectionName: c.collection,
Wait: &wait,
Points: &pb.PointsSelector{
PointsSelectorOneOf: &pb.PointsSelector_Points{
Points: &pb.PointsIdsList{Ids: pointIDs},
},
},
})
if err != nil {
return fmt.Errorf("qdrant: delete by ids: %w", err)
}
return nil
}
// DeleteByBotID removes all points for a given bot_id.
func (c *Client) DeleteByBotID(ctx context.Context, botID string) error {
wait := true
_, err := c.inner.Delete(ctx, &pb.DeletePoints{
CollectionName: c.collection,
Wait: &wait,
Points: &pb.PointsSelector{
PointsSelectorOneOf: &pb.PointsSelector_Filter{
Filter: botFilter(botID),
},
},
})
if err != nil {
return fmt.Errorf("qdrant: delete by bot_id: %w", err)
}
return nil
}
// --- helpers ---
func botFilter(botID string) *pb.Filter {
return &pb.Filter{
Must: []*pb.Condition{
pb.NewMatch("bot_id", botID),
},
}
}
func stringPayloadToValueMap(payload map[string]string) map[string]*pb.Value {
m := make(map[string]*pb.Value, len(payload))
for k, v := range payload {
m[k] = pb.NewValueString(v)
}
return m
}
func valueMapToStringPayload(m map[string]*pb.Value) map[string]string {
if len(m) == 0 {
return nil
}
out := make(map[string]string, len(m))
for k, v := range m {
if v != nil {
if sv := v.GetStringValue(); sv != "" {
out[k] = sv
}
}
}
return out
}
func scoredPointsToResults(scored []*pb.ScoredPoint) []SearchResult {
results := make([]SearchResult, 0, len(scored))
for _, p := range scored {
results = append(results, SearchResult{
ID: extractID(p.GetId()),
Score: float64(p.GetScore()),
Payload: valueMapToStringPayload(p.GetPayload()),
})
}
return results
}
func retrievedPointToResult(p *pb.RetrievedPoint) SearchResult {
return SearchResult{
ID: extractID(p.GetId()),
Payload: valueMapToStringPayload(p.GetPayload()),
}
}
func extractID(id *pb.PointId) string {
if id == nil {
return ""
}
if uuid := id.GetUuid(); uuid != "" {
return uuid
}
return strconv.FormatUint(id.GetNum(), 10)
}
func strPtr(s string) *string { return &s }
func uint64Ptr(v uint64) *uint64 { return &v }
func intToUint64(v int) (uint64, error) {
return strconv.ParseUint(strconv.Itoa(v), 10, 64)
}
func intToUint32(v int) (uint32, error) {
n, err := strconv.ParseUint(strconv.Itoa(v), 10, 32)
if err != nil {
return 0, err
}
return uint32(n), nil
}
-366
View File
@@ -1,366 +0,0 @@
package memory
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"path"
"sort"
"strings"
"time"
"github.com/memohai/memoh/internal/config"
)
const (
memoryDateLayout = "2006-01-02"
memEntryStartPrefix = "<!-- MEMOH:ENTRY "
memEntryStartSuffix = " -->"
memEntryEndMarker = "<!-- /MEMOH:ENTRY -->"
memFileHeaderTemplate = "# Memory %s\n\n"
)
type writeRecord struct {
Topic string `json:"topic"`
ID string `json:"id"`
Memory string `json:"memory"`
Text string `json:"text"`
Content string `json:"content"`
Hash string `json:"hash"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// NormalizeMemoryDayContent converts user/LLM writes to canonical memory day format.
// Non-memory-day paths are returned unchanged.
func NormalizeMemoryDayContent(containerPath, raw string) string {
if !isMemoryDayMarkdownPath(containerPath) {
return raw
}
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return raw
}
if strings.Contains(trimmed, memEntryStartPrefix) && strings.Contains(trimmed, memEntryEndMarker) {
return raw
}
date := strings.TrimSuffix(path.Base(containerPath), ".md")
records := parseStructuredRecords(trimmed)
if len(records) == 0 {
records = []writeRecord{buildFallbackRecord(trimmed, date, time.Now().UTC())}
}
return formatDayMarkdown(date, records)
}
// RenderMemoryDayForDisplay converts canonical memory day markdown into
// a user-facing timeline view. Non-memory-day paths are returned unchanged.
func RenderMemoryDayForDisplay(containerPath, raw string) string {
if !isMemoryDayMarkdownPath(containerPath) {
return raw
}
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return raw
}
date := strings.TrimSuffix(path.Base(containerPath), ".md")
records := parseCanonicalDayRecords(trimmed)
if len(records) == 0 {
return raw
}
sort.Slice(records, func(i, j int) bool {
ti := recordTime(records[i])
tj := recordTime(records[j])
if ti.Equal(tj) {
return records[i].ID < records[j].ID
}
return ti.Before(tj)
})
var b strings.Builder
b.WriteString("# ")
b.WriteString(date)
b.WriteString("\n\n")
for idx, r := range records {
if idx > 0 {
b.WriteString("\n")
}
b.WriteString("## ")
b.WriteString(formatRecordTime(r))
b.WriteString(" - ")
b.WriteString(recordTitle(r))
b.WriteString("\n")
b.WriteString(formatRecordBody(r.Memory))
b.WriteString("\n")
}
return strings.TrimSpace(b.String())
}
func isMemoryDayMarkdownPath(containerPath string) bool {
clean := path.Clean("/" + strings.TrimSpace(containerPath))
memoryDir := path.Clean(config.DefaultDataMount+"/memory") + "/"
if !strings.HasPrefix(clean, memoryDir) || !strings.HasSuffix(clean, ".md") {
return false
}
datePart := strings.TrimSuffix(path.Base(clean), ".md")
_, err := time.Parse(memoryDateLayout, datePart)
return err == nil
}
func parseStructuredRecords(content string) []writeRecord {
now := time.Now().UTC()
normalize := func(in []writeRecord) []writeRecord {
out := make([]writeRecord, 0, len(in))
for _, r := range in {
nr, ok := normalizeRecord(r, now)
if ok {
out = append(out, nr)
}
}
return out
}
var list []writeRecord
if err := json.Unmarshal([]byte(content), &list); err == nil {
return normalize(list)
}
var obj writeRecord
if err := json.Unmarshal([]byte(content), &obj); err == nil {
return normalize([]writeRecord{obj})
}
var wrapped struct {
Items []writeRecord `json:"items"`
}
if err := json.Unmarshal([]byte(content), &wrapped); err == nil {
return normalize(wrapped.Items)
}
return nil
}
func parseCanonicalDayRecords(content string) []writeRecord {
lines := strings.Split(content, "\n")
out := make([]writeRecord, 0, 8)
for i := 0; i < len(lines); i++ {
line := strings.TrimSpace(lines[i])
if !strings.HasPrefix(line, memEntryStartPrefix) || !strings.HasSuffix(line, memEntryStartSuffix) {
continue
}
metaJSON := strings.TrimSuffix(strings.TrimPrefix(line, memEntryStartPrefix), memEntryStartSuffix)
var rec writeRecord
if err := json.Unmarshal([]byte(metaJSON), &rec); err != nil {
continue
}
start := i + 1
end := start
for ; end < len(lines); end++ {
if strings.TrimSpace(lines[end]) == memEntryEndMarker {
break
}
}
if end >= len(lines) {
break
}
rec.Memory = strings.TrimSpace(strings.Join(lines[start:end], "\n"))
out = append(out, rec)
i = end
}
return out
}
func formatDayMarkdown(date string, records []writeRecord) string {
sort.Slice(records, func(i, j int) bool {
ti := parseRFC3339OrZero(records[i].CreatedAt)
tj := parseRFC3339OrZero(records[j].CreatedAt)
if ti.Equal(tj) {
return records[i].ID < records[j].ID
}
return ti.Before(tj)
})
var b strings.Builder
fmt.Fprintf(&b, memFileHeaderTemplate, date)
for _, r := range records {
meta := map[string]string{"id": r.ID}
if r.Topic != "" {
meta["topic"] = r.Topic
}
if r.Hash != "" {
meta["hash"] = r.Hash
}
if r.CreatedAt != "" {
meta["created_at"] = r.CreatedAt
}
if r.UpdatedAt != "" {
meta["updated_at"] = r.UpdatedAt
}
rawMeta, _ := json.Marshal(meta)
b.WriteString(memEntryStartPrefix)
b.Write(rawMeta)
b.WriteString(memEntryStartSuffix)
b.WriteString("\n")
b.WriteString(r.Memory)
b.WriteString("\n")
b.WriteString(memEntryEndMarker)
b.WriteString("\n\n")
}
return b.String()
}
func parseRFC3339OrZero(raw string) time.Time {
raw = strings.TrimSpace(raw)
if raw == "" {
return time.Time{}
}
t, err := time.Parse(time.RFC3339, raw)
if err != nil {
return time.Time{}
}
return t.UTC()
}
func recordTime(r writeRecord) time.Time {
if t := parseRFC3339OrZero(r.CreatedAt); !t.IsZero() {
return t
}
if t := parseRFC3339OrZero(r.UpdatedAt); !t.IsZero() {
return t
}
return time.Time{}
}
func formatRecordTime(r writeRecord) string {
t := recordTime(r)
if t.IsZero() {
return "--:--"
}
return t.Format("03:04 PM")
}
func recordTitle(r writeRecord) string {
if topic := strings.TrimSpace(r.Topic); topic != "" {
return topic
}
return "Notes"
}
func formatRecordBody(body string) string {
lines := strings.Split(strings.TrimSpace(body), "\n")
out := make([]string, 0, len(lines))
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
if strings.HasPrefix(line, "- ") || strings.HasPrefix(line, "* ") || strings.HasPrefix(line, "1. ") {
out = append(out, line)
continue
}
out = append(out, "- "+line)
}
if len(out) == 0 {
return "- (empty)"
}
return strings.Join(out, "\n")
}
func buildFallbackRecord(content, date string, now time.Time) writeRecord {
record := writeRecord{
ID: fmt.Sprintf("mem_%d", now.UnixNano()),
Memory: sanitizeFallbackBody(content, date),
CreatedAt: now.Format(time.RFC3339),
UpdatedAt: now.Format(time.RFC3339),
}
if legacy, ok := parseLegacyFrontmatterRecord(content); ok {
if normalized, ok := normalizeRecord(legacy, now); ok {
return normalized
}
}
if record.Hash == "" {
record.Hash = generateMemoryHash(record.Topic, record.Memory)
}
return record
}
func parseLegacyFrontmatterRecord(content string) (writeRecord, bool) {
trimmed := strings.TrimSpace(content)
if !strings.HasPrefix(trimmed, "---") {
return writeRecord{}, false
}
parts := strings.SplitN(trimmed[3:], "---", 2)
if len(parts) < 2 {
return writeRecord{}, false
}
frontmatter := strings.TrimSpace(parts[0])
body := strings.TrimSpace(parts[1])
record := writeRecord{Memory: body}
for _, line := range strings.Split(frontmatter, "\n") {
key, value, found := strings.Cut(strings.TrimSpace(line), ":")
if !found {
continue
}
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
switch key {
case "id":
record.ID = value
case "hash":
record.Hash = value
case "created_at":
record.CreatedAt = value
case "updated_at":
record.UpdatedAt = value
}
}
return record, true
}
func sanitizeFallbackBody(content, date string) string {
body := strings.TrimSpace(content)
header := "# Memory " + strings.TrimSpace(date)
if strings.HasPrefix(body, header) {
body = strings.TrimSpace(strings.TrimPrefix(body, header))
}
return body
}
func normalizeRecord(r writeRecord, now time.Time) (writeRecord, bool) {
mem := strings.TrimSpace(r.Memory)
if mem == "" {
mem = strings.TrimSpace(r.Content)
}
if mem == "" {
mem = strings.TrimSpace(r.Text)
}
if mem == "" {
return writeRecord{}, false
}
topic := strings.TrimSpace(r.Topic)
id := strings.TrimSpace(r.ID)
if id == "" {
id = fmt.Sprintf("mem_%d", now.UnixNano())
}
createdAt := strings.TrimSpace(r.CreatedAt)
if createdAt == "" {
createdAt = now.Format(time.RFC3339)
}
updatedAt := strings.TrimSpace(r.UpdatedAt)
if updatedAt == "" {
updatedAt = createdAt
}
hash := strings.TrimSpace(r.Hash)
if hash == "" {
hash = generateMemoryHash(topic, mem)
}
return writeRecord{
Topic: topic,
ID: id,
Memory: mem,
Hash: hash,
CreatedAt: createdAt,
UpdatedAt: updatedAt,
}, true
}
func generateMemoryHash(topic, memory string) string {
sum := sha256.Sum256([]byte(strings.TrimSpace(topic) + "\n" + strings.TrimSpace(memory)))
return hex.EncodeToString(sum[:])
}
-114
View File
@@ -1,114 +0,0 @@
package memory
import (
"regexp"
"strings"
"testing"
)
func TestNormalizeMemoryDayContent_StructuredJSON(t *testing.T) {
path := "/data/memory/2026-03-01.md"
input := `[
{
"topic": "Decision",
"memory": "Choose provider architecture."
}
]`
out := NormalizeMemoryDayContent(path, input)
if !strings.Contains(out, "# Memory 2026-03-01") {
t.Fatalf("expected day header, got: %s", out)
}
if !strings.Contains(out, `<!-- MEMOH:ENTRY `) {
t.Fatalf("expected entry marker, got: %s", out)
}
if !strings.Contains(out, `"topic":"Decision"`) {
t.Fatalf("expected topic metadata, got: %s", out)
}
if !strings.Contains(out, "Choose provider architecture.") {
t.Fatalf("expected memory body, got: %s", out)
}
if !strings.Contains(out, `"hash":"`) {
t.Fatalf("expected generated hash metadata, got: %s", out)
}
}
func TestNormalizeMemoryDayContent_FallbackPlainText(t *testing.T) {
path := "/data/memory/2026-03-01.md"
input := "Unstructured note from model output."
out := NormalizeMemoryDayContent(path, input)
if !strings.Contains(out, "# Memory 2026-03-01") {
t.Fatalf("expected day header, got: %s", out)
}
if !strings.Contains(out, "Unstructured note from model output.") {
t.Fatalf("expected original text preserved, got: %s", out)
}
if !strings.Contains(out, `"created_at":"`) || !strings.Contains(out, `"updated_at":"`) {
t.Fatalf("expected timestamps, got: %s", out)
}
if !regexp.MustCompile(`"id":"mem_\d+"`).MatchString(out) {
t.Fatalf("expected generated id, got: %s", out)
}
}
func TestNormalizeMemoryDayContent_LegacyFrontmatter(t *testing.T) {
path := "/data/memory/2026-03-01.md"
input := `---
id: mem_legacy_1
hash: legacyhash
created_at: 2026-03-01T09:00:00Z
updated_at: 2026-03-01T10:00:00Z
---
Legacy body text.`
out := NormalizeMemoryDayContent(path, input)
if !strings.Contains(out, `"id":"mem_legacy_1"`) {
t.Fatalf("expected legacy id reused, got: %s", out)
}
if !strings.Contains(out, `"hash":"legacyhash"`) {
t.Fatalf("expected legacy hash reused, got: %s", out)
}
if !strings.Contains(out, "Legacy body text.") {
t.Fatalf("expected legacy body reused, got: %s", out)
}
}
func TestRenderMemoryDayForDisplay(t *testing.T) {
path := "/data/memory/2026-03-01.md"
raw := `# Memory 2026-03-01
<!-- MEMOH:ENTRY {"id":"mem_1","topic":"Decision","created_at":"2026-03-01T09:40:00Z"} -->
结论:采用 provider 架构
<!-- /MEMOH:ENTRY -->
<!-- MEMOH:ENTRY {"id":"mem_2","topic":"Notes","created_at":"2026-03-01T11:15:00Z"} -->
用户偏好:简短回复
<!-- /MEMOH:ENTRY -->
`
out := RenderMemoryDayForDisplay(path, raw)
if strings.Contains(out, "MEMOH:ENTRY") {
t.Fatalf("display output should hide raw markers: %s", out)
}
if !strings.Contains(out, "# 2026-03-01") {
t.Fatalf("expected display day header, got: %s", out)
}
if !strings.Contains(out, "## 09:40 AM - Decision") {
t.Fatalf("expected timeline section, got: %s", out)
}
if !strings.Contains(out, "- 结论:采用 provider 架构") {
t.Fatalf("expected bulletized body, got: %s", out)
}
if !strings.Contains(out, "## 11:15 AM - Notes") {
t.Fatalf("expected second timeline section, got: %s", out)
}
}
func TestRenderMemoryDayForDisplay_NonMemoryPathUnchanged(t *testing.T) {
raw := "plain content"
out := RenderMemoryDayForDisplay("/data/notes.md", raw)
if out != raw {
t.Fatalf("non-memory path should be unchanged, got: %s", out)
}
}
+159
View File
@@ -0,0 +1,159 @@
// Package sparse provides a Go client for the sparse encoding Python service.
// The Python service loads the OpenSearch neural sparse model from HuggingFace
// and exposes HTTP endpoints for text → sparse vector encoding.
package sparse
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// SparseVector holds the non-zero components of a sparse text encoding.
type SparseVector struct {
Indices []uint32 `json:"indices"`
Values []float32 `json:"values"`
}
// Client calls the Python sparse encoding service.
type Client struct {
baseURL string
http *http.Client
}
// NewClient creates a sparse encoding client pointing to the Python service.
func NewClient(baseURL string) *Client {
return &Client{
baseURL: baseURL,
http: &http.Client{Timeout: 60 * time.Second},
}
}
// EncodeDocument encodes a document text into a sparse vector using the neural model.
func (c *Client) EncodeDocument(ctx context.Context, text string) (*SparseVector, error) {
return c.encode(ctx, "/encode/document", text)
}
// EncodeQuery encodes a query text into a sparse vector (IDF-weighted tokenizer lookup).
func (c *Client) EncodeQuery(ctx context.Context, text string) (*SparseVector, error) {
return c.encode(ctx, "/encode/query", text)
}
// EncodeDocuments encodes multiple document texts in a single batch call.
func (c *Client) EncodeDocuments(ctx context.Context, texts []string) ([]SparseVector, error) {
body, err := json.Marshal(map[string]any{"texts": texts})
if err != nil {
return nil, err
}
endpoint, err := joinEndpointURL(c.baseURL, "/encode/documents")
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.http.Do(req) //nolint:gosec // G704: URL is validated and derived from operator-configured sparse encoder base URL
if err != nil {
return nil, fmt.Errorf("sparse encode failed: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("sparse encode error (status %d): %s", resp.StatusCode, string(respBody))
}
var vectors []SparseVector
if err := json.NewDecoder(resp.Body).Decode(&vectors); err != nil {
return nil, err
}
return vectors, nil
}
func (c *Client) Health(ctx context.Context) error {
endpoint, err := joinEndpointURL(c.baseURL, "/health")
if err != nil {
return err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return err
}
resp, err := c.http.Do(req) //nolint:gosec // G704: URL is validated and derived from operator-configured sparse encoder base URL
if err != nil {
return fmt.Errorf("sparse health check failed: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("sparse health error (status %d): %s", resp.StatusCode, string(respBody))
}
return nil
}
func (c *Client) encode(ctx context.Context, path, text string) (*SparseVector, error) {
body, err := json.Marshal(map[string]string{"text": text})
if err != nil {
return nil, err
}
endpoint, err := joinEndpointURL(c.baseURL, path)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.http.Do(req) //nolint:gosec // G704: URL is validated and derived from operator-configured sparse encoder base URL
if err != nil {
return nil, fmt.Errorf("sparse encode failed: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("sparse encode error (status %d): %s", resp.StatusCode, string(respBody))
}
var vec SparseVector
if err := json.NewDecoder(resp.Body).Decode(&vec); err != nil {
return nil, err
}
return &vec, nil
}
func joinEndpointURL(baseURL, path string) (string, error) {
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return "", errors.New("sparse encode base URL is required")
}
base, err := url.Parse(baseURL)
if err != nil {
return "", fmt.Errorf("invalid sparse encode base URL: %w", err)
}
if base.Scheme != "http" && base.Scheme != "https" {
return "", fmt.Errorf("invalid sparse encode base URL scheme: %q", base.Scheme)
}
if base.Host == "" {
return "", errors.New("invalid sparse encode base URL: host is required")
}
ref, err := url.Parse(path)
if err != nil {
return "", fmt.Errorf("invalid sparse encode path: %w", err)
}
return base.ResolveReference(ref).String(), nil
}
+150
View File
@@ -0,0 +1,150 @@
"""Sparse encoding Flask service using OpenSearch neural sparse model."""
import json
import os
import sys
from pathlib import Path
import torch
from flask import Flask, jsonify, request
from huggingface_hub import hf_hub_download
from transformers import AutoModelForMaskedLM, AutoTokenizer
DEFAULT_MODEL_REPO = "opensearch-project/opensearch-neural-sparse-encoding-multilingual-v1"
DEFAULT_PORT = 8085
DEFAULT_CACHE_DIR = os.environ.get(
"SPARSE_CACHE_DIR",
str(Path(__file__).resolve().parent / "hf-cache"),
)
model_repo = DEFAULT_MODEL_REPO
cache_dir = DEFAULT_CACHE_DIR
port = int(os.environ.get("SPARSE_PORT", DEFAULT_PORT))
app = Flask(__name__)
_model = None
_tokenizer = None
_idf = None
_special_token_ids: list[int] = []
def _load_model() -> None:
global _model, _tokenizer, _idf, _special_token_ids
Path(cache_dir).mkdir(parents=True, exist_ok=True)
_model = AutoModelForMaskedLM.from_pretrained(model_repo, cache_dir=cache_dir)
_tokenizer = AutoTokenizer.from_pretrained(model_repo, cache_dir=cache_dir)
_model.eval()
_idf = _load_idf(_tokenizer)
_special_token_ids = [
_tokenizer.vocab[tok]
for tok in _tokenizer.special_tokens_map.values()
if tok in _tokenizer.vocab
]
def _load_idf(tokenizer):
local_path = hf_hub_download(
repo_id=model_repo, filename="idf.json", cache_dir=cache_dir
)
with open(local_path, encoding="utf-8") as f:
idf_data = json.load(f)
idf_vector = [0.0] * tokenizer.vocab_size
for tok, weight in idf_data.items():
tid = tokenizer._convert_token_to_id_with_added_voc(tok)
idf_vector[tid] = weight
return torch.tensor(idf_vector)
@torch.no_grad()
def _encode_document(text: str) -> dict:
feat = _tokenizer(
[text],
padding=True,
truncation=True,
return_tensors="pt",
return_token_type_ids=False,
)
out = _model(**feat)[0]
vals, _ = torch.max(out * feat["attention_mask"].unsqueeze(-1), dim=1)
vals = torch.log(1 + torch.log(1 + torch.relu(vals)))
vals[:, _special_token_ids] = 0
return _sparse_to_dict(vals[0])
@torch.no_grad()
def _encode_documents(texts: list[str]) -> list[dict]:
feat = _tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt",
return_token_type_ids=False,
)
out = _model(**feat)[0]
vals, _ = torch.max(out * feat["attention_mask"].unsqueeze(-1), dim=1)
vals = torch.log(1 + torch.log(1 + torch.relu(vals)))
vals[:, _special_token_ids] = 0
return [_sparse_to_dict(vals[i]) for i in range(vals.shape[0])]
def _encode_query(text: str) -> dict:
feat = _tokenizer(
[text],
padding=True,
truncation=True,
return_tensors="pt",
return_token_type_ids=False,
)
input_ids = feat["input_ids"]
batch_size = input_ids.shape[0]
qv = torch.zeros(batch_size, _tokenizer.vocab_size)
qv[torch.arange(batch_size).unsqueeze(-1), input_ids] = 1
sparse_vector = qv * _idf
return _sparse_to_dict(sparse_vector[0])
def _sparse_to_dict(vector: torch.Tensor) -> dict:
nz = torch.nonzero(vector, as_tuple=True)[0]
return {"indices": nz.tolist(), "values": vector[nz].tolist()}
@app.route("/health", methods=["GET"])
def health():
return jsonify(status="ok", model_loaded=True, model_repo=model_repo)
@app.route("/encode/document", methods=["POST"])
def encode_document():
body = request.get_json(silent=True) or {}
text = body.get("text", "")
if not text:
return jsonify(error="text is required"), 400
return jsonify(_encode_document(text))
@app.route("/encode/query", methods=["POST"])
def encode_query():
body = request.get_json(silent=True) or {}
text = body.get("text", "")
if not text:
return jsonify(error="text is required"), 400
return jsonify(_encode_query(text))
@app.route("/encode/documents", methods=["POST"])
def encode_documents():
body = request.get_json(silent=True) or {}
texts = body.get("texts", [])
if not texts:
return jsonify(error="texts is required"), 400
return jsonify(_encode_documents(texts))
def main():
print(f"[sparse-service] loading model {model_repo}...", file=sys.stderr, flush=True)
_load_model()
print(f"[sparse-service] listening on port {port}", file=sys.stderr, flush=True)
app.run(host="0.0.0.0", port=port, threaded=True)
if __name__ == "__main__":
main()
@@ -0,0 +1,6 @@
flask
torch
transformers
huggingface_hub
sentencepiece
protobuf
+180 -72
View File
@@ -13,15 +13,17 @@ import (
"strings"
"time"
"gopkg.in/yaml.v3"
"github.com/memohai/memoh/internal/config"
"github.com/memohai/memoh/internal/mcp/mcpclient"
)
const (
memoryDateLayout = "2006-01-02"
entryStartPrefix = "<!-- MEMOH:ENTRY "
entryStartSuffix = " -->"
entryEndMarker = "<!-- /MEMOH:ENTRY -->"
memoryDateLayout = "2006-01-02"
entryHeadingPrefix = "## Entry "
yamlFence = "```yaml"
codeFence = "```"
)
var ErrNotConfigured = errors.New("memory filesystem not configured")
@@ -49,6 +51,14 @@ type MemoryItem struct {
RunID string `json:"run_id,omitempty"`
}
type memoryEntryMeta struct {
ID string `yaml:"id"`
Hash string `yaml:"hash,omitempty"`
CreatedAt string `yaml:"created_at,omitempty"`
UpdatedAt string `yaml:"updated_at,omitempty"`
Metadata map[string]any `yaml:"metadata,omitempty"`
}
func New(log *slog.Logger, provider mcpclient.Provider) *Service {
if log == nil {
log = slog.Default()
@@ -123,13 +133,16 @@ func (s *Service) buildScanIndex(ctx context.Context, botID string) (map[string]
}
parsed, parseErr := parseMemoryDayMD(content)
if parseErr != nil {
legacy, legacyErr := parseLegacyMemoryMD(content)
if legacyErr != nil {
jsonItems, jsonErr := parseJSONMemoryItems(content)
if jsonErr != nil {
s.logger.Warn("buildScanIndex: failed to parse memory file",
slog.String("bot_id", botID), slog.String("path", entryPath), slog.Any("error", parseErr))
continue
}
parsed = []MemoryItem{legacy}
if err := s.writeMemoryDay(ctx, botID, entryPath, jsonItems); err != nil {
return nil, err
}
parsed = jsonItems
}
for _, item := range parsed {
id := strings.TrimSpace(item.ID)
@@ -239,11 +252,10 @@ func (s *Service) RemoveMemories(ctx context.Context, botID string, ids []string
if id == "" {
continue
}
targets := make([]string, 0, 2)
targets := make([]string, 0, 1)
if entry, ok := index[id]; ok {
targets = append(targets, entry.FilePath)
}
targets = append(targets, memoryLegacyItemPath(id))
for _, target := range targets {
if removals[target] == nil {
removals[target] = map[string]struct{}{}
@@ -295,11 +307,14 @@ func (s *Service) ReadAllMemoryFiles(ctx context.Context, botID string) ([]Memor
}
parsed, parseErr := parseMemoryDayMD(content)
if parseErr != nil {
legacy, legacyErr := parseLegacyMemoryMD(content)
if legacyErr != nil {
jsonItems, jsonErr := parseJSONMemoryItems(content)
if jsonErr != nil {
continue
}
parsed = []MemoryItem{legacy}
if err := s.writeMemoryDay(ctx, botID, entryPath, jsonItems); err != nil {
return nil, err
}
parsed = jsonItems
}
for _, item := range parsed {
if strings.TrimSpace(item.ID) == "" {
@@ -318,6 +333,31 @@ func (s *Service) ReadAllMemoryFiles(ctx context.Context, botID string) ([]Memor
return items, nil
}
func (s *Service) CountMemoryFiles(ctx context.Context, botID string) (int, error) {
if s.provider == nil {
return 0, ErrNotConfigured
}
c, err := s.client(ctx, botID)
if err != nil {
return 0, err
}
entries, err := c.ListDir(ctx, memoryDirPath(), false)
if err != nil {
if isNotFound(err) {
return 0, nil
}
return 0, err
}
count := 0
for _, entry := range entries {
if entry.GetIsDir() || !strings.HasSuffix(entry.GetPath(), ".md") {
continue
}
count++
}
return count, nil
}
func (s *Service) SyncOverview(ctx context.Context, botID string) error {
if s.provider == nil {
return ErrNotConfigured
@@ -341,11 +381,14 @@ func (s *Service) readMemoryDay(ctx context.Context, botID, filePath string) ([]
if parseErr == nil {
return items, nil
}
legacy, legacyErr := parseLegacyMemoryMD(content)
if legacyErr != nil {
jsonItems, jsonErr := parseJSONMemoryItems(content)
if jsonErr != nil {
return []MemoryItem{}, nil
}
return []MemoryItem{legacy}, nil
if err := s.writeMemoryDay(ctx, botID, filePath, jsonItems); err != nil {
return nil, err
}
return jsonItems, nil
}
func (s *Service) writeMemoryDay(ctx context.Context, botID, filePath string, items []MemoryItem) error {
@@ -393,10 +436,6 @@ func memoryDayPath(date string) string {
return path.Join(memoryDirPath(), strings.TrimSpace(date)+".md")
}
func memoryLegacyItemPath(id string) string {
return path.Join(memoryDirPath(), strings.TrimSpace(id)+".md")
}
// --- format / parse helpers ---
func formatMemoryDayMD(date string, items []MemoryItem) string {
@@ -417,24 +456,24 @@ func formatMemoryDayMD(date string, items []MemoryItem) string {
if item.ID == "" || item.Memory == "" {
continue
}
meta := map[string]string{"id": item.ID}
if item.Hash != "" {
meta["hash"] = item.Hash
meta := memoryEntryMeta{
ID: item.ID,
Hash: item.Hash,
CreatedAt: item.CreatedAt,
UpdatedAt: item.UpdatedAt,
Metadata: item.Metadata,
}
if item.CreatedAt != "" {
meta["created_at"] = item.CreatedAt
}
if item.UpdatedAt != "" {
meta["updated_at"] = item.UpdatedAt
}
rawMeta, _ := json.Marshal(meta)
b.WriteString(entryStartPrefix)
b.Write(rawMeta)
b.WriteString(entryStartSuffix)
rawMeta, _ := yaml.Marshal(meta)
b.WriteString(entryHeadingPrefix)
b.WriteString(item.ID)
b.WriteString("\n\n")
b.WriteString(yamlFence)
b.WriteString("\n")
b.WriteString(strings.TrimSpace(string(rawMeta)))
b.WriteString("\n")
b.WriteString(codeFence)
b.WriteString("\n\n")
b.WriteString(item.Memory)
b.WriteString("\n")
b.WriteString(entryEndMarker)
b.WriteString("\n\n")
}
return b.String()
@@ -449,35 +488,52 @@ func parseMemoryDayMD(content string) ([]MemoryItem, error) {
items := make([]MemoryItem, 0, 8)
for i := 0; i < len(lines); i++ {
line := strings.TrimSpace(lines[i])
if !strings.HasPrefix(line, entryStartPrefix) || !strings.HasSuffix(line, entryStartSuffix) {
if !strings.HasPrefix(line, entryHeadingPrefix) {
continue
}
metaJSON := strings.TrimSuffix(strings.TrimPrefix(line, entryStartPrefix), entryStartSuffix)
var meta map[string]string
if err := json.Unmarshal([]byte(metaJSON), &meta); err != nil {
entryID := strings.TrimSpace(strings.TrimPrefix(line, entryHeadingPrefix))
j := i + 1
for ; j < len(lines) && strings.TrimSpace(lines[j]) == ""; j++ {
}
if j >= len(lines) || strings.TrimSpace(lines[j]) != yamlFence {
continue
}
start := i + 1
end := start
for ; end < len(lines); end++ {
if strings.TrimSpace(lines[end]) == entryEndMarker {
metaStart := j + 1
metaEnd := metaStart
for ; metaEnd < len(lines); metaEnd++ {
if strings.TrimSpace(lines[metaEnd]) == codeFence {
break
}
}
if end >= len(lines) {
if metaEnd >= len(lines) {
break
}
var meta memoryEntryMeta
if err := yaml.Unmarshal([]byte(strings.Join(lines[metaStart:metaEnd], "\n")), &meta); err != nil {
continue
}
bodyStart := metaEnd + 1
if bodyStart < len(lines) && strings.TrimSpace(lines[bodyStart]) == "" {
bodyStart++
}
bodyEnd := bodyStart
for ; bodyEnd < len(lines); bodyEnd++ {
if strings.HasPrefix(strings.TrimSpace(lines[bodyEnd]), entryHeadingPrefix) {
break
}
}
item := MemoryItem{
ID: strings.TrimSpace(meta["id"]),
Hash: strings.TrimSpace(meta["hash"]),
CreatedAt: strings.TrimSpace(meta["created_at"]),
UpdatedAt: strings.TrimSpace(meta["updated_at"]),
Memory: strings.TrimSpace(strings.Join(lines[start:end], "\n")),
ID: firstNonEmpty(meta.ID, entryID),
Hash: strings.TrimSpace(meta.Hash),
CreatedAt: strings.TrimSpace(meta.CreatedAt),
UpdatedAt: strings.TrimSpace(meta.UpdatedAt),
Metadata: meta.Metadata,
Memory: strings.TrimSpace(strings.Join(lines[bodyStart:bodyEnd], "\n")),
}
if item.ID != "" && item.Memory != "" {
items = append(items, item)
}
i = end
i = bodyEnd - 1
}
if len(items) == 0 {
return nil, errors.New("no memory entries found")
@@ -485,36 +541,78 @@ func parseMemoryDayMD(content string) ([]MemoryItem, error) {
return items, nil
}
func parseLegacyMemoryMD(content string) (MemoryItem, error) {
type jsonMemoryRecord struct {
Topic string `json:"topic"`
ID string `json:"id"`
Memory string `json:"memory"`
Text string `json:"text"`
Content string `json:"content"`
Hash string `json:"hash"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
func parseJSONMemoryItems(content string) ([]MemoryItem, error) {
content = strings.TrimSpace(content)
if !strings.HasPrefix(content, "---") {
return MemoryItem{}, errors.New("missing frontmatter")
if content == "" {
return nil, errors.New("empty memory file")
}
parts := strings.SplitN(content[3:], "---", 2)
if len(parts) < 2 {
return MemoryItem{}, errors.New("incomplete frontmatter")
var list []jsonMemoryRecord
if err := json.Unmarshal([]byte(content), &list); err == nil {
return normalizeJSONMemoryItems(list), nil
}
item := MemoryItem{Memory: strings.TrimSpace(parts[1])}
for _, line := range strings.Split(strings.TrimSpace(parts[0]), "\n") {
key, value, found := strings.Cut(strings.TrimSpace(line), ":")
if !found {
var obj jsonMemoryRecord
if err := json.Unmarshal([]byte(content), &obj); err == nil {
return normalizeJSONMemoryItems([]jsonMemoryRecord{obj}), nil
}
var wrapped struct {
Items []jsonMemoryRecord `json:"items"`
}
if err := json.Unmarshal([]byte(content), &wrapped); err == nil {
return normalizeJSONMemoryItems(wrapped.Items), nil
}
return nil, errors.New("not json memory format")
}
func normalizeJSONMemoryItems(records []jsonMemoryRecord) []MemoryItem {
now := time.Now().UTC()
items := make([]MemoryItem, 0, len(records))
for _, record := range records {
text := strings.TrimSpace(record.Memory)
if text == "" {
text = strings.TrimSpace(record.Content)
}
if text == "" {
text = strings.TrimSpace(record.Text)
}
if text == "" {
continue
}
switch strings.TrimSpace(key) {
case "id":
item.ID = strings.TrimSpace(value)
case "hash":
item.Hash = strings.TrimSpace(value)
case "created_at":
item.CreatedAt = strings.TrimSpace(value)
case "updated_at":
item.UpdatedAt = strings.TrimSpace(value)
item := MemoryItem{
ID: strings.TrimSpace(record.ID),
Hash: strings.TrimSpace(record.Hash),
CreatedAt: strings.TrimSpace(record.CreatedAt),
UpdatedAt: strings.TrimSpace(record.UpdatedAt),
Memory: text,
}
if item.ID == "" {
item.ID = "mem_" + strconv.FormatInt(now.UnixNano(), 10)
}
if item.CreatedAt == "" {
item.CreatedAt = now.Format(time.RFC3339)
}
if item.UpdatedAt == "" {
item.UpdatedAt = item.CreatedAt
}
if item.Hash == "" {
item.Hash = "json_" + strconv.FormatInt(now.UnixNano(), 10)
}
if topic := strings.TrimSpace(record.Topic); topic != "" {
item.Metadata = map[string]any{"topic": topic}
}
items = append(items, item)
}
if item.ID == "" {
return MemoryItem{}, errors.New("missing id in frontmatter")
}
return item, nil
return items
}
func formatMemoryOverviewMD(items []MemoryItem) string {
@@ -638,3 +736,13 @@ func memoryTime(item MemoryItem) time.Time {
}
return time.Time{}
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
value = strings.TrimSpace(value)
if value != "" {
return value
}
}
return ""
}
+59 -16
View File
@@ -12,12 +12,14 @@ func TestFormatAndParseMemoryDayMD_Roundtrip(t *testing.T) {
Memory: "second record",
Hash: "h2",
CreatedAt: "2026-03-01T11:15:00Z",
Metadata: map[string]any{"topic": "Notes"},
},
{
ID: "mem_1",
Memory: "first record",
Hash: "h1",
CreatedAt: "2026-03-01T09:40:00Z",
Metadata: map[string]any{"topic": "Decision"},
},
}
@@ -25,6 +27,12 @@ func TestFormatAndParseMemoryDayMD_Roundtrip(t *testing.T) {
if !strings.Contains(md, "# Memory 2026-03-01") {
t.Fatalf("expected header in markdown: %s", md)
}
if !strings.Contains(md, "## Entry mem_1") {
t.Fatalf("expected entry heading in markdown: %s", md)
}
if !strings.Contains(md, "```yaml") {
t.Fatalf("expected yaml block in markdown: %s", md)
}
parsed, err := parseMemoryDayMD(md)
if err != nil {
@@ -37,28 +45,63 @@ func TestFormatAndParseMemoryDayMD_Roundtrip(t *testing.T) {
if parsed[0].ID != "mem_1" || parsed[1].ID != "mem_2" {
t.Fatalf("unexpected order after roundtrip: %#v", parsed)
}
if got := parsed[0].Metadata["topic"]; got != "Decision" {
t.Fatalf("expected metadata preserved, got %#v", parsed[0].Metadata)
}
}
func TestParseLegacyMemoryMD(t *testing.T) {
legacy := `---
id: mem_legacy
hash: legacyhash
created_at: 2026-03-01T09:00:00Z
updated_at: 2026-03-01T10:00:00Z
---
legacy content`
func TestParseJSONMemoryItems(t *testing.T) {
raw := `[
{
"id": "mem_json",
"topic": "Decision",
"memory": "Choose provider architecture."
}
]`
item, err := parseLegacyMemoryMD(legacy)
items, err := parseJSONMemoryItems(raw)
if err != nil {
t.Fatalf("parseLegacyMemoryMD error: %v", err)
t.Fatalf("parseJSONMemoryItems error: %v", err)
}
if item.ID != "mem_legacy" {
t.Fatalf("unexpected id: %#v", item)
if len(items) != 1 {
t.Fatalf("expected 1 item, got %d", len(items))
}
if item.Hash != "legacyhash" {
t.Fatalf("unexpected hash: %#v", item)
if items[0].ID != "mem_json" {
t.Fatalf("unexpected id: %#v", items[0])
}
if item.Memory != "legacy content" {
t.Fatalf("unexpected memory body: %#v", item)
if got := items[0].Metadata["topic"]; got != "Decision" {
t.Fatalf("expected topic metadata, got %#v", items[0].Metadata)
}
if items[0].Memory != "Choose provider architecture." {
t.Fatalf("unexpected memory body: %#v", items[0])
}
}
func TestParseJSONMemoryItemsCanBeFormattedToCanonicalMarkdown(t *testing.T) {
raw := `[
{
"id": "mem_json",
"topic": "Decision",
"memory": "Choose provider architecture."
}
]`
items, err := parseJSONMemoryItems(raw)
if err != nil {
t.Fatalf("parseJSONMemoryItems error: %v", err)
}
md := formatMemoryDayMD("2026-03-01", items)
if !strings.Contains(md, "## Entry mem_json") {
t.Fatalf("expected canonical heading, got: %s", md)
}
if !strings.Contains(md, "topic: Decision") {
t.Fatalf("expected yaml metadata topic, got: %s", md)
}
parsed, err := parseMemoryDayMD(md)
if err != nil {
t.Fatalf("parseMemoryDayMD error: %v", err)
}
if len(parsed) != 1 || parsed[0].ID != "mem_json" {
t.Fatalf("unexpected parsed canonical items: %#v", parsed)
}
}