feat: provider management & chat

This commit is contained in:
Acbox
2026-01-26 23:06:54 +08:00
parent 35a8927a79
commit da6a264699
28 changed files with 4699 additions and 63 deletions
+129
View File
@@ -0,0 +1,129 @@
package memory
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/memohai/memoh/internal/chat"
)
// ProviderLLMClient uses chat.Provider to make LLM calls for memory operations
type ProviderLLMClient struct {
provider chat.Provider
model string
}
// NewProviderLLMClient creates a new LLM client that uses chat.Provider
func NewProviderLLMClient(provider chat.Provider, model string) *ProviderLLMClient {
if model == "" {
model = "gpt-4.1-nano-2025-04-14"
}
return &ProviderLLMClient{
provider: provider,
model: model,
}
}
// Extract extracts facts from messages using the provider
func (c *ProviderLLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) {
if len(req.Messages) == 0 {
return ExtractResponse{}, fmt.Errorf("messages is required")
}
parsedMessages := parseMessages(formatMessages(req.Messages))
systemPrompt, userPrompt := getFactRetrievalMessages(parsedMessages)
// Call provider with JSON mode
temp := float32(0)
result, err := c.provider.Chat(ctx, chat.Request{
Model: c.model,
Temperature: &temp,
ResponseFormat: &chat.ResponseFormat{
Type: "json_object",
},
Messages: []chat.Message{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userPrompt},
},
})
if err != nil {
return ExtractResponse{}, err
}
content := result.Message.Content
var parsed ExtractResponse
if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &parsed); err != nil {
return ExtractResponse{}, err
}
return parsed, nil
}
// Decide decides what actions to take based on facts and existing memories
func (c *ProviderLLMClient) Decide(ctx context.Context, req DecideRequest) (DecideResponse, error) {
if len(req.Facts) == 0 {
return DecideResponse{}, fmt.Errorf("facts is required")
}
retrieved := make([]map[string]string, 0, len(req.Candidates))
for _, candidate := range req.Candidates {
retrieved = append(retrieved, map[string]string{
"id": candidate.ID,
"text": candidate.Memory,
})
}
prompt := getUpdateMemoryMessages(retrieved, req.Facts)
// Call provider with JSON mode
temp := float32(0)
result, err := c.provider.Chat(ctx, chat.Request{
Model: c.model,
Temperature: &temp,
ResponseFormat: &chat.ResponseFormat{
Type: "json_object",
},
Messages: []chat.Message{
{Role: "user", Content: prompt},
},
})
if err != nil {
return DecideResponse{}, err
}
content := result.Message.Content
var raw map[string]interface{}
if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &raw); err != nil {
return DecideResponse{}, err
}
memoryItems := normalizeMemoryItems(raw["memory"])
actions := make([]DecisionAction, 0, len(memoryItems))
for _, item := range memoryItems {
event := strings.ToUpper(asString(item["event"]))
if event == "" {
event = "ADD"
}
if event == "NONE" {
continue
}
text := asString(item["text"])
if text == "" {
text = asString(item["fact"])
}
if strings.TrimSpace(text) == "" {
continue
}
actions = append(actions, DecisionAction{
Event: event,
ID: normalizeID(item["id"]),
Text: text,
OldMemory: asString(item["old_memory"]),
})
}
return DecideResponse{Actions: actions}, nil
}
+24 -25
View File
@@ -16,20 +16,20 @@ import (
)
type Service struct {
llm *LLMClient
embedder embeddings.Embedder
store *QdrantStore
resolver *embeddings.Resolver
llm LLM
embedder embeddings.Embedder
store *QdrantStore
resolver *embeddings.Resolver
defaultTextModelID string
defaultMultimodalModelID string
}
func NewService(llm *LLMClient, embedder embeddings.Embedder, store *QdrantStore, resolver *embeddings.Resolver, defaultTextModelID, defaultMultimodalModelID string) *Service {
func NewService(llm LLM, embedder embeddings.Embedder, store *QdrantStore, resolver *embeddings.Resolver, defaultTextModelID, defaultMultimodalModelID string) *Service {
return &Service{
llm: llm,
embedder: embedder,
store: store,
resolver: resolver,
llm: llm,
embedder: embedder,
store: store,
resolver: resolver,
defaultTextModelID: defaultTextModelID,
defaultMultimodalModelID: defaultMultimodalModelID,
}
@@ -138,10 +138,10 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse
}
var (
vector []float32
store *QdrantStore
vector []float32
store *QdrantStore
vectorName string
err error
err error
)
if modality == embeddings.TypeMultimodal {
if s.resolver == nil {
@@ -237,10 +237,10 @@ func (s *Service) EmbedUpsert(ctx context.Context, req EmbedUpsertRequest) (Embe
metadata["model_id"] = result.Model
}
if err := s.store.Upsert(ctx, []qdrantPoint{{
ID: id,
Vector: result.Embedding,
ID: id,
Vector: result.Embedding,
VectorName: vectorName,
Payload: payload,
Payload: payload,
}}); err != nil {
return EmbedUpsertResponse{}, err
}
@@ -280,10 +280,10 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er
return MemoryItem{}, err
}
if err := s.store.Upsert(ctx, []qdrantPoint{{
ID: req.MemoryID,
Vector: vector,
ID: req.MemoryID,
Vector: vector,
VectorName: s.vectorNameForText(),
Payload: payload,
Payload: payload,
}}); err != nil {
return MemoryItem{}, err
}
@@ -411,10 +411,10 @@ func (s *Service) applyAdd(ctx context.Context, text string, filters map[string]
id := uuid.NewString()
payload := buildPayload(text, filters, metadata, "")
if err := s.store.Upsert(ctx, []qdrantPoint{{
ID: id,
Vector: vector,
ID: id,
Vector: vector,
VectorName: s.vectorNameForText(),
Payload: payload,
Payload: payload,
}}); err != nil {
return MemoryItem{}, err
}
@@ -448,10 +448,10 @@ func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[
return MemoryItem{}, err
}
if err := s.store.Upsert(ctx, []qdrantPoint{{
ID: id,
Vector: vector,
ID: id,
Vector: vector,
VectorName: s.vectorNameForText(),
Payload: payload,
Payload: payload,
}}); err != nil {
return MemoryItem{}, err
}
@@ -756,4 +756,3 @@ func normalizeScore(score, minScore, maxScore float64) float64 {
}
return (score - minScore) / (maxScore - minScore)
}
+8
View File
@@ -1,5 +1,13 @@
package memory
import "context"
// LLM is the interface for LLM operations needed by memory service
type LLM interface {
Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error)
Decide(ctx context.Context, req DecideRequest) (DecideResponse, error)
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`