feat: chat api

This commit is contained in:
Acbox
2026-01-28 15:15:57 +08:00
parent 0711b1f086
commit 39215309da
26 changed files with 673 additions and 1873 deletions
+3 -1
View File
@@ -95,4 +95,6 @@ docs/docs/.vitepress/cache
dump.rdb
memory.db
config.toml
config.toml
.workdocs/
-477
View File
@@ -1,477 +0,0 @@
# Providers 模块文档
## 概述
本文档描述了独立的 Providers 模块,用于管理 LLM Provider 配置。
## 架构设计
### 模块结构
```
internal/
├── providers/ # 独立的 provider 模块
│ ├── types.go # 类型定义
│ └── service.go # 业务逻辑
└── handlers/
└── providers.go # API 处理器
```
### 分层设计
```
┌─────────────────────┐
│ API Layer │
│ (handlers) │
└──────────┬──────────┘
┌──────────▼──────────┐
│ Service Layer │
│ (providers pkg) │
└──────────┬──────────┘
┌──────────▼──────────┐
│ Data Layer │
│ (sqlc queries) │
└─────────────────────┘
```
## API 端点
### Providers API (`/providers`)
所有端点都需要 JWT 认证。
#### 1. 创建 Provider
```http
POST /providers
Content-Type: application/json
Authorization: Bearer <token>
{
"name": "OpenAI Official",
"client_type": "openai",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-...",
"metadata": {
"description": "Official OpenAI API"
}
}
```
**响应** (201 Created):
```json
{
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "OpenAI Official",
"client_type": "openai",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-12345***",
"metadata": {
"description": "Official OpenAI API"
},
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}
```
#### 2. 列出所有 Providers
```http
GET /providers
Authorization: Bearer <token>
```
**可选查询参数**:
- `client_type` - 按客户端类型过滤 (openai, anthropic, google, ollama)
**响应** (200 OK):
```json
[
{
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "OpenAI Official",
"client_type": "openai",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-12345***",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}
]
```
#### 3. 获取单个 Provider
```http
GET /providers/{id}
Authorization: Bearer <token>
```
或者按名称获取:
```http
GET /providers/name/{name}
Authorization: Bearer <token>
```
**响应** (200 OK): 同创建响应
#### 4. 更新 Provider
```http
PUT /providers/{id}
Content-Type: application/json
Authorization: Bearer <token>
{
"name": "OpenAI Updated",
"api_key": "sk-newkey..."
}
```
**注意**: 所有字段都是可选的,只更新提供的字段。
**响应** (200 OK): 返回更新后的 provider
#### 5. 删除 Provider
```http
DELETE /providers/{id}
Authorization: Bearer <token>
```
**响应** (204 No Content)
#### 6. 统计 Provider 数量
```http
GET /providers/count
Authorization: Bearer <token>
```
**可选查询参数**:
- `client_type` - 按客户端类型过滤
**响应** (200 OK):
```json
{
"count": 5
}
```
## 支持的 Client Types
| Client Type | 描述 | 需要 API Key |
|------------|------|-------------|
| `openai` | OpenAI 官方 API | ✅ |
| `openai-compat` | OpenAI 兼容的 API | ✅ |
| `anthropic` | Anthropic Claude API | ✅ |
| `google` | Google Gemini API | ✅ |
| `ollama` | 本地 Ollama | ❌ |
## 数据模型
### Provider 结构
```go
type CreateRequest struct {
Name string `json:"name"` // 必填
ClientType ClientType `json:"client_type"` // 必填
BaseURL string `json:"base_url"` // 必填
APIKey string `json:"api_key"` // 可选
Metadata map[string]interface{} `json:"metadata"` // 可选
}
type GetResponse struct {
ID string `json:"id"`
Name string `json:"name"`
ClientType string `json:"client_type"`
BaseURL string `json:"base_url"`
APIKey string `json:"api_key"` // 已脱敏
Metadata map[string]interface{} `json:"metadata"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
```
## 安全特性
### 1. API Key 脱敏
在响应中,API Key 会被自动脱敏:
- 只显示前 8 个字符
- 其余部分用 `*` 替换
- 例如: `sk-12345678***`
### 2. 认证保护
所有 API 端点都需要 JWT 认证:
```http
Authorization: Bearer <your-jwt-token>
```
### 3. 输入验证
- 自动验证 UUID 格式
- 验证 client_type 是否支持
- 验证必填字段
## 使用示例
### 1. 配置 OpenAI Provider
```bash
curl -X POST http://localhost:8080/providers \
-H "Authorization: Bearer $TOKEN" \
-H "Content-Type: application/json" \
-d '{
"name": "OpenAI GPT-4",
"client_type": "openai",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-..."
}'
```
### 2. 配置自定义 OpenAI 兼容服务
```bash
curl -X POST http://localhost:8080/providers \
-H "Authorization: Bearer $TOKEN" \
-H "Content-Type: application/json" \
-d '{
"name": "Azure OpenAI",
"client_type": "openai-compat",
"base_url": "https://your-resource.openai.azure.com/v1",
"api_key": "your-azure-key",
"metadata": {
"deployment": "gpt-4",
"region": "eastus"
}
}'
```
### 3. 配置本地 Ollama
```bash
curl -X POST http://localhost:8080/providers \
-H "Authorization: Bearer $TOKEN" \
-H "Content-Type: application/json" \
-d '{
"name": "Local Ollama",
"client_type": "ollama",
"base_url": "http://localhost:11434"
}'
```
### 4. 列出所有 OpenAI Providers
```bash
curl http://localhost:8080/providers?client_type=openai \
-H "Authorization: Bearer $TOKEN"
```
### 5. 更新 Provider API Key
```bash
curl -X PUT http://localhost:8080/providers/{id} \
-H "Authorization: Bearer $TOKEN" \
-H "Content-Type: application/json" \
-d '{
"api_key": "sk-new-key..."
}'
```
## 与 Models 的关系
Provider 和 Model 是一对多的关系:
```
┌─────────────┐
│ Provider │
│ (OpenAI) │
└──────┬──────┘
├─── Model (gpt-4)
├─── Model (gpt-3.5-turbo)
└─── Model (text-embedding-ada-002)
```
### 创建 Model 时引用 Provider
```bash
# 1. 创建 Provider
PROVIDER_ID=$(curl -X POST http://localhost:8080/providers \
-H "Authorization: Bearer $TOKEN" \
-d '{"name":"OpenAI","client_type":"openai",...}' \
| jq -r '.id')
# 2. 创建 Model 并引用 Provider
curl -X POST http://localhost:8080/models \
-H "Authorization: Bearer $TOKEN" \
-d '{
"model_id": "gpt-4",
"name": "GPT-4",
"llm_provider_id": "'$PROVIDER_ID'",
"type": "chat"
}'
```
## 代码集成
### 在代码中使用 Provider Service
```go
import "github.com/memohai/memoh/internal/providers"
// 创建 service
providersService := providers.NewService(queries)
// 创建 provider
provider, err := providersService.Create(ctx, providers.CreateRequest{
Name: "OpenAI",
ClientType: providers.ClientTypeOpenAI,
BaseURL: "https://api.openai.com/v1",
APIKey: "sk-...",
})
// 列出所有 providers
allProviders, err := providersService.List(ctx)
// 按类型过滤
openaiProviders, err := providersService.ListByClientType(ctx, providers.ClientTypeOpenAI)
// 获取单个 provider
provider, err := providersService.Get(ctx, "provider-uuid")
// 更新 provider
updated, err := providersService.Update(ctx, "provider-uuid", providers.UpdateRequest{
APIKey: stringPtr("new-key"),
})
// 删除 provider
err := providersService.Delete(ctx, "provider-uuid")
```
## 错误处理
### 常见错误
| 状态码 | 错误 | 原因 |
|-------|------|------|
| 400 | Bad Request | 缺少必填字段或格式错误 |
| 404 | Not Found | Provider ID 不存在 |
| 409 | Conflict | Provider 名称已存在 |
| 500 | Internal Server Error | 服务器错误 |
### 错误响应格式
```json
{
"message": "invalid UUID: invalid UUID length: 5"
}
```
## 数据库 Schema
Providers 存储在 `llm_providers` 表:
```sql
CREATE TABLE llm_providers (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name TEXT NOT NULL,
client_type TEXT NOT NULL,
base_url TEXT NOT NULL,
api_key TEXT NOT NULL,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT llm_providers_name_unique UNIQUE (name),
CONSTRAINT llm_providers_client_type_check
CHECK (client_type IN ('openai', 'openai-compat', 'anthropic', 'google', 'ollama'))
);
```
## 最佳实践
### 1. 命名规范
- 使用描述性名称: `OpenAI GPT-4`, `Anthropic Claude 3`
- 包含环境信息: `OpenAI Production`, `Ollama Local`
- 避免特殊字符和空格
### 2. API Key 管理
- 定期轮换 API Keys
- 使用环境变量存储敏感信息
- 不要在日志中输出完整的 API Key
### 3. Metadata 使用
使用 metadata 字段存储额外信息:
```json
{
"metadata": {
"environment": "production",
"rate_limit": 10000,
"contact": "admin@example.com",
"notes": "Primary provider for production"
}
}
```
### 4. Provider 测试
创建 provider 后,建议测试连接:
```go
// 未来可以实现的测试端点
POST /providers/test
{
"client_type": "openai",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-...",
"model": "gpt-3.5-turbo"
}
```
## 迁移指南
### 从配置文件迁移到数据库
**旧方式** (config.toml):
```toml
[memory]
base_url = "https://api.openai.com/v1"
api_key = "sk-..."
model = "gpt-4"
```
**新方式** (API):
```bash
# 1. 创建 provider
curl -X POST http://localhost:8080/providers \
-d '{"name":"OpenAI","client_type":"openai","base_url":"https://api.openai.com/v1","api_key":"sk-..."}'
# 2. 创建 model
curl -X POST http://localhost:8080/models \
-d '{"model_id":"gpt-4","llm_provider_id":"<provider-id>","type":"chat","enable_as":"memory"}'
```
## 相关文档
- [Provider 重构总结](./REFACTORING_SUMMARY.md)
- [Agent 迁移文档](./AGENT_MIGRATION.md)
- [Chat 架构文档](./internal/chat/ARCHITECTURE.md)
- [Models API 文档](#) (待创建)
## TODO
- [ ] 实现 provider 连接测试端点
- [ ] 添加 provider 使用统计
- [ ] 实现 provider 健康检查
- [ ] 添加 provider 费用跟踪
- [ ] 支持 provider 负载均衡
- [ ] 实现 provider 故障转移
-273
View File
@@ -1,273 +0,0 @@
# Provider 重构总结
## 重构目标
将 Memory 服务和 Chat 服务统一使用 `internal/chat` 中的 Provider 接口,实现代码复用和架构统一。
## 主要变更
### 1. 扩展 Chat Provider 接口
**文件**: `internal/chat/types.go`
- 扩展 `Request` 结构,添加了:
- `Temperature *float32` - 温度参数
- `ResponseFormat *ResponseFormat` - 响应格式(支持 JSON 模式)
- `MaxTokens *int` - 最大 token 数
- 新增 `ResponseFormat` 结构:
```go
type ResponseFormat struct {
Type string // "json_object" 或 "text"
}
```
### 2. 创建 Memory 的 LLM 接口
**文件**: `internal/memory/types.go`
- 定义了 `LLM` 接口:
```go
type LLM interface {
Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error)
Decide(ctx context.Context, req DecideRequest) (DecideResponse, error)
}
```
- 这个接口被以下两个实现:
- `LLMClient` (旧实现,向后兼容)
- `ProviderLLMClient` (新实现,使用 chat.Provider)
### 3. 创建 Provider-based LLM 客户端
**文件**: `internal/memory/llm_provider_client.go` (新文件)
- 实现了 `ProviderLLMClient` 结构
- 使用 `chat.Provider` 来执行 LLM 调用
- 支持 JSON 模式输出,确保结构化响应
- 重用了 `internal/memory` 包中的辅助函数
关键代码:
```go
type ProviderLLMClient struct {
provider chat.Provider
model string
}
func NewProviderLLMClient(provider chat.Provider, model string) *ProviderLLMClient
func (c *ProviderLLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error)
func (c *ProviderLLMClient) Decide(ctx context.Context, req DecideRequest) (DecideResponse, error)
```
### 4. 更新 Memory 服务
**文件**: `internal/memory/service.go`
- 将 `llm *LLMClient` 改为 `llm LLM`
- 现在接受任何实现 `LLM` 接口的类型
- 保持了所有现有功能
### 5. 更新主程序
**文件**: `cmd/agent/main.go`
#### 添加的函数:
1. **selectMemoryModel** - 选择用于 memory 操作的模型
- 优先级:memory 模型 → chat 模型 → 任何 chat 类型模型
2. **fetchProviderByID** - 根据 ID 获取 provider 配置
3. **createChatProvider** - 根据配置创建 provider 实例
- 支持 OpenAI、Anthropic、Google、Ollama
#### 初始化流程更新:
```go
// 1. 初始化 chat resolver(用于 chat 和 memory
chatResolver := chat.NewResolver(modelsService, queries, 30*time.Second)
// 2. 尝试为 memory 创建 provider-based 客户端
memoryModel, memoryProvider, err := selectMemoryModel(ctx, modelsService, queries, &cfg)
if err != nil {
// 回退到旧的 LLMClient
llmClient = memory.NewLLMClient(cfg.Memory.BaseURL, cfg.Memory.APIKey, ...)
} else {
// 使用新的 provider-based 客户端
provider, _ := createChatProvider(memoryProvider, 30*time.Second)
llmClient = memory.NewProviderLLMClient(provider, memoryModel.ModelID)
}
// 3. 创建 memory 服务
memoryService = memory.NewService(llmClient, embedder, store, resolver, ...)
```
## 架构优势
### 1. 统一的 Provider 管理
- Chat 和 Memory 服务共享相同的 Provider 实现
- 减少代码重复
- 统一的配置和管理
### 2. 灵活的模型选择
- 可以为不同功能配置不同的模型
- 支持 `enable_as` 字段来指定模型用途
- 自动回退机制
### 3. 向后兼容
- 保留了旧的 `LLMClient` 实现
- 如果数据库中没有配置模型,自动回退到配置文件
- 平滑迁移路径
### 4. 类型安全
- 使用 Go 接口而不是运行时类型判断
- 编译时类型检查
- 更好的 IDE 支持
### 5. 易于扩展
- 添加新 Provider 只需实现 `Provider` 接口
- 添加新 LLM 客户端只需实现 `LLM` 接口
- 模块化设计
## 配置说明
### 数据库配置
为 Memory 操作配置模型:
```sql
-- 方式 1: 使用专用的 memory 模型
UPDATE models SET enable_as = 'memory'
WHERE model_id = 'gpt-4-turbo-preview';
-- 方式 2: 使用 chat 模型(如果没有专用 memory 模型)
UPDATE models SET enable_as = 'chat'
WHERE model_id = 'gpt-4';
```
### 环境变量(回退配置)
如果数据库中没有配置模型,系统会使用这些配置:
```toml
[memory]
base_url = "https://api.openai.com/v1"
api_key = "sk-..."
model = "gpt-4.1-nano"
timeout_seconds = 10
```
## 测试建议
### 1. Memory 操作测试
```bash
# 测试 Extract(提取事实)
curl -X POST http://localhost:8080/api/memory/add \
-H "Authorization: Bearer $TOKEN" \
-d '{
"messages": [
{"role": "user", "content": "My name is Alice and I like pizza"}
],
"user_id": "user-123"
}'
# 测试 Search(搜索记忆)
curl -X POST http://localhost:8080/api/memory/search \
-H "Authorization: Bearer $TOKEN" \
-d '{
"query": "What food do I like?",
"user_id": "user-123"
}'
```
### 2. Chat 操作测试
```bash
# 测试普通聊天
curl -X POST http://localhost:8080/api/chat \
-H "Authorization: Bearer $TOKEN" \
-d '{
"messages": [
{"role": "user", "content": "Hello, how are you?"}
]
}'
```
### 3. 验证日志
启动时应该看到:
```
Using memory model: gpt-4-turbo-preview (provider: openai)
```
或者(如果回退):
```
WARNING: No memory model configured, using fallback LLMClient: ...
```
## 后续工作
### 短期(必须)
- [ ] 实现各个 Provider 的具体逻辑(目前大部分返回 "not yet implemented"
- [ ] 添加流式响应支持
- [ ] 完善错误处理
### 中期(建议)
- [ ] 添加 Provider 的单元测试
- [ ] 添加 Memory 集成测试
- [ ] 实现 Provider 连接池
- [ ] 添加请求重试机制
### 长期(优化)
- [ ] Provider 性能监控
- [ ] 自动模型选择和负载均衡
- [ ] 模型响应缓存
- [ ] 支持更多 Provider(如 Cohere、HuggingFace
## 迁移检查清单
- [x] 扩展 Request 结构支持 JSON 模式
- [x] 创建 LLM 接口
- [x] 实现 ProviderLLMClient
- [x] 更新 Memory Service 使用接口
- [x] 更新主程序初始化流程
- [x] 添加模型选择逻辑
- [x] 添加 Provider 创建逻辑
- [x] 保持向后兼容
- [x] 添加架构文档
- [ ] 添加单元测试
- [ ] 添加集成测试
- [ ] 更新部署文档
## 文件清单
### 新增文件
- `internal/memory/llm_provider_client.go` - Provider-based LLM 客户端
- `internal/chat/ARCHITECTURE.md` - 架构文档
- `REFACTORING_SUMMARY.md` - 本文件
### 修改文件
- `internal/chat/types.go` - 扩展 Request 结构
- `internal/memory/types.go` - 添加 LLM 接口
- `internal/memory/service.go` - 使用 LLM 接口
- `cmd/agent/main.go` - 更新初始化流程
### 保留文件(向后兼容)
- `internal/memory/llm_client.go` - 旧的 HTTP 客户端实现
## 注意事项
1. **JSON 模式兼容性**: 不是所有模型都支持 JSON 模式,需要在实现 Provider 时处理
2. **错误处理**: 当前错误处理较简单,生产环境需要更详细的错误信息
3. **超时设置**: 不同操作可能需要不同的超时时间,可以考虑配置化
4. **并发安全**: Provider 实例应该是并发安全的
5. **资源清理**: 确保 Provider 的资源(如 HTTP 连接)正确释放
## 问题反馈
如果遇到问题,请检查:
1. 数据库中是否有配置的 chat 模型
2. Provider 配置是否正确(API key、base URL
3. 日志中的错误信息
4. 是否正确初始化了 chat resolver
详细架构说明请参考:`internal/chat/ARCHITECTURE.md`
+45 -14
View File
@@ -2,6 +2,7 @@ package main
import (
"context"
"fmt"
"log"
"os"
"strings"
@@ -71,22 +72,15 @@ func main() {
authHandler := handlers.NewAuthHandler(conn, cfg.Auth.JWTSecret, jwtExpiresIn)
// Initialize chat resolver for both chat and memory operations
chatResolver := chat.NewResolver(modelsService, queries, 30*time.Second)
chatResolver := chat.NewResolver(modelsService, queries, cfg.AgentGateway.BaseURL(), 30*time.Second)
// Create LLM client for memory operations using chat provider
var llmClient memory.LLM
memoryModel, memoryProvider, err := models.SelectMemoryModel(ctx, modelsService, queries)
if err != nil {
log.Fatalf("select memory model: %v\nPlease configure at least one chat model in the database.", err)
// Create LLM client for memory operations (deferred model/provider selection).
var llmClient memory.LLM = &lazyLLMClient{
modelsService: modelsService,
queries: queries,
timeout: 30 * time.Second,
}
log.Printf("Using memory model: %s (provider: %s)", memoryModel.ModelID, memoryProvider.ClientType)
provider, err := chat.CreateProvider(memoryProvider, 30*time.Second)
if err != nil {
log.Fatalf("create memory provider: %v", err)
}
llmClient = memory.NewProviderLLMClient(provider, memoryModel.ModelID)
resolver := embeddings.NewResolver(modelsService, queries, 10*time.Second)
vectors, textModel, multimodalModel, hasModels, err := embeddings.CollectEmbeddingVectors(ctx, modelsService)
if err != nil {
@@ -154,9 +148,46 @@ func main() {
providersService := providers.NewService(queries)
providersHandler := handlers.NewProvidersHandler(providersService)
modelsHandler := handlers.NewModelsHandler(modelsService)
srv := server.NewServer(addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, swaggerHandler, chatHandler, providersHandler, modelsHandler, containerdHandler)
srv := server.NewServer(addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, containerdHandler)
if err := srv.Start(); err != nil {
log.Fatalf("server failed: %v", err)
}
}
type lazyLLMClient struct {
modelsService *models.Service
queries *dbsqlc.Queries
timeout time.Duration
}
func (c *lazyLLMClient) Extract(ctx context.Context, req memory.ExtractRequest) (memory.ExtractResponse, error) {
client, err := c.resolve(ctx)
if err != nil {
return memory.ExtractResponse{}, err
}
return client.Extract(ctx, req)
}
func (c *lazyLLMClient) Decide(ctx context.Context, req memory.DecideRequest) (memory.DecideResponse, error) {
client, err := c.resolve(ctx)
if err != nil {
return memory.DecideResponse{}, err
}
return client.Decide(ctx, req)
}
func (c *lazyLLMClient) resolve(ctx context.Context) (memory.LLM, error) {
if c.modelsService == nil || c.queries == nil {
return nil, fmt.Errorf("models service not configured")
}
memoryModel, memoryProvider, err := models.SelectMemoryModel(ctx, c.modelsService, c.queries)
if err != nil {
return nil, err
}
clientType := strings.ToLower(strings.TrimSpace(memoryProvider.ClientType))
if clientType != "openai" && clientType != "openai-compat" {
return nil, fmt.Errorf("memory provider client type not supported: %s", memoryProvider.ClientType)
}
return memory.NewLLMClient(memoryProvider.BaseUrl, memoryProvider.ApiKey, memoryModel.ModelID, c.timeout), nil
}
+6 -1
View File
@@ -33,4 +33,9 @@ sslmode = "disable"
base_url = "http://127.0.0.1:6334"
api_key = ""
collection = "memory"
timeout_seconds = 10
timeout_seconds = 10
## Agent Gateway
[agent_gateway]
host = "127.0.0.1"
port = 8081
+10
View File
@@ -130,3 +130,13 @@ CREATE TABLE IF NOT EXISTS lifecycle_events (
CREATE INDEX IF NOT EXISTS idx_lifecycle_events_container_id ON lifecycle_events(container_id);
CREATE INDEX IF NOT EXISTS idx_lifecycle_events_event_type ON lifecycle_events(event_type);
CREATE TABLE IF NOT EXISTS history (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
messages JSONB NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
"user" UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_history_user ON history("user");
CREATE INDEX IF NOT EXISTS idx_history_timestamp ON history(timestamp);
+11
View File
@@ -0,0 +1,11 @@
-- name: CreateHistory :one
INSERT INTO history (messages, timestamp, "user")
VALUES ($1, $2, $3)
RETURNING id, messages, timestamp, "user";
-- name: ListHistoryByUserSince :many
SELECT id, messages, timestamp, "user"
FROM history
WHERE "user" = $1 AND timestamp >= $2
ORDER BY timestamp ASC;
+38 -59
View File
@@ -135,7 +135,11 @@ const docTemplate = `{
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/chat.StreamChunk"
"type": "array",
"items": {
"type": "integer",
"format": "int32"
}
}
},
"400": {
@@ -1431,71 +1435,53 @@ const docTemplate = `{
"chat.ChatRequest": {
"type": "object",
"properties": {
"current_platform": {
"type": "string"
},
"language": {
"type": "string"
},
"locale": {
"type": "string"
},
"max_context_load_time": {
"type": "integer"
},
"max_steps": {
"type": "integer"
},
"messages": {
"type": "array",
"items": {
"$ref": "#/definitions/chat.Message"
"$ref": "#/definitions/chat.GatewayMessage"
}
},
"model": {
"description": "optional: specific model to use",
"type": "string"
},
"platforms": {
"type": "array",
"items": {
"type": "string"
}
},
"provider": {
"description": "optional: specific provider to use",
"type": "string"
},
"stream": {
"type": "boolean"
"query": {
"type": "string"
}
}
},
"chat.ChatResponse": {
"type": "object",
"properties": {
"finish_reason": {
"type": "string"
},
"message": {
"$ref": "#/definitions/chat.Message"
},
"model": {
"type": "string"
},
"provider": {
"type": "string"
},
"usage": {
"$ref": "#/definitions/chat.Usage"
}
}
},
"chat.Message": {
"type": "object",
"properties": {
"content": {
"type": "string"
},
"role": {
"description": "user, assistant, system",
"type": "string"
}
}
},
"chat.StreamChunk": {
"type": "object",
"properties": {
"delta": {
"type": "object",
"properties": {
"content": {
"type": "string"
}
"messages": {
"type": "array",
"items": {
"$ref": "#/definitions/chat.GatewayMessage"
}
},
"finish_reason": {
"type": "string"
},
"model": {
"type": "string"
},
@@ -1504,19 +1490,9 @@ const docTemplate = `{
}
}
},
"chat.Usage": {
"chat.GatewayMessage": {
"type": "object",
"properties": {
"completion_tokens": {
"type": "integer"
},
"prompt_tokens": {
"type": "integer"
},
"total_tokens": {
"type": "integer"
}
}
"additionalProperties": true
},
"handlers.CreateContainerRequest": {
"type": "object",
@@ -1524,6 +1500,9 @@ const docTemplate = `{
"container_id": {
"type": "string"
},
"image": {
"type": "string"
},
"snapshotter": {
"type": "string"
}
+38 -59
View File
@@ -124,7 +124,11 @@
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/chat.StreamChunk"
"type": "array",
"items": {
"type": "integer",
"format": "int32"
}
}
},
"400": {
@@ -1420,71 +1424,53 @@
"chat.ChatRequest": {
"type": "object",
"properties": {
"current_platform": {
"type": "string"
},
"language": {
"type": "string"
},
"locale": {
"type": "string"
},
"max_context_load_time": {
"type": "integer"
},
"max_steps": {
"type": "integer"
},
"messages": {
"type": "array",
"items": {
"$ref": "#/definitions/chat.Message"
"$ref": "#/definitions/chat.GatewayMessage"
}
},
"model": {
"description": "optional: specific model to use",
"type": "string"
},
"platforms": {
"type": "array",
"items": {
"type": "string"
}
},
"provider": {
"description": "optional: specific provider to use",
"type": "string"
},
"stream": {
"type": "boolean"
"query": {
"type": "string"
}
}
},
"chat.ChatResponse": {
"type": "object",
"properties": {
"finish_reason": {
"type": "string"
},
"message": {
"$ref": "#/definitions/chat.Message"
},
"model": {
"type": "string"
},
"provider": {
"type": "string"
},
"usage": {
"$ref": "#/definitions/chat.Usage"
}
}
},
"chat.Message": {
"type": "object",
"properties": {
"content": {
"type": "string"
},
"role": {
"description": "user, assistant, system",
"type": "string"
}
}
},
"chat.StreamChunk": {
"type": "object",
"properties": {
"delta": {
"type": "object",
"properties": {
"content": {
"type": "string"
}
"messages": {
"type": "array",
"items": {
"$ref": "#/definitions/chat.GatewayMessage"
}
},
"finish_reason": {
"type": "string"
},
"model": {
"type": "string"
},
@@ -1493,19 +1479,9 @@
}
}
},
"chat.Usage": {
"chat.GatewayMessage": {
"type": "object",
"properties": {
"completion_tokens": {
"type": "integer"
},
"prompt_tokens": {
"type": "integer"
},
"total_tokens": {
"type": "integer"
}
}
"additionalProperties": true
},
"handlers.CreateContainerRequest": {
"type": "object",
@@ -1513,6 +1489,9 @@
"container_id": {
"type": "string"
},
"image": {
"type": "string"
},
"snapshotter": {
"type": "string"
}
+29 -42
View File
@@ -1,67 +1,51 @@
definitions:
chat.ChatRequest:
properties:
current_platform:
type: string
language:
type: string
locale:
type: string
max_context_load_time:
type: integer
max_steps:
type: integer
messages:
items:
$ref: '#/definitions/chat.Message'
$ref: '#/definitions/chat.GatewayMessage'
type: array
model:
description: 'optional: specific model to use'
type: string
platforms:
items:
type: string
type: array
provider:
description: 'optional: specific provider to use'
type: string
stream:
type: boolean
query:
type: string
type: object
chat.ChatResponse:
properties:
finish_reason:
type: string
message:
$ref: '#/definitions/chat.Message'
model:
type: string
provider:
type: string
usage:
$ref: '#/definitions/chat.Usage'
type: object
chat.Message:
properties:
content:
type: string
role:
description: user, assistant, system
type: string
type: object
chat.StreamChunk:
properties:
delta:
properties:
content:
type: string
type: object
finish_reason:
type: string
messages:
items:
$ref: '#/definitions/chat.GatewayMessage'
type: array
model:
type: string
provider:
type: string
type: object
chat.Usage:
properties:
completion_tokens:
type: integer
prompt_tokens:
type: integer
total_tokens:
type: integer
chat.GatewayMessage:
additionalProperties: true
type: object
handlers.CreateContainerRequest:
properties:
container_id:
type: string
image:
type: string
snapshotter:
type: string
type: object
@@ -553,7 +537,10 @@ paths:
"200":
description: OK
schema:
$ref: '#/definitions/chat.StreamChunk'
items:
format: int32
type: integer
type: array
"400":
description: Bad Request
schema:
-213
View File
@@ -1,213 +0,0 @@
# Chat Provider 架构文档
## 概述
本文档描述了 Memoh 项目中统一的 Chat Provider 架构,该架构被 chat 服务和 memory 服务共同使用。
## 架构设计
### 核心接口
#### Provider 接口
所有 LLM 提供商都实现 `chat.Provider` 接口:
```go
type Provider interface {
Chat(ctx context.Context, req Request) (Result, error)
StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error)
}
```
#### Request 结构
Provider 请求支持多种配置选项:
```go
type Request struct {
Messages []Message
Model string
Provider string
Temperature *float32 // 可选:温度参数
ResponseFormat *ResponseFormat // 可选:响应格式(JSON 模式)
MaxTokens *int // 可选:最大 token 数
}
type ResponseFormat struct {
Type string // "json_object" 或 "text"
}
```
### 支持的提供商
1. **OpenAI** (`openai` / `openai-compat`)
- 标准 OpenAI API
- 兼容 OpenAI 格式的自定义端点
2. **Anthropic** (`anthropic`)
- Claude 系列模型
3. **Google** (`google`)
- Gemini 系列模型
4. **Ollama** (`ollama`)
- 本地部署的开源模型
## 使用场景
### 1. Chat 服务
Chat 服务通过 `chat.Resolver` 使用 Provider
```go
chatResolver := chat.NewResolver(modelsService, queries, 30*time.Second)
response, err := chatResolver.Chat(ctx, ChatRequest{
Messages: messages,
Model: "gpt-4",
})
```
### 2. Memory 服务
Memory 服务通过 `memory.ProviderLLMClient` 使用 Provider
```go
// 创建 provider
provider, err := chat.NewOpenAIProvider(apiKey, baseURL, timeout)
// 创建 memory LLM 客户端
llmClient := memory.NewProviderLLMClient(provider, modelID)
// 使用 memory 服务
memoryService := memory.NewService(llmClient, embedder, store, resolver, ...)
```
Memory 服务需要两个核心功能:
- **Extract**: 从对话中提取事实信息
- **Decide**: 决定如何更新记忆(添加/更新/删除)
这两个操作都使用 JSON 模式来确保结构化输出。
## 配置示例
### 数据库配置
Provider 配置存储在 `llm_providers` 表:
```sql
CREATE TABLE llm_providers (
id UUID PRIMARY KEY,
name TEXT NOT NULL,
client_type TEXT NOT NULL, -- 'openai', 'anthropic', 'google', 'ollama'
base_url TEXT NOT NULL,
api_key TEXT NOT NULL,
metadata JSONB
);
```
模型配置存储在 `models` 表:
```sql
CREATE TABLE models (
id UUID PRIMARY KEY,
model_id TEXT NOT NULL,
name TEXT,
llm_provider_id UUID REFERENCES llm_providers(id),
type TEXT NOT NULL, -- 'chat' or 'embedding'
enable_as TEXT, -- 'chat', 'memory', 'embedding'
...
);
```
### 启动时初始化
`cmd/agent/main.go` 中:
```go
// 1. 初始化 chat resolver(用于 chat 和 memory
chatResolver := chat.NewResolver(modelsService, queries, 30*time.Second)
// 2. 为 memory 选择模型和创建 provider
memoryModel, memoryProvider, err := selectMemoryModel(ctx, modelsService, queries, cfg)
provider, err := createChatProvider(memoryProvider, 30*time.Second)
// 3. 创建 memory LLM 客户端
llmClient := memory.NewProviderLLMClient(provider, memoryModel.ModelID)
// 4. 创建 memory 服务
memoryService := memory.NewService(llmClient, embedder, store, resolver, ...)
```
## 模型选择策略
### Memory 模型选择优先级
1. `enable_as = 'memory'` 的模型(专用 memory 模型)
2. `enable_as = 'chat'` 的模型(通用 chat 模型)
3. 任何可用的 chat 类型模型
4. 回退到配置文件中的 LLMClient(向后兼容)
### Chat 模型选择优先级
1. 请求中指定的模型
2. `enable_as = 'chat'` 的模型
3. 任何可用的 chat 类型模型
## 优势
1. **统一架构**: Chat 和 Memory 使用相同的 Provider 接口
2. **灵活配置**: 支持多个提供商和模型
3. **向后兼容**: 保留旧的 LLMClient 作为回退选项
4. **类型安全**: 使用 Go 接口确保类型安全
5. **易于扩展**: 添加新的提供商只需实现 Provider 接口
## 扩展新提供商
要添加新的 LLM 提供商:
1. 在 `internal/chat/` 创建新文件(如 `newprovider.go`
2. 实现 `Provider` 接口
3. 在 `resolver.go``createProvider()` 中添加新的 case
4. 在数据库的 `llm_providers_client_type_check` 约束中添加新类型
示例:
```go
// newprovider.go
type NewProvider struct {
apiKey string
timeout time.Duration
}
func NewNewProvider(apiKey string, timeout time.Duration) (*NewProvider, error) {
return &NewProvider{apiKey: apiKey, timeout: timeout}, nil
}
func (p *NewProvider) Chat(ctx context.Context, req Request) (Result, error) {
// 实现 chat 逻辑
}
func (p *NewProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) {
// 实现流式 chat 逻辑
}
```
## 迁移指南
从旧的 TypeScript 后端迁移到 Go
1. ✅ 创建 Provider 接口和实现
2. ✅ 实现 Chat Resolver
3. ✅ 创建 Memory 的 Provider 适配器
4. ✅ 更新主程序使用统一 Provider
5. 🚧 实现各个 Provider 的具体逻辑(OpenAI, Anthropic, Google, Ollama
6. 🚧 添加流式响应支持
7. 🚧 添加完整的错误处理和重试机制
## 注意事项
1. **JSON 模式**: Memory 操作需要 `ResponseFormat.Type = "json_object"` 来确保结构化输出
2. **温度参数**: Memory 操作使用 `Temperature = 0` 确保确定性输出
3. **超时设置**: 不同操作可能需要不同的超时时间
4. **错误处理**: Provider 应该返回清晰的错误信息,包括 API 错误详情
-43
View File
@@ -1,43 +0,0 @@
package chat
import (
"context"
"fmt"
"time"
"github.com/firebase/genkit/go/genkit"
)
// AnthropicProvider wraps Genkit's Anthropic provider
type AnthropicProvider struct {
g *genkit.Genkit
apiKey string
timeout time.Duration
}
func NewAnthropicProvider(apiKey string, timeout time.Duration) (*AnthropicProvider, error) {
if timeout <= 0 {
timeout = 30 * time.Second
}
return &AnthropicProvider{
apiKey: apiKey,
timeout: timeout,
}, nil
}
func (p *AnthropicProvider) Chat(ctx context.Context, req Request) (Result, error) {
return Result{}, fmt.Errorf("anthropic provider not yet implemented")
}
func (p *AnthropicProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) {
chunkChan := make(chan StreamChunk, 10)
errChan := make(chan error, 1)
go func() {
defer close(chunkChan)
defer close(errChan)
errChan <- fmt.Errorf("anthropic streaming not yet implemented")
}()
return chunkChan, errChan
}
+1 -2
View File
@@ -1,2 +1 @@
package chat
package chat
-39
View File
@@ -1,39 +0,0 @@
package chat
import (
"fmt"
"strings"
"time"
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
)
// CreateProvider creates a chat provider instance.
func CreateProvider(provider dbsqlc.LlmProvider, timeout time.Duration) (Provider, error) {
clientType := strings.ToLower(strings.TrimSpace(provider.ClientType))
if timeout <= 0 {
timeout = 30 * time.Second
}
switch clientType {
case ProviderOpenAI, ProviderOpenAICompat:
if strings.TrimSpace(provider.ApiKey) == "" {
return nil, fmt.Errorf("openai api key is required")
}
return NewOpenAIProvider(provider.ApiKey, provider.BaseUrl, timeout)
case ProviderAnthropic:
if strings.TrimSpace(provider.ApiKey) == "" {
return nil, fmt.Errorf("anthropic api key is required")
}
return NewAnthropicProvider(provider.ApiKey, timeout)
case ProviderGoogle:
if strings.TrimSpace(provider.ApiKey) == "" {
return nil, fmt.Errorf("google api key is required")
}
return NewGoogleProvider(provider.ApiKey, timeout)
case ProviderOllama:
return NewOllamaProvider(provider.BaseUrl, timeout)
default:
return nil, fmt.Errorf("unsupported provider type: %s", clientType)
}
}
-43
View File
@@ -1,43 +0,0 @@
package chat
import (
"context"
"fmt"
"time"
"github.com/firebase/genkit/go/genkit"
)
// GoogleProvider wraps Genkit's Google AI provider
type GoogleProvider struct {
g *genkit.Genkit
apiKey string
timeout time.Duration
}
func NewGoogleProvider(apiKey string, timeout time.Duration) (*GoogleProvider, error) {
if timeout <= 0 {
timeout = 30 * time.Second
}
return &GoogleProvider{
apiKey: apiKey,
timeout: timeout,
}, nil
}
func (p *GoogleProvider) Chat(ctx context.Context, req Request) (Result, error) {
return Result{}, fmt.Errorf("google provider not yet implemented")
}
func (p *GoogleProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) {
chunkChan := make(chan StreamChunk, 10)
errChan := make(chan error, 1)
go func() {
defer close(chunkChan)
defer close(errChan)
errChan <- fmt.Errorf("google streaming not yet implemented")
}()
return chunkChan, errChan
}
-48
View File
@@ -1,48 +0,0 @@
package chat
import (
"context"
"fmt"
"strings"
"time"
"github.com/firebase/genkit/go/genkit"
)
// OllamaProvider wraps Genkit's Ollama provider
type OllamaProvider struct {
g *genkit.Genkit
baseURL string
timeout time.Duration
}
func NewOllamaProvider(baseURL string, timeout time.Duration) (*OllamaProvider, error) {
if baseURL == "" {
baseURL = "http://localhost:11434"
}
baseURL = strings.TrimRight(baseURL, "/")
if timeout <= 0 {
timeout = 60 * time.Second
}
return &OllamaProvider{
baseURL: baseURL,
timeout: timeout,
}, nil
}
func (p *OllamaProvider) Chat(ctx context.Context, req Request) (Result, error) {
return Result{}, fmt.Errorf("ollama provider not yet implemented")
}
func (p *OllamaProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) {
chunkChan := make(chan StreamChunk, 10)
errChan := make(chan error, 1)
go func() {
defer close(chunkChan)
defer close(errChan)
errChan <- fmt.Errorf("ollama streaming not yet implemented")
}()
return chunkChan, errChan
}
-54
View File
@@ -1,54 +0,0 @@
package chat
import (
"context"
"fmt"
"strings"
"time"
"github.com/firebase/genkit/go/genkit"
)
// OpenAIProvider wraps Genkit's OpenAI provider
type OpenAIProvider struct {
g *genkit.Genkit
apiKey string
baseURL string
timeout time.Duration
}
func NewOpenAIProvider(apiKey, baseURL string, timeout time.Duration) (*OpenAIProvider, error) {
if baseURL == "" {
baseURL = "https://api.openai.com/v1"
}
baseURL = strings.TrimRight(baseURL, "/")
if timeout <= 0 {
timeout = 30 * time.Second
}
// For now, we'll create a simple HTTP client-based implementation
// since Genkit Go plugins require initialization at startup
return &OpenAIProvider{
apiKey: apiKey,
baseURL: baseURL,
timeout: timeout,
}, nil
}
func (p *OpenAIProvider) Chat(ctx context.Context, req Request) (Result, error) {
// Use direct HTTP API call since Genkit plugins need to be initialized at startup
return Result{}, fmt.Errorf("openai provider not yet implemented - please use openai-compat provider")
}
func (p *OpenAIProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) {
chunkChan := make(chan StreamChunk, 10)
errChan := make(chan error, 1)
go func() {
defer close(chunkChan)
defer close(errChan)
errChan <- fmt.Errorf("openai streaming not yet implemented")
}()
return chunkChan, errChan
}
-141
View File
@@ -1,141 +0,0 @@
package chat
import (
"fmt"
"strings"
"time"
)
// PromptParams contains parameters for generating system prompts
type PromptParams struct {
Date time.Time
Locale string
Language string
MaxContextLoadTime int // in minutes
Platforms []string // available platforms (e.g., ["telegram", "wechat"])
CurrentPlatform string // current platform the user is using
}
// SystemPrompt generates the system prompt for the AI assistant
// This is migrated from packages/agent/src/prompts/system.ts
func SystemPrompt(params PromptParams) string {
if params.Language == "" {
params.Language = "Same as user input"
}
if params.MaxContextLoadTime == 0 {
params.MaxContextLoadTime = 24 * 60 // 24 hours default
}
if params.CurrentPlatform == "" {
params.CurrentPlatform = "client"
}
// Build platforms list
platformsList := ""
if len(params.Platforms) > 0 {
lines := make([]string, len(params.Platforms))
for i, p := range params.Platforms {
lines[i] = fmt.Sprintf(" - %s", p)
}
platformsList = strings.Join(lines, "\n")
}
timeStr := FormatTime(params.Date, params.Locale)
return fmt.Sprintf(`---
%s
language: %s
available-platforms:
%s
current-platform: %s
---
You are a personal housekeeper assistant, which able to manage the master's daily affairs.
Your abilities:
- Long memory: You possess long-term memory; conversations from the last %d minutes will be directly loaded into your context. Additionally, you can use tools to search for past memories.
- Scheduled tasks: You can create scheduled tasks to automatically remind you to do something.
- Messaging: You may allowed to use message software to send messages to the master.
**Response Guidelines**
- Always respond in the language specified above, unless it says "Same as user input", then match the user's language.
- Be helpful, concise, and friendly.
- For complex questions, break down your answer into clear steps.
- If you're unsure about something, acknowledge it honestly.`,
timeStr,
params.Language,
platformsList,
params.CurrentPlatform,
params.MaxContextLoadTime,
)
}
// SchedulePrompt generates a prompt for scheduled task execution
// This is migrated from packages/agent/src/prompts/schedule.ts
type SchedulePromptParams struct {
Date time.Time
Locale string
ScheduleName string
ScheduleDescription string
ScheduleID string
MaxCalls *int // nil means unlimited
CronPattern string
Command string // the natural language command to execute
}
func SchedulePrompt(params SchedulePromptParams) string {
timeStr := FormatTime(params.Date, params.Locale)
maxCallsStr := "Unlimited"
if params.MaxCalls != nil {
maxCallsStr = fmt.Sprintf("%d", *params.MaxCalls)
}
return fmt.Sprintf(`---
notice: **This is a scheduled task automatically send to you by the system, not the user input**
%s
schedule-name: %s
schedule-description: %s
schedule-id: %s
max-calls: %s
cron-pattern: %s
---
**COMMAND**
%s`,
timeStr,
params.ScheduleName,
params.ScheduleDescription,
params.ScheduleID,
maxCallsStr,
params.CronPattern,
params.Command,
)
}
// FormatTime formats the date and time according to locale
func FormatTime(date time.Time, locale string) string {
if locale == "" {
locale = "en-US"
}
// Format date and time
// For simplicity, using standard format. In production, you might want to use
// a proper i18n library for locale-specific formatting
dateStr := date.Format("2006-01-02")
timeStr := date.Format("15:04:05")
return fmt.Sprintf("date: %s\ntime: %s", dateStr, timeStr)
}
// Quote wraps content in backticks for markdown code formatting
func Quote(content string) string {
return fmt.Sprintf("`%s`", content)
}
// Block wraps content in code block with optional language tag
func Block(content, tag string) string {
if tag == "" {
return fmt.Sprintf("```\n%s\n```", content)
}
return fmt.Sprintf("```%s\n%s\n```", tag, content)
}
+335 -174
View File
@@ -1,8 +1,13 @@
package chat
import (
"bufio"
"bytes"
"context"
"errors"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
@@ -13,169 +18,359 @@ import (
"github.com/memohai/memoh/internal/models"
)
const (
ProviderOpenAI = "openai"
ProviderOpenAICompat = "openai-compat"
ProviderAnthropic = "anthropic"
ProviderGoogle = "google"
ProviderOllama = "ollama"
)
const defaultMaxContextMinutes = 24 * 60
// Provider interface for chat providers
type Provider interface {
Chat(ctx context.Context, req Request) (Result, error)
StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error)
}
// Resolver resolves chat models and providers
type Resolver struct {
modelsService *models.Service
queries *sqlc.Queries
timeout time.Duration
modelsService *models.Service
queries *sqlc.Queries
gatewayBaseURL string
timeout time.Duration
httpClient *http.Client
streamingClient *http.Client
}
// NewResolver creates a new chat resolver
func NewResolver(modelsService *models.Service, queries *sqlc.Queries, timeout time.Duration) *Resolver {
func NewResolver(modelsService *models.Service, queries *sqlc.Queries, gatewayBaseURL string, timeout time.Duration) *Resolver {
if strings.TrimSpace(gatewayBaseURL) == "" {
gatewayBaseURL = "http://127.0.0.1:8081"
}
gatewayBaseURL = strings.TrimRight(gatewayBaseURL, "/")
if timeout <= 0 {
timeout = 30 * time.Second
}
return &Resolver{
modelsService: modelsService,
queries: queries,
timeout: timeout,
modelsService: modelsService,
queries: queries,
gatewayBaseURL: gatewayBaseURL,
timeout: timeout,
httpClient: &http.Client{
Timeout: timeout,
},
streamingClient: &http.Client{},
}
}
// Chat performs a chat completion
func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, error) {
// Select model and provider
selected, provider, err := r.selectChatModel(ctx, req.Model, req.Provider)
if strings.TrimSpace(req.Query) == "" {
return ChatResponse{}, fmt.Errorf("query is required")
}
if strings.TrimSpace(req.UserID) == "" {
return ChatResponse{}, fmt.Errorf("user id is required")
}
chatModel, provider, err := r.selectChatModel(ctx, req)
if err != nil {
return ChatResponse{}, err
}
clientType, err := normalizeClientType(provider.ClientType)
if err != nil {
return ChatResponse{}, err
}
// Create internal request
internalReq := Request{
Messages: req.Messages,
Model: selected.ModelID,
Provider: strings.ToLower(provider.ClientType),
messages, err := r.loadHistoryMessages(ctx, req.UserID, req.MaxContextLoadTime)
if err != nil {
return ChatResponse{}, err
}
if len(req.Messages) > 0 {
messages = append(messages, req.Messages...)
}
// Add system prompt
if len(internalReq.Messages) > 0 && internalReq.Messages[0].Role != "system" {
systemPrompt := SystemPrompt(PromptParams{
Date: time.Now(),
Locale: "en-US",
Language: "Same as user input",
MaxContextLoadTime: 24 * 60, // 24 hours
Platforms: []string{},
CurrentPlatform: "api",
})
internalReq.Messages = append([]Message{
{Role: "system", Content: systemPrompt},
}, internalReq.Messages...)
payload := agentGatewayRequest{
APIKey: provider.ApiKey,
BaseURL: provider.BaseUrl,
Model: chatModel.ModelID,
ClientType: clientType,
Locale: req.Locale,
Language: req.Language,
MaxSteps: req.MaxSteps,
MaxContextLoadTime: normalizeMaxContextLoad(req.MaxContextLoadTime),
Platforms: req.Platforms,
CurrentPlatform: req.CurrentPlatform,
Messages: messages,
Query: req.Query,
}
// Create provider instance
providerInst, err := r.createProvider(provider)
resp, err := r.postChat(ctx, payload)
if err != nil {
return ChatResponse{}, err
}
// Execute chat
result, err := providerInst.Chat(ctx, internalReq)
if err != nil {
if err := r.storeHistory(ctx, req.UserID, req.Query, resp.Messages); err != nil {
return ChatResponse{}, err
}
return ChatResponse{
Message: result.Message,
Model: result.Model,
Provider: result.Provider,
FinishReason: result.FinishReason,
Usage: result.Usage,
Messages: resp.Messages,
Model: chatModel.ModelID,
Provider: provider.ClientType,
}, nil
}
// StreamChat performs a streaming chat completion
func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan StreamChunk, <-chan error) {
chunkChan := make(chan StreamChunk, 10)
chunkChan := make(chan StreamChunk)
errChan := make(chan error, 1)
go func() {
defer close(chunkChan)
defer close(errChan)
// Select model and provider
selected, provider, err := r.selectChatModel(ctx, req.Model, req.Provider)
if strings.TrimSpace(req.Query) == "" {
errChan <- fmt.Errorf("query is required")
return
}
if strings.TrimSpace(req.UserID) == "" {
errChan <- fmt.Errorf("user id is required")
return
}
chatModel, provider, err := r.selectChatModel(ctx, req)
if err != nil {
errChan <- err
return
}
clientType, err := normalizeClientType(provider.ClientType)
if err != nil {
errChan <- err
return
}
// Create internal request
internalReq := Request{
Messages: req.Messages,
Model: selected.ModelID,
Provider: strings.ToLower(provider.ClientType),
}
// Add system prompt
if len(internalReq.Messages) > 0 && internalReq.Messages[0].Role != "system" {
systemPrompt := SystemPrompt(PromptParams{
Date: time.Now(),
Locale: "en-US",
Language: "Same as user input",
MaxContextLoadTime: 24 * 60, // 24 hours
Platforms: []string{},
CurrentPlatform: "api",
})
internalReq.Messages = append([]Message{
{Role: "system", Content: systemPrompt},
}, internalReq.Messages...)
}
// Create provider instance
providerInst, err := r.createProvider(provider)
messages, err := r.loadHistoryMessages(ctx, req.UserID, req.MaxContextLoadTime)
if err != nil {
errChan <- err
return
}
if len(req.Messages) > 0 {
messages = append(messages, req.Messages...)
}
// Execute streaming chat
providerChunkChan, providerErrChan := providerInst.StreamChat(ctx, internalReq)
payload := agentGatewayRequest{
APIKey: provider.ApiKey,
BaseURL: provider.BaseUrl,
Model: chatModel.ModelID,
ClientType: clientType,
Locale: req.Locale,
Language: req.Language,
MaxSteps: req.MaxSteps,
MaxContextLoadTime: normalizeMaxContextLoad(req.MaxContextLoadTime),
Platforms: req.Platforms,
CurrentPlatform: req.CurrentPlatform,
Messages: messages,
Query: req.Query,
}
// Forward chunks and errors
for {
select {
case chunk, ok := <-providerChunkChan:
if !ok {
return
}
chunkChan <- chunk
case err := <-providerErrChan:
if err != nil {
errChan <- err
}
return
}
if err := r.streamChat(ctx, payload, req.UserID, req.Query, chunkChan); err != nil {
errChan <- err
return
}
}()
return chunkChan, errChan
}
// selectChatModel selects a chat model based on the request
func (r *Resolver) selectChatModel(ctx context.Context, modelID, providerType string) (models.GetResponse, sqlc.LlmProvider, error) {
if r.modelsService == nil {
return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("models service not configured")
type agentGatewayRequest struct {
APIKey string `json:"apiKey"`
BaseURL string `json:"baseUrl"`
Model string `json:"model"`
ClientType string `json:"clientType"`
Locale string `json:"locale,omitempty"`
Language string `json:"language,omitempty"`
MaxSteps int `json:"maxSteps,omitempty"`
MaxContextLoadTime int `json:"maxContextLoadTime"`
Platforms []string `json:"platforms,omitempty"`
CurrentPlatform string `json:"currentPlatform,omitempty"`
Messages []GatewayMessage `json:"messages"`
Query string `json:"query"`
}
type agentGatewayResponse struct {
Messages []GatewayMessage `json:"messages"`
}
func (r *Resolver) postChat(ctx context.Context, payload agentGatewayRequest) (agentGatewayResponse, error) {
body, err := json.Marshal(payload)
if err != nil {
return agentGatewayResponse{}, err
}
url := r.gatewayBaseURL + "/chat"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return agentGatewayResponse{}, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := r.httpClient.Do(req)
if err != nil {
return agentGatewayResponse{}, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
payload, _ := io.ReadAll(resp.Body)
return agentGatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(payload)))
}
modelID = strings.TrimSpace(modelID)
providerType = strings.ToLower(strings.TrimSpace(providerType))
var parsed agentGatewayResponse
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
return agentGatewayResponse{}, err
}
return parsed, nil
}
// If no model specified, try to get default chat model
if modelID == "" && providerType == "" {
func (r *Resolver) streamChat(ctx context.Context, payload agentGatewayRequest, userID, query string, chunkChan chan<- StreamChunk) error {
body, err := json.Marshal(payload)
if err != nil {
return err
}
url := r.gatewayBaseURL + "/chat/stream"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
resp, err := r.streamingClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
payload, _ := io.ReadAll(resp.Body)
return fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(payload)))
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || !strings.HasPrefix(line, "data:") {
continue
}
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if data == "" || data == "[DONE]" {
continue
}
chunkChan <- StreamChunk([]byte(data))
var envelope struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
}
if err := json.Unmarshal([]byte(data), &envelope); err != nil {
continue
}
if envelope.Type != "done" || len(envelope.Data) == 0 {
continue
}
var parsed agentGatewayResponse
if err := json.Unmarshal(envelope.Data, &parsed); err != nil {
continue
}
if err := r.storeHistory(ctx, userID, query, parsed.Messages); err != nil {
return err
}
}
if err := scanner.Err(); err != nil {
return err
}
return nil
}
func (r *Resolver) loadHistoryMessages(ctx context.Context, userID string, maxContextLoadTime int) ([]GatewayMessage, error) {
if r.queries == nil {
return nil, fmt.Errorf("history queries not configured")
}
pgUserID, err := parseUUID(userID)
if err != nil {
return nil, err
}
from := time.Now().UTC().Add(-time.Duration(normalizeMaxContextLoad(maxContextLoadTime)) * time.Minute)
rows, err := r.queries.ListHistoryByUserSince(ctx, sqlc.ListHistoryByUserSinceParams{
User: pgUserID,
Timestamp: pgtype.Timestamptz{
Time: from,
Valid: true,
},
})
if err != nil {
return nil, err
}
messages := make([]GatewayMessage, 0, len(rows))
for _, row := range rows {
var batch []GatewayMessage
if len(row.Messages) == 0 {
continue
}
if err := json.Unmarshal(row.Messages, &batch); err != nil {
return nil, err
}
messages = append(messages, batch...)
}
return messages, nil
}
func (r *Resolver) storeHistory(ctx context.Context, userID, query string, responseMessages []GatewayMessage) error {
if r.queries == nil {
return fmt.Errorf("history queries not configured")
}
if strings.TrimSpace(userID) == "" {
return fmt.Errorf("user id is required")
}
if strings.TrimSpace(query) == "" && len(responseMessages) == 0 {
return nil
}
userMessage := GatewayMessage{
"role": "user",
"content": query,
}
messages := append([]GatewayMessage{userMessage}, responseMessages...)
payload, err := json.Marshal(messages)
if err != nil {
return err
}
pgUserID, err := parseUUID(userID)
if err != nil {
return err
}
_, err = r.queries.CreateHistory(ctx, sqlc.CreateHistoryParams{
Messages: payload,
Timestamp: pgtype.Timestamptz{
Time: time.Now().UTC(),
Valid: true,
},
User: pgUserID,
})
return err
}
func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest) (models.GetResponse, sqlc.LlmProvider, error) {
if r.modelsService == nil {
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured")
}
modelID := strings.TrimSpace(req.Model)
providerFilter := strings.TrimSpace(req.Provider)
if modelID != "" && providerFilter == "" {
model, err := r.modelsService.GetByModelID(ctx, modelID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
if model.Type != models.ModelTypeChat {
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model")
}
provider, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
return model, provider, nil
}
if providerFilter == "" && modelID == "" {
defaultModel, err := r.modelsService.GetByEnableAs(ctx, models.EnableAsChat)
if err == nil {
provider, err := r.fetchProvider(ctx, defaultModel.LlmProviderID)
provider, err := models.FetchProviderByID(ctx, r.queries, defaultModel.LlmProviderID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
@@ -183,11 +378,10 @@ func (r *Resolver) selectChatModel(ctx context.Context, modelID, providerType st
}
}
// List available models
var candidates []models.GetResponse
var err error
if providerType != "" {
candidates, err = r.modelsService.ListByClientType(ctx, models.ClientType(providerType))
if providerFilter != "" {
candidates, err = r.modelsService.ListByClientType(ctx, models.ClientType(providerFilter))
} else {
candidates, err = r.modelsService.ListByType(ctx, models.ModelTypeChat)
}
@@ -195,7 +389,6 @@ func (r *Resolver) selectChatModel(ctx context.Context, modelID, providerType st
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
// Filter chat models
filtered := make([]models.GetResponse, 0, len(candidates))
for _, model := range candidates {
if model.Type != models.ModelTypeChat {
@@ -204,92 +397,60 @@ func (r *Resolver) selectChatModel(ctx context.Context, modelID, providerType st
filtered = append(filtered, model)
}
if len(filtered) == 0 {
return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("no chat models available")
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available")
}
// If model specified, find it
if modelID != "" {
for _, model := range filtered {
if model.ModelID == modelID {
provider, err := r.fetchProvider(ctx, model.LlmProviderID)
provider, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
return model, provider, nil
}
}
return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("chat model not found")
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not found")
}
// Return first available model
selected := filtered[0]
provider, err := r.fetchProvider(ctx, selected.LlmProviderID)
provider, err := models.FetchProviderByID(ctx, r.queries, selected.LlmProviderID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
return selected, provider, nil
}
// fetchProvider fetches provider information from database
func (r *Resolver) fetchProvider(ctx context.Context, providerID string) (sqlc.LlmProvider, error) {
if r.queries == nil {
return sqlc.LlmProvider{}, errors.New("llm provider queries not configured")
func normalizeMaxContextLoad(value int) int {
if value <= 0 {
return defaultMaxContextMinutes
}
if strings.TrimSpace(providerID) == "" {
return sqlc.LlmProvider{}, errors.New("llm provider id missing")
}
parsed, err := uuid.Parse(providerID)
if err != nil {
return sqlc.LlmProvider{}, err
}
pgID := pgtype.UUID{Valid: true}
copy(pgID.Bytes[:], parsed[:])
return r.queries.GetLlmProviderByID(ctx, pgID)
return value
}
// createProvider creates a provider instance based on configuration
func (r *Resolver) createProvider(provider sqlc.LlmProvider) (Provider, error) {
clientType := strings.ToLower(strings.TrimSpace(provider.ClientType))
timeout := r.timeout
if timeout <= 0 {
timeout = 30 * time.Second
}
switch clientType {
case ProviderOpenAI, ProviderOpenAICompat:
if strings.TrimSpace(provider.ApiKey) == "" {
return nil, errors.New("openai api key is required")
}
p, err := NewOpenAIProvider(provider.ApiKey, provider.BaseUrl, timeout)
if err != nil {
return nil, err
}
return p, nil
case ProviderAnthropic:
if strings.TrimSpace(provider.ApiKey) == "" {
return nil, errors.New("anthropic api key is required")
}
p, err := NewAnthropicProvider(provider.ApiKey, timeout)
if err != nil {
return nil, err
}
return p, nil
case ProviderGoogle:
if strings.TrimSpace(provider.ApiKey) == "" {
return nil, errors.New("google api key is required")
}
p, err := NewGoogleProvider(provider.ApiKey, timeout)
if err != nil {
return nil, err
}
return p, nil
case ProviderOllama:
p, err := NewOllamaProvider(provider.BaseUrl, timeout)
if err != nil {
return nil, err
}
return p, nil
func normalizeClientType(clientType string) (string, error) {
switch strings.ToLower(strings.TrimSpace(clientType)) {
case "openai":
return "openai", nil
case "openai-compat":
return "openai", nil
case "anthropic":
return "anthropic", nil
case "google":
return "google", nil
default:
return nil, errors.New("unsupported provider type: " + clientType)
return "", fmt.Errorf("unsupported agent gateway client type: %s", clientType)
}
}
func parseUUID(id string) (pgtype.UUID, error) {
parsed, err := uuid.Parse(id)
if err != nil {
return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err)
}
var pgID pgtype.UUID
pgID.Valid = true
copy(pgID.Bytes[:], parsed[:])
return pgID, nil
}
+20 -53
View File
@@ -1,65 +1,32 @@
package chat
// Message represents a chat message
import "encoding/json"
type Message struct {
Role string `json:"role"` // user, assistant, system
Role string `json:"role"`
Content string `json:"content"`
}
// ChatRequest represents an incoming chat request
type GatewayMessage map[string]interface{}
type ChatRequest struct {
Messages []Message `json:"messages"`
Model string `json:"model,omitempty"` // optional: specific model to use
Provider string `json:"provider,omitempty"` // optional: specific provider to use
Stream bool `json:"stream,omitempty"`
UserID string `json:"-"`
Query string `json:"query"`
Model string `json:"model,omitempty"`
Provider string `json:"provider,omitempty"`
MaxContextLoadTime int `json:"max_context_load_time,omitempty"`
Locale string `json:"locale,omitempty"`
Language string `json:"language,omitempty"`
MaxSteps int `json:"max_steps,omitempty"`
Platforms []string `json:"platforms,omitempty"`
CurrentPlatform string `json:"current_platform,omitempty"`
Messages []GatewayMessage `json:"messages,omitempty"`
}
// ChatResponse represents a chat response
type ChatResponse struct {
Message Message `json:"message"`
Model string `json:"model"`
Provider string `json:"provider"`
FinishReason string `json:"finish_reason,omitempty"`
Usage Usage `json:"usage,omitempty"`
Messages []GatewayMessage `json:"messages"`
Model string `json:"model,omitempty"`
Provider string `json:"provider,omitempty"`
}
// StreamChunk represents a chunk in streaming response
type StreamChunk struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
Model string `json:"model,omitempty"`
Provider string `json:"provider,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
}
// Usage represents token usage information
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// Request is the internal request structure
type Request struct {
Messages []Message
Model string
Provider string
Temperature *float32 // optional temperature
ResponseFormat *ResponseFormat // optional response format
MaxTokens *int // optional max tokens
}
// ResponseFormat specifies the format of the response
type ResponseFormat struct {
Type string `json:"type"` // "json_object" or "text"
}
// Result is the internal result structure
type Result struct {
Message Message
Model string
Provider string
FinishReason string
Usage Usage
}
type StreamChunk = json.RawMessage
+23
View File
@@ -1,6 +1,7 @@
package config
import (
"fmt"
"os"
"github.com/BurntSushi/toml"
@@ -31,6 +32,7 @@ type Config struct {
MCP MCPConfig `toml:"mcp"`
Postgres PostgresConfig `toml:"postgres"`
Qdrant QdrantConfig `toml:"qdrant"`
AgentGateway AgentGatewayConfig `toml:"agent_gateway"`
}
type ServerConfig struct {
@@ -70,6 +72,23 @@ type QdrantConfig struct {
TimeoutSeconds int `toml:"timeout_seconds"`
}
type AgentGatewayConfig struct {
Host string `toml:"host"`
Port int `toml:"port"`
}
func (c AgentGatewayConfig) BaseURL() string {
host := c.Host
if host == "" {
host = "127.0.0.1"
}
port := c.Port
if port == 0 {
port = 8081
}
return "http://" + host + ":" + fmt.Sprint(port)
}
func Load(path string) (Config, error) {
cfg := Config{
Server: ServerConfig{
@@ -98,6 +117,10 @@ func Load(path string) (Config, error) {
BaseURL: DefaultQdrantURL,
Collection: DefaultQdrantCollection,
},
AgentGateway: AgentGatewayConfig{
Host: "127.0.0.1",
Port: 8081,
},
}
if path == "" {
+73
View File
@@ -0,0 +1,73 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: history.sql
package sqlc
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const createHistory = `-- name: CreateHistory :one
INSERT INTO history (messages, timestamp, "user")
VALUES ($1, $2, $3)
RETURNING id, messages, timestamp, "user"
`
type CreateHistoryParams struct {
Messages []byte `json:"messages"`
Timestamp pgtype.Timestamptz `json:"timestamp"`
User pgtype.UUID `json:"user"`
}
func (q *Queries) CreateHistory(ctx context.Context, arg CreateHistoryParams) (History, error) {
row := q.db.QueryRow(ctx, createHistory, arg.Messages, arg.Timestamp, arg.User)
var i History
err := row.Scan(
&i.ID,
&i.Messages,
&i.Timestamp,
&i.User,
)
return i, err
}
const listHistoryByUserSince = `-- name: ListHistoryByUserSince :many
SELECT id, messages, timestamp, "user"
FROM history
WHERE "user" = $1 AND timestamp >= $2
ORDER BY timestamp ASC
`
type ListHistoryByUserSinceParams struct {
User pgtype.UUID `json:"user"`
Timestamp pgtype.Timestamptz `json:"timestamp"`
}
func (q *Queries) ListHistoryByUserSince(ctx context.Context, arg ListHistoryByUserSinceParams) ([]History, error) {
rows, err := q.db.Query(ctx, listHistoryByUserSince, arg.User, arg.Timestamp)
if err != nil {
return nil, err
}
defer rows.Close()
var items []History
for rows.Next() {
var i History
if err := rows.Scan(
&i.ID,
&i.Messages,
&i.Timestamp,
&i.User,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
+7
View File
@@ -33,6 +33,13 @@ type ContainerVersion struct {
CreatedAt pgtype.Timestamptz `json:"created_at"`
}
type History struct {
ID pgtype.UUID `json:"id"`
Messages []byte `json:"messages"`
Timestamp pgtype.Timestamptz `json:"timestamp"`
User pgtype.UUID `json:"user"`
}
type LifecycleEvent struct {
ID string `json:"id"`
ContainerID string `json:"container_id"`
+30 -4
View File
@@ -7,7 +7,10 @@ import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/chat"
"github.com/memohai/memoh/internal/identity"
)
type ChatHandler struct {
@@ -36,14 +39,20 @@ func (h *ChatHandler) Register(e *echo.Echo) {
// @Failure 500 {object} ErrorResponse
// @Router /chat [post]
func (h *ChatHandler) Chat(c echo.Context) error {
userID, err := h.requireUserID(c)
if err != nil {
return err
}
var req chat.ChatRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if len(req.Messages) == 0 {
return echo.NewHTTPError(http.StatusBadRequest, "messages are required")
if req.Query == "" {
return echo.NewHTTPError(http.StatusBadRequest, "query is required")
}
req.UserID = userID
resp, err := h.resolver.Chat(c.Request().Context(), req)
if err != nil {
@@ -65,14 +74,20 @@ func (h *ChatHandler) Chat(c echo.Context) error {
// @Failure 500 {object} ErrorResponse
// @Router /chat/stream [post]
func (h *ChatHandler) StreamChat(c echo.Context) error {
userID, err := h.requireUserID(c)
if err != nil {
return err
}
var req chat.ChatRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if len(req.Messages) == 0 {
return echo.NewHTTPError(http.StatusBadRequest, "messages are required")
if req.Query == "" {
return echo.NewHTTPError(http.StatusBadRequest, "query is required")
}
req.UserID = userID
// Set headers for SSE
c.Response().Header().Set(echo.HeaderContentType, "text/event-stream")
@@ -127,3 +142,14 @@ func (h *ChatHandler) StreamChat(c echo.Context) error {
}
}
}
func (h *ChatHandler) requireUserID(c echo.Context) (string, error) {
userID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateUserID(userID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return userID, nil
}
-129
View File
@@ -1,129 +0,0 @@
package memory
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/memohai/memoh/internal/chat"
)
// ProviderLLMClient uses chat.Provider to make LLM calls for memory operations
type ProviderLLMClient struct {
provider chat.Provider
model string
}
// NewProviderLLMClient creates a new LLM client that uses chat.Provider
func NewProviderLLMClient(provider chat.Provider, model string) *ProviderLLMClient {
if model == "" {
model = "gpt-4.1-nano-2025-04-14"
}
return &ProviderLLMClient{
provider: provider,
model: model,
}
}
// Extract extracts facts from messages using the provider
func (c *ProviderLLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) {
if len(req.Messages) == 0 {
return ExtractResponse{}, fmt.Errorf("messages is required")
}
parsedMessages := parseMessages(formatMessages(req.Messages))
systemPrompt, userPrompt := getFactRetrievalMessages(parsedMessages)
// Call provider with JSON mode
temp := float32(0)
result, err := c.provider.Chat(ctx, chat.Request{
Model: c.model,
Temperature: &temp,
ResponseFormat: &chat.ResponseFormat{
Type: "json_object",
},
Messages: []chat.Message{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userPrompt},
},
})
if err != nil {
return ExtractResponse{}, err
}
content := result.Message.Content
var parsed ExtractResponse
if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &parsed); err != nil {
return ExtractResponse{}, err
}
return parsed, nil
}
// Decide decides what actions to take based on facts and existing memories
func (c *ProviderLLMClient) Decide(ctx context.Context, req DecideRequest) (DecideResponse, error) {
if len(req.Facts) == 0 {
return DecideResponse{}, fmt.Errorf("facts is required")
}
retrieved := make([]map[string]string, 0, len(req.Candidates))
for _, candidate := range req.Candidates {
retrieved = append(retrieved, map[string]string{
"id": candidate.ID,
"text": candidate.Memory,
})
}
prompt := getUpdateMemoryMessages(retrieved, req.Facts)
// Call provider with JSON mode
temp := float32(0)
result, err := c.provider.Chat(ctx, chat.Request{
Model: c.model,
Temperature: &temp,
ResponseFormat: &chat.ResponseFormat{
Type: "json_object",
},
Messages: []chat.Message{
{Role: "user", Content: prompt},
},
})
if err != nil {
return DecideResponse{}, err
}
content := result.Message.Content
var raw map[string]interface{}
if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &raw); err != nil {
return DecideResponse{}, err
}
memoryItems := normalizeMemoryItems(raw["memory"])
actions := make([]DecisionAction, 0, len(memoryItems))
for _, item := range memoryItems {
event := strings.ToUpper(asString(item["event"]))
if event == "" {
event = "ADD"
}
if event == "NONE" {
continue
}
text := asString(item["text"])
if text == "" {
text = asString(item["fact"])
}
if strings.TrimSpace(text) == "" {
continue
}
actions = append(actions, DecisionAction{
Event: event,
ID: normalizeID(item["id"]),
Text: text,
OldMemory: asString(item["old_memory"]),
})
}
return DecideResponse{Actions: actions}, nil
}
+4 -4
View File
@@ -15,7 +15,7 @@ type Server struct {
addr string
}
func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, swaggerHandler *handlers.SwaggerHandler, chatHandler *handlers.ChatHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, containerdHandler *handlers.ContainerdHandler) *Server {
func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, chatHandler *handlers.ChatHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, containerdHandler *handlers.ContainerdHandler) *Server {
if addr == "" {
addr = ":8080"
}
@@ -50,12 +50,12 @@ func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler,
if embeddingsHandler != nil {
embeddingsHandler.Register(e)
}
if swaggerHandler != nil {
swaggerHandler.Register(e)
}
if chatHandler != nil {
chatHandler.Register(e)
}
if swaggerHandler != nil {
swaggerHandler.Register(e)
}
if providersHandler != nil {
providersHandler.Register(e)
}