mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
refactor(memory): replace sdk to twilight
This commit is contained in:
+17
-14
@@ -66,6 +66,7 @@ import (
|
||||
membuiltin "github.com/memohai/memoh/internal/memory/adapters/builtin"
|
||||
memmem0 "github.com/memohai/memoh/internal/memory/adapters/mem0"
|
||||
memopenviking "github.com/memohai/memoh/internal/memory/adapters/openviking"
|
||||
"github.com/memohai/memoh/internal/memory/memllm"
|
||||
storefs "github.com/memohai/memoh/internal/memory/storefs"
|
||||
"github.com/memohai/memoh/internal/message"
|
||||
"github.com/memohai/memoh/internal/message/event"
|
||||
@@ -354,7 +355,7 @@ func provideMemoryLLM(modelsService *models.Service, queries *dbsqlc.Queries, lo
|
||||
}
|
||||
}
|
||||
|
||||
func provideMemoryProviderRegistry(log *slog.Logger, chatService *conversation.Service, accountService *accounts.Service, manager *workspace.Manager, queries *dbsqlc.Queries, cfg config.Config) *memprovider.Registry {
|
||||
func provideMemoryProviderRegistry(log *slog.Logger, llm memprovider.LLM, chatService *conversation.Service, accountService *accounts.Service, manager *workspace.Manager, queries *dbsqlc.Queries, cfg config.Config) *memprovider.Registry {
|
||||
registry := memprovider.NewRegistry(log)
|
||||
fileRuntime := handlers.NewBuiltinMemoryRuntime(manager)
|
||||
fileStore := storefs.New(log, manager)
|
||||
@@ -363,7 +364,10 @@ func provideMemoryProviderRegistry(log *slog.Logger, chatService *conversation.S
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return membuiltin.NewBuiltinProvider(log, runtime, chatService, accountService), nil
|
||||
p := membuiltin.NewBuiltinProvider(log, runtime, chatService, accountService)
|
||||
p.SetLLM(llm)
|
||||
p.ApplyProviderConfig(providerConfig)
|
||||
return p, nil
|
||||
})
|
||||
registry.RegisterFactory(string(memprovider.ProviderMem0), func(_ string, providerConfig map[string]any) (memprovider.Provider, error) {
|
||||
return memmem0.NewMem0Provider(log, providerConfig, fileStore)
|
||||
@@ -371,7 +375,9 @@ func provideMemoryProviderRegistry(log *slog.Logger, chatService *conversation.S
|
||||
registry.RegisterFactory(string(memprovider.ProviderOpenViking), func(_ string, providerConfig map[string]any) (memprovider.Provider, error) {
|
||||
return memopenviking.NewOpenVikingProvider(log, providerConfig)
|
||||
})
|
||||
registry.Register("__builtin_default__", membuiltin.NewBuiltinProvider(log, fileRuntime, chatService, accountService))
|
||||
defaultProvider := membuiltin.NewBuiltinProvider(log, fileRuntime, chatService, accountService)
|
||||
defaultProvider.SetLLM(llm)
|
||||
registry.Register("__builtin_default__", defaultProvider)
|
||||
return registry
|
||||
}
|
||||
|
||||
@@ -1049,20 +1055,17 @@ func (c *lazyLLMClient) resolve(ctx context.Context) (memprovider.LLM, error) {
|
||||
if c.modelsService == nil || c.queries == nil {
|
||||
return nil, errors.New("models service not configured")
|
||||
}
|
||||
botID := ""
|
||||
memoryModel, memoryProvider, err := models.SelectMemoryModelForBot(ctx, c.modelsService, c.queries, botID)
|
||||
memoryModel, memoryProvider, err := models.SelectMemoryModelForBot(ctx, c.modelsService, c.queries, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientType := memoryProvider.ClientType
|
||||
switch clientType {
|
||||
case "openai-responses", "openai-completions", "anthropic-messages", "google-generative-ai":
|
||||
default:
|
||||
return nil, fmt.Errorf("memory model client type not supported: %s", clientType)
|
||||
}
|
||||
_ = memoryProvider
|
||||
_ = memoryModel
|
||||
return nil, errors.New("memory llm runtime is not available")
|
||||
return memllm.New(memllm.Config{
|
||||
ModelID: memoryModel.ModelID,
|
||||
BaseURL: strings.TrimRight(memoryProvider.BaseUrl, "/"),
|
||||
APIKey: memoryProvider.ApiKey,
|
||||
ClientType: memoryProvider.ClientType,
|
||||
Timeout: c.timeout,
|
||||
}), nil
|
||||
}
|
||||
|
||||
// skillLoaderAdapter bridges handlers.ContainerdHandler to flow.SkillLoader.
|
||||
|
||||
+17
-14
@@ -67,6 +67,7 @@ import (
|
||||
membuiltin "github.com/memohai/memoh/internal/memory/adapters/builtin"
|
||||
memmem0 "github.com/memohai/memoh/internal/memory/adapters/mem0"
|
||||
memopenviking "github.com/memohai/memoh/internal/memory/adapters/openviking"
|
||||
"github.com/memohai/memoh/internal/memory/memllm"
|
||||
storefs "github.com/memohai/memoh/internal/memory/storefs"
|
||||
"github.com/memohai/memoh/internal/message"
|
||||
"github.com/memohai/memoh/internal/message/event"
|
||||
@@ -251,7 +252,7 @@ func provideMemoryLLM(modelsService *models.Service, queries *dbsqlc.Queries, lo
|
||||
return &lazyLLMClient{modelsService: modelsService, queries: queries, timeout: 30 * time.Second, logger: log}
|
||||
}
|
||||
|
||||
func provideMemoryProviderRegistry(log *slog.Logger, chatService *conversation.Service, accountService *accounts.Service, manager *workspace.Manager, queries *dbsqlc.Queries, cfg config.Config) *memprovider.Registry {
|
||||
func provideMemoryProviderRegistry(log *slog.Logger, llm memprovider.LLM, chatService *conversation.Service, accountService *accounts.Service, manager *workspace.Manager, queries *dbsqlc.Queries, cfg config.Config) *memprovider.Registry {
|
||||
registry := memprovider.NewRegistry(log)
|
||||
builtinRuntime := handlers.NewBuiltinMemoryRuntime(manager)
|
||||
fileStore := storefs.New(log, manager)
|
||||
@@ -260,7 +261,10 @@ func provideMemoryProviderRegistry(log *slog.Logger, chatService *conversation.S
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return membuiltin.NewBuiltinProvider(log, runtime, chatService, accountService), nil
|
||||
p := membuiltin.NewBuiltinProvider(log, runtime, chatService, accountService)
|
||||
p.SetLLM(llm)
|
||||
p.ApplyProviderConfig(providerConfig)
|
||||
return p, nil
|
||||
})
|
||||
registry.RegisterFactory(string(memprovider.ProviderMem0), func(_ string, config map[string]any) (memprovider.Provider, error) {
|
||||
return memmem0.NewMem0Provider(log, config, fileStore)
|
||||
@@ -268,7 +272,9 @@ func provideMemoryProviderRegistry(log *slog.Logger, chatService *conversation.S
|
||||
registry.RegisterFactory(string(memprovider.ProviderOpenViking), func(_ string, config map[string]any) (memprovider.Provider, error) {
|
||||
return memopenviking.NewOpenVikingProvider(log, config)
|
||||
})
|
||||
registry.Register("__builtin_default__", membuiltin.NewBuiltinProvider(log, builtinRuntime, chatService, accountService))
|
||||
defaultProvider := membuiltin.NewBuiltinProvider(log, builtinRuntime, chatService, accountService)
|
||||
defaultProvider.SetLLM(llm)
|
||||
registry.Register("__builtin_default__", defaultProvider)
|
||||
return registry
|
||||
}
|
||||
|
||||
@@ -977,20 +983,17 @@ func (c *lazyLLMClient) resolve(ctx context.Context) (memprovider.LLM, error) {
|
||||
if c.modelsService == nil || c.queries == nil {
|
||||
return nil, errors.New("models service not configured")
|
||||
}
|
||||
botID := ""
|
||||
memoryModel, memoryProvider, err := models.SelectMemoryModelForBot(ctx, c.modelsService, c.queries, botID)
|
||||
memoryModel, memoryProvider, err := models.SelectMemoryModelForBot(ctx, c.modelsService, c.queries, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientType := memoryProvider.ClientType
|
||||
switch clientType {
|
||||
case "openai-responses", "openai-completions", "anthropic-messages", "google-generative-ai":
|
||||
default:
|
||||
return nil, fmt.Errorf("memory model client type not supported: %s", clientType)
|
||||
}
|
||||
_ = memoryProvider
|
||||
_ = memoryModel
|
||||
return nil, errors.New("memory llm runtime is not available")
|
||||
return memllm.New(memllm.Config{
|
||||
ModelID: memoryModel.ModelID,
|
||||
BaseURL: strings.TrimRight(memoryProvider.BaseUrl, "/"),
|
||||
APIKey: memoryProvider.ApiKey,
|
||||
ClientType: memoryProvider.ClientType,
|
||||
Timeout: c.timeout,
|
||||
}), nil
|
||||
}
|
||||
|
||||
type skillLoaderAdapter struct{ handler *handlers.ContainerdHandler }
|
||||
|
||||
@@ -20,6 +20,9 @@ var (
|
||||
scheduleTmpl string
|
||||
heartbeatTmpl string
|
||||
|
||||
MemoryExtractPrompt string
|
||||
MemoryUpdatePrompt string
|
||||
|
||||
includes map[string]string
|
||||
)
|
||||
|
||||
@@ -32,6 +35,8 @@ func init() {
|
||||
systemSubagentTmpl = mustReadPrompt("prompts/system_subagent.md")
|
||||
scheduleTmpl = mustReadPrompt("prompts/schedule.md")
|
||||
heartbeatTmpl = mustReadPrompt("prompts/heartbeat.md")
|
||||
MemoryExtractPrompt = mustReadPrompt("prompts/memory_extract.md")
|
||||
MemoryUpdatePrompt = mustReadPrompt("prompts/memory_update.md")
|
||||
|
||||
includes = map[string]string{
|
||||
"_memory": mustReadPrompt("prompts/_memory.md"),
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
You are a Personal Information Organizer, specialized in accurately storing facts, user memories, and preferences. Your primary role is to extract relevant pieces of information from conversations and organize them into distinct, manageable facts. This allows for easy retrieval and personalization in future interactions. Below are the types of information you need to focus on and the detailed instructions on how to handle the input data.
|
||||
|
||||
Types of Information to Remember:
|
||||
|
||||
1. Store Personal Preferences: Keep track of likes, dislikes, and specific preferences in various categories such as food, products, activities, and entertainment.
|
||||
2. Maintain Important Personal Details: Remember significant personal information like names, relationships, and important dates.
|
||||
3. Track Plans and Intentions: Note upcoming events, trips, goals, and any plans the user has shared.
|
||||
4. Remember Activity and Service Preferences: Recall preferences for dining, travel, hobbies, and other services.
|
||||
5. Monitor Health and Wellness Preferences: Keep a record of dietary restrictions, fitness routines, and other wellness-related information.
|
||||
6. Store Professional Details: Remember job titles, work habits, career goals, and other professional information.
|
||||
7. Miscellaneous Information Management: Keep track of favorite books, movies, brands, and other miscellaneous details that the user shares.
|
||||
|
||||
Here are some few shot examples:
|
||||
|
||||
Input: Hi.
|
||||
Output: {"facts" : []}
|
||||
|
||||
Input: There are branches in trees.
|
||||
Output: {"facts" : []}
|
||||
|
||||
Input: Hi, I am looking for a restaurant in San Francisco.
|
||||
Output: {"facts" : ["Looking for a restaurant in San Francisco"]}
|
||||
|
||||
Input: Yesterday, I had a meeting with John at 3pm. We discussed the new project.
|
||||
Output: {"facts" : ["Had a meeting with John at 3pm", "Discussed the new project"]}
|
||||
|
||||
Input: Hi, my name is John. I am a software engineer.
|
||||
Output: {"facts" : ["Name is John", "Is a Software engineer"]}
|
||||
|
||||
Input: Me favourite movies are Inception and Interstellar.
|
||||
Output: {"facts" : ["Favourite movies are Inception and Interstellar"]}
|
||||
|
||||
Return the facts and preferences in a json format as shown above.
|
||||
|
||||
Remember the following:
|
||||
- Today's date is {{today}}.
|
||||
- Do not return anything from the custom few shot example prompts provided above.
|
||||
- If you do not find anything relevant in the below conversation, you can return an empty list corresponding to the "facts" key.
|
||||
- Create the facts based on the user and assistant messages only. Do not pick anything from the system messages.
|
||||
- Make sure to return the response in the format mentioned in the examples. The response should be in json with a key as "facts" and corresponding value will be a list of strings.
|
||||
- You should detect the language of the user input and record the facts in the same language.
|
||||
|
||||
Following is a conversation between the user and the assistant. You have to extract the relevant facts and preferences about the user, if any, from the conversation and return them in the json format as shown above.
|
||||
@@ -0,0 +1,147 @@
|
||||
You are a smart memory manager which controls the memory of a system.
|
||||
You can perform four operations: (1) add into the memory, (2) update the memory, (3) delete from the memory, and (4) no change.
|
||||
|
||||
Based on the above four operations, the memory will change.
|
||||
|
||||
Compare newly retrieved facts with the existing memory. For each new fact, decide whether to:
|
||||
- ADD: Add it to the memory as a new element
|
||||
- UPDATE: Update an existing memory element
|
||||
- DELETE: Delete an existing memory element
|
||||
- NONE: Make no change (if the fact is already present or irrelevant)
|
||||
|
||||
There are specific guidelines to select which operation to perform:
|
||||
|
||||
1. **Add**: If the retrieved facts contain new information not present in the memory, then you have to add it by generating a new ID in the id field.
|
||||
- **Example**:
|
||||
- Old Memory:
|
||||
[
|
||||
{
|
||||
"id" : "0",
|
||||
"text" : "User is a software engineer"
|
||||
}
|
||||
]
|
||||
- Retrieved facts: ["Name is John"]
|
||||
- New Memory:
|
||||
{
|
||||
"memory" : [
|
||||
{
|
||||
"id" : "0",
|
||||
"text" : "User is a software engineer",
|
||||
"event" : "NONE"
|
||||
},
|
||||
{
|
||||
"id" : "1",
|
||||
"text" : "Name is John",
|
||||
"event" : "ADD"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
2. **Update**: If the retrieved facts contain information that is already present in the memory but the information is totally different, then you have to update it.
|
||||
If the retrieved fact contains information that conveys the same thing as the elements present in the memory, then you have to keep the fact which has the most information.
|
||||
Example (a) -- if the memory contains "User likes to play cricket" and the retrieved fact is "Loves to play cricket with friends", then update the memory with the retrieved facts.
|
||||
Example (b) -- if the memory contains "Likes cheese pizza" and the retrieved fact is "Loves cheese pizza", then you do not need to update it because they convey the same information.
|
||||
If the direction is to update the memory, then you have to update it.
|
||||
Please keep in mind while updating you have to keep the same ID.
|
||||
Please note to return the IDs in the output from the input IDs only and do not generate any new ID.
|
||||
- **Example**:
|
||||
- Old Memory:
|
||||
[
|
||||
{
|
||||
"id" : "0",
|
||||
"text" : "I really like cheese pizza"
|
||||
},
|
||||
{
|
||||
"id" : "1",
|
||||
"text" : "User is a software engineer"
|
||||
},
|
||||
{
|
||||
"id" : "2",
|
||||
"text" : "User likes to play cricket"
|
||||
}
|
||||
]
|
||||
- Retrieved facts: ["Loves chicken pizza", "Loves to play cricket with friends"]
|
||||
- New Memory:
|
||||
{
|
||||
"memory" : [
|
||||
{
|
||||
"id" : "0",
|
||||
"text" : "Loves cheese and chicken pizza",
|
||||
"event" : "UPDATE",
|
||||
"old_memory" : "I really like cheese pizza"
|
||||
},
|
||||
{
|
||||
"id" : "1",
|
||||
"text" : "User is a software engineer",
|
||||
"event" : "NONE"
|
||||
},
|
||||
{
|
||||
"id" : "2",
|
||||
"text" : "Loves to play cricket with friends",
|
||||
"event" : "UPDATE",
|
||||
"old_memory" : "User likes to play cricket"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
3. **Delete**: If the retrieved facts contain information that contradicts the information present in the memory, then you have to delete it. Or if the direction is to delete the memory, then you have to delete it.
|
||||
Please note to return the IDs in the output from the input IDs only and do not generate any new ID.
|
||||
- **Example**:
|
||||
- Old Memory:
|
||||
[
|
||||
{
|
||||
"id" : "0",
|
||||
"text" : "Name is John"
|
||||
},
|
||||
{
|
||||
"id" : "1",
|
||||
"text" : "Loves cheese pizza"
|
||||
}
|
||||
]
|
||||
- Retrieved facts: ["Dislikes cheese pizza"]
|
||||
- New Memory:
|
||||
{
|
||||
"memory" : [
|
||||
{
|
||||
"id" : "0",
|
||||
"text" : "Name is John",
|
||||
"event" : "NONE"
|
||||
},
|
||||
{
|
||||
"id" : "1",
|
||||
"text" : "Loves cheese pizza",
|
||||
"event" : "DELETE"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
4. **No Change**: If the retrieved facts contain information that is already present in the memory, then you do not need to make any changes.
|
||||
- **Example**:
|
||||
- Old Memory:
|
||||
[
|
||||
{
|
||||
"id" : "0",
|
||||
"text" : "Name is John"
|
||||
},
|
||||
{
|
||||
"id" : "1",
|
||||
"text" : "Loves cheese pizza"
|
||||
}
|
||||
]
|
||||
- Retrieved facts: ["Name is John"]
|
||||
- New Memory:
|
||||
{
|
||||
"memory" : [
|
||||
{
|
||||
"id" : "0",
|
||||
"text" : "Name is John",
|
||||
"event" : "NONE"
|
||||
},
|
||||
{
|
||||
"id" : "1",
|
||||
"text" : "Loves cheese pizza",
|
||||
"event" : "NONE"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -15,10 +15,7 @@ import (
|
||||
const (
|
||||
BuiltinType = "builtin"
|
||||
|
||||
sharedMemoryNamespace = "bot"
|
||||
memoryContextLimitPerScope = 4
|
||||
memoryContextMaxItems = 8
|
||||
memoryContextItemMaxChars = 220
|
||||
sharedMemoryNamespace = "bot"
|
||||
|
||||
defaultMemoryToolLimit = 8
|
||||
maxMemoryToolLimit = 50
|
||||
@@ -28,9 +25,11 @@ const (
|
||||
// BuiltinProvider wraps the existing Service as a Provider.
|
||||
type BuiltinProvider struct {
|
||||
service memoryRuntime
|
||||
llm adapters.LLM
|
||||
chatAccessor conversation.Accessor
|
||||
adminChecker AdminChecker
|
||||
logger *slog.Logger
|
||||
packer contextPackerConfig
|
||||
}
|
||||
|
||||
// memoryRuntime is the runtime memory backend required by the builtin provider.
|
||||
@@ -60,17 +59,97 @@ func NewBuiltinProvider(log *slog.Logger, service any, chatAccessor conversation
|
||||
if log == nil {
|
||||
log = slog.Default()
|
||||
}
|
||||
runtimeService, _ := service.(memoryRuntime)
|
||||
logger := log.With(slog.String("provider", BuiltinType))
|
||||
runtimeService, ok := service.(memoryRuntime)
|
||||
if service != nil && !ok {
|
||||
logger.Warn("service does not implement memoryRuntime; provider will operate without a backend")
|
||||
}
|
||||
return &BuiltinProvider{
|
||||
service: runtimeService,
|
||||
chatAccessor: chatAccessor,
|
||||
adminChecker: adminChecker,
|
||||
logger: log.With(slog.String("provider", BuiltinType)),
|
||||
logger: logger,
|
||||
packer: defaultPackerConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// SetLLM injects the LLM client used for Extract/Decide in memory formation.
|
||||
func (p *BuiltinProvider) SetLLM(llm adapters.LLM) {
|
||||
p.llm = llm
|
||||
}
|
||||
|
||||
// SetPackerConfig overrides the default context packing configuration.
|
||||
// Zero-valued fields fall back to defaults.
|
||||
func (p *BuiltinProvider) SetPackerConfig(cfg contextPackerConfig) {
|
||||
if cfg.TargetItems > 0 {
|
||||
p.packer.TargetItems = cfg.TargetItems
|
||||
}
|
||||
if cfg.MaxTotalChars > 0 {
|
||||
p.packer.MaxTotalChars = cfg.MaxTotalChars
|
||||
}
|
||||
if cfg.MinItemChars > 0 {
|
||||
p.packer.MinItemChars = cfg.MinItemChars
|
||||
}
|
||||
if cfg.MaxItemChars > 0 {
|
||||
p.packer.MaxItemChars = cfg.MaxItemChars
|
||||
}
|
||||
if cfg.OverfetchRatio > 0 {
|
||||
p.packer.OverfetchRatio = cfg.OverfetchRatio
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyProviderConfig reads context packing knobs from a provider config map
|
||||
// and applies any non-zero values to the provider's packer configuration.
|
||||
func (p *BuiltinProvider) ApplyProviderConfig(providerConfig map[string]any) {
|
||||
p.SetPackerConfig(contextPackerConfig{
|
||||
TargetItems: intFromConfig(providerConfig, "context_target_items"),
|
||||
MaxTotalChars: intFromConfig(providerConfig, "context_max_total_chars"),
|
||||
})
|
||||
}
|
||||
|
||||
func intFromConfig(m map[string]any, key string) int {
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
v, ok := m[key]
|
||||
if !ok || v == nil {
|
||||
return 0
|
||||
}
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return int(n)
|
||||
case int:
|
||||
return n
|
||||
case int64:
|
||||
return int(n)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (*BuiltinProvider) Type() string { return BuiltinType }
|
||||
|
||||
func memorySourceLabel(item adapters.MemoryItem) string {
|
||||
var parts []string
|
||||
if item.Metadata != nil {
|
||||
if name, ok := item.Metadata["profile_display_name"].(string); ok {
|
||||
name = strings.TrimSpace(name)
|
||||
if name != "" {
|
||||
parts = append(parts, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
if ts := strings.TrimSpace(item.CreatedAt); ts != "" {
|
||||
if len(ts) > 10 {
|
||||
ts = ts[:10]
|
||||
}
|
||||
parts = append(parts, ts)
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
// --- Conversation Hooks ---
|
||||
|
||||
func (p *BuiltinProvider) OnBeforeChat(ctx context.Context, req adapters.BeforeChatRequest) (*adapters.BeforeChatResult, error) {
|
||||
@@ -81,10 +160,11 @@ func (p *BuiltinProvider) OnBeforeChat(ctx context.Context, req adapters.BeforeC
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
fetchLimit := overfetchLimit(p.packer)
|
||||
resp, err := p.service.Search(ctx, adapters.SearchRequest{
|
||||
Query: req.Query,
|
||||
BotID: req.BotID,
|
||||
Limit: memoryContextLimitPerScope,
|
||||
Limit: fetchLimit,
|
||||
Filters: map[string]any{
|
||||
"namespace": sharedMemoryNamespace,
|
||||
"scopeId": req.BotID,
|
||||
@@ -97,48 +177,26 @@ func (p *BuiltinProvider) OnBeforeChat(ctx context.Context, req adapters.BeforeC
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
seen := map[string]struct{}{}
|
||||
type contextItem struct {
|
||||
namespace string
|
||||
item adapters.MemoryItem
|
||||
}
|
||||
results := make([]contextItem, 0, memoryContextLimitPerScope)
|
||||
for _, item := range resp.Results {
|
||||
key := strings.TrimSpace(item.ID)
|
||||
if key == "" {
|
||||
key = sharedMemoryNamespace + ":" + strings.TrimSpace(item.Memory)
|
||||
}
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
results = append(results, contextItem{namespace: sharedMemoryNamespace, item: item})
|
||||
}
|
||||
if len(results) == 0 {
|
||||
candidates := deduplicateAndSort(resp.Results)
|
||||
if len(candidates) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
return results[i].item.Score > results[j].item.Score
|
||||
})
|
||||
if len(results) > memoryContextMaxItems {
|
||||
results = results[:memoryContextMaxItems]
|
||||
packed := packContext(candidates, p.packer)
|
||||
if len(packed.Items) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("<memory-context>\nRelevant memory context (use when helpful):\n")
|
||||
for _, entry := range results {
|
||||
text := strings.TrimSpace(entry.item.Memory)
|
||||
if text == "" {
|
||||
continue
|
||||
for _, entry := range packed.Items {
|
||||
sb.WriteString("- ")
|
||||
if label := memorySourceLabel(entry.Item); label != "" {
|
||||
sb.WriteString("[")
|
||||
sb.WriteString(label)
|
||||
sb.WriteString("] ")
|
||||
}
|
||||
sb.WriteString("- [")
|
||||
sb.WriteString(entry.namespace)
|
||||
sb.WriteString("] ")
|
||||
sb.WriteString(adapters.TruncateSnippet(text, memoryContextItemMaxChars))
|
||||
sb.WriteString(entry.Snippet)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("</memory-context>")
|
||||
@@ -160,6 +218,21 @@ func (p *BuiltinProvider) OnAfterChat(ctx context.Context, req adapters.AfterCha
|
||||
if len(req.Messages) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if p.llm != nil {
|
||||
result := runFormation(ctx, p.logger, p.llm, p.service, req)
|
||||
p.logger.Debug("memory formation completed",
|
||||
slog.String("bot_id", botID),
|
||||
slog.Int("extracted", result.ExtractedFacts),
|
||||
slog.Int("added", result.Added),
|
||||
slog.Int("updated", result.Updated),
|
||||
slog.Int("deleted", result.Deleted),
|
||||
slog.Int("skipped", result.Skipped),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fallback: no LLM configured, store raw transcript (legacy path).
|
||||
filters := map[string]any{
|
||||
"namespace": sharedMemoryNamespace,
|
||||
"scopeId": botID,
|
||||
|
||||
@@ -1,56 +1,229 @@
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
"github.com/memohai/memoh/internal/memory/sparse"
|
||||
)
|
||||
|
||||
func TestTruncateSnippet_ASCII(t *testing.T) {
|
||||
func TestBuiltinProviderNilService(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := adapters.TruncateSnippet("hello world", 5)
|
||||
if got != "hello..." {
|
||||
t.Fatalf("expected %q, got %q", "hello...", got)
|
||||
p := NewBuiltinProvider(slog.Default(), nil, nil, nil)
|
||||
if p.Type() != BuiltinType {
|
||||
t.Fatalf("expected type %q, got %q", BuiltinType, p.Type())
|
||||
}
|
||||
|
||||
result, err := p.OnBeforeChat(context.Background(), adapters.BeforeChatRequest{
|
||||
BotID: "bot-1",
|
||||
Query: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("OnBeforeChat error: %v", err)
|
||||
}
|
||||
if result != nil {
|
||||
t.Fatalf("expected nil result for nil service, got %+v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateSnippet_NoTruncation(t *testing.T) {
|
||||
func TestBuiltinProviderOnBeforeChatEmptyQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := adapters.TruncateSnippet("short", 100)
|
||||
if got != "short" {
|
||||
t.Fatalf("expected %q, got %q", "short", got)
|
||||
encoder := &fakeSparseEncoder{}
|
||||
index := newFakeSparseIndex(encoder)
|
||||
store := newFakeSparseStore()
|
||||
runtime := &sparseRuntime{qdrant: index, encoder: encoder, store: store}
|
||||
p := NewBuiltinProvider(slog.Default(), runtime, nil, nil)
|
||||
|
||||
result, err := p.OnBeforeChat(context.Background(), adapters.BeforeChatRequest{
|
||||
BotID: "bot-1",
|
||||
Query: "",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("OnBeforeChat error: %v", err)
|
||||
}
|
||||
if result != nil {
|
||||
t.Fatal("expected nil result for empty query")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateSnippet_CJK(t *testing.T) {
|
||||
func TestBuiltinProviderContextPackingProducesMemoryContextTags(t *testing.T) {
|
||||
t.Parallel()
|
||||
// 5 CJK characters (15 bytes in UTF-8), truncate to 3 runes.
|
||||
got := adapters.TruncateSnippet("你好世界啊", 3)
|
||||
if !utf8.ValidString(got) {
|
||||
t.Fatalf("result is not valid UTF-8: %q", got)
|
||||
encoder := &fakeSparseEncoder{}
|
||||
index := newFakeSparseIndex(encoder)
|
||||
store := newFakeSparseStore()
|
||||
runtime := &sparseRuntime{qdrant: index, encoder: encoder, store: store}
|
||||
p := NewBuiltinProvider(slog.Default(), runtime, nil, nil)
|
||||
|
||||
_ = p.OnAfterChat(context.Background(), adapters.AfterChatRequest{
|
||||
BotID: "bot-1",
|
||||
Messages: []adapters.Message{{Role: "user", Content: "I like green tea"}},
|
||||
})
|
||||
_ = p.OnAfterChat(context.Background(), adapters.AfterChatRequest{
|
||||
BotID: "bot-1",
|
||||
Messages: []adapters.Message{{Role: "user", Content: "I work in Tokyo"}},
|
||||
})
|
||||
|
||||
result, err := p.OnBeforeChat(context.Background(), adapters.BeforeChatRequest{
|
||||
BotID: "bot-1",
|
||||
Query: "tea",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("OnBeforeChat error: %v", err)
|
||||
}
|
||||
if got != "你好世..." {
|
||||
t.Fatalf("expected %q, got %q", "你好世...", got)
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if !strings.Contains(result.ContextText, "<memory-context>") {
|
||||
t.Fatalf("expected memory-context tags, got: %s", result.ContextText)
|
||||
}
|
||||
if !strings.Contains(result.ContextText, "</memory-context>") {
|
||||
t.Fatalf("expected closing memory-context tag, got: %s", result.ContextText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateSnippet_Emoji(t *testing.T) {
|
||||
func TestBuiltinProviderApplyProviderConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Emoji are 4 bytes each in UTF-8.
|
||||
got := adapters.TruncateSnippet("😀😁😂🤣😃", 2)
|
||||
if !utf8.ValidString(got) {
|
||||
t.Fatalf("result is not valid UTF-8: %q", got)
|
||||
p := NewBuiltinProvider(slog.Default(), nil, nil, nil)
|
||||
|
||||
p.ApplyProviderConfig(map[string]any{
|
||||
"context_target_items": float64(10),
|
||||
"context_max_total_chars": float64(3000),
|
||||
})
|
||||
|
||||
if p.packer.TargetItems != 10 {
|
||||
t.Fatalf("expected TargetItems=10, got %d", p.packer.TargetItems)
|
||||
}
|
||||
if got != "😀😁..." {
|
||||
t.Fatalf("expected %q, got %q", "😀😁...", got)
|
||||
if p.packer.MaxTotalChars != 3000 {
|
||||
t.Fatalf("expected MaxTotalChars=3000, got %d", p.packer.MaxTotalChars)
|
||||
}
|
||||
if p.packer.MinItemChars != defaultPackerConfig.MinItemChars {
|
||||
t.Fatalf("expected MinItemChars to remain default, got %d", p.packer.MinItemChars)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateSnippet_TrimWhitespace(t *testing.T) {
|
||||
func TestBuiltinProviderApplyProviderConfigNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := adapters.TruncateSnippet(" hello ", 100)
|
||||
if got != "hello" {
|
||||
t.Fatalf("expected %q, got %q", "hello", got)
|
||||
p := NewBuiltinProvider(slog.Default(), nil, nil, nil)
|
||||
p.ApplyProviderConfig(nil)
|
||||
if p.packer.TargetItems != defaultPackerConfig.TargetItems {
|
||||
t.Fatalf("expected default TargetItems, got %d", p.packer.TargetItems)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntFromConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct {
|
||||
name string
|
||||
m map[string]any
|
||||
key string
|
||||
expected int
|
||||
}{
|
||||
{"float64", map[string]any{"k": float64(42)}, "k", 42},
|
||||
{"int", map[string]any{"k": 10}, "k", 10},
|
||||
{"missing", map[string]any{}, "k", 0},
|
||||
{"nil_map", nil, "k", 0},
|
||||
{"string_value", map[string]any{"k": "abc"}, "k", 0},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := intFromConfig(tc.m, tc.key)
|
||||
if got != tc.expected {
|
||||
t.Fatalf("expected %d, got %d", tc.expected, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltinProviderBadServiceTypeDoesNotPanic(t *testing.T) {
|
||||
t.Parallel()
|
||||
p := NewBuiltinProvider(slog.Default(), "not a runtime", nil, nil)
|
||||
if p.service != nil {
|
||||
t.Fatal("expected nil service for non-memoryRuntime value")
|
||||
}
|
||||
_, err := p.Search(context.Background(), adapters.SearchRequest{BotID: "b", Query: "q"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from nil service")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltinProviderCRUDErrorsWithNilService(t *testing.T) {
|
||||
t.Parallel()
|
||||
p := NewBuiltinProvider(slog.Default(), nil, nil, nil)
|
||||
if _, err := p.Add(context.Background(), adapters.AddRequest{}); err == nil {
|
||||
t.Fatal("expected Add error")
|
||||
}
|
||||
if _, err := p.GetAll(context.Background(), adapters.GetAllRequest{}); err == nil {
|
||||
t.Fatal("expected GetAll error")
|
||||
}
|
||||
if _, err := p.Update(context.Background(), adapters.UpdateRequest{}); err == nil {
|
||||
t.Fatal("expected Update error")
|
||||
}
|
||||
if _, err := p.Delete(context.Background(), "x"); err == nil {
|
||||
t.Fatal("expected Delete error")
|
||||
}
|
||||
if _, err := p.DeleteBatch(context.Background(), []string{"x"}); err == nil {
|
||||
t.Fatal("expected DeleteBatch error")
|
||||
}
|
||||
if _, err := p.DeleteAll(context.Background(), adapters.DeleteAllRequest{}); err == nil {
|
||||
t.Fatal("expected DeleteAll error")
|
||||
}
|
||||
if _, err := p.Compact(context.Background(), nil, 0.5, 0); err == nil {
|
||||
t.Fatal("expected Compact error")
|
||||
}
|
||||
if _, err := p.Usage(context.Background(), nil); err == nil {
|
||||
t.Fatal("expected Usage error")
|
||||
}
|
||||
if _, err := p.Status(context.Background(), "b"); err == nil {
|
||||
t.Fatal("expected Status error")
|
||||
}
|
||||
if _, err := p.Rebuild(context.Background(), "b"); err == nil {
|
||||
t.Fatal("expected Rebuild error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBuiltinRuntimeFromConfig_DefaultReturnsFileRuntime(t *testing.T) {
|
||||
t.Parallel()
|
||||
sentinel := "file-runtime-sentinel"
|
||||
rt, err := NewBuiltinRuntimeFromConfig(nil, nil, sentinel, nil, nil, defaultTestConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if rt != sentinel {
|
||||
t.Fatalf("expected file runtime sentinel, got %v", rt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBuiltinRuntimeFromConfig_DenseErrorPropagates(t *testing.T) {
|
||||
t.Parallel()
|
||||
cfg := map[string]any{"memory_mode": "dense"}
|
||||
_, err := NewBuiltinRuntimeFromConfig(nil, cfg, "fallback", nil, nil, defaultTestConfig())
|
||||
if err == nil {
|
||||
t.Fatal("expected error for dense mode without embedding_model_id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBuiltinRuntimeFromConfig_SparseErrorPropagates(t *testing.T) {
|
||||
t.Parallel()
|
||||
cfg := map[string]any{"memory_mode": "sparse"}
|
||||
_, err := NewBuiltinRuntimeFromConfig(nil, cfg, "fallback", nil, nil, defaultTestConfig())
|
||||
if err == nil {
|
||||
t.Fatal("expected error for sparse mode without encoder base URL")
|
||||
}
|
||||
}
|
||||
|
||||
func defaultTestConfig() config.Config {
|
||||
return config.Config{}
|
||||
}
|
||||
|
||||
// Fakes from sparse_runtime_test.go are in the same package and accessible.
|
||||
|
||||
var _ sparseEncoder = (*fakeSparseEncoder)(nil)
|
||||
|
||||
func init() {
|
||||
_ = sparse.SparseVector{}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,201 @@
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
)
|
||||
|
||||
// contextPackerConfig controls how memory items are packed into a context
|
||||
// window with a fixed character budget.
|
||||
type contextPackerConfig struct {
|
||||
TargetItems int // desired number of items in final context
|
||||
MaxTotalChars int // hard budget for combined snippet length
|
||||
MinItemChars int // minimum snippet length per item
|
||||
MaxItemChars int // maximum snippet length per item
|
||||
OverfetchRatio int // fetch TargetItems * OverfetchRatio candidates
|
||||
}
|
||||
|
||||
var defaultPackerConfig = contextPackerConfig{
|
||||
TargetItems: 6,
|
||||
MaxTotalChars: 1800,
|
||||
MinItemChars: 80,
|
||||
MaxItemChars: 360,
|
||||
OverfetchRatio: 3,
|
||||
}
|
||||
|
||||
// contextPackResult contains the items selected for context injection.
|
||||
type contextPackResult struct {
|
||||
Items []packedItem
|
||||
}
|
||||
|
||||
type packedItem struct {
|
||||
Item adapters.MemoryItem
|
||||
Snippet string
|
||||
}
|
||||
|
||||
// packContext selects and truncates memory items to fit within the character
|
||||
// budget. Items must already be deduplicated and sorted by score descending.
|
||||
//
|
||||
// The algorithm:
|
||||
// 1. Assign each item its full text or MaxItemChars, whichever is shorter.
|
||||
// 2. Walk items in score order; greedily include items while budget allows.
|
||||
// 3. If we haven't reached TargetItems, try compressing already-included
|
||||
// items to make room for more.
|
||||
// 4. Apply anti-lost-in-the-middle reordering: best items at head and tail.
|
||||
func packContext(items []adapters.MemoryItem, cfg contextPackerConfig) contextPackResult {
|
||||
if len(items) == 0 {
|
||||
return contextPackResult{}
|
||||
}
|
||||
if cfg.TargetItems <= 0 {
|
||||
cfg.TargetItems = defaultPackerConfig.TargetItems
|
||||
}
|
||||
if cfg.MaxTotalChars <= 0 {
|
||||
cfg.MaxTotalChars = defaultPackerConfig.MaxTotalChars
|
||||
}
|
||||
if cfg.MinItemChars <= 0 {
|
||||
cfg.MinItemChars = defaultPackerConfig.MinItemChars
|
||||
}
|
||||
if cfg.MaxItemChars <= 0 {
|
||||
cfg.MaxItemChars = defaultPackerConfig.MaxItemChars
|
||||
}
|
||||
|
||||
type candidate struct {
|
||||
item adapters.MemoryItem
|
||||
text string
|
||||
charLen int
|
||||
}
|
||||
candidates := make([]candidate, 0, len(items))
|
||||
for _, it := range items {
|
||||
text := strings.TrimSpace(it.Memory)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
runes := []rune(text)
|
||||
cl := len(runes)
|
||||
if cl > cfg.MaxItemChars {
|
||||
cl = cfg.MaxItemChars
|
||||
}
|
||||
candidates = append(candidates, candidate{item: it, text: text, charLen: cl})
|
||||
}
|
||||
|
||||
// Phase 1: greedily pack items at their natural (capped) length.
|
||||
selected := make([]candidate, 0, cfg.TargetItems)
|
||||
usedChars := 0
|
||||
for _, c := range candidates {
|
||||
if len(selected) >= cfg.TargetItems {
|
||||
break
|
||||
}
|
||||
if usedChars+c.charLen > cfg.MaxTotalChars {
|
||||
// Try with minimum length.
|
||||
if usedChars+cfg.MinItemChars > cfg.MaxTotalChars {
|
||||
continue
|
||||
}
|
||||
c.charLen = cfg.MinItemChars
|
||||
}
|
||||
selected = append(selected, c)
|
||||
usedChars += c.charLen
|
||||
}
|
||||
|
||||
// Phase 2: if we didn't reach TargetItems, try compressing existing items
|
||||
// to free budget for more candidates.
|
||||
if len(selected) < cfg.TargetItems && len(selected) < len(candidates) {
|
||||
for i := range selected {
|
||||
if selected[i].charLen > cfg.MinItemChars {
|
||||
freed := selected[i].charLen - cfg.MinItemChars
|
||||
selected[i].charLen = cfg.MinItemChars
|
||||
usedChars -= freed
|
||||
}
|
||||
}
|
||||
for _, c := range candidates[len(selected):] {
|
||||
if len(selected) >= cfg.TargetItems {
|
||||
break
|
||||
}
|
||||
needed := c.charLen
|
||||
if needed > cfg.MaxTotalChars-usedChars {
|
||||
needed = cfg.MaxTotalChars - usedChars
|
||||
}
|
||||
if needed < cfg.MinItemChars {
|
||||
continue
|
||||
}
|
||||
c.charLen = needed
|
||||
selected = append(selected, c)
|
||||
usedChars += needed
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: redistribute remaining budget to compressed items.
|
||||
remaining := cfg.MaxTotalChars - usedChars
|
||||
if remaining > 0 && len(selected) > 0 {
|
||||
perItem := remaining / len(selected)
|
||||
for i := range selected {
|
||||
textLen := len([]rune(selected[i].text))
|
||||
maxGrow := cfg.MaxItemChars - selected[i].charLen
|
||||
if maxGrow > textLen-selected[i].charLen {
|
||||
maxGrow = textLen - selected[i].charLen
|
||||
}
|
||||
if maxGrow <= 0 {
|
||||
continue
|
||||
}
|
||||
grow := perItem
|
||||
if grow > maxGrow {
|
||||
grow = maxGrow
|
||||
}
|
||||
selected[i].charLen += grow
|
||||
usedChars += grow
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 4: anti-lost-in-the-middle reordering.
|
||||
// Place best items at positions 0 and last; fill middle with the rest.
|
||||
reordered := antiLostInMiddle(selected)
|
||||
|
||||
result := make([]packedItem, 0, len(reordered))
|
||||
for _, c := range reordered {
|
||||
snippet := adapters.TruncateSnippet(c.text, c.charLen)
|
||||
result = append(result, packedItem{Item: c.item, Snippet: snippet})
|
||||
}
|
||||
return contextPackResult{Items: result}
|
||||
}
|
||||
|
||||
// antiLostInMiddle reorders candidates so the highest-scored items appear at
|
||||
// the beginning and end of the sequence, reducing the "lost in the middle"
|
||||
// effect observed in LLMs.
|
||||
func antiLostInMiddle[T any](items []T) []T {
|
||||
n := len(items)
|
||||
if n <= 2 {
|
||||
return items
|
||||
}
|
||||
out := make([]T, n)
|
||||
head, tail := 0, n-1
|
||||
for i, item := range items {
|
||||
if i%2 == 0 {
|
||||
out[head] = item
|
||||
head++
|
||||
} else {
|
||||
out[tail] = item
|
||||
tail--
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// overfetchLimit returns the number of candidates to request from the backend,
|
||||
// given the desired target item count and overfetch ratio.
|
||||
func overfetchLimit(cfg contextPackerConfig) int {
|
||||
limit := cfg.TargetItems * cfg.OverfetchRatio
|
||||
if limit < cfg.TargetItems {
|
||||
limit = cfg.TargetItems
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
// deduplicateAndSort removes duplicate items by ID, then sorts by score desc.
|
||||
func deduplicateAndSort(items []adapters.MemoryItem) []adapters.MemoryItem {
|
||||
deduped := adapters.DeduplicateItems(items)
|
||||
sort.Slice(deduped, func(i, j int) bool {
|
||||
return deduped[i].Score > deduped[j].Score
|
||||
})
|
||||
return deduped
|
||||
}
|
||||
@@ -0,0 +1,157 @@
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
)
|
||||
|
||||
func makeItems(texts ...string) []adapters.MemoryItem {
|
||||
items := make([]adapters.MemoryItem, len(texts))
|
||||
for i, text := range texts {
|
||||
items[i] = adapters.MemoryItem{
|
||||
ID: "id-" + text[:min(len(text), 8)],
|
||||
Memory: text,
|
||||
Score: float64(len(texts) - i),
|
||||
}
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func TestPackContext_BasicPacking(t *testing.T) {
|
||||
t.Parallel()
|
||||
items := makeItems("alpha", "bravo", "charlie", "delta", "echo", "foxtrot")
|
||||
cfg := contextPackerConfig{
|
||||
TargetItems: 4,
|
||||
MaxTotalChars: 2000,
|
||||
MinItemChars: 3,
|
||||
MaxItemChars: 100,
|
||||
OverfetchRatio: 2,
|
||||
}
|
||||
result := packContext(items, cfg)
|
||||
if len(result.Items) != 4 {
|
||||
t.Fatalf("expected 4 packed items, got %d", len(result.Items))
|
||||
}
|
||||
for _, pi := range result.Items {
|
||||
if pi.Snippet == "" {
|
||||
t.Fatal("expected non-empty snippet")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackContext_BudgetLimitsItems(t *testing.T) {
|
||||
t.Parallel()
|
||||
long := strings.Repeat("x", 500)
|
||||
items := makeItems(long, long, long, long, long)
|
||||
cfg := contextPackerConfig{
|
||||
TargetItems: 5,
|
||||
MaxTotalChars: 800,
|
||||
MinItemChars: 100,
|
||||
MaxItemChars: 500,
|
||||
OverfetchRatio: 2,
|
||||
}
|
||||
result := packContext(items, cfg)
|
||||
totalChars := 0
|
||||
for _, pi := range result.Items {
|
||||
totalChars += len([]rune(pi.Snippet))
|
||||
}
|
||||
if totalChars > cfg.MaxTotalChars+50 {
|
||||
t.Fatalf("total chars %d exceeds budget %d by too much", totalChars, cfg.MaxTotalChars)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackContext_CompressesToFitMore(t *testing.T) {
|
||||
t.Parallel()
|
||||
medium := strings.Repeat("m", 200)
|
||||
items := makeItems(medium, medium, medium, medium, medium, medium)
|
||||
cfg := contextPackerConfig{
|
||||
TargetItems: 6,
|
||||
MaxTotalChars: 600,
|
||||
MinItemChars: 50,
|
||||
MaxItemChars: 200,
|
||||
OverfetchRatio: 2,
|
||||
}
|
||||
result := packContext(items, cfg)
|
||||
if len(result.Items) < 3 {
|
||||
t.Fatalf("expected at least 3 items after compression, got %d", len(result.Items))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackContext_ShortItemsNotTruncated(t *testing.T) {
|
||||
t.Parallel()
|
||||
items := makeItems("hi", "yo", "ok")
|
||||
cfg := contextPackerConfig{
|
||||
TargetItems: 3,
|
||||
MaxTotalChars: 1000,
|
||||
MinItemChars: 10,
|
||||
MaxItemChars: 200,
|
||||
OverfetchRatio: 2,
|
||||
}
|
||||
result := packContext(items, cfg)
|
||||
if len(result.Items) != 3 {
|
||||
t.Fatalf("expected 3 items, got %d", len(result.Items))
|
||||
}
|
||||
for _, pi := range result.Items {
|
||||
if strings.HasSuffix(pi.Snippet, "...") {
|
||||
t.Fatalf("short item should not be truncated: %q", pi.Snippet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackContext_EmptyInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := packContext(nil, defaultPackerConfig)
|
||||
if len(result.Items) != 0 {
|
||||
t.Fatalf("expected 0 items for nil input, got %d", len(result.Items))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntiLostInMiddle_Reordering(t *testing.T) {
|
||||
t.Parallel()
|
||||
items := []int{1, 2, 3, 4, 5}
|
||||
reordered := antiLostInMiddle(items)
|
||||
if reordered[0] != 1 {
|
||||
t.Fatalf("expected first item to be 1, got %d", reordered[0])
|
||||
}
|
||||
if reordered[len(reordered)-1] != 2 {
|
||||
t.Fatalf("expected last item to be 2, got %d", reordered[len(reordered)-1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntiLostInMiddle_SmallSlice(t *testing.T) {
|
||||
t.Parallel()
|
||||
single := antiLostInMiddle([]string{"a"})
|
||||
if len(single) != 1 || single[0] != "a" {
|
||||
t.Fatalf("unexpected result for single item: %v", single)
|
||||
}
|
||||
pair := antiLostInMiddle([]string{"a", "b"})
|
||||
if len(pair) != 2 {
|
||||
t.Fatalf("unexpected result for pair: %v", pair)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOverfetchLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
cfg := contextPackerConfig{TargetItems: 5, OverfetchRatio: 3}
|
||||
if got := overfetchLimit(cfg); got != 15 {
|
||||
t.Fatalf("expected 15, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeduplicateAndSort(t *testing.T) {
|
||||
t.Parallel()
|
||||
items := []adapters.MemoryItem{
|
||||
{ID: "a", Score: 1.0, Memory: "first"},
|
||||
{ID: "b", Score: 3.0, Memory: "second"},
|
||||
{ID: "a", Score: 2.0, Memory: "duplicate"},
|
||||
{ID: "c", Score: 2.5, Memory: "third"},
|
||||
}
|
||||
result := deduplicateAndSort(items)
|
||||
if len(result) != 3 {
|
||||
t.Fatalf("expected 3 items after dedup, got %d", len(result))
|
||||
}
|
||||
if result[0].ID != "b" {
|
||||
t.Fatalf("expected highest score first, got %q", result[0].ID)
|
||||
}
|
||||
}
|
||||
@@ -1,49 +1,36 @@
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
qdrantclient "github.com/memohai/memoh/internal/memory/qdrant"
|
||||
storefs "github.com/memohai/memoh/internal/memory/storefs"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
)
|
||||
|
||||
const denseEmbedTimeout = 30 * time.Second
|
||||
|
||||
type denseRuntime struct {
|
||||
qdrant *qdrantclient.Client
|
||||
store *storefs.Service
|
||||
embedder *denseEmbeddingClient
|
||||
collection string
|
||||
}
|
||||
|
||||
type denseEmbeddingClient struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
modelID string
|
||||
embedModel *sdk.EmbeddingModel
|
||||
dimensions int
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
type denseEmbeddingResponse struct {
|
||||
Data []struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
} `json:"data"`
|
||||
collection string
|
||||
}
|
||||
|
||||
type denseModelSpec struct {
|
||||
@@ -66,7 +53,7 @@ func newDenseRuntime(providerConfig map[string]any, queries *dbsqlc.Queries, cfg
|
||||
return nil, errors.New("dense runtime: embedding_model_id is required")
|
||||
}
|
||||
|
||||
modelSpec, err := resolveDenseEmbeddingModel(context.Background(), queries, modelRef)
|
||||
spec, err := resolveDenseEmbeddingModel(context.Background(), queries, modelRef)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -87,63 +74,105 @@ func newDenseRuntime(providerConfig map[string]any, queries *dbsqlc.Queries, cfg
|
||||
return nil, fmt.Errorf("dense runtime: %w", err)
|
||||
}
|
||||
|
||||
embedModel := models.NewSDKEmbeddingModel(spec.baseURL, spec.apiKey, spec.modelID, denseEmbedTimeout)
|
||||
|
||||
return &denseRuntime{
|
||||
qdrant: qClient,
|
||||
store: store,
|
||||
embedder: &denseEmbeddingClient{
|
||||
baseURL: strings.TrimRight(modelSpec.baseURL, "/"),
|
||||
apiKey: modelSpec.apiKey,
|
||||
modelID: modelSpec.modelID,
|
||||
dimensions: modelSpec.dimensions,
|
||||
httpClient: &http.Client{Timeout: 30 * time.Second},
|
||||
},
|
||||
qdrant: qClient,
|
||||
store: store,
|
||||
embedModel: embedModel,
|
||||
dimensions: spec.dimensions,
|
||||
collection: collection,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// --- embedder helpers using Twilight SDK ---
|
||||
|
||||
func (r *denseRuntime) embedQuery(ctx context.Context, text string) ([]float32, error) {
|
||||
client := sdk.NewClient()
|
||||
vec, err := client.Embed(ctx, text, sdk.WithEmbeddingModel(r.embedModel))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dense embed query: %w", err)
|
||||
}
|
||||
return float64sToFloat32s(vec), nil
|
||||
}
|
||||
|
||||
func (r *denseRuntime) embedDocuments(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
client := sdk.NewClient()
|
||||
result, err := client.EmbedMany(ctx, texts, sdk.WithEmbeddingModel(r.embedModel))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dense embed documents: %w", err)
|
||||
}
|
||||
out := make([][]float32, len(result.Embeddings))
|
||||
for i, emb := range result.Embeddings {
|
||||
out[i] = float64sToFloat32s(emb)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// embedHealth performs a minimal smoke-test embedding to verify that the
|
||||
// configured embedding model is reachable and functional.
|
||||
func (r *denseRuntime) embedHealth(ctx context.Context) error {
|
||||
client := sdk.NewClient()
|
||||
_, err := client.Embed(ctx, "health", sdk.WithEmbeddingModel(r.embedModel))
|
||||
if err != nil {
|
||||
return fmt.Errorf("dense embedding health check failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func float64sToFloat32s(in []float64) []float32 {
|
||||
out := make([]float32, len(in))
|
||||
for i, v := range in {
|
||||
out[i] = float32(v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// --- memoryRuntime interface ---
|
||||
|
||||
func (r *denseRuntime) Add(ctx context.Context, req adapters.AddRequest) (adapters.SearchResponse, error) {
|
||||
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
|
||||
botID, err := runtimeBotID(req.BotID, req.Filters)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
text := sparseRuntimeText(req.Message, req.Messages)
|
||||
text := runtimeText(req.Message, req.Messages)
|
||||
if text == "" {
|
||||
return adapters.SearchResponse{}, errors.New("dense runtime: message is required")
|
||||
}
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
item := adapters.MemoryItem{
|
||||
ID: sparseRuntimeMemoryID(botID, time.Now().UTC()),
|
||||
ID: runtimeMemoryID(botID, time.Now().UTC()),
|
||||
Memory: text,
|
||||
Hash: denseRuntimeHash(text),
|
||||
Hash: runtimeHash(text),
|
||||
Metadata: req.Metadata,
|
||||
BotID: botID,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if err := r.store.PersistMemories(ctx, botID, []storefs.MemoryItem{denseStoreItemFromMemoryItem(item)}, req.Filters); err != nil {
|
||||
if err := r.store.PersistMemories(ctx, botID, []storefs.MemoryItem{storeItemFromMemoryItem(item)}, req.Filters); err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
if err := r.upsertSourceItems(ctx, botID, []storefs.MemoryItem{denseStoreItemFromMemoryItem(item)}); err != nil {
|
||||
if err := r.upsertSourceItems(ctx, botID, []storefs.MemoryItem{storeItemFromMemoryItem(item)}); err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
return adapters.SearchResponse{Results: []adapters.MemoryItem{item}}, nil
|
||||
}
|
||||
|
||||
func (r *denseRuntime) Search(ctx context.Context, req adapters.SearchRequest) (adapters.SearchResponse, error) {
|
||||
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
|
||||
botID, err := runtimeBotID(req.BotID, req.Filters)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
if err := r.qdrant.EnsureDenseCollection(ctx, r.embedder.dimensions); err != nil {
|
||||
if err := r.qdrant.EnsureDenseCollection(ctx, r.dimensions); err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
limit := req.Limit
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
vec, err := r.embedder.EmbedQuery(ctx, req.Query)
|
||||
vec, err := r.embedQuery(ctx, req.Query)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, fmt.Errorf("dense embed query: %w", err)
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
results, err := r.qdrant.SearchDense(ctx, qdrantclient.DenseVector{Values: vec}, botID, limit)
|
||||
if err != nil {
|
||||
@@ -151,13 +180,13 @@ func (r *denseRuntime) Search(ctx context.Context, req adapters.SearchRequest) (
|
||||
}
|
||||
items := make([]adapters.MemoryItem, 0, len(results))
|
||||
for _, result := range results {
|
||||
items = append(items, denseResultToItem(result))
|
||||
items = append(items, resultToItem(result))
|
||||
}
|
||||
return adapters.SearchResponse{Results: items}, nil
|
||||
}
|
||||
|
||||
func (r *denseRuntime) GetAll(ctx context.Context, req adapters.GetAllRequest) (adapters.SearchResponse, error) {
|
||||
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
|
||||
botID, err := runtimeBotID(req.BotID, req.Filters)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
@@ -167,7 +196,7 @@ func (r *denseRuntime) GetAll(ctx context.Context, req adapters.GetAllRequest) (
|
||||
}
|
||||
result := make([]adapters.MemoryItem, 0, len(items))
|
||||
for _, item := range items {
|
||||
mem := denseMemoryItemFromStore(item)
|
||||
mem := memoryItemFromStore(item)
|
||||
mem.BotID = botID
|
||||
result = append(result, mem)
|
||||
}
|
||||
@@ -187,7 +216,7 @@ func (r *denseRuntime) Update(ctx context.Context, req adapters.UpdateRequest) (
|
||||
if text == "" {
|
||||
return adapters.MemoryItem{}, errors.New("dense runtime: memory is required")
|
||||
}
|
||||
botID := sparseRuntimeBotIDFromMemoryID(memoryID)
|
||||
botID := runtimeBotIDFromMemoryID(memoryID)
|
||||
if botID == "" {
|
||||
return adapters.MemoryItem{}, errors.New("dense runtime: invalid memory_id")
|
||||
}
|
||||
@@ -207,7 +236,7 @@ func (r *denseRuntime) Update(ctx context.Context, req adapters.UpdateRequest) (
|
||||
return adapters.MemoryItem{}, errors.New("dense runtime: memory not found")
|
||||
}
|
||||
existing.Memory = text
|
||||
existing.Hash = denseRuntimeHash(text)
|
||||
existing.Hash = runtimeHash(text)
|
||||
existing.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
if err := r.store.PersistMemories(ctx, botID, []storefs.MemoryItem{*existing}, nil); err != nil {
|
||||
return adapters.MemoryItem{}, err
|
||||
@@ -215,7 +244,7 @@ func (r *denseRuntime) Update(ctx context.Context, req adapters.UpdateRequest) (
|
||||
if err := r.upsertSourceItems(ctx, botID, []storefs.MemoryItem{*existing}); err != nil {
|
||||
return adapters.MemoryItem{}, err
|
||||
}
|
||||
item := denseMemoryItemFromStore(*existing)
|
||||
item := memoryItemFromStore(*existing)
|
||||
item.BotID = botID
|
||||
return item, nil
|
||||
}
|
||||
@@ -232,12 +261,12 @@ func (r *denseRuntime) DeleteBatch(ctx context.Context, memoryIDs []string) (ada
|
||||
if memoryID == "" {
|
||||
continue
|
||||
}
|
||||
botID := sparseRuntimeBotIDFromMemoryID(memoryID)
|
||||
botID := runtimeBotIDFromMemoryID(memoryID)
|
||||
if botID == "" {
|
||||
continue
|
||||
}
|
||||
grouped[botID] = append(grouped[botID], memoryID)
|
||||
pointIDs = append(pointIDs, sparsePointID(botID, memoryID))
|
||||
pointIDs = append(pointIDs, runtimePointID(botID, memoryID))
|
||||
}
|
||||
for botID, ids := range grouped {
|
||||
if err := r.store.RemoveMemories(ctx, botID, ids); err != nil {
|
||||
@@ -251,7 +280,7 @@ func (r *denseRuntime) DeleteBatch(ctx context.Context, memoryIDs []string) (ada
|
||||
}
|
||||
|
||||
func (r *denseRuntime) DeleteAll(ctx context.Context, req adapters.DeleteAllRequest) (adapters.DeleteResponse, error) {
|
||||
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
|
||||
botID, err := runtimeBotID(req.BotID, req.Filters)
|
||||
if err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
@@ -265,7 +294,7 @@ func (r *denseRuntime) DeleteAll(ctx context.Context, req adapters.DeleteAllRequ
|
||||
}
|
||||
|
||||
func (r *denseRuntime) Compact(ctx context.Context, filters map[string]any, ratio float64, _ int) (adapters.CompactResult, error) {
|
||||
botID, err := sparseRuntimeBotID("", filters)
|
||||
botID, err := runtimeBotID("", filters)
|
||||
if err != nil {
|
||||
return adapters.CompactResult{}, err
|
||||
}
|
||||
@@ -297,7 +326,7 @@ func (r *denseRuntime) Compact(ctx context.Context, filters map[string]any, rati
|
||||
}
|
||||
kept := make([]adapters.MemoryItem, 0, len(keptStore))
|
||||
for _, item := range keptStore {
|
||||
kept = append(kept, denseMemoryItemFromStore(item))
|
||||
kept = append(kept, memoryItemFromStore(item))
|
||||
}
|
||||
return adapters.CompactResult{
|
||||
BeforeCount: before,
|
||||
@@ -308,7 +337,7 @@ func (r *denseRuntime) Compact(ctx context.Context, filters map[string]any, rati
|
||||
}
|
||||
|
||||
func (r *denseRuntime) Usage(ctx context.Context, filters map[string]any) (adapters.UsageResponse, error) {
|
||||
botID, err := sparseRuntimeBotID("", filters)
|
||||
botID, err := runtimeBotID("", filters)
|
||||
if err != nil {
|
||||
return adapters.UsageResponse{}, err
|
||||
}
|
||||
@@ -351,7 +380,7 @@ func (r *denseRuntime) Status(ctx context.Context, botID string) (adapters.Memor
|
||||
SourceCount: len(items),
|
||||
QdrantCollection: r.collection,
|
||||
}
|
||||
if err := r.embedder.Health(ctx); err != nil {
|
||||
if err := r.embedHealth(ctx); err != nil {
|
||||
status.Encoder.Error = err.Error()
|
||||
} else {
|
||||
status.Encoder.OK = true
|
||||
@@ -386,7 +415,7 @@ func (r *denseRuntime) Rebuild(ctx context.Context, botID string) (adapters.Rebu
|
||||
}
|
||||
|
||||
func (r *denseRuntime) syncSourceItems(ctx context.Context, botID string, items []storefs.MemoryItem) (adapters.RebuildResult, error) {
|
||||
if err := r.qdrant.EnsureDenseCollection(ctx, r.embedder.dimensions); err != nil {
|
||||
if err := r.qdrant.EnsureDenseCollection(ctx, r.dimensions); err != nil {
|
||||
return adapters.RebuildResult{}, err
|
||||
}
|
||||
existing, err := r.qdrant.Scroll(ctx, botID, 10000)
|
||||
@@ -408,12 +437,12 @@ func (r *denseRuntime) syncSourceItems(ctx context.Context, botID string, items
|
||||
missingCount := 0
|
||||
restoredCount := 0
|
||||
for _, item := range items {
|
||||
item = denseCanonicalStoreItem(item)
|
||||
item = canonicalStoreItem(item)
|
||||
if item.ID == "" || item.Memory == "" {
|
||||
continue
|
||||
}
|
||||
sourceIDs[item.ID] = struct{}{}
|
||||
payload := densePayload(botID, item)
|
||||
payload := runtimePayload(botID, item)
|
||||
existingItem, ok := existingBySource[item.ID]
|
||||
if !ok {
|
||||
missingCount++
|
||||
@@ -421,7 +450,7 @@ func (r *denseRuntime) syncSourceItems(ctx context.Context, botID string, items
|
||||
toUpsert = append(toUpsert, item)
|
||||
continue
|
||||
}
|
||||
if !densePayloadMatches(existingItem.Payload, payload) {
|
||||
if !payloadMatches(existingItem.Payload, payload) {
|
||||
restoredCount++
|
||||
toUpsert = append(toUpsert, item)
|
||||
}
|
||||
@@ -463,13 +492,13 @@ func (r *denseRuntime) upsertSourceItems(ctx context.Context, botID string, item
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := r.qdrant.EnsureDenseCollection(ctx, r.embedder.dimensions); err != nil {
|
||||
if err := r.qdrant.EnsureDenseCollection(ctx, r.dimensions); err != nil {
|
||||
return err
|
||||
}
|
||||
canonical := make([]storefs.MemoryItem, 0, len(items))
|
||||
texts := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
item = denseCanonicalStoreItem(item)
|
||||
item = canonicalStoreItem(item)
|
||||
if item.ID == "" || item.Memory == "" {
|
||||
continue
|
||||
}
|
||||
@@ -479,17 +508,17 @@ func (r *denseRuntime) upsertSourceItems(ctx context.Context, botID string, item
|
||||
if len(canonical) == 0 {
|
||||
return nil
|
||||
}
|
||||
vectors, err := r.embedder.EmbedDocuments(ctx, texts)
|
||||
vectors, err := r.embedDocuments(ctx, texts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dense embed documents: %w", err)
|
||||
return err
|
||||
}
|
||||
if len(vectors) != len(canonical) {
|
||||
return fmt.Errorf("dense embed documents: expected %d vectors, got %d", len(canonical), len(vectors))
|
||||
}
|
||||
for i, item := range canonical {
|
||||
if err := r.qdrant.UpsertDense(ctx, sparsePointID(botID, item.ID), qdrantclient.DenseVector{
|
||||
if err := r.qdrant.UpsertDense(ctx, runtimePointID(botID, item.ID), qdrantclient.DenseVector{
|
||||
Values: vectors[i],
|
||||
}, densePayload(botID, item)); err != nil {
|
||||
}, runtimePayload(botID, item)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -542,201 +571,9 @@ func resolveDenseEmbeddingModel(ctx context.Context, queries *dbsqlc.Queries, mo
|
||||
}, nil
|
||||
}
|
||||
|
||||
func joinDenseEmbeddingEndpointURL(baseURL, endpointPath string) (string, error) {
|
||||
baseURL = strings.TrimSpace(baseURL)
|
||||
if baseURL == "" {
|
||||
return "", errors.New("dense embedding base URL is required")
|
||||
}
|
||||
// --- shared helpers (used by both dense and sparse runtimes) ---
|
||||
|
||||
base, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid dense embedding base URL: %w", err)
|
||||
}
|
||||
if base.Scheme != "http" && base.Scheme != "https" {
|
||||
return "", fmt.Errorf("invalid dense embedding base URL scheme: %q", base.Scheme)
|
||||
}
|
||||
if base.Host == "" {
|
||||
return "", errors.New("invalid dense embedding base URL: host is required")
|
||||
}
|
||||
|
||||
ref, err := url.Parse(endpointPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid dense embedding path: %w", err)
|
||||
}
|
||||
return base.ResolveReference(ref).String(), nil
|
||||
}
|
||||
|
||||
func (c *denseEmbeddingClient) Health(ctx context.Context) error {
|
||||
endpoint, err := joinDenseEmbeddingEndpointURL(c.baseURL, "/models")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
}
|
||||
resp, err := c.httpClient.Do(req) //nolint:gosec // G704: URL is validated and derived from operator-configured embedding provider base URL
|
||||
if err != nil {
|
||||
return fmt.Errorf("dense embedding health check failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("dense embedding health error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *denseEmbeddingClient) EmbedQuery(ctx context.Context, text string) ([]float32, error) {
|
||||
vectors, err := c.EmbedDocuments(ctx, []string{text})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vectors) == 0 {
|
||||
return nil, errors.New("dense embed query: empty embedding response")
|
||||
}
|
||||
return vectors[0], nil
|
||||
}
|
||||
|
||||
func (c *denseEmbeddingClient) EmbedDocuments(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"model": c.modelID,
|
||||
"input": texts,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
endpoint, err := joinDenseEmbeddingEndpointURL(c.baseURL, "/embeddings")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if c.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
}
|
||||
resp, err := c.httpClient.Do(req) //nolint:gosec // G704: URL is validated and derived from operator-configured embedding provider base URL
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dense embed request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dense embed read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("dense embed api error %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
var parsed denseEmbeddingResponse
|
||||
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("dense embed decode response: %w", err)
|
||||
}
|
||||
vectors := make([][]float32, len(parsed.Data))
|
||||
for _, item := range parsed.Data {
|
||||
if item.Index >= 0 && item.Index < len(vectors) {
|
||||
vectors[item.Index] = item.Embedding
|
||||
}
|
||||
}
|
||||
out := make([][]float32, 0, len(vectors))
|
||||
for _, vector := range vectors {
|
||||
if len(vector) > 0 {
|
||||
out = append(out, vector)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func denseCanonicalStoreItem(item storefs.MemoryItem) storefs.MemoryItem {
|
||||
item.ID = strings.TrimSpace(item.ID)
|
||||
item.Memory = strings.TrimSpace(item.Memory)
|
||||
if item.Memory != "" && strings.TrimSpace(item.Hash) == "" {
|
||||
item.Hash = denseRuntimeHash(item.Memory)
|
||||
}
|
||||
return item
|
||||
}
|
||||
|
||||
func densePayload(botID string, item storefs.MemoryItem) map[string]string {
|
||||
item = denseCanonicalStoreItem(item)
|
||||
payload := map[string]string{
|
||||
"memory": item.Memory,
|
||||
"bot_id": strings.TrimSpace(botID),
|
||||
"source_entry_id": item.ID,
|
||||
"hash": item.Hash,
|
||||
}
|
||||
if item.CreatedAt != "" {
|
||||
payload["created_at"] = item.CreatedAt
|
||||
}
|
||||
if item.UpdatedAt != "" {
|
||||
payload["updated_at"] = item.UpdatedAt
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func densePayloadMatches(existing, expected map[string]string) bool {
|
||||
for key, value := range expected {
|
||||
if strings.TrimSpace(existing[key]) != strings.TrimSpace(value) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func denseStoreItemFromMemoryItem(item adapters.MemoryItem) storefs.MemoryItem {
|
||||
return denseCanonicalStoreItem(storefs.MemoryItem{
|
||||
ID: item.ID,
|
||||
Memory: item.Memory,
|
||||
Hash: item.Hash,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
Score: item.Score,
|
||||
Metadata: item.Metadata,
|
||||
BotID: item.BotID,
|
||||
AgentID: item.AgentID,
|
||||
RunID: item.RunID,
|
||||
})
|
||||
}
|
||||
|
||||
func denseMemoryItemFromStore(item storefs.MemoryItem) adapters.MemoryItem {
|
||||
item = denseCanonicalStoreItem(item)
|
||||
return adapters.MemoryItem{
|
||||
ID: item.ID,
|
||||
Memory: item.Memory,
|
||||
Hash: item.Hash,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
Score: item.Score,
|
||||
Metadata: item.Metadata,
|
||||
BotID: item.BotID,
|
||||
AgentID: item.AgentID,
|
||||
RunID: item.RunID,
|
||||
}
|
||||
}
|
||||
|
||||
func denseResultToItem(r qdrantclient.SearchResult) adapters.MemoryItem {
|
||||
item := adapters.MemoryItem{
|
||||
ID: r.ID,
|
||||
Score: r.Score,
|
||||
}
|
||||
if r.Payload != nil {
|
||||
if sourceID := strings.TrimSpace(r.Payload["source_entry_id"]); sourceID != "" {
|
||||
item.ID = sourceID
|
||||
}
|
||||
item.Memory = r.Payload["memory"]
|
||||
item.Hash = r.Payload["hash"]
|
||||
item.BotID = r.Payload["bot_id"]
|
||||
item.CreatedAt = r.Payload["created_at"]
|
||||
item.UpdatedAt = r.Payload["updated_at"]
|
||||
}
|
||||
return item
|
||||
}
|
||||
|
||||
func denseRuntimeHash(text string) string {
|
||||
func runtimeHash(text string) string {
|
||||
sum := sha256.Sum256([]byte(strings.TrimSpace(text)))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
@@ -20,9 +20,12 @@ const (
|
||||
ModeDense BuiltinMemoryMode = "dense"
|
||||
)
|
||||
|
||||
// NewBuiltinRuntimeFromConfig returns the appropriate memoryRuntime based on the
|
||||
// provider's persisted config (memory_mode field). Falls back to the file runtime for "off" or unknown.
|
||||
func NewBuiltinRuntimeFromConfig(log *slog.Logger, providerConfig map[string]any, fileRuntime any, store *storefs.Service, queries *dbsqlc.Queries, cfg config.Config) (any, error) {
|
||||
// NewBuiltinRuntimeFromConfig returns the appropriate memoryRuntime based on
|
||||
// the provider's persisted config (memory_mode field). Returns the file
|
||||
// runtime for "off" or unknown modes. Returns an error if a sparse or dense
|
||||
// runtime was explicitly requested but failed to initialise, so that callers
|
||||
// can surface configuration problems rather than silently degrading.
|
||||
func NewBuiltinRuntimeFromConfig(_ *slog.Logger, providerConfig map[string]any, fileRuntime any, store *storefs.Service, queries *dbsqlc.Queries, cfg config.Config) (any, error) {
|
||||
mode := BuiltinMemoryMode(strings.TrimSpace(adapters.StringFromConfig(providerConfig, "memory_mode")))
|
||||
|
||||
switch mode {
|
||||
@@ -38,7 +41,7 @@ func NewBuiltinRuntimeFromConfig(log *slog.Logger, providerConfig map[string]any
|
||||
if collection == "" {
|
||||
collection = "memory_sparse"
|
||||
}
|
||||
rt, err := newSparseRuntime(
|
||||
return newSparseRuntime(
|
||||
host,
|
||||
port,
|
||||
cfg.Qdrant.APIKey,
|
||||
@@ -46,23 +49,9 @@ func NewBuiltinRuntimeFromConfig(log *slog.Logger, providerConfig map[string]any
|
||||
strings.TrimSpace(cfg.Sparse.BaseURL),
|
||||
store,
|
||||
)
|
||||
if err != nil {
|
||||
if log != nil {
|
||||
log.Warn("sparse runtime init failed, falling back to file runtime", slog.Any("error", err))
|
||||
}
|
||||
return fileRuntime, nil
|
||||
}
|
||||
return rt, nil
|
||||
|
||||
case ModeDense:
|
||||
rt, err := newDenseRuntime(providerConfig, queries, cfg, store)
|
||||
if err != nil {
|
||||
if log != nil {
|
||||
log.Warn("dense runtime init failed, falling back to file runtime", slog.Any("error", err))
|
||||
}
|
||||
return fileRuntime, nil
|
||||
}
|
||||
return rt, nil
|
||||
return newDenseRuntime(providerConfig, queries, cfg, store)
|
||||
|
||||
default:
|
||||
return fileRuntime, nil
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
)
|
||||
|
||||
const (
|
||||
formationTimeout = 60 * time.Second
|
||||
candidateSearchLimit = 20
|
||||
candidateGetAllLimit = 50
|
||||
maxCandidatesPerDecide = 30
|
||||
|
||||
actionADD = "ADD"
|
||||
actionUPDATE = "UPDATE"
|
||||
actionDELETE = "DELETE"
|
||||
actionNOOP = "NOOP"
|
||||
)
|
||||
|
||||
// formationResult holds the outcome of a memory formation cycle.
|
||||
type formationResult struct {
|
||||
ExtractedFacts int
|
||||
Added int
|
||||
Updated int
|
||||
Deleted int
|
||||
Skipped int
|
||||
}
|
||||
|
||||
// runFormation executes the Extract -> candidate retrieval -> Decide -> apply pipeline.
|
||||
func runFormation(ctx context.Context, logger *slog.Logger, llm adapters.LLM, runtime memoryRuntime, req adapters.AfterChatRequest) formationResult {
|
||||
ctx, cancel := context.WithTimeout(ctx, formationTimeout)
|
||||
defer cancel()
|
||||
|
||||
botID := strings.TrimSpace(req.BotID)
|
||||
result := formationResult{}
|
||||
|
||||
extracted, err := llm.Extract(ctx, adapters.ExtractRequest{
|
||||
Messages: req.Messages,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn("memory formation: extract failed", slog.String("bot_id", botID), slog.Any("error", err))
|
||||
return result
|
||||
}
|
||||
facts := filterNonEmpty(extracted.Facts)
|
||||
if len(facts) == 0 {
|
||||
return result
|
||||
}
|
||||
result.ExtractedFacts = len(facts)
|
||||
|
||||
candidates := gatherCandidates(ctx, logger, runtime, botID, facts)
|
||||
|
||||
decided, err := llm.Decide(ctx, adapters.DecideRequest{
|
||||
Facts: facts,
|
||||
Candidates: candidates,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn("memory formation: decide failed", slog.String("bot_id", botID), slog.Any("error", err))
|
||||
return result
|
||||
}
|
||||
|
||||
filters := map[string]any{
|
||||
"namespace": sharedMemoryNamespace,
|
||||
"scopeId": botID,
|
||||
"bot_id": botID,
|
||||
}
|
||||
metadata := adapters.BuildProfileMetadata(req.UserID, req.ChannelIdentityID, req.DisplayName)
|
||||
|
||||
applyActions(ctx, logger, runtime, botID, decided.Actions, filters, metadata, &result)
|
||||
return result
|
||||
}
|
||||
|
||||
// gatherCandidates collects existing memories relevant to the extracted facts.
|
||||
func gatherCandidates(ctx context.Context, logger *slog.Logger, runtime memoryRuntime, botID string, facts []string) []adapters.CandidateMemory {
|
||||
seen := make(map[string]struct{})
|
||||
candidates := make([]adapters.CandidateMemory, 0, candidateSearchLimit)
|
||||
|
||||
filters := map[string]any{
|
||||
"namespace": sharedMemoryNamespace,
|
||||
"scopeId": botID,
|
||||
"bot_id": botID,
|
||||
}
|
||||
|
||||
for _, fact := range facts {
|
||||
if len(candidates) >= maxCandidatesPerDecide {
|
||||
break
|
||||
}
|
||||
resp, err := runtime.Search(ctx, adapters.SearchRequest{
|
||||
Query: fact,
|
||||
BotID: botID,
|
||||
Limit: candidateSearchLimit / max(len(facts), 1),
|
||||
Filters: filters,
|
||||
NoStats: true,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Debug("memory formation: search candidates failed", slog.String("bot_id", botID), slog.Any("error", err))
|
||||
continue
|
||||
}
|
||||
for _, item := range resp.Results {
|
||||
id := strings.TrimSpace(item.ID)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
candidates = append(candidates, adapters.CandidateMemory{
|
||||
ID: id,
|
||||
Memory: item.Memory,
|
||||
CreatedAt: item.CreatedAt,
|
||||
Metadata: item.Metadata,
|
||||
})
|
||||
if len(candidates) >= maxCandidatesPerDecide {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(candidates) < maxCandidatesPerDecide {
|
||||
resp, err := runtime.GetAll(ctx, adapters.GetAllRequest{
|
||||
BotID: botID,
|
||||
Limit: candidateGetAllLimit,
|
||||
Filters: filters,
|
||||
NoStats: true,
|
||||
})
|
||||
if err == nil {
|
||||
for _, item := range resp.Results {
|
||||
id := strings.TrimSpace(item.ID)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
candidates = append(candidates, adapters.CandidateMemory{
|
||||
ID: id,
|
||||
Memory: item.Memory,
|
||||
CreatedAt: item.CreatedAt,
|
||||
Metadata: item.Metadata,
|
||||
})
|
||||
if len(candidates) >= maxCandidatesPerDecide {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return candidates
|
||||
}
|
||||
|
||||
// applyActions executes the decided CRUD actions against the runtime.
|
||||
func applyActions(ctx context.Context, logger *slog.Logger, runtime memoryRuntime, botID string, actions []adapters.DecisionAction, filters map[string]any, metadata map[string]any, result *formationResult) {
|
||||
deleted := make(map[string]struct{})
|
||||
updated := make(map[string]struct{})
|
||||
|
||||
for _, action := range actions {
|
||||
event := strings.ToUpper(strings.TrimSpace(action.Event))
|
||||
switch event {
|
||||
case actionADD:
|
||||
text := strings.TrimSpace(action.Text)
|
||||
if text == "" {
|
||||
logger.Debug("memory formation: ADD skipped (empty text)", slog.String("bot_id", botID))
|
||||
result.Skipped++
|
||||
continue
|
||||
}
|
||||
if _, err := runtime.Add(ctx, adapters.AddRequest{
|
||||
Message: text,
|
||||
BotID: botID,
|
||||
Metadata: metadata,
|
||||
Filters: filters,
|
||||
}); err != nil {
|
||||
logger.Warn("memory formation: ADD failed", slog.String("bot_id", botID), slog.Any("error", err))
|
||||
} else {
|
||||
result.Added++
|
||||
}
|
||||
|
||||
case actionUPDATE:
|
||||
id := strings.TrimSpace(action.ID)
|
||||
text := strings.TrimSpace(action.Text)
|
||||
if id == "" || text == "" {
|
||||
logger.Debug("memory formation: UPDATE skipped (missing id or text)", slog.String("bot_id", botID))
|
||||
result.Skipped++
|
||||
continue
|
||||
}
|
||||
if _, ok := updated[id]; ok {
|
||||
result.Skipped++
|
||||
continue
|
||||
}
|
||||
if _, err := runtime.Update(ctx, adapters.UpdateRequest{
|
||||
MemoryID: id,
|
||||
Memory: text,
|
||||
}); err != nil {
|
||||
logger.Warn("memory formation: UPDATE failed", slog.String("bot_id", botID), slog.String("memory_id", id), slog.Any("error", err))
|
||||
} else {
|
||||
updated[id] = struct{}{}
|
||||
result.Updated++
|
||||
}
|
||||
|
||||
case actionDELETE:
|
||||
id := strings.TrimSpace(action.ID)
|
||||
if id == "" {
|
||||
logger.Debug("memory formation: DELETE skipped (missing id)", slog.String("bot_id", botID))
|
||||
result.Skipped++
|
||||
continue
|
||||
}
|
||||
if _, ok := deleted[id]; ok {
|
||||
result.Skipped++
|
||||
continue
|
||||
}
|
||||
if _, err := runtime.Delete(ctx, id); err != nil {
|
||||
logger.Warn("memory formation: DELETE failed", slog.String("bot_id", botID), slog.String("memory_id", id), slog.Any("error", err))
|
||||
} else {
|
||||
deleted[id] = struct{}{}
|
||||
result.Deleted++
|
||||
}
|
||||
|
||||
case actionNOOP, "":
|
||||
result.Skipped++
|
||||
|
||||
default:
|
||||
logger.Debug("memory formation: unknown action event", slog.String("bot_id", botID), slog.String("event", event))
|
||||
result.Skipped++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filterNonEmpty(ss []string) []string {
|
||||
out := make([]string, 0, len(ss))
|
||||
for _, s := range ss {
|
||||
s = strings.TrimSpace(s)
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,434 @@
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
)
|
||||
|
||||
// fakeLLM implements adapters.LLM for testing the formation pipeline.
|
||||
type fakeLLM struct {
|
||||
extractFacts []string
|
||||
extractErr error
|
||||
decideActions []adapters.DecisionAction
|
||||
decideErr error
|
||||
extractCalls int
|
||||
decideCalls int
|
||||
}
|
||||
|
||||
func (f *fakeLLM) Extract(_ context.Context, _ adapters.ExtractRequest) (adapters.ExtractResponse, error) {
|
||||
f.extractCalls++
|
||||
return adapters.ExtractResponse{Facts: f.extractFacts}, f.extractErr
|
||||
}
|
||||
|
||||
func (f *fakeLLM) Decide(_ context.Context, _ adapters.DecideRequest) (adapters.DecideResponse, error) {
|
||||
f.decideCalls++
|
||||
return adapters.DecideResponse{Actions: f.decideActions}, f.decideErr
|
||||
}
|
||||
|
||||
func (*fakeLLM) Compact(context.Context, adapters.CompactRequest) (adapters.CompactResponse, error) {
|
||||
return adapters.CompactResponse{}, nil
|
||||
}
|
||||
|
||||
func (*fakeLLM) DetectLanguage(context.Context, string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func TestFormationExtractAndAdd(t *testing.T) {
|
||||
t.Parallel()
|
||||
encoder := &fakeSparseEncoder{}
|
||||
index := newFakeSparseIndex(encoder)
|
||||
store := newFakeSparseStore()
|
||||
runtime := &sparseRuntime{qdrant: index, encoder: encoder, store: store}
|
||||
llm := &fakeLLM{
|
||||
extractFacts: []string{"User likes oolong tea", "User is based in Berlin"},
|
||||
decideActions: []adapters.DecisionAction{
|
||||
{Event: "ADD", Text: "User likes oolong tea"},
|
||||
{Event: "ADD", Text: "User is based in Berlin"},
|
||||
},
|
||||
}
|
||||
|
||||
result := runFormation(context.Background(), slog.Default(), llm, runtime, adapters.AfterChatRequest{
|
||||
BotID: "bot-1",
|
||||
Messages: []adapters.Message{
|
||||
{Role: "user", Content: "I like oolong tea and I live in Berlin"},
|
||||
{Role: "assistant", Content: "Noted!"},
|
||||
},
|
||||
})
|
||||
|
||||
if result.ExtractedFacts != 2 {
|
||||
t.Fatalf("expected 2 extracted facts, got %d", result.ExtractedFacts)
|
||||
}
|
||||
if result.Added != 2 {
|
||||
t.Fatalf("expected 2 adds, got %d", result.Added)
|
||||
}
|
||||
if result.Updated != 0 || result.Deleted != 0 {
|
||||
t.Fatalf("expected no updates/deletes, got updated=%d deleted=%d", result.Updated, result.Deleted)
|
||||
}
|
||||
if len(store.items) != 2 {
|
||||
t.Fatalf("expected 2 items in store, got %d", len(store.items))
|
||||
}
|
||||
if llm.extractCalls != 1 || llm.decideCalls != 1 {
|
||||
t.Fatalf("expected 1 extract + 1 decide call, got %d/%d", llm.extractCalls, llm.decideCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormationUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
encoder := &fakeSparseEncoder{}
|
||||
index := newFakeSparseIndex(encoder)
|
||||
store := newFakeSparseStore()
|
||||
runtime := &sparseRuntime{qdrant: index, encoder: encoder, store: store}
|
||||
|
||||
addResp, err := runtime.Add(context.Background(), adapters.AddRequest{
|
||||
BotID: "bot-1",
|
||||
Message: "User lives in Tokyo",
|
||||
Filters: map[string]any{"bot_id": "bot-1"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("seed Add failed: %v", err)
|
||||
}
|
||||
memID := addResp.Results[0].ID
|
||||
|
||||
llm := &fakeLLM{
|
||||
extractFacts: []string{"User moved to Berlin"},
|
||||
decideActions: []adapters.DecisionAction{
|
||||
{Event: "UPDATE", ID: memID, Text: "User is based in Berlin", OldMemory: "User lives in Tokyo"},
|
||||
},
|
||||
}
|
||||
|
||||
result := runFormation(context.Background(), slog.Default(), llm, runtime, adapters.AfterChatRequest{
|
||||
BotID: "bot-1",
|
||||
Messages: []adapters.Message{
|
||||
{Role: "user", Content: "Actually, I moved to Berlin"},
|
||||
},
|
||||
})
|
||||
|
||||
if result.Updated != 1 {
|
||||
t.Fatalf("expected 1 update, got %d", result.Updated)
|
||||
}
|
||||
if result.Added != 0 {
|
||||
t.Fatalf("expected 0 adds, got %d", result.Added)
|
||||
}
|
||||
|
||||
item, ok := store.items[memID]
|
||||
if !ok {
|
||||
t.Fatalf("expected memory %q to still exist", memID)
|
||||
}
|
||||
if !strings.Contains(item.Memory, "Berlin") {
|
||||
t.Fatalf("expected updated memory to contain Berlin, got %q", item.Memory)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormationDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
encoder := &fakeSparseEncoder{}
|
||||
index := newFakeSparseIndex(encoder)
|
||||
store := newFakeSparseStore()
|
||||
runtime := &sparseRuntime{qdrant: index, encoder: encoder, store: store}
|
||||
|
||||
addResp, err := runtime.Add(context.Background(), adapters.AddRequest{
|
||||
BotID: "bot-1",
|
||||
Message: "User likes coffee",
|
||||
Filters: map[string]any{"bot_id": "bot-1"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("seed Add failed: %v", err)
|
||||
}
|
||||
memID := addResp.Results[0].ID
|
||||
|
||||
llm := &fakeLLM{
|
||||
extractFacts: []string{"User no longer drinks coffee"},
|
||||
decideActions: []adapters.DecisionAction{
|
||||
{Event: "DELETE", ID: memID},
|
||||
},
|
||||
}
|
||||
|
||||
result := runFormation(context.Background(), slog.Default(), llm, runtime, adapters.AfterChatRequest{
|
||||
BotID: "bot-1",
|
||||
Messages: []adapters.Message{
|
||||
{Role: "user", Content: "I stopped drinking coffee"},
|
||||
},
|
||||
})
|
||||
|
||||
if result.Deleted != 1 {
|
||||
t.Fatalf("expected 1 delete, got %d", result.Deleted)
|
||||
}
|
||||
if _, ok := store.items[memID]; ok {
|
||||
t.Fatal("expected memory to be deleted from store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormationNOOP(t *testing.T) {
|
||||
t.Parallel()
|
||||
encoder := &fakeSparseEncoder{}
|
||||
index := newFakeSparseIndex(encoder)
|
||||
store := newFakeSparseStore()
|
||||
runtime := &sparseRuntime{qdrant: index, encoder: encoder, store: store}
|
||||
|
||||
llm := &fakeLLM{
|
||||
extractFacts: []string{"User likes tea"},
|
||||
decideActions: []adapters.DecisionAction{
|
||||
{Event: "NOOP"},
|
||||
},
|
||||
}
|
||||
|
||||
result := runFormation(context.Background(), slog.Default(), llm, runtime, adapters.AfterChatRequest{
|
||||
BotID: "bot-1",
|
||||
Messages: []adapters.Message{
|
||||
{Role: "user", Content: "I like tea"},
|
||||
},
|
||||
})
|
||||
|
||||
if result.Skipped != 1 {
|
||||
t.Fatalf("expected 1 skipped, got %d", result.Skipped)
|
||||
}
|
||||
if result.Added != 0 || result.Updated != 0 || result.Deleted != 0 {
|
||||
t.Fatalf("expected no mutations, got added=%d updated=%d deleted=%d", result.Added, result.Updated, result.Deleted)
|
||||
}
|
||||
if len(store.items) != 0 {
|
||||
t.Fatalf("expected 0 items in store, got %d", len(store.items))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormationNoFacts(t *testing.T) {
|
||||
t.Parallel()
|
||||
encoder := &fakeSparseEncoder{}
|
||||
index := newFakeSparseIndex(encoder)
|
||||
store := newFakeSparseStore()
|
||||
runtime := &sparseRuntime{qdrant: index, encoder: encoder, store: store}
|
||||
|
||||
llm := &fakeLLM{
|
||||
extractFacts: []string{},
|
||||
}
|
||||
|
||||
result := runFormation(context.Background(), slog.Default(), llm, runtime, adapters.AfterChatRequest{
|
||||
BotID: "bot-1",
|
||||
Messages: []adapters.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi there!"},
|
||||
},
|
||||
})
|
||||
|
||||
if result.ExtractedFacts != 0 {
|
||||
t.Fatalf("expected 0 extracted facts, got %d", result.ExtractedFacts)
|
||||
}
|
||||
if llm.decideCalls != 0 {
|
||||
t.Fatal("expected Decide to NOT be called when no facts extracted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormationMixedActions(t *testing.T) {
|
||||
t.Parallel()
|
||||
encoder := &fakeSparseEncoder{}
|
||||
index := newFakeSparseIndex(encoder)
|
||||
store := newFakeSparseStore()
|
||||
runtime := &sparseRuntime{qdrant: index, encoder: encoder, store: store}
|
||||
|
||||
addResp, _ := runtime.Add(context.Background(), adapters.AddRequest{
|
||||
BotID: "bot-1",
|
||||
Message: "User lives in Tokyo",
|
||||
Filters: map[string]any{"bot_id": "bot-1"},
|
||||
})
|
||||
existingID := addResp.Results[0].ID
|
||||
|
||||
llm := &fakeLLM{
|
||||
extractFacts: []string{"User moved to Berlin", "User prefers dark mode"},
|
||||
decideActions: []adapters.DecisionAction{
|
||||
{Event: "UPDATE", ID: existingID, Text: "User lives in Berlin"},
|
||||
{Event: "ADD", Text: "User prefers dark mode"},
|
||||
{Event: "NOOP"},
|
||||
},
|
||||
}
|
||||
|
||||
result := runFormation(context.Background(), slog.Default(), llm, runtime, adapters.AfterChatRequest{
|
||||
BotID: "bot-1",
|
||||
Messages: []adapters.Message{
|
||||
{Role: "user", Content: "I moved to Berlin and I like dark mode"},
|
||||
},
|
||||
})
|
||||
|
||||
if result.Added != 1 {
|
||||
t.Fatalf("expected 1 add, got %d", result.Added)
|
||||
}
|
||||
if result.Updated != 1 {
|
||||
t.Fatalf("expected 1 update, got %d", result.Updated)
|
||||
}
|
||||
if result.Skipped != 1 {
|
||||
t.Fatalf("expected 1 skipped, got %d", result.Skipped)
|
||||
}
|
||||
if len(store.items) != 2 {
|
||||
t.Fatalf("expected 2 items in store, got %d", len(store.items))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormationInvalidActionsSkipped(t *testing.T) {
|
||||
t.Parallel()
|
||||
encoder := &fakeSparseEncoder{}
|
||||
index := newFakeSparseIndex(encoder)
|
||||
store := newFakeSparseStore()
|
||||
runtime := &sparseRuntime{qdrant: index, encoder: encoder, store: store}
|
||||
|
||||
llm := &fakeLLM{
|
||||
extractFacts: []string{"User likes cats"},
|
||||
decideActions: []adapters.DecisionAction{
|
||||
{Event: "ADD", Text: ""},
|
||||
{Event: "UPDATE", ID: "", Text: "something"},
|
||||
{Event: "DELETE", ID: ""},
|
||||
{Event: "UNKNOWN_EVENT", Text: "foo"},
|
||||
{Event: "ADD", Text: "User likes cats"},
|
||||
},
|
||||
}
|
||||
|
||||
result := runFormation(context.Background(), slog.Default(), llm, runtime, adapters.AfterChatRequest{
|
||||
BotID: "bot-1",
|
||||
Messages: []adapters.Message{
|
||||
{Role: "user", Content: "I like cats"},
|
||||
},
|
||||
})
|
||||
|
||||
if result.Added != 1 {
|
||||
t.Fatalf("expected 1 valid add, got %d", result.Added)
|
||||
}
|
||||
if result.Skipped != 4 {
|
||||
t.Fatalf("expected 4 skipped (3 invalid + 1 unknown), got %d", result.Skipped)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormationDuplicateActionsSameID(t *testing.T) {
|
||||
t.Parallel()
|
||||
encoder := &fakeSparseEncoder{}
|
||||
index := newFakeSparseIndex(encoder)
|
||||
store := newFakeSparseStore()
|
||||
runtime := &sparseRuntime{qdrant: index, encoder: encoder, store: store}
|
||||
|
||||
addResp, _ := runtime.Add(context.Background(), adapters.AddRequest{
|
||||
BotID: "bot-1",
|
||||
Message: "User likes tea",
|
||||
Filters: map[string]any{"bot_id": "bot-1"},
|
||||
})
|
||||
memID := addResp.Results[0].ID
|
||||
|
||||
llm := &fakeLLM{
|
||||
extractFacts: []string{"Updated fact"},
|
||||
decideActions: []adapters.DecisionAction{
|
||||
{Event: "UPDATE", ID: memID, Text: "User prefers coffee"},
|
||||
{Event: "UPDATE", ID: memID, Text: "User prefers juice"},
|
||||
},
|
||||
}
|
||||
|
||||
result := runFormation(context.Background(), slog.Default(), llm, runtime, adapters.AfterChatRequest{
|
||||
BotID: "bot-1",
|
||||
Messages: []adapters.Message{
|
||||
{Role: "user", Content: "I changed my mind"},
|
||||
},
|
||||
})
|
||||
|
||||
if result.Updated != 1 {
|
||||
t.Fatalf("expected 1 update (second should be deduped), got %d", result.Updated)
|
||||
}
|
||||
if result.Skipped != 1 {
|
||||
t.Fatalf("expected 1 skipped (duplicate), got %d", result.Skipped)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnAfterChatWithLLM(t *testing.T) {
|
||||
t.Parallel()
|
||||
encoder := &fakeSparseEncoder{}
|
||||
index := newFakeSparseIndex(encoder)
|
||||
store := newFakeSparseStore()
|
||||
runtime := &sparseRuntime{qdrant: index, encoder: encoder, store: store}
|
||||
llm := &fakeLLM{
|
||||
extractFacts: []string{"User prefers dark mode"},
|
||||
decideActions: []adapters.DecisionAction{
|
||||
{Event: "ADD", Text: "User prefers dark mode"},
|
||||
},
|
||||
}
|
||||
|
||||
p := NewBuiltinProvider(slog.Default(), runtime, nil, nil)
|
||||
p.SetLLM(llm)
|
||||
|
||||
err := p.OnAfterChat(context.Background(), adapters.AfterChatRequest{
|
||||
BotID: "bot-1",
|
||||
Messages: []adapters.Message{
|
||||
{Role: "user", Content: "I prefer dark mode"},
|
||||
{Role: "assistant", Content: "Got it!"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("OnAfterChat error: %v", err)
|
||||
}
|
||||
if len(store.items) != 1 {
|
||||
t.Fatalf("expected 1 fact stored, got %d", len(store.items))
|
||||
}
|
||||
for _, item := range store.items {
|
||||
if !strings.Contains(item.Memory, "dark mode") {
|
||||
t.Fatalf("expected stored fact to mention dark mode, got %q", item.Memory)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnAfterChatFallbackWithoutLLM(t *testing.T) {
|
||||
t.Parallel()
|
||||
encoder := &fakeSparseEncoder{}
|
||||
index := newFakeSparseIndex(encoder)
|
||||
store := newFakeSparseStore()
|
||||
runtime := &sparseRuntime{qdrant: index, encoder: encoder, store: store}
|
||||
|
||||
p := NewBuiltinProvider(slog.Default(), runtime, nil, nil)
|
||||
|
||||
err := p.OnAfterChat(context.Background(), adapters.AfterChatRequest{
|
||||
BotID: "bot-1",
|
||||
Messages: []adapters.Message{
|
||||
{Role: "user", Content: "Hello world"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("OnAfterChat error: %v", err)
|
||||
}
|
||||
if len(store.items) != 1 {
|
||||
t.Fatalf("expected 1 item in store (legacy fallback), got %d", len(store.items))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnBeforeChatRecallsFactMemory(t *testing.T) {
|
||||
t.Parallel()
|
||||
encoder := &fakeSparseEncoder{}
|
||||
index := newFakeSparseIndex(encoder)
|
||||
store := newFakeSparseStore()
|
||||
runtime := &sparseRuntime{qdrant: index, encoder: encoder, store: store}
|
||||
llm := &fakeLLM{
|
||||
extractFacts: []string{"User prefers oolong tea"},
|
||||
decideActions: []adapters.DecisionAction{
|
||||
{Event: "ADD", Text: "User prefers oolong tea"},
|
||||
},
|
||||
}
|
||||
|
||||
p := NewBuiltinProvider(slog.Default(), runtime, nil, nil)
|
||||
p.SetLLM(llm)
|
||||
|
||||
_ = p.OnAfterChat(context.Background(), adapters.AfterChatRequest{
|
||||
BotID: "bot-1",
|
||||
Messages: []adapters.Message{
|
||||
{Role: "user", Content: "I prefer oolong tea"},
|
||||
},
|
||||
})
|
||||
|
||||
result, err := p.OnBeforeChat(context.Background(), adapters.BeforeChatRequest{
|
||||
BotID: "bot-1",
|
||||
Query: "tea",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("OnBeforeChat error: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil context result")
|
||||
}
|
||||
lower := strings.ToLower(result.ContextText)
|
||||
if !strings.Contains(lower, "oolong tea") {
|
||||
t.Fatalf("expected recalled context to mention oolong tea, got %q", result.ContextText)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,176 @@
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
qdrantclient "github.com/memohai/memoh/internal/memory/qdrant"
|
||||
storefs "github.com/memohai/memoh/internal/memory/storefs"
|
||||
)
|
||||
|
||||
func canonicalStoreItem(item storefs.MemoryItem) storefs.MemoryItem {
|
||||
item.ID = strings.TrimSpace(item.ID)
|
||||
item.Memory = strings.TrimSpace(item.Memory)
|
||||
if item.Memory != "" && strings.TrimSpace(item.Hash) == "" {
|
||||
item.Hash = runtimeHash(item.Memory)
|
||||
}
|
||||
return item
|
||||
}
|
||||
|
||||
func runtimePayload(botID string, item storefs.MemoryItem) map[string]string {
|
||||
item = canonicalStoreItem(item)
|
||||
payload := map[string]string{
|
||||
"memory": item.Memory,
|
||||
"bot_id": strings.TrimSpace(botID),
|
||||
"source_entry_id": item.ID,
|
||||
"hash": item.Hash,
|
||||
}
|
||||
if item.CreatedAt != "" {
|
||||
payload["created_at"] = item.CreatedAt
|
||||
}
|
||||
if item.UpdatedAt != "" {
|
||||
payload["updated_at"] = item.UpdatedAt
|
||||
}
|
||||
for _, key := range []string{"profile_user_id", "profile_channel_identity_id", "profile_display_name", "profile_ref"} {
|
||||
if v, ok := item.Metadata[key]; ok {
|
||||
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
|
||||
payload[key] = strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func payloadMatches(existing, expected map[string]string) bool {
|
||||
for key, value := range expected {
|
||||
if strings.TrimSpace(existing[key]) != strings.TrimSpace(value) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func storeItemFromMemoryItem(item adapters.MemoryItem) storefs.MemoryItem {
|
||||
return canonicalStoreItem(storefs.MemoryItem{
|
||||
ID: item.ID,
|
||||
Memory: item.Memory,
|
||||
Hash: item.Hash,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
Score: item.Score,
|
||||
Metadata: item.Metadata,
|
||||
BotID: item.BotID,
|
||||
AgentID: item.AgentID,
|
||||
RunID: item.RunID,
|
||||
})
|
||||
}
|
||||
|
||||
func memoryItemFromStore(item storefs.MemoryItem) adapters.MemoryItem {
|
||||
item = canonicalStoreItem(item)
|
||||
return adapters.MemoryItem{
|
||||
ID: item.ID,
|
||||
Memory: item.Memory,
|
||||
Hash: item.Hash,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
Score: item.Score,
|
||||
Metadata: item.Metadata,
|
||||
BotID: item.BotID,
|
||||
AgentID: item.AgentID,
|
||||
RunID: item.RunID,
|
||||
}
|
||||
}
|
||||
|
||||
func resultToItem(r qdrantclient.SearchResult) adapters.MemoryItem {
|
||||
item := adapters.MemoryItem{
|
||||
ID: r.ID,
|
||||
Score: r.Score,
|
||||
}
|
||||
if r.Payload != nil {
|
||||
if sourceID := strings.TrimSpace(r.Payload["source_entry_id"]); sourceID != "" {
|
||||
item.ID = sourceID
|
||||
}
|
||||
item.Memory = r.Payload["memory"]
|
||||
item.Hash = r.Payload["hash"]
|
||||
item.BotID = r.Payload["bot_id"]
|
||||
item.CreatedAt = r.Payload["created_at"]
|
||||
item.UpdatedAt = r.Payload["updated_at"]
|
||||
meta := map[string]any{}
|
||||
for _, key := range []string{"profile_user_id", "profile_channel_identity_id", "profile_display_name", "profile_ref"} {
|
||||
if v := strings.TrimSpace(r.Payload[key]); v != "" {
|
||||
meta[key] = v
|
||||
}
|
||||
}
|
||||
if len(meta) > 0 {
|
||||
item.Metadata = meta
|
||||
}
|
||||
}
|
||||
return item
|
||||
}
|
||||
|
||||
func runtimeBotID(botID string, filters map[string]any) (string, error) {
|
||||
botID = strings.TrimSpace(botID)
|
||||
if botID == "" {
|
||||
botID = strings.TrimSpace(runtimeFilterString(filters, "bot_id"))
|
||||
}
|
||||
if botID == "" {
|
||||
botID = strings.TrimSpace(runtimeFilterString(filters, "scopeId"))
|
||||
}
|
||||
if botID == "" {
|
||||
return "", errors.New("bot_id is required")
|
||||
}
|
||||
return botID, nil
|
||||
}
|
||||
|
||||
func runtimeBotIDFromMemoryID(memoryID string) string {
|
||||
parts := strings.SplitN(strings.TrimSpace(memoryID), ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
|
||||
func runtimeText(message string, messages []adapters.Message) string {
|
||||
text := strings.TrimSpace(message)
|
||||
if text == "" && len(messages) > 0 {
|
||||
parts := make([]string, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
content := strings.TrimSpace(m.Content)
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
role := strings.ToUpper(strings.TrimSpace(m.Role))
|
||||
if role == "" {
|
||||
role = "MESSAGE"
|
||||
}
|
||||
parts = append(parts, "["+role+"] "+content)
|
||||
}
|
||||
text = strings.Join(parts, "\n")
|
||||
}
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
func runtimeMemoryID(botID string, now time.Time) string {
|
||||
return botID + ":" + "mem_" + strconv.FormatInt(now.UnixNano(), 10)
|
||||
}
|
||||
|
||||
func runtimePointID(botID, sourceID string) string {
|
||||
return uuid.NewSHA1(uuid.NameSpaceURL, []byte(strings.TrimSpace(botID)+"\n"+strings.TrimSpace(sourceID))).String()
|
||||
}
|
||||
|
||||
func runtimeFilterString(m map[string]any, key string) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
v, ok := m[key]
|
||||
if !ok || v == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
qdrantclient "github.com/memohai/memoh/internal/memory/qdrant"
|
||||
storefs "github.com/memohai/memoh/internal/memory/storefs"
|
||||
)
|
||||
|
||||
func TestRuntimeHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
h1 := runtimeHash("hello")
|
||||
h2 := runtimeHash(" hello ")
|
||||
if h1 != h2 {
|
||||
t.Fatalf("expected trimmed strings to produce same hash, got %q vs %q", h1, h2)
|
||||
}
|
||||
h3 := runtimeHash("world")
|
||||
if h1 == h3 {
|
||||
t.Fatal("expected different hashes for different inputs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeBotID(t *testing.T) {
|
||||
t.Parallel()
|
||||
id, err := runtimeBotID("bot-1", nil)
|
||||
if err != nil || id != "bot-1" {
|
||||
t.Fatalf("expected bot-1, got %q, err=%v", id, err)
|
||||
}
|
||||
|
||||
id, err = runtimeBotID("", map[string]any{"bot_id": "bot-2"})
|
||||
if err != nil || id != "bot-2" {
|
||||
t.Fatalf("expected bot-2 from filter, got %q, err=%v", id, err)
|
||||
}
|
||||
|
||||
_, err = runtimeBotID("", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty bot_id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeBotIDFromMemoryID(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := runtimeBotIDFromMemoryID("bot-1:mem_123"); got != "bot-1" {
|
||||
t.Fatalf("expected bot-1, got %q", got)
|
||||
}
|
||||
if got := runtimeBotIDFromMemoryID("invalid"); got != "" {
|
||||
t.Fatalf("expected empty for invalid format, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimePointID_Deterministic(t *testing.T) {
|
||||
t.Parallel()
|
||||
p1 := runtimePointID("bot-1", "mem-1")
|
||||
p2 := runtimePointID("bot-1", "mem-1")
|
||||
if p1 != p2 {
|
||||
t.Fatalf("expected deterministic point ID, got %q vs %q", p1, p2)
|
||||
}
|
||||
p3 := runtimePointID("bot-1", "mem-2")
|
||||
if p1 == p3 {
|
||||
t.Fatal("expected different point IDs for different sources")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanonicalStoreItem(t *testing.T) {
|
||||
t.Parallel()
|
||||
item := storefs.MemoryItem{
|
||||
ID: " id-1 ",
|
||||
Memory: " hello world ",
|
||||
}
|
||||
c := canonicalStoreItem(item)
|
||||
if c.ID != "id-1" {
|
||||
t.Fatalf("expected trimmed ID, got %q", c.ID)
|
||||
}
|
||||
if c.Memory != "hello world" {
|
||||
t.Fatalf("expected trimmed memory, got %q", c.Memory)
|
||||
}
|
||||
if c.Hash == "" {
|
||||
t.Fatal("expected hash to be populated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPayloadMatches(t *testing.T) {
|
||||
t.Parallel()
|
||||
existing := map[string]string{"memory": "hello", "bot_id": "b1"}
|
||||
expected := map[string]string{"memory": "hello", "bot_id": "b1"}
|
||||
if !payloadMatches(existing, expected) {
|
||||
t.Fatal("expected matching payloads")
|
||||
}
|
||||
expected["memory"] = "world"
|
||||
if payloadMatches(existing, expected) {
|
||||
t.Fatal("expected non-matching payloads")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResultToItem(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := qdrantclient.SearchResult{
|
||||
ID: "point-1",
|
||||
Score: 0.95,
|
||||
Payload: map[string]string{
|
||||
"source_entry_id": "mem-1",
|
||||
"memory": "test memory",
|
||||
"hash": "abc",
|
||||
"bot_id": "bot-1",
|
||||
"created_at": "2026-01-01T00:00:00Z",
|
||||
"updated_at": "2026-01-01T00:00:00Z",
|
||||
},
|
||||
}
|
||||
item := resultToItem(r)
|
||||
if item.ID != "mem-1" {
|
||||
t.Fatalf("expected source_entry_id as ID, got %q", item.ID)
|
||||
}
|
||||
if item.Score != 0.95 {
|
||||
t.Fatalf("expected score 0.95, got %f", item.Score)
|
||||
}
|
||||
if item.Memory != "test memory" {
|
||||
t.Fatalf("expected memory, got %q", item.Memory)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreItemRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
original := adapters.MemoryItem{
|
||||
ID: "id-1",
|
||||
Memory: "hello world",
|
||||
Hash: "abc",
|
||||
BotID: "bot-1",
|
||||
}
|
||||
store := storeItemFromMemoryItem(original)
|
||||
back := memoryItemFromStore(store)
|
||||
if back.ID != original.ID || back.Memory != original.Memory || back.BotID != original.BotID {
|
||||
t.Fatalf("round-trip failed: got %+v", back)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeText_SingleMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
text := runtimeText("hello", nil)
|
||||
if text != "hello" {
|
||||
t.Fatalf("expected 'hello', got %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeText_MultipleMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
msgs := []adapters.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
{Role: "assistant", Content: "hello"},
|
||||
}
|
||||
text := runtimeText("", msgs)
|
||||
if text == "" {
|
||||
t.Fatal("expected non-empty text from messages")
|
||||
}
|
||||
if !contains(text, "[USER] hi") || !contains(text, "[ASSISTANT] hello") {
|
||||
t.Fatalf("unexpected text format: %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
|
||||
}
|
||||
|
||||
func containsHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -2,18 +2,13 @@ package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
qdrantclient "github.com/memohai/memoh/internal/memory/qdrant"
|
||||
@@ -92,36 +87,36 @@ func (*sparseRuntime) Mode() string {
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) Add(ctx context.Context, req adapters.AddRequest) (adapters.SearchResponse, error) {
|
||||
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
|
||||
botID, err := runtimeBotID(req.BotID, req.Filters)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
text := sparseRuntimeText(req.Message, req.Messages)
|
||||
text := runtimeText(req.Message, req.Messages)
|
||||
if text == "" {
|
||||
return adapters.SearchResponse{}, errors.New("sparse runtime: message is required")
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
item := adapters.MemoryItem{
|
||||
ID: sparseRuntimeMemoryID(botID, time.Now().UTC()),
|
||||
ID: runtimeMemoryID(botID, time.Now().UTC()),
|
||||
Memory: text,
|
||||
Hash: sparseRuntimeHash(text),
|
||||
Hash: runtimeHash(text),
|
||||
Metadata: req.Metadata,
|
||||
BotID: botID,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if err := r.store.PersistMemories(ctx, botID, []storefs.MemoryItem{sparseStoreItemFromMemoryItem(item)}, req.Filters); err != nil {
|
||||
if err := r.store.PersistMemories(ctx, botID, []storefs.MemoryItem{storeItemFromMemoryItem(item)}, req.Filters); err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
if err := r.upsertSourceItems(ctx, botID, []storefs.MemoryItem{sparseStoreItemFromMemoryItem(item)}); err != nil {
|
||||
if err := r.upsertSourceItems(ctx, botID, []storefs.MemoryItem{storeItemFromMemoryItem(item)}); err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
return adapters.SearchResponse{Results: []adapters.MemoryItem{item}}, nil
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) Search(ctx context.Context, req adapters.SearchRequest) (adapters.SearchResponse, error) {
|
||||
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
|
||||
botID, err := runtimeBotID(req.BotID, req.Filters)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
@@ -147,13 +142,13 @@ func (r *sparseRuntime) Search(ctx context.Context, req adapters.SearchRequest)
|
||||
}
|
||||
items := make([]adapters.MemoryItem, 0, len(results))
|
||||
for _, r := range results {
|
||||
items = append(items, sparseResultToItem(r))
|
||||
items = append(items, resultToItem(r))
|
||||
}
|
||||
return adapters.SearchResponse{Results: items}, nil
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) GetAll(ctx context.Context, req adapters.GetAllRequest) (adapters.SearchResponse, error) {
|
||||
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
|
||||
botID, err := runtimeBotID(req.BotID, req.Filters)
|
||||
if err != nil {
|
||||
return adapters.SearchResponse{}, err
|
||||
}
|
||||
@@ -163,7 +158,7 @@ func (r *sparseRuntime) GetAll(ctx context.Context, req adapters.GetAllRequest)
|
||||
}
|
||||
result := make([]adapters.MemoryItem, 0, len(items))
|
||||
for _, item := range items {
|
||||
mem := sparseMemoryItemFromStore(item)
|
||||
mem := memoryItemFromStore(item)
|
||||
mem.BotID = botID
|
||||
result = append(result, mem)
|
||||
}
|
||||
@@ -184,7 +179,7 @@ func (r *sparseRuntime) Update(ctx context.Context, req adapters.UpdateRequest)
|
||||
if text == "" {
|
||||
return adapters.MemoryItem{}, errors.New("sparse runtime: memory is required")
|
||||
}
|
||||
botID := sparseRuntimeBotIDFromMemoryID(memoryID)
|
||||
botID := runtimeBotIDFromMemoryID(memoryID)
|
||||
if botID == "" {
|
||||
return adapters.MemoryItem{}, errors.New("sparse runtime: invalid memory_id")
|
||||
}
|
||||
@@ -204,7 +199,7 @@ func (r *sparseRuntime) Update(ctx context.Context, req adapters.UpdateRequest)
|
||||
return adapters.MemoryItem{}, errors.New("sparse runtime: memory not found")
|
||||
}
|
||||
existing.Memory = text
|
||||
existing.Hash = sparseRuntimeHash(text)
|
||||
existing.Hash = runtimeHash(text)
|
||||
existing.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
if err := r.store.PersistMemories(ctx, botID, []storefs.MemoryItem{*existing}, nil); err != nil {
|
||||
return adapters.MemoryItem{}, err
|
||||
@@ -212,7 +207,7 @@ func (r *sparseRuntime) Update(ctx context.Context, req adapters.UpdateRequest)
|
||||
if err := r.upsertSourceItems(ctx, botID, []storefs.MemoryItem{*existing}); err != nil {
|
||||
return adapters.MemoryItem{}, err
|
||||
}
|
||||
item := sparseMemoryItemFromStore(*existing)
|
||||
item := memoryItemFromStore(*existing)
|
||||
item.BotID = botID
|
||||
return item, nil
|
||||
}
|
||||
@@ -229,12 +224,12 @@ func (r *sparseRuntime) DeleteBatch(ctx context.Context, memoryIDs []string) (ad
|
||||
if memoryID == "" {
|
||||
continue
|
||||
}
|
||||
botID := sparseRuntimeBotIDFromMemoryID(memoryID)
|
||||
botID := runtimeBotIDFromMemoryID(memoryID)
|
||||
if botID == "" {
|
||||
continue
|
||||
}
|
||||
grouped[botID] = append(grouped[botID], memoryID)
|
||||
pointIDs = append(pointIDs, sparsePointID(botID, memoryID))
|
||||
pointIDs = append(pointIDs, runtimePointID(botID, memoryID))
|
||||
}
|
||||
for botID, ids := range grouped {
|
||||
if err := r.store.RemoveMemories(ctx, botID, ids); err != nil {
|
||||
@@ -251,7 +246,7 @@ func (r *sparseRuntime) DeleteBatch(ctx context.Context, memoryIDs []string) (ad
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) DeleteAll(ctx context.Context, req adapters.DeleteAllRequest) (adapters.DeleteResponse, error) {
|
||||
botID, err := sparseRuntimeBotID(req.BotID, req.Filters)
|
||||
botID, err := runtimeBotID(req.BotID, req.Filters)
|
||||
if err != nil {
|
||||
return adapters.DeleteResponse{}, err
|
||||
}
|
||||
@@ -268,7 +263,7 @@ func (r *sparseRuntime) DeleteAll(ctx context.Context, req adapters.DeleteAllReq
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) Compact(ctx context.Context, filters map[string]any, ratio float64, _ int) (adapters.CompactResult, error) {
|
||||
botID, err := sparseRuntimeBotID("", filters)
|
||||
botID, err := runtimeBotID("", filters)
|
||||
if err != nil {
|
||||
return adapters.CompactResult{}, err
|
||||
}
|
||||
@@ -300,7 +295,7 @@ func (r *sparseRuntime) Compact(ctx context.Context, filters map[string]any, rat
|
||||
}
|
||||
kept := make([]adapters.MemoryItem, 0, len(keptStore))
|
||||
for _, item := range keptStore {
|
||||
kept = append(kept, sparseMemoryItemFromStore(item))
|
||||
kept = append(kept, memoryItemFromStore(item))
|
||||
}
|
||||
return adapters.CompactResult{
|
||||
BeforeCount: before,
|
||||
@@ -311,7 +306,7 @@ func (r *sparseRuntime) Compact(ctx context.Context, filters map[string]any, rat
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) Usage(ctx context.Context, filters map[string]any) (adapters.UsageResponse, error) {
|
||||
botID, err := sparseRuntimeBotID("", filters)
|
||||
botID, err := runtimeBotID("", filters)
|
||||
if err != nil {
|
||||
return adapters.UsageResponse{}, err
|
||||
}
|
||||
@@ -411,13 +406,13 @@ func (r *sparseRuntime) syncSourceItems(ctx context.Context, botID string, items
|
||||
missingCount := 0
|
||||
restoredCount := 0
|
||||
for _, item := range items {
|
||||
item = sparseCanonicalStoreItem(item)
|
||||
item = canonicalStoreItem(item)
|
||||
if item.ID == "" || item.Memory == "" {
|
||||
continue
|
||||
}
|
||||
canonical = append(canonical, item)
|
||||
sourceIDs[item.ID] = struct{}{}
|
||||
payload := sparsePayload(botID, item)
|
||||
payload := runtimePayload(botID, item)
|
||||
existingItem, ok := existingBySource[item.ID]
|
||||
if !ok {
|
||||
missingCount++
|
||||
@@ -425,7 +420,7 @@ func (r *sparseRuntime) syncSourceItems(ctx context.Context, botID string, items
|
||||
toUpsert = append(toUpsert, item)
|
||||
continue
|
||||
}
|
||||
if !sparsePayloadMatches(existingItem.Payload, payload) {
|
||||
if !payloadMatches(existingItem.Payload, payload) {
|
||||
restoredCount++
|
||||
toUpsert = append(toUpsert, item)
|
||||
}
|
||||
@@ -473,7 +468,7 @@ func (r *sparseRuntime) upsertSourceItems(ctx context.Context, botID string, ite
|
||||
texts := make([]string, 0, len(items))
|
||||
canonical := make([]storefs.MemoryItem, 0, len(items))
|
||||
for _, item := range items {
|
||||
item = sparseCanonicalStoreItem(item)
|
||||
item = canonicalStoreItem(item)
|
||||
if item.ID == "" || item.Memory == "" {
|
||||
continue
|
||||
}
|
||||
@@ -492,34 +487,16 @@ func (r *sparseRuntime) upsertSourceItems(ctx context.Context, botID string, ite
|
||||
}
|
||||
for i, item := range canonical {
|
||||
vec := vectors[i]
|
||||
if err := r.qdrant.Upsert(ctx, sparsePointID(botID, item.ID), qdrantclient.SparseVector{
|
||||
if err := r.qdrant.Upsert(ctx, runtimePointID(botID, item.ID), qdrantclient.SparseVector{
|
||||
Indices: vec.Indices,
|
||||
Values: vec.Values,
|
||||
}, sparsePayload(botID, item)); err != nil {
|
||||
}, runtimePayload(botID, item)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sparseResultToItem(r qdrantclient.SearchResult) adapters.MemoryItem {
|
||||
item := adapters.MemoryItem{
|
||||
ID: r.ID,
|
||||
Score: r.Score,
|
||||
}
|
||||
if r.Payload != nil {
|
||||
if sourceID := strings.TrimSpace(r.Payload["source_entry_id"]); sourceID != "" {
|
||||
item.ID = sourceID
|
||||
}
|
||||
item.Memory = r.Payload["memory"]
|
||||
item.Hash = r.Payload["hash"]
|
||||
item.BotID = r.Payload["bot_id"]
|
||||
item.CreatedAt = r.Payload["created_at"]
|
||||
item.UpdatedAt = r.Payload["updated_at"]
|
||||
}
|
||||
return item
|
||||
}
|
||||
|
||||
func (r *sparseRuntime) populateExplainStats(ctx context.Context, items []*adapters.MemoryItem) {
|
||||
if len(items) == 0 {
|
||||
return
|
||||
@@ -608,135 +585,3 @@ func sparseMemoryItemPointers(items []adapters.MemoryItem) []*adapters.MemoryIte
|
||||
}
|
||||
return pointers
|
||||
}
|
||||
|
||||
func sparseCanonicalStoreItem(item storefs.MemoryItem) storefs.MemoryItem {
|
||||
item.ID = strings.TrimSpace(item.ID)
|
||||
item.Memory = strings.TrimSpace(item.Memory)
|
||||
if item.Memory != "" && strings.TrimSpace(item.Hash) == "" {
|
||||
item.Hash = sparseRuntimeHash(item.Memory)
|
||||
}
|
||||
return item
|
||||
}
|
||||
|
||||
func sparsePayload(botID string, item storefs.MemoryItem) map[string]string {
|
||||
item = sparseCanonicalStoreItem(item)
|
||||
payload := map[string]string{
|
||||
"memory": item.Memory,
|
||||
"bot_id": strings.TrimSpace(botID),
|
||||
"source_entry_id": item.ID,
|
||||
"hash": item.Hash,
|
||||
}
|
||||
if item.CreatedAt != "" {
|
||||
payload["created_at"] = item.CreatedAt
|
||||
}
|
||||
if item.UpdatedAt != "" {
|
||||
payload["updated_at"] = item.UpdatedAt
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func sparsePayloadMatches(existing, expected map[string]string) bool {
|
||||
for key, value := range expected {
|
||||
if strings.TrimSpace(existing[key]) != strings.TrimSpace(value) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func sparseMemoryItemFromStore(item storefs.MemoryItem) adapters.MemoryItem {
|
||||
item = sparseCanonicalStoreItem(item)
|
||||
return adapters.MemoryItem{
|
||||
ID: item.ID,
|
||||
Memory: item.Memory,
|
||||
Hash: item.Hash,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
Score: item.Score,
|
||||
Metadata: item.Metadata,
|
||||
BotID: item.BotID,
|
||||
AgentID: item.AgentID,
|
||||
RunID: item.RunID,
|
||||
}
|
||||
}
|
||||
|
||||
func sparseStoreItemFromMemoryItem(item adapters.MemoryItem) storefs.MemoryItem {
|
||||
return sparseCanonicalStoreItem(storefs.MemoryItem{
|
||||
ID: item.ID,
|
||||
Memory: item.Memory,
|
||||
Hash: item.Hash,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
Score: item.Score,
|
||||
Metadata: item.Metadata,
|
||||
BotID: item.BotID,
|
||||
AgentID: item.AgentID,
|
||||
RunID: item.RunID,
|
||||
})
|
||||
}
|
||||
|
||||
func sparseRuntimeText(message string, messages []adapters.Message) string {
|
||||
text := strings.TrimSpace(message)
|
||||
if text == "" && len(messages) > 0 {
|
||||
parts := make([]string, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
content := strings.TrimSpace(m.Content)
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
role := strings.ToUpper(strings.TrimSpace(m.Role))
|
||||
if role == "" {
|
||||
role = "MESSAGE"
|
||||
}
|
||||
parts = append(parts, "["+role+"] "+content)
|
||||
}
|
||||
text = strings.Join(parts, "\n")
|
||||
}
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
func sparseRuntimeMemoryID(botID string, now time.Time) string {
|
||||
return botID + ":" + "mem_" + strconv.FormatInt(now.UnixNano(), 10)
|
||||
}
|
||||
|
||||
func sparseRuntimeHash(text string) string {
|
||||
sum := sha256.Sum256([]byte(strings.TrimSpace(text)))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func sparseRuntimeBotID(botID string, filters map[string]any) (string, error) {
|
||||
botID = strings.TrimSpace(botID)
|
||||
if botID == "" {
|
||||
botID = strings.TrimSpace(sparseRuntimeAny(filters, "bot_id"))
|
||||
}
|
||||
if botID == "" {
|
||||
botID = strings.TrimSpace(sparseRuntimeAny(filters, "scopeId"))
|
||||
}
|
||||
if botID == "" {
|
||||
return "", errors.New("bot_id is required")
|
||||
}
|
||||
return botID, nil
|
||||
}
|
||||
|
||||
func sparseRuntimeBotIDFromMemoryID(memoryID string) string {
|
||||
parts := strings.SplitN(strings.TrimSpace(memoryID), ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
|
||||
func sparseRuntimeAny(m map[string]any, key string) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
v, ok := m[key]
|
||||
if !ok || v == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
|
||||
func sparsePointID(botID, sourceID string) string {
|
||||
return uuid.NewSHA1(uuid.NameSpaceURL, []byte(strings.TrimSpace(botID)+"\n"+strings.TrimSpace(sourceID))).String()
|
||||
}
|
||||
|
||||
@@ -220,7 +220,7 @@ func TestSparseRuntimeAddWritesSourceAndSupportsRecall(t *testing.T) {
|
||||
if _, ok := store.items[item.ID]; !ok {
|
||||
t.Fatalf("expected memory %q to be written to markdown source", item.ID)
|
||||
}
|
||||
point, ok := index.points[sparsePointID("bot-1", item.ID)]
|
||||
point, ok := index.points[runtimePointID("bot-1", item.ID)]
|
||||
if !ok {
|
||||
t.Fatalf("expected qdrant point for source memory %q", item.ID)
|
||||
}
|
||||
@@ -258,14 +258,14 @@ func TestSparseRuntimeRebuildSyncsSourceAndRemovesStalePoints(t *testing.T) {
|
||||
storefs.MemoryItem{
|
||||
ID: "bot-1:mem_1",
|
||||
Memory: "Ran likes tea",
|
||||
Hash: sparseRuntimeHash("Ran likes tea"),
|
||||
Hash: runtimeHash("Ran likes tea"),
|
||||
CreatedAt: "2026-03-13T09:00:00Z",
|
||||
UpdatedAt: "2026-03-13T09:00:00Z",
|
||||
},
|
||||
storefs.MemoryItem{
|
||||
ID: "bot-1:mem_2",
|
||||
Memory: "Ran works in Berlin",
|
||||
Hash: sparseRuntimeHash("Ran works in Berlin"),
|
||||
Hash: runtimeHash("Ran works in Berlin"),
|
||||
CreatedAt: "2026-03-13T10:00:00Z",
|
||||
UpdatedAt: "2026-03-13T10:00:00Z",
|
||||
},
|
||||
@@ -276,8 +276,8 @@ func TestSparseRuntimeRebuildSyncsSourceAndRemovesStalePoints(t *testing.T) {
|
||||
store: store,
|
||||
}
|
||||
|
||||
index.points[sparsePointID("bot-1", "bot-1:mem_1")] = qdrantclient.SearchResult{
|
||||
ID: sparsePointID("bot-1", "bot-1:mem_1"),
|
||||
index.points[runtimePointID("bot-1", "bot-1:mem_1")] = qdrantclient.SearchResult{
|
||||
ID: runtimePointID("bot-1", "bot-1:mem_1"),
|
||||
Payload: map[string]string{
|
||||
"bot_id": "bot-1",
|
||||
"memory": "Ran likes tea",
|
||||
@@ -287,8 +287,8 @@ func TestSparseRuntimeRebuildSyncsSourceAndRemovesStalePoints(t *testing.T) {
|
||||
"updated_at": "2026-03-13T09:00:00Z",
|
||||
},
|
||||
}
|
||||
index.points[sparsePointID("bot-1", "bot-1:stale")] = qdrantclient.SearchResult{
|
||||
ID: sparsePointID("bot-1", "bot-1:stale"),
|
||||
index.points[runtimePointID("bot-1", "bot-1:stale")] = qdrantclient.SearchResult{
|
||||
ID: runtimePointID("bot-1", "bot-1:stale"),
|
||||
Payload: map[string]string{
|
||||
"bot_id": "bot-1",
|
||||
"memory": "stale memory",
|
||||
@@ -312,7 +312,7 @@ func TestSparseRuntimeRebuildSyncsSourceAndRemovesStalePoints(t *testing.T) {
|
||||
if result.RestoredCount != 2 {
|
||||
t.Fatalf("expected restored_count=2, got %d", result.RestoredCount)
|
||||
}
|
||||
if _, ok := index.points[sparsePointID("bot-1", "bot-1:stale")]; ok {
|
||||
if _, ok := index.points[runtimePointID("bot-1", "bot-1:stale")]; ok {
|
||||
t.Fatal("expected stale qdrant point to be removed")
|
||||
}
|
||||
}
|
||||
@@ -326,7 +326,7 @@ func TestSparseRuntimeGetAllIncludesExplainStats(t *testing.T) {
|
||||
storefs.MemoryItem{
|
||||
ID: "bot-1:mem_1",
|
||||
Memory: "Ran likes tea",
|
||||
Hash: sparseRuntimeHash("Ran likes tea"),
|
||||
Hash: runtimeHash("Ran likes tea"),
|
||||
CreatedAt: "2026-03-13T09:00:00Z",
|
||||
UpdatedAt: "2026-03-13T09:00:00Z",
|
||||
},
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
package adapters
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func TestTruncateSnippet_ASCII(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := TruncateSnippet("hello world", 5)
|
||||
if got != "hello..." {
|
||||
t.Fatalf("expected %q, got %q", "hello...", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateSnippet_NoTruncation(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := TruncateSnippet("short", 100)
|
||||
if got != "short" {
|
||||
t.Fatalf("expected %q, got %q", "short", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateSnippet_CJK(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := TruncateSnippet("你好世界啊", 3)
|
||||
if !utf8.ValidString(got) {
|
||||
t.Fatalf("result is not valid UTF-8: %q", got)
|
||||
}
|
||||
if got != "你好世..." {
|
||||
t.Fatalf("expected %q, got %q", "你好世...", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateSnippet_Emoji(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := TruncateSnippet("😀😁😂🤣😃", 2)
|
||||
if !utf8.ValidString(got) {
|
||||
t.Fatalf("result is not valid UTF-8: %q", got)
|
||||
}
|
||||
if got != "😀😁..." {
|
||||
t.Fatalf("expected %q, got %q", "😀😁...", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateSnippet_TrimWhitespace(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := TruncateSnippet(" hello ", 100)
|
||||
if got != "hello" {
|
||||
t.Fatalf("expected %q, got %q", "hello", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeduplicateItems(t *testing.T) {
|
||||
t.Parallel()
|
||||
items := []MemoryItem{
|
||||
{ID: "a", Memory: "first"},
|
||||
{ID: "b", Memory: "second"},
|
||||
{ID: "a", Memory: "duplicate"},
|
||||
}
|
||||
result := DeduplicateItems(items)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 items, got %d", len(result))
|
||||
}
|
||||
}
|
||||
@@ -61,6 +61,20 @@ func (*Service) ListMeta(_ context.Context) []ProviderMeta {
|
||||
Required: false,
|
||||
Example: "memory_sparse",
|
||||
},
|
||||
"context_target_items": {
|
||||
Type: "integer",
|
||||
Title: "Context Target Items",
|
||||
Description: "Target number of memory snippets to inject per chat turn. Defaults to 6.",
|
||||
Required: false,
|
||||
Example: 6,
|
||||
},
|
||||
"context_max_total_chars": {
|
||||
Type: "integer",
|
||||
Title: "Context Max Total Chars",
|
||||
Description: "Maximum total characters for all memory snippets combined. Defaults to 1800.",
|
||||
Required: false,
|
||||
Example: 1800,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -0,0 +1,320 @@
|
||||
package memllm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
"github.com/memohai/memoh/internal/agent"
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeout = 30 * time.Second
|
||||
maxExtractFacts = 10
|
||||
maxDecideActions = 20
|
||||
)
|
||||
|
||||
// Config holds model resolution details for the memory LLM.
|
||||
type Config struct {
|
||||
ModelID string
|
||||
BaseURL string
|
||||
APIKey string `json:"-"`
|
||||
ClientType string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// Client implements adapters.LLM using the Twilight AI SDK.
|
||||
type Client struct {
|
||||
cfg Config
|
||||
}
|
||||
|
||||
// New creates a memory LLM client.
|
||||
func New(cfg Config) *Client {
|
||||
if cfg.Timeout <= 0 {
|
||||
cfg.Timeout = defaultTimeout
|
||||
}
|
||||
return &Client{cfg: cfg}
|
||||
}
|
||||
|
||||
func (c *Client) model() *sdk.Model {
|
||||
return agent.CreateModel(agent.ModelConfig{
|
||||
ModelID: c.cfg.ModelID,
|
||||
ClientType: c.cfg.ClientType,
|
||||
APIKey: c.cfg.APIKey,
|
||||
BaseURL: c.cfg.BaseURL,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) Extract(ctx context.Context, req adapters.ExtractRequest) (adapters.ExtractResponse, error) {
|
||||
if len(req.Messages) == 0 {
|
||||
return adapters.ExtractResponse{}, nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, c.cfg.Timeout)
|
||||
defer cancel()
|
||||
|
||||
var sb strings.Builder
|
||||
for _, m := range req.Messages {
|
||||
content := strings.TrimSpace(m.Content)
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
role := strings.TrimSpace(m.Role)
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
sb.WriteString(strings.ToUpper(role[:1]) + role[1:])
|
||||
sb.WriteString(": ")
|
||||
sb.WriteString(content)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
transcript := strings.TrimSpace(sb.String())
|
||||
if transcript == "" {
|
||||
return adapters.ExtractResponse{}, nil
|
||||
}
|
||||
|
||||
systemPrompt := strings.ReplaceAll(agent.MemoryExtractPrompt, "{{today}}", time.Now().UTC().Format("2006-01-02"))
|
||||
|
||||
result, err := sdk.GenerateTextResult(ctx,
|
||||
sdk.WithModel(c.model()),
|
||||
sdk.WithSystem(systemPrompt),
|
||||
sdk.WithMessages([]sdk.Message{sdk.UserMessage(transcript)}),
|
||||
)
|
||||
if err != nil {
|
||||
return adapters.ExtractResponse{}, fmt.Errorf("extract: %w", err)
|
||||
}
|
||||
|
||||
facts := parseExtractResponse(result.Text)
|
||||
if len(facts) > maxExtractFacts {
|
||||
facts = facts[:maxExtractFacts]
|
||||
}
|
||||
return adapters.ExtractResponse{Facts: facts}, nil
|
||||
}
|
||||
|
||||
func (c *Client) Decide(ctx context.Context, req adapters.DecideRequest) (adapters.DecideResponse, error) {
|
||||
if len(req.Facts) == 0 {
|
||||
return adapters.DecideResponse{}, nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, c.cfg.Timeout)
|
||||
defer cancel()
|
||||
|
||||
userMessage := buildUpdateUserMessage(req.Candidates, req.Facts)
|
||||
|
||||
result, err := sdk.GenerateTextResult(ctx,
|
||||
sdk.WithModel(c.model()),
|
||||
sdk.WithSystem(agent.MemoryUpdatePrompt),
|
||||
sdk.WithMessages([]sdk.Message{sdk.UserMessage(userMessage)}),
|
||||
)
|
||||
if err != nil {
|
||||
return adapters.DecideResponse{}, fmt.Errorf("decide: %w", err)
|
||||
}
|
||||
|
||||
actions := parseUpdateResponse(result.Text)
|
||||
if len(actions) > maxDecideActions {
|
||||
actions = actions[:maxDecideActions]
|
||||
}
|
||||
return adapters.DecideResponse{Actions: actions}, nil
|
||||
}
|
||||
|
||||
func (c *Client) Compact(ctx context.Context, req adapters.CompactRequest) (adapters.CompactResponse, error) {
|
||||
if len(req.Memories) == 0 {
|
||||
return adapters.CompactResponse{}, nil
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, c.cfg.Timeout)
|
||||
defer cancel()
|
||||
|
||||
payload, err := json.Marshal(map[string]any{
|
||||
"memories": req.Memories,
|
||||
"target_count": req.TargetCount,
|
||||
"decay_days": req.DecayDays,
|
||||
})
|
||||
if err != nil {
|
||||
return adapters.CompactResponse{}, fmt.Errorf("compact: marshal input: %w", err)
|
||||
}
|
||||
result, err := sdk.GenerateTextResult(ctx,
|
||||
sdk.WithModel(c.model()),
|
||||
sdk.WithSystem(compactSystemPrompt),
|
||||
sdk.WithMessages([]sdk.Message{sdk.UserMessage(string(payload))}),
|
||||
)
|
||||
if err != nil {
|
||||
return adapters.CompactResponse{}, fmt.Errorf("compact: %w", err)
|
||||
}
|
||||
facts := parseJSONStringArray(result.Text)
|
||||
return adapters.CompactResponse{Facts: facts}, nil
|
||||
}
|
||||
|
||||
func (*Client) DetectLanguage(_ context.Context, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// buildUpdateUserMessage formats the Decide user message following Mem0's
|
||||
// update prompt convention: current memory + retrieved facts in triple backticks.
|
||||
func buildUpdateUserMessage(candidates []adapters.CandidateMemory, facts []string) string {
|
||||
var sb strings.Builder
|
||||
|
||||
if len(candidates) > 0 {
|
||||
sb.WriteString("Below is the current content of my memory which I have collected till now. You have to update it in the following format only:\n\n```\n")
|
||||
oldMem := make([]map[string]string, 0, len(candidates))
|
||||
for _, c := range candidates {
|
||||
oldMem = append(oldMem, map[string]string{
|
||||
"id": c.ID,
|
||||
"text": c.Memory,
|
||||
})
|
||||
}
|
||||
raw, _ := json.MarshalIndent(oldMem, "", " ")
|
||||
sb.Write(raw)
|
||||
sb.WriteString("\n```\n\n")
|
||||
} else {
|
||||
sb.WriteString("Current memory is empty.\n\n")
|
||||
}
|
||||
|
||||
sb.WriteString("The new retrieved facts are mentioned in the triple backticks. You have to analyze the new retrieved facts and determine whether these facts should be added, updated, or deleted in the memory.\n\n```\n")
|
||||
factsJSON, _ := json.Marshal(facts)
|
||||
sb.Write(factsJSON)
|
||||
sb.WriteString("\n```\n\n")
|
||||
|
||||
sb.WriteString(`You must return your response in the following JSON structure only:
|
||||
|
||||
{
|
||||
"memory" : [
|
||||
{
|
||||
"id" : " ",
|
||||
"text" : " ",
|
||||
"event" : " ",
|
||||
"old_memory" : " "
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Follow the instruction mentioned below:
|
||||
- Do not return anything from the custom few shot prompts provided above.
|
||||
- If the current memory is empty, then you have to add the new retrieved facts to the memory.
|
||||
- You should return the updated memory in only JSON format as shown below. The memory key should be the same if no changes are made.
|
||||
- If there is an addition, generate a new key and add the new memory corresponding to it.
|
||||
- If there is a deletion, the memory key-value pair should be removed from the memory.
|
||||
- If there is an update, the ID key should remain the same and only the value needs to be updated.
|
||||
|
||||
Do not return anything except the JSON format.
|
||||
`)
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// --- JSON parsing helpers ---
|
||||
|
||||
// parseExtractResponse parses the {"facts": [...]} response from Extract.
|
||||
func parseExtractResponse(text string) []string {
|
||||
text = extractJSONBlock(text)
|
||||
|
||||
var wrapper struct {
|
||||
Facts []string `json:"facts"`
|
||||
}
|
||||
if json.Unmarshal([]byte(text), &wrapper) == nil && len(wrapper.Facts) > 0 {
|
||||
return filterNonEmpty(wrapper.Facts)
|
||||
}
|
||||
|
||||
return parseJSONStringArray(text)
|
||||
}
|
||||
|
||||
func parseJSONStringArray(text string) []string {
|
||||
text = extractJSONBlock(text)
|
||||
var facts []string
|
||||
if json.Unmarshal([]byte(text), &facts) == nil {
|
||||
return filterNonEmpty(facts)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateResponseEntry mirrors a single item in Mem0's {"memory": [...]} response.
|
||||
type updateResponseEntry struct {
|
||||
ID string `json:"id"`
|
||||
Text string `json:"text"`
|
||||
Event string `json:"event"`
|
||||
OldMemory string `json:"old_memory"`
|
||||
}
|
||||
|
||||
// parseUpdateResponse parses the {"memory": [...]} response from Decide.
|
||||
func parseUpdateResponse(text string) []adapters.DecisionAction {
|
||||
text = extractJSONBlock(text)
|
||||
|
||||
var wrapper struct {
|
||||
Memory []updateResponseEntry `json:"memory"`
|
||||
}
|
||||
if json.Unmarshal([]byte(text), &wrapper) == nil && len(wrapper.Memory) > 0 {
|
||||
actions := make([]adapters.DecisionAction, 0, len(wrapper.Memory))
|
||||
for _, entry := range wrapper.Memory {
|
||||
event := strings.ToUpper(strings.TrimSpace(entry.Event))
|
||||
if event == "NONE" {
|
||||
event = "NOOP"
|
||||
}
|
||||
actions = append(actions, adapters.DecisionAction{
|
||||
Event: event,
|
||||
ID: strings.TrimSpace(entry.ID),
|
||||
Text: strings.TrimSpace(entry.Text),
|
||||
OldMemory: strings.TrimSpace(entry.OldMemory),
|
||||
})
|
||||
}
|
||||
return actions
|
||||
}
|
||||
|
||||
var flat []adapters.DecisionAction
|
||||
if json.Unmarshal([]byte(text), &flat) == nil {
|
||||
return flat
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractJSONBlock(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
if start := strings.Index(text, "```json"); start >= 0 {
|
||||
text = text[start+7:]
|
||||
if end := strings.Index(text, "```"); end >= 0 {
|
||||
text = text[:end]
|
||||
}
|
||||
} else if start := strings.Index(text, "```"); start >= 0 {
|
||||
text = text[start+3:]
|
||||
if end := strings.Index(text, "```"); end >= 0 {
|
||||
text = text[:end]
|
||||
}
|
||||
}
|
||||
text = strings.TrimSpace(text)
|
||||
if len(text) > 0 && text[0] != '{' && text[0] != '[' {
|
||||
braceIdx := strings.IndexByte(text, '{')
|
||||
bracketIdx := strings.IndexByte(text, '[')
|
||||
cutIdx := -1
|
||||
switch {
|
||||
case braceIdx >= 0 && bracketIdx >= 0:
|
||||
cutIdx = min(braceIdx, bracketIdx)
|
||||
case braceIdx >= 0:
|
||||
cutIdx = braceIdx
|
||||
case bracketIdx >= 0:
|
||||
cutIdx = bracketIdx
|
||||
}
|
||||
if cutIdx >= 0 {
|
||||
text = text[cutIdx:]
|
||||
}
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func filterNonEmpty(ss []string) []string {
|
||||
out := make([]string, 0, len(ss))
|
||||
for _, s := range ss {
|
||||
s = strings.TrimSpace(s)
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
const compactSystemPrompt = `You are a memory compaction assistant. Given a list of memories and a target count, consolidate them into fewer, higher-quality entries.
|
||||
|
||||
Merge duplicate or overlapping facts. Drop obsolete or low-value entries. Keep the most important and recent information.
|
||||
|
||||
Return a JSON array of concise fact strings representing the compacted memory set.`
|
||||
@@ -0,0 +1,175 @@
|
||||
package memllm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
)
|
||||
|
||||
func TestParseJSONStringArray_Valid(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := parseJSONStringArray(`["fact one", "fact two"]`)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 facts, got %d", len(result))
|
||||
}
|
||||
if result[0] != "fact one" || result[1] != "fact two" {
|
||||
t.Fatalf("unexpected facts: %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONStringArray_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := parseJSONStringArray(`[]`)
|
||||
if len(result) != 0 {
|
||||
t.Fatalf("expected 0 facts, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONStringArray_CodeFence(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := "```json\n[\"hello\", \"world\"]\n```"
|
||||
result := parseJSONStringArray(input)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 facts from code fence, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONStringArray_PrefixText(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := "Here are the facts:\n[\"a\", \"b\"]"
|
||||
result := parseJSONStringArray(input)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 facts with prefix text, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONStringArray_Garbage(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := parseJSONStringArray("this is not json at all")
|
||||
if result != nil {
|
||||
t.Fatalf("expected nil for garbage input, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONStringArray_FiltersBlanks(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := parseJSONStringArray(`["fact one", "", " ", "fact two"]`)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 non-empty facts, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseExtractResponse_FactsWrapper(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := `{"facts": ["Name is John", "Is a Software engineer"]}`
|
||||
result := parseExtractResponse(input)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 facts, got %d", len(result))
|
||||
}
|
||||
if result[0] != "Name is John" {
|
||||
t.Fatalf("unexpected first fact: %q", result[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseExtractResponse_BareArray(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := `["fact one", "fact two"]`
|
||||
result := parseExtractResponse(input)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 facts from bare array, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseExtractResponse_EmptyFacts(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := parseExtractResponse(`{"facts": []}`)
|
||||
if len(result) != 0 {
|
||||
t.Fatalf("expected 0 facts, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUpdateResponse_Mem0Format(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := `{"memory": [
|
||||
{"id": "0", "text": "User is a software engineer", "event": "NONE"},
|
||||
{"id": "1", "text": "Name is John", "event": "ADD"},
|
||||
{"id": "2", "text": "Loves cheese pizza", "event": "DELETE"},
|
||||
{"id": "3", "text": "Moved to Berlin", "event": "UPDATE", "old_memory": "Lives in Tokyo"}
|
||||
]}`
|
||||
result := parseUpdateResponse(input)
|
||||
if len(result) != 4 {
|
||||
t.Fatalf("expected 4 actions, got %d", len(result))
|
||||
}
|
||||
if result[0].Event != "NOOP" {
|
||||
t.Fatalf("expected NONE mapped to NOOP, got %q", result[0].Event)
|
||||
}
|
||||
if result[1].Event != "ADD" || result[1].Text != "Name is John" {
|
||||
t.Fatalf("unexpected ADD action: %+v", result[1])
|
||||
}
|
||||
if result[2].Event != "DELETE" || result[2].ID != "2" {
|
||||
t.Fatalf("unexpected DELETE action: %+v", result[2])
|
||||
}
|
||||
if result[3].Event != "UPDATE" || result[3].OldMemory != "Lives in Tokyo" {
|
||||
t.Fatalf("unexpected UPDATE action: %+v", result[3])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUpdateResponse_FlatArrayFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := `[{"event":"ADD","text":"User likes tea"},{"event":"NOOP"},{"event":"DELETE","id":"bot-1:mem_123"}]`
|
||||
result := parseUpdateResponse(input)
|
||||
if len(result) != 3 {
|
||||
t.Fatalf("expected 3 actions, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUpdateResponse_WithCodeFence(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := "```json\n{\"memory\": [{\"event\":\"ADD\",\"text\":\"foo\",\"id\":\"1\"}]}\n```"
|
||||
result := parseUpdateResponse(input)
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("expected 1 action from code fence, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUpdateResponse_EmptyMemory(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := parseUpdateResponse(`{"memory": []}`)
|
||||
if result != nil {
|
||||
t.Fatalf("expected nil for empty memory array, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUpdateResponse_Garbage(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := parseUpdateResponse("not json")
|
||||
if result != nil {
|
||||
t.Fatalf("expected nil for garbage, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractJSONBlock_NoFence(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := extractJSONBlock(`["a"]`)
|
||||
if got != `["a"]` {
|
||||
t.Fatalf("expected raw pass-through, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractJSONBlock_JSONFence(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := extractJSONBlock("```json\n[\"a\"]\n```")
|
||||
if got != `["a"]` {
|
||||
t.Fatalf("expected extracted content, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractJSONBlock_PlainFence(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := extractJSONBlock("```\n[\"a\"]\n```")
|
||||
if got != `["a"]` {
|
||||
t.Fatalf("expected extracted content, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
var _ adapters.LLM = (*Client)(nil)
|
||||
@@ -0,0 +1,33 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
openaiembedding "github.com/memohai/twilight-ai/provider/openai/embedding"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
)
|
||||
|
||||
// NewSDKEmbeddingModel creates a Twilight AI SDK EmbeddingModel for the given
|
||||
// provider configuration. Currently all embedding providers use the
|
||||
// OpenAI-compatible /embeddings endpoint (including Google-hosted models that
|
||||
// expose the same wire format), so we route everything through the OpenAI
|
||||
// embedding provider. If a future provider requires a different wire protocol,
|
||||
// add a branch here.
|
||||
func NewSDKEmbeddingModel(baseURL, apiKey, modelID string, timeout time.Duration) *sdk.EmbeddingModel {
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
httpClient := &http.Client{Timeout: timeout}
|
||||
|
||||
opts := []openaiembedding.Option{
|
||||
openaiembedding.WithAPIKey(apiKey),
|
||||
openaiembedding.WithHTTPClient(httpClient),
|
||||
}
|
||||
if baseURL != "" {
|
||||
opts = append(opts, openaiembedding.WithBaseURL(baseURL))
|
||||
}
|
||||
|
||||
p := openaiembedding.New(opts...)
|
||||
return p.EmbeddingModel(modelID)
|
||||
}
|
||||
+10
-38
@@ -1,11 +1,8 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -100,26 +97,18 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// testEmbeddingModel probes an embedding model by sending a minimal
|
||||
// request to the /embeddings endpoint.
|
||||
// testEmbeddingModel probes an embedding model by performing a minimal
|
||||
// embedding request via the Twilight SDK, verifying that the model is
|
||||
// reachable and functional rather than merely checking HTTP connectivity.
|
||||
func (*Service) testEmbeddingModel(ctx context.Context, baseURL, apiKey, modelID string) (TestResponse, error) {
|
||||
body, _ := json.Marshal(map[string]any{"model": modelID, "input": "hello"})
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, probeTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
|
||||
baseURL+"/embeddings", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return TestResponse{Status: TestStatusError, Message: err.Error()}, nil
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
model := NewSDKEmbeddingModel(baseURL, apiKey, modelID, probeTimeout)
|
||||
client := sdk.NewClient()
|
||||
|
||||
start := time.Now()
|
||||
httpClient := &http.Client{Timeout: probeTimeout}
|
||||
// #nosec G704 -- baseURL comes from the configured provider endpoint that this health probe is expected to test.
|
||||
resp, err := httpClient.Do(req)
|
||||
_, err := client.Embed(ctx, "hello", sdk.WithEmbeddingModel(model))
|
||||
latency := time.Since(start).Milliseconds()
|
||||
|
||||
if err != nil {
|
||||
@@ -130,30 +119,13 @@ func (*Service) testEmbeddingModel(ctx context.Context, baseURL, apiKey, modelID
|
||||
Message: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
|
||||
result, classifyErr := sdk.ClassifyProbeStatus(resp.StatusCode)
|
||||
if classifyErr != nil {
|
||||
return TestResponse{
|
||||
Status: TestStatusError,
|
||||
Reachable: true,
|
||||
LatencyMs: latency,
|
||||
Message: classifyErr.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
tr := TestResponse{
|
||||
return TestResponse{
|
||||
Status: TestStatusOK,
|
||||
Reachable: true,
|
||||
LatencyMs: latency,
|
||||
Message: result.Message,
|
||||
}
|
||||
if result.Supported {
|
||||
tr.Status = TestStatusOK
|
||||
} else {
|
||||
tr.Status = TestStatusModelNotSupported
|
||||
}
|
||||
return tr, nil
|
||||
Message: "embedding model is operational",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewSDKProvider creates a Twilight AI SDK Provider for the given client type.
|
||||
|
||||
Reference in New Issue
Block a user