Files
Memoh/internal/memory/llm_client.go
T
BBQ 5a35ef34ac feat: channel gateway implementation and multi-bot refactor
- Refactor channel manager with support for Sender/Receiver interfaces and hot-swappable adapters.
- Implement identity routing and pre-authentication logic for inbound messages.
- Update database schema to support bot pre-auth keys and extended channel session metadata.
- Add Telegram and Feishu channel configuration and adapter enhancements.
- Update Swagger documentation and internal handlers for channel management.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-06 14:41:54 +08:00

295 lines
7.3 KiB
Go

package memory
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"time"
)
type LLMClient struct {
baseURL string
apiKey string
model string
logger *slog.Logger
http *http.Client
}
func NewLLMClient(log *slog.Logger, baseURL, apiKey, model string, timeout time.Duration) *LLMClient {
if log == nil {
log = slog.Default()
}
if baseURL == "" {
baseURL = "https://api.openai.com/v1"
}
baseURL = strings.TrimRight(baseURL, "/")
if model == "" {
model = "gpt-4.1-nano"
}
if timeout <= 0 {
timeout = 10 * time.Second
}
return &LLMClient{
baseURL: baseURL,
apiKey: apiKey,
model: model,
logger: log.With(slog.String("client", "llm")),
http: &http.Client{
Timeout: timeout,
},
}
}
func (c *LLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) {
if len(req.Messages) == 0 {
return ExtractResponse{}, fmt.Errorf("messages is required")
}
parsedMessages := parseMessages(formatMessages(req.Messages))
systemPrompt, userPrompt := getFactRetrievalMessages(parsedMessages)
content, err := c.callChat(ctx, []chatMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userPrompt},
})
if err != nil {
return ExtractResponse{}, err
}
var parsed ExtractResponse
if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &parsed); err != nil {
return ExtractResponse{}, err
}
return parsed, nil
}
func (c *LLMClient) Decide(ctx context.Context, req DecideRequest) (DecideResponse, error) {
if len(req.Facts) == 0 {
return DecideResponse{}, fmt.Errorf("facts is required")
}
retrieved := make([]map[string]string, 0, len(req.Candidates))
for _, candidate := range req.Candidates {
retrieved = append(retrieved, map[string]string{
"id": candidate.ID,
"text": candidate.Memory,
})
}
prompt := getUpdateMemoryMessages(retrieved, req.Facts)
content, err := c.callChat(ctx, []chatMessage{
{Role: "user", Content: prompt},
})
if err != nil {
return DecideResponse{}, err
}
cleaned := removeCodeBlocks(content)
var memoryItems []map[string]any
// Try parsing as object first
var raw map[string]any
if err := json.Unmarshal([]byte(cleaned), &raw); err == nil {
memoryItems = normalizeMemoryItems(raw["memory"])
} else {
// If object parsing fails, try parsing as array directly
var arr []any
if err := json.Unmarshal([]byte(cleaned), &arr); err != nil {
return DecideResponse{}, fmt.Errorf("failed to parse LLM response: %w", err)
}
memoryItems = normalizeMemoryItems(arr)
}
actions := make([]DecisionAction, 0, len(memoryItems))
for _, item := range memoryItems {
event := strings.ToUpper(asString(item["event"]))
if event == "" {
event = "ADD"
}
if event == "NONE" {
continue
}
text := asString(item["text"])
if text == "" {
text = asString(item["fact"])
}
if strings.TrimSpace(text) == "" {
continue
}
actions = append(actions, DecisionAction{
Event: event,
ID: normalizeID(item["id"]),
Text: text,
OldMemory: asString(item["old_memory"]),
})
}
return DecideResponse{Actions: actions}, nil
}
func (c *LLMClient) DetectLanguage(ctx context.Context, text string) (string, error) {
if strings.TrimSpace(text) == "" {
return "", fmt.Errorf("text is required")
}
systemPrompt, userPrompt := getLanguageDetectionMessages(text)
content, err := c.callChat(ctx, []chatMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userPrompt},
})
if err != nil {
return "", err
}
var parsed struct {
Language string `json:"language"`
}
if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &parsed); err != nil {
return "", err
}
lang := strings.ToLower(strings.TrimSpace(parsed.Language))
if !isAllowedLanguageCode(lang) {
return "", fmt.Errorf("unsupported language code: %s", lang)
}
return lang, nil
}
type chatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type chatRequest struct {
Model string `json:"model"`
Temperature float32 `json:"temperature"`
ResponseFormat map[string]string `json:"response_format,omitempty"`
Messages []chatMessage `json:"messages"`
}
type chatResponse struct {
Choices []struct {
Message chatMessage `json:"message"`
} `json:"choices"`
}
func (c *LLMClient) callChat(ctx context.Context, messages []chatMessage) (string, error) {
if c.apiKey == "" {
return "", fmt.Errorf("llm api key is required")
}
body, err := json.Marshal(chatRequest{
Model: c.model,
Temperature: 0,
ResponseFormat: map[string]string{
"type": "json_object",
},
Messages: messages,
})
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/chat/completions", bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.apiKey)
resp, err := c.http.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
b, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("llm error: %s", strings.TrimSpace(string(b)))
}
var parsed chatResponse
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
return "", err
}
if len(parsed.Choices) == 0 || parsed.Choices[0].Message.Content == "" {
return "", fmt.Errorf("llm response missing content")
}
return parsed.Choices[0].Message.Content, nil
}
func formatMessages(messages []Message) []string {
formatted := make([]string, 0, len(messages))
for _, message := range messages {
formatted = append(formatted, fmt.Sprintf("%s: %s", message.Role, message.Content))
}
return formatted
}
func asString(value any) string {
switch typed := value.(type) {
case string:
return typed
case float64:
if typed == float64(int64(typed)) {
return fmt.Sprintf("%d", int64(typed))
}
return fmt.Sprintf("%f", typed)
case int:
return fmt.Sprintf("%d", typed)
case int64:
return fmt.Sprintf("%d", typed)
default:
return ""
}
}
func normalizeID(value any) string {
id := asString(value)
if id == "" {
return ""
}
return id
}
func normalizeMemoryItems(value any) []map[string]any {
switch typed := value.(type) {
case []any:
items := make([]map[string]any, 0, len(typed))
for _, item := range typed {
if m, ok := item.(map[string]any); ok {
items = append(items, m)
}
}
return items
case map[string]any:
// If this map looks like a single item, wrap it.
if _, hasText := typed["text"]; hasText {
return []map[string]any{typed}
}
if _, hasFact := typed["fact"]; hasFact {
return []map[string]any{typed}
}
if _, hasEvent := typed["event"]; hasEvent {
return []map[string]any{typed}
}
// Otherwise treat as map of items.
items := make([]map[string]any, 0, len(typed))
for _, item := range typed {
if m, ok := item.(map[string]any); ok {
items = append(items, m)
}
}
return items
default:
return nil
}
}
func isAllowedLanguageCode(code string) bool {
switch strings.ToLower(strings.TrimSpace(code)) {
case "ar", "bg", "ca", "cjk", "ckb", "da", "de", "el", "en", "es", "eu",
"fa", "fi", "fr", "ga", "gl", "hi", "hr", "hu", "hy", "id", "in",
"it", "nl", "no", "pl", "pt", "ro", "ru", "sv", "tr":
return true
default:
return false
}
}