mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
65b2797626
Move CreateModel, BuildReasoningOptions, ReasoningBudgetTokens and related types from internal/agent to internal/models as NewSDKChatModel, SDKModelConfig, etc. This eliminates duplicate ClientType constants and centralises all Twilight AI SDK instance creation in a single package. NewSDKEmbeddingModel now accepts a clientType parameter and dispatches to the native Google embedding provider for google-generative-ai, instead of always using the OpenAI-compatible endpoint.
322 lines
9.0 KiB
Go
322 lines
9.0 KiB
Go
package memllm
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
sdk "github.com/memohai/twilight-ai/sdk"
|
|
|
|
"github.com/memohai/memoh/internal/agent"
|
|
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
|
"github.com/memohai/memoh/internal/models"
|
|
)
|
|
|
|
const (
|
|
defaultTimeout = 30 * time.Second
|
|
maxExtractFacts = 10
|
|
maxDecideActions = 20
|
|
)
|
|
|
|
// Config holds model resolution details for the memory LLM.
|
|
type Config struct {
|
|
ModelID string
|
|
BaseURL string
|
|
APIKey string `json:"-"`
|
|
ClientType string
|
|
Timeout time.Duration
|
|
}
|
|
|
|
// Client implements adapters.LLM using the Twilight AI SDK.
|
|
type Client struct {
|
|
cfg Config
|
|
}
|
|
|
|
// New creates a memory LLM client.
|
|
func New(cfg Config) *Client {
|
|
if cfg.Timeout <= 0 {
|
|
cfg.Timeout = defaultTimeout
|
|
}
|
|
return &Client{cfg: cfg}
|
|
}
|
|
|
|
func (c *Client) model() *sdk.Model {
|
|
return models.NewSDKChatModel(models.SDKModelConfig{
|
|
ModelID: c.cfg.ModelID,
|
|
ClientType: c.cfg.ClientType,
|
|
APIKey: c.cfg.APIKey,
|
|
BaseURL: c.cfg.BaseURL,
|
|
})
|
|
}
|
|
|
|
func (c *Client) Extract(ctx context.Context, req adapters.ExtractRequest) (adapters.ExtractResponse, error) {
|
|
if len(req.Messages) == 0 {
|
|
return adapters.ExtractResponse{}, nil
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, c.cfg.Timeout)
|
|
defer cancel()
|
|
|
|
var sb strings.Builder
|
|
for _, m := range req.Messages {
|
|
content := strings.TrimSpace(m.Content)
|
|
if content == "" {
|
|
continue
|
|
}
|
|
role := strings.TrimSpace(m.Role)
|
|
if role == "" {
|
|
role = "user"
|
|
}
|
|
sb.WriteString(strings.ToUpper(role[:1]) + role[1:])
|
|
sb.WriteString(": ")
|
|
sb.WriteString(content)
|
|
sb.WriteString("\n")
|
|
}
|
|
transcript := strings.TrimSpace(sb.String())
|
|
if transcript == "" {
|
|
return adapters.ExtractResponse{}, nil
|
|
}
|
|
|
|
systemPrompt := strings.ReplaceAll(agent.MemoryExtractPrompt, "{{today}}", time.Now().UTC().Format("2006-01-02"))
|
|
|
|
result, err := sdk.GenerateTextResult(ctx,
|
|
sdk.WithModel(c.model()),
|
|
sdk.WithSystem(systemPrompt),
|
|
sdk.WithMessages([]sdk.Message{sdk.UserMessage(transcript)}),
|
|
)
|
|
if err != nil {
|
|
return adapters.ExtractResponse{}, fmt.Errorf("extract: %w", err)
|
|
}
|
|
|
|
facts := parseExtractResponse(result.Text)
|
|
if len(facts) > maxExtractFacts {
|
|
facts = facts[:maxExtractFacts]
|
|
}
|
|
return adapters.ExtractResponse{Facts: facts}, nil
|
|
}
|
|
|
|
func (c *Client) Decide(ctx context.Context, req adapters.DecideRequest) (adapters.DecideResponse, error) {
|
|
if len(req.Facts) == 0 {
|
|
return adapters.DecideResponse{}, nil
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, c.cfg.Timeout)
|
|
defer cancel()
|
|
|
|
userMessage := buildUpdateUserMessage(req.Candidates, req.Facts)
|
|
|
|
result, err := sdk.GenerateTextResult(ctx,
|
|
sdk.WithModel(c.model()),
|
|
sdk.WithSystem(agent.MemoryUpdatePrompt),
|
|
sdk.WithMessages([]sdk.Message{sdk.UserMessage(userMessage)}),
|
|
)
|
|
if err != nil {
|
|
return adapters.DecideResponse{}, fmt.Errorf("decide: %w", err)
|
|
}
|
|
|
|
actions := parseUpdateResponse(result.Text)
|
|
if len(actions) > maxDecideActions {
|
|
actions = actions[:maxDecideActions]
|
|
}
|
|
return adapters.DecideResponse{Actions: actions}, nil
|
|
}
|
|
|
|
func (c *Client) Compact(ctx context.Context, req adapters.CompactRequest) (adapters.CompactResponse, error) {
|
|
if len(req.Memories) == 0 {
|
|
return adapters.CompactResponse{}, nil
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx, c.cfg.Timeout)
|
|
defer cancel()
|
|
|
|
payload, err := json.Marshal(map[string]any{
|
|
"memories": req.Memories,
|
|
"target_count": req.TargetCount,
|
|
"decay_days": req.DecayDays,
|
|
})
|
|
if err != nil {
|
|
return adapters.CompactResponse{}, fmt.Errorf("compact: marshal input: %w", err)
|
|
}
|
|
result, err := sdk.GenerateTextResult(ctx,
|
|
sdk.WithModel(c.model()),
|
|
sdk.WithSystem(compactSystemPrompt),
|
|
sdk.WithMessages([]sdk.Message{sdk.UserMessage(string(payload))}),
|
|
)
|
|
if err != nil {
|
|
return adapters.CompactResponse{}, fmt.Errorf("compact: %w", err)
|
|
}
|
|
facts := parseJSONStringArray(result.Text)
|
|
return adapters.CompactResponse{Facts: facts}, nil
|
|
}
|
|
|
|
func (*Client) DetectLanguage(_ context.Context, _ string) (string, error) {
|
|
return "", nil
|
|
}
|
|
|
|
// buildUpdateUserMessage formats the Decide user message following Mem0's
|
|
// update prompt convention: current memory + retrieved facts in triple backticks.
|
|
func buildUpdateUserMessage(candidates []adapters.CandidateMemory, facts []string) string {
|
|
var sb strings.Builder
|
|
|
|
if len(candidates) > 0 {
|
|
sb.WriteString("Below is the current content of my memory which I have collected till now. You have to update it in the following format only:\n\n```\n")
|
|
oldMem := make([]map[string]string, 0, len(candidates))
|
|
for _, c := range candidates {
|
|
oldMem = append(oldMem, map[string]string{
|
|
"id": c.ID,
|
|
"text": c.Memory,
|
|
})
|
|
}
|
|
raw, _ := json.MarshalIndent(oldMem, "", " ")
|
|
sb.Write(raw)
|
|
sb.WriteString("\n```\n\n")
|
|
} else {
|
|
sb.WriteString("Current memory is empty.\n\n")
|
|
}
|
|
|
|
sb.WriteString("The new retrieved facts are mentioned in the triple backticks. You have to analyze the new retrieved facts and determine whether these facts should be added, updated, or deleted in the memory.\n\n```\n")
|
|
factsJSON, _ := json.Marshal(facts)
|
|
sb.Write(factsJSON)
|
|
sb.WriteString("\n```\n\n")
|
|
|
|
sb.WriteString(`You must return your response in the following JSON structure only:
|
|
|
|
{
|
|
"memory" : [
|
|
{
|
|
"id" : " ",
|
|
"text" : " ",
|
|
"event" : " ",
|
|
"old_memory" : " "
|
|
}
|
|
]
|
|
}
|
|
|
|
Follow the instruction mentioned below:
|
|
- Do not return anything from the custom few shot prompts provided above.
|
|
- If the current memory is empty, then you have to add the new retrieved facts to the memory.
|
|
- You should return the updated memory in only JSON format as shown below. The memory key should be the same if no changes are made.
|
|
- If there is an addition, generate a new key and add the new memory corresponding to it.
|
|
- If there is a deletion, the memory key-value pair should be removed from the memory.
|
|
- If there is an update, the ID key should remain the same and only the value needs to be updated.
|
|
|
|
Do not return anything except the JSON format.
|
|
`)
|
|
return sb.String()
|
|
}
|
|
|
|
// --- JSON parsing helpers ---
|
|
|
|
// parseExtractResponse parses the {"facts": [...]} response from Extract.
|
|
func parseExtractResponse(text string) []string {
|
|
text = extractJSONBlock(text)
|
|
|
|
var wrapper struct {
|
|
Facts []string `json:"facts"`
|
|
}
|
|
if json.Unmarshal([]byte(text), &wrapper) == nil && len(wrapper.Facts) > 0 {
|
|
return filterNonEmpty(wrapper.Facts)
|
|
}
|
|
|
|
return parseJSONStringArray(text)
|
|
}
|
|
|
|
func parseJSONStringArray(text string) []string {
|
|
text = extractJSONBlock(text)
|
|
var facts []string
|
|
if json.Unmarshal([]byte(text), &facts) == nil {
|
|
return filterNonEmpty(facts)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// updateResponseEntry mirrors a single item in Mem0's {"memory": [...]} response.
|
|
type updateResponseEntry struct {
|
|
ID string `json:"id"`
|
|
Text string `json:"text"`
|
|
Event string `json:"event"`
|
|
OldMemory string `json:"old_memory"`
|
|
}
|
|
|
|
// parseUpdateResponse parses the {"memory": [...]} response from Decide.
|
|
func parseUpdateResponse(text string) []adapters.DecisionAction {
|
|
text = extractJSONBlock(text)
|
|
|
|
var wrapper struct {
|
|
Memory []updateResponseEntry `json:"memory"`
|
|
}
|
|
if json.Unmarshal([]byte(text), &wrapper) == nil && len(wrapper.Memory) > 0 {
|
|
actions := make([]adapters.DecisionAction, 0, len(wrapper.Memory))
|
|
for _, entry := range wrapper.Memory {
|
|
event := strings.ToUpper(strings.TrimSpace(entry.Event))
|
|
if event == "NONE" {
|
|
event = "NOOP"
|
|
}
|
|
actions = append(actions, adapters.DecisionAction{
|
|
Event: event,
|
|
ID: strings.TrimSpace(entry.ID),
|
|
Text: strings.TrimSpace(entry.Text),
|
|
OldMemory: strings.TrimSpace(entry.OldMemory),
|
|
})
|
|
}
|
|
return actions
|
|
}
|
|
|
|
var flat []adapters.DecisionAction
|
|
if json.Unmarshal([]byte(text), &flat) == nil {
|
|
return flat
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func extractJSONBlock(text string) string {
|
|
text = strings.TrimSpace(text)
|
|
if start := strings.Index(text, "```json"); start >= 0 {
|
|
text = text[start+7:]
|
|
if end := strings.Index(text, "```"); end >= 0 {
|
|
text = text[:end]
|
|
}
|
|
} else if start := strings.Index(text, "```"); start >= 0 {
|
|
text = text[start+3:]
|
|
if end := strings.Index(text, "```"); end >= 0 {
|
|
text = text[:end]
|
|
}
|
|
}
|
|
text = strings.TrimSpace(text)
|
|
if len(text) > 0 && text[0] != '{' && text[0] != '[' {
|
|
braceIdx := strings.IndexByte(text, '{')
|
|
bracketIdx := strings.IndexByte(text, '[')
|
|
cutIdx := -1
|
|
switch {
|
|
case braceIdx >= 0 && bracketIdx >= 0:
|
|
cutIdx = min(braceIdx, bracketIdx)
|
|
case braceIdx >= 0:
|
|
cutIdx = braceIdx
|
|
case bracketIdx >= 0:
|
|
cutIdx = bracketIdx
|
|
}
|
|
if cutIdx >= 0 {
|
|
text = text[cutIdx:]
|
|
}
|
|
}
|
|
return text
|
|
}
|
|
|
|
func filterNonEmpty(ss []string) []string {
|
|
out := make([]string, 0, len(ss))
|
|
for _, s := range ss {
|
|
s = strings.TrimSpace(s)
|
|
if s != "" {
|
|
out = append(out, s)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
const compactSystemPrompt = `You are a memory compaction assistant. Given a list of memories and a target count, consolidate them into fewer, higher-quality entries.
|
|
|
|
Merge duplicate or overlapping facts. Drop obsolete or low-value entries. Keep the most important and recent information.
|
|
|
|
Return a JSON array of concise fact strings representing the compacted memory set.`
|