mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat: provider management & chat
This commit is contained in:
@@ -0,0 +1,129 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/chat"
|
||||
)
|
||||
|
||||
// ProviderLLMClient uses chat.Provider to make LLM calls for memory operations
|
||||
type ProviderLLMClient struct {
|
||||
provider chat.Provider
|
||||
model string
|
||||
}
|
||||
|
||||
// NewProviderLLMClient creates a new LLM client that uses chat.Provider
|
||||
func NewProviderLLMClient(provider chat.Provider, model string) *ProviderLLMClient {
|
||||
if model == "" {
|
||||
model = "gpt-4.1-nano-2025-04-14"
|
||||
}
|
||||
return &ProviderLLMClient{
|
||||
provider: provider,
|
||||
model: model,
|
||||
}
|
||||
}
|
||||
|
||||
// Extract extracts facts from messages using the provider
|
||||
func (c *ProviderLLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) {
|
||||
if len(req.Messages) == 0 {
|
||||
return ExtractResponse{}, fmt.Errorf("messages is required")
|
||||
}
|
||||
|
||||
parsedMessages := parseMessages(formatMessages(req.Messages))
|
||||
systemPrompt, userPrompt := getFactRetrievalMessages(parsedMessages)
|
||||
|
||||
// Call provider with JSON mode
|
||||
temp := float32(0)
|
||||
result, err := c.provider.Chat(ctx, chat.Request{
|
||||
Model: c.model,
|
||||
Temperature: &temp,
|
||||
ResponseFormat: &chat.ResponseFormat{
|
||||
Type: "json_object",
|
||||
},
|
||||
Messages: []chat.Message{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: userPrompt},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return ExtractResponse{}, err
|
||||
}
|
||||
|
||||
content := result.Message.Content
|
||||
var parsed ExtractResponse
|
||||
if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &parsed); err != nil {
|
||||
return ExtractResponse{}, err
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
// Decide decides what actions to take based on facts and existing memories
|
||||
func (c *ProviderLLMClient) Decide(ctx context.Context, req DecideRequest) (DecideResponse, error) {
|
||||
if len(req.Facts) == 0 {
|
||||
return DecideResponse{}, fmt.Errorf("facts is required")
|
||||
}
|
||||
|
||||
retrieved := make([]map[string]string, 0, len(req.Candidates))
|
||||
for _, candidate := range req.Candidates {
|
||||
retrieved = append(retrieved, map[string]string{
|
||||
"id": candidate.ID,
|
||||
"text": candidate.Memory,
|
||||
})
|
||||
}
|
||||
|
||||
prompt := getUpdateMemoryMessages(retrieved, req.Facts)
|
||||
|
||||
// Call provider with JSON mode
|
||||
temp := float32(0)
|
||||
result, err := c.provider.Chat(ctx, chat.Request{
|
||||
Model: c.model,
|
||||
Temperature: &temp,
|
||||
ResponseFormat: &chat.ResponseFormat{
|
||||
Type: "json_object",
|
||||
},
|
||||
Messages: []chat.Message{
|
||||
{Role: "user", Content: prompt},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return DecideResponse{}, err
|
||||
}
|
||||
|
||||
content := result.Message.Content
|
||||
var raw map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &raw); err != nil {
|
||||
return DecideResponse{}, err
|
||||
}
|
||||
|
||||
memoryItems := normalizeMemoryItems(raw["memory"])
|
||||
actions := make([]DecisionAction, 0, len(memoryItems))
|
||||
for _, item := range memoryItems {
|
||||
event := strings.ToUpper(asString(item["event"]))
|
||||
if event == "" {
|
||||
event = "ADD"
|
||||
}
|
||||
if event == "NONE" {
|
||||
continue
|
||||
}
|
||||
|
||||
text := asString(item["text"])
|
||||
if text == "" {
|
||||
text = asString(item["fact"])
|
||||
}
|
||||
if strings.TrimSpace(text) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
actions = append(actions, DecisionAction{
|
||||
Event: event,
|
||||
ID: normalizeID(item["id"]),
|
||||
Text: text,
|
||||
OldMemory: asString(item["old_memory"]),
|
||||
})
|
||||
}
|
||||
return DecideResponse{Actions: actions}, nil
|
||||
}
|
||||
|
||||
+24
-25
@@ -16,20 +16,20 @@ import (
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
llm *LLMClient
|
||||
embedder embeddings.Embedder
|
||||
store *QdrantStore
|
||||
resolver *embeddings.Resolver
|
||||
llm LLM
|
||||
embedder embeddings.Embedder
|
||||
store *QdrantStore
|
||||
resolver *embeddings.Resolver
|
||||
defaultTextModelID string
|
||||
defaultMultimodalModelID string
|
||||
}
|
||||
|
||||
func NewService(llm *LLMClient, embedder embeddings.Embedder, store *QdrantStore, resolver *embeddings.Resolver, defaultTextModelID, defaultMultimodalModelID string) *Service {
|
||||
func NewService(llm LLM, embedder embeddings.Embedder, store *QdrantStore, resolver *embeddings.Resolver, defaultTextModelID, defaultMultimodalModelID string) *Service {
|
||||
return &Service{
|
||||
llm: llm,
|
||||
embedder: embedder,
|
||||
store: store,
|
||||
resolver: resolver,
|
||||
llm: llm,
|
||||
embedder: embedder,
|
||||
store: store,
|
||||
resolver: resolver,
|
||||
defaultTextModelID: defaultTextModelID,
|
||||
defaultMultimodalModelID: defaultMultimodalModelID,
|
||||
}
|
||||
@@ -138,10 +138,10 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse
|
||||
}
|
||||
|
||||
var (
|
||||
vector []float32
|
||||
store *QdrantStore
|
||||
vector []float32
|
||||
store *QdrantStore
|
||||
vectorName string
|
||||
err error
|
||||
err error
|
||||
)
|
||||
if modality == embeddings.TypeMultimodal {
|
||||
if s.resolver == nil {
|
||||
@@ -237,10 +237,10 @@ func (s *Service) EmbedUpsert(ctx context.Context, req EmbedUpsertRequest) (Embe
|
||||
metadata["model_id"] = result.Model
|
||||
}
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{{
|
||||
ID: id,
|
||||
Vector: result.Embedding,
|
||||
ID: id,
|
||||
Vector: result.Embedding,
|
||||
VectorName: vectorName,
|
||||
Payload: payload,
|
||||
Payload: payload,
|
||||
}}); err != nil {
|
||||
return EmbedUpsertResponse{}, err
|
||||
}
|
||||
@@ -280,10 +280,10 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{{
|
||||
ID: req.MemoryID,
|
||||
Vector: vector,
|
||||
ID: req.MemoryID,
|
||||
Vector: vector,
|
||||
VectorName: s.vectorNameForText(),
|
||||
Payload: payload,
|
||||
Payload: payload,
|
||||
}}); err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
@@ -411,10 +411,10 @@ func (s *Service) applyAdd(ctx context.Context, text string, filters map[string]
|
||||
id := uuid.NewString()
|
||||
payload := buildPayload(text, filters, metadata, "")
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{{
|
||||
ID: id,
|
||||
Vector: vector,
|
||||
ID: id,
|
||||
Vector: vector,
|
||||
VectorName: s.vectorNameForText(),
|
||||
Payload: payload,
|
||||
Payload: payload,
|
||||
}}); err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
@@ -448,10 +448,10 @@ func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{{
|
||||
ID: id,
|
||||
Vector: vector,
|
||||
ID: id,
|
||||
Vector: vector,
|
||||
VectorName: s.vectorNameForText(),
|
||||
Payload: payload,
|
||||
Payload: payload,
|
||||
}}); err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
@@ -756,4 +756,3 @@ func normalizeScore(score, minScore, maxScore float64) float64 {
|
||||
}
|
||||
return (score - minScore) / (maxScore - minScore)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
package memory
|
||||
|
||||
import "context"
|
||||
|
||||
// LLM is the interface for LLM operations needed by memory service
|
||||
type LLM interface {
|
||||
Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error)
|
||||
Decide(ctx context.Context, req DecideRequest) (DecideResponse, error)
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
|
||||
Reference in New Issue
Block a user