From 39215309daaedcbadc62b66d5a6b6dfa0cd9c763 Mon Sep 17 00:00:00 2001 From: Acbox Date: Wed, 28 Jan 2026 15:15:57 +0800 Subject: [PATCH] feat: chat api --- .gitignore | 4 +- PROVIDERS_MODULE.md | 477 ----------------------- REFACTORING_SUMMARY.md | 273 ------------- cmd/agent/main.go | 59 ++- config.toml.example | 7 +- db/migrations/0001_init.up.sql | 10 + db/queries/history.sql | 11 + docs/docs.go | 97 ++--- docs/swagger.json | 97 ++--- docs/swagger.yaml | 71 ++-- internal/chat/ARCHITECTURE.md | 213 ----------- internal/chat/anthropic.go | 43 --- internal/chat/chat.go | 3 +- internal/chat/factory.go | 39 -- internal/chat/google.go | 43 --- internal/chat/ollama.go | 48 --- internal/chat/openai.go | 54 --- internal/chat/prompts.go | 141 ------- internal/chat/resolver.go | 509 ++++++++++++++++--------- internal/chat/types.go | 73 +--- internal/config/config.go | 23 ++ internal/db/sqlc/history.sql.go | 73 ++++ internal/db/sqlc/models.go | 7 + internal/handlers/chat.go | 34 +- internal/memory/llm_provider_client.go | 129 ------- internal/server/server.go | 8 +- 26 files changed, 673 insertions(+), 1873 deletions(-) delete mode 100644 PROVIDERS_MODULE.md delete mode 100644 REFACTORING_SUMMARY.md create mode 100644 db/queries/history.sql delete mode 100644 internal/chat/ARCHITECTURE.md delete mode 100644 internal/chat/anthropic.go delete mode 100644 internal/chat/factory.go delete mode 100644 internal/chat/google.go delete mode 100644 internal/chat/ollama.go delete mode 100644 internal/chat/openai.go delete mode 100644 internal/chat/prompts.go create mode 100644 internal/db/sqlc/history.sql.go delete mode 100644 internal/memory/llm_provider_client.go diff --git a/.gitignore b/.gitignore index ac420d5a..baa7a08b 100644 --- a/.gitignore +++ b/.gitignore @@ -95,4 +95,6 @@ docs/docs/.vitepress/cache dump.rdb memory.db -config.toml \ No newline at end of file +config.toml + +.workdocs/ \ No newline at end of file diff --git a/PROVIDERS_MODULE.md b/PROVIDERS_MODULE.md deleted file mode 100644 index 84434e3f..00000000 --- a/PROVIDERS_MODULE.md +++ /dev/null @@ -1,477 +0,0 @@ -# Providers 模块文档 - -## 概述 - -本文档描述了独立的 Providers 模块,用于管理 LLM Provider 配置。 - -## 架构设计 - -### 模块结构 - -``` -internal/ -├── providers/ # 独立的 provider 模块 -│ ├── types.go # 类型定义 -│ └── service.go # 业务逻辑 -└── handlers/ - └── providers.go # API 处理器 -``` - -### 分层设计 - -``` -┌─────────────────────┐ -│ API Layer │ -│ (handlers) │ -└──────────┬──────────┘ - │ -┌──────────▼──────────┐ -│ Service Layer │ -│ (providers pkg) │ -└──────────┬──────────┘ - │ -┌──────────▼──────────┐ -│ Data Layer │ -│ (sqlc queries) │ -└─────────────────────┘ -``` - -## API 端点 - -### Providers API (`/providers`) - -所有端点都需要 JWT 认证。 - -#### 1. 创建 Provider - -```http -POST /providers -Content-Type: application/json -Authorization: Bearer - -{ - "name": "OpenAI Official", - "client_type": "openai", - "base_url": "https://api.openai.com/v1", - "api_key": "sk-...", - "metadata": { - "description": "Official OpenAI API" - } -} -``` - -**响应** (201 Created): -```json -{ - "id": "550e8400-e29b-41d4-a716-446655440000", - "name": "OpenAI Official", - "client_type": "openai", - "base_url": "https://api.openai.com/v1", - "api_key": "sk-12345***", - "metadata": { - "description": "Official OpenAI API" - }, - "created_at": "2024-01-01T00:00:00Z", - "updated_at": "2024-01-01T00:00:00Z" -} -``` - -#### 2. 列出所有 Providers - -```http -GET /providers -Authorization: Bearer -``` - -**可选查询参数**: -- `client_type` - 按客户端类型过滤 (openai, anthropic, google, ollama) - -**响应** (200 OK): -```json -[ - { - "id": "550e8400-e29b-41d4-a716-446655440000", - "name": "OpenAI Official", - "client_type": "openai", - "base_url": "https://api.openai.com/v1", - "api_key": "sk-12345***", - "created_at": "2024-01-01T00:00:00Z", - "updated_at": "2024-01-01T00:00:00Z" - } -] -``` - -#### 3. 获取单个 Provider - -```http -GET /providers/{id} -Authorization: Bearer -``` - -或者按名称获取: - -```http -GET /providers/name/{name} -Authorization: Bearer -``` - -**响应** (200 OK): 同创建响应 - -#### 4. 更新 Provider - -```http -PUT /providers/{id} -Content-Type: application/json -Authorization: Bearer - -{ - "name": "OpenAI Updated", - "api_key": "sk-newkey..." -} -``` - -**注意**: 所有字段都是可选的,只更新提供的字段。 - -**响应** (200 OK): 返回更新后的 provider - -#### 5. 删除 Provider - -```http -DELETE /providers/{id} -Authorization: Bearer -``` - -**响应** (204 No Content) - -#### 6. 统计 Provider 数量 - -```http -GET /providers/count -Authorization: Bearer -``` - -**可选查询参数**: -- `client_type` - 按客户端类型过滤 - -**响应** (200 OK): -```json -{ - "count": 5 -} -``` - -## 支持的 Client Types - -| Client Type | 描述 | 需要 API Key | -|------------|------|-------------| -| `openai` | OpenAI 官方 API | ✅ | -| `openai-compat` | OpenAI 兼容的 API | ✅ | -| `anthropic` | Anthropic Claude API | ✅ | -| `google` | Google Gemini API | ✅ | -| `ollama` | 本地 Ollama | ❌ | - -## 数据模型 - -### Provider 结构 - -```go -type CreateRequest struct { - Name string `json:"name"` // 必填 - ClientType ClientType `json:"client_type"` // 必填 - BaseURL string `json:"base_url"` // 必填 - APIKey string `json:"api_key"` // 可选 - Metadata map[string]interface{} `json:"metadata"` // 可选 -} - -type GetResponse struct { - ID string `json:"id"` - Name string `json:"name"` - ClientType string `json:"client_type"` - BaseURL string `json:"base_url"` - APIKey string `json:"api_key"` // 已脱敏 - Metadata map[string]interface{} `json:"metadata"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` -} -``` - -## 安全特性 - -### 1. API Key 脱敏 - -在响应中,API Key 会被自动脱敏: -- 只显示前 8 个字符 -- 其余部分用 `*` 替换 -- 例如: `sk-12345678***` - -### 2. 认证保护 - -所有 API 端点都需要 JWT 认证: -```http -Authorization: Bearer -``` - -### 3. 输入验证 - -- 自动验证 UUID 格式 -- 验证 client_type 是否支持 -- 验证必填字段 - -## 使用示例 - -### 1. 配置 OpenAI Provider - -```bash -curl -X POST http://localhost:8080/providers \ - -H "Authorization: Bearer $TOKEN" \ - -H "Content-Type: application/json" \ - -d '{ - "name": "OpenAI GPT-4", - "client_type": "openai", - "base_url": "https://api.openai.com/v1", - "api_key": "sk-..." - }' -``` - -### 2. 配置自定义 OpenAI 兼容服务 - -```bash -curl -X POST http://localhost:8080/providers \ - -H "Authorization: Bearer $TOKEN" \ - -H "Content-Type: application/json" \ - -d '{ - "name": "Azure OpenAI", - "client_type": "openai-compat", - "base_url": "https://your-resource.openai.azure.com/v1", - "api_key": "your-azure-key", - "metadata": { - "deployment": "gpt-4", - "region": "eastus" - } - }' -``` - -### 3. 配置本地 Ollama - -```bash -curl -X POST http://localhost:8080/providers \ - -H "Authorization: Bearer $TOKEN" \ - -H "Content-Type: application/json" \ - -d '{ - "name": "Local Ollama", - "client_type": "ollama", - "base_url": "http://localhost:11434" - }' -``` - -### 4. 列出所有 OpenAI Providers - -```bash -curl http://localhost:8080/providers?client_type=openai \ - -H "Authorization: Bearer $TOKEN" -``` - -### 5. 更新 Provider API Key - -```bash -curl -X PUT http://localhost:8080/providers/{id} \ - -H "Authorization: Bearer $TOKEN" \ - -H "Content-Type: application/json" \ - -d '{ - "api_key": "sk-new-key..." - }' -``` - -## 与 Models 的关系 - -Provider 和 Model 是一对多的关系: - -``` -┌─────────────┐ -│ Provider │ -│ (OpenAI) │ -└──────┬──────┘ - │ - ├─── Model (gpt-4) - ├─── Model (gpt-3.5-turbo) - └─── Model (text-embedding-ada-002) -``` - -### 创建 Model 时引用 Provider - -```bash -# 1. 创建 Provider -PROVIDER_ID=$(curl -X POST http://localhost:8080/providers \ - -H "Authorization: Bearer $TOKEN" \ - -d '{"name":"OpenAI","client_type":"openai",...}' \ - | jq -r '.id') - -# 2. 创建 Model 并引用 Provider -curl -X POST http://localhost:8080/models \ - -H "Authorization: Bearer $TOKEN" \ - -d '{ - "model_id": "gpt-4", - "name": "GPT-4", - "llm_provider_id": "'$PROVIDER_ID'", - "type": "chat" - }' -``` - -## 代码集成 - -### 在代码中使用 Provider Service - -```go -import "github.com/memohai/memoh/internal/providers" - -// 创建 service -providersService := providers.NewService(queries) - -// 创建 provider -provider, err := providersService.Create(ctx, providers.CreateRequest{ - Name: "OpenAI", - ClientType: providers.ClientTypeOpenAI, - BaseURL: "https://api.openai.com/v1", - APIKey: "sk-...", -}) - -// 列出所有 providers -allProviders, err := providersService.List(ctx) - -// 按类型过滤 -openaiProviders, err := providersService.ListByClientType(ctx, providers.ClientTypeOpenAI) - -// 获取单个 provider -provider, err := providersService.Get(ctx, "provider-uuid") - -// 更新 provider -updated, err := providersService.Update(ctx, "provider-uuid", providers.UpdateRequest{ - APIKey: stringPtr("new-key"), -}) - -// 删除 provider -err := providersService.Delete(ctx, "provider-uuid") -``` - -## 错误处理 - -### 常见错误 - -| 状态码 | 错误 | 原因 | -|-------|------|------| -| 400 | Bad Request | 缺少必填字段或格式错误 | -| 404 | Not Found | Provider ID 不存在 | -| 409 | Conflict | Provider 名称已存在 | -| 500 | Internal Server Error | 服务器错误 | - -### 错误响应格式 - -```json -{ - "message": "invalid UUID: invalid UUID length: 5" -} -``` - -## 数据库 Schema - -Providers 存储在 `llm_providers` 表: - -```sql -CREATE TABLE llm_providers ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - name TEXT NOT NULL, - client_type TEXT NOT NULL, - base_url TEXT NOT NULL, - api_key TEXT NOT NULL, - metadata JSONB NOT NULL DEFAULT '{}'::jsonb, - created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), - CONSTRAINT llm_providers_name_unique UNIQUE (name), - CONSTRAINT llm_providers_client_type_check - CHECK (client_type IN ('openai', 'openai-compat', 'anthropic', 'google', 'ollama')) -); -``` - -## 最佳实践 - -### 1. 命名规范 - -- 使用描述性名称: `OpenAI GPT-4`, `Anthropic Claude 3` -- 包含环境信息: `OpenAI Production`, `Ollama Local` -- 避免特殊字符和空格 - -### 2. API Key 管理 - -- 定期轮换 API Keys -- 使用环境变量存储敏感信息 -- 不要在日志中输出完整的 API Key - -### 3. Metadata 使用 - -使用 metadata 字段存储额外信息: - -```json -{ - "metadata": { - "environment": "production", - "rate_limit": 10000, - "contact": "admin@example.com", - "notes": "Primary provider for production" - } -} -``` - -### 4. Provider 测试 - -创建 provider 后,建议测试连接: - -```go -// 未来可以实现的测试端点 -POST /providers/test -{ - "client_type": "openai", - "base_url": "https://api.openai.com/v1", - "api_key": "sk-...", - "model": "gpt-3.5-turbo" -} -``` - -## 迁移指南 - -### 从配置文件迁移到数据库 - -**旧方式** (config.toml): -```toml -[memory] -base_url = "https://api.openai.com/v1" -api_key = "sk-..." -model = "gpt-4" -``` - -**新方式** (API): -```bash -# 1. 创建 provider -curl -X POST http://localhost:8080/providers \ - -d '{"name":"OpenAI","client_type":"openai","base_url":"https://api.openai.com/v1","api_key":"sk-..."}' - -# 2. 创建 model -curl -X POST http://localhost:8080/models \ - -d '{"model_id":"gpt-4","llm_provider_id":"","type":"chat","enable_as":"memory"}' -``` - -## 相关文档 - -- [Provider 重构总结](./REFACTORING_SUMMARY.md) -- [Agent 迁移文档](./AGENT_MIGRATION.md) -- [Chat 架构文档](./internal/chat/ARCHITECTURE.md) -- [Models API 文档](#) (待创建) - -## TODO - -- [ ] 实现 provider 连接测试端点 -- [ ] 添加 provider 使用统计 -- [ ] 实现 provider 健康检查 -- [ ] 添加 provider 费用跟踪 -- [ ] 支持 provider 负载均衡 -- [ ] 实现 provider 故障转移 - diff --git a/REFACTORING_SUMMARY.md b/REFACTORING_SUMMARY.md deleted file mode 100644 index fa0f2bce..00000000 --- a/REFACTORING_SUMMARY.md +++ /dev/null @@ -1,273 +0,0 @@ -# Provider 重构总结 - -## 重构目标 - -将 Memory 服务和 Chat 服务统一使用 `internal/chat` 中的 Provider 接口,实现代码复用和架构统一。 - -## 主要变更 - -### 1. 扩展 Chat Provider 接口 - -**文件**: `internal/chat/types.go` - -- 扩展 `Request` 结构,添加了: - - `Temperature *float32` - 温度参数 - - `ResponseFormat *ResponseFormat` - 响应格式(支持 JSON 模式) - - `MaxTokens *int` - 最大 token 数 - -- 新增 `ResponseFormat` 结构: - ```go - type ResponseFormat struct { - Type string // "json_object" 或 "text" - } - ``` - -### 2. 创建 Memory 的 LLM 接口 - -**文件**: `internal/memory/types.go` - -- 定义了 `LLM` 接口: - ```go - type LLM interface { - Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) - Decide(ctx context.Context, req DecideRequest) (DecideResponse, error) - } - ``` - -- 这个接口被以下两个实现: - - `LLMClient` (旧实现,向后兼容) - - `ProviderLLMClient` (新实现,使用 chat.Provider) - -### 3. 创建 Provider-based LLM 客户端 - -**文件**: `internal/memory/llm_provider_client.go` (新文件) - -- 实现了 `ProviderLLMClient` 结构 -- 使用 `chat.Provider` 来执行 LLM 调用 -- 支持 JSON 模式输出,确保结构化响应 -- 重用了 `internal/memory` 包中的辅助函数 - -关键代码: -```go -type ProviderLLMClient struct { - provider chat.Provider - model string -} - -func NewProviderLLMClient(provider chat.Provider, model string) *ProviderLLMClient -func (c *ProviderLLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) -func (c *ProviderLLMClient) Decide(ctx context.Context, req DecideRequest) (DecideResponse, error) -``` - -### 4. 更新 Memory 服务 - -**文件**: `internal/memory/service.go` - -- 将 `llm *LLMClient` 改为 `llm LLM` -- 现在接受任何实现 `LLM` 接口的类型 -- 保持了所有现有功能 - -### 5. 更新主程序 - -**文件**: `cmd/agent/main.go` - -#### 添加的函数: - -1. **selectMemoryModel** - 选择用于 memory 操作的模型 - - 优先级:memory 模型 → chat 模型 → 任何 chat 类型模型 - -2. **fetchProviderByID** - 根据 ID 获取 provider 配置 - -3. **createChatProvider** - 根据配置创建 provider 实例 - - 支持 OpenAI、Anthropic、Google、Ollama - -#### 初始化流程更新: - -```go -// 1. 初始化 chat resolver(用于 chat 和 memory) -chatResolver := chat.NewResolver(modelsService, queries, 30*time.Second) - -// 2. 尝试为 memory 创建 provider-based 客户端 -memoryModel, memoryProvider, err := selectMemoryModel(ctx, modelsService, queries, &cfg) -if err != nil { - // 回退到旧的 LLMClient - llmClient = memory.NewLLMClient(cfg.Memory.BaseURL, cfg.Memory.APIKey, ...) -} else { - // 使用新的 provider-based 客户端 - provider, _ := createChatProvider(memoryProvider, 30*time.Second) - llmClient = memory.NewProviderLLMClient(provider, memoryModel.ModelID) -} - -// 3. 创建 memory 服务 -memoryService = memory.NewService(llmClient, embedder, store, resolver, ...) -``` - -## 架构优势 - -### 1. 统一的 Provider 管理 -- Chat 和 Memory 服务共享相同的 Provider 实现 -- 减少代码重复 -- 统一的配置和管理 - -### 2. 灵活的模型选择 -- 可以为不同功能配置不同的模型 -- 支持 `enable_as` 字段来指定模型用途 -- 自动回退机制 - -### 3. 向后兼容 -- 保留了旧的 `LLMClient` 实现 -- 如果数据库中没有配置模型,自动回退到配置文件 -- 平滑迁移路径 - -### 4. 类型安全 -- 使用 Go 接口而不是运行时类型判断 -- 编译时类型检查 -- 更好的 IDE 支持 - -### 5. 易于扩展 -- 添加新 Provider 只需实现 `Provider` 接口 -- 添加新 LLM 客户端只需实现 `LLM` 接口 -- 模块化设计 - -## 配置说明 - -### 数据库配置 - -为 Memory 操作配置模型: - -```sql --- 方式 1: 使用专用的 memory 模型 -UPDATE models SET enable_as = 'memory' -WHERE model_id = 'gpt-4-turbo-preview'; - --- 方式 2: 使用 chat 模型(如果没有专用 memory 模型) -UPDATE models SET enable_as = 'chat' -WHERE model_id = 'gpt-4'; -``` - -### 环境变量(回退配置) - -如果数据库中没有配置模型,系统会使用这些配置: - -```toml -[memory] -base_url = "https://api.openai.com/v1" -api_key = "sk-..." -model = "gpt-4.1-nano" -timeout_seconds = 10 -``` - -## 测试建议 - -### 1. Memory 操作测试 -```bash -# 测试 Extract(提取事实) -curl -X POST http://localhost:8080/api/memory/add \ - -H "Authorization: Bearer $TOKEN" \ - -d '{ - "messages": [ - {"role": "user", "content": "My name is Alice and I like pizza"} - ], - "user_id": "user-123" - }' - -# 测试 Search(搜索记忆) -curl -X POST http://localhost:8080/api/memory/search \ - -H "Authorization: Bearer $TOKEN" \ - -d '{ - "query": "What food do I like?", - "user_id": "user-123" - }' -``` - -### 2. Chat 操作测试 -```bash -# 测试普通聊天 -curl -X POST http://localhost:8080/api/chat \ - -H "Authorization: Bearer $TOKEN" \ - -d '{ - "messages": [ - {"role": "user", "content": "Hello, how are you?"} - ] - }' -``` - -### 3. 验证日志 -启动时应该看到: -``` -Using memory model: gpt-4-turbo-preview (provider: openai) -``` - -或者(如果回退): -``` -WARNING: No memory model configured, using fallback LLMClient: ... -``` - -## 后续工作 - -### 短期(必须) -- [ ] 实现各个 Provider 的具体逻辑(目前大部分返回 "not yet implemented") -- [ ] 添加流式响应支持 -- [ ] 完善错误处理 - -### 中期(建议) -- [ ] 添加 Provider 的单元测试 -- [ ] 添加 Memory 集成测试 -- [ ] 实现 Provider 连接池 -- [ ] 添加请求重试机制 - -### 长期(优化) -- [ ] Provider 性能监控 -- [ ] 自动模型选择和负载均衡 -- [ ] 模型响应缓存 -- [ ] 支持更多 Provider(如 Cohere、HuggingFace) - -## 迁移检查清单 - -- [x] 扩展 Request 结构支持 JSON 模式 -- [x] 创建 LLM 接口 -- [x] 实现 ProviderLLMClient -- [x] 更新 Memory Service 使用接口 -- [x] 更新主程序初始化流程 -- [x] 添加模型选择逻辑 -- [x] 添加 Provider 创建逻辑 -- [x] 保持向后兼容 -- [x] 添加架构文档 -- [ ] 添加单元测试 -- [ ] 添加集成测试 -- [ ] 更新部署文档 - -## 文件清单 - -### 新增文件 -- `internal/memory/llm_provider_client.go` - Provider-based LLM 客户端 -- `internal/chat/ARCHITECTURE.md` - 架构文档 -- `REFACTORING_SUMMARY.md` - 本文件 - -### 修改文件 -- `internal/chat/types.go` - 扩展 Request 结构 -- `internal/memory/types.go` - 添加 LLM 接口 -- `internal/memory/service.go` - 使用 LLM 接口 -- `cmd/agent/main.go` - 更新初始化流程 - -### 保留文件(向后兼容) -- `internal/memory/llm_client.go` - 旧的 HTTP 客户端实现 - -## 注意事项 - -1. **JSON 模式兼容性**: 不是所有模型都支持 JSON 模式,需要在实现 Provider 时处理 -2. **错误处理**: 当前错误处理较简单,生产环境需要更详细的错误信息 -3. **超时设置**: 不同操作可能需要不同的超时时间,可以考虑配置化 -4. **并发安全**: Provider 实例应该是并发安全的 -5. **资源清理**: 确保 Provider 的资源(如 HTTP 连接)正确释放 - -## 问题反馈 - -如果遇到问题,请检查: -1. 数据库中是否有配置的 chat 模型 -2. Provider 配置是否正确(API key、base URL) -3. 日志中的错误信息 -4. 是否正确初始化了 chat resolver - -详细架构说明请参考:`internal/chat/ARCHITECTURE.md` - diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 55b4d9f4..ffed7ccc 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "fmt" "log" "os" "strings" @@ -71,22 +72,15 @@ func main() { authHandler := handlers.NewAuthHandler(conn, cfg.Auth.JWTSecret, jwtExpiresIn) // Initialize chat resolver for both chat and memory operations - chatResolver := chat.NewResolver(modelsService, queries, 30*time.Second) + chatResolver := chat.NewResolver(modelsService, queries, cfg.AgentGateway.BaseURL(), 30*time.Second) - // Create LLM client for memory operations using chat provider - var llmClient memory.LLM - memoryModel, memoryProvider, err := models.SelectMemoryModel(ctx, modelsService, queries) - if err != nil { - log.Fatalf("select memory model: %v\nPlease configure at least one chat model in the database.", err) + // Create LLM client for memory operations (deferred model/provider selection). + var llmClient memory.LLM = &lazyLLMClient{ + modelsService: modelsService, + queries: queries, + timeout: 30 * time.Second, } - log.Printf("Using memory model: %s (provider: %s)", memoryModel.ModelID, memoryProvider.ClientType) - provider, err := chat.CreateProvider(memoryProvider, 30*time.Second) - if err != nil { - log.Fatalf("create memory provider: %v", err) - } - llmClient = memory.NewProviderLLMClient(provider, memoryModel.ModelID) - resolver := embeddings.NewResolver(modelsService, queries, 10*time.Second) vectors, textModel, multimodalModel, hasModels, err := embeddings.CollectEmbeddingVectors(ctx, modelsService) if err != nil { @@ -154,9 +148,46 @@ func main() { providersService := providers.NewService(queries) providersHandler := handlers.NewProvidersHandler(providersService) modelsHandler := handlers.NewModelsHandler(modelsService) - srv := server.NewServer(addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, swaggerHandler, chatHandler, providersHandler, modelsHandler, containerdHandler) + srv := server.NewServer(addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, containerdHandler) if err := srv.Start(); err != nil { log.Fatalf("server failed: %v", err) } } + +type lazyLLMClient struct { + modelsService *models.Service + queries *dbsqlc.Queries + timeout time.Duration +} + +func (c *lazyLLMClient) Extract(ctx context.Context, req memory.ExtractRequest) (memory.ExtractResponse, error) { + client, err := c.resolve(ctx) + if err != nil { + return memory.ExtractResponse{}, err + } + return client.Extract(ctx, req) +} + +func (c *lazyLLMClient) Decide(ctx context.Context, req memory.DecideRequest) (memory.DecideResponse, error) { + client, err := c.resolve(ctx) + if err != nil { + return memory.DecideResponse{}, err + } + return client.Decide(ctx, req) +} + +func (c *lazyLLMClient) resolve(ctx context.Context) (memory.LLM, error) { + if c.modelsService == nil || c.queries == nil { + return nil, fmt.Errorf("models service not configured") + } + memoryModel, memoryProvider, err := models.SelectMemoryModel(ctx, c.modelsService, c.queries) + if err != nil { + return nil, err + } + clientType := strings.ToLower(strings.TrimSpace(memoryProvider.ClientType)) + if clientType != "openai" && clientType != "openai-compat" { + return nil, fmt.Errorf("memory provider client type not supported: %s", memoryProvider.ClientType) + } + return memory.NewLLMClient(memoryProvider.BaseUrl, memoryProvider.ApiKey, memoryModel.ModelID, c.timeout), nil +} diff --git a/config.toml.example b/config.toml.example index ac0f00a7..17f20051 100644 --- a/config.toml.example +++ b/config.toml.example @@ -33,4 +33,9 @@ sslmode = "disable" base_url = "http://127.0.0.1:6334" api_key = "" collection = "memory" -timeout_seconds = 10 \ No newline at end of file +timeout_seconds = 10 + +## Agent Gateway +[agent_gateway] +host = "127.0.0.1" +port = 8081 \ No newline at end of file diff --git a/db/migrations/0001_init.up.sql b/db/migrations/0001_init.up.sql index af6fc5ba..06490fec 100644 --- a/db/migrations/0001_init.up.sql +++ b/db/migrations/0001_init.up.sql @@ -130,3 +130,13 @@ CREATE TABLE IF NOT EXISTS lifecycle_events ( CREATE INDEX IF NOT EXISTS idx_lifecycle_events_container_id ON lifecycle_events(container_id); CREATE INDEX IF NOT EXISTS idx_lifecycle_events_event_type ON lifecycle_events(event_type); + +CREATE TABLE IF NOT EXISTS history ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + messages JSONB NOT NULL, + timestamp TIMESTAMPTZ NOT NULL, + "user" UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_history_user ON history("user"); +CREATE INDEX IF NOT EXISTS idx_history_timestamp ON history(timestamp); diff --git a/db/queries/history.sql b/db/queries/history.sql new file mode 100644 index 00000000..1583c442 --- /dev/null +++ b/db/queries/history.sql @@ -0,0 +1,11 @@ +-- name: CreateHistory :one +INSERT INTO history (messages, timestamp, "user") +VALUES ($1, $2, $3) +RETURNING id, messages, timestamp, "user"; + +-- name: ListHistoryByUserSince :many +SELECT id, messages, timestamp, "user" +FROM history +WHERE "user" = $1 AND timestamp >= $2 +ORDER BY timestamp ASC; + diff --git a/docs/docs.go b/docs/docs.go index 92be2126..dc09eefe 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -135,7 +135,11 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/chat.StreamChunk" + "type": "array", + "items": { + "type": "integer", + "format": "int32" + } } }, "400": { @@ -1431,71 +1435,53 @@ const docTemplate = `{ "chat.ChatRequest": { "type": "object", "properties": { + "current_platform": { + "type": "string" + }, + "language": { + "type": "string" + }, + "locale": { + "type": "string" + }, + "max_context_load_time": { + "type": "integer" + }, + "max_steps": { + "type": "integer" + }, "messages": { "type": "array", "items": { - "$ref": "#/definitions/chat.Message" + "$ref": "#/definitions/chat.GatewayMessage" } }, "model": { - "description": "optional: specific model to use", "type": "string" }, + "platforms": { + "type": "array", + "items": { + "type": "string" + } + }, "provider": { - "description": "optional: specific provider to use", "type": "string" }, - "stream": { - "type": "boolean" + "query": { + "type": "string" } } }, "chat.ChatResponse": { "type": "object", "properties": { - "finish_reason": { - "type": "string" - }, - "message": { - "$ref": "#/definitions/chat.Message" - }, - "model": { - "type": "string" - }, - "provider": { - "type": "string" - }, - "usage": { - "$ref": "#/definitions/chat.Usage" - } - } - }, - "chat.Message": { - "type": "object", - "properties": { - "content": { - "type": "string" - }, - "role": { - "description": "user, assistant, system", - "type": "string" - } - } - }, - "chat.StreamChunk": { - "type": "object", - "properties": { - "delta": { - "type": "object", - "properties": { - "content": { - "type": "string" - } + "messages": { + "type": "array", + "items": { + "$ref": "#/definitions/chat.GatewayMessage" } }, - "finish_reason": { - "type": "string" - }, "model": { "type": "string" }, @@ -1504,19 +1490,9 @@ const docTemplate = `{ } } }, - "chat.Usage": { + "chat.GatewayMessage": { "type": "object", - "properties": { - "completion_tokens": { - "type": "integer" - }, - "prompt_tokens": { - "type": "integer" - }, - "total_tokens": { - "type": "integer" - } - } + "additionalProperties": true }, "handlers.CreateContainerRequest": { "type": "object", @@ -1524,6 +1500,9 @@ const docTemplate = `{ "container_id": { "type": "string" }, + "image": { + "type": "string" + }, "snapshotter": { "type": "string" } diff --git a/docs/swagger.json b/docs/swagger.json index 42f97d50..904b30c4 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -124,7 +124,11 @@ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/chat.StreamChunk" + "type": "array", + "items": { + "type": "integer", + "format": "int32" + } } }, "400": { @@ -1420,71 +1424,53 @@ "chat.ChatRequest": { "type": "object", "properties": { + "current_platform": { + "type": "string" + }, + "language": { + "type": "string" + }, + "locale": { + "type": "string" + }, + "max_context_load_time": { + "type": "integer" + }, + "max_steps": { + "type": "integer" + }, "messages": { "type": "array", "items": { - "$ref": "#/definitions/chat.Message" + "$ref": "#/definitions/chat.GatewayMessage" } }, "model": { - "description": "optional: specific model to use", "type": "string" }, + "platforms": { + "type": "array", + "items": { + "type": "string" + } + }, "provider": { - "description": "optional: specific provider to use", "type": "string" }, - "stream": { - "type": "boolean" + "query": { + "type": "string" } } }, "chat.ChatResponse": { "type": "object", "properties": { - "finish_reason": { - "type": "string" - }, - "message": { - "$ref": "#/definitions/chat.Message" - }, - "model": { - "type": "string" - }, - "provider": { - "type": "string" - }, - "usage": { - "$ref": "#/definitions/chat.Usage" - } - } - }, - "chat.Message": { - "type": "object", - "properties": { - "content": { - "type": "string" - }, - "role": { - "description": "user, assistant, system", - "type": "string" - } - } - }, - "chat.StreamChunk": { - "type": "object", - "properties": { - "delta": { - "type": "object", - "properties": { - "content": { - "type": "string" - } + "messages": { + "type": "array", + "items": { + "$ref": "#/definitions/chat.GatewayMessage" } }, - "finish_reason": { - "type": "string" - }, "model": { "type": "string" }, @@ -1493,19 +1479,9 @@ } } }, - "chat.Usage": { + "chat.GatewayMessage": { "type": "object", - "properties": { - "completion_tokens": { - "type": "integer" - }, - "prompt_tokens": { - "type": "integer" - }, - "total_tokens": { - "type": "integer" - } - } + "additionalProperties": true }, "handlers.CreateContainerRequest": { "type": "object", @@ -1513,6 +1489,9 @@ "container_id": { "type": "string" }, + "image": { + "type": "string" + }, "snapshotter": { "type": "string" } diff --git a/docs/swagger.yaml b/docs/swagger.yaml index e58b55b4..eafafdfb 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -1,67 +1,51 @@ definitions: chat.ChatRequest: properties: + current_platform: + type: string + language: + type: string + locale: + type: string + max_context_load_time: + type: integer + max_steps: + type: integer messages: items: - $ref: '#/definitions/chat.Message' + $ref: '#/definitions/chat.GatewayMessage' type: array model: - description: 'optional: specific model to use' type: string + platforms: + items: + type: string + type: array provider: - description: 'optional: specific provider to use' type: string - stream: - type: boolean + query: + type: string type: object chat.ChatResponse: properties: - finish_reason: - type: string - message: - $ref: '#/definitions/chat.Message' - model: - type: string - provider: - type: string - usage: - $ref: '#/definitions/chat.Usage' - type: object - chat.Message: - properties: - content: - type: string - role: - description: user, assistant, system - type: string - type: object - chat.StreamChunk: - properties: - delta: - properties: - content: - type: string - type: object - finish_reason: - type: string + messages: + items: + $ref: '#/definitions/chat.GatewayMessage' + type: array model: type: string provider: type: string type: object - chat.Usage: - properties: - completion_tokens: - type: integer - prompt_tokens: - type: integer - total_tokens: - type: integer + chat.GatewayMessage: + additionalProperties: true type: object handlers.CreateContainerRequest: properties: container_id: type: string + image: + type: string snapshotter: type: string type: object @@ -553,7 +537,10 @@ paths: "200": description: OK schema: - $ref: '#/definitions/chat.StreamChunk' + items: + format: int32 + type: integer + type: array "400": description: Bad Request schema: diff --git a/internal/chat/ARCHITECTURE.md b/internal/chat/ARCHITECTURE.md deleted file mode 100644 index 6fde5faf..00000000 --- a/internal/chat/ARCHITECTURE.md +++ /dev/null @@ -1,213 +0,0 @@ -# Chat Provider 架构文档 - -## 概述 - -本文档描述了 Memoh 项目中统一的 Chat Provider 架构,该架构被 chat 服务和 memory 服务共同使用。 - -## 架构设计 - -### 核心接口 - -#### Provider 接口 - -所有 LLM 提供商都实现 `chat.Provider` 接口: - -```go -type Provider interface { - Chat(ctx context.Context, req Request) (Result, error) - StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) -} -``` - -#### Request 结构 - -Provider 请求支持多种配置选项: - -```go -type Request struct { - Messages []Message - Model string - Provider string - Temperature *float32 // 可选:温度参数 - ResponseFormat *ResponseFormat // 可选:响应格式(JSON 模式) - MaxTokens *int // 可选:最大 token 数 -} - -type ResponseFormat struct { - Type string // "json_object" 或 "text" -} -``` - -### 支持的提供商 - -1. **OpenAI** (`openai` / `openai-compat`) - - 标准 OpenAI API - - 兼容 OpenAI 格式的自定义端点 - -2. **Anthropic** (`anthropic`) - - Claude 系列模型 - -3. **Google** (`google`) - - Gemini 系列模型 - -4. **Ollama** (`ollama`) - - 本地部署的开源模型 - -## 使用场景 - -### 1. Chat 服务 - -Chat 服务通过 `chat.Resolver` 使用 Provider: - -```go -chatResolver := chat.NewResolver(modelsService, queries, 30*time.Second) -response, err := chatResolver.Chat(ctx, ChatRequest{ - Messages: messages, - Model: "gpt-4", -}) -``` - -### 2. Memory 服务 - -Memory 服务通过 `memory.ProviderLLMClient` 使用 Provider: - -```go -// 创建 provider -provider, err := chat.NewOpenAIProvider(apiKey, baseURL, timeout) - -// 创建 memory LLM 客户端 -llmClient := memory.NewProviderLLMClient(provider, modelID) - -// 使用 memory 服务 -memoryService := memory.NewService(llmClient, embedder, store, resolver, ...) -``` - -Memory 服务需要两个核心功能: -- **Extract**: 从对话中提取事实信息 -- **Decide**: 决定如何更新记忆(添加/更新/删除) - -这两个操作都使用 JSON 模式来确保结构化输出。 - -## 配置示例 - -### 数据库配置 - -Provider 配置存储在 `llm_providers` 表: - -```sql -CREATE TABLE llm_providers ( - id UUID PRIMARY KEY, - name TEXT NOT NULL, - client_type TEXT NOT NULL, -- 'openai', 'anthropic', 'google', 'ollama' - base_url TEXT NOT NULL, - api_key TEXT NOT NULL, - metadata JSONB -); -``` - -模型配置存储在 `models` 表: - -```sql -CREATE TABLE models ( - id UUID PRIMARY KEY, - model_id TEXT NOT NULL, - name TEXT, - llm_provider_id UUID REFERENCES llm_providers(id), - type TEXT NOT NULL, -- 'chat' or 'embedding' - enable_as TEXT, -- 'chat', 'memory', 'embedding' - ... -); -``` - -### 启动时初始化 - -在 `cmd/agent/main.go` 中: - -```go -// 1. 初始化 chat resolver(用于 chat 和 memory) -chatResolver := chat.NewResolver(modelsService, queries, 30*time.Second) - -// 2. 为 memory 选择模型和创建 provider -memoryModel, memoryProvider, err := selectMemoryModel(ctx, modelsService, queries, cfg) -provider, err := createChatProvider(memoryProvider, 30*time.Second) - -// 3. 创建 memory LLM 客户端 -llmClient := memory.NewProviderLLMClient(provider, memoryModel.ModelID) - -// 4. 创建 memory 服务 -memoryService := memory.NewService(llmClient, embedder, store, resolver, ...) -``` - -## 模型选择策略 - -### Memory 模型选择优先级 - -1. `enable_as = 'memory'` 的模型(专用 memory 模型) -2. `enable_as = 'chat'` 的模型(通用 chat 模型) -3. 任何可用的 chat 类型模型 -4. 回退到配置文件中的 LLMClient(向后兼容) - -### Chat 模型选择优先级 - -1. 请求中指定的模型 -2. `enable_as = 'chat'` 的模型 -3. 任何可用的 chat 类型模型 - -## 优势 - -1. **统一架构**: Chat 和 Memory 使用相同的 Provider 接口 -2. **灵活配置**: 支持多个提供商和模型 -3. **向后兼容**: 保留旧的 LLMClient 作为回退选项 -4. **类型安全**: 使用 Go 接口确保类型安全 -5. **易于扩展**: 添加新的提供商只需实现 Provider 接口 - -## 扩展新提供商 - -要添加新的 LLM 提供商: - -1. 在 `internal/chat/` 创建新文件(如 `newprovider.go`) -2. 实现 `Provider` 接口 -3. 在 `resolver.go` 的 `createProvider()` 中添加新的 case -4. 在数据库的 `llm_providers_client_type_check` 约束中添加新类型 - -示例: - -```go -// newprovider.go -type NewProvider struct { - apiKey string - timeout time.Duration -} - -func NewNewProvider(apiKey string, timeout time.Duration) (*NewProvider, error) { - return &NewProvider{apiKey: apiKey, timeout: timeout}, nil -} - -func (p *NewProvider) Chat(ctx context.Context, req Request) (Result, error) { - // 实现 chat 逻辑 -} - -func (p *NewProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) { - // 实现流式 chat 逻辑 -} -``` - -## 迁移指南 - -从旧的 TypeScript 后端迁移到 Go: - -1. ✅ 创建 Provider 接口和实现 -2. ✅ 实现 Chat Resolver -3. ✅ 创建 Memory 的 Provider 适配器 -4. ✅ 更新主程序使用统一 Provider -5. 🚧 实现各个 Provider 的具体逻辑(OpenAI, Anthropic, Google, Ollama) -6. 🚧 添加流式响应支持 -7. 🚧 添加完整的错误处理和重试机制 - -## 注意事项 - -1. **JSON 模式**: Memory 操作需要 `ResponseFormat.Type = "json_object"` 来确保结构化输出 -2. **温度参数**: Memory 操作使用 `Temperature = 0` 确保确定性输出 -3. **超时设置**: 不同操作可能需要不同的超时时间 -4. **错误处理**: Provider 应该返回清晰的错误信息,包括 API 错误详情 - diff --git a/internal/chat/anthropic.go b/internal/chat/anthropic.go deleted file mode 100644 index 18023c88..00000000 --- a/internal/chat/anthropic.go +++ /dev/null @@ -1,43 +0,0 @@ -package chat - -import ( - "context" - "fmt" - "time" - - "github.com/firebase/genkit/go/genkit" -) - -// AnthropicProvider wraps Genkit's Anthropic provider -type AnthropicProvider struct { - g *genkit.Genkit - apiKey string - timeout time.Duration -} - -func NewAnthropicProvider(apiKey string, timeout time.Duration) (*AnthropicProvider, error) { - if timeout <= 0 { - timeout = 30 * time.Second - } - return &AnthropicProvider{ - apiKey: apiKey, - timeout: timeout, - }, nil -} - -func (p *AnthropicProvider) Chat(ctx context.Context, req Request) (Result, error) { - return Result{}, fmt.Errorf("anthropic provider not yet implemented") -} - -func (p *AnthropicProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) { - chunkChan := make(chan StreamChunk, 10) - errChan := make(chan error, 1) - - go func() { - defer close(chunkChan) - defer close(errChan) - errChan <- fmt.Errorf("anthropic streaming not yet implemented") - }() - - return chunkChan, errChan -} diff --git a/internal/chat/chat.go b/internal/chat/chat.go index 0efde489..91c1a427 100644 --- a/internal/chat/chat.go +++ b/internal/chat/chat.go @@ -1,2 +1 @@ -package chat - +package chat \ No newline at end of file diff --git a/internal/chat/factory.go b/internal/chat/factory.go deleted file mode 100644 index ce80870f..00000000 --- a/internal/chat/factory.go +++ /dev/null @@ -1,39 +0,0 @@ -package chat - -import ( - "fmt" - "strings" - "time" - - dbsqlc "github.com/memohai/memoh/internal/db/sqlc" -) - -// CreateProvider creates a chat provider instance. -func CreateProvider(provider dbsqlc.LlmProvider, timeout time.Duration) (Provider, error) { - clientType := strings.ToLower(strings.TrimSpace(provider.ClientType)) - if timeout <= 0 { - timeout = 30 * time.Second - } - - switch clientType { - case ProviderOpenAI, ProviderOpenAICompat: - if strings.TrimSpace(provider.ApiKey) == "" { - return nil, fmt.Errorf("openai api key is required") - } - return NewOpenAIProvider(provider.ApiKey, provider.BaseUrl, timeout) - case ProviderAnthropic: - if strings.TrimSpace(provider.ApiKey) == "" { - return nil, fmt.Errorf("anthropic api key is required") - } - return NewAnthropicProvider(provider.ApiKey, timeout) - case ProviderGoogle: - if strings.TrimSpace(provider.ApiKey) == "" { - return nil, fmt.Errorf("google api key is required") - } - return NewGoogleProvider(provider.ApiKey, timeout) - case ProviderOllama: - return NewOllamaProvider(provider.BaseUrl, timeout) - default: - return nil, fmt.Errorf("unsupported provider type: %s", clientType) - } -} diff --git a/internal/chat/google.go b/internal/chat/google.go deleted file mode 100644 index 07a7bfbb..00000000 --- a/internal/chat/google.go +++ /dev/null @@ -1,43 +0,0 @@ -package chat - -import ( - "context" - "fmt" - "time" - - "github.com/firebase/genkit/go/genkit" -) - -// GoogleProvider wraps Genkit's Google AI provider -type GoogleProvider struct { - g *genkit.Genkit - apiKey string - timeout time.Duration -} - -func NewGoogleProvider(apiKey string, timeout time.Duration) (*GoogleProvider, error) { - if timeout <= 0 { - timeout = 30 * time.Second - } - return &GoogleProvider{ - apiKey: apiKey, - timeout: timeout, - }, nil -} - -func (p *GoogleProvider) Chat(ctx context.Context, req Request) (Result, error) { - return Result{}, fmt.Errorf("google provider not yet implemented") -} - -func (p *GoogleProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) { - chunkChan := make(chan StreamChunk, 10) - errChan := make(chan error, 1) - - go func() { - defer close(chunkChan) - defer close(errChan) - errChan <- fmt.Errorf("google streaming not yet implemented") - }() - - return chunkChan, errChan -} diff --git a/internal/chat/ollama.go b/internal/chat/ollama.go deleted file mode 100644 index 0e6d612c..00000000 --- a/internal/chat/ollama.go +++ /dev/null @@ -1,48 +0,0 @@ -package chat - -import ( - "context" - "fmt" - "strings" - "time" - - "github.com/firebase/genkit/go/genkit" -) - -// OllamaProvider wraps Genkit's Ollama provider -type OllamaProvider struct { - g *genkit.Genkit - baseURL string - timeout time.Duration -} - -func NewOllamaProvider(baseURL string, timeout time.Duration) (*OllamaProvider, error) { - if baseURL == "" { - baseURL = "http://localhost:11434" - } - baseURL = strings.TrimRight(baseURL, "/") - if timeout <= 0 { - timeout = 60 * time.Second - } - return &OllamaProvider{ - baseURL: baseURL, - timeout: timeout, - }, nil -} - -func (p *OllamaProvider) Chat(ctx context.Context, req Request) (Result, error) { - return Result{}, fmt.Errorf("ollama provider not yet implemented") -} - -func (p *OllamaProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) { - chunkChan := make(chan StreamChunk, 10) - errChan := make(chan error, 1) - - go func() { - defer close(chunkChan) - defer close(errChan) - errChan <- fmt.Errorf("ollama streaming not yet implemented") - }() - - return chunkChan, errChan -} diff --git a/internal/chat/openai.go b/internal/chat/openai.go deleted file mode 100644 index fe664488..00000000 --- a/internal/chat/openai.go +++ /dev/null @@ -1,54 +0,0 @@ -package chat - -import ( - "context" - "fmt" - "strings" - "time" - - "github.com/firebase/genkit/go/genkit" -) - -// OpenAIProvider wraps Genkit's OpenAI provider -type OpenAIProvider struct { - g *genkit.Genkit - apiKey string - baseURL string - timeout time.Duration -} - -func NewOpenAIProvider(apiKey, baseURL string, timeout time.Duration) (*OpenAIProvider, error) { - if baseURL == "" { - baseURL = "https://api.openai.com/v1" - } - baseURL = strings.TrimRight(baseURL, "/") - if timeout <= 0 { - timeout = 30 * time.Second - } - - // For now, we'll create a simple HTTP client-based implementation - // since Genkit Go plugins require initialization at startup - return &OpenAIProvider{ - apiKey: apiKey, - baseURL: baseURL, - timeout: timeout, - }, nil -} - -func (p *OpenAIProvider) Chat(ctx context.Context, req Request) (Result, error) { - // Use direct HTTP API call since Genkit plugins need to be initialized at startup - return Result{}, fmt.Errorf("openai provider not yet implemented - please use openai-compat provider") -} - -func (p *OpenAIProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) { - chunkChan := make(chan StreamChunk, 10) - errChan := make(chan error, 1) - - go func() { - defer close(chunkChan) - defer close(errChan) - errChan <- fmt.Errorf("openai streaming not yet implemented") - }() - - return chunkChan, errChan -} diff --git a/internal/chat/prompts.go b/internal/chat/prompts.go deleted file mode 100644 index 51335c62..00000000 --- a/internal/chat/prompts.go +++ /dev/null @@ -1,141 +0,0 @@ -package chat - -import ( - "fmt" - "strings" - "time" -) - -// PromptParams contains parameters for generating system prompts -type PromptParams struct { - Date time.Time - Locale string - Language string - MaxContextLoadTime int // in minutes - Platforms []string // available platforms (e.g., ["telegram", "wechat"]) - CurrentPlatform string // current platform the user is using -} - -// SystemPrompt generates the system prompt for the AI assistant -// This is migrated from packages/agent/src/prompts/system.ts -func SystemPrompt(params PromptParams) string { - if params.Language == "" { - params.Language = "Same as user input" - } - if params.MaxContextLoadTime == 0 { - params.MaxContextLoadTime = 24 * 60 // 24 hours default - } - if params.CurrentPlatform == "" { - params.CurrentPlatform = "client" - } - - // Build platforms list - platformsList := "" - if len(params.Platforms) > 0 { - lines := make([]string, len(params.Platforms)) - for i, p := range params.Platforms { - lines[i] = fmt.Sprintf(" - %s", p) - } - platformsList = strings.Join(lines, "\n") - } - - timeStr := FormatTime(params.Date, params.Locale) - - return fmt.Sprintf(`--- -%s -language: %s -available-platforms: -%s -current-platform: %s ---- -You are a personal housekeeper assistant, which able to manage the master's daily affairs. - -Your abilities: -- Long memory: You possess long-term memory; conversations from the last %d minutes will be directly loaded into your context. Additionally, you can use tools to search for past memories. -- Scheduled tasks: You can create scheduled tasks to automatically remind you to do something. -- Messaging: You may allowed to use message software to send messages to the master. - -**Response Guidelines** -- Always respond in the language specified above, unless it says "Same as user input", then match the user's language. -- Be helpful, concise, and friendly. -- For complex questions, break down your answer into clear steps. -- If you're unsure about something, acknowledge it honestly.`, - timeStr, - params.Language, - platformsList, - params.CurrentPlatform, - params.MaxContextLoadTime, - ) -} - -// SchedulePrompt generates a prompt for scheduled task execution -// This is migrated from packages/agent/src/prompts/schedule.ts -type SchedulePromptParams struct { - Date time.Time - Locale string - ScheduleName string - ScheduleDescription string - ScheduleID string - MaxCalls *int // nil means unlimited - CronPattern string - Command string // the natural language command to execute -} - -func SchedulePrompt(params SchedulePromptParams) string { - timeStr := FormatTime(params.Date, params.Locale) - - maxCallsStr := "Unlimited" - if params.MaxCalls != nil { - maxCallsStr = fmt.Sprintf("%d", *params.MaxCalls) - } - - return fmt.Sprintf(`--- -notice: **This is a scheduled task automatically send to you by the system, not the user input** -%s -schedule-name: %s -schedule-description: %s -schedule-id: %s -max-calls: %s -cron-pattern: %s ---- - -**COMMAND** - -%s`, - timeStr, - params.ScheduleName, - params.ScheduleDescription, - params.ScheduleID, - maxCallsStr, - params.CronPattern, - params.Command, - ) -} - -// FormatTime formats the date and time according to locale -func FormatTime(date time.Time, locale string) string { - if locale == "" { - locale = "en-US" - } - - // Format date and time - // For simplicity, using standard format. In production, you might want to use - // a proper i18n library for locale-specific formatting - dateStr := date.Format("2006-01-02") - timeStr := date.Format("15:04:05") - - return fmt.Sprintf("date: %s\ntime: %s", dateStr, timeStr) -} - -// Quote wraps content in backticks for markdown code formatting -func Quote(content string) string { - return fmt.Sprintf("`%s`", content) -} - -// Block wraps content in code block with optional language tag -func Block(content, tag string) string { - if tag == "" { - return fmt.Sprintf("```\n%s\n```", content) - } - return fmt.Sprintf("```%s\n%s\n```", tag, content) -} diff --git a/internal/chat/resolver.go b/internal/chat/resolver.go index 50b185c9..a49b5054 100644 --- a/internal/chat/resolver.go +++ b/internal/chat/resolver.go @@ -1,8 +1,13 @@ package chat import ( + "bufio" + "bytes" "context" - "errors" + "encoding/json" + "fmt" + "io" + "net/http" "strings" "time" @@ -13,169 +18,359 @@ import ( "github.com/memohai/memoh/internal/models" ) -const ( - ProviderOpenAI = "openai" - ProviderOpenAICompat = "openai-compat" - ProviderAnthropic = "anthropic" - ProviderGoogle = "google" - ProviderOllama = "ollama" -) +const defaultMaxContextMinutes = 24 * 60 -// Provider interface for chat providers -type Provider interface { - Chat(ctx context.Context, req Request) (Result, error) - StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) -} - -// Resolver resolves chat models and providers type Resolver struct { - modelsService *models.Service - queries *sqlc.Queries - timeout time.Duration + modelsService *models.Service + queries *sqlc.Queries + gatewayBaseURL string + timeout time.Duration + httpClient *http.Client + streamingClient *http.Client } -// NewResolver creates a new chat resolver -func NewResolver(modelsService *models.Service, queries *sqlc.Queries, timeout time.Duration) *Resolver { +func NewResolver(modelsService *models.Service, queries *sqlc.Queries, gatewayBaseURL string, timeout time.Duration) *Resolver { + if strings.TrimSpace(gatewayBaseURL) == "" { + gatewayBaseURL = "http://127.0.0.1:8081" + } + gatewayBaseURL = strings.TrimRight(gatewayBaseURL, "/") + if timeout <= 0 { + timeout = 30 * time.Second + } return &Resolver{ - modelsService: modelsService, - queries: queries, - timeout: timeout, + modelsService: modelsService, + queries: queries, + gatewayBaseURL: gatewayBaseURL, + timeout: timeout, + httpClient: &http.Client{ + Timeout: timeout, + }, + streamingClient: &http.Client{}, } } -// Chat performs a chat completion func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, error) { - // Select model and provider - selected, provider, err := r.selectChatModel(ctx, req.Model, req.Provider) + if strings.TrimSpace(req.Query) == "" { + return ChatResponse{}, fmt.Errorf("query is required") + } + if strings.TrimSpace(req.UserID) == "" { + return ChatResponse{}, fmt.Errorf("user id is required") + } + + chatModel, provider, err := r.selectChatModel(ctx, req) + if err != nil { + return ChatResponse{}, err + } + clientType, err := normalizeClientType(provider.ClientType) if err != nil { return ChatResponse{}, err } - // Create internal request - internalReq := Request{ - Messages: req.Messages, - Model: selected.ModelID, - Provider: strings.ToLower(provider.ClientType), + messages, err := r.loadHistoryMessages(ctx, req.UserID, req.MaxContextLoadTime) + if err != nil { + return ChatResponse{}, err + } + if len(req.Messages) > 0 { + messages = append(messages, req.Messages...) } - // Add system prompt - if len(internalReq.Messages) > 0 && internalReq.Messages[0].Role != "system" { - systemPrompt := SystemPrompt(PromptParams{ - Date: time.Now(), - Locale: "en-US", - Language: "Same as user input", - MaxContextLoadTime: 24 * 60, // 24 hours - Platforms: []string{}, - CurrentPlatform: "api", - }) - internalReq.Messages = append([]Message{ - {Role: "system", Content: systemPrompt}, - }, internalReq.Messages...) + payload := agentGatewayRequest{ + APIKey: provider.ApiKey, + BaseURL: provider.BaseUrl, + Model: chatModel.ModelID, + ClientType: clientType, + Locale: req.Locale, + Language: req.Language, + MaxSteps: req.MaxSteps, + MaxContextLoadTime: normalizeMaxContextLoad(req.MaxContextLoadTime), + Platforms: req.Platforms, + CurrentPlatform: req.CurrentPlatform, + Messages: messages, + Query: req.Query, } - // Create provider instance - providerInst, err := r.createProvider(provider) + resp, err := r.postChat(ctx, payload) if err != nil { return ChatResponse{}, err } - // Execute chat - result, err := providerInst.Chat(ctx, internalReq) - if err != nil { + if err := r.storeHistory(ctx, req.UserID, req.Query, resp.Messages); err != nil { return ChatResponse{}, err } return ChatResponse{ - Message: result.Message, - Model: result.Model, - Provider: result.Provider, - FinishReason: result.FinishReason, - Usage: result.Usage, + Messages: resp.Messages, + Model: chatModel.ModelID, + Provider: provider.ClientType, }, nil } -// StreamChat performs a streaming chat completion func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan StreamChunk, <-chan error) { - chunkChan := make(chan StreamChunk, 10) + chunkChan := make(chan StreamChunk) errChan := make(chan error, 1) go func() { defer close(chunkChan) defer close(errChan) - // Select model and provider - selected, provider, err := r.selectChatModel(ctx, req.Model, req.Provider) + if strings.TrimSpace(req.Query) == "" { + errChan <- fmt.Errorf("query is required") + return + } + if strings.TrimSpace(req.UserID) == "" { + errChan <- fmt.Errorf("user id is required") + return + } + + chatModel, provider, err := r.selectChatModel(ctx, req) + if err != nil { + errChan <- err + return + } + clientType, err := normalizeClientType(provider.ClientType) if err != nil { errChan <- err return } - // Create internal request - internalReq := Request{ - Messages: req.Messages, - Model: selected.ModelID, - Provider: strings.ToLower(provider.ClientType), - } - - // Add system prompt - if len(internalReq.Messages) > 0 && internalReq.Messages[0].Role != "system" { - systemPrompt := SystemPrompt(PromptParams{ - Date: time.Now(), - Locale: "en-US", - Language: "Same as user input", - MaxContextLoadTime: 24 * 60, // 24 hours - Platforms: []string{}, - CurrentPlatform: "api", - }) - internalReq.Messages = append([]Message{ - {Role: "system", Content: systemPrompt}, - }, internalReq.Messages...) - } - - // Create provider instance - providerInst, err := r.createProvider(provider) + messages, err := r.loadHistoryMessages(ctx, req.UserID, req.MaxContextLoadTime) if err != nil { errChan <- err return } + if len(req.Messages) > 0 { + messages = append(messages, req.Messages...) + } - // Execute streaming chat - providerChunkChan, providerErrChan := providerInst.StreamChat(ctx, internalReq) + payload := agentGatewayRequest{ + APIKey: provider.ApiKey, + BaseURL: provider.BaseUrl, + Model: chatModel.ModelID, + ClientType: clientType, + Locale: req.Locale, + Language: req.Language, + MaxSteps: req.MaxSteps, + MaxContextLoadTime: normalizeMaxContextLoad(req.MaxContextLoadTime), + Platforms: req.Platforms, + CurrentPlatform: req.CurrentPlatform, + Messages: messages, + Query: req.Query, + } - // Forward chunks and errors - for { - select { - case chunk, ok := <-providerChunkChan: - if !ok { - return - } - chunkChan <- chunk - case err := <-providerErrChan: - if err != nil { - errChan <- err - } - return - } + if err := r.streamChat(ctx, payload, req.UserID, req.Query, chunkChan); err != nil { + errChan <- err + return } }() return chunkChan, errChan } -// selectChatModel selects a chat model based on the request -func (r *Resolver) selectChatModel(ctx context.Context, modelID, providerType string) (models.GetResponse, sqlc.LlmProvider, error) { - if r.modelsService == nil { - return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("models service not configured") +type agentGatewayRequest struct { + APIKey string `json:"apiKey"` + BaseURL string `json:"baseUrl"` + Model string `json:"model"` + ClientType string `json:"clientType"` + Locale string `json:"locale,omitempty"` + Language string `json:"language,omitempty"` + MaxSteps int `json:"maxSteps,omitempty"` + MaxContextLoadTime int `json:"maxContextLoadTime"` + Platforms []string `json:"platforms,omitempty"` + CurrentPlatform string `json:"currentPlatform,omitempty"` + Messages []GatewayMessage `json:"messages"` + Query string `json:"query"` +} + +type agentGatewayResponse struct { + Messages []GatewayMessage `json:"messages"` +} + +func (r *Resolver) postChat(ctx context.Context, payload agentGatewayRequest) (agentGatewayResponse, error) { + body, err := json.Marshal(payload) + if err != nil { + return agentGatewayResponse{}, err + } + url := r.gatewayBaseURL + "/chat" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return agentGatewayResponse{}, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := r.httpClient.Do(req) + if err != nil { + return agentGatewayResponse{}, err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + payload, _ := io.ReadAll(resp.Body) + return agentGatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(payload))) } - modelID = strings.TrimSpace(modelID) - providerType = strings.ToLower(strings.TrimSpace(providerType)) + var parsed agentGatewayResponse + if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil { + return agentGatewayResponse{}, err + } + return parsed, nil +} - // If no model specified, try to get default chat model - if modelID == "" && providerType == "" { +func (r *Resolver) streamChat(ctx context.Context, payload agentGatewayRequest, userID, query string, chunkChan chan<- StreamChunk) error { + body, err := json.Marshal(payload) + if err != nil { + return err + } + url := r.gatewayBaseURL + "/chat/stream" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + + resp, err := r.streamingClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + payload, _ := io.ReadAll(resp.Body) + return fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(payload))) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || !strings.HasPrefix(line, "data:") { + continue + } + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if data == "" || data == "[DONE]" { + continue + } + chunkChan <- StreamChunk([]byte(data)) + + var envelope struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + } + if err := json.Unmarshal([]byte(data), &envelope); err != nil { + continue + } + if envelope.Type != "done" || len(envelope.Data) == 0 { + continue + } + var parsed agentGatewayResponse + if err := json.Unmarshal(envelope.Data, &parsed); err != nil { + continue + } + if err := r.storeHistory(ctx, userID, query, parsed.Messages); err != nil { + return err + } + } + + if err := scanner.Err(); err != nil { + return err + } + return nil +} + +func (r *Resolver) loadHistoryMessages(ctx context.Context, userID string, maxContextLoadTime int) ([]GatewayMessage, error) { + if r.queries == nil { + return nil, fmt.Errorf("history queries not configured") + } + pgUserID, err := parseUUID(userID) + if err != nil { + return nil, err + } + from := time.Now().UTC().Add(-time.Duration(normalizeMaxContextLoad(maxContextLoadTime)) * time.Minute) + rows, err := r.queries.ListHistoryByUserSince(ctx, sqlc.ListHistoryByUserSinceParams{ + User: pgUserID, + Timestamp: pgtype.Timestamptz{ + Time: from, + Valid: true, + }, + }) + if err != nil { + return nil, err + } + messages := make([]GatewayMessage, 0, len(rows)) + for _, row := range rows { + var batch []GatewayMessage + if len(row.Messages) == 0 { + continue + } + if err := json.Unmarshal(row.Messages, &batch); err != nil { + return nil, err + } + messages = append(messages, batch...) + } + return messages, nil +} + +func (r *Resolver) storeHistory(ctx context.Context, userID, query string, responseMessages []GatewayMessage) error { + if r.queries == nil { + return fmt.Errorf("history queries not configured") + } + if strings.TrimSpace(userID) == "" { + return fmt.Errorf("user id is required") + } + if strings.TrimSpace(query) == "" && len(responseMessages) == 0 { + return nil + } + userMessage := GatewayMessage{ + "role": "user", + "content": query, + } + messages := append([]GatewayMessage{userMessage}, responseMessages...) + payload, err := json.Marshal(messages) + if err != nil { + return err + } + pgUserID, err := parseUUID(userID) + if err != nil { + return err + } + _, err = r.queries.CreateHistory(ctx, sqlc.CreateHistoryParams{ + Messages: payload, + Timestamp: pgtype.Timestamptz{ + Time: time.Now().UTC(), + Valid: true, + }, + User: pgUserID, + }) + return err +} + +func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest) (models.GetResponse, sqlc.LlmProvider, error) { + if r.modelsService == nil { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured") + } + modelID := strings.TrimSpace(req.Model) + providerFilter := strings.TrimSpace(req.Provider) + + if modelID != "" && providerFilter == "" { + model, err := r.modelsService.GetByModelID(ctx, modelID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + if model.Type != models.ModelTypeChat { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model") + } + provider, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + return model, provider, nil + } + + if providerFilter == "" && modelID == "" { defaultModel, err := r.modelsService.GetByEnableAs(ctx, models.EnableAsChat) if err == nil { - provider, err := r.fetchProvider(ctx, defaultModel.LlmProviderID) + provider, err := models.FetchProviderByID(ctx, r.queries, defaultModel.LlmProviderID) if err != nil { return models.GetResponse{}, sqlc.LlmProvider{}, err } @@ -183,11 +378,10 @@ func (r *Resolver) selectChatModel(ctx context.Context, modelID, providerType st } } - // List available models var candidates []models.GetResponse var err error - if providerType != "" { - candidates, err = r.modelsService.ListByClientType(ctx, models.ClientType(providerType)) + if providerFilter != "" { + candidates, err = r.modelsService.ListByClientType(ctx, models.ClientType(providerFilter)) } else { candidates, err = r.modelsService.ListByType(ctx, models.ModelTypeChat) } @@ -195,7 +389,6 @@ func (r *Resolver) selectChatModel(ctx context.Context, modelID, providerType st return models.GetResponse{}, sqlc.LlmProvider{}, err } - // Filter chat models filtered := make([]models.GetResponse, 0, len(candidates)) for _, model := range candidates { if model.Type != models.ModelTypeChat { @@ -204,92 +397,60 @@ func (r *Resolver) selectChatModel(ctx context.Context, modelID, providerType st filtered = append(filtered, model) } if len(filtered) == 0 { - return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("no chat models available") + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available") } - // If model specified, find it if modelID != "" { for _, model := range filtered { if model.ModelID == modelID { - provider, err := r.fetchProvider(ctx, model.LlmProviderID) + provider, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID) if err != nil { return models.GetResponse{}, sqlc.LlmProvider{}, err } return model, provider, nil } } - return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("chat model not found") + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not found") } - // Return first available model selected := filtered[0] - provider, err := r.fetchProvider(ctx, selected.LlmProviderID) + provider, err := models.FetchProviderByID(ctx, r.queries, selected.LlmProviderID) if err != nil { return models.GetResponse{}, sqlc.LlmProvider{}, err } return selected, provider, nil } -// fetchProvider fetches provider information from database -func (r *Resolver) fetchProvider(ctx context.Context, providerID string) (sqlc.LlmProvider, error) { - if r.queries == nil { - return sqlc.LlmProvider{}, errors.New("llm provider queries not configured") +func normalizeMaxContextLoad(value int) int { + if value <= 0 { + return defaultMaxContextMinutes } - if strings.TrimSpace(providerID) == "" { - return sqlc.LlmProvider{}, errors.New("llm provider id missing") - } - parsed, err := uuid.Parse(providerID) - if err != nil { - return sqlc.LlmProvider{}, err - } - pgID := pgtype.UUID{Valid: true} - copy(pgID.Bytes[:], parsed[:]) - return r.queries.GetLlmProviderByID(ctx, pgID) + return value } -// createProvider creates a provider instance based on configuration -func (r *Resolver) createProvider(provider sqlc.LlmProvider) (Provider, error) { - clientType := strings.ToLower(strings.TrimSpace(provider.ClientType)) - timeout := r.timeout - if timeout <= 0 { - timeout = 30 * time.Second - } - - switch clientType { - case ProviderOpenAI, ProviderOpenAICompat: - if strings.TrimSpace(provider.ApiKey) == "" { - return nil, errors.New("openai api key is required") - } - p, err := NewOpenAIProvider(provider.ApiKey, provider.BaseUrl, timeout) - if err != nil { - return nil, err - } - return p, nil - case ProviderAnthropic: - if strings.TrimSpace(provider.ApiKey) == "" { - return nil, errors.New("anthropic api key is required") - } - p, err := NewAnthropicProvider(provider.ApiKey, timeout) - if err != nil { - return nil, err - } - return p, nil - case ProviderGoogle: - if strings.TrimSpace(provider.ApiKey) == "" { - return nil, errors.New("google api key is required") - } - p, err := NewGoogleProvider(provider.ApiKey, timeout) - if err != nil { - return nil, err - } - return p, nil - case ProviderOllama: - p, err := NewOllamaProvider(provider.BaseUrl, timeout) - if err != nil { - return nil, err - } - return p, nil +func normalizeClientType(clientType string) (string, error) { + switch strings.ToLower(strings.TrimSpace(clientType)) { + case "openai": + return "openai", nil + case "openai-compat": + return "openai", nil + case "anthropic": + return "anthropic", nil + case "google": + return "google", nil default: - return nil, errors.New("unsupported provider type: " + clientType) + return "", fmt.Errorf("unsupported agent gateway client type: %s", clientType) } } + +func parseUUID(id string) (pgtype.UUID, error) { + parsed, err := uuid.Parse(id) + if err != nil { + return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) + } + var pgID pgtype.UUID + pgID.Valid = true + copy(pgID.Bytes[:], parsed[:]) + return pgID, nil +} + diff --git a/internal/chat/types.go b/internal/chat/types.go index 90e73138..10cff90e 100644 --- a/internal/chat/types.go +++ b/internal/chat/types.go @@ -1,65 +1,32 @@ package chat -// Message represents a chat message +import "encoding/json" + type Message struct { - Role string `json:"role"` // user, assistant, system + Role string `json:"role"` Content string `json:"content"` } -// ChatRequest represents an incoming chat request +type GatewayMessage map[string]interface{} + type ChatRequest struct { - Messages []Message `json:"messages"` - Model string `json:"model,omitempty"` // optional: specific model to use - Provider string `json:"provider,omitempty"` // optional: specific provider to use - Stream bool `json:"stream,omitempty"` + UserID string `json:"-"` + Query string `json:"query"` + Model string `json:"model,omitempty"` + Provider string `json:"provider,omitempty"` + MaxContextLoadTime int `json:"max_context_load_time,omitempty"` + Locale string `json:"locale,omitempty"` + Language string `json:"language,omitempty"` + MaxSteps int `json:"max_steps,omitempty"` + Platforms []string `json:"platforms,omitempty"` + CurrentPlatform string `json:"current_platform,omitempty"` + Messages []GatewayMessage `json:"messages,omitempty"` } -// ChatResponse represents a chat response type ChatResponse struct { - Message Message `json:"message"` - Model string `json:"model"` - Provider string `json:"provider"` - FinishReason string `json:"finish_reason,omitempty"` - Usage Usage `json:"usage,omitempty"` + Messages []GatewayMessage `json:"messages"` + Model string `json:"model,omitempty"` + Provider string `json:"provider,omitempty"` } -// StreamChunk represents a chunk in streaming response -type StreamChunk struct { - Delta struct { - Content string `json:"content"` - } `json:"delta"` - Model string `json:"model,omitempty"` - Provider string `json:"provider,omitempty"` - FinishReason string `json:"finish_reason,omitempty"` -} - -// Usage represents token usage information -type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -// Request is the internal request structure -type Request struct { - Messages []Message - Model string - Provider string - Temperature *float32 // optional temperature - ResponseFormat *ResponseFormat // optional response format - MaxTokens *int // optional max tokens -} - -// ResponseFormat specifies the format of the response -type ResponseFormat struct { - Type string `json:"type"` // "json_object" or "text" -} - -// Result is the internal result structure -type Result struct { - Message Message - Model string - Provider string - FinishReason string - Usage Usage -} +type StreamChunk = json.RawMessage diff --git a/internal/config/config.go b/internal/config/config.go index b5dd0c93..3b7b74da 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,7 @@ package config import ( + "fmt" "os" "github.com/BurntSushi/toml" @@ -31,6 +32,7 @@ type Config struct { MCP MCPConfig `toml:"mcp"` Postgres PostgresConfig `toml:"postgres"` Qdrant QdrantConfig `toml:"qdrant"` + AgentGateway AgentGatewayConfig `toml:"agent_gateway"` } type ServerConfig struct { @@ -70,6 +72,23 @@ type QdrantConfig struct { TimeoutSeconds int `toml:"timeout_seconds"` } +type AgentGatewayConfig struct { + Host string `toml:"host"` + Port int `toml:"port"` +} + +func (c AgentGatewayConfig) BaseURL() string { + host := c.Host + if host == "" { + host = "127.0.0.1" + } + port := c.Port + if port == 0 { + port = 8081 + } + return "http://" + host + ":" + fmt.Sprint(port) +} + func Load(path string) (Config, error) { cfg := Config{ Server: ServerConfig{ @@ -98,6 +117,10 @@ func Load(path string) (Config, error) { BaseURL: DefaultQdrantURL, Collection: DefaultQdrantCollection, }, + AgentGateway: AgentGatewayConfig{ + Host: "127.0.0.1", + Port: 8081, + }, } if path == "" { diff --git a/internal/db/sqlc/history.sql.go b/internal/db/sqlc/history.sql.go new file mode 100644 index 00000000..7d2712fa --- /dev/null +++ b/internal/db/sqlc/history.sql.go @@ -0,0 +1,73 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: history.sql + +package sqlc + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createHistory = `-- name: CreateHistory :one +INSERT INTO history (messages, timestamp, "user") +VALUES ($1, $2, $3) +RETURNING id, messages, timestamp, "user" +` + +type CreateHistoryParams struct { + Messages []byte `json:"messages"` + Timestamp pgtype.Timestamptz `json:"timestamp"` + User pgtype.UUID `json:"user"` +} + +func (q *Queries) CreateHistory(ctx context.Context, arg CreateHistoryParams) (History, error) { + row := q.db.QueryRow(ctx, createHistory, arg.Messages, arg.Timestamp, arg.User) + var i History + err := row.Scan( + &i.ID, + &i.Messages, + &i.Timestamp, + &i.User, + ) + return i, err +} + +const listHistoryByUserSince = `-- name: ListHistoryByUserSince :many +SELECT id, messages, timestamp, "user" +FROM history +WHERE "user" = $1 AND timestamp >= $2 +ORDER BY timestamp ASC +` + +type ListHistoryByUserSinceParams struct { + User pgtype.UUID `json:"user"` + Timestamp pgtype.Timestamptz `json:"timestamp"` +} + +func (q *Queries) ListHistoryByUserSince(ctx context.Context, arg ListHistoryByUserSinceParams) ([]History, error) { + rows, err := q.db.Query(ctx, listHistoryByUserSince, arg.User, arg.Timestamp) + if err != nil { + return nil, err + } + defer rows.Close() + var items []History + for rows.Next() { + var i History + if err := rows.Scan( + &i.ID, + &i.Messages, + &i.Timestamp, + &i.User, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/db/sqlc/models.go b/internal/db/sqlc/models.go index 3714b95c..123ea448 100644 --- a/internal/db/sqlc/models.go +++ b/internal/db/sqlc/models.go @@ -33,6 +33,13 @@ type ContainerVersion struct { CreatedAt pgtype.Timestamptz `json:"created_at"` } +type History struct { + ID pgtype.UUID `json:"id"` + Messages []byte `json:"messages"` + Timestamp pgtype.Timestamptz `json:"timestamp"` + User pgtype.UUID `json:"user"` +} + type LifecycleEvent struct { ID string `json:"id"` ContainerID string `json:"container_id"` diff --git a/internal/handlers/chat.go b/internal/handlers/chat.go index f3371941..aa7329c8 100644 --- a/internal/handlers/chat.go +++ b/internal/handlers/chat.go @@ -7,7 +7,10 @@ import ( "net/http" "github.com/labstack/echo/v4" + + "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/chat" + "github.com/memohai/memoh/internal/identity" ) type ChatHandler struct { @@ -36,14 +39,20 @@ func (h *ChatHandler) Register(e *echo.Echo) { // @Failure 500 {object} ErrorResponse // @Router /chat [post] func (h *ChatHandler) Chat(c echo.Context) error { + userID, err := h.requireUserID(c) + if err != nil { + return err + } + var req chat.ChatRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if len(req.Messages) == 0 { - return echo.NewHTTPError(http.StatusBadRequest, "messages are required") + if req.Query == "" { + return echo.NewHTTPError(http.StatusBadRequest, "query is required") } + req.UserID = userID resp, err := h.resolver.Chat(c.Request().Context(), req) if err != nil { @@ -65,14 +74,20 @@ func (h *ChatHandler) Chat(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /chat/stream [post] func (h *ChatHandler) StreamChat(c echo.Context) error { + userID, err := h.requireUserID(c) + if err != nil { + return err + } + var req chat.ChatRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if len(req.Messages) == 0 { - return echo.NewHTTPError(http.StatusBadRequest, "messages are required") + if req.Query == "" { + return echo.NewHTTPError(http.StatusBadRequest, "query is required") } + req.UserID = userID // Set headers for SSE c.Response().Header().Set(echo.HeaderContentType, "text/event-stream") @@ -127,3 +142,14 @@ func (h *ChatHandler) StreamChat(c echo.Context) error { } } } + +func (h *ChatHandler) requireUserID(c echo.Context) (string, error) { + userID, err := auth.UserIDFromContext(c) + if err != nil { + return "", err + } + if err := identity.ValidateUserID(userID); err != nil { + return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + return userID, nil +} diff --git a/internal/memory/llm_provider_client.go b/internal/memory/llm_provider_client.go deleted file mode 100644 index 426d3694..00000000 --- a/internal/memory/llm_provider_client.go +++ /dev/null @@ -1,129 +0,0 @@ -package memory - -import ( - "context" - "encoding/json" - "fmt" - "strings" - - "github.com/memohai/memoh/internal/chat" -) - -// ProviderLLMClient uses chat.Provider to make LLM calls for memory operations -type ProviderLLMClient struct { - provider chat.Provider - model string -} - -// NewProviderLLMClient creates a new LLM client that uses chat.Provider -func NewProviderLLMClient(provider chat.Provider, model string) *ProviderLLMClient { - if model == "" { - model = "gpt-4.1-nano-2025-04-14" - } - return &ProviderLLMClient{ - provider: provider, - model: model, - } -} - -// Extract extracts facts from messages using the provider -func (c *ProviderLLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) { - if len(req.Messages) == 0 { - return ExtractResponse{}, fmt.Errorf("messages is required") - } - - parsedMessages := parseMessages(formatMessages(req.Messages)) - systemPrompt, userPrompt := getFactRetrievalMessages(parsedMessages) - - // Call provider with JSON mode - temp := float32(0) - result, err := c.provider.Chat(ctx, chat.Request{ - Model: c.model, - Temperature: &temp, - ResponseFormat: &chat.ResponseFormat{ - Type: "json_object", - }, - Messages: []chat.Message{ - {Role: "system", Content: systemPrompt}, - {Role: "user", Content: userPrompt}, - }, - }) - if err != nil { - return ExtractResponse{}, err - } - - content := result.Message.Content - var parsed ExtractResponse - if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &parsed); err != nil { - return ExtractResponse{}, err - } - return parsed, nil -} - -// Decide decides what actions to take based on facts and existing memories -func (c *ProviderLLMClient) Decide(ctx context.Context, req DecideRequest) (DecideResponse, error) { - if len(req.Facts) == 0 { - return DecideResponse{}, fmt.Errorf("facts is required") - } - - retrieved := make([]map[string]string, 0, len(req.Candidates)) - for _, candidate := range req.Candidates { - retrieved = append(retrieved, map[string]string{ - "id": candidate.ID, - "text": candidate.Memory, - }) - } - - prompt := getUpdateMemoryMessages(retrieved, req.Facts) - - // Call provider with JSON mode - temp := float32(0) - result, err := c.provider.Chat(ctx, chat.Request{ - Model: c.model, - Temperature: &temp, - ResponseFormat: &chat.ResponseFormat{ - Type: "json_object", - }, - Messages: []chat.Message{ - {Role: "user", Content: prompt}, - }, - }) - if err != nil { - return DecideResponse{}, err - } - - content := result.Message.Content - var raw map[string]interface{} - if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &raw); err != nil { - return DecideResponse{}, err - } - - memoryItems := normalizeMemoryItems(raw["memory"]) - actions := make([]DecisionAction, 0, len(memoryItems)) - for _, item := range memoryItems { - event := strings.ToUpper(asString(item["event"])) - if event == "" { - event = "ADD" - } - if event == "NONE" { - continue - } - - text := asString(item["text"]) - if text == "" { - text = asString(item["fact"]) - } - if strings.TrimSpace(text) == "" { - continue - } - - actions = append(actions, DecisionAction{ - Event: event, - ID: normalizeID(item["id"]), - Text: text, - OldMemory: asString(item["old_memory"]), - }) - } - return DecideResponse{Actions: actions}, nil -} - diff --git a/internal/server/server.go b/internal/server/server.go index bfcff7a4..4f4c1a40 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -15,7 +15,7 @@ type Server struct { addr string } -func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, swaggerHandler *handlers.SwaggerHandler, chatHandler *handlers.ChatHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, containerdHandler *handlers.ContainerdHandler) *Server { +func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, chatHandler *handlers.ChatHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, containerdHandler *handlers.ContainerdHandler) *Server { if addr == "" { addr = ":8080" } @@ -50,12 +50,12 @@ func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, if embeddingsHandler != nil { embeddingsHandler.Register(e) } - if swaggerHandler != nil { - swaggerHandler.Register(e) - } if chatHandler != nil { chatHandler.Register(e) } + if swaggerHandler != nil { + swaggerHandler.Register(e) + } if providersHandler != nil { providersHandler.Register(e) }