mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
refactor: multi-provider memory adapters with scan-based builtin (#227)
* refactor: restructure memory into multi-provider adapters, remove manifest.json dependency - Rename internal/memory/provider to internal/memory/adapters with per-provider subdirectories (builtin, mem0, openviking) - Replace manifest.json-based delete/update with scan-based index from daily files - Add mem0 and openviking provider adapters with HTTP client, chat hooks, MCP tools, and CRUD - Wire provider lifecycle into registry (auto-instantiate on create, evict on update/delete) - Split docker-compose into base stack + optional overlays (qdrant, browser, mem0, openviking) - Update admin UI to support dynamic provider config schema rendering * chore(lint): fix all golangci-lint issues for clean CI * refactor(docker): replace compose overlay files with profiles * feat(memory): add built-in memory multi modes * fix(ci): golangci lint * feat(memory): edit built-in memory sparse design
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
package provider
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/mcp"
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -36,15 +37,18 @@ type BuiltinProvider struct {
|
||||
// It is intentionally defined as an interface to decouple provider wiring from
|
||||
// concrete service structs in the memory package.
|
||||
type memoryRuntime interface {
|
||||
Add(ctx context.Context, req AddRequest) (SearchResponse, error)
|
||||
Search(ctx context.Context, req SearchRequest) (SearchResponse, error)
|
||||
GetAll(ctx context.Context, req GetAllRequest) (SearchResponse, error)
|
||||
Update(ctx context.Context, req UpdateRequest) (MemoryItem, error)
|
||||
Delete(ctx context.Context, memoryID string) (DeleteResponse, error)
|
||||
DeleteBatch(ctx context.Context, memoryIDs []string) (DeleteResponse, error)
|
||||
DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteResponse, error)
|
||||
Compact(ctx context.Context, filters map[string]any, ratio float64, decayDays int) (CompactResult, error)
|
||||
Usage(ctx context.Context, filters map[string]any) (UsageResponse, error)
|
||||
Add(ctx context.Context, req adapters.AddRequest) (adapters.SearchResponse, error)
|
||||
Search(ctx context.Context, req adapters.SearchRequest) (adapters.SearchResponse, error)
|
||||
GetAll(ctx context.Context, req adapters.GetAllRequest) (adapters.SearchResponse, error)
|
||||
Update(ctx context.Context, req adapters.UpdateRequest) (adapters.MemoryItem, error)
|
||||
Delete(ctx context.Context, memoryID string) (adapters.DeleteResponse, error)
|
||||
DeleteBatch(ctx context.Context, memoryIDs []string) (adapters.DeleteResponse, error)
|
||||
DeleteAll(ctx context.Context, req adapters.DeleteAllRequest) (adapters.DeleteResponse, error)
|
||||
Compact(ctx context.Context, filters map[string]any, ratio float64, decayDays int) (adapters.CompactResult, error)
|
||||
Usage(ctx context.Context, filters map[string]any) (adapters.UsageResponse, error)
|
||||
Mode() string
|
||||
Status(ctx context.Context, botID string) (adapters.MemoryStatusResponse, error)
|
||||
Rebuild(ctx context.Context, botID string) (adapters.RebuildResult, error)
|
||||
}
|
||||
|
||||
// AdminChecker checks whether a channel identity has admin privileges.
|
||||
@@ -69,7 +73,7 @@ func (*BuiltinProvider) Type() string { return BuiltinType }
|
||||
|
||||
// --- Conversation Hooks ---
|
||||
|
||||
func (p *BuiltinProvider) OnBeforeChat(ctx context.Context, req BeforeChatRequest) (*BeforeChatResult, error) {
|
||||
func (p *BuiltinProvider) OnBeforeChat(ctx context.Context, req adapters.BeforeChatRequest) (*adapters.BeforeChatResult, error) {
|
||||
if p.service == nil {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -77,7 +81,7 @@ func (p *BuiltinProvider) OnBeforeChat(ctx context.Context, req BeforeChatReques
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
resp, err := p.service.Search(ctx, SearchRequest{
|
||||
resp, err := p.service.Search(ctx, adapters.SearchRequest{
|
||||
Query: req.Query,
|
||||
BotID: req.BotID,
|
||||
Limit: memoryContextLimitPerScope,
|
||||
@@ -96,7 +100,7 @@ func (p *BuiltinProvider) OnBeforeChat(ctx context.Context, req BeforeChatReques
|
||||
seen := map[string]struct{}{}
|
||||
type contextItem struct {
|
||||
namespace string
|
||||
item MemoryItem
|
||||
item adapters.MemoryItem
|
||||
}
|
||||
results := make([]contextItem, 0, memoryContextLimitPerScope)
|
||||
for _, item := range resp.Results {
|
||||
@@ -134,7 +138,7 @@ func (p *BuiltinProvider) OnBeforeChat(ctx context.Context, req BeforeChatReques
|
||||
sb.WriteString("- [")
|
||||
sb.WriteString(entry.namespace)
|
||||
sb.WriteString("] ")
|
||||
sb.WriteString(truncateSnippet(text, memoryContextItemMaxChars))
|
||||
sb.WriteString(adapters.TruncateSnippet(text, memoryContextItemMaxChars))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("</memory-context>")
|
||||
@@ -142,10 +146,10 @@ func (p *BuiltinProvider) OnBeforeChat(ctx context.Context, req BeforeChatReques
|
||||
if payload == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return &BeforeChatResult{ContextText: payload}, nil
|
||||
return &adapters.BeforeChatResult{ContextText: payload}, nil
|
||||
}
|
||||
|
||||
func (p *BuiltinProvider) OnAfterChat(ctx context.Context, req AfterChatRequest) error {
|
||||
func (p *BuiltinProvider) OnAfterChat(ctx context.Context, req adapters.AfterChatRequest) error {
|
||||
if p.service == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -161,9 +165,11 @@ func (p *BuiltinProvider) OnAfterChat(ctx context.Context, req AfterChatRequest)
|
||||
"scopeId": botID,
|
||||
"bot_id": botID,
|
||||
}
|
||||
if _, err := p.service.Add(ctx, AddRequest{
|
||||
metadata := adapters.BuildProfileMetadata(req.UserID, req.ChannelIdentityID, req.DisplayName)
|
||||
if _, err := p.service.Add(ctx, adapters.AddRequest{
|
||||
Messages: req.Messages,
|
||||
BotID: botID,
|
||||
Metadata: metadata,
|
||||
Filters: filters,
|
||||
}); err != nil {
|
||||
p.logger.Warn("store memory failed", slog.String("bot_id", botID), slog.Any("error", err))
|
||||
@@ -256,7 +262,7 @@ func (p *BuiltinProvider) CallTool(ctx context.Context, session mcp.ToolSessionC
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := p.service.Search(ctx, SearchRequest{
|
||||
resp, err := p.service.Search(ctx, adapters.SearchRequest{
|
||||
Query: query,
|
||||
BotID: botID,
|
||||
Limit: limit,
|
||||
@@ -271,7 +277,7 @@ func (p *BuiltinProvider) CallTool(ctx context.Context, session mcp.ToolSessionC
|
||||
return mcp.BuildToolErrorResult("memory search failed"), nil
|
||||
}
|
||||
|
||||
allResults := deduplicateItems(resp.Results)
|
||||
allResults := adapters.DeduplicateItems(resp.Results)
|
||||
sort.Slice(allResults, func(i, j int) bool {
|
||||
return allResults[i].Score > allResults[j].Score
|
||||
})
|
||||
@@ -313,99 +319,79 @@ func (p *BuiltinProvider) canAccessChat(ctx context.Context, chatID, channelIden
|
||||
|
||||
// --- CRUD ---
|
||||
|
||||
func (p *BuiltinProvider) Add(ctx context.Context, req AddRequest) (SearchResponse, error) {
|
||||
func (p *BuiltinProvider) Add(ctx context.Context, req adapters.AddRequest) (adapters.SearchResponse, error) {
|
||||
if p.service == nil {
|
||||
return SearchResponse{}, errors.New("memory runtime not configured")
|
||||
return adapters.SearchResponse{}, errors.New("memory runtime not configured")
|
||||
}
|
||||
return p.service.Add(ctx, req)
|
||||
}
|
||||
|
||||
func (p *BuiltinProvider) Search(ctx context.Context, req SearchRequest) (SearchResponse, error) {
|
||||
func (p *BuiltinProvider) Search(ctx context.Context, req adapters.SearchRequest) (adapters.SearchResponse, error) {
|
||||
if p.service == nil {
|
||||
return SearchResponse{}, errors.New("memory runtime not configured")
|
||||
return adapters.SearchResponse{}, errors.New("memory runtime not configured")
|
||||
}
|
||||
return p.service.Search(ctx, req)
|
||||
}
|
||||
|
||||
func (p *BuiltinProvider) GetAll(ctx context.Context, req GetAllRequest) (SearchResponse, error) {
|
||||
func (p *BuiltinProvider) GetAll(ctx context.Context, req adapters.GetAllRequest) (adapters.SearchResponse, error) {
|
||||
if p.service == nil {
|
||||
return SearchResponse{}, errors.New("memory runtime not configured")
|
||||
return adapters.SearchResponse{}, errors.New("memory runtime not configured")
|
||||
}
|
||||
return p.service.GetAll(ctx, req)
|
||||
}
|
||||
|
||||
func (p *BuiltinProvider) Update(ctx context.Context, req UpdateRequest) (MemoryItem, error) {
|
||||
func (p *BuiltinProvider) Update(ctx context.Context, req adapters.UpdateRequest) (adapters.MemoryItem, error) {
|
||||
if p.service == nil {
|
||||
return MemoryItem{}, errors.New("memory runtime not configured")
|
||||
return adapters.MemoryItem{}, errors.New("memory runtime not configured")
|
||||
}
|
||||
return p.service.Update(ctx, req)
|
||||
}
|
||||
|
||||
func (p *BuiltinProvider) Delete(ctx context.Context, memoryID string) (DeleteResponse, error) {
|
||||
func (p *BuiltinProvider) Delete(ctx context.Context, memoryID string) (adapters.DeleteResponse, error) {
|
||||
if p.service == nil {
|
||||
return DeleteResponse{}, errors.New("memory runtime not configured")
|
||||
return adapters.DeleteResponse{}, errors.New("memory runtime not configured")
|
||||
}
|
||||
return p.service.Delete(ctx, memoryID)
|
||||
}
|
||||
|
||||
func (p *BuiltinProvider) DeleteBatch(ctx context.Context, memoryIDs []string) (DeleteResponse, error) {
|
||||
func (p *BuiltinProvider) DeleteBatch(ctx context.Context, memoryIDs []string) (adapters.DeleteResponse, error) {
|
||||
if p.service == nil {
|
||||
return DeleteResponse{}, errors.New("memory runtime not configured")
|
||||
return adapters.DeleteResponse{}, errors.New("memory runtime not configured")
|
||||
}
|
||||
return p.service.DeleteBatch(ctx, memoryIDs)
|
||||
}
|
||||
|
||||
func (p *BuiltinProvider) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteResponse, error) {
|
||||
func (p *BuiltinProvider) DeleteAll(ctx context.Context, req adapters.DeleteAllRequest) (adapters.DeleteResponse, error) {
|
||||
if p.service == nil {
|
||||
return DeleteResponse{}, errors.New("memory runtime not configured")
|
||||
return adapters.DeleteResponse{}, errors.New("memory runtime not configured")
|
||||
}
|
||||
return p.service.DeleteAll(ctx, req)
|
||||
}
|
||||
|
||||
func (p *BuiltinProvider) Compact(ctx context.Context, filters map[string]any, ratio float64, decayDays int) (CompactResult, error) {
|
||||
func (p *BuiltinProvider) Compact(ctx context.Context, filters map[string]any, ratio float64, decayDays int) (adapters.CompactResult, error) {
|
||||
if p.service == nil {
|
||||
return CompactResult{}, errors.New("memory runtime not configured")
|
||||
return adapters.CompactResult{}, errors.New("memory runtime not configured")
|
||||
}
|
||||
return p.service.Compact(ctx, filters, ratio, decayDays)
|
||||
}
|
||||
|
||||
func (p *BuiltinProvider) Usage(ctx context.Context, filters map[string]any) (UsageResponse, error) {
|
||||
func (p *BuiltinProvider) Usage(ctx context.Context, filters map[string]any) (adapters.UsageResponse, error) {
|
||||
if p.service == nil {
|
||||
return UsageResponse{}, errors.New("memory runtime not configured")
|
||||
return adapters.UsageResponse{}, errors.New("memory runtime not configured")
|
||||
}
|
||||
return p.service.Usage(ctx, filters)
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func truncateSnippet(s string, n int) string {
|
||||
trimmed := strings.TrimSpace(s)
|
||||
runes := []rune(trimmed)
|
||||
if len(runes) <= n {
|
||||
return trimmed
|
||||
func (p *BuiltinProvider) Status(ctx context.Context, botID string) (adapters.MemoryStatusResponse, error) {
|
||||
if p.service == nil {
|
||||
return adapters.MemoryStatusResponse{}, errors.New("memory runtime not configured")
|
||||
}
|
||||
return strings.TrimSpace(string(runes[:n])) + "..."
|
||||
return p.service.Status(ctx, botID)
|
||||
}
|
||||
|
||||
func deduplicateItems(items []MemoryItem) []MemoryItem {
|
||||
if len(items) == 0 {
|
||||
return items
|
||||
func (p *BuiltinProvider) Rebuild(ctx context.Context, botID string) (adapters.RebuildResult, error) {
|
||||
if p.service == nil {
|
||||
return adapters.RebuildResult{}, errors.New("memory runtime not configured")
|
||||
}
|
||||
seen := make(map[string]struct{}, len(items))
|
||||
result := make([]MemoryItem, 0, len(items))
|
||||
for _, item := range items {
|
||||
id := strings.TrimSpace(item.ID)
|
||||
if id == "" {
|
||||
id = strings.TrimSpace(item.Memory)
|
||||
}
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
result = append(result, item)
|
||||
}
|
||||
return result
|
||||
return p.service.Rebuild(ctx, botID)
|
||||
}
|
||||
+8
-6
@@ -1,13 +1,15 @@
|
||||
package provider
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
)
|
||||
|
||||
func TestTruncateSnippet_ASCII(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := truncateSnippet("hello world", 5)
|
||||
got := adapters.TruncateSnippet("hello world", 5)
|
||||
if got != "hello..." {
|
||||
t.Fatalf("expected %q, got %q", "hello...", got)
|
||||
}
|
||||
@@ -15,7 +17,7 @@ func TestTruncateSnippet_ASCII(t *testing.T) {
|
||||
|
||||
func TestTruncateSnippet_NoTruncation(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := truncateSnippet("short", 100)
|
||||
got := adapters.TruncateSnippet("short", 100)
|
||||
if got != "short" {
|
||||
t.Fatalf("expected %q, got %q", "short", got)
|
||||
}
|
||||
@@ -24,7 +26,7 @@ func TestTruncateSnippet_NoTruncation(t *testing.T) {
|
||||
func TestTruncateSnippet_CJK(t *testing.T) {
|
||||
t.Parallel()
|
||||
// 5 CJK characters (15 bytes in UTF-8), truncate to 3 runes.
|
||||
got := truncateSnippet("你好世界啊", 3)
|
||||
got := adapters.TruncateSnippet("你好世界啊", 3)
|
||||
if !utf8.ValidString(got) {
|
||||
t.Fatalf("result is not valid UTF-8: %q", got)
|
||||
}
|
||||
@@ -36,7 +38,7 @@ func TestTruncateSnippet_CJK(t *testing.T) {
|
||||
func TestTruncateSnippet_Emoji(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Emoji are 4 bytes each in UTF-8.
|
||||
got := truncateSnippet("😀😁😂🤣😃", 2)
|
||||
got := adapters.TruncateSnippet("😀😁😂🤣😃", 2)
|
||||
if !utf8.ValidString(got) {
|
||||
t.Fatalf("result is not valid UTF-8: %q", got)
|
||||
}
|
||||
@@ -47,7 +49,7 @@ func TestTruncateSnippet_Emoji(t *testing.T) {
|
||||
|
||||
func TestTruncateSnippet_TrimWhitespace(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := truncateSnippet(" hello ", 100)
|
||||
got := adapters.TruncateSnippet(" hello ", 100)
|
||||
if got != "hello" {
|
||||
t.Fatalf("expected %q, got %q", "hello", got)
|
||||
}
|
||||
@@ -0,0 +1,736 @@
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
qdrantclient "github.com/memohai/memoh/internal/memory/qdrant"
|
||||
storefs "github.com/memohai/memoh/internal/memory/storefs"
|
||||
)
|
||||
|
||||
type denseRuntime struct {
|
||||
qdrant *qdrantclient.Client
|
||||
store *storefs.Service
|
||||
embedder *denseEmbeddingClient
|
||||
collection string
|
||||
}
|
||||
|
||||
type denseEmbeddingClient struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
modelID string
|
||||
dimensions int
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
type denseEmbeddingResponse struct {
|
||||
Data []struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type denseModelSpec struct {
|
||||
modelID string
|
||||
baseURL string
|
||||
apiKey string
|
||||
dimensions int
|
||||
}
|
||||
|
||||
func newDenseRuntime(providerConfig map[string]any, queries *dbsqlc.Queries, cfg config.Config, store *storefs.Service) (*denseRuntime, error) {
|
||||
if queries == nil {
|
||||
return nil, errors.New("dense runtime: queries are required")
|
||||
}
|
||||
if store == nil {
|
||||
return nil, errors.New("dense runtime: memory store is required")
|
||||
}
|
||||
|
||||
modelRef := strings.TrimSpace(adapters.StringFromConfig(providerConfig, "embedding_model_id"))
|
||||
if modelRef == "" {
|
||||
return nil, errors.New("dense runtime: embedding_model_id is required")
|
||||
}
|
||||
|
||||
modelSpec, err := resolveDenseEmbeddingModel(context.Background(), queries, modelRef)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
host, port := parseQdrantHostPort(cfg.Qdrant.BaseURL)
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
}
|
||||
if port == 0 {
|
||||
port = 6334
|
||||
}
|
||||
collection := adapters.StringFromConfig(providerConfig, "qdrant_collection")
|
||||
if strings.TrimSpace(collection) == "" {
|
||||
collection = "memory_dense"
|
||||
}
|
||||
qClient, err := qdrantclient.NewClient(host, port, cfg.Qdrant.APIKey, collection)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dense runtime: %w", err)
|
||||
}
|
||||
|
||||
return &denseRuntime{
|
||||
qdrant: qClient,
|
||||
store: store,
|
||||
embedder: &denseEmbeddingClient{
|
||||
baseURL: strings.TrimRight(modelSpec.baseURL, "/"),
|
||||
apiKey: modelSpec.apiKey,
|
||||
modelID: modelSpec.modelID,
|
||||
dimensions: modelSpec.dimensions,
|
||||
httpClient: &http.Client{Timeout: 30 * time.Second},
|
||||
},
|
||||
collection: collection,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *denseRuntime) Add(ctx context.Context, req adapters.AddRequest) (adapters.SearchResponse, error) {
|
||||
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
text := sparseRuntimeText(req.Message, req.Messages)
|
||||
if text == "" {
|
||||
return adapters.SearchResponse{}, errors.New("dense runtime: message is required")
|
||||
}
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
item := adapters.MemoryItem{
|
||||
ID: sparseRuntimeMemoryID(botID, time.Now().UTC()),
|
||||
Memory: text,
|
||||
Hash: denseRuntimeHash(text),
|
||||
Metadata: req.Metadata,
|
||||
BotID: botID,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if err := r.store.PersistMemories(ctx, botID, []storefs.MemoryItem{denseStoreItemFromMemoryItem(item)}, req.Filters); err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
if err := r.upsertSourceItems(ctx, botID, []storefs.MemoryItem{denseStoreItemFromMemoryItem(item)}); err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
return adapters.SearchResponse{Results: []adapters.MemoryItem{item}}, nil
|
||||
}
|
||||
|
||||
func (r *denseRuntime) Search(ctx context.Context, req adapters.SearchRequest) (adapters.SearchResponse, error) {
|
||||
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
if err := r.qdrant.EnsureDenseCollection(ctx, r.embedder.dimensions); err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
limit := req.Limit
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
vec, err := r.embedder.EmbedQuery(ctx, req.Query)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, fmt.Errorf("dense embed query: %w", err)
|
||||
}
|
||||
results, err := r.qdrant.SearchDense(ctx, qdrantclient.DenseVector{Values: vec}, botID, limit)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
items := make([]adapters.MemoryItem, 0, len(results))
|
||||
for _, result := range results {
|
||||
items = append(items, denseResultToItem(result))
|
||||
}
|
||||
return adapters.SearchResponse{Results: items}, nil
|
||||
}
|
||||
|
||||
func (r *denseRuntime) GetAll(ctx context.Context, req adapters.GetAllRequest) (adapters.SearchResponse, error) {
|
||||
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
result := make([]adapters.MemoryItem, 0, len(items))
|
||||
for _, item := range items {
|
||||
mem := denseMemoryItemFromStore(item)
|
||||
mem.BotID = botID
|
||||
result = append(result, mem)
|
||||
}
|
||||
sort.Slice(result, func(i, j int) bool { return result[i].UpdatedAt > result[j].UpdatedAt })
|
||||
if req.Limit > 0 && len(result) > req.Limit {
|
||||
result = result[:req.Limit]
|
||||
}
|
||||
return adapters.SearchResponse{Results: result}, nil
|
||||
}
|
||||
|
||||
func (r *denseRuntime) Update(ctx context.Context, req adapters.UpdateRequest) (adapters.MemoryItem, error) {
|
||||
memoryID := strings.TrimSpace(req.MemoryID)
|
||||
if memoryID == "" {
|
||||
return adapters.MemoryItem{}, errors.New("dense runtime: memory_id is required")
|
||||
}
|
||||
text := strings.TrimSpace(req.Memory)
|
||||
if text == "" {
|
||||
return adapters.MemoryItem{}, errors.New("dense runtime: memory is required")
|
||||
}
|
||||
botID := sparseRuntimeBotIDFromMemoryID(memoryID)
|
||||
if botID == "" {
|
||||
return adapters.MemoryItem{}, errors.New("dense runtime: invalid memory_id")
|
||||
}
|
||||
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.MemoryItem{}, err
|
||||
}
|
||||
var existing *storefs.MemoryItem
|
||||
for i := range items {
|
||||
if strings.TrimSpace(items[i].ID) == memoryID {
|
||||
item := items[i]
|
||||
existing = &item
|
||||
break
|
||||
}
|
||||
}
|
||||
if existing == nil {
|
||||
return adapters.MemoryItem{}, errors.New("dense runtime: memory not found")
|
||||
}
|
||||
existing.Memory = text
|
||||
existing.Hash = denseRuntimeHash(text)
|
||||
existing.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
if err := r.store.PersistMemories(ctx, botID, []storefs.MemoryItem{*existing}, nil); err != nil {
|
||||
return adapters.MemoryItem{}, err
|
||||
}
|
||||
if err := r.upsertSourceItems(ctx, botID, []storefs.MemoryItem{*existing}); err != nil {
|
||||
return adapters.MemoryItem{}, err
|
||||
}
|
||||
item := denseMemoryItemFromStore(*existing)
|
||||
item.BotID = botID
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func (r *denseRuntime) Delete(ctx context.Context, memoryID string) (adapters.DeleteResponse, error) {
|
||||
return r.DeleteBatch(ctx, []string{memoryID})
|
||||
}
|
||||
|
||||
func (r *denseRuntime) DeleteBatch(ctx context.Context, memoryIDs []string) (adapters.DeleteResponse, error) {
|
||||
grouped := map[string][]string{}
|
||||
pointIDs := make([]string, 0, len(memoryIDs))
|
||||
for _, rawID := range memoryIDs {
|
||||
memoryID := strings.TrimSpace(rawID)
|
||||
if memoryID == "" {
|
||||
continue
|
||||
}
|
||||
botID := sparseRuntimeBotIDFromMemoryID(memoryID)
|
||||
if botID == "" {
|
||||
continue
|
||||
}
|
||||
grouped[botID] = append(grouped[botID], memoryID)
|
||||
pointIDs = append(pointIDs, sparsePointID(botID, memoryID))
|
||||
}
|
||||
for botID, ids := range grouped {
|
||||
if err := r.store.RemoveMemories(ctx, botID, ids); err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
}
|
||||
if err := r.qdrant.DeleteByIDs(ctx, pointIDs); err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
return adapters.DeleteResponse{Message: "Memories deleted successfully!"}, nil
|
||||
}
|
||||
|
||||
func (r *denseRuntime) DeleteAll(ctx context.Context, req adapters.DeleteAllRequest) (adapters.DeleteResponse, error) {
|
||||
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
|
||||
if err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
if err := r.store.RemoveAllMemories(ctx, botID); err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
if err := r.qdrant.DeleteByBotID(ctx, botID); err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
return adapters.DeleteResponse{Message: "All memories deleted successfully!"}, nil
|
||||
}
|
||||
|
||||
func (r *denseRuntime) Compact(ctx context.Context, filters map[string]any, ratio float64, _ int) (adapters.CompactResult, error) {
|
||||
botID, err := sparseRuntimeBotID("", filters)
|
||||
if err != nil {
|
||||
return adapters.CompactResult{}, err
|
||||
}
|
||||
if ratio <= 0 || ratio > 1 {
|
||||
return adapters.CompactResult{}, errors.New("ratio must be in range (0, 1]")
|
||||
}
|
||||
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.CompactResult{}, err
|
||||
}
|
||||
before := len(items)
|
||||
if before == 0 {
|
||||
return adapters.CompactResult{BeforeCount: 0, AfterCount: 0, Ratio: ratio, Results: []adapters.MemoryItem{}}, nil
|
||||
}
|
||||
sort.Slice(items, func(i, j int) bool { return items[i].UpdatedAt > items[j].UpdatedAt })
|
||||
target := int(float64(before) * ratio)
|
||||
if target < 1 {
|
||||
target = 1
|
||||
}
|
||||
if target > before {
|
||||
target = before
|
||||
}
|
||||
keptStore := append([]storefs.MemoryItem(nil), items[:target]...)
|
||||
if err := r.store.RebuildFiles(ctx, botID, keptStore, filters); err != nil {
|
||||
return adapters.CompactResult{}, err
|
||||
}
|
||||
if _, err := r.Rebuild(ctx, botID); err != nil {
|
||||
return adapters.CompactResult{}, err
|
||||
}
|
||||
kept := make([]adapters.MemoryItem, 0, len(keptStore))
|
||||
for _, item := range keptStore {
|
||||
kept = append(kept, denseMemoryItemFromStore(item))
|
||||
}
|
||||
return adapters.CompactResult{
|
||||
BeforeCount: before,
|
||||
AfterCount: len(kept),
|
||||
Ratio: ratio,
|
||||
Results: kept,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *denseRuntime) Usage(ctx context.Context, filters map[string]any) (adapters.UsageResponse, error) {
|
||||
botID, err := sparseRuntimeBotID("", filters)
|
||||
if err != nil {
|
||||
return adapters.UsageResponse{}, err
|
||||
}
|
||||
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.UsageResponse{}, err
|
||||
}
|
||||
var usage adapters.UsageResponse
|
||||
usage.Count = len(items)
|
||||
for _, item := range items {
|
||||
usage.TotalTextBytes += int64(len(item.Memory))
|
||||
}
|
||||
if usage.Count > 0 {
|
||||
usage.AvgTextBytes = usage.TotalTextBytes / int64(usage.Count)
|
||||
}
|
||||
usage.EstimatedStorageBytes = usage.TotalTextBytes
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (*denseRuntime) Mode() string {
|
||||
return string(ModeDense)
|
||||
}
|
||||
|
||||
func (r *denseRuntime) Status(ctx context.Context, botID string) (adapters.MemoryStatusResponse, error) {
|
||||
fileCount, err := r.store.CountMemoryFiles(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.MemoryStatusResponse{}, err
|
||||
}
|
||||
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.MemoryStatusResponse{}, err
|
||||
}
|
||||
status := adapters.MemoryStatusResponse{
|
||||
ProviderType: BuiltinType,
|
||||
MemoryMode: string(ModeDense),
|
||||
CanManualSync: true,
|
||||
SourceDir: path.Join(config.DefaultDataMount, "memory"),
|
||||
OverviewPath: path.Join(config.DefaultDataMount, "MEMORY.md"),
|
||||
MarkdownFileCount: fileCount,
|
||||
SourceCount: len(items),
|
||||
QdrantCollection: r.collection,
|
||||
}
|
||||
if err := r.embedder.Health(ctx); err != nil {
|
||||
status.Encoder.Error = err.Error()
|
||||
} else {
|
||||
status.Encoder.OK = true
|
||||
}
|
||||
exists, err := r.qdrant.CollectionExists(ctx)
|
||||
if err != nil {
|
||||
status.Qdrant.Error = err.Error()
|
||||
return status, nil
|
||||
}
|
||||
status.Qdrant.OK = true
|
||||
if exists {
|
||||
count, err := r.qdrant.Count(ctx, botID)
|
||||
if err != nil {
|
||||
status.Qdrant.OK = false
|
||||
status.Qdrant.Error = err.Error()
|
||||
return status, nil
|
||||
}
|
||||
status.IndexedCount = count
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (r *denseRuntime) Rebuild(ctx context.Context, botID string) (adapters.RebuildResult, error) {
|
||||
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
if err := r.store.SyncOverview(ctx, botID); err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
return r.syncSourceItems(ctx, botID, items)
|
||||
}
|
||||
|
||||
func (r *denseRuntime) syncSourceItems(ctx context.Context, botID string, items []storefs.MemoryItem) (adapters.RebuildResult, error) {
|
||||
if err := r.qdrant.EnsureDenseCollection(ctx, r.embedder.dimensions); err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
existing, err := r.qdrant.Scroll(ctx, botID, 10000)
|
||||
if err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
existingBySource := make(map[string]qdrantclient.SearchResult, len(existing))
|
||||
for _, item := range existing {
|
||||
sourceID := strings.TrimSpace(item.Payload["source_entry_id"])
|
||||
if sourceID == "" {
|
||||
sourceID = strings.TrimSpace(item.ID)
|
||||
}
|
||||
if sourceID != "" {
|
||||
existingBySource[sourceID] = item
|
||||
}
|
||||
}
|
||||
sourceIDs := make(map[string]struct{}, len(items))
|
||||
toUpsert := make([]storefs.MemoryItem, 0, len(items))
|
||||
missingCount := 0
|
||||
restoredCount := 0
|
||||
for _, item := range items {
|
||||
item = denseCanonicalStoreItem(item)
|
||||
if item.ID == "" || item.Memory == "" {
|
||||
continue
|
||||
}
|
||||
sourceIDs[item.ID] = struct{}{}
|
||||
payload := densePayload(botID, item)
|
||||
existingItem, ok := existingBySource[item.ID]
|
||||
if !ok {
|
||||
missingCount++
|
||||
restoredCount++
|
||||
toUpsert = append(toUpsert, item)
|
||||
continue
|
||||
}
|
||||
if !densePayloadMatches(existingItem.Payload, payload) {
|
||||
restoredCount++
|
||||
toUpsert = append(toUpsert, item)
|
||||
}
|
||||
}
|
||||
stale := make([]string, 0)
|
||||
for _, item := range existing {
|
||||
sourceID := strings.TrimSpace(item.Payload["source_entry_id"])
|
||||
if sourceID == "" {
|
||||
sourceID = strings.TrimSpace(item.ID)
|
||||
}
|
||||
if _, ok := sourceIDs[sourceID]; ok {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(item.ID) != "" {
|
||||
stale = append(stale, item.ID)
|
||||
}
|
||||
}
|
||||
if len(stale) > 0 {
|
||||
if err := r.qdrant.DeleteByIDs(ctx, stale); err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
}
|
||||
if err := r.upsertSourceItems(ctx, botID, toUpsert); err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
count, err := r.qdrant.Count(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
return adapters.RebuildResult{
|
||||
FsCount: len(items),
|
||||
StorageCount: count,
|
||||
MissingCount: missingCount,
|
||||
RestoredCount: restoredCount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *denseRuntime) upsertSourceItems(ctx context.Context, botID string, items []storefs.MemoryItem) error {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := r.qdrant.EnsureDenseCollection(ctx, r.embedder.dimensions); err != nil {
|
||||
return err
|
||||
}
|
||||
canonical := make([]storefs.MemoryItem, 0, len(items))
|
||||
texts := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
item = denseCanonicalStoreItem(item)
|
||||
if item.ID == "" || item.Memory == "" {
|
||||
continue
|
||||
}
|
||||
canonical = append(canonical, item)
|
||||
texts = append(texts, item.Memory)
|
||||
}
|
||||
if len(canonical) == 0 {
|
||||
return nil
|
||||
}
|
||||
vectors, err := r.embedder.EmbedDocuments(ctx, texts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dense embed documents: %w", err)
|
||||
}
|
||||
if len(vectors) != len(canonical) {
|
||||
return fmt.Errorf("dense embed documents: expected %d vectors, got %d", len(canonical), len(vectors))
|
||||
}
|
||||
for i, item := range canonical {
|
||||
if err := r.qdrant.UpsertDense(ctx, sparsePointID(botID, item.ID), qdrantclient.DenseVector{
|
||||
Values: vectors[i],
|
||||
}, densePayload(botID, item)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveDenseEmbeddingModel(ctx context.Context, queries *dbsqlc.Queries, modelRef string) (denseModelSpec, error) {
|
||||
modelRef = strings.TrimSpace(modelRef)
|
||||
if modelRef == "" {
|
||||
return denseModelSpec{}, errors.New("dense runtime: embedding_model_id is required")
|
||||
}
|
||||
var row dbsqlc.Model
|
||||
if parsed, err := db.ParseUUID(modelRef); err == nil {
|
||||
dbModel, err := queries.GetModelByID(ctx, parsed)
|
||||
if err == nil {
|
||||
row = dbModel
|
||||
}
|
||||
}
|
||||
if !row.ID.Valid {
|
||||
rows, err := queries.ListModelsByModelID(ctx, modelRef)
|
||||
if err != nil || len(rows) == 0 {
|
||||
return denseModelSpec{}, fmt.Errorf("dense runtime: embedding model not found: %s", modelRef)
|
||||
}
|
||||
row = rows[0]
|
||||
}
|
||||
if row.Type != "embedding" {
|
||||
return denseModelSpec{}, fmt.Errorf("dense runtime: model %s is not an embedding model", modelRef)
|
||||
}
|
||||
if !row.LlmProviderID.Valid {
|
||||
return denseModelSpec{}, fmt.Errorf("dense runtime: model %s has no provider", modelRef)
|
||||
}
|
||||
provider, err := queries.GetLlmProviderByID(ctx, row.LlmProviderID)
|
||||
if err != nil {
|
||||
return denseModelSpec{}, fmt.Errorf("dense runtime: get embedding provider: %w", err)
|
||||
}
|
||||
if !row.Dimensions.Valid || row.Dimensions.Int32 <= 0 {
|
||||
return denseModelSpec{}, fmt.Errorf("dense runtime: embedding model %s missing dimensions", modelRef)
|
||||
}
|
||||
return denseModelSpec{
|
||||
modelID: strings.TrimSpace(row.ModelID),
|
||||
baseURL: strings.TrimSpace(provider.BaseUrl),
|
||||
apiKey: strings.TrimSpace(provider.ApiKey),
|
||||
dimensions: int(row.Dimensions.Int32),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func joinDenseEmbeddingEndpointURL(baseURL, endpointPath string) (string, error) {
|
||||
baseURL = strings.TrimSpace(baseURL)
|
||||
if baseURL == "" {
|
||||
return "", errors.New("dense embedding base URL is required")
|
||||
}
|
||||
|
||||
base, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid dense embedding base URL: %w", err)
|
||||
}
|
||||
if base.Scheme != "http" && base.Scheme != "https" {
|
||||
return "", fmt.Errorf("invalid dense embedding base URL scheme: %q", base.Scheme)
|
||||
}
|
||||
if base.Host == "" {
|
||||
return "", errors.New("invalid dense embedding base URL: host is required")
|
||||
}
|
||||
|
||||
ref, err := url.Parse(endpointPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid dense embedding path: %w", err)
|
||||
}
|
||||
return base.ResolveReference(ref).String(), nil
|
||||
}
|
||||
|
||||
func (c *denseEmbeddingClient) Health(ctx context.Context) error {
|
||||
endpoint, err := joinDenseEmbeddingEndpointURL(c.baseURL, "/models")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
}
|
||||
resp, err := c.httpClient.Do(req) //nolint:gosec // G704: URL is validated and derived from operator-configured embedding provider base URL
|
||||
if err != nil {
|
||||
return fmt.Errorf("dense embedding health check failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("dense embedding health error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *denseEmbeddingClient) EmbedQuery(ctx context.Context, text string) ([]float32, error) {
|
||||
vectors, err := c.EmbedDocuments(ctx, []string{text})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vectors) == 0 {
|
||||
return nil, errors.New("dense embed query: empty embedding response")
|
||||
}
|
||||
return vectors[0], nil
|
||||
}
|
||||
|
||||
func (c *denseEmbeddingClient) EmbedDocuments(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"model": c.modelID,
|
||||
"input": texts,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
endpoint, err := joinDenseEmbeddingEndpointURL(c.baseURL, "/embeddings")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if c.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
}
|
||||
resp, err := c.httpClient.Do(req) //nolint:gosec // G704: URL is validated and derived from operator-configured embedding provider base URL
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dense embed request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dense embed read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("dense embed api error %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
var parsed denseEmbeddingResponse
|
||||
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("dense embed decode response: %w", err)
|
||||
}
|
||||
vectors := make([][]float32, len(parsed.Data))
|
||||
for _, item := range parsed.Data {
|
||||
if item.Index >= 0 && item.Index < len(vectors) {
|
||||
vectors[item.Index] = item.Embedding
|
||||
}
|
||||
}
|
||||
out := make([][]float32, 0, len(vectors))
|
||||
for _, vector := range vectors {
|
||||
if len(vector) > 0 {
|
||||
out = append(out, vector)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func denseCanonicalStoreItem(item storefs.MemoryItem) storefs.MemoryItem {
|
||||
item.ID = strings.TrimSpace(item.ID)
|
||||
item.Memory = strings.TrimSpace(item.Memory)
|
||||
if item.Memory != "" && strings.TrimSpace(item.Hash) == "" {
|
||||
item.Hash = denseRuntimeHash(item.Memory)
|
||||
}
|
||||
return item
|
||||
}
|
||||
|
||||
func densePayload(botID string, item storefs.MemoryItem) map[string]string {
|
||||
item = denseCanonicalStoreItem(item)
|
||||
payload := map[string]string{
|
||||
"memory": item.Memory,
|
||||
"bot_id": strings.TrimSpace(botID),
|
||||
"source_entry_id": item.ID,
|
||||
"hash": item.Hash,
|
||||
}
|
||||
if item.CreatedAt != "" {
|
||||
payload["created_at"] = item.CreatedAt
|
||||
}
|
||||
if item.UpdatedAt != "" {
|
||||
payload["updated_at"] = item.UpdatedAt
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func densePayloadMatches(existing, expected map[string]string) bool {
|
||||
for key, value := range expected {
|
||||
if strings.TrimSpace(existing[key]) != strings.TrimSpace(value) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func denseStoreItemFromMemoryItem(item adapters.MemoryItem) storefs.MemoryItem {
|
||||
return denseCanonicalStoreItem(storefs.MemoryItem{
|
||||
ID: item.ID,
|
||||
Memory: item.Memory,
|
||||
Hash: item.Hash,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
Score: item.Score,
|
||||
Metadata: item.Metadata,
|
||||
BotID: item.BotID,
|
||||
AgentID: item.AgentID,
|
||||
RunID: item.RunID,
|
||||
})
|
||||
}
|
||||
|
||||
func denseMemoryItemFromStore(item storefs.MemoryItem) adapters.MemoryItem {
|
||||
item = denseCanonicalStoreItem(item)
|
||||
return adapters.MemoryItem{
|
||||
ID: item.ID,
|
||||
Memory: item.Memory,
|
||||
Hash: item.Hash,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
Score: item.Score,
|
||||
Metadata: item.Metadata,
|
||||
BotID: item.BotID,
|
||||
AgentID: item.AgentID,
|
||||
RunID: item.RunID,
|
||||
}
|
||||
}
|
||||
|
||||
func denseResultToItem(r qdrantclient.SearchResult) adapters.MemoryItem {
|
||||
item := adapters.MemoryItem{
|
||||
ID: r.ID,
|
||||
Score: r.Score,
|
||||
}
|
||||
if r.Payload != nil {
|
||||
if sourceID := strings.TrimSpace(r.Payload["source_entry_id"]); sourceID != "" {
|
||||
item.ID = sourceID
|
||||
}
|
||||
item.Memory = r.Payload["memory"]
|
||||
item.Hash = r.Payload["hash"]
|
||||
item.BotID = r.Payload["bot_id"]
|
||||
item.CreatedAt = r.Payload["created_at"]
|
||||
item.UpdatedAt = r.Payload["updated_at"]
|
||||
}
|
||||
return item
|
||||
}
|
||||
|
||||
func denseRuntimeHash(text string) string {
|
||||
sum := sha256.Sum256([]byte(strings.TrimSpace(text)))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
storefs "github.com/memohai/memoh/internal/memory/storefs"
|
||||
)
|
||||
|
||||
// BuiltinMemoryMode represents the operating mode of the built-in memory provider.
|
||||
type BuiltinMemoryMode string
|
||||
|
||||
const (
|
||||
ModeOff BuiltinMemoryMode = "off"
|
||||
ModeSparse BuiltinMemoryMode = "sparse"
|
||||
ModeDense BuiltinMemoryMode = "dense"
|
||||
)
|
||||
|
||||
// NewBuiltinRuntimeFromConfig returns the appropriate memoryRuntime based on the
|
||||
// provider's persisted config (memory_mode field). Falls back to the file runtime for "off" or unknown.
|
||||
func NewBuiltinRuntimeFromConfig(log *slog.Logger, providerConfig map[string]any, fileRuntime any, store *storefs.Service, queries *dbsqlc.Queries, cfg config.Config) (any, error) {
|
||||
mode := BuiltinMemoryMode(strings.TrimSpace(adapters.StringFromConfig(providerConfig, "memory_mode")))
|
||||
|
||||
switch mode {
|
||||
case ModeSparse:
|
||||
host, port := parseQdrantHostPort(cfg.Qdrant.BaseURL)
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
}
|
||||
if port == 0 {
|
||||
port = 6334
|
||||
}
|
||||
collection := adapters.StringFromConfig(providerConfig, "qdrant_collection")
|
||||
if collection == "" {
|
||||
collection = "memory_sparse"
|
||||
}
|
||||
rt, err := newSparseRuntime(
|
||||
host,
|
||||
port,
|
||||
cfg.Qdrant.APIKey,
|
||||
collection,
|
||||
strings.TrimSpace(cfg.Sparse.BaseURL),
|
||||
store,
|
||||
)
|
||||
if err != nil {
|
||||
if log != nil {
|
||||
log.Warn("sparse runtime init failed, falling back to file runtime", slog.Any("error", err))
|
||||
}
|
||||
return fileRuntime, nil
|
||||
}
|
||||
return rt, nil
|
||||
|
||||
case ModeDense:
|
||||
rt, err := newDenseRuntime(providerConfig, queries, cfg, store)
|
||||
if err != nil {
|
||||
if log != nil {
|
||||
log.Warn("dense runtime init failed, falling back to file runtime", slog.Any("error", err))
|
||||
}
|
||||
return fileRuntime, nil
|
||||
}
|
||||
return rt, nil
|
||||
|
||||
default:
|
||||
return fileRuntime, nil
|
||||
}
|
||||
}
|
||||
|
||||
// parseQdrantHostPort extracts host and gRPC port from a Qdrant base URL.
|
||||
// Qdrant base URLs are typically HTTP (port 6333), but the gRPC port is 6334.
|
||||
func parseQdrantHostPort(baseURL string) (string, int) {
|
||||
baseURL = strings.TrimSpace(baseURL)
|
||||
if baseURL == "" {
|
||||
return "", 0
|
||||
}
|
||||
baseURL = strings.TrimPrefix(baseURL, "http://")
|
||||
baseURL = strings.TrimPrefix(baseURL, "https://")
|
||||
parts := strings.SplitN(baseURL, ":", 2)
|
||||
host := parts[0]
|
||||
if len(parts) == 2 {
|
||||
httpPort, err := strconv.Atoi(strings.TrimRight(parts[1], "/"))
|
||||
if err == nil {
|
||||
switch httpPort {
|
||||
case 6333:
|
||||
return host, 6334
|
||||
case 6334:
|
||||
return host, 6334
|
||||
default:
|
||||
// Common case: operator already configured the intended gRPC port.
|
||||
return host, httpPort
|
||||
}
|
||||
}
|
||||
}
|
||||
return host, 6334
|
||||
}
|
||||
@@ -0,0 +1,742 @@
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
qdrantclient "github.com/memohai/memoh/internal/memory/qdrant"
|
||||
"github.com/memohai/memoh/internal/memory/sparse"
|
||||
storefs "github.com/memohai/memoh/internal/memory/storefs"
|
||||
)
|
||||
|
||||
type sparseEncoder interface {
|
||||
EncodeDocument(ctx context.Context, text string) (*sparse.SparseVector, error)
|
||||
EncodeDocuments(ctx context.Context, texts []string) ([]sparse.SparseVector, error)
|
||||
EncodeQuery(ctx context.Context, text string) (*sparse.SparseVector, error)
|
||||
Health(ctx context.Context) error
|
||||
}
|
||||
|
||||
type sparseIndex interface {
|
||||
CollectionName() string
|
||||
CollectionExists(ctx context.Context) (bool, error)
|
||||
EnsureCollection(ctx context.Context) error
|
||||
Upsert(ctx context.Context, id string, vec qdrantclient.SparseVector, payload map[string]string) error
|
||||
Search(ctx context.Context, vec qdrantclient.SparseVector, botID string, limit int) ([]qdrantclient.SearchResult, error)
|
||||
Scroll(ctx context.Context, botID string, limit int) ([]qdrantclient.SearchResult, error)
|
||||
Count(ctx context.Context, botID string) (int, error)
|
||||
DeleteByIDs(ctx context.Context, ids []string) error
|
||||
DeleteByBotID(ctx context.Context, botID string) error
|
||||
}
|
||||
|
||||
type sparseMemoryStore interface {
|
||||
PersistMemories(ctx context.Context, botID string, items []storefs.MemoryItem, filters map[string]any) error
|
||||
ReadAllMemoryFiles(ctx context.Context, botID string) ([]storefs.MemoryItem, error)
|
||||
RemoveMemories(ctx context.Context, botID string, ids []string) error
|
||||
RemoveAllMemories(ctx context.Context, botID string) error
|
||||
RebuildFiles(ctx context.Context, botID string, items []storefs.MemoryItem, filters map[string]any) error
|
||||
SyncOverview(ctx context.Context, botID string) error
|
||||
CountMemoryFiles(ctx context.Context, botID string) (int, error)
|
||||
}
|
||||
|
||||
// sparseRuntime implements memoryRuntime with markdown files as the source of
|
||||
// truth and Qdrant as a derived sparse index used for retrieval.
|
||||
type sparseRuntime struct {
|
||||
qdrant sparseIndex
|
||||
encoder sparseEncoder
|
||||
store sparseMemoryStore
|
||||
}
|
||||
|
||||
const (
|
||||
sparseExplainTopKLimit = 24
|
||||
)
|
||||
|
||||
func newSparseRuntime(qdrantHost string, qdrantPort int, qdrantAPIKey, collection, encoderBaseURL string, store *storefs.Service) (*sparseRuntime, error) {
|
||||
if strings.TrimSpace(qdrantHost) == "" {
|
||||
return nil, errors.New("sparse runtime: qdrant host is required")
|
||||
}
|
||||
if strings.TrimSpace(encoderBaseURL) == "" {
|
||||
return nil, errors.New("sparse runtime: sparse.base_url is required")
|
||||
}
|
||||
if store == nil {
|
||||
return nil, errors.New("sparse runtime: memory store is required")
|
||||
}
|
||||
qClient, err := qdrantclient.NewClient(qdrantHost, qdrantPort, qdrantAPIKey, collection)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sparse runtime: %w", err)
|
||||
}
|
||||
return &sparseRuntime{
|
||||
qdrant: qClient,
|
||||
encoder: sparse.NewClient(encoderBaseURL),
|
||||
store: store,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) ensureCollection(ctx context.Context) error {
|
||||
return r.qdrant.EnsureCollection(ctx)
|
||||
}
|
||||
|
||||
func (*sparseRuntime) Mode() string {
|
||||
return string(ModeSparse)
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) Add(ctx context.Context, req adapters.AddRequest) (adapters.SearchResponse, error) {
|
||||
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
text := sparseRuntimeText(req.Message, req.Messages)
|
||||
if text == "" {
|
||||
return adapters.SearchResponse{}, errors.New("sparse runtime: message is required")
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
item := adapters.MemoryItem{
|
||||
ID: sparseRuntimeMemoryID(botID, time.Now().UTC()),
|
||||
Memory: text,
|
||||
Hash: sparseRuntimeHash(text),
|
||||
Metadata: req.Metadata,
|
||||
BotID: botID,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if err := r.store.PersistMemories(ctx, botID, []storefs.MemoryItem{sparseStoreItemFromMemoryItem(item)}, req.Filters); err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
if err := r.upsertSourceItems(ctx, botID, []storefs.MemoryItem{sparseStoreItemFromMemoryItem(item)}); err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
return adapters.SearchResponse{Results: []adapters.MemoryItem{item}}, nil
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) Search(ctx context.Context, req adapters.SearchRequest) (adapters.SearchResponse, error) {
|
||||
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
if err := r.ensureCollection(ctx); err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
|
||||
limit := req.Limit
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
|
||||
vec, err := r.encoder.EncodeQuery(ctx, req.Query)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, fmt.Errorf("sparse encode query: %w", err)
|
||||
}
|
||||
results, err := r.qdrant.Search(ctx, qdrantclient.SparseVector{
|
||||
Indices: vec.Indices,
|
||||
Values: vec.Values,
|
||||
}, botID, limit)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
items := make([]adapters.MemoryItem, 0, len(results))
|
||||
for _, r := range results {
|
||||
items = append(items, sparseResultToItem(r))
|
||||
}
|
||||
return adapters.SearchResponse{Results: items}, nil
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) GetAll(ctx context.Context, req adapters.GetAllRequest) (adapters.SearchResponse, error) {
|
||||
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
result := make([]adapters.MemoryItem, 0, len(items))
|
||||
for _, item := range items {
|
||||
mem := sparseMemoryItemFromStore(item)
|
||||
mem.BotID = botID
|
||||
result = append(result, mem)
|
||||
}
|
||||
r.populateExplainStats(ctx, sparseMemoryItemPointers(result))
|
||||
sort.Slice(result, func(i, j int) bool { return result[i].UpdatedAt > result[j].UpdatedAt })
|
||||
if req.Limit > 0 && len(result) > req.Limit {
|
||||
result = result[:req.Limit]
|
||||
}
|
||||
return adapters.SearchResponse{Results: result}, nil
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) Update(ctx context.Context, req adapters.UpdateRequest) (adapters.MemoryItem, error) {
|
||||
memoryID := strings.TrimSpace(req.MemoryID)
|
||||
if memoryID == "" {
|
||||
return adapters.MemoryItem{}, errors.New("sparse runtime: memory_id is required")
|
||||
}
|
||||
text := strings.TrimSpace(req.Memory)
|
||||
if text == "" {
|
||||
return adapters.MemoryItem{}, errors.New("sparse runtime: memory is required")
|
||||
}
|
||||
botID := sparseRuntimeBotIDFromMemoryID(memoryID)
|
||||
if botID == "" {
|
||||
return adapters.MemoryItem{}, errors.New("sparse runtime: invalid memory_id")
|
||||
}
|
||||
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.MemoryItem{}, err
|
||||
}
|
||||
var existing *storefs.MemoryItem
|
||||
for i := range items {
|
||||
if strings.TrimSpace(items[i].ID) == memoryID {
|
||||
item := items[i]
|
||||
existing = &item
|
||||
break
|
||||
}
|
||||
}
|
||||
if existing == nil {
|
||||
return adapters.MemoryItem{}, errors.New("sparse runtime: memory not found")
|
||||
}
|
||||
existing.Memory = text
|
||||
existing.Hash = sparseRuntimeHash(text)
|
||||
existing.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
if err := r.store.PersistMemories(ctx, botID, []storefs.MemoryItem{*existing}, nil); err != nil {
|
||||
return adapters.MemoryItem{}, err
|
||||
}
|
||||
if err := r.upsertSourceItems(ctx, botID, []storefs.MemoryItem{*existing}); err != nil {
|
||||
return adapters.MemoryItem{}, err
|
||||
}
|
||||
item := sparseMemoryItemFromStore(*existing)
|
||||
item.BotID = botID
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) Delete(ctx context.Context, memoryID string) (adapters.DeleteResponse, error) {
|
||||
return r.DeleteBatch(ctx, []string{memoryID})
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) DeleteBatch(ctx context.Context, memoryIDs []string) (adapters.DeleteResponse, error) {
|
||||
grouped := map[string][]string{}
|
||||
pointIDs := make([]string, 0, len(memoryIDs))
|
||||
for _, rawID := range memoryIDs {
|
||||
memoryID := strings.TrimSpace(rawID)
|
||||
if memoryID == "" {
|
||||
continue
|
||||
}
|
||||
botID := sparseRuntimeBotIDFromMemoryID(memoryID)
|
||||
if botID == "" {
|
||||
continue
|
||||
}
|
||||
grouped[botID] = append(grouped[botID], memoryID)
|
||||
pointIDs = append(pointIDs, sparsePointID(botID, memoryID))
|
||||
}
|
||||
for botID, ids := range grouped {
|
||||
if err := r.store.RemoveMemories(ctx, botID, ids); err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
}
|
||||
if err := r.ensureCollection(ctx); err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
if err := r.qdrant.DeleteByIDs(ctx, pointIDs); err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
return adapters.DeleteResponse{Message: "Memories deleted successfully!"}, nil
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) DeleteAll(ctx context.Context, req adapters.DeleteAllRequest) (adapters.DeleteResponse, error) {
|
||||
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
|
||||
if err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
if err := r.store.RemoveAllMemories(ctx, botID); err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
if err := r.ensureCollection(ctx); err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
if err := r.qdrant.DeleteByBotID(ctx, botID); err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
return adapters.DeleteResponse{Message: "All memories deleted successfully!"}, nil
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) Compact(ctx context.Context, filters map[string]any, ratio float64, _ int) (adapters.CompactResult, error) {
|
||||
botID, err := sparseRuntimeBotID("", filters)
|
||||
if err != nil {
|
||||
return adapters.CompactResult{}, err
|
||||
}
|
||||
all, err := r.store.ReadAllMemoryFiles(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.CompactResult{}, err
|
||||
}
|
||||
before := len(all)
|
||||
if before == 0 {
|
||||
return adapters.CompactResult{Ratio: ratio}, nil
|
||||
}
|
||||
|
||||
sort.Slice(all, func(i, j int) bool {
|
||||
return all[i].UpdatedAt > all[j].UpdatedAt
|
||||
})
|
||||
target := int(float64(before) * ratio)
|
||||
if target < 1 {
|
||||
target = 1
|
||||
}
|
||||
if target > before {
|
||||
target = before
|
||||
}
|
||||
keptStore := append([]storefs.MemoryItem(nil), all[:target]...)
|
||||
if err := r.store.RebuildFiles(ctx, botID, keptStore, filters); err != nil {
|
||||
return adapters.CompactResult{}, err
|
||||
}
|
||||
if _, err := r.Rebuild(ctx, botID); err != nil {
|
||||
return adapters.CompactResult{}, err
|
||||
}
|
||||
kept := make([]adapters.MemoryItem, 0, len(keptStore))
|
||||
for _, item := range keptStore {
|
||||
kept = append(kept, sparseMemoryItemFromStore(item))
|
||||
}
|
||||
return adapters.CompactResult{
|
||||
BeforeCount: before,
|
||||
AfterCount: len(kept),
|
||||
Ratio: ratio,
|
||||
Results: kept,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) Usage(ctx context.Context, filters map[string]any) (adapters.UsageResponse, error) {
|
||||
botID, err := sparseRuntimeBotID("", filters)
|
||||
if err != nil {
|
||||
return adapters.UsageResponse{}, err
|
||||
}
|
||||
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.UsageResponse{}, err
|
||||
}
|
||||
var usage adapters.UsageResponse
|
||||
usage.Count = len(items)
|
||||
for _, item := range items {
|
||||
usage.TotalTextBytes += int64(len(item.Memory))
|
||||
}
|
||||
if usage.Count > 0 {
|
||||
usage.AvgTextBytes = usage.TotalTextBytes / int64(usage.Count)
|
||||
}
|
||||
usage.EstimatedStorageBytes = usage.TotalTextBytes
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) Status(ctx context.Context, botID string) (adapters.MemoryStatusResponse, error) {
|
||||
fileCount, err := r.store.CountMemoryFiles(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.MemoryStatusResponse{}, err
|
||||
}
|
||||
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.MemoryStatusResponse{}, err
|
||||
}
|
||||
status := adapters.MemoryStatusResponse{
|
||||
ProviderType: BuiltinType,
|
||||
MemoryMode: string(ModeSparse),
|
||||
CanManualSync: true,
|
||||
SourceDir: path.Join(config.DefaultDataMount, "memory"),
|
||||
OverviewPath: path.Join(config.DefaultDataMount, "MEMORY.md"),
|
||||
MarkdownFileCount: fileCount,
|
||||
SourceCount: len(items),
|
||||
QdrantCollection: r.qdrant.CollectionName(),
|
||||
}
|
||||
if err := r.encoder.Health(ctx); err != nil {
|
||||
status.Encoder.Error = err.Error()
|
||||
} else {
|
||||
status.Encoder.OK = true
|
||||
}
|
||||
exists, err := r.qdrant.CollectionExists(ctx)
|
||||
if err != nil {
|
||||
status.Qdrant.Error = err.Error()
|
||||
return status, nil
|
||||
}
|
||||
status.Qdrant.OK = true
|
||||
if exists {
|
||||
count, err := r.qdrant.Count(ctx, botID)
|
||||
if err != nil {
|
||||
status.Qdrant.OK = false
|
||||
status.Qdrant.Error = err.Error()
|
||||
return status, nil
|
||||
}
|
||||
status.IndexedCount = count
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) Rebuild(ctx context.Context, botID string) (adapters.RebuildResult, error) {
|
||||
items, err := r.store.ReadAllMemoryFiles(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
if err := r.store.SyncOverview(ctx, botID); err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
return r.syncSourceItems(ctx, botID, items)
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func (r *sparseRuntime) syncSourceItems(ctx context.Context, botID string, items []storefs.MemoryItem) (adapters.RebuildResult, error) {
|
||||
if err := r.ensureCollection(ctx); err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
existing, err := r.qdrant.Scroll(ctx, botID, 10000)
|
||||
if err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
existingBySource := make(map[string]qdrantclient.SearchResult, len(existing))
|
||||
for _, item := range existing {
|
||||
sourceID := strings.TrimSpace(item.Payload["source_entry_id"])
|
||||
if sourceID == "" {
|
||||
sourceID = strings.TrimSpace(item.ID)
|
||||
}
|
||||
if sourceID == "" {
|
||||
continue
|
||||
}
|
||||
existingBySource[sourceID] = item
|
||||
}
|
||||
canonical := make([]storefs.MemoryItem, 0, len(items))
|
||||
sourceIDs := make(map[string]struct{}, len(items))
|
||||
toUpsert := make([]storefs.MemoryItem, 0, len(items))
|
||||
missingCount := 0
|
||||
restoredCount := 0
|
||||
for _, item := range items {
|
||||
item = sparseCanonicalStoreItem(item)
|
||||
if item.ID == "" || item.Memory == "" {
|
||||
continue
|
||||
}
|
||||
canonical = append(canonical, item)
|
||||
sourceIDs[item.ID] = struct{}{}
|
||||
payload := sparsePayload(botID, item)
|
||||
existingItem, ok := existingBySource[item.ID]
|
||||
if !ok {
|
||||
missingCount++
|
||||
restoredCount++
|
||||
toUpsert = append(toUpsert, item)
|
||||
continue
|
||||
}
|
||||
if !sparsePayloadMatches(existingItem.Payload, payload) {
|
||||
restoredCount++
|
||||
toUpsert = append(toUpsert, item)
|
||||
}
|
||||
}
|
||||
stalePointIDs := make([]string, 0)
|
||||
for _, item := range existing {
|
||||
sourceID := strings.TrimSpace(item.Payload["source_entry_id"])
|
||||
if sourceID == "" {
|
||||
sourceID = strings.TrimSpace(item.ID)
|
||||
}
|
||||
if _, ok := sourceIDs[sourceID]; ok {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(item.ID) != "" {
|
||||
stalePointIDs = append(stalePointIDs, item.ID)
|
||||
}
|
||||
}
|
||||
if len(stalePointIDs) > 0 {
|
||||
if err := r.qdrant.DeleteByIDs(ctx, stalePointIDs); err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
}
|
||||
if err := r.upsertSourceItems(ctx, botID, toUpsert); err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
count, err := r.qdrant.Count(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
return adapters.RebuildResult{
|
||||
FsCount: len(canonical),
|
||||
StorageCount: count,
|
||||
MissingCount: missingCount,
|
||||
RestoredCount: restoredCount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) upsertSourceItems(ctx context.Context, botID string, items []storefs.MemoryItem) error {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := r.ensureCollection(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
texts := make([]string, 0, len(items))
|
||||
canonical := make([]storefs.MemoryItem, 0, len(items))
|
||||
for _, item := range items {
|
||||
item = sparseCanonicalStoreItem(item)
|
||||
if item.ID == "" || item.Memory == "" {
|
||||
continue
|
||||
}
|
||||
canonical = append(canonical, item)
|
||||
texts = append(texts, item.Memory)
|
||||
}
|
||||
if len(canonical) == 0 {
|
||||
return nil
|
||||
}
|
||||
vectors, err := r.encoder.EncodeDocuments(ctx, texts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sparse encode documents: %w", err)
|
||||
}
|
||||
if len(vectors) != len(canonical) {
|
||||
return fmt.Errorf("sparse encode documents: expected %d vectors, got %d", len(canonical), len(vectors))
|
||||
}
|
||||
for i, item := range canonical {
|
||||
vec := vectors[i]
|
||||
if err := r.qdrant.Upsert(ctx, sparsePointID(botID, item.ID), qdrantclient.SparseVector{
|
||||
Indices: vec.Indices,
|
||||
Values: vec.Values,
|
||||
}, sparsePayload(botID, item)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sparseResultToItem(r qdrantclient.SearchResult) adapters.MemoryItem {
|
||||
item := adapters.MemoryItem{
|
||||
ID: r.ID,
|
||||
Score: r.Score,
|
||||
}
|
||||
if r.Payload != nil {
|
||||
if sourceID := strings.TrimSpace(r.Payload["source_entry_id"]); sourceID != "" {
|
||||
item.ID = sourceID
|
||||
}
|
||||
item.Memory = r.Payload["memory"]
|
||||
item.Hash = r.Payload["hash"]
|
||||
item.BotID = r.Payload["bot_id"]
|
||||
item.CreatedAt = r.Payload["created_at"]
|
||||
item.UpdatedAt = r.Payload["updated_at"]
|
||||
}
|
||||
return item
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) populateExplainStats(ctx context.Context, items []*adapters.MemoryItem) {
|
||||
if len(items) == 0 {
|
||||
return
|
||||
}
|
||||
texts := make([]string, 0, len(items))
|
||||
targets := make([]*adapters.MemoryItem, 0, len(items))
|
||||
for _, item := range items {
|
||||
if item == nil || strings.TrimSpace(item.Memory) == "" {
|
||||
continue
|
||||
}
|
||||
texts = append(texts, item.Memory)
|
||||
targets = append(targets, item)
|
||||
}
|
||||
if len(texts) == 0 {
|
||||
return
|
||||
}
|
||||
vectors, err := r.encoder.EncodeDocuments(ctx, texts)
|
||||
if err != nil || len(vectors) != len(targets) {
|
||||
return
|
||||
}
|
||||
for i := range targets {
|
||||
topK, cdf := sparseExplainStats(vectors[i])
|
||||
targets[i].TopKBuckets = topK
|
||||
targets[i].CDFCurve = cdf
|
||||
}
|
||||
}
|
||||
|
||||
func sparseExplainStats(vec sparse.SparseVector) ([]adapters.TopKBucket, []adapters.CDFPoint) {
|
||||
type pair struct {
|
||||
index uint32
|
||||
value float32
|
||||
}
|
||||
pairs := make([]pair, 0, len(vec.Values))
|
||||
for i, value := range vec.Values {
|
||||
if i >= len(vec.Indices) || value <= 0 {
|
||||
continue
|
||||
}
|
||||
pairs = append(pairs, pair{index: vec.Indices[i], value: value})
|
||||
}
|
||||
if len(pairs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
sort.Slice(pairs, func(i, j int) bool {
|
||||
if pairs[i].value == pairs[j].value {
|
||||
return pairs[i].index < pairs[j].index
|
||||
}
|
||||
return pairs[i].value > pairs[j].value
|
||||
})
|
||||
topN := len(pairs)
|
||||
if topN > sparseExplainTopKLimit {
|
||||
topN = sparseExplainTopKLimit
|
||||
}
|
||||
topK := make([]adapters.TopKBucket, 0, topN)
|
||||
total := 0.0
|
||||
for _, pair := range pairs {
|
||||
total += float64(pair.value)
|
||||
}
|
||||
for _, pair := range pairs[:topN] {
|
||||
topK = append(topK, adapters.TopKBucket{
|
||||
Index: pair.index,
|
||||
Value: pair.value,
|
||||
})
|
||||
}
|
||||
cdf := make([]adapters.CDFPoint, 0, len(pairs))
|
||||
if total <= 0 {
|
||||
return topK, cdf
|
||||
}
|
||||
running := 0.0
|
||||
for i, pair := range pairs {
|
||||
running += float64(pair.value)
|
||||
cdf = append(cdf, adapters.CDFPoint{
|
||||
K: i + 1,
|
||||
Cumulative: running / total,
|
||||
})
|
||||
}
|
||||
return topK, cdf
|
||||
}
|
||||
|
||||
func sparseMemoryItemPointers(items []adapters.MemoryItem) []*adapters.MemoryItem {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
pointers := make([]*adapters.MemoryItem, 0, len(items))
|
||||
for i := range items {
|
||||
pointers = append(pointers, &items[i])
|
||||
}
|
||||
return pointers
|
||||
}
|
||||
|
||||
func sparseCanonicalStoreItem(item storefs.MemoryItem) storefs.MemoryItem {
|
||||
item.ID = strings.TrimSpace(item.ID)
|
||||
item.Memory = strings.TrimSpace(item.Memory)
|
||||
if item.Memory != "" && strings.TrimSpace(item.Hash) == "" {
|
||||
item.Hash = sparseRuntimeHash(item.Memory)
|
||||
}
|
||||
return item
|
||||
}
|
||||
|
||||
func sparsePayload(botID string, item storefs.MemoryItem) map[string]string {
|
||||
item = sparseCanonicalStoreItem(item)
|
||||
payload := map[string]string{
|
||||
"memory": item.Memory,
|
||||
"bot_id": strings.TrimSpace(botID),
|
||||
"source_entry_id": item.ID,
|
||||
"hash": item.Hash,
|
||||
}
|
||||
if item.CreatedAt != "" {
|
||||
payload["created_at"] = item.CreatedAt
|
||||
}
|
||||
if item.UpdatedAt != "" {
|
||||
payload["updated_at"] = item.UpdatedAt
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func sparsePayloadMatches(existing, expected map[string]string) bool {
|
||||
for key, value := range expected {
|
||||
if strings.TrimSpace(existing[key]) != strings.TrimSpace(value) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func sparseMemoryItemFromStore(item storefs.MemoryItem) adapters.MemoryItem {
|
||||
item = sparseCanonicalStoreItem(item)
|
||||
return adapters.MemoryItem{
|
||||
ID: item.ID,
|
||||
Memory: item.Memory,
|
||||
Hash: item.Hash,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
Score: item.Score,
|
||||
Metadata: item.Metadata,
|
||||
BotID: item.BotID,
|
||||
AgentID: item.AgentID,
|
||||
RunID: item.RunID,
|
||||
}
|
||||
}
|
||||
|
||||
func sparseStoreItemFromMemoryItem(item adapters.MemoryItem) storefs.MemoryItem {
|
||||
return sparseCanonicalStoreItem(storefs.MemoryItem{
|
||||
ID: item.ID,
|
||||
Memory: item.Memory,
|
||||
Hash: item.Hash,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
Score: item.Score,
|
||||
Metadata: item.Metadata,
|
||||
BotID: item.BotID,
|
||||
AgentID: item.AgentID,
|
||||
RunID: item.RunID,
|
||||
})
|
||||
}
|
||||
|
||||
func sparseRuntimeText(message string, messages []adapters.Message) string {
|
||||
text := strings.TrimSpace(message)
|
||||
if text == "" && len(messages) > 0 {
|
||||
parts := make([]string, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
content := strings.TrimSpace(m.Content)
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
role := strings.ToUpper(strings.TrimSpace(m.Role))
|
||||
if role == "" {
|
||||
role = "MESSAGE"
|
||||
}
|
||||
parts = append(parts, "["+role+"] "+content)
|
||||
}
|
||||
text = strings.Join(parts, "\n")
|
||||
}
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
func sparseRuntimeMemoryID(botID string, now time.Time) string {
|
||||
return botID + ":" + "mem_" + strconv.FormatInt(now.UnixNano(), 10)
|
||||
}
|
||||
|
||||
func sparseRuntimeHash(text string) string {
|
||||
sum := sha256.Sum256([]byte(strings.TrimSpace(text)))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func sparseRuntimeBotID(botID string, filters map[string]any) (string, error) {
|
||||
botID = strings.TrimSpace(botID)
|
||||
if botID == "" {
|
||||
botID = strings.TrimSpace(sparseRuntimeAny(filters, "bot_id"))
|
||||
}
|
||||
if botID == "" {
|
||||
botID = strings.TrimSpace(sparseRuntimeAny(filters, "scopeId"))
|
||||
}
|
||||
if botID == "" {
|
||||
return "", errors.New("bot_id is required")
|
||||
}
|
||||
return botID, nil
|
||||
}
|
||||
|
||||
func sparseRuntimeBotIDFromMemoryID(memoryID string) string {
|
||||
parts := strings.SplitN(strings.TrimSpace(memoryID), ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
|
||||
func sparseRuntimeAny(m map[string]any, key string) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
v, ok := m[key]
|
||||
if !ok || v == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
|
||||
func sparsePointID(botID, sourceID string) string {
|
||||
return uuid.NewSHA1(uuid.NameSpaceURL, []byte(strings.TrimSpace(botID)+"\n"+strings.TrimSpace(sourceID))).String()
|
||||
}
|
||||
@@ -0,0 +1,413 @@
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
qdrantclient "github.com/memohai/memoh/internal/memory/qdrant"
|
||||
"github.com/memohai/memoh/internal/memory/sparse"
|
||||
storefs "github.com/memohai/memoh/internal/memory/storefs"
|
||||
)
|
||||
|
||||
type fakeSparseStore struct {
|
||||
items map[string]storefs.MemoryItem
|
||||
}
|
||||
|
||||
func newFakeSparseStore(items ...storefs.MemoryItem) *fakeSparseStore {
|
||||
store := &fakeSparseStore{items: map[string]storefs.MemoryItem{}}
|
||||
for _, item := range items {
|
||||
store.items[item.ID] = item
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
func (s *fakeSparseStore) PersistMemories(_ context.Context, _ string, items []storefs.MemoryItem, _ map[string]any) error {
|
||||
for _, item := range items {
|
||||
s.items[item.ID] = item
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeSparseStore) ReadAllMemoryFiles(_ context.Context, _ string) ([]storefs.MemoryItem, error) {
|
||||
out := make([]storefs.MemoryItem, 0, len(s.items))
|
||||
for _, item := range s.items {
|
||||
out = append(out, item)
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool { return out[i].ID < out[j].ID })
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *fakeSparseStore) RemoveMemories(_ context.Context, _ string, ids []string) error {
|
||||
for _, id := range ids {
|
||||
delete(s.items, strings.TrimSpace(id))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeSparseStore) RemoveAllMemories(_ context.Context, _ string) error {
|
||||
s.items = map[string]storefs.MemoryItem{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeSparseStore) RebuildFiles(_ context.Context, _ string, items []storefs.MemoryItem, _ map[string]any) error {
|
||||
s.items = map[string]storefs.MemoryItem{}
|
||||
for _, item := range items {
|
||||
s.items[item.ID] = item
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*fakeSparseStore) SyncOverview(context.Context, string) error { return nil }
|
||||
|
||||
func (s *fakeSparseStore) CountMemoryFiles(_ context.Context, _ string) (int, error) {
|
||||
if len(s.items) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
type fakeSparseEncoder struct {
|
||||
lastQuery string
|
||||
}
|
||||
|
||||
func (*fakeSparseEncoder) EncodeDocument(_ context.Context, _ string) (*sparse.SparseVector, error) {
|
||||
return &sparse.SparseVector{Indices: []uint32{1, 2, 3}, Values: []float32{1, 3, 2}}, nil
|
||||
}
|
||||
|
||||
func (*fakeSparseEncoder) EncodeDocuments(_ context.Context, texts []string) ([]sparse.SparseVector, error) {
|
||||
out := make([]sparse.SparseVector, 0, len(texts))
|
||||
for _, text := range texts {
|
||||
_ = text
|
||||
out = append(out, sparse.SparseVector{Indices: []uint32{1, 2, 3}, Values: []float32{1, 3, 2}})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (e *fakeSparseEncoder) EncodeQuery(_ context.Context, text string) (*sparse.SparseVector, error) {
|
||||
e.lastQuery = text
|
||||
return &sparse.SparseVector{Indices: []uint32{9}, Values: []float32{1}}, nil
|
||||
}
|
||||
|
||||
func (*fakeSparseEncoder) Health(context.Context) error { return nil }
|
||||
|
||||
type fakeSparseIndex struct {
|
||||
encoder *fakeSparseEncoder
|
||||
collection string
|
||||
exists bool
|
||||
points map[string]qdrantclient.SearchResult
|
||||
}
|
||||
|
||||
func newFakeSparseIndex(encoder *fakeSparseEncoder) *fakeSparseIndex {
|
||||
return &fakeSparseIndex{
|
||||
encoder: encoder,
|
||||
collection: "memory_sparse_test",
|
||||
points: map[string]qdrantclient.SearchResult{},
|
||||
}
|
||||
}
|
||||
|
||||
func (i *fakeSparseIndex) CollectionName() string { return i.collection }
|
||||
|
||||
func (i *fakeSparseIndex) CollectionExists(context.Context) (bool, error) { return i.exists, nil }
|
||||
|
||||
func (i *fakeSparseIndex) EnsureCollection(context.Context) error {
|
||||
i.exists = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *fakeSparseIndex) Upsert(_ context.Context, id string, _ qdrantclient.SparseVector, payload map[string]string) error {
|
||||
i.exists = true
|
||||
i.points[id] = qdrantclient.SearchResult{
|
||||
ID: id,
|
||||
Score: 1,
|
||||
Payload: payload,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *fakeSparseIndex) Search(_ context.Context, _ qdrantclient.SparseVector, botID string, limit int) ([]qdrantclient.SearchResult, error) {
|
||||
query := strings.ToLower(strings.TrimSpace(i.encoder.lastQuery))
|
||||
results := make([]qdrantclient.SearchResult, 0, len(i.points))
|
||||
for _, point := range i.points {
|
||||
if strings.TrimSpace(point.Payload["bot_id"]) != strings.TrimSpace(botID) {
|
||||
continue
|
||||
}
|
||||
text := strings.ToLower(point.Payload["memory"])
|
||||
if query != "" && !strings.Contains(text, query) {
|
||||
continue
|
||||
}
|
||||
point.Score = 1
|
||||
results = append(results, point)
|
||||
}
|
||||
sort.Slice(results, func(a, b int) bool { return results[a].ID < results[b].ID })
|
||||
if limit > 0 && len(results) > limit {
|
||||
results = results[:limit]
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (i *fakeSparseIndex) Scroll(_ context.Context, botID string, limit int) ([]qdrantclient.SearchResult, error) {
|
||||
results := make([]qdrantclient.SearchResult, 0, len(i.points))
|
||||
for _, point := range i.points {
|
||||
if strings.TrimSpace(point.Payload["bot_id"]) != strings.TrimSpace(botID) {
|
||||
continue
|
||||
}
|
||||
results = append(results, point)
|
||||
}
|
||||
sort.Slice(results, func(a, b int) bool { return results[a].ID < results[b].ID })
|
||||
if limit > 0 && len(results) > limit {
|
||||
results = results[:limit]
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (i *fakeSparseIndex) Count(_ context.Context, botID string) (int, error) {
|
||||
count := 0
|
||||
for _, point := range i.points {
|
||||
if strings.TrimSpace(point.Payload["bot_id"]) == strings.TrimSpace(botID) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (i *fakeSparseIndex) DeleteByIDs(_ context.Context, ids []string) error {
|
||||
for _, id := range ids {
|
||||
delete(i.points, strings.TrimSpace(id))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *fakeSparseIndex) DeleteByBotID(_ context.Context, botID string) error {
|
||||
for id, point := range i.points {
|
||||
if strings.TrimSpace(point.Payload["bot_id"]) == strings.TrimSpace(botID) {
|
||||
delete(i.points, id)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSparseRuntimeAddWritesSourceAndSupportsRecall(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
encoder := &fakeSparseEncoder{}
|
||||
index := newFakeSparseIndex(encoder)
|
||||
store := newFakeSparseStore()
|
||||
runtime := &sparseRuntime{
|
||||
qdrant: index,
|
||||
encoder: encoder,
|
||||
store: store,
|
||||
}
|
||||
|
||||
resp, err := runtime.Add(context.Background(), adapters.AddRequest{
|
||||
BotID: "bot-1",
|
||||
Message: "Ran likes oolong tea",
|
||||
Filters: map[string]any{"scopeId": "bot-1"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error = %v", err)
|
||||
}
|
||||
if len(resp.Results) != 1 {
|
||||
t.Fatalf("expected 1 add result, got %d", len(resp.Results))
|
||||
}
|
||||
item := resp.Results[0]
|
||||
if item.ID == "" {
|
||||
t.Fatal("expected source memory id to be populated")
|
||||
}
|
||||
if _, ok := store.items[item.ID]; !ok {
|
||||
t.Fatalf("expected memory %q to be written to markdown source", item.ID)
|
||||
}
|
||||
point, ok := index.points[sparsePointID("bot-1", item.ID)]
|
||||
if !ok {
|
||||
t.Fatalf("expected qdrant point for source memory %q", item.ID)
|
||||
}
|
||||
if point.Payload["source_entry_id"] != item.ID {
|
||||
t.Fatalf("expected source_entry_id payload %q, got %q", item.ID, point.Payload["source_entry_id"])
|
||||
}
|
||||
if len(item.TopKBuckets) != 0 || len(item.CDFCurve) != 0 {
|
||||
t.Fatalf("expected add response to skip explain stats, got %#v", item)
|
||||
}
|
||||
|
||||
searchResp, err := runtime.Search(context.Background(), adapters.SearchRequest{
|
||||
BotID: "bot-1",
|
||||
Query: "oolong tea",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Search() error = %v", err)
|
||||
}
|
||||
if len(searchResp.Results) != 1 {
|
||||
t.Fatalf("expected 1 search result, got %d", len(searchResp.Results))
|
||||
}
|
||||
if searchResp.Results[0].ID != item.ID {
|
||||
t.Fatalf("expected search result id %q, got %q", item.ID, searchResp.Results[0].ID)
|
||||
}
|
||||
if len(searchResp.Results[0].TopKBuckets) != 0 || len(searchResp.Results[0].CDFCurve) != 0 {
|
||||
t.Fatalf("expected search result to skip explain stats, got %#v", searchResp.Results[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSparseRuntimeRebuildSyncsSourceAndRemovesStalePoints(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
encoder := &fakeSparseEncoder{}
|
||||
index := newFakeSparseIndex(encoder)
|
||||
store := newFakeSparseStore(
|
||||
storefs.MemoryItem{
|
||||
ID: "bot-1:mem_1",
|
||||
Memory: "Ran likes tea",
|
||||
Hash: sparseRuntimeHash("Ran likes tea"),
|
||||
CreatedAt: "2026-03-13T09:00:00Z",
|
||||
UpdatedAt: "2026-03-13T09:00:00Z",
|
||||
},
|
||||
storefs.MemoryItem{
|
||||
ID: "bot-1:mem_2",
|
||||
Memory: "Ran works in Berlin",
|
||||
Hash: sparseRuntimeHash("Ran works in Berlin"),
|
||||
CreatedAt: "2026-03-13T10:00:00Z",
|
||||
UpdatedAt: "2026-03-13T10:00:00Z",
|
||||
},
|
||||
)
|
||||
runtime := &sparseRuntime{
|
||||
qdrant: index,
|
||||
encoder: encoder,
|
||||
store: store,
|
||||
}
|
||||
|
||||
index.points[sparsePointID("bot-1", "bot-1:mem_1")] = qdrantclient.SearchResult{
|
||||
ID: sparsePointID("bot-1", "bot-1:mem_1"),
|
||||
Payload: map[string]string{
|
||||
"bot_id": "bot-1",
|
||||
"memory": "Ran likes tea",
|
||||
"source_entry_id": "bot-1:mem_1",
|
||||
"hash": "outdated",
|
||||
"created_at": "2026-03-13T09:00:00Z",
|
||||
"updated_at": "2026-03-13T09:00:00Z",
|
||||
},
|
||||
}
|
||||
index.points[sparsePointID("bot-1", "bot-1:stale")] = qdrantclient.SearchResult{
|
||||
ID: sparsePointID("bot-1", "bot-1:stale"),
|
||||
Payload: map[string]string{
|
||||
"bot_id": "bot-1",
|
||||
"memory": "stale memory",
|
||||
"source_entry_id": "bot-1:stale",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := runtime.Rebuild(context.Background(), "bot-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Rebuild() error = %v", err)
|
||||
}
|
||||
if result.FsCount != 2 {
|
||||
t.Fatalf("expected fs_count=2, got %d", result.FsCount)
|
||||
}
|
||||
if result.StorageCount != 2 {
|
||||
t.Fatalf("expected storage_count=2, got %d", result.StorageCount)
|
||||
}
|
||||
if result.MissingCount != 1 {
|
||||
t.Fatalf("expected missing_count=1, got %d", result.MissingCount)
|
||||
}
|
||||
if result.RestoredCount != 2 {
|
||||
t.Fatalf("expected restored_count=2, got %d", result.RestoredCount)
|
||||
}
|
||||
if _, ok := index.points[sparsePointID("bot-1", "bot-1:stale")]; ok {
|
||||
t.Fatal("expected stale qdrant point to be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSparseRuntimeGetAllIncludesExplainStats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
encoder := &fakeSparseEncoder{}
|
||||
index := newFakeSparseIndex(encoder)
|
||||
store := newFakeSparseStore(
|
||||
storefs.MemoryItem{
|
||||
ID: "bot-1:mem_1",
|
||||
Memory: "Ran likes tea",
|
||||
Hash: sparseRuntimeHash("Ran likes tea"),
|
||||
CreatedAt: "2026-03-13T09:00:00Z",
|
||||
UpdatedAt: "2026-03-13T09:00:00Z",
|
||||
},
|
||||
)
|
||||
runtime := &sparseRuntime{
|
||||
qdrant: index,
|
||||
encoder: encoder,
|
||||
store: store,
|
||||
}
|
||||
|
||||
resp, err := runtime.GetAll(context.Background(), adapters.GetAllRequest{BotID: "bot-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetAll() error = %v", err)
|
||||
}
|
||||
if len(resp.Results) != 1 {
|
||||
t.Fatalf("expected 1 result, got %d", len(resp.Results))
|
||||
}
|
||||
if len(resp.Results[0].TopKBuckets) == 0 || len(resp.Results[0].CDFCurve) == 0 {
|
||||
t.Fatalf("expected get all result to include explain stats, got %#v", resp.Results[0])
|
||||
}
|
||||
if resp.Results[0].TopKBuckets[0].Index != 2 {
|
||||
t.Fatalf("expected top bucket index 2, got %d", resp.Results[0].TopKBuckets[0].Index)
|
||||
}
|
||||
if got := resp.Results[0].CDFCurve[len(resp.Results[0].CDFCurve)-1].Cumulative; got != 1 {
|
||||
t.Fatalf("expected final CDF cumulative to be 1, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltinProviderMultiTurnRecallUsesSparseSourceRuntime(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
encoder := &fakeSparseEncoder{}
|
||||
index := newFakeSparseIndex(encoder)
|
||||
store := newFakeSparseStore()
|
||||
runtime := &sparseRuntime{
|
||||
qdrant: index,
|
||||
encoder: encoder,
|
||||
store: store,
|
||||
}
|
||||
provider := NewBuiltinProvider(slog.Default(), runtime, nil, nil)
|
||||
|
||||
err := provider.OnAfterChat(context.Background(), adapters.AfterChatRequest{
|
||||
BotID: "bot-1",
|
||||
Messages: []adapters.Message{
|
||||
{Role: "user", Content: "I like oolong tea."},
|
||||
{Role: "assistant", Content: "Noted, you like oolong tea."},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("OnAfterChat round 1 error = %v", err)
|
||||
}
|
||||
err = provider.OnAfterChat(context.Background(), adapters.AfterChatRequest{
|
||||
BotID: "bot-1",
|
||||
Messages: []adapters.Message{
|
||||
{Role: "user", Content: "I am based in Berlin."},
|
||||
{Role: "assistant", Content: "Understood, you are based in Berlin."},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("OnAfterChat round 2 error = %v", err)
|
||||
}
|
||||
|
||||
before, err := provider.OnBeforeChat(context.Background(), adapters.BeforeChatRequest{
|
||||
BotID: "bot-1",
|
||||
Query: "berlin",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("OnBeforeChat() error = %v", err)
|
||||
}
|
||||
if before == nil || !strings.Contains(strings.ToLower(before.ContextText), "berlin") {
|
||||
t.Fatalf("expected recalled context to mention berlin, got %#v", before)
|
||||
}
|
||||
|
||||
before, err = provider.OnBeforeChat(context.Background(), adapters.BeforeChatRequest{
|
||||
BotID: "bot-1",
|
||||
Query: "oolong tea",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("OnBeforeChat() tea error = %v", err)
|
||||
}
|
||||
if before == nil || !strings.Contains(strings.ToLower(before.ContextText), "oolong tea") {
|
||||
t.Fatalf("expected recalled context to mention oolong tea, got %#v", before)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package adapters
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// TruncateSnippet truncates a string to n runes, appending "..." if truncated.
|
||||
func TruncateSnippet(s string, n int) string {
|
||||
trimmed := strings.TrimSpace(s)
|
||||
runes := []rune(trimmed)
|
||||
if len(runes) <= n {
|
||||
return trimmed
|
||||
}
|
||||
return strings.TrimSpace(string(runes[:n])) + "..."
|
||||
}
|
||||
|
||||
// DeduplicateItems removes duplicate MemoryItems by ID.
|
||||
func DeduplicateItems(items []MemoryItem) []MemoryItem {
|
||||
if len(items) == 0 {
|
||||
return items
|
||||
}
|
||||
seen := make(map[string]struct{}, len(items))
|
||||
result := make([]MemoryItem, 0, len(items))
|
||||
for _, item := range items {
|
||||
id := strings.TrimSpace(item.ID)
|
||||
if id == "" {
|
||||
id = strings.TrimSpace(item.Memory)
|
||||
}
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
result = append(result, item)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// StringFromConfig extracts a trimmed string value from a config map.
|
||||
func StringFromConfig(config map[string]any, key string) string {
|
||||
if config == nil {
|
||||
return ""
|
||||
}
|
||||
v, ok := config[key]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
func MergeMetadata(base map[string]any, extra map[string]any) map[string]any {
|
||||
if len(base) == 0 && len(extra) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(base)+len(extra))
|
||||
for k, v := range base {
|
||||
out[k] = v
|
||||
}
|
||||
for k, v := range extra {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func BuildProfileMetadata(userID, channelIdentityID, displayName string) map[string]any {
|
||||
userID = strings.TrimSpace(userID)
|
||||
channelIdentityID = strings.TrimSpace(channelIdentityID)
|
||||
displayName = strings.TrimSpace(displayName)
|
||||
if userID == "" && channelIdentityID == "" && displayName == "" {
|
||||
return nil
|
||||
}
|
||||
out := map[string]any{}
|
||||
if userID != "" {
|
||||
out["profile_user_id"] = userID
|
||||
out["profile_ref"] = fmt.Sprintf("user:%s", userID)
|
||||
} else if channelIdentityID != "" {
|
||||
out["profile_ref"] = fmt.Sprintf("channel_identity:%s", channelIdentityID)
|
||||
}
|
||||
if channelIdentityID != "" {
|
||||
out["profile_channel_identity_id"] = channelIdentityID
|
||||
}
|
||||
if displayName != "" {
|
||||
out["profile_display_name"] = displayName
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,418 @@
|
||||
package mem0
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
)
|
||||
|
||||
const (
|
||||
mem0DefaultBaseURL = "https://api.mem0.ai"
|
||||
mem0OutputFormatV11 = "v1.1"
|
||||
mem0VersionV2 = "v2"
|
||||
mem0BatchDeleteMaxSize = 1000
|
||||
mem0ListPageSize = 1000
|
||||
)
|
||||
|
||||
type mem0Client struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
orgID string
|
||||
projectID string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func newMem0Client(config map[string]any) (*mem0Client, error) {
|
||||
baseURL := adapters.StringFromConfig(config, "base_url")
|
||||
if baseURL == "" {
|
||||
baseURL = mem0DefaultBaseURL
|
||||
}
|
||||
apiKey := adapters.StringFromConfig(config, "api_key")
|
||||
if apiKey == "" {
|
||||
return nil, errors.New("mem0: api_key is required for SaaS")
|
||||
}
|
||||
baseURL = strings.TrimRight(baseURL, "/")
|
||||
return &mem0Client{
|
||||
baseURL: baseURL,
|
||||
apiKey: apiKey,
|
||||
orgID: adapters.StringFromConfig(config, "org_id"),
|
||||
projectID: adapters.StringFromConfig(config, "project_id"),
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type mem0AddRequest struct {
|
||||
Messages []adapters.Message `json:"messages,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Infer *bool `json:"infer,omitempty"`
|
||||
AsyncMode *bool `json:"async_mode,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
OrgID string `json:"org_id,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
}
|
||||
|
||||
type mem0AddEvent struct {
|
||||
ID string `json:"id"`
|
||||
Event string `json:"event,omitempty"`
|
||||
Data struct {
|
||||
Memory string `json:"memory,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type mem0Memory struct {
|
||||
ID string `json:"id"`
|
||||
Memory string `json:"memory"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Hash string `json:"hash,omitempty"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
Score float64 `json:"score,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type mem0SearchRequest struct {
|
||||
Query string `json:"query"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Filters map[string]any `json:"filters,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
OrgID string `json:"org_id,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
}
|
||||
|
||||
type mem0GetAllRequest struct {
|
||||
Filters map[string]any `json:"filters"`
|
||||
Page int `json:"page,omitempty"`
|
||||
PageSize int `json:"page_size,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
OrgID string `json:"org_id,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
}
|
||||
|
||||
type mem0UpdateRequest struct {
|
||||
Text string `json:"text"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type mem0MemoriesResponse struct {
|
||||
Results []mem0Memory `json:"results"`
|
||||
Memories []mem0Memory `json:"memories"`
|
||||
Total int `json:"total,omitempty"`
|
||||
Relations []any `json:"relations,omitempty"`
|
||||
}
|
||||
|
||||
type mem0AddEventsResponse struct {
|
||||
Results []mem0AddEvent `json:"results"`
|
||||
}
|
||||
|
||||
func (c *mem0Client) Add(ctx context.Context, req mem0AddRequest) ([]mem0Memory, error) {
|
||||
if req.OutputFormat == "" {
|
||||
req.OutputFormat = mem0OutputFormatV11
|
||||
}
|
||||
if req.Version == "" {
|
||||
req.Version = mem0VersionV2
|
||||
}
|
||||
req.OrgID = c.orgID
|
||||
req.ProjectID = c.projectID
|
||||
body, err := c.doJSONRaw(ctx, http.MethodPost, "/v1/memories/", req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mem0 add: %w", err)
|
||||
}
|
||||
memories, err := parseMem0AddMemories(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mem0 add: %w", err)
|
||||
}
|
||||
return memories, nil
|
||||
}
|
||||
|
||||
func (c *mem0Client) Search(ctx context.Context, req mem0SearchRequest) ([]mem0Memory, error) {
|
||||
if req.Version == "" {
|
||||
req.Version = mem0VersionV2
|
||||
}
|
||||
req.OrgID = c.orgID
|
||||
req.ProjectID = c.projectID
|
||||
body, err := c.doJSONRaw(ctx, http.MethodPost, "/v2/memories/search/", req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mem0 search: %w", err)
|
||||
}
|
||||
results, err := parseMem0Memories(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mem0 search: %w", err)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (c *mem0Client) GetAll(ctx context.Context, req mem0GetAllRequest) ([]mem0Memory, error) {
|
||||
req.OrgID = c.orgID
|
||||
req.ProjectID = c.projectID
|
||||
if req.OutputFormat == "" {
|
||||
req.OutputFormat = mem0OutputFormatV11
|
||||
}
|
||||
body, err := c.doJSONRaw(ctx, http.MethodPost, "/v2/memories/", req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mem0 get all: %w", err)
|
||||
}
|
||||
results, err := parseMem0Memories(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mem0 get all: %w", err)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (c *mem0Client) ListAllByAgent(ctx context.Context, agentID string) ([]mem0Memory, error) {
|
||||
agentID = strings.TrimSpace(agentID)
|
||||
if agentID == "" {
|
||||
return nil, errors.New("agent_id is required")
|
||||
}
|
||||
all := make([]mem0Memory, 0)
|
||||
seen := map[string]struct{}{}
|
||||
for page := 1; ; page++ {
|
||||
results, err := c.GetAll(ctx, mem0GetAllRequest{
|
||||
Filters: map[string]any{
|
||||
"agent_id": agentID,
|
||||
},
|
||||
Page: page,
|
||||
PageSize: mem0ListPageSize,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, item := range results {
|
||||
id := strings.TrimSpace(item.ID)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
all = append(all, item)
|
||||
}
|
||||
if len(results) < mem0ListPageSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
func (c *mem0Client) Update(ctx context.Context, memoryID string, text string, metadata map[string]any) (*mem0Memory, error) {
|
||||
var result mem0Memory
|
||||
if err := c.doJSON(ctx, http.MethodPut, "/v1/memories/"+memoryID+"/", mem0UpdateRequest{
|
||||
Text: text,
|
||||
Metadata: metadata,
|
||||
}, &result); err != nil {
|
||||
return nil, fmt.Errorf("mem0 update: %w", err)
|
||||
}
|
||||
result = normalizeMem0Memory(result)
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *mem0Client) Delete(ctx context.Context, memoryID string) error {
|
||||
if err := c.doJSON(ctx, http.MethodDelete, "/v1/memories/"+memoryID+"/", nil, nil); err != nil {
|
||||
return fmt.Errorf("mem0 delete: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mem0Client) DeleteAll(ctx context.Context, agentID string) error {
|
||||
q := url.Values{}
|
||||
q.Set("agent_id", agentID)
|
||||
if c.orgID != "" {
|
||||
q.Set("org_id", c.orgID)
|
||||
}
|
||||
if c.projectID != "" {
|
||||
q.Set("project_id", c.projectID)
|
||||
}
|
||||
u := "/v1/memories/?" + q.Encode()
|
||||
if err := c.doJSON(ctx, http.MethodDelete, u, nil, nil); err != nil {
|
||||
return fmt.Errorf("mem0 delete all: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mem0Client) BatchDelete(ctx context.Context, memoryIDs []string) error {
|
||||
if len(memoryIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(memoryIDs) > mem0BatchDeleteMaxSize {
|
||||
return fmt.Errorf("mem0 batch delete: maximum %d memories allowed", mem0BatchDeleteMaxSize)
|
||||
}
|
||||
memories := make([]map[string]string, 0, len(memoryIDs))
|
||||
for _, id := range memoryIDs {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
memories = append(memories, map[string]string{"memory_id": id})
|
||||
}
|
||||
if len(memories) == 0 {
|
||||
return nil
|
||||
}
|
||||
ids := make([]string, 0, len(memories))
|
||||
for _, memory := range memories {
|
||||
ids = append(ids, memory["memory_id"])
|
||||
}
|
||||
if err := c.doJSON(ctx, http.MethodDelete, "/v1/batch/", map[string]any{
|
||||
"memory_ids": ids,
|
||||
"memories": memories,
|
||||
}, nil); err != nil {
|
||||
return fmt.Errorf("mem0 batch delete: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mem0Client) doJSON(ctx context.Context, method, urlPath string, body any, result any) error {
|
||||
respBody, err := c.doJSONRaw(ctx, method, urlPath, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if result != nil && len(respBody) > 0 {
|
||||
if err := json.Unmarshal(respBody, result); err != nil {
|
||||
return fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mem0Client) doJSONRaw(ctx context.Context, method, urlPath string, body any) ([]byte, error) {
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
bodyReader = bytes.NewReader(data)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, method, c.baseURL+urlPath, bodyReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Token "+c.apiKey)
|
||||
if c.orgID != "" {
|
||||
req.Header.Set("X-Org-Id", c.orgID)
|
||||
}
|
||||
if c.projectID != "" {
|
||||
req.Header.Set("X-Project-Id", c.projectID)
|
||||
}
|
||||
resp, err := c.httpClient.Do(req) //nolint:gosec // URL is from admin-configured base_url
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("mem0 API error %d: %s", resp.StatusCode, truncateBody(respBody))
|
||||
}
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
func parseMem0AddMemories(body []byte) ([]mem0Memory, error) {
|
||||
memories, err := parseMem0Memories(body)
|
||||
if err == nil && hasConcreteMem0Memories(memories) {
|
||||
return memories, nil
|
||||
}
|
||||
|
||||
var envelope mem0AddEventsResponse
|
||||
if err := json.Unmarshal(body, &envelope); err == nil && len(envelope.Results) > 0 {
|
||||
return mem0EventsToMemories(envelope.Results), nil
|
||||
}
|
||||
|
||||
var events []mem0AddEvent
|
||||
if err := json.Unmarshal(body, &events); err == nil && len(events) > 0 {
|
||||
return mem0EventsToMemories(events), nil
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
return memories, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func parseMem0Memories(body []byte) ([]mem0Memory, error) {
|
||||
var list []mem0Memory
|
||||
if err := json.Unmarshal(body, &list); err == nil {
|
||||
return normalizeMem0Memories(list), nil
|
||||
}
|
||||
|
||||
var envelope mem0MemoriesResponse
|
||||
if err := json.Unmarshal(body, &envelope); err == nil {
|
||||
switch {
|
||||
case len(envelope.Results) > 0:
|
||||
return normalizeMem0Memories(envelope.Results), nil
|
||||
case len(envelope.Memories) > 0:
|
||||
return normalizeMem0Memories(envelope.Memories), nil
|
||||
default:
|
||||
return []mem0Memory{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("unsupported mem0 response shape")
|
||||
}
|
||||
|
||||
func mem0EventsToMemories(events []mem0AddEvent) []mem0Memory {
|
||||
memories := make([]mem0Memory, 0, len(events))
|
||||
for _, event := range events {
|
||||
memory := strings.TrimSpace(event.Data.Memory)
|
||||
if memory == "" {
|
||||
memory = strings.TrimSpace(event.Data.Text)
|
||||
}
|
||||
memories = append(memories, mem0Memory{
|
||||
ID: strings.TrimSpace(event.ID),
|
||||
Memory: memory,
|
||||
})
|
||||
}
|
||||
return memories
|
||||
}
|
||||
|
||||
func normalizeMem0Memories(memories []mem0Memory) []mem0Memory {
|
||||
for i := range memories {
|
||||
memories[i] = normalizeMem0Memory(memories[i])
|
||||
}
|
||||
return memories
|
||||
}
|
||||
|
||||
func normalizeMem0Memory(memory mem0Memory) mem0Memory {
|
||||
if strings.TrimSpace(memory.Memory) == "" {
|
||||
memory.Memory = strings.TrimSpace(memory.Text)
|
||||
}
|
||||
return memory
|
||||
}
|
||||
|
||||
func hasConcreteMem0Memories(memories []mem0Memory) bool {
|
||||
for _, memory := range memories {
|
||||
if strings.TrimSpace(memory.Memory) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func truncateBody(b []byte) string {
|
||||
if len(b) > 300 {
|
||||
return string(b[:300]) + "..."
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
package mem0
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNewMem0ClientDefaultsToSaaS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, err := newMem0Client(map[string]any{
|
||||
"api_key": "test-key",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("newMem0Client() error = %v", err)
|
||||
}
|
||||
if client.baseURL != mem0DefaultBaseURL {
|
||||
t.Fatalf("baseURL = %q, want %q", client.baseURL, mem0DefaultBaseURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMem0AddMemoriesSupportsEventResponses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := []byte(`[
|
||||
{
|
||||
"id": "mem_123",
|
||||
"event": "ADD",
|
||||
"data": {
|
||||
"memory": "The user likes oolong tea."
|
||||
}
|
||||
}
|
||||
]`)
|
||||
|
||||
memories, err := parseMem0AddMemories(body)
|
||||
if err != nil {
|
||||
t.Fatalf("parseMem0AddMemories() error = %v", err)
|
||||
}
|
||||
if len(memories) != 1 {
|
||||
t.Fatalf("len(memories) = %d, want 1", len(memories))
|
||||
}
|
||||
if memories[0].ID != "mem_123" {
|
||||
t.Fatalf("memory id = %q, want %q", memories[0].ID, "mem_123")
|
||||
}
|
||||
if memories[0].Memory != "The user likes oolong tea." {
|
||||
t.Fatalf("memory text = %q", memories[0].Memory)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMem0MemoriesSupportsEnvelopeResponses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := []byte(`{
|
||||
"results": [
|
||||
{
|
||||
"id": "mem_456",
|
||||
"memory": "The user lives in Shanghai.",
|
||||
"score": 0.92,
|
||||
"agent_id": "bot-1",
|
||||
"run_id": "run-1",
|
||||
"created_at": "2026-01-01T00:00:00Z",
|
||||
"updated_at": "2026-01-02T00:00:00Z"
|
||||
}
|
||||
],
|
||||
"total": 1
|
||||
}`)
|
||||
|
||||
memories, err := parseMem0Memories(body)
|
||||
if err != nil {
|
||||
t.Fatalf("parseMem0Memories() error = %v", err)
|
||||
}
|
||||
if len(memories) != 1 {
|
||||
t.Fatalf("len(memories) = %d, want 1", len(memories))
|
||||
}
|
||||
if memories[0].Score != 0.92 {
|
||||
t.Fatalf("score = %v, want 0.92", memories[0].Score)
|
||||
}
|
||||
if memories[0].AgentID != "bot-1" {
|
||||
t.Fatalf("agent_id = %q, want %q", memories[0].AgentID, "bot-1")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,546 @@
|
||||
package mem0
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"path"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
"github.com/memohai/memoh/internal/mcp"
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
storefs "github.com/memohai/memoh/internal/memory/storefs"
|
||||
)
|
||||
|
||||
const (
|
||||
Mem0Type = "mem0"
|
||||
|
||||
mem0ToolSearchMemory = "search_memory"
|
||||
mem0DefaultLimit = 8
|
||||
mem0MaxLimit = 50
|
||||
mem0ContextMaxItems = 8
|
||||
mem0ContextMaxChars = 220
|
||||
|
||||
mem0SyncMetadataKeySourceEntryID = "source_entry_id"
|
||||
mem0SyncMetadataKeySourceHash = "source_hash"
|
||||
mem0SyncMetadataKeySourceBotID = "source_bot_id"
|
||||
mem0SyncMetadataKeySourceManaged = "source_managed"
|
||||
)
|
||||
|
||||
// Mem0Provider implements adapters.Provider by delegating to the Mem0 SaaS API.
|
||||
type Mem0Provider struct {
|
||||
client *mem0Client
|
||||
logger *slog.Logger
|
||||
store *storefs.Service
|
||||
}
|
||||
|
||||
func NewMem0Provider(log *slog.Logger, config map[string]any, store *storefs.Service) (*Mem0Provider, error) {
|
||||
if log == nil {
|
||||
log = slog.Default()
|
||||
}
|
||||
c, err := newMem0Client(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Mem0Provider{
|
||||
client: c,
|
||||
logger: log.With(slog.String("provider", Mem0Type)),
|
||||
store: store,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (*Mem0Provider) Type() string { return Mem0Type }
|
||||
|
||||
// --- Conversation Hooks ---
|
||||
|
||||
func (p *Mem0Provider) OnBeforeChat(ctx context.Context, req adapters.BeforeChatRequest) (*adapters.BeforeChatResult, error) {
|
||||
query := strings.TrimSpace(req.Query)
|
||||
botID := strings.TrimSpace(req.BotID)
|
||||
if query == "" || botID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
memories, err := p.searchMemories(ctx, query, botID, mem0ContextMaxItems)
|
||||
if err != nil {
|
||||
p.logger.Warn("mem0 search for context failed", slog.Any("error", err))
|
||||
return nil, nil
|
||||
}
|
||||
if len(memories) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("<memory-context>\nRelevant memory context (use when helpful):\n")
|
||||
for i, mem := range memories {
|
||||
if i >= mem0ContextMaxItems {
|
||||
break
|
||||
}
|
||||
text := strings.TrimSpace(mem.Memory)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
sb.WriteString("- ")
|
||||
sb.WriteString(adapters.TruncateSnippet(text, mem0ContextMaxChars))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("</memory-context>")
|
||||
return &adapters.BeforeChatResult{ContextText: sb.String()}, nil
|
||||
}
|
||||
|
||||
func (p *Mem0Provider) OnAfterChat(ctx context.Context, req adapters.AfterChatRequest) error {
|
||||
botID := strings.TrimSpace(req.BotID)
|
||||
if botID == "" || len(req.Messages) == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := p.client.Add(ctx, mem0AddRequest{
|
||||
Messages: req.Messages,
|
||||
AgentID: botID,
|
||||
})
|
||||
if err != nil {
|
||||
p.logger.Warn("mem0 store memory failed", slog.String("bot_id", botID), slog.Any("error", err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- MCP Tools ---
|
||||
|
||||
func (*Mem0Provider) ListTools(_ context.Context, _ mcp.ToolSessionContext) ([]mcp.ToolDescriptor, error) {
|
||||
return []mcp.ToolDescriptor{
|
||||
{
|
||||
Name: mem0ToolSearchMemory,
|
||||
Description: "Search for memories relevant to the current chat",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The query to search memories",
|
||||
},
|
||||
"limit": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Maximum number of memory results",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *Mem0Provider) CallTool(ctx context.Context, session mcp.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) {
|
||||
if toolName != mem0ToolSearchMemory {
|
||||
return nil, mcp.ErrToolNotFound
|
||||
}
|
||||
query := mcp.StringArg(arguments, "query")
|
||||
if query == "" {
|
||||
return mcp.BuildToolErrorResult("query is required"), nil
|
||||
}
|
||||
botID := strings.TrimSpace(session.BotID)
|
||||
if botID == "" {
|
||||
return mcp.BuildToolErrorResult("bot_id is required"), nil
|
||||
}
|
||||
limit := mem0DefaultLimit
|
||||
if value, ok, err := mcp.IntArg(arguments, "limit"); err != nil {
|
||||
return mcp.BuildToolErrorResult(err.Error()), nil
|
||||
} else if ok {
|
||||
limit = value
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = mem0DefaultLimit
|
||||
}
|
||||
if limit > mem0MaxLimit {
|
||||
limit = mem0MaxLimit
|
||||
}
|
||||
|
||||
memories, err := p.searchMemories(ctx, query, botID, limit)
|
||||
if err != nil {
|
||||
return mcp.BuildToolErrorResult("memory search failed"), nil
|
||||
}
|
||||
|
||||
results := make([]map[string]any, 0, len(memories))
|
||||
for _, mem := range memories {
|
||||
results = append(results, map[string]any{
|
||||
"id": mem.ID,
|
||||
"memory": mem.Memory,
|
||||
"score": mem.Score,
|
||||
})
|
||||
}
|
||||
return mcp.BuildToolSuccessResult(map[string]any{
|
||||
"query": query,
|
||||
"total": len(results),
|
||||
"results": results,
|
||||
}), nil
|
||||
}
|
||||
|
||||
// --- CRUD ---
|
||||
|
||||
func (p *Mem0Provider) Add(ctx context.Context, req adapters.AddRequest) (adapters.SearchResponse, error) {
|
||||
agentID := mem0ScopeID(req.BotID, req.AgentID)
|
||||
if agentID == "" {
|
||||
return adapters.SearchResponse{}, errors.New("bot_id or agent_id is required")
|
||||
}
|
||||
addReq := mem0AddRequest{
|
||||
AgentID: agentID,
|
||||
RunID: req.RunID,
|
||||
Infer: req.Infer,
|
||||
}
|
||||
if req.Message != "" {
|
||||
addReq.Messages = []adapters.Message{{Role: "user", Content: req.Message}}
|
||||
} else {
|
||||
addReq.Messages = req.Messages
|
||||
}
|
||||
if req.Metadata != nil {
|
||||
addReq.Metadata = req.Metadata
|
||||
}
|
||||
memories, err := p.client.Add(ctx, addReq)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
return adapters.SearchResponse{Results: mem0ToItems(memories)}, nil
|
||||
}
|
||||
|
||||
func (p *Mem0Provider) Search(ctx context.Context, req adapters.SearchRequest) (adapters.SearchResponse, error) {
|
||||
agentID := mem0ScopeID(req.BotID, req.AgentID)
|
||||
if agentID == "" {
|
||||
return adapters.SearchResponse{}, errors.New("bot_id or agent_id is required")
|
||||
}
|
||||
limit := req.Limit
|
||||
if limit <= 0 {
|
||||
limit = mem0DefaultLimit
|
||||
} else if limit > mem0MaxLimit {
|
||||
limit = mem0MaxLimit
|
||||
}
|
||||
memories, err := p.client.Search(ctx, mem0SearchRequest{
|
||||
Query: req.Query,
|
||||
TopK: limit,
|
||||
Filters: mem0AgentFilter(agentID, req.RunID),
|
||||
})
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
items := mem0ToItems(memories)
|
||||
sort.Slice(items, func(i, j int) bool { return items[i].Score > items[j].Score })
|
||||
return adapters.SearchResponse{Results: adapters.DeduplicateItems(items)}, nil
|
||||
}
|
||||
|
||||
func (p *Mem0Provider) GetAll(ctx context.Context, req adapters.GetAllRequest) (adapters.SearchResponse, error) {
|
||||
agentID := mem0ScopeID(req.BotID, req.AgentID)
|
||||
if agentID == "" {
|
||||
return adapters.SearchResponse{}, errors.New("bot_id or agent_id is required")
|
||||
}
|
||||
getReq := mem0GetAllRequest{
|
||||
Filters: mem0AgentFilter(agentID, req.RunID),
|
||||
}
|
||||
if req.Limit > 0 {
|
||||
getReq.PageSize = req.Limit
|
||||
}
|
||||
memories, err := p.client.GetAll(ctx, getReq)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
items := mem0ToItems(memories)
|
||||
sort.Slice(items, func(i, j int) bool { return items[i].UpdatedAt > items[j].UpdatedAt })
|
||||
return adapters.SearchResponse{Results: items}, nil
|
||||
}
|
||||
|
||||
func (p *Mem0Provider) Update(ctx context.Context, req adapters.UpdateRequest) (adapters.MemoryItem, error) {
|
||||
memoryID := strings.TrimSpace(req.MemoryID)
|
||||
if memoryID == "" {
|
||||
return adapters.MemoryItem{}, errors.New("memory_id is required")
|
||||
}
|
||||
mem, err := p.client.Update(ctx, memoryID, req.Memory, nil)
|
||||
if err != nil {
|
||||
return adapters.MemoryItem{}, err
|
||||
}
|
||||
return mem0ToItem(*mem), nil
|
||||
}
|
||||
|
||||
func (p *Mem0Provider) Delete(ctx context.Context, memoryID string) (adapters.DeleteResponse, error) {
|
||||
if err := p.client.Delete(ctx, strings.TrimSpace(memoryID)); err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
return adapters.DeleteResponse{Message: "Memory deleted successfully"}, nil
|
||||
}
|
||||
|
||||
func (p *Mem0Provider) DeleteBatch(ctx context.Context, memoryIDs []string) (adapters.DeleteResponse, error) {
|
||||
if err := p.client.BatchDelete(ctx, memoryIDs); err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
return adapters.DeleteResponse{Message: "Memories deleted successfully"}, nil
|
||||
}
|
||||
|
||||
func (p *Mem0Provider) DeleteAll(ctx context.Context, req adapters.DeleteAllRequest) (adapters.DeleteResponse, error) {
|
||||
agentID := mem0ScopeID(req.BotID, req.AgentID)
|
||||
if agentID == "" {
|
||||
return adapters.DeleteResponse{}, errors.New("bot_id or agent_id is required")
|
||||
}
|
||||
if err := p.client.DeleteAll(ctx, agentID); err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
return adapters.DeleteResponse{Message: "All memories deleted"}, nil
|
||||
}
|
||||
|
||||
// --- Lifecycle ---
|
||||
|
||||
func (*Mem0Provider) Compact(_ context.Context, _ map[string]any, _ float64, _ int) (adapters.CompactResult, error) {
|
||||
return adapters.CompactResult{}, errors.New("compact is not supported by mem0 provider")
|
||||
}
|
||||
|
||||
func (*Mem0Provider) Usage(_ context.Context, _ map[string]any) (adapters.UsageResponse, error) {
|
||||
return adapters.UsageResponse{}, errors.New("usage is not supported by mem0 provider")
|
||||
}
|
||||
|
||||
func (p *Mem0Provider) Status(ctx context.Context, botID string) (adapters.MemoryStatusResponse, error) {
|
||||
status := adapters.MemoryStatusResponse{
|
||||
ProviderType: Mem0Type,
|
||||
CanManualSync: p.store != nil,
|
||||
SourceDir: path.Join(config.DefaultDataMount, "memory"),
|
||||
OverviewPath: path.Join(config.DefaultDataMount, "MEMORY.md"),
|
||||
}
|
||||
if p.store == nil {
|
||||
return status, nil
|
||||
}
|
||||
fileCount, err := p.store.CountMemoryFiles(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.MemoryStatusResponse{}, err
|
||||
}
|
||||
items, err := p.store.ReadAllMemoryFiles(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.MemoryStatusResponse{}, err
|
||||
}
|
||||
status.MarkdownFileCount = fileCount
|
||||
status.SourceCount = len(mem0CanonicalStoreItems(items))
|
||||
remote, err := p.client.ListAllByAgent(ctx, strings.TrimSpace(botID))
|
||||
if err != nil {
|
||||
return adapters.MemoryStatusResponse{}, err
|
||||
}
|
||||
status.IndexedCount = len(remote)
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (p *Mem0Provider) Rebuild(ctx context.Context, botID string) (adapters.RebuildResult, error) {
|
||||
if p.store == nil {
|
||||
return adapters.RebuildResult{}, errors.New("memory filesystem not configured")
|
||||
}
|
||||
items, err := p.store.ReadAllMemoryFiles(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
if err := p.store.SyncOverview(ctx, botID); err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
return p.syncSourceItems(ctx, strings.TrimSpace(botID), items)
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func mem0ToItems(memories []mem0Memory) []adapters.MemoryItem {
|
||||
items := make([]adapters.MemoryItem, 0, len(memories))
|
||||
for _, m := range memories {
|
||||
items = append(items, mem0ToItem(m))
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func mem0ToItem(m mem0Memory) adapters.MemoryItem {
|
||||
return adapters.MemoryItem{
|
||||
ID: m.ID,
|
||||
Memory: m.Memory,
|
||||
Hash: m.Hash,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
Score: m.Score,
|
||||
Metadata: m.Metadata,
|
||||
BotID: m.AgentID,
|
||||
AgentID: m.AgentID,
|
||||
RunID: m.RunID,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Mem0Provider) syncSourceItems(ctx context.Context, botID string, items []storefs.MemoryItem) (adapters.RebuildResult, error) {
|
||||
canonical := mem0CanonicalStoreItems(items)
|
||||
existing, err := p.client.ListAllByAgent(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
existingBySource := make(map[string]mem0Memory, len(existing))
|
||||
for _, memory := range existing {
|
||||
sourceID := mem0SourceEntryID(memory.Metadata)
|
||||
if sourceID == "" {
|
||||
continue
|
||||
}
|
||||
existingBySource[sourceID] = memory
|
||||
}
|
||||
sourceIDs := make(map[string]struct{}, len(canonical))
|
||||
missingCount := 0
|
||||
restoredCount := 0
|
||||
for _, item := range canonical {
|
||||
sourceIDs[item.ID] = struct{}{}
|
||||
metadata := mem0SourceMetadata(botID, item)
|
||||
existingMemory, ok := existingBySource[item.ID]
|
||||
if !ok {
|
||||
missingCount++
|
||||
restoredCount++
|
||||
if _, err := p.client.Add(ctx, mem0AddRequest{
|
||||
Messages: []adapters.Message{{Role: "system", Content: item.Memory}},
|
||||
AgentID: botID,
|
||||
RunID: item.RunID,
|
||||
Metadata: metadata,
|
||||
Infer: mem0BoolPtr(false),
|
||||
AsyncMode: mem0BoolPtr(false),
|
||||
}); err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if mem0SourceMemoryMatches(existingMemory, item, botID) {
|
||||
continue
|
||||
}
|
||||
restoredCount++
|
||||
if _, err := p.client.Update(ctx, existingMemory.ID, item.Memory, metadata); err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
}
|
||||
staleIDs := make([]string, 0)
|
||||
for _, memory := range existing {
|
||||
sourceID := mem0SourceEntryID(memory.Metadata)
|
||||
if sourceID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := sourceIDs[sourceID]; ok {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(memory.ID) != "" {
|
||||
staleIDs = append(staleIDs, memory.ID)
|
||||
}
|
||||
}
|
||||
for len(staleIDs) > 0 {
|
||||
chunkSize := mem0BatchDeleteMaxSize
|
||||
if len(staleIDs) < chunkSize {
|
||||
chunkSize = len(staleIDs)
|
||||
}
|
||||
if err := p.client.BatchDelete(ctx, staleIDs[:chunkSize]); err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
staleIDs = staleIDs[chunkSize:]
|
||||
}
|
||||
remote, err := p.client.ListAllByAgent(ctx, botID)
|
||||
if err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
return adapters.RebuildResult{
|
||||
FsCount: len(canonical),
|
||||
StorageCount: len(remote),
|
||||
MissingCount: missingCount,
|
||||
RestoredCount: restoredCount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *Mem0Provider) searchMemories(ctx context.Context, query, agentID string, limit int) ([]adapters.MemoryItem, error) {
|
||||
memories, err := p.client.Search(ctx, mem0SearchRequest{
|
||||
Query: query,
|
||||
TopK: limit,
|
||||
Filters: mem0AgentFilter(agentID, ""),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items := mem0ToItems(memories)
|
||||
sort.Slice(items, func(i, j int) bool { return items[i].Score > items[j].Score })
|
||||
return adapters.DeduplicateItems(items), nil
|
||||
}
|
||||
|
||||
func mem0ScopeID(botID, agentID string) string {
|
||||
if value := strings.TrimSpace(botID); value != "" {
|
||||
return value
|
||||
}
|
||||
return strings.TrimSpace(agentID)
|
||||
}
|
||||
|
||||
func mem0AgentFilter(agentID, runID string) map[string]any {
|
||||
filter := map[string]any{
|
||||
"agent_id": strings.TrimSpace(agentID),
|
||||
}
|
||||
if strings.TrimSpace(runID) != "" {
|
||||
filter["run_id"] = strings.TrimSpace(runID)
|
||||
}
|
||||
return filter
|
||||
}
|
||||
|
||||
func mem0CanonicalStoreItems(items []storefs.MemoryItem) []storefs.MemoryItem {
|
||||
result := make([]storefs.MemoryItem, 0, len(items))
|
||||
for _, item := range items {
|
||||
item.ID = strings.TrimSpace(item.ID)
|
||||
item.Memory = strings.TrimSpace(item.Memory)
|
||||
if item.ID == "" || item.Memory == "" {
|
||||
continue
|
||||
}
|
||||
result = append(result, item)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func mem0SourceMetadata(botID string, item storefs.MemoryItem) map[string]any {
|
||||
metadata := make(map[string]any, len(item.Metadata)+4)
|
||||
for key, value := range item.Metadata {
|
||||
metadata[key] = value
|
||||
}
|
||||
metadata[mem0SyncMetadataKeySourceEntryID] = item.ID
|
||||
metadata[mem0SyncMetadataKeySourceHash] = strings.TrimSpace(item.Hash)
|
||||
metadata[mem0SyncMetadataKeySourceBotID] = strings.TrimSpace(botID)
|
||||
metadata[mem0SyncMetadataKeySourceManaged] = true
|
||||
return metadata
|
||||
}
|
||||
|
||||
func mem0SourceEntryID(metadata map[string]any) string {
|
||||
return strings.TrimSpace(mem0MetadataString(metadata, mem0SyncMetadataKeySourceEntryID))
|
||||
}
|
||||
|
||||
func mem0SourceMemoryMatches(memory mem0Memory, item storefs.MemoryItem, botID string) bool {
|
||||
if strings.TrimSpace(memory.Memory) != strings.TrimSpace(item.Memory) {
|
||||
return false
|
||||
}
|
||||
metadata := memory.Metadata
|
||||
if mem0SourceEntryID(metadata) != strings.TrimSpace(item.ID) {
|
||||
return false
|
||||
}
|
||||
if mem0MetadataString(metadata, mem0SyncMetadataKeySourceHash) != strings.TrimSpace(item.Hash) {
|
||||
return false
|
||||
}
|
||||
if mem0MetadataString(metadata, mem0SyncMetadataKeySourceBotID) != strings.TrimSpace(botID) {
|
||||
return false
|
||||
}
|
||||
return mem0MetadataBool(metadata, mem0SyncMetadataKeySourceManaged)
|
||||
}
|
||||
|
||||
func mem0MetadataString(metadata map[string]any, key string) string {
|
||||
if metadata == nil {
|
||||
return ""
|
||||
}
|
||||
raw, ok := metadata[key]
|
||||
if !ok || raw == nil {
|
||||
return ""
|
||||
}
|
||||
if value, ok := raw.(string); ok {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
return strings.TrimSpace(strings.Trim(fmt.Sprintf("%v", raw), "\""))
|
||||
}
|
||||
|
||||
func mem0MetadataBool(metadata map[string]any, key string) bool {
|
||||
if metadata == nil {
|
||||
return false
|
||||
}
|
||||
raw, ok := metadata[key]
|
||||
if !ok || raw == nil {
|
||||
return false
|
||||
}
|
||||
value, ok := raw.(bool)
|
||||
return ok && value
|
||||
}
|
||||
|
||||
func mem0BoolPtr(v bool) *bool {
|
||||
return &v
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
package openviking
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
)
|
||||
|
||||
type openVikingClient struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func newOpenVikingClient(config map[string]any) (*openVikingClient, error) {
|
||||
baseURL := adapters.StringFromConfig(config, "base_url")
|
||||
if baseURL == "" {
|
||||
return nil, errors.New("openviking: base_url is required")
|
||||
}
|
||||
baseURL = strings.TrimRight(baseURL, "/")
|
||||
return &openVikingClient{
|
||||
baseURL: baseURL,
|
||||
apiKey: adapters.StringFromConfig(config, "api_key"),
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ovMemory struct {
|
||||
ID string `json:"id"`
|
||||
Content string `json:"content"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Score float64 `json:"score,omitempty"`
|
||||
}
|
||||
|
||||
type ovAddRequest struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ovSearchRequest struct {
|
||||
Query string `json:"query"`
|
||||
AgentID string `json:"agent_id"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
}
|
||||
|
||||
type ovUpdateRequest struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func (c *openVikingClient) Add(ctx context.Context, agentID, content string) (*ovMemory, error) {
|
||||
var result ovMemory
|
||||
if err := c.doJSON(ctx, http.MethodPost, "/memories", ovAddRequest{
|
||||
AgentID: agentID,
|
||||
Content: content,
|
||||
}, &result); err != nil {
|
||||
return nil, fmt.Errorf("openviking add: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *openVikingClient) Search(ctx context.Context, agentID, query string, limit int) ([]ovMemory, error) {
|
||||
var results []ovMemory
|
||||
if err := c.doJSON(ctx, http.MethodPost, "/memories/search", ovSearchRequest{
|
||||
Query: query,
|
||||
AgentID: agentID,
|
||||
Limit: limit,
|
||||
}, &results); err != nil {
|
||||
return nil, fmt.Errorf("openviking search: %w", err)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (c *openVikingClient) GetAll(ctx context.Context, agentID string, limit int) ([]ovMemory, error) {
|
||||
u := "/memories?agent_id=" + url.QueryEscape(agentID)
|
||||
if limit > 0 {
|
||||
u += fmt.Sprintf("&limit=%d", limit)
|
||||
}
|
||||
var results []ovMemory
|
||||
if err := c.doJSON(ctx, http.MethodGet, u, nil, &results); err != nil {
|
||||
return nil, fmt.Errorf("openviking get all: %w", err)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (c *openVikingClient) Update(ctx context.Context, memoryID, content string) (*ovMemory, error) {
|
||||
var result ovMemory
|
||||
if err := c.doJSON(ctx, http.MethodPut, "/memories/"+memoryID, ovUpdateRequest{Content: content}, &result); err != nil {
|
||||
return nil, fmt.Errorf("openviking update: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *openVikingClient) Delete(ctx context.Context, memoryID string) error {
|
||||
if err := c.doJSON(ctx, http.MethodDelete, "/memories/"+memoryID, nil, nil); err != nil {
|
||||
return fmt.Errorf("openviking delete: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *openVikingClient) DeleteAll(ctx context.Context, agentID string) error {
|
||||
u := "/memories?agent_id=" + url.QueryEscape(agentID)
|
||||
if err := c.doJSON(ctx, http.MethodDelete, u, nil, nil); err != nil {
|
||||
return fmt.Errorf("openviking delete all: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *openVikingClient) doJSON(ctx context.Context, method, urlPath string, body any, result any) error {
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
bodyReader = bytes.NewReader(data)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, method, c.baseURL+urlPath, bodyReader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if c.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
}
|
||||
resp, err := c.httpClient.Do(req) //nolint:gosec // URL is from admin-configured base_url
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return fmt.Errorf("openviking API error %d: %s", resp.StatusCode, truncateBody(respBody))
|
||||
}
|
||||
if result != nil && len(respBody) > 0 {
|
||||
if err := json.Unmarshal(respBody, result); err != nil {
|
||||
return fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func truncateBody(b []byte) string {
|
||||
if len(b) > 300 {
|
||||
return string(b[:300]) + "..."
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
@@ -0,0 +1,309 @@
|
||||
package openviking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/mcp"
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
)
|
||||
|
||||
const (
|
||||
OpenVikingType = "openviking"
|
||||
|
||||
ovToolSearchMemory = "search_memory"
|
||||
ovDefaultLimit = 10
|
||||
ovMaxLimit = 50
|
||||
ovContextMaxItems = 8
|
||||
ovContextMaxChars = 220
|
||||
)
|
||||
|
||||
// OpenVikingProvider implements adapters.Provider by delegating to an OpenViking API (self-hosted or SaaS).
|
||||
type OpenVikingProvider struct {
|
||||
client *openVikingClient
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewOpenVikingProvider(log *slog.Logger, config map[string]any) (*OpenVikingProvider, error) {
|
||||
if log == nil {
|
||||
log = slog.Default()
|
||||
}
|
||||
c, err := newOpenVikingClient(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OpenVikingProvider{
|
||||
client: c,
|
||||
logger: log.With(slog.String("provider", OpenVikingType)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (*OpenVikingProvider) Type() string { return OpenVikingType }
|
||||
|
||||
// --- Conversation Hooks ---
|
||||
|
||||
func (p *OpenVikingProvider) OnBeforeChat(ctx context.Context, req adapters.BeforeChatRequest) (*adapters.BeforeChatResult, error) {
|
||||
query := strings.TrimSpace(req.Query)
|
||||
botID := strings.TrimSpace(req.BotID)
|
||||
if query == "" || botID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
memories, err := p.client.Search(ctx, botID, query, ovContextMaxItems)
|
||||
if err != nil {
|
||||
p.logger.Warn("openviking search for context failed", slog.Any("error", err))
|
||||
return nil, nil
|
||||
}
|
||||
if len(memories) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("<memory-context>\nRelevant memory context (use when helpful):\n")
|
||||
for i, mem := range memories {
|
||||
if i >= ovContextMaxItems {
|
||||
break
|
||||
}
|
||||
text := strings.TrimSpace(mem.Content)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
sb.WriteString("- ")
|
||||
sb.WriteString(adapters.TruncateSnippet(text, ovContextMaxChars))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("</memory-context>")
|
||||
return &adapters.BeforeChatResult{ContextText: sb.String()}, nil
|
||||
}
|
||||
|
||||
func (p *OpenVikingProvider) OnAfterChat(ctx context.Context, req adapters.AfterChatRequest) error {
|
||||
botID := strings.TrimSpace(req.BotID)
|
||||
if botID == "" || len(req.Messages) == 0 {
|
||||
return nil
|
||||
}
|
||||
var parts []string
|
||||
for _, msg := range req.Messages {
|
||||
content := strings.TrimSpace(msg.Content)
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
role := strings.ToUpper(strings.TrimSpace(msg.Role))
|
||||
if role == "" {
|
||||
role = "MESSAGE"
|
||||
}
|
||||
parts = append(parts, "["+role+"] "+content)
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := p.client.Add(ctx, botID, strings.Join(parts, "\n"))
|
||||
if err != nil {
|
||||
p.logger.Warn("openviking store memory failed", slog.String("bot_id", botID), slog.Any("error", err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- MCP Tools ---
|
||||
|
||||
func (*OpenVikingProvider) ListTools(_ context.Context, _ mcp.ToolSessionContext) ([]mcp.ToolDescriptor, error) {
|
||||
return []mcp.ToolDescriptor{
|
||||
{
|
||||
Name: ovToolSearchMemory,
|
||||
Description: "Search for memories relevant to the current chat",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The query to search memories",
|
||||
},
|
||||
"limit": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Maximum number of memory results",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *OpenVikingProvider) CallTool(ctx context.Context, session mcp.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) {
|
||||
if toolName != ovToolSearchMemory {
|
||||
return nil, mcp.ErrToolNotFound
|
||||
}
|
||||
query := mcp.StringArg(arguments, "query")
|
||||
if query == "" {
|
||||
return mcp.BuildToolErrorResult("query is required"), nil
|
||||
}
|
||||
botID := strings.TrimSpace(session.BotID)
|
||||
if botID == "" {
|
||||
return mcp.BuildToolErrorResult("bot_id is required"), nil
|
||||
}
|
||||
limit := ovDefaultLimit
|
||||
if value, ok, err := mcp.IntArg(arguments, "limit"); err != nil {
|
||||
return mcp.BuildToolErrorResult(err.Error()), nil
|
||||
} else if ok {
|
||||
limit = value
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = ovDefaultLimit
|
||||
}
|
||||
if limit > ovMaxLimit {
|
||||
limit = ovMaxLimit
|
||||
}
|
||||
memories, err := p.client.Search(ctx, botID, query, limit)
|
||||
if err != nil {
|
||||
return mcp.BuildToolErrorResult("memory search failed"), nil
|
||||
}
|
||||
results := make([]map[string]any, 0, len(memories))
|
||||
for _, mem := range memories {
|
||||
results = append(results, map[string]any{
|
||||
"id": mem.ID,
|
||||
"memory": mem.Content,
|
||||
"score": mem.Score,
|
||||
})
|
||||
}
|
||||
return mcp.BuildToolSuccessResult(map[string]any{
|
||||
"query": query,
|
||||
"total": len(results),
|
||||
"results": results,
|
||||
}), nil
|
||||
}
|
||||
|
||||
// --- CRUD ---
|
||||
|
||||
func (p *OpenVikingProvider) Add(ctx context.Context, req adapters.AddRequest) (adapters.SearchResponse, error) {
|
||||
botID := strings.TrimSpace(req.BotID)
|
||||
if botID == "" {
|
||||
return adapters.SearchResponse{}, errors.New("bot_id is required")
|
||||
}
|
||||
text := strings.TrimSpace(req.Message)
|
||||
if text == "" && len(req.Messages) > 0 {
|
||||
parts := make([]string, 0, len(req.Messages))
|
||||
for _, m := range req.Messages {
|
||||
content := strings.TrimSpace(m.Content)
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
role := strings.ToUpper(strings.TrimSpace(m.Role))
|
||||
if role == "" {
|
||||
role = "MESSAGE"
|
||||
}
|
||||
parts = append(parts, "["+role+"] "+content)
|
||||
}
|
||||
text = strings.Join(parts, "\n")
|
||||
}
|
||||
if text == "" {
|
||||
return adapters.SearchResponse{}, errors.New("message is required")
|
||||
}
|
||||
mem, err := p.client.Add(ctx, botID, text)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
return adapters.SearchResponse{Results: []adapters.MemoryItem{ovToItem(*mem)}}, nil
|
||||
}
|
||||
|
||||
func (p *OpenVikingProvider) Search(ctx context.Context, req adapters.SearchRequest) (adapters.SearchResponse, error) {
|
||||
botID := strings.TrimSpace(req.BotID)
|
||||
if botID == "" {
|
||||
return adapters.SearchResponse{}, errors.New("bot_id is required")
|
||||
}
|
||||
limit := req.Limit
|
||||
if limit <= 0 {
|
||||
limit = ovDefaultLimit
|
||||
} else if limit > ovMaxLimit {
|
||||
limit = ovMaxLimit
|
||||
}
|
||||
memories, err := p.client.Search(ctx, botID, req.Query, limit)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
return adapters.SearchResponse{Results: ovToItems(memories)}, nil
|
||||
}
|
||||
|
||||
func (p *OpenVikingProvider) GetAll(ctx context.Context, req adapters.GetAllRequest) (adapters.SearchResponse, error) {
|
||||
botID := strings.TrimSpace(req.BotID)
|
||||
if botID == "" {
|
||||
return adapters.SearchResponse{}, errors.New("bot_id is required")
|
||||
}
|
||||
memories, err := p.client.GetAll(ctx, botID, req.Limit)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
items := ovToItems(memories)
|
||||
sort.Slice(items, func(i, j int) bool { return items[i].UpdatedAt > items[j].UpdatedAt })
|
||||
return adapters.SearchResponse{Results: items}, nil
|
||||
}
|
||||
|
||||
func (p *OpenVikingProvider) Update(ctx context.Context, req adapters.UpdateRequest) (adapters.MemoryItem, error) {
|
||||
memoryID := strings.TrimSpace(req.MemoryID)
|
||||
if memoryID == "" {
|
||||
return adapters.MemoryItem{}, errors.New("memory_id is required")
|
||||
}
|
||||
mem, err := p.client.Update(ctx, memoryID, req.Memory)
|
||||
if err != nil {
|
||||
return adapters.MemoryItem{}, err
|
||||
}
|
||||
return ovToItem(*mem), nil
|
||||
}
|
||||
|
||||
func (p *OpenVikingProvider) Delete(ctx context.Context, memoryID string) (adapters.DeleteResponse, error) {
|
||||
if err := p.client.Delete(ctx, strings.TrimSpace(memoryID)); err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
return adapters.DeleteResponse{Message: "Memory deleted successfully"}, nil
|
||||
}
|
||||
|
||||
func (p *OpenVikingProvider) DeleteBatch(ctx context.Context, memoryIDs []string) (adapters.DeleteResponse, error) {
|
||||
for _, id := range memoryIDs {
|
||||
if err := p.client.Delete(ctx, strings.TrimSpace(id)); err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
}
|
||||
return adapters.DeleteResponse{Message: "Memories deleted successfully"}, nil
|
||||
}
|
||||
|
||||
func (p *OpenVikingProvider) DeleteAll(ctx context.Context, req adapters.DeleteAllRequest) (adapters.DeleteResponse, error) {
|
||||
botID := strings.TrimSpace(req.BotID)
|
||||
if botID == "" {
|
||||
return adapters.DeleteResponse{}, errors.New("bot_id is required")
|
||||
}
|
||||
if err := p.client.DeleteAll(ctx, botID); err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
return adapters.DeleteResponse{Message: "All memories deleted"}, nil
|
||||
}
|
||||
|
||||
// --- Lifecycle ---
|
||||
|
||||
func (*OpenVikingProvider) Compact(_ context.Context, _ map[string]any, _ float64, _ int) (adapters.CompactResult, error) {
|
||||
return adapters.CompactResult{}, errors.New("compact is not supported by openviking provider")
|
||||
}
|
||||
|
||||
func (*OpenVikingProvider) Usage(_ context.Context, _ map[string]any) (adapters.UsageResponse, error) {
|
||||
return adapters.UsageResponse{}, errors.New("usage is not supported by openviking provider")
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func ovToItems(memories []ovMemory) []adapters.MemoryItem {
|
||||
items := make([]adapters.MemoryItem, 0, len(memories))
|
||||
for _, m := range memories {
|
||||
items = append(items, ovToItem(m))
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func ovToItem(m ovMemory) adapters.MemoryItem {
|
||||
return adapters.MemoryItem{
|
||||
ID: m.ID,
|
||||
Memory: m.Content,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
Metadata: m.Metadata,
|
||||
Score: m.Score,
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package provider
|
||||
package adapters
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
// Provider is the unified interface for memory systems. Each provider type
|
||||
// (builtin, mem0, openmemory, etc.) implements this independently with its
|
||||
// (builtin, mem0, openviking, etc.) implements this independently with its
|
||||
// own storage, retrieval, and tool logic.
|
||||
type Provider interface {
|
||||
// Type returns the provider type identifier (e.g. "builtin", "mem0").
|
||||
@@ -15,20 +15,12 @@ type Provider interface {
|
||||
|
||||
// --- Conversation Hooks ---
|
||||
|
||||
// OnBeforeChat is called before sending to the agent gateway.
|
||||
// It returns memory context to inject into the conversation, or nil if none.
|
||||
OnBeforeChat(ctx context.Context, req BeforeChatRequest) (*BeforeChatResult, error)
|
||||
|
||||
// OnAfterChat is called after receiving the gateway response.
|
||||
// It extracts facts from the conversation and stores them.
|
||||
OnAfterChat(ctx context.Context, req AfterChatRequest) error
|
||||
|
||||
// --- MCP Tools ---
|
||||
|
||||
// ListTools returns MCP tool descriptors provided by this memory provider.
|
||||
ListTools(ctx context.Context, session mcp.ToolSessionContext) ([]mcp.ToolDescriptor, error)
|
||||
|
||||
// CallTool executes an MCP tool owned by this memory provider.
|
||||
CallTool(ctx context.Context, session mcp.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error)
|
||||
|
||||
// --- CRUD ---
|
||||
@@ -46,3 +38,10 @@ type Provider interface {
|
||||
Compact(ctx context.Context, filters map[string]any, ratio float64, decayDays int) (CompactResult, error)
|
||||
Usage(ctx context.Context, filters map[string]any) (UsageResponse, error)
|
||||
}
|
||||
|
||||
// SourceSyncProvider is implemented by providers that can report runtime status
|
||||
// and rebuild derived storage from a canonical source of truth.
|
||||
type SourceSyncProvider interface {
|
||||
Status(ctx context.Context, botID string) (MemoryStatusResponse, error)
|
||||
Rebuild(ctx context.Context, botID string) (RebuildResult, error)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package provider
|
||||
package adapters
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -0,0 +1,348 @@
|
||||
package adapters
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
qdrantclient "github.com/memohai/memoh/internal/memory/qdrant"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
queries *sqlc.Queries
|
||||
registry *Registry
|
||||
logger *slog.Logger
|
||||
cfg config.Config
|
||||
}
|
||||
|
||||
func NewService(log *slog.Logger, queries *sqlc.Queries, cfg config.Config) *Service {
|
||||
return &Service{
|
||||
queries: queries,
|
||||
logger: log.With(slog.String("service", "memory_providers")),
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// SetRegistry configures the runtime registry so that CRUD operations
|
||||
// can instantiate/evict provider instances automatically.
|
||||
func (s *Service) SetRegistry(registry *Registry) {
|
||||
s.registry = registry
|
||||
}
|
||||
|
||||
func (*Service) ListMeta(_ context.Context) []ProviderMeta {
|
||||
return []ProviderMeta{
|
||||
{
|
||||
Provider: string(ProviderBuiltin),
|
||||
DisplayName: "Built-in",
|
||||
ConfigSchema: ProviderConfigSchema{
|
||||
Fields: map[string]ProviderFieldSchema{
|
||||
"memory_mode": {
|
||||
Type: "select",
|
||||
Title: "Memory Mode",
|
||||
Description: "off = file-based, sparse = Qdrant sparse vectors, dense = embedding API + Qdrant dense vectors",
|
||||
Required: false,
|
||||
},
|
||||
"embedding_model_id": {
|
||||
Type: "model_select",
|
||||
Title: "Embedding Model",
|
||||
Description: "Embedding model for dense vector search (dense mode only)",
|
||||
Required: false,
|
||||
},
|
||||
"qdrant_collection": {
|
||||
Type: "string",
|
||||
Title: "Qdrant Collection",
|
||||
Description: "Qdrant collection name for sparse mode. Defaults to memory_sparse.",
|
||||
Required: false,
|
||||
Example: "memory_sparse",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Provider: string(ProviderMem0),
|
||||
DisplayName: "Mem0",
|
||||
ConfigSchema: ProviderConfigSchema{
|
||||
Fields: map[string]ProviderFieldSchema{
|
||||
"base_url": {
|
||||
Type: "string",
|
||||
Title: "Base URL",
|
||||
Description: "Mem0 SaaS API base URL. Defaults to https://api.mem0.ai when empty.",
|
||||
Required: false,
|
||||
Example: "https://api.mem0.ai",
|
||||
},
|
||||
"api_key": {
|
||||
Type: "string",
|
||||
Title: "API Key",
|
||||
Description: "API key for Mem0 SaaS authentication",
|
||||
Required: true,
|
||||
Secret: true,
|
||||
},
|
||||
"org_id": {
|
||||
Type: "string",
|
||||
Title: "Organization ID",
|
||||
Description: "Organization ID for Mem0 SaaS workspace context",
|
||||
},
|
||||
"project_id": {
|
||||
Type: "string",
|
||||
Title: "Project ID",
|
||||
Description: "Project ID for Mem0 SaaS workspace context",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Provider: string(ProviderOpenViking),
|
||||
DisplayName: "OpenViking",
|
||||
ConfigSchema: ProviderConfigSchema{
|
||||
Fields: map[string]ProviderFieldSchema{
|
||||
"base_url": {
|
||||
Type: "string",
|
||||
Title: "Base URL",
|
||||
Description: "OpenViking API base URL (self-hosted or SaaS)",
|
||||
Required: true,
|
||||
Example: "http://openviking:8088",
|
||||
},
|
||||
"api_key": {
|
||||
Type: "string",
|
||||
Title: "API Key",
|
||||
Description: "API key for OpenViking authentication",
|
||||
Secret: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Create(ctx context.Context, req ProviderCreateRequest) (ProviderGetResponse, error) {
|
||||
if !isValidProviderType(req.Provider) {
|
||||
return ProviderGetResponse{}, fmt.Errorf("invalid provider type: %s", req.Provider)
|
||||
}
|
||||
configJSON, err := json.Marshal(req.Config)
|
||||
if err != nil {
|
||||
return ProviderGetResponse{}, fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
row, err := s.queries.CreateMemoryProvider(ctx, sqlc.CreateMemoryProviderParams{
|
||||
Name: strings.TrimSpace(req.Name),
|
||||
Provider: string(req.Provider),
|
||||
Config: configJSON,
|
||||
IsDefault: false,
|
||||
})
|
||||
if err != nil {
|
||||
return ProviderGetResponse{}, fmt.Errorf("create memory provider: %w", err)
|
||||
}
|
||||
resp := s.toGetResponse(row)
|
||||
s.tryInstantiate(resp.ID, resp.Provider, resp.Config)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *Service) Get(ctx context.Context, id string) (ProviderGetResponse, error) {
|
||||
pgID, err := db.ParseUUID(id)
|
||||
if err != nil {
|
||||
return ProviderGetResponse{}, err
|
||||
}
|
||||
row, err := s.queries.GetMemoryProviderByID(ctx, pgID)
|
||||
if err != nil {
|
||||
return ProviderGetResponse{}, fmt.Errorf("get memory provider: %w", err)
|
||||
}
|
||||
return s.toGetResponse(row), nil
|
||||
}
|
||||
|
||||
func (s *Service) Status(ctx context.Context, id string) (ProviderStatusResponse, error) {
|
||||
resp, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return ProviderStatusResponse{}, err
|
||||
}
|
||||
status := ProviderStatusResponse{
|
||||
ProviderType: resp.Provider,
|
||||
}
|
||||
if resp.Provider != string(ProviderBuiltin) {
|
||||
return status, nil
|
||||
}
|
||||
status.MemoryMode = StringFromConfig(resp.Config, "memory_mode")
|
||||
status.EmbeddingModelID = StringFromConfig(resp.Config, "embedding_model_id")
|
||||
collections := []string{"memory_sparse", "memory_dense"}
|
||||
status.Collections = make([]ProviderCollectionStatus, 0, len(collections))
|
||||
for _, collection := range collections {
|
||||
collStatus := ProviderCollectionStatus{Name: collection}
|
||||
host, port := parseQdrantHostPort(s.cfg.Qdrant.BaseURL)
|
||||
client, err := qdrantclient.NewClient(host, port, s.cfg.Qdrant.APIKey, collection)
|
||||
if err != nil {
|
||||
collStatus.Qdrant.Error = err.Error()
|
||||
status.Collections = append(status.Collections, collStatus)
|
||||
continue
|
||||
}
|
||||
exists, err := client.CollectionExists(ctx)
|
||||
if err != nil {
|
||||
collStatus.Qdrant.Error = err.Error()
|
||||
status.Collections = append(status.Collections, collStatus)
|
||||
continue
|
||||
}
|
||||
collStatus.Qdrant.OK = true
|
||||
collStatus.Exists = exists
|
||||
if exists {
|
||||
points, err := client.CountAll(ctx)
|
||||
if err != nil {
|
||||
collStatus.Qdrant.OK = false
|
||||
collStatus.Qdrant.Error = err.Error()
|
||||
} else {
|
||||
collStatus.Points = points
|
||||
}
|
||||
}
|
||||
status.Collections = append(status.Collections, collStatus)
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (s *Service) List(ctx context.Context) ([]ProviderGetResponse, error) {
|
||||
rows, err := s.queries.ListMemoryProviders(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list memory providers: %w", err)
|
||||
}
|
||||
items := make([]ProviderGetResponse, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
items = append(items, s.toGetResponse(row))
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (s *Service) Update(ctx context.Context, id string, req ProviderUpdateRequest) (ProviderGetResponse, error) {
|
||||
pgID, err := db.ParseUUID(id)
|
||||
if err != nil {
|
||||
return ProviderGetResponse{}, err
|
||||
}
|
||||
current, err := s.queries.GetMemoryProviderByID(ctx, pgID)
|
||||
if err != nil {
|
||||
return ProviderGetResponse{}, fmt.Errorf("get memory provider: %w", err)
|
||||
}
|
||||
name := current.Name
|
||||
if req.Name != nil {
|
||||
name = strings.TrimSpace(*req.Name)
|
||||
}
|
||||
config := current.Config
|
||||
if req.Config != nil {
|
||||
configJSON, marshalErr := json.Marshal(req.Config)
|
||||
if marshalErr != nil {
|
||||
return ProviderGetResponse{}, fmt.Errorf("marshal config: %w", marshalErr)
|
||||
}
|
||||
config = configJSON
|
||||
}
|
||||
updated, err := s.queries.UpdateMemoryProvider(ctx, sqlc.UpdateMemoryProviderParams{
|
||||
ID: pgID,
|
||||
Name: name,
|
||||
Config: config,
|
||||
})
|
||||
if err != nil {
|
||||
return ProviderGetResponse{}, fmt.Errorf("update memory provider: %w", err)
|
||||
}
|
||||
resp := s.toGetResponse(updated)
|
||||
s.tryEvictAndReinstantiate(resp.ID, resp.Provider, resp.Config)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *Service) Delete(ctx context.Context, id string) error {
|
||||
pgID, err := db.ParseUUID(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.queries.DeleteMemoryProvider(ctx, pgID); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.registry != nil {
|
||||
s.registry.Remove(id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnsureDefault creates a default builtin provider if none exists.
|
||||
func (s *Service) EnsureDefault(ctx context.Context) (ProviderGetResponse, error) {
|
||||
row, err := s.queries.GetDefaultMemoryProvider(ctx)
|
||||
if err == nil {
|
||||
return s.toGetResponse(row), nil
|
||||
}
|
||||
configJSON, _ := json.Marshal(map[string]any{})
|
||||
created, err := s.queries.CreateMemoryProvider(ctx, sqlc.CreateMemoryProviderParams{
|
||||
Name: "Built-in Memory",
|
||||
Provider: string(ProviderBuiltin),
|
||||
Config: configJSON,
|
||||
IsDefault: true,
|
||||
})
|
||||
if err != nil {
|
||||
return ProviderGetResponse{}, fmt.Errorf("create default memory provider: %w", err)
|
||||
}
|
||||
return s.toGetResponse(created), nil
|
||||
}
|
||||
|
||||
func (s *Service) toGetResponse(row sqlc.MemoryProvider) ProviderGetResponse {
|
||||
var cfg map[string]any
|
||||
if len(row.Config) > 0 {
|
||||
if err := json.Unmarshal(row.Config, &cfg); err != nil {
|
||||
s.logger.Warn("memory provider config unmarshal failed", slog.String("id", row.ID.String()), slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
return ProviderGetResponse{
|
||||
ID: row.ID.String(),
|
||||
Name: row.Name,
|
||||
Provider: row.Provider,
|
||||
Config: cfg,
|
||||
IsDefault: row.IsDefault,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
UpdatedAt: row.UpdatedAt.Time,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) tryInstantiate(id, providerType string, config map[string]any) {
|
||||
if s.registry == nil {
|
||||
return
|
||||
}
|
||||
if _, err := s.registry.Instantiate(id, providerType, config); err != nil {
|
||||
s.logger.Warn("auto-instantiate memory provider failed",
|
||||
slog.String("id", id), slog.String("provider", providerType), slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) tryEvictAndReinstantiate(id, providerType string, config map[string]any) {
|
||||
if s.registry == nil {
|
||||
return
|
||||
}
|
||||
s.registry.Remove(id)
|
||||
s.tryInstantiate(id, providerType, config)
|
||||
}
|
||||
|
||||
func isValidProviderType(t ProviderType) bool {
|
||||
switch t {
|
||||
case ProviderBuiltin, ProviderMem0, ProviderOpenViking:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func parseQdrantHostPort(baseURL string) (string, int) {
|
||||
baseURL = strings.TrimSpace(baseURL)
|
||||
if baseURL == "" {
|
||||
return "", 0
|
||||
}
|
||||
baseURL = strings.TrimPrefix(baseURL, "http://")
|
||||
baseURL = strings.TrimPrefix(baseURL, "https://")
|
||||
parts := strings.SplitN(baseURL, ":", 2)
|
||||
host := parts[0]
|
||||
if len(parts) == 2 {
|
||||
httpPort, err := strconv.Atoi(strings.TrimRight(parts[1], "/"))
|
||||
if err == nil {
|
||||
switch httpPort {
|
||||
case 6333, 6334:
|
||||
return host, 6334
|
||||
default:
|
||||
return host, httpPort
|
||||
}
|
||||
}
|
||||
}
|
||||
return host, 6334
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package provider
|
||||
package adapters
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -19,8 +19,11 @@ type BeforeChatResult struct {
|
||||
|
||||
// AfterChatRequest is passed to OnAfterChat after receiving the gateway response.
|
||||
type AfterChatRequest struct {
|
||||
BotID string
|
||||
Messages []Message
|
||||
BotID string
|
||||
Messages []Message
|
||||
UserID string
|
||||
ChannelIdentityID string
|
||||
DisplayName string
|
||||
}
|
||||
|
||||
// LLM is the interface for LLM operations needed by memory service.
|
||||
@@ -205,16 +208,37 @@ type UsageResponse struct {
|
||||
|
||||
type RebuildResult struct {
|
||||
FsCount int `json:"fs_count"`
|
||||
QdrantCount int `json:"qdrant_count"`
|
||||
StorageCount int `json:"storage_count"`
|
||||
MissingCount int `json:"missing_count"`
|
||||
RestoredCount int `json:"restored_count"`
|
||||
}
|
||||
|
||||
type HealthStatus struct {
|
||||
OK bool `json:"ok"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type MemoryStatusResponse struct {
|
||||
ProviderType string `json:"provider_type,omitempty"`
|
||||
MemoryMode string `json:"memory_mode,omitempty"`
|
||||
CanManualSync bool `json:"can_manual_sync"`
|
||||
SourceDir string `json:"source_dir,omitempty"`
|
||||
OverviewPath string `json:"overview_path,omitempty"`
|
||||
MarkdownFileCount int `json:"markdown_file_count,omitempty"`
|
||||
SourceCount int `json:"source_count,omitempty"`
|
||||
IndexedCount int `json:"indexed_count,omitempty"`
|
||||
QdrantCollection string `json:"qdrant_collection,omitempty"`
|
||||
Encoder HealthStatus `json:"encoder"`
|
||||
Qdrant HealthStatus `json:"qdrant"`
|
||||
}
|
||||
|
||||
// Memory provider admin types.
|
||||
type ProviderType string
|
||||
|
||||
const (
|
||||
ProviderBuiltin ProviderType = "builtin"
|
||||
ProviderBuiltin ProviderType = "builtin"
|
||||
ProviderMem0 ProviderType = "mem0"
|
||||
ProviderOpenViking ProviderType = "openviking"
|
||||
)
|
||||
|
||||
type ProviderCreateRequest struct {
|
||||
@@ -247,6 +271,7 @@ type ProviderFieldSchema struct {
|
||||
Title string `json:"title,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Required bool `json:"required,omitempty"`
|
||||
Secret bool `json:"secret,omitempty"`
|
||||
Example any `json:"example,omitempty"`
|
||||
}
|
||||
|
||||
@@ -255,3 +280,17 @@ type ProviderMeta struct {
|
||||
DisplayName string `json:"display_name"`
|
||||
ConfigSchema ProviderConfigSchema `json:"config_schema"`
|
||||
}
|
||||
|
||||
type ProviderCollectionStatus struct {
|
||||
Name string `json:"name"`
|
||||
Exists bool `json:"exists"`
|
||||
Points int `json:"points"`
|
||||
Qdrant HealthStatus `json:"qdrant"`
|
||||
}
|
||||
|
||||
type ProviderStatusResponse struct {
|
||||
ProviderType string `json:"provider_type"`
|
||||
MemoryMode string `json:"memory_mode,omitempty"`
|
||||
EmbeddingModelID string `json:"embedding_model_id,omitempty"`
|
||||
Collections []ProviderCollectionStatus `json:"collections,omitempty"`
|
||||
}
|
||||
@@ -1,179 +0,0 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
queries *sqlc.Queries
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
|
||||
return &Service{
|
||||
queries: queries,
|
||||
logger: log.With(slog.String("service", "memory_providers")),
|
||||
}
|
||||
}
|
||||
|
||||
func (*Service) ListMeta(_ context.Context) []ProviderMeta {
|
||||
return []ProviderMeta{
|
||||
{
|
||||
Provider: string(ProviderBuiltin),
|
||||
DisplayName: "Built-in",
|
||||
ConfigSchema: ProviderConfigSchema{
|
||||
Fields: map[string]ProviderFieldSchema{
|
||||
"memory_model_id": {
|
||||
Type: "model_select",
|
||||
Title: "Memory Model",
|
||||
Description: "LLM model used for memory extraction and decision",
|
||||
Required: false,
|
||||
},
|
||||
"embedding_model_id": {
|
||||
Type: "model_select",
|
||||
Title: "Embedding Model",
|
||||
Description: "Embedding model for dense vector search",
|
||||
Required: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Create(ctx context.Context, req ProviderCreateRequest) (ProviderGetResponse, error) {
|
||||
if !isValidProviderType(req.Provider) {
|
||||
return ProviderGetResponse{}, fmt.Errorf("invalid provider type: %s", req.Provider)
|
||||
}
|
||||
configJSON, err := json.Marshal(req.Config)
|
||||
if err != nil {
|
||||
return ProviderGetResponse{}, fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
row, err := s.queries.CreateMemoryProvider(ctx, sqlc.CreateMemoryProviderParams{
|
||||
Name: strings.TrimSpace(req.Name),
|
||||
Provider: string(req.Provider),
|
||||
Config: configJSON,
|
||||
IsDefault: false,
|
||||
})
|
||||
if err != nil {
|
||||
return ProviderGetResponse{}, fmt.Errorf("create memory provider: %w", err)
|
||||
}
|
||||
return s.toGetResponse(row), nil
|
||||
}
|
||||
|
||||
func (s *Service) Get(ctx context.Context, id string) (ProviderGetResponse, error) {
|
||||
pgID, err := db.ParseUUID(id)
|
||||
if err != nil {
|
||||
return ProviderGetResponse{}, err
|
||||
}
|
||||
row, err := s.queries.GetMemoryProviderByID(ctx, pgID)
|
||||
if err != nil {
|
||||
return ProviderGetResponse{}, fmt.Errorf("get memory provider: %w", err)
|
||||
}
|
||||
return s.toGetResponse(row), nil
|
||||
}
|
||||
|
||||
func (s *Service) List(ctx context.Context) ([]ProviderGetResponse, error) {
|
||||
rows, err := s.queries.ListMemoryProviders(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list memory providers: %w", err)
|
||||
}
|
||||
items := make([]ProviderGetResponse, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
items = append(items, s.toGetResponse(row))
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (s *Service) Update(ctx context.Context, id string, req ProviderUpdateRequest) (ProviderGetResponse, error) {
|
||||
pgID, err := db.ParseUUID(id)
|
||||
if err != nil {
|
||||
return ProviderGetResponse{}, err
|
||||
}
|
||||
current, err := s.queries.GetMemoryProviderByID(ctx, pgID)
|
||||
if err != nil {
|
||||
return ProviderGetResponse{}, fmt.Errorf("get memory provider: %w", err)
|
||||
}
|
||||
name := current.Name
|
||||
if req.Name != nil {
|
||||
name = strings.TrimSpace(*req.Name)
|
||||
}
|
||||
config := current.Config
|
||||
if req.Config != nil {
|
||||
configJSON, marshalErr := json.Marshal(req.Config)
|
||||
if marshalErr != nil {
|
||||
return ProviderGetResponse{}, fmt.Errorf("marshal config: %w", marshalErr)
|
||||
}
|
||||
config = configJSON
|
||||
}
|
||||
updated, err := s.queries.UpdateMemoryProvider(ctx, sqlc.UpdateMemoryProviderParams{
|
||||
ID: pgID,
|
||||
Name: name,
|
||||
Config: config,
|
||||
})
|
||||
if err != nil {
|
||||
return ProviderGetResponse{}, fmt.Errorf("update memory provider: %w", err)
|
||||
}
|
||||
return s.toGetResponse(updated), nil
|
||||
}
|
||||
|
||||
func (s *Service) Delete(ctx context.Context, id string) error {
|
||||
pgID, err := db.ParseUUID(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.queries.DeleteMemoryProvider(ctx, pgID)
|
||||
}
|
||||
|
||||
// EnsureDefault creates a default builtin provider if none exists.
|
||||
func (s *Service) EnsureDefault(ctx context.Context) (ProviderGetResponse, error) {
|
||||
row, err := s.queries.GetDefaultMemoryProvider(ctx)
|
||||
if err == nil {
|
||||
return s.toGetResponse(row), nil
|
||||
}
|
||||
configJSON, _ := json.Marshal(map[string]any{})
|
||||
created, err := s.queries.CreateMemoryProvider(ctx, sqlc.CreateMemoryProviderParams{
|
||||
Name: "Built-in Memory",
|
||||
Provider: string(ProviderBuiltin),
|
||||
Config: configJSON,
|
||||
IsDefault: true,
|
||||
})
|
||||
if err != nil {
|
||||
return ProviderGetResponse{}, fmt.Errorf("create default memory provider: %w", err)
|
||||
}
|
||||
return s.toGetResponse(created), nil
|
||||
}
|
||||
|
||||
func (s *Service) toGetResponse(row sqlc.MemoryProvider) ProviderGetResponse {
|
||||
var cfg map[string]any
|
||||
if len(row.Config) > 0 {
|
||||
if err := json.Unmarshal(row.Config, &cfg); err != nil {
|
||||
s.logger.Warn("memory provider config unmarshal failed", slog.String("id", row.ID.String()), slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
return ProviderGetResponse{
|
||||
ID: row.ID.String(),
|
||||
Name: row.Name,
|
||||
Provider: row.Provider,
|
||||
Config: cfg,
|
||||
IsDefault: row.IsDefault,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
UpdatedAt: row.UpdatedAt.Time,
|
||||
}
|
||||
}
|
||||
|
||||
func isValidProviderType(t ProviderType) bool {
|
||||
switch t {
|
||||
case ProviderBuiltin:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,423 @@
|
||||
// Package qdrant wraps the official github.com/qdrant/go-client SDK,
|
||||
// providing a thin facade for sparse-vector memory operations.
|
||||
package qdrant
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
pb "github.com/qdrant/go-client/qdrant"
|
||||
)
|
||||
|
||||
const (
|
||||
sparseVectorName = "sparse"
|
||||
)
|
||||
|
||||
// Client wraps the official Qdrant gRPC client with sparse-memory-specific helpers.
|
||||
type Client struct {
|
||||
inner *pb.Client
|
||||
collection string
|
||||
}
|
||||
|
||||
// NewClient creates a Qdrant client connected via gRPC.
|
||||
// host should be a bare hostname/IP; port is the gRPC port (default 6334).
|
||||
func NewClient(host string, port int, apiKey, collection string) (*Client, error) {
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
}
|
||||
if port == 0 {
|
||||
port = 6334
|
||||
}
|
||||
if collection == "" {
|
||||
collection = "memory"
|
||||
}
|
||||
|
||||
cfg := &pb.Config{
|
||||
Host: host,
|
||||
Port: port,
|
||||
}
|
||||
if apiKey != "" {
|
||||
cfg.APIKey = apiKey
|
||||
}
|
||||
|
||||
inner, err := pb.NewClient(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qdrant: connect: %w", err)
|
||||
}
|
||||
return &Client{inner: inner, collection: collection}, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying gRPC connection.
|
||||
func (c *Client) Close() error {
|
||||
return c.inner.Close()
|
||||
}
|
||||
|
||||
func (c *Client) CollectionName() string {
|
||||
return c.collection
|
||||
}
|
||||
|
||||
func (c *Client) CollectionExists(ctx context.Context) (bool, error) {
|
||||
exists, err := c.inner.CollectionExists(ctx, c.collection)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("qdrant: check collection: %w", err)
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// EnsureCollection creates the collection with a named sparse vector config if it does not exist.
|
||||
func (c *Client) EnsureCollection(ctx context.Context) error {
|
||||
exists, err := c.CollectionExists(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
err = c.inner.CreateCollection(ctx, &pb.CreateCollection{
|
||||
CollectionName: c.collection,
|
||||
SparseVectorsConfig: pb.NewSparseVectorsConfig(map[string]*pb.SparseVectorParams{
|
||||
sparseVectorName: {},
|
||||
}),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("qdrant: create collection: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnsureDenseCollection creates the collection with dense vector config if it
|
||||
// does not exist.
|
||||
func (c *Client) EnsureDenseCollection(ctx context.Context, dimensions int) error {
|
||||
if dimensions <= 0 {
|
||||
return fmt.Errorf("qdrant: dense dimensions must be positive, got %d", dimensions)
|
||||
}
|
||||
exists, err := c.CollectionExists(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
err = c.inner.CreateCollection(ctx, &pb.CreateCollection{
|
||||
CollectionName: c.collection,
|
||||
VectorsConfig: pb.NewVectorsConfig(&pb.VectorParams{
|
||||
Size: uint64(dimensions),
|
||||
Distance: pb.Distance_Cosine,
|
||||
}),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("qdrant: create dense collection: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SparseVector holds the non-zero components of a sparse text encoding.
|
||||
type SparseVector struct {
|
||||
Indices []uint32
|
||||
Values []float32
|
||||
}
|
||||
|
||||
type DenseVector struct {
|
||||
Values []float32
|
||||
}
|
||||
|
||||
// SearchResult is one result from a sparse search or scroll.
|
||||
type SearchResult struct {
|
||||
ID string
|
||||
Score float64
|
||||
Payload map[string]string
|
||||
}
|
||||
|
||||
// Upsert inserts or updates points with named sparse vectors.
|
||||
func (c *Client) Upsert(ctx context.Context, id string, vec SparseVector, payload map[string]string) error {
|
||||
wait := true
|
||||
_, err := c.inner.Upsert(ctx, &pb.UpsertPoints{
|
||||
CollectionName: c.collection,
|
||||
Wait: &wait,
|
||||
Points: []*pb.PointStruct{
|
||||
{
|
||||
Id: pb.NewID(id),
|
||||
Vectors: pb.NewVectorsMap(map[string]*pb.Vector{
|
||||
sparseVectorName: {
|
||||
Data: vec.Values,
|
||||
Indices: &pb.SparseIndices{Data: vec.Indices},
|
||||
},
|
||||
}),
|
||||
Payload: stringPayloadToValueMap(payload),
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("qdrant: upsert: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpsertDense inserts or updates points with dense vectors.
|
||||
func (c *Client) UpsertDense(ctx context.Context, id string, vec DenseVector, payload map[string]string) error {
|
||||
wait := true
|
||||
_, err := c.inner.Upsert(ctx, &pb.UpsertPoints{
|
||||
CollectionName: c.collection,
|
||||
Wait: &wait,
|
||||
Points: []*pb.PointStruct{
|
||||
{
|
||||
Id: pb.NewID(id),
|
||||
Vectors: pb.NewVectorsDense(vec.Values),
|
||||
Payload: stringPayloadToValueMap(payload),
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("qdrant: upsert dense: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Search performs a sparse-vector query against the collection, filtered by bot_id.
|
||||
func (c *Client) Search(ctx context.Context, vec SparseVector, botID string, limit int) ([]SearchResult, error) {
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
queryLimit, err := intToUint64(limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qdrant: invalid search limit: %w", err)
|
||||
}
|
||||
scored, err := c.inner.Query(ctx, &pb.QueryPoints{
|
||||
CollectionName: c.collection,
|
||||
Query: pb.NewQuerySparse(vec.Indices, vec.Values),
|
||||
Using: strPtr(sparseVectorName),
|
||||
Filter: botFilter(botID),
|
||||
Limit: uint64Ptr(queryLimit),
|
||||
WithPayload: pb.NewWithPayload(true),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qdrant: search: %w", err)
|
||||
}
|
||||
return scoredPointsToResults(scored), nil
|
||||
}
|
||||
|
||||
// SearchDense performs a dense-vector query against the collection, filtered by bot_id.
|
||||
func (c *Client) SearchDense(ctx context.Context, vec DenseVector, botID string, limit int) ([]SearchResult, error) {
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
queryLimit, err := intToUint64(limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qdrant: invalid dense search limit: %w", err)
|
||||
}
|
||||
scored, err := c.inner.Query(ctx, &pb.QueryPoints{
|
||||
CollectionName: c.collection,
|
||||
Query: pb.NewQueryDense(vec.Values),
|
||||
Filter: botFilter(botID),
|
||||
Limit: uint64Ptr(queryLimit),
|
||||
WithPayload: pb.NewWithPayload(true),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qdrant: dense search: %w", err)
|
||||
}
|
||||
return scoredPointsToResults(scored), nil
|
||||
}
|
||||
|
||||
// GetByID fetches a single point by UUID.
|
||||
func (c *Client) GetByID(ctx context.Context, id string) (*SearchResult, error) {
|
||||
points, err := c.inner.Get(ctx, &pb.GetPoints{
|
||||
CollectionName: c.collection,
|
||||
Ids: []*pb.PointId{pb.NewID(id)},
|
||||
WithPayload: pb.NewWithPayload(true),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qdrant: get: %w", err)
|
||||
}
|
||||
if len(points) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
r := retrievedPointToResult(points[0])
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
// Scroll returns all points matching bot_id, up to limit.
|
||||
func (c *Client) Scroll(ctx context.Context, botID string, limit int) ([]SearchResult, error) {
|
||||
if limit <= 0 {
|
||||
limit = 1000
|
||||
}
|
||||
if limit > math.MaxUint32 {
|
||||
limit = math.MaxUint32
|
||||
}
|
||||
l, err := intToUint32(limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qdrant: invalid scroll limit: %w", err)
|
||||
}
|
||||
points, err := c.inner.Scroll(ctx, &pb.ScrollPoints{
|
||||
CollectionName: c.collection,
|
||||
Filter: botFilter(botID),
|
||||
Limit: &l,
|
||||
WithPayload: pb.NewWithPayload(true),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qdrant: scroll: %w", err)
|
||||
}
|
||||
results := make([]SearchResult, 0, len(points))
|
||||
for _, p := range points {
|
||||
results = append(results, retrievedPointToResult(p))
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Count returns the number of points matching bot_id.
|
||||
func (c *Client) Count(ctx context.Context, botID string) (int, error) {
|
||||
exact := true
|
||||
n, err := c.inner.Count(ctx, &pb.CountPoints{
|
||||
CollectionName: c.collection,
|
||||
Filter: botFilter(botID),
|
||||
Exact: &exact,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("qdrant: count: %w", err)
|
||||
}
|
||||
if n > uint64(math.MaxInt) {
|
||||
return 0, fmt.Errorf("qdrant: count overflow: %d", n)
|
||||
}
|
||||
return int(n), nil
|
||||
}
|
||||
|
||||
// CountAll returns the total number of points in the collection.
|
||||
func (c *Client) CountAll(ctx context.Context) (int, error) {
|
||||
exact := true
|
||||
n, err := c.inner.Count(ctx, &pb.CountPoints{
|
||||
CollectionName: c.collection,
|
||||
Exact: &exact,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("qdrant: count all: %w", err)
|
||||
}
|
||||
if n > uint64(math.MaxInt) {
|
||||
return 0, fmt.Errorf("qdrant: count overflow: %d", n)
|
||||
}
|
||||
return int(n), nil
|
||||
}
|
||||
|
||||
// DeleteByIDs removes specific points by their UUID strings.
|
||||
func (c *Client) DeleteByIDs(ctx context.Context, ids []string) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
pointIDs := make([]*pb.PointId, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if strings.TrimSpace(id) != "" {
|
||||
pointIDs = append(pointIDs, pb.NewID(strings.TrimSpace(id)))
|
||||
}
|
||||
}
|
||||
wait := true
|
||||
_, err := c.inner.Delete(ctx, &pb.DeletePoints{
|
||||
CollectionName: c.collection,
|
||||
Wait: &wait,
|
||||
Points: &pb.PointsSelector{
|
||||
PointsSelectorOneOf: &pb.PointsSelector_Points{
|
||||
Points: &pb.PointsIdsList{Ids: pointIDs},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("qdrant: delete by ids: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByBotID removes all points for a given bot_id.
|
||||
func (c *Client) DeleteByBotID(ctx context.Context, botID string) error {
|
||||
wait := true
|
||||
_, err := c.inner.Delete(ctx, &pb.DeletePoints{
|
||||
CollectionName: c.collection,
|
||||
Wait: &wait,
|
||||
Points: &pb.PointsSelector{
|
||||
PointsSelectorOneOf: &pb.PointsSelector_Filter{
|
||||
Filter: botFilter(botID),
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("qdrant: delete by bot_id: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func botFilter(botID string) *pb.Filter {
|
||||
return &pb.Filter{
|
||||
Must: []*pb.Condition{
|
||||
pb.NewMatch("bot_id", botID),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func stringPayloadToValueMap(payload map[string]string) map[string]*pb.Value {
|
||||
m := make(map[string]*pb.Value, len(payload))
|
||||
for k, v := range payload {
|
||||
m[k] = pb.NewValueString(v)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func valueMapToStringPayload(m map[string]*pb.Value) map[string]string {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]string, len(m))
|
||||
for k, v := range m {
|
||||
if v != nil {
|
||||
if sv := v.GetStringValue(); sv != "" {
|
||||
out[k] = sv
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func scoredPointsToResults(scored []*pb.ScoredPoint) []SearchResult {
|
||||
results := make([]SearchResult, 0, len(scored))
|
||||
for _, p := range scored {
|
||||
results = append(results, SearchResult{
|
||||
ID: extractID(p.GetId()),
|
||||
Score: float64(p.GetScore()),
|
||||
Payload: valueMapToStringPayload(p.GetPayload()),
|
||||
})
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
func retrievedPointToResult(p *pb.RetrievedPoint) SearchResult {
|
||||
return SearchResult{
|
||||
ID: extractID(p.GetId()),
|
||||
Payload: valueMapToStringPayload(p.GetPayload()),
|
||||
}
|
||||
}
|
||||
|
||||
func extractID(id *pb.PointId) string {
|
||||
if id == nil {
|
||||
return ""
|
||||
}
|
||||
if uuid := id.GetUuid(); uuid != "" {
|
||||
return uuid
|
||||
}
|
||||
return strconv.FormatUint(id.GetNum(), 10)
|
||||
}
|
||||
|
||||
func strPtr(s string) *string { return &s }
|
||||
|
||||
func uint64Ptr(v uint64) *uint64 { return &v }
|
||||
|
||||
func intToUint64(v int) (uint64, error) {
|
||||
return strconv.ParseUint(strconv.Itoa(v), 10, 64)
|
||||
}
|
||||
|
||||
func intToUint32(v int) (uint32, error) {
|
||||
n, err := strconv.ParseUint(strconv.Itoa(v), 10, 32)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint32(n), nil
|
||||
}
|
||||
@@ -1,366 +0,0 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
memoryDateLayout = "2006-01-02"
|
||||
memEntryStartPrefix = "<!-- MEMOH:ENTRY "
|
||||
memEntryStartSuffix = " -->"
|
||||
memEntryEndMarker = "<!-- /MEMOH:ENTRY -->"
|
||||
memFileHeaderTemplate = "# Memory %s\n\n"
|
||||
)
|
||||
|
||||
type writeRecord struct {
|
||||
Topic string `json:"topic"`
|
||||
ID string `json:"id"`
|
||||
Memory string `json:"memory"`
|
||||
Text string `json:"text"`
|
||||
Content string `json:"content"`
|
||||
Hash string `json:"hash"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// NormalizeMemoryDayContent converts user/LLM writes to canonical memory day format.
|
||||
// Non-memory-day paths are returned unchanged.
|
||||
func NormalizeMemoryDayContent(containerPath, raw string) string {
|
||||
if !isMemoryDayMarkdownPath(containerPath) {
|
||||
return raw
|
||||
}
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return raw
|
||||
}
|
||||
if strings.Contains(trimmed, memEntryStartPrefix) && strings.Contains(trimmed, memEntryEndMarker) {
|
||||
return raw
|
||||
}
|
||||
date := strings.TrimSuffix(path.Base(containerPath), ".md")
|
||||
records := parseStructuredRecords(trimmed)
|
||||
if len(records) == 0 {
|
||||
records = []writeRecord{buildFallbackRecord(trimmed, date, time.Now().UTC())}
|
||||
}
|
||||
return formatDayMarkdown(date, records)
|
||||
}
|
||||
|
||||
// RenderMemoryDayForDisplay converts canonical memory day markdown into
|
||||
// a user-facing timeline view. Non-memory-day paths are returned unchanged.
|
||||
func RenderMemoryDayForDisplay(containerPath, raw string) string {
|
||||
if !isMemoryDayMarkdownPath(containerPath) {
|
||||
return raw
|
||||
}
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return raw
|
||||
}
|
||||
date := strings.TrimSuffix(path.Base(containerPath), ".md")
|
||||
records := parseCanonicalDayRecords(trimmed)
|
||||
if len(records) == 0 {
|
||||
return raw
|
||||
}
|
||||
sort.Slice(records, func(i, j int) bool {
|
||||
ti := recordTime(records[i])
|
||||
tj := recordTime(records[j])
|
||||
if ti.Equal(tj) {
|
||||
return records[i].ID < records[j].ID
|
||||
}
|
||||
return ti.Before(tj)
|
||||
})
|
||||
var b strings.Builder
|
||||
b.WriteString("# ")
|
||||
b.WriteString(date)
|
||||
b.WriteString("\n\n")
|
||||
for idx, r := range records {
|
||||
if idx > 0 {
|
||||
b.WriteString("\n")
|
||||
}
|
||||
b.WriteString("## ")
|
||||
b.WriteString(formatRecordTime(r))
|
||||
b.WriteString(" - ")
|
||||
b.WriteString(recordTitle(r))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(formatRecordBody(r.Memory))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
return strings.TrimSpace(b.String())
|
||||
}
|
||||
|
||||
func isMemoryDayMarkdownPath(containerPath string) bool {
|
||||
clean := path.Clean("/" + strings.TrimSpace(containerPath))
|
||||
memoryDir := path.Clean(config.DefaultDataMount+"/memory") + "/"
|
||||
if !strings.HasPrefix(clean, memoryDir) || !strings.HasSuffix(clean, ".md") {
|
||||
return false
|
||||
}
|
||||
datePart := strings.TrimSuffix(path.Base(clean), ".md")
|
||||
_, err := time.Parse(memoryDateLayout, datePart)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func parseStructuredRecords(content string) []writeRecord {
|
||||
now := time.Now().UTC()
|
||||
normalize := func(in []writeRecord) []writeRecord {
|
||||
out := make([]writeRecord, 0, len(in))
|
||||
for _, r := range in {
|
||||
nr, ok := normalizeRecord(r, now)
|
||||
if ok {
|
||||
out = append(out, nr)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
var list []writeRecord
|
||||
if err := json.Unmarshal([]byte(content), &list); err == nil {
|
||||
return normalize(list)
|
||||
}
|
||||
var obj writeRecord
|
||||
if err := json.Unmarshal([]byte(content), &obj); err == nil {
|
||||
return normalize([]writeRecord{obj})
|
||||
}
|
||||
var wrapped struct {
|
||||
Items []writeRecord `json:"items"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(content), &wrapped); err == nil {
|
||||
return normalize(wrapped.Items)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseCanonicalDayRecords(content string) []writeRecord {
|
||||
lines := strings.Split(content, "\n")
|
||||
out := make([]writeRecord, 0, 8)
|
||||
for i := 0; i < len(lines); i++ {
|
||||
line := strings.TrimSpace(lines[i])
|
||||
if !strings.HasPrefix(line, memEntryStartPrefix) || !strings.HasSuffix(line, memEntryStartSuffix) {
|
||||
continue
|
||||
}
|
||||
metaJSON := strings.TrimSuffix(strings.TrimPrefix(line, memEntryStartPrefix), memEntryStartSuffix)
|
||||
var rec writeRecord
|
||||
if err := json.Unmarshal([]byte(metaJSON), &rec); err != nil {
|
||||
continue
|
||||
}
|
||||
start := i + 1
|
||||
end := start
|
||||
for ; end < len(lines); end++ {
|
||||
if strings.TrimSpace(lines[end]) == memEntryEndMarker {
|
||||
break
|
||||
}
|
||||
}
|
||||
if end >= len(lines) {
|
||||
break
|
||||
}
|
||||
rec.Memory = strings.TrimSpace(strings.Join(lines[start:end], "\n"))
|
||||
out = append(out, rec)
|
||||
i = end
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func formatDayMarkdown(date string, records []writeRecord) string {
|
||||
sort.Slice(records, func(i, j int) bool {
|
||||
ti := parseRFC3339OrZero(records[i].CreatedAt)
|
||||
tj := parseRFC3339OrZero(records[j].CreatedAt)
|
||||
if ti.Equal(tj) {
|
||||
return records[i].ID < records[j].ID
|
||||
}
|
||||
return ti.Before(tj)
|
||||
})
|
||||
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, memFileHeaderTemplate, date)
|
||||
for _, r := range records {
|
||||
meta := map[string]string{"id": r.ID}
|
||||
if r.Topic != "" {
|
||||
meta["topic"] = r.Topic
|
||||
}
|
||||
if r.Hash != "" {
|
||||
meta["hash"] = r.Hash
|
||||
}
|
||||
if r.CreatedAt != "" {
|
||||
meta["created_at"] = r.CreatedAt
|
||||
}
|
||||
if r.UpdatedAt != "" {
|
||||
meta["updated_at"] = r.UpdatedAt
|
||||
}
|
||||
rawMeta, _ := json.Marshal(meta)
|
||||
b.WriteString(memEntryStartPrefix)
|
||||
b.Write(rawMeta)
|
||||
b.WriteString(memEntryStartSuffix)
|
||||
b.WriteString("\n")
|
||||
b.WriteString(r.Memory)
|
||||
b.WriteString("\n")
|
||||
b.WriteString(memEntryEndMarker)
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func parseRFC3339OrZero(raw string) time.Time {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return time.Time{}
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, raw)
|
||||
if err != nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return t.UTC()
|
||||
}
|
||||
|
||||
func recordTime(r writeRecord) time.Time {
|
||||
if t := parseRFC3339OrZero(r.CreatedAt); !t.IsZero() {
|
||||
return t
|
||||
}
|
||||
if t := parseRFC3339OrZero(r.UpdatedAt); !t.IsZero() {
|
||||
return t
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func formatRecordTime(r writeRecord) string {
|
||||
t := recordTime(r)
|
||||
if t.IsZero() {
|
||||
return "--:--"
|
||||
}
|
||||
return t.Format("03:04 PM")
|
||||
}
|
||||
|
||||
func recordTitle(r writeRecord) string {
|
||||
if topic := strings.TrimSpace(r.Topic); topic != "" {
|
||||
return topic
|
||||
}
|
||||
return "Notes"
|
||||
}
|
||||
|
||||
func formatRecordBody(body string) string {
|
||||
lines := strings.Split(strings.TrimSpace(body), "\n")
|
||||
out := make([]string, 0, len(lines))
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "- ") || strings.HasPrefix(line, "* ") || strings.HasPrefix(line, "1. ") {
|
||||
out = append(out, line)
|
||||
continue
|
||||
}
|
||||
out = append(out, "- "+line)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return "- (empty)"
|
||||
}
|
||||
return strings.Join(out, "\n")
|
||||
}
|
||||
|
||||
func buildFallbackRecord(content, date string, now time.Time) writeRecord {
|
||||
record := writeRecord{
|
||||
ID: fmt.Sprintf("mem_%d", now.UnixNano()),
|
||||
Memory: sanitizeFallbackBody(content, date),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
UpdatedAt: now.Format(time.RFC3339),
|
||||
}
|
||||
if legacy, ok := parseLegacyFrontmatterRecord(content); ok {
|
||||
if normalized, ok := normalizeRecord(legacy, now); ok {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
if record.Hash == "" {
|
||||
record.Hash = generateMemoryHash(record.Topic, record.Memory)
|
||||
}
|
||||
return record
|
||||
}
|
||||
|
||||
func parseLegacyFrontmatterRecord(content string) (writeRecord, bool) {
|
||||
trimmed := strings.TrimSpace(content)
|
||||
if !strings.HasPrefix(trimmed, "---") {
|
||||
return writeRecord{}, false
|
||||
}
|
||||
parts := strings.SplitN(trimmed[3:], "---", 2)
|
||||
if len(parts) < 2 {
|
||||
return writeRecord{}, false
|
||||
}
|
||||
frontmatter := strings.TrimSpace(parts[0])
|
||||
body := strings.TrimSpace(parts[1])
|
||||
record := writeRecord{Memory: body}
|
||||
for _, line := range strings.Split(frontmatter, "\n") {
|
||||
key, value, found := strings.Cut(strings.TrimSpace(line), ":")
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
switch key {
|
||||
case "id":
|
||||
record.ID = value
|
||||
case "hash":
|
||||
record.Hash = value
|
||||
case "created_at":
|
||||
record.CreatedAt = value
|
||||
case "updated_at":
|
||||
record.UpdatedAt = value
|
||||
}
|
||||
}
|
||||
return record, true
|
||||
}
|
||||
|
||||
func sanitizeFallbackBody(content, date string) string {
|
||||
body := strings.TrimSpace(content)
|
||||
header := "# Memory " + strings.TrimSpace(date)
|
||||
if strings.HasPrefix(body, header) {
|
||||
body = strings.TrimSpace(strings.TrimPrefix(body, header))
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func normalizeRecord(r writeRecord, now time.Time) (writeRecord, bool) {
|
||||
mem := strings.TrimSpace(r.Memory)
|
||||
if mem == "" {
|
||||
mem = strings.TrimSpace(r.Content)
|
||||
}
|
||||
if mem == "" {
|
||||
mem = strings.TrimSpace(r.Text)
|
||||
}
|
||||
if mem == "" {
|
||||
return writeRecord{}, false
|
||||
}
|
||||
topic := strings.TrimSpace(r.Topic)
|
||||
id := strings.TrimSpace(r.ID)
|
||||
if id == "" {
|
||||
id = fmt.Sprintf("mem_%d", now.UnixNano())
|
||||
}
|
||||
createdAt := strings.TrimSpace(r.CreatedAt)
|
||||
if createdAt == "" {
|
||||
createdAt = now.Format(time.RFC3339)
|
||||
}
|
||||
updatedAt := strings.TrimSpace(r.UpdatedAt)
|
||||
if updatedAt == "" {
|
||||
updatedAt = createdAt
|
||||
}
|
||||
hash := strings.TrimSpace(r.Hash)
|
||||
if hash == "" {
|
||||
hash = generateMemoryHash(topic, mem)
|
||||
}
|
||||
return writeRecord{
|
||||
Topic: topic,
|
||||
ID: id,
|
||||
Memory: mem,
|
||||
Hash: hash,
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
}, true
|
||||
}
|
||||
|
||||
func generateMemoryHash(topic, memory string) string {
|
||||
sum := sha256.Sum256([]byte(strings.TrimSpace(topic) + "\n" + strings.TrimSpace(memory)))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeMemoryDayContent_StructuredJSON(t *testing.T) {
|
||||
path := "/data/memory/2026-03-01.md"
|
||||
input := `[
|
||||
{
|
||||
"topic": "Decision",
|
||||
"memory": "Choose provider architecture."
|
||||
}
|
||||
]`
|
||||
|
||||
out := NormalizeMemoryDayContent(path, input)
|
||||
if !strings.Contains(out, "# Memory 2026-03-01") {
|
||||
t.Fatalf("expected day header, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, `<!-- MEMOH:ENTRY `) {
|
||||
t.Fatalf("expected entry marker, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, `"topic":"Decision"`) {
|
||||
t.Fatalf("expected topic metadata, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "Choose provider architecture.") {
|
||||
t.Fatalf("expected memory body, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, `"hash":"`) {
|
||||
t.Fatalf("expected generated hash metadata, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeMemoryDayContent_FallbackPlainText(t *testing.T) {
|
||||
path := "/data/memory/2026-03-01.md"
|
||||
input := "Unstructured note from model output."
|
||||
out := NormalizeMemoryDayContent(path, input)
|
||||
|
||||
if !strings.Contains(out, "# Memory 2026-03-01") {
|
||||
t.Fatalf("expected day header, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "Unstructured note from model output.") {
|
||||
t.Fatalf("expected original text preserved, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, `"created_at":"`) || !strings.Contains(out, `"updated_at":"`) {
|
||||
t.Fatalf("expected timestamps, got: %s", out)
|
||||
}
|
||||
if !regexp.MustCompile(`"id":"mem_\d+"`).MatchString(out) {
|
||||
t.Fatalf("expected generated id, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeMemoryDayContent_LegacyFrontmatter(t *testing.T) {
|
||||
path := "/data/memory/2026-03-01.md"
|
||||
input := `---
|
||||
id: mem_legacy_1
|
||||
hash: legacyhash
|
||||
created_at: 2026-03-01T09:00:00Z
|
||||
updated_at: 2026-03-01T10:00:00Z
|
||||
---
|
||||
Legacy body text.`
|
||||
|
||||
out := NormalizeMemoryDayContent(path, input)
|
||||
if !strings.Contains(out, `"id":"mem_legacy_1"`) {
|
||||
t.Fatalf("expected legacy id reused, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, `"hash":"legacyhash"`) {
|
||||
t.Fatalf("expected legacy hash reused, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "Legacy body text.") {
|
||||
t.Fatalf("expected legacy body reused, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderMemoryDayForDisplay(t *testing.T) {
|
||||
path := "/data/memory/2026-03-01.md"
|
||||
raw := `# Memory 2026-03-01
|
||||
|
||||
<!-- MEMOH:ENTRY {"id":"mem_1","topic":"Decision","created_at":"2026-03-01T09:40:00Z"} -->
|
||||
结论:采用 provider 架构
|
||||
<!-- /MEMOH:ENTRY -->
|
||||
|
||||
<!-- MEMOH:ENTRY {"id":"mem_2","topic":"Notes","created_at":"2026-03-01T11:15:00Z"} -->
|
||||
用户偏好:简短回复
|
||||
<!-- /MEMOH:ENTRY -->
|
||||
`
|
||||
|
||||
out := RenderMemoryDayForDisplay(path, raw)
|
||||
if strings.Contains(out, "MEMOH:ENTRY") {
|
||||
t.Fatalf("display output should hide raw markers: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "# 2026-03-01") {
|
||||
t.Fatalf("expected display day header, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "## 09:40 AM - Decision") {
|
||||
t.Fatalf("expected timeline section, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "- 结论:采用 provider 架构") {
|
||||
t.Fatalf("expected bulletized body, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "## 11:15 AM - Notes") {
|
||||
t.Fatalf("expected second timeline section, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderMemoryDayForDisplay_NonMemoryPathUnchanged(t *testing.T) {
|
||||
raw := "plain content"
|
||||
out := RenderMemoryDayForDisplay("/data/notes.md", raw)
|
||||
if out != raw {
|
||||
t.Fatalf("non-memory path should be unchanged, got: %s", out)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
// Package sparse provides a Go client for the sparse encoding Python service.
|
||||
// The Python service loads the OpenSearch neural sparse model from HuggingFace
|
||||
// and exposes HTTP endpoints for text → sparse vector encoding.
|
||||
package sparse
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SparseVector holds the non-zero components of a sparse text encoding.
|
||||
type SparseVector struct {
|
||||
Indices []uint32 `json:"indices"`
|
||||
Values []float32 `json:"values"`
|
||||
}
|
||||
|
||||
// Client calls the Python sparse encoding service.
|
||||
type Client struct {
|
||||
baseURL string
|
||||
http *http.Client
|
||||
}
|
||||
|
||||
// NewClient creates a sparse encoding client pointing to the Python service.
|
||||
func NewClient(baseURL string) *Client {
|
||||
return &Client{
|
||||
baseURL: baseURL,
|
||||
http: &http.Client{Timeout: 60 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeDocument encodes a document text into a sparse vector using the neural model.
|
||||
func (c *Client) EncodeDocument(ctx context.Context, text string) (*SparseVector, error) {
|
||||
return c.encode(ctx, "/encode/document", text)
|
||||
}
|
||||
|
||||
// EncodeQuery encodes a query text into a sparse vector (IDF-weighted tokenizer lookup).
|
||||
func (c *Client) EncodeQuery(ctx context.Context, text string) (*SparseVector, error) {
|
||||
return c.encode(ctx, "/encode/query", text)
|
||||
}
|
||||
|
||||
// EncodeDocuments encodes multiple document texts in a single batch call.
|
||||
func (c *Client) EncodeDocuments(ctx context.Context, texts []string) ([]SparseVector, error) {
|
||||
body, err := json.Marshal(map[string]any{"texts": texts})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
endpoint, err := joinEndpointURL(c.baseURL, "/encode/documents")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.http.Do(req) //nolint:gosec // G704: URL is validated and derived from operator-configured sparse encoder base URL
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sparse encode failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("sparse encode error (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
var vectors []SparseVector
|
||||
if err := json.NewDecoder(resp.Body).Decode(&vectors); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return vectors, nil
|
||||
}
|
||||
|
||||
func (c *Client) Health(ctx context.Context) error {
|
||||
endpoint, err := joinEndpointURL(c.baseURL, "/health")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := c.http.Do(req) //nolint:gosec // G704: URL is validated and derived from operator-configured sparse encoder base URL
|
||||
if err != nil {
|
||||
return fmt.Errorf("sparse health check failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("sparse health error (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) encode(ctx context.Context, path, text string) (*SparseVector, error) {
|
||||
body, err := json.Marshal(map[string]string{"text": text})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
endpoint, err := joinEndpointURL(c.baseURL, path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.http.Do(req) //nolint:gosec // G704: URL is validated and derived from operator-configured sparse encoder base URL
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sparse encode failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("sparse encode error (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
var vec SparseVector
|
||||
if err := json.NewDecoder(resp.Body).Decode(&vec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &vec, nil
|
||||
}
|
||||
|
||||
func joinEndpointURL(baseURL, path string) (string, error) {
|
||||
baseURL = strings.TrimSpace(baseURL)
|
||||
if baseURL == "" {
|
||||
return "", errors.New("sparse encode base URL is required")
|
||||
}
|
||||
|
||||
base, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid sparse encode base URL: %w", err)
|
||||
}
|
||||
if base.Scheme != "http" && base.Scheme != "https" {
|
||||
return "", fmt.Errorf("invalid sparse encode base URL scheme: %q", base.Scheme)
|
||||
}
|
||||
if base.Host == "" {
|
||||
return "", errors.New("invalid sparse encode base URL: host is required")
|
||||
}
|
||||
|
||||
ref, err := url.Parse(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid sparse encode path: %w", err)
|
||||
}
|
||||
return base.ResolveReference(ref).String(), nil
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
"""Sparse encoding Flask service using OpenSearch neural sparse model."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from flask import Flask, jsonify, request
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
DEFAULT_MODEL_REPO = "opensearch-project/opensearch-neural-sparse-encoding-multilingual-v1"
|
||||
DEFAULT_PORT = 8085
|
||||
DEFAULT_CACHE_DIR = os.environ.get(
|
||||
"SPARSE_CACHE_DIR",
|
||||
str(Path(__file__).resolve().parent / "hf-cache"),
|
||||
)
|
||||
|
||||
model_repo = DEFAULT_MODEL_REPO
|
||||
cache_dir = DEFAULT_CACHE_DIR
|
||||
port = int(os.environ.get("SPARSE_PORT", DEFAULT_PORT))
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
_model = None
|
||||
_tokenizer = None
|
||||
_idf = None
|
||||
_special_token_ids: list[int] = []
|
||||
def _load_model() -> None:
|
||||
global _model, _tokenizer, _idf, _special_token_ids
|
||||
Path(cache_dir).mkdir(parents=True, exist_ok=True)
|
||||
_model = AutoModelForMaskedLM.from_pretrained(model_repo, cache_dir=cache_dir)
|
||||
_tokenizer = AutoTokenizer.from_pretrained(model_repo, cache_dir=cache_dir)
|
||||
_model.eval()
|
||||
_idf = _load_idf(_tokenizer)
|
||||
_special_token_ids = [
|
||||
_tokenizer.vocab[tok]
|
||||
for tok in _tokenizer.special_tokens_map.values()
|
||||
if tok in _tokenizer.vocab
|
||||
]
|
||||
|
||||
|
||||
def _load_idf(tokenizer):
|
||||
local_path = hf_hub_download(
|
||||
repo_id=model_repo, filename="idf.json", cache_dir=cache_dir
|
||||
)
|
||||
with open(local_path, encoding="utf-8") as f:
|
||||
idf_data = json.load(f)
|
||||
idf_vector = [0.0] * tokenizer.vocab_size
|
||||
for tok, weight in idf_data.items():
|
||||
tid = tokenizer._convert_token_to_id_with_added_voc(tok)
|
||||
idf_vector[tid] = weight
|
||||
return torch.tensor(idf_vector)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_document(text: str) -> dict:
|
||||
feat = _tokenizer(
|
||||
[text],
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_token_type_ids=False,
|
||||
)
|
||||
out = _model(**feat)[0]
|
||||
vals, _ = torch.max(out * feat["attention_mask"].unsqueeze(-1), dim=1)
|
||||
vals = torch.log(1 + torch.log(1 + torch.relu(vals)))
|
||||
vals[:, _special_token_ids] = 0
|
||||
return _sparse_to_dict(vals[0])
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_documents(texts: list[str]) -> list[dict]:
|
||||
feat = _tokenizer(
|
||||
texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_token_type_ids=False,
|
||||
)
|
||||
out = _model(**feat)[0]
|
||||
vals, _ = torch.max(out * feat["attention_mask"].unsqueeze(-1), dim=1)
|
||||
vals = torch.log(1 + torch.log(1 + torch.relu(vals)))
|
||||
vals[:, _special_token_ids] = 0
|
||||
return [_sparse_to_dict(vals[i]) for i in range(vals.shape[0])]
|
||||
|
||||
|
||||
def _encode_query(text: str) -> dict:
|
||||
feat = _tokenizer(
|
||||
[text],
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_token_type_ids=False,
|
||||
)
|
||||
input_ids = feat["input_ids"]
|
||||
batch_size = input_ids.shape[0]
|
||||
qv = torch.zeros(batch_size, _tokenizer.vocab_size)
|
||||
qv[torch.arange(batch_size).unsqueeze(-1), input_ids] = 1
|
||||
sparse_vector = qv * _idf
|
||||
return _sparse_to_dict(sparse_vector[0])
|
||||
|
||||
|
||||
def _sparse_to_dict(vector: torch.Tensor) -> dict:
|
||||
nz = torch.nonzero(vector, as_tuple=True)[0]
|
||||
return {"indices": nz.tolist(), "values": vector[nz].tolist()}
|
||||
|
||||
|
||||
@app.route("/health", methods=["GET"])
|
||||
def health():
|
||||
return jsonify(status="ok", model_loaded=True, model_repo=model_repo)
|
||||
|
||||
|
||||
@app.route("/encode/document", methods=["POST"])
|
||||
def encode_document():
|
||||
body = request.get_json(silent=True) or {}
|
||||
text = body.get("text", "")
|
||||
if not text:
|
||||
return jsonify(error="text is required"), 400
|
||||
return jsonify(_encode_document(text))
|
||||
|
||||
|
||||
@app.route("/encode/query", methods=["POST"])
|
||||
def encode_query():
|
||||
body = request.get_json(silent=True) or {}
|
||||
text = body.get("text", "")
|
||||
if not text:
|
||||
return jsonify(error="text is required"), 400
|
||||
return jsonify(_encode_query(text))
|
||||
|
||||
|
||||
@app.route("/encode/documents", methods=["POST"])
|
||||
def encode_documents():
|
||||
body = request.get_json(silent=True) or {}
|
||||
texts = body.get("texts", [])
|
||||
if not texts:
|
||||
return jsonify(error="texts is required"), 400
|
||||
return jsonify(_encode_documents(texts))
|
||||
|
||||
|
||||
def main():
|
||||
print(f"[sparse-service] loading model {model_repo}...", file=sys.stderr, flush=True)
|
||||
_load_model()
|
||||
print(f"[sparse-service] listening on port {port}", file=sys.stderr, flush=True)
|
||||
app.run(host="0.0.0.0", port=port, threaded=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,6 @@
|
||||
flask
|
||||
torch
|
||||
transformers
|
||||
huggingface_hub
|
||||
sentencepiece
|
||||
protobuf
|
||||
@@ -13,15 +13,17 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
"github.com/memohai/memoh/internal/mcp/mcpclient"
|
||||
)
|
||||
|
||||
const (
|
||||
memoryDateLayout = "2006-01-02"
|
||||
entryStartPrefix = "<!-- MEMOH:ENTRY "
|
||||
entryStartSuffix = " -->"
|
||||
entryEndMarker = "<!-- /MEMOH:ENTRY -->"
|
||||
memoryDateLayout = "2006-01-02"
|
||||
entryHeadingPrefix = "## Entry "
|
||||
yamlFence = "```yaml"
|
||||
codeFence = "```"
|
||||
)
|
||||
|
||||
var ErrNotConfigured = errors.New("memory filesystem not configured")
|
||||
@@ -49,6 +51,14 @@ type MemoryItem struct {
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
}
|
||||
|
||||
type memoryEntryMeta struct {
|
||||
ID string `yaml:"id"`
|
||||
Hash string `yaml:"hash,omitempty"`
|
||||
CreatedAt string `yaml:"created_at,omitempty"`
|
||||
UpdatedAt string `yaml:"updated_at,omitempty"`
|
||||
Metadata map[string]any `yaml:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func New(log *slog.Logger, provider mcpclient.Provider) *Service {
|
||||
if log == nil {
|
||||
log = slog.Default()
|
||||
@@ -123,13 +133,16 @@ func (s *Service) buildScanIndex(ctx context.Context, botID string) (map[string]
|
||||
}
|
||||
parsed, parseErr := parseMemoryDayMD(content)
|
||||
if parseErr != nil {
|
||||
legacy, legacyErr := parseLegacyMemoryMD(content)
|
||||
if legacyErr != nil {
|
||||
jsonItems, jsonErr := parseJSONMemoryItems(content)
|
||||
if jsonErr != nil {
|
||||
s.logger.Warn("buildScanIndex: failed to parse memory file",
|
||||
slog.String("bot_id", botID), slog.String("path", entryPath), slog.Any("error", parseErr))
|
||||
continue
|
||||
}
|
||||
parsed = []MemoryItem{legacy}
|
||||
if err := s.writeMemoryDay(ctx, botID, entryPath, jsonItems); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsed = jsonItems
|
||||
}
|
||||
for _, item := range parsed {
|
||||
id := strings.TrimSpace(item.ID)
|
||||
@@ -239,11 +252,10 @@ func (s *Service) RemoveMemories(ctx context.Context, botID string, ids []string
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
targets := make([]string, 0, 2)
|
||||
targets := make([]string, 0, 1)
|
||||
if entry, ok := index[id]; ok {
|
||||
targets = append(targets, entry.FilePath)
|
||||
}
|
||||
targets = append(targets, memoryLegacyItemPath(id))
|
||||
for _, target := range targets {
|
||||
if removals[target] == nil {
|
||||
removals[target] = map[string]struct{}{}
|
||||
@@ -295,11 +307,14 @@ func (s *Service) ReadAllMemoryFiles(ctx context.Context, botID string) ([]Memor
|
||||
}
|
||||
parsed, parseErr := parseMemoryDayMD(content)
|
||||
if parseErr != nil {
|
||||
legacy, legacyErr := parseLegacyMemoryMD(content)
|
||||
if legacyErr != nil {
|
||||
jsonItems, jsonErr := parseJSONMemoryItems(content)
|
||||
if jsonErr != nil {
|
||||
continue
|
||||
}
|
||||
parsed = []MemoryItem{legacy}
|
||||
if err := s.writeMemoryDay(ctx, botID, entryPath, jsonItems); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsed = jsonItems
|
||||
}
|
||||
for _, item := range parsed {
|
||||
if strings.TrimSpace(item.ID) == "" {
|
||||
@@ -318,6 +333,31 @@ func (s *Service) ReadAllMemoryFiles(ctx context.Context, botID string) ([]Memor
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (s *Service) CountMemoryFiles(ctx context.Context, botID string) (int, error) {
|
||||
if s.provider == nil {
|
||||
return 0, ErrNotConfigured
|
||||
}
|
||||
c, err := s.client(ctx, botID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
entries, err := c.ListDir(ctx, memoryDirPath(), false)
|
||||
if err != nil {
|
||||
if isNotFound(err) {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
count := 0
|
||||
for _, entry := range entries {
|
||||
if entry.GetIsDir() || !strings.HasSuffix(entry.GetPath(), ".md") {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (s *Service) SyncOverview(ctx context.Context, botID string) error {
|
||||
if s.provider == nil {
|
||||
return ErrNotConfigured
|
||||
@@ -341,11 +381,14 @@ func (s *Service) readMemoryDay(ctx context.Context, botID, filePath string) ([]
|
||||
if parseErr == nil {
|
||||
return items, nil
|
||||
}
|
||||
legacy, legacyErr := parseLegacyMemoryMD(content)
|
||||
if legacyErr != nil {
|
||||
jsonItems, jsonErr := parseJSONMemoryItems(content)
|
||||
if jsonErr != nil {
|
||||
return []MemoryItem{}, nil
|
||||
}
|
||||
return []MemoryItem{legacy}, nil
|
||||
if err := s.writeMemoryDay(ctx, botID, filePath, jsonItems); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jsonItems, nil
|
||||
}
|
||||
|
||||
func (s *Service) writeMemoryDay(ctx context.Context, botID, filePath string, items []MemoryItem) error {
|
||||
@@ -393,10 +436,6 @@ func memoryDayPath(date string) string {
|
||||
return path.Join(memoryDirPath(), strings.TrimSpace(date)+".md")
|
||||
}
|
||||
|
||||
func memoryLegacyItemPath(id string) string {
|
||||
return path.Join(memoryDirPath(), strings.TrimSpace(id)+".md")
|
||||
}
|
||||
|
||||
// --- format / parse helpers ---
|
||||
|
||||
func formatMemoryDayMD(date string, items []MemoryItem) string {
|
||||
@@ -417,24 +456,24 @@ func formatMemoryDayMD(date string, items []MemoryItem) string {
|
||||
if item.ID == "" || item.Memory == "" {
|
||||
continue
|
||||
}
|
||||
meta := map[string]string{"id": item.ID}
|
||||
if item.Hash != "" {
|
||||
meta["hash"] = item.Hash
|
||||
meta := memoryEntryMeta{
|
||||
ID: item.ID,
|
||||
Hash: item.Hash,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
Metadata: item.Metadata,
|
||||
}
|
||||
if item.CreatedAt != "" {
|
||||
meta["created_at"] = item.CreatedAt
|
||||
}
|
||||
if item.UpdatedAt != "" {
|
||||
meta["updated_at"] = item.UpdatedAt
|
||||
}
|
||||
rawMeta, _ := json.Marshal(meta)
|
||||
b.WriteString(entryStartPrefix)
|
||||
b.Write(rawMeta)
|
||||
b.WriteString(entryStartSuffix)
|
||||
rawMeta, _ := yaml.Marshal(meta)
|
||||
b.WriteString(entryHeadingPrefix)
|
||||
b.WriteString(item.ID)
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(yamlFence)
|
||||
b.WriteString("\n")
|
||||
b.WriteString(strings.TrimSpace(string(rawMeta)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(codeFence)
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(item.Memory)
|
||||
b.WriteString("\n")
|
||||
b.WriteString(entryEndMarker)
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
return b.String()
|
||||
@@ -449,35 +488,52 @@ func parseMemoryDayMD(content string) ([]MemoryItem, error) {
|
||||
items := make([]MemoryItem, 0, 8)
|
||||
for i := 0; i < len(lines); i++ {
|
||||
line := strings.TrimSpace(lines[i])
|
||||
if !strings.HasPrefix(line, entryStartPrefix) || !strings.HasSuffix(line, entryStartSuffix) {
|
||||
if !strings.HasPrefix(line, entryHeadingPrefix) {
|
||||
continue
|
||||
}
|
||||
metaJSON := strings.TrimSuffix(strings.TrimPrefix(line, entryStartPrefix), entryStartSuffix)
|
||||
var meta map[string]string
|
||||
if err := json.Unmarshal([]byte(metaJSON), &meta); err != nil {
|
||||
entryID := strings.TrimSpace(strings.TrimPrefix(line, entryHeadingPrefix))
|
||||
j := i + 1
|
||||
for ; j < len(lines) && strings.TrimSpace(lines[j]) == ""; j++ {
|
||||
}
|
||||
if j >= len(lines) || strings.TrimSpace(lines[j]) != yamlFence {
|
||||
continue
|
||||
}
|
||||
start := i + 1
|
||||
end := start
|
||||
for ; end < len(lines); end++ {
|
||||
if strings.TrimSpace(lines[end]) == entryEndMarker {
|
||||
metaStart := j + 1
|
||||
metaEnd := metaStart
|
||||
for ; metaEnd < len(lines); metaEnd++ {
|
||||
if strings.TrimSpace(lines[metaEnd]) == codeFence {
|
||||
break
|
||||
}
|
||||
}
|
||||
if end >= len(lines) {
|
||||
if metaEnd >= len(lines) {
|
||||
break
|
||||
}
|
||||
var meta memoryEntryMeta
|
||||
if err := yaml.Unmarshal([]byte(strings.Join(lines[metaStart:metaEnd], "\n")), &meta); err != nil {
|
||||
continue
|
||||
}
|
||||
bodyStart := metaEnd + 1
|
||||
if bodyStart < len(lines) && strings.TrimSpace(lines[bodyStart]) == "" {
|
||||
bodyStart++
|
||||
}
|
||||
bodyEnd := bodyStart
|
||||
for ; bodyEnd < len(lines); bodyEnd++ {
|
||||
if strings.HasPrefix(strings.TrimSpace(lines[bodyEnd]), entryHeadingPrefix) {
|
||||
break
|
||||
}
|
||||
}
|
||||
item := MemoryItem{
|
||||
ID: strings.TrimSpace(meta["id"]),
|
||||
Hash: strings.TrimSpace(meta["hash"]),
|
||||
CreatedAt: strings.TrimSpace(meta["created_at"]),
|
||||
UpdatedAt: strings.TrimSpace(meta["updated_at"]),
|
||||
Memory: strings.TrimSpace(strings.Join(lines[start:end], "\n")),
|
||||
ID: firstNonEmpty(meta.ID, entryID),
|
||||
Hash: strings.TrimSpace(meta.Hash),
|
||||
CreatedAt: strings.TrimSpace(meta.CreatedAt),
|
||||
UpdatedAt: strings.TrimSpace(meta.UpdatedAt),
|
||||
Metadata: meta.Metadata,
|
||||
Memory: strings.TrimSpace(strings.Join(lines[bodyStart:bodyEnd], "\n")),
|
||||
}
|
||||
if item.ID != "" && item.Memory != "" {
|
||||
items = append(items, item)
|
||||
}
|
||||
i = end
|
||||
i = bodyEnd - 1
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return nil, errors.New("no memory entries found")
|
||||
@@ -485,36 +541,78 @@ func parseMemoryDayMD(content string) ([]MemoryItem, error) {
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func parseLegacyMemoryMD(content string) (MemoryItem, error) {
|
||||
type jsonMemoryRecord struct {
|
||||
Topic string `json:"topic"`
|
||||
ID string `json:"id"`
|
||||
Memory string `json:"memory"`
|
||||
Text string `json:"text"`
|
||||
Content string `json:"content"`
|
||||
Hash string `json:"hash"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
func parseJSONMemoryItems(content string) ([]MemoryItem, error) {
|
||||
content = strings.TrimSpace(content)
|
||||
if !strings.HasPrefix(content, "---") {
|
||||
return MemoryItem{}, errors.New("missing frontmatter")
|
||||
if content == "" {
|
||||
return nil, errors.New("empty memory file")
|
||||
}
|
||||
parts := strings.SplitN(content[3:], "---", 2)
|
||||
if len(parts) < 2 {
|
||||
return MemoryItem{}, errors.New("incomplete frontmatter")
|
||||
var list []jsonMemoryRecord
|
||||
if err := json.Unmarshal([]byte(content), &list); err == nil {
|
||||
return normalizeJSONMemoryItems(list), nil
|
||||
}
|
||||
item := MemoryItem{Memory: strings.TrimSpace(parts[1])}
|
||||
for _, line := range strings.Split(strings.TrimSpace(parts[0]), "\n") {
|
||||
key, value, found := strings.Cut(strings.TrimSpace(line), ":")
|
||||
if !found {
|
||||
var obj jsonMemoryRecord
|
||||
if err := json.Unmarshal([]byte(content), &obj); err == nil {
|
||||
return normalizeJSONMemoryItems([]jsonMemoryRecord{obj}), nil
|
||||
}
|
||||
var wrapped struct {
|
||||
Items []jsonMemoryRecord `json:"items"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(content), &wrapped); err == nil {
|
||||
return normalizeJSONMemoryItems(wrapped.Items), nil
|
||||
}
|
||||
return nil, errors.New("not json memory format")
|
||||
}
|
||||
|
||||
func normalizeJSONMemoryItems(records []jsonMemoryRecord) []MemoryItem {
|
||||
now := time.Now().UTC()
|
||||
items := make([]MemoryItem, 0, len(records))
|
||||
for _, record := range records {
|
||||
text := strings.TrimSpace(record.Memory)
|
||||
if text == "" {
|
||||
text = strings.TrimSpace(record.Content)
|
||||
}
|
||||
if text == "" {
|
||||
text = strings.TrimSpace(record.Text)
|
||||
}
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
switch strings.TrimSpace(key) {
|
||||
case "id":
|
||||
item.ID = strings.TrimSpace(value)
|
||||
case "hash":
|
||||
item.Hash = strings.TrimSpace(value)
|
||||
case "created_at":
|
||||
item.CreatedAt = strings.TrimSpace(value)
|
||||
case "updated_at":
|
||||
item.UpdatedAt = strings.TrimSpace(value)
|
||||
item := MemoryItem{
|
||||
ID: strings.TrimSpace(record.ID),
|
||||
Hash: strings.TrimSpace(record.Hash),
|
||||
CreatedAt: strings.TrimSpace(record.CreatedAt),
|
||||
UpdatedAt: strings.TrimSpace(record.UpdatedAt),
|
||||
Memory: text,
|
||||
}
|
||||
if item.ID == "" {
|
||||
item.ID = "mem_" + strconv.FormatInt(now.UnixNano(), 10)
|
||||
}
|
||||
if item.CreatedAt == "" {
|
||||
item.CreatedAt = now.Format(time.RFC3339)
|
||||
}
|
||||
if item.UpdatedAt == "" {
|
||||
item.UpdatedAt = item.CreatedAt
|
||||
}
|
||||
if item.Hash == "" {
|
||||
item.Hash = "json_" + strconv.FormatInt(now.UnixNano(), 10)
|
||||
}
|
||||
if topic := strings.TrimSpace(record.Topic); topic != "" {
|
||||
item.Metadata = map[string]any{"topic": topic}
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
if item.ID == "" {
|
||||
return MemoryItem{}, errors.New("missing id in frontmatter")
|
||||
}
|
||||
return item, nil
|
||||
return items
|
||||
}
|
||||
|
||||
func formatMemoryOverviewMD(items []MemoryItem) string {
|
||||
@@ -638,3 +736,13 @@ func memoryTime(item MemoryItem) time.Time {
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
value = strings.TrimSpace(value)
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -12,12 +12,14 @@ func TestFormatAndParseMemoryDayMD_Roundtrip(t *testing.T) {
|
||||
Memory: "second record",
|
||||
Hash: "h2",
|
||||
CreatedAt: "2026-03-01T11:15:00Z",
|
||||
Metadata: map[string]any{"topic": "Notes"},
|
||||
},
|
||||
{
|
||||
ID: "mem_1",
|
||||
Memory: "first record",
|
||||
Hash: "h1",
|
||||
CreatedAt: "2026-03-01T09:40:00Z",
|
||||
Metadata: map[string]any{"topic": "Decision"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -25,6 +27,12 @@ func TestFormatAndParseMemoryDayMD_Roundtrip(t *testing.T) {
|
||||
if !strings.Contains(md, "# Memory 2026-03-01") {
|
||||
t.Fatalf("expected header in markdown: %s", md)
|
||||
}
|
||||
if !strings.Contains(md, "## Entry mem_1") {
|
||||
t.Fatalf("expected entry heading in markdown: %s", md)
|
||||
}
|
||||
if !strings.Contains(md, "```yaml") {
|
||||
t.Fatalf("expected yaml block in markdown: %s", md)
|
||||
}
|
||||
|
||||
parsed, err := parseMemoryDayMD(md)
|
||||
if err != nil {
|
||||
@@ -37,28 +45,63 @@ func TestFormatAndParseMemoryDayMD_Roundtrip(t *testing.T) {
|
||||
if parsed[0].ID != "mem_1" || parsed[1].ID != "mem_2" {
|
||||
t.Fatalf("unexpected order after roundtrip: %#v", parsed)
|
||||
}
|
||||
if got := parsed[0].Metadata["topic"]; got != "Decision" {
|
||||
t.Fatalf("expected metadata preserved, got %#v", parsed[0].Metadata)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLegacyMemoryMD(t *testing.T) {
|
||||
legacy := `---
|
||||
id: mem_legacy
|
||||
hash: legacyhash
|
||||
created_at: 2026-03-01T09:00:00Z
|
||||
updated_at: 2026-03-01T10:00:00Z
|
||||
---
|
||||
legacy content`
|
||||
func TestParseJSONMemoryItems(t *testing.T) {
|
||||
raw := `[
|
||||
{
|
||||
"id": "mem_json",
|
||||
"topic": "Decision",
|
||||
"memory": "Choose provider architecture."
|
||||
}
|
||||
]`
|
||||
|
||||
item, err := parseLegacyMemoryMD(legacy)
|
||||
items, err := parseJSONMemoryItems(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("parseLegacyMemoryMD error: %v", err)
|
||||
t.Fatalf("parseJSONMemoryItems error: %v", err)
|
||||
}
|
||||
if item.ID != "mem_legacy" {
|
||||
t.Fatalf("unexpected id: %#v", item)
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("expected 1 item, got %d", len(items))
|
||||
}
|
||||
if item.Hash != "legacyhash" {
|
||||
t.Fatalf("unexpected hash: %#v", item)
|
||||
if items[0].ID != "mem_json" {
|
||||
t.Fatalf("unexpected id: %#v", items[0])
|
||||
}
|
||||
if item.Memory != "legacy content" {
|
||||
t.Fatalf("unexpected memory body: %#v", item)
|
||||
if got := items[0].Metadata["topic"]; got != "Decision" {
|
||||
t.Fatalf("expected topic metadata, got %#v", items[0].Metadata)
|
||||
}
|
||||
if items[0].Memory != "Choose provider architecture." {
|
||||
t.Fatalf("unexpected memory body: %#v", items[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONMemoryItemsCanBeFormattedToCanonicalMarkdown(t *testing.T) {
|
||||
raw := `[
|
||||
{
|
||||
"id": "mem_json",
|
||||
"topic": "Decision",
|
||||
"memory": "Choose provider architecture."
|
||||
}
|
||||
]`
|
||||
|
||||
items, err := parseJSONMemoryItems(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONMemoryItems error: %v", err)
|
||||
}
|
||||
md := formatMemoryDayMD("2026-03-01", items)
|
||||
if !strings.Contains(md, "## Entry mem_json") {
|
||||
t.Fatalf("expected canonical heading, got: %s", md)
|
||||
}
|
||||
if !strings.Contains(md, "topic: Decision") {
|
||||
t.Fatalf("expected yaml metadata topic, got: %s", md)
|
||||
}
|
||||
parsed, err := parseMemoryDayMD(md)
|
||||
if err != nil {
|
||||
t.Fatalf("parseMemoryDayMD error: %v", err)
|
||||
}
|
||||
if len(parsed) != 1 || parsed[0].ID != "mem_json" {
|
||||
t.Fatalf("unexpected parsed canonical items: %#v", parsed)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user