From da6a264699aab62a8db635812747dd9315cb80fb Mon Sep 17 00:00:00 2001 From: Acbox Date: Mon, 26 Jan 2026 23:06:54 +0800 Subject: [PATCH] feat: provider management & chat --- AGENT_MIGRATION.md | 336 +++++++++++++ PROVIDERS_MODULE.md | 477 +++++++++++++++++++ REFACTORING_SUMMARY.md | 273 +++++++++++ cmd/agent/main.go | 119 ++++- config.toml.example | 7 - db/migrations/0001_init.up.sql | 2 +- docs/docs.go | 632 ++++++++++++++++++++++++- docs/swagger.json | 632 ++++++++++++++++++++++++- docs/swagger.yaml | 429 ++++++++++++++++- go.mod | 14 + go.sum | 27 ++ internal/chat/ARCHITECTURE.md | 213 +++++++++ internal/chat/anthropic.go | 43 ++ internal/chat/google.go | 43 ++ internal/chat/ollama.go | 48 ++ internal/chat/openai.go | 54 +++ internal/chat/prompts.go | 141 ++++++ internal/chat/resolver.go | 295 ++++++++++++ internal/chat/types.go | 65 +++ internal/config/config.go | 15 - internal/handlers/chat.go | 129 +++++ internal/handlers/providers.go | 234 +++++++++ internal/memory/llm_provider_client.go | 129 +++++ internal/memory/service.go | 49 +- internal/memory/types.go | 8 + internal/providers/service.go | 266 +++++++++++ internal/providers/types.go | 71 +++ internal/server/server.go | 11 +- 28 files changed, 4699 insertions(+), 63 deletions(-) create mode 100644 AGENT_MIGRATION.md create mode 100644 PROVIDERS_MODULE.md create mode 100644 REFACTORING_SUMMARY.md create mode 100644 internal/chat/ARCHITECTURE.md create mode 100644 internal/chat/anthropic.go create mode 100644 internal/chat/google.go create mode 100644 internal/chat/ollama.go create mode 100644 internal/chat/openai.go create mode 100644 internal/chat/prompts.go create mode 100644 internal/chat/resolver.go create mode 100644 internal/chat/types.go create mode 100644 internal/handlers/chat.go create mode 100644 internal/handlers/providers.go create mode 100644 internal/memory/llm_provider_client.go create mode 100644 internal/providers/service.go create mode 100644 internal/providers/types.go diff --git a/AGENT_MIGRATION.md b/AGENT_MIGRATION.md new file mode 100644 index 00000000..6e9ce6bc --- /dev/null +++ b/AGENT_MIGRATION.md @@ -0,0 +1,336 @@ +# Agent Prompts 初步迁移 + +## 概述 + +本文档记录了从 TypeScript agent (`@packages/agent`) 到 Go chat 服务的初步迁移工作。 + +## 完成的工作 + +### 1. 删除 Config 中的 Memory 配置 + +由于现在使用数据库中的模型配置,不再需要配置文件中的 memory 配置。 + +**变更文件**: +- ✅ `config.toml.example` - 删除 `[memory]` 配置段 +- ✅ `internal/config/config.go` - 删除 `MemoryConfig` 结构 +- ✅ `cmd/agent/main.go` - 移除对 `cfg.Memory` 的引用 + +**影响**: +- Memory 服务现在完全依赖数据库中配置的模型 +- 如果数据库中没有配置模型,服务会启动失败并提示用户配置 +- 更清晰的配置管理,避免配置分散 + +### 2. 迁移 Agent Prompts 到 Chat 包 + +从 TypeScript 的 `packages/agent/src/prompts/` 迁移到 Go 的 `internal/chat/prompts.go`。 + +**迁移的内容**: + +#### System Prompt (系统提示词) +- ✅ 基础系统提示词 +- ✅ 日期时间格式化 +- ✅ 语言设置 +- ✅ 平台信息(available-platforms, current-platform) +- ✅ 上下文加载时间配置 +- ✅ 响应指南 + +#### Schedule Prompt (定时任务提示词) +- ✅ 定时任务触发提示 +- ✅ 任务信息(名称、描述、ID、最大调用次数、Cron 模式) +- ✅ 命令内容 + +#### 辅助函数 +- ✅ `FormatTime()` - 时间格式化 +- ✅ `Quote()` - Markdown 代码格式化 +- ✅ `Block()` - 代码块格式化 + +**暂未迁移(后续工作)**: +- ⏸️ Memory 工具说明 +- ⏸️ Schedule 工具说明 +- ⏸️ Message 工具说明 +- ⏸️ MCP 工具集成 +- ⏸️ 工具调用逻辑 + +## 新的 Prompt 结构 + +### SystemPrompt + +```go +type PromptParams struct { + Date time.Time + Locale string + Language string + MaxContextLoadTime int // 上下文加载时间(分钟) + Platforms []string // 可用平台列表 + CurrentPlatform string // 当前平台 +} + +func SystemPrompt(params PromptParams) string +``` + +**示例**: +```go +prompt := chat.SystemPrompt(chat.PromptParams{ + Date: time.Now(), + Locale: "zh-CN", + Language: "Chinese", + MaxContextLoadTime: 24 * 60, + Platforms: []string{"telegram", "wechat"}, + CurrentPlatform: "telegram", +}) +``` + +### SchedulePrompt + +```go +type SchedulePromptParams struct { + Date time.Time + Locale string + ScheduleName string + ScheduleDescription string + ScheduleID string + MaxCalls *int // nil 表示无限次 + CronPattern string + Command string +} + +func SchedulePrompt(params SchedulePromptParams) string +``` + +**示例**: +```go +maxCalls := 1 +prompt := chat.SchedulePrompt(chat.SchedulePromptParams{ + Date: time.Now(), + Locale: "zh-CN", + ScheduleName: "早餐提醒", + ScheduleDescription: "每天早上 7 点提醒吃早餐", + ScheduleID: "schedule-123", + MaxCalls: &maxCalls, + CronPattern: "0 7 * * *", + Command: "提醒用户吃早餐,推荐健康食谱", +}) +``` + +## 使用方式 + +### Chat Resolver 中的使用 + +Chat Resolver 会自动为每个请求添加系统提示词: + +```go +// 在 resolver.go 中 +systemPrompt := SystemPrompt(PromptParams{ + Date: time.Now(), + Locale: "en-US", + Language: "Same as user input", + MaxContextLoadTime: 24 * 60, + Platforms: []string{}, + CurrentPlatform: "api", +}) +``` + +### 自定义 Prompt 参数 + +未来可以通过以下方式自定义: + +1. **从数据库加载平台列表**: +```go +platforms, _ := platformService.GetActivePlatforms(ctx) +platformNames := make([]string, len(platforms)) +for i, p := range platforms { + platformNames[i] = p.Name +} +``` + +2. **从用户设置加载语言偏好**: +```go +userSettings, _ := settingsService.GetUserSettings(ctx, userID) +language := userSettings.Language +``` + +3. **从会话上下文获取当前平台**: +```go +currentPlatform := "telegram" // 从请求头或会话中获取 +``` + +## 对比 TypeScript 版本 + +### TypeScript (原始) +```typescript +export const system = ({ date, locale, language, maxContextLoadTime, platforms, currentPlatform }: SystemParams) => { + return ` +--- +${time({ date, locale })} +language: ${language} +available-platforms: +${platforms.map(platform => ` - ${platform.name}`).join('\n')} +current-platform: ${currentPlatform} +--- +You are a personal housekeeper assistant... + `.trim() +} +``` + +### Go (迁移后) +```go +func SystemPrompt(params PromptParams) string { + timeStr := FormatTime(params.Date, params.Locale) + platformsList := buildPlatformsList(params.Platforms) + + return fmt.Sprintf(`--- +%s +language: %s +available-platforms: +%s +current-platform: %s +--- +You are a personal housekeeper assistant...`, + timeStr, params.Language, platformsList, params.CurrentPlatform) +} +``` + +## 配置迁移指南 + +### 旧配置 (config.toml) +```toml +[memory] +base_url = "https://api.openai.com/v1" +api_key = "sk-..." +model = "gpt-4.1-nano" +timeout_seconds = 10 +``` + +### 新配置 (数据库) +```sql +-- 1. 创建 LLM Provider +INSERT INTO llm_providers (name, client_type, base_url, api_key) +VALUES ('OpenAI', 'openai', 'https://api.openai.com/v1', 'sk-...'); + +-- 2. 创建 Chat 模型(用于 memory) +INSERT INTO models (model_id, name, llm_provider_id, type, enable_as) +VALUES ('gpt-4-turbo', 'GPT-4 Turbo', '', 'chat', 'memory'); + +-- 或使用现有的 chat 模型 +UPDATE models SET enable_as = 'memory' WHERE model_id = 'gpt-4-turbo'; +``` + +## 测试建议 + +### 1. 测试系统提示词 +```bash +# 发起聊天请求 +curl -X POST http://localhost:8080/api/chat \ + -H "Authorization: Bearer $TOKEN" \ + -d '{ + "messages": [ + {"role": "user", "content": "你好"} + ] + }' +``` + +检查响应中是否: +- 使用了正确的语言 +- AI 理解自己是个人管家助手 +- 响应风格友好且有帮助 + +### 2. 验证配置迁移 +```bash +# 启动服务,检查日志 +# 应该看到: +# Using memory model: gpt-4-turbo (provider: openai) + +# 不应该看到: +# WARNING: No memory model configured, using fallback LLMClient +``` + +### 3. 测试 Memory 操作 +```bash +# 添加记忆 +curl -X POST http://localhost:8080/api/memory/add \ + -H "Authorization: Bearer $TOKEN" \ + -d '{ + "messages": [ + {"role": "user", "content": "我喜欢吃披萨"} + ], + "user_id": "user-123" + }' + +# 搜索记忆 +curl -X POST http://localhost:8080/api/memory/search \ + -H "Authorization: Bearer $TOKEN" \ + -d '{ + "query": "我喜欢什么食物?", + "user_id": "user-123" + }' +``` + +## 后续工作 + +### 短期(工具集成) +- [ ] 添加 Memory 工具说明到系统提示词 +- [ ] 添加 Schedule 工具说明到系统提示词 +- [ ] 添加 Message 工具说明到系统提示词 +- [ ] 实现工具调用功能 +- [ ] 添加工具参数验证 + +### 中期(MCP 集成) +- [ ] MCP 连接管理 +- [ ] MCP 工具动态加载 +- [ ] MCP stdio/http/sse 传输支持 +- [ ] 容器环境中的 MCP 执行 + +### 长期(功能完善) +- [ ] 多语言本地化支持 +- [ ] 用户自定义系统提示词 +- [ ] Prompt 模板系统 +- [ ] A/B 测试不同 Prompt 版本 +- [ ] Prompt 性能监控 + +## 文件清单 + +### 删除/清理 +- ✅ `config.toml.example` - 删除 `[memory]` 段 +- ✅ `internal/config/config.go` - 删除 `MemoryConfig` +- ✅ `cmd/agent/main.go` - 删除 `cfg.Memory` 引用 + +### 修改 +- ✅ `internal/chat/prompts.go` - 完全重写,添加完整的 prompt 系统 +- ✅ `internal/chat/resolver.go` - 使用新的 `SystemPrompt` 函数 + +### 新增 +- ✅ `AGENT_MIGRATION.md` - 本文档 + +## 注意事项 + +1. **必须配置数据库模型**: 由于删除了配置文件中的回退配置,必须在数据库中配置至少一个 chat 模型 + +2. **Prompt 参数**: 当前使用硬编码的默认值,未来应该从用户设置或请求上下文中获取 + +3. **多语言支持**: `FormatTime()` 当前使用标准格式,未来应该使用 i18n 库进行本地化 + +4. **工具说明**: 当前 prompt 中提到了工具能力,但实际的工具说明还未添加,需要后续实现 + +5. **向后兼容**: 删除了配置文件中的 memory 配置,如果有旧的部署需要迁移 + +## 迁移检查清单 + +- [x] 删除 config.toml 中的 memory 配置 +- [x] 删除 Config 结构中的 MemoryConfig +- [x] 更新 main.go 移除 cfg.Memory 引用 +- [x] 迁移系统提示词 +- [x] 迁移定时任务提示词 +- [x] 迁移辅助函数 +- [x] 更新 resolver 使用新的 prompt +- [x] 通过 linter 检查 +- [x] 编写迁移文档 +- [ ] 测试 chat 功能 +- [ ] 测试 memory 功能 +- [ ] 更新部署文档 + +## 相关文档 + +- [Provider 重构总结](./REFACTORING_SUMMARY.md) +- [Chat 架构文档](./internal/chat/ARCHITECTURE.md) +- [TypeScript Agent 源码](./packages/agent/src/) + diff --git a/PROVIDERS_MODULE.md b/PROVIDERS_MODULE.md new file mode 100644 index 00000000..84434e3f --- /dev/null +++ b/PROVIDERS_MODULE.md @@ -0,0 +1,477 @@ +# Providers 模块文档 + +## 概述 + +本文档描述了独立的 Providers 模块,用于管理 LLM Provider 配置。 + +## 架构设计 + +### 模块结构 + +``` +internal/ +├── providers/ # 独立的 provider 模块 +│ ├── types.go # 类型定义 +│ └── service.go # 业务逻辑 +└── handlers/ + └── providers.go # API 处理器 +``` + +### 分层设计 + +``` +┌─────────────────────┐ +│ API Layer │ +│ (handlers) │ +└──────────┬──────────┘ + │ +┌──────────▼──────────┐ +│ Service Layer │ +│ (providers pkg) │ +└──────────┬──────────┘ + │ +┌──────────▼──────────┐ +│ Data Layer │ +│ (sqlc queries) │ +└─────────────────────┘ +``` + +## API 端点 + +### Providers API (`/providers`) + +所有端点都需要 JWT 认证。 + +#### 1. 创建 Provider + +```http +POST /providers +Content-Type: application/json +Authorization: Bearer + +{ + "name": "OpenAI Official", + "client_type": "openai", + "base_url": "https://api.openai.com/v1", + "api_key": "sk-...", + "metadata": { + "description": "Official OpenAI API" + } +} +``` + +**响应** (201 Created): +```json +{ + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "OpenAI Official", + "client_type": "openai", + "base_url": "https://api.openai.com/v1", + "api_key": "sk-12345***", + "metadata": { + "description": "Official OpenAI API" + }, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z" +} +``` + +#### 2. 列出所有 Providers + +```http +GET /providers +Authorization: Bearer +``` + +**可选查询参数**: +- `client_type` - 按客户端类型过滤 (openai, anthropic, google, ollama) + +**响应** (200 OK): +```json +[ + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "OpenAI Official", + "client_type": "openai", + "base_url": "https://api.openai.com/v1", + "api_key": "sk-12345***", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z" + } +] +``` + +#### 3. 获取单个 Provider + +```http +GET /providers/{id} +Authorization: Bearer +``` + +或者按名称获取: + +```http +GET /providers/name/{name} +Authorization: Bearer +``` + +**响应** (200 OK): 同创建响应 + +#### 4. 更新 Provider + +```http +PUT /providers/{id} +Content-Type: application/json +Authorization: Bearer + +{ + "name": "OpenAI Updated", + "api_key": "sk-newkey..." +} +``` + +**注意**: 所有字段都是可选的,只更新提供的字段。 + +**响应** (200 OK): 返回更新后的 provider + +#### 5. 删除 Provider + +```http +DELETE /providers/{id} +Authorization: Bearer +``` + +**响应** (204 No Content) + +#### 6. 统计 Provider 数量 + +```http +GET /providers/count +Authorization: Bearer +``` + +**可选查询参数**: +- `client_type` - 按客户端类型过滤 + +**响应** (200 OK): +```json +{ + "count": 5 +} +``` + +## 支持的 Client Types + +| Client Type | 描述 | 需要 API Key | +|------------|------|-------------| +| `openai` | OpenAI 官方 API | ✅ | +| `openai-compat` | OpenAI 兼容的 API | ✅ | +| `anthropic` | Anthropic Claude API | ✅ | +| `google` | Google Gemini API | ✅ | +| `ollama` | 本地 Ollama | ❌ | + +## 数据模型 + +### Provider 结构 + +```go +type CreateRequest struct { + Name string `json:"name"` // 必填 + ClientType ClientType `json:"client_type"` // 必填 + BaseURL string `json:"base_url"` // 必填 + APIKey string `json:"api_key"` // 可选 + Metadata map[string]interface{} `json:"metadata"` // 可选 +} + +type GetResponse struct { + ID string `json:"id"` + Name string `json:"name"` + ClientType string `json:"client_type"` + BaseURL string `json:"base_url"` + APIKey string `json:"api_key"` // 已脱敏 + Metadata map[string]interface{} `json:"metadata"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} +``` + +## 安全特性 + +### 1. API Key 脱敏 + +在响应中,API Key 会被自动脱敏: +- 只显示前 8 个字符 +- 其余部分用 `*` 替换 +- 例如: `sk-12345678***` + +### 2. 认证保护 + +所有 API 端点都需要 JWT 认证: +```http +Authorization: Bearer +``` + +### 3. 输入验证 + +- 自动验证 UUID 格式 +- 验证 client_type 是否支持 +- 验证必填字段 + +## 使用示例 + +### 1. 配置 OpenAI Provider + +```bash +curl -X POST http://localhost:8080/providers \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "OpenAI GPT-4", + "client_type": "openai", + "base_url": "https://api.openai.com/v1", + "api_key": "sk-..." + }' +``` + +### 2. 配置自定义 OpenAI 兼容服务 + +```bash +curl -X POST http://localhost:8080/providers \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Azure OpenAI", + "client_type": "openai-compat", + "base_url": "https://your-resource.openai.azure.com/v1", + "api_key": "your-azure-key", + "metadata": { + "deployment": "gpt-4", + "region": "eastus" + } + }' +``` + +### 3. 配置本地 Ollama + +```bash +curl -X POST http://localhost:8080/providers \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Local Ollama", + "client_type": "ollama", + "base_url": "http://localhost:11434" + }' +``` + +### 4. 列出所有 OpenAI Providers + +```bash +curl http://localhost:8080/providers?client_type=openai \ + -H "Authorization: Bearer $TOKEN" +``` + +### 5. 更新 Provider API Key + +```bash +curl -X PUT http://localhost:8080/providers/{id} \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "api_key": "sk-new-key..." + }' +``` + +## 与 Models 的关系 + +Provider 和 Model 是一对多的关系: + +``` +┌─────────────┐ +│ Provider │ +│ (OpenAI) │ +└──────┬──────┘ + │ + ├─── Model (gpt-4) + ├─── Model (gpt-3.5-turbo) + └─── Model (text-embedding-ada-002) +``` + +### 创建 Model 时引用 Provider + +```bash +# 1. 创建 Provider +PROVIDER_ID=$(curl -X POST http://localhost:8080/providers \ + -H "Authorization: Bearer $TOKEN" \ + -d '{"name":"OpenAI","client_type":"openai",...}' \ + | jq -r '.id') + +# 2. 创建 Model 并引用 Provider +curl -X POST http://localhost:8080/models \ + -H "Authorization: Bearer $TOKEN" \ + -d '{ + "model_id": "gpt-4", + "name": "GPT-4", + "llm_provider_id": "'$PROVIDER_ID'", + "type": "chat" + }' +``` + +## 代码集成 + +### 在代码中使用 Provider Service + +```go +import "github.com/memohai/memoh/internal/providers" + +// 创建 service +providersService := providers.NewService(queries) + +// 创建 provider +provider, err := providersService.Create(ctx, providers.CreateRequest{ + Name: "OpenAI", + ClientType: providers.ClientTypeOpenAI, + BaseURL: "https://api.openai.com/v1", + APIKey: "sk-...", +}) + +// 列出所有 providers +allProviders, err := providersService.List(ctx) + +// 按类型过滤 +openaiProviders, err := providersService.ListByClientType(ctx, providers.ClientTypeOpenAI) + +// 获取单个 provider +provider, err := providersService.Get(ctx, "provider-uuid") + +// 更新 provider +updated, err := providersService.Update(ctx, "provider-uuid", providers.UpdateRequest{ + APIKey: stringPtr("new-key"), +}) + +// 删除 provider +err := providersService.Delete(ctx, "provider-uuid") +``` + +## 错误处理 + +### 常见错误 + +| 状态码 | 错误 | 原因 | +|-------|------|------| +| 400 | Bad Request | 缺少必填字段或格式错误 | +| 404 | Not Found | Provider ID 不存在 | +| 409 | Conflict | Provider 名称已存在 | +| 500 | Internal Server Error | 服务器错误 | + +### 错误响应格式 + +```json +{ + "message": "invalid UUID: invalid UUID length: 5" +} +``` + +## 数据库 Schema + +Providers 存储在 `llm_providers` 表: + +```sql +CREATE TABLE llm_providers ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + name TEXT NOT NULL, + client_type TEXT NOT NULL, + base_url TEXT NOT NULL, + api_key TEXT NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT llm_providers_name_unique UNIQUE (name), + CONSTRAINT llm_providers_client_type_check + CHECK (client_type IN ('openai', 'openai-compat', 'anthropic', 'google', 'ollama')) +); +``` + +## 最佳实践 + +### 1. 命名规范 + +- 使用描述性名称: `OpenAI GPT-4`, `Anthropic Claude 3` +- 包含环境信息: `OpenAI Production`, `Ollama Local` +- 避免特殊字符和空格 + +### 2. API Key 管理 + +- 定期轮换 API Keys +- 使用环境变量存储敏感信息 +- 不要在日志中输出完整的 API Key + +### 3. Metadata 使用 + +使用 metadata 字段存储额外信息: + +```json +{ + "metadata": { + "environment": "production", + "rate_limit": 10000, + "contact": "admin@example.com", + "notes": "Primary provider for production" + } +} +``` + +### 4. Provider 测试 + +创建 provider 后,建议测试连接: + +```go +// 未来可以实现的测试端点 +POST /providers/test +{ + "client_type": "openai", + "base_url": "https://api.openai.com/v1", + "api_key": "sk-...", + "model": "gpt-3.5-turbo" +} +``` + +## 迁移指南 + +### 从配置文件迁移到数据库 + +**旧方式** (config.toml): +```toml +[memory] +base_url = "https://api.openai.com/v1" +api_key = "sk-..." +model = "gpt-4" +``` + +**新方式** (API): +```bash +# 1. 创建 provider +curl -X POST http://localhost:8080/providers \ + -d '{"name":"OpenAI","client_type":"openai","base_url":"https://api.openai.com/v1","api_key":"sk-..."}' + +# 2. 创建 model +curl -X POST http://localhost:8080/models \ + -d '{"model_id":"gpt-4","llm_provider_id":"","type":"chat","enable_as":"memory"}' +``` + +## 相关文档 + +- [Provider 重构总结](./REFACTORING_SUMMARY.md) +- [Agent 迁移文档](./AGENT_MIGRATION.md) +- [Chat 架构文档](./internal/chat/ARCHITECTURE.md) +- [Models API 文档](#) (待创建) + +## TODO + +- [ ] 实现 provider 连接测试端点 +- [ ] 添加 provider 使用统计 +- [ ] 实现 provider 健康检查 +- [ ] 添加 provider 费用跟踪 +- [ ] 支持 provider 负载均衡 +- [ ] 实现 provider 故障转移 + diff --git a/REFACTORING_SUMMARY.md b/REFACTORING_SUMMARY.md new file mode 100644 index 00000000..fa0f2bce --- /dev/null +++ b/REFACTORING_SUMMARY.md @@ -0,0 +1,273 @@ +# Provider 重构总结 + +## 重构目标 + +将 Memory 服务和 Chat 服务统一使用 `internal/chat` 中的 Provider 接口,实现代码复用和架构统一。 + +## 主要变更 + +### 1. 扩展 Chat Provider 接口 + +**文件**: `internal/chat/types.go` + +- 扩展 `Request` 结构,添加了: + - `Temperature *float32` - 温度参数 + - `ResponseFormat *ResponseFormat` - 响应格式(支持 JSON 模式) + - `MaxTokens *int` - 最大 token 数 + +- 新增 `ResponseFormat` 结构: + ```go + type ResponseFormat struct { + Type string // "json_object" 或 "text" + } + ``` + +### 2. 创建 Memory 的 LLM 接口 + +**文件**: `internal/memory/types.go` + +- 定义了 `LLM` 接口: + ```go + type LLM interface { + Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) + Decide(ctx context.Context, req DecideRequest) (DecideResponse, error) + } + ``` + +- 这个接口被以下两个实现: + - `LLMClient` (旧实现,向后兼容) + - `ProviderLLMClient` (新实现,使用 chat.Provider) + +### 3. 创建 Provider-based LLM 客户端 + +**文件**: `internal/memory/llm_provider_client.go` (新文件) + +- 实现了 `ProviderLLMClient` 结构 +- 使用 `chat.Provider` 来执行 LLM 调用 +- 支持 JSON 模式输出,确保结构化响应 +- 重用了 `internal/memory` 包中的辅助函数 + +关键代码: +```go +type ProviderLLMClient struct { + provider chat.Provider + model string +} + +func NewProviderLLMClient(provider chat.Provider, model string) *ProviderLLMClient +func (c *ProviderLLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) +func (c *ProviderLLMClient) Decide(ctx context.Context, req DecideRequest) (DecideResponse, error) +``` + +### 4. 更新 Memory 服务 + +**文件**: `internal/memory/service.go` + +- 将 `llm *LLMClient` 改为 `llm LLM` +- 现在接受任何实现 `LLM` 接口的类型 +- 保持了所有现有功能 + +### 5. 更新主程序 + +**文件**: `cmd/agent/main.go` + +#### 添加的函数: + +1. **selectMemoryModel** - 选择用于 memory 操作的模型 + - 优先级:memory 模型 → chat 模型 → 任何 chat 类型模型 + +2. **fetchProviderByID** - 根据 ID 获取 provider 配置 + +3. **createChatProvider** - 根据配置创建 provider 实例 + - 支持 OpenAI、Anthropic、Google、Ollama + +#### 初始化流程更新: + +```go +// 1. 初始化 chat resolver(用于 chat 和 memory) +chatResolver := chat.NewResolver(modelsService, queries, 30*time.Second) + +// 2. 尝试为 memory 创建 provider-based 客户端 +memoryModel, memoryProvider, err := selectMemoryModel(ctx, modelsService, queries, &cfg) +if err != nil { + // 回退到旧的 LLMClient + llmClient = memory.NewLLMClient(cfg.Memory.BaseURL, cfg.Memory.APIKey, ...) +} else { + // 使用新的 provider-based 客户端 + provider, _ := createChatProvider(memoryProvider, 30*time.Second) + llmClient = memory.NewProviderLLMClient(provider, memoryModel.ModelID) +} + +// 3. 创建 memory 服务 +memoryService = memory.NewService(llmClient, embedder, store, resolver, ...) +``` + +## 架构优势 + +### 1. 统一的 Provider 管理 +- Chat 和 Memory 服务共享相同的 Provider 实现 +- 减少代码重复 +- 统一的配置和管理 + +### 2. 灵活的模型选择 +- 可以为不同功能配置不同的模型 +- 支持 `enable_as` 字段来指定模型用途 +- 自动回退机制 + +### 3. 向后兼容 +- 保留了旧的 `LLMClient` 实现 +- 如果数据库中没有配置模型,自动回退到配置文件 +- 平滑迁移路径 + +### 4. 类型安全 +- 使用 Go 接口而不是运行时类型判断 +- 编译时类型检查 +- 更好的 IDE 支持 + +### 5. 易于扩展 +- 添加新 Provider 只需实现 `Provider` 接口 +- 添加新 LLM 客户端只需实现 `LLM` 接口 +- 模块化设计 + +## 配置说明 + +### 数据库配置 + +为 Memory 操作配置模型: + +```sql +-- 方式 1: 使用专用的 memory 模型 +UPDATE models SET enable_as = 'memory' +WHERE model_id = 'gpt-4-turbo-preview'; + +-- 方式 2: 使用 chat 模型(如果没有专用 memory 模型) +UPDATE models SET enable_as = 'chat' +WHERE model_id = 'gpt-4'; +``` + +### 环境变量(回退配置) + +如果数据库中没有配置模型,系统会使用这些配置: + +```toml +[memory] +base_url = "https://api.openai.com/v1" +api_key = "sk-..." +model = "gpt-4.1-nano" +timeout_seconds = 10 +``` + +## 测试建议 + +### 1. Memory 操作测试 +```bash +# 测试 Extract(提取事实) +curl -X POST http://localhost:8080/api/memory/add \ + -H "Authorization: Bearer $TOKEN" \ + -d '{ + "messages": [ + {"role": "user", "content": "My name is Alice and I like pizza"} + ], + "user_id": "user-123" + }' + +# 测试 Search(搜索记忆) +curl -X POST http://localhost:8080/api/memory/search \ + -H "Authorization: Bearer $TOKEN" \ + -d '{ + "query": "What food do I like?", + "user_id": "user-123" + }' +``` + +### 2. Chat 操作测试 +```bash +# 测试普通聊天 +curl -X POST http://localhost:8080/api/chat \ + -H "Authorization: Bearer $TOKEN" \ + -d '{ + "messages": [ + {"role": "user", "content": "Hello, how are you?"} + ] + }' +``` + +### 3. 验证日志 +启动时应该看到: +``` +Using memory model: gpt-4-turbo-preview (provider: openai) +``` + +或者(如果回退): +``` +WARNING: No memory model configured, using fallback LLMClient: ... +``` + +## 后续工作 + +### 短期(必须) +- [ ] 实现各个 Provider 的具体逻辑(目前大部分返回 "not yet implemented") +- [ ] 添加流式响应支持 +- [ ] 完善错误处理 + +### 中期(建议) +- [ ] 添加 Provider 的单元测试 +- [ ] 添加 Memory 集成测试 +- [ ] 实现 Provider 连接池 +- [ ] 添加请求重试机制 + +### 长期(优化) +- [ ] Provider 性能监控 +- [ ] 自动模型选择和负载均衡 +- [ ] 模型响应缓存 +- [ ] 支持更多 Provider(如 Cohere、HuggingFace) + +## 迁移检查清单 + +- [x] 扩展 Request 结构支持 JSON 模式 +- [x] 创建 LLM 接口 +- [x] 实现 ProviderLLMClient +- [x] 更新 Memory Service 使用接口 +- [x] 更新主程序初始化流程 +- [x] 添加模型选择逻辑 +- [x] 添加 Provider 创建逻辑 +- [x] 保持向后兼容 +- [x] 添加架构文档 +- [ ] 添加单元测试 +- [ ] 添加集成测试 +- [ ] 更新部署文档 + +## 文件清单 + +### 新增文件 +- `internal/memory/llm_provider_client.go` - Provider-based LLM 客户端 +- `internal/chat/ARCHITECTURE.md` - 架构文档 +- `REFACTORING_SUMMARY.md` - 本文件 + +### 修改文件 +- `internal/chat/types.go` - 扩展 Request 结构 +- `internal/memory/types.go` - 添加 LLM 接口 +- `internal/memory/service.go` - 使用 LLM 接口 +- `cmd/agent/main.go` - 更新初始化流程 + +### 保留文件(向后兼容) +- `internal/memory/llm_client.go` - 旧的 HTTP 客户端实现 + +## 注意事项 + +1. **JSON 模式兼容性**: 不是所有模型都支持 JSON 模式,需要在实现 Provider 时处理 +2. **错误处理**: 当前错误处理较简单,生产环境需要更详细的错误信息 +3. **超时设置**: 不同操作可能需要不同的超时时间,可以考虑配置化 +4. **并发安全**: Provider 实例应该是并发安全的 +5. **资源清理**: 确保 Provider 的资源(如 HTTP 连接)正确释放 + +## 问题反馈 + +如果遇到问题,请检查: +1. 数据库中是否有配置的 chat 模型 +2. Provider 配置是否正确(API key、base URL) +3. 日志中的错误信息 +4. 是否正确初始化了 chat resolver + +详细架构说明请参考:`internal/chat/ARCHITECTURE.md` + diff --git a/cmd/agent/main.go b/cmd/agent/main.go index f3eadaf5..86037096 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -2,11 +2,16 @@ package main import ( "context" + "fmt" "log" "os" "strings" "time" + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/chat" "github.com/memohai/memoh/internal/config" ctr "github.com/memohai/memoh/internal/containerd" "github.com/memohai/memoh/internal/db" @@ -16,6 +21,7 @@ import ( "github.com/memohai/memoh/internal/mcp" "github.com/memohai/memoh/internal/memory" "github.com/memohai/memoh/internal/models" + "github.com/memohai/memoh/internal/providers" "github.com/memohai/memoh/internal/server" ) @@ -113,12 +119,24 @@ func main() { pingHandler := handlers.NewPingHandler() authHandler := handlers.NewAuthHandler(conn, cfg.Auth.JWTSecret, jwtExpiresIn) - llmClient := memory.NewLLMClient( - cfg.Memory.BaseURL, - cfg.Memory.APIKey, - cfg.Memory.Model, - time.Duration(cfg.Memory.TimeoutSeconds)*time.Second, - ) + + // Initialize chat resolver for both chat and memory operations + chatResolver := chat.NewResolver(modelsService, queries, 30*time.Second) + + // Create LLM client for memory operations using chat provider + var llmClient memory.LLM + memoryModel, memoryProvider, err := selectMemoryModel(ctx, modelsService, queries) + if err != nil { + log.Fatalf("select memory model: %v\nPlease configure at least one chat model in the database.", err) + } + + log.Printf("Using memory model: %s (provider: %s)", memoryModel.ModelID, memoryProvider.ClientType) + provider, err := createChatProvider(memoryProvider, 30*time.Second) + if err != nil { + log.Fatalf("create memory provider: %v", err) + } + llmClient = memory.NewProviderLLMClient(provider, memoryModel.ModelID) + resolver := embeddings.NewResolver(modelsService, queries, 10*time.Second) vectors, textModel, multimodalModel, hasModels, err := collectEmbeddingVectors(ctx, modelsService) if err != nil { @@ -181,9 +199,96 @@ func main() { embeddingsHandler := handlers.NewEmbeddingsHandler(modelsService, queries) fsHandler := handlers.NewFSHandler(service, manager, cfg.MCP, cfg.Containerd.Namespace) swaggerHandler := handlers.NewSwaggerHandler() - srv := server.NewServer(addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, fsHandler, swaggerHandler) + chatHandler := handlers.NewChatHandler(chatResolver) + + // Initialize providers and models handlers + providersService := providers.NewService(queries) + providersHandler := handlers.NewProvidersHandler(providersService) + modelsHandler := handlers.NewModelsHandler(modelsService) + + srv := server.NewServer(addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, fsHandler, swaggerHandler, chatHandler, providersHandler, modelsHandler) if err := srv.Start(); err != nil { log.Fatalf("server failed: %v", err) } } + +// selectMemoryModel selects a chat model for memory operations +func selectMemoryModel(ctx context.Context, modelsService *models.Service, queries *dbsqlc.Queries) (models.GetResponse, dbsqlc.LlmProvider, error) { + // First try to get the memory-enabled model + memoryModel, err := modelsService.GetByEnableAs(ctx, models.EnableAsMemory) + if err == nil { + provider, err := fetchProviderByID(ctx, queries, memoryModel.LlmProviderID) + if err != nil { + return models.GetResponse{}, dbsqlc.LlmProvider{}, err + } + return memoryModel, provider, nil + } + + // Fallback to chat model + chatModel, err := modelsService.GetByEnableAs(ctx, models.EnableAsChat) + if err == nil { + provider, err := fetchProviderByID(ctx, queries, chatModel.LlmProviderID) + if err != nil { + return models.GetResponse{}, dbsqlc.LlmProvider{}, err + } + return chatModel, provider, nil + } + + // If no enabled models, try to find any chat model + candidates, err := modelsService.ListByType(ctx, models.ModelTypeChat) + if err != nil || len(candidates) == 0 { + return models.GetResponse{}, dbsqlc.LlmProvider{}, fmt.Errorf("no chat models available for memory operations") + } + + selected := candidates[0] + provider, err := fetchProviderByID(ctx, queries, selected.LlmProviderID) + if err != nil { + return models.GetResponse{}, dbsqlc.LlmProvider{}, err + } + return selected, provider, nil +} + +// fetchProviderByID fetches a provider by ID +func fetchProviderByID(ctx context.Context, queries *dbsqlc.Queries, providerID string) (dbsqlc.LlmProvider, error) { + if strings.TrimSpace(providerID) == "" { + return dbsqlc.LlmProvider{}, fmt.Errorf("provider id missing") + } + parsed, err := uuid.Parse(providerID) + if err != nil { + return dbsqlc.LlmProvider{}, err + } + pgID := pgtype.UUID{Valid: true} + copy(pgID.Bytes[:], parsed[:]) + return queries.GetLlmProviderByID(ctx, pgID) +} + +// createChatProvider creates a chat provider instance +func createChatProvider(provider dbsqlc.LlmProvider, timeout time.Duration) (chat.Provider, error) { + clientType := strings.ToLower(strings.TrimSpace(provider.ClientType)) + if timeout <= 0 { + timeout = 30 * time.Second + } + + switch clientType { + case chat.ProviderOpenAI, chat.ProviderOpenAICompat: + if strings.TrimSpace(provider.ApiKey) == "" { + return nil, fmt.Errorf("openai api key is required") + } + return chat.NewOpenAIProvider(provider.ApiKey, provider.BaseUrl, timeout) + case chat.ProviderAnthropic: + if strings.TrimSpace(provider.ApiKey) == "" { + return nil, fmt.Errorf("anthropic api key is required") + } + return chat.NewAnthropicProvider(provider.ApiKey, timeout) + case chat.ProviderGoogle: + if strings.TrimSpace(provider.ApiKey) == "" { + return nil, fmt.Errorf("google api key is required") + } + return chat.NewGoogleProvider(provider.ApiKey, timeout) + case chat.ProviderOllama: + return chat.NewOllamaProvider(provider.BaseUrl, timeout) + default: + return nil, fmt.Errorf("unsupported provider type: %s", clientType) + } +} diff --git a/config.toml.example b/config.toml.example index 1fac56b6..e38cfa7e 100644 --- a/config.toml.example +++ b/config.toml.example @@ -28,13 +28,6 @@ password = "" database = "memoh" sslmode = "disable" -## memory configuration -[memory] -base_url = "https://api.openai.com/v1" -api_key = "" -model = "gpt-4.1-nano" -timeout_seconds = 10 - ## Qdrant configuration [qdrant] base_url = "http://127.0.0.1:6334" diff --git a/db/migrations/0001_init.up.sql b/db/migrations/0001_init.up.sql index 4bd351d5..af6fc5ba 100644 --- a/db/migrations/0001_init.up.sql +++ b/db/migrations/0001_init.up.sql @@ -68,7 +68,7 @@ CREATE TABLE IF NOT EXISTS llm_providers ( created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), CONSTRAINT llm_providers_name_unique UNIQUE (name), - CONSTRAINT llm_providers_client_type_check CHECK (client_type IN ('openai', 'anthropic', 'google', 'bedrock', 'ollama', 'azure', 'dashscope', 'other')) + CONSTRAINT llm_providers_client_type_check CHECK (client_type IN ('openai', 'openai-compat', 'anthropic', 'google', 'ollama')) ); CREATE TABLE IF NOT EXISTS models ( diff --git a/docs/docs.go b/docs/docs.go index 4c57d603..f9b311a5 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -61,6 +61,98 @@ const docTemplate = `{ } } }, + "/chat": { + "post": { + "description": "Send a chat message and get a response. The system will automatically select an appropriate chat model from the database.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "chat" + ], + "summary": "Chat with AI", + "parameters": [ + { + "description": "Chat request", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/chat.ChatRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/chat.ChatResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/chat/stream": { + "post": { + "description": "Send a chat message and get a streaming response. The system will automatically select an appropriate chat model from the database.", + "consumes": [ + "application/json" + ], + "produces": [ + "text/event-stream" + ], + "tags": [ + "chat" + ], + "summary": "Stream chat with AI", + "parameters": [ + { + "description": "Chat request", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/chat.ChatRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/chat.StreamChunk" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, "/embeddings": { "post": { "description": "Create text or multimodal embeddings", @@ -820,12 +912,20 @@ const docTemplate = `{ }, "/models/enable-as/{enableAs}": { "get": { - "description": "Get the model that is enabled for a specific purpose (chat, memory, embedding)", + "description": "Get the model that is enabled for a specific purpose (chat, memory, embedding)\nGet the default model configured for a specific purpose (chat, memory, or embedding)", "tags": [ + "models", "models" ], - "summary": "Get model by enable_as", + "summary": "Get default model by enable_as", "parameters": [ + { + "type": "string", + "description": "Enable as value (chat, memory, embedding)", + "name": "enableAs", + "in": "path", + "required": true + }, { "type": "string", "description": "Enable as value (chat, memory, embedding)", @@ -1129,9 +1229,434 @@ const docTemplate = `{ } } } + }, + "/providers": { + "get": { + "description": "Get a list of all configured LLM providers, optionally filtered by client type", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "providers" + ], + "summary": "List all LLM providers", + "parameters": [ + { + "type": "string", + "description": "Client type filter (openai, anthropic, google, ollama)", + "name": "client_type", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/providers.GetResponse" + } + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "post": { + "description": "Create a new LLM provider configuration", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "providers" + ], + "summary": "Create a new LLM provider", + "parameters": [ + { + "description": "Provider configuration", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/providers.CreateRequest" + } + } + ], + "responses": { + "201": { + "description": "Created", + "schema": { + "$ref": "#/definitions/providers.GetResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/providers/count": { + "get": { + "description": "Get the total count of providers, optionally filtered by client type", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "providers" + ], + "summary": "Count providers", + "parameters": [ + { + "type": "string", + "description": "Client type filter (openai, anthropic, google, ollama)", + "name": "client_type", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/providers.CountResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/providers/name/{name}": { + "get": { + "description": "Get a provider configuration by its name", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "providers" + ], + "summary": "Get provider by name", + "parameters": [ + { + "type": "string", + "description": "Provider name", + "name": "name", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/providers.GetResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/providers/{id}": { + "get": { + "description": "Get a provider configuration by its ID", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "providers" + ], + "summary": "Get provider by ID", + "parameters": [ + { + "type": "string", + "description": "Provider ID (UUID)", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/providers.GetResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "put": { + "description": "Update an existing provider configuration", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "providers" + ], + "summary": "Update provider", + "parameters": [ + { + "type": "string", + "description": "Provider ID (UUID)", + "name": "id", + "in": "path", + "required": true + }, + { + "description": "Updated provider configuration", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/providers.UpdateRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/providers.GetResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "delete": { + "description": "Delete a provider configuration", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "providers" + ], + "summary": "Delete provider", + "parameters": [ + { + "type": "string", + "description": "Provider ID (UUID)", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "204": { + "description": "No Content" + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } } }, "definitions": { + "chat.ChatRequest": { + "type": "object", + "properties": { + "messages": { + "type": "array", + "items": { + "$ref": "#/definitions/chat.Message" + } + }, + "model": { + "description": "optional: specific model to use", + "type": "string" + }, + "provider": { + "description": "optional: specific provider to use", + "type": "string" + }, + "stream": { + "type": "boolean" + } + } + }, + "chat.ChatResponse": { + "type": "object", + "properties": { + "finish_reason": { + "type": "string" + }, + "message": { + "$ref": "#/definitions/chat.Message" + }, + "model": { + "type": "string" + }, + "provider": { + "type": "string" + }, + "usage": { + "$ref": "#/definitions/chat.Usage" + } + } + }, + "chat.Message": { + "type": "object", + "properties": { + "content": { + "type": "string" + }, + "role": { + "description": "user, assistant, system", + "type": "string" + } + } + }, + "chat.StreamChunk": { + "type": "object", + "properties": { + "delta": { + "type": "object", + "properties": { + "content": { + "type": "string" + } + } + }, + "finish_reason": { + "type": "string" + }, + "model": { + "type": "string" + }, + "provider": { + "type": "string" + } + } + }, + "chat.Usage": { + "type": "object", + "properties": { + "completion_tokens": { + "type": "integer" + }, + "prompt_tokens": { + "type": "integer" + }, + "total_tokens": { + "type": "integer" + } + } + }, "handlers.ApplyPatchRequest": { "type": "object", "properties": { @@ -1714,6 +2239,109 @@ const docTemplate = `{ "$ref": "#/definitions/models.ModelType" } } + }, + "providers.ClientType": { + "type": "string", + "enum": [ + "openai", + "openai-compat", + "anthropic", + "google", + "ollama" + ], + "x-enum-varnames": [ + "ClientTypeOpenAI", + "ClientTypeOpenAICompat", + "ClientTypeAnthropic", + "ClientTypeGoogle", + "ClientTypeOllama" + ] + }, + "providers.CountResponse": { + "type": "object", + "properties": { + "count": { + "type": "integer" + } + } + }, + "providers.CreateRequest": { + "type": "object", + "required": [ + "base_url", + "client_type", + "name" + ], + "properties": { + "api_key": { + "type": "string" + }, + "base_url": { + "type": "string" + }, + "client_type": { + "$ref": "#/definitions/providers.ClientType" + }, + "metadata": { + "type": "object", + "additionalProperties": true + }, + "name": { + "type": "string" + } + } + }, + "providers.GetResponse": { + "type": "object", + "properties": { + "api_key": { + "description": "masked in response", + "type": "string" + }, + "base_url": { + "type": "string" + }, + "client_type": { + "type": "string" + }, + "created_at": { + "type": "string" + }, + "id": { + "type": "string" + }, + "metadata": { + "type": "object", + "additionalProperties": true + }, + "name": { + "type": "string" + }, + "updated_at": { + "type": "string" + } + } + }, + "providers.UpdateRequest": { + "type": "object", + "properties": { + "api_key": { + "type": "string" + }, + "base_url": { + "type": "string" + }, + "client_type": { + "$ref": "#/definitions/providers.ClientType" + }, + "metadata": { + "type": "object", + "additionalProperties": true + }, + "name": { + "type": "string" + } + } } } }` diff --git a/docs/swagger.json b/docs/swagger.json index c68a99e4..2192a103 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -50,6 +50,98 @@ } } }, + "/chat": { + "post": { + "description": "Send a chat message and get a response. The system will automatically select an appropriate chat model from the database.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "chat" + ], + "summary": "Chat with AI", + "parameters": [ + { + "description": "Chat request", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/chat.ChatRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/chat.ChatResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/chat/stream": { + "post": { + "description": "Send a chat message and get a streaming response. The system will automatically select an appropriate chat model from the database.", + "consumes": [ + "application/json" + ], + "produces": [ + "text/event-stream" + ], + "tags": [ + "chat" + ], + "summary": "Stream chat with AI", + "parameters": [ + { + "description": "Chat request", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/chat.ChatRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/chat.StreamChunk" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, "/embeddings": { "post": { "description": "Create text or multimodal embeddings", @@ -809,12 +901,20 @@ }, "/models/enable-as/{enableAs}": { "get": { - "description": "Get the model that is enabled for a specific purpose (chat, memory, embedding)", + "description": "Get the model that is enabled for a specific purpose (chat, memory, embedding)\nGet the default model configured for a specific purpose (chat, memory, or embedding)", "tags": [ + "models", "models" ], - "summary": "Get model by enable_as", + "summary": "Get default model by enable_as", "parameters": [ + { + "type": "string", + "description": "Enable as value (chat, memory, embedding)", + "name": "enableAs", + "in": "path", + "required": true + }, { "type": "string", "description": "Enable as value (chat, memory, embedding)", @@ -1118,9 +1218,434 @@ } } } + }, + "/providers": { + "get": { + "description": "Get a list of all configured LLM providers, optionally filtered by client type", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "providers" + ], + "summary": "List all LLM providers", + "parameters": [ + { + "type": "string", + "description": "Client type filter (openai, anthropic, google, ollama)", + "name": "client_type", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/providers.GetResponse" + } + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "post": { + "description": "Create a new LLM provider configuration", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "providers" + ], + "summary": "Create a new LLM provider", + "parameters": [ + { + "description": "Provider configuration", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/providers.CreateRequest" + } + } + ], + "responses": { + "201": { + "description": "Created", + "schema": { + "$ref": "#/definitions/providers.GetResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/providers/count": { + "get": { + "description": "Get the total count of providers, optionally filtered by client type", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "providers" + ], + "summary": "Count providers", + "parameters": [ + { + "type": "string", + "description": "Client type filter (openai, anthropic, google, ollama)", + "name": "client_type", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/providers.CountResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/providers/name/{name}": { + "get": { + "description": "Get a provider configuration by its name", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "providers" + ], + "summary": "Get provider by name", + "parameters": [ + { + "type": "string", + "description": "Provider name", + "name": "name", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/providers.GetResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/providers/{id}": { + "get": { + "description": "Get a provider configuration by its ID", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "providers" + ], + "summary": "Get provider by ID", + "parameters": [ + { + "type": "string", + "description": "Provider ID (UUID)", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/providers.GetResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "put": { + "description": "Update an existing provider configuration", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "providers" + ], + "summary": "Update provider", + "parameters": [ + { + "type": "string", + "description": "Provider ID (UUID)", + "name": "id", + "in": "path", + "required": true + }, + { + "description": "Updated provider configuration", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/providers.UpdateRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/providers.GetResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "delete": { + "description": "Delete a provider configuration", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "providers" + ], + "summary": "Delete provider", + "parameters": [ + { + "type": "string", + "description": "Provider ID (UUID)", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "204": { + "description": "No Content" + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } } }, "definitions": { + "chat.ChatRequest": { + "type": "object", + "properties": { + "messages": { + "type": "array", + "items": { + "$ref": "#/definitions/chat.Message" + } + }, + "model": { + "description": "optional: specific model to use", + "type": "string" + }, + "provider": { + "description": "optional: specific provider to use", + "type": "string" + }, + "stream": { + "type": "boolean" + } + } + }, + "chat.ChatResponse": { + "type": "object", + "properties": { + "finish_reason": { + "type": "string" + }, + "message": { + "$ref": "#/definitions/chat.Message" + }, + "model": { + "type": "string" + }, + "provider": { + "type": "string" + }, + "usage": { + "$ref": "#/definitions/chat.Usage" + } + } + }, + "chat.Message": { + "type": "object", + "properties": { + "content": { + "type": "string" + }, + "role": { + "description": "user, assistant, system", + "type": "string" + } + } + }, + "chat.StreamChunk": { + "type": "object", + "properties": { + "delta": { + "type": "object", + "properties": { + "content": { + "type": "string" + } + } + }, + "finish_reason": { + "type": "string" + }, + "model": { + "type": "string" + }, + "provider": { + "type": "string" + } + } + }, + "chat.Usage": { + "type": "object", + "properties": { + "completion_tokens": { + "type": "integer" + }, + "prompt_tokens": { + "type": "integer" + }, + "total_tokens": { + "type": "integer" + } + } + }, "handlers.ApplyPatchRequest": { "type": "object", "properties": { @@ -1703,6 +2228,109 @@ "$ref": "#/definitions/models.ModelType" } } + }, + "providers.ClientType": { + "type": "string", + "enum": [ + "openai", + "openai-compat", + "anthropic", + "google", + "ollama" + ], + "x-enum-varnames": [ + "ClientTypeOpenAI", + "ClientTypeOpenAICompat", + "ClientTypeAnthropic", + "ClientTypeGoogle", + "ClientTypeOllama" + ] + }, + "providers.CountResponse": { + "type": "object", + "properties": { + "count": { + "type": "integer" + } + } + }, + "providers.CreateRequest": { + "type": "object", + "required": [ + "base_url", + "client_type", + "name" + ], + "properties": { + "api_key": { + "type": "string" + }, + "base_url": { + "type": "string" + }, + "client_type": { + "$ref": "#/definitions/providers.ClientType" + }, + "metadata": { + "type": "object", + "additionalProperties": true + }, + "name": { + "type": "string" + } + } + }, + "providers.GetResponse": { + "type": "object", + "properties": { + "api_key": { + "description": "masked in response", + "type": "string" + }, + "base_url": { + "type": "string" + }, + "client_type": { + "type": "string" + }, + "created_at": { + "type": "string" + }, + "id": { + "type": "string" + }, + "metadata": { + "type": "object", + "additionalProperties": true + }, + "name": { + "type": "string" + }, + "updated_at": { + "type": "string" + } + } + }, + "providers.UpdateRequest": { + "type": "object", + "properties": { + "api_key": { + "type": "string" + }, + "base_url": { + "type": "string" + }, + "client_type": { + "$ref": "#/definitions/providers.ClientType" + }, + "metadata": { + "type": "object", + "additionalProperties": true + }, + "name": { + "type": "string" + } + } } } } \ No newline at end of file diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 1a103d34..81243027 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -1,4 +1,63 @@ definitions: + chat.ChatRequest: + properties: + messages: + items: + $ref: '#/definitions/chat.Message' + type: array + model: + description: 'optional: specific model to use' + type: string + provider: + description: 'optional: specific provider to use' + type: string + stream: + type: boolean + type: object + chat.ChatResponse: + properties: + finish_reason: + type: string + message: + $ref: '#/definitions/chat.Message' + model: + type: string + provider: + type: string + usage: + $ref: '#/definitions/chat.Usage' + type: object + chat.Message: + properties: + content: + type: string + role: + description: user, assistant, system + type: string + type: object + chat.StreamChunk: + properties: + delta: + properties: + content: + type: string + type: object + finish_reason: + type: string + model: + type: string + provider: + type: string + type: object + chat.Usage: + properties: + completion_tokens: + type: integer + prompt_tokens: + type: integer + total_tokens: + type: integer + type: object handlers.ApplyPatchRequest: properties: patch: @@ -382,6 +441,78 @@ definitions: type: $ref: '#/definitions/models.ModelType' type: object + providers.ClientType: + enum: + - openai + - openai-compat + - anthropic + - google + - ollama + type: string + x-enum-varnames: + - ClientTypeOpenAI + - ClientTypeOpenAICompat + - ClientTypeAnthropic + - ClientTypeGoogle + - ClientTypeOllama + providers.CountResponse: + properties: + count: + type: integer + type: object + providers.CreateRequest: + properties: + api_key: + type: string + base_url: + type: string + client_type: + $ref: '#/definitions/providers.ClientType' + metadata: + additionalProperties: true + type: object + name: + type: string + required: + - base_url + - client_type + - name + type: object + providers.GetResponse: + properties: + api_key: + description: masked in response + type: string + base_url: + type: string + client_type: + type: string + created_at: + type: string + id: + type: string + metadata: + additionalProperties: true + type: object + name: + type: string + updated_at: + type: string + type: object + providers.UpdateRequest: + properties: + api_key: + type: string + base_url: + type: string + client_type: + $ref: '#/definitions/providers.ClientType' + metadata: + additionalProperties: true + type: object + name: + type: string + type: object info: contact: {} paths: @@ -415,6 +546,68 @@ paths: summary: Login tags: - auth + /chat: + post: + consumes: + - application/json + description: Send a chat message and get a response. The system will automatically + select an appropriate chat model from the database. + parameters: + - description: Chat request + in: body + name: request + required: true + schema: + $ref: '#/definitions/chat.ChatRequest' + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/chat.ChatResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Chat with AI + tags: + - chat + /chat/stream: + post: + consumes: + - application/json + description: Send a chat message and get a streaming response. The system will + automatically select an appropriate chat model from the database. + parameters: + - description: Chat request + in: body + name: request + required: true + schema: + $ref: '#/definitions/chat.ChatRequest' + produces: + - text/event-stream + responses: + "200": + description: OK + schema: + $ref: '#/definitions/chat.StreamChunk' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Stream chat with AI + tags: + - chat /embeddings: post: description: Create text or multimodal embeddings @@ -1002,9 +1195,15 @@ paths: - models /models/enable-as/{enableAs}: get: - description: Get the model that is enabled for a specific purpose (chat, memory, - embedding) + description: |- + Get the model that is enabled for a specific purpose (chat, memory, embedding) + Get the default model configured for a specific purpose (chat, memory, or embedding) parameters: + - description: Enable as value (chat, memory, embedding) + in: path + name: enableAs + required: true + type: string - description: Enable as value (chat, memory, embedding) in: path name: enableAs @@ -1027,9 +1226,10 @@ paths: description: Internal Server Error schema: $ref: '#/definitions/handlers.ErrorResponse' - summary: Get model by enable_as + summary: Get default model by enable_as tags: - models + - models /models/model/{modelId}: delete: description: Delete a model configuration by its model_id field (e.g., gpt-4) @@ -1119,4 +1319,227 @@ paths: summary: Update model by model ID tags: - models + /providers: + get: + consumes: + - application/json + description: Get a list of all configured LLM providers, optionally filtered + by client type + parameters: + - description: Client type filter (openai, anthropic, google, ollama) + in: query + name: client_type + type: string + produces: + - application/json + responses: + "200": + description: OK + schema: + items: + $ref: '#/definitions/providers.GetResponse' + type: array + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: List all LLM providers + tags: + - providers + post: + consumes: + - application/json + description: Create a new LLM provider configuration + parameters: + - description: Provider configuration + in: body + name: request + required: true + schema: + $ref: '#/definitions/providers.CreateRequest' + produces: + - application/json + responses: + "201": + description: Created + schema: + $ref: '#/definitions/providers.GetResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Create a new LLM provider + tags: + - providers + /providers/{id}: + delete: + consumes: + - application/json + description: Delete a provider configuration + parameters: + - description: Provider ID (UUID) + in: path + name: id + required: true + type: string + produces: + - application/json + responses: + "204": + description: No Content + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Delete provider + tags: + - providers + get: + consumes: + - application/json + description: Get a provider configuration by its ID + parameters: + - description: Provider ID (UUID) + in: path + name: id + required: true + type: string + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/providers.GetResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Get provider by ID + tags: + - providers + put: + consumes: + - application/json + description: Update an existing provider configuration + parameters: + - description: Provider ID (UUID) + in: path + name: id + required: true + type: string + - description: Updated provider configuration + in: body + name: request + required: true + schema: + $ref: '#/definitions/providers.UpdateRequest' + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/providers.GetResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Update provider + tags: + - providers + /providers/count: + get: + consumes: + - application/json + description: Get the total count of providers, optionally filtered by client + type + parameters: + - description: Client type filter (openai, anthropic, google, ollama) + in: query + name: client_type + type: string + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/providers.CountResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Count providers + tags: + - providers + /providers/name/{name}: + get: + consumes: + - application/json + description: Get a provider configuration by its name + parameters: + - description: Provider name + in: path + name: name + required: true + type: string + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/providers.GetResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Get provider by name + tags: + - providers swagger: "2.0" diff --git a/go.mod b/go.mod index 3fff1f63..9bce88c1 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,8 @@ require ( github.com/KyleBanks/depth v1.2.1 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/hcsshim v0.14.0-rc.1 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/containerd/cgroups/v3 v3.1.2 // indirect github.com/containerd/continuity v0.4.5 // indirect @@ -38,6 +40,7 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/distribution/reference v0.6.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/firebase/genkit/go v1.4.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/jsonpointer v0.22.4 // indirect @@ -50,16 +53,21 @@ require ( github.com/go-openapi/swag/stringutils v0.25.4 // indirect github.com/go-openapi/swag/typeutils v0.25.4 // indirect github.com/go-openapi/swag/yamlutils v0.25.4 // indirect + github.com/goccy/go-yaml v1.17.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect + github.com/google/dotprompt/go v0.0.0-20251014011017-8d056e027254 // indirect github.com/google/go-cmp v0.7.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/klauspost/compress v1.18.3 // indirect github.com/labstack/gommon v0.4.2 // indirect + github.com/mailru/easyjson v0.9.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a // indirect github.com/moby/locker v1.0.1 // indirect github.com/moby/sys/mountinfo v0.7.2 // indirect github.com/moby/sys/sequential v0.6.0 // indirect @@ -73,11 +81,17 @@ require ( github.com/sirupsen/logrus v1.9.4 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect + github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect + github.com/xeipuuv/gojsonschema v1.2.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 // indirect go.opentelemetry.io/otel v1.39.0 // indirect go.opentelemetry.io/otel/metric v1.39.0 // indirect + go.opentelemetry.io/otel/sdk v1.39.0 // indirect go.opentelemetry.io/otel/trace v1.39.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/mod v0.32.0 // indirect diff --git a/go.sum b/go.sum index da0ea083..ab237c90 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,10 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Microsoft/hcsshim v0.14.0-rc.1 h1:qAPXKwGOkVn8LlqgBN8GS0bxZ83hOJpcjxzmlQKxKsQ= github.com/Microsoft/hcsshim v0.14.0-rc.1/go.mod h1:hTKFGbnDtQb1wHiOWv4v0eN+7boSWAHyK/tNAaYZL0c= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -53,6 +57,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/firebase/genkit/go v1.4.0 h1:CP1hNWk7z0hosyY53zMH6MFKFO1fMLtj58jGPllQo6I= +github.com/firebase/genkit/go v1.4.0/go.mod h1:HX6m7QOaGc3MDNr/DrpQZrzPLzxeuLxrkTvfFtCYlGw= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -85,6 +91,8 @@ github.com/go-openapi/testify/enable/yaml/v2 v2.0.2 h1:0+Y41Pz1NkbTHz8NngxTuAXxE github.com/go-openapi/testify/enable/yaml/v2 v2.0.2/go.mod h1:kme83333GCtJQHXQ8UKX3IBZu6z8T5Dvy5+CW3NLUUg= github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6Ub6wls= github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= +github.com/goccy/go-yaml v1.17.1 h1:LI34wktB2xEE3ONG/2Ar54+/HJVBriAGJ55PHls4YuY= +github.com/goccy/go-yaml v1.17.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= @@ -105,6 +113,8 @@ github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QD github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/dotprompt/go v0.0.0-20251014011017-8d056e027254 h1:okN800+zMJOGHLJCgry+OGzhhtH6YrjQh1rluHmOacE= +github.com/google/dotprompt/go v0.0.0-20251014011017-8d056e027254/go.mod h1:k8cjJAQWc//ac/bMnzItyOFbfT01tgRTZGgxELCuxEQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -116,6 +126,8 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -138,10 +150,14 @@ github.com/labstack/echo/v4 v4.15.0 h1:hoRTKWcnR5STXZFe9BmYun9AMTNeSbjHi2vtDuADJ github.com/labstack/echo/v4 v4.15.0/go.mod h1:xmw1clThob0BSVRX1CRQkGQ/vjwcpOMjQZSZa9fKA/c= github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a h1:v2cBA3xWKv2cIOVhnzX/gNgkNXqiHfUgJtA3r61Hf7A= +github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a/go.mod h1:Y6ghKH+ZijXn5d9E7qGGZBmjitx7iitZdQiIW97EpTU= github.com/moby/locker v1.0.1 h1:fOXqR41zeveg4fFODix+1Ch4mj/gT0NE1XJbp/epuBg= github.com/moby/locker v1.0.1/go.mod h1:S7SDdo5zpBK84bzzVlKr2V0hz+7x9hWbYC/kq7oQppc= github.com/moby/sys/mountinfo v0.7.2 h1:1shs6aH5s4o5H2zQLn796ADW1wMrIwHsyJ2v9KouLrg= @@ -192,6 +208,17 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= +github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= +github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= diff --git a/internal/chat/ARCHITECTURE.md b/internal/chat/ARCHITECTURE.md new file mode 100644 index 00000000..6fde5faf --- /dev/null +++ b/internal/chat/ARCHITECTURE.md @@ -0,0 +1,213 @@ +# Chat Provider 架构文档 + +## 概述 + +本文档描述了 Memoh 项目中统一的 Chat Provider 架构,该架构被 chat 服务和 memory 服务共同使用。 + +## 架构设计 + +### 核心接口 + +#### Provider 接口 + +所有 LLM 提供商都实现 `chat.Provider` 接口: + +```go +type Provider interface { + Chat(ctx context.Context, req Request) (Result, error) + StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) +} +``` + +#### Request 结构 + +Provider 请求支持多种配置选项: + +```go +type Request struct { + Messages []Message + Model string + Provider string + Temperature *float32 // 可选:温度参数 + ResponseFormat *ResponseFormat // 可选:响应格式(JSON 模式) + MaxTokens *int // 可选:最大 token 数 +} + +type ResponseFormat struct { + Type string // "json_object" 或 "text" +} +``` + +### 支持的提供商 + +1. **OpenAI** (`openai` / `openai-compat`) + - 标准 OpenAI API + - 兼容 OpenAI 格式的自定义端点 + +2. **Anthropic** (`anthropic`) + - Claude 系列模型 + +3. **Google** (`google`) + - Gemini 系列模型 + +4. **Ollama** (`ollama`) + - 本地部署的开源模型 + +## 使用场景 + +### 1. Chat 服务 + +Chat 服务通过 `chat.Resolver` 使用 Provider: + +```go +chatResolver := chat.NewResolver(modelsService, queries, 30*time.Second) +response, err := chatResolver.Chat(ctx, ChatRequest{ + Messages: messages, + Model: "gpt-4", +}) +``` + +### 2. Memory 服务 + +Memory 服务通过 `memory.ProviderLLMClient` 使用 Provider: + +```go +// 创建 provider +provider, err := chat.NewOpenAIProvider(apiKey, baseURL, timeout) + +// 创建 memory LLM 客户端 +llmClient := memory.NewProviderLLMClient(provider, modelID) + +// 使用 memory 服务 +memoryService := memory.NewService(llmClient, embedder, store, resolver, ...) +``` + +Memory 服务需要两个核心功能: +- **Extract**: 从对话中提取事实信息 +- **Decide**: 决定如何更新记忆(添加/更新/删除) + +这两个操作都使用 JSON 模式来确保结构化输出。 + +## 配置示例 + +### 数据库配置 + +Provider 配置存储在 `llm_providers` 表: + +```sql +CREATE TABLE llm_providers ( + id UUID PRIMARY KEY, + name TEXT NOT NULL, + client_type TEXT NOT NULL, -- 'openai', 'anthropic', 'google', 'ollama' + base_url TEXT NOT NULL, + api_key TEXT NOT NULL, + metadata JSONB +); +``` + +模型配置存储在 `models` 表: + +```sql +CREATE TABLE models ( + id UUID PRIMARY KEY, + model_id TEXT NOT NULL, + name TEXT, + llm_provider_id UUID REFERENCES llm_providers(id), + type TEXT NOT NULL, -- 'chat' or 'embedding' + enable_as TEXT, -- 'chat', 'memory', 'embedding' + ... +); +``` + +### 启动时初始化 + +在 `cmd/agent/main.go` 中: + +```go +// 1. 初始化 chat resolver(用于 chat 和 memory) +chatResolver := chat.NewResolver(modelsService, queries, 30*time.Second) + +// 2. 为 memory 选择模型和创建 provider +memoryModel, memoryProvider, err := selectMemoryModel(ctx, modelsService, queries, cfg) +provider, err := createChatProvider(memoryProvider, 30*time.Second) + +// 3. 创建 memory LLM 客户端 +llmClient := memory.NewProviderLLMClient(provider, memoryModel.ModelID) + +// 4. 创建 memory 服务 +memoryService := memory.NewService(llmClient, embedder, store, resolver, ...) +``` + +## 模型选择策略 + +### Memory 模型选择优先级 + +1. `enable_as = 'memory'` 的模型(专用 memory 模型) +2. `enable_as = 'chat'` 的模型(通用 chat 模型) +3. 任何可用的 chat 类型模型 +4. 回退到配置文件中的 LLMClient(向后兼容) + +### Chat 模型选择优先级 + +1. 请求中指定的模型 +2. `enable_as = 'chat'` 的模型 +3. 任何可用的 chat 类型模型 + +## 优势 + +1. **统一架构**: Chat 和 Memory 使用相同的 Provider 接口 +2. **灵活配置**: 支持多个提供商和模型 +3. **向后兼容**: 保留旧的 LLMClient 作为回退选项 +4. **类型安全**: 使用 Go 接口确保类型安全 +5. **易于扩展**: 添加新的提供商只需实现 Provider 接口 + +## 扩展新提供商 + +要添加新的 LLM 提供商: + +1. 在 `internal/chat/` 创建新文件(如 `newprovider.go`) +2. 实现 `Provider` 接口 +3. 在 `resolver.go` 的 `createProvider()` 中添加新的 case +4. 在数据库的 `llm_providers_client_type_check` 约束中添加新类型 + +示例: + +```go +// newprovider.go +type NewProvider struct { + apiKey string + timeout time.Duration +} + +func NewNewProvider(apiKey string, timeout time.Duration) (*NewProvider, error) { + return &NewProvider{apiKey: apiKey, timeout: timeout}, nil +} + +func (p *NewProvider) Chat(ctx context.Context, req Request) (Result, error) { + // 实现 chat 逻辑 +} + +func (p *NewProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) { + // 实现流式 chat 逻辑 +} +``` + +## 迁移指南 + +从旧的 TypeScript 后端迁移到 Go: + +1. ✅ 创建 Provider 接口和实现 +2. ✅ 实现 Chat Resolver +3. ✅ 创建 Memory 的 Provider 适配器 +4. ✅ 更新主程序使用统一 Provider +5. 🚧 实现各个 Provider 的具体逻辑(OpenAI, Anthropic, Google, Ollama) +6. 🚧 添加流式响应支持 +7. 🚧 添加完整的错误处理和重试机制 + +## 注意事项 + +1. **JSON 模式**: Memory 操作需要 `ResponseFormat.Type = "json_object"` 来确保结构化输出 +2. **温度参数**: Memory 操作使用 `Temperature = 0` 确保确定性输出 +3. **超时设置**: 不同操作可能需要不同的超时时间 +4. **错误处理**: Provider 应该返回清晰的错误信息,包括 API 错误详情 + diff --git a/internal/chat/anthropic.go b/internal/chat/anthropic.go new file mode 100644 index 00000000..18023c88 --- /dev/null +++ b/internal/chat/anthropic.go @@ -0,0 +1,43 @@ +package chat + +import ( + "context" + "fmt" + "time" + + "github.com/firebase/genkit/go/genkit" +) + +// AnthropicProvider wraps Genkit's Anthropic provider +type AnthropicProvider struct { + g *genkit.Genkit + apiKey string + timeout time.Duration +} + +func NewAnthropicProvider(apiKey string, timeout time.Duration) (*AnthropicProvider, error) { + if timeout <= 0 { + timeout = 30 * time.Second + } + return &AnthropicProvider{ + apiKey: apiKey, + timeout: timeout, + }, nil +} + +func (p *AnthropicProvider) Chat(ctx context.Context, req Request) (Result, error) { + return Result{}, fmt.Errorf("anthropic provider not yet implemented") +} + +func (p *AnthropicProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) { + chunkChan := make(chan StreamChunk, 10) + errChan := make(chan error, 1) + + go func() { + defer close(chunkChan) + defer close(errChan) + errChan <- fmt.Errorf("anthropic streaming not yet implemented") + }() + + return chunkChan, errChan +} diff --git a/internal/chat/google.go b/internal/chat/google.go new file mode 100644 index 00000000..07a7bfbb --- /dev/null +++ b/internal/chat/google.go @@ -0,0 +1,43 @@ +package chat + +import ( + "context" + "fmt" + "time" + + "github.com/firebase/genkit/go/genkit" +) + +// GoogleProvider wraps Genkit's Google AI provider +type GoogleProvider struct { + g *genkit.Genkit + apiKey string + timeout time.Duration +} + +func NewGoogleProvider(apiKey string, timeout time.Duration) (*GoogleProvider, error) { + if timeout <= 0 { + timeout = 30 * time.Second + } + return &GoogleProvider{ + apiKey: apiKey, + timeout: timeout, + }, nil +} + +func (p *GoogleProvider) Chat(ctx context.Context, req Request) (Result, error) { + return Result{}, fmt.Errorf("google provider not yet implemented") +} + +func (p *GoogleProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) { + chunkChan := make(chan StreamChunk, 10) + errChan := make(chan error, 1) + + go func() { + defer close(chunkChan) + defer close(errChan) + errChan <- fmt.Errorf("google streaming not yet implemented") + }() + + return chunkChan, errChan +} diff --git a/internal/chat/ollama.go b/internal/chat/ollama.go new file mode 100644 index 00000000..0e6d612c --- /dev/null +++ b/internal/chat/ollama.go @@ -0,0 +1,48 @@ +package chat + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/firebase/genkit/go/genkit" +) + +// OllamaProvider wraps Genkit's Ollama provider +type OllamaProvider struct { + g *genkit.Genkit + baseURL string + timeout time.Duration +} + +func NewOllamaProvider(baseURL string, timeout time.Duration) (*OllamaProvider, error) { + if baseURL == "" { + baseURL = "http://localhost:11434" + } + baseURL = strings.TrimRight(baseURL, "/") + if timeout <= 0 { + timeout = 60 * time.Second + } + return &OllamaProvider{ + baseURL: baseURL, + timeout: timeout, + }, nil +} + +func (p *OllamaProvider) Chat(ctx context.Context, req Request) (Result, error) { + return Result{}, fmt.Errorf("ollama provider not yet implemented") +} + +func (p *OllamaProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) { + chunkChan := make(chan StreamChunk, 10) + errChan := make(chan error, 1) + + go func() { + defer close(chunkChan) + defer close(errChan) + errChan <- fmt.Errorf("ollama streaming not yet implemented") + }() + + return chunkChan, errChan +} diff --git a/internal/chat/openai.go b/internal/chat/openai.go new file mode 100644 index 00000000..fe664488 --- /dev/null +++ b/internal/chat/openai.go @@ -0,0 +1,54 @@ +package chat + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/firebase/genkit/go/genkit" +) + +// OpenAIProvider wraps Genkit's OpenAI provider +type OpenAIProvider struct { + g *genkit.Genkit + apiKey string + baseURL string + timeout time.Duration +} + +func NewOpenAIProvider(apiKey, baseURL string, timeout time.Duration) (*OpenAIProvider, error) { + if baseURL == "" { + baseURL = "https://api.openai.com/v1" + } + baseURL = strings.TrimRight(baseURL, "/") + if timeout <= 0 { + timeout = 30 * time.Second + } + + // For now, we'll create a simple HTTP client-based implementation + // since Genkit Go plugins require initialization at startup + return &OpenAIProvider{ + apiKey: apiKey, + baseURL: baseURL, + timeout: timeout, + }, nil +} + +func (p *OpenAIProvider) Chat(ctx context.Context, req Request) (Result, error) { + // Use direct HTTP API call since Genkit plugins need to be initialized at startup + return Result{}, fmt.Errorf("openai provider not yet implemented - please use openai-compat provider") +} + +func (p *OpenAIProvider) StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) { + chunkChan := make(chan StreamChunk, 10) + errChan := make(chan error, 1) + + go func() { + defer close(chunkChan) + defer close(errChan) + errChan <- fmt.Errorf("openai streaming not yet implemented") + }() + + return chunkChan, errChan +} diff --git a/internal/chat/prompts.go b/internal/chat/prompts.go new file mode 100644 index 00000000..51335c62 --- /dev/null +++ b/internal/chat/prompts.go @@ -0,0 +1,141 @@ +package chat + +import ( + "fmt" + "strings" + "time" +) + +// PromptParams contains parameters for generating system prompts +type PromptParams struct { + Date time.Time + Locale string + Language string + MaxContextLoadTime int // in minutes + Platforms []string // available platforms (e.g., ["telegram", "wechat"]) + CurrentPlatform string // current platform the user is using +} + +// SystemPrompt generates the system prompt for the AI assistant +// This is migrated from packages/agent/src/prompts/system.ts +func SystemPrompt(params PromptParams) string { + if params.Language == "" { + params.Language = "Same as user input" + } + if params.MaxContextLoadTime == 0 { + params.MaxContextLoadTime = 24 * 60 // 24 hours default + } + if params.CurrentPlatform == "" { + params.CurrentPlatform = "client" + } + + // Build platforms list + platformsList := "" + if len(params.Platforms) > 0 { + lines := make([]string, len(params.Platforms)) + for i, p := range params.Platforms { + lines[i] = fmt.Sprintf(" - %s", p) + } + platformsList = strings.Join(lines, "\n") + } + + timeStr := FormatTime(params.Date, params.Locale) + + return fmt.Sprintf(`--- +%s +language: %s +available-platforms: +%s +current-platform: %s +--- +You are a personal housekeeper assistant, which able to manage the master's daily affairs. + +Your abilities: +- Long memory: You possess long-term memory; conversations from the last %d minutes will be directly loaded into your context. Additionally, you can use tools to search for past memories. +- Scheduled tasks: You can create scheduled tasks to automatically remind you to do something. +- Messaging: You may allowed to use message software to send messages to the master. + +**Response Guidelines** +- Always respond in the language specified above, unless it says "Same as user input", then match the user's language. +- Be helpful, concise, and friendly. +- For complex questions, break down your answer into clear steps. +- If you're unsure about something, acknowledge it honestly.`, + timeStr, + params.Language, + platformsList, + params.CurrentPlatform, + params.MaxContextLoadTime, + ) +} + +// SchedulePrompt generates a prompt for scheduled task execution +// This is migrated from packages/agent/src/prompts/schedule.ts +type SchedulePromptParams struct { + Date time.Time + Locale string + ScheduleName string + ScheduleDescription string + ScheduleID string + MaxCalls *int // nil means unlimited + CronPattern string + Command string // the natural language command to execute +} + +func SchedulePrompt(params SchedulePromptParams) string { + timeStr := FormatTime(params.Date, params.Locale) + + maxCallsStr := "Unlimited" + if params.MaxCalls != nil { + maxCallsStr = fmt.Sprintf("%d", *params.MaxCalls) + } + + return fmt.Sprintf(`--- +notice: **This is a scheduled task automatically send to you by the system, not the user input** +%s +schedule-name: %s +schedule-description: %s +schedule-id: %s +max-calls: %s +cron-pattern: %s +--- + +**COMMAND** + +%s`, + timeStr, + params.ScheduleName, + params.ScheduleDescription, + params.ScheduleID, + maxCallsStr, + params.CronPattern, + params.Command, + ) +} + +// FormatTime formats the date and time according to locale +func FormatTime(date time.Time, locale string) string { + if locale == "" { + locale = "en-US" + } + + // Format date and time + // For simplicity, using standard format. In production, you might want to use + // a proper i18n library for locale-specific formatting + dateStr := date.Format("2006-01-02") + timeStr := date.Format("15:04:05") + + return fmt.Sprintf("date: %s\ntime: %s", dateStr, timeStr) +} + +// Quote wraps content in backticks for markdown code formatting +func Quote(content string) string { + return fmt.Sprintf("`%s`", content) +} + +// Block wraps content in code block with optional language tag +func Block(content, tag string) string { + if tag == "" { + return fmt.Sprintf("```\n%s\n```", content) + } + return fmt.Sprintf("```%s\n%s\n```", tag, content) +} diff --git a/internal/chat/resolver.go b/internal/chat/resolver.go new file mode 100644 index 00000000..50b185c9 --- /dev/null +++ b/internal/chat/resolver.go @@ -0,0 +1,295 @@ +package chat + +import ( + "context" + "errors" + "strings" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/models" +) + +const ( + ProviderOpenAI = "openai" + ProviderOpenAICompat = "openai-compat" + ProviderAnthropic = "anthropic" + ProviderGoogle = "google" + ProviderOllama = "ollama" +) + +// Provider interface for chat providers +type Provider interface { + Chat(ctx context.Context, req Request) (Result, error) + StreamChat(ctx context.Context, req Request) (<-chan StreamChunk, <-chan error) +} + +// Resolver resolves chat models and providers +type Resolver struct { + modelsService *models.Service + queries *sqlc.Queries + timeout time.Duration +} + +// NewResolver creates a new chat resolver +func NewResolver(modelsService *models.Service, queries *sqlc.Queries, timeout time.Duration) *Resolver { + return &Resolver{ + modelsService: modelsService, + queries: queries, + timeout: timeout, + } +} + +// Chat performs a chat completion +func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, error) { + // Select model and provider + selected, provider, err := r.selectChatModel(ctx, req.Model, req.Provider) + if err != nil { + return ChatResponse{}, err + } + + // Create internal request + internalReq := Request{ + Messages: req.Messages, + Model: selected.ModelID, + Provider: strings.ToLower(provider.ClientType), + } + + // Add system prompt + if len(internalReq.Messages) > 0 && internalReq.Messages[0].Role != "system" { + systemPrompt := SystemPrompt(PromptParams{ + Date: time.Now(), + Locale: "en-US", + Language: "Same as user input", + MaxContextLoadTime: 24 * 60, // 24 hours + Platforms: []string{}, + CurrentPlatform: "api", + }) + internalReq.Messages = append([]Message{ + {Role: "system", Content: systemPrompt}, + }, internalReq.Messages...) + } + + // Create provider instance + providerInst, err := r.createProvider(provider) + if err != nil { + return ChatResponse{}, err + } + + // Execute chat + result, err := providerInst.Chat(ctx, internalReq) + if err != nil { + return ChatResponse{}, err + } + + return ChatResponse{ + Message: result.Message, + Model: result.Model, + Provider: result.Provider, + FinishReason: result.FinishReason, + Usage: result.Usage, + }, nil +} + +// StreamChat performs a streaming chat completion +func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan StreamChunk, <-chan error) { + chunkChan := make(chan StreamChunk, 10) + errChan := make(chan error, 1) + + go func() { + defer close(chunkChan) + defer close(errChan) + + // Select model and provider + selected, provider, err := r.selectChatModel(ctx, req.Model, req.Provider) + if err != nil { + errChan <- err + return + } + + // Create internal request + internalReq := Request{ + Messages: req.Messages, + Model: selected.ModelID, + Provider: strings.ToLower(provider.ClientType), + } + + // Add system prompt + if len(internalReq.Messages) > 0 && internalReq.Messages[0].Role != "system" { + systemPrompt := SystemPrompt(PromptParams{ + Date: time.Now(), + Locale: "en-US", + Language: "Same as user input", + MaxContextLoadTime: 24 * 60, // 24 hours + Platforms: []string{}, + CurrentPlatform: "api", + }) + internalReq.Messages = append([]Message{ + {Role: "system", Content: systemPrompt}, + }, internalReq.Messages...) + } + + // Create provider instance + providerInst, err := r.createProvider(provider) + if err != nil { + errChan <- err + return + } + + // Execute streaming chat + providerChunkChan, providerErrChan := providerInst.StreamChat(ctx, internalReq) + + // Forward chunks and errors + for { + select { + case chunk, ok := <-providerChunkChan: + if !ok { + return + } + chunkChan <- chunk + case err := <-providerErrChan: + if err != nil { + errChan <- err + } + return + } + } + }() + + return chunkChan, errChan +} + +// selectChatModel selects a chat model based on the request +func (r *Resolver) selectChatModel(ctx context.Context, modelID, providerType string) (models.GetResponse, sqlc.LlmProvider, error) { + if r.modelsService == nil { + return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("models service not configured") + } + + modelID = strings.TrimSpace(modelID) + providerType = strings.ToLower(strings.TrimSpace(providerType)) + + // If no model specified, try to get default chat model + if modelID == "" && providerType == "" { + defaultModel, err := r.modelsService.GetByEnableAs(ctx, models.EnableAsChat) + if err == nil { + provider, err := r.fetchProvider(ctx, defaultModel.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + return defaultModel, provider, nil + } + } + + // List available models + var candidates []models.GetResponse + var err error + if providerType != "" { + candidates, err = r.modelsService.ListByClientType(ctx, models.ClientType(providerType)) + } else { + candidates, err = r.modelsService.ListByType(ctx, models.ModelTypeChat) + } + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + + // Filter chat models + filtered := make([]models.GetResponse, 0, len(candidates)) + for _, model := range candidates { + if model.Type != models.ModelTypeChat { + continue + } + filtered = append(filtered, model) + } + if len(filtered) == 0 { + return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("no chat models available") + } + + // If model specified, find it + if modelID != "" { + for _, model := range filtered { + if model.ModelID == modelID { + provider, err := r.fetchProvider(ctx, model.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + return model, provider, nil + } + } + return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("chat model not found") + } + + // Return first available model + selected := filtered[0] + provider, err := r.fetchProvider(ctx, selected.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + return selected, provider, nil +} + +// fetchProvider fetches provider information from database +func (r *Resolver) fetchProvider(ctx context.Context, providerID string) (sqlc.LlmProvider, error) { + if r.queries == nil { + return sqlc.LlmProvider{}, errors.New("llm provider queries not configured") + } + if strings.TrimSpace(providerID) == "" { + return sqlc.LlmProvider{}, errors.New("llm provider id missing") + } + parsed, err := uuid.Parse(providerID) + if err != nil { + return sqlc.LlmProvider{}, err + } + pgID := pgtype.UUID{Valid: true} + copy(pgID.Bytes[:], parsed[:]) + return r.queries.GetLlmProviderByID(ctx, pgID) +} + +// createProvider creates a provider instance based on configuration +func (r *Resolver) createProvider(provider sqlc.LlmProvider) (Provider, error) { + clientType := strings.ToLower(strings.TrimSpace(provider.ClientType)) + timeout := r.timeout + if timeout <= 0 { + timeout = 30 * time.Second + } + + switch clientType { + case ProviderOpenAI, ProviderOpenAICompat: + if strings.TrimSpace(provider.ApiKey) == "" { + return nil, errors.New("openai api key is required") + } + p, err := NewOpenAIProvider(provider.ApiKey, provider.BaseUrl, timeout) + if err != nil { + return nil, err + } + return p, nil + case ProviderAnthropic: + if strings.TrimSpace(provider.ApiKey) == "" { + return nil, errors.New("anthropic api key is required") + } + p, err := NewAnthropicProvider(provider.ApiKey, timeout) + if err != nil { + return nil, err + } + return p, nil + case ProviderGoogle: + if strings.TrimSpace(provider.ApiKey) == "" { + return nil, errors.New("google api key is required") + } + p, err := NewGoogleProvider(provider.ApiKey, timeout) + if err != nil { + return nil, err + } + return p, nil + case ProviderOllama: + p, err := NewOllamaProvider(provider.BaseUrl, timeout) + if err != nil { + return nil, err + } + return p, nil + default: + return nil, errors.New("unsupported provider type: " + clientType) + } +} diff --git a/internal/chat/types.go b/internal/chat/types.go new file mode 100644 index 00000000..90e73138 --- /dev/null +++ b/internal/chat/types.go @@ -0,0 +1,65 @@ +package chat + +// Message represents a chat message +type Message struct { + Role string `json:"role"` // user, assistant, system + Content string `json:"content"` +} + +// ChatRequest represents an incoming chat request +type ChatRequest struct { + Messages []Message `json:"messages"` + Model string `json:"model,omitempty"` // optional: specific model to use + Provider string `json:"provider,omitempty"` // optional: specific provider to use + Stream bool `json:"stream,omitempty"` +} + +// ChatResponse represents a chat response +type ChatResponse struct { + Message Message `json:"message"` + Model string `json:"model"` + Provider string `json:"provider"` + FinishReason string `json:"finish_reason,omitempty"` + Usage Usage `json:"usage,omitempty"` +} + +// StreamChunk represents a chunk in streaming response +type StreamChunk struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + Model string `json:"model,omitempty"` + Provider string `json:"provider,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` +} + +// Usage represents token usage information +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// Request is the internal request structure +type Request struct { + Messages []Message + Model string + Provider string + Temperature *float32 // optional temperature + ResponseFormat *ResponseFormat // optional response format + MaxTokens *int // optional max tokens +} + +// ResponseFormat specifies the format of the response +type ResponseFormat struct { + Type string `json:"type"` // "json_object" or "text" +} + +// Result is the internal result structure +type Result struct { + Message Message + Model string + Provider string + FinishReason string + Usage Usage +} diff --git a/internal/config/config.go b/internal/config/config.go index ea2c6e2e..b5dd0c93 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -20,8 +20,6 @@ const ( DefaultPGUser = "postgres" DefaultPGDatabase = "memoh" DefaultPGSSLMode = "disable" - DefaultMemoryBaseURL = "https://api.openai.com" - DefaultMemoryTimeout = 10 DefaultQdrantURL = "http://127.0.0.1:6334" DefaultQdrantCollection = "memory" ) @@ -32,7 +30,6 @@ type Config struct { Containerd ContainerdConfig `toml:"containerd"` MCP MCPConfig `toml:"mcp"` Postgres PostgresConfig `toml:"postgres"` - Memory MemoryConfig `toml:"memory"` Qdrant QdrantConfig `toml:"qdrant"` } @@ -66,13 +63,6 @@ type PostgresConfig struct { SSLMode string `toml:"sslmode"` } -type MemoryConfig struct { - BaseURL string `toml:"base_url"` - APIKey string `toml:"api_key"` - Model string `toml:"model"` - TimeoutSeconds int `toml:"timeout_seconds"` -} - type QdrantConfig struct { BaseURL string `toml:"base_url"` APIKey string `toml:"api_key"` @@ -104,11 +94,6 @@ func Load(path string) (Config, error) { Database: DefaultPGDatabase, SSLMode: DefaultPGSSLMode, }, - Memory: MemoryConfig{ - BaseURL: DefaultMemoryBaseURL, - Model: "gpt-4.1-nano", - TimeoutSeconds: DefaultMemoryTimeout, - }, Qdrant: QdrantConfig{ BaseURL: DefaultQdrantURL, Collection: DefaultQdrantCollection, diff --git a/internal/handlers/chat.go b/internal/handlers/chat.go new file mode 100644 index 00000000..f3371941 --- /dev/null +++ b/internal/handlers/chat.go @@ -0,0 +1,129 @@ +package handlers + +import ( + "bufio" + "encoding/json" + "fmt" + "net/http" + + "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/chat" +) + +type ChatHandler struct { + resolver *chat.Resolver +} + +func NewChatHandler(resolver *chat.Resolver) *ChatHandler { + return &ChatHandler{resolver: resolver} +} + +func (h *ChatHandler) Register(e *echo.Echo) { + group := e.Group("/chat") + group.POST("", h.Chat) + group.POST("/stream", h.StreamChat) +} + +// Chat godoc +// @Summary Chat with AI +// @Description Send a chat message and get a response. The system will automatically select an appropriate chat model from the database. +// @Tags chat +// @Accept json +// @Produce json +// @Param request body chat.ChatRequest true "Chat request" +// @Success 200 {object} chat.ChatResponse +// @Failure 400 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /chat [post] +func (h *ChatHandler) Chat(c echo.Context) error { + var req chat.ChatRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + if len(req.Messages) == 0 { + return echo.NewHTTPError(http.StatusBadRequest, "messages are required") + } + + resp, err := h.resolver.Chat(c.Request().Context(), req) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + + return c.JSON(http.StatusOK, resp) +} + +// StreamChat godoc +// @Summary Stream chat with AI +// @Description Send a chat message and get a streaming response. The system will automatically select an appropriate chat model from the database. +// @Tags chat +// @Accept json +// @Produce text/event-stream +// @Param request body chat.ChatRequest true "Chat request" +// @Success 200 {object} chat.StreamChunk +// @Failure 400 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /chat/stream [post] +func (h *ChatHandler) StreamChat(c echo.Context) error { + var req chat.ChatRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + if len(req.Messages) == 0 { + return echo.NewHTTPError(http.StatusBadRequest, "messages are required") + } + + // Set headers for SSE + c.Response().Header().Set(echo.HeaderContentType, "text/event-stream") + c.Response().Header().Set(echo.HeaderCacheControl, "no-cache") + c.Response().Header().Set(echo.HeaderConnection, "keep-alive") + c.Response().WriteHeader(http.StatusOK) + + // Get streaming channels + chunkChan, errChan := h.resolver.StreamChat(c.Request().Context(), req) + + // Create a flusher + flusher, ok := c.Response().Writer.(http.Flusher) + if !ok { + return echo.NewHTTPError(http.StatusInternalServerError, "streaming not supported") + } + + writer := bufio.NewWriter(c.Response().Writer) + + // Stream chunks + for { + select { + case chunk, ok := <-chunkChan: + if !ok { + // Channel closed, send done message + writer.WriteString("data: [DONE]\n\n") + writer.Flush() + flusher.Flush() + return nil + } + + // Marshal chunk to JSON + data, err := json.Marshal(chunk) + if err != nil { + continue + } + + // Write SSE format + writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data))) + writer.Flush() + flusher.Flush() + + case err := <-errChan: + if err != nil { + // Send error as SSE event + errData := map[string]string{"error": err.Error()} + data, _ := json.Marshal(errData) + writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data))) + writer.Flush() + flusher.Flush() + return nil + } + } + } +} diff --git a/internal/handlers/providers.go b/internal/handlers/providers.go new file mode 100644 index 00000000..58a76973 --- /dev/null +++ b/internal/handlers/providers.go @@ -0,0 +1,234 @@ +package handlers + +import ( + "net/http" + + "github.com/labstack/echo/v4" + + "github.com/memohai/memoh/internal/providers" +) + +type ProvidersHandler struct { + service *providers.Service +} + +func NewProvidersHandler(service *providers.Service) *ProvidersHandler { + return &ProvidersHandler{service: service} +} + +func (h *ProvidersHandler) Register(e *echo.Echo) { + group := e.Group("/providers") + group.POST("", h.Create) + group.GET("", h.List) + group.GET("/:id", h.Get) + group.GET("/name/:name", h.GetByName) + group.PUT("/:id", h.Update) + group.DELETE("/:id", h.Delete) + group.GET("/count", h.Count) +} + +// Create godoc +// @Summary Create a new LLM provider +// @Description Create a new LLM provider configuration +// @Tags providers +// @Accept json +// @Produce json +// @Param request body providers.CreateRequest true "Provider configuration" +// @Success 201 {object} providers.GetResponse +// @Failure 400 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /providers [post] +func (h *ProvidersHandler) Create(c echo.Context) error { + var req providers.CreateRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + // Validate required fields + if req.Name == "" { + return echo.NewHTTPError(http.StatusBadRequest, "name is required") + } + if req.ClientType == "" { + return echo.NewHTTPError(http.StatusBadRequest, "client_type is required") + } + if req.BaseURL == "" { + return echo.NewHTTPError(http.StatusBadRequest, "base_url is required") + } + + resp, err := h.service.Create(c.Request().Context(), req) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + + return c.JSON(http.StatusCreated, resp) +} + +// List godoc +// @Summary List all LLM providers +// @Description Get a list of all configured LLM providers, optionally filtered by client type +// @Tags providers +// @Accept json +// @Produce json +// @Param client_type query string false "Client type filter (openai, anthropic, google, ollama)" +// @Success 200 {array} providers.GetResponse +// @Failure 400 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /providers [get] +func (h *ProvidersHandler) List(c echo.Context) error { + clientType := c.QueryParam("client_type") + + var resp []providers.GetResponse + var err error + + if clientType != "" { + resp, err = h.service.ListByClientType(c.Request().Context(), providers.ClientType(clientType)) + } else { + resp, err = h.service.List(c.Request().Context()) + } + + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + + return c.JSON(http.StatusOK, resp) +} + +// Get godoc +// @Summary Get provider by ID +// @Description Get a provider configuration by its ID +// @Tags providers +// @Accept json +// @Produce json +// @Param id path string true "Provider ID (UUID)" +// @Success 200 {object} providers.GetResponse +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /providers/{id} [get] +func (h *ProvidersHandler) Get(c echo.Context) error { + id := c.Param("id") + if id == "" { + return echo.NewHTTPError(http.StatusBadRequest, "id is required") + } + + resp, err := h.service.Get(c.Request().Context(), id) + if err != nil { + return echo.NewHTTPError(http.StatusNotFound, err.Error()) + } + + return c.JSON(http.StatusOK, resp) +} + +// GetByName godoc +// @Summary Get provider by name +// @Description Get a provider configuration by its name +// @Tags providers +// @Accept json +// @Produce json +// @Param name path string true "Provider name" +// @Success 200 {object} providers.GetResponse +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /providers/name/{name} [get] +func (h *ProvidersHandler) GetByName(c echo.Context) error { + name := c.Param("name") + if name == "" { + return echo.NewHTTPError(http.StatusBadRequest, "name is required") + } + + resp, err := h.service.GetByName(c.Request().Context(), name) + if err != nil { + return echo.NewHTTPError(http.StatusNotFound, err.Error()) + } + + return c.JSON(http.StatusOK, resp) +} + +// Update godoc +// @Summary Update provider +// @Description Update an existing provider configuration +// @Tags providers +// @Accept json +// @Produce json +// @Param id path string true "Provider ID (UUID)" +// @Param request body providers.UpdateRequest true "Updated provider configuration" +// @Success 200 {object} providers.GetResponse +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /providers/{id} [put] +func (h *ProvidersHandler) Update(c echo.Context) error { + id := c.Param("id") + if id == "" { + return echo.NewHTTPError(http.StatusBadRequest, "id is required") + } + + var req providers.UpdateRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + resp, err := h.service.Update(c.Request().Context(), id, req) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + + return c.JSON(http.StatusOK, resp) +} + +// Delete godoc +// @Summary Delete provider +// @Description Delete a provider configuration +// @Tags providers +// @Accept json +// @Produce json +// @Param id path string true "Provider ID (UUID)" +// @Success 204 "No Content" +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /providers/{id} [delete] +func (h *ProvidersHandler) Delete(c echo.Context) error { + id := c.Param("id") + if id == "" { + return echo.NewHTTPError(http.StatusBadRequest, "id is required") + } + + if err := h.service.Delete(c.Request().Context(), id); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + + return c.NoContent(http.StatusNoContent) +} + +// Count godoc +// @Summary Count providers +// @Description Get the total count of providers, optionally filtered by client type +// @Tags providers +// @Accept json +// @Produce json +// @Param client_type query string false "Client type filter (openai, anthropic, google, ollama)" +// @Success 200 {object} providers.CountResponse +// @Failure 400 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /providers/count [get] +func (h *ProvidersHandler) Count(c echo.Context) error { + clientType := c.QueryParam("client_type") + + var count int64 + var err error + + if clientType != "" { + count, err = h.service.CountByClientType(c.Request().Context(), providers.ClientType(clientType)) + } else { + count, err = h.service.Count(c.Request().Context()) + } + + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + + return c.JSON(http.StatusOK, providers.CountResponse{Count: count}) +} + diff --git a/internal/memory/llm_provider_client.go b/internal/memory/llm_provider_client.go new file mode 100644 index 00000000..426d3694 --- /dev/null +++ b/internal/memory/llm_provider_client.go @@ -0,0 +1,129 @@ +package memory + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/memohai/memoh/internal/chat" +) + +// ProviderLLMClient uses chat.Provider to make LLM calls for memory operations +type ProviderLLMClient struct { + provider chat.Provider + model string +} + +// NewProviderLLMClient creates a new LLM client that uses chat.Provider +func NewProviderLLMClient(provider chat.Provider, model string) *ProviderLLMClient { + if model == "" { + model = "gpt-4.1-nano-2025-04-14" + } + return &ProviderLLMClient{ + provider: provider, + model: model, + } +} + +// Extract extracts facts from messages using the provider +func (c *ProviderLLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) { + if len(req.Messages) == 0 { + return ExtractResponse{}, fmt.Errorf("messages is required") + } + + parsedMessages := parseMessages(formatMessages(req.Messages)) + systemPrompt, userPrompt := getFactRetrievalMessages(parsedMessages) + + // Call provider with JSON mode + temp := float32(0) + result, err := c.provider.Chat(ctx, chat.Request{ + Model: c.model, + Temperature: &temp, + ResponseFormat: &chat.ResponseFormat{ + Type: "json_object", + }, + Messages: []chat.Message{ + {Role: "system", Content: systemPrompt}, + {Role: "user", Content: userPrompt}, + }, + }) + if err != nil { + return ExtractResponse{}, err + } + + content := result.Message.Content + var parsed ExtractResponse + if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &parsed); err != nil { + return ExtractResponse{}, err + } + return parsed, nil +} + +// Decide decides what actions to take based on facts and existing memories +func (c *ProviderLLMClient) Decide(ctx context.Context, req DecideRequest) (DecideResponse, error) { + if len(req.Facts) == 0 { + return DecideResponse{}, fmt.Errorf("facts is required") + } + + retrieved := make([]map[string]string, 0, len(req.Candidates)) + for _, candidate := range req.Candidates { + retrieved = append(retrieved, map[string]string{ + "id": candidate.ID, + "text": candidate.Memory, + }) + } + + prompt := getUpdateMemoryMessages(retrieved, req.Facts) + + // Call provider with JSON mode + temp := float32(0) + result, err := c.provider.Chat(ctx, chat.Request{ + Model: c.model, + Temperature: &temp, + ResponseFormat: &chat.ResponseFormat{ + Type: "json_object", + }, + Messages: []chat.Message{ + {Role: "user", Content: prompt}, + }, + }) + if err != nil { + return DecideResponse{}, err + } + + content := result.Message.Content + var raw map[string]interface{} + if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &raw); err != nil { + return DecideResponse{}, err + } + + memoryItems := normalizeMemoryItems(raw["memory"]) + actions := make([]DecisionAction, 0, len(memoryItems)) + for _, item := range memoryItems { + event := strings.ToUpper(asString(item["event"])) + if event == "" { + event = "ADD" + } + if event == "NONE" { + continue + } + + text := asString(item["text"]) + if text == "" { + text = asString(item["fact"]) + } + if strings.TrimSpace(text) == "" { + continue + } + + actions = append(actions, DecisionAction{ + Event: event, + ID: normalizeID(item["id"]), + Text: text, + OldMemory: asString(item["old_memory"]), + }) + } + return DecideResponse{Actions: actions}, nil +} + diff --git a/internal/memory/service.go b/internal/memory/service.go index f8bb0e7b..1331537f 100644 --- a/internal/memory/service.go +++ b/internal/memory/service.go @@ -16,20 +16,20 @@ import ( ) type Service struct { - llm *LLMClient - embedder embeddings.Embedder - store *QdrantStore - resolver *embeddings.Resolver + llm LLM + embedder embeddings.Embedder + store *QdrantStore + resolver *embeddings.Resolver defaultTextModelID string defaultMultimodalModelID string } -func NewService(llm *LLMClient, embedder embeddings.Embedder, store *QdrantStore, resolver *embeddings.Resolver, defaultTextModelID, defaultMultimodalModelID string) *Service { +func NewService(llm LLM, embedder embeddings.Embedder, store *QdrantStore, resolver *embeddings.Resolver, defaultTextModelID, defaultMultimodalModelID string) *Service { return &Service{ - llm: llm, - embedder: embedder, - store: store, - resolver: resolver, + llm: llm, + embedder: embedder, + store: store, + resolver: resolver, defaultTextModelID: defaultTextModelID, defaultMultimodalModelID: defaultMultimodalModelID, } @@ -138,10 +138,10 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse } var ( - vector []float32 - store *QdrantStore + vector []float32 + store *QdrantStore vectorName string - err error + err error ) if modality == embeddings.TypeMultimodal { if s.resolver == nil { @@ -237,10 +237,10 @@ func (s *Service) EmbedUpsert(ctx context.Context, req EmbedUpsertRequest) (Embe metadata["model_id"] = result.Model } if err := s.store.Upsert(ctx, []qdrantPoint{{ - ID: id, - Vector: result.Embedding, + ID: id, + Vector: result.Embedding, VectorName: vectorName, - Payload: payload, + Payload: payload, }}); err != nil { return EmbedUpsertResponse{}, err } @@ -280,10 +280,10 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er return MemoryItem{}, err } if err := s.store.Upsert(ctx, []qdrantPoint{{ - ID: req.MemoryID, - Vector: vector, + ID: req.MemoryID, + Vector: vector, VectorName: s.vectorNameForText(), - Payload: payload, + Payload: payload, }}); err != nil { return MemoryItem{}, err } @@ -411,10 +411,10 @@ func (s *Service) applyAdd(ctx context.Context, text string, filters map[string] id := uuid.NewString() payload := buildPayload(text, filters, metadata, "") if err := s.store.Upsert(ctx, []qdrantPoint{{ - ID: id, - Vector: vector, + ID: id, + Vector: vector, VectorName: s.vectorNameForText(), - Payload: payload, + Payload: payload, }}); err != nil { return MemoryItem{}, err } @@ -448,10 +448,10 @@ func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[ return MemoryItem{}, err } if err := s.store.Upsert(ctx, []qdrantPoint{{ - ID: id, - Vector: vector, + ID: id, + Vector: vector, VectorName: s.vectorNameForText(), - Payload: payload, + Payload: payload, }}); err != nil { return MemoryItem{}, err } @@ -756,4 +756,3 @@ func normalizeScore(score, minScore, maxScore float64) float64 { } return (score - minScore) / (maxScore - minScore) } - diff --git a/internal/memory/types.go b/internal/memory/types.go index 0a20f61e..8c6ecf69 100644 --- a/internal/memory/types.go +++ b/internal/memory/types.go @@ -1,5 +1,13 @@ package memory +import "context" + +// LLM is the interface for LLM operations needed by memory service +type LLM interface { + Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) + Decide(ctx context.Context, req DecideRequest) (DecideResponse, error) +} + type Message struct { Role string `json:"role"` Content string `json:"content"` diff --git a/internal/providers/service.go b/internal/providers/service.go new file mode 100644 index 00000000..bf12ec2c --- /dev/null +++ b/internal/providers/service.go @@ -0,0 +1,266 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/db/sqlc" +) + +// Service handles provider operations +type Service struct { + queries *sqlc.Queries +} + +// NewService creates a new provider service +func NewService(queries *sqlc.Queries) *Service { + return &Service{queries: queries} +} + +// Create creates a new LLM provider +func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, error) { + // Validate client type + if !isValidClientType(req.ClientType) { + return GetResponse{}, fmt.Errorf("invalid client_type: %s", req.ClientType) + } + + // Marshal metadata + metadataJSON, err := json.Marshal(req.Metadata) + if err != nil { + return GetResponse{}, fmt.Errorf("marshal metadata: %w", err) + } + + // Create provider + provider, err := s.queries.CreateLlmProvider(ctx, sqlc.CreateLlmProviderParams{ + Name: req.Name, + ClientType: string(req.ClientType), + BaseUrl: req.BaseURL, + ApiKey: req.APIKey, + Metadata: metadataJSON, + }) + if err != nil { + return GetResponse{}, fmt.Errorf("create provider: %w", err) + } + + return s.toGetResponse(provider), nil +} + +// Get retrieves a provider by ID +func (s *Service) Get(ctx context.Context, id string) (GetResponse, error) { + providerID, err := parseUUID(id) + if err != nil { + return GetResponse{}, err + } + + provider, err := s.queries.GetLlmProviderByID(ctx, providerID) + if err != nil { + return GetResponse{}, fmt.Errorf("get provider: %w", err) + } + + return s.toGetResponse(provider), nil +} + +// GetByName retrieves a provider by name +func (s *Service) GetByName(ctx context.Context, name string) (GetResponse, error) { + provider, err := s.queries.GetLlmProviderByName(ctx, name) + if err != nil { + return GetResponse{}, fmt.Errorf("get provider by name: %w", err) + } + + return s.toGetResponse(provider), nil +} + +// List retrieves all providers +func (s *Service) List(ctx context.Context) ([]GetResponse, error) { + providers, err := s.queries.ListLlmProviders(ctx) + if err != nil { + return nil, fmt.Errorf("list providers: %w", err) + } + + results := make([]GetResponse, 0, len(providers)) + for _, p := range providers { + results = append(results, s.toGetResponse(p)) + } + return results, nil +} + +// ListByClientType retrieves providers by client type +func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) ([]GetResponse, error) { + if !isValidClientType(clientType) { + return nil, fmt.Errorf("invalid client_type: %s", clientType) + } + + providers, err := s.queries.ListLlmProvidersByClientType(ctx, string(clientType)) + if err != nil { + return nil, fmt.Errorf("list providers by client type: %w", err) + } + + results := make([]GetResponse, 0, len(providers)) + for _, p := range providers { + results = append(results, s.toGetResponse(p)) + } + return results, nil +} + +// Update updates an existing provider +func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (GetResponse, error) { + providerID, err := parseUUID(id) + if err != nil { + return GetResponse{}, err + } + + // Get existing provider + existing, err := s.queries.GetLlmProviderByID(ctx, providerID) + if err != nil { + return GetResponse{}, fmt.Errorf("get provider: %w", err) + } + + // Apply updates + name := existing.Name + if req.Name != nil { + name = *req.Name + } + + clientType := existing.ClientType + if req.ClientType != nil { + if !isValidClientType(*req.ClientType) { + return GetResponse{}, fmt.Errorf("invalid client_type: %s", *req.ClientType) + } + clientType = string(*req.ClientType) + } + + baseURL := existing.BaseUrl + if req.BaseURL != nil { + baseURL = *req.BaseURL + } + + apiKey := existing.ApiKey + if req.APIKey != nil { + apiKey = *req.APIKey + } + + metadata := existing.Metadata + if req.Metadata != nil { + metadataJSON, err := json.Marshal(req.Metadata) + if err != nil { + return GetResponse{}, fmt.Errorf("marshal metadata: %w", err) + } + metadata = metadataJSON + } + + // Update provider + updated, err := s.queries.UpdateLlmProvider(ctx, sqlc.UpdateLlmProviderParams{ + ID: providerID, + Name: name, + ClientType: clientType, + BaseUrl: baseURL, + ApiKey: apiKey, + Metadata: metadata, + }) + if err != nil { + return GetResponse{}, fmt.Errorf("update provider: %w", err) + } + + return s.toGetResponse(updated), nil +} + +// Delete deletes a provider by ID +func (s *Service) Delete(ctx context.Context, id string) error { + providerID, err := parseUUID(id) + if err != nil { + return err + } + + if err := s.queries.DeleteLlmProvider(ctx, providerID); err != nil { + return fmt.Errorf("delete provider: %w", err) + } + return nil +} + +// Count returns the total count of providers +func (s *Service) Count(ctx context.Context) (int64, error) { + count, err := s.queries.CountLlmProviders(ctx) + if err != nil { + return 0, fmt.Errorf("count providers: %w", err) + } + return count, nil +} + +// CountByClientType returns the count of providers by client type +func (s *Service) CountByClientType(ctx context.Context, clientType ClientType) (int64, error) { + if !isValidClientType(clientType) { + return 0, fmt.Errorf("invalid client_type: %s", clientType) + } + + count, err := s.queries.CountLlmProvidersByClientType(ctx, string(clientType)) + if err != nil { + return 0, fmt.Errorf("count providers by client type: %w", err) + } + return count, nil +} + +// toGetResponse converts a database provider to a response +func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse { + var metadata map[string]interface{} + if len(provider.Metadata) > 0 { + _ = json.Unmarshal(provider.Metadata, &metadata) + } + + // Mask API key (show only first 8 characters) + maskedAPIKey := maskAPIKey(provider.ApiKey) + + // Convert pgtype.UUID to string + var id [16]byte + copy(id[:], provider.ID.Bytes[:]) + idUUID := uuid.UUID(id) + + return GetResponse{ + ID: idUUID.String(), + Name: provider.Name, + ClientType: provider.ClientType, + BaseURL: provider.BaseUrl, + APIKey: maskedAPIKey, + Metadata: metadata, + CreatedAt: provider.CreatedAt.Time, + UpdatedAt: provider.UpdatedAt.Time, + } +} + +// parseUUID parses a UUID string +func parseUUID(id string) (pgtype.UUID, error) { + parsed, err := uuid.Parse(id) + if err != nil { + return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) + } + var pgID pgtype.UUID + pgID.Valid = true + copy(pgID.Bytes[:], parsed[:]) + return pgID, nil +} + +// isValidClientType checks if a client type is valid +func isValidClientType(clientType ClientType) bool { + switch clientType { + case ClientTypeOpenAI, ClientTypeOpenAICompat, ClientTypeAnthropic, ClientTypeGoogle, ClientTypeOllama: + return true + default: + return false + } +} + +// maskAPIKey masks an API key for security +func maskAPIKey(apiKey string) string { + if apiKey == "" { + return "" + } + if len(apiKey) <= 8 { + return strings.Repeat("*", len(apiKey)) + } + return apiKey[:8] + strings.Repeat("*", len(apiKey)-8) +} + diff --git a/internal/providers/types.go b/internal/providers/types.go new file mode 100644 index 00000000..1a8ddd36 --- /dev/null +++ b/internal/providers/types.go @@ -0,0 +1,71 @@ +package providers + +import "time" + +// ClientType represents the type of LLM provider client +type ClientType string + +const ( + ClientTypeOpenAI ClientType = "openai" + ClientTypeOpenAICompat ClientType = "openai-compat" + ClientTypeAnthropic ClientType = "anthropic" + ClientTypeGoogle ClientType = "google" + ClientTypeOllama ClientType = "ollama" +) + +// CreateRequest represents a request to create a new LLM provider +type CreateRequest struct { + Name string `json:"name" validate:"required"` + ClientType ClientType `json:"client_type" validate:"required"` + BaseURL string `json:"base_url" validate:"required,url"` + APIKey string `json:"api_key"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// UpdateRequest represents a request to update an existing LLM provider +type UpdateRequest struct { + Name *string `json:"name,omitempty"` + ClientType *ClientType `json:"client_type,omitempty"` + BaseURL *string `json:"base_url,omitempty"` + APIKey *string `json:"api_key,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// GetResponse represents the response for getting a provider +type GetResponse struct { + ID string `json:"id"` + Name string `json:"name"` + ClientType string `json:"client_type"` + BaseURL string `json:"base_url"` + APIKey string `json:"api_key,omitempty"` // masked in response + Metadata map[string]interface{} `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// ListResponse represents the response for listing providers +type ListResponse struct { + Providers []GetResponse `json:"providers"` + Total int64 `json:"total"` +} + +// CountResponse represents the count response +type CountResponse struct { + Count int64 `json:"count"` +} + +// TestRequest represents a request to test provider connection +type TestRequest struct { + ClientType ClientType `json:"client_type" validate:"required"` + BaseURL string `json:"base_url" validate:"required,url"` + APIKey string `json:"api_key"` + Model string `json:"model"` // optional test model +} + +// TestResponse represents the result of testing a provider +type TestResponse struct { + Success bool `json:"success"` + Message string `json:"message,omitempty"` + Latency int64 `json:"latency_ms,omitempty"` // latency in milliseconds +} + diff --git a/internal/server/server.go b/internal/server/server.go index cdf3509b..2b1b68e7 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -15,7 +15,7 @@ type Server struct { addr string } -func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, fsHandler *handlers.FSHandler, swaggerHandler *handlers.SwaggerHandler) *Server { +func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, fsHandler *handlers.FSHandler, swaggerHandler *handlers.SwaggerHandler, chatHandler *handlers.ChatHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler) *Server { if addr == "" { addr = ":8080" } @@ -53,6 +53,15 @@ func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, if swaggerHandler != nil { swaggerHandler.Register(e) } + if chatHandler != nil { + chatHandler.Register(e) + } + if providersHandler != nil { + providersHandler.Register(e) + } + if modelsHandler != nil { + modelsHandler.Register(e) + } return &Server{ echo: e,