Files
Memoh/internal/memory/llm_client.go
T
BBQ 83b6ee608c refactor: bind container lifecycle to bot and improve schedule trigger flow
- Add SetupBotContainer to ContainerLifecycle interface so containers
  are automatically created when a bot is created, matching the existing
  cleanup-on-delete behavior.
- Refactor schedule tools to use bot-scoped API paths and pass identity
  context for proper authorization.
- Introduce dedicated trigger-schedule endpoint in chat resolver with
  explicit schedule payload instead of reusing the generic chat path.
- Generate short-lived JWT tokens for schedule trigger callbacks with
  resolved bot owner identity.
- Validate required parameters in NewLLMClient and NewOpenAIEmbedder
  constructors, returning errors instead of falling back to defaults.
- Add unit tests for schedule token generation and chat resolver.
2026-02-07 12:04:37 +08:00

297 lines
7.4 KiB
Go

package memory
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"time"
)
type LLMClient struct {
baseURL string
apiKey string
model string
logger *slog.Logger
http *http.Client
}
func NewLLMClient(log *slog.Logger, baseURL, apiKey, model string, timeout time.Duration) (*LLMClient, error) {
if strings.TrimSpace(baseURL) == "" {
return nil, fmt.Errorf("llm client: base url is required")
}
if strings.TrimSpace(apiKey) == "" {
return nil, fmt.Errorf("llm client: api key is required")
}
if strings.TrimSpace(model) == "" {
return nil, fmt.Errorf("llm client: model is required")
}
if log == nil {
log = slog.Default()
}
if timeout <= 0 {
timeout = 10 * time.Second
}
return &LLMClient{
baseURL: strings.TrimRight(baseURL, "/"),
apiKey: apiKey,
model: model,
logger: log.With(slog.String("client", "llm")),
http: &http.Client{
Timeout: timeout,
},
}, nil
}
func (c *LLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) {
if len(req.Messages) == 0 {
return ExtractResponse{}, fmt.Errorf("messages is required")
}
parsedMessages := parseMessages(formatMessages(req.Messages))
systemPrompt, userPrompt := getFactRetrievalMessages(parsedMessages)
content, err := c.callChat(ctx, []chatMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userPrompt},
})
if err != nil {
return ExtractResponse{}, err
}
var parsed ExtractResponse
if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &parsed); err != nil {
return ExtractResponse{}, err
}
return parsed, nil
}
func (c *LLMClient) Decide(ctx context.Context, req DecideRequest) (DecideResponse, error) {
if len(req.Facts) == 0 {
return DecideResponse{}, fmt.Errorf("facts is required")
}
retrieved := make([]map[string]string, 0, len(req.Candidates))
for _, candidate := range req.Candidates {
retrieved = append(retrieved, map[string]string{
"id": candidate.ID,
"text": candidate.Memory,
})
}
prompt := getUpdateMemoryMessages(retrieved, req.Facts)
content, err := c.callChat(ctx, []chatMessage{
{Role: "user", Content: prompt},
})
if err != nil {
return DecideResponse{}, err
}
cleaned := removeCodeBlocks(content)
var memoryItems []map[string]any
// Try parsing as object first
var raw map[string]any
if err := json.Unmarshal([]byte(cleaned), &raw); err == nil {
memoryItems = normalizeMemoryItems(raw["memory"])
} else {
// If object parsing fails, try parsing as array directly
var arr []any
if err := json.Unmarshal([]byte(cleaned), &arr); err != nil {
return DecideResponse{}, fmt.Errorf("failed to parse LLM response: %w", err)
}
memoryItems = normalizeMemoryItems(arr)
}
actions := make([]DecisionAction, 0, len(memoryItems))
for _, item := range memoryItems {
event := strings.ToUpper(asString(item["event"]))
if event == "" {
event = "ADD"
}
if event == "NONE" {
continue
}
text := asString(item["text"])
if text == "" {
text = asString(item["fact"])
}
if strings.TrimSpace(text) == "" {
continue
}
actions = append(actions, DecisionAction{
Event: event,
ID: normalizeID(item["id"]),
Text: text,
OldMemory: asString(item["old_memory"]),
})
}
return DecideResponse{Actions: actions}, nil
}
func (c *LLMClient) DetectLanguage(ctx context.Context, text string) (string, error) {
if strings.TrimSpace(text) == "" {
return "", fmt.Errorf("text is required")
}
systemPrompt, userPrompt := getLanguageDetectionMessages(text)
content, err := c.callChat(ctx, []chatMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userPrompt},
})
if err != nil {
return "", err
}
var parsed struct {
Language string `json:"language"`
}
if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &parsed); err != nil {
return "", err
}
lang := strings.ToLower(strings.TrimSpace(parsed.Language))
if !isAllowedLanguageCode(lang) {
return "", fmt.Errorf("unsupported language code: %s", lang)
}
return lang, nil
}
type chatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type chatRequest struct {
Model string `json:"model"`
Temperature float32 `json:"temperature"`
ResponseFormat map[string]string `json:"response_format,omitempty"`
Messages []chatMessage `json:"messages"`
}
type chatResponse struct {
Choices []struct {
Message chatMessage `json:"message"`
} `json:"choices"`
}
func (c *LLMClient) callChat(ctx context.Context, messages []chatMessage) (string, error) {
if c.apiKey == "" {
return "", fmt.Errorf("llm api key is required")
}
body, err := json.Marshal(chatRequest{
Model: c.model,
Temperature: 0,
ResponseFormat: map[string]string{
"type": "json_object",
},
Messages: messages,
})
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/chat/completions", bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.apiKey)
resp, err := c.http.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
b, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("llm error: %s", strings.TrimSpace(string(b)))
}
var parsed chatResponse
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
return "", err
}
if len(parsed.Choices) == 0 || parsed.Choices[0].Message.Content == "" {
return "", fmt.Errorf("llm response missing content")
}
return parsed.Choices[0].Message.Content, nil
}
func formatMessages(messages []Message) []string {
formatted := make([]string, 0, len(messages))
for _, message := range messages {
formatted = append(formatted, fmt.Sprintf("%s: %s", message.Role, message.Content))
}
return formatted
}
func asString(value any) string {
switch typed := value.(type) {
case string:
return typed
case float64:
if typed == float64(int64(typed)) {
return fmt.Sprintf("%d", int64(typed))
}
return fmt.Sprintf("%f", typed)
case int:
return fmt.Sprintf("%d", typed)
case int64:
return fmt.Sprintf("%d", typed)
default:
return ""
}
}
func normalizeID(value any) string {
id := asString(value)
if id == "" {
return ""
}
return id
}
func normalizeMemoryItems(value any) []map[string]any {
switch typed := value.(type) {
case []any:
items := make([]map[string]any, 0, len(typed))
for _, item := range typed {
if m, ok := item.(map[string]any); ok {
items = append(items, m)
}
}
return items
case map[string]any:
// If this map looks like a single item, wrap it.
if _, hasText := typed["text"]; hasText {
return []map[string]any{typed}
}
if _, hasFact := typed["fact"]; hasFact {
return []map[string]any{typed}
}
if _, hasEvent := typed["event"]; hasEvent {
return []map[string]any{typed}
}
// Otherwise treat as map of items.
items := make([]map[string]any, 0, len(typed))
for _, item := range typed {
if m, ok := item.(map[string]any); ok {
items = append(items, m)
}
}
return items
default:
return nil
}
}
func isAllowedLanguageCode(code string) bool {
switch strings.ToLower(strings.TrimSpace(code)) {
case "ar", "bg", "ca", "cjk", "ckb", "da", "de", "el", "en", "es", "eu",
"fa", "fi", "fr", "ga", "gl", "hi", "hr", "hu", "hy", "id", "in",
"it", "nl", "no", "pl", "pt", "ro", "ru", "sv", "tr":
return true
default:
return false
}
}