mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat: chat api
This commit is contained in:
+3
-1
@@ -95,4 +95,6 @@ docs/docs/.vitepress/cache
|
||||
dump.rdb
|
||||
memory.db
|
||||
|
||||
config.toml
|
||||
config.toml
|
||||
|
||||
.workdocs/
|
||||
@@ -1,477 +0,0 @@
|
||||
# Providers 模块文档
|
||||
|
||||
## 概述
|
||||
|
||||
本文档描述了独立的 Providers 模块,用于管理 LLM Provider 配置。
|
||||
|
||||
## 架构设计
|
||||
|
||||
### 模块结构
|
||||
|
||||
```
|
||||
internal/
|
||||
├── providers/ # 独立的 provider 模块
|
||||
│ ├── types.go # 类型定义
|
||||
│ └── service.go # 业务逻辑
|
||||
└── handlers/
|
||||
└── providers.go # API 处理器
|
||||
```
|
||||
|
||||
### 分层设计
|
||||
|
||||
```
|
||||
┌─────────────────────┐
|
||||
│ API Layer │
|
||||
│ (handlers) │
|
||||
└──────────┬──────────┘
|
||||
│
|
||||
┌──────────▼──────────┐
|
||||
│ Service Layer │
|
||||
│ (providers pkg) │
|
||||
└──────────┬──────────┘
|
||||
│
|
||||
┌──────────▼──────────┐
|
||||
│ Data Layer │
|
||||
│ (sqlc queries) │
|
||||
└─────────────────────┘
|
||||
```
|
||||
|
||||
## API 端点
|
||||
|
||||
### Providers API (`/providers`)
|
||||
|
||||
所有端点都需要 JWT 认证。
|
||||
|
||||
#### 1. 创建 Provider
|
||||
|
||||
```http
|
||||
POST /providers
|
||||
Content-Type: application/json
|
||||
Authorization: Bearer <token>
|
||||
|
||||
{
|
||||
"name": "OpenAI Official",
|
||||
"client_type": "openai",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-...",
|
||||
"metadata": {
|
||||
"description": "Official OpenAI API"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**响应** (201 Created):
|
||||
```json
|
||||
{
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "OpenAI Official",
|
||||
"client_type": "openai",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-12345***",
|
||||
"metadata": {
|
||||
"description": "Official OpenAI API"
|
||||
},
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
#### 2. 列出所有 Providers
|
||||
|
||||
```http
|
||||
GET /providers
|
||||
Authorization: Bearer <token>
|
||||
```
|
||||
|
||||
**可选查询参数**:
|
||||
- `client_type` - 按客户端类型过滤 (openai, anthropic, google, ollama)
|
||||
|
||||
**响应** (200 OK):
|
||||
```json
|
||||
[
|
||||
{
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "OpenAI Official",
|
||||
"client_type": "openai",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-12345***",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
#### 3. 获取单个 Provider
|
||||
|
||||
```http
|
||||
GET /providers/{id}
|
||||
Authorization: Bearer <token>
|
||||
```
|
||||
|
||||
或者按名称获取:
|
||||
|
||||
```http
|
||||
GET /providers/name/{name}
|
||||
Authorization: Bearer <token>
|
||||
```
|
||||
|
||||
**响应** (200 OK): 同创建响应
|
||||
|
||||
#### 4. 更新 Provider
|
||||
|
||||
```http
|
||||
PUT /providers/{id}
|
||||
Content-Type: application/json
|
||||
Authorization: Bearer <token>
|
||||
|
||||
{
|
||||
"name": "OpenAI Updated",
|
||||
"api_key": "sk-newkey..."
|
||||
}
|
||||
```
|
||||
|
||||
**注意**: 所有字段都是可选的,只更新提供的字段。
|
||||
|
||||
**响应** (200 OK): 返回更新后的 provider
|
||||
|
||||
#### 5. 删除 Provider
|
||||
|
||||
```http
|
||||
DELETE /providers/{id}
|
||||
Authorization: Bearer <token>
|
||||
```
|
||||
|
||||
**响应** (204 No Content)
|
||||
|
||||
#### 6. 统计 Provider 数量
|
||||
|
||||
```http
|
||||
GET /providers/count
|
||||
Authorization: Bearer <token>
|
||||
```
|
||||
|
||||
**可选查询参数**:
|
||||
- `client_type` - 按客户端类型过滤
|
||||
|
||||
**响应** (200 OK):
|
||||
```json
|
||||
{
|
||||
"count": 5
|
||||
}
|
||||
```
|
||||
|
||||
## 支持的 Client Types
|
||||
|
||||
| Client Type | 描述 | 需要 API Key |
|
||||
|------------|------|-------------|
|
||||
| `openai` | OpenAI 官方 API | ✅ |
|
||||
| `openai-compat` | OpenAI 兼容的 API | ✅ |
|
||||
| `anthropic` | Anthropic Claude API | ✅ |
|
||||
| `google` | Google Gemini API | ✅ |
|
||||
| `ollama` | 本地 Ollama | ❌ |
|
||||
|
||||
## 数据模型
|
||||
|
||||
### Provider 结构
|
||||
|
||||
```go
|
||||
type CreateRequest struct {
|
||||
Name string `json:"name"` // 必填
|
||||
ClientType ClientType `json:"client_type"` // 必填
|
||||
BaseURL string `json:"base_url"` // 必填
|
||||
APIKey string `json:"api_key"` // 可选
|
||||
Metadata map[string]interface{} `json:"metadata"` // 可选
|
||||
}
|
||||
|
||||
type GetResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
ClientType string `json:"client_type"`
|
||||
BaseURL string `json:"base_url"`
|
||||
APIKey string `json:"api_key"` // 已脱敏
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
```
|
||||
|
||||
## 安全特性
|
||||
|
||||
### 1. API Key 脱敏
|
||||
|
||||
在响应中,API Key 会被自动脱敏:
|
||||
- 只显示前 8 个字符
|
||||
- 其余部分用 `*` 替换
|
||||
- 例如: `sk-12345678***`
|
||||
|
||||
### 2. 认证保护
|
||||
|
||||
所有 API 端点都需要 JWT 认证:
|
||||
```http
|
||||
Authorization: Bearer <your-jwt-token>
|
||||
```
|
||||
|
||||
### 3. 输入验证
|
||||
|
||||
- 自动验证 UUID 格式
|
||||
- 验证 client_type 是否支持
|
||||
- 验证必填字段
|
||||
|
||||
## 使用示例
|
||||
|
||||
### 1. 配置 OpenAI Provider
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/providers \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"name": "OpenAI GPT-4",
|
||||
"client_type": "openai",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-..."
|
||||
}'
|
||||
```
|
||||
|
||||
### 2. 配置自定义 OpenAI 兼容服务
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/providers \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"name": "Azure OpenAI",
|
||||
"client_type": "openai-compat",
|
||||
"base_url": "https://your-resource.openai.azure.com/v1",
|
||||
"api_key": "your-azure-key",
|
||||
"metadata": {
|
||||
"deployment": "gpt-4",
|
||||
"region": "eastus"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### 3. 配置本地 Ollama
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/providers \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"name": "Local Ollama",
|
||||
"client_type": "ollama",
|
||||
"base_url": "http://localhost:11434"
|
||||
}'
|
||||
```
|
||||
|
||||
### 4. 列出所有 OpenAI Providers
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/providers?client_type=openai \
|
||||
-H "Authorization: Bearer $TOKEN"
|
||||
```
|
||||
|
||||
### 5. 更新 Provider API Key
|
||||
|
||||
```bash
|
||||
curl -X PUT http://localhost:8080/providers/{id} \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"api_key": "sk-new-key..."
|
||||
}'
|
||||
```
|
||||
|
||||
## 与 Models 的关系
|
||||
|
||||
Provider 和 Model 是一对多的关系:
|
||||
|
||||
```
|
||||
┌─────────────┐
|
||||
│ Provider │
|
||||
│ (OpenAI) │
|
||||
└──────┬──────┘
|
||||
│
|
||||
├─── Model (gpt-4)
|
||||
├─── Model (gpt-3.5-turbo)
|
||||
└─── Model (text-embedding-ada-002)
|
||||
```
|
||||
|
||||
### 创建 Model 时引用 Provider
|
||||
|
||||
```bash
|
||||
# 1. 创建 Provider
|
||||
PROVIDER_ID=$(curl -X POST http://localhost:8080/providers \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-d '{"name":"OpenAI","client_type":"openai",...}' \
|
||||
| jq -r '.id')
|
||||
|
||||
# 2. 创建 Model 并引用 Provider
|
||||
curl -X POST http://localhost:8080/models \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-d '{
|
||||
"model_id": "gpt-4",
|
||||
"name": "GPT-4",
|
||||
"llm_provider_id": "'$PROVIDER_ID'",
|
||||
"type": "chat"
|
||||
}'
|
||||
```
|
||||
|
||||
## 代码集成
|
||||
|
||||
### 在代码中使用 Provider Service
|
||||
|
||||
```go
|
||||
import "github.com/memohai/memoh/internal/providers"
|
||||
|
||||
// 创建 service
|
||||
providersService := providers.NewService(queries)
|
||||
|
||||
// 创建 provider
|
||||
provider, err := providersService.Create(ctx, providers.CreateRequest{
|
||||
Name: "OpenAI",
|
||||
ClientType: providers.ClientTypeOpenAI,
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
APIKey: "sk-...",
|
||||
})
|
||||
|
||||
// 列出所有 providers
|
||||
allProviders, err := providersService.List(ctx)
|
||||
|
||||
// 按类型过滤
|
||||
openaiProviders, err := providersService.ListByClientType(ctx, providers.ClientTypeOpenAI)
|
||||
|
||||
// 获取单个 provider
|
||||
provider, err := providersService.Get(ctx, "provider-uuid")
|
||||
|
||||
// 更新 provider
|
||||
updated, err := providersService.Update(ctx, "provider-uuid", providers.UpdateRequest{
|
||||
APIKey: stringPtr("new-key"),
|
||||
})
|
||||
|
||||
// 删除 provider
|
||||
err := providersService.Delete(ctx, "provider-uuid")
|
||||
```
|
||||
|
||||
## 错误处理
|
||||
|
||||
### 常见错误
|
||||
|
||||
| 状态码 | 错误 | 原因 |
|
||||
|-------|------|------|
|
||||
| 400 | Bad Request | 缺少必填字段或格式错误 |
|
||||
| 404 | Not Found | Provider ID 不存在 |
|
||||
| 409 | Conflict | Provider 名称已存在 |
|
||||
| 500 | Internal Server Error | 服务器错误 |
|
||||
|
||||
### 错误响应格式
|
||||
|
||||
```json
|
||||
{
|
||||
"message": "invalid UUID: invalid UUID length: 5"
|
||||
}
|
||||
```
|
||||
|
||||
## 数据库 Schema
|
||||
|
||||
Providers 存储在 `llm_providers` 表:
|
||||
|
||||
```sql
|
||||
CREATE TABLE llm_providers (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name TEXT NOT NULL,
|
||||
client_type TEXT NOT NULL,
|
||||
base_url TEXT NOT NULL,
|
||||
api_key TEXT NOT NULL,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
CONSTRAINT llm_providers_name_unique UNIQUE (name),
|
||||
CONSTRAINT llm_providers_client_type_check
|
||||
CHECK (client_type IN ('openai', 'openai-compat', 'anthropic', 'google', 'ollama'))
|
||||
);
|
||||
```
|
||||
|
||||
## 最佳实践
|
||||
|
||||
### 1. 命名规范
|
||||
|
||||
- 使用描述性名称: `OpenAI GPT-4`, `Anthropic Claude 3`
|
||||
- 包含环境信息: `OpenAI Production`, `Ollama Local`
|
||||
- 避免特殊字符和空格
|
||||
|
||||
### 2. API Key 管理
|
||||
|
||||
- 定期轮换 API Keys
|
||||
- 使用环境变量存储敏感信息
|
||||
- 不要在日志中输出完整的 API Key
|
||||
|
||||
### 3. Metadata 使用
|
||||
|
||||
使用 metadata 字段存储额外信息:
|
||||
|
||||
```json
|
||||
{
|
||||
"metadata": {
|
||||
"environment": "production",
|
||||
"rate_limit": 10000,
|
||||
"contact": "admin@example.com",
|
||||
"notes": "Primary provider for production"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Provider 测试
|
||||
|
||||
创建 provider 后,建议测试连接:
|
||||
|
||||
```go
|
||||
// 未来可以实现的测试端点
|
||||
POST /providers/test
|
||||
{
|
||||
"client_type": "openai",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-...",
|
||||
"model": "gpt-3.5-turbo"
|
||||
}
|
||||
```
|
||||
|
||||
## 迁移指南
|
||||
|
||||
### 从配置文件迁移到数据库
|
||||
|
||||
**旧方式** (config.toml):
|
||||
```toml
|
||||
[memory]
|
||||
base_url = "https://api.openai.com/v1"
|
||||
api_key = "sk-..."
|
||||
model = "gpt-4"
|
||||
```
|
||||
|
||||
**新方式** (API):
|
||||
```bash
|
||||
# 1. 创建 provider
|
||||
curl -X POST http://localhost:8080/providers \
|
||||
-d '{"name":"OpenAI","client_type":"openai","base_url":"https://api.openai.com/v1","api_key":"sk-..."}'
|
||||
|
||||
# 2. 创建 model
|
||||
curl -X POST http://localhost:8080/models \
|
||||
-d '{"model_id":"gpt-4","llm_provider_id":"<provider-id>","type":"chat","enable_as":"memory"}'
|
||||
```
|
||||
|
||||
## 相关文档
|
||||
|
||||
- [Provider 重构总结](./REFACTORING_SUMMARY.md)
|
||||
- [Agent 迁移文档](./AGENT_MIGRATION.md)
|
||||
- [Chat 架构文档](./internal/chat/ARCHITECTURE.md)
|
||||
- [Models API 文档](#) (待创建)
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] 实现 provider 连接测试端点
|
||||
- [ ] 添加 provider 使用统计
|
||||
- [ ] 实现 provider 健康检查
|
||||
- [ ] 添加 provider 费用跟踪
|
||||
- [ ] 支持 provider 负载均衡
|
||||
- [ ] 实现 provider 故障转移
|
||||
|
||||
@@ -1,273 +0,0 @@
|
||||
# Provider 重构总结
|
||||
|
||||
## 重构目标
|
||||
|
||||
将 Memory 服务和 Chat 服务统一使用 `internal/chat` 中的 Provider 接口,实现代码复用和架构统一。
|
||||
|
||||
## 主要变更
|
||||
|
||||
### 1. 扩展 Chat Provider 接口
|
||||
|
||||
**文件**: `internal/chat/types.go`
|
||||
|
||||
- 扩展 `Request` 结构,添加了:
|
||||
- `Temperature *float32` - 温度参数
|
||||
- `ResponseFormat *ResponseFormat` - 响应格式(支持 JSON 模式)
|
||||
- `MaxTokens *int` - 最大 token 数
|
||||
|
||||
- 新增 `ResponseFormat` 结构:
|
||||
```go
|
||||
type ResponseFormat struct {
|
||||
Type string // "json_object" 或 "text"
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 创建 Memory 的 LLM 接口
|
||||
|
||||
**文件**: `internal/memory/types.go`
|
||||
|
||||
- 定义了 `LLM` 接口:
|
||||
```go
|
||||
type LLM interface {
|
||||
Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error)
|
||||
Decide(ctx context.Context, req DecideRequest) (DecideResponse, error)
|
||||
}
|
||||
```
|
||||
|
||||
- 这个接口被以下两个实现:
|
||||
- `LLMClient` (旧实现,向后兼容)
|
||||
- `ProviderLLMClient` (新实现,使用 chat.Provider)
|
||||
|
||||
### 3. 创建 Provider-based LLM 客户端
|
||||
|
||||
**文件**: `internal/memory/llm_provider_client.go` (新文件)
|
||||
|
||||
- 实现了 `ProviderLLMClient` 结构
|
||||
- 使用 `chat.Provider` 来执行 LLM 调用
|
||||
- 支持 JSON 模式输出,确保结构化响应
|
||||
- 重用了 `internal/memory` 包中的辅助函数
|
||||
|
||||
关键代码:
|
||||
```go
|
||||
type ProviderLLMClient struct {
|
||||
provider chat.Provider
|
||||
model string
|
||||
}
|
||||
|
||||
func NewProviderLLMClient(provider chat.Provider, model string) *ProviderLLMClient
|
||||
func (c *ProviderLLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error)
|
||||
func (c *ProviderLLMClient) Decide(ctx context.Context, req DecideRequest) (DecideResponse, error)
|
||||
```
|
||||
|
||||
### 4. 更新 Memory 服务
|
||||
|
||||
**文件**: `internal/memory/service.go`
|
||||
|
||||
- 将 `llm *LLMClient` 改为 `llm LLM`
|
||||
- 现在接受任何实现 `LLM` 接口的类型
|
||||
- 保持了所有现有功能
|
||||
|
||||
### 5. 更新主程序
|
||||
|
||||
**文件**: `cmd/agent/main.go`
|
||||
|
||||
#### 添加的函数:
|
||||
|
||||
1. **selectMemoryModel** - 选择用于 memory 操作的模型
|
||||
- 优先级:memory 模型 → chat 模型 → 任何 chat 类型模型
|
||||
|
||||
2. **fetchProviderByID** - 根据 ID 获取 provider 配置
|
||||
|
||||
3. **createChatProvider** - 根据配置创建 provider 实例
|
||||
- 支持 OpenAI、Anthropic、Google、Ollama
|
||||
|
||||
#### 初始化流程更新:
|
||||
|
||||
```go
|
||||
// 1. 初始化 chat resolver(用于 chat 和 memory)
|
||||
chatResolver := chat.NewResolver(modelsService, queries, 30*time.Second)
|
||||
|
||||
// 2. 尝试为 memory 创建 provider-based 客户端
|
||||
memoryModel, memoryProvider, err := selectMemoryModel(ctx, modelsService, queries, &cfg)
|
||||
if err != nil {
|
||||
// 回退到旧的 LLMClient
|
||||
llmClient = memory.NewLLMClient(cfg.Memory.BaseURL, cfg.Memory.APIKey, ...)
|
||||
} else {
|
||||
// 使用新的 provider-based 客户端
|
||||
provider, _ := createChatProvider(memoryProvider, 30*time.Second)
|
||||
llmClient = memory.NewProviderLLMClient(provider, memoryModel.ModelID)
|
||||
}
|
||||
|
||||
// 3. 创建 memory 服务
|
||||
memoryService = memory.NewService(llmClient, embedder, store, resolver, ...)
|
||||
```
|
||||
|
||||
## 架构优势
|
||||
|
||||
### 1. 统一的 Provider 管理
|
||||
- Chat 和 Memory 服务共享相同的 Provider 实现
|
||||
- 减少代码重复
|
||||
- 统一的配置和管理
|
||||
|
||||
### 2. 灵活的模型选择
|
||||
- 可以为不同功能配置不同的模型
|
||||
- 支持 `enable_as` 字段来指定模型用途
|
||||
- 自动回退机制
|
||||
|
||||
### 3. 向后兼容
|
||||
- 保留了旧的 `LLMClient` 实现
|
||||
- 如果数据库中没有配置模型,自动回退到配置文件
|
||||
- 平滑迁移路径
|
||||
|
||||
### 4. 类型安全
|
||||
- 使用 Go 接口而不是运行时类型判断
|
||||
- 编译时类型检查
|
||||
- 更好的 IDE 支持
|
||||
|
||||
### 5. 易于扩展
|
||||
- 添加新 Provider 只需实现 `Provider` 接口
|
||||
- 添加新 LLM 客户端只需实现 `LLM` 接口
|
||||
- 模块化设计
|
||||
|
||||
## 配置说明
|
||||
|
||||
### 数据库配置
|
||||
|
||||
为 Memory 操作配置模型:
|
||||
|
||||
```sql
|
||||
-- 方式 1: 使用专用的 memory 模型
|
||||
UPDATE models SET enable_as = 'memory'
|
||||
WHERE model_id = 'gpt-4-turbo-preview';
|
||||
|
||||
-- 方式 2: 使用 chat 模型(如果没有专用 memory 模型)
|
||||
UPDATE models SET enable_as = 'chat'
|
||||
WHERE model_id = 'gpt-4';
|
||||
```
|
||||
|
||||
### 环境变量(回退配置)
|
||||
|
||||
如果数据库中没有配置模型,系统会使用这些配置:
|
||||
|
||||
```toml
|
||||
[memory]
|
||||
base_url = "https://api.openai.com/v1"
|
||||
api_key = "sk-..."
|
||||
model = "gpt-4.1-nano"
|
||||
timeout_seconds = 10
|
||||
```
|
||||
|
||||
## 测试建议
|
||||
|
||||
### 1. Memory 操作测试
|
||||
```bash
|
||||
# 测试 Extract(提取事实)
|
||||
curl -X POST http://localhost:8080/api/memory/add \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": "My name is Alice and I like pizza"}
|
||||
],
|
||||
"user_id": "user-123"
|
||||
}'
|
||||
|
||||
# 测试 Search(搜索记忆)
|
||||
curl -X POST http://localhost:8080/api/memory/search \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-d '{
|
||||
"query": "What food do I like?",
|
||||
"user_id": "user-123"
|
||||
}'
|
||||
```
|
||||
|
||||
### 2. Chat 操作测试
|
||||
```bash
|
||||
# 测试普通聊天
|
||||
curl -X POST http://localhost:8080/api/chat \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
### 3. 验证日志
|
||||
启动时应该看到:
|
||||
```
|
||||
Using memory model: gpt-4-turbo-preview (provider: openai)
|
||||
```
|
||||
|
||||
或者(如果回退):
|
||||
```
|
||||
WARNING: No memory model configured, using fallback LLMClient: ...
|
||||
```
|
||||
|
||||
## 后续工作
|
||||
|
||||
### 短期(必须)
|
||||
- [ ] 实现各个 Provider 的具体逻辑(目前大部分返回 "not yet implemented")
|
||||
- [ ] 添加流式响应支持
|
||||
- [ ] 完善错误处理
|
||||
|
||||
### 中期(建议)
|
||||
- [ ] 添加 Provider 的单元测试
|
||||
- [ ] 添加 Memory 集成测试
|
||||
- [ ] 实现 Provider 连接池
|
||||
- [ ] 添加请求重试机制
|
||||
|
||||
### 长期(优化)
|
||||
- [ ] Provider 性能监控
|
||||
- [ ] 自动模型选择和负载均衡
|
||||
- [ ] 模型响应缓存
|
||||
- [ ] 支持更多 Provider(如 Cohere、HuggingFace)
|
||||
|
||||
## 迁移检查清单
|
||||
|
||||
- [x] 扩展 Request 结构支持 JSON 模式
|
||||
- [x] 创建 LLM 接口
|
||||
- [x] 实现 ProviderLLMClient
|
||||
- [x] 更新 Memory Service 使用接口
|
||||
- [x] 更新主程序初始化流程
|
||||
- [x] 添加模型选择逻辑
|
||||
- [x] 添加 Provider 创建逻辑
|
||||
- [x] 保持向后兼容
|
||||
- [x] 添加架构文档
|
||||
- [ ] 添加单元测试
|
||||
- [ ] 添加集成测试
|
||||
- [ ] 更新部署文档
|
||||
|
||||
## 文件清单
|
||||
|
||||
### 新增文件
|
||||
- `internal/memory/llm_provider_client.go` - Provider-based LLM 客户端
|
||||
- `internal/chat/ARCHITECTURE.md` - 架构文档
|
||||
- `REFACTORING_SUMMARY.md` - 本文件
|
||||
|
||||
### 修改文件
|
||||
- `internal/chat/types.go` - 扩展 Request 结构
|
||||
- `internal/memory/types.go` - 添加 LLM 接口
|
||||
- `internal/memory/service.go` - 使用 LLM 接口
|
||||
- `cmd/agent/main.go` - 更新初始化流程
|
||||
|
||||
### 保留文件(向后兼容)
|
||||
- `internal/memory/llm_client.go` - 旧的 HTTP 客户端实现
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **JSON 模式兼容性**: 不是所有模型都支持 JSON 模式,需要在实现 Provider 时处理
|
||||
2. **错误处理**: 当前错误处理较简单,生产环境需要更详细的错误信息
|
||||
3. **超时设置**: 不同操作可能需要不同的超时时间,可以考虑配置化
|
||||
4. **并发安全**: Provider 实例应该是并发安全的
|
||||
5. **资源清理**: 确保 Provider 的资源(如 HTTP 连接)正确释放
|
||||
|
||||
## 问题反馈
|
||||
|
||||
如果遇到问题,请检查:
|
||||
1. 数据库中是否有配置的 chat 模型
|
||||
2. Provider 配置是否正确(API key、base URL)
|
||||
3. 日志中的错误信息
|
||||
4. 是否正确初始化了 chat resolver
|
||||
|
||||
详细架构说明请参考:`internal/chat/ARCHITECTURE.md`
|
||||
|
||||
+45
-14
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -71,22 +72,15 @@ func main() {
|
||||
authHandler := handlers.NewAuthHandler(conn, cfg.Auth.JWTSecret, jwtExpiresIn)
|
||||
|
||||
// Initialize chat resolver for both chat and memory operations
|
||||
chatResolver := chat.NewResolver(modelsService, queries, 30*time.Second)
|
||||
chatResolver := chat.NewResolver(modelsService, queries, cfg.AgentGateway.BaseURL(), 30*time.Second)
|
||||
|
||||
// Create LLM client for memory operations using chat provider
|
||||
var llmClient memory.LLM
|
||||
memoryModel, memoryProvider, err := models.SelectMemoryModel(ctx, modelsService, queries)
|
||||
if err != nil {
|
||||
log.Fatalf("select memory model: %v\nPlease configure at least one chat model in the database.", err)
|
||||
// Create LLM client for memory operations (deferred model/provider selection).
|
||||
var llmClient memory.LLM = &lazyLLMClient{
|
||||
modelsService: modelsService,
|
||||
queries: queries,
|
||||
timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
log.Printf("Using memory model: %s (provider: %s)", memoryModel.ModelID, memoryProvider.ClientType)
|
||||
provider, err := chat.CreateProvider(memoryProvider, 30*time.Second)
|
||||
if err != nil {
|
||||
log.Fatalf("create memory provider: %v", err)
|
||||
}
|
||||
llmClient = memory.NewProviderLLMClient(provider, memoryModel.ModelID)
|
||||
|
||||
resolver := embeddings.NewResolver(modelsService, queries, 10*time.Second)
|
||||
vectors, textModel, multimodalModel, hasModels, err := embeddings.CollectEmbeddingVectors(ctx, modelsService)
|
||||
if err != nil {
|
||||
@@ -154,9 +148,46 @@ func main() {
|
||||
providersService := providers.NewService(queries)
|
||||
providersHandler := handlers.NewProvidersHandler(providersService)
|
||||
modelsHandler := handlers.NewModelsHandler(modelsService)
|
||||
srv := server.NewServer(addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, swaggerHandler, chatHandler, providersHandler, modelsHandler, containerdHandler)
|
||||
srv := server.NewServer(addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, containerdHandler)
|
||||
|
||||
if err := srv.Start(); err != nil {
|
||||
log.Fatalf("server failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type lazyLLMClient struct {
|
||||
modelsService *models.Service
|
||||
queries *dbsqlc.Queries
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func (c *lazyLLMClient) Extract(ctx context.Context, req memory.ExtractRequest) (memory.ExtractResponse, error) {
|
||||
client, err := c.resolve(ctx)
|
||||
if err != nil {
|
||||
return memory.ExtractResponse{}, err
|
||||
}
|
||||
return client.Extract(ctx, req)
|
||||
}
|
||||
|
||||
func (c *lazyLLMClient) Decide(ctx context.Context, req memory.DecideRequest) (memory.DecideResponse, error) {
|
||||
client, err := c.resolve(ctx)
|
||||
if err != nil {
|
||||
return memory.DecideResponse{}, err
|
||||
}
|
||||
return client.Decide(ctx, req)
|
||||
}
|
||||
|
||||
func (c *lazyLLMClient) resolve(ctx context.Context) (memory.LLM, error) {
|
||||
if c.modelsService == nil || c.queries == nil {
|
||||
return nil, fmt.Errorf("models service not configured")
|
||||
}
|
||||
memoryModel, memoryProvider, err := models.SelectMemoryModel(ctx, c.modelsService, c.queries)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientType := strings.ToLower(strings.TrimSpace(memoryProvider.ClientType))
|
||||
if clientType != "openai" && clientType != "openai-compat" {
|
||||
return nil, fmt.Errorf("memory provider client type not supported: %s", memoryProvider.ClientType)
|
||||
}
|
||||
return memory.NewLLMClient(memoryProvider.BaseUrl, memoryProvider.ApiKey, memoryModel.ModelID, c.timeout), nil
|
||||
}
|
||||
|
||||
+6
-1
@@ -33,4 +33,9 @@ sslmode = "disable"
|
||||
base_url = "http://127.0.0.1:6334"
|
||||
api_key = ""
|
||||
collection = "memory"
|
||||
timeout_seconds = 10
|
||||
timeout_seconds = 10
|
||||
|
||||
## Agent Gateway
|
||||
[agent_gateway]
|
||||
host = "127.0.0.1"
|
||||
port = 8081
|
||||
@@ -130,3 +130,13 @@ CREATE TABLE IF NOT EXISTS lifecycle_events (
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_lifecycle_events_container_id ON lifecycle_events(container_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_lifecycle_events_event_type ON lifecycle_events(event_type);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS history (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
messages JSONB NOT NULL,
|
||||
timestamp TIMESTAMPTZ NOT NULL,
|
||||
"user" UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_history_user ON history("user");
|
||||
CREATE INDEX IF NOT EXISTS idx_history_timestamp ON history(timestamp);
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
-- name: CreateHistory :one
|
||||
INSERT INTO history (messages, timestamp, "user")
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, messages, timestamp, "user";
|
||||
|
||||
-- name: ListHistoryByUserSince :many
|
||||
SELECT id, messages, timestamp, "user"
|
||||
FROM history
|
||||
WHERE "user" = $1 AND timestamp >= $2
|
||||
ORDER BY timestamp ASC;
|
||||
|
||||
+38
-59
@@ -135,7 +135,11 @@ const docTemplate = `{
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/chat.StreamChunk"
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer",
|
||||
"format": "int32"
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
@@ -1431,71 +1435,53 @@ const docTemplate = `{
|
||||
"chat.ChatRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"current_platform": {
|
||||
"type": "string"
|
||||
},
|
||||
"language": {
|
||||
"type": "string"
|
||||
},
|
||||
"locale": {
|
||||
"type": "string"
|
||||
},
|
||||
"max_context_load_time": {
|
||||
"type": "integer"
|
||||
},
|
||||
"max_steps": {
|
||||
"type": "integer"
|
||||
},
|
||||
"messages": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/chat.Message"
|
||||
"$ref": "#/definitions/chat.GatewayMessage"
|
||||
}
|
||||
},
|
||||
"model": {
|
||||
"description": "optional: specific model to use",
|
||||
"type": "string"
|
||||
},
|
||||
"platforms": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"provider": {
|
||||
"description": "optional: specific provider to use",
|
||||
"type": "string"
|
||||
},
|
||||
"stream": {
|
||||
"type": "boolean"
|
||||
"query": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"chat.ChatResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"finish_reason": {
|
||||
"type": "string"
|
||||
},
|
||||
"message": {
|
||||
"$ref": "#/definitions/chat.Message"
|
||||
},
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider": {
|
||||
"type": "string"
|
||||
},
|
||||
"usage": {
|
||||
"$ref": "#/definitions/chat.Usage"
|
||||
}
|
||||
}
|
||||
},
|
||||
"chat.Message": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string"
|
||||
},
|
||||
"role": {
|
||||
"description": "user, assistant, system",
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"chat.StreamChunk": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"delta": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string"
|
||||
}
|
||||
"messages": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/chat.GatewayMessage"
|
||||
}
|
||||
},
|
||||
"finish_reason": {
|
||||
"type": "string"
|
||||
},
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
@@ -1504,19 +1490,9 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"chat.Usage": {
|
||||
"chat.GatewayMessage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"completion_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"prompt_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_tokens": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
"additionalProperties": true
|
||||
},
|
||||
"handlers.CreateContainerRequest": {
|
||||
"type": "object",
|
||||
@@ -1524,6 +1500,9 @@ const docTemplate = `{
|
||||
"container_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"image": {
|
||||
"type": "string"
|
||||
},
|
||||
"snapshotter": {
|
||||
"type": "string"
|
||||
}
|
||||
|
||||
+38
-59
@@ -124,7 +124,11 @@
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/chat.StreamChunk"
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer",
|
||||
"format": "int32"
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
@@ -1420,71 +1424,53 @@
|
||||
"chat.ChatRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"current_platform": {
|
||||
"type": "string"
|
||||
},
|
||||
"language": {
|
||||
"type": "string"
|
||||
},
|
||||
"locale": {
|
||||
"type": "string"
|
||||
},
|
||||
"max_context_load_time": {
|
||||
"type": "integer"
|
||||
},
|
||||
"max_steps": {
|
||||
"type": "integer"
|
||||
},
|
||||
"messages": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/chat.Message"
|
||||
"$ref": "#/definitions/chat.GatewayMessage"
|
||||
}
|
||||
},
|
||||
"model": {
|
||||
"description": "optional: specific model to use",
|
||||
"type": "string"
|
||||
},
|
||||
"platforms": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"provider": {
|
||||
"description": "optional: specific provider to use",
|
||||
"type": "string"
|
||||
},
|
||||
"stream": {
|
||||
"type": "boolean"
|
||||
"query": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"chat.ChatResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"finish_reason": {
|
||||
"type": "string"
|
||||
},
|
||||
"message": {
|
||||
"$ref": "#/definitions/chat.Message"
|
||||
},
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider": {
|
||||
"type": "string"
|
||||
},
|
||||
"usage": {
|
||||
"$ref": "#/definitions/chat.Usage"
|
||||
}
|
||||
}
|
||||
},
|
||||
"chat.Message": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string"
|
||||
},
|
||||
"role": {
|
||||
"description": "user, assistant, system",
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"chat.StreamChunk": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"delta": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string"
|
||||
}
|
||||
"messages": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/chat.GatewayMessage"
|
||||
}
|
||||
},
|
||||
"finish_reason": {
|
||||
"type": "string"
|
||||
},
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
@@ -1493,19 +1479,9 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"chat.Usage": {
|
||||
"chat.GatewayMessage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"completion_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"prompt_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_tokens": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
"additionalProperties": true
|
||||
},
|
||||
"handlers.CreateContainerRequest": {
|
||||
"type": "object",
|
||||
@@ -1513,6 +1489,9 @@
|
||||
"container_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"image": {
|
||||
"type": "string"
|
||||
},
|
||||
"snapshotter": {
|
||||
"type": "string"
|
||||
}
|
||||
|
||||
+29
-42
@@ -1,67 +1,51 @@
|
||||
definitions:
|
||||
chat.ChatRequest:
|
||||
properties:
|
||||
current_platform:
|
||||
type: string
|
||||
language:
|
||||
type: string
|
||||
locale:
|
||||
type: string
|
||||
max_context_load_time:
|
||||
type: integer
|
||||
max_steps:
|
||||
type: integer
|
||||
messages:
|
||||
items:
|
||||
$ref: '#/definitions/chat.Message'
|
||||
$ref: '#/definitions/chat.GatewayMessage'
|
||||
type: array
|
||||
model:
|
||||
description: 'optional: specific model to use'
|
||||
type: string
|
||||
platforms:
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
provider:
|
||||
description: 'optional: specific provider to use'
|
||||
type: string
|
||||
stream:
|
||||
type: boolean
|
||||
query:
|
||||
type: string
|
||||
type: object
|
||||
chat.ChatResponse:
|
||||
properties:
|
||||
finish_reason:
|
||||
type: string
|
||||
message:
|
||||
$ref: '#/definitions/chat.Message'
|
||||
model:
|
||||
type: string
|
||||
provider:
|
||||
type: string
|
||||
usage:
|
||||
$ref: '#/definitions/chat.Usage'
|
||||
type: object
|
||||
chat.Message:
|
||||
properties:
|
||||
content:
|
||||
type: string
|
||||
role:
|
||||
description: user, assistant, system
|
||||
type: string
|
||||
type: object
|
||||
chat.StreamChunk:
|
||||
properties:
|
||||
delta:
|
||||
properties:
|
||||
content:
|
||||
type: string
|
||||
type: object
|
||||
finish_reason:
|
||||
type: string
|
||||
messages:
|
||||
items:
|
||||
$ref: '#/definitions/chat.GatewayMessage'
|
||||
type: array
|
||||
model:
|
||||
type: string
|
||||
provider:
|
||||
type: string
|
||||
type: object
|
||||
chat.Usage:
|
||||
properties:
|
||||
completion_tokens:
|
||||
type: integer
|
||||
prompt_tokens:
|
||||
type: integer
|
||||
total_tokens:
|
||||
type: integer
|
||||
chat.GatewayMessage:
|
||||
additionalProperties: true
|
||||
type: object
|
||||
handlers.CreateContainerRequest:
|
||||
properties:
|
||||
container_id:
|
||||
type: string
|
||||
image:
|
||||
type: string
|
||||
snapshotter:
|
||||
type: string
|
||||
type: object
|
||||
@@ -553,7 +537,10 @@ paths:
|
||||
"200":
|
||||
description: OK
|
||||
schema:
|
||||
$ref: '#/definitions/chat.StreamChunk'
|
||||
items:
|
||||
format: int32
|
||||
type: integer
|
||||
type: array
|
||||
"400":
|
||||
description: Bad Request
|
||||
schema:
|
||||
|
||||
@@ -1,213 +0,0 @@
|
||||
# Chat Provider 架构文档
|
||||
|
||||
## 概述
|
||||
|
||||
本文档描述了 Memoh 项目中统一的 Chat Provider 架构,该架构被 chat 服务和 memory 服务共同使用。
|
||||
|
||||
## 架构设计
|
||||
|
||||
### 核心接口
|
||||
|
||||
#### Provider 接口
|
||||
|
||||
所有 LLM 提供商都实现 `chat.Provider` 接口:
|
||||
|
||||
```go
|
||||
type Provider interface {
|
||||
Chat(ctx context.Context, req Request) (Result, error)
|
||||
StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error)
|
||||
}
|
||||
```
|
||||
|
||||
#### Request 结构
|
||||
|
||||
Provider 请求支持多种配置选项:
|
||||
|
||||
```go
|
||||
type Request struct {
|
||||
Messages []Message
|
||||
Model string
|
||||
Provider string
|
||||
Temperature *float32 // 可选:温度参数
|
||||
ResponseFormat *ResponseFormat // 可选:响应格式(JSON 模式)
|
||||
MaxTokens *int // 可选:最大 token 数
|
||||
}
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string // "json_object" 或 "text"
|
||||
}
|
||||
```
|
||||
|
||||
### 支持的提供商
|
||||
|
||||
1. **OpenAI** (`openai` / `openai-compat`)
|
||||
- 标准 OpenAI API
|
||||
- 兼容 OpenAI 格式的自定义端点
|
||||
|
||||
2. **Anthropic** (`anthropic`)
|
||||
- Claude 系列模型
|
||||
|
||||
3. **Google** (`google`)
|
||||
- Gemini 系列模型
|
||||
|
||||
4. **Ollama** (`ollama`)
|
||||
- 本地部署的开源模型
|
||||
|
||||
## 使用场景
|
||||
|
||||
### 1. Chat 服务
|
||||
|
||||
Chat 服务通过 `chat.Resolver` 使用 Provider:
|
||||
|
||||
```go
|
||||
chatResolver := chat.NewResolver(modelsService, queries, 30*time.Second)
|
||||
response, err := chatResolver.Chat(ctx, ChatRequest{
|
||||
Messages: messages,
|
||||
Model: "gpt-4",
|
||||
})
|
||||
```
|
||||
|
||||
### 2. Memory 服务
|
||||
|
||||
Memory 服务通过 `memory.ProviderLLMClient` 使用 Provider:
|
||||
|
||||
```go
|
||||
// 创建 provider
|
||||
provider, err := chat.NewOpenAIProvider(apiKey, baseURL, timeout)
|
||||
|
||||
// 创建 memory LLM 客户端
|
||||
llmClient := memory.NewProviderLLMClient(provider, modelID)
|
||||
|
||||
// 使用 memory 服务
|
||||
memoryService := memory.NewService(llmClient, embedder, store, resolver, ...)
|
||||
```
|
||||
|
||||
Memory 服务需要两个核心功能:
|
||||
- **Extract**: 从对话中提取事实信息
|
||||
- **Decide**: 决定如何更新记忆(添加/更新/删除)
|
||||
|
||||
这两个操作都使用 JSON 模式来确保结构化输出。
|
||||
|
||||
## 配置示例
|
||||
|
||||
### 数据库配置
|
||||
|
||||
Provider 配置存储在 `llm_providers` 表:
|
||||
|
||||
```sql
|
||||
CREATE TABLE llm_providers (
|
||||
id UUID PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
client_type TEXT NOT NULL, -- 'openai', 'anthropic', 'google', 'ollama'
|
||||
base_url TEXT NOT NULL,
|
||||
api_key TEXT NOT NULL,
|
||||
metadata JSONB
|
||||
);
|
||||
```
|
||||
|
||||
模型配置存储在 `models` 表:
|
||||
|
||||
```sql
|
||||
CREATE TABLE models (
|
||||
id UUID PRIMARY KEY,
|
||||
model_id TEXT NOT NULL,
|
||||
name TEXT,
|
||||
llm_provider_id UUID REFERENCES llm_providers(id),
|
||||
type TEXT NOT NULL, -- 'chat' or 'embedding'
|
||||
enable_as TEXT, -- 'chat', 'memory', 'embedding'
|
||||
...
|
||||
);
|
||||
```
|
||||
|
||||
### 启动时初始化
|
||||
|
||||
在 `cmd/agent/main.go` 中:
|
||||
|
||||
```go
|
||||
// 1. 初始化 chat resolver(用于 chat 和 memory)
|
||||
chatResolver := chat.NewResolver(modelsService, queries, 30*time.Second)
|
||||
|
||||
// 2. 为 memory 选择模型和创建 provider
|
||||
memoryModel, memoryProvider, err := selectMemoryModel(ctx, modelsService, queries, cfg)
|
||||
provider, err := createChatProvider(memoryProvider, 30*time.Second)
|
||||
|
||||
// 3. 创建 memory LLM 客户端
|
||||
llmClient := memory.NewProviderLLMClient(provider, memoryModel.ModelID)
|
||||
|
||||
// 4. 创建 memory 服务
|
||||
memoryService := memory.NewService(llmClient, embedder, store, resolver, ...)
|
||||
```
|
||||
|
||||
## 模型选择策略
|
||||
|
||||
### Memory 模型选择优先级
|
||||
|
||||
1. `enable_as = 'memory'` 的模型(专用 memory 模型)
|
||||
2. `enable_as = 'chat'` 的模型(通用 chat 模型)
|
||||
3. 任何可用的 chat 类型模型
|
||||
4. 回退到配置文件中的 LLMClient(向后兼容)
|
||||
|
||||
### Chat 模型选择优先级
|
||||
|
||||
1. 请求中指定的模型
|
||||
2. `enable_as = 'chat'` 的模型
|
||||
3. 任何可用的 chat 类型模型
|
||||
|
||||
## 优势
|
||||
|
||||
1. **统一架构**: Chat 和 Memory 使用相同的 Provider 接口
|
||||
2. **灵活配置**: 支持多个提供商和模型
|
||||
3. **向后兼容**: 保留旧的 LLMClient 作为回退选项
|
||||
4. **类型安全**: 使用 Go 接口确保类型安全
|
||||
5. **易于扩展**: 添加新的提供商只需实现 Provider 接口
|
||||
|
||||
## 扩展新提供商
|
||||
|
||||
要添加新的 LLM 提供商:
|
||||
|
||||
1. 在 `internal/chat/` 创建新文件(如 `newprovider.go`)
|
||||
2. 实现 `Provider` 接口
|
||||
3. 在 `resolver.go` 的 `createProvider()` 中添加新的 case
|
||||
4. 在数据库的 `llm_providers_client_type_check` 约束中添加新类型
|
||||
|
||||
示例:
|
||||
|
||||
```go
|
||||
// newprovider.go
|
||||
type NewProvider struct {
|
||||
apiKey string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewNewProvider(apiKey string, timeout time.Duration) (*NewProvider, error) {
|
||||
return &NewProvider{apiKey: apiKey, timeout: timeout}, nil
|
||||
}
|
||||
|
||||
func (p *NewProvider) Chat(ctx context.Context, req Request) (Result, error) {
|
||||
// 实现 chat 逻辑
|
||||
}
|
||||
|
||||
func (p *NewProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) {
|
||||
// 实现流式 chat 逻辑
|
||||
}
|
||||
```
|
||||
|
||||
## 迁移指南
|
||||
|
||||
从旧的 TypeScript 后端迁移到 Go:
|
||||
|
||||
1. ✅ 创建 Provider 接口和实现
|
||||
2. ✅ 实现 Chat Resolver
|
||||
3. ✅ 创建 Memory 的 Provider 适配器
|
||||
4. ✅ 更新主程序使用统一 Provider
|
||||
5. 🚧 实现各个 Provider 的具体逻辑(OpenAI, Anthropic, Google, Ollama)
|
||||
6. 🚧 添加流式响应支持
|
||||
7. 🚧 添加完整的错误处理和重试机制
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **JSON 模式**: Memory 操作需要 `ResponseFormat.Type = "json_object"` 来确保结构化输出
|
||||
2. **温度参数**: Memory 操作使用 `Temperature = 0` 确保确定性输出
|
||||
3. **超时设置**: 不同操作可能需要不同的超时时间
|
||||
4. **错误处理**: Provider 应该返回清晰的错误信息,包括 API 错误详情
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/firebase/genkit/go/genkit"
|
||||
)
|
||||
|
||||
// AnthropicProvider wraps Genkit's Anthropic provider
|
||||
type AnthropicProvider struct {
|
||||
g *genkit.Genkit
|
||||
apiKey string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewAnthropicProvider(apiKey string, timeout time.Duration) (*AnthropicProvider, error) {
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
return &AnthropicProvider{
|
||||
apiKey: apiKey,
|
||||
timeout: timeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *AnthropicProvider) Chat(ctx context.Context, req Request) (Result, error) {
|
||||
return Result{}, fmt.Errorf("anthropic provider not yet implemented")
|
||||
}
|
||||
|
||||
func (p *AnthropicProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) {
|
||||
chunkChan := make(chan StreamChunk, 10)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
defer close(chunkChan)
|
||||
defer close(errChan)
|
||||
errChan <- fmt.Errorf("anthropic streaming not yet implemented")
|
||||
}()
|
||||
|
||||
return chunkChan, errChan
|
||||
}
|
||||
@@ -1,2 +1 @@
|
||||
package chat
|
||||
|
||||
package chat
|
||||
@@ -1,39 +0,0 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||
)
|
||||
|
||||
// CreateProvider creates a chat provider instance.
|
||||
func CreateProvider(provider dbsqlc.LlmProvider, timeout time.Duration) (Provider, error) {
|
||||
clientType := strings.ToLower(strings.TrimSpace(provider.ClientType))
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
switch clientType {
|
||||
case ProviderOpenAI, ProviderOpenAICompat:
|
||||
if strings.TrimSpace(provider.ApiKey) == "" {
|
||||
return nil, fmt.Errorf("openai api key is required")
|
||||
}
|
||||
return NewOpenAIProvider(provider.ApiKey, provider.BaseUrl, timeout)
|
||||
case ProviderAnthropic:
|
||||
if strings.TrimSpace(provider.ApiKey) == "" {
|
||||
return nil, fmt.Errorf("anthropic api key is required")
|
||||
}
|
||||
return NewAnthropicProvider(provider.ApiKey, timeout)
|
||||
case ProviderGoogle:
|
||||
if strings.TrimSpace(provider.ApiKey) == "" {
|
||||
return nil, fmt.Errorf("google api key is required")
|
||||
}
|
||||
return NewGoogleProvider(provider.ApiKey, timeout)
|
||||
case ProviderOllama:
|
||||
return NewOllamaProvider(provider.BaseUrl, timeout)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provider type: %s", clientType)
|
||||
}
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/firebase/genkit/go/genkit"
|
||||
)
|
||||
|
||||
// GoogleProvider wraps Genkit's Google AI provider
|
||||
type GoogleProvider struct {
|
||||
g *genkit.Genkit
|
||||
apiKey string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewGoogleProvider(apiKey string, timeout time.Duration) (*GoogleProvider, error) {
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
return &GoogleProvider{
|
||||
apiKey: apiKey,
|
||||
timeout: timeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *GoogleProvider) Chat(ctx context.Context, req Request) (Result, error) {
|
||||
return Result{}, fmt.Errorf("google provider not yet implemented")
|
||||
}
|
||||
|
||||
func (p *GoogleProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) {
|
||||
chunkChan := make(chan StreamChunk, 10)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
defer close(chunkChan)
|
||||
defer close(errChan)
|
||||
errChan <- fmt.Errorf("google streaming not yet implemented")
|
||||
}()
|
||||
|
||||
return chunkChan, errChan
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/firebase/genkit/go/genkit"
|
||||
)
|
||||
|
||||
// OllamaProvider wraps Genkit's Ollama provider
|
||||
type OllamaProvider struct {
|
||||
g *genkit.Genkit
|
||||
baseURL string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewOllamaProvider(baseURL string, timeout time.Duration) (*OllamaProvider, error) {
|
||||
if baseURL == "" {
|
||||
baseURL = "http://localhost:11434"
|
||||
}
|
||||
baseURL = strings.TrimRight(baseURL, "/")
|
||||
if timeout <= 0 {
|
||||
timeout = 60 * time.Second
|
||||
}
|
||||
return &OllamaProvider{
|
||||
baseURL: baseURL,
|
||||
timeout: timeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *OllamaProvider) Chat(ctx context.Context, req Request) (Result, error) {
|
||||
return Result{}, fmt.Errorf("ollama provider not yet implemented")
|
||||
}
|
||||
|
||||
func (p *OllamaProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) {
|
||||
chunkChan := make(chan StreamChunk, 10)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
defer close(chunkChan)
|
||||
defer close(errChan)
|
||||
errChan <- fmt.Errorf("ollama streaming not yet implemented")
|
||||
}()
|
||||
|
||||
return chunkChan, errChan
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/firebase/genkit/go/genkit"
|
||||
)
|
||||
|
||||
// OpenAIProvider wraps Genkit's OpenAI provider
|
||||
type OpenAIProvider struct {
|
||||
g *genkit.Genkit
|
||||
apiKey string
|
||||
baseURL string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewOpenAIProvider(apiKey, baseURL string, timeout time.Duration) (*OpenAIProvider, error) {
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
baseURL = strings.TrimRight(baseURL, "/")
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
// For now, we'll create a simple HTTP client-based implementation
|
||||
// since Genkit Go plugins require initialization at startup
|
||||
return &OpenAIProvider{
|
||||
apiKey: apiKey,
|
||||
baseURL: baseURL,
|
||||
timeout: timeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) Chat(ctx context.Context, req Request) (Result, error) {
|
||||
// Use direct HTTP API call since Genkit plugins need to be initialized at startup
|
||||
return Result{}, fmt.Errorf("openai provider not yet implemented - please use openai-compat provider")
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) {
|
||||
chunkChan := make(chan StreamChunk, 10)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
defer close(chunkChan)
|
||||
defer close(errChan)
|
||||
errChan <- fmt.Errorf("openai streaming not yet implemented")
|
||||
}()
|
||||
|
||||
return chunkChan, errChan
|
||||
}
|
||||
@@ -1,141 +0,0 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PromptParams contains parameters for generating system prompts
|
||||
type PromptParams struct {
|
||||
Date time.Time
|
||||
Locale string
|
||||
Language string
|
||||
MaxContextLoadTime int // in minutes
|
||||
Platforms []string // available platforms (e.g., ["telegram", "wechat"])
|
||||
CurrentPlatform string // current platform the user is using
|
||||
}
|
||||
|
||||
// SystemPrompt generates the system prompt for the AI assistant
|
||||
// This is migrated from packages/agent/src/prompts/system.ts
|
||||
func SystemPrompt(params PromptParams) string {
|
||||
if params.Language == "" {
|
||||
params.Language = "Same as user input"
|
||||
}
|
||||
if params.MaxContextLoadTime == 0 {
|
||||
params.MaxContextLoadTime = 24 * 60 // 24 hours default
|
||||
}
|
||||
if params.CurrentPlatform == "" {
|
||||
params.CurrentPlatform = "client"
|
||||
}
|
||||
|
||||
// Build platforms list
|
||||
platformsList := ""
|
||||
if len(params.Platforms) > 0 {
|
||||
lines := make([]string, len(params.Platforms))
|
||||
for i, p := range params.Platforms {
|
||||
lines[i] = fmt.Sprintf(" - %s", p)
|
||||
}
|
||||
platformsList = strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
timeStr := FormatTime(params.Date, params.Locale)
|
||||
|
||||
return fmt.Sprintf(`---
|
||||
%s
|
||||
language: %s
|
||||
available-platforms:
|
||||
%s
|
||||
current-platform: %s
|
||||
---
|
||||
You are a personal housekeeper assistant, which able to manage the master's daily affairs.
|
||||
|
||||
Your abilities:
|
||||
- Long memory: You possess long-term memory; conversations from the last %d minutes will be directly loaded into your context. Additionally, you can use tools to search for past memories.
|
||||
- Scheduled tasks: You can create scheduled tasks to automatically remind you to do something.
|
||||
- Messaging: You may allowed to use message software to send messages to the master.
|
||||
|
||||
**Response Guidelines**
|
||||
- Always respond in the language specified above, unless it says "Same as user input", then match the user's language.
|
||||
- Be helpful, concise, and friendly.
|
||||
- For complex questions, break down your answer into clear steps.
|
||||
- If you're unsure about something, acknowledge it honestly.`,
|
||||
timeStr,
|
||||
params.Language,
|
||||
platformsList,
|
||||
params.CurrentPlatform,
|
||||
params.MaxContextLoadTime,
|
||||
)
|
||||
}
|
||||
|
||||
// SchedulePrompt generates a prompt for scheduled task execution
|
||||
// This is migrated from packages/agent/src/prompts/schedule.ts
|
||||
type SchedulePromptParams struct {
|
||||
Date time.Time
|
||||
Locale string
|
||||
ScheduleName string
|
||||
ScheduleDescription string
|
||||
ScheduleID string
|
||||
MaxCalls *int // nil means unlimited
|
||||
CronPattern string
|
||||
Command string // the natural language command to execute
|
||||
}
|
||||
|
||||
func SchedulePrompt(params SchedulePromptParams) string {
|
||||
timeStr := FormatTime(params.Date, params.Locale)
|
||||
|
||||
maxCallsStr := "Unlimited"
|
||||
if params.MaxCalls != nil {
|
||||
maxCallsStr = fmt.Sprintf("%d", *params.MaxCalls)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`---
|
||||
notice: **This is a scheduled task automatically send to you by the system, not the user input**
|
||||
%s
|
||||
schedule-name: %s
|
||||
schedule-description: %s
|
||||
schedule-id: %s
|
||||
max-calls: %s
|
||||
cron-pattern: %s
|
||||
---
|
||||
|
||||
**COMMAND**
|
||||
|
||||
%s`,
|
||||
timeStr,
|
||||
params.ScheduleName,
|
||||
params.ScheduleDescription,
|
||||
params.ScheduleID,
|
||||
maxCallsStr,
|
||||
params.CronPattern,
|
||||
params.Command,
|
||||
)
|
||||
}
|
||||
|
||||
// FormatTime formats the date and time according to locale
|
||||
func FormatTime(date time.Time, locale string) string {
|
||||
if locale == "" {
|
||||
locale = "en-US"
|
||||
}
|
||||
|
||||
// Format date and time
|
||||
// For simplicity, using standard format. In production, you might want to use
|
||||
// a proper i18n library for locale-specific formatting
|
||||
dateStr := date.Format("2006-01-02")
|
||||
timeStr := date.Format("15:04:05")
|
||||
|
||||
return fmt.Sprintf("date: %s\ntime: %s", dateStr, timeStr)
|
||||
}
|
||||
|
||||
// Quote wraps content in backticks for markdown code formatting
|
||||
func Quote(content string) string {
|
||||
return fmt.Sprintf("`%s`", content)
|
||||
}
|
||||
|
||||
// Block wraps content in code block with optional language tag
|
||||
func Block(content, tag string) string {
|
||||
if tag == "" {
|
||||
return fmt.Sprintf("```\n%s\n```", content)
|
||||
}
|
||||
return fmt.Sprintf("```%s\n%s\n```", tag, content)
|
||||
}
|
||||
+335
-174
@@ -1,8 +1,13 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -13,169 +18,359 @@ import (
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderOpenAI = "openai"
|
||||
ProviderOpenAICompat = "openai-compat"
|
||||
ProviderAnthropic = "anthropic"
|
||||
ProviderGoogle = "google"
|
||||
ProviderOllama = "ollama"
|
||||
)
|
||||
const defaultMaxContextMinutes = 24 * 60
|
||||
|
||||
// Provider interface for chat providers
|
||||
type Provider interface {
|
||||
Chat(ctx context.Context, req Request) (Result, error)
|
||||
StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error)
|
||||
}
|
||||
|
||||
// Resolver resolves chat models and providers
|
||||
type Resolver struct {
|
||||
modelsService *models.Service
|
||||
queries *sqlc.Queries
|
||||
timeout time.Duration
|
||||
modelsService *models.Service
|
||||
queries *sqlc.Queries
|
||||
gatewayBaseURL string
|
||||
timeout time.Duration
|
||||
httpClient *http.Client
|
||||
streamingClient *http.Client
|
||||
}
|
||||
|
||||
// NewResolver creates a new chat resolver
|
||||
func NewResolver(modelsService *models.Service, queries *sqlc.Queries, timeout time.Duration) *Resolver {
|
||||
func NewResolver(modelsService *models.Service, queries *sqlc.Queries, gatewayBaseURL string, timeout time.Duration) *Resolver {
|
||||
if strings.TrimSpace(gatewayBaseURL) == "" {
|
||||
gatewayBaseURL = "http://127.0.0.1:8081"
|
||||
}
|
||||
gatewayBaseURL = strings.TrimRight(gatewayBaseURL, "/")
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
return &Resolver{
|
||||
modelsService: modelsService,
|
||||
queries: queries,
|
||||
timeout: timeout,
|
||||
modelsService: modelsService,
|
||||
queries: queries,
|
||||
gatewayBaseURL: gatewayBaseURL,
|
||||
timeout: timeout,
|
||||
httpClient: &http.Client{
|
||||
Timeout: timeout,
|
||||
},
|
||||
streamingClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
// Chat performs a chat completion
|
||||
func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, error) {
|
||||
// Select model and provider
|
||||
selected, provider, err := r.selectChatModel(ctx, req.Model, req.Provider)
|
||||
if strings.TrimSpace(req.Query) == "" {
|
||||
return ChatResponse{}, fmt.Errorf("query is required")
|
||||
}
|
||||
if strings.TrimSpace(req.UserID) == "" {
|
||||
return ChatResponse{}, fmt.Errorf("user id is required")
|
||||
}
|
||||
|
||||
chatModel, provider, err := r.selectChatModel(ctx, req)
|
||||
if err != nil {
|
||||
return ChatResponse{}, err
|
||||
}
|
||||
clientType, err := normalizeClientType(provider.ClientType)
|
||||
if err != nil {
|
||||
return ChatResponse{}, err
|
||||
}
|
||||
|
||||
// Create internal request
|
||||
internalReq := Request{
|
||||
Messages: req.Messages,
|
||||
Model: selected.ModelID,
|
||||
Provider: strings.ToLower(provider.ClientType),
|
||||
messages, err := r.loadHistoryMessages(ctx, req.UserID, req.MaxContextLoadTime)
|
||||
if err != nil {
|
||||
return ChatResponse{}, err
|
||||
}
|
||||
if len(req.Messages) > 0 {
|
||||
messages = append(messages, req.Messages...)
|
||||
}
|
||||
|
||||
// Add system prompt
|
||||
if len(internalReq.Messages) > 0 && internalReq.Messages[0].Role != "system" {
|
||||
systemPrompt := SystemPrompt(PromptParams{
|
||||
Date: time.Now(),
|
||||
Locale: "en-US",
|
||||
Language: "Same as user input",
|
||||
MaxContextLoadTime: 24 * 60, // 24 hours
|
||||
Platforms: []string{},
|
||||
CurrentPlatform: "api",
|
||||
})
|
||||
internalReq.Messages = append([]Message{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
}, internalReq.Messages...)
|
||||
payload := agentGatewayRequest{
|
||||
APIKey: provider.ApiKey,
|
||||
BaseURL: provider.BaseUrl,
|
||||
Model: chatModel.ModelID,
|
||||
ClientType: clientType,
|
||||
Locale: req.Locale,
|
||||
Language: req.Language,
|
||||
MaxSteps: req.MaxSteps,
|
||||
MaxContextLoadTime: normalizeMaxContextLoad(req.MaxContextLoadTime),
|
||||
Platforms: req.Platforms,
|
||||
CurrentPlatform: req.CurrentPlatform,
|
||||
Messages: messages,
|
||||
Query: req.Query,
|
||||
}
|
||||
|
||||
// Create provider instance
|
||||
providerInst, err := r.createProvider(provider)
|
||||
resp, err := r.postChat(ctx, payload)
|
||||
if err != nil {
|
||||
return ChatResponse{}, err
|
||||
}
|
||||
|
||||
// Execute chat
|
||||
result, err := providerInst.Chat(ctx, internalReq)
|
||||
if err != nil {
|
||||
if err := r.storeHistory(ctx, req.UserID, req.Query, resp.Messages); err != nil {
|
||||
return ChatResponse{}, err
|
||||
}
|
||||
|
||||
return ChatResponse{
|
||||
Message: result.Message,
|
||||
Model: result.Model,
|
||||
Provider: result.Provider,
|
||||
FinishReason: result.FinishReason,
|
||||
Usage: result.Usage,
|
||||
Messages: resp.Messages,
|
||||
Model: chatModel.ModelID,
|
||||
Provider: provider.ClientType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StreamChat performs a streaming chat completion
|
||||
func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan StreamChunk, <-chan error) {
|
||||
chunkChan := make(chan StreamChunk, 10)
|
||||
chunkChan := make(chan StreamChunk)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
defer close(chunkChan)
|
||||
defer close(errChan)
|
||||
|
||||
// Select model and provider
|
||||
selected, provider, err := r.selectChatModel(ctx, req.Model, req.Provider)
|
||||
if strings.TrimSpace(req.Query) == "" {
|
||||
errChan <- fmt.Errorf("query is required")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.UserID) == "" {
|
||||
errChan <- fmt.Errorf("user id is required")
|
||||
return
|
||||
}
|
||||
|
||||
chatModel, provider, err := r.selectChatModel(ctx, req)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
clientType, err := normalizeClientType(provider.ClientType)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Create internal request
|
||||
internalReq := Request{
|
||||
Messages: req.Messages,
|
||||
Model: selected.ModelID,
|
||||
Provider: strings.ToLower(provider.ClientType),
|
||||
}
|
||||
|
||||
// Add system prompt
|
||||
if len(internalReq.Messages) > 0 && internalReq.Messages[0].Role != "system" {
|
||||
systemPrompt := SystemPrompt(PromptParams{
|
||||
Date: time.Now(),
|
||||
Locale: "en-US",
|
||||
Language: "Same as user input",
|
||||
MaxContextLoadTime: 24 * 60, // 24 hours
|
||||
Platforms: []string{},
|
||||
CurrentPlatform: "api",
|
||||
})
|
||||
internalReq.Messages = append([]Message{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
}, internalReq.Messages...)
|
||||
}
|
||||
|
||||
// Create provider instance
|
||||
providerInst, err := r.createProvider(provider)
|
||||
messages, err := r.loadHistoryMessages(ctx, req.UserID, req.MaxContextLoadTime)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
if len(req.Messages) > 0 {
|
||||
messages = append(messages, req.Messages...)
|
||||
}
|
||||
|
||||
// Execute streaming chat
|
||||
providerChunkChan, providerErrChan := providerInst.StreamChat(ctx, internalReq)
|
||||
payload := agentGatewayRequest{
|
||||
APIKey: provider.ApiKey,
|
||||
BaseURL: provider.BaseUrl,
|
||||
Model: chatModel.ModelID,
|
||||
ClientType: clientType,
|
||||
Locale: req.Locale,
|
||||
Language: req.Language,
|
||||
MaxSteps: req.MaxSteps,
|
||||
MaxContextLoadTime: normalizeMaxContextLoad(req.MaxContextLoadTime),
|
||||
Platforms: req.Platforms,
|
||||
CurrentPlatform: req.CurrentPlatform,
|
||||
Messages: messages,
|
||||
Query: req.Query,
|
||||
}
|
||||
|
||||
// Forward chunks and errors
|
||||
for {
|
||||
select {
|
||||
case chunk, ok := <-providerChunkChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
chunkChan <- chunk
|
||||
case err := <-providerErrChan:
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := r.streamChat(ctx, payload, req.UserID, req.Query, chunkChan); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
return chunkChan, errChan
|
||||
}
|
||||
|
||||
// selectChatModel selects a chat model based on the request
|
||||
func (r *Resolver) selectChatModel(ctx context.Context, modelID, providerType string) (models.GetResponse, sqlc.LlmProvider, error) {
|
||||
if r.modelsService == nil {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("models service not configured")
|
||||
type agentGatewayRequest struct {
|
||||
APIKey string `json:"apiKey"`
|
||||
BaseURL string `json:"baseUrl"`
|
||||
Model string `json:"model"`
|
||||
ClientType string `json:"clientType"`
|
||||
Locale string `json:"locale,omitempty"`
|
||||
Language string `json:"language,omitempty"`
|
||||
MaxSteps int `json:"maxSteps,omitempty"`
|
||||
MaxContextLoadTime int `json:"maxContextLoadTime"`
|
||||
Platforms []string `json:"platforms,omitempty"`
|
||||
CurrentPlatform string `json:"currentPlatform,omitempty"`
|
||||
Messages []GatewayMessage `json:"messages"`
|
||||
Query string `json:"query"`
|
||||
}
|
||||
|
||||
type agentGatewayResponse struct {
|
||||
Messages []GatewayMessage `json:"messages"`
|
||||
}
|
||||
|
||||
func (r *Resolver) postChat(ctx context.Context, payload agentGatewayRequest) (agentGatewayResponse, error) {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return agentGatewayResponse{}, err
|
||||
}
|
||||
url := r.gatewayBaseURL + "/chat"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return agentGatewayResponse{}, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return agentGatewayResponse{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
payload, _ := io.ReadAll(resp.Body)
|
||||
return agentGatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(payload)))
|
||||
}
|
||||
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
providerType = strings.ToLower(strings.TrimSpace(providerType))
|
||||
var parsed agentGatewayResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
|
||||
return agentGatewayResponse{}, err
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
// If no model specified, try to get default chat model
|
||||
if modelID == "" && providerType == "" {
|
||||
func (r *Resolver) streamChat(ctx context.Context, payload agentGatewayRequest, userID, query string, chunkChan chan<- StreamChunk) error {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
url := r.gatewayBaseURL + "/chat/stream"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
resp, err := r.streamingClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
payload, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(payload)))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
chunkChan <- StreamChunk([]byte(data))
|
||||
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(data), &envelope); err != nil {
|
||||
continue
|
||||
}
|
||||
if envelope.Type != "done" || len(envelope.Data) == 0 {
|
||||
continue
|
||||
}
|
||||
var parsed agentGatewayResponse
|
||||
if err := json.Unmarshal(envelope.Data, &parsed); err != nil {
|
||||
continue
|
||||
}
|
||||
if err := r.storeHistory(ctx, userID, query, parsed.Messages); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Resolver) loadHistoryMessages(ctx context.Context, userID string, maxContextLoadTime int) ([]GatewayMessage, error) {
|
||||
if r.queries == nil {
|
||||
return nil, fmt.Errorf("history queries not configured")
|
||||
}
|
||||
pgUserID, err := parseUUID(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
from := time.Now().UTC().Add(-time.Duration(normalizeMaxContextLoad(maxContextLoadTime)) * time.Minute)
|
||||
rows, err := r.queries.ListHistoryByUserSince(ctx, sqlc.ListHistoryByUserSinceParams{
|
||||
User: pgUserID,
|
||||
Timestamp: pgtype.Timestamptz{
|
||||
Time: from,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages := make([]GatewayMessage, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
var batch []GatewayMessage
|
||||
if len(row.Messages) == 0 {
|
||||
continue
|
||||
}
|
||||
if err := json.Unmarshal(row.Messages, &batch); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, batch...)
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (r *Resolver) storeHistory(ctx context.Context, userID, query string, responseMessages []GatewayMessage) error {
|
||||
if r.queries == nil {
|
||||
return fmt.Errorf("history queries not configured")
|
||||
}
|
||||
if strings.TrimSpace(userID) == "" {
|
||||
return fmt.Errorf("user id is required")
|
||||
}
|
||||
if strings.TrimSpace(query) == "" && len(responseMessages) == 0 {
|
||||
return nil
|
||||
}
|
||||
userMessage := GatewayMessage{
|
||||
"role": "user",
|
||||
"content": query,
|
||||
}
|
||||
messages := append([]GatewayMessage{userMessage}, responseMessages...)
|
||||
payload, err := json.Marshal(messages)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pgUserID, err := parseUUID(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.queries.CreateHistory(ctx, sqlc.CreateHistoryParams{
|
||||
Messages: payload,
|
||||
Timestamp: pgtype.Timestamptz{
|
||||
Time: time.Now().UTC(),
|
||||
Valid: true,
|
||||
},
|
||||
User: pgUserID,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest) (models.GetResponse, sqlc.LlmProvider, error) {
|
||||
if r.modelsService == nil {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured")
|
||||
}
|
||||
modelID := strings.TrimSpace(req.Model)
|
||||
providerFilter := strings.TrimSpace(req.Provider)
|
||||
|
||||
if modelID != "" && providerFilter == "" {
|
||||
model, err := r.modelsService.GetByModelID(ctx, modelID)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
if model.Type != models.ModelTypeChat {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model")
|
||||
}
|
||||
provider, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
return model, provider, nil
|
||||
}
|
||||
|
||||
if providerFilter == "" && modelID == "" {
|
||||
defaultModel, err := r.modelsService.GetByEnableAs(ctx, models.EnableAsChat)
|
||||
if err == nil {
|
||||
provider, err := r.fetchProvider(ctx, defaultModel.LlmProviderID)
|
||||
provider, err := models.FetchProviderByID(ctx, r.queries, defaultModel.LlmProviderID)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
@@ -183,11 +378,10 @@ func (r *Resolver) selectChatModel(ctx context.Context, modelID, providerType st
|
||||
}
|
||||
}
|
||||
|
||||
// List available models
|
||||
var candidates []models.GetResponse
|
||||
var err error
|
||||
if providerType != "" {
|
||||
candidates, err = r.modelsService.ListByClientType(ctx, models.ClientType(providerType))
|
||||
if providerFilter != "" {
|
||||
candidates, err = r.modelsService.ListByClientType(ctx, models.ClientType(providerFilter))
|
||||
} else {
|
||||
candidates, err = r.modelsService.ListByType(ctx, models.ModelTypeChat)
|
||||
}
|
||||
@@ -195,7 +389,6 @@ func (r *Resolver) selectChatModel(ctx context.Context, modelID, providerType st
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
|
||||
// Filter chat models
|
||||
filtered := make([]models.GetResponse, 0, len(candidates))
|
||||
for _, model := range candidates {
|
||||
if model.Type != models.ModelTypeChat {
|
||||
@@ -204,92 +397,60 @@ func (r *Resolver) selectChatModel(ctx context.Context, modelID, providerType st
|
||||
filtered = append(filtered, model)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("no chat models available")
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available")
|
||||
}
|
||||
|
||||
// If model specified, find it
|
||||
if modelID != "" {
|
||||
for _, model := range filtered {
|
||||
if model.ModelID == modelID {
|
||||
provider, err := r.fetchProvider(ctx, model.LlmProviderID)
|
||||
provider, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
return model, provider, nil
|
||||
}
|
||||
}
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("chat model not found")
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not found")
|
||||
}
|
||||
|
||||
// Return first available model
|
||||
selected := filtered[0]
|
||||
provider, err := r.fetchProvider(ctx, selected.LlmProviderID)
|
||||
provider, err := models.FetchProviderByID(ctx, r.queries, selected.LlmProviderID)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
return selected, provider, nil
|
||||
}
|
||||
|
||||
// fetchProvider fetches provider information from database
|
||||
func (r *Resolver) fetchProvider(ctx context.Context, providerID string) (sqlc.LlmProvider, error) {
|
||||
if r.queries == nil {
|
||||
return sqlc.LlmProvider{}, errors.New("llm provider queries not configured")
|
||||
func normalizeMaxContextLoad(value int) int {
|
||||
if value <= 0 {
|
||||
return defaultMaxContextMinutes
|
||||
}
|
||||
if strings.TrimSpace(providerID) == "" {
|
||||
return sqlc.LlmProvider{}, errors.New("llm provider id missing")
|
||||
}
|
||||
parsed, err := uuid.Parse(providerID)
|
||||
if err != nil {
|
||||
return sqlc.LlmProvider{}, err
|
||||
}
|
||||
pgID := pgtype.UUID{Valid: true}
|
||||
copy(pgID.Bytes[:], parsed[:])
|
||||
return r.queries.GetLlmProviderByID(ctx, pgID)
|
||||
return value
|
||||
}
|
||||
|
||||
// createProvider creates a provider instance based on configuration
|
||||
func (r *Resolver) createProvider(provider sqlc.LlmProvider) (Provider, error) {
|
||||
clientType := strings.ToLower(strings.TrimSpace(provider.ClientType))
|
||||
timeout := r.timeout
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
switch clientType {
|
||||
case ProviderOpenAI, ProviderOpenAICompat:
|
||||
if strings.TrimSpace(provider.ApiKey) == "" {
|
||||
return nil, errors.New("openai api key is required")
|
||||
}
|
||||
p, err := NewOpenAIProvider(provider.ApiKey, provider.BaseUrl, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
case ProviderAnthropic:
|
||||
if strings.TrimSpace(provider.ApiKey) == "" {
|
||||
return nil, errors.New("anthropic api key is required")
|
||||
}
|
||||
p, err := NewAnthropicProvider(provider.ApiKey, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
case ProviderGoogle:
|
||||
if strings.TrimSpace(provider.ApiKey) == "" {
|
||||
return nil, errors.New("google api key is required")
|
||||
}
|
||||
p, err := NewGoogleProvider(provider.ApiKey, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
case ProviderOllama:
|
||||
p, err := NewOllamaProvider(provider.BaseUrl, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
func normalizeClientType(clientType string) (string, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(clientType)) {
|
||||
case "openai":
|
||||
return "openai", nil
|
||||
case "openai-compat":
|
||||
return "openai", nil
|
||||
case "anthropic":
|
||||
return "anthropic", nil
|
||||
case "google":
|
||||
return "google", nil
|
||||
default:
|
||||
return nil, errors.New("unsupported provider type: " + clientType)
|
||||
return "", fmt.Errorf("unsupported agent gateway client type: %s", clientType)
|
||||
}
|
||||
}
|
||||
|
||||
func parseUUID(id string) (pgtype.UUID, error) {
|
||||
parsed, err := uuid.Parse(id)
|
||||
if err != nil {
|
||||
return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err)
|
||||
}
|
||||
var pgID pgtype.UUID
|
||||
pgID.Valid = true
|
||||
copy(pgID.Bytes[:], parsed[:])
|
||||
return pgID, nil
|
||||
}
|
||||
|
||||
|
||||
+20
-53
@@ -1,65 +1,32 @@
|
||||
package chat
|
||||
|
||||
// Message represents a chat message
|
||||
import "encoding/json"
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"` // user, assistant, system
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// ChatRequest represents an incoming chat request
|
||||
type GatewayMessage map[string]interface{}
|
||||
|
||||
type ChatRequest struct {
|
||||
Messages []Message `json:"messages"`
|
||||
Model string `json:"model,omitempty"` // optional: specific model to use
|
||||
Provider string `json:"provider,omitempty"` // optional: specific provider to use
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
UserID string `json:"-"`
|
||||
Query string `json:"query"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
MaxContextLoadTime int `json:"max_context_load_time,omitempty"`
|
||||
Locale string `json:"locale,omitempty"`
|
||||
Language string `json:"language,omitempty"`
|
||||
MaxSteps int `json:"max_steps,omitempty"`
|
||||
Platforms []string `json:"platforms,omitempty"`
|
||||
CurrentPlatform string `json:"current_platform,omitempty"`
|
||||
Messages []GatewayMessage `json:"messages,omitempty"`
|
||||
}
|
||||
|
||||
// ChatResponse represents a chat response
|
||||
type ChatResponse struct {
|
||||
Message Message `json:"message"`
|
||||
Model string `json:"model"`
|
||||
Provider string `json:"provider"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Usage Usage `json:"usage,omitempty"`
|
||||
Messages []GatewayMessage `json:"messages"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
}
|
||||
|
||||
// StreamChunk represents a chunk in streaming response
|
||||
type StreamChunk struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
// Usage represents token usage information
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// Request is the internal request structure
|
||||
type Request struct {
|
||||
Messages []Message
|
||||
Model string
|
||||
Provider string
|
||||
Temperature *float32 // optional temperature
|
||||
ResponseFormat *ResponseFormat // optional response format
|
||||
MaxTokens *int // optional max tokens
|
||||
}
|
||||
|
||||
// ResponseFormat specifies the format of the response
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type"` // "json_object" or "text"
|
||||
}
|
||||
|
||||
// Result is the internal result structure
|
||||
type Result struct {
|
||||
Message Message
|
||||
Model string
|
||||
Provider string
|
||||
FinishReason string
|
||||
Usage Usage
|
||||
}
|
||||
type StreamChunk = json.RawMessage
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
@@ -31,6 +32,7 @@ type Config struct {
|
||||
MCP MCPConfig `toml:"mcp"`
|
||||
Postgres PostgresConfig `toml:"postgres"`
|
||||
Qdrant QdrantConfig `toml:"qdrant"`
|
||||
AgentGateway AgentGatewayConfig `toml:"agent_gateway"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
@@ -70,6 +72,23 @@ type QdrantConfig struct {
|
||||
TimeoutSeconds int `toml:"timeout_seconds"`
|
||||
}
|
||||
|
||||
type AgentGatewayConfig struct {
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
}
|
||||
|
||||
func (c AgentGatewayConfig) BaseURL() string {
|
||||
host := c.Host
|
||||
if host == "" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
port := c.Port
|
||||
if port == 0 {
|
||||
port = 8081
|
||||
}
|
||||
return "http://" + host + ":" + fmt.Sprint(port)
|
||||
}
|
||||
|
||||
func Load(path string) (Config, error) {
|
||||
cfg := Config{
|
||||
Server: ServerConfig{
|
||||
@@ -98,6 +117,10 @@ func Load(path string) (Config, error) {
|
||||
BaseURL: DefaultQdrantURL,
|
||||
Collection: DefaultQdrantCollection,
|
||||
},
|
||||
AgentGateway: AgentGatewayConfig{
|
||||
Host: "127.0.0.1",
|
||||
Port: 8081,
|
||||
},
|
||||
}
|
||||
|
||||
if path == "" {
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: history.sql
|
||||
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const createHistory = `-- name: CreateHistory :one
|
||||
INSERT INTO history (messages, timestamp, "user")
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, messages, timestamp, "user"
|
||||
`
|
||||
|
||||
type CreateHistoryParams struct {
|
||||
Messages []byte `json:"messages"`
|
||||
Timestamp pgtype.Timestamptz `json:"timestamp"`
|
||||
User pgtype.UUID `json:"user"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateHistory(ctx context.Context, arg CreateHistoryParams) (History, error) {
|
||||
row := q.db.QueryRow(ctx, createHistory, arg.Messages, arg.Timestamp, arg.User)
|
||||
var i History
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Messages,
|
||||
&i.Timestamp,
|
||||
&i.User,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listHistoryByUserSince = `-- name: ListHistoryByUserSince :many
|
||||
SELECT id, messages, timestamp, "user"
|
||||
FROM history
|
||||
WHERE "user" = $1 AND timestamp >= $2
|
||||
ORDER BY timestamp ASC
|
||||
`
|
||||
|
||||
type ListHistoryByUserSinceParams struct {
|
||||
User pgtype.UUID `json:"user"`
|
||||
Timestamp pgtype.Timestamptz `json:"timestamp"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListHistoryByUserSince(ctx context.Context, arg ListHistoryByUserSinceParams) ([]History, error) {
|
||||
rows, err := q.db.Query(ctx, listHistoryByUserSince, arg.User, arg.Timestamp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []History
|
||||
for rows.Next() {
|
||||
var i History
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Messages,
|
||||
&i.Timestamp,
|
||||
&i.User,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
@@ -33,6 +33,13 @@ type ContainerVersion struct {
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
}
|
||||
|
||||
type History struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Messages []byte `json:"messages"`
|
||||
Timestamp pgtype.Timestamptz `json:"timestamp"`
|
||||
User pgtype.UUID `json:"user"`
|
||||
}
|
||||
|
||||
type LifecycleEvent struct {
|
||||
ID string `json:"id"`
|
||||
ContainerID string `json:"container_id"`
|
||||
|
||||
@@ -7,7 +7,10 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/chat"
|
||||
"github.com/memohai/memoh/internal/identity"
|
||||
)
|
||||
|
||||
type ChatHandler struct {
|
||||
@@ -36,14 +39,20 @@ func (h *ChatHandler) Register(e *echo.Echo) {
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /chat [post]
|
||||
func (h *ChatHandler) Chat(c echo.Context) error {
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req chat.ChatRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
if len(req.Messages) == 0 {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "messages are required")
|
||||
if req.Query == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "query is required")
|
||||
}
|
||||
req.UserID = userID
|
||||
|
||||
resp, err := h.resolver.Chat(c.Request().Context(), req)
|
||||
if err != nil {
|
||||
@@ -65,14 +74,20 @@ func (h *ChatHandler) Chat(c echo.Context) error {
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /chat/stream [post]
|
||||
func (h *ChatHandler) StreamChat(c echo.Context) error {
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req chat.ChatRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
if len(req.Messages) == 0 {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "messages are required")
|
||||
if req.Query == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "query is required")
|
||||
}
|
||||
req.UserID = userID
|
||||
|
||||
// Set headers for SSE
|
||||
c.Response().Header().Set(echo.HeaderContentType, "text/event-stream")
|
||||
@@ -127,3 +142,14 @@ func (h *ChatHandler) StreamChat(c echo.Context) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ChatHandler) requireUserID(c echo.Context) (string, error) {
|
||||
userID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := identity.ValidateUserID(userID); err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
@@ -1,129 +0,0 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/chat"
|
||||
)
|
||||
|
||||
// ProviderLLMClient uses chat.Provider to make LLM calls for memory operations
|
||||
type ProviderLLMClient struct {
|
||||
provider chat.Provider
|
||||
model string
|
||||
}
|
||||
|
||||
// NewProviderLLMClient creates a new LLM client that uses chat.Provider
|
||||
func NewProviderLLMClient(provider chat.Provider, model string) *ProviderLLMClient {
|
||||
if model == "" {
|
||||
model = "gpt-4.1-nano-2025-04-14"
|
||||
}
|
||||
return &ProviderLLMClient{
|
||||
provider: provider,
|
||||
model: model,
|
||||
}
|
||||
}
|
||||
|
||||
// Extract extracts facts from messages using the provider
|
||||
func (c *ProviderLLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) {
|
||||
if len(req.Messages) == 0 {
|
||||
return ExtractResponse{}, fmt.Errorf("messages is required")
|
||||
}
|
||||
|
||||
parsedMessages := parseMessages(formatMessages(req.Messages))
|
||||
systemPrompt, userPrompt := getFactRetrievalMessages(parsedMessages)
|
||||
|
||||
// Call provider with JSON mode
|
||||
temp := float32(0)
|
||||
result, err := c.provider.Chat(ctx, chat.Request{
|
||||
Model: c.model,
|
||||
Temperature: &temp,
|
||||
ResponseFormat: &chat.ResponseFormat{
|
||||
Type: "json_object",
|
||||
},
|
||||
Messages: []chat.Message{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: userPrompt},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return ExtractResponse{}, err
|
||||
}
|
||||
|
||||
content := result.Message.Content
|
||||
var parsed ExtractResponse
|
||||
if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &parsed); err != nil {
|
||||
return ExtractResponse{}, err
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
// Decide decides what actions to take based on facts and existing memories
|
||||
func (c *ProviderLLMClient) Decide(ctx context.Context, req DecideRequest) (DecideResponse, error) {
|
||||
if len(req.Facts) == 0 {
|
||||
return DecideResponse{}, fmt.Errorf("facts is required")
|
||||
}
|
||||
|
||||
retrieved := make([]map[string]string, 0, len(req.Candidates))
|
||||
for _, candidate := range req.Candidates {
|
||||
retrieved = append(retrieved, map[string]string{
|
||||
"id": candidate.ID,
|
||||
"text": candidate.Memory,
|
||||
})
|
||||
}
|
||||
|
||||
prompt := getUpdateMemoryMessages(retrieved, req.Facts)
|
||||
|
||||
// Call provider with JSON mode
|
||||
temp := float32(0)
|
||||
result, err := c.provider.Chat(ctx, chat.Request{
|
||||
Model: c.model,
|
||||
Temperature: &temp,
|
||||
ResponseFormat: &chat.ResponseFormat{
|
||||
Type: "json_object",
|
||||
},
|
||||
Messages: []chat.Message{
|
||||
{Role: "user", Content: prompt},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return DecideResponse{}, err
|
||||
}
|
||||
|
||||
content := result.Message.Content
|
||||
var raw map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &raw); err != nil {
|
||||
return DecideResponse{}, err
|
||||
}
|
||||
|
||||
memoryItems := normalizeMemoryItems(raw["memory"])
|
||||
actions := make([]DecisionAction, 0, len(memoryItems))
|
||||
for _, item := range memoryItems {
|
||||
event := strings.ToUpper(asString(item["event"]))
|
||||
if event == "" {
|
||||
event = "ADD"
|
||||
}
|
||||
if event == "NONE" {
|
||||
continue
|
||||
}
|
||||
|
||||
text := asString(item["text"])
|
||||
if text == "" {
|
||||
text = asString(item["fact"])
|
||||
}
|
||||
if strings.TrimSpace(text) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
actions = append(actions, DecisionAction{
|
||||
Event: event,
|
||||
ID: normalizeID(item["id"]),
|
||||
Text: text,
|
||||
OldMemory: asString(item["old_memory"]),
|
||||
})
|
||||
}
|
||||
return DecideResponse{Actions: actions}, nil
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ type Server struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, swaggerHandler *handlers.SwaggerHandler, chatHandler *handlers.ChatHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, containerdHandler *handlers.ContainerdHandler) *Server {
|
||||
func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, chatHandler *handlers.ChatHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, containerdHandler *handlers.ContainerdHandler) *Server {
|
||||
if addr == "" {
|
||||
addr = ":8080"
|
||||
}
|
||||
@@ -50,12 +50,12 @@ func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler,
|
||||
if embeddingsHandler != nil {
|
||||
embeddingsHandler.Register(e)
|
||||
}
|
||||
if swaggerHandler != nil {
|
||||
swaggerHandler.Register(e)
|
||||
}
|
||||
if chatHandler != nil {
|
||||
chatHandler.Register(e)
|
||||
}
|
||||
if swaggerHandler != nil {
|
||||
swaggerHandler.Register(e)
|
||||
}
|
||||
if providersHandler != nil {
|
||||
providersHandler.Register(e)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user