From ca5c6a1866dabd8356fc4ba0c91a0a85e134fcea Mon Sep 17 00:00:00 2001 From: BBQ Date: Thu, 12 Feb 2026 15:33:09 +0800 Subject: [PATCH] refactor(core): restructure conversation, channel and message domains - Rename chat module to conversation with flow-based architecture - Move channelidentities into channel/identities subpackage - Add channel/route for routing logic - Add message service with event hub - Add MCP providers: container, directory, schedule - Refactor Feishu/Telegram adapters with directory and stream support - Add platform management page and channel badges in web UI - Update database schema for conversations, messages and channel routes - Add @memoh/shared package for cross-package type definitions --- agent/src/agent.ts | 76 +- agent/src/models.ts | 1 - agent/src/modules/chat.ts | 5 +- agent/src/prompts/system.ts | 10 +- agent/src/test/unified_mcp_tools.test.ts | 1 - agent/src/tools/memory.ts | 8 +- agent/src/types/agent.ts | 2 - cmd/agent/main.go | 68 +- cmd/feishu-echo/main.go | 120 ++ cmd/feishu-echo/main_test.go | 20 + cmd/mcp/Dockerfile | 18 + cmd/mcp/main.go | 3 +- db/migrations/0001_init.down.sql | 7 +- db/migrations/0001_init.up.sql | 139 +- db/queries/bind.sql | 10 +- db/queries/bots.sql | 22 +- db/queries/channel_identities.sql | 26 +- db/queries/channel_routes.sql | 87 + db/queries/channels.sql | 14 +- db/queries/chats.sql | 214 -- db/queries/containers.sql | 3 + db/queries/conversations.sql | 229 +++ db/queries/messages.sql | 118 ++ internal/accounts/service.go | 41 +- internal/bind/service.go | 42 +- internal/bind/service_integration_test.go | 119 +- .../bind/service_link_integration_test.go | 156 -- internal/bind/service_test.go | 207 ++ internal/bots/service.go | 407 +++- internal/bots/types.go | 64 +- internal/channel/adapter.go | 48 +- internal/channel/adapters/feishu/directory.go | 298 +++ .../channel/adapters/feishu/directory_test.go | 118 ++ internal/channel/adapters/feishu/feishu.go | 655 ++++--- .../feishu/feishu_integration_test.go | 39 +- .../channel/adapters/feishu/feishu_test.go | 350 ++++ internal/channel/adapters/feishu/inbound.go | 269 +++ .../feishu/{feishu_logger.go => logger.go} | 0 internal/channel/adapters/feishu/stream.go | 286 +++ internal/channel/adapters/local/cli.go | 37 +- internal/channel/adapters/local/hub.go | 106 +- internal/channel/adapters/local/hub_test.go | 50 + internal/channel/adapters/local/web.go | 37 +- .../channel/adapters/telegram/directory.go | 239 +++ .../adapters/telegram/directory_test.go | 99 + internal/channel/adapters/telegram/logger.go | 19 + .../channel/adapters/telegram/logger_test.go | 46 + .../channel/adapters/telegram/markdown.go | 210 ++ .../adapters/telegram/markdown_test.go | 209 ++ internal/channel/adapters/telegram/stream.go | 211 ++ .../channel/adapters/telegram/stream_test.go | 119 ++ .../channel/adapters/telegram/telegram.go | 115 +- .../adapters/telegram/telegram_test.go | 241 +++ internal/channel/capabilities.go | 32 +- internal/channel/connection.go | 50 +- internal/channel/directory.go | 2 +- .../identities}/service.go | 18 +- .../service_identity_integration_test.go | 15 +- .../identities}/service_integration_test.go | 15 +- .../identities}/service_test.go | 2 +- .../identities}/types.go | 2 +- internal/channel/inbound_test.go | 217 ++ internal/channel/manager.go | 1 + internal/channel/manager_integration_test.go | 2 +- internal/channel/outbound.go | 157 +- internal/channel/processor.go | 2 +- internal/channel/registry.go | 40 + internal/channel/registry_test.go | 63 + internal/channel/route/service.go | 362 ++++ internal/channel/route/types.go | 68 + internal/channel/service.go | 26 +- internal/channel/types.go | 95 +- internal/chat/assistant_output.go | 2 +- internal/chat/resolver.go | 335 ++-- internal/chat/resolver_memory_context_test.go | 8 +- internal/chat/resolver_test.go | 3 +- internal/chat/schedule_gateway.go | 2 +- internal/chat/service.go | 493 +++-- .../chat/service_presence_integration_test.go | 38 +- internal/chat/types.go | 48 +- .../conversation/flow/assistant_output.go | 36 + internal/conversation/flow/resolver.go | 1226 ++++++++++++ .../flow/resolver_memory_context_test.go | 55 + internal/conversation/flow/resolver_test.go | 158 ++ .../conversation/flow/schedule_gateway.go | 26 + internal/conversation/flow/types.go | 14 + internal/conversation/flow/types_alias.go | 15 + internal/conversation/interfaces.go | 20 + internal/conversation/resolver.go | 1179 +++++++++++ internal/conversation/service_domain.go | 483 +++++ .../service_presence_integration_test.go | 242 +++ internal/conversation/types.go | 235 +++ internal/db/sqlc/bind.sql.go | 22 +- internal/db/sqlc/bots.sql.go | 41 +- internal/db/sqlc/channel_identities.sql.go | 54 +- internal/db/sqlc/channel_routes.sql.go | 298 +++ internal/db/sqlc/channels.sql.go | 38 +- internal/db/sqlc/chats.sql.go | 988 ---------- internal/db/sqlc/containers.sql.go | 39 + internal/db/sqlc/conversations.sql.go | 678 +++++++ internal/db/sqlc/messages.sql.go | 409 ++++ internal/db/sqlc/models.go | 106 +- internal/db/text.go | 11 + internal/db/text_test.go | 26 + internal/db/uuid.go | 12 - internal/embeddings/resolver.go | 13 +- internal/handlers/auth.go | 18 +- internal/handlers/channel.go | 8 +- internal/handlers/chat.go | 642 ------ internal/handlers/containerd.go | 244 ++- internal/handlers/fs.go | 71 - internal/handlers/fs_rest.go | 585 ------ internal/handlers/local_channel.go | 93 +- internal/handlers/mcp.go | 5 - internal/handlers/mcp_federation_gateway.go | 56 +- internal/handlers/mcp_stdio.go | 40 +- internal/handlers/mcp_tools.go | 5 +- internal/handlers/mcp_tools_test.go | 4 +- internal/handlers/memory.go | 146 +- internal/handlers/message.go | 583 ++++++ internal/handlers/schedule.go | 5 - internal/handlers/settings.go | 3 - internal/handlers/subagent.go | 10 - internal/handlers/users.go | 114 +- internal/logger/logger.go | 8 +- internal/logger/logger_test.go | 4 - internal/mcp/connections.go | 20 +- internal/mcp/manager.go | 151 +- internal/mcp/providers/container/fsops.go | 288 +++ .../mcp/providers/container/fsops_test.go | 148 ++ internal/mcp/providers/container/provider.go | 250 +++ .../mcp/providers/container/provider_test.go | 236 +++ internal/mcp/providers/directory/provider.go | 162 ++ .../mcp/providers/directory/provider_test.go | 72 + internal/mcp/providers/memory/provider.go | 101 +- .../mcp/providers/memory/provider_test.go | 284 +++ internal/mcp/providers/message/provider.go | 50 +- .../mcp/providers/message/provider_test.go | 247 +++ .../mcp/providers/schedule/provider_test.go | 374 ++++ internal/mcp/sources/federation/source.go | 17 - .../mcp/sources/federation/source_test.go | 10 - internal/mcp/tool_types.go | 1 - internal/mcp/tools.go | 421 ---- internal/mcp/versioning.go | 27 +- internal/memory/indexer_test.go | 14 +- internal/memory/qdrant_store.go | 2 +- internal/memory/service_test.go | 38 +- internal/message/event/hub.go | 124 ++ internal/message/event/hub_test.go | 59 + internal/message/service.go | 358 ++++ internal/message/types.go | 52 + internal/models/models.go | 47 +- internal/models/types.go | 14 +- internal/preauth/service.go | 35 +- internal/providers/service.go | 29 +- internal/providers/types.go | 12 +- internal/router/channel.go | 522 ++++- internal/router/channel_test.go | 477 ++++- internal/router/identity.go | 147 +- internal/router/identity_test.go | 287 ++- internal/schedule/service.go | 48 +- internal/schedule/trigger.go | 2 +- internal/server/server.go | 6 +- internal/settings/service.go | 22 +- internal/settings/types.go | 12 +- internal/subagent/service.go | 44 +- packages/sdk/src/@pinia/colada.gen.ts | 490 +---- packages/sdk/src/index.ts | 4 +- packages/sdk/src/sdk.gen.ts | 291 +-- packages/sdk/src/types.gen.ts | 1738 +++-------------- packages/shared/README.md | 1 + packages/shared/package.json | 13 + packages/shared/src/chatInfo.ts | 15 + packages/shared/src/index.ts | 5 + packages/shared/src/mcp.ts | 48 + packages/shared/src/model.ts | 101 + packages/shared/src/platform.ts | 7 + packages/shared/src/schedule.ts | 8 + packages/web/mise.toml | 16 +- packages/web/package.json | 15 +- packages/web/public/channels/feishu.png | Bin 0 -> 669 bytes packages/web/public/channels/telegram.webp | Bin 0 -> 6724 bytes packages/web/src/App.vue | 1 - packages/web/src/components/Sidebar/index.vue | 124 ++ .../Sidebar/lists/chat-list-menu.vue | 189 +- .../Sidebar/lists/settings-list-menu.vue | 9 +- .../web/src/components/add-platform/index.vue | 172 ++ .../web/src/components/add-provider/index.vue | 18 +- .../chat-list/assistant-chat/index.vue | 48 +- .../chat-list/channel-badge/index.vue | 46 + .../web/src/components/chat-list/index.vue | 104 +- .../components/chat-list/robot-chat/index.vue | 71 + .../components/chat-list/user-chat/index.vue | 40 +- .../web/src/components/create-mcp/index.vue | 6 +- .../web/src/components/create-model/index.vue | 37 +- packages/web/src/components/sidebar/index.vue | 123 -- packages/web/src/composables/api/useAuth.ts | 24 +- .../web/src/composables/api/useBotSettings.ts | 42 + packages/web/src/composables/api/useBots.ts | 189 +- .../web/src/composables/api/useChannels.ts | 125 +- packages/web/src/composables/api/useChat.ts | 198 +- packages/web/src/composables/api/useMcp.ts | 21 +- packages/web/src/composables/api/useModels.ts | 87 + .../web/src/composables/api/usePlatform.ts | 38 + .../web/src/composables/api/useProviders.ts | 80 + packages/web/src/composables/api/useUsers.ts | 54 +- packages/web/src/composables/useAutoScroll.ts | 60 +- .../web/src/composables/useKeyValueTags.ts | 10 +- packages/web/src/i18n/locales/en.json | 113 +- packages/web/src/i18n/locales/zh.json | 115 +- packages/web/src/main.ts | 17 +- .../src/pages/bots/components/bot-card.vue | 125 +- .../pages/bots/components/bot-channels.vue | 64 +- .../pages/bots/components/bot-settings.vue | 100 +- .../components/channel-settings-panel.vue | 44 +- .../src/pages/bots/components/create-bot.vue | 78 +- .../pages/bots/components/model-select.vue | 12 +- packages/web/src/pages/bots/detail.vue | 883 ++++++++- packages/web/src/pages/bots/index.vue | 59 +- .../src/pages/chat/components/bot-sidebar.vue | 23 +- .../src/pages/chat/components/chat-area.vue | 11 +- .../pages/chat/components/thinking-block.vue | 1 - packages/web/src/pages/chat/index.vue | 16 +- packages/web/src/pages/main-section/index.vue | 2 +- .../pages/models/components/model-item.vue | 6 +- .../pages/models/components/model-list.vue | 6 +- .../pages/models/components/provider-form.vue | 4 +- .../web/src/pages/models/model-setting.vue | 69 +- .../platform/components/platform-card.vue | 71 + packages/web/src/pages/platform/index.vue | 26 + packages/web/src/pages/settings/index.vue | 163 +- packages/web/src/pages/settings/user.vue | 14 +- packages/web/src/router.ts | 73 +- packages/web/src/store/User.ts | 72 + packages/web/src/store/chat-list.ts | 274 ++- packages/web/src/store/settings.ts | 1 - packages/web/src/types/index.ts | 0 packages/web/src/utils/channel-icons.ts | 30 + packages/web/src/utils/request.ts | 10 +- pnpm-lock.yaml | 249 ++- spec/docs.go | 799 ++------ spec/swagger.json | 795 ++------ spec/swagger.yaml | 541 +---- 243 files changed, 21463 insertions(+), 10485 deletions(-) create mode 100644 cmd/feishu-echo/main.go create mode 100644 cmd/feishu-echo/main_test.go create mode 100644 cmd/mcp/Dockerfile create mode 100644 db/queries/channel_routes.sql delete mode 100644 db/queries/chats.sql create mode 100644 db/queries/conversations.sql create mode 100644 db/queries/messages.sql delete mode 100644 internal/bind/service_link_integration_test.go create mode 100644 internal/bind/service_test.go create mode 100644 internal/channel/adapters/feishu/directory.go create mode 100644 internal/channel/adapters/feishu/directory_test.go create mode 100644 internal/channel/adapters/feishu/inbound.go rename internal/channel/adapters/feishu/{feishu_logger.go => logger.go} (100%) create mode 100644 internal/channel/adapters/feishu/stream.go create mode 100644 internal/channel/adapters/local/hub_test.go create mode 100644 internal/channel/adapters/telegram/directory.go create mode 100644 internal/channel/adapters/telegram/directory_test.go create mode 100644 internal/channel/adapters/telegram/logger.go create mode 100644 internal/channel/adapters/telegram/logger_test.go create mode 100644 internal/channel/adapters/telegram/markdown.go create mode 100644 internal/channel/adapters/telegram/markdown_test.go create mode 100644 internal/channel/adapters/telegram/stream.go create mode 100644 internal/channel/adapters/telegram/stream_test.go rename internal/{channelidentities => channel/identities}/service.go (96%) rename internal/{channelidentities => channel/identities}/service_identity_integration_test.go (86%) rename internal/{channelidentities => channel/identities}/service_integration_test.go (85%) rename internal/{channelidentities => channel/identities}/service_test.go (96%) rename internal/{channelidentities => channel/identities}/types.go (95%) create mode 100644 internal/channel/inbound_test.go create mode 100644 internal/channel/registry_test.go create mode 100644 internal/channel/route/service.go create mode 100644 internal/channel/route/types.go create mode 100644 internal/conversation/flow/assistant_output.go create mode 100644 internal/conversation/flow/resolver.go create mode 100644 internal/conversation/flow/resolver_memory_context_test.go create mode 100644 internal/conversation/flow/resolver_test.go create mode 100644 internal/conversation/flow/schedule_gateway.go create mode 100644 internal/conversation/flow/types.go create mode 100644 internal/conversation/flow/types_alias.go create mode 100644 internal/conversation/interfaces.go create mode 100644 internal/conversation/resolver.go create mode 100644 internal/conversation/service_domain.go create mode 100644 internal/conversation/service_presence_integration_test.go create mode 100644 internal/conversation/types.go create mode 100644 internal/db/sqlc/channel_routes.sql.go delete mode 100644 internal/db/sqlc/chats.sql.go create mode 100644 internal/db/sqlc/conversations.sql.go create mode 100644 internal/db/sqlc/messages.sql.go create mode 100644 internal/db/text.go create mode 100644 internal/db/text_test.go delete mode 100644 internal/handlers/chat.go delete mode 100644 internal/handlers/fs_rest.go create mode 100644 internal/handlers/message.go create mode 100644 internal/mcp/providers/container/fsops.go create mode 100644 internal/mcp/providers/container/fsops_test.go create mode 100644 internal/mcp/providers/container/provider.go create mode 100644 internal/mcp/providers/container/provider_test.go create mode 100644 internal/mcp/providers/directory/provider.go create mode 100644 internal/mcp/providers/directory/provider_test.go create mode 100644 internal/mcp/providers/memory/provider_test.go create mode 100644 internal/mcp/providers/message/provider_test.go create mode 100644 internal/mcp/providers/schedule/provider_test.go delete mode 100644 internal/mcp/tools.go create mode 100644 internal/message/event/hub.go create mode 100644 internal/message/event/hub_test.go create mode 100644 internal/message/service.go create mode 100644 internal/message/types.go create mode 100644 packages/shared/README.md create mode 100644 packages/shared/package.json create mode 100644 packages/shared/src/chatInfo.ts create mode 100644 packages/shared/src/index.ts create mode 100644 packages/shared/src/mcp.ts create mode 100644 packages/shared/src/model.ts create mode 100644 packages/shared/src/platform.ts create mode 100644 packages/shared/src/schedule.ts create mode 100644 packages/web/public/channels/feishu.png create mode 100644 packages/web/public/channels/telegram.webp create mode 100644 packages/web/src/components/Sidebar/index.vue create mode 100644 packages/web/src/components/add-platform/index.vue create mode 100644 packages/web/src/components/chat-list/channel-badge/index.vue create mode 100644 packages/web/src/components/chat-list/robot-chat/index.vue delete mode 100644 packages/web/src/components/sidebar/index.vue create mode 100644 packages/web/src/composables/api/useBotSettings.ts create mode 100644 packages/web/src/composables/api/useModels.ts create mode 100644 packages/web/src/composables/api/usePlatform.ts create mode 100644 packages/web/src/composables/api/useProviders.ts create mode 100644 packages/web/src/pages/platform/components/platform-card.vue create mode 100644 packages/web/src/pages/platform/index.vue create mode 100644 packages/web/src/store/User.ts create mode 100644 packages/web/src/types/index.ts create mode 100644 packages/web/src/utils/channel-icons.ts diff --git a/agent/src/agent.ts b/agent/src/agent.ts index 992511ed..3866575d 100644 --- a/agent/src/agent.ts +++ b/agent/src/agent.ts @@ -24,7 +24,6 @@ export const createAgent = ({ currentChannel = 'Unknown Channel', identity = { botId: '', - sessionId: '', containerId: '', channelIdentityId: '', displayName: '', @@ -53,18 +52,43 @@ export const createAgent = ({ toolsContent: '', } } - const fetchFile = async (path: string) => { - const response = await fetch(`/bots/${identity.botId}/container/fs/file?path=${encodeURIComponent(path)}`) - if (!response.ok) { - return '' + const readViaMCP = async (path: string): Promise => { + const url = `${auth.baseUrl.replace(/\/$/, '')}/bots/${identity.botId}/tools` + const headers: Record = { + 'Content-Type': 'application/json', + 'Accept': 'application/json, text/event-stream', + 'Authorization': `Bearer ${auth.bearer}`, } - const data = await response.json().catch(() => ({} as { content?: string })) - return typeof data?.content === 'string' ? data.content : '' + if (identity.channelIdentityId) { + headers['X-Memoh-Channel-Identity-Id'] = identity.channelIdentityId + } + const body = JSON.stringify({ + jsonrpc: '2.0', + id: `read-${path}`, + method: 'tools/call', + params: { name: 'read', arguments: { path } }, + }) + const response = await fetch(url, { method: 'POST', headers, body }) + if (!response.ok) return '' + const data = await response.json().catch(() => ({} as any)) + const structured = data?.result?.structuredContent ?? data?.result?.content?.[0]?.text + if (typeof structured === 'string') { + try { + const parsed = JSON.parse(structured) + return typeof parsed?.content === 'string' ? parsed.content : '' + } catch { + return structured + } + } + if (typeof structured === 'object' && structured?.content) { + return typeof structured.content === 'string' ? structured.content : '' + } + return '' } const [identityContent, soulContent, toolsContent] = await Promise.all([ - fetchFile('IDENTITY.md'), - fetchFile('SOUL.md'), - fetchFile('TOOLS.md'), + readViaMCP('IDENTITY.md'), + readViaMCP('SOUL.md'), + readViaMCP('TOOLS.md'), ]) return { identityContent, @@ -80,6 +104,7 @@ export const createAgent = ({ language, maxContextLoadTime: activeContextTime, channels, + currentChannel, skills, enabledSkills, identityContent, @@ -100,9 +125,6 @@ export const createAgent = ({ const headers: Record = { 'Authorization': `Bearer ${auth.bearer}`, } - if (identity.sessionId) { - headers['X-Memoh-Chat-Id'] = identity.sessionId - } if (identity.channelIdentityId) { headers['X-Memoh-Channel-Identity-Id'] = identity.channelIdentityId } @@ -115,9 +137,6 @@ export const createAgent = ({ if (identity.replyTarget) { headers['X-Memoh-Reply-Target'] = identity.replyTarget } - if (identity.displayName) { - headers['X-Memoh-Display-Name'] = identity.displayName - } const { tools: mcpTools, close: closeMCP } = await getMCPTools(`${baseUrl}/bots/${botId}/tools`, headers) return { tools: mcpTools, @@ -257,6 +276,28 @@ export const createAgent = ({ } } + const resolveStreamErrorMessage = (raw: unknown): string => { + if (raw instanceof Error && raw.message.trim()) { + return raw.message + } + if (typeof raw === 'string' && raw.trim()) { + return raw + } + if (raw && typeof raw === 'object') { + const candidate = raw as { message?: unknown; error?: unknown } + if (typeof candidate.message === 'string' && candidate.message.trim()) { + return candidate.message + } + if (typeof candidate.error === 'string' && candidate.error.trim()) { + return candidate.error + } + if (candidate.error instanceof Error && candidate.error.message.trim()) { + return candidate.error.message + } + } + return 'Model stream failed' + } + async function* stream(input: AgentInput): AsyncGenerator { const userPrompt = generateUserPrompt(input) const messages = [...input.messages, userPrompt] @@ -296,6 +337,9 @@ export const createAgent = ({ input, } for await (const chunk of fullStream) { + if (chunk.type === 'error') { + throw new Error(resolveStreamErrorMessage((chunk as { error?: unknown }).error)) + } switch (chunk.type) { case 'reasoning-start': yield { type: 'reasoning_start', diff --git a/agent/src/models.ts b/agent/src/models.ts index acce0fa8..d7403c97 100644 --- a/agent/src/models.ts +++ b/agent/src/models.ts @@ -22,7 +22,6 @@ export const AllowedActionModel = z.enum(allActions) export const IdentityContextModel = z.object({ botId: z.string().min(1, 'Bot ID is required'), - sessionId: z.string().min(1, 'Session ID is required'), containerId: z.string().min(1, 'Container ID is required'), channelIdentityId: z.string().min(1, 'Channel identity ID is required'), displayName: z.string().min(1, 'Display name is required'), diff --git a/agent/src/modules/chat.ts b/agent/src/modules/chat.ts index db95a559..150d83df 100644 --- a/agent/src/modules/chat.ts +++ b/agent/src/modules/chat.ts @@ -78,9 +78,12 @@ export const chatModule = new Elysia({ prefix: '/chat' }) } } catch (error) { console.error(error) + const message = error instanceof Error && error.message.trim() + ? error.message + : 'Internal server error' yield sse(JSON.stringify({ type: 'error', - message: 'Internal server error', + message, })) } }, { diff --git a/agent/src/prompts/system.ts b/agent/src/prompts/system.ts index 7049b94b..11c8649e 100644 --- a/agent/src/prompts/system.ts +++ b/agent/src/prompts/system.ts @@ -6,6 +6,8 @@ export interface SystemParams { language: string maxContextLoadTime: number channels: string[] + /** Channel where the current session/message is from (e.g. telegram, feishu, web). */ + currentChannel: string skills: AgentSkill[] enabledSkills: AgentSkill[] identityContent?: string @@ -23,11 +25,12 @@ ${skill.content} `.trim() } -export const system = ({ +export const system = ({ date, language, maxContextLoadTime, channels, + currentChannel, skills, enabledSkills, identityContent, @@ -37,6 +40,7 @@ export const system = ({ const headers = { 'language': language, 'available-channels': channels.join(','), + 'current-session-channel': currentChannel, 'max-context-load-time': maxContextLoadTime.toString(), 'time-now': date.toISOString(), } @@ -97,8 +101,12 @@ You have a contacts book to record them that you do not need to worry about who ## Channels +The current session (and the latest user message) is from channel: ${quote(currentChannel)}. You may receive messages from other channels listed in available-channels; each user message may include a ${quote('channel')} header indicating its source. + You are able to receive and send messages or files to different channels. +When you need to resolve a user or group on a channel (e.g. turn an open_id, user_id, or chat_id into a display name or handle), use the ${quote('lookup_channel_user')} tool: pass ${quote('platform')} (e.g. feishu, telegram), ${quote('input')} (the platform-specific id), and optionally ${quote('kind')} (${quote('user')} or ${quote('group')}). It returns name, handle, and id for that entry. + ## Attachments ### Receive diff --git a/agent/src/test/unified_mcp_tools.test.ts b/agent/src/test/unified_mcp_tools.test.ts index 6da02ada..a3c7b319 100644 --- a/agent/src/test/unified_mcp_tools.test.ts +++ b/agent/src/test/unified_mcp_tools.test.ts @@ -74,7 +74,6 @@ describe('getMCPTools (unified endpoint)', () => { const endpoint = `http://127.0.0.1:${server.port}/bots/bot-1/tools` const { tools, close } = await getMCPTools(endpoint, { Authorization: 'Bearer test-token', - 'X-Memoh-Chat-Id': 'chat-1', }) expect(Object.keys(tools)).toContain('search_memory') diff --git a/agent/src/tools/memory.ts b/agent/src/tools/memory.ts index 083b2d48..e2ed7b78 100644 --- a/agent/src/tools/memory.ts +++ b/agent/src/tools/memory.ts @@ -26,11 +26,11 @@ export const getMemoryTools = ({ fetch, identity }: MemoryToolParams) => { limit: z.number().int().positive().max(50).optional(), }), execute: async ({ query, limit }) => { - const chatId = identity.sessionId.trim() - if (!chatId) { - throw new Error('sessionId is required to search memory') + const botId = identity.botId.trim() + if (!botId) { + throw new Error('botId is required to search memory') } - const response = await fetch(`/chats/${chatId}/memory/search`, { + const response = await fetch(`/bots/${botId}/memory/search`, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/agent/src/types/agent.ts b/agent/src/types/agent.ts index 0dc9a6c3..ec92ace4 100644 --- a/agent/src/types/agent.ts +++ b/agent/src/types/agent.ts @@ -4,13 +4,11 @@ import { AgentAttachment } from './attachment' export interface IdentityContext { botId: string - sessionId: string containerId: string channelIdentityId: string displayName: string - // Deprecated compatibility fields kept optional for older callers. contactId?: string contactName?: string contactAlias?: string diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 1a8f9db2..4251adde 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -15,21 +15,27 @@ import ( "github.com/memohai/memoh/internal/channel/adapters/feishu" "github.com/memohai/memoh/internal/channel/adapters/local" "github.com/memohai/memoh/internal/channel/adapters/telegram" - "github.com/memohai/memoh/internal/channelidentities" - "github.com/memohai/memoh/internal/chat" + "github.com/memohai/memoh/internal/channel/identities" + "github.com/memohai/memoh/internal/channel/route" "github.com/memohai/memoh/internal/config" ctr "github.com/memohai/memoh/internal/containerd" + "github.com/memohai/memoh/internal/conversation" + "github.com/memohai/memoh/internal/conversation/flow" "github.com/memohai/memoh/internal/db" dbsqlc "github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/embeddings" "github.com/memohai/memoh/internal/handlers" "github.com/memohai/memoh/internal/logger" "github.com/memohai/memoh/internal/mcp" + mcpcontainer "github.com/memohai/memoh/internal/mcp/providers/container" + mcpdirectory "github.com/memohai/memoh/internal/mcp/providers/directory" mcpmemory "github.com/memohai/memoh/internal/mcp/providers/memory" mcpmessage "github.com/memohai/memoh/internal/mcp/providers/message" mcpschedule "github.com/memohai/memoh/internal/mcp/providers/schedule" mcpfederation "github.com/memohai/memoh/internal/mcp/sources/federation" "github.com/memohai/memoh/internal/memory" + "github.com/memohai/memoh/internal/message" + "github.com/memohai/memoh/internal/message/event" "github.com/memohai/memoh/internal/models" "github.com/memohai/memoh/internal/policy" "github.com/memohai/memoh/internal/preauth" @@ -85,7 +91,7 @@ func main() { defer client.Close() service := ctr.NewDefaultService(logger.L, client, cfg.Containerd.Namespace) - manager := mcp.NewManager(logger.L, service, cfg.MCP) + manager := mcp.NewManager(logger.L, service, cfg.MCP, cfg.Containerd.Namespace) pingHandler := handlers.NewPingHandler(logger.L) // containerdHandler is created later after DB services are initialized @@ -101,8 +107,10 @@ func main() { modelsService := models.NewService(logger.L, queries) botService := bots.NewService(logger.L, queries) accountService := accounts.NewService(logger.L, queries) + settingsService := settings.NewService(logger.L, queries) + policyService := policy.NewService(logger.L, botService, settingsService) - containerdHandler := handlers.NewContainerdHandler(logger.L, service, cfg.MCP, cfg.Containerd.Namespace, botService, accountService, queries) + containerdHandler := handlers.NewContainerdHandler(logger.L, service, cfg.MCP, cfg.Containerd.Namespace, botService, accountService, policyService, queries) botService.SetContainerLifecycle(containerdHandler) if err := ensureAdminUser(ctx, logger.L, queries, cfg); err != nil { @@ -112,8 +120,8 @@ func main() { authHandler := handlers.NewAuthHandler(logger.L, accountService, cfg.Auth.JWTSecret, jwtExpiresIn) - // Initialize chat resolver after memory service is configured. - var chatResolver *chat.Resolver + // Initialize conversation runner after memory service is configured. + var chatResolver *flow.Resolver // Create LLM client for memory operations (deferred model/provider selection). var llmClient memory.LLM = &lazyLLMClient{ @@ -147,42 +155,43 @@ func main() { // Initialize providers and models handlers providersService := providers.NewService(logger.L, queries) providersHandler := handlers.NewProvidersHandler(logger.L, providersService, modelsService) - settingsService := settings.NewService(logger.L, queries) settingsHandler := handlers.NewSettingsHandler(logger.L, settingsService, botService, accountService) modelsHandler := handlers.NewModelsHandler(logger.L, modelsService, settingsService) - policyService := policy.NewService(logger.L, botService, settingsService) - chatService := chat.NewService(logger.L, queries) + chatService := conversation.NewService(logger.L, queries) + routeService := route.NewService(logger.L, queries, chatService) + messageEvents := event.NewHub() + messageService := message.NewService(logger.L, queries, messageEvents) memoryHandler := handlers.NewMemoryHandler(logger.L, memoryService, chatService, accountService) - actorService := channelidentities.NewService(logger.L, queries) + channelIdentitySvc := identities.NewService(logger.L, queries) preauthService := preauth.NewService(queries) preauthHandler := handlers.NewPreauthHandler(preauthService, botService, accountService) bindService := bind.NewService(logger.L, conn, queries) bindHandler := handlers.NewBindHandler(logger.L, bindService) mcpConnectionsService := mcp.NewConnectionService(logger.L, queries) mcpHandler := handlers.NewMCPHandler(logger.L, mcpConnectionsService, botService, accountService) - chatResolver = chat.NewResolver(logger.L, modelsService, queries, memoryService, chatService, settingsService, mcpConnectionsService, cfg.AgentGateway.BaseURL(), 120*time.Second) + chatResolver = flow.NewResolver(logger.L, modelsService, queries, memoryService, chatService, messageService, settingsService, mcpConnectionsService, cfg.AgentGateway.BaseURL(), 120*time.Second) chatResolver.SetSkillLoader(&skillLoaderAdapter{handler: containerdHandler}) embeddingsHandler := handlers.NewEmbeddingsHandler(logger.L, modelsService, queries) swaggerHandler := handlers.NewSwaggerHandler(logger.L) - chatHandler := handlers.NewChatHandler(logger.L, chatResolver, chatService, botService, accountService) + conversationHandler := handlers.NewMessageHandler(logger.L, chatResolver, chatService, messageService, botService, accountService, channelIdentitySvc, messageEvents) channelRegistry := channel.NewRegistry() - sessionHub := local.NewSessionHub() + routeHub := local.NewRouteHub() channelRegistry.MustRegister(telegram.NewTelegramAdapter(logger.L)) channelRegistry.MustRegister(feishu.NewFeishuAdapter(logger.L)) - channelRegistry.MustRegister(local.NewCLIAdapter(sessionHub)) - channelRegistry.MustRegister(local.NewWebAdapter(sessionHub)) + channelRegistry.MustRegister(local.NewCLIAdapter(routeHub)) + channelRegistry.MustRegister(local.NewWebAdapter(routeHub)) channelService := channel.NewService(queries, channelRegistry) - channelRouter := router.NewChannelInboundProcessor(logger.L, channelRegistry, chatService, chatResolver, actorService, botService, policyService, preauthService, bindService, cfg.Auth.JWTSecret, 5*time.Minute) + channelRouter := router.NewChannelInboundProcessor(logger.L, channelRegistry, routeService, messageService, chatResolver, channelIdentitySvc, botService, policyService, preauthService, bindService, cfg.Auth.JWTSecret, 5*time.Minute) channelManager := channel.NewManager(logger.L, channelRegistry, channelService, channelRouter) if mw := channelRouter.IdentityMiddleware(); mw != nil { channelManager.Use(mw) } channelManager.Start(ctx) channelHandler := handlers.NewChannelHandler(channelService, channelRegistry) - usersHandler := handlers.NewUsersHandler(logger.L, accountService, actorService, botService, chatService, channelService, channelManager, channelRegistry) - cliHandler := handlers.NewLocalChannelHandler(local.CLIType, channelManager, channelService, chatService, sessionHub, botService, accountService) - webHandler := handlers.NewLocalChannelHandler(local.WebType, channelManager, channelService, chatService, sessionHub, botService, accountService) - scheduleGateway := chat.NewScheduleGateway(chatResolver) + usersHandler := handlers.NewUsersHandler(logger.L, accountService, channelIdentitySvc, botService, routeService, channelService, channelManager, channelRegistry) + cliHandler := handlers.NewLocalChannelHandler(local.CLIType, channelManager, channelService, chatService, routeHub, botService, accountService) + webHandler := handlers.NewLocalChannelHandler(local.WebType, channelManager, channelService, chatService, routeHub, botService, accountService) + scheduleGateway := flow.NewScheduleGateway(chatResolver) scheduleService := schedule.NewService(logger.L, queries, scheduleGateway, cfg.Auth.JWTSecret) if err := scheduleService.Bootstrap(ctx); err != nil { logger.Error("schedule bootstrap", slog.Any("error", err)) @@ -192,23 +201,32 @@ func main() { subagentService := subagent.NewService(logger.L, queries) subagentHandler := handlers.NewSubagentHandler(logger.L, subagentService, botService, accountService) messageToolExecutor := mcpmessage.NewExecutor(logger.L, channelManager, channelRegistry) + directoryToolExecutor := mcpdirectory.NewExecutor(logger.L, channelRegistry, channelService, channelRegistry) scheduleToolExecutor := mcpschedule.NewExecutor(logger.L, scheduleService) memoryToolExecutor := mcpmemory.NewExecutor(logger.L, memoryService, chatService, accountService) + execWorkDir := cfg.MCP.DataMount + if strings.TrimSpace(execWorkDir) == "" { + execWorkDir = config.DefaultDataMount + } + fsToolExecutor := mcpcontainer.NewExecutor(logger.L, manager, execWorkDir) federationGateway := handlers.NewMCPFederationGateway(logger.L, containerdHandler) federatedToolSource := mcpfederation.NewSource(logger.L, federationGateway, mcpConnectionsService) toolGatewayService := mcp.NewToolGatewayService( logger.L, []mcp.ToolExecutor{ messageToolExecutor, + directoryToolExecutor, scheduleToolExecutor, memoryToolExecutor, + fsToolExecutor, }, []mcp.ToolSource{ federatedToolSource, }, ) containerdHandler.SetToolGatewayService(toolGatewayService) - srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, preauthHandler, bindHandler, scheduleHandler, subagentHandler, containerdHandler, channelHandler, usersHandler, mcpHandler, cliHandler, webHandler) + go containerdHandler.ReconcileContainers(ctx) + srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, conversationHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, preauthHandler, bindHandler, scheduleHandler, subagentHandler, containerdHandler, channelHandler, usersHandler, mcpHandler, cliHandler, webHandler) if err := srv.Start(); err != nil { logger.Error("server failed", slog.Any("error", err)) @@ -371,19 +389,19 @@ func (c *lazyLLMClient) resolve(ctx context.Context) (memory.LLM, error) { return memory.NewLLMClient(c.logger, memoryProvider.BaseUrl, memoryProvider.ApiKey, memoryModel.ModelID, c.timeout) } -// skillLoaderAdapter bridges handlers.ContainerdHandler to chat.SkillLoader. +// skillLoaderAdapter bridges handlers.ContainerdHandler to flow.SkillLoader. type skillLoaderAdapter struct { handler *handlers.ContainerdHandler } -func (a *skillLoaderAdapter) LoadSkills(ctx context.Context, botID string) ([]chat.SkillEntry, error) { +func (a *skillLoaderAdapter) LoadSkills(ctx context.Context, botID string) ([]flow.SkillEntry, error) { items, err := a.handler.LoadSkills(ctx, botID) if err != nil { return nil, err } - entries := make([]chat.SkillEntry, len(items)) + entries := make([]flow.SkillEntry, len(items)) for i, item := range items { - entries[i] = chat.SkillEntry{ + entries[i] = flow.SkillEntry{ Name: item.Name, Description: item.Description, Content: item.Content, diff --git a/cmd/feishu-echo/main.go b/cmd/feishu-echo/main.go new file mode 100644 index 00000000..d092ca6a --- /dev/null +++ b/cmd/feishu-echo/main.go @@ -0,0 +1,120 @@ +// feishu-echo is a minimal Feishu bot that connects via WebSocket and counts received events. +// Used to verify whether message loss is due to our app logic or network/Feishu delivery. +// +// Usage: +// +// FEISHU_APP_ID=xxx FEISHU_APP_SECRET=xxx FEISHU_ENCRYPT=xxx FEISHU_VERIFY=xxx go run ./cmd/feishu-echo +package main + +import ( + "context" + "log" + "os" + "os/signal" + "strings" + "sync/atomic" + "time" + + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + larkws "github.com/larksuite/oapi-sdk-go/v3/ws" + + "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher" +) + +type eventCounts struct { + messageReceive atomic.Int64 + messageRead atomic.Int64 + reactionCreated atomic.Int64 + reactionDeleted atomic.Int64 +} + +func (c *eventCounts) log() { + log.Printf("[feishu-echo] counts: receive=%d read=%d reaction_created=%d reaction_deleted=%d", + c.messageReceive.Load(), c.messageRead.Load(), c.reactionCreated.Load(), c.reactionDeleted.Load()) +} + +func main() { + appID := strings.TrimSpace(os.Getenv("FEISHU_APP_ID")) + appSecret := strings.TrimSpace(os.Getenv("FEISHU_APP_SECRET")) + encryptKey := strings.TrimSpace(os.Getenv("FEISHU_ENCRYPT")) + verifyToken := strings.TrimSpace(os.Getenv("FEISHU_VERIFY")) + + if appID == "" || appSecret == "" { + log.Fatal("FEISHU_APP_ID and FEISHU_APP_SECRET are required") + } + + log.Printf("[feishu-echo] starting with app_id=%s (encrypt=%v, verify=%v)", appID, encryptKey != "", verifyToken != "") + + counts := new(eventCounts) + eventDispatcher := dispatcher.NewEventDispatcher(verifyToken, encryptKey) + + eventDispatcher.OnP2MessageReceiveV1(func(_ context.Context, _ *larkim.P2MessageReceiveV1) error { + counts.messageReceive.Add(1) + counts.log() + return nil + }) + + eventDispatcher.OnP2MessageReadV1(func(_ context.Context, _ *larkim.P2MessageReadV1) error { + counts.messageRead.Add(1) + counts.log() + return nil + }) + + eventDispatcher.OnP2MessageReactionCreatedV1(func(_ context.Context, _ *larkim.P2MessageReactionCreatedV1) error { + counts.reactionCreated.Add(1) + counts.log() + return nil + }) + + eventDispatcher.OnP2MessageReactionDeletedV1(func(_ context.Context, _ *larkim.P2MessageReactionDeletedV1) error { + counts.reactionDeleted.Add(1) + counts.log() + return nil + }) + + client := larkws.NewClient( + appID, + appSecret, + larkws.WithEventHandler(eventDispatcher), + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt) + <-sig + log.Println("[feishu-echo] interrupt, shutting down") + cancel() + counts.log() + os.Exit(0) + }() + + const reconnectDelay = 3 * time.Second +run: + for { + if ctx.Err() != nil { + break run + } + log.Println("[feishu-echo] connecting to Feishu WebSocket...") + err := client.Start(ctx) + if ctx.Err() != nil { + break run + } + if err != nil { + log.Printf("[feishu-echo] client error: %v; reconnecting in %v", err, reconnectDelay) + } else { + log.Printf("[feishu-echo] connection closed; reconnecting in %v", reconnectDelay) + } + timer := time.NewTimer(reconnectDelay) + select { + case <-ctx.Done(): + timer.Stop() + break run + case <-timer.C: + } + } + counts.log() + log.Println("[feishu-echo] stopped") +} diff --git a/cmd/feishu-echo/main_test.go b/cmd/feishu-echo/main_test.go new file mode 100644 index 00000000..9ed03377 --- /dev/null +++ b/cmd/feishu-echo/main_test.go @@ -0,0 +1,20 @@ +package main + +import ( + "testing" +) + +func TestEventCounts(t *testing.T) { + c := new(eventCounts) + c.log() + if c.messageReceive.Load() != 0 || c.messageRead.Load() != 0 { + t.Fatalf("initial counts should be 0") + } + c.messageReceive.Add(2) + c.messageRead.Add(1) + c.reactionCreated.Add(1) + if c.messageReceive.Load() != 2 || c.messageRead.Load() != 1 || c.reactionCreated.Load() != 1 { + t.Fatalf("counts after add: receive=2 read=1 reaction_created=1") + } + c.log() +} diff --git a/cmd/mcp/Dockerfile b/cmd/mcp/Dockerfile new file mode 100644 index 00000000..1daf0c51 --- /dev/null +++ b/cmd/mcp/Dockerfile @@ -0,0 +1,18 @@ +FROM golang:1.25-alpine AS build + +WORKDIR /src +COPY go.mod go.sum ./ +RUN go mod download + +COPY . . +ARG TARGETARCH +ARG COMMIT_HASH=unknown +RUN CGO_ENABLED=0 GOOS=linux GOARCH=${TARGETARCH:-amd64} \ + go build -trimpath -ldflags "-s -w -X github.com/memohai/memoh/internal/version.CommitHash=${COMMIT_HASH}" -o /out/mcp ./cmd/mcp + +FROM alpine:latest +RUN apk add --no-cache grep +WORKDIR /app +COPY --from=build /out/mcp /opt/mcp +COPY cmd/mcp/template /opt/mcp-template +ENTRYPOINT ["/bin/sh","-lc","bootstrap(){ [ -e /app/mcp ] || { mkdir -p /app; [ -f /opt/mcp ] && cp -a /opt/mcp /app/mcp 2>/dev/null || true; }; }; bootstrap; if [ -x /app/mcp ]; then exec /app/mcp \"$@\"; fi; exec /opt/mcp \"$@\"","--"] diff --git a/cmd/mcp/main.go b/cmd/mcp/main.go index ac5ed75f..9110eceb 100644 --- a/cmd/mcp/main.go +++ b/cmd/mcp/main.go @@ -10,7 +10,6 @@ import ( "syscall" "github.com/memohai/memoh/internal/logger" - "github.com/memohai/memoh/internal/mcp" "github.com/memohai/memoh/internal/version" gomcp "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -19,11 +18,11 @@ func main() { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() + // File tools (read/write/list/edit) are provided by the agent's MCP tool gateway, not this binary. server := gomcp.NewServer( &gomcp.Implementation{Name: "memoh-mcp", Version: version.GetInfo()}, nil, ) - mcp.RegisterTools(server) err := server.Run(ctx, &gomcp.StdioTransport{}) if ctx.Err() != nil { return diff --git a/db/migrations/0001_init.down.sql b/db/migrations/0001_init.down.sql index 014bed2e..8369a94b 100644 --- a/db/migrations/0001_init.down.sql +++ b/db/migrations/0001_init.down.sql @@ -4,11 +4,8 @@ DROP TABLE IF EXISTS lifecycle_events; DROP TABLE IF EXISTS container_versions; DROP TABLE IF EXISTS snapshots; DROP TABLE IF EXISTS containers; -DROP TABLE IF EXISTS chat_routes; -DROP TABLE IF EXISTS chat_messages; -DROP TABLE IF EXISTS chat_channel_identity_presence; -DROP TABLE IF EXISTS chat_participants; -DROP TABLE IF EXISTS chats; +DROP TABLE IF EXISTS bot_history_messages; +DROP TABLE IF EXISTS bot_channel_routes; DROP TABLE IF EXISTS channel_identity_bind_codes; DROP TABLE IF EXISTS bot_preauth_keys; DROP TABLE IF EXISTS bot_channel_configs; diff --git a/db/migrations/0001_init.up.sql b/db/migrations/0001_init.up.sql index 39528a14..812b7865 100644 --- a/db/migrations/0001_init.up.sql +++ b/db/migrations/0001_init.up.sql @@ -36,13 +36,13 @@ CREATE TABLE IF NOT EXISTS users ( CREATE TABLE IF NOT EXISTS channel_identities ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), user_id UUID REFERENCES users(id) ON DELETE SET NULL, - channel TEXT NOT NULL, + channel_type TEXT NOT NULL, channel_subject_id TEXT NOT NULL, display_name TEXT, metadata JSONB NOT NULL DEFAULT '{}'::jsonb, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), - CONSTRAINT channel_identities_channel_subject_unique UNIQUE (channel, channel_subject_id) + CONSTRAINT channel_identities_channel_type_subject_unique UNIQUE (channel_type, channel_subject_id) ); CREATE INDEX IF NOT EXISTS idx_channel_identities_user_id ON channel_identities(user_id); @@ -51,11 +51,11 @@ CREATE INDEX IF NOT EXISTS idx_channel_identities_user_id ON channel_identities( CREATE TABLE IF NOT EXISTS user_channel_bindings ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, - platform TEXT NOT NULL, + channel_type TEXT NOT NULL, config JSONB NOT NULL DEFAULT '{}'::jsonb, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), - CONSTRAINT user_channel_bindings_unique UNIQUE (user_id, platform) + CONSTRAINT user_channel_bindings_unique UNIQUE (user_id, channel_type) ); CREATE INDEX IF NOT EXISTS idx_user_channel_bindings_user_id ON user_channel_bindings(user_id); @@ -108,6 +108,7 @@ CREATE TABLE IF NOT EXISTS bots ( display_name TEXT, avatar_url TEXT, is_active BOOLEAN NOT NULL DEFAULT true, + status TEXT NOT NULL DEFAULT 'ready', max_context_load_time INTEGER NOT NULL DEFAULT 1440, language TEXT NOT NULL DEFAULT 'auto', allow_guest BOOLEAN NOT NULL DEFAULT false, @@ -117,7 +118,8 @@ CREATE TABLE IF NOT EXISTS bots ( metadata JSONB NOT NULL DEFAULT '{}'::jsonb, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), - CONSTRAINT bots_type_check CHECK (type IN ('personal', 'public')) + CONSTRAINT bots_type_check CHECK (type IN ('personal', 'public')), + CONSTRAINT bots_status_check CHECK (status IN ('creating', 'ready', 'deleting')) ); CREATE INDEX IF NOT EXISTS idx_bots_owner_user_id ON bots(owner_user_id); @@ -148,85 +150,7 @@ CREATE TABLE IF NOT EXISTS mcp_connections ( CREATE INDEX IF NOT EXISTS idx_mcp_connections_bot_id ON mcp_connections(bot_id); --- chats: first-class conversation container -CREATE TABLE IF NOT EXISTS chats ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, - kind TEXT NOT NULL CHECK (kind IN ('direct', 'group', 'thread')), - parent_chat_id UUID REFERENCES chats(id) ON DELETE CASCADE, - title TEXT, - created_by_user_id UUID REFERENCES users(id), - metadata JSONB NOT NULL DEFAULT '{}'::jsonb, - enable_chat_memory BOOLEAN NOT NULL DEFAULT true, - enable_private_memory BOOLEAN NOT NULL DEFAULT true, - enable_public_memory BOOLEAN NOT NULL DEFAULT false, - model_id TEXT, - settings_metadata JSONB NOT NULL DEFAULT '{}'::jsonb, - created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT now() -); - -CREATE INDEX IF NOT EXISTS idx_chats_bot_id ON chats(bot_id); -CREATE INDEX IF NOT EXISTS idx_chats_parent ON chats(parent_chat_id); - --- chat_participants: chat membership -CREATE TABLE IF NOT EXISTS chat_participants ( - chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE, - user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, - role TEXT NOT NULL DEFAULT 'member' CHECK (role IN ('owner', 'admin', 'member')), - joined_at TIMESTAMPTZ NOT NULL DEFAULT now(), - PRIMARY KEY (chat_id, user_id) -); - -CREATE INDEX IF NOT EXISTS idx_chat_participants_user ON chat_participants(user_id); - --- chat_messages: per-message storage (replaces history) -CREATE TABLE IF NOT EXISTS chat_messages ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE, - bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, - route_id UUID, - sender_channel_identity_id UUID REFERENCES channel_identities(id), - sender_user_id UUID REFERENCES users(id), - platform TEXT, - external_message_id TEXT, - role TEXT NOT NULL CHECK (role IN ('user', 'assistant', 'system', 'tool')), - content JSONB NOT NULL, - metadata JSONB NOT NULL DEFAULT '{}'::jsonb, - created_at TIMESTAMPTZ NOT NULL DEFAULT now() -); - --- Backfill newly introduced columns for existing deployments where chat_messages --- was created before route/platform traceability fields were added. -ALTER TABLE IF EXISTS chat_messages - ADD COLUMN IF NOT EXISTS route_id UUID; - -ALTER TABLE IF EXISTS chat_messages - ADD COLUMN IF NOT EXISTS platform TEXT; - -ALTER TABLE IF EXISTS chat_messages - ADD COLUMN IF NOT EXISTS external_message_id TEXT; - -CREATE INDEX IF NOT EXISTS idx_chat_messages_chat_created ON chat_messages(chat_id, created_at); -CREATE INDEX IF NOT EXISTS idx_chat_messages_bot ON chat_messages(bot_id); -CREATE INDEX IF NOT EXISTS idx_chat_messages_route ON chat_messages(route_id); -CREATE INDEX IF NOT EXISTS idx_chat_messages_external_lookup - ON chat_messages(platform, external_message_id); - --- chat_channel_identity_presence: derived cache of channel identities observed in chats -CREATE TABLE IF NOT EXISTS chat_channel_identity_presence ( - chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE, - channel_identity_id UUID NOT NULL REFERENCES channel_identities(id) ON DELETE CASCADE, - first_seen_at TIMESTAMPTZ NOT NULL DEFAULT now(), - last_seen_at TIMESTAMPTZ NOT NULL DEFAULT now(), - message_count BIGINT NOT NULL DEFAULT 1, - PRIMARY KEY (chat_id, channel_identity_id) -); - -CREATE INDEX IF NOT EXISTS idx_chat_channel_identity_presence_identity_last_seen - ON chat_channel_identity_presence(channel_identity_id, last_seen_at DESC); -CREATE INDEX IF NOT EXISTS idx_chat_channel_identity_presence_chat_last_seen - ON chat_channel_identity_presence(chat_id, last_seen_at DESC); +-- Bot history is bot-scoped (one history container per bot). CREATE TABLE IF NOT EXISTS bot_channel_configs ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), @@ -269,7 +193,7 @@ CREATE TABLE IF NOT EXISTS channel_identity_bind_codes ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), token TEXT NOT NULL, issued_by_user_id UUID NOT NULL REFERENCES users(id), - platform TEXT, + channel_type TEXT, expires_at TIMESTAMPTZ, used_at TIMESTAMPTZ, used_by_channel_identity_id UUID REFERENCES channel_identities(id), @@ -277,27 +201,48 @@ CREATE TABLE IF NOT EXISTS channel_identity_bind_codes ( CONSTRAINT channel_identity_bind_codes_token_unique UNIQUE (token) ); -CREATE INDEX IF NOT EXISTS idx_channel_identity_bind_codes_platform ON channel_identity_bind_codes(platform); +CREATE INDEX IF NOT EXISTS idx_channel_identity_bind_codes_channel_type ON channel_identity_bind_codes(channel_type); --- chat_routes: routing mapping (replaces channel_sessions) -CREATE TABLE IF NOT EXISTS chat_routes ( +-- bot_channel_routes: route mapping for inbound channel threads to bot history. +CREATE TABLE IF NOT EXISTS bot_channel_routes ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE, bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, - platform TEXT NOT NULL, + channel_type TEXT NOT NULL, channel_config_id UUID REFERENCES bot_channel_configs(id) ON DELETE SET NULL, - conversation_id TEXT NOT NULL, - thread_id TEXT, - reply_target TEXT, + external_conversation_id TEXT NOT NULL, + external_thread_id TEXT, + default_reply_target TEXT, metadata JSONB NOT NULL DEFAULT '{}'::jsonb, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now() ); -CREATE UNIQUE INDEX IF NOT EXISTS idx_chat_routes_unique - ON chat_routes (bot_id, platform, conversation_id, COALESCE(thread_id, '')); -CREATE INDEX IF NOT EXISTS idx_chat_routes_chat ON chat_routes(chat_id); -CREATE INDEX IF NOT EXISTS idx_chat_routes_bot ON chat_routes(bot_id); +CREATE UNIQUE INDEX IF NOT EXISTS idx_bot_channel_routes_unique + ON bot_channel_routes (bot_id, channel_type, external_conversation_id, COALESCE(external_thread_id, '')); +CREATE INDEX IF NOT EXISTS idx_bot_channel_routes_bot ON bot_channel_routes(bot_id); + +-- bot_history_messages: unified message history under bot scope. +CREATE TABLE IF NOT EXISTS bot_history_messages ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, + route_id UUID REFERENCES bot_channel_routes(id) ON DELETE SET NULL, + sender_channel_identity_id UUID REFERENCES channel_identities(id), + sender_account_user_id UUID REFERENCES users(id), + channel_type TEXT, + source_message_id TEXT, + source_reply_to_message_id TEXT, + role TEXT NOT NULL CHECK (role IN ('user', 'assistant', 'system', 'tool')), + content JSONB NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_bot_history_messages_bot_created ON bot_history_messages(bot_id, created_at); +CREATE INDEX IF NOT EXISTS idx_bot_history_messages_route ON bot_history_messages(route_id); +CREATE INDEX IF NOT EXISTS idx_bot_history_messages_source_lookup + ON bot_history_messages(channel_type, source_message_id); +CREATE INDEX IF NOT EXISTS idx_bot_history_messages_reply_lookup + ON bot_history_messages(channel_type, source_reply_to_message_id); CREATE TABLE IF NOT EXISTS containers ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), diff --git a/db/queries/bind.sql b/db/queries/bind.sql index 01233862..a2f30218 100644 --- a/db/queries/bind.sql +++ b/db/queries/bind.sql @@ -1,15 +1,15 @@ -- name: CreateBindCode :one -INSERT INTO channel_identity_bind_codes (token, issued_by_user_id, platform, expires_at) +INSERT INTO channel_identity_bind_codes (token, issued_by_user_id, channel_type, expires_at) VALUES ($1, $2, $3, $4) -RETURNING id, token, issued_by_user_id, platform, expires_at, used_at, used_by_channel_identity_id, created_at; +RETURNING id, token, issued_by_user_id, channel_type, expires_at, used_at, used_by_channel_identity_id, created_at; -- name: GetBindCode :one -SELECT id, token, issued_by_user_id, platform, expires_at, used_at, used_by_channel_identity_id, created_at +SELECT id, token, issued_by_user_id, channel_type, expires_at, used_at, used_by_channel_identity_id, created_at FROM channel_identity_bind_codes WHERE token = $1; -- name: GetBindCodeForUpdate :one -SELECT id, token, issued_by_user_id, platform, expires_at, used_at, used_by_channel_identity_id, created_at +SELECT id, token, issued_by_user_id, channel_type, expires_at, used_at, used_by_channel_identity_id, created_at FROM channel_identity_bind_codes WHERE token = $1 FOR UPDATE; @@ -19,4 +19,4 @@ UPDATE channel_identity_bind_codes SET used_at = now(), used_by_channel_identity_id = $2 WHERE id = $1 AND used_at IS NULL -RETURNING id, token, issued_by_user_id, platform, expires_at, used_at, used_by_channel_identity_id, created_at; +RETURNING id, token, issued_by_user_id, channel_type, expires_at, used_at, used_by_channel_identity_id, created_at; diff --git a/db/queries/bots.sql b/db/queries/bots.sql index 2a16131d..6194f319 100644 --- a/db/queries/bots.sql +++ b/db/queries/bots.sql @@ -1,21 +1,21 @@ -- name: CreateBot :one -INSERT INTO bots (owner_user_id, type, display_name, avatar_url, is_active, metadata) -VALUES ($1, $2, $3, $4, $5, $6) -RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at; +INSERT INTO bots (owner_user_id, type, display_name, avatar_url, is_active, metadata, status) +VALUES ($1, $2, $3, $4, $5, $6, $7) +RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at; -- name: GetBotByID :one -SELECT id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at +SELECT id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at FROM bots WHERE id = $1; -- name: ListBotsByOwner :many -SELECT id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at +SELECT id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at FROM bots WHERE owner_user_id = $1 ORDER BY created_at DESC; -- name: ListBotsByMember :many -SELECT b.id, b.owner_user_id, b.type, b.display_name, b.avatar_url, b.is_active, b.max_context_load_time, b.language, b.allow_guest, b.chat_model_id, b.memory_model_id, b.embedding_model_id, b.metadata, b.created_at, b.updated_at +SELECT b.id, b.owner_user_id, b.type, b.display_name, b.avatar_url, b.is_active, b.status, b.max_context_load_time, b.language, b.allow_guest, b.chat_model_id, b.memory_model_id, b.embedding_model_id, b.metadata, b.created_at, b.updated_at FROM bots b JOIN bot_members m ON m.bot_id = b.id WHERE m.user_id = $1 @@ -29,14 +29,20 @@ SET display_name = $2, metadata = $5, updated_at = now() WHERE id = $1 -RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at; +RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at; -- name: UpdateBotOwner :one UPDATE bots SET owner_user_id = $2, updated_at = now() WHERE id = $1 -RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at; +RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at; + +-- name: UpdateBotStatus :exec +UPDATE bots +SET status = $2, + updated_at = now() +WHERE id = $1; -- name: DeleteBotByID :exec DELETE FROM bots WHERE id = $1; diff --git a/db/queries/channel_identities.sql b/db/queries/channel_identities.sql index 584e9116..8a761998 100644 --- a/db/queries/channel_identities.sql +++ b/db/queries/channel_identities.sql @@ -1,37 +1,37 @@ -- name: CreateChannelIdentity :one -INSERT INTO channel_identities (user_id, channel, channel_subject_id, display_name, metadata) +INSERT INTO channel_identities (user_id, channel_type, channel_subject_id, display_name, metadata) VALUES ($1, $2, $3, $4, $5) -RETURNING id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at; +RETURNING id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at; -- name: GetChannelIdentityByID :one -SELECT id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +SELECT id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at FROM channel_identities WHERE id = $1; -- name: GetChannelIdentityByIDForUpdate :one -SELECT id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +SELECT id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at FROM channel_identities WHERE id = $1 FOR UPDATE; -- name: GetChannelIdentityByChannelSubject :one -SELECT id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +SELECT id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at FROM channel_identities -WHERE channel = $1 AND channel_subject_id = $2; +WHERE channel_type = $1 AND channel_subject_id = $2; -- name: UpsertChannelIdentityByChannelSubject :one -INSERT INTO channel_identities (user_id, channel, channel_subject_id, display_name, metadata) +INSERT INTO channel_identities (user_id, channel_type, channel_subject_id, display_name, metadata) VALUES ($1, $2, $3, $4, $5) -ON CONFLICT (channel, channel_subject_id) +ON CONFLICT (channel_type, channel_subject_id) DO UPDATE SET - display_name = EXCLUDED.display_name, + display_name = COALESCE(NULLIF(EXCLUDED.display_name, ''), channel_identities.display_name), metadata = EXCLUDED.metadata, user_id = COALESCE(channel_identities.user_id, EXCLUDED.user_id), updated_at = now() -RETURNING id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at; +RETURNING id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at; -- name: ListChannelIdentitiesByUserID :many -SELECT id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +SELECT id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at FROM channel_identities WHERE user_id = $1 ORDER BY created_at DESC; @@ -40,10 +40,10 @@ ORDER BY created_at DESC; UPDATE channel_identities SET user_id = $2, updated_at = now() WHERE id = $1 -RETURNING id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at; +RETURNING id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at; -- name: ClearChannelIdentityLinkedUser :one UPDATE channel_identities SET user_id = NULL, updated_at = now() WHERE id = $1 -RETURNING id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at; +RETURNING id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at; diff --git a/db/queries/channel_routes.sql b/db/queries/channel_routes.sql new file mode 100644 index 00000000..be84a19f --- /dev/null +++ b/db/queries/channel_routes.sql @@ -0,0 +1,87 @@ +-- name: CreateChatRoute :one +INSERT INTO bot_channel_routes ( + bot_id, channel_type, channel_config_id, external_conversation_id, external_thread_id, default_reply_target, metadata +) +VALUES ( + sqlc.arg(bot_id), + sqlc.arg(platform), + sqlc.narg(channel_config_id)::uuid, + sqlc.arg(conversation_id), + sqlc.narg(thread_id)::text, + sqlc.narg(reply_target)::text, + sqlc.arg(metadata) +) +RETURNING + id, + sqlc.arg(chat_id)::uuid AS chat_id, + bot_id, + channel_type AS platform, + channel_config_id, + external_conversation_id AS conversation_id, + external_thread_id AS thread_id, + default_reply_target AS reply_target, + metadata, + created_at, + updated_at; + +-- name: FindChatRoute :one +SELECT + id, + bot_id AS chat_id, + bot_id, + channel_type AS platform, + channel_config_id, + external_conversation_id AS conversation_id, + external_thread_id AS thread_id, + default_reply_target AS reply_target, + metadata, + created_at, + updated_at +FROM bot_channel_routes +WHERE bot_id = $1 + AND channel_type = sqlc.arg(platform) + AND external_conversation_id = sqlc.arg(conversation_id) + AND COALESCE(external_thread_id, '') = COALESCE(sqlc.narg(thread_id), '') +LIMIT 1; + +-- name: GetChatRouteByID :one +SELECT + id, + bot_id AS chat_id, + bot_id, + channel_type AS platform, + channel_config_id, + external_conversation_id AS conversation_id, + external_thread_id AS thread_id, + default_reply_target AS reply_target, + metadata, + created_at, + updated_at +FROM bot_channel_routes +WHERE id = $1; + +-- name: ListChatRoutes :many +SELECT + id, + bot_id AS chat_id, + bot_id, + channel_type AS platform, + channel_config_id, + external_conversation_id AS conversation_id, + external_thread_id AS thread_id, + default_reply_target AS reply_target, + metadata, + created_at, + updated_at +FROM bot_channel_routes +WHERE bot_id = sqlc.arg(chat_id) +ORDER BY created_at ASC; + +-- name: UpdateChatRouteReplyTarget :exec +UPDATE bot_channel_routes +SET default_reply_target = sqlc.arg(reply_target), updated_at = now() +WHERE id = sqlc.arg(id); + +-- name: DeleteChatRoute :exec +DELETE FROM bot_channel_routes +WHERE id = $1; diff --git a/db/queries/channels.sql b/db/queries/channels.sql index 99b1eb3f..da22f1c2 100644 --- a/db/queries/channels.sql +++ b/db/queries/channels.sql @@ -34,23 +34,23 @@ WHERE channel_type = $1 ORDER BY created_at DESC; -- name: GetUserChannelBinding :one -SELECT id, user_id, platform, config, created_at, updated_at +SELECT id, user_id, channel_type, config, created_at, updated_at FROM user_channel_bindings -WHERE user_id = $1 AND platform = $2 +WHERE user_id = $1 AND channel_type = $2 LIMIT 1; -- name: UpsertUserChannelBinding :one -INSERT INTO user_channel_bindings (user_id, platform, config) +INSERT INTO user_channel_bindings (user_id, channel_type, config) VALUES ($1, $2, $3) -ON CONFLICT (user_id, platform) +ON CONFLICT (user_id, channel_type) DO UPDATE SET config = EXCLUDED.config, updated_at = now() -RETURNING id, user_id, platform, config, created_at, updated_at; +RETURNING id, user_id, channel_type, config, created_at, updated_at; -- name: ListUserChannelBindingsByPlatform :many -SELECT id, user_id, platform, config, created_at, updated_at +SELECT id, user_id, channel_type, config, created_at, updated_at FROM user_channel_bindings -WHERE platform = $1 +WHERE channel_type = $1 ORDER BY created_at DESC; diff --git a/db/queries/chats.sql b/db/queries/chats.sql deleted file mode 100644 index 501249a2..00000000 --- a/db/queries/chats.sql +++ /dev/null @@ -1,214 +0,0 @@ --- name: CreateChat :one -INSERT INTO chats (bot_id, kind, parent_chat_id, title, created_by_user_id, metadata) -VALUES ($1, $2, $3, $4, $5, $6) -RETURNING id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata, created_at, updated_at; - --- name: GetChatByID :one -SELECT id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata, created_at, updated_at -FROM chats -WHERE id = $1; - --- name: ListChatsByBotAndUser :many -SELECT c.id, c.bot_id, c.kind, c.parent_chat_id, c.title, c.created_by_user_id, c.metadata, c.enable_chat_memory, c.enable_private_memory, c.enable_public_memory, c.model_id, c.settings_metadata, c.created_at, c.updated_at -FROM chats c -JOIN chat_participants cp ON cp.chat_id = c.id -WHERE c.bot_id = $1 AND cp.user_id = $2 -ORDER BY c.updated_at DESC; - --- name: ListVisibleChatsByBotAndUser :many -WITH participant_chats AS ( - SELECT c.id, c.bot_id, c.kind, c.parent_chat_id, c.title, c.created_by_user_id, c.metadata, c.created_at, c.updated_at, - 'participant'::text AS access_mode, - cp.role AS participant_role, - NULL::timestamptz AS last_observed_at - FROM chats c - JOIN chat_participants cp ON cp.chat_id = c.id - WHERE c.bot_id = $1 AND cp.user_id = $2 -), -observed_chats AS ( - SELECT c.id, c.bot_id, c.kind, c.parent_chat_id, c.title, c.created_by_user_id, c.metadata, c.created_at, c.updated_at, - 'channel_identity_observed'::text AS access_mode, - ''::text AS participant_role, - MAX(cap.last_seen_at) AS last_observed_at - FROM chats c - JOIN chat_channel_identity_presence cap ON cap.chat_id = c.id - JOIN channel_identities ci ON ci.id = cap.channel_identity_id - WHERE c.bot_id = $1 - AND ci.user_id = $2 - AND NOT EXISTS ( - SELECT 1 FROM chat_participants cp - WHERE cp.chat_id = c.id AND cp.user_id = $2 - ) - GROUP BY c.id, c.bot_id, c.kind, c.parent_chat_id, c.title, c.created_by_user_id, c.metadata, c.created_at, c.updated_at -) -SELECT id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, created_at, updated_at, - access_mode, participant_role, last_observed_at -FROM ( - SELECT * FROM participant_chats - UNION ALL - SELECT * FROM observed_chats -) v -ORDER BY v.updated_at DESC, v.last_observed_at DESC NULLS LAST; - --- name: GetChatReadAccessByUser :one -WITH participant_access AS ( - SELECT 'participant'::text AS access_mode, - cp.role AS participant_role, - NULL::timestamptz AS last_observed_at - FROM chat_participants cp - WHERE cp.chat_id = $1 AND cp.user_id = $2 -), -observed_access AS ( - SELECT 'channel_identity_observed'::text AS access_mode, - ''::text AS participant_role, - MAX(cap.last_seen_at) AS last_observed_at - FROM chat_channel_identity_presence cap - JOIN channel_identities ci ON ci.id = cap.channel_identity_id - WHERE cap.chat_id = $1 AND ci.user_id = $2 - GROUP BY cap.chat_id -), -all_access AS ( - SELECT * FROM participant_access - UNION ALL - SELECT * FROM observed_access -) -SELECT access_mode, participant_role, last_observed_at -FROM all_access -ORDER BY CASE WHEN access_mode = 'participant' THEN 0 ELSE 1 END, last_observed_at DESC NULLS LAST -LIMIT 1; - --- name: ListThreadsByParent :many -SELECT id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata, created_at, updated_at -FROM chats -WHERE parent_chat_id = $1 AND kind = 'thread' -ORDER BY created_at DESC; - --- name: UpdateChatTitle :one -UPDATE chats SET title = $2, updated_at = now() -WHERE id = $1 -RETURNING id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata, created_at, updated_at; - --- name: TouchChat :exec -UPDATE chats SET updated_at = now() WHERE id = $1; - --- name: DeleteChat :exec -DELETE FROM chats WHERE id = $1; - --- chat_participants - --- name: AddChatParticipant :one -INSERT INTO chat_participants (chat_id, user_id, role) -VALUES ($1, $2, $3) -ON CONFLICT (chat_id, user_id) DO UPDATE SET role = EXCLUDED.role -RETURNING chat_id, user_id, role, joined_at; - --- name: GetChatParticipant :one -SELECT chat_id, user_id, role, joined_at -FROM chat_participants -WHERE chat_id = $1 AND user_id = $2; - --- name: ListChatParticipants :many -SELECT chat_id, user_id, role, joined_at -FROM chat_participants -WHERE chat_id = $1 -ORDER BY joined_at ASC; - --- name: RemoveChatParticipant :exec -DELETE FROM chat_participants WHERE chat_id = $1 AND user_id = $2; - --- name: CopyParticipantsToChat :exec -INSERT INTO chat_participants (chat_id, user_id, role) -SELECT $2, cp.user_id, cp.role FROM chat_participants cp WHERE cp.chat_id = $1 -ON CONFLICT (chat_id, user_id) DO NOTHING; - --- chat_settings - --- name: UpsertChatSettings :one -UPDATE chats -SET enable_chat_memory = $2, - enable_private_memory = $3, - enable_public_memory = $4, - model_id = $5, - settings_metadata = $6 -WHERE id = $1 -RETURNING id AS chat_id, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata AS metadata, updated_at; - --- name: GetChatSettings :one -SELECT id AS chat_id, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata AS metadata, updated_at -FROM chats -WHERE id = $1; - --- chat_routes - --- name: CreateChatRoute :one -INSERT INTO chat_routes (chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8) -RETURNING id, chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata, created_at, updated_at; - --- name: FindChatRoute :one -SELECT id, chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata, created_at, updated_at -FROM chat_routes -WHERE bot_id = $1 AND platform = $2 AND conversation_id = $3 - AND COALESCE(thread_id, '') = COALESCE(sqlc.narg('thread_id'), '') -LIMIT 1; - --- name: GetChatRouteByID :one -SELECT id, chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata, created_at, updated_at -FROM chat_routes -WHERE id = $1; - --- name: ListChatRoutes :many -SELECT id, chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata, created_at, updated_at -FROM chat_routes -WHERE chat_id = $1 -ORDER BY created_at ASC; - --- name: UpdateChatRouteReplyTarget :exec -UPDATE chat_routes SET reply_target = $2, updated_at = now() WHERE id = $1; - --- name: DeleteChatRoute :exec -DELETE FROM chat_routes WHERE id = $1; - --- chat_messages - --- name: CreateChatMessage :one -INSERT INTO chat_messages (chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) -RETURNING id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at; - --- name: UpsertChatChannelIdentityPresence :exec -INSERT INTO chat_channel_identity_presence (chat_id, channel_identity_id, first_seen_at, last_seen_at, message_count) -VALUES ($1, $2, now(), now(), 1) -ON CONFLICT (chat_id, channel_identity_id) -DO UPDATE SET - last_seen_at = now(), - message_count = chat_channel_identity_presence.message_count + 1; - --- name: ListChatMessages :many -SELECT id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at -FROM chat_messages -WHERE chat_id = $1 -ORDER BY created_at ASC; - --- name: ListChatMessagesSince :many -SELECT id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at -FROM chat_messages -WHERE chat_id = $1 AND created_at >= $2 -ORDER BY created_at ASC; - --- name: ListChatMessagesBefore :many -SELECT id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at -FROM chat_messages -WHERE chat_id = $1 AND created_at < $2 -ORDER BY created_at DESC -LIMIT $3; - --- name: ListChatMessagesLatest :many -SELECT id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at -FROM chat_messages -WHERE chat_id = $1 -ORDER BY created_at DESC -LIMIT $2; - --- name: DeleteChatMessagesByChat :exec -DELETE FROM chat_messages WHERE chat_id = $1; diff --git a/db/queries/containers.sql b/db/queries/containers.sql index 3e81c964..e3fba5ce 100644 --- a/db/queries/containers.sql +++ b/db/queries/containers.sql @@ -52,3 +52,6 @@ WHERE bot_id = sqlc.arg(bot_id); UPDATE containers SET status = 'stopped', last_stopped_at = now(), updated_at = now() WHERE bot_id = sqlc.arg(bot_id); + +-- name: ListAutoStartContainers :many +SELECT * FROM containers WHERE auto_start = true ORDER BY updated_at DESC; diff --git a/db/queries/conversations.sql b/db/queries/conversations.sql new file mode 100644 index 00000000..e3e25528 --- /dev/null +++ b/db/queries/conversations.sql @@ -0,0 +1,229 @@ +-- name: CreateChat :one +SELECT + b.id AS id, + b.id AS bot_id, + (COALESCE(NULLIF(sqlc.arg(kind)::text, ''), CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END))::text AS kind, + CASE WHEN sqlc.arg(kind) = 'thread' THEN sqlc.arg(parent_chat_id)::uuid ELSE NULL::uuid END AS parent_chat_id, + COALESCE(NULLIF(sqlc.arg(title)::text, ''), b.display_name) AS title, + COALESCE(sqlc.arg(created_by_user_id)::uuid, b.owner_user_id) AS created_by_user_id, + COALESCE(sqlc.arg(metadata)::jsonb, b.metadata) AS metadata, + chat_models.model_id AS model_id, + b.created_at, + b.updated_at +FROM bots b +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = sqlc.arg(bot_id) +LIMIT 1; + +-- name: GetChatByID :one +SELECT + b.id AS id, + b.id AS bot_id, + CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + b.display_name AS title, + b.owner_user_id AS created_by_user_id, + b.metadata AS metadata, + chat_models.model_id AS model_id, + b.created_at, + b.updated_at +FROM bots b +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = $1; + +-- name: ListChatsByBotAndUser :many +SELECT + b.id AS id, + b.id AS bot_id, + CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + b.display_name AS title, + b.owner_user_id AS created_by_user_id, + b.metadata AS metadata, + chat_models.model_id AS model_id, + b.created_at, + b.updated_at +FROM bots b +LEFT JOIN bot_members bm ON bm.bot_id = b.id AND bm.user_id = sqlc.arg(user_id) +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = sqlc.arg(bot_id) + AND (b.owner_user_id = sqlc.arg(user_id) OR bm.user_id IS NOT NULL) +ORDER BY b.updated_at DESC; + +-- name: ListVisibleChatsByBotAndUser :many +SELECT + b.id AS id, + b.id AS bot_id, + CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + b.display_name AS title, + b.owner_user_id AS created_by_user_id, + b.metadata AS metadata, + b.created_at, + b.updated_at, + 'participant'::text AS access_mode, + (CASE + WHEN b.owner_user_id = sqlc.arg(user_id) THEN 'owner' + ELSE COALESCE(bm.role, ''::text) + END)::text AS participant_role, + NULL::timestamptz AS last_observed_at +FROM bots b +LEFT JOIN bot_members bm ON bm.bot_id = b.id AND bm.user_id = sqlc.arg(user_id) +WHERE b.id = sqlc.arg(bot_id) + AND (b.owner_user_id = sqlc.arg(user_id) OR bm.user_id IS NOT NULL) +ORDER BY b.updated_at DESC; + +-- name: GetChatReadAccessByUser :one +SELECT + 'participant'::text AS access_mode, + (CASE + WHEN b.owner_user_id = sqlc.arg(user_id) THEN 'owner' + ELSE COALESCE(bm.role, ''::text) + END)::text AS participant_role, + NULL::timestamptz AS last_observed_at +FROM bots b +LEFT JOIN bot_members bm ON bm.bot_id = b.id AND bm.user_id = sqlc.arg(user_id) +WHERE b.id = sqlc.arg(chat_id) + AND (b.owner_user_id = sqlc.arg(user_id) OR bm.user_id IS NOT NULL) +LIMIT 1; + +-- name: ListThreadsByParent :many +SELECT + b.id AS id, + b.id AS bot_id, + CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + b.display_name AS title, + b.owner_user_id AS created_by_user_id, + b.metadata AS metadata, + chat_models.model_id AS model_id, + b.created_at, + b.updated_at +FROM bots b +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = $1 + AND false +ORDER BY b.created_at DESC; + +-- name: UpdateChatTitle :one +UPDATE bots +SET display_name = sqlc.arg(title), + updated_at = now() +WHERE id = sqlc.arg(id) +RETURNING + id, + id AS bot_id, + CASE WHEN type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + display_name AS title, + owner_user_id AS created_by_user_id, + metadata, + NULL::text AS model_id, + created_at, + updated_at; + +-- name: TouchChat :exec +UPDATE bots +SET updated_at = now() +WHERE id = sqlc.arg(chat_id); + +-- name: DeleteChat :exec +WITH deleted_messages AS ( + DELETE FROM bot_history_messages + WHERE bot_id = sqlc.arg(chat_id) +) +DELETE FROM bot_channel_routes bcr +WHERE bcr.bot_id = sqlc.arg(chat_id); + +-- chat_participants + +-- name: AddChatParticipant :one +INSERT INTO bot_members (bot_id, user_id, role) +VALUES (sqlc.arg(chat_id), sqlc.arg(user_id), sqlc.arg(role)) +ON CONFLICT (bot_id, user_id) DO UPDATE SET role = EXCLUDED.role +RETURNING bot_id AS chat_id, user_id, role, created_at AS joined_at; + +-- name: GetChatParticipant :one +WITH owner_participant AS ( + SELECT b.id AS chat_id, b.owner_user_id AS user_id, 'owner'::text AS role, b.created_at AS joined_at + FROM bots b + WHERE b.id = sqlc.arg(chat_id) AND b.owner_user_id = sqlc.arg(user_id) +), +member_participant AS ( + SELECT bm.bot_id AS chat_id, bm.user_id, bm.role, bm.created_at AS joined_at + FROM bot_members bm + WHERE bm.bot_id = sqlc.arg(chat_id) AND bm.user_id = sqlc.arg(user_id) +) +SELECT chat_id, user_id, role, joined_at +FROM ( + SELECT * FROM owner_participant + UNION ALL + SELECT * FROM member_participant +) p +ORDER BY CASE WHEN role = 'owner' THEN 0 ELSE 1 END +LIMIT 1; + +-- name: ListChatParticipants :many +WITH owner_participant AS ( + SELECT b.id AS chat_id, b.owner_user_id AS user_id, 'owner'::text AS role, b.created_at AS joined_at + FROM bots b + WHERE b.id = sqlc.arg(chat_id) +), +member_participant AS ( + SELECT bm.bot_id AS chat_id, bm.user_id, bm.role, bm.created_at AS joined_at + FROM bot_members bm + WHERE bm.bot_id = sqlc.arg(chat_id) + AND bm.user_id <> (SELECT owner_user_id FROM bots WHERE id = sqlc.arg(chat_id)) +) +SELECT chat_id, user_id, role, joined_at +FROM ( + SELECT * FROM owner_participant + UNION ALL + SELECT * FROM member_participant +) p +ORDER BY joined_at ASC; + +-- name: RemoveChatParticipant :exec +DELETE FROM bot_members +WHERE bot_id = sqlc.arg(chat_id) + AND user_id = sqlc.arg(user_id) + AND user_id <> (SELECT owner_user_id FROM bots WHERE id = sqlc.arg(chat_id)); + +-- name: CopyParticipantsToChat :exec +INSERT INTO bot_members (bot_id, user_id, role) +SELECT sqlc.arg(chat_id_2), bm.user_id, bm.role +FROM bot_members bm +WHERE bm.bot_id = sqlc.arg(chat_id) +ON CONFLICT (bot_id, user_id) DO NOTHING; + +-- chat_settings + +-- name: UpsertChatSettings :one +WITH resolved_model AS ( + SELECT id + FROM models + WHERE model_id = NULLIF(sqlc.narg(model_id)::text, '') + LIMIT 1 +), +updated AS ( + UPDATE bots + SET chat_model_id = COALESCE((SELECT id FROM resolved_model), bots.chat_model_id), + updated_at = now() + WHERE bots.id = sqlc.arg(id) + RETURNING bots.id, bots.chat_model_id, bots.updated_at +) +SELECT + updated.id AS chat_id, + chat_models.model_id AS model_id, + updated.updated_at +FROM updated +LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id; + +-- name: GetChatSettings :one +SELECT + b.id AS chat_id, + chat_models.model_id AS model_id, + b.updated_at +FROM bots b +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = $1; diff --git a/db/queries/messages.sql b/db/queries/messages.sql new file mode 100644 index 00000000..eaa5a02b --- /dev/null +++ b/db/queries/messages.sql @@ -0,0 +1,118 @@ +-- name: CreateMessage :one +INSERT INTO bot_history_messages ( + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id, + channel_type, + source_message_id, + source_reply_to_message_id, + role, + content, + metadata +) +VALUES ( + sqlc.arg(bot_id), + sqlc.narg(route_id)::uuid, + sqlc.narg(sender_channel_identity_id)::uuid, + sqlc.narg(sender_user_id)::uuid, + sqlc.narg(platform)::text, + sqlc.narg(external_message_id)::text, + sqlc.narg(source_reply_to_message_id)::text, + sqlc.arg(role), + sqlc.arg(content), + sqlc.arg(metadata) +) +RETURNING + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at; + +-- name: ListMessages :many +SELECT + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at +FROM bot_history_messages +WHERE bot_id = sqlc.arg(bot_id) +ORDER BY created_at ASC; + +-- name: ListMessagesSince :many +SELECT + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at +FROM bot_history_messages +WHERE bot_id = sqlc.arg(bot_id) + AND created_at >= sqlc.arg(created_at) +ORDER BY created_at ASC; + +-- name: ListMessagesBefore :many +SELECT + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at +FROM bot_history_messages +WHERE bot_id = sqlc.arg(bot_id) + AND created_at < sqlc.arg(created_at) +ORDER BY created_at DESC +LIMIT sqlc.arg(max_count); + +-- name: ListMessagesLatest :many +SELECT + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at +FROM bot_history_messages +WHERE bot_id = sqlc.arg(bot_id) +ORDER BY created_at DESC +LIMIT sqlc.arg(max_count); + +-- name: DeleteMessagesByBot :exec +DELETE FROM bot_history_messages +WHERE bot_id = sqlc.arg(bot_id); diff --git a/internal/accounts/service.go b/internal/accounts/service.go index 787dc251..e683ef8a 100644 --- a/internal/accounts/service.go +++ b/internal/accounts/service.go @@ -8,11 +8,11 @@ import ( "strings" "time" - "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "golang.org/x/crypto/bcrypt" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -44,7 +44,7 @@ func (s *Service) Get(ctx context.Context, userID string) (Account, error) { if s.queries == nil { return Account{}, fmt.Errorf("account queries not configured") } - pgID, err := parseUUID(userID) + pgID, err := db.ParseUUID(userID) if err != nil { return Account{}, err } @@ -109,7 +109,7 @@ func (s *Service) IsAdmin(ctx context.Context, userID string) (bool, error) { if s.queries == nil { return false, fmt.Errorf("account queries not configured") } - pgID, err := parseUUID(userID) + pgID, err := db.ParseUUID(userID) if err != nil { return false, err } @@ -157,7 +157,7 @@ func (s *Service) Create(ctx context.Context, userID string, req CreateAccountRe isActive = *req.IsActive } - pgUserID, err := parseUUID(userID) + pgUserID, err := db.ParseUUID(userID) if err != nil { return Account{}, err } @@ -205,7 +205,7 @@ func (s *Service) CreateHuman(ctx context.Context, userID string, req CreateAcco if !userRow.ID.Valid { return Account{}, fmt.Errorf("create user: invalid id") } - userID = uuid.UUID(userRow.ID.Bytes).String() + userID = userRow.ID.String() } return s.Create(ctx, userID, req) } @@ -215,7 +215,7 @@ func (s *Service) UpdateAdmin(ctx context.Context, userID string, req UpdateAcco if s.queries == nil { return Account{}, fmt.Errorf("account queries not configured") } - pgID, err := parseUUID(userID) + pgID, err := db.ParseUUID(userID) if err != nil { return Account{}, err } @@ -264,7 +264,7 @@ func (s *Service) UpdateProfile(ctx context.Context, userID string, req UpdatePr if s.queries == nil { return Account{}, fmt.Errorf("account queries not configured") } - pgID, err := parseUUID(userID) + pgID, err := db.ParseUUID(userID) if err != nil { return Account{}, err } @@ -303,7 +303,7 @@ func (s *Service) UpdatePassword(ctx context.Context, userID, currentPassword, n if strings.TrimSpace(newPassword) == "" { return fmt.Errorf("new password is required") } - pgID, err := parseUUID(userID) + pgID, err := db.ParseUUID(userID) if err != nil { return err } @@ -339,7 +339,7 @@ func (s *Service) ResetPassword(ctx context.Context, userID, newPassword string) if strings.TrimSpace(newPassword) == "" { return fmt.Errorf("new password is required") } - pgID, err := parseUUID(userID) + pgID, err := db.ParseUUID(userID) if err != nil { return err } @@ -409,7 +409,7 @@ func toAccount(row sqlc.User) Account { lastLogin = row.LastLoginAt.Time } return Account{ - ID: toUUIDString(row.ID), + ID: row.ID.String(), Username: username, Email: email, Role: fmt.Sprint(row.Role), @@ -422,24 +422,3 @@ func toAccount(row sqlc.User) Account { } } -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} - -func toUUIDString(value pgtype.UUID) string { - if !value.Valid { - return "" - } - parsed, err := uuid.FromBytes(value.Bytes[:]) - if err != nil { - return "" - } - return parsed.String() -} diff --git a/internal/bind/service.go b/internal/bind/service.go index c4a33dd3..a0c84188 100644 --- a/internal/bind/service.go +++ b/internal/bind/service.go @@ -14,6 +14,7 @@ import ( "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -51,7 +52,7 @@ func (s *Service) Issue(ctx context.Context, issuedByUserID, platform string, tt ttl = defaultTTL } - pgUserID, err := parseUUID(issuedByUserID) + pgUserID, err := db.ParseUUID(issuedByUserID) if err != nil { return Code{}, fmt.Errorf("invalid user id: %w", err) } @@ -63,7 +64,7 @@ func (s *Service) Issue(ctx context.Context, issuedByUserID, platform string, tt row, err := s.queries.CreateBindCode(ctx, sqlc.CreateBindCodeParams{ Token: token, IssuedByUserID: pgUserID, - Platform: pgtype.Text{ + ChannelType: pgtype.Text{ String: normalizedPlatform, Valid: normalizedPlatform != "", }, @@ -116,7 +117,7 @@ func (s *Service) Consume(ctx context.Context, code Code, channelIdentityID stri if sourceIdentityID == "" { return fmt.Errorf("channel identity id is required") } - pgSourceIdentityID, err := parseUUID(sourceIdentityID) + pgSourceIdentityID, err := db.ParseUUID(sourceIdentityID) if err != nil { return err } @@ -150,7 +151,7 @@ func (s *Service) Consume(ctx context.Context, code Code, channelIdentityID stri if targetUserID == "" { return fmt.Errorf("bind code issuer user is missing") } - pgTargetUserID, err := parseUUID(targetUserID) + pgTargetUserID, err := db.ParseUUID(targetUserID) if err != nil { return err } @@ -168,7 +169,7 @@ func (s *Service) Consume(ctx context.Context, code Code, channelIdentityID stri } return fmt.Errorf("reload source identity: %w", err) } - if sourceIdentity.UserID.Valid && uuidString(sourceIdentity.UserID) != targetUserID { + if sourceIdentity.UserID.Valid && sourceIdentity.UserID.String() != targetUserID { return ErrLinkConflict } if !sourceIdentity.UserID.Valid { @@ -205,13 +206,13 @@ func (s *Service) Consume(ctx context.Context, code Code, channelIdentityID stri func toCode(row sqlc.ChannelIdentityBindCode) Code { c := Code{ - ID: uuidString(row.ID), + ID: row.ID.String(), Token: row.Token, - IssuedByUserID: uuidString(row.IssuedByUserID), + IssuedByUserID: row.IssuedByUserID.String(), CreatedAt: row.CreatedAt.Time, } - if row.Platform.Valid { - c.Platform = normalizePlatform(row.Platform.String) + if row.ChannelType.Valid { + c.Platform = normalizePlatform(row.ChannelType.String) } if row.ExpiresAt.Valid { c.ExpiresAt = row.ExpiresAt.Time @@ -220,32 +221,11 @@ func toCode(row sqlc.ChannelIdentityBindCode) Code { c.UsedAt = row.UsedAt.Time } if row.UsedByChannelIdentityID.Valid { - c.UsedByChannelIdentityID = uuidString(row.UsedByChannelIdentityID) + c.UsedByChannelIdentityID = row.UsedByChannelIdentityID.String() } return c } -func parseUUID(id string) (pgtype.UUID, error) { - trimmed := strings.TrimSpace(id) - if trimmed == "" { - return pgtype.UUID{}, fmt.Errorf("empty id") - } - var pgID pgtype.UUID - if err := pgID.Scan(trimmed); err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - return pgID, nil -} - -func uuidString(id pgtype.UUID) string { - if !id.Valid { - return "" - } - b := id.Bytes - return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", - b[0:4], b[4:6], b[6:8], b[8:10], b[10:16]) -} - func isUniqueViolation(err error) bool { var pgErr *pgconn.PgError if !errors.As(err, &pgErr) { diff --git a/internal/bind/service_integration_test.go b/internal/bind/service_integration_test.go index 02648e2c..b6f0f5ec 100644 --- a/internal/bind/service_integration_test.go +++ b/internal/bind/service_integration_test.go @@ -1,12 +1,10 @@ -//go:build ignore -// +build ignore - package bind_test import ( "context" "encoding/json" "errors" + "fmt" "log/slog" "os" "testing" @@ -16,12 +14,12 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/memohai/memoh/internal/bind" - "github.com/memohai/memoh/internal/channelidentities" + "github.com/memohai/memoh/internal/channel/identities" "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) -func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.Service, *bind.Service, func()) { +func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *identities.Service, *bind.Service, func()) { t.Helper() dsn := os.Getenv("TEST_POSTGRES_DSN") @@ -41,13 +39,12 @@ func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.S queries := sqlc.New(pool) logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) - channelIdentitySvc := channelidentities.NewService(logger, queries) + channelIdentitySvc := identities.NewService(logger, queries) bindSvc := bind.NewService(logger, pool, queries) - return queries, channelIdentitySvc, bindSvc, func() { pool.Close() } } -func createUserForBindTest(ctx context.Context, queries *sqlc.Queries) (string, error) { +func createUserForBind(ctx context.Context, queries *sqlc.Queries) (string, error) { row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{ IsActive: true, Metadata: []byte("{}"), @@ -55,10 +52,10 @@ func createUserForBindTest(ctx context.Context, queries *sqlc.Queries) (string, if err != nil { return "", err } - return db.UUIDToString(row.ID), nil + return row.ID.String(), nil } -func createBotForBindTest(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) { +func createBotForBind(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) { pgOwnerID, err := db.ParseUUID(ownerUserID) if err != nil { return "", err @@ -77,7 +74,7 @@ func createBotForBindTest(ctx context.Context, queries *sqlc.Queries, ownerUserI if err != nil { return "", err } - return db.UUIDToString(row.ID), nil + return row.ID.String(), nil } func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) { @@ -85,20 +82,16 @@ func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) { defer cleanup() ctx := context.Background() - ownerUserID, err := createUserForBindTest(ctx, queries) + ownerUserID, err := createUserForBind(ctx, queries) if err != nil { t.Fatalf("create owner user failed: %v", err) } - sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel) + sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, "feishu", fmt.Sprintf("bind-success-%d", time.Now().UnixNano()), "source") if err != nil { t.Fatalf("create source channel identity failed: %v", err) } - botID, err := createBotForBindTest(ctx, queries, ownerUserID) - if err != nil { - t.Fatalf("create bot failed: %v", err) - } - code, err := bindSvc.Issue(ctx, botID, ownerUserID, 10*time.Minute) + code, err := bindSvc.Issue(ctx, ownerUserID, "feishu", 10*time.Minute) if err != nil { t.Fatalf("issue bind code failed: %v", err) } @@ -135,27 +128,23 @@ func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) { defer cleanup() ctx := context.Background() - ownerUserID, err := createUserForBindTest(ctx, queries) + ownerUserID, err := createUserForBind(ctx, queries) if err != nil { t.Fatalf("create owner user failed: %v", err) } - otherUserID, err := createUserForBindTest(ctx, queries) + otherUserID, err := createUserForBind(ctx, queries) if err != nil { t.Fatalf("create other user failed: %v", err) } - sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel) + sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, "feishu", fmt.Sprintf("bind-rollback-%d", time.Now().UnixNano()), "source") if err != nil { t.Fatalf("create source channel identity failed: %v", err) } if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, sourceChannelIdentity.ID, otherUserID); err != nil { t.Fatalf("pre-link source channel identity failed: %v", err) } - botID, err := createBotForBindTest(ctx, queries, ownerUserID) - if err != nil { - t.Fatalf("create bot failed: %v", err) - } - code, err := bindSvc.Issue(ctx, botID, ownerUserID, 10*time.Minute) + code, err := bindSvc.Issue(ctx, ownerUserID, "feishu", 10*time.Minute) if err != nil { t.Fatalf("issue bind code failed: %v", err) } @@ -171,3 +160,81 @@ func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) { t.Fatal("expected used_at to remain empty when consume fails") } } + +func TestIntegrationConsumeLinksChannelIdentityToIssuerUser(t *testing.T) { + queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + ownerUserID, err := createUserForBind(ctx, queries) + if err != nil { + t.Fatalf("create owner user failed: %v", err) + } + sourceChannelIdentity, err := channelIdentitySvc.ResolveByChannelIdentity(ctx, "feishu", fmt.Sprintf("bind-src-%d", time.Now().UnixNano()), "source") + if err != nil { + t.Fatalf("create source channelIdentity failed: %v", err) + } + code, err := bindSvc.Issue(ctx, ownerUserID, "feishu", 10*time.Minute) + if err != nil { + t.Fatalf("issue bind code failed: %v", err) + } + if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); err != nil { + t.Fatalf("consume bind code failed: %v", err) + } + + after, err := bindSvc.Get(ctx, code.Token) + if err != nil { + t.Fatalf("get bind code failed: %v", err) + } + if after.UsedAt.IsZero() { + t.Fatal("expected code used_at set after consume") + } + if after.UsedByChannelIdentityID != sourceChannelIdentity.ID { + t.Fatalf("expected used_by_channel_identity_id=%s, got %s", sourceChannelIdentity.ID, after.UsedByChannelIdentityID) + } + + linkedUserID, err := channelIdentitySvc.GetLinkedUserID(ctx, sourceChannelIdentity.ID) + if err != nil { + t.Fatalf("get linked user failed: %v", err) + } + if linkedUserID != ownerUserID { + t.Fatalf("expected linked user=%s, got %s", ownerUserID, linkedUserID) + } +} + +func TestIntegrationConsumeConflictDoesNotMarkUsed(t *testing.T) { + queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + issuerUserID, err := createUserForBind(ctx, queries) + if err != nil { + t.Fatalf("create issuer user failed: %v", err) + } + otherUserID, err := createUserForBind(ctx, queries) + if err != nil { + t.Fatalf("create other user failed: %v", err) + } + sourceChannelIdentity, err := channelIdentitySvc.ResolveByChannelIdentity(ctx, "feishu", fmt.Sprintf("bind-conflict-%d", time.Now().UnixNano()), "source") + if err != nil { + t.Fatalf("create source channelIdentity failed: %v", err) + } + if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, sourceChannelIdentity.ID, otherUserID); err != nil { + t.Fatalf("pre-link source channelIdentity failed: %v", err) + } + code, err := bindSvc.Issue(ctx, issuerUserID, "feishu", 10*time.Minute) + if err != nil { + t.Fatalf("issue bind code failed: %v", err) + } + if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); !errors.Is(err, bind.ErrLinkConflict) { + t.Fatalf("expected ErrLinkConflict, got %v", err) + } + + after, err := bindSvc.Get(ctx, code.Token) + if err != nil { + t.Fatalf("get bind code failed: %v", err) + } + if !after.UsedAt.IsZero() { + t.Fatal("expected code to remain unused after conflict") + } +} diff --git a/internal/bind/service_link_integration_test.go b/internal/bind/service_link_integration_test.go deleted file mode 100644 index 05e88b63..00000000 --- a/internal/bind/service_link_integration_test.go +++ /dev/null @@ -1,156 +0,0 @@ -package bind_test - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "log/slog" - "os" - "testing" - "time" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgxpool" - - "github.com/memohai/memoh/internal/bind" - "github.com/memohai/memoh/internal/channelidentities" - "github.com/memohai/memoh/internal/db" - "github.com/memohai/memoh/internal/db/sqlc" -) - -func setupBindLinkIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.Service, *bind.Service, func()) { - t.Helper() - - dsn := os.Getenv("TEST_POSTGRES_DSN") - if dsn == "" { - t.Skip("skip integration test: TEST_POSTGRES_DSN is not set") - } - - ctx := context.Background() - pool, err := pgxpool.New(ctx, dsn) - if err != nil { - t.Skipf("skip integration test: cannot connect to database: %v", err) - } - if err := pool.Ping(ctx); err != nil { - pool.Close() - t.Skipf("skip integration test: database ping failed: %v", err) - } - - queries := sqlc.New(pool) - logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) - channelIdentitySvc := channelidentities.NewService(logger, queries) - bindSvc := bind.NewService(logger, pool, queries) - return queries, channelIdentitySvc, bindSvc, func() { pool.Close() } -} - -func createUserForBind(ctx context.Context, queries *sqlc.Queries) (string, error) { - row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{ - IsActive: true, - Metadata: []byte("{}"), - }) - if err != nil { - return "", err - } - return db.UUIDToString(row.ID), nil -} - -func createBotForBind(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) { - pgOwnerID, err := db.ParseUUID(ownerUserID) - if err != nil { - return "", err - } - meta, err := json.Marshal(map[string]any{"source": "bind-integration-test"}) - if err != nil { - return "", err - } - row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{ - OwnerUserID: pgOwnerID, - Type: "personal", - DisplayName: pgtype.Text{String: "bind-test-bot", Valid: true}, - IsActive: true, - Metadata: meta, - }) - if err != nil { - return "", err - } - return db.UUIDToString(row.ID), nil -} - -func TestBindConsumeLinksChannelIdentityToIssuerUser(t *testing.T) { - queries, channelIdentitySvc, bindSvc, cleanup := setupBindLinkIntegrationTest(t) - defer cleanup() - - ctx := context.Background() - ownerUserID, err := createUserForBind(ctx, queries) - if err != nil { - t.Fatalf("create owner user failed: %v", err) - } - sourceChannelIdentity, err := channelIdentitySvc.ResolveByChannelIdentity(ctx, "feishu", fmt.Sprintf("bind-src-%d", time.Now().UnixNano()), "source") - if err != nil { - t.Fatalf("create source channelIdentity failed: %v", err) - } - code, err := bindSvc.Issue(ctx, ownerUserID, "feishu", 10*time.Minute) - if err != nil { - t.Fatalf("issue bind code failed: %v", err) - } - if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); err != nil { - t.Fatalf("consume bind code failed: %v", err) - } - - after, err := bindSvc.Get(ctx, code.Token) - if err != nil { - t.Fatalf("get bind code failed: %v", err) - } - if after.UsedAt.IsZero() { - t.Fatal("expected code used_at set after consume") - } - if after.UsedByChannelIdentityID != sourceChannelIdentity.ID { - t.Fatalf("expected used_by_channel_identity_id=%s, got %s", sourceChannelIdentity.ID, after.UsedByChannelIdentityID) - } - - linkedUserID, err := channelIdentitySvc.GetLinkedUserID(ctx, sourceChannelIdentity.ID) - if err != nil { - t.Fatalf("get linked user failed: %v", err) - } - if linkedUserID != ownerUserID { - t.Fatalf("expected linked user=%s, got %s", ownerUserID, linkedUserID) - } -} - -func TestBindConsumeConflictDoesNotMarkUsed(t *testing.T) { - queries, channelIdentitySvc, bindSvc, cleanup := setupBindLinkIntegrationTest(t) - defer cleanup() - - ctx := context.Background() - issuerUserID, err := createUserForBind(ctx, queries) - if err != nil { - t.Fatalf("create issuer user failed: %v", err) - } - otherUserID, err := createUserForBind(ctx, queries) - if err != nil { - t.Fatalf("create other user failed: %v", err) - } - sourceChannelIdentity, err := channelIdentitySvc.ResolveByChannelIdentity(ctx, "feishu", fmt.Sprintf("bind-conflict-%d", time.Now().UnixNano()), "source") - if err != nil { - t.Fatalf("create source channelIdentity failed: %v", err) - } - if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, sourceChannelIdentity.ID, otherUserID); err != nil { - t.Fatalf("pre-link source channelIdentity failed: %v", err) - } - code, err := bindSvc.Issue(ctx, issuerUserID, "feishu", 10*time.Minute) - if err != nil { - t.Fatalf("issue bind code failed: %v", err) - } - if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); !errors.Is(err, bind.ErrLinkConflict) { - t.Fatalf("expected ErrLinkConflict, got %v", err) - } - - after, err := bindSvc.Get(ctx, code.Token) - if err != nil { - t.Fatalf("get bind code failed: %v", err) - } - if !after.UsedAt.IsZero() { - t.Fatal("expected code to remain unused after conflict") - } -} diff --git a/internal/bind/service_test.go b/internal/bind/service_test.go new file mode 100644 index 00000000..298a744b --- /dev/null +++ b/internal/bind/service_test.go @@ -0,0 +1,207 @@ +package bind + +import ( + "context" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" +) + +func TestParseUUID(t *testing.T) { + tests := []struct { + name string + id string + wantErr bool + }{ + {"empty", "", true}, + {"blank", " ", true}, + {"invalid", "not-a-uuid", true}, + {"valid", "550e8400-e29b-41d4-a716-446655440000", false}, + {"valid with spaces", " 550e8400-e29b-41d4-a716-446655440000 ", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := db.ParseUUID(tt.id) + if (err != nil) != tt.wantErr { + t.Errorf("parseUUID() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && !got.Valid { + t.Error("expected valid UUID") + } + }) + } +} + +func TestIsUniqueViolation(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"nil", nil, false}, + {"other error", assertAnError(), false}, + {"unique violation token", &pgconn.PgError{Code: "23505", ConstraintName: "channel_identity_bind_codes_token_unique"}, true}, + {"unique violation empty constraint", &pgconn.PgError{Code: "23505", ConstraintName: ""}, true}, + {"wrong code", &pgconn.PgError{Code: "23503", ConstraintName: "some_fk"}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isUniqueViolation(tt.err); got != tt.want { + t.Errorf("isUniqueViolation() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNormalizePlatform(t *testing.T) { + tests := []struct { + raw string + want string + }{ + {"", ""}, + {" Feishu ", "feishu"}, + {"TELEGRAM", "telegram"}, + } + for _, tt := range tests { + if got := normalizePlatform(tt.raw); got != tt.want { + t.Errorf("normalizePlatform(%q) = %q, want %q", tt.raw, got, tt.want) + } + } +} + +func TestToCode(t *testing.T) { + pgID, err := db.ParseUUID("550e8400-e29b-41d4-a716-446655440000") + if err != nil { + t.Fatal(err) + } + now := time.Now().UTC() + var usedBy pgtype.UUID + _ = usedBy.Scan("660e8400-e29b-41d4-a716-446655440001") + row := sqlc.ChannelIdentityBindCode{ + ID: pgID, + Token: "ABC12345", + IssuedByUserID: pgID, + ChannelType: pgtype.Text{String: " Feishu ", Valid: true}, + ExpiresAt: pgtype.Timestamptz{Time: now, Valid: true}, + UsedAt: pgtype.Timestamptz{Time: now, Valid: true}, + UsedByChannelIdentityID: usedBy, + CreatedAt: pgtype.Timestamptz{Time: now, Valid: true}, + } + + c := toCode(row) + if c.Token != "ABC12345" { + t.Errorf("Token = %q", c.Token) + } + if c.Platform != "feishu" { + t.Errorf("Platform = %q (normalized)", c.Platform) + } + if c.IssuedByUserID != "550e8400-e29b-41d4-a716-446655440000" { + t.Errorf("IssuedByUserID = %q", c.IssuedByUserID) + } + if c.ExpiresAt.IsZero() { + t.Error("ExpiresAt should be set") + } + if c.UsedAt.IsZero() { + t.Error("UsedAt should be set") + } + if c.CreatedAt.IsZero() { + t.Error("CreatedAt should be set") + } +} + +func TestToCode_OptionalFields(t *testing.T) { + pgID, err := db.ParseUUID("550e8400-e29b-41d4-a716-446655440000") + if err != nil { + t.Fatal(err) + } + now := time.Now().UTC() + row := sqlc.ChannelIdentityBindCode{ + ID: pgID, + Token: "TOKEN", + IssuedByUserID: pgID, + ChannelType: pgtype.Text{Valid: false}, + ExpiresAt: pgtype.Timestamptz{Valid: false}, + UsedAt: pgtype.Timestamptz{Valid: false}, + CreatedAt: pgtype.Timestamptz{Time: now, Valid: true}, + } + c := toCode(row) + if c.Platform != "" { + t.Errorf("Platform should be empty, got %q", c.Platform) + } + if !c.ExpiresAt.IsZero() { + t.Error("ExpiresAt should be zero") + } + if !c.UsedAt.IsZero() { + t.Error("UsedAt should be zero") + } +} + +func assertAnError() error { + return errForTest +} + +var errForTest = errTyp{msg: "test error"} + +type errTyp struct{ msg string } + +func (e errTyp) Error() string { return e.msg } + +func TestService_Issue_NilQueries(t *testing.T) { + svc := NewService(nil, nil, nil) + _, err := svc.Issue(context.Background(), "550e8400-e29b-41d4-a716-446655440000", "feishu", time.Hour) + if err == nil { + t.Fatal("expected error when queries nil") + } +} + +func TestService_Issue_InvalidUserID(t *testing.T) { + svc := NewService(nil, nil, nil) + _, err := svc.Issue(context.Background(), "invalid", "feishu", time.Hour) + if err == nil { + t.Fatal("expected error for invalid user id") + } +} + +func TestService_Get_NilQueries(t *testing.T) { + svc := NewService(nil, nil, nil) + _, err := svc.Get(context.Background(), "TOKEN") + if err == nil { + t.Fatal("expected error when queries nil") + } +} + +func TestService_Consume_NilConfig(t *testing.T) { + svc := NewService(nil, nil, nil) + code := Code{Token: "ABC", IssuedByUserID: "550e8400-e29b-41d4-a716-446655440000"} + err := svc.Consume(context.Background(), code, "660e8400-e29b-41d4-a716-446655440001") + if err == nil { + t.Fatal("expected error when service not configured") + } +} + +// Consume fast-path (CodeUsed, CodeExpired, EmptyToken) runs after nil check; covered by integration tests. + +func TestService_Consume_EmptyChannelIdentityID(t *testing.T) { + svc := NewService(nil, nil, nil) + code := Code{Token: "ABC", IssuedByUserID: "550e8400-e29b-41d4-a716-446655440000"} + err := svc.Consume(context.Background(), code, "") + if err == nil { + t.Fatal("expected error when channel identity id empty") + } +} + +func TestService_Consume_InvalidChannelIdentityID(t *testing.T) { + svc := NewService(nil, nil, nil) + code := Code{Token: "ABC", IssuedByUserID: "550e8400-e29b-41d4-a716-446655440000"} + err := svc.Consume(context.Background(), code, "not-a-uuid") + if err == nil { + t.Fatal("expected error for invalid channel identity id") + } +} + diff --git a/internal/bots/service.go b/internal/bots/service.go index 8c358fe9..4c3dc3b6 100644 --- a/internal/bots/service.go +++ b/internal/bots/service.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "log/slog" + "os" "strings" "time" @@ -13,6 +14,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -23,6 +25,10 @@ type Service struct { containerLifecycle ContainerLifecycle } +const ( + botLifecycleOperationTimeout = 5 * time.Minute +) + var ( ErrBotNotFound = errors.New("bot not found") ErrBotAccessDenied = errors.New("bot access denied") @@ -82,7 +88,7 @@ func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotR if ownerID == "" { return Bot{}, fmt.Errorf("owner user id is required") } - ownerUUID, err := parseUUID(ownerID) + ownerUUID, err := db.ParseUUID(ownerID) if err != nil { return Bot{}, err } @@ -117,6 +123,7 @@ func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotR AvatarUrl: pgtype.Text{String: avatarURL, Valid: avatarURL != ""}, IsActive: isActive, Metadata: payload, + Status: BotStatusCreating, }) if err != nil { return Bot{}, err @@ -125,14 +132,10 @@ func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotR if err != nil { return Bot{}, err } - if s.containerLifecycle != nil { - if err := s.containerLifecycle.SetupBotContainer(ctx, bot.ID); err != nil { - s.logger.Error("failed to setup bot container", - slog.String("bot_id", bot.ID), - slog.Any("error", err), - ) - } + if err := s.attachCheckSummary(ctx, &bot, row); err != nil { + return Bot{}, err } + s.enqueueCreateLifecycle(bot.ID) return bot, nil } @@ -141,7 +144,7 @@ func (s *Service) Get(ctx context.Context, botID string) (Bot, error) { if s.queries == nil { return Bot{}, fmt.Errorf("bot queries not configured") } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return Bot{}, err } @@ -149,7 +152,14 @@ func (s *Service) Get(ctx context.Context, botID string) (Bot, error) { if err != nil { return Bot{}, err } - return toBot(row) + bot, err := toBot(row) + if err != nil { + return Bot{}, err + } + if err := s.attachCheckSummary(ctx, &bot, row); err != nil { + return Bot{}, err + } + return bot, nil } // ListByOwner returns bots owned by the given user. @@ -157,7 +167,7 @@ func (s *Service) ListByOwner(ctx context.Context, ownerUserID string) ([]Bot, e if s.queries == nil { return nil, fmt.Errorf("bot queries not configured") } - ownerUUID, err := parseUUID(ownerUserID) + ownerUUID, err := db.ParseUUID(ownerUserID) if err != nil { return nil, err } @@ -171,6 +181,9 @@ func (s *Service) ListByOwner(ctx context.Context, ownerUserID string) ([]Bot, e if err != nil { return nil, err } + if err := s.attachCheckSummary(ctx, &item, row); err != nil { + return nil, err + } items = append(items, item) } return items, nil @@ -181,7 +194,7 @@ func (s *Service) ListByMember(ctx context.Context, channelIdentityID string) ([ if s.queries == nil { return nil, fmt.Errorf("bot queries not configured") } - memberUUID, err := parseUUID(channelIdentityID) + memberUUID, err := db.ParseUUID(channelIdentityID) if err != nil { return nil, err } @@ -195,6 +208,9 @@ func (s *Service) ListByMember(ctx context.Context, channelIdentityID string) ([ if err != nil { return nil, err } + if err := s.attachCheckSummary(ctx, &item, row); err != nil { + return nil, err + } items = append(items, item) } return items, nil @@ -231,7 +247,7 @@ func (s *Service) Update(ctx context.Context, botID string, req UpdateBotRequest if s.queries == nil { return Bot{}, fmt.Errorf("bot queries not configured") } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return Bot{}, err } @@ -275,7 +291,14 @@ func (s *Service) Update(ctx context.Context, botID string, req UpdateBotRequest if err != nil { return Bot{}, err } - return toBot(row) + bot, err := toBot(row) + if err != nil { + return Bot{}, err + } + if err := s.attachCheckSummary(ctx, &bot, row); err != nil { + return Bot{}, err + } + return bot, nil } // TransferOwner transfers bot ownership to another user. @@ -283,11 +306,11 @@ func (s *Service) TransferOwner(ctx context.Context, botID string, ownerUserID s if s.queries == nil { return Bot{}, fmt.Errorf("bot queries not configured") } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return Bot{}, err } - ownerUUID, err := parseUUID(ownerUserID) + ownerUUID, err := db.ParseUUID(ownerUserID) if err != nil { return Bot{}, err } @@ -301,7 +324,14 @@ func (s *Service) TransferOwner(ctx context.Context, botID string, ownerUserID s if err != nil { return Bot{}, err } - return toBot(row) + bot, err := toBot(row) + if err != nil { + return Bot{}, err + } + if err := s.attachCheckSummary(ctx, &bot, row); err != nil { + return Bot{}, err + } + return bot, nil } // Delete removes a bot and its associated resources. @@ -309,25 +339,112 @@ func (s *Service) Delete(ctx context.Context, botID string) error { if s.queries == nil { return fmt.Errorf("bot queries not configured") } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return err } - if _, err := s.queries.GetBotByID(ctx, botUUID); err != nil { + row, err := s.queries.GetBotByID(ctx, botUUID) + if err != nil { return err } - if s.containerLifecycle != nil { - s.logger.Info("cleaning up bot container before deletion", slog.String("bot_id", botID)) - if err := s.containerLifecycle.CleanupBotContainer(ctx, botID); err != nil { - s.logger.Error("failed to cleanup bot container", + if strings.TrimSpace(row.Status) == BotStatusDeleting { + return nil + } + if err := s.queries.UpdateBotStatus(ctx, sqlc.UpdateBotStatusParams{ + ID: botUUID, + Status: BotStatusDeleting, + }); err != nil { + return err + } + s.enqueueDeleteLifecycle(botID) + return nil +} + +// ListChecks evaluates runtime resource checks for a bot. +func (s *Service) ListChecks(ctx context.Context, botID string) ([]BotCheck, error) { + if s.queries == nil { + return nil, fmt.Errorf("bot queries not configured") + } + botUUID, err := db.ParseUUID(botID) + if err != nil { + return nil, err + } + row, err := s.queries.GetBotByID(ctx, botUUID) + if err != nil { + return nil, err + } + return s.buildRuntimeChecks(ctx, row) +} + +func (s *Service) enqueueCreateLifecycle(botID string) { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), botLifecycleOperationTimeout) + defer cancel() + + if s.containerLifecycle != nil { + if err := s.containerLifecycle.SetupBotContainer(ctx, botID); err != nil { + s.logger.Error("bot container setup failed", + slog.String("bot_id", botID), + slog.Any("error", err), + ) + } + } + + if err := s.updateStatus(ctx, botID, BotStatusReady); err != nil { + s.logger.Error("failed to update bot status to ready after create", slog.String("bot_id", botID), slog.Any("error", err), ) } - } else { - s.logger.Warn("container lifecycle not configured, skipping container cleanup", slog.String("bot_id", botID)) + }() +} + +func (s *Service) enqueueDeleteLifecycle(botID string) { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), botLifecycleOperationTimeout) + defer cancel() + + if s.containerLifecycle != nil { + if err := s.containerLifecycle.CleanupBotContainer(ctx, botID); err != nil { + s.logger.Error("bot container cleanup failed", + slog.String("bot_id", botID), + slog.Any("error", err), + ) + } + } + + botUUID, err := db.ParseUUID(botID) + if err != nil { + s.logger.Error("invalid bot id while finalizing delete", + slog.String("bot_id", botID), + slog.Any("error", err), + ) + _ = s.updateStatus(ctx, botID, BotStatusReady) + return + } + if err := s.queries.DeleteBotByID(ctx, botUUID); err != nil { + s.logger.Error("failed to delete bot after cleanup", + slog.String("bot_id", botID), + slog.Any("error", err), + ) + _ = s.updateStatus(ctx, botID, BotStatusReady) + return + } + }() +} + +func (s *Service) updateStatus(ctx context.Context, botID, status string) error { + if s.queries == nil { + return fmt.Errorf("bot queries not configured") } - return s.queries.DeleteBotByID(ctx, botUUID) + botUUID, err := db.ParseUUID(botID) + if err != nil { + return err + } + return s.queries.UpdateBotStatus(ctx, sqlc.UpdateBotStatusParams{ + ID: botUUID, + Status: strings.TrimSpace(status), + }) } func (s *Service) ensureUserExists(ctx context.Context, userID pgtype.UUID) error { @@ -349,11 +466,11 @@ func (s *Service) UpsertMember(ctx context.Context, botID string, req UpsertMemb if s.queries == nil { return BotMember{}, fmt.Errorf("bot queries not configured") } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return BotMember{}, err } - memberUUID, err := parseUUID(req.UserID) + memberUUID, err := db.ParseUUID(req.UserID) if err != nil { return BotMember{}, err } @@ -377,7 +494,7 @@ func (s *Service) ListMembers(ctx context.Context, botID string) ([]BotMember, e if s.queries == nil { return nil, fmt.Errorf("bot queries not configured") } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return nil, err } @@ -397,11 +514,11 @@ func (s *Service) GetMember(ctx context.Context, botID, channelIdentityID string if s.queries == nil { return BotMember{}, fmt.Errorf("bot queries not configured") } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return BotMember{}, err } - memberUUID, err := parseUUID(channelIdentityID) + memberUUID, err := db.ParseUUID(channelIdentityID) if err != nil { return BotMember{}, err } @@ -420,11 +537,11 @@ func (s *Service) DeleteMember(ctx context.Context, botID, channelIdentityID str if s.queries == nil { return fmt.Errorf("bot queries not configured") } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return err } - memberUUID, err := parseUUID(channelIdentityID) + memberUUID, err := db.ParseUUID(channelIdentityID) if err != nil { return err } @@ -504,15 +621,18 @@ func toBot(row sqlc.Bot) (Bot, error) { updatedAt = row.UpdatedAt.Time } return Bot{ - ID: toUUIDString(row.ID), - OwnerUserID: toUUIDString(row.OwnerUserID), - Type: row.Type, - DisplayName: displayName, - AvatarURL: avatarURL, - IsActive: row.IsActive, - Metadata: metadata, - CreatedAt: createdAt, - UpdatedAt: updatedAt, + ID: row.ID.String(), + OwnerUserID: row.OwnerUserID.String(), + Type: row.Type, + DisplayName: displayName, + AvatarURL: avatarURL, + IsActive: row.IsActive, + Status: strings.TrimSpace(row.Status), + CheckState: BotCheckStateUnknown, + CheckIssueCount: 0, + Metadata: metadata, + CreatedAt: createdAt, + UpdatedAt: updatedAt, }, nil } @@ -522,8 +642,8 @@ func toBotMember(row sqlc.BotMember) BotMember { createdAt = row.CreatedAt.Time } return BotMember{ - BotID: toUUIDString(row.BotID), - UserID: toUUIDString(row.UserID), + BotID: row.BotID.String(), + UserID: row.UserID.String(), Role: row.Role, CreatedAt: createdAt, } @@ -543,24 +663,197 @@ func decodeMetadata(payload []byte) (map[string]any, error) { return data, nil } -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) +func (s *Service) attachCheckSummary(ctx context.Context, bot *Bot, row sqlc.Bot) error { + checks, err := s.buildRuntimeChecks(ctx, row) if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) + return err } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil + checkState, issueCount := summarizeChecks(checks) + bot.CheckState = checkState + bot.CheckIssueCount = issueCount + return nil } -func toUUIDString(value pgtype.UUID) string { - if !value.Valid { - return "" +func (s *Service) buildRuntimeChecks(ctx context.Context, row sqlc.Bot) ([]BotCheck, error) { + status := strings.TrimSpace(row.Status) + checks := make([]BotCheck, 0, 4) + + if status == BotStatusCreating { + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerInit, + Status: BotCheckStatusUnknown, + Summary: "Initialization is in progress.", + Detail: "Bot resources are still being provisioned.", + }) + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerRecord, + Status: BotCheckStatusUnknown, + Summary: "Container record is pending.", + Detail: "Container record will be checked after initialization.", + }) + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerTask, + Status: BotCheckStatusUnknown, + Summary: "Container task state is pending.", + Detail: "Task state will be checked after initialization.", + }) + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerData, + Status: BotCheckStatusUnknown, + Summary: "Container host path check is pending.", + Detail: "Data path will be checked after initialization.", + }) + return checks, nil } - parsed, err := uuid.FromBytes(value.Bytes[:]) + if status == BotStatusDeleting { + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyDelete, + Status: BotCheckStatusUnknown, + Summary: "Deletion is in progress.", + Detail: "Bot resources are being cleaned up.", + }) + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerRecord, + Status: BotCheckStatusUnknown, + Summary: "Container record check is skipped.", + Detail: "Bot is deleting and container checks are paused.", + }) + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerTask, + Status: BotCheckStatusUnknown, + Summary: "Container task check is skipped.", + Detail: "Bot is deleting and task checks are paused.", + }) + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerData, + Status: BotCheckStatusUnknown, + Summary: "Container host path check is skipped.", + Detail: "Bot is deleting and data path checks are paused.", + }) + return checks, nil + } + + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerInit, + Status: BotCheckStatusOK, + Summary: "Initialization finished.", + }) + + containerRow, err := s.queries.GetContainerByBotID(ctx, row.ID) if err != nil { - return "" + if errors.Is(err, pgx.ErrNoRows) { + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerRecord, + Status: BotCheckStatusError, + Summary: "Container record is missing.", + Detail: "No container is attached to this bot.", + }) + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerTask, + Status: BotCheckStatusUnknown, + Summary: "Container task state is unknown.", + Detail: "Task state cannot be determined without a container record.", + }) + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerData, + Status: BotCheckStatusUnknown, + Summary: "Container data path is unknown.", + Detail: "Data path cannot be determined without a container record.", + }) + return checks, nil + } + return nil, err } - return parsed.String() + + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerRecord, + Status: BotCheckStatusOK, + Summary: "Container record exists.", + Detail: fmt.Sprintf("container_id=%s", strings.TrimSpace(containerRow.ContainerID)), + Metadata: map[string]any{ + "container_id": strings.TrimSpace(containerRow.ContainerID), + "namespace": strings.TrimSpace(containerRow.Namespace), + "image": strings.TrimSpace(containerRow.Image), + }, + }) + + taskStatus := strings.TrimSpace(strings.ToLower(containerRow.Status)) + taskCheck := BotCheck{ + CheckKey: BotCheckKeyContainerTask, + Status: BotCheckStatusWarn, + Summary: "Container task state needs attention.", + } + switch taskStatus { + case "running", "created", "stopped", "paused": + taskCheck.Status = BotCheckStatusOK + taskCheck.Summary = "Container task state is reported." + taskCheck.Detail = fmt.Sprintf("status=%s", taskStatus) + case "": + taskCheck.Detail = "status is empty" + default: + taskCheck.Detail = fmt.Sprintf("unexpected status=%s", taskStatus) + } + taskCheck.Metadata = map[string]any{"status": taskStatus} + checks = append(checks, taskCheck) + + hostPath := "" + if containerRow.HostPath.Valid { + hostPath = strings.TrimSpace(containerRow.HostPath.String) + } + dataCheck := BotCheck{ + CheckKey: BotCheckKeyContainerData, + Status: BotCheckStatusWarn, + Summary: "Container host path needs attention.", + Metadata: map[string]any{"host_path": hostPath}, + } + if hostPath == "" { + dataCheck.Detail = "host path is empty" + checks = append(checks, dataCheck) + return checks, nil + } + info, statErr := os.Stat(hostPath) + switch { + case statErr == nil && info != nil && info.IsDir(): + dataCheck.Status = BotCheckStatusOK + dataCheck.Summary = "Container host path is accessible." + dataCheck.Detail = hostPath + case statErr == nil: + dataCheck.Status = BotCheckStatusError + dataCheck.Summary = "Container host path is invalid." + dataCheck.Detail = "host path is not a directory" + case errors.Is(statErr, os.ErrNotExist): + dataCheck.Status = BotCheckStatusError + dataCheck.Summary = "Container host path does not exist." + dataCheck.Detail = hostPath + default: + dataCheck.Status = BotCheckStatusWarn + dataCheck.Summary = "Container host path cannot be checked." + dataCheck.Detail = statErr.Error() + } + checks = append(checks, dataCheck) + + return checks, nil +} + +func summarizeChecks(checks []BotCheck) (string, int32) { + if len(checks) == 0 { + return BotCheckStateUnknown, 0 + } + var issueCount int32 + unknownCount := 0 + for _, check := range checks { + switch check.Status { + case BotCheckStatusWarn, BotCheckStatusError: + issueCount++ + case BotCheckStatusUnknown: + unknownCount++ + } + } + if issueCount > 0 { + return BotCheckStateIssue, issueCount + } + if unknownCount == len(checks) { + return BotCheckStateUnknown, 0 + } + return BotCheckStateOK, 0 } diff --git a/internal/bots/types.go b/internal/bots/types.go index 6d237746..e002524c 100644 --- a/internal/bots/types.go +++ b/internal/bots/types.go @@ -7,15 +7,18 @@ import ( // Bot represents a bot entity. type Bot struct { - ID string `json:"id" validate:"required"` - OwnerUserID string `json:"owner_user_id" validate:"required"` - Type string `json:"type" validate:"required"` - DisplayName string `json:"display_name" validate:"required"` - AvatarURL string `json:"avatar_url,omitempty"` - IsActive bool `json:"is_active" validate:"required"` - Metadata map[string]any `json:"metadata,omitempty"` - CreatedAt time.Time `json:"created_at" validate:"required"` - UpdatedAt time.Time `json:"updated_at" validate:"required"` + ID string `json:"id"` + OwnerUserID string `json:"owner_user_id"` + Type string `json:"type"` + DisplayName string `json:"display_name"` + AvatarURL string `json:"avatar_url,omitempty"` + IsActive bool `json:"is_active"` + Status string `json:"status"` + CheckState string `json:"check_state"` + CheckIssueCount int32 `json:"check_issue_count"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // BotMember represents a bot membership record. @@ -26,6 +29,15 @@ type BotMember struct { CreatedAt time.Time `json:"created_at"` } +// BotCheck represents one resource check row for a bot. +type BotCheck struct { + CheckKey string `json:"check_key"` + Status string `json:"status"` + Summary string `json:"summary"` + Detail string `json:"detail,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + // CreateBotRequest is the input for creating a bot. type CreateBotRequest struct { Type string `json:"type"` @@ -56,7 +68,7 @@ type UpsertMemberRequest struct { // ListBotsResponse wraps a list of bots. type ListBotsResponse struct { - Items []Bot `json:"items" validate:"required"` + Items []Bot `json:"items"` } // ListMembersResponse wraps a list of bot members. @@ -64,6 +76,11 @@ type ListMembersResponse struct { Items []BotMember `json:"items"` } +// ListChecksResponse wraps a list of bot checks. +type ListChecksResponse struct { + Items []BotCheck `json:"items"` +} + // ContainerLifecycle handles container lifecycle events bound to bot operations. type ContainerLifecycle interface { SetupBotContainer(ctx context.Context, botID string) error @@ -75,6 +92,33 @@ const ( BotTypePublic = "public" ) +const ( + BotStatusCreating = "creating" + BotStatusReady = "ready" + BotStatusDeleting = "deleting" +) + +const ( + BotCheckStateOK = "ok" + BotCheckStateIssue = "issue" + BotCheckStateUnknown = "unknown" +) + +const ( + BotCheckStatusOK = "ok" + BotCheckStatusWarn = "warn" + BotCheckStatusError = "error" + BotCheckStatusUnknown = "unknown" +) + +const ( + BotCheckKeyContainerInit = "container.init" + BotCheckKeyContainerRecord = "container.record" + BotCheckKeyContainerTask = "container.task" + BotCheckKeyContainerData = "container.data_path" + BotCheckKeyDelete = "bot.delete" +) + const ( MemberRoleOwner = "owner" MemberRoleAdmin = "admin" diff --git a/internal/channel/adapter.go b/internal/channel/adapter.go index 9cb5f6b7..0b12cf07 100644 --- a/internal/channel/adapter.go +++ b/internal/channel/adapter.go @@ -12,9 +12,42 @@ var ErrStopNotSupported = errors.New("channel connection stop not supported") // InboundHandler is a callback invoked when a message arrives from a channel. type InboundHandler func(ctx context.Context, cfg ChannelConfig, msg InboundMessage) error -// ReplySender sends an outbound reply within the scope of a single inbound message. -type ReplySender interface { +// StreamReplySender sends replies within a single inbound-processing scope. +// It supports both one-shot delivery and streaming sessions. +type StreamReplySender interface { Send(ctx context.Context, msg OutboundMessage) error + OpenStream(ctx context.Context, target string, opts StreamOptions) (OutboundStream, error) +} + +// OutboundStream is a live stream session for emitting outbound events. +type OutboundStream interface { + Push(ctx context.Context, event StreamEvent) error + Close(ctx context.Context) error +} + +// ProcessingStatusInfo carries context for channel-level processing status updates. +type ProcessingStatusInfo struct { + BotID string + ChatID string + RouteID string + ChannelIdentityID string + UserID string + Query string + ReplyTarget string + SourceMessageID string +} + +// ProcessingStatusHandle stores channel-specific state between status callbacks. +type ProcessingStatusHandle struct { + Token string +} + +// ProcessingStatusNotifier reports processing lifecycle updates to channel platforms. +// Implementations should be best-effort and idempotent. +type ProcessingStatusNotifier interface { + ProcessingStarted(ctx context.Context, cfg ChannelConfig, msg InboundMessage, info ProcessingStatusInfo) (ProcessingStatusHandle, error) + ProcessingCompleted(ctx context.Context, cfg ChannelConfig, msg InboundMessage, info ProcessingStatusInfo, handle ProcessingStatusHandle) error + ProcessingFailed(ctx context.Context, cfg ChannelConfig, msg InboundMessage, info ProcessingStatusInfo, handle ProcessingStatusHandle, cause error) error } // Adapter is the base interface every channel adapter must implement. @@ -59,6 +92,17 @@ type Sender interface { Send(ctx context.Context, cfg ChannelConfig, msg OutboundMessage) error } +// StreamSender is an adapter capable of opening outbound stream sessions. +type StreamSender interface { + OpenStream(ctx context.Context, cfg ChannelConfig, target string, opts StreamOptions) (OutboundStream, error) +} + +// MessageEditor updates and deletes already-sent messages when supported. +type MessageEditor interface { + Update(ctx context.Context, cfg ChannelConfig, target string, messageID string, msg Message) error + Unsend(ctx context.Context, cfg ChannelConfig, target string, messageID string) error +} + // Receiver is an adapter capable of establishing a long-lived connection to receive messages. type Receiver interface { Connect(ctx context.Context, cfg ChannelConfig, handler InboundHandler) (Connection, error) diff --git a/internal/channel/adapters/feishu/directory.go b/internal/channel/adapters/feishu/directory.go new file mode 100644 index 00000000..8493eed9 --- /dev/null +++ b/internal/channel/adapters/feishu/directory.go @@ -0,0 +1,298 @@ +package feishu + +import ( + "context" + "fmt" + "strings" + + lark "github.com/larksuite/oapi-sdk-go/v3" + larkcontact "github.com/larksuite/oapi-sdk-go/v3/service/contact/v3" + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + + "github.com/memohai/memoh/internal/channel" +) + +const ( + defaultDirectoryPageSize = 20 + maxDirectoryPageSize = 200 +) + +func directoryLimit(n int) int { + if n <= 0 { + return defaultDirectoryPageSize + } + if n > maxDirectoryPageSize { + return maxDirectoryPageSize + } + return n +} + +// ListPeers lists users (peers) from Feishu contact, optionally filtered by query. +func (a *FeishuAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + feishuCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, err + } + client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret) + pageSize := directoryLimit(query.Limit) + req := larkcontact.NewListUserReqBuilder(). + UserIdType(larkcontact.UserIdTypeOpenId). + DepartmentIdType(larkcontact.DepartmentIdTypeOpenDepartmentId). + DepartmentId("0"). + PageSize(pageSize). + Build() + resp, err := client.Contact.User.List(ctx, req) + if err != nil { + return nil, fmt.Errorf("feishu list users: %w", err) + } + if !resp.Success() { + return nil, fmt.Errorf("feishu list users: code=%d msg=%s", resp.Code, resp.Msg) + } + entries := make([]channel.DirectoryEntry, 0, len(resp.Data.Items)) + for _, u := range resp.Data.Items { + e := feishuUserToEntry(u) + if query.Query != "" && !strings.Contains(strings.ToLower(e.Name+e.Handle), strings.ToLower(query.Query)) { + continue + } + entries = append(entries, e) + } + return entries, nil +} + +// ListGroups lists chat groups from Feishu IM, optionally filtered by query. +func (a *FeishuAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + feishuCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, err + } + client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret) + pageSize := directoryLimit(query.Limit) + var items []*larkim.ListChat + if strings.TrimSpace(query.Query) != "" { + req := larkim.NewSearchChatReqBuilder(). + UserIdType("open_id"). + Query(strings.TrimSpace(query.Query)). + PageSize(pageSize). + Build() + resp, err := client.Im.Chat.Search(ctx, req) + if err != nil { + return nil, fmt.Errorf("feishu search chats: %w", err) + } + if !resp.Success() { + return nil, fmt.Errorf("feishu search chats: code=%d msg=%s", resp.Code, resp.Msg) + } + items = resp.Data.Items + } else { + req := larkim.NewListChatReqBuilder(). + UserIdType("open_id"). + PageSize(pageSize). + Build() + resp, err := client.Im.Chat.List(ctx, req) + if err != nil { + return nil, fmt.Errorf("feishu list chats: %w", err) + } + if !resp.Success() { + return nil, fmt.Errorf("feishu list chats: code=%d msg=%s", resp.Code, resp.Msg) + } + items = resp.Data.Items + } + entries := make([]channel.DirectoryEntry, 0, len(items)) + for _, c := range items { + entries = append(entries, feishuChatToEntry(c)) + } + return entries, nil +} + +// ListGroupMembers lists members of a Feishu chat group. +func (a *FeishuAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelConfig, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + feishuCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, err + } + chatID := strings.TrimSpace(groupID) + if strings.HasPrefix(chatID, "chat_id:") { + chatID = strings.TrimPrefix(chatID, "chat_id:") + } + if chatID == "" { + return nil, fmt.Errorf("feishu list group members: empty group id") + } + client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret) + pageSize := directoryLimit(query.Limit) + req := larkim.NewGetChatMembersReqBuilder(). + ChatId(chatID). + MemberIdType("open_id"). + PageSize(pageSize). + Build() + resp, err := client.Im.ChatMembers.Get(ctx, req) + if err != nil { + return nil, fmt.Errorf("feishu get chat members: %w", err) + } + if !resp.Success() { + return nil, fmt.Errorf("feishu get chat members: code=%d msg=%s", resp.Code, resp.Msg) + } + entries := make([]channel.DirectoryEntry, 0, len(resp.Data.Items)) + for _, m := range resp.Data.Items { + e := feishuMemberToEntry(m) + if query.Query != "" && !strings.Contains(strings.ToLower(e.Name+e.Handle), strings.ToLower(query.Query)) { + continue + } + entries = append(entries, e) + } + return entries, nil +} + +// ResolveEntry resolves an input string to a user or group DirectoryEntry. +func (a *FeishuAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + feishuCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return channel.DirectoryEntry{}, err + } + client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret) + input = strings.TrimSpace(input) + switch kind { + case channel.DirectoryEntryUser: + return a.resolveUser(ctx, client, input) + case channel.DirectoryEntryGroup: + return a.resolveGroup(ctx, client, input) + default: + return channel.DirectoryEntry{}, fmt.Errorf("feishu resolve entry: unsupported kind %q", kind) + } +} + +func (a *FeishuAdapter) resolveUser(ctx context.Context, client *lark.Client, input string) (channel.DirectoryEntry, error) { + userID, userIDType := parseFeishuUserInput(input) + if userID == "" { + return channel.DirectoryEntry{}, fmt.Errorf("feishu resolve entry user: invalid input %q", input) + } + req := larkcontact.NewGetUserReqBuilder(). + UserId(userID). + UserIdType(userIDType). + Build() + resp, err := client.Contact.User.Get(ctx, req) + if err != nil { + return channel.DirectoryEntry{}, fmt.Errorf("feishu get user: %w", err) + } + if !resp.Success() { + return channel.DirectoryEntry{}, fmt.Errorf("feishu get user: code=%d msg=%s", resp.Code, resp.Msg) + } + if resp.Data == nil || resp.Data.User == nil { + return channel.DirectoryEntry{}, fmt.Errorf("feishu get user: empty response") + } + return feishuUserToEntry(resp.Data.User), nil +} + +func (a *FeishuAdapter) resolveGroup(ctx context.Context, client *lark.Client, input string) (channel.DirectoryEntry, error) { + chatID := strings.TrimSpace(input) + if strings.HasPrefix(chatID, "chat_id:") { + chatID = strings.TrimPrefix(chatID, "chat_id:") + } + if chatID == "" { + return channel.DirectoryEntry{}, fmt.Errorf("feishu resolve entry group: invalid input %q", input) + } + req := larkim.NewGetChatReqBuilder(). + ChatId(chatID). + UserIdType("open_id"). + Build() + resp, err := client.Im.Chat.Get(ctx, req) + if err != nil { + return channel.DirectoryEntry{}, fmt.Errorf("feishu get chat: %w", err) + } + if !resp.Success() { + return channel.DirectoryEntry{}, fmt.Errorf("feishu get chat: code=%d msg=%s", resp.Code, resp.Msg) + } + return channel.DirectoryEntry{ + Kind: channel.DirectoryEntryGroup, + ID: "chat_id:" + chatID, + Name: ptrStr(resp.Data.Name), + AvatarURL: ptrStr(resp.Data.Avatar), + Metadata: map[string]any{"chat_id": chatID}, + }, nil +} + +func parseFeishuUserInput(raw string) (userID, userIDType string) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", "" + } + if strings.HasPrefix(raw, "open_id:") { + return strings.TrimSpace(strings.TrimPrefix(raw, "open_id:")), larkcontact.UserIdTypeOpenId + } + if strings.HasPrefix(raw, "user_id:") { + return strings.TrimSpace(strings.TrimPrefix(raw, "user_id:")), larkcontact.UserIdTypeUserId + } + if strings.HasPrefix(raw, "ou_") { + return raw, larkcontact.UserIdTypeOpenId + } + if strings.HasPrefix(raw, "u_") || strings.HasPrefix(raw, "u-") { + return raw, larkcontact.UserIdTypeUserId + } + return raw, larkcontact.UserIdTypeOpenId +} + +func feishuUserToEntry(u *larkcontact.User) channel.DirectoryEntry { + openID := ptrStr(u.OpenId) + userID := ptrStr(u.UserId) + id := "open_id:" + openID + if openID == "" && userID != "" { + id = "user_id:" + userID + } + meta := make(map[string]any) + if u.OpenId != nil { + meta["open_id"] = *u.OpenId + } + if u.UserId != nil { + meta["user_id"] = *u.UserId + } + return channel.DirectoryEntry{ + Kind: channel.DirectoryEntryUser, + ID: id, + Name: ptrStr(u.Name), + Handle: ptrStr(u.Nickname), + AvatarURL: feishuAvatarURL(u.Avatar), + Metadata: meta, + } +} + +func feishuChatToEntry(c *larkim.ListChat) channel.DirectoryEntry { + chatID := ptrStr(c.ChatId) + meta := map[string]any{"chat_id": chatID} + return channel.DirectoryEntry{ + Kind: channel.DirectoryEntryGroup, + ID: "chat_id:" + chatID, + Name: ptrStr(c.Name), + AvatarURL: ptrStr(c.Avatar), + Metadata: meta, + } +} + +func feishuMemberToEntry(m *larkim.ListMember) channel.DirectoryEntry { + id := ptrStr(m.MemberId) + meta := make(map[string]any) + if m.MemberIdType != nil { + meta["member_id_type"] = *m.MemberIdType + } + prefix := "open_id:" + if m.MemberIdType != nil && *m.MemberIdType == "user_id" { + prefix = "user_id:" + } + return channel.DirectoryEntry{ + Kind: channel.DirectoryEntryUser, + ID: prefix + id, + Name: ptrStr(m.Name), + Metadata: meta, + } +} + +func ptrStr(s *string) string { + if s == nil { + return "" + } + return strings.TrimSpace(*s) +} + +func feishuAvatarURL(avatar *larkcontact.AvatarInfo) string { + if avatar == nil || avatar.Avatar72 == nil { + return "" + } + return strings.TrimSpace(*avatar.Avatar72) +} diff --git a/internal/channel/adapters/feishu/directory_test.go b/internal/channel/adapters/feishu/directory_test.go new file mode 100644 index 00000000..caedadc0 --- /dev/null +++ b/internal/channel/adapters/feishu/directory_test.go @@ -0,0 +1,118 @@ +package feishu + +import ( + "testing" + + larkcontact "github.com/larksuite/oapi-sdk-go/v3/service/contact/v3" + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" +) + +func Test_directoryLimit(t *testing.T) { + tests := []struct { + name string + n int + want int + }{ + {"zero", 0, defaultDirectoryPageSize}, + {"negative", -1, defaultDirectoryPageSize}, + {"one", 1, 1}, + {"default", defaultDirectoryPageSize, defaultDirectoryPageSize}, + {"over max", maxDirectoryPageSize + 100, maxDirectoryPageSize}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := directoryLimit(tt.n); got != tt.want { + t.Errorf("directoryLimit() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_parseFeishuUserInput(t *testing.T) { + tests := []struct { + raw string + wantID string + wantIDType string + }{ + {"open_id:ou_xxx", "ou_xxx", larkcontact.UserIdTypeOpenId}, + {"user_id:u_yyy", "u_yyy", larkcontact.UserIdTypeUserId}, + {"ou_abc", "ou_abc", larkcontact.UserIdTypeOpenId}, + {"u_123", "u_123", larkcontact.UserIdTypeUserId}, + {" open_id: ou_zzz ", "ou_zzz", larkcontact.UserIdTypeOpenId}, + {"", "", ""}, + } + for _, tt := range tests { + id, idType := parseFeishuUserInput(tt.raw) + if id != tt.wantID || idType != tt.wantIDType { + t.Errorf("parseFeishuUserInput(%q) = %q, %q; want %q, %q", tt.raw, id, idType, tt.wantID, tt.wantIDType) + } + } +} + +func Test_ptrStr(t *testing.T) { + s := "x" + if got := ptrStr(nil); got != "" { + t.Errorf("ptrStr(nil) = %q", got) + } + if got := ptrStr(&s); got != "x" { + t.Errorf("ptrStr(&s) = %q", got) + } + space := " a " + if got := ptrStr(&space); got != "a" { + t.Errorf("ptrStr(space) = %q", got) + } +} + +func Test_feishuUserToEntry(t *testing.T) { + openID := "ou_1" + name := "Alice" + u := &larkcontact.User{OpenId: &openID, Name: &name} + e := feishuUserToEntry(u) + if e.Kind != "user" || e.ID != "open_id:ou_1" || e.Name != "Alice" { + t.Errorf("feishuUserToEntry = %+v", e) + } + userID := "u_2" + u2 := &larkcontact.User{UserId: &userID, Name: &name} + e2 := feishuUserToEntry(u2) + if e2.ID != "user_id:u_2" { + t.Errorf("feishuUserToEntry user_id only = %+v", e2) + } +} + +func Test_feishuChatToEntry(t *testing.T) { + chatID := "oc_abc" + name := "Test Group" + c := &larkim.ListChat{ChatId: &chatID, Name: &name} + e := feishuChatToEntry(c) + if e.Kind != "group" || e.ID != "chat_id:oc_abc" || e.Name != "Test Group" { + t.Errorf("feishuChatToEntry = %+v", e) + } +} + +func Test_feishuMemberToEntry(t *testing.T) { + memberID := "ou_m1" + memberType := "open_id" + name := "Bob" + m := &larkim.ListMember{MemberId: &memberID, MemberIdType: &memberType, Name: &name} + e := feishuMemberToEntry(m) + if e.Kind != "user" || e.ID != "open_id:ou_m1" || e.Name != "Bob" { + t.Errorf("feishuMemberToEntry = %+v", e) + } + memberTypeUser := "user_id" + m2 := &larkim.ListMember{MemberId: &memberID, MemberIdType: &memberTypeUser, Name: &name} + e2 := feishuMemberToEntry(m2) + if e2.ID != "user_id:ou_m1" { + t.Errorf("feishuMemberToEntry user_id type = %+v", e2) + } +} + +func Test_feishuAvatarURL(t *testing.T) { + if got := feishuAvatarURL(nil); got != "" { + t.Errorf("feishuAvatarURL(nil) = %q", got) + } + url72 := "https://avatar.example/72.png" + a := &larkcontact.AvatarInfo{Avatar72: &url72} + if got := feishuAvatarURL(a); got != url72 { + t.Errorf("feishuAvatarURL = %q", got) + } +} diff --git a/internal/channel/adapters/feishu/feishu.go b/internal/channel/adapters/feishu/feishu.go index 22f3bc66..cca6d370 100644 --- a/internal/channel/adapters/feishu/feishu.go +++ b/internal/channel/adapters/feishu/feishu.go @@ -25,6 +25,75 @@ type FeishuAdapter struct { logger *slog.Logger } +const processingBusyReactionType = "Typing" + +type messageReactionAPI interface { + Create(ctx context.Context, req *larkim.CreateMessageReactionReq, options ...larkcore.RequestOptionFunc) (*larkim.CreateMessageReactionResp, error) + Delete(ctx context.Context, req *larkim.DeleteMessageReactionReq, options ...larkcore.RequestOptionFunc) (*larkim.DeleteMessageReactionResp, error) +} + +type processingReactionGateway interface { + Add(ctx context.Context, messageID, reactionType string) (string, error) + Remove(ctx context.Context, messageID, reactionID string) error +} + +type larkProcessingReactionGateway struct { + api messageReactionAPI +} + +func (g *larkProcessingReactionGateway) Add(ctx context.Context, messageID, reactionType string) (string, error) { + if g == nil || g.api == nil { + return "", fmt.Errorf("feishu reaction api not configured") + } + req := larkim.NewCreateMessageReactionReqBuilder(). + MessageId(messageID). + Body(larkim.NewCreateMessageReactionReqBodyBuilder(). + ReactionType(larkim.NewEmojiBuilder().EmojiType(reactionType).Build()). + Build()). + Build() + resp, err := g.api.Create(ctx, req) + if err != nil { + return "", err + } + if resp == nil || !resp.Success() { + code := 0 + msg := "" + if resp != nil { + code = resp.Code + msg = resp.Msg + } + return "", fmt.Errorf("feishu add reaction failed: %s (code: %d)", msg, code) + } + if resp.Data == nil || resp.Data.ReactionId == nil || strings.TrimSpace(*resp.Data.ReactionId) == "" { + return "", fmt.Errorf("feishu add reaction failed: empty reaction id") + } + return strings.TrimSpace(*resp.Data.ReactionId), nil +} + +func (g *larkProcessingReactionGateway) Remove(ctx context.Context, messageID, reactionID string) error { + if g == nil || g.api == nil { + return fmt.Errorf("feishu reaction api not configured") + } + req := larkim.NewDeleteMessageReactionReqBuilder(). + MessageId(messageID). + ReactionId(reactionID). + Build() + resp, err := g.api.Delete(ctx, req) + if err != nil { + return err + } + if resp == nil || !resp.Success() { + code := 0 + msg := "" + if resp != nil { + code = resp.Code + msg = resp.Msg + } + return fmt.Errorf("feishu remove reaction failed: %s (code: %d)", msg, code) + } + return nil +} + // NewFeishuAdapter creates a FeishuAdapter with the given logger. func NewFeishuAdapter(log *slog.Logger) *FeishuAdapter { if log == nil { @@ -46,10 +115,14 @@ func (a *FeishuAdapter) Descriptor() channel.Descriptor { Type: Type, DisplayName: "Feishu", Capabilities: channel.ChannelCapabilities{ - Text: true, - RichText: true, - Attachments: true, - Reply: true, + Text: true, + RichText: true, + Attachments: true, + Media: true, + Reactions: true, + Reply: true, + Streaming: true, + BlockStreaming: true, }, ConfigSchema: channel.ConfigSchema{ Version: 1, @@ -84,6 +157,79 @@ func (a *FeishuAdapter) Descriptor() channel.Descriptor { } } +// ProcessingStarted adds a transient reaction to indicate the inbound message is being processed. +func (a *FeishuAdapter) ProcessingStarted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { + messageID := strings.TrimSpace(info.SourceMessageID) + if messageID == "" { + return channel.ProcessingStatusHandle{}, nil + } + gateway, err := a.processingReactionGateway(cfg) + if err != nil { + return channel.ProcessingStatusHandle{}, err + } + token, err := addProcessingReaction(ctx, gateway, messageID, processingBusyReactionType) + if err != nil { + return channel.ProcessingStatusHandle{}, err + } + return channel.ProcessingStatusHandle{Token: token}, nil +} + +// ProcessingCompleted removes the transient processing reaction before output is sent. +func (a *FeishuAdapter) ProcessingCompleted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { + messageID := strings.TrimSpace(info.SourceMessageID) + reactionID := strings.TrimSpace(handle.Token) + if messageID == "" || reactionID == "" { + return nil + } + gateway, err := a.processingReactionGateway(cfg) + if err != nil { + return err + } + return removeProcessingReaction(ctx, gateway, messageID, reactionID) +} + +// ProcessingFailed removes the transient processing reaction when chat processing fails. +func (a *FeishuAdapter) ProcessingFailed(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error { + return a.ProcessingCompleted(ctx, cfg, msg, info, handle) +} + +func (a *FeishuAdapter) processingReactionGateway(cfg channel.ChannelConfig) (processingReactionGateway, error) { + feishuCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, err + } + client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret) + gateway := &larkProcessingReactionGateway{api: client.Im.V1.MessageReaction} + return gateway, nil +} + +func addProcessingReaction(ctx context.Context, gateway processingReactionGateway, messageID, reactionType string) (string, error) { + if gateway == nil { + return "", fmt.Errorf("processing reaction gateway is nil") + } + msgID := strings.TrimSpace(messageID) + if msgID == "" { + return "", nil + } + rxType := strings.TrimSpace(reactionType) + if rxType == "" { + return "", fmt.Errorf("processing reaction type is empty") + } + return gateway.Add(ctx, msgID, rxType) +} + +func removeProcessingReaction(ctx context.Context, gateway processingReactionGateway, messageID, reactionID string) error { + if gateway == nil { + return fmt.Errorf("processing reaction gateway is nil") + } + msgID := strings.TrimSpace(messageID) + rxID := strings.TrimSpace(reactionID) + if msgID == "" || rxID == "" { + return nil + } + return gateway.Remove(ctx, msgID, rxID) +} + // NormalizeConfig validates and normalizes a Feishu channel configuration map. func (a *FeishuAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { return normalizeConfig(raw) @@ -127,48 +273,107 @@ func (a *FeishuAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, return nil, err } connCtx, cancel := context.WithCancel(ctx) - eventDispatcher := dispatcher.NewEventDispatcher( - feishuCfg.VerificationToken, - feishuCfg.EncryptKey, - ) - eventDispatcher.OnP2MessageReceiveV1(func(_ context.Context, event *larkim.P2MessageReceiveV1) error { - msg := extractFeishuInbound(event) - text := msg.Message.PlainText() - if text == "" && len(msg.Message.Attachments) == 0 { - return nil - } - msg.BotID = cfg.BotID - if a.logger != nil { - a.logger.Info( - "inbound received", - slog.String("config_id", cfg.ID), - slog.String("session_id", msg.SessionID()), - slog.String("chat_type", msg.Conversation.Type), - slog.String("text", common.SummarizeText(text)), - ) - } - go func() { - if err := handler(connCtx, cfg, msg); err != nil && a.logger != nil { - a.logger.Error("handle inbound failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) + newClient := func() *larkws.Client { + eventDispatcher := dispatcher.NewEventDispatcher( + feishuCfg.VerificationToken, + feishuCfg.EncryptKey, + ) + eventDispatcher.OnP2MessageReceiveV1(func(_ context.Context, event *larkim.P2MessageReceiveV1) error { + msg := extractFeishuInbound(event) + text := msg.Message.PlainText() + rawMessageID := "" + rawMessageType := "" + if event != nil && event.Event != nil && event.Event.Message != nil { + if event.Event.Message.MessageId != nil { + rawMessageID = strings.TrimSpace(*event.Event.Message.MessageId) + } + if event.Event.Message.MessageType != nil { + rawMessageType = strings.TrimSpace(string(*event.Event.Message.MessageType)) + } } - }() - return nil - }) - eventDispatcher.OnP2MessageReadV1(func(_ context.Context, _ *larkim.P2MessageReadV1) error { - return nil - }) - - client := larkws.NewClient( - feishuCfg.AppID, - feishuCfg.AppSecret, - larkws.WithEventHandler(eventDispatcher), - larkws.WithLogger(newLarkSlogLogger(a.logger)), - larkws.WithLogLevel(larkcore.LogLevelDebug), - ) + if text == "" && len(msg.Message.Attachments) == 0 { + if a.logger != nil { + a.logger.Info( + "inbound ignored empty payload", + slog.String("config_id", cfg.ID), + slog.String("message_id", rawMessageID), + slog.String("message_type", rawMessageType), + slog.String("chat_type", msg.Conversation.Type), + ) + } + return nil + } + msg.BotID = cfg.BotID + if a.logger != nil { + isMentioned := false + if msg.Metadata != nil { + if v, ok := msg.Metadata["is_mentioned"].(bool); ok { + isMentioned = v + } + } + a.logger.Info( + "inbound received", + slog.String("config_id", cfg.ID), + slog.String("message_id", rawMessageID), + slog.String("message_type", rawMessageType), + slog.String("route_key", msg.RoutingKey()), + slog.String("chat_type", msg.Conversation.Type), + slog.Bool("is_mentioned", isMentioned), + slog.String("text", common.SummarizeText(text)), + ) + } + go func() { + if err := handler(connCtx, cfg, msg); err != nil && a.logger != nil { + a.logger.Error("handle inbound failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) + } + }() + return nil + }) + eventDispatcher.OnP2MessageReadV1(func(_ context.Context, _ *larkim.P2MessageReadV1) error { + return nil + }) + // Ignore reaction lifecycle events explicitly to avoid SDK "not found handler" noise logs. + // These events are expected because the adapter uses reactions for processing status. + eventDispatcher.OnP2MessageReactionCreatedV1(func(_ context.Context, _ *larkim.P2MessageReactionCreatedV1) error { + return nil + }) + eventDispatcher.OnP2MessageReactionDeletedV1(func(_ context.Context, _ *larkim.P2MessageReactionDeletedV1) error { + return nil + }) + return larkws.NewClient( + feishuCfg.AppID, + feishuCfg.AppSecret, + larkws.WithEventHandler(eventDispatcher), + larkws.WithLogger(newLarkSlogLogger(a.logger)), + larkws.WithLogLevel(larkcore.LogLevelDebug), + ) + } go func() { - if err := client.Start(connCtx); err != nil && a.logger != nil { - a.logger.Error("client start failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) + const reconnectDelay = 3 * time.Second + for { + if connCtx.Err() != nil { + return + } + client := newClient() + err := client.Start(connCtx) + if connCtx.Err() != nil { + return + } + if a.logger != nil { + if err != nil { + a.logger.Error("client start failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) + } else { + a.logger.Warn("client exited without error; reconnecting", slog.String("config_id", cfg.ID)) + } + } + timer := time.NewTimer(reconnectDelay) + select { + case <-connCtx.Done(): + timer.Stop() + return + case <-timer.C: + } } }() @@ -256,6 +461,39 @@ func (a *FeishuAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg return a.handleResponse(cfg.ID, resp, err) } +// OpenStream opens a Feishu streaming session. +// The adapter strategy uses one interactive card and patches it incrementally. +func (a *FeishuAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { + target = strings.TrimSpace(target) + if target == "" { + return nil, fmt.Errorf("feishu target is required") + } + feishuCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, err + } + receiveID, receiveType, err := resolveFeishuReceiveID(target) + if err != nil { + return nil, err + } + client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret) + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + return &feishuOutboundStream{ + adapter: a, + cfg: cfg, + target: target, + reply: opts.Reply, + client: client, + receiveID: receiveID, + receiveType: receiveType, + patchInterval: feishuStreamPatchInterval, + }, nil +} + func (a *FeishuAdapter) handleReplyResponse(configID string, resp *larkim.ReplyMessageResp, err error) error { if err != nil { if a.logger != nil { @@ -307,66 +545,83 @@ func (a *FeishuAdapter) handleResponse(configID string, resp *larkim.CreateMessa } func (a *FeishuAdapter) sendAttachment(ctx context.Context, client *lark.Client, receiveID, receiveType string, att channel.Attachment, text string) error { - httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, att.URL, nil) - if err != nil { - return fmt.Errorf("failed to build download request: %w", err) - } - httpClient := &http.Client{Timeout: 60 * time.Second} - resp, err := httpClient.Do(httpReq) - if err != nil { - return fmt.Errorf("failed to download attachment: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("failed to download attachment, status: %d", resp.StatusCode) - } - var msgType string var contentMap map[string]string - - if strings.HasPrefix(att.Mime, "image/") || att.Type == channel.AttachmentImage { - uploadReq := larkim.NewCreateImageReqBuilder(). - Body(larkim.NewCreateImageReqBodyBuilder(). - ImageType(larkim.ImageTypeMessage). - Image(resp.Body). - Build()). - Build() - uploadResp, err := client.Im.V1.Image.Create(ctx, uploadReq) - if err != nil { - return fmt.Errorf("failed to upload image: %w", err) + sourcePlatform := strings.TrimSpace(att.SourcePlatform) + platformKey := strings.TrimSpace(att.PlatformKey) + if platformKey != "" && (sourcePlatform == "" || strings.EqualFold(sourcePlatform, Type.String())) { + if strings.HasPrefix(att.Mime, "image/") || att.Type == channel.AttachmentImage { + msgType = larkim.MsgTypeImage + contentMap = map[string]string{"image_key": platformKey} + } else { + msgType = larkim.MsgTypeFile + contentMap = map[string]string{"file_key": platformKey} } - if uploadResp == nil || !uploadResp.Success() { - code, msg := 0, "" - if uploadResp != nil { - code, msg = uploadResp.Code, uploadResp.Msg - } - return fmt.Errorf("failed to upload image: %s (code: %d)", msg, code) - } - msgType = larkim.MsgTypeImage - contentMap = map[string]string{"image_key": *uploadResp.Data.ImageKey} } else { - fileType := resolveFeishuFileType(att.Name, att.Mime) - uploadReq := larkim.NewCreateFileReqBuilder(). - Body(larkim.NewCreateFileReqBodyBuilder(). - FileType(fileType). - FileName(att.Name). - File(resp.Body). - Build()). - Build() - uploadResp, err := client.Im.V1.File.Create(ctx, uploadReq) + downloadURL := strings.TrimSpace(att.URL) + if downloadURL == "" { + return fmt.Errorf("failed to download attachment: url is required when platform key is unavailable") + } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil) if err != nil { - return fmt.Errorf("failed to upload file: %w", err) + return fmt.Errorf("failed to build download request: %w", err) } - if uploadResp == nil || !uploadResp.Success() { - code, msg := 0, "" - if uploadResp != nil { - code, msg = uploadResp.Code, uploadResp.Msg + httpClient := &http.Client{Timeout: 60 * time.Second} + resp, err := httpClient.Do(httpReq) + if err != nil { + return fmt.Errorf("failed to download attachment: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to download attachment, status: %d", resp.StatusCode) + } + if strings.HasPrefix(att.Mime, "image/") || att.Type == channel.AttachmentImage { + uploadReq := larkim.NewCreateImageReqBuilder(). + Body(larkim.NewCreateImageReqBodyBuilder(). + ImageType(larkim.ImageTypeMessage). + Image(resp.Body). + Build()). + Build() + uploadResp, err := client.Im.V1.Image.Create(ctx, uploadReq) + if err != nil { + return fmt.Errorf("failed to upload image: %w", err) } - return fmt.Errorf("failed to upload file: %s (code: %d)", msg, code) + if uploadResp == nil || !uploadResp.Success() { + code, msg := 0, "" + if uploadResp != nil { + code, msg = uploadResp.Code, uploadResp.Msg + } + return fmt.Errorf("failed to upload image: %s (code: %d)", msg, code) + } + msgType = larkim.MsgTypeImage + contentMap = map[string]string{"image_key": *uploadResp.Data.ImageKey} + } else { + fileType := resolveFeishuFileType(att.Name, att.Mime) + fileName := strings.TrimSpace(att.Name) + if fileName == "" { + fileName = "attachment" + } + uploadReq := larkim.NewCreateFileReqBuilder(). + Body(larkim.NewCreateFileReqBodyBuilder(). + FileType(fileType). + FileName(fileName). + File(resp.Body). + Build()). + Build() + uploadResp, err := client.Im.V1.File.Create(ctx, uploadReq) + if err != nil { + return fmt.Errorf("failed to upload file: %w", err) + } + if uploadResp == nil || !uploadResp.Success() { + code, msg := 0, "" + if uploadResp != nil { + code, msg = uploadResp.Code, uploadResp.Msg + } + return fmt.Errorf("failed to upload file: %s (code: %d)", msg, code) + } + msgType = larkim.MsgTypeFile + contentMap = map[string]string{"file_key": *uploadResp.Data.FileKey} } - msgType = larkim.MsgTypeFile - contentMap = map[string]string{"file_key": *uploadResp.Data.FileKey} } content, err := json.Marshal(contentMap) @@ -421,10 +676,77 @@ func (a *FeishuAdapter) buildPostContent(msg channel.Message) (string, error) { line := []any{} for _, part := range msg.Parts { - if part.Type == channel.MessagePartText { + switch part.Type { + case channel.MessagePartText: + text := strings.TrimSpace(part.Text) + if text == "" { + continue + } line = append(line, map[string]any{ "tag": "text", - "text": part.Text, + "text": text, + }) + case channel.MessagePartLink: + url := strings.TrimSpace(part.URL) + label := strings.TrimSpace(part.Text) + if label == "" { + label = url + } + if url == "" || label == "" { + continue + } + line = append(line, map[string]any{ + "tag": "a", + "text": label, + "href": url, + }) + case channel.MessagePartCodeBlock: + code := strings.TrimSpace(part.Text) + if code == "" { + continue + } + language := strings.TrimSpace(part.Language) + if language != "" { + code = "```" + language + "\n" + code + "\n```" + } else { + code = "```\n" + code + "\n```" + } + line = append(line, map[string]any{ + "tag": "text", + "text": code, + }) + case channel.MessagePartMention: + mention := strings.TrimSpace(part.Text) + if mention == "" { + mention = strings.TrimSpace(part.ChannelIdentityID) + } + if mention == "" { + continue + } + line = append(line, map[string]any{ + "tag": "text", + "text": "@" + mention, + }) + case channel.MessagePartEmoji: + emoji := strings.TrimSpace(part.Emoji) + if emoji == "" { + emoji = strings.TrimSpace(part.Text) + } + if emoji == "" { + continue + } + line = append(line, map[string]any{ + "tag": "text", + "text": emoji, + }) + } + } + if len(line) == 0 { + text := strings.TrimSpace(msg.PlainText()) + if text != "" { + line = append(line, map[string]any{ + "tag": "text", + "text": text, }) } } @@ -433,142 +755,3 @@ func (a *FeishuAdapter) buildPostContent(msg channel.Message) (string, error) { payload, err := json.Marshal(pc) return string(payload), err } - -func extractFeishuInbound(event *larkim.P2MessageReceiveV1) channel.InboundMessage { - if event == nil || event.Event == nil || event.Event.Message == nil { - return channel.InboundMessage{Channel: Type} - } - message := event.Event.Message - - var msg channel.Message - if message.MessageId != nil { - msg.ID = *message.MessageId - } - - var contentMap map[string]any - if message.Content != nil { - _ = json.Unmarshal([]byte(*message.Content), &contentMap) - } - isMentioned := hasFeishuMention(contentMap) - - if message.MessageType != nil { - switch *message.MessageType { - case larkim.MsgTypeText: - if txt, ok := contentMap["text"].(string); ok { - msg.Text = txt - } - case larkim.MsgTypeImage: - if key, ok := contentMap["image_key"].(string); ok { - msg.Attachments = append(msg.Attachments, channel.Attachment{ - Type: channel.AttachmentImage, - URL: key, - }) - } - case larkim.MsgTypeFile, larkim.MsgTypeAudio: - if key, ok := contentMap["file_key"].(string); ok { - name, _ := contentMap["file_name"].(string) - msg.Attachments = append(msg.Attachments, channel.Attachment{ - Type: channel.AttachmentType(*message.MessageType), - URL: key, - Name: name, - }) - } - } - } - - if message.ParentId != nil && *message.ParentId != "" { - msg.Reply = &channel.ReplyRef{ - MessageID: *message.ParentId, - } - } - - senderID, senderOpenID := "", "" - if event.Event.Sender != nil && event.Event.Sender.SenderId != nil { - if event.Event.Sender.SenderId.UserId != nil { - senderID = strings.TrimSpace(*event.Event.Sender.SenderId.UserId) - } - if event.Event.Sender.SenderId.OpenId != nil { - senderOpenID = strings.TrimSpace(*event.Event.Sender.SenderId.OpenId) - } - } - chatID := "" - chatType := "" - if message.ChatId != nil { - chatID = strings.TrimSpace(*message.ChatId) - } - if message.ChatType != nil { - chatType = strings.TrimSpace(*message.ChatType) - } - replyTo := senderOpenID - if replyTo == "" { - replyTo = senderID - } - if chatType != "" && chatType != "p2p" && chatID != "" { - replyTo = "chat_id:" + chatID - } - attrs := map[string]string{} - if senderID != "" { - attrs["user_id"] = senderID - } - if senderOpenID != "" { - attrs["open_id"] = senderOpenID - } - subjectID := senderOpenID - if subjectID == "" { - subjectID = senderID - } - - return channel.InboundMessage{ - Channel: Type, - Message: msg, - ReplyTarget: replyTo, - Sender: channel.Identity{ - SubjectID: subjectID, - DisplayName: senderOpenID, - Attributes: attrs, - }, - Conversation: channel.Conversation{ - ID: chatID, - Type: chatType, - }, - ReceivedAt: time.Now().UTC(), - Source: "feishu", - Metadata: map[string]any{ - "is_mentioned": isMentioned, - }, - } -} - -func hasFeishuMention(contentMap map[string]any) bool { - if len(contentMap) == 0 { - return false - } - raw, ok := contentMap["mentions"] - if !ok { - return false - } - switch mentions := raw.(type) { - case []any: - return len(mentions) > 0 - case []map[string]any: - return len(mentions) > 0 - default: - return false - } -} - -func resolveFeishuReceiveID(raw string) (string, string, error) { - if raw == "" { - return "", "", fmt.Errorf("feishu target is required") - } - if strings.HasPrefix(raw, "open_id:") { - return strings.TrimPrefix(raw, "open_id:"), larkim.ReceiveIdTypeOpenId, nil - } - if strings.HasPrefix(raw, "user_id:") { - return strings.TrimPrefix(raw, "user_id:"), larkim.ReceiveIdTypeUserId, nil - } - if strings.HasPrefix(raw, "chat_id:") { - return strings.TrimPrefix(raw, "chat_id:"), larkim.ReceiveIdTypeChatId, nil - } - return raw, larkim.ReceiveIdTypeOpenId, nil -} diff --git a/internal/channel/adapters/feishu/feishu_integration_test.go b/internal/channel/adapters/feishu/feishu_integration_test.go index 80daea33..35e4fe56 100644 --- a/internal/channel/adapters/feishu/feishu_integration_test.go +++ b/internal/channel/adapters/feishu/feishu_integration_test.go @@ -11,30 +11,25 @@ import ( "github.com/memohai/memoh/internal/channel" ) -// TestFeishuGateway_Integration 飞书通道集成测试 -// 运行此测试需要设置环境变量: -// FEISHU_APP_ID: 飞书应用的 App ID -// FEISHU_APP_SECRET: 飞书应用的 App Secret -// FEISHU_ENCRYPT_KEY: (可选) 飞书应用的 Encrypt Key -// FEISHU_VERIFICATION_TOKEN: (可选) 飞书应用的 Verification Token +// TestFeishuGateway_Integration runs Feishu channel integration test. +// Required env: FEISHU_APP_ID, FEISHU_APP_SECRET. +// Optional: FEISHU_ENCRYPT_KEY, FEISHU_VERIFICATION_TOKEN. func TestFeishuGateway_Integration(t *testing.T) { appID := os.Getenv("FEISHU_APP_ID") appSecret := os.Getenv("FEISHU_APP_SECRET") if appID == "" || appSecret == "" { - t.Skip("跳过集成测试: 未设置 FEISHU_APP_ID 或 FEISHU_APP_SECRET 环境变量") + t.Skip("skipping integration test: FEISHU_APP_ID or FEISHU_APP_SECRET not set") } encryptKey := os.Getenv("FEISHU_ENCRYPT_KEY") verificationToken := os.Getenv("FEISHU_VERIFICATION_TOKEN") - // 使用更规范的日志配置 logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelInfo, })) adapter := NewFeishuAdapter(logger) - // 构造测试配置 cfg := channel.ChannelConfig{ ID: "integration-test-bot", Credentials: map[string]any{ @@ -45,28 +40,23 @@ func TestFeishuGateway_Integration(t *testing.T) { }, } - // 定义测试上下文,设置合理的超时时间 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() - // 消息计数,用于验证是否收到消息 receivedChan := make(chan channel.InboundMessage, 1) - // 模拟 InboundHandler handler := func(ctx context.Context, c channel.ChannelConfig, msg channel.InboundMessage) error { plainText := msg.Message.PlainText() - logger.Info("测试收到消息", + logger.Info("received message in test", slog.String("text", plainText), slog.String("user_id", msg.Sender.Attribute("user_id")), - slog.String("session_id", msg.SessionID())) + slog.String("route_key", msg.RoutingKey())) - // 将消息放入通道,供主测试逻辑验证 select { case receivedChan <- msg: default: } - // 自动回复测试 (验证下行链路) reply := channel.OutboundMessage{ Target: msg.ReplyTarget, Message: channel.Message{ @@ -78,7 +68,6 @@ func TestFeishuGateway_Integration(t *testing.T) { return fmt.Errorf("failed to send reply: %w", err) } - // 模拟异步主动推送测试 go func() { time.Sleep(1 * time.Second) pushMsg := channel.OutboundMessage{ @@ -93,31 +82,27 @@ func TestFeishuGateway_Integration(t *testing.T) { return nil } - // 启动适配器 - logger.Info("正在启动飞书适配器...", slog.String("app_id", appID)) + logger.Info("starting Feishu adapter", slog.String("app_id", appID)) runner, err := adapter.Connect(ctx, cfg, handler) if err != nil { - t.Fatalf("适配器启动失败: %v", err) + t.Fatalf("adapter connect failed: %v", err) } defer func() { _ = runner.Stop(context.Background()) }() fmt.Println("==================================================================") - fmt.Println("🚀 飞书集成测试已就绪!") - fmt.Println("请在飞书客户端向机器人发送一条消息,以完成端到端验证。") - fmt.Println("测试将在收到第一条消息或 10 分钟超时后结束。") + fmt.Println("Feishu integration test ready. Send a message in Feishu client to verify.") + fmt.Println("Test ends on first message received or 10 min timeout.") fmt.Println("==================================================================") - // 等待测试结果 select { case msg := <-receivedChan: - logger.Info("集成测试验证成功!", slog.String("received_text", msg.Message.PlainText())) - // 给一点时间让异步推送完成 + logger.Info("integration test passed", slog.String("received_text", msg.Message.PlainText())) time.Sleep(2 * time.Second) case <-ctx.Done(): if ctx.Err() == context.DeadlineExceeded { - t.Log("测试超时结束") + t.Log("test timed out") } } } diff --git a/internal/channel/adapters/feishu/feishu_test.go b/internal/channel/adapters/feishu/feishu_test.go index 765f10a3..77847386 100644 --- a/internal/channel/adapters/feishu/feishu_test.go +++ b/internal/channel/adapters/feishu/feishu_test.go @@ -1,11 +1,48 @@ package feishu import ( + "context" + "encoding/json" + "errors" + "strings" "testing" larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + + "github.com/memohai/memoh/internal/channel" ) +type fakeProcessingReactionGateway struct { + addCalls []struct{ messageID, reactionType string } + removeCalls []struct{ messageID, reactionID string } + addResponse []struct { + reactionID string + err error + } + removeErr error +} + +func (g *fakeProcessingReactionGateway) Add(ctx context.Context, messageID, reactionType string) (string, error) { + g.addCalls = append(g.addCalls, struct{ messageID, reactionType string }{ + messageID: messageID, + reactionType: reactionType, + }) + if len(g.addResponse) == 0 { + return "reaction-default", nil + } + resp := g.addResponse[0] + g.addResponse = g.addResponse[1:] + return resp.reactionID, resp.err +} + +func (g *fakeProcessingReactionGateway) Remove(ctx context.Context, messageID, reactionID string) error { + g.removeCalls = append(g.removeCalls, struct{ messageID, reactionID string }{ + messageID: messageID, + reactionID: reactionID, + }) + return g.removeErr +} + func TestResolveFeishuReceiveID(t *testing.T) { t.Parallel() @@ -70,6 +107,18 @@ func TestExtractFeishuInboundP2P(t *testing.T) { if got.ReplyTarget != "ou_1" { t.Fatalf("unexpected reply target: %s", got.ReplyTarget) } + if got.Sender.DisplayName != "" { + t.Fatalf("expected empty sender display name, got: %s", got.Sender.DisplayName) + } + if got.Sender.SubjectID != "ou_1" { + t.Fatalf("unexpected sender subject id: %s", got.Sender.SubjectID) + } + if got.Sender.Attribute("open_id") != "ou_1" { + t.Fatalf("unexpected sender open_id: %s", got.Sender.Attribute("open_id")) + } + if got.Sender.Attribute("user_id") != "u_1" { + t.Fatalf("unexpected sender user_id: %s", got.Sender.Attribute("user_id")) + } if mentioned, _ := got.Metadata["is_mentioned"].(bool); mentioned { t.Fatalf("unexpected mention flag for p2p message") } @@ -126,6 +175,109 @@ func TestExtractFeishuInboundNonText(t *testing.T) { } } +func TestExtractFeishuInboundImageAttachmentReference(t *testing.T) { + t.Parallel() + + content := `{"image_key":"img_1"}` + msgType := larkim.MsgTypeImage + event := &larkim.P2MessageReceiveV1{ + Event: &larkim.P2MessageReceiveV1Data{ + Message: &larkim.EventMessage{ + MessageType: &msgType, + Content: &content, + }, + }, + } + got := extractFeishuInbound(event) + if len(got.Message.Attachments) != 1 { + t.Fatalf("expected one attachment, got %d", len(got.Message.Attachments)) + } + att := got.Message.Attachments[0] + if att.Type != channel.AttachmentImage { + t.Fatalf("unexpected attachment type: %s", att.Type) + } + if att.PlatformKey != "img_1" { + t.Fatalf("unexpected platform key: %s", att.PlatformKey) + } + if att.SourcePlatform != Type.String() { + t.Fatalf("unexpected source platform: %s", att.SourcePlatform) + } +} + +func TestFeishuDescriptorIncludesStreamingAndMedia(t *testing.T) { + t.Parallel() + + adapter := NewFeishuAdapter(nil) + caps := adapter.Descriptor().Capabilities + if !caps.Streaming { + t.Fatal("expected streaming capability") + } + if !caps.Media { + t.Fatal("expected media capability") + } +} + +func TestBuildFeishuStreamCardContent(t *testing.T) { + t.Parallel() + + payload, err := buildFeishuStreamCardContent("hello") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + var parsed map[string]any + if err := json.Unmarshal([]byte(payload), &parsed); err != nil { + t.Fatalf("unexpected json error: %v", err) + } + cfg, ok := parsed["config"].(map[string]any) + if !ok { + t.Fatalf("missing config: %+v", parsed) + } + value, ok := cfg["update_multi"].(bool) + if !ok || !value { + t.Fatalf("expected update_multi=true, got: %#v", cfg["update_multi"]) + } +} + +func TestNormalizeFeishuStreamText(t *testing.T) { + t.Parallel() + + if got := normalizeFeishuStreamText(" "); got != feishuStreamThinkingText { + t.Fatalf("unexpected thinking text: %s", got) + } + long := strings.Repeat("a", feishuStreamMaxRunes+100) + got := normalizeFeishuStreamText(long) + if len([]rune(got)) > feishuStreamMaxRunes+4 { + t.Fatalf("expected truncated text, got len=%d", len([]rune(got))) + } + if !strings.HasPrefix(got, "...\n") { + t.Fatalf("expected truncation prefix, got: %s", got[:4]) + } +} + +func TestProcessFeishuCardMarkdown(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in string + want string + }{ + {"literal newline", "a\\nb", "a\nb"}, + {"atx h1", "# Title", "**Title**"}, + {"atx h2", "## Section", "**Section**"}, + {"atx h6", "###### Small", "**Small**"}, + {"heading with newline", "# Hi\n\nBody", "**Hi**\n\nBody"}, + {"no heading", "plain **bold**", "plain **bold**"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := processFeishuCardMarkdown(tt.in) + if got != tt.want { + t.Errorf("processFeishuCardMarkdown() = %q, want %q", got, tt.want) + } + }) + } +} + func TestExtractFeishuInboundMention(t *testing.T) { t.Parallel() @@ -149,3 +301,201 @@ func TestExtractFeishuInboundMention(t *testing.T) { t.Fatalf("expected mention flag to be true") } } + +func TestExtractFeishuInboundMentionFromEventMentions(t *testing.T) { + t.Parallel() + + text := `{"text":"hello"}` + msgType := larkim.MsgTypeText + chatType := "group" + chatID := "oc_mention_event" + mention := larkim.NewMentionEventBuilder().Key("@_user_1").Build() + event := &larkim.P2MessageReceiveV1{ + Event: &larkim.P2MessageReceiveV1Data{ + Message: &larkim.EventMessage{ + MessageType: &msgType, + Content: &text, + ChatType: &chatType, + ChatId: &chatID, + Mentions: []*larkim.MentionEvent{mention}, + }, + }, + } + got := extractFeishuInbound(event) + mentioned, ok := got.Metadata["is_mentioned"].(bool) + if !ok || !mentioned { + t.Fatalf("expected mention flag from event mentions") + } +} + +func TestExtractFeishuInboundPostMention(t *testing.T) { + t.Parallel() + + content := `{"zh_cn":{"title":"","content":[[{"tag":"at","user_name":"bot"},{"tag":"text","text":" hi"}]]}}` + msgType := larkim.MsgTypePost + chatType := "group" + chatID := "oc_post_1" + event := &larkim.P2MessageReceiveV1{ + Event: &larkim.P2MessageReceiveV1Data{ + Message: &larkim.EventMessage{ + MessageType: &msgType, + Content: &content, + ChatType: &chatType, + ChatId: &chatID, + }, + }, + } + + got := extractFeishuInbound(event) + if got.Message.PlainText() == "" { + t.Fatalf("expected post message to be converted into text") + } + mentioned, ok := got.Metadata["is_mentioned"].(bool) + if !ok || !mentioned { + t.Fatalf("expected mention flag for post message") + } +} + +func TestAddProcessingReactionFirstSuccess(t *testing.T) { + t.Parallel() + + gateway := &fakeProcessingReactionGateway{ + addResponse: []struct { + reactionID string + err error + }{ + {reactionID: "reaction-1"}, + }, + } + token, err := addProcessingReaction(context.Background(), gateway, "om_1", "Typing") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if token != "reaction-1" { + t.Fatalf("expected token reaction-1, got %q", token) + } + if len(gateway.addCalls) != 1 { + t.Fatalf("expected one add call, got %d", len(gateway.addCalls)) + } + if gateway.addCalls[0].messageID != "om_1" || gateway.addCalls[0].reactionType != "Typing" { + t.Fatalf("unexpected add params: %+v", gateway.addCalls[0]) + } +} + +func TestAddProcessingReactionReturnsError(t *testing.T) { + t.Parallel() + + gateway := &fakeProcessingReactionGateway{ + addResponse: []struct { + reactionID string + err error + }{ + {err: errors.New("invalid reaction type")}, + }, + } + token, err := addProcessingReaction(context.Background(), gateway, "om_2", "INVALID") + if err == nil { + t.Fatal("expected error") + } + if token != "" { + t.Fatalf("expected empty token, got %q", token) + } + if len(gateway.addCalls) != 1 { + t.Fatalf("expected one add call, got %d", len(gateway.addCalls)) + } + if gateway.addCalls[0].reactionType != "INVALID" { + t.Fatalf("unexpected add call sequence: %+v", gateway.addCalls) + } +} + +func TestAddProcessingReactionNoMessageID(t *testing.T) { + t.Parallel() + + gateway := &fakeProcessingReactionGateway{} + token, err := addProcessingReaction(context.Background(), gateway, "", "Typing") + if err != nil { + t.Fatalf("expected no error for empty message id, got: %v", err) + } + if token != "" { + t.Fatalf("expected empty token, got %q", token) + } + if len(gateway.addCalls) != 0 { + t.Fatalf("expected no add calls, got %+v", gateway.addCalls) + } +} + +func TestRemoveProcessingReaction(t *testing.T) { + t.Parallel() + + gateway := &fakeProcessingReactionGateway{} + if err := removeProcessingReaction(context.Background(), gateway, "om_3", "reaction-3"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(gateway.removeCalls) != 1 { + t.Fatalf("expected one remove call, got %d", len(gateway.removeCalls)) + } + if gateway.removeCalls[0].messageID != "om_3" || gateway.removeCalls[0].reactionID != "reaction-3" { + t.Fatalf("unexpected remove params: %+v", gateway.removeCalls[0]) + } +} + +func TestRemoveProcessingReactionNoopForEmptyToken(t *testing.T) { + t.Parallel() + + gateway := &fakeProcessingReactionGateway{} + if err := removeProcessingReaction(context.Background(), gateway, "om_4", ""); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(gateway.removeCalls) != 0 { + t.Fatalf("expected no remove calls, got %+v", gateway.removeCalls) + } +} + +func TestFeishuProcessingStartedNoSourceMessageID(t *testing.T) { + t.Parallel() + + adapter := NewFeishuAdapter(nil) + handle, err := adapter.ProcessingStarted( + context.Background(), + channel.ChannelConfig{}, + channel.InboundMessage{}, + channel.ProcessingStatusInfo{}, + ) + if err != nil { + t.Fatalf("expected no error for empty source message id, got: %v", err) + } + if handle.Token != "" { + t.Fatalf("expected empty token, got %q", handle.Token) + } +} + +func TestFeishuProcessingStartedRequiresConfigWhenSourceMessageExists(t *testing.T) { + t.Parallel() + + adapter := NewFeishuAdapter(nil) + _, err := adapter.ProcessingStarted( + context.Background(), + channel.ChannelConfig{}, + channel.InboundMessage{}, + channel.ProcessingStatusInfo{SourceMessageID: "om_5"}, + ) + if err == nil { + t.Fatal("expected error when credentials are missing") + } +} + +func TestFeishuProcessingCompletedNoopWithoutToken(t *testing.T) { + t.Parallel() + + adapter := NewFeishuAdapter(nil) + err := adapter.ProcessingCompleted( + context.Background(), + channel.ChannelConfig{}, + channel.InboundMessage{}, + channel.ProcessingStatusInfo{SourceMessageID: "om_6"}, + channel.ProcessingStatusHandle{}, + ) + if err != nil { + t.Fatalf("expected no error for empty token, got: %v", err) + } +} diff --git a/internal/channel/adapters/feishu/inbound.go b/internal/channel/adapters/feishu/inbound.go new file mode 100644 index 00000000..1c32aa5c --- /dev/null +++ b/internal/channel/adapters/feishu/inbound.go @@ -0,0 +1,269 @@ +package feishu + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + + "github.com/memohai/memoh/internal/channel" +) + +// extractFeishuInbound converts a Feishu P2MessageReceiveV1 event into a channel.InboundMessage. +func extractFeishuInbound(event *larkim.P2MessageReceiveV1) channel.InboundMessage { + if event == nil || event.Event == nil || event.Event.Message == nil { + return channel.InboundMessage{Channel: Type} + } + message := event.Event.Message + + var msg channel.Message + if message.MessageId != nil { + msg.ID = *message.MessageId + } + + var contentMap map[string]any + if message.Content != nil { + _ = json.Unmarshal([]byte(*message.Content), &contentMap) + } + isMentioned := hasFeishuMention(contentMap, message.Mentions) + + if message.MessageType != nil { + switch *message.MessageType { + case larkim.MsgTypeText: + if txt, ok := contentMap["text"].(string); ok { + msg.Text = txt + } + case larkim.MsgTypePost: + if postText := extractFeishuPostText(contentMap); postText != "" { + msg.Text = postText + } + case larkim.MsgTypeImage: + if key, ok := contentMap["image_key"].(string); ok { + msg.Attachments = append(msg.Attachments, channel.Attachment{ + Type: channel.AttachmentImage, + PlatformKey: key, + SourcePlatform: Type.String(), + }) + } + case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia: + if key, ok := contentMap["file_key"].(string); ok { + name, _ := contentMap["file_name"].(string) + attType := channel.AttachmentFile + switch *message.MessageType { + case larkim.MsgTypeAudio: + attType = channel.AttachmentAudio + case larkim.MsgTypeMedia: + attType = channel.AttachmentVideo + } + msg.Attachments = append(msg.Attachments, channel.Attachment{ + Type: attType, + PlatformKey: key, + SourcePlatform: Type.String(), + Name: name, + }) + } + } + } + + if message.ParentId != nil && *message.ParentId != "" { + msg.Reply = &channel.ReplyRef{ + MessageID: *message.ParentId, + } + } + + senderID, senderOpenID := "", "" + if event.Event.Sender != nil && event.Event.Sender.SenderId != nil { + if event.Event.Sender.SenderId.UserId != nil { + senderID = strings.TrimSpace(*event.Event.Sender.SenderId.UserId) + } + if event.Event.Sender.SenderId.OpenId != nil { + senderOpenID = strings.TrimSpace(*event.Event.Sender.SenderId.OpenId) + } + } + chatID := "" + chatType := "" + if message.ChatId != nil { + chatID = strings.TrimSpace(*message.ChatId) + } + if message.ChatType != nil { + chatType = strings.TrimSpace(*message.ChatType) + } + replyTo := senderOpenID + if replyTo == "" { + replyTo = senderID + } + if chatType != "" && chatType != "p2p" && chatID != "" { + replyTo = "chat_id:" + chatID + } + attrs := map[string]string{} + if senderID != "" { + attrs["user_id"] = senderID + } + if senderOpenID != "" { + attrs["open_id"] = senderOpenID + } + subjectID := senderOpenID + if subjectID == "" { + subjectID = senderID + } + + return channel.InboundMessage{ + Channel: Type, + Message: msg, + ReplyTarget: replyTo, + Sender: channel.Identity{ + SubjectID: subjectID, + Attributes: attrs, + }, + Conversation: channel.Conversation{ + ID: chatID, + Type: chatType, + }, + ReceivedAt: time.Now().UTC(), + Source: "feishu", + Metadata: map[string]any{ + "is_mentioned": isMentioned, + }, + } +} + +func hasFeishuMention(contentMap map[string]any, mentions []*larkim.MentionEvent) bool { + if len(mentions) > 0 { + return true + } + if len(contentMap) == 0 { + return false + } + raw, ok := contentMap["mentions"] + if ok { + switch values := raw.(type) { + case []any: + if len(values) > 0 { + return true + } + case []map[string]any: + if len(values) > 0 { + return true + } + case map[string]any: + if len(values) > 0 { + return true + } + } + } + if text, ok := contentMap["text"].(string); ok { + normalized := strings.ToLower(strings.TrimSpace(text)) + if strings.Contains(normalized, "@_user_") || strings.Contains(normalized, "") { + return true + } + } + return hasFeishuAtTag(contentMap) +} + +func hasFeishuAtTag(raw any) bool { + switch value := raw.(type) { + case map[string]any: + if tag, ok := value["tag"].(string); ok && strings.EqualFold(strings.TrimSpace(tag), "at") { + return true + } + for _, child := range value { + if hasFeishuAtTag(child) { + return true + } + } + case []any: + for _, child := range value { + if hasFeishuAtTag(child) { + return true + } + } + } + return false +} + +func extractFeishuPostText(contentMap map[string]any) string { + zhCN, ok := contentMap["zh_cn"].(map[string]any) + if !ok { + return "" + } + linesRaw, ok := zhCN["content"].([]any) + if !ok { + return "" + } + parts := make([]string, 0, 8) + for _, rawLine := range linesRaw { + line, ok := rawLine.([]any) + if !ok { + continue + } + for _, rawPart := range line { + part, ok := rawPart.(map[string]any) + if !ok { + continue + } + tag := strings.ToLower(strings.TrimSpace(stringValue(part["tag"]))) + switch tag { + case "text", "a": + text := strings.TrimSpace(stringValue(part["text"])) + if text != "" { + parts = append(parts, text) + } + case "at": + name := strings.TrimSpace(stringValue(part["text"])) + if name == "" { + name = strings.TrimSpace(stringValue(part["name"])) + } + if name == "" { + name = strings.TrimSpace(stringValue(part["user_name"])) + } + if name == "" { + parts = append(parts, "@") + continue + } + if !strings.HasPrefix(name, "@") { + name = "@" + name + } + parts = append(parts, name) + default: + text := strings.TrimSpace(stringValue(part["text"])) + if text != "" { + parts = append(parts, text) + } + } + } + } + if len(parts) == 0 { + return "" + } + return strings.Join(parts, " ") +} + +func stringValue(raw any) string { + if raw == nil { + return "" + } + value, ok := raw.(string) + if ok { + return value + } + return fmt.Sprint(raw) +} + +// resolveFeishuReceiveID parses target (open_id:/user_id:/chat_id: prefix) and returns receiveID and receiveType. +func resolveFeishuReceiveID(raw string) (string, string, error) { + if raw == "" { + return "", "", fmt.Errorf("feishu target is required") + } + if strings.HasPrefix(raw, "open_id:") { + return strings.TrimPrefix(raw, "open_id:"), larkim.ReceiveIdTypeOpenId, nil + } + if strings.HasPrefix(raw, "user_id:") { + return strings.TrimPrefix(raw, "user_id:"), larkim.ReceiveIdTypeUserId, nil + } + if strings.HasPrefix(raw, "chat_id:") { + return strings.TrimPrefix(raw, "chat_id:"), larkim.ReceiveIdTypeChatId, nil + } + return raw, larkim.ReceiveIdTypeOpenId, nil +} diff --git a/internal/channel/adapters/feishu/feishu_logger.go b/internal/channel/adapters/feishu/logger.go similarity index 100% rename from internal/channel/adapters/feishu/feishu_logger.go rename to internal/channel/adapters/feishu/logger.go diff --git a/internal/channel/adapters/feishu/stream.go b/internal/channel/adapters/feishu/stream.go new file mode 100644 index 00000000..ce063119 --- /dev/null +++ b/internal/channel/adapters/feishu/stream.go @@ -0,0 +1,286 @@ +package feishu + +import ( + "context" + "encoding/json" + "fmt" + "regexp" + "strings" + "sync/atomic" + "time" + + "github.com/google/uuid" + lark "github.com/larksuite/oapi-sdk-go/v3" + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + + "github.com/memohai/memoh/internal/channel" +) + +const ( + feishuStreamThinkingText = "Thinking..." + feishuStreamPatchInterval = 700 * time.Millisecond + feishuStreamMaxRunes = 8000 +) + +type feishuOutboundStream struct { + adapter *FeishuAdapter + cfg channel.ChannelConfig + target string + reply *channel.ReplyRef + client *lark.Client + receiveID string + receiveType string + cardMessageID string + textBuffer strings.Builder + lastPatchedAt time.Time + lastPatched string + patchInterval time.Duration + closed atomic.Bool +} + +func (s *feishuOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { + if s == nil || s.adapter == nil { + return fmt.Errorf("feishu stream not configured") + } + if s.closed.Load() { + return fmt.Errorf("feishu stream is closed") + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + switch event.Type { + case channel.StreamEventStatus: + if event.Status == channel.StreamStatusStarted { + return s.ensureCard(ctx, feishuStreamThinkingText) + } + return nil + case channel.StreamEventDelta: + if event.Delta == "" { + return nil + } + s.textBuffer.WriteString(event.Delta) + if err := s.ensureCard(ctx, feishuStreamThinkingText); err != nil { + return err + } + if time.Since(s.lastPatchedAt) < s.patchInterval && !strings.Contains(event.Delta, "\n") { + return nil + } + return s.patchCard(ctx, s.textBuffer.String()) + case channel.StreamEventFinal: + if event.Final == nil || event.Final.Message.IsEmpty() { + return nil + } + msg := event.Final.Message + finalText := strings.TrimSpace(msg.PlainText()) + if finalText == "" { + finalText = strings.TrimSpace(s.textBuffer.String()) + } + if finalText != "" { + if err := s.ensureCard(ctx, feishuStreamThinkingText); err != nil { + return err + } + if err := s.patchCard(ctx, finalText); err != nil { + return err + } + } + if len(msg.Attachments) > 0 { + media := msg + media.Format = "" + media.Text = "" + media.Parts = nil + media.Actions = nil + media.Reply = nil + return s.adapter.Send(ctx, s.cfg, channel.OutboundMessage{ + Target: s.target, + Message: media, + }) + } + return nil + case channel.StreamEventError: + errText := strings.TrimSpace(event.Error) + if errText == "" { + return nil + } + if err := s.ensureCard(ctx, feishuStreamThinkingText); err != nil { + return err + } + return s.patchCard(ctx, "Error: "+errText) + default: + return fmt.Errorf("unsupported stream event type: %s", event.Type) + } +} + +func (s *feishuOutboundStream) Close(ctx context.Context) error { + if s == nil { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + s.closed.Store(true) + return nil +} + +func (s *feishuOutboundStream) ensureCard(ctx context.Context, text string) error { + if strings.TrimSpace(s.cardMessageID) != "" { + return nil + } + if s.client == nil { + return fmt.Errorf("feishu client not configured") + } + content, err := buildFeishuStreamCardContent(text) + if err != nil { + return err + } + if s.reply != nil && strings.TrimSpace(s.reply.MessageID) != "" { + replyReq := larkim.NewReplyMessageReqBuilder(). + MessageId(strings.TrimSpace(s.reply.MessageID)). + Body(larkim.NewReplyMessageReqBodyBuilder(). + Content(content). + MsgType(larkim.MsgTypeInteractive). + Uuid(uuid.NewString()). + Build()). + Build() + replyResp, err := s.client.Im.V1.Message.Reply(ctx, replyReq) + if err != nil { + return err + } + if replyResp == nil || !replyResp.Success() { + code, msg := 0, "" + if replyResp != nil { + code, msg = replyResp.Code, replyResp.Msg + } + return fmt.Errorf("feishu stream reply failed: %s (code: %d)", msg, code) + } + if replyResp.Data == nil || replyResp.Data.MessageId == nil || strings.TrimSpace(*replyResp.Data.MessageId) == "" { + return fmt.Errorf("feishu stream reply failed: empty message id") + } + s.cardMessageID = strings.TrimSpace(*replyResp.Data.MessageId) + s.lastPatched = normalizeFeishuStreamText(text) + s.lastPatchedAt = time.Now() + return nil + } + createReq := larkim.NewCreateMessageReqBuilder(). + ReceiveIdType(s.receiveType). + Body(larkim.NewCreateMessageReqBodyBuilder(). + ReceiveId(s.receiveID). + MsgType(larkim.MsgTypeInteractive). + Content(content). + Uuid(uuid.NewString()). + Build()). + Build() + createResp, err := s.client.Im.V1.Message.Create(ctx, createReq) + if err != nil { + return err + } + if createResp == nil || !createResp.Success() { + code, msg := 0, "" + if createResp != nil { + code, msg = createResp.Code, createResp.Msg + } + return fmt.Errorf("feishu stream create failed: %s (code: %d)", msg, code) + } + if createResp.Data == nil || createResp.Data.MessageId == nil || strings.TrimSpace(*createResp.Data.MessageId) == "" { + return fmt.Errorf("feishu stream create failed: empty message id") + } + s.cardMessageID = strings.TrimSpace(*createResp.Data.MessageId) + s.lastPatched = normalizeFeishuStreamText(text) + s.lastPatchedAt = time.Now() + return nil +} + +func (s *feishuOutboundStream) patchCard(ctx context.Context, text string) error { + if strings.TrimSpace(s.cardMessageID) == "" { + return fmt.Errorf("feishu stream card message not initialized") + } + contentText := normalizeFeishuStreamText(text) + if contentText == s.lastPatched { + return nil + } + content, err := buildFeishuStreamCardContent(contentText) + if err != nil { + return err + } + patchReq := larkim.NewPatchMessageReqBuilder(). + MessageId(strings.TrimSpace(s.cardMessageID)). + Body(larkim.NewPatchMessageReqBodyBuilder(). + Content(content). + Build()). + Build() + patchResp, err := s.client.Im.V1.Message.Patch(ctx, patchReq) + if err != nil { + return err + } + if patchResp == nil || !patchResp.Success() { + code, msg := 0, "" + if patchResp != nil { + code, msg = patchResp.Code, patchResp.Msg + } + return fmt.Errorf("feishu stream patch failed: %s (code: %d)", msg, code) + } + s.lastPatched = contentText + s.lastPatchedAt = time.Now() + return nil +} + +func buildFeishuStreamCardContent(text string) (string, error) { + content := normalizeFeishuStreamText(text) + content = processFeishuCardMarkdown(content) + card := map[string]any{ + "config": map[string]any{ + "wide_screen_mode": true, + "enable_forward": true, + "update_multi": true, + }, + "elements": []map[string]any{ + { + "tag": "div", + "fields": []map[string]any{ + { + "is_short": false, + "text": map[string]any{ + "tag": "lark_md", + "content": content, + }, + }, + }, + }, + }, + } + data, err := json.Marshal(card) + if err != nil { + return "", err + } + return string(data), nil +} + +var feishuCardHeadingPrefix = regexp.MustCompile(`(?m)^#{1,6}\s+(.+)$`) + +// processFeishuCardMarkdown normalizes markdown for Feishu card lark_md (e.g. ATX headings to bold). +func processFeishuCardMarkdown(s string) string { + s = strings.ReplaceAll(s, "\\n", "\n") + s = feishuCardHeadingPrefix.ReplaceAllStringFunc(s, func(m string) string { + parts := feishuCardHeadingPrefix.FindStringSubmatch(m) + if len(parts) == 2 { + return "**" + parts[1] + "**" + } + return m + }) + return s +} + +func normalizeFeishuStreamText(text string) string { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return feishuStreamThinkingText + } + runes := []rune(trimmed) + if len(runes) <= feishuStreamMaxRunes { + return trimmed + } + return "...\n" + string(runes[len(runes)-feishuStreamMaxRunes:]) +} diff --git a/internal/channel/adapters/local/cli.go b/internal/channel/adapters/local/cli.go index 9241b207..3b026fb6 100644 --- a/internal/channel/adapters/local/cli.go +++ b/internal/channel/adapters/local/cli.go @@ -10,11 +10,11 @@ import ( // CLIAdapter implements channel.Sender for the local CLI channel. type CLIAdapter struct { - hub *SessionHub + hub *RouteHub } -// NewCLIAdapter creates a CLIAdapter backed by the given session hub. -func NewCLIAdapter(hub *SessionHub) *CLIAdapter { +// NewCLIAdapter creates a CLIAdapter backed by the given route hub. +func NewCLIAdapter(hub *RouteHub) *CLIAdapter { return &CLIAdapter{hub: hub} } @@ -30,20 +30,22 @@ func (a *CLIAdapter) Descriptor() channel.Descriptor { DisplayName: "CLI", Configless: true, Capabilities: channel.ChannelCapabilities{ - Text: true, - Reply: true, - Attachments: true, + Text: true, + Reply: true, + Attachments: true, + Streaming: true, + BlockStreaming: true, }, TargetSpec: channel.TargetSpec{ - Format: "session_id", + Format: "bot_id", Hints: []channel.TargetHint{ - {Label: "Session ID", Example: "cli:uuid"}, + {Label: "Bot ID", Example: "bot_123"}, }, }, } } -// Send publishes an outbound message to the CLI session hub. +// Send publishes an outbound message to the CLI route hub. func (a *CLIAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { if a.hub == nil { return fmt.Errorf("cli hub not configured") @@ -58,3 +60,20 @@ func (a *CLIAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg ch a.hub.Publish(target, msg) return nil } + +// OpenStream opens a local stream session bound to the target route. +func (a *CLIAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { + if a.hub == nil { + return nil, fmt.Errorf("cli hub not configured") + } + target = strings.TrimSpace(target) + if target == "" { + return nil, fmt.Errorf("cli target is required") + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + return newLocalOutboundStream(a.hub, target), nil +} diff --git a/internal/channel/adapters/local/hub.go b/internal/channel/adapters/local/hub.go index ddc4b4bd..0fef9edb 100644 --- a/internal/channel/adapters/local/hub.go +++ b/internal/channel/adapters/local/hub.go @@ -1,51 +1,60 @@ package local import ( + "context" + "fmt" "sync" + "sync/atomic" "github.com/google/uuid" "github.com/memohai/memoh/internal/channel" ) -// SessionHub is a pub/sub hub that routes outbound messages to CLI/Web session subscribers. -type SessionHub struct { - mu sync.RWMutex - sessions map[string]map[string]chan channel.OutboundMessage +// RouteHubEvent is a routed outbound stream event for local transports. +type RouteHubEvent struct { + Target string `json:"target"` + Event channel.StreamEvent `json:"event"` } -// NewSessionHub creates an empty SessionHub. -func NewSessionHub() *SessionHub { - return &SessionHub{ - sessions: map[string]map[string]chan channel.OutboundMessage{}, +// RouteHub is a pub/sub hub that routes outbound messages to local subscribers by route key. +type RouteHub struct { + mu sync.RWMutex + streams map[string]map[string]chan RouteHubEvent +} + +// NewRouteHub creates an empty RouteHub. +func NewRouteHub() *RouteHub { + return &RouteHub{ + streams: map[string]map[string]chan RouteHubEvent{}, } } -// Subscribe registers a new stream for the given session and returns a stream ID, +// Subscribe registers a new stream for the given route key and returns a stream ID, // a read-only channel for messages, and a cancel function to unsubscribe. -func (h *SessionHub) Subscribe(sessionID string) (string, <-chan channel.OutboundMessage, func()) { +func (h *RouteHub) Subscribe(routeKey string) (string, <-chan RouteHubEvent, func()) { streamID := uuid.NewString() - ch := make(chan channel.OutboundMessage, 32) + ch := make(chan RouteHubEvent, 32) h.mu.Lock() - streams, ok := h.sessions[sessionID] + streams, ok := h.streams[routeKey] if !ok { - streams = map[string]chan channel.OutboundMessage{} - h.sessions[sessionID] = streams + streams = map[string]chan RouteHubEvent{} + h.streams[routeKey] = streams } streams[streamID] = ch h.mu.Unlock() cancel := func() { h.mu.Lock() - streams := h.sessions[sessionID] + streams := h.streams[routeKey] if streams != nil { if current, ok := streams[streamID]; ok { delete(streams, streamID) close(current) } if len(streams) == 0 { - delete(h.sessions, sessionID) + delete(h.streams, routeKey) } } h.mu.Unlock() @@ -54,16 +63,73 @@ func (h *SessionHub) Subscribe(sessionID string) (string, <-chan channel.Outboun return streamID, ch, cancel } -// Publish delivers a message to all subscribers of the given session. +// Publish delivers a message to all subscribers of the given route key. // Slow receivers are silently dropped. -func (h *SessionHub) Publish(sessionID string, msg channel.OutboundMessage) { +func (h *RouteHub) Publish(routeKey string, msg channel.OutboundMessage) { + h.PublishEvent(routeKey, channel.StreamEvent{ + Type: channel.StreamEventFinal, + Final: &channel.StreamFinalizePayload{ + Message: msg.Message, + }, + }) +} + +// PublishEvent delivers a stream event to all subscribers of the given route key. +// Slow receivers are silently dropped. +func (h *RouteHub) PublishEvent(routeKey string, event channel.StreamEvent) { h.mu.RLock() defer h.mu.RUnlock() - for _, ch := range h.sessions[sessionID] { + for _, ch := range h.streams[routeKey] { + payload := RouteHubEvent{ + Target: routeKey, + Event: event, + } select { - case ch <- msg: + case ch <- payload: default: // Drop if receiver is slow. } } } + +type localOutboundStream struct { + hub *RouteHub + target string + closed atomic.Bool +} + +func newLocalOutboundStream(hub *RouteHub, target string) channel.OutboundStream { + return &localOutboundStream{ + hub: hub, + target: target, + } +} + +func (s *localOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { + if s == nil || s.hub == nil { + return fmt.Errorf("route hub not configured") + } + if s.closed.Load() { + return fmt.Errorf("stream is closed") + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + s.hub.PublishEvent(s.target, event) + return nil +} + +func (s *localOutboundStream) Close(ctx context.Context) error { + if s == nil { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + s.closed.Store(true) + return nil +} diff --git a/internal/channel/adapters/local/hub_test.go b/internal/channel/adapters/local/hub_test.go new file mode 100644 index 00000000..d9814af9 --- /dev/null +++ b/internal/channel/adapters/local/hub_test.go @@ -0,0 +1,50 @@ +package local + +import ( + "context" + "testing" + "time" + + "github.com/memohai/memoh/internal/channel" +) + +func TestRouteHubPublishEvent(t *testing.T) { + t.Parallel() + + hub := NewRouteHub() + _, stream, cancel := hub.Subscribe("bot-1") + defer cancel() + + hub.PublishEvent("bot-1", channel.StreamEvent{ + Type: channel.StreamEventDelta, + Delta: "hello", + }) + + select { + case item := <-stream: + if item.Target != "bot-1" { + t.Fatalf("unexpected target: %s", item.Target) + } + if item.Event.Type != channel.StreamEventDelta { + t.Fatalf("unexpected event type: %s", item.Event.Type) + } + case <-time.After(time.Second): + t.Fatal("expected event but timed out") + } +} + +func TestLocalOutboundStreamClose(t *testing.T) { + t.Parallel() + + hub := NewRouteHub() + stream := newLocalOutboundStream(hub, "bot-2") + if err := stream.Close(context.Background()); err != nil { + t.Fatalf("unexpected close error: %v", err) + } + if err := stream.Push(context.Background(), channel.StreamEvent{ + Type: channel.StreamEventDelta, + Delta: "should fail", + }); err == nil { + t.Fatal("expected push on closed stream to fail") + } +} diff --git a/internal/channel/adapters/local/web.go b/internal/channel/adapters/local/web.go index 1490604b..70309748 100644 --- a/internal/channel/adapters/local/web.go +++ b/internal/channel/adapters/local/web.go @@ -10,11 +10,11 @@ import ( // WebAdapter implements channel.Sender for the local Web channel. type WebAdapter struct { - hub *SessionHub + hub *RouteHub } -// NewWebAdapter creates a WebAdapter backed by the given session hub. -func NewWebAdapter(hub *SessionHub) *WebAdapter { +// NewWebAdapter creates a WebAdapter backed by the given route hub. +func NewWebAdapter(hub *RouteHub) *WebAdapter { return &WebAdapter{hub: hub} } @@ -30,20 +30,22 @@ func (a *WebAdapter) Descriptor() channel.Descriptor { DisplayName: "Web", Configless: true, Capabilities: channel.ChannelCapabilities{ - Text: true, - Reply: true, - Attachments: true, + Text: true, + Reply: true, + Attachments: true, + Streaming: true, + BlockStreaming: true, }, TargetSpec: channel.TargetSpec{ - Format: "session_id", + Format: "bot_id", Hints: []channel.TargetHint{ - {Label: "Session ID", Example: "web:uuid"}, + {Label: "Bot ID", Example: "bot_123"}, }, }, } } -// Send publishes an outbound message to the Web session hub. +// Send publishes an outbound message to the Web route hub. func (a *WebAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { if a.hub == nil { return fmt.Errorf("web hub not configured") @@ -58,3 +60,20 @@ func (a *WebAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg ch a.hub.Publish(target, msg) return nil } + +// OpenStream opens a local stream session bound to the target route. +func (a *WebAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { + if a.hub == nil { + return nil, fmt.Errorf("web hub not configured") + } + target = strings.TrimSpace(target) + if target == "" { + return nil, fmt.Errorf("web target is required") + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + return newLocalOutboundStream(a.hub, target), nil +} diff --git a/internal/channel/adapters/telegram/directory.go b/internal/channel/adapters/telegram/directory.go new file mode 100644 index 00000000..7901cdeb --- /dev/null +++ b/internal/channel/adapters/telegram/directory.go @@ -0,0 +1,239 @@ +package telegram + +import ( + "context" + "fmt" + "strconv" + "strings" + + tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + + "github.com/memohai/memoh/internal/channel" +) + +const ( + defaultDirectoryLimit = 50 + maxDirectoryLimit = 200 +) + +func directoryLimit(n int) int { + if n <= 0 { + return defaultDirectoryLimit + } + if n > maxDirectoryLimit { + return maxDirectoryLimit + } + return n +} + +// ListPeers returns users the bot can reach. Telegram Bot API does not provide a list of users; returns empty. +func (a *TelegramAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} + +// ListGroups returns chats the bot is in. Telegram Bot API does not provide a list of chats; returns empty. +func (a *TelegramAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} + +// ListGroupMembers returns administrators of the given group (Telegram only exposes admin list, not full members). +func (a *TelegramAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelConfig, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + telegramCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, err + } + bot, err := a.getOrCreateBot(telegramCfg.BotToken, cfg.ID) + if err != nil { + return nil, err + } + chatID, superGroupUsername := parseTelegramChatInput(strings.TrimSpace(groupID)) + if chatID == 0 && superGroupUsername == "" { + return nil, fmt.Errorf("telegram list group members: invalid group id %q", groupID) + } + config := tgbotapi.ChatAdministratorsConfig{ + ChatConfig: tgbotapi.ChatConfig{ChatID: chatID, SuperGroupUsername: superGroupUsername}, + } + members, err := bot.GetChatAdministrators(config) + if err != nil { + return nil, fmt.Errorf("telegram get chat administrators: %w", err) + } + limit := directoryLimit(query.Limit) + entries := make([]channel.DirectoryEntry, 0, limit) + for i, m := range members { + if i >= limit { + break + } + if m.User == nil { + continue + } + e := telegramUserToEntry(m.User) + if query.Query != "" && !strings.Contains(strings.ToLower(e.Name+e.Handle), strings.ToLower(query.Query)) { + continue + } + entries = append(entries, e) + } + return entries, nil +} + +// ResolveEntry resolves an input string to a user or group DirectoryEntry using getChat / getChatMember. +func (a *TelegramAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + telegramCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return channel.DirectoryEntry{}, err + } + bot, err := a.getOrCreateBot(telegramCfg.BotToken, cfg.ID) + if err != nil { + return channel.DirectoryEntry{}, err + } + input = strings.TrimSpace(input) + switch kind { + case channel.DirectoryEntryUser: + return a.resolveTelegramUser(ctx, bot, input) + case channel.DirectoryEntryGroup: + return a.resolveTelegramGroup(ctx, bot, input) + default: + return channel.DirectoryEntry{}, fmt.Errorf("telegram resolve entry: unsupported kind %q", kind) + } +} + +func (a *TelegramAdapter) resolveTelegramUser(ctx context.Context, bot *tgbotapi.BotAPI, input string) (channel.DirectoryEntry, error) { + chatID, userID := parseTelegramUserInput(input) + if chatID == 0 && userID == 0 { + return channel.DirectoryEntry{}, fmt.Errorf("telegram resolve entry user: invalid input %q", input) + } + if userID != 0 { + config := tgbotapi.GetChatMemberConfig{ + ChatConfigWithUser: tgbotapi.ChatConfigWithUser{ + ChatID: chatID, + SuperGroupUsername: "", + UserID: userID, + }, + } + member, err := bot.GetChatMember(config) + if err != nil { + return channel.DirectoryEntry{}, fmt.Errorf("telegram get chat member: %w", err) + } + if member.User == nil { + return channel.DirectoryEntry{}, fmt.Errorf("telegram get chat member: empty user") + } + return telegramUserToEntry(member.User), nil + } + chatConfig := tgbotapi.ChatInfoConfig{ChatConfig: tgbotapi.ChatConfig{ChatID: chatID}} + chat, err := bot.GetChat(chatConfig) + if err != nil { + return channel.DirectoryEntry{}, fmt.Errorf("telegram get chat: %w", err) + } + if !chat.IsPrivate() { + return channel.DirectoryEntry{}, fmt.Errorf("telegram resolve entry user: chat %d is not private", chatID) + } + name := strings.TrimSpace(chat.FirstName + " " + chat.LastName) + if name == "" { + name = strings.TrimSpace(chat.Title) + } + handle := strings.TrimSpace(chat.UserName) + if handle != "" && !strings.HasPrefix(handle, "@") { + handle = "@" + handle + } + idStr := strconv.FormatInt(chat.ID, 10) + return channel.DirectoryEntry{ + Kind: channel.DirectoryEntryUser, + ID: idStr, + Name: name, + Handle: handle, + Metadata: map[string]any{ + "chat_id": idStr, + "username": chat.UserName, + }, + }, nil +} + +func (a *TelegramAdapter) resolveTelegramGroup(ctx context.Context, bot *tgbotapi.BotAPI, input string) (channel.DirectoryEntry, error) { + chatID, superGroupUsername := parseTelegramChatInput(input) + if chatID == 0 && superGroupUsername == "" { + return channel.DirectoryEntry{}, fmt.Errorf("telegram resolve entry group: invalid input %q", input) + } + config := tgbotapi.ChatInfoConfig{ + ChatConfig: tgbotapi.ChatConfig{ChatID: chatID, SuperGroupUsername: superGroupUsername}, + } + chat, err := bot.GetChat(config) + if err != nil { + return channel.DirectoryEntry{}, fmt.Errorf("telegram get chat: %w", err) + } + idStr := strconv.FormatInt(chat.ID, 10) + name := strings.TrimSpace(chat.Title) + if name == "" { + name = strings.TrimSpace(chat.FirstName + " " + chat.LastName) + } + handle := strings.TrimSpace(chat.UserName) + if handle != "" && !strings.HasPrefix(handle, "@") { + handle = "@" + handle + } + return channel.DirectoryEntry{ + Kind: channel.DirectoryEntryGroup, + ID: idStr, + Name: name, + Handle: handle, + Metadata: map[string]any{"chat_id": idStr, "type": chat.Type}, + }, nil +} + +// parseTelegramChatInput parses input as chat_id (numeric) or @channel_username. Returns (chatID, superGroupUsername). +func parseTelegramChatInput(input string) (chatID int64, superGroupUsername string) { + input = strings.TrimSpace(input) + if input == "" { + return 0, "" + } + if strings.HasPrefix(input, "@") { + return 0, input + } + id, err := strconv.ParseInt(input, 10, 64) + if err != nil { + return 0, "" + } + return id, "" +} + +// parseTelegramUserInput parses input as "chat_id" (private chat) or "chat_id:user_id". Returns (chatID, userID); userID 0 means resolve as private chat. +func parseTelegramUserInput(input string) (chatID, userID int64) { + input = strings.TrimSpace(input) + if input == "" { + return 0, 0 + } + if idx := strings.Index(input, ":"); idx >= 0 { + left := strings.TrimSpace(input[:idx]) + right := strings.TrimSpace(input[idx+1:]) + cid, err1 := strconv.ParseInt(left, 10, 64) + uid, err2 := strconv.ParseInt(right, 10, 64) + if err1 == nil && err2 == nil && cid != 0 && uid != 0 { + return cid, uid + } + } + id, err := strconv.ParseInt(input, 10, 64) + if err != nil { + return 0, 0 + } + return id, 0 +} + +func telegramUserToEntry(u *tgbotapi.User) channel.DirectoryEntry { + if u == nil { + return channel.DirectoryEntry{Kind: channel.DirectoryEntryUser} + } + name := strings.TrimSpace(u.FirstName + " " + u.LastName) + handle := strings.TrimSpace(u.UserName) + if handle != "" && !strings.HasPrefix(handle, "@") { + handle = "@" + handle + } + idStr := strconv.FormatInt(u.ID, 10) + meta := map[string]any{"user_id": idStr} + if u.UserName != "" { + meta["username"] = u.UserName + } + return channel.DirectoryEntry{ + Kind: channel.DirectoryEntryUser, + ID: idStr, + Name: name, + Handle: handle, + Metadata: meta, + } +} diff --git a/internal/channel/adapters/telegram/directory_test.go b/internal/channel/adapters/telegram/directory_test.go new file mode 100644 index 00000000..ebaf313e --- /dev/null +++ b/internal/channel/adapters/telegram/directory_test.go @@ -0,0 +1,99 @@ +package telegram + +import ( + "strconv" + "testing" + + tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + + "github.com/memohai/memoh/internal/channel" +) + +func Test_directoryLimit(t *testing.T) { + tests := []struct { + name string + n int + want int + }{ + {"zero", 0, defaultDirectoryLimit}, + {"negative", -1, defaultDirectoryLimit}, + {"one", 1, 1}, + {"default", defaultDirectoryLimit, defaultDirectoryLimit}, + {"over max", maxDirectoryLimit + 100, maxDirectoryLimit}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := directoryLimit(tt.n); got != tt.want { + t.Errorf("directoryLimit() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_parseTelegramChatInput(t *testing.T) { + tests := []struct { + input string + wantID int64 + wantUsername string + }{ + {"123456789", 123456789, ""}, + {" -100123 ", -100123, ""}, + {"@channel", 0, "@channel"}, + {" @supergroup ", 0, "@supergroup"}, + {"", 0, ""}, + {"abc", 0, ""}, + } + for _, tt := range tests { + chatID, username := parseTelegramChatInput(tt.input) + if chatID != tt.wantID || username != tt.wantUsername { + t.Errorf("parseTelegramChatInput(%q) = %d, %q; want %d, %q", tt.input, chatID, username, tt.wantID, tt.wantUsername) + } + } +} + +func Test_parseTelegramUserInput(t *testing.T) { + tests := []struct { + input string + wantChat int64 + wantUser int64 + }{ + {"12345", 12345, 0}, + {" -100 ", -100, 0}, + {"12345:67890", 12345, 67890}, + {" -100 : 200 ", -100, 200}, + {"", 0, 0}, + {"abc", 0, 0}, + {"1:2:3", 0, 0}, + } + for _, tt := range tests { + chatID, userID := parseTelegramUserInput(tt.input) + if chatID != tt.wantChat || userID != tt.wantUser { + t.Errorf("parseTelegramUserInput(%q) = %d, %d; want %d, %d", tt.input, chatID, userID, tt.wantChat, tt.wantUser) + } + } +} + +func Test_telegramUserToEntry(t *testing.T) { + u := &tgbotapi.User{ID: 123, UserName: "alice", FirstName: "Alice", LastName: "Smith"} + e := telegramUserToEntry(u) + if e.Kind != channel.DirectoryEntryUser { + t.Errorf("Kind = %q", e.Kind) + } + if e.ID != strconv.FormatInt(123, 10) { + t.Errorf("ID = %q", e.ID) + } + if e.Name != "Alice Smith" { + t.Errorf("Name = %q", e.Name) + } + if e.Handle != "@alice" { + t.Errorf("Handle = %q", e.Handle) + } + if e.Metadata["user_id"] != "123" || e.Metadata["username"] != "alice" { + t.Errorf("Metadata = %+v", e.Metadata) + } + // nil user + e2 := telegramUserToEntry(nil) + if e2.Kind != channel.DirectoryEntryUser || e2.ID != "" { + t.Errorf("telegramUserToEntry(nil) = %+v", e2) + } +} diff --git a/internal/channel/adapters/telegram/logger.go b/internal/channel/adapters/telegram/logger.go new file mode 100644 index 00000000..e2fa7645 --- /dev/null +++ b/internal/channel/adapters/telegram/logger.go @@ -0,0 +1,19 @@ +package telegram + +import ( + "fmt" + "log/slog" +) + +// slogBotLogger adapts slog.Logger to tgbotapi.BotLogger so library logs go through slog. +type slogBotLogger struct { + log *slog.Logger +} + +func (s *slogBotLogger) Println(v ...interface{}) { + s.log.Warn(fmt.Sprint(v...)) +} + +func (s *slogBotLogger) Printf(format string, v ...interface{}) { + s.log.Warn(fmt.Sprintf(format, v...)) +} diff --git a/internal/channel/adapters/telegram/logger_test.go b/internal/channel/adapters/telegram/logger_test.go new file mode 100644 index 00000000..dbcbb2d6 --- /dev/null +++ b/internal/channel/adapters/telegram/logger_test.go @@ -0,0 +1,46 @@ +package telegram + +import ( + "bytes" + "log/slog" + "testing" +) + +func TestSlogBotLogger_Println(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + log := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) + w := &slogBotLogger{log: log} + + w.Println("hello") + if !bytes.Contains(buf.Bytes(), []byte("level=WARN")) { + t.Fatalf("expected WARN level in output: %s", buf.String()) + } + if !bytes.Contains(buf.Bytes(), []byte("hello")) { + t.Fatalf("expected message in output: %s", buf.String()) + } + + buf.Reset() + w.Println("err", 123) + out := buf.String() + if !bytes.Contains(buf.Bytes(), []byte("err")) || !bytes.Contains(buf.Bytes(), []byte("123")) { + t.Fatalf("expected err and 123 in output: %s", out) + } +} + +func TestSlogBotLogger_Printf(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + log := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) + w := &slogBotLogger{log: log} + + w.Printf("retrying in %d seconds...", 3) + if !bytes.Contains(buf.Bytes(), []byte("level=WARN")) { + t.Fatalf("expected WARN level: %s", buf.String()) + } + if !bytes.Contains(buf.Bytes(), []byte("retrying in 3 seconds")) { + t.Fatalf("expected formatted message: %s", buf.String()) + } +} diff --git a/internal/channel/adapters/telegram/markdown.go b/internal/channel/adapters/telegram/markdown.go new file mode 100644 index 00000000..a3fba424 --- /dev/null +++ b/internal/channel/adapters/telegram/markdown.go @@ -0,0 +1,210 @@ +package telegram + +import ( + "fmt" + "regexp" + "strings" + + tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + + "github.com/memohai/memoh/internal/channel" +) + +const ( + codeBlockPlaceholder = "\x00CB" + inlineCodePlaceholder = "\x00IC" +) + +var ( + reCodeBlockFence = regexp.MustCompile("(?s)```(\\w*)\\n?(.*?)```") + reInlineCode = regexp.MustCompile("`([^`\\n]+?)`") + reBold = regexp.MustCompile(`\*\*(.+?)\*\*`) + reStrike = regexp.MustCompile(`~~(.+?)~~`) + reLink = regexp.MustCompile(`\[([^\]]+?)\]\(([^)]+?)\)`) + reHeading = regexp.MustCompile(`(?m)^#{1,6}\s+(.+)$`) + reListBullet = regexp.MustCompile(`(?m)^(\s*)[-+]\s`) + reItalic = regexp.MustCompile(`\*([^*\n]+?)\*`) +) + +// formatTelegramOutput converts standard markdown to Telegram-compatible HTML +// when the message format is markdown. Returns the formatted text and the +// Telegram parse mode to use. +func formatTelegramOutput(text string, format channel.MessageFormat) (string, string) { + if format == channel.MessageFormatMarkdown && strings.TrimSpace(text) != "" { + return markdownToTelegramHTML(text), tgbotapi.ModeHTML + } + return text, "" +} + +// markdownToTelegramHTML converts standard markdown to Telegram-compatible HTML. +// +// Supported conversions: +// - Fenced code blocks (```lang ... ```) →

+//   - Inline code (`code`) → 
+//   - Bold (**text**) → 
+//   - Italic (*text*) → 
+//   - Strikethrough (~~text~~) → 
+//   - Links ([text](url)) → 
+//   - Headings (# text) → 
+//   - Unordered lists (- item) → bullet
+//   - Block quotes (> text) → 
+func markdownToTelegramHTML(text string) string { + if strings.TrimSpace(text) == "" { + return text + } + + // Split by fenced code blocks (``` ... ```). + // Even-indexed segments are normal text, odd-indexed are code content. + segments := splitCodeBlocks(text) + var buf strings.Builder + for i, seg := range segments { + if i%2 == 0 { + buf.WriteString(convertInlineMarkdown(seg)) + } else { + lang, code := extractCodeBlockLang(seg) + escaped := telegramEscapeHTML(strings.TrimRight(code, "\n")) + if lang != "" { + buf.WriteString(fmt.Sprintf("
%s
", lang, escaped)) + } else { + buf.WriteString("
" + escaped + "
") + } + } + } + return strings.TrimSpace(buf.String()) +} + +// splitCodeBlocks splits text by triple-backtick fences. +// Returns alternating [normal, code, normal, code, ...] segments. +func splitCodeBlocks(text string) []string { + const fence = "```" + var segments []string + for { + start := strings.Index(text, fence) + if start < 0 { + segments = append(segments, text) + break + } + segments = append(segments, text[:start]) + rest := text[start+len(fence):] + end := strings.Index(rest, fence) + if end < 0 { + // Unclosed code block: treat remainder as normal text. + segments = append(segments, text[start:]) + // Remove the last normal segment and replace with full remainder. + segments[len(segments)-2] = segments[len(segments)-2] + segments[len(segments)-1] + segments = segments[:len(segments)-1] + break + } + segments = append(segments, rest[:end]) + text = rest[end+len(fence):] + } + return segments +} + +// extractCodeBlockLang separates the optional language tag from code content. +func extractCodeBlockLang(block string) (string, string) { + idx := strings.IndexByte(block, '\n') + if idx < 0 { + // Single line: check if it looks like a language tag. + trimmed := strings.TrimSpace(block) + if trimmed != "" && !strings.Contains(trimmed, " ") && len(trimmed) <= 20 { + return trimmed, "" + } + return "", block + } + firstLine := strings.TrimSpace(block[:idx]) + rest := block[idx+1:] + if firstLine != "" && !strings.Contains(firstLine, " ") && len(firstLine) <= 20 { + return firstLine, rest + } + // No language tag: strip leading newline from content. + return "", strings.TrimLeft(block, "\n") +} + +// convertInlineMarkdown converts inline markdown formatting to Telegram HTML. +func convertInlineMarkdown(text string) string { + if strings.TrimSpace(text) == "" { + return text + } + + // Protect inline code spans from further processing. + var inlineCodes []string + text = reInlineCode.ReplaceAllStringFunc(text, func(match string) string { + idx := len(inlineCodes) + inlineCodes = append(inlineCodes, match) + return fmt.Sprintf("%s%d\x00", inlineCodePlaceholder, idx) + }) + + // Escape HTML entities. + text = telegramEscapeHTML(text) + + // Bold: **text** → text (must run before italic). + text = reBold.ReplaceAllString(text, "$1") + + // Strikethrough: ~~text~~ → text. + text = reStrike.ReplaceAllString(text, "$1") + + // Links: [text](url) →
text. + text = reLink.ReplaceAllString(text, `$1`) + + // Headings: # text → bold line. + text = reHeading.ReplaceAllString(text, "$1") + + // Unordered lists: - item / + item → bullet. + text = reListBullet.ReplaceAllString(text, "${1}• ") + + // Italic: *text* → text (after bold, so ** is already consumed). + text = reItalic.ReplaceAllString(text, "$1") + + // Block quotes: > text →
. + text = convertBlockquotes(text) + + // Restore inline code spans. + for i, original := range inlineCodes { + sub := reInlineCode.FindStringSubmatch(original) + content := "" + if len(sub) >= 2 { + content = sub[1] + } + placeholder := fmt.Sprintf("%s%d\x00", inlineCodePlaceholder, i) + text = strings.Replace(text, placeholder, ""+telegramEscapeHTML(content)+"", 1) + } + + return text +} + +// convertBlockquotes converts markdown block quotes to Telegram HTML blockquotes. +// After HTML escaping, ">" becomes ">", so we match the escaped form. +func convertBlockquotes(text string) string { + lines := strings.Split(text, "\n") + var result []string + var quoteLines []string + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "> ") || trimmed == ">" { + content := strings.TrimPrefix(trimmed, "> ") + if content == ">" { + content = "" + } + quoteLines = append(quoteLines, content) + } else { + if len(quoteLines) > 0 { + result = append(result, "
"+strings.Join(quoteLines, "\n")+"
") + quoteLines = nil + } + result = append(result, line) + } + } + if len(quoteLines) > 0 { + result = append(result, "
"+strings.Join(quoteLines, "\n")+"
") + } + return strings.Join(result, "\n") +} + +// telegramEscapeHTML escapes characters that are special in HTML. +func telegramEscapeHTML(text string) string { + text = strings.ReplaceAll(text, "&", "&") + text = strings.ReplaceAll(text, "<", "<") + text = strings.ReplaceAll(text, ">", ">") + return text +} diff --git a/internal/channel/adapters/telegram/markdown_test.go b/internal/channel/adapters/telegram/markdown_test.go new file mode 100644 index 00000000..e2d1c790 --- /dev/null +++ b/internal/channel/adapters/telegram/markdown_test.go @@ -0,0 +1,209 @@ +package telegram + +import ( + "strings" + "testing" + + tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + + "github.com/memohai/memoh/internal/channel" +) + +func TestMarkdownToTelegramHTML(t *testing.T) { + tests := []struct { + name string + input string + contains []string + absent []string + }{ + { + name: "bold", + input: "this is **bold** text", + contains: []string{"bold"}, + absent: []string{"**"}, + }, + { + name: "italic", + input: "this is *italic* text", + contains: []string{"italic"}, + }, + { + name: "bold and italic", + input: "**bold** and *italic*", + contains: []string{"bold", "italic"}, + }, + { + name: "strikethrough", + input: "this is ~~deleted~~ text", + contains: []string{"deleted"}, + absent: []string{"~~"}, + }, + { + name: "inline code", + input: "use `fmt.Println` here", + contains: []string{"fmt.Println"}, + absent: []string{"`fmt.Println`"}, + }, + { + name: "link", + input: "visit [Google](https://google.com)", + contains: []string{`Google`}, + }, + { + name: "heading", + input: "# Title\nsome text", + contains: []string{"Title"}, + absent: []string{"# Title"}, + }, + { + name: "unordered list", + input: "- first\n- second\n+ third", + contains: []string{"• first", "• second", "• third"}, + }, + { + name: "fenced code block", + input: "```go\nfmt.Println(\"hello\")\n```", + contains: []string{`
`, "fmt.Println", "
"}, + }, + { + name: "fenced code block without language", + input: "```\nplain code\n```", + contains: []string{"
plain code
"}, + }, + { + name: "html entities escaped", + input: "a < b && c > d", + contains: []string{"<", "&&", ">"}, + absent: []string{"< b", "> d"}, + }, + { + name: "code block preserves content", + input: "```\n**not bold** \n```", + contains: []string{"**not bold**", "<tag>"}, + absent: []string{"", ""}, + }, + { + name: "inline code preserves content", + input: "use `**not bold**` inline", + contains: []string{"**not bold**"}, + absent: []string{""}, + }, + { + name: "blockquote", + input: "> quoted line\n> another line", + contains: []string{"
"}, + }, + { + name: "empty input", + input: "", + contains: nil, + }, + { + name: "plain text no conversion", + input: "just plain text here", + contains: []string{"just plain text here"}, + }, + { + name: "link with ampersand in url", + input: "[search](https://example.com?a=1&b=2)", + contains: []string{`search`}, + }, + { + name: "bold inside link", + input: "**[click here](https://example.com)**", + contains: []string{"", `click here`, ""}, + }, + { + name: "mixed formatting", + input: "# Summary\n\n**Hello** world, visit [docs](https://docs.io).\n\n- item one\n- item two\n\n```python\nprint(\"hi\")\n```", + contains: []string{"Summary", "Hello", `docs`, "• item one", `class="language-python"`}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := markdownToTelegramHTML(tt.input) + for _, want := range tt.contains { + if !strings.Contains(result, want) { + t.Errorf("expected result to contain %q, got:\n%s", want, result) + } + } + for _, absent := range tt.absent { + if strings.Contains(result, absent) { + t.Errorf("expected result NOT to contain %q, got:\n%s", absent, result) + } + } + }) + } +} + +func TestFormatTelegramOutput(t *testing.T) { + tests := []struct { + name string + text string + format channel.MessageFormat + wantMode string + wantContains string + }{ + { + name: "markdown format returns html mode", + text: "**bold**", + format: channel.MessageFormatMarkdown, + wantMode: tgbotapi.ModeHTML, + wantContains: "bold", + }, + { + name: "plain format returns empty mode", + text: "hello", + format: channel.MessageFormatPlain, + wantMode: "", + }, + { + name: "empty text returns empty mode", + text: "", + format: channel.MessageFormatMarkdown, + wantMode: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + text, mode := formatTelegramOutput(tt.text, tt.format) + if mode != tt.wantMode { + t.Errorf("expected mode %q, got %q", tt.wantMode, mode) + } + if tt.wantContains != "" && !strings.Contains(text, tt.wantContains) { + t.Errorf("expected text to contain %q, got %q", tt.wantContains, text) + } + }) + } +} + +func TestSplitCodeBlocks(t *testing.T) { + tests := []struct { + name string + input string + want int // expected number of segments + }{ + {name: "no code blocks", input: "hello world", want: 1}, + {name: "one code block", input: "before```code```after", want: 3}, + {name: "two code blocks", input: "a```b```c```d```e", want: 5}, + {name: "unclosed code block", input: "before```unclosed", want: 1}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + segments := splitCodeBlocks(tt.input) + if len(segments) != tt.want { + t.Errorf("expected %d segments, got %d: %v", tt.want, len(segments), segments) + } + }) + } +} + +func TestTelegramEscapeHTML(t *testing.T) { + input := `a < b & c > d` + result := telegramEscapeHTML(input) + if result != "a < b & c > d" { + t.Errorf("unexpected escape result: %s", result) + } +} diff --git a/internal/channel/adapters/telegram/stream.go b/internal/channel/adapters/telegram/stream.go new file mode 100644 index 00000000..771d565a --- /dev/null +++ b/internal/channel/adapters/telegram/stream.go @@ -0,0 +1,211 @@ +package telegram + +import ( + "context" + "fmt" + "log/slog" + "strings" + "sync" + "sync/atomic" + "time" + + tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + + "github.com/memohai/memoh/internal/channel" +) + +const telegramStreamEditThrottle = 350 * time.Millisecond + +type telegramOutboundStream struct { + adapter *TelegramAdapter + cfg channel.ChannelConfig + target string + reply *channel.ReplyRef + parseMode string + closed atomic.Bool + mu sync.Mutex + buf strings.Builder + streamChatID int64 + streamMsgID int + lastEdited string + lastEditedAt time.Time +} + +func (s *telegramOutboundStream) getBotAndReply(ctx context.Context) (bot *tgbotapi.BotAPI, replyTo int, err error) { + telegramCfg, err := parseConfig(s.cfg.Credentials) + if err != nil { + return nil, 0, err + } + bot, err = s.adapter.getOrCreateBot(telegramCfg.BotToken, s.cfg.ID) + if err != nil { + return nil, 0, err + } + replyTo = parseReplyToMessageID(s.reply) + return bot, replyTo, nil +} + +func (s *telegramOutboundStream) ensureStreamMessage(ctx context.Context, text string) error { + s.mu.Lock() + if s.streamMsgID != 0 { + s.mu.Unlock() + return nil + } + bot, replyTo, err := s.getBotAndReply(ctx) + if err != nil { + s.mu.Unlock() + return err + } + if strings.TrimSpace(text) == "" { + text = "..." + } + chatID, msgID, err := sendTelegramTextReturnMessage(bot, s.target, text, replyTo, s.parseMode) + if err != nil { + s.mu.Unlock() + return err + } + s.streamChatID = chatID + s.streamMsgID = msgID + s.lastEdited = text + s.lastEditedAt = time.Now() + s.mu.Unlock() + return nil +} + +func (s *telegramOutboundStream) editStreamMessage(ctx context.Context, text string) error { + s.mu.Lock() + chatID := s.streamChatID + msgID := s.streamMsgID + lastEdited := s.lastEdited + lastEditedAt := s.lastEditedAt + s.mu.Unlock() + if msgID == 0 { + return nil + } + if strings.TrimSpace(text) == lastEdited { + return nil + } + if time.Since(lastEditedAt) < telegramStreamEditThrottle && !strings.Contains(text, "\n") { + return nil + } + bot, _, err := s.getBotAndReply(ctx) + if err != nil { + return err + } + if err := editTelegramMessageText(bot, chatID, msgID, text, s.parseMode); err != nil { + return err + } + s.mu.Lock() + s.lastEdited = text + s.lastEditedAt = time.Now() + s.mu.Unlock() + return nil +} + +func (s *telegramOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { + if s == nil || s.adapter == nil { + return fmt.Errorf("telegram stream not configured") + } + if s.closed.Load() { + return fmt.Errorf("telegram stream is closed") + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + switch event.Type { + case channel.StreamEventStatus: + return nil + case channel.StreamEventDelta: + if event.Delta == "" { + return nil + } + s.mu.Lock() + s.buf.WriteString(event.Delta) + content := s.buf.String() + s.mu.Unlock() + if err := s.ensureStreamMessage(ctx, content); err != nil { + return err + } + return s.editStreamMessage(ctx, content) + case channel.StreamEventFinal: + if event.Final == nil || event.Final.Message.IsEmpty() { + s.mu.Lock() + finalText := strings.TrimSpace(s.buf.String()) + s.mu.Unlock() + if finalText != "" { + _ = s.ensureStreamMessage(ctx, finalText) + _ = s.editStreamMessage(ctx, finalText) + } + return nil + } + msg := event.Final.Message + finalText := strings.TrimSpace(msg.PlainText()) + s.mu.Lock() + if finalText == "" { + finalText = strings.TrimSpace(s.buf.String()) + } + s.mu.Unlock() + // Convert markdown to Telegram HTML for the final message. + formatted, pm := formatTelegramOutput(finalText, msg.Format) + if pm != "" { + s.mu.Lock() + s.parseMode = pm + s.mu.Unlock() + finalText = formatted + } + if err := s.ensureStreamMessage(ctx, finalText); err != nil { + return err + } + if err := s.editStreamMessage(ctx, finalText); err != nil { + return err + } + if len(msg.Attachments) > 0 { + replyTo := parseReplyToMessageID(s.reply) + telegramCfg, err := parseConfig(s.cfg.Credentials) + if err != nil { + return err + } + bot, err := s.adapter.getOrCreateBot(telegramCfg.BotToken, s.cfg.ID) + if err != nil { + return err + } + parseMode := resolveTelegramParseMode(msg.Format) + for i, att := range msg.Attachments { + rto := replyTo + if i > 0 { + rto = 0 + } + if err := sendTelegramAttachment(bot, s.target, att, "", rto, parseMode); err != nil && s.adapter.logger != nil { + s.adapter.logger.Error("stream final attachment failed", slog.String("config_id", s.cfg.ID), slog.Any("error", err)) + } + } + } + return nil + case channel.StreamEventError: + errText := strings.TrimSpace(event.Error) + if errText == "" { + return nil + } + display := "Error: " + errText + if err := s.ensureStreamMessage(ctx, display); err != nil { + return err + } + return s.editStreamMessage(ctx, display) + default: + return fmt.Errorf("unsupported stream event type: %s", event.Type) + } +} + +func (s *telegramOutboundStream) Close(ctx context.Context) error { + if s == nil { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + s.closed.Store(true) + return nil +} diff --git a/internal/channel/adapters/telegram/stream_test.go b/internal/channel/adapters/telegram/stream_test.go new file mode 100644 index 00000000..56ea4032 --- /dev/null +++ b/internal/channel/adapters/telegram/stream_test.go @@ -0,0 +1,119 @@ +package telegram + +import ( + "context" + "strings" + "testing" + + "github.com/memohai/memoh/internal/channel" +) + +func TestTelegramOutboundStream_CloseNil(t *testing.T) { + t.Parallel() + + var s *telegramOutboundStream + ctx := context.Background() + if err := s.Close(ctx); err != nil { + t.Fatalf("Close on nil stream should return nil: %v", err) + } +} + +func TestTelegramOutboundStream_PushClosed(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + s := &telegramOutboundStream{adapter: adapter} + s.closed.Store(true) + + ctx := context.Background() + err := s.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "x"}) + if err == nil { + t.Fatal("Push on closed stream should return error") + } + if !strings.Contains(err.Error(), "closed") { + t.Fatalf("expected closed error: %v", err) + } +} + +func TestTelegramOutboundStream_PushStatusNoOp(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + s := &telegramOutboundStream{adapter: adapter} + + ctx := context.Background() + err := s.Push(ctx, channel.StreamEvent{Type: channel.StreamEventStatus}) + if err != nil { + t.Fatalf("StreamEventStatus should be no-op: %v", err) + } +} + +func TestTelegramOutboundStream_PushNilAdapter(t *testing.T) { + t.Parallel() + + s := &telegramOutboundStream{adapter: nil} + ctx := context.Background() + err := s.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "x"}) + if err == nil { + t.Fatal("Push with nil adapter should return error") + } + if !strings.Contains(err.Error(), "not configured") { + t.Fatalf("expected not configured error: %v", err) + } +} + +func TestTelegramOutboundStream_PushUnsupportedEventType(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + s := &telegramOutboundStream{adapter: adapter} + ctx := context.Background() + + err := s.Push(ctx, channel.StreamEvent{Type: channel.StreamEventType("unknown")}) + if err == nil { + t.Fatal("Push with unknown event type should return error") + } + if !strings.Contains(err.Error(), "unsupported") { + t.Fatalf("expected unsupported error: %v", err) + } +} + +func TestTelegramOutboundStream_PushEmptyDeltaNoOp(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + s := &telegramOutboundStream{adapter: adapter} + ctx := context.Background() + + err := s.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: ""}) + if err != nil { + t.Fatalf("empty delta should be no-op: %v", err) + } +} + +func TestTelegramOutboundStream_PushErrorEventEmptyNoOp(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + s := &telegramOutboundStream{adapter: adapter} + ctx := context.Background() + + err := s.Push(ctx, channel.StreamEvent{Type: channel.StreamEventError, Error: ""}) + if err != nil { + t.Fatalf("empty error event should be no-op: %v", err) + } +} + +func TestTelegramOutboundStream_CloseContextCanceled(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + s := &telegramOutboundStream{adapter: adapter} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := s.Close(ctx) + if err != context.Canceled { + t.Fatalf("Close with canceled context should return context.Canceled: %v", err) + } +} diff --git a/internal/channel/adapters/telegram/telegram.go b/internal/channel/adapters/telegram/telegram.go index aa19ebe4..9f76b655 100644 --- a/internal/channel/adapters/telegram/telegram.go +++ b/internal/channel/adapters/telegram/telegram.go @@ -15,6 +15,8 @@ import ( "github.com/memohai/memoh/internal/channel/adapters/common" ) +const telegramMaxMessageLength = 4096 + // TelegramAdapter implements the channel.Adapter, channel.Sender, and channel.Receiver interfaces for Telegram. type TelegramAdapter struct { logger *slog.Logger @@ -27,10 +29,12 @@ func NewTelegramAdapter(log *slog.Logger) *TelegramAdapter { if log == nil { log = slog.Default() } - return &TelegramAdapter{ + adapter := &TelegramAdapter{ logger: log.With(slog.String("adapter", "telegram")), bots: make(map[string]*tgbotapi.BotAPI), } + _ = tgbotapi.SetLogger(&slogBotLogger{log: adapter.logger}) + return adapter } func (a *TelegramAdapter) getOrCreateBot(token, configID string) (*tgbotapi.BotAPI, error) { @@ -67,11 +71,13 @@ func (a *TelegramAdapter) Descriptor() channel.Descriptor { Type: Type, DisplayName: "Telegram", Capabilities: channel.ChannelCapabilities{ - Text: true, - Markdown: true, - Reply: true, - Attachments: true, - Media: true, + Text: true, + Markdown: true, + Reply: true, + Attachments: true, + Media: true, + Streaming: true, + BlockStreaming: true, }, ConfigSchema: channel.ConfigSchema{ Version: 1, @@ -277,7 +283,7 @@ func (a *TelegramAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, m return fmt.Errorf("message is required") } text := strings.TrimSpace(msg.Message.PlainText()) - parseMode := resolveTelegramParseMode(msg.Message.Format) + text, parseMode := formatTelegramOutput(text, msg.Message.Format) replyTo := parseReplyToMessageID(msg.Message.Reply) if len(msg.Message.Attachments) > 0 { usedCaption := false @@ -306,6 +312,28 @@ func (a *TelegramAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, m return sendTelegramText(bot, to, text, replyTo, parseMode) } +// OpenStream opens a Telegram streaming session. +// The adapter sends one message then edits it in place as deltas arrive (editMessageText), +// avoiding one message per delta and rate limits. +func (a *TelegramAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { + target = strings.TrimSpace(target) + if target == "" { + return nil, fmt.Errorf("telegram target is required") + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + return &telegramOutboundStream{ + adapter: a, + cfg: cfg, + target: target, + reply: opts.Reply, + parseMode: "", + }, nil +} + func resolveTelegramSender(msg *tgbotapi.Message) (string, string, map[string]string) { attrs := map[string]string{} if msg == nil { @@ -376,36 +404,69 @@ func parseReplyToMessageID(reply *channel.ReplyRef) int { } func sendTelegramText(bot *tgbotapi.BotAPI, target string, text string, replyTo int, parseMode string) error { + _, _, err := sendTelegramTextReturnMessage(bot, target, text, replyTo, parseMode) + return err +} + +// sendTelegramTextReturnMessage sends a text message and returns the chat ID and message ID for later editing. +func sendTelegramTextReturnMessage(bot *tgbotapi.BotAPI, target string, text string, replyTo int, parseMode string) (chatID int64, messageID int, err error) { + var sent tgbotapi.Message if strings.HasPrefix(target, "@") { message := tgbotapi.NewMessageToChannel(target, text) message.ParseMode = parseMode if replyTo > 0 { message.ReplyToMessageID = replyTo } - _, err := bot.Send(message) - return err + sent, err = bot.Send(message) + if err != nil { + return 0, 0, err + } + } else { + chatID, err = strconv.ParseInt(target, 10, 64) + if err != nil { + return 0, 0, fmt.Errorf("telegram target must be @username or chat_id") + } + message := tgbotapi.NewMessage(chatID, text) + message.ParseMode = parseMode + if replyTo > 0 { + message.ReplyToMessageID = replyTo + } + sent, err = bot.Send(message) + if err != nil { + return 0, 0, err + } } - chatID, err := strconv.ParseInt(target, 10, 64) - if err != nil { - return fmt.Errorf("telegram target must be @username or chat_id") + if sent.Chat != nil { + chatID = sent.Chat.ID } - message := tgbotapi.NewMessage(chatID, text) - message.ParseMode = parseMode - if replyTo > 0 { - message.ReplyToMessageID = replyTo + messageID = sent.MessageID + return chatID, messageID, nil +} + +func editTelegramMessageText(bot *tgbotapi.BotAPI, chatID int64, messageID int, text string, parseMode string) error { + if len(text) > telegramMaxMessageLength { + text = text[:telegramMaxMessageLength-3] + "..." } - _, err = bot.Send(message) + edit := tgbotapi.NewEditMessageText(chatID, messageID, text) + edit.ParseMode = parseMode + _, err := bot.Send(edit) return err } func sendTelegramAttachment(bot *tgbotapi.BotAPI, target string, att channel.Attachment, caption string, replyTo int, parseMode string) error { - if strings.TrimSpace(att.URL) == "" { - return fmt.Errorf("attachment url is required") + urlRef := strings.TrimSpace(att.URL) + keyRef := strings.TrimSpace(att.PlatformKey) + sourcePlatform := strings.TrimSpace(att.SourcePlatform) + if urlRef == "" && keyRef == "" { + return fmt.Errorf("attachment reference is required") } if strings.TrimSpace(caption) == "" && strings.TrimSpace(att.Caption) != "" { caption = strings.TrimSpace(att.Caption) } - file := tgbotapi.FileURL(att.URL) + file := tgbotapi.RequestFileData(tgbotapi.FileURL(urlRef)) + if keyRef != "" && (sourcePlatform == "" || strings.EqualFold(sourcePlatform, Type.String())) { + file = tgbotapi.FileID(keyRef) + } isChannel := strings.HasPrefix(target, "@") switch att.Type { case channel.AttachmentImage: @@ -668,12 +729,14 @@ func (a *TelegramAdapter) buildTelegramAttachment(bot *tgbotapi.BotAPI, attType } } att := channel.Attachment{ - Type: attType, - URL: strings.TrimSpace(url), - Name: strings.TrimSpace(name), - Mime: strings.TrimSpace(mime), - Size: size, - Metadata: map[string]any{}, + Type: attType, + URL: strings.TrimSpace(url), + PlatformKey: strings.TrimSpace(fileID), + SourcePlatform: Type.String(), + Name: strings.TrimSpace(name), + Mime: strings.TrimSpace(mime), + Size: size, + Metadata: map[string]any{}, } if fileID != "" { att.Metadata["file_id"] = fileID diff --git a/internal/channel/adapters/telegram/telegram_test.go b/internal/channel/adapters/telegram/telegram_test.go index 18c7492b..b3c4ca46 100644 --- a/internal/channel/adapters/telegram/telegram_test.go +++ b/internal/channel/adapters/telegram/telegram_test.go @@ -1,9 +1,12 @@ package telegram import ( + "context" + "strings" "testing" tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + "github.com/memohai/memoh/internal/channel" ) func TestResolveTelegramSender(t *testing.T) { @@ -63,3 +66,241 @@ func TestIsTelegramBotMentioned(t *testing.T) { } }) } + +func TestTelegramDescriptorIncludesStreaming(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + caps := adapter.Descriptor().Capabilities + if !caps.Streaming { + t.Fatal("expected streaming capability") + } + if !caps.Media { + t.Fatal("expected media capability") + } +} + +func TestBuildTelegramAttachmentIncludesPlatformReference(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + att := adapter.buildTelegramAttachment(nil, channel.AttachmentFile, "file_1", "doc.txt", "text/plain", 10) + if att.PlatformKey != "file_1" { + t.Fatalf("unexpected platform key: %s", att.PlatformKey) + } + if att.SourcePlatform != Type.String() { + t.Fatalf("unexpected source platform: %s", att.SourcePlatform) + } +} + +func TestParseReplyToMessageID(t *testing.T) { + t.Parallel() + + if got := parseReplyToMessageID(nil); got != 0 { + t.Fatalf("nil reply should return 0: %d", got) + } + if got := parseReplyToMessageID(&channel.ReplyRef{}); got != 0 { + t.Fatalf("empty MessageID should return 0: %d", got) + } + if got := parseReplyToMessageID(&channel.ReplyRef{MessageID: " 123 "}); got != 123 { + t.Fatalf("expected 123: %d", got) + } + if got := parseReplyToMessageID(&channel.ReplyRef{MessageID: "abc"}); got != 0 { + t.Fatalf("invalid number should return 0: %d", got) + } +} + +func TestResolveTelegramParseMode(t *testing.T) { + t.Parallel() + + if got := resolveTelegramParseMode(channel.MessageFormatMarkdown); got != tgbotapi.ModeMarkdown { + t.Fatalf("markdown should return ModeMarkdown: %s", got) + } + if got := resolveTelegramParseMode(channel.MessageFormatPlain); got != "" { + t.Fatalf("plain should return empty: %s", got) + } + if got := resolveTelegramParseMode(channel.MessageFormatRich); got != "" { + t.Fatalf("rich should return empty: %s", got) + } +} + +func TestBuildTelegramReplyRef(t *testing.T) { + t.Parallel() + + if buildTelegramReplyRef(nil, "123") != nil { + t.Fatal("nil msg should return nil") + } + msg := &tgbotapi.Message{} + if buildTelegramReplyRef(msg, "123") != nil { + t.Fatal("msg without ReplyToMessage should return nil") + } + msg.ReplyToMessage = &tgbotapi.Message{MessageID: 42} + ref := buildTelegramReplyRef(msg, " -100 ") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.MessageID != "42" || ref.Target != "-100" { + t.Fatalf("unexpected ref: %+v", ref) + } +} + +func TestPickTelegramPhoto(t *testing.T) { + t.Parallel() + + if got := pickTelegramPhoto(nil); got.FileID != "" { + t.Fatalf("nil should return zero: %+v", got) + } + if got := pickTelegramPhoto([]tgbotapi.PhotoSize{}); got.FileID != "" { + t.Fatalf("empty slice should return zero: %+v", got) + } + one := tgbotapi.PhotoSize{FileID: "a", FileSize: 100, Width: 10, Height: 10} + if got := pickTelegramPhoto([]tgbotapi.PhotoSize{one}); got.FileID != "a" { + t.Fatalf("single photo should return it: %+v", got) + } + photos := []tgbotapi.PhotoSize{ + {FileID: "small", FileSize: 100, Width: 100, Height: 100}, + {FileID: "large", FileSize: 500, Width: 200, Height: 200}, + } + if got := pickTelegramPhoto(photos); got.FileID != "large" { + t.Fatalf("should pick largest by size: %+v", got) + } +} + +func TestTelegramAdapter_Type(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + if adapter.Type() != Type { + t.Fatalf("Type should return telegram: %s", adapter.Type()) + } +} + +func TestTelegramAdapter_OpenStreamEmptyTarget(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + ctx := context.Background() + cfg := channel.ChannelConfig{} + _, err := adapter.OpenStream(ctx, cfg, "", channel.StreamOptions{}) + if err == nil { + t.Fatal("empty target should return error") + } + if !strings.Contains(err.Error(), "target") { + t.Fatalf("expected target in error: %v", err) + } +} + +func TestResolveTelegramSender_SenderChat(t *testing.T) { + t.Parallel() + + msg := &tgbotapi.Message{ + SenderChat: &tgbotapi.Chat{ID: 456, UserName: "group", Title: "My Group"}, + } + externalID, displayName, attrs := resolveTelegramSender(msg) + if externalID != "456" { + t.Fatalf("unexpected externalID: %s", externalID) + } + if displayName != "My Group" { + t.Fatalf("unexpected displayName: %s", displayName) + } + if attrs["sender_chat_id"] != "456" || attrs["sender_chat_username"] != "group" { + t.Fatalf("unexpected attrs: %#v", attrs) + } +} + +func TestBuildTelegramAudio(t *testing.T) { + t.Parallel() + + cfg, err := buildTelegramAudio("@channel", tgbotapi.FileID("f1")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.ChannelUsername != "@channel" { + t.Fatalf("unexpected channel: %s", cfg.ChannelUsername) + } + _, err = buildTelegramAudio("invalid", tgbotapi.FileID("f1")) + if err == nil { + t.Fatal("invalid target should return error") + } + if !strings.Contains(err.Error(), "chat_id") { + t.Fatalf("expected chat_id in error: %v", err) + } +} + +func TestBuildTelegramVoice(t *testing.T) { + t.Parallel() + + cfg, err := buildTelegramVoice("@ch", tgbotapi.FileID("f1")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.ChannelUsername != "@ch" { + t.Fatalf("unexpected channel: %s", cfg.ChannelUsername) + } + _, err = buildTelegramVoice("x", tgbotapi.FileID("f1")) + if err == nil { + t.Fatal("invalid target should return error") + } +} + +func TestBuildTelegramVideo(t *testing.T) { + t.Parallel() + + cfg, err := buildTelegramVideo("@ch", tgbotapi.FileID("f1")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.ChannelUsername != "@ch" { + t.Fatalf("unexpected channel: %s", cfg.ChannelUsername) + } + _, err = buildTelegramVideo("bad", tgbotapi.FileID("f1")) + if err == nil { + t.Fatal("invalid target should return error") + } +} + +func TestBuildTelegramAnimation(t *testing.T) { + t.Parallel() + + cfg, err := buildTelegramAnimation("@ch", tgbotapi.FileID("f1")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.ChannelUsername != "@ch" { + t.Fatalf("unexpected channel: %s", cfg.ChannelUsername) + } + _, err = buildTelegramAnimation("x", tgbotapi.FileID("f1")) + if err == nil { + t.Fatal("invalid target should return error") + } +} + +func TestTelegramAdapter_NormalizeAndResolve(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + norm, err := adapter.NormalizeConfig(map[string]any{"botToken": "t1"}) + if err != nil { + t.Fatalf("NormalizeConfig: %v", err) + } + if norm["botToken"] != "t1" { + t.Fatalf("unexpected normalized: %#v", norm) + } + userNorm, err := adapter.NormalizeUserConfig(map[string]any{"username": "u1"}) + if err != nil { + t.Fatalf("NormalizeUserConfig: %v", err) + } + if userNorm["username"] != "u1" { + t.Fatalf("unexpected user config: %#v", userNorm) + } + if got := adapter.NormalizeTarget("https://t.me/x"); got != "@x" { + t.Fatalf("NormalizeTarget: %s", got) + } + target, err := adapter.ResolveTarget(map[string]any{"chat_id": "123"}) + if err != nil { + t.Fatalf("ResolveTarget: %v", err) + } + if target != "123" { + t.Fatalf("ResolveTarget: %s", target) + } +} diff --git a/internal/channel/capabilities.go b/internal/channel/capabilities.go index a1cf379c..2de4c90c 100644 --- a/internal/channel/capabilities.go +++ b/internal/channel/capabilities.go @@ -3,20 +3,20 @@ package channel // ChannelCapabilities describes the feature matrix of a channel type. // It is used by the outbound layer to validate message content before delivery. type ChannelCapabilities struct { - Text bool `json:"text"` - Markdown bool `json:"markdown"` - RichText bool `json:"rich_text"` - Attachments bool `json:"attachments"` - Media bool `json:"media"` - Reactions bool `json:"reactions"` - Buttons bool `json:"buttons"` - Reply bool `json:"reply"` - Threads bool `json:"threads"` - Streaming bool `json:"streaming"` - Polls bool `json:"polls"` - Edit bool `json:"edit"` - Unsend bool `json:"unsend"` - NativeCommands bool `json:"native_commands"` - BlockStreaming bool `json:"block_streaming"` - ChatTypes []string `json:"chat_types,omitempty"` + Text bool `json:"text"` + Markdown bool `json:"markdown"` + RichText bool `json:"rich_text"` + Attachments bool `json:"attachments"` + Media bool `json:"media"` + Reactions bool `json:"reactions"` + Buttons bool `json:"buttons"` + Reply bool `json:"reply"` + Threads bool `json:"threads"` + Streaming bool `json:"streaming"` + Polls bool `json:"polls"` + Edit bool `json:"edit"` + Unsend bool `json:"unsend"` + NativeCommands bool `json:"native_commands"` + BlockStreaming bool `json:"block_streaming"` + ChatTypes []string `json:"chat_types,omitempty"` } diff --git a/internal/channel/connection.go b/internal/channel/connection.go index ffcf75c1..a64a476f 100644 --- a/internal/channel/connection.go +++ b/internal/channel/connection.go @@ -14,6 +14,13 @@ type connectionEntry struct { } func (m *Manager) refresh(ctx context.Context) { + // Serialize refresh calls to prevent concurrent reconcile from starting + // duplicate adapter connections. + if !m.refreshMu.TryLock() { + return + } + defer m.refreshMu.Unlock() + if m.service == nil { return } @@ -75,29 +82,41 @@ func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error m.mu.Lock() entry := m.connections[cfg.ID] + + // Config unchanged — nothing to do. if entry != nil && !entry.config.UpdatedAt.Before(cfg.UpdatedAt) { m.mu.Unlock() return nil } + + // Need to stop existing connection before starting a new one. + // Keep the lock to prevent another goroutine from starting a duplicate. + var oldConn Connection if entry != nil { - m.mu.Unlock() + oldConn = entry.connection + delete(m.connections, cfg.ID) + } + m.mu.Unlock() + + if oldConn != nil { if m.logger != nil { m.logger.Info("adapter restart", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) } - if err := entry.connection.Stop(ctx); err != nil { + if err := oldConn.Stop(ctx); err != nil { if errors.Is(err, ErrStopNotSupported) { if m.logger != nil { m.logger.Warn("adapter restart skipped", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) } + // Re-insert the entry since we can't restart it. + m.mu.Lock() + if _, exists := m.connections[cfg.ID]; !exists { + m.connections[cfg.ID] = entry + } + m.mu.Unlock() return nil } return err } - m.mu.Lock() - delete(m.connections, cfg.ID) - m.mu.Unlock() - } else { - m.mu.Unlock() } receiver, ok := m.registry.GetReceiver(cfg.ChannelType) @@ -105,6 +124,15 @@ func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error return nil } + // Double-check: another goroutine may have already started a connection + // for this config while we were stopping the old one. + m.mu.Lock() + if existing, ok := m.connections[cfg.ID]; ok && existing != nil { + m.mu.Unlock() + return nil + } + m.mu.Unlock() + if m.logger != nil { m.logger.Info("adapter start", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) } @@ -116,7 +144,15 @@ func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error if err != nil { return err } + m.mu.Lock() + // Final check: if another goroutine raced and inserted first, stop our new + // connection and keep the existing one. + if existing, ok := m.connections[cfg.ID]; ok && existing != nil { + m.mu.Unlock() + _ = conn.Stop(ctx) + return nil + } m.connections[cfg.ID] = &connectionEntry{ config: cfg, connection: conn, diff --git a/internal/channel/directory.go b/internal/channel/directory.go index fa82e98e..32457912 100644 --- a/internal/channel/directory.go +++ b/internal/channel/directory.go @@ -32,5 +32,5 @@ type ChannelDirectoryAdapter interface { ListPeers(ctx context.Context, cfg ChannelConfig, query DirectoryQuery) ([]DirectoryEntry, error) ListGroups(ctx context.Context, cfg ChannelConfig, query DirectoryQuery) ([]DirectoryEntry, error) ListGroupMembers(ctx context.Context, cfg ChannelConfig, groupID string, query DirectoryQuery) ([]DirectoryEntry, error) - ResolveTarget(ctx context.Context, cfg ChannelConfig, input string, kind DirectoryEntryKind) (DirectoryEntry, error) + ResolveEntry(ctx context.Context, cfg ChannelConfig, input string, kind DirectoryEntryKind) (DirectoryEntry, error) } diff --git a/internal/channelidentities/service.go b/internal/channel/identities/service.go similarity index 96% rename from internal/channelidentities/service.go rename to internal/channel/identities/service.go index a43b4dc0..b5f0126d 100644 --- a/internal/channelidentities/service.go +++ b/internal/channel/identities/service.go @@ -1,4 +1,4 @@ -package channelidentities +package identities import ( "context" @@ -32,7 +32,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { } return &Service{ queries: queries, - logger: log.With(slog.String("service", "channelidentities")), + logger: log.With(slog.String("service", "channel/identities")), } } @@ -48,7 +48,7 @@ func (s *Service) Create(ctx context.Context, channel, channelSubjectID, display } row, err := s.queries.CreateChannelIdentity(ctx, sqlc.CreateChannelIdentityParams{ UserID: pgtype.UUID{}, - Channel: channel, + ChannelType: channel, ChannelSubjectID: channelSubjectID, DisplayName: toPgText(displayName), Metadata: emptyMetadataBytes(), @@ -110,7 +110,7 @@ func (s *Service) ResolveByChannelIdentity(ctx context.Context, channel, channel row, err := s.queries.UpsertChannelIdentityByChannelSubject(ctx, sqlc.UpsertChannelIdentityByChannelSubjectParams{ UserID: pgtype.UUID{}, - Channel: channel, + ChannelType: channel, ChannelSubjectID: channelSubjectID, DisplayName: toPgText(displayName), Metadata: emptyMetadataBytes(), @@ -137,7 +137,7 @@ func (s *Service) UpsertChannelIdentity(ctx context.Context, channel, channelSub } row, err := s.queries.UpsertChannelIdentityByChannelSubject(ctx, sqlc.UpsertChannelIdentityByChannelSubjectParams{ UserID: pgtype.UUID{}, - Channel: channel, + ChannelType: channel, ChannelSubjectID: channelSubjectID, DisplayName: toPgText(displayName), Metadata: metaBytes, @@ -217,7 +217,7 @@ func (s *Service) GetLinkedUserID(ctx context.Context, channelIdentityID string) if !row.UserID.Valid { return "", nil } - return db.UUIDToString(row.UserID), nil + return row.UserID.String(), nil } // LinkChannelIdentityToUser binds a channel identity to a user. @@ -260,12 +260,12 @@ func toChannelIdentity(row sqlc.ChannelIdentity) ChannelIdentity { } userID := "" if row.UserID.Valid { - userID = db.UUIDToString(row.UserID) + userID = row.UserID.String() } return ChannelIdentity{ - ID: db.UUIDToString(row.ID), + ID: row.ID.String(), UserID: userID, - Channel: row.Channel, + Channel: row.ChannelType, ChannelSubjectID: row.ChannelSubjectID, DisplayName: displayName, Metadata: metadata, diff --git a/internal/channelidentities/service_identity_integration_test.go b/internal/channel/identities/service_identity_integration_test.go similarity index 86% rename from internal/channelidentities/service_identity_integration_test.go rename to internal/channel/identities/service_identity_integration_test.go index 827ebcb0..9cda339a 100644 --- a/internal/channelidentities/service_identity_integration_test.go +++ b/internal/channel/identities/service_identity_integration_test.go @@ -1,4 +1,4 @@ -package channelidentities_test +package identities_test import ( "context" @@ -8,13 +8,14 @@ import ( "testing" "time" + "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" - "github.com/memohai/memoh/internal/channelidentities" + "github.com/memohai/memoh/internal/channel/identities" "github.com/memohai/memoh/internal/db/sqlc" ) -func setupChannelIdentityIdentityIntegrationTest(t *testing.T) (*channelidentities.Service, *sqlc.Queries, func()) { +func setupChannelIdentityIdentityIntegrationTest(t *testing.T) (*identities.Service, *sqlc.Queries, func()) { t.Helper() dsn := os.Getenv("TEST_POSTGRES_DSN") @@ -34,14 +35,10 @@ func setupChannelIdentityIdentityIntegrationTest(t *testing.T) (*channelidentiti queries := sqlc.New(pool) logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) - svc := channelidentities.NewService(logger, queries) + svc := identities.NewService(logger, queries) return svc, queries, func() { pool.Close() } } -func formatUUID(bytes [16]byte) string { - return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16]) -} - func TestChannelIdentityResolveChannelIdentityStable(t *testing.T) { svc, _, cleanup := setupChannelIdentityIdentityIntegrationTest(t) defer cleanup() @@ -77,7 +74,7 @@ func TestChannelIdentityLinkToUser(t *testing.T) { if err != nil { t.Fatalf("create user failed: %v", err) } - userID := formatUUID(user.ID.Bytes) + userID := uuid.UUID(user.ID.Bytes).String() if err := svc.LinkChannelIdentityToUser(ctx, channelIdentity.ID, userID); err != nil { t.Fatalf("link channelIdentity to user failed: %v", err) diff --git a/internal/channelidentities/service_integration_test.go b/internal/channel/identities/service_integration_test.go similarity index 85% rename from internal/channelidentities/service_integration_test.go rename to internal/channel/identities/service_integration_test.go index ddc08e85..6f1833d7 100644 --- a/internal/channelidentities/service_integration_test.go +++ b/internal/channel/identities/service_integration_test.go @@ -1,7 +1,7 @@ //go:build ignore // +build ignore -package channelidentities_test +package identities_test import ( "context" @@ -11,13 +11,14 @@ import ( "testing" "time" + "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" - "github.com/memohai/memoh/internal/channelidentities" + "github.com/memohai/memoh/internal/channel/identities" "github.com/memohai/memoh/internal/db/sqlc" ) -func setupIntegrationTest(t *testing.T) (*channelidentities.Service, *sqlc.Queries, func()) { +func setupIntegrationTest(t *testing.T) (*identities.Service, *sqlc.Queries, func()) { t.Helper() dsn := os.Getenv("TEST_POSTGRES_DSN") @@ -37,15 +38,11 @@ func setupIntegrationTest(t *testing.T) (*channelidentities.Service, *sqlc.Queri queries := sqlc.New(pool) logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) - svc := channelidentities.NewService(logger, queries) + svc := identities.NewService(logger, queries) return svc, queries, func() { pool.Close() } } -func toUUIDString(v [16]byte) string { - return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", v[0:4], v[4:6], v[6:8], v[8:10], v[10:16]) -} - func TestIntegrationResolveByChannelIdentityStability(t *testing.T) { svc, _, cleanup := setupIntegrationTest(t) defer cleanup() @@ -84,7 +81,7 @@ func TestIntegrationLinkChannelIdentityToUser(t *testing.T) { if err != nil { t.Fatalf("create user failed: %v", err) } - userID := toUUIDString(user.ID.Bytes) + userID := uuid.UUID(user.ID.Bytes).String() if err := svc.LinkChannelIdentityToUser(ctx, channelIdentity.ID, userID); err != nil { t.Fatalf("link channelIdentity to user failed: %v", err) diff --git a/internal/channelidentities/service_test.go b/internal/channel/identities/service_test.go similarity index 96% rename from internal/channelidentities/service_test.go rename to internal/channel/identities/service_test.go index a5a73b9e..2bdb6bde 100644 --- a/internal/channelidentities/service_test.go +++ b/internal/channel/identities/service_test.go @@ -1,4 +1,4 @@ -package channelidentities +package identities import "testing" diff --git a/internal/channelidentities/types.go b/internal/channel/identities/types.go similarity index 95% rename from internal/channelidentities/types.go rename to internal/channel/identities/types.go index ecffe0e4..cfc36d30 100644 --- a/internal/channelidentities/types.go +++ b/internal/channel/identities/types.go @@ -1,4 +1,4 @@ -package channelidentities +package identities import "time" diff --git a/internal/channel/inbound_test.go b/internal/channel/inbound_test.go new file mode 100644 index 00000000..2135d5ba --- /dev/null +++ b/internal/channel/inbound_test.go @@ -0,0 +1,217 @@ +package channel + +import ( + "context" + "fmt" + "log/slog" + "testing" +) + +// mockAdapter is used for inbound handleInbound tests. +type mockAdapter struct { + sentMessages []OutboundMessage + streamEvents []StreamEvent +} + +func (m *mockAdapter) Type() ChannelType { return ChannelType("test") } +func (m *mockAdapter) Descriptor() Descriptor { + return Descriptor{ + Type: ChannelType("test"), + DisplayName: "Test", + Capabilities: ChannelCapabilities{ + Text: true, + Reply: true, + Streaming: true, + }, + } +} +func (m *mockAdapter) Send(ctx context.Context, cfg ChannelConfig, msg OutboundMessage) error { + m.sentMessages = append(m.sentMessages, msg) + return nil +} + +func (m *mockAdapter) OpenStream(ctx context.Context, cfg ChannelConfig, target string, opts StreamOptions) (OutboundStream, error) { + return &mockAdapterStream{adapter: m}, nil +} + +type mockAdapterStream struct { + adapter *mockAdapter +} + +func (s *mockAdapterStream) Push(ctx context.Context, event StreamEvent) error { + if s == nil || s.adapter == nil { + return nil + } + s.adapter.streamEvents = append(s.adapter.streamEvents, event) + if event.Type == StreamEventFinal && event.Final != nil && !event.Final.Message.IsEmpty() { + s.adapter.sentMessages = append(s.adapter.sentMessages, OutboundMessage{ + Target: "stream-target", + Message: event.Final.Message, + }) + } + return nil +} + +func (s *mockAdapterStream) Close(ctx context.Context) error { + return nil +} + +type fakeInboundProcessor struct { + resp *OutboundMessage + err error + gotCfg ChannelConfig + gotMsg InboundMessage +} + +func (f *fakeInboundProcessor) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender StreamReplySender) error { + f.gotCfg = cfg + f.gotMsg = msg + if f.err != nil { + return f.err + } + if f.resp == nil { + return nil + } + if sender == nil { + return fmt.Errorf("sender missing") + } + return sender.Send(ctx, *f.resp) +} + +type fakeInboundStreamProcessor struct{} + +func (f *fakeInboundStreamProcessor) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender StreamReplySender) error { + stream, err := sender.OpenStream(ctx, "stream-target", StreamOptions{}) + if err != nil { + return err + } + if err := stream.Push(ctx, StreamEvent{ + Type: StreamEventDelta, + Delta: "partial", + }); err != nil { + return err + } + if err := stream.Push(ctx, StreamEvent{ + Type: StreamEventFinal, + Final: &StreamFinalizePayload{ + Message: Message{Text: "stream-final"}, + }, + }); err != nil { + return err + } + return stream.Close(ctx) +} + +func TestManager_handleInbound(t *testing.T) { + logger := slog.Default() + + t.Run("with_reply_sends_successfully", func(t *testing.T) { + processor := &fakeInboundProcessor{ + resp: &OutboundMessage{ + Target: "target-id", + Message: Message{ + Text: "AI reply content", + }, + }, + } + + reg := NewRegistry() + m := NewManager(logger, reg, &fakeConfigStore{}, processor) + adapter := &mockAdapter{} + m.RegisterAdapter(adapter) + + cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelType("test")} + msg := InboundMessage{ + Channel: ChannelType("test"), + Message: Message{Text: "hello"}, + ReplyTarget: "target-id", + Conversation: Conversation{ + ID: "chat-1", + Type: "p2p", + }, + } + + err := m.handleInbound(context.Background(), cfg, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(adapter.sentMessages) != 1 { + t.Fatalf("expected 1 reply sent, got %d", len(adapter.sentMessages)) + } + if adapter.sentMessages[0].Message.PlainText() != "AI reply content" { + t.Errorf("reply content mismatch: %s", adapter.sentMessages[0].Message.PlainText()) + } + if adapter.sentMessages[0].Target != "target-id" { + t.Errorf("reply target mismatch: %s", adapter.sentMessages[0].Target) + } + }) + + t.Run("no_reply_does_not_send", func(t *testing.T) { + processor := &fakeInboundProcessor{resp: nil} + reg := NewRegistry() + m := NewManager(logger, reg, &fakeConfigStore{}, processor) + adapter := &mockAdapter{} + m.RegisterAdapter(adapter) + + cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelType("test")} + msg := InboundMessage{ + Channel: ChannelType("test"), + Message: Message{Text: "hello"}, + ReplyTarget: "target-id", + } + + err := m.handleInbound(context.Background(), cfg, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(adapter.sentMessages) != 0 { + t.Errorf("expected no reply sent, got %+v", adapter.sentMessages) + } + }) + + t.Run("handler_error_returns_error", func(t *testing.T) { + processor := &fakeInboundProcessor{err: context.Canceled} + reg := NewRegistry() + m := NewManager(logger, reg, &fakeConfigStore{}, processor) + cfg := ChannelConfig{ID: "bot-1"} + msg := InboundMessage{Message: Message{Text: " "}} // whitespace-only message + + err := m.handleInbound(context.Background(), cfg, msg) + if err == nil { + t.Errorf("expected handler to return error") + } + }) + + t.Run("stream sender forwards events", func(t *testing.T) { + processor := &fakeInboundStreamProcessor{} + reg := NewRegistry() + m := NewManager(logger, reg, &fakeConfigStore{}, processor) + adapter := &mockAdapter{} + m.RegisterAdapter(adapter) + + cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelType("test")} + msg := InboundMessage{ + Channel: ChannelType("test"), + Message: Message{Text: "hello"}, + ReplyTarget: "stream-target", + Conversation: Conversation{ + ID: "chat-1", + Type: "p2p", + }, + } + if err := m.handleInbound(context.Background(), cfg, msg); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(adapter.streamEvents) < 2 { + t.Fatalf("expected at least two stream events, got %d", len(adapter.streamEvents)) + } + if len(adapter.sentMessages) == 0 { + t.Fatal("expected stream final message to be published") + } + if adapter.sentMessages[len(adapter.sentMessages)-1].Message.PlainText() != "stream-final" { + t.Fatalf("unexpected stream final message: %s", adapter.sentMessages[len(adapter.sentMessages)-1].Message.PlainText()) + } + }) +} diff --git a/internal/channel/manager.go b/internal/channel/manager.go index 628e5340..94b0647d 100644 --- a/internal/channel/manager.go +++ b/internal/channel/manager.go @@ -61,6 +61,7 @@ type Manager struct { inboundCtx context.Context inboundCancel context.CancelFunc mu sync.Mutex + refreshMu sync.Mutex connections map[string]*connectionEntry } diff --git a/internal/channel/manager_integration_test.go b/internal/channel/manager_integration_test.go index fc5fe1bd..fc296014 100644 --- a/internal/channel/manager_integration_test.go +++ b/internal/channel/manager_integration_test.go @@ -54,7 +54,7 @@ type fakeInboundProcessorIntegration struct { gotMsg InboundMessage } -func (f *fakeInboundProcessorIntegration) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender ReplySender) error { +func (f *fakeInboundProcessorIntegration) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender StreamReplySender) error { f.gotCfg = cfg f.gotMsg = msg if f.err != nil { diff --git a/internal/channel/outbound.go b/internal/channel/outbound.go index a8d1482d..4110b6e3 100644 --- a/internal/channel/outbound.go +++ b/internal/channel/outbound.go @@ -302,6 +302,9 @@ func validateMessageCapabilities(registry *Registry, channelType ChannelType, ms if msg.Reply != nil && !caps.Reply { return fmt.Errorf("channel does not support reply") } + if strings.TrimSpace(msg.ID) != "" && !caps.Edit { + return fmt.Errorf("channel does not support edit") + } return nil } @@ -316,12 +319,40 @@ func (m *Manager) sendWithConfig(ctx context.Context, sender Sender, cfg Channel if msg.Message.IsEmpty() { return fmt.Errorf("message is required") } - if err := validateMessageCapabilities(m.registry, cfg.ChannelType, msg.Message); err != nil { + normalized := msg + attachments, err := normalizeAttachmentRefs(msg.Message.Attachments, cfg.ChannelType) + if err != nil { return err } + normalized.Message.Attachments = attachments + if err := validateMessageCapabilities(m.registry, cfg.ChannelType, normalized.Message); err != nil { + return err + } + editor, _ := m.registry.GetMessageEditor(cfg.ChannelType) + if strings.TrimSpace(normalized.Message.ID) != "" { + if editor == nil { + return fmt.Errorf("channel does not support edit") + } + var lastErr error + for i := 0; i < policy.RetryMax; i++ { + err := editor.Update(ctx, cfg, target, strings.TrimSpace(normalized.Message.ID), normalized.Message) + if err == nil { + return nil + } + lastErr = err + if m.logger != nil { + m.logger.Warn("edit outbound retry", + slog.String("channel", cfg.ChannelType.String()), + slog.Int("attempt", i+1), + slog.Any("error", err)) + } + time.Sleep(time.Duration(i+1) * time.Duration(policy.RetryBackoffMs) * time.Millisecond) + } + return fmt.Errorf("edit outbound failed after retries: %w", lastErr) + } var lastErr error for i := 0; i < policy.RetryMax; i++ { - err := sender.Send(ctx, cfg, OutboundMessage{Target: target, Message: msg.Message}) + err := sender.Send(ctx, cfg, OutboundMessage{Target: target, Message: normalized.Message}) if err == nil { return nil } @@ -337,6 +368,27 @@ func (m *Manager) sendWithConfig(ctx context.Context, sender Sender, cfg Channel return fmt.Errorf("send outbound failed after retries: %w", lastErr) } +func normalizeAttachmentRefs(attachments []Attachment, defaultPlatform ChannelType) ([]Attachment, error) { + if len(attachments) == 0 { + return nil, nil + } + normalized := make([]Attachment, 0, len(attachments)) + for _, att := range attachments { + item := att + item.URL = strings.TrimSpace(item.URL) + item.PlatformKey = strings.TrimSpace(item.PlatformKey) + item.SourcePlatform = strings.TrimSpace(item.SourcePlatform) + if item.SourcePlatform == "" && item.PlatformKey != "" { + item.SourcePlatform = defaultPlatform.String() + } + if item.URL == "" && item.PlatformKey == "" { + return nil, fmt.Errorf("attachment reference is required") + } + normalized = append(normalized, item) + } + return normalized, nil +} + func requiresMedia(attachments []Attachment) bool { for _, att := range attachments { switch att.Type { @@ -349,21 +401,55 @@ func requiresMedia(attachments []Attachment) bool { return false } -func (m *Manager) newReplySender(cfg ChannelConfig, channelType ChannelType) ReplySender { +func validateStreamEvent(registry *Registry, channelType ChannelType, event StreamEvent) error { + caps, _ := registry.GetCapabilities(channelType) + switch event.Type { + case StreamEventStatus: + if event.Status == "" { + return fmt.Errorf("stream status is required") + } + case StreamEventDelta: + if !caps.Streaming && !caps.BlockStreaming { + return fmt.Errorf("channel does not support streaming") + } + case StreamEventFinal: + if event.Final == nil { + return fmt.Errorf("stream final payload is required") + } + if err := validateMessageCapabilities(registry, channelType, event.Final.Message); err != nil { + return err + } + if _, err := normalizeAttachmentRefs(event.Final.Message.Attachments, channelType); err != nil { + return err + } + case StreamEventError: + if strings.TrimSpace(event.Error) == "" { + return fmt.Errorf("stream error is required") + } + default: + return fmt.Errorf("unsupported stream event type: %s", event.Type) + } + return nil +} + +func (m *Manager) newReplySender(cfg ChannelConfig, channelType ChannelType) StreamReplySender { sender, _ := m.registry.GetSender(channelType) + streamSender, _ := m.registry.GetStreamSender(channelType) return &managerReplySender{ - manager: m, - sender: sender, - channelType: channelType, - config: cfg, + manager: m, + sender: sender, + streamSender: streamSender, + channelType: channelType, + config: cfg, } } type managerReplySender struct { - manager *Manager - sender Sender - channelType ChannelType - config ChannelConfig + manager *Manager + sender Sender + streamSender StreamSender + channelType ChannelType + config ChannelConfig } func (s *managerReplySender) Send(ctx context.Context, msg OutboundMessage) error { @@ -382,3 +468,52 @@ func (s *managerReplySender) Send(ctx context.Context, msg OutboundMessage) erro } return nil } + +func (s *managerReplySender) OpenStream(ctx context.Context, target string, opts StreamOptions) (OutboundStream, error) { + if s.manager == nil { + return nil, fmt.Errorf("channel manager not configured") + } + if s.streamSender == nil { + return nil, fmt.Errorf("channel stream sender not configured") + } + target = strings.TrimSpace(target) + if target == "" { + return nil, fmt.Errorf("target is required") + } + caps, _ := s.manager.registry.GetCapabilities(s.channelType) + if !caps.Streaming && !caps.BlockStreaming { + return nil, fmt.Errorf("channel does not support streaming") + } + stream, err := s.streamSender.OpenStream(ctx, s.config, target, opts) + if err != nil { + return nil, err + } + return &managerOutboundStream{ + manager: s.manager, + stream: stream, + channelType: s.channelType, + }, nil +} + +type managerOutboundStream struct { + manager *Manager + stream OutboundStream + channelType ChannelType +} + +func (s *managerOutboundStream) Push(ctx context.Context, event StreamEvent) error { + if s.manager == nil || s.stream == nil { + return fmt.Errorf("stream is not configured") + } + if err := validateStreamEvent(s.manager.registry, s.channelType, event); err != nil { + return err + } + return s.stream.Push(ctx, event) +} + +func (s *managerOutboundStream) Close(ctx context.Context) error { + if s.stream == nil { + return fmt.Errorf("stream is not configured") + } + return s.stream.Close(ctx) +} diff --git a/internal/channel/processor.go b/internal/channel/processor.go index 6f4e79f2..bb363013 100644 --- a/internal/channel/processor.go +++ b/internal/channel/processor.go @@ -4,5 +4,5 @@ import "context" // InboundProcessor handles inbound messages and replies through the given sender. type InboundProcessor interface { - HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender ReplySender) error + HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender StreamReplySender) error } diff --git a/internal/channel/registry.go b/internal/channel/registry.go index b9a629ce..9a672d59 100644 --- a/internal/channel/registry.go +++ b/internal/channel/registry.go @@ -71,6 +71,16 @@ func (r *Registry) Get(channelType ChannelType) (Adapter, bool) { return adapter, ok } +// DirectoryAdapter returns the directory adapter for the given channel type if it implements ChannelDirectoryAdapter. +func (r *Registry) DirectoryAdapter(channelType ChannelType) (ChannelDirectoryAdapter, bool) { + adapter, ok := r.Get(channelType) + if !ok { + return nil, false + } + dir, ok := adapter.(ChannelDirectoryAdapter) + return dir, ok +} + // List returns all registered adapters. func (r *Registry) List() []Adapter { r.mu.RLock() @@ -185,6 +195,26 @@ func (r *Registry) GetSender(channelType ChannelType) (Sender, bool) { return sender, ok } +// GetStreamSender returns the StreamSender for the given channel type, or nil if unsupported. +func (r *Registry) GetStreamSender(channelType ChannelType) (StreamSender, bool) { + adapter, ok := r.Get(channelType) + if !ok { + return nil, false + } + streamSender, ok := adapter.(StreamSender) + return streamSender, ok +} + +// GetMessageEditor returns the MessageEditor for the given channel type, or nil if unsupported. +func (r *Registry) GetMessageEditor(channelType ChannelType) (MessageEditor, bool) { + adapter, ok := r.Get(channelType) + if !ok { + return nil, false + } + editor, ok := adapter.(MessageEditor) + return editor, ok +} + // GetReceiver returns the Receiver for the given channel type, or nil if unsupported. func (r *Registry) GetReceiver(channelType ChannelType) (Receiver, bool) { adapter, ok := r.Get(channelType) @@ -195,6 +225,16 @@ func (r *Registry) GetReceiver(channelType ChannelType) (Receiver, bool) { return receiver, ok } +// GetProcessingStatusNotifier returns the ProcessingStatusNotifier for the given channel type, or nil if unsupported. +func (r *Registry) GetProcessingStatusNotifier(channelType ChannelType) (ProcessingStatusNotifier, bool) { + adapter, ok := r.Get(channelType) + if !ok { + return nil, false + } + notifier, ok := adapter.(ProcessingStatusNotifier) + return notifier, ok +} + // --- Dispatch methods (replace former global functions in config.go / target.go) --- // NormalizeConfig validates and normalizes a channel configuration map. diff --git a/internal/channel/registry_test.go b/internal/channel/registry_test.go new file mode 100644 index 00000000..c27c3875 --- /dev/null +++ b/internal/channel/registry_test.go @@ -0,0 +1,63 @@ +package channel_test + +import ( + "context" + "testing" + + "github.com/memohai/memoh/internal/channel" +) + +const dirTestChannelType = channel.ChannelType("dir-test") + +// dirMockAdapter implements Adapter and ChannelDirectoryAdapter for registry DirectoryAdapter tests. +type dirMockAdapter struct{} + +func (a *dirMockAdapter) Type() channel.ChannelType { return dirTestChannelType } + +func (a *dirMockAdapter) Descriptor() channel.Descriptor { + return channel.Descriptor{Type: dirTestChannelType, DisplayName: "DirTest"} +} + +func (a *dirMockAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} + +func (a *dirMockAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} + +func (a *dirMockAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelConfig, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} + +func (a *dirMockAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + return channel.DirectoryEntry{}, nil +} + +func TestDirectoryAdapter_Unsupported(t *testing.T) { + t.Parallel() + reg := newTestConfigRegistry() + dir, ok := reg.DirectoryAdapter(testChannelType) + if ok || dir != nil { + t.Fatalf("DirectoryAdapter(test) = (%v, %v), want (nil, false)", dir, ok) + } +} + +func TestDirectoryAdapter_Supported(t *testing.T) { + t.Parallel() + reg := channel.NewRegistry() + reg.MustRegister(&dirMockAdapter{}) + dir, ok := reg.DirectoryAdapter(dirTestChannelType) + if !ok || dir == nil { + t.Fatalf("DirectoryAdapter(dir-test) = (%v, %v), want (non-nil, true)", dir, ok) + } +} + +func TestDirectoryAdapter_UnknownType(t *testing.T) { + t.Parallel() + reg := channel.NewRegistry() + dir, ok := reg.DirectoryAdapter(channel.ChannelType("unknown")) + if ok || dir != nil { + t.Fatalf("DirectoryAdapter(unknown) = (%v, %v), want (nil, false)", dir, ok) + } +} diff --git a/internal/channel/route/service.go b/internal/channel/route/service.go new file mode 100644 index 00000000..c50aeafe --- /dev/null +++ b/internal/channel/route/service.go @@ -0,0 +1,362 @@ +package route + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/conversation" + dbpkg "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" +) + +// ConversationService contains the minimal conversation behavior required by route resolution. +type ConversationService interface { + Create(ctx context.Context, botID, channelIdentityID string, req conversation.CreateRequest) (conversation.Chat, error) + IsParticipant(ctx context.Context, conversationID, channelIdentityID string) (bool, error) + AddParticipant(ctx context.Context, conversationID, channelIdentityID, role string) (conversation.Participant, error) +} + +// DBService manages channel routes and route-to-conversation resolution. +type DBService struct { + queries *sqlc.Queries + conversation ConversationService + logger *slog.Logger +} + +// NewService creates a channel route service. +func NewService(log *slog.Logger, queries *sqlc.Queries, conversationService ConversationService) *DBService { + if log == nil { + log = slog.Default() + } + return &DBService{ + queries: queries, + conversation: conversationService, + logger: log.With(slog.String("service", "channel/route")), + } +} + +// Create creates a route. +func (s *DBService) Create(ctx context.Context, input CreateInput) (Route, error) { + pgConversationID, err := dbpkg.ParseUUID(input.ChatID) + if err != nil { + return Route{}, err + } + pgBotID, err := dbpkg.ParseUUID(input.BotID) + if err != nil { + return Route{}, err + } + var pgConfigID pgtype.UUID + if strings.TrimSpace(input.ChannelConfigID) != "" { + pgConfigID, err = dbpkg.ParseUUID(input.ChannelConfigID) + if err != nil { + return Route{}, err + } + } + metadata, err := json.Marshal(nonNilMap(input.Metadata)) + if err != nil { + return Route{}, fmt.Errorf("marshal route metadata: %w", err) + } + + row, err := s.queries.CreateChatRoute(ctx, sqlc.CreateChatRouteParams{ + ChatID: pgConversationID, + BotID: pgBotID, + Platform: input.Platform, + ChannelConfigID: pgConfigID, + ConversationID: input.ConversationID, + ThreadID: toPgText(input.ThreadID), + ReplyTarget: toPgText(input.ReplyTarget), + Metadata: metadata, + }) + if err != nil { + return Route{}, fmt.Errorf("create route: %w", err) + } + + return toRouteFromCreate(row), nil +} + +// Find finds a route by bot/platform/external-conversation/thread. +func (s *DBService) Find(ctx context.Context, botID, platform, conversationID, threadID string) (Route, error) { + pgBotID, err := dbpkg.ParseUUID(botID) + if err != nil { + return Route{}, err + } + row, err := s.queries.FindChatRoute(ctx, sqlc.FindChatRouteParams{ + BotID: pgBotID, + Platform: platform, + ConversationID: conversationID, + ThreadID: toPgText(threadID), + }) + if err != nil { + return Route{}, err + } + return toRouteFromFind(row), nil +} + +// GetByID gets a route by ID. +func (s *DBService) GetByID(ctx context.Context, routeID string) (Route, error) { + pgID, err := dbpkg.ParseUUID(routeID) + if err != nil { + return Route{}, err + } + row, err := s.queries.GetChatRouteByID(ctx, pgID) + if err != nil { + return Route{}, err + } + return toRouteFromGet(row), nil +} + +// List lists all routes for a conversation. +func (s *DBService) List(ctx context.Context, conversationID string) ([]Route, error) { + pgID, err := dbpkg.ParseUUID(conversationID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListChatRoutes(ctx, pgID) + if err != nil { + return nil, err + } + routes := make([]Route, 0, len(rows)) + for _, row := range rows { + routes = append(routes, toRouteFromList(row)) + } + return routes, nil +} + +// Delete deletes a route by ID. +func (s *DBService) Delete(ctx context.Context, routeID string) error { + pgID, err := dbpkg.ParseUUID(routeID) + if err != nil { + return err + } + return s.queries.DeleteChatRoute(ctx, pgID) +} + +// UpdateReplyTarget updates default reply target. +func (s *DBService) UpdateReplyTarget(ctx context.Context, routeID, replyTarget string) error { + pgID, err := dbpkg.ParseUUID(routeID) + if err != nil { + return err + } + return s.queries.UpdateChatRouteReplyTarget(ctx, sqlc.UpdateChatRouteReplyTargetParams{ + ID: pgID, + ReplyTarget: toPgText(replyTarget), + }) +} + +// ResolveConversation finds or creates a conversation route for an inbound message. +func (s *DBService) ResolveConversation(ctx context.Context, input ResolveInput) (ResolveConversationResult, error) { + route, err := s.Find(ctx, input.BotID, input.Platform, input.ConversationID, input.ThreadID) + if err == nil { + if strings.TrimSpace(input.ChannelIdentityID) != "" && s.conversation != nil { + ok, checkErr := s.conversation.IsParticipant(ctx, route.ChatID, input.ChannelIdentityID) + if checkErr != nil { + return ResolveConversationResult{}, fmt.Errorf("check conversation participant: %w", checkErr) + } + if !ok { + if _, addErr := s.conversation.AddParticipant(ctx, route.ChatID, input.ChannelIdentityID, conversation.RoleMember); addErr != nil && s.logger != nil { + s.logger.Warn("auto-add participant failed", slog.Any("error", addErr)) + } + } + } + if strings.TrimSpace(input.ReplyTarget) != "" && input.ReplyTarget != route.ReplyTarget { + if updateErr := s.UpdateReplyTarget(ctx, route.ID, input.ReplyTarget); updateErr != nil && s.logger != nil { + s.logger.Warn("update route reply target failed", slog.Any("error", updateErr)) + } + } + pgConversationID, parseErr := dbpkg.ParseUUID(route.ChatID) + if parseErr != nil { + return ResolveConversationResult{}, fmt.Errorf("parse route conversation id: %w", parseErr) + } + if touchErr := s.queries.TouchChat(ctx, pgConversationID); touchErr != nil && s.logger != nil { + s.logger.Warn("touch conversation failed", slog.Any("error", touchErr)) + } + return ResolveConversationResult{ChatID: route.ChatID, RouteID: route.ID, Created: false}, nil + } + + if s.conversation == nil { + return ResolveConversationResult{}, fmt.Errorf("conversation service not configured") + } + + kind := determineConversationKind(input.ThreadID, input.ConversationType) + creatorChannelIdentityID := s.resolveConversationCreatorChannelIdentityID(ctx, input.BotID, input.ChannelIdentityID, kind) + + var parentConversationID string + if kind == conversation.KindThread { + parentRoute, parentErr := s.Find(ctx, input.BotID, input.Platform, input.ConversationID, "") + if parentErr == nil { + parentConversationID = parentRoute.ChatID + } + } + + createdConversation, err := s.conversation.Create(ctx, input.BotID, creatorChannelIdentityID, conversation.CreateRequest{ + Kind: kind, + ParentChatID: parentConversationID, + }) + if err != nil { + return ResolveConversationResult{}, fmt.Errorf("create conversation: %w", err) + } + + if strings.TrimSpace(input.ChannelIdentityID) != "" && strings.TrimSpace(input.ChannelIdentityID) != strings.TrimSpace(creatorChannelIdentityID) { + if _, addErr := s.conversation.AddParticipant(ctx, createdConversation.ID, input.ChannelIdentityID, conversation.RoleMember); addErr != nil && s.logger != nil { + s.logger.Warn("auto-add creator participant failed", slog.Any("error", addErr)) + } + } + + newRoute, err := s.Create(ctx, CreateInput{ + ChatID: createdConversation.ID, + BotID: input.BotID, + Platform: input.Platform, + ChannelConfigID: input.ChannelConfigID, + ConversationID: input.ConversationID, + ThreadID: input.ThreadID, + ReplyTarget: input.ReplyTarget, + }) + if err != nil { + return ResolveConversationResult{}, fmt.Errorf("create route: %w", err) + } + + return ResolveConversationResult{ChatID: createdConversation.ID, RouteID: newRoute.ID, Created: true}, nil +} + +func determineConversationKind(threadID, conversationType string) string { + if strings.TrimSpace(threadID) != "" { + return conversation.KindThread + } + ct := strings.ToLower(strings.TrimSpace(conversationType)) + if ct == "p2p" || ct == "private" || ct == "" { + return conversation.KindDirect + } + return conversation.KindGroup +} + +func (s *DBService) resolveConversationCreatorChannelIdentityID(ctx context.Context, botID, fallbackChannelIdentityID, kind string) string { + fallback := strings.TrimSpace(fallbackChannelIdentityID) + if kind != conversation.KindGroup || s.queries == nil { + return fallback + } + pgBotID, err := dbpkg.ParseUUID(botID) + if err != nil { + return fallback + } + row, err := s.queries.GetBotByID(ctx, pgBotID) + if err != nil { + if s.logger != nil { + s.logger.Warn("resolve bot owner for group conversation failed", slog.Any("error", err)) + } + return fallback + } + ownerChannelIdentityID := row.OwnerUserID.String() + if strings.TrimSpace(ownerChannelIdentityID) == "" { + return fallback + } + return ownerChannelIdentityID +} + +func toRouteFromCreate(row sqlc.CreateChatRouteRow) Route { + return toRouteFields( + row.ID, + row.ChatID, + row.BotID, + row.Platform, + row.ChannelConfigID, + row.ConversationID, + row.ThreadID, + row.ReplyTarget, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toRouteFromFind(row sqlc.FindChatRouteRow) Route { + return toRouteFields( + row.ID, + row.ChatID, + row.BotID, + row.Platform, + row.ChannelConfigID, + row.ConversationID, + row.ThreadID, + row.ReplyTarget, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toRouteFromGet(row sqlc.GetChatRouteByIDRow) Route { + return toRouteFields( + row.ID, + row.ChatID, + row.BotID, + row.Platform, + row.ChannelConfigID, + row.ConversationID, + row.ThreadID, + row.ReplyTarget, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toRouteFromList(row sqlc.ListChatRoutesRow) Route { + return toRouteFields( + row.ID, + row.ChatID, + row.BotID, + row.Platform, + row.ChannelConfigID, + row.ConversationID, + row.ThreadID, + row.ReplyTarget, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toRouteFields(id, conversationID, botID pgtype.UUID, platform string, channelConfigID pgtype.UUID, externalConversationID string, threadID, replyTarget pgtype.Text, metadata []byte, createdAt, updatedAt pgtype.Timestamptz) Route { + return Route{ + ID: id.String(), + ChatID: conversationID.String(), + BotID: botID.String(), + Platform: platform, + ChannelConfigID: channelConfigID.String(), + ConversationID: externalConversationID, + ThreadID: dbpkg.TextToString(threadID), + ReplyTarget: dbpkg.TextToString(replyTarget), + Metadata: parseJSONMap(metadata), + CreatedAt: createdAt.Time, + UpdatedAt: updatedAt.Time, + } +} + +func toPgText(value string) pgtype.Text { + value = strings.TrimSpace(value) + if value == "" { + return pgtype.Text{} + } + return pgtype.Text{String: value, Valid: true} +} + +func nonNilMap(m map[string]any) map[string]any { + if m == nil { + return map[string]any{} + } + return m +} + +func parseJSONMap(data []byte) map[string]any { + if len(data) == 0 { + return nil + } + var m map[string]any + _ = json.Unmarshal(data, &m) + return m +} diff --git a/internal/channel/route/types.go b/internal/channel/route/types.go new file mode 100644 index 00000000..a7c1e44c --- /dev/null +++ b/internal/channel/route/types.go @@ -0,0 +1,68 @@ +package route + +import ( + "context" + "time" +) + +// Route maps external channel conversations to an internal conversation. +type Route struct { + ID string `json:"id"` + ChatID string `json:"chat_id"` + BotID string `json:"bot_id"` + Platform string `json:"platform"` + ChannelConfigID string `json:"channel_config_id,omitempty"` + ConversationID string `json:"conversation_id"` + ThreadID string `json:"thread_id,omitempty"` + ReplyTarget string `json:"reply_target,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// ResolveConversationResult is returned by ResolveConversation. +type ResolveConversationResult struct { + ChatID string + RouteID string + Created bool +} + +// CreateInput is the input for creating a route. +type CreateInput struct { + ChatID string + BotID string + Platform string + ChannelConfigID string + ConversationID string + ThreadID string + ReplyTarget string + Metadata map[string]any +} + +// ResolveInput is the input for route-to-conversation resolution. +type ResolveInput struct { + BotID string + Platform string + ConversationID string + ThreadID string + ConversationType string + ChannelIdentityID string + ChannelConfigID string + ReplyTarget string +} + +// Resolver defines the route resolution behavior used by inbound routing. +type Resolver interface { + ResolveConversation(ctx context.Context, input ResolveInput) (ResolveConversationResult, error) +} + +// Service defines route management behavior. +type Service interface { + Resolver + Create(ctx context.Context, input CreateInput) (Route, error) + Find(ctx context.Context, botID, platform, conversationID, threadID string) (Route, error) + GetByID(ctx context.Context, routeID string) (Route, error) + List(ctx context.Context, chatID string) ([]Route, error) + Delete(ctx context.Context, routeID string) error + UpdateReplyTarget(ctx context.Context, routeID, replyTarget string) error +} diff --git a/internal/channel/service.go b/internal/channel/service.go index c5c4f796..5414af65 100644 --- a/internal/channel/service.go +++ b/internal/channel/service.go @@ -132,9 +132,9 @@ func (s *Service) UpsertChannelIdentityConfig(ctx context.Context, channelIdenti return ChannelIdentityBinding{}, err } row, err := s.queries.UpsertUserChannelBinding(ctx, sqlc.UpsertUserChannelBindingParams{ - UserID: pgChannelIdentityID, - Platform: channelType.String(), - Config: payload, + UserID: pgChannelIdentityID, + ChannelType: channelType.String(), + Config: payload, }) if err != nil { return ChannelIdentityBinding{}, err @@ -211,8 +211,8 @@ func (s *Service) GetChannelIdentityConfig(ctx context.Context, channelIdentityI return ChannelIdentityBinding{}, err } row, err := s.queries.GetUserChannelBinding(ctx, sqlc.GetUserChannelBindingParams{ - UserID: pgChannelIdentityID, - Platform: channelType.String(), + UserID: pgChannelIdentityID, + ChannelType: channelType.String(), }) if err != nil { if errors.Is(err, pgx.ErrNoRows) { @@ -225,9 +225,9 @@ func (s *Service) GetChannelIdentityConfig(ctx context.Context, channelIdentityI return ChannelIdentityBinding{}, err } return ChannelIdentityBinding{ - ID: db.UUIDToString(row.ID), - ChannelType: ChannelType(row.Platform), - ChannelIdentityID: db.UUIDToString(row.UserID), + ID: row.ID.String(), + ChannelType: ChannelType(row.ChannelType), + ChannelIdentityID: row.UserID.String(), Config: config, CreatedAt: db.TimeFromPg(row.CreatedAt), UpdatedAt: db.TimeFromPg(row.UpdatedAt), @@ -293,8 +293,8 @@ func normalizeChannelConfig(row sqlc.BotChannelConfig) (ChannelConfig, error) { externalIdentity = strings.TrimSpace(row.ExternalIdentity.String) } return ChannelConfig{ - ID: db.UUIDToString(row.ID), - BotID: db.UUIDToString(row.BotID), + ID: row.ID.String(), + BotID: row.BotID.String(), ChannelType: ChannelType(row.ChannelType), Credentials: credentials, ExternalIdentity: externalIdentity, @@ -313,9 +313,9 @@ func normalizeChannelIdentityBinding(row sqlc.UserChannelBinding) (ChannelIdenti return ChannelIdentityBinding{}, err } return ChannelIdentityBinding{ - ID: db.UUIDToString(row.ID), - ChannelType: ChannelType(row.Platform), - ChannelIdentityID: db.UUIDToString(row.UserID), + ID: row.ID.String(), + ChannelType: ChannelType(row.ChannelType), + ChannelIdentityID: row.UserID.String(), Config: config, CreatedAt: db.TimeFromPg(row.CreatedAt), UpdatedAt: db.TimeFromPg(row.UpdatedAt), diff --git a/internal/channel/types.go b/internal/channel/types.go index 8c6ef1b0..2a4d4d46 100644 --- a/internal/channel/types.go +++ b/internal/channel/types.go @@ -45,7 +45,7 @@ type InboundMessage struct { Message Message BotID string ReplyTarget string - SessionKey string + RouteKey string Sender Identity Conversation Conversation ReceivedAt time.Time @@ -53,22 +53,22 @@ type InboundMessage struct { Metadata map[string]any } -// SessionID returns a stable identifier for the conversation session. +// RoutingKey returns a stable identifier used for reply routing. // Format: platform:bot_id:conversation_id[:sender_id]. -func (m InboundMessage) SessionID() string { - if strings.TrimSpace(m.SessionKey) != "" { - return strings.TrimSpace(m.SessionKey) +func (m InboundMessage) RoutingKey() string { + if strings.TrimSpace(m.RouteKey) != "" { + return strings.TrimSpace(m.RouteKey) } senderID := strings.TrimSpace(m.Sender.SubjectID) if senderID == "" { senderID = strings.TrimSpace(m.Sender.DisplayName) } - return GenerateSessionID(string(m.Channel), m.BotID, m.Conversation.ID, m.Conversation.Type, senderID) + return GenerateRoutingKey(string(m.Channel), m.BotID, m.Conversation.ID, m.Conversation.Type, senderID) } -// GenerateSessionID builds a session identifier from platform, bot, conversation, and sender info. +// GenerateRoutingKey builds a route key from platform, bot, conversation, and sender info. // For group chats, the sender ID is appended to provide per-user context. -func GenerateSessionID(platform, botID, conversationID, conversationType, senderID string) string { +func GenerateRoutingKey(platform, botID, conversationID, conversationType, senderID string) string { parts := []string{platform, botID, conversationID} ct := strings.ToLower(strings.TrimSpace(conversationType)) if ct != "" && ct != "p2p" && ct != "private" { @@ -86,6 +86,47 @@ type OutboundMessage struct { Message Message `json:"message"` } +// StreamEventType defines the kind of outbound stream event. +type StreamEventType string + +const ( + StreamEventStatus StreamEventType = "status" + StreamEventDelta StreamEventType = "delta" + StreamEventFinal StreamEventType = "final" + StreamEventError StreamEventType = "error" +) + +// StreamStatus indicates the lifecycle state of a streaming reply. +type StreamStatus string + +const ( + StreamStatusStarted StreamStatus = "started" + StreamStatusCompleted StreamStatus = "completed" + StreamStatusFailed StreamStatus = "failed" +) + +// StreamFinalizePayload carries the final reply message emitted by a stream. +type StreamFinalizePayload struct { + Message Message `json:"message"` +} + +// StreamEvent represents a unified stream event routed through the channel layer. +type StreamEvent struct { + Type StreamEventType `json:"type"` + Status StreamStatus `json:"status,omitempty"` + Delta string `json:"delta,omitempty"` + Final *StreamFinalizePayload `json:"final,omitempty"` + Error string `json:"error,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// StreamOptions configures how an outbound stream is initialized. +type StreamOptions struct { + Reply *ReplyRef `json:"reply,omitempty"` + SourceMessageID string `json:"source_message_id,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + // MessageFormat indicates how the message text should be rendered. type MessageFormat string @@ -142,17 +183,33 @@ const ( // Attachment represents a binary file attached to a message. type Attachment struct { - Type AttachmentType `json:"type"` - URL string `json:"url,omitempty"` - Name string `json:"name,omitempty"` - Size int64 `json:"size,omitempty"` - Mime string `json:"mime,omitempty"` - DurationMs int64 `json:"duration_ms,omitempty"` - Width int `json:"width,omitempty"` - Height int `json:"height,omitempty"` - ThumbnailURL string `json:"thumbnail_url,omitempty"` - Caption string `json:"caption,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Type AttachmentType `json:"type"` + URL string `json:"url,omitempty"` + PlatformKey string `json:"platform_key,omitempty"` + SourcePlatform string `json:"source_platform,omitempty"` + Name string `json:"name,omitempty"` + Size int64 `json:"size,omitempty"` + Mime string `json:"mime,omitempty"` + DurationMs int64 `json:"duration_ms,omitempty"` + Width int `json:"width,omitempty"` + Height int `json:"height,omitempty"` + ThumbnailURL string `json:"thumbnail_url,omitempty"` + Caption string `json:"caption,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// Reference returns the strongest available attachment reference. +// URL is preferred for cross-platform portability, then platform key. +func (a Attachment) Reference() string { + if strings.TrimSpace(a.URL) != "" { + return strings.TrimSpace(a.URL) + } + return strings.TrimSpace(a.PlatformKey) +} + +// HasReference reports whether URL or platform key is available. +func (a Attachment) HasReference() bool { + return a.Reference() != "" } // Action describes an interactive button or link in a message. diff --git a/internal/chat/assistant_output.go b/internal/chat/assistant_output.go index ae00f91f..235b121d 100644 --- a/internal/chat/assistant_output.go +++ b/internal/chat/assistant_output.go @@ -1,4 +1,4 @@ -package chat +package conversation import "strings" diff --git a/internal/chat/resolver.go b/internal/chat/resolver.go index 5929d599..47f1d84e 100644 --- a/internal/chat/resolver.go +++ b/internal/chat/resolver.go @@ -1,4 +1,4 @@ -package chat +package conversation import ( "bufio" @@ -15,6 +15,7 @@ import ( "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/mcp" "github.com/memohai/memoh/internal/memory" @@ -28,6 +29,7 @@ const ( memoryContextLimitPerScope = 4 memoryContextMaxItems = 8 memoryContextItemMaxChars = 220 + sharedMemoryNamespace = "bot" ) // SkillEntry represents a skill loaded from the container. @@ -110,7 +112,6 @@ type gatewayModelConfig struct { type gatewayIdentity struct { BotID string `json:"botId"` - SessionID string `json:"sessionId"` ContainerID string `json:"containerId"` ChannelIdentityID string `json:"channelIdentityId"` DisplayName string `json:"displayName"` @@ -203,7 +204,6 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex if err != nil { return resolvedContext{}, err } - r.enforceGroupMemoryPolicy(ctx, req.ChatID, &chatSettings) } userSettings, err := r.loadUserSettings(ctx, req.UserID) @@ -222,12 +222,12 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex var messages []ModelMessage if !skipHistory && r.chatService != nil { - messages, err = r.loadChatMessages(ctx, req.ChatID, maxCtx) + messages, err = r.loadMessages(ctx, req.ChatID, maxCtx) if err != nil { return resolvedContext{}, err } } - if memoryMsg := r.loadMemoryContextMessage(ctx, req, chatSettings); memoryMsg != nil { + if memoryMsg := r.loadMemoryContextMessage(ctx, req); memoryMsg != nil { messages = append(messages, *memoryMsg) } messages = append(messages, req.Messages...) @@ -274,7 +274,6 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex Query: req.Query, Identity: gatewayIdentity{ BotID: req.BotID, - SessionID: req.ChatID, ContainerID: containerID, ChannelIdentityID: firstNonEmpty(req.SourceChannelIdentityID, req.UserID), DisplayName: firstNonEmpty(req.DisplayName, "User"), @@ -349,7 +348,6 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc UsableSkills: rc.payload.UsableSkills, Identity: gatewayIdentity{ BotID: rc.payload.Identity.BotID, - SessionID: rc.payload.Identity.SessionID, ContainerID: rc.payload.Identity.ContainerID, ChannelIdentityID: strings.TrimSpace(payload.OwnerUserID), DisplayName: "Scheduler", @@ -387,20 +385,31 @@ func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan Stre defer close(chunkCh) defer close(errCh) - rc, err := r.resolve(ctx, req) + streamReq := req + rc, err := r.resolve(ctx, streamReq) if err != nil { r.logger.Error("gateway stream resolve failed", - slog.String("bot_id", req.BotID), - slog.String("chat_id", req.ChatID), + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), slog.Any("error", err), ) errCh <- err return } - if err := r.streamChat(ctx, rc.payload, req, chunkCh); err != nil { + if err := r.persistUserMessage(ctx, streamReq); err != nil { + r.logger.Error("gateway stream persist user message failed", + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), + slog.Any("error", err), + ) + errCh <- err + return + } + streamReq.UserMessagePersisted = true + if err := r.streamChat(ctx, rc.payload, streamReq, chunkCh); err != nil { r.logger.Error("gateway stream request failed", - slog.String("bot_id", req.BotID), - slog.String("chat_id", req.ChatID), + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), slog.Any("error", err), ) errCh <- err @@ -614,7 +623,7 @@ func (r *Resolver) resolveContainerID(ctx context.Context, botID, explicit strin // --- message loading --- -func (r *Resolver) loadChatMessages(ctx context.Context, chatID string, maxContextMinutes int) ([]ModelMessage, error) { +func (r *Resolver) loadMessages(ctx context.Context, chatID string, maxContextMinutes int) ([]ModelMessage, error) { since := time.Now().UTC().Add(-time.Duration(maxContextMinutes) * time.Minute) msgs, err := r.chatService.ListMessagesSince(ctx, chatID, since) if err != nil { @@ -639,65 +648,46 @@ type memoryContextItem struct { Item memory.MemoryItem } -func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest, settings Settings) *ModelMessage { +func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest) *ModelMessage { if r.memoryService == nil { return nil } if strings.TrimSpace(req.Query) == "" || strings.TrimSpace(req.BotID) == "" || strings.TrimSpace(req.ChatID) == "" { return nil } - type memoryScope struct { - Namespace string - ScopeID string - } - var scopes []memoryScope - if settings.EnableChatMemory { - scopes = append(scopes, memoryScope{Namespace: "chat", ScopeID: req.ChatID}) - } - if settings.EnablePrivateMemory && strings.TrimSpace(req.UserID) != "" { - scopes = append(scopes, memoryScope{Namespace: "private", ScopeID: req.UserID}) - } - if settings.EnablePublicMemory { - scopes = append(scopes, memoryScope{Namespace: "public", ScopeID: req.BotID}) - } - if len(scopes) == 0 { + + results := make([]memoryContextItem, 0, memoryContextLimitPerScope) + seen := map[string]struct{}{} + resp, err := r.memoryService.Search(ctx, memory.SearchRequest{ + Query: req.Query, + BotID: req.BotID, + Limit: memoryContextLimitPerScope, + Filters: map[string]any{ + "namespace": sharedMemoryNamespace, + "scopeId": req.BotID, + "botId": req.BotID, + }, + }) + if err != nil { + r.logger.Warn("memory search for context failed", + slog.String("namespace", sharedMemoryNamespace), + slog.Any("error", err), + ) return nil } - - results := make([]memoryContextItem, 0, len(scopes)*memoryContextLimitPerScope) - seen := map[string]struct{}{} - for _, scope := range scopes { - resp, err := r.memoryService.Search(ctx, memory.SearchRequest{ - Query: req.Query, - BotID: req.BotID, - Limit: memoryContextLimitPerScope, - Filters: map[string]any{ - "namespace": scope.Namespace, - "scopeId": scope.ScopeID, - "botId": req.BotID, - }, - }) - if err != nil { - r.logger.Warn("memory search for context failed", - slog.String("namespace", scope.Namespace), - slog.Any("error", err), - ) + for _, item := range resp.Results { + key := strings.TrimSpace(item.ID) + if key == "" { + key = sharedMemoryNamespace + ":" + strings.TrimSpace(item.Memory) + } + if key == "" { continue } - for _, item := range resp.Results { - key := strings.TrimSpace(item.ID) - if key == "" { - key = scope.Namespace + ":" + strings.TrimSpace(item.Memory) - } - if key == "" { - continue - } - if _, ok := seen[key]; ok { - continue - } - seen[key] = struct{}{} - results = append(results, memoryContextItem{Namespace: scope.Namespace, Item: item}) + if _, ok := seen[key]; ok { + continue } + seen[key] = struct{}{} + results = append(results, memoryContextItem{Namespace: sharedMemoryNamespace, Item: item}) } if len(results) == 0 { return nil @@ -736,6 +726,43 @@ func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest // --- store helpers --- +func (r *Resolver) persistUserMessage(ctx context.Context, req ChatRequest) error { + if r.chatService == nil { + return nil + } + if strings.TrimSpace(req.BotID) == "" { + return fmt.Errorf("bot id is required for persistence") + } + text := strings.TrimSpace(req.Query) + if text == "" { + return nil + } + + message := ModelMessage{ + Role: "user", + Content: NewTextContent(text), + } + content, err := json.Marshal(message) + if err != nil { + return err + } + senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req) + _, err = r.chatService.PersistMessage( + ctx, + req.BotID, + req.RouteID, + senderChannelIdentityID, + senderUserID, + req.CurrentChannel, + req.ExternalMessageID, + "", + "user", + content, + buildRouteMetadata(req), + ) + return err +} + func (r *Resolver) storeRound(ctx context.Context, req ChatRequest, messages []ModelMessage) error { // Add user query as the first message if not already present in the round. // This ensures the user's prompt is persisted alongside the assistant's response. @@ -747,16 +774,25 @@ func (r *Resolver) storeRound(ctx context.Context, req ChatRequest, messages []M break } } - if !hasUserQuery && strings.TrimSpace(req.Query) != "" { + if !req.UserMessagePersisted && !hasUserQuery && strings.TrimSpace(req.Query) != "" { fullRound = append(fullRound, ModelMessage{ Role: "user", Content: NewTextContent(req.Query), }) } - fullRound = append(fullRound, messages...) + for _, m := range messages { + if req.UserMessagePersisted && m.Role == "user" && strings.TrimSpace(m.TextContent()) == strings.TrimSpace(req.Query) { + // User message was already persisted before streaming; skip duplicate copy in round payload. + continue + } + fullRound = append(fullRound, m) + } + if len(fullRound) == 0 { + return nil + } r.storeMessages(ctx, req, fullRound) - r.storeMemory(ctx, req.BotID, req.ChatID, req.UserID, req.Query, fullRound) + r.storeMemory(ctx, req.BotID, fullRound) return nil } @@ -764,42 +800,37 @@ func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages if r.chatService == nil { return } - if strings.TrimSpace(req.BotID) == "" || strings.TrimSpace(req.ChatID) == "" { + if strings.TrimSpace(req.BotID) == "" { return } - // Build route-level metadata for traceability. - var meta map[string]any - if strings.TrimSpace(req.RouteID) != "" || strings.TrimSpace(req.CurrentChannel) != "" { - meta = map[string]any{} - if strings.TrimSpace(req.RouteID) != "" { - meta["route_id"] = req.RouteID - } - if strings.TrimSpace(req.CurrentChannel) != "" { - meta["platform"] = req.CurrentChannel - } - } + meta := buildRouteMetadata(req) + senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req) for _, msg := range messages { content, err := json.Marshal(msg) if err != nil { continue } - senderChannelIdentityID := "" - senderUserID := "" + messageSenderChannelIdentityID := "" + messageSenderUserID := "" externalMessageID := "" + sourceReplyToMessageID := "" if msg.Role == "user" { - senderChannelIdentityID = req.SourceChannelIdentityID - senderUserID = req.UserID + messageSenderChannelIdentityID = senderChannelIdentityID + messageSenderUserID = senderUserID externalMessageID = req.ExternalMessageID + } else if strings.TrimSpace(req.ExternalMessageID) != "" { + // Assistant/tool/system outputs are linked to the inbound source message for cross-channel reply threading. + sourceReplyToMessageID = req.ExternalMessageID } if _, err := r.chatService.PersistMessage( ctx, - req.ChatID, req.BotID, req.RouteID, - senderChannelIdentityID, - senderUserID, + messageSenderChannelIdentityID, + messageSenderUserID, req.CurrentChannel, externalMessageID, + sourceReplyToMessageID, msg.Role, content, meta, @@ -809,11 +840,93 @@ func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages } } -func (r *Resolver) storeMemory(ctx context.Context, botID, chatID, userID, query string, messages []ModelMessage) { +func buildRouteMetadata(req ChatRequest) map[string]any { + if strings.TrimSpace(req.RouteID) == "" && strings.TrimSpace(req.CurrentChannel) == "" { + return nil + } + meta := map[string]any{} + if strings.TrimSpace(req.RouteID) != "" { + meta["route_id"] = req.RouteID + } + if strings.TrimSpace(req.CurrentChannel) != "" { + meta["platform"] = req.CurrentChannel + } + return meta +} + +func (r *Resolver) resolvePersistSenderIDs(ctx context.Context, req ChatRequest) (string, string) { + channelIdentityID := strings.TrimSpace(req.SourceChannelIdentityID) + userID := strings.TrimSpace(req.UserID) + + channelIdentityValid := r.isExistingChannelIdentityID(ctx, channelIdentityID) + userAsUserValid := r.isExistingUserID(ctx, userID) + userAsChannelIdentityValid := r.isExistingChannelIdentityID(ctx, userID) + + senderChannelIdentityID := "" + switch { + case channelIdentityValid: + senderChannelIdentityID = channelIdentityID + case userAsChannelIdentityValid && !userAsUserValid: + // Some flows may carry channel_identity_id in req.UserID. + senderChannelIdentityID = userID + } + + senderUserID := "" + if userAsUserValid { + senderUserID = userID + } + if senderUserID == "" && senderChannelIdentityID != "" { + if linked := r.linkedUserIDFromChannelIdentity(ctx, senderChannelIdentityID); linked != "" { + senderUserID = linked + } + } + return senderChannelIdentityID, senderUserID +} + +func (r *Resolver) isExistingChannelIdentityID(ctx context.Context, id string) bool { + if r.queries == nil { + return false + } + pgID, err := parseUUID(id) + if err != nil { + return false + } + _, err = r.queries.GetChannelIdentityByID(ctx, pgID) + return err == nil +} + +func (r *Resolver) isExistingUserID(ctx context.Context, id string) bool { + if r.queries == nil { + return false + } + pgID, err := parseUUID(id) + if err != nil { + return false + } + _, err = r.queries.GetUserByID(ctx, pgID) + return err == nil +} + +func (r *Resolver) linkedUserIDFromChannelIdentity(ctx context.Context, channelIdentityID string) string { + if r.queries == nil { + return "" + } + pgID, err := parseUUID(channelIdentityID) + if err != nil { + return "" + } + row, err := r.queries.GetChannelIdentityByID(ctx, pgID) + if err != nil || !row.UserID.Valid { + return "" + } + return row.UserID.String() +} + +func (r *Resolver) storeMemory(ctx context.Context, botID string, messages []ModelMessage) { if r.memoryService == nil { return } - if strings.TrimSpace(botID) == "" || strings.TrimSpace(chatID) == "" { + if strings.TrimSpace(botID) == "" { return } memMsgs := make([]memory.Message, 0, len(messages)) @@ -831,33 +944,7 @@ func (r *Resolver) storeMemory(ctx context.Context, botID, chatID, userID, query if len(memMsgs) == 0 { return } - - // Load chat settings to determine which namespaces to write to. - var cs Settings - if r.chatService != nil { - settings, err := r.chatService.GetSettings(ctx, chatID) - if err != nil { - r.logger.Warn("load chat settings for memory write failed", slog.Any("error", err)) - } else { - cs = settings - r.enforceGroupMemoryPolicy(ctx, chatID, &cs) - } - } - - // Always write to chat namespace if enabled (default true). - if cs.EnableChatMemory { - r.addMemory(ctx, botID, memMsgs, "chat", chatID) - } - - // Write to private namespace if enabled and user id is known. - if cs.EnablePrivateMemory && strings.TrimSpace(userID) != "" { - r.addMemory(ctx, botID, memMsgs, "private", userID) - } - - // Write to public namespace if enabled. - if cs.EnablePublicMemory { - r.addMemory(ctx, botID, memMsgs, "public", botID) - } + r.addMemory(ctx, botID, memMsgs, sharedMemoryNamespace, botID) } func (r *Resolver) addMemory(ctx context.Context, botID string, msgs []memory.Message, namespace, scopeID string) { @@ -1081,26 +1168,8 @@ func truncateMemorySnippet(s string, n int) string { } func parseUUID(id string) (pgtype.UUID, error) { - trimmed := strings.TrimSpace(id) - if trimmed == "" { + if strings.TrimSpace(id) == "" { return pgtype.UUID{}, fmt.Errorf("empty id") } - var pgID pgtype.UUID - if err := pgID.Scan(trimmed); err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - return pgID, nil -} - -func (r *Resolver) enforceGroupMemoryPolicy(ctx context.Context, chatID string, settings *Settings) { - if r == nil || r.chatService == nil || settings == nil { - return - } - chatObj, err := r.chatService.Get(ctx, chatID) - if err != nil { - return - } - if strings.EqualFold(strings.TrimSpace(chatObj.Kind), KindGroup) { - settings.EnablePrivateMemory = false - } + return db.ParseUUID(id) } diff --git a/internal/chat/resolver_memory_context_test.go b/internal/chat/resolver_memory_context_test.go index 082bad6d..48c9eb39 100644 --- a/internal/chat/resolver_memory_context_test.go +++ b/internal/chat/resolver_memory_context_test.go @@ -1,4 +1,4 @@ -package chat +package conversation import ( "context" @@ -18,8 +18,6 @@ func TestLoadMemoryContextMessage_NoMemoryService(t *testing.T) { Query: "hello", BotID: "bot-1", ChatID: "chat-1", - }, Settings{ - EnableChatMemory: true, }) if msg != nil { t.Fatalf("expected nil message when memory service is nil") @@ -36,10 +34,6 @@ func TestLoadMemoryContextMessage_SearchFailureFallback(t *testing.T) { BotID: "bot-1", ChatID: "chat-1", UserID: "user-1", - }, Settings{ - EnableChatMemory: true, - EnablePrivateMemory: true, - EnablePublicMemory: true, }) if msg != nil { t.Fatalf("expected nil message when memory search cannot return results") diff --git a/internal/chat/resolver_test.go b/internal/chat/resolver_test.go index 74d8a329..e866ba00 100644 --- a/internal/chat/resolver_test.go +++ b/internal/chat/resolver_test.go @@ -1,4 +1,4 @@ -package chat +package conversation import ( "context" @@ -48,7 +48,6 @@ func TestPostTriggerSchedule_Endpoint(t *testing.T) { Skills: []string{}, Identity: gatewayIdentity{ BotID: "bot-123", - SessionID: "schedule:sched-1", ContainerID: "mcp-bot-123", ChannelIdentityID: "owner-user-1", DisplayName: "Scheduler", diff --git a/internal/chat/schedule_gateway.go b/internal/chat/schedule_gateway.go index d1578065..8e543f78 100644 --- a/internal/chat/schedule_gateway.go +++ b/internal/chat/schedule_gateway.go @@ -1,4 +1,4 @@ -package chat +package conversation import ( "context" diff --git a/internal/chat/service.go b/internal/chat/service.go index 863db8c8..1ce4a6e5 100644 --- a/internal/chat/service.go +++ b/internal/chat/service.go @@ -1,4 +1,4 @@ -package chat +package conversation import ( "context" @@ -12,6 +12,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -41,7 +42,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { // --- Chat CRUD --- // Create creates a new chat and adds the creator as owner. -func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, req CreateChatRequest) (Chat, error) { +func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, req CreateRequest) (Chat, error) { kind := strings.TrimSpace(req.Kind) if kind == "" { kind = KindDirect @@ -79,7 +80,7 @@ func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, r BotID: pgBotID, Kind: kind, ParentChatID: pgParent, - Title: toPgText(req.Title), + Title: strings.TrimSpace(req.Title), CreatedByUserID: pgChannelIdentityID, Metadata: metadata, }) @@ -98,29 +99,17 @@ func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, r } } - // Create default settings based on kind. - enablePrivate := kind != KindGroup - if _, err := s.queries.UpsertChatSettings(ctx, sqlc.UpsertChatSettingsParams{ - ID: row.ID, - EnableChatMemory: true, - EnablePrivateMemory: enablePrivate, - EnablePublicMemory: false, - SettingsMetadata: []byte("{}"), - }); err != nil { - return Chat{}, fmt.Errorf("create default settings: %w", err) - } - // For threads, copy participants from parent. if kind == KindThread && pgParent.Valid { if err := s.queries.CopyParticipantsToChat(ctx, sqlc.CopyParticipantsToChatParams{ - ChatID: pgParent, - ChatID_2: row.ID, + ChatID: pgParent, + ChatID2: row.ID, }); err != nil { s.logger.Warn("copy parent participants failed", slog.Any("error", err)) } } - return toChat(row), nil + return toChatFromCreate(row), nil } // Get returns a chat by ID. @@ -136,7 +125,7 @@ func (s *Service) Get(ctx context.Context, chatID string) (Chat, error) { } return Chat{}, err } - return toChat(row), nil + return toChatFromGet(row), nil } // GetReadAccess resolves whether a user can read a chat. @@ -202,7 +191,7 @@ func (s *Service) ListThreads(ctx context.Context, parentChatID string) ([]Chat, } chats := make([]Chat, 0, len(rows)) for _, row := range rows { - chats = append(chats, toChat(row)) + chats = append(chats, toChatFromThread(row)) } return chats, nil } @@ -239,7 +228,7 @@ func (s *Service) AddParticipant(ctx context.Context, chatID, channelIdentityID, if err != nil { return Participant{}, err } - return toParticipant(row), nil + return toParticipantFromAdd(row), nil } // GetParticipant returns a participant record. @@ -262,7 +251,7 @@ func (s *Service) GetParticipant(ctx context.Context, chatID, channelIdentityID } return Participant{}, err } - return toParticipant(row), nil + return toParticipantFromGet(row), nil } // IsParticipant checks whether a user identity is a participant in a chat. @@ -286,7 +275,7 @@ func (s *Service) ListParticipants(ctx context.Context, chatID string) ([]Partic } participants := make([]Participant, 0, len(rows)) for _, row := range rows { - participants = append(participants, toParticipant(row)) + participants = append(participants, toParticipantFromList(row)) } return participants, nil } @@ -312,27 +301,17 @@ func (s *Service) RemoveParticipant(ctx context.Context, chatID, channelIdentity // GetSettings returns settings for a chat. Returns defaults if not found. func (s *Service) GetSettings(ctx context.Context, chatID string) (Settings, error) { pgID, err := parseUUID(chatID) - var current Settings if err != nil { - current = defaultSettings(chatID) - return current, nil + return defaultSettings(chatID), nil } row, err := s.queries.GetChatSettings(ctx, pgID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - current = defaultSettings(chatID) - if s.isGroupChat(ctx, chatID) { - current.EnablePrivateMemory = false - } - return current, nil + return defaultSettings(chatID), nil } return Settings{}, err } - current = toSettingsFromRead(row) - if s.isGroupChat(ctx, chatID) { - current.EnablePrivateMemory = false - } - return current, nil + return toSettingsFromRead(row), nil } // UpdateSettings updates chat settings. @@ -341,35 +320,17 @@ func (s *Service) UpdateSettings(ctx context.Context, chatID string, req UpdateS if err != nil { return Settings{}, err } - isGroup := s.isGroupChat(ctx, chatID) - if req.EnableChatMemory != nil { - current.EnableChatMemory = *req.EnableChatMemory - } - if req.EnablePrivateMemory != nil { - current.EnablePrivateMemory = *req.EnablePrivateMemory - } - if req.EnablePublicMemory != nil { - current.EnablePublicMemory = *req.EnablePublicMemory - } if req.ModelID != nil { current.ModelID = *req.ModelID } - if isGroup { - // Group chats are shared contexts, so private memory stays disabled. - current.EnablePrivateMemory = false - } pgID, err := parseUUID(chatID) if err != nil { return Settings{}, err } row, err := s.queries.UpsertChatSettings(ctx, sqlc.UpsertChatSettingsParams{ - ID: pgID, - EnableChatMemory: current.EnableChatMemory, - EnablePrivateMemory: current.EnablePrivateMemory, - EnablePublicMemory: current.EnablePublicMemory, - ModelID: toPgText(current.ModelID), - SettingsMetadata: []byte("{}"), + ID: pgID, + ModelID: toPgText(current.ModelID), }) if err != nil { return Settings{}, err @@ -413,7 +374,7 @@ func (s *Service) CreateRoute(ctx context.Context, chatID string, r Route) (Rout if err != nil { return Route{}, fmt.Errorf("create route: %w", err) } - return toRoute(row), nil + return toRouteFromCreate(row), nil } // FindRoute looks up a route by (bot_id, platform, conversation_id, thread_id). @@ -431,7 +392,7 @@ func (s *Service) FindRoute(ctx context.Context, botID, platform, conversationID if err != nil { return Route{}, err } - return toRoute(row), nil + return toRouteFromFind(row), nil } // GetRouteByID returns a single route by its ID. @@ -444,7 +405,7 @@ func (s *Service) GetRouteByID(ctx context.Context, routeID string) (Route, erro if err != nil { return Route{}, err } - return toRoute(row), nil + return toRouteFromGet(row), nil } // ListRoutes lists all routes for a chat. @@ -459,7 +420,7 @@ func (s *Service) ListRoutes(ctx context.Context, chatID string) ([]Route, error } routes := make([]Route, 0, len(rows)) for _, row := range rows { - routes = append(routes, toRoute(row)) + routes = append(routes, toRouteFromList(row)) } return routes, nil } @@ -532,7 +493,7 @@ func (s *Service) ResolveChat(ctx context.Context, botID, platform, conversation } } - c, err := s.Create(ctx, botID, creatorChannelIdentityID, CreateChatRequest{ + c, err := s.Create(ctx, botID, creatorChannelIdentityID, CreateRequest{ Kind: kind, ParentChatID: parentChatID, }) @@ -562,12 +523,8 @@ func (s *Service) ResolveChat(ctx context.Context, botID, platform, conversation // --- Messages --- -// PersistMessage writes a single message to chat_messages. -func (s *Service) PersistMessage(ctx context.Context, chatID, botID, routeID, senderChannelIdentityID, senderUserID, platform, externalMessageID, role string, content json.RawMessage, metadata map[string]any) (Message, error) { - pgChatID, err := parseUUID(chatID) - if err != nil { - return Message{}, err - } +// PersistMessage writes a single message to bot_history_messages. +func (s *Service) PersistMessage(ctx context.Context, botID, routeID, senderChannelIdentityID, senderUserID, platform, externalMessageID, sourceReplyToMessageID, role string, content json.RawMessage, metadata map[string]any) (Message, error) { pgBotID, err := parseUUID(botID) if err != nil { return Message{}, err @@ -601,14 +558,14 @@ func (s *Service) PersistMessage(ctx context.Context, chatID, botID, routeID, se content = []byte("{}") } - row, err := s.queries.CreateChatMessage(ctx, sqlc.CreateChatMessageParams{ - ChatID: pgChatID, + row, err := s.queries.CreateMessage(ctx, sqlc.CreateMessageParams{ BotID: pgBotID, RouteID: pgRouteID, SenderChannelIdentityID: pgSender, SenderUserID: pgSenderUser, Platform: toPgText(platform), ExternalMessageID: toPgText(externalMessageID), + SourceReplyToMessageID: toPgText(sourceReplyToMessageID), Role: role, Content: content, Metadata: metaBytes, @@ -616,96 +573,129 @@ func (s *Service) PersistMessage(ctx context.Context, chatID, botID, routeID, se if err != nil { return Message{}, err } - if pgSender.Valid { - if err := s.queries.UpsertChatChannelIdentityPresence(ctx, sqlc.UpsertChatChannelIdentityPresenceParams{ - ChatID: pgChatID, - ChannelIdentityID: pgSender, - }); err != nil && s.logger != nil { - // Presence is a derived cache. Keep message persistence successful even if cache update fails. - s.logger.Warn("upsert chat channel identity presence failed", slog.Any("error", err)) - } - } - return toMessage(row), nil + return toMessageFromCreate(row), nil } -// ListMessages returns all messages for a chat. -func (s *Service) ListMessages(ctx context.Context, chatID string) ([]Message, error) { - pgID, err := parseUUID(chatID) +// ListMessages returns all messages for a bot. +func (s *Service) ListMessages(ctx context.Context, botID string) ([]Message, error) { + pgID, err := parseUUID(botID) if err != nil { return nil, err } - rows, err := s.queries.ListChatMessages(ctx, pgID) + rows, err := s.queries.ListMessages(ctx, pgID) if err != nil { return nil, err } - return toMessages(rows), nil + return toMessagesFromList(rows), nil } -// ListMessagesSince returns messages since a given time. -func (s *Service) ListMessagesSince(ctx context.Context, chatID string, since time.Time) ([]Message, error) { - pgID, err := parseUUID(chatID) +// ListMessagesSince returns bot messages since a given time. +func (s *Service) ListMessagesSince(ctx context.Context, botID string, since time.Time) ([]Message, error) { + pgID, err := parseUUID(botID) if err != nil { return nil, err } - rows, err := s.queries.ListChatMessagesSince(ctx, sqlc.ListChatMessagesSinceParams{ - ChatID: pgID, + rows, err := s.queries.ListMessagesSince(ctx, sqlc.ListMessagesSinceParams{ + BotID: pgID, CreatedAt: pgtype.Timestamptz{Time: since, Valid: true}, }) if err != nil { return nil, err } - return toMessages(rows), nil + return toMessagesFromSince(rows), nil } -// ListMessagesLatest returns the latest N messages (most recent first). -func (s *Service) ListMessagesLatest(ctx context.Context, chatID string, limit int32) ([]Message, error) { - pgID, err := parseUUID(chatID) +// ListMessagesLatest returns the latest N bot messages (most recent first). +func (s *Service) ListMessagesLatest(ctx context.Context, botID string, limit int32) ([]Message, error) { + pgID, err := parseUUID(botID) if err != nil { return nil, err } - rows, err := s.queries.ListChatMessagesLatest(ctx, sqlc.ListChatMessagesLatestParams{ - ChatID: pgID, - Limit: limit, + rows, err := s.queries.ListMessagesLatest(ctx, sqlc.ListMessagesLatestParams{ + BotID: pgID, + MaxCount: limit, }) if err != nil { return nil, err } - return toMessages(rows), nil + return toMessagesFromLatest(rows), nil } -// DeleteMessages deletes all messages for a chat. -func (s *Service) DeleteMessages(ctx context.Context, chatID string) error { - pgID, err := parseUUID(chatID) +// DeleteMessages deletes all messages for a bot. +func (s *Service) DeleteMessages(ctx context.Context, botID string) error { + pgID, err := parseUUID(botID) if err != nil { return err } - return s.queries.DeleteChatMessagesByChat(ctx, pgID) + return s.queries.DeleteMessagesByBot(ctx, pgID) } // --- conversion helpers --- -func toChat(row sqlc.Chat) Chat { +func toChatFromCreate(row sqlc.CreateChatRow) Chat { + return toChatFields( + row.ID, + row.BotID, + row.Kind, + row.ParentChatID, + row.Title, + row.CreatedByUserID, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toChatFromGet(row sqlc.GetChatByIDRow) Chat { + return toChatFields( + row.ID, + row.BotID, + row.Kind, + row.ParentChatID, + row.Title, + row.CreatedByUserID, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toChatFromThread(row sqlc.ListThreadsByParentRow) Chat { + return toChatFields( + row.ID, + row.BotID, + row.Kind, + row.ParentChatID, + row.Title, + row.CreatedByUserID, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toChatFields(id, botID pgtype.UUID, kind string, parentChatID pgtype.UUID, title pgtype.Text, createdBy pgtype.UUID, metadata []byte, createdAt, updatedAt pgtype.Timestamptz) Chat { return Chat{ - ID: uuidString(row.ID), - BotID: uuidString(row.BotID), - Kind: row.Kind, - ParentChatID: uuidString(row.ParentChatID), - Title: pgTextString(row.Title), - CreatedBy: uuidString(row.CreatedByUserID), - Metadata: parseJSONMap(row.Metadata), - CreatedAt: row.CreatedAt.Time, - UpdatedAt: row.UpdatedAt.Time, + ID: id.String(), + BotID: botID.String(), + Kind: kind, + ParentChatID: parentChatID.String(), + Title: db.TextToString(title), + CreatedBy: createdBy.String(), + Metadata: parseJSONMap(metadata), + CreatedAt: createdAt.Time, + UpdatedAt: updatedAt.Time, } } func toChatListItem(row sqlc.ListVisibleChatsByBotAndUserRow) ChatListItem { return ChatListItem{ - ID: uuidString(row.ID), - BotID: uuidString(row.BotID), + ID: row.ID.String(), + BotID: row.BotID.String(), Kind: row.Kind, - ParentChatID: uuidString(row.ParentChatID), - Title: pgTextString(row.Title), - CreatedBy: uuidString(row.CreatedByUserID), + ParentChatID: row.ParentChatID.String(), + Title: db.TextToString(row.Title), + CreatedBy: row.CreatedByUserID.String(), Metadata: parseJSONMap(row.Metadata), CreatedAt: row.CreatedAt.Time, UpdatedAt: row.UpdatedAt.Time, @@ -715,84 +705,233 @@ func toChatListItem(row sqlc.ListVisibleChatsByBotAndUserRow) ChatListItem { } } -func toParticipant(row sqlc.ChatParticipant) Participant { +func toParticipantFromAdd(row sqlc.AddChatParticipantRow) Participant { + return toParticipantFields(row.ChatID, row.UserID, row.Role, row.JoinedAt) +} + +func toParticipantFromGet(row sqlc.GetChatParticipantRow) Participant { + return toParticipantFields(row.ChatID, row.UserID, row.Role, row.JoinedAt) +} + +func toParticipantFromList(row sqlc.ListChatParticipantsRow) Participant { + return toParticipantFields(row.ChatID, row.UserID, row.Role, row.JoinedAt) +} + +func toParticipantFields(chatID, userID pgtype.UUID, role string, joinedAt pgtype.Timestamptz) Participant { return Participant{ - ChatID: uuidString(row.ChatID), - UserID: uuidString(row.UserID), - Role: row.Role, - JoinedAt: row.JoinedAt.Time, + ChatID: chatID.String(), + UserID: userID.String(), + Role: role, + JoinedAt: joinedAt.Time, } } func toSettingsFromRead(row sqlc.GetChatSettingsRow) Settings { return Settings{ - ChatID: uuidString(row.ChatID), - EnableChatMemory: row.EnableChatMemory, - EnablePrivateMemory: row.EnablePrivateMemory, - EnablePublicMemory: row.EnablePublicMemory, - ModelID: pgTextString(row.ModelID), - Metadata: parseJSONMap(row.Metadata), + ChatID: row.ChatID.String(), + ModelID: db.TextToString(row.ModelID), } } func toSettingsFromUpsert(row sqlc.UpsertChatSettingsRow) Settings { return Settings{ - ChatID: uuidString(row.ChatID), - EnableChatMemory: row.EnableChatMemory, - EnablePrivateMemory: row.EnablePrivateMemory, - EnablePublicMemory: row.EnablePublicMemory, - ModelID: pgTextString(row.ModelID), - Metadata: parseJSONMap(row.Metadata), + ChatID: row.ChatID.String(), + ModelID: db.TextToString(row.ModelID), } } -func toRoute(row sqlc.ChatRoute) Route { +func toRouteFromCreate(row sqlc.CreateChatRouteRow) Route { + return toRouteFields( + row.ID, + row.ChatID, + row.BotID, + row.Platform, + row.ChannelConfigID, + row.ConversationID, + row.ThreadID, + row.ReplyTarget, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toRouteFromFind(row sqlc.FindChatRouteRow) Route { + return toRouteFields( + row.ID, + row.ChatID, + row.BotID, + row.Platform, + row.ChannelConfigID, + row.ConversationID, + row.ThreadID, + row.ReplyTarget, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toRouteFromGet(row sqlc.GetChatRouteByIDRow) Route { + return toRouteFields( + row.ID, + row.ChatID, + row.BotID, + row.Platform, + row.ChannelConfigID, + row.ConversationID, + row.ThreadID, + row.ReplyTarget, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toRouteFromList(row sqlc.ListChatRoutesRow) Route { + return toRouteFields( + row.ID, + row.ChatID, + row.BotID, + row.Platform, + row.ChannelConfigID, + row.ConversationID, + row.ThreadID, + row.ReplyTarget, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toRouteFields(id, chatID, botID pgtype.UUID, platform string, channelConfigID pgtype.UUID, conversationID string, threadID, replyTarget pgtype.Text, metadata []byte, createdAt, updatedAt pgtype.Timestamptz) Route { return Route{ - ID: uuidString(row.ID), - ChatID: uuidString(row.ChatID), - BotID: uuidString(row.BotID), - Platform: row.Platform, - ChannelConfigID: uuidString(row.ChannelConfigID), - ConversationID: row.ConversationID, - ThreadID: pgTextString(row.ThreadID), - ReplyTarget: pgTextString(row.ReplyTarget), - Metadata: parseJSONMap(row.Metadata), - CreatedAt: row.CreatedAt.Time, - UpdatedAt: row.UpdatedAt.Time, + ID: id.String(), + ChatID: chatID.String(), + BotID: botID.String(), + Platform: platform, + ChannelConfigID: channelConfigID.String(), + ConversationID: conversationID, + ThreadID: db.TextToString(threadID), + ReplyTarget: db.TextToString(replyTarget), + Metadata: parseJSONMap(metadata), + CreatedAt: createdAt.Time, + UpdatedAt: updatedAt.Time, } } -func toMessage(row sqlc.ChatMessage) Message { +func toMessageFromCreate(row sqlc.CreateMessageRow) Message { + return toMessageFields( + row.ID, + row.BotID, + row.RouteID, + row.SenderChannelIdentityID, + row.SenderUserID, + row.Platform, + row.ExternalMessageID, + row.SourceReplyToMessageID, + row.Role, + row.Content, + row.Metadata, + row.CreatedAt, + ) +} + +func toMessageFromListRow(row sqlc.ListMessagesRow) Message { + return toMessageFields( + row.ID, + row.BotID, + row.RouteID, + row.SenderChannelIdentityID, + row.SenderUserID, + row.Platform, + row.ExternalMessageID, + row.SourceReplyToMessageID, + row.Role, + row.Content, + row.Metadata, + row.CreatedAt, + ) +} + +func toMessageFromSinceRow(row sqlc.ListMessagesSinceRow) Message { + return toMessageFields( + row.ID, + row.BotID, + row.RouteID, + row.SenderChannelIdentityID, + row.SenderUserID, + row.Platform, + row.ExternalMessageID, + row.SourceReplyToMessageID, + row.Role, + row.Content, + row.Metadata, + row.CreatedAt, + ) +} + +func toMessageFromLatestRow(row sqlc.ListMessagesLatestRow) Message { + return toMessageFields( + row.ID, + row.BotID, + row.RouteID, + row.SenderChannelIdentityID, + row.SenderUserID, + row.Platform, + row.ExternalMessageID, + row.SourceReplyToMessageID, + row.Role, + row.Content, + row.Metadata, + row.CreatedAt, + ) +} + +func toMessageFields(id, botID, routeID, senderChannelIdentityID, senderUserID pgtype.UUID, platform, externalMessageID, sourceReplyToMessageID pgtype.Text, role string, content, metadata []byte, createdAt pgtype.Timestamptz) Message { return Message{ - ID: uuidString(row.ID), - ChatID: uuidString(row.ChatID), - BotID: uuidString(row.BotID), - RouteID: uuidString(row.RouteID), - SenderChannelIdentityID: uuidString(row.SenderChannelIdentityID), - SenderUserID: uuidString(row.SenderUserID), - Platform: pgTextString(row.Platform), - ExternalMessageID: pgTextString(row.ExternalMessageID), - Role: row.Role, - Content: json.RawMessage(row.Content), - Metadata: parseJSONMap(row.Metadata), - CreatedAt: row.CreatedAt.Time, + ID: id.String(), + BotID: botID.String(), + RouteID: routeID.String(), + SenderChannelIdentityID: senderChannelIdentityID.String(), + SenderUserID: senderUserID.String(), + Platform: db.TextToString(platform), + ExternalMessageID: db.TextToString(externalMessageID), + SourceReplyToMessageID: db.TextToString(sourceReplyToMessageID), + Role: role, + Content: json.RawMessage(content), + Metadata: parseJSONMap(metadata), + CreatedAt: createdAt.Time, } } -func toMessages(rows []sqlc.ChatMessage) []Message { +func toMessagesFromList(rows []sqlc.ListMessagesRow) []Message { msgs := make([]Message, 0, len(rows)) for _, row := range rows { - msgs = append(msgs, toMessage(row)) + msgs = append(msgs, toMessageFromListRow(row)) + } + return msgs +} + +func toMessagesFromSince(rows []sqlc.ListMessagesSinceRow) []Message { + msgs := make([]Message, 0, len(rows)) + for _, row := range rows { + msgs = append(msgs, toMessageFromSinceRow(row)) + } + return msgs +} + +func toMessagesFromLatest(rows []sqlc.ListMessagesLatestRow) []Message { + msgs := make([]Message, 0, len(rows)) + for _, row := range rows { + msgs = append(msgs, toMessageFromLatestRow(row)) } return msgs } func defaultSettings(chatID string) Settings { return Settings{ - ChatID: chatID, - EnableChatMemory: true, - EnablePrivateMemory: true, - EnablePublicMemory: false, + ChatID: chatID, } } @@ -807,22 +946,6 @@ func determineChatKind(threadID, conversationType string) string { return KindGroup } -func uuidString(id pgtype.UUID) string { - if !id.Valid { - return "" - } - b := id.Bytes - return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", - b[0:4], b[4:6], b[6:8], b[8:10], b[10:16]) -} - -func pgTextString(t pgtype.Text) string { - if !t.Valid { - return "" - } - return t.String -} - func toPgText(s string) pgtype.Text { s = strings.TrimSpace(s) if s == "" { @@ -869,17 +992,9 @@ func (s *Service) resolveChatCreatorChannelIdentityID(ctx context.Context, botID s.logger.Warn("resolve bot owner for group chat failed", slog.Any("error", err)) return fallback } - ownerChannelIdentityID := uuidString(row.OwnerUserID) + ownerChannelIdentityID := row.OwnerUserID.String() if strings.TrimSpace(ownerChannelIdentityID) == "" { return fallback } return ownerChannelIdentityID } - -func (s *Service) isGroupChat(ctx context.Context, chatID string) bool { - chatObj, err := s.Get(ctx, chatID) - if err != nil { - return false - } - return strings.EqualFold(strings.TrimSpace(chatObj.Kind), KindGroup) -} diff --git a/internal/chat/service_presence_integration_test.go b/internal/chat/service_presence_integration_test.go index 69ced70b..9b07851b 100644 --- a/internal/chat/service_presence_integration_test.go +++ b/internal/chat/service_presence_integration_test.go @@ -1,4 +1,4 @@ -package chat_test +package conversation_test import ( "context" @@ -13,15 +13,15 @@ import ( "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" - "github.com/memohai/memoh/internal/channelidentities" - "github.com/memohai/memoh/internal/chat" + "github.com/memohai/memoh/internal/channel/identities" + conversation "github.com/memohai/memoh/internal/chat" "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) type chatPresenceFixture struct { - chatSvc *chat.Service - channelIdentitySvc *channelidentities.Service + chatSvc *conversation.Service + channelIdentitySvc *identities.Service queries *sqlc.Queries cleanup func() } @@ -48,8 +48,8 @@ func setupChatPresenceIntegrationTest(t *testing.T) chatPresenceFixture { queries := sqlc.New(pool) return chatPresenceFixture{ - chatSvc: chat.NewService(logger, queries), - channelIdentitySvc: channelidentities.NewService(logger, queries), + chatSvc: conversation.NewService(logger, queries), + channelIdentitySvc: identities.NewService(logger, queries), queries: queries, cleanup: func() { pool.Close() }, } @@ -63,7 +63,7 @@ func createUserForChatPresence(ctx context.Context, queries *sqlc.Queries) (stri if err != nil { return "", err } - return db.UUIDToString(row.ID), nil + return row.ID.String(), nil } func createBotForChatPresence(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) { @@ -85,7 +85,7 @@ func createBotForChatPresence(ctx context.Context, queries *sqlc.Queries, ownerU if err != nil { return "", err } - return db.UUIDToString(row.ID), nil + return row.ID.String(), nil } func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, string, string, string) { @@ -110,8 +110,8 @@ func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, strin t.Fatalf("create bot failed: %v", err) } - createdChat, err := fixture.chatSvc.Create(ctx, botID, ownerUserID, chat.CreateChatRequest{ - Kind: chat.KindGroup, + createdChat, err := fixture.chatSvc.Create(ctx, botID, ownerUserID, conversation.CreateRequest{ + Kind: conversation.KindGroup, Title: "presence-observed", }) if err != nil { @@ -132,13 +132,13 @@ func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, strin _, err = fixture.chatSvc.PersistMessage( ctx, - createdChat.ID, botID, "", observedChannelIdentity.ID, "", "feishu", fmt.Sprintf("ext-msg-%d", time.Now().UnixNano()), + "", "user", []byte(`{"content":"hello from observed channelIdentity"}`), nil, @@ -176,7 +176,7 @@ func TestObservedChatVisibleAfterBindWithoutBackfill(t *testing.T) { t.Fatalf("expected observed chat visible after bind, got %d chats", len(afterBind)) } - var target *chat.ChatListItem + var target *conversation.ChatListItem for i := range afterBind { if afterBind[i].ID == chatID { target = &afterBind[i] @@ -186,8 +186,8 @@ func TestObservedChatVisibleAfterBindWithoutBackfill(t *testing.T) { if target == nil { t.Fatalf("expected chat %s in visible list after bind", chatID) } - if target.AccessMode != chat.AccessModeChannelIdentityObserved { - t.Fatalf("expected access_mode=%s, got %s", chat.AccessModeChannelIdentityObserved, target.AccessMode) + if target.AccessMode != conversation.AccessModeChannelIdentityObserved { + t.Fatalf("expected access_mode=%s, got %s", conversation.AccessModeChannelIdentityObserved, target.AccessMode) } if target.ParticipantRole != "" { t.Fatalf("expected empty participant_role for observed chat, got %s", target.ParticipantRole) @@ -210,8 +210,8 @@ func TestObservedAccessReadableButNotParticipant(t *testing.T) { if err != nil { t.Fatalf("get read access failed: %v", err) } - if access.AccessMode != chat.AccessModeChannelIdentityObserved { - t.Fatalf("expected read access %s, got %s", chat.AccessModeChannelIdentityObserved, access.AccessMode) + if access.AccessMode != conversation.AccessModeChannelIdentityObserved { + t.Fatalf("expected read access %s, got %s", conversation.AccessModeChannelIdentityObserved, access.AccessMode) } messages, err := fixture.chatSvc.ListMessages(ctx, chatID) @@ -223,7 +223,7 @@ func TestObservedAccessReadableButNotParticipant(t *testing.T) { } _, err = fixture.chatSvc.GetParticipant(ctx, chatID, observerUserID) - if !errors.Is(err, chat.ErrNotParticipant) { + if !errors.Is(err, conversation.ErrNotParticipant) { t.Fatalf("expected ErrNotParticipant for observed user, got %v", err) } ok, err := fixture.chatSvc.IsParticipant(ctx, chatID, observerUserID) @@ -238,7 +238,7 @@ func TestObservedAccessReadableButNotParticipant(t *testing.T) { if err != nil { t.Fatalf("list visible chats failed: %v", err) } - if len(visibleChats) == 0 || visibleChats[0].AccessMode != chat.AccessModeChannelIdentityObserved { + if len(visibleChats) == 0 || visibleChats[0].AccessMode != conversation.AccessModeChannelIdentityObserved { t.Fatal("expected observed list entry with channel_identity_observed access mode") } } diff --git a/internal/chat/types.go b/internal/chat/types.go index 4cd9d66d..2e203280 100644 --- a/internal/chat/types.go +++ b/internal/chat/types.go @@ -1,6 +1,6 @@ -// Package chat orchestrates conversations with the agent gateway, including -// synchronous and streaming chat, scheduled triggers, messages, and memory storage. -package chat +// Package conversation orchestrates interactions with the agent gateway, including +// synchronous and streaming responses, scheduled triggers, messages, and memory storage. +package conversation import ( "encoding/json" @@ -74,12 +74,8 @@ type Participant struct { // Settings holds per-chat configuration. type Settings struct { - ChatID string `json:"chat_id"` - EnableChatMemory bool `json:"enable_chat_memory"` - EnablePrivateMemory bool `json:"enable_private_memory"` - EnablePublicMemory bool `json:"enable_public_memory"` - ModelID string `json:"model_id,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + ChatID string `json:"chat_id"` + ModelID string `json:"model_id,omitempty"` } // Route maps external channel conversations to a chat. @@ -97,24 +93,24 @@ type Route struct { UpdatedAt time.Time `json:"updated_at"` } -// Message represents a single persisted chat message. +// Message represents a single persisted bot message. type Message struct { ID string `json:"id"` - ChatID string `json:"chat_id"` BotID string `json:"bot_id"` RouteID string `json:"route_id,omitempty"` SenderChannelIdentityID string `json:"sender_channel_identity_id,omitempty"` SenderUserID string `json:"sender_user_id,omitempty"` Platform string `json:"platform,omitempty"` ExternalMessageID string `json:"external_message_id,omitempty"` + SourceReplyToMessageID string `json:"source_reply_to_message_id,omitempty"` Role string `json:"role"` Content json.RawMessage `json:"content"` Metadata map[string]any `json:"metadata,omitempty"` CreatedAt time.Time `json:"created_at"` } -// CreateChatRequest is the input for creating a chat. -type CreateChatRequest struct { +// CreateRequest is the input for creating a bot-scoped conversation container. +type CreateRequest struct { Kind string `json:"kind"` Title string `json:"title,omitempty"` ParentChatID string `json:"parent_chat_id,omitempty"` @@ -123,10 +119,7 @@ type CreateChatRequest struct { // UpdateSettingsRequest is the input for updating chat settings. type UpdateSettingsRequest struct { - EnableChatMemory *bool `json:"enable_chat_memory,omitempty"` - EnablePrivateMemory *bool `json:"enable_private_memory,omitempty"` - EnablePublicMemory *bool `json:"enable_public_memory,omitempty"` - ModelID *string `json:"model_id,omitempty"` + ModelID *string `json:"model_id,omitempty"` } // ResolveChatResult is returned by ResolveChat. @@ -234,16 +227,17 @@ type ToolCallFunction struct { // ChatRequest is the input for Chat and StreamChat. type ChatRequest struct { - BotID string `json:"-"` - ChatID string `json:"-"` - Token string `json:"-"` - UserID string `json:"-"` - SourceChannelIdentityID string `json:"-"` - ContainerID string `json:"-"` - DisplayName string `json:"-"` - RouteID string `json:"-"` - ChatToken string `json:"-"` - ExternalMessageID string `json:"-"` + BotID string `json:"-"` + ChatID string `json:"-"` + Token string `json:"-"` + UserID string `json:"-"` + SourceChannelIdentityID string `json:"-"` + ContainerID string `json:"-"` + DisplayName string `json:"-"` + RouteID string `json:"-"` + ChatToken string `json:"-"` + ExternalMessageID string `json:"-"` + UserMessagePersisted bool `json:"-"` Query string `json:"query"` Model string `json:"model,omitempty"` diff --git a/internal/conversation/flow/assistant_output.go b/internal/conversation/flow/assistant_output.go new file mode 100644 index 00000000..b752fcdf --- /dev/null +++ b/internal/conversation/flow/assistant_output.go @@ -0,0 +1,36 @@ +package flow + +import "strings" + +// ExtractAssistantOutputs collects assistant-role outputs from a slice of ModelMessages. +func ExtractAssistantOutputs(messages []ModelMessage) []AssistantOutput { + if len(messages) == 0 { + return nil + } + outputs := make([]AssistantOutput, 0, len(messages)) + for _, msg := range messages { + if msg.Role != "assistant" { + continue + } + content := strings.TrimSpace(msg.TextContent()) + parts := filterContentParts(msg.ContentParts()) + if content == "" && len(parts) == 0 { + continue + } + outputs = append(outputs, AssistantOutput{Content: content, Parts: parts}) + } + return outputs +} + +func filterContentParts(parts []ContentPart) []ContentPart { + if len(parts) == 0 { + return nil + } + filtered := make([]ContentPart, 0, len(parts)) + for _, p := range parts { + if p.HasValue() { + filtered = append(filtered, p) + } + } + return filtered +} diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go new file mode 100644 index 00000000..9ed3225d --- /dev/null +++ b/internal/conversation/flow/resolver.go @@ -0,0 +1,1226 @@ +package flow + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "sort" + "strings" + "time" + + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/conversation" + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/mcp" + "github.com/memohai/memoh/internal/memory" + messagepkg "github.com/memohai/memoh/internal/message" + "github.com/memohai/memoh/internal/models" + "github.com/memohai/memoh/internal/schedule" + "github.com/memohai/memoh/internal/settings" +) + +const ( + defaultMaxContextMinutes = 24 * 60 + memoryContextLimitPerScope = 4 + memoryContextMaxItems = 8 + memoryContextItemMaxChars = 220 + sharedMemoryNamespace = "bot" +) + +// SkillEntry represents a skill loaded from the container. +type SkillEntry struct { + Name string + Description string + Content string + Metadata map[string]any +} + +// SkillLoader loads skills for a given bot from its container. +type SkillLoader interface { + LoadSkills(ctx context.Context, botID string) ([]SkillEntry, error) +} + +// ConversationSettingsReader defines settings lookup behavior needed by flow resolution. +type ConversationSettingsReader interface { + GetSettings(ctx context.Context, conversationID string) (conversation.Settings, error) +} + +// Resolver orchestrates chat with the agent gateway. +type Resolver struct { + modelsService *models.Service + queries *sqlc.Queries + memoryService *memory.Service + conversationSvc ConversationSettingsReader + messageService messagepkg.Service + settingsService *settings.Service + mcpService *mcp.ConnectionService + skillLoader SkillLoader + gatewayBaseURL string + timeout time.Duration + logger *slog.Logger + httpClient *http.Client + streamingClient *http.Client +} + +// NewResolver creates a Resolver that communicates with the agent gateway. +func NewResolver( + log *slog.Logger, + modelsService *models.Service, + queries *sqlc.Queries, + memoryService *memory.Service, + conversationSvc ConversationSettingsReader, + messageService messagepkg.Service, + settingsService *settings.Service, + mcpService *mcp.ConnectionService, + gatewayBaseURL string, + timeout time.Duration, +) *Resolver { + if strings.TrimSpace(gatewayBaseURL) == "" { + gatewayBaseURL = "http://127.0.0.1:8081" + } + gatewayBaseURL = strings.TrimRight(gatewayBaseURL, "/") + if timeout <= 0 { + timeout = 60 * time.Second + } + return &Resolver{ + modelsService: modelsService, + queries: queries, + memoryService: memoryService, + conversationSvc: conversationSvc, + messageService: messageService, + settingsService: settingsService, + mcpService: mcpService, + gatewayBaseURL: gatewayBaseURL, + timeout: timeout, + logger: log.With(slog.String("service", "conversation_resolver")), + httpClient: &http.Client{Timeout: timeout}, + streamingClient: &http.Client{}, + } +} + +// SetSkillLoader sets the skill loader used to populate usable skills in gateway requests. +func (r *Resolver) SetSkillLoader(sl SkillLoader) { + r.skillLoader = sl +} + +// --- gateway payload --- + +type gatewayModelConfig struct { + ModelID string `json:"modelId"` + ClientType string `json:"clientType"` + Input []string `json:"input"` + APIKey string `json:"apiKey"` + BaseURL string `json:"baseUrl"` +} + +type gatewayIdentity struct { + BotID string `json:"botId"` + ContainerID string `json:"containerId"` + ChannelIdentityID string `json:"channelIdentityId"` + DisplayName string `json:"displayName"` + CurrentPlatform string `json:"currentPlatform,omitempty"` + ReplyTarget string `json:"replyTarget,omitempty"` + SessionToken string `json:"sessionToken,omitempty"` +} + +type gatewaySkill struct { + Name string `json:"name"` + Description string `json:"description"` + Content string `json:"content"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type gatewayRequest struct { + Model gatewayModelConfig `json:"model"` + ActiveContextTime int `json:"activeContextTime"` + Channels []string `json:"channels"` + CurrentChannel string `json:"currentChannel"` + AllowedActions []string `json:"allowedActions,omitempty"` + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` + UsableSkills []gatewaySkill `json:"usableSkills"` + Query string `json:"query"` + Identity gatewayIdentity `json:"identity"` + Attachments []any `json:"attachments"` +} + +type gatewayResponse struct { + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` +} + +// gatewaySchedule matches the agent gateway ScheduleModel for /chat/trigger-schedule. +type gatewaySchedule struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Pattern string `json:"pattern"` + MaxCalls *int `json:"maxCalls,omitempty"` + Command string `json:"command"` +} + +// triggerScheduleRequest is the payload for POST /chat/trigger-schedule. +type triggerScheduleRequest struct { + Model gatewayModelConfig `json:"model"` + ActiveContextTime int `json:"activeContextTime"` + Channels []string `json:"channels"` + CurrentChannel string `json:"currentChannel"` + AllowedActions []string `json:"allowedActions,omitempty"` + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` + UsableSkills []gatewaySkill `json:"usableSkills"` + Identity gatewayIdentity `json:"identity"` + Attachments []any `json:"attachments"` + Schedule gatewaySchedule `json:"schedule"` +} + +// --- resolved context (shared by Chat / StreamChat / TriggerSchedule) --- + +type resolvedContext struct { + payload gatewayRequest + model models.GetResponse + provider sqlc.LlmProvider +} + +func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContext, error) { + if strings.TrimSpace(req.Query) == "" { + return resolvedContext{}, fmt.Errorf("query is required") + } + if strings.TrimSpace(req.BotID) == "" { + return resolvedContext{}, fmt.Errorf("bot id is required") + } + if strings.TrimSpace(req.ChatID) == "" { + return resolvedContext{}, fmt.Errorf("chat id is required") + } + + skipHistory := req.MaxContextLoadTime < 0 + + botSettings, err := r.loadBotSettings(ctx, req.BotID) + if err != nil { + return resolvedContext{}, err + } + + // Check chat-level model override. + var chatSettings Settings + if r.conversationSvc != nil { + chatSettings, err = r.conversationSvc.GetSettings(ctx, req.ChatID) + if err != nil { + return resolvedContext{}, err + } + } + + userSettings, err := r.loadUserSettings(ctx, req.UserID) + if err != nil { + return resolvedContext{}, err + } + chatModel, provider, err := r.selectChatModel(ctx, req, botSettings, userSettings, chatSettings) + if err != nil { + return resolvedContext{}, err + } + clientType, err := normalizeClientType(provider.ClientType) + if err != nil { + return resolvedContext{}, err + } + maxCtx := coalescePositiveInt(req.MaxContextLoadTime, botSettings.MaxContextLoadTime, defaultMaxContextMinutes) + + var messages []ModelMessage + if !skipHistory && r.conversationSvc != nil { + messages, err = r.loadMessages(ctx, req.ChatID, maxCtx) + if err != nil { + return resolvedContext{}, err + } + } + if memoryMsg := r.loadMemoryContextMessage(ctx, req); memoryMsg != nil { + messages = append(messages, *memoryMsg) + } + messages = append(messages, req.Messages...) + messages = sanitizeMessages(messages) + skills := dedup(req.Skills) + containerID := r.resolveContainerID(ctx, req.BotID, req.ContainerID) + + var usableSkills []gatewaySkill + if r.skillLoader != nil { + entries, err := r.skillLoader.LoadSkills(ctx, req.BotID) + if err != nil { + r.logger.Warn("failed to load usable skills", slog.String("bot_id", req.BotID), slog.Any("error", err)) + } else { + usableSkills = make([]gatewaySkill, 0, len(entries)) + for _, e := range entries { + usableSkills = append(usableSkills, gatewaySkill{ + Name: e.Name, + Description: e.Description, + Content: e.Content, + Metadata: e.Metadata, + }) + } + } + } + if usableSkills == nil { + usableSkills = []gatewaySkill{} + } + + payload := gatewayRequest{ + Model: gatewayModelConfig{ + ModelID: chatModel.ModelID, + ClientType: clientType, + Input: chatModel.Input, + APIKey: provider.ApiKey, + BaseURL: provider.BaseUrl, + }, + ActiveContextTime: maxCtx, + Channels: nonNilStrings(req.Channels), + CurrentChannel: req.CurrentChannel, + AllowedActions: req.AllowedActions, + Messages: nonNilModelMessages(messages), + Skills: nonNilStrings(skills), + UsableSkills: usableSkills, + Query: req.Query, + Identity: gatewayIdentity{ + BotID: req.BotID, + ContainerID: containerID, + ChannelIdentityID: firstNonEmpty(req.SourceChannelIdentityID, req.UserID), + DisplayName: r.resolveDisplayName(ctx, req), + CurrentPlatform: req.CurrentChannel, + ReplyTarget: "", + SessionToken: req.ChatToken, + }, + Attachments: []any{}, + } + + return resolvedContext{payload: payload, model: chatModel, provider: provider}, nil +} + +// --- Chat --- + +// Chat sends a synchronous chat request to the agent gateway and stores the result. +func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, error) { + rc, err := r.resolve(ctx, req) + if err != nil { + return ChatResponse{}, err + } + resp, err := r.postChat(ctx, rc.payload, req.Token) + if err != nil { + return ChatResponse{}, err + } + if err := r.storeRound(ctx, req, resp.Messages); err != nil { + return ChatResponse{}, err + } + return ChatResponse{ + Messages: resp.Messages, + Skills: resp.Skills, + Model: rc.model.ModelID, + Provider: rc.provider.ClientType, + }, nil +} + +// --- TriggerSchedule --- + +// TriggerSchedule executes a scheduled command through the agent gateway trigger-schedule endpoint. +func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error { + if strings.TrimSpace(botID) == "" { + return fmt.Errorf("bot id is required") + } + if strings.TrimSpace(payload.Command) == "" { + return fmt.Errorf("schedule command is required") + } + + chatID := payload.ChatID + if strings.TrimSpace(chatID) == "" { + chatID = "schedule-" + payload.ID + } + req := ChatRequest{ + BotID: botID, + ChatID: chatID, + Query: payload.Command, + UserID: payload.OwnerUserID, + Token: token, + } + rc, err := r.resolve(ctx, req) + if err != nil { + return err + } + + triggerReq := triggerScheduleRequest{ + Model: rc.payload.Model, + ActiveContextTime: rc.payload.ActiveContextTime, + Channels: rc.payload.Channels, + CurrentChannel: rc.payload.CurrentChannel, + AllowedActions: rc.payload.AllowedActions, + Messages: rc.payload.Messages, + Skills: rc.payload.Skills, + UsableSkills: rc.payload.UsableSkills, + Identity: gatewayIdentity{ + BotID: rc.payload.Identity.BotID, + ContainerID: rc.payload.Identity.ContainerID, + ChannelIdentityID: strings.TrimSpace(payload.OwnerUserID), + DisplayName: "Scheduler", + }, + Attachments: rc.payload.Attachments, + Schedule: gatewaySchedule{ + ID: payload.ID, + Name: payload.Name, + Description: payload.Description, + Pattern: payload.Pattern, + MaxCalls: payload.MaxCalls, + Command: payload.Command, + }, + } + + resp, err := r.postTriggerSchedule(ctx, triggerReq, token) + if err != nil { + return err + } + return r.storeRound(ctx, req, resp.Messages) +} + +// --- StreamChat --- + +// StreamChat sends a streaming chat request to the agent gateway. +func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan StreamChunk, <-chan error) { + chunkCh := make(chan StreamChunk) + errCh := make(chan error, 1) + r.logger.Info("gateway stream start", + slog.String("bot_id", req.BotID), + slog.String("chat_id", req.ChatID), + ) + + go func() { + defer close(chunkCh) + defer close(errCh) + + streamReq := req + rc, err := r.resolve(ctx, streamReq) + if err != nil { + r.logger.Error("gateway stream resolve failed", + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), + slog.Any("error", err), + ) + errCh <- err + return + } + if err := r.persistUserMessage(ctx, streamReq); err != nil { + r.logger.Error("gateway stream persist user message failed", + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), + slog.Any("error", err), + ) + errCh <- err + return + } + streamReq.UserMessagePersisted = true + if err := r.streamChat(ctx, rc.payload, streamReq, chunkCh); err != nil { + r.logger.Error("gateway stream request failed", + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), + slog.Any("error", err), + ) + errCh <- err + } + }() + return chunkCh, errCh +} + +// --- HTTP helpers --- + +func (r *Resolver) postChat(ctx context.Context, payload gatewayRequest, token string) (gatewayResponse, error) { + body, err := json.Marshal(payload) + if err != nil { + return gatewayResponse{}, err + } + url := r.gatewayBaseURL + "/chat/" + r.logger.Info("gateway request", slog.String("url", url), slog.String("body_prefix", truncate(string(body), 200))) + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return gatewayResponse{}, err + } + httpReq.Header.Set("Content-Type", "application/json") + if strings.TrimSpace(token) != "" { + httpReq.Header.Set("Authorization", token) + } + + resp, err := r.httpClient.Do(httpReq) + if err != nil { + return gatewayResponse{}, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return gatewayResponse{}, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + r.logger.Error("gateway error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(respBody), 300))) + return gatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody))) + } + + var parsed gatewayResponse + if err := json.Unmarshal(respBody, &parsed); err != nil { + r.logger.Error("gateway response parse failed", slog.String("body_prefix", truncate(string(respBody), 300)), slog.Any("error", err)) + return gatewayResponse{}, fmt.Errorf("failed to parse gateway response: %w", err) + } + return parsed, nil +} + +// postTriggerSchedule sends a trigger-schedule request to the agent gateway. +func (r *Resolver) postTriggerSchedule(ctx context.Context, payload triggerScheduleRequest, token string) (gatewayResponse, error) { + body, err := json.Marshal(payload) + if err != nil { + return gatewayResponse{}, err + } + url := r.gatewayBaseURL + "/chat/trigger-schedule" + r.logger.Info("gateway trigger-schedule request", slog.String("url", url), slog.String("schedule_id", payload.Schedule.ID)) + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return gatewayResponse{}, err + } + httpReq.Header.Set("Content-Type", "application/json") + if strings.TrimSpace(token) != "" { + httpReq.Header.Set("Authorization", token) + } + + resp, err := r.httpClient.Do(httpReq) + if err != nil { + return gatewayResponse{}, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return gatewayResponse{}, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + r.logger.Error("gateway trigger-schedule error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(respBody), 300))) + return gatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody))) + } + + var parsed gatewayResponse + if err := json.Unmarshal(respBody, &parsed); err != nil { + r.logger.Error("gateway trigger-schedule response parse failed", slog.String("body_prefix", truncate(string(respBody), 300)), slog.Any("error", err)) + return gatewayResponse{}, fmt.Errorf("failed to parse gateway response: %w", err) + } + return parsed, nil +} + +func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req ChatRequest, chunkCh chan<- StreamChunk) error { + body, err := json.Marshal(payload) + if err != nil { + return err + } + url := r.gatewayBaseURL + "/chat/stream" + r.logger.Info("gateway stream request", slog.String("url", url), slog.String("body_prefix", truncate(string(body), 200))) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return err + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "text/event-stream") + if strings.TrimSpace(req.Token) != "" { + httpReq.Header.Set("Authorization", req.Token) + } + + resp, err := r.streamingClient.Do(httpReq) + if err != nil { + r.logger.Error("gateway stream connect failed", slog.String("url", url), slog.Any("error", err)) + return err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + errBody, _ := io.ReadAll(resp.Body) + r.logger.Error("gateway stream error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(errBody), 300))) + return fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(errBody))) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) + + currentEvent := "" + stored := false + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + if strings.HasPrefix(line, "event:") { + currentEvent = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + continue + } + if !strings.HasPrefix(line, "data:") { + continue + } + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if data == "" || data == "[DONE]" { + continue + } + chunkCh <- StreamChunk([]byte(data)) + + if stored { + continue + } + if handled, storeErr := r.tryStoreStream(ctx, req, currentEvent, data); storeErr != nil { + return storeErr + } else if handled { + stored = true + } + } + return scanner.Err() +} + +// tryStoreStream attempts to extract final messages from a stream event and persist them. +func (r *Resolver) tryStoreStream(ctx context.Context, req ChatRequest, eventType, data string) (bool, error) { + // event: done + data: {messages: [...]} + if eventType == "done" { + var resp gatewayResponse + if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 { + return true, r.storeRound(ctx, req, resp.Messages) + } + } + + // data: {"type":"text_delta"|"agent_end"|"done", ...} + var envelope struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` + } + if err := json.Unmarshal([]byte(data), &envelope); err == nil { + if (envelope.Type == "agent_end" || envelope.Type == "done") && len(envelope.Messages) > 0 { + return true, r.storeRound(ctx, req, envelope.Messages) + } + if envelope.Type == "done" && len(envelope.Data) > 0 { + var resp gatewayResponse + if err := json.Unmarshal(envelope.Data, &resp); err == nil && len(resp.Messages) > 0 { + return true, r.storeRound(ctx, req, resp.Messages) + } + } + } + + // fallback: data: {messages: [...]} + var resp gatewayResponse + if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 { + return true, r.storeRound(ctx, req, resp.Messages) + } + return false, nil +} + +// --- container resolution --- + +func (r *Resolver) resolveContainerID(ctx context.Context, botID, explicit string) string { + if strings.TrimSpace(explicit) != "" { + return explicit + } + if r.queries != nil { + pgBotID, err := parseResolverUUID(botID) + if err == nil { + row, err := r.queries.GetContainerByBotID(ctx, pgBotID) + if err == nil && strings.TrimSpace(row.ContainerID) != "" { + return row.ContainerID + } + } + } + return "mcp-" + botID +} + +// --- message loading --- + +func (r *Resolver) loadMessages(ctx context.Context, chatID string, maxContextMinutes int) ([]ModelMessage, error) { + if r.messageService == nil { + return nil, nil + } + since := time.Now().UTC().Add(-time.Duration(maxContextMinutes) * time.Minute) + msgs, err := r.messageService.ListSince(ctx, chatID, since) + if err != nil { + return nil, err + } + var result []ModelMessage + for _, m := range msgs { + var mm ModelMessage + if err := json.Unmarshal(m.Content, &mm); err != nil { + // Fallback: treat content as text string. + mm = ModelMessage{Role: m.Role, Content: m.Content} + } else { + mm.Role = m.Role + } + result = append(result, mm) + } + return result, nil +} + +type memoryContextItem struct { + Namespace string + Item memory.MemoryItem +} + +func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest) *ModelMessage { + if r.memoryService == nil { + return nil + } + if strings.TrimSpace(req.Query) == "" || strings.TrimSpace(req.BotID) == "" || strings.TrimSpace(req.ChatID) == "" { + return nil + } + + results := make([]memoryContextItem, 0, memoryContextLimitPerScope) + seen := map[string]struct{}{} + resp, err := r.memoryService.Search(ctx, memory.SearchRequest{ + Query: req.Query, + BotID: req.BotID, + Limit: memoryContextLimitPerScope, + Filters: map[string]any{ + "namespace": sharedMemoryNamespace, + "scopeId": req.BotID, + "botId": req.BotID, + }, + }) + if err != nil { + r.logger.Warn("memory search for context failed", + slog.String("namespace", sharedMemoryNamespace), + slog.Any("error", err), + ) + return nil + } + for _, item := range resp.Results { + key := strings.TrimSpace(item.ID) + if key == "" { + key = sharedMemoryNamespace + ":" + strings.TrimSpace(item.Memory) + } + if key == "" { + continue + } + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + results = append(results, memoryContextItem{Namespace: sharedMemoryNamespace, Item: item}) + } + if len(results) == 0 { + return nil + } + + sort.Slice(results, func(i, j int) bool { + return results[i].Item.Score > results[j].Item.Score + }) + if len(results) > memoryContextMaxItems { + results = results[:memoryContextMaxItems] + } + + var sb strings.Builder + sb.WriteString("Relevant memory context (use when helpful):\n") + for _, entry := range results { + text := strings.TrimSpace(entry.Item.Memory) + if text == "" { + continue + } + sb.WriteString("- [") + sb.WriteString(entry.Namespace) + sb.WriteString("] ") + sb.WriteString(truncateMemorySnippet(text, memoryContextItemMaxChars)) + sb.WriteString("\n") + } + payload := strings.TrimSpace(sb.String()) + if payload == "" { + return nil + } + msg := ModelMessage{ + Role: "system", + Content: NewTextContent(payload), + } + return &msg +} + +// --- store helpers --- + +func (r *Resolver) persistUserMessage(ctx context.Context, req ChatRequest) error { + if r.messageService == nil { + return nil + } + if strings.TrimSpace(req.BotID) == "" { + return fmt.Errorf("bot id is required for persistence") + } + text := strings.TrimSpace(req.Query) + if text == "" { + return nil + } + + message := ModelMessage{ + Role: "user", + Content: NewTextContent(text), + } + content, err := json.Marshal(message) + if err != nil { + return err + } + senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req) + _, err = r.messageService.Persist(ctx, messagepkg.PersistInput{ + BotID: req.BotID, + RouteID: req.RouteID, + SenderChannelIdentityID: senderChannelIdentityID, + SenderUserID: senderUserID, + Platform: req.CurrentChannel, + ExternalMessageID: req.ExternalMessageID, + Role: "user", + Content: content, + Metadata: buildRouteMetadata(req), + }) + return err +} + +func (r *Resolver) storeRound(ctx context.Context, req ChatRequest, messages []ModelMessage) error { + // Add user query as the first message if not already present in the round. + // This ensures the user's prompt is persisted alongside the assistant's response. + fullRound := make([]ModelMessage, 0, len(messages)+1) + hasUserQuery := false + for _, m := range messages { + if m.Role == "user" && m.TextContent() == req.Query { + hasUserQuery = true + break + } + } + if !req.UserMessagePersisted && !hasUserQuery && strings.TrimSpace(req.Query) != "" { + fullRound = append(fullRound, ModelMessage{ + Role: "user", + Content: NewTextContent(req.Query), + }) + } + for _, m := range messages { + if req.UserMessagePersisted && m.Role == "user" && strings.TrimSpace(m.TextContent()) == strings.TrimSpace(req.Query) { + // User message was already persisted before streaming; skip duplicate copy in round payload. + continue + } + fullRound = append(fullRound, m) + } + if len(fullRound) == 0 { + return nil + } + + r.storeMessages(ctx, req, fullRound) + r.storeMemory(ctx, req.BotID, fullRound) + return nil +} + +func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages []ModelMessage) { + if r.messageService == nil { + return + } + if strings.TrimSpace(req.BotID) == "" { + return + } + meta := buildRouteMetadata(req) + senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req) + for _, msg := range messages { + content, err := json.Marshal(msg) + if err != nil { + continue + } + messageSenderChannelIdentityID := "" + messageSenderUserID := "" + externalMessageID := "" + sourceReplyToMessageID := "" + if msg.Role == "user" { + messageSenderChannelIdentityID = senderChannelIdentityID + messageSenderUserID = senderUserID + externalMessageID = req.ExternalMessageID + } else if strings.TrimSpace(req.ExternalMessageID) != "" { + // Assistant/tool/system outputs are linked to the inbound source message for cross-channel reply threading. + sourceReplyToMessageID = req.ExternalMessageID + } + if _, err := r.messageService.Persist(ctx, messagepkg.PersistInput{ + BotID: req.BotID, + RouteID: req.RouteID, + SenderChannelIdentityID: messageSenderChannelIdentityID, + SenderUserID: messageSenderUserID, + Platform: req.CurrentChannel, + ExternalMessageID: externalMessageID, + SourceReplyToMessageID: sourceReplyToMessageID, + Role: msg.Role, + Content: content, + Metadata: meta, + }); err != nil { + r.logger.Warn("persist message failed", slog.Any("error", err)) + } + } +} + +func buildRouteMetadata(req ChatRequest) map[string]any { + if strings.TrimSpace(req.RouteID) == "" && strings.TrimSpace(req.CurrentChannel) == "" { + return nil + } + meta := map[string]any{} + if strings.TrimSpace(req.RouteID) != "" { + meta["route_id"] = req.RouteID + } + if strings.TrimSpace(req.CurrentChannel) != "" { + meta["platform"] = req.CurrentChannel + } + return meta +} + +func (r *Resolver) resolvePersistSenderIDs(ctx context.Context, req ChatRequest) (string, string) { + channelIdentityID := strings.TrimSpace(req.SourceChannelIdentityID) + userID := strings.TrimSpace(req.UserID) + + channelIdentityValid := r.isExistingChannelIdentityID(ctx, channelIdentityID) + userAsUserValid := r.isExistingUserID(ctx, userID) + userAsChannelIdentityValid := r.isExistingChannelIdentityID(ctx, userID) + + senderChannelIdentityID := "" + switch { + case channelIdentityValid: + senderChannelIdentityID = channelIdentityID + case userAsChannelIdentityValid && !userAsUserValid: + // Some flows may carry channel_identity_id in req.UserID. + senderChannelIdentityID = userID + } + + senderUserID := "" + if userAsUserValid { + senderUserID = userID + } + if senderUserID == "" && senderChannelIdentityID != "" { + if linked := r.linkedUserIDFromChannelIdentity(ctx, senderChannelIdentityID); linked != "" { + senderUserID = linked + } + } + return senderChannelIdentityID, senderUserID +} + +func (r *Resolver) isExistingChannelIdentityID(ctx context.Context, id string) bool { + if r.queries == nil { + return false + } + pgID, err := parseResolverUUID(id) + if err != nil { + return false + } + _, err = r.queries.GetChannelIdentityByID(ctx, pgID) + return err == nil +} + +func (r *Resolver) isExistingUserID(ctx context.Context, id string) bool { + if r.queries == nil { + return false + } + pgID, err := parseResolverUUID(id) + if err != nil { + return false + } + _, err = r.queries.GetUserByID(ctx, pgID) + return err == nil +} + +func (r *Resolver) linkedUserIDFromChannelIdentity(ctx context.Context, channelIdentityID string) string { + if r.queries == nil { + return "" + } + pgID, err := parseResolverUUID(channelIdentityID) + if err != nil { + return "" + } + row, err := r.queries.GetChannelIdentityByID(ctx, pgID) + if err != nil || !row.UserID.Valid { + return "" + } + return row.UserID.String() +} + +// resolveDisplayName returns the best available display name for the request identity: +// req.DisplayName if set, else channel identity's display_name, else linked user's display_name, else "User". +func (r *Resolver) resolveDisplayName(ctx context.Context, req ChatRequest) string { + if name := strings.TrimSpace(req.DisplayName); name != "" { + return name + } + if r.queries == nil { + return "User" + } + channelIdentityID := firstNonEmpty(req.SourceChannelIdentityID, req.UserID) + if channelIdentityID == "" { + return "User" + } + pgID, err := parseResolverUUID(channelIdentityID) + if err != nil { + return "User" + } + ci, err := r.queries.GetChannelIdentityByID(ctx, pgID) + if err == nil && ci.DisplayName.Valid { + if name := strings.TrimSpace(ci.DisplayName.String); name != "" { + return name + } + } + linkedUserID := r.linkedUserIDFromChannelIdentity(ctx, channelIdentityID) + if linkedUserID == "" { + return "User" + } + userPgID, err := parseResolverUUID(linkedUserID) + if err != nil { + return "User" + } + u, err := r.queries.GetUserByID(ctx, userPgID) + if err != nil || !u.DisplayName.Valid { + return "User" + } + if name := strings.TrimSpace(u.DisplayName.String); name != "" { + return name + } + return "User" +} + +func (r *Resolver) storeMemory(ctx context.Context, botID string, messages []ModelMessage) { + if r.memoryService == nil { + return + } + if strings.TrimSpace(botID) == "" { + return + } + memMsgs := make([]memory.Message, 0, len(messages)) + for _, msg := range messages { + text := strings.TrimSpace(msg.TextContent()) + if text == "" { + continue + } + role := msg.Role + if strings.TrimSpace(role) == "" { + role = "assistant" + } + memMsgs = append(memMsgs, memory.Message{Role: role, Content: text}) + } + if len(memMsgs) == 0 { + return + } + r.addMemory(ctx, botID, memMsgs, sharedMemoryNamespace, botID) +} + +func (r *Resolver) addMemory(ctx context.Context, botID string, msgs []memory.Message, namespace, scopeID string) { + filters := map[string]any{ + "namespace": namespace, + "scopeId": scopeID, + "botId": botID, + } + if _, err := r.memoryService.Add(ctx, memory.AddRequest{ + Messages: msgs, + BotID: botID, + Filters: filters, + }); err != nil { + r.logger.Warn("store memory failed", + slog.String("namespace", namespace), + slog.String("scope_id", scopeID), + slog.Any("error", err), + ) + } +} + +// --- model selection --- + +func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, botSettings settings.Settings, us resolvedUserSettings, cs Settings) (models.GetResponse, sqlc.LlmProvider, error) { + if r.modelsService == nil { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured") + } + modelID := strings.TrimSpace(req.Model) + providerFilter := strings.TrimSpace(req.Provider) + + // Priority: request model > chat settings > bot settings > user settings. + if modelID == "" && providerFilter == "" { + if value := strings.TrimSpace(cs.ModelID); value != "" { + modelID = value + } else if value := strings.TrimSpace(botSettings.ChatModelID); value != "" { + modelID = value + } else if value := strings.TrimSpace(us.ChatModelID); value != "" { + modelID = value + } + } + + if modelID == "" { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not configured: specify model in request or bot settings") + } + + if providerFilter == "" { + return r.fetchChatModel(ctx, modelID) + } + + candidates, err := r.listCandidates(ctx, providerFilter) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + for _, m := range candidates { + if m.ModelID == modelID { + prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + return m, prov, nil + } + } + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model %q not found for provider %q", modelID, providerFilter) +} + +func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) { + model, err := r.modelsService.GetByModelID(ctx, modelID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + if model.Type != models.ModelTypeChat { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model") + } + prov, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + return model, prov, nil +} + +func (r *Resolver) listCandidates(ctx context.Context, providerFilter string) ([]models.GetResponse, error) { + var all []models.GetResponse + var err error + if providerFilter != "" { + all, err = r.modelsService.ListByClientType(ctx, models.ClientType(providerFilter)) + } else { + all, err = r.modelsService.ListByType(ctx, models.ModelTypeChat) + } + if err != nil { + return nil, err + } + filtered := make([]models.GetResponse, 0, len(all)) + for _, m := range all { + if m.Type == models.ModelTypeChat { + filtered = append(filtered, m) + } + } + return filtered, nil +} + +// --- settings --- + +type resolvedUserSettings struct { + ChatModelID string +} + +func (r *Resolver) loadUserSettings(ctx context.Context, userID string) (resolvedUserSettings, error) { + if r.settingsService == nil || strings.TrimSpace(userID) == "" { + return resolvedUserSettings{}, nil + } + s, err := r.settingsService.Get(ctx, userID) + if err != nil { + return resolvedUserSettings{}, err + } + return resolvedUserSettings{ + ChatModelID: strings.TrimSpace(s.ChatModelID), + }, nil +} + +func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (settings.Settings, error) { + if r.settingsService == nil { + return settings.Settings{ + MaxContextLoadTime: settings.DefaultMaxContextLoadTime, + Language: settings.DefaultLanguage, + }, nil + } + return r.settingsService.GetBot(ctx, botID) +} + +// --- utility --- + +func normalizeClientType(clientType string) (string, error) { + switch strings.ToLower(strings.TrimSpace(clientType)) { + case "openai", "openai-compat": + return "openai", nil + case "anthropic": + return "anthropic", nil + case "google": + return "google", nil + default: + return "", fmt.Errorf("unsupported agent gateway client type: %s", clientType) + } +} + +func sanitizeMessages(messages []ModelMessage) []ModelMessage { + cleaned := make([]ModelMessage, 0, len(messages)) + for _, msg := range messages { + if strings.TrimSpace(msg.Role) == "" { + continue + } + if !msg.HasContent() && strings.TrimSpace(msg.ToolCallID) == "" { + continue + } + cleaned = append(cleaned, msg) + } + return cleaned +} + +func dedup(items []string) []string { + seen := make(map[string]struct{}, len(items)) + result := make([]string, 0, len(items)) + for _, s := range items { + trimmed := strings.TrimSpace(s) + if trimmed == "" { + continue + } + if _, ok := seen[trimmed]; ok { + continue + } + seen[trimmed] = struct{}{} + result = append(result, trimmed) + } + return result +} + +func firstNonEmpty(values ...string) string { + for _, v := range values { + if strings.TrimSpace(v) != "" { + return v + } + } + return "" +} + +func coalescePositiveInt(values ...int) int { + for _, v := range values { + if v > 0 { + return v + } + } + return defaultMaxContextMinutes +} + +func nonNilStrings(s []string) []string { + if s == nil { + return []string{} + } + return s +} + +func nonNilModelMessages(m []ModelMessage) []ModelMessage { + if m == nil { + return []ModelMessage{} + } + return m +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +} + +func truncateMemorySnippet(s string, n int) string { + trimmed := strings.TrimSpace(s) + if len(trimmed) <= n { + return trimmed + } + return strings.TrimSpace(trimmed[:n]) + "..." +} + +func parseResolverUUID(id string) (pgtype.UUID, error) { + if strings.TrimSpace(id) == "" { + return pgtype.UUID{}, fmt.Errorf("empty id") + } + return db.ParseUUID(id) +} diff --git a/internal/conversation/flow/resolver_memory_context_test.go b/internal/conversation/flow/resolver_memory_context_test.go new file mode 100644 index 00000000..1e326986 --- /dev/null +++ b/internal/conversation/flow/resolver_memory_context_test.go @@ -0,0 +1,55 @@ +package flow + +import ( + "context" + "log/slog" + "strings" + "testing" + + "github.com/memohai/memoh/internal/memory" +) + +func TestLoadMemoryContextMessage_NoMemoryService(t *testing.T) { + resolver := &Resolver{ + memoryService: nil, + logger: slog.Default(), + } + msg := resolver.loadMemoryContextMessage(context.Background(), ChatRequest{ + Query: "hello", + BotID: "bot-1", + ChatID: "chat-1", + }) + if msg != nil { + t.Fatalf("expected nil message when memory service is nil") + } +} + +func TestLoadMemoryContextMessage_SearchFailureFallback(t *testing.T) { + resolver := &Resolver{ + memoryService: &memory.Service{}, + logger: slog.Default(), + } + msg := resolver.loadMemoryContextMessage(context.Background(), ChatRequest{ + Query: "hello", + BotID: "bot-1", + ChatID: "chat-1", + UserID: "user-1", + }) + if msg != nil { + t.Fatalf("expected nil message when memory search cannot return results") + } +} + +func TestTruncateMemorySnippet(t *testing.T) { + longText := strings.Repeat("a", 20) + " " + got := truncateMemorySnippet(longText, 10) + if got != strings.Repeat("a", 10)+"..." { + t.Fatalf("unexpected truncated value: %q", got) + } + + shortText := " short " + got = truncateMemorySnippet(shortText, 10) + if got != "short" { + t.Fatalf("unexpected trimmed short value: %q", got) + } +} diff --git a/internal/conversation/flow/resolver_test.go b/internal/conversation/flow/resolver_test.go new file mode 100644 index 00000000..702d77bf --- /dev/null +++ b/internal/conversation/flow/resolver_test.go @@ -0,0 +1,158 @@ +package flow + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestPostTriggerSchedule_Endpoint(t *testing.T) { + var capturedPath string + var capturedBody []byte + var capturedAuth string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + capturedAuth = r.Header.Get("Authorization") + capturedBody, _ = io.ReadAll(r.Body) + resp := gatewayResponse{ + Messages: []ModelMessage{{Role: "assistant", Content: NewTextContent("ok")}}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + resolver := &Resolver{ + gatewayBaseURL: srv.URL, + httpClient: &http.Client{Timeout: 5 * time.Second}, + logger: slog.Default(), + } + + maxCalls := 5 + req := triggerScheduleRequest{ + Model: gatewayModelConfig{ + ModelID: "gpt-4", + ClientType: "openai", + APIKey: "sk-test", + BaseURL: "https://api.openai.com", + }, + ActiveContextTime: 1440, + Channels: []string{}, + Messages: []ModelMessage{}, + Skills: []string{}, + Identity: gatewayIdentity{ + BotID: "bot-123", + ContainerID: "mcp-bot-123", + ChannelIdentityID: "owner-user-1", + DisplayName: "Scheduler", + }, + Attachments: []any{}, + Schedule: gatewaySchedule{ + ID: "sched-1", + Name: "daily report", + Description: "generate daily report", + Pattern: "0 9 * * *", + MaxCalls: &maxCalls, + Command: "generate the daily report", + }, + } + + resp, err := resolver.postTriggerSchedule(context.Background(), req, "Bearer test-token") + if err != nil { + t.Fatalf("postTriggerSchedule returned error: %v", err) + } + + if capturedPath != "/chat/trigger-schedule" { + t.Errorf("expected path /chat/trigger-schedule, got %s", capturedPath) + } + if capturedAuth != "Bearer test-token" { + t.Errorf("expected Authorization header 'Bearer test-token', got %s", capturedAuth) + } + if len(resp.Messages) != 1 { + t.Errorf("expected 1 message, got %d", len(resp.Messages)) + } + + var body map[string]any + if err := json.Unmarshal(capturedBody, &body); err != nil { + t.Fatalf("failed to parse captured body: %v", err) + } + schedule, ok := body["schedule"].(map[string]any) + if !ok { + t.Fatal("expected 'schedule' field in request body") + } + if schedule["id"] != "sched-1" { + t.Errorf("expected schedule.id=sched-1, got %v", schedule["id"]) + } + if schedule["command"] != "generate the daily report" { + t.Errorf("expected schedule.command, got %v", schedule["command"]) + } + if _, hasQuery := body["query"]; hasQuery { + t.Error("trigger-schedule request should not contain 'query' field") + } +} + +func TestPostTriggerSchedule_NoAuth(t *testing.T) { + var capturedAuth string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAuth = r.Header.Get("Authorization") + resp := gatewayResponse{Messages: []ModelMessage{}} + json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + resolver := &Resolver{ + gatewayBaseURL: srv.URL, + httpClient: &http.Client{Timeout: 5 * time.Second}, + logger: slog.Default(), + } + + req := triggerScheduleRequest{ + Channels: []string{}, + Messages: []ModelMessage{}, + Skills: []string{}, + Attachments: []any{}, + Schedule: gatewaySchedule{ID: "s1", Command: "test"}, + } + + _, err := resolver.postTriggerSchedule(context.Background(), req, "") + if err != nil { + t.Fatalf("postTriggerSchedule returned error: %v", err) + } + if capturedAuth != "" { + t.Errorf("expected no Authorization header, got %s", capturedAuth) + } +} + +func TestPostTriggerSchedule_GatewayError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("internal error")) + })) + defer srv.Close() + + resolver := &Resolver{ + gatewayBaseURL: srv.URL, + httpClient: &http.Client{Timeout: 5 * time.Second}, + logger: slog.Default(), + } + + req := triggerScheduleRequest{ + Channels: []string{}, + Messages: []ModelMessage{}, + Skills: []string{}, + Attachments: []any{}, + Schedule: gatewaySchedule{ID: "s1", Command: "test"}, + } + + _, err := resolver.postTriggerSchedule(context.Background(), req, "Bearer tok") + if err == nil { + t.Fatal("expected error for 500 response") + } +} diff --git a/internal/conversation/flow/schedule_gateway.go b/internal/conversation/flow/schedule_gateway.go new file mode 100644 index 00000000..4b3c2138 --- /dev/null +++ b/internal/conversation/flow/schedule_gateway.go @@ -0,0 +1,26 @@ +package flow + +import ( + "context" + "fmt" + + "github.com/memohai/memoh/internal/schedule" +) + +// ScheduleGateway adapts schedule trigger calls to the chat Resolver. +type ScheduleGateway struct { + resolver *Resolver +} + +// NewScheduleGateway creates a ScheduleGateway backed by the given Resolver. +func NewScheduleGateway(resolver *Resolver) *ScheduleGateway { + return &ScheduleGateway{resolver: resolver} +} + +// TriggerSchedule delegates a schedule trigger to the chat Resolver. +func (g *ScheduleGateway) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error { + if g == nil || g.resolver == nil { + return fmt.Errorf("chat resolver not configured") + } + return g.resolver.TriggerSchedule(ctx, botID, payload, token) +} diff --git a/internal/conversation/flow/types.go b/internal/conversation/flow/types.go new file mode 100644 index 00000000..a4c19313 --- /dev/null +++ b/internal/conversation/flow/types.go @@ -0,0 +1,14 @@ +package flow + +import ( + "context" + + "github.com/memohai/memoh/internal/schedule" +) + +// Runner defines conversation execution behavior for sync, stream, and scheduled flows. +type Runner interface { + Chat(ctx context.Context, req ChatRequest) (ChatResponse, error) + StreamChat(ctx context.Context, req ChatRequest) (<-chan StreamChunk, <-chan error) + TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error +} diff --git a/internal/conversation/flow/types_alias.go b/internal/conversation/flow/types_alias.go new file mode 100644 index 00000000..1b345db6 --- /dev/null +++ b/internal/conversation/flow/types_alias.go @@ -0,0 +1,15 @@ +package flow + +import "github.com/memohai/memoh/internal/conversation" + +type ModelMessage = conversation.ModelMessage +type ContentPart = conversation.ContentPart +type ToolCall = conversation.ToolCall +type ToolCallFunction = conversation.ToolCallFunction +type AssistantOutput = conversation.AssistantOutput +type ChatRequest = conversation.ChatRequest +type ChatResponse = conversation.ChatResponse +type StreamChunk = conversation.StreamChunk +type Settings = conversation.Settings + +var NewTextContent = conversation.NewTextContent diff --git a/internal/conversation/interfaces.go b/internal/conversation/interfaces.go new file mode 100644 index 00000000..a11dc0a5 --- /dev/null +++ b/internal/conversation/interfaces.go @@ -0,0 +1,20 @@ +package conversation + +import "context" + +// Reader defines conversation lookup behavior. +type Reader interface { + Get(ctx context.Context, conversationID string) (Chat, error) +} + +// ParticipantChecker defines participant membership checks. +type ParticipantChecker interface { + IsParticipant(ctx context.Context, conversationID, channelIdentityID string) (bool, error) +} + +// Accessor defines read access checks for conversation-scoped operations. +type Accessor interface { + Reader + ParticipantChecker + GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (ChatReadAccess, error) +} diff --git a/internal/conversation/resolver.go b/internal/conversation/resolver.go new file mode 100644 index 00000000..a78f9e2c --- /dev/null +++ b/internal/conversation/resolver.go @@ -0,0 +1,1179 @@ +package conversation + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "sort" + "strings" + "time" + + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/mcp" + "github.com/memohai/memoh/internal/memory" + messagepkg "github.com/memohai/memoh/internal/message" + "github.com/memohai/memoh/internal/models" + "github.com/memohai/memoh/internal/schedule" + "github.com/memohai/memoh/internal/settings" +) + +const ( + defaultMaxContextMinutes = 24 * 60 + memoryContextLimitPerScope = 4 + memoryContextMaxItems = 8 + memoryContextItemMaxChars = 220 + sharedMemoryNamespace = "bot" +) + +// SkillEntry represents a skill loaded from the container. +type SkillEntry struct { + Name string + Description string + Content string + Metadata map[string]any +} + +// SkillLoader loads skills for a given bot from its container. +type SkillLoader interface { + LoadSkills(ctx context.Context, botID string) ([]SkillEntry, error) +} + +// Resolver orchestrates chat with the agent gateway. +type Resolver struct { + modelsService *models.Service + queries *sqlc.Queries + memoryService *memory.Service + chatService *Service + messageService messagepkg.Service + settingsService *settings.Service + mcpService *mcp.ConnectionService + skillLoader SkillLoader + gatewayBaseURL string + timeout time.Duration + logger *slog.Logger + httpClient *http.Client + streamingClient *http.Client +} + +// NewResolver creates a Resolver that communicates with the agent gateway. +func NewResolver( + log *slog.Logger, + modelsService *models.Service, + queries *sqlc.Queries, + memoryService *memory.Service, + chatService *Service, + messageService messagepkg.Service, + settingsService *settings.Service, + mcpService *mcp.ConnectionService, + gatewayBaseURL string, + timeout time.Duration, +) *Resolver { + if strings.TrimSpace(gatewayBaseURL) == "" { + gatewayBaseURL = "http://127.0.0.1:8081" + } + gatewayBaseURL = strings.TrimRight(gatewayBaseURL, "/") + if timeout <= 0 { + timeout = 60 * time.Second + } + return &Resolver{ + modelsService: modelsService, + queries: queries, + memoryService: memoryService, + chatService: chatService, + messageService: messageService, + settingsService: settingsService, + mcpService: mcpService, + gatewayBaseURL: gatewayBaseURL, + timeout: timeout, + logger: log.With(slog.String("service", "chat_resolver")), + httpClient: &http.Client{Timeout: timeout}, + streamingClient: &http.Client{}, + } +} + +// SetSkillLoader sets the skill loader used to populate usable skills in gateway requests. +func (r *Resolver) SetSkillLoader(sl SkillLoader) { + r.skillLoader = sl +} + +// --- gateway payload --- + +type gatewayModelConfig struct { + ModelID string `json:"modelId"` + ClientType string `json:"clientType"` + Input []string `json:"input"` + APIKey string `json:"apiKey"` + BaseURL string `json:"baseUrl"` +} + +type gatewayIdentity struct { + BotID string `json:"botId"` + ContainerID string `json:"containerId"` + ChannelIdentityID string `json:"channelIdentityId"` + DisplayName string `json:"displayName"` + CurrentPlatform string `json:"currentPlatform,omitempty"` + ReplyTarget string `json:"replyTarget,omitempty"` + SessionToken string `json:"sessionToken,omitempty"` +} + +type gatewaySkill struct { + Name string `json:"name"` + Description string `json:"description"` + Content string `json:"content"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type gatewayRequest struct { + Model gatewayModelConfig `json:"model"` + ActiveContextTime int `json:"activeContextTime"` + Channels []string `json:"channels"` + CurrentChannel string `json:"currentChannel"` + AllowedActions []string `json:"allowedActions,omitempty"` + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` + UsableSkills []gatewaySkill `json:"usableSkills"` + Query string `json:"query"` + Identity gatewayIdentity `json:"identity"` + Attachments []any `json:"attachments"` +} + +type gatewayResponse struct { + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` +} + +// gatewaySchedule matches the agent gateway ScheduleModel for /chat/trigger-schedule. +type gatewaySchedule struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Pattern string `json:"pattern"` + MaxCalls *int `json:"maxCalls,omitempty"` + Command string `json:"command"` +} + +// triggerScheduleRequest is the payload for POST /chat/trigger-schedule. +type triggerScheduleRequest struct { + Model gatewayModelConfig `json:"model"` + ActiveContextTime int `json:"activeContextTime"` + Channels []string `json:"channels"` + CurrentChannel string `json:"currentChannel"` + AllowedActions []string `json:"allowedActions,omitempty"` + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` + UsableSkills []gatewaySkill `json:"usableSkills"` + Identity gatewayIdentity `json:"identity"` + Attachments []any `json:"attachments"` + Schedule gatewaySchedule `json:"schedule"` +} + +// --- resolved context (shared by Chat / StreamChat / TriggerSchedule) --- + +type resolvedContext struct { + payload gatewayRequest + model models.GetResponse + provider sqlc.LlmProvider +} + +func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContext, error) { + if strings.TrimSpace(req.Query) == "" { + return resolvedContext{}, fmt.Errorf("query is required") + } + if strings.TrimSpace(req.BotID) == "" { + return resolvedContext{}, fmt.Errorf("bot id is required") + } + if strings.TrimSpace(req.ChatID) == "" { + return resolvedContext{}, fmt.Errorf("chat id is required") + } + + skipHistory := req.MaxContextLoadTime < 0 + + botSettings, err := r.loadBotSettings(ctx, req.BotID) + if err != nil { + return resolvedContext{}, err + } + + // Check chat-level model override. + var chatSettings Settings + if r.chatService != nil { + chatSettings, err = r.chatService.GetSettings(ctx, req.ChatID) + if err != nil { + return resolvedContext{}, err + } + } + + userSettings, err := r.loadUserSettings(ctx, req.UserID) + if err != nil { + return resolvedContext{}, err + } + chatModel, provider, err := r.selectChatModel(ctx, req, botSettings, userSettings, chatSettings) + if err != nil { + return resolvedContext{}, err + } + clientType, err := normalizeClientType(provider.ClientType) + if err != nil { + return resolvedContext{}, err + } + maxCtx := coalescePositiveInt(req.MaxContextLoadTime, botSettings.MaxContextLoadTime, defaultMaxContextMinutes) + + var messages []ModelMessage + if !skipHistory && r.chatService != nil { + messages, err = r.loadMessages(ctx, req.ChatID, maxCtx) + if err != nil { + return resolvedContext{}, err + } + } + if memoryMsg := r.loadMemoryContextMessage(ctx, req); memoryMsg != nil { + messages = append(messages, *memoryMsg) + } + messages = append(messages, req.Messages...) + messages = sanitizeMessages(messages) + skills := dedup(req.Skills) + containerID := r.resolveContainerID(ctx, req.BotID, req.ContainerID) + + var usableSkills []gatewaySkill + if r.skillLoader != nil { + entries, err := r.skillLoader.LoadSkills(ctx, req.BotID) + if err != nil { + r.logger.Warn("failed to load usable skills", slog.String("bot_id", req.BotID), slog.Any("error", err)) + } else { + usableSkills = make([]gatewaySkill, 0, len(entries)) + for _, e := range entries { + usableSkills = append(usableSkills, gatewaySkill{ + Name: e.Name, + Description: e.Description, + Content: e.Content, + Metadata: e.Metadata, + }) + } + } + } + if usableSkills == nil { + usableSkills = []gatewaySkill{} + } + + payload := gatewayRequest{ + Model: gatewayModelConfig{ + ModelID: chatModel.ModelID, + ClientType: clientType, + Input: chatModel.Input, + APIKey: provider.ApiKey, + BaseURL: provider.BaseUrl, + }, + ActiveContextTime: maxCtx, + Channels: nonNilStrings(req.Channels), + CurrentChannel: req.CurrentChannel, + AllowedActions: req.AllowedActions, + Messages: nonNilModelMessages(messages), + Skills: nonNilStrings(skills), + UsableSkills: usableSkills, + Query: req.Query, + Identity: gatewayIdentity{ + BotID: req.BotID, + ContainerID: containerID, + ChannelIdentityID: firstNonEmpty(req.SourceChannelIdentityID, req.UserID), + DisplayName: firstNonEmpty(req.DisplayName, "User"), + CurrentPlatform: req.CurrentChannel, + ReplyTarget: "", + SessionToken: req.ChatToken, + }, + Attachments: []any{}, + } + + return resolvedContext{payload: payload, model: chatModel, provider: provider}, nil +} + +// --- Chat --- + +// Chat sends a synchronous chat request to the agent gateway and stores the result. +func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, error) { + rc, err := r.resolve(ctx, req) + if err != nil { + return ChatResponse{}, err + } + resp, err := r.postChat(ctx, rc.payload, req.Token) + if err != nil { + return ChatResponse{}, err + } + if err := r.storeRound(ctx, req, resp.Messages); err != nil { + return ChatResponse{}, err + } + return ChatResponse{ + Messages: resp.Messages, + Skills: resp.Skills, + Model: rc.model.ModelID, + Provider: rc.provider.ClientType, + }, nil +} + +// --- TriggerSchedule --- + +// TriggerSchedule executes a scheduled command through the agent gateway trigger-schedule endpoint. +func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error { + if strings.TrimSpace(botID) == "" { + return fmt.Errorf("bot id is required") + } + if strings.TrimSpace(payload.Command) == "" { + return fmt.Errorf("schedule command is required") + } + + chatID := payload.ChatID + if strings.TrimSpace(chatID) == "" { + chatID = "schedule-" + payload.ID + } + req := ChatRequest{ + BotID: botID, + ChatID: chatID, + Query: payload.Command, + UserID: payload.OwnerUserID, + Token: token, + } + rc, err := r.resolve(ctx, req) + if err != nil { + return err + } + + triggerReq := triggerScheduleRequest{ + Model: rc.payload.Model, + ActiveContextTime: rc.payload.ActiveContextTime, + Channels: rc.payload.Channels, + CurrentChannel: rc.payload.CurrentChannel, + AllowedActions: rc.payload.AllowedActions, + Messages: rc.payload.Messages, + Skills: rc.payload.Skills, + UsableSkills: rc.payload.UsableSkills, + Identity: gatewayIdentity{ + BotID: rc.payload.Identity.BotID, + ContainerID: rc.payload.Identity.ContainerID, + ChannelIdentityID: strings.TrimSpace(payload.OwnerUserID), + DisplayName: "Scheduler", + }, + Attachments: rc.payload.Attachments, + Schedule: gatewaySchedule{ + ID: payload.ID, + Name: payload.Name, + Description: payload.Description, + Pattern: payload.Pattern, + MaxCalls: payload.MaxCalls, + Command: payload.Command, + }, + } + + resp, err := r.postTriggerSchedule(ctx, triggerReq, token) + if err != nil { + return err + } + return r.storeRound(ctx, req, resp.Messages) +} + +// --- StreamChat --- + +// StreamChat sends a streaming chat request to the agent gateway. +func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan StreamChunk, <-chan error) { + chunkCh := make(chan StreamChunk) + errCh := make(chan error, 1) + r.logger.Info("gateway stream start", + slog.String("bot_id", req.BotID), + slog.String("chat_id", req.ChatID), + ) + + go func() { + defer close(chunkCh) + defer close(errCh) + + streamReq := req + rc, err := r.resolve(ctx, streamReq) + if err != nil { + r.logger.Error("gateway stream resolve failed", + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), + slog.Any("error", err), + ) + errCh <- err + return + } + if err := r.persistUserMessage(ctx, streamReq); err != nil { + r.logger.Error("gateway stream persist user message failed", + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), + slog.Any("error", err), + ) + errCh <- err + return + } + streamReq.UserMessagePersisted = true + if err := r.streamChat(ctx, rc.payload, streamReq, chunkCh); err != nil { + r.logger.Error("gateway stream request failed", + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), + slog.Any("error", err), + ) + errCh <- err + } + }() + return chunkCh, errCh +} + +// --- HTTP helpers --- + +func (r *Resolver) postChat(ctx context.Context, payload gatewayRequest, token string) (gatewayResponse, error) { + body, err := json.Marshal(payload) + if err != nil { + return gatewayResponse{}, err + } + url := r.gatewayBaseURL + "/chat/" + r.logger.Info("gateway request", slog.String("url", url), slog.String("body_prefix", truncate(string(body), 200))) + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return gatewayResponse{}, err + } + httpReq.Header.Set("Content-Type", "application/json") + if strings.TrimSpace(token) != "" { + httpReq.Header.Set("Authorization", token) + } + + resp, err := r.httpClient.Do(httpReq) + if err != nil { + return gatewayResponse{}, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return gatewayResponse{}, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + r.logger.Error("gateway error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(respBody), 300))) + return gatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody))) + } + + var parsed gatewayResponse + if err := json.Unmarshal(respBody, &parsed); err != nil { + r.logger.Error("gateway response parse failed", slog.String("body_prefix", truncate(string(respBody), 300)), slog.Any("error", err)) + return gatewayResponse{}, fmt.Errorf("failed to parse gateway response: %w", err) + } + return parsed, nil +} + +// postTriggerSchedule sends a trigger-schedule request to the agent gateway. +func (r *Resolver) postTriggerSchedule(ctx context.Context, payload triggerScheduleRequest, token string) (gatewayResponse, error) { + body, err := json.Marshal(payload) + if err != nil { + return gatewayResponse{}, err + } + url := r.gatewayBaseURL + "/chat/trigger-schedule" + r.logger.Info("gateway trigger-schedule request", slog.String("url", url), slog.String("schedule_id", payload.Schedule.ID)) + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return gatewayResponse{}, err + } + httpReq.Header.Set("Content-Type", "application/json") + if strings.TrimSpace(token) != "" { + httpReq.Header.Set("Authorization", token) + } + + resp, err := r.httpClient.Do(httpReq) + if err != nil { + return gatewayResponse{}, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return gatewayResponse{}, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + r.logger.Error("gateway trigger-schedule error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(respBody), 300))) + return gatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody))) + } + + var parsed gatewayResponse + if err := json.Unmarshal(respBody, &parsed); err != nil { + r.logger.Error("gateway trigger-schedule response parse failed", slog.String("body_prefix", truncate(string(respBody), 300)), slog.Any("error", err)) + return gatewayResponse{}, fmt.Errorf("failed to parse gateway response: %w", err) + } + return parsed, nil +} + +func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req ChatRequest, chunkCh chan<- StreamChunk) error { + body, err := json.Marshal(payload) + if err != nil { + return err + } + url := r.gatewayBaseURL + "/chat/stream" + r.logger.Info("gateway stream request", slog.String("url", url), slog.String("body_prefix", truncate(string(body), 200))) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return err + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "text/event-stream") + if strings.TrimSpace(req.Token) != "" { + httpReq.Header.Set("Authorization", req.Token) + } + + resp, err := r.streamingClient.Do(httpReq) + if err != nil { + r.logger.Error("gateway stream connect failed", slog.String("url", url), slog.Any("error", err)) + return err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + errBody, _ := io.ReadAll(resp.Body) + r.logger.Error("gateway stream error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(errBody), 300))) + return fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(errBody))) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) + + currentEvent := "" + stored := false + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + if strings.HasPrefix(line, "event:") { + currentEvent = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + continue + } + if !strings.HasPrefix(line, "data:") { + continue + } + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if data == "" || data == "[DONE]" { + continue + } + chunkCh <- StreamChunk([]byte(data)) + + if stored { + continue + } + if handled, storeErr := r.tryStoreStream(ctx, req, currentEvent, data); storeErr != nil { + return storeErr + } else if handled { + stored = true + } + } + return scanner.Err() +} + +// tryStoreStream attempts to extract final messages from a stream event and persist them. +func (r *Resolver) tryStoreStream(ctx context.Context, req ChatRequest, eventType, data string) (bool, error) { + // event: done + data: {messages: [...]} + if eventType == "done" { + var resp gatewayResponse + if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 { + return true, r.storeRound(ctx, req, resp.Messages) + } + } + + // data: {"type":"text_delta"|"agent_end"|"done", ...} + var envelope struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` + } + if err := json.Unmarshal([]byte(data), &envelope); err == nil { + if (envelope.Type == "agent_end" || envelope.Type == "done") && len(envelope.Messages) > 0 { + return true, r.storeRound(ctx, req, envelope.Messages) + } + if envelope.Type == "done" && len(envelope.Data) > 0 { + var resp gatewayResponse + if err := json.Unmarshal(envelope.Data, &resp); err == nil && len(resp.Messages) > 0 { + return true, r.storeRound(ctx, req, resp.Messages) + } + } + } + + // fallback: data: {messages: [...]} + var resp gatewayResponse + if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 { + return true, r.storeRound(ctx, req, resp.Messages) + } + return false, nil +} + +// --- container resolution --- + +func (r *Resolver) resolveContainerID(ctx context.Context, botID, explicit string) string { + if strings.TrimSpace(explicit) != "" { + return explicit + } + if r.queries != nil { + pgBotID, err := parseResolverUUID(botID) + if err == nil { + row, err := r.queries.GetContainerByBotID(ctx, pgBotID) + if err == nil && strings.TrimSpace(row.ContainerID) != "" { + return row.ContainerID + } + } + } + return "mcp-" + botID +} + +// --- message loading --- + +func (r *Resolver) loadMessages(ctx context.Context, chatID string, maxContextMinutes int) ([]ModelMessage, error) { + if r.messageService == nil { + return nil, nil + } + since := time.Now().UTC().Add(-time.Duration(maxContextMinutes) * time.Minute) + msgs, err := r.messageService.ListSince(ctx, chatID, since) + if err != nil { + return nil, err + } + var result []ModelMessage + for _, m := range msgs { + var mm ModelMessage + if err := json.Unmarshal(m.Content, &mm); err != nil { + // Fallback: treat content as text string. + mm = ModelMessage{Role: m.Role, Content: m.Content} + } else { + mm.Role = m.Role + } + result = append(result, mm) + } + return result, nil +} + +type memoryContextItem struct { + Namespace string + Item memory.MemoryItem +} + +func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest) *ModelMessage { + if r.memoryService == nil { + return nil + } + if strings.TrimSpace(req.Query) == "" || strings.TrimSpace(req.BotID) == "" || strings.TrimSpace(req.ChatID) == "" { + return nil + } + + results := make([]memoryContextItem, 0, memoryContextLimitPerScope) + seen := map[string]struct{}{} + resp, err := r.memoryService.Search(ctx, memory.SearchRequest{ + Query: req.Query, + BotID: req.BotID, + Limit: memoryContextLimitPerScope, + Filters: map[string]any{ + "namespace": sharedMemoryNamespace, + "scopeId": req.BotID, + "botId": req.BotID, + }, + }) + if err != nil { + r.logger.Warn("memory search for context failed", + slog.String("namespace", sharedMemoryNamespace), + slog.Any("error", err), + ) + return nil + } + for _, item := range resp.Results { + key := strings.TrimSpace(item.ID) + if key == "" { + key = sharedMemoryNamespace + ":" + strings.TrimSpace(item.Memory) + } + if key == "" { + continue + } + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + results = append(results, memoryContextItem{Namespace: sharedMemoryNamespace, Item: item}) + } + if len(results) == 0 { + return nil + } + + sort.Slice(results, func(i, j int) bool { + return results[i].Item.Score > results[j].Item.Score + }) + if len(results) > memoryContextMaxItems { + results = results[:memoryContextMaxItems] + } + + var sb strings.Builder + sb.WriteString("Relevant memory context (use when helpful):\n") + for _, entry := range results { + text := strings.TrimSpace(entry.Item.Memory) + if text == "" { + continue + } + sb.WriteString("- [") + sb.WriteString(entry.Namespace) + sb.WriteString("] ") + sb.WriteString(truncateMemorySnippet(text, memoryContextItemMaxChars)) + sb.WriteString("\n") + } + payload := strings.TrimSpace(sb.String()) + if payload == "" { + return nil + } + msg := ModelMessage{ + Role: "system", + Content: NewTextContent(payload), + } + return &msg +} + +// --- store helpers --- + +func (r *Resolver) persistUserMessage(ctx context.Context, req ChatRequest) error { + if r.messageService == nil { + return nil + } + if strings.TrimSpace(req.BotID) == "" { + return fmt.Errorf("bot id is required for persistence") + } + text := strings.TrimSpace(req.Query) + if text == "" { + return nil + } + + message := ModelMessage{ + Role: "user", + Content: NewTextContent(text), + } + content, err := json.Marshal(message) + if err != nil { + return err + } + senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req) + _, err = r.messageService.Persist(ctx, messagepkg.PersistInput{ + BotID: req.BotID, + RouteID: req.RouteID, + SenderChannelIdentityID: senderChannelIdentityID, + SenderUserID: senderUserID, + Platform: req.CurrentChannel, + ExternalMessageID: req.ExternalMessageID, + Role: "user", + Content: content, + Metadata: buildRouteMetadata(req), + }) + return err +} + +func (r *Resolver) storeRound(ctx context.Context, req ChatRequest, messages []ModelMessage) error { + // Add user query as the first message if not already present in the round. + // This ensures the user's prompt is persisted alongside the assistant's response. + fullRound := make([]ModelMessage, 0, len(messages)+1) + hasUserQuery := false + for _, m := range messages { + if m.Role == "user" && m.TextContent() == req.Query { + hasUserQuery = true + break + } + } + if !req.UserMessagePersisted && !hasUserQuery && strings.TrimSpace(req.Query) != "" { + fullRound = append(fullRound, ModelMessage{ + Role: "user", + Content: NewTextContent(req.Query), + }) + } + for _, m := range messages { + if req.UserMessagePersisted && m.Role == "user" && strings.TrimSpace(m.TextContent()) == strings.TrimSpace(req.Query) { + // User message was already persisted before streaming; skip duplicate copy in round payload. + continue + } + fullRound = append(fullRound, m) + } + if len(fullRound) == 0 { + return nil + } + + r.storeMessages(ctx, req, fullRound) + r.storeMemory(ctx, req.BotID, fullRound) + return nil +} + +func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages []ModelMessage) { + if r.messageService == nil { + return + } + if strings.TrimSpace(req.BotID) == "" { + return + } + meta := buildRouteMetadata(req) + senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req) + for _, msg := range messages { + content, err := json.Marshal(msg) + if err != nil { + continue + } + messageSenderChannelIdentityID := "" + messageSenderUserID := "" + externalMessageID := "" + sourceReplyToMessageID := "" + if msg.Role == "user" { + messageSenderChannelIdentityID = senderChannelIdentityID + messageSenderUserID = senderUserID + externalMessageID = req.ExternalMessageID + } else if strings.TrimSpace(req.ExternalMessageID) != "" { + // Assistant/tool/system outputs are linked to the inbound source message for cross-channel reply threading. + sourceReplyToMessageID = req.ExternalMessageID + } + if _, err := r.messageService.Persist(ctx, messagepkg.PersistInput{ + BotID: req.BotID, + RouteID: req.RouteID, + SenderChannelIdentityID: messageSenderChannelIdentityID, + SenderUserID: messageSenderUserID, + Platform: req.CurrentChannel, + ExternalMessageID: externalMessageID, + SourceReplyToMessageID: sourceReplyToMessageID, + Role: msg.Role, + Content: content, + Metadata: meta, + }); err != nil { + r.logger.Warn("persist message failed", slog.Any("error", err)) + } + } +} + +func buildRouteMetadata(req ChatRequest) map[string]any { + if strings.TrimSpace(req.RouteID) == "" && strings.TrimSpace(req.CurrentChannel) == "" { + return nil + } + meta := map[string]any{} + if strings.TrimSpace(req.RouteID) != "" { + meta["route_id"] = req.RouteID + } + if strings.TrimSpace(req.CurrentChannel) != "" { + meta["platform"] = req.CurrentChannel + } + return meta +} + +func (r *Resolver) resolvePersistSenderIDs(ctx context.Context, req ChatRequest) (string, string) { + channelIdentityID := strings.TrimSpace(req.SourceChannelIdentityID) + userID := strings.TrimSpace(req.UserID) + + channelIdentityValid := r.isExistingChannelIdentityID(ctx, channelIdentityID) + userAsUserValid := r.isExistingUserID(ctx, userID) + userAsChannelIdentityValid := r.isExistingChannelIdentityID(ctx, userID) + + senderChannelIdentityID := "" + switch { + case channelIdentityValid: + senderChannelIdentityID = channelIdentityID + case userAsChannelIdentityValid && !userAsUserValid: + // Some flows may carry channel_identity_id in req.UserID. + senderChannelIdentityID = userID + } + + senderUserID := "" + if userAsUserValid { + senderUserID = userID + } + if senderUserID == "" && senderChannelIdentityID != "" { + if linked := r.linkedUserIDFromChannelIdentity(ctx, senderChannelIdentityID); linked != "" { + senderUserID = linked + } + } + return senderChannelIdentityID, senderUserID +} + +func (r *Resolver) isExistingChannelIdentityID(ctx context.Context, id string) bool { + if r.queries == nil { + return false + } + pgID, err := parseResolverUUID(id) + if err != nil { + return false + } + _, err = r.queries.GetChannelIdentityByID(ctx, pgID) + return err == nil +} + +func (r *Resolver) isExistingUserID(ctx context.Context, id string) bool { + if r.queries == nil { + return false + } + pgID, err := parseResolverUUID(id) + if err != nil { + return false + } + _, err = r.queries.GetUserByID(ctx, pgID) + return err == nil +} + +func (r *Resolver) linkedUserIDFromChannelIdentity(ctx context.Context, channelIdentityID string) string { + if r.queries == nil { + return "" + } + pgID, err := parseResolverUUID(channelIdentityID) + if err != nil { + return "" + } + row, err := r.queries.GetChannelIdentityByID(ctx, pgID) + if err != nil || !row.UserID.Valid { + return "" + } + return row.UserID.String() +} + +func (r *Resolver) storeMemory(ctx context.Context, botID string, messages []ModelMessage) { + if r.memoryService == nil { + return + } + if strings.TrimSpace(botID) == "" { + return + } + memMsgs := make([]memory.Message, 0, len(messages)) + for _, msg := range messages { + text := strings.TrimSpace(msg.TextContent()) + if text == "" { + continue + } + role := msg.Role + if strings.TrimSpace(role) == "" { + role = "assistant" + } + memMsgs = append(memMsgs, memory.Message{Role: role, Content: text}) + } + if len(memMsgs) == 0 { + return + } + r.addMemory(ctx, botID, memMsgs, sharedMemoryNamespace, botID) +} + +func (r *Resolver) addMemory(ctx context.Context, botID string, msgs []memory.Message, namespace, scopeID string) { + filters := map[string]any{ + "namespace": namespace, + "scopeId": scopeID, + "botId": botID, + } + if _, err := r.memoryService.Add(ctx, memory.AddRequest{ + Messages: msgs, + BotID: botID, + Filters: filters, + }); err != nil { + r.logger.Warn("store memory failed", + slog.String("namespace", namespace), + slog.String("scope_id", scopeID), + slog.Any("error", err), + ) + } +} + +// --- model selection --- + +func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, botSettings settings.Settings, us resolvedUserSettings, cs Settings) (models.GetResponse, sqlc.LlmProvider, error) { + if r.modelsService == nil { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured") + } + modelID := strings.TrimSpace(req.Model) + providerFilter := strings.TrimSpace(req.Provider) + + // Priority: request model > chat settings > bot settings > user settings. + if modelID == "" && providerFilter == "" { + if value := strings.TrimSpace(cs.ModelID); value != "" { + modelID = value + } else if value := strings.TrimSpace(botSettings.ChatModelID); value != "" { + modelID = value + } else if value := strings.TrimSpace(us.ChatModelID); value != "" { + modelID = value + } + } + + if modelID == "" { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not configured: specify model in request or bot settings") + } + + if providerFilter == "" { + return r.fetchChatModel(ctx, modelID) + } + + candidates, err := r.listCandidates(ctx, providerFilter) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + for _, m := range candidates { + if m.ModelID == modelID { + prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + return m, prov, nil + } + } + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model %q not found for provider %q", modelID, providerFilter) +} + +func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) { + model, err := r.modelsService.GetByModelID(ctx, modelID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + if model.Type != models.ModelTypeChat { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model") + } + prov, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + return model, prov, nil +} + +func (r *Resolver) listCandidates(ctx context.Context, providerFilter string) ([]models.GetResponse, error) { + var all []models.GetResponse + var err error + if providerFilter != "" { + all, err = r.modelsService.ListByClientType(ctx, models.ClientType(providerFilter)) + } else { + all, err = r.modelsService.ListByType(ctx, models.ModelTypeChat) + } + if err != nil { + return nil, err + } + filtered := make([]models.GetResponse, 0, len(all)) + for _, m := range all { + if m.Type == models.ModelTypeChat { + filtered = append(filtered, m) + } + } + return filtered, nil +} + +// --- settings --- + +type resolvedUserSettings struct { + ChatModelID string +} + +func (r *Resolver) loadUserSettings(ctx context.Context, userID string) (resolvedUserSettings, error) { + if r.settingsService == nil || strings.TrimSpace(userID) == "" { + return resolvedUserSettings{}, nil + } + s, err := r.settingsService.Get(ctx, userID) + if err != nil { + return resolvedUserSettings{}, err + } + return resolvedUserSettings{ + ChatModelID: strings.TrimSpace(s.ChatModelID), + }, nil +} + +func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (settings.Settings, error) { + if r.settingsService == nil { + return settings.Settings{ + MaxContextLoadTime: settings.DefaultMaxContextLoadTime, + Language: settings.DefaultLanguage, + }, nil + } + return r.settingsService.GetBot(ctx, botID) +} + +// --- utility --- + +func normalizeClientType(clientType string) (string, error) { + switch strings.ToLower(strings.TrimSpace(clientType)) { + case "openai", "openai-compat": + return "openai", nil + case "anthropic": + return "anthropic", nil + case "google": + return "google", nil + default: + return "", fmt.Errorf("unsupported agent gateway client type: %s", clientType) + } +} + +func sanitizeMessages(messages []ModelMessage) []ModelMessage { + cleaned := make([]ModelMessage, 0, len(messages)) + for _, msg := range messages { + if strings.TrimSpace(msg.Role) == "" { + continue + } + if !msg.HasContent() && strings.TrimSpace(msg.ToolCallID) == "" { + continue + } + cleaned = append(cleaned, msg) + } + return cleaned +} + +func dedup(items []string) []string { + seen := make(map[string]struct{}, len(items)) + result := make([]string, 0, len(items)) + for _, s := range items { + trimmed := strings.TrimSpace(s) + if trimmed == "" { + continue + } + if _, ok := seen[trimmed]; ok { + continue + } + seen[trimmed] = struct{}{} + result = append(result, trimmed) + } + return result +} + +func firstNonEmpty(values ...string) string { + for _, v := range values { + if strings.TrimSpace(v) != "" { + return v + } + } + return "" +} + +func coalescePositiveInt(values ...int) int { + for _, v := range values { + if v > 0 { + return v + } + } + return defaultMaxContextMinutes +} + +func nonNilStrings(s []string) []string { + if s == nil { + return []string{} + } + return s +} + +func nonNilModelMessages(m []ModelMessage) []ModelMessage { + if m == nil { + return []ModelMessage{} + } + return m +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +} + +func truncateMemorySnippet(s string, n int) string { + trimmed := strings.TrimSpace(s) + if len(trimmed) <= n { + return trimmed + } + return strings.TrimSpace(trimmed[:n]) + "..." +} + +func parseResolverUUID(id string) (pgtype.UUID, error) { + if strings.TrimSpace(id) == "" { + return pgtype.UUID{}, fmt.Errorf("empty id") + } + return db.ParseUUID(id) +} diff --git a/internal/conversation/service_domain.go b/internal/conversation/service_domain.go new file mode 100644 index 00000000..8e0f0099 --- /dev/null +++ b/internal/conversation/service_domain.go @@ -0,0 +1,483 @@ +package conversation + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + + dbpkg "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" +) + +var ( + ErrChatNotFound = errors.New("chat not found") + ErrNotParticipant = errors.New("not a participant") + ErrPermissionDenied = errors.New("permission denied") +) + +// Service manages conversation lifecycle, participants, and settings. +type Service struct { + queries *sqlc.Queries + logger *slog.Logger +} + +// NewService creates a conversation service. +func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { + if log == nil { + log = slog.Default() + } + return &Service{ + queries: queries, + logger: log.With(slog.String("service", "conversation")), + } +} + +// Create creates a new conversation and adds the creator as owner. +func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, req CreateRequest) (Chat, error) { + kind := strings.TrimSpace(req.Kind) + if kind == "" { + kind = KindDirect + } + if kind != KindDirect && kind != KindGroup && kind != KindThread { + return Chat{}, fmt.Errorf("invalid conversation kind: %s", kind) + } + + pgBotID, err := parseUUID(botID) + if err != nil { + return Chat{}, fmt.Errorf("invalid bot id: %w", err) + } + pgChannelIdentityID := pgtype.UUID{} + if strings.TrimSpace(channelIdentityID) != "" { + pgChannelIdentityID, err = parseUUID(channelIdentityID) + if err != nil { + return Chat{}, fmt.Errorf("invalid channel identity id: %w", err) + } + } + + var pgParent pgtype.UUID + if kind == KindThread && strings.TrimSpace(req.ParentChatID) != "" { + pgParent, err = parseUUID(req.ParentChatID) + if err != nil { + return Chat{}, fmt.Errorf("invalid parent conversation id: %w", err) + } + } + + metadata, err := json.Marshal(nonNilMap(req.Metadata)) + if err != nil { + return Chat{}, fmt.Errorf("marshal conversation metadata: %w", err) + } + + row, err := s.queries.CreateChat(ctx, sqlc.CreateChatParams{ + BotID: pgBotID, + Kind: kind, + ParentChatID: pgParent, + Title: strings.TrimSpace(req.Title), + CreatedByUserID: pgChannelIdentityID, + Metadata: metadata, + }) + if err != nil { + return Chat{}, fmt.Errorf("create conversation: %w", err) + } + + // Add creator as owner when the channel identity is available. + if pgChannelIdentityID.Valid { + if _, err := s.queries.AddChatParticipant(ctx, sqlc.AddChatParticipantParams{ + ChatID: row.ID, + UserID: pgChannelIdentityID, + Role: RoleOwner, + }); err != nil { + return Chat{}, fmt.Errorf("add owner participant: %w", err) + } + } + + // For threads, copy parent participants. + if kind == KindThread && pgParent.Valid { + if err := s.queries.CopyParticipantsToChat(ctx, sqlc.CopyParticipantsToChatParams{ + ChatID: pgParent, + ChatID2: row.ID, + }); err != nil && s.logger != nil { + s.logger.Warn("copy parent participants failed", slog.Any("error", err)) + } + } + + return toChatFromCreate(row), nil +} + +// Get returns a conversation by ID. +func (s *Service) Get(ctx context.Context, conversationID string) (Chat, error) { + pgID, err := parseUUID(conversationID) + if err != nil { + return Chat{}, ErrChatNotFound + } + row, err := s.queries.GetChatByID(ctx, pgID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return Chat{}, ErrChatNotFound + } + return Chat{}, err + } + return toChatFromGet(row), nil +} + +// GetReadAccess resolves whether a user can read a conversation. +func (s *Service) GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (ChatReadAccess, error) { + pgConversationID, err := parseUUID(conversationID) + if err != nil { + return ChatReadAccess{}, ErrPermissionDenied + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return ChatReadAccess{}, ErrPermissionDenied + } + row, err := s.queries.GetChatReadAccessByUser(ctx, sqlc.GetChatReadAccessByUserParams{ + ChatID: pgConversationID, + UserID: pgChannelIdentityID, + }) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ChatReadAccess{}, ErrPermissionDenied + } + return ChatReadAccess{}, err + } + return ChatReadAccess{ + AccessMode: row.AccessMode, + ParticipantRole: strings.TrimSpace(row.ParticipantRole), + LastObservedAt: pgTimePtr(row.LastObservedAt), + }, nil +} + +// ListByBotAndChannelIdentity returns all visible conversations for a bot and channel identity. +func (s *Service) ListByBotAndChannelIdentity(ctx context.Context, botID, channelIdentityID string) ([]ChatListItem, error) { + pgBotID, err := parseUUID(botID) + if err != nil { + return nil, err + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListVisibleChatsByBotAndUser(ctx, sqlc.ListVisibleChatsByBotAndUserParams{ + BotID: pgBotID, + UserID: pgChannelIdentityID, + }) + if err != nil { + return nil, err + } + conversations := make([]ChatListItem, 0, len(rows)) + for _, row := range rows { + conversations = append(conversations, toChatListItem(row)) + } + return conversations, nil +} + +// ListThreads returns threads for a parent conversation. +func (s *Service) ListThreads(ctx context.Context, parentConversationID string) ([]Chat, error) { + pgID, err := parseUUID(parentConversationID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListThreadsByParent(ctx, pgID) + if err != nil { + return nil, err + } + conversations := make([]Chat, 0, len(rows)) + for _, row := range rows { + conversations = append(conversations, toChatFromThread(row)) + } + return conversations, nil +} + +// Delete deletes a conversation and linked records. +func (s *Service) Delete(ctx context.Context, conversationID string) error { + pgID, err := parseUUID(conversationID) + if err != nil { + return ErrChatNotFound + } + return s.queries.DeleteChat(ctx, pgID) +} + +// AddParticipant adds a channel identity to a conversation. +func (s *Service) AddParticipant(ctx context.Context, conversationID, channelIdentityID, role string) (Participant, error) { + pgConversationID, err := parseUUID(conversationID) + if err != nil { + return Participant{}, err + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return Participant{}, err + } + if role == "" { + role = RoleMember + } + row, err := s.queries.AddChatParticipant(ctx, sqlc.AddChatParticipantParams{ + ChatID: pgConversationID, + UserID: pgChannelIdentityID, + Role: role, + }) + if err != nil { + return Participant{}, err + } + return toParticipantFromAdd(row), nil +} + +// GetParticipant returns a conversation participant. +func (s *Service) GetParticipant(ctx context.Context, conversationID, channelIdentityID string) (Participant, error) { + pgConversationID, err := parseUUID(conversationID) + if err != nil { + return Participant{}, ErrNotParticipant + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return Participant{}, ErrNotParticipant + } + row, err := s.queries.GetChatParticipant(ctx, sqlc.GetChatParticipantParams{ + ChatID: pgConversationID, + UserID: pgChannelIdentityID, + }) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return Participant{}, ErrNotParticipant + } + return Participant{}, err + } + return toParticipantFromGet(row), nil +} + +// IsParticipant checks whether a channel identity is a participant. +func (s *Service) IsParticipant(ctx context.Context, conversationID, channelIdentityID string) (bool, error) { + _, err := s.GetParticipant(ctx, conversationID, channelIdentityID) + if errors.Is(err, ErrNotParticipant) { + return false, nil + } + return err == nil, err +} + +// ListParticipants returns all participants for a conversation. +func (s *Service) ListParticipants(ctx context.Context, conversationID string) ([]Participant, error) { + pgID, err := parseUUID(conversationID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListChatParticipants(ctx, pgID) + if err != nil { + return nil, err + } + participants := make([]Participant, 0, len(rows)) + for _, row := range rows { + participants = append(participants, toParticipantFromList(row)) + } + return participants, nil +} + +// RemoveParticipant removes a participant from a conversation. +func (s *Service) RemoveParticipant(ctx context.Context, conversationID, channelIdentityID string) error { + pgConversationID, err := parseUUID(conversationID) + if err != nil { + return err + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return err + } + return s.queries.RemoveChatParticipant(ctx, sqlc.RemoveChatParticipantParams{ + ChatID: pgConversationID, + UserID: pgChannelIdentityID, + }) +} + +// GetSettings returns conversation settings and falls back to defaults when missing. +func (s *Service) GetSettings(ctx context.Context, conversationID string) (Settings, error) { + pgID, err := parseUUID(conversationID) + if err != nil { + return defaultSettings(conversationID), nil + } + row, err := s.queries.GetChatSettings(ctx, pgID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return defaultSettings(conversationID), nil + } + return Settings{}, err + } + return toSettingsFromRead(row), nil +} + +// UpdateSettings updates conversation settings. +func (s *Service) UpdateSettings(ctx context.Context, conversationID string, req UpdateSettingsRequest) (Settings, error) { + current, err := s.GetSettings(ctx, conversationID) + if err != nil { + return Settings{}, err + } + if req.ModelID != nil { + current.ModelID = *req.ModelID + } + + pgID, err := parseUUID(conversationID) + if err != nil { + return Settings{}, err + } + row, err := s.queries.UpsertChatSettings(ctx, sqlc.UpsertChatSettingsParams{ + ID: pgID, + ModelID: toPgText(current.ModelID), + }) + if err != nil { + return Settings{}, err + } + return toSettingsFromUpsert(row), nil +} + +func toChatFromCreate(row sqlc.CreateChatRow) Chat { + return toChatFields( + row.ID, + row.BotID, + row.Kind, + row.ParentChatID, + row.Title, + row.CreatedByUserID, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toChatFromGet(row sqlc.GetChatByIDRow) Chat { + return toChatFields( + row.ID, + row.BotID, + row.Kind, + row.ParentChatID, + row.Title, + row.CreatedByUserID, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toChatFromThread(row sqlc.ListThreadsByParentRow) Chat { + return toChatFields( + row.ID, + row.BotID, + row.Kind, + row.ParentChatID, + row.Title, + row.CreatedByUserID, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toChatFields(id, botID pgtype.UUID, kind string, parentChatID pgtype.UUID, title pgtype.Text, createdBy pgtype.UUID, metadata []byte, createdAt, updatedAt pgtype.Timestamptz) Chat { + return Chat{ + ID: id.String(), + BotID: botID.String(), + Kind: kind, + ParentChatID: parentChatID.String(), + Title: dbpkg.TextToString(title), + CreatedBy: createdBy.String(), + Metadata: parseJSONMap(metadata), + CreatedAt: createdAt.Time, + UpdatedAt: updatedAt.Time, + } +} + +func toChatListItem(row sqlc.ListVisibleChatsByBotAndUserRow) ChatListItem { + return ChatListItem{ + ID: row.ID.String(), + BotID: row.BotID.String(), + Kind: row.Kind, + ParentChatID: row.ParentChatID.String(), + Title: dbpkg.TextToString(row.Title), + CreatedBy: row.CreatedByUserID.String(), + Metadata: parseJSONMap(row.Metadata), + CreatedAt: row.CreatedAt.Time, + UpdatedAt: row.UpdatedAt.Time, + AccessMode: row.AccessMode, + ParticipantRole: strings.TrimSpace(row.ParticipantRole), + LastObservedAt: pgTimePtr(row.LastObservedAt), + } +} + +func toParticipantFromAdd(row sqlc.AddChatParticipantRow) Participant { + return toParticipantFields(row.ChatID, row.UserID, row.Role, row.JoinedAt) +} + +func toParticipantFromGet(row sqlc.GetChatParticipantRow) Participant { + return toParticipantFields(row.ChatID, row.UserID, row.Role, row.JoinedAt) +} + +func toParticipantFromList(row sqlc.ListChatParticipantsRow) Participant { + return toParticipantFields(row.ChatID, row.UserID, row.Role, row.JoinedAt) +} + +func toParticipantFields(conversationID, userID pgtype.UUID, role string, joinedAt pgtype.Timestamptz) Participant { + return Participant{ + ChatID: conversationID.String(), + UserID: userID.String(), + Role: role, + JoinedAt: joinedAt.Time, + } +} + +func toSettingsFromRead(row sqlc.GetChatSettingsRow) Settings { + return Settings{ + ChatID: row.ChatID.String(), + ModelID: dbpkg.TextToString(row.ModelID), + } +} + +func toSettingsFromUpsert(row sqlc.UpsertChatSettingsRow) Settings { + return Settings{ + ChatID: row.ChatID.String(), + ModelID: dbpkg.TextToString(row.ModelID), + } +} + +func defaultSettings(conversationID string) Settings { + return Settings{ + ChatID: conversationID, + } +} + +func parseUUID(id string) (pgtype.UUID, error) { + return dbpkg.ParseUUID(id) +} + +func toPgText(s string) pgtype.Text { + s = strings.TrimSpace(s) + if s == "" { + return pgtype.Text{} + } + return pgtype.Text{String: s, Valid: true} +} + +func pgTimePtr(ts pgtype.Timestamptz) *time.Time { + if !ts.Valid { + return nil + } + value := ts.Time + return &value +} + +func nonNilMap(m map[string]any) map[string]any { + if m == nil { + return map[string]any{} + } + return m +} + +func parseJSONMap(data []byte) map[string]any { + if len(data) == 0 { + return nil + } + var m map[string]any + _ = json.Unmarshal(data, &m) + return m +} diff --git a/internal/conversation/service_presence_integration_test.go b/internal/conversation/service_presence_integration_test.go new file mode 100644 index 00000000..fef12f72 --- /dev/null +++ b/internal/conversation/service_presence_integration_test.go @@ -0,0 +1,242 @@ +package conversation_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/memohai/memoh/internal/channel/identities" + conversation "github.com/memohai/memoh/internal/conversation" + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/message" +) + +type chatPresenceFixture struct { + chatSvc *conversation.Service + messageSvc message.Service + channelIdentitySvc *identities.Service + queries *sqlc.Queries + cleanup func() +} + +func setupChatPresenceIntegrationTest(t *testing.T) chatPresenceFixture { + t.Helper() + + dsn := os.Getenv("TEST_POSTGRES_DSN") + if dsn == "" { + t.Skip("skip integration test: TEST_POSTGRES_DSN is not set") + } + + ctx := context.Background() + pool, err := pgxpool.New(ctx, dsn) + if err != nil { + t.Skipf("skip integration test: cannot connect to database: %v", err) + } + if err := pool.Ping(ctx); err != nil { + pool.Close() + t.Skipf("skip integration test: database ping failed: %v", err) + } + + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + queries := sqlc.New(pool) + + return chatPresenceFixture{ + chatSvc: conversation.NewService(logger, queries), + messageSvc: message.NewService(logger, queries), + channelIdentitySvc: identities.NewService(logger, queries), + queries: queries, + cleanup: func() { pool.Close() }, + } +} + +func createUserForChatPresence(ctx context.Context, queries *sqlc.Queries) (string, error) { + row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{ + IsActive: true, + Metadata: []byte("{}"), + }) + if err != nil { + return "", err + } + return row.ID.String(), nil +} + +func createBotForChatPresence(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) { + pgOwnerID, err := db.ParseUUID(ownerUserID) + if err != nil { + return "", err + } + meta, err := json.Marshal(map[string]any{"source": "chat-presence-integration-test"}) + if err != nil { + return "", err + } + row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{ + OwnerUserID: pgOwnerID, + Type: "personal", + DisplayName: pgtype.Text{String: "presence-test-bot", Valid: true}, + IsActive: true, + Metadata: meta, + }) + if err != nil { + return "", err + } + return row.ID.String(), nil +} + +func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, string, string, string) { + t.Helper() + + fixture := setupChatPresenceIntegrationTest(t) + ctx := context.Background() + + ownerUserID, err := createUserForChatPresence(ctx, fixture.queries) + if err != nil { + fixture.cleanup() + t.Fatalf("create owner user failed: %v", err) + } + observerUserID, err := createUserForChatPresence(ctx, fixture.queries) + if err != nil { + fixture.cleanup() + t.Fatalf("create observer user failed: %v", err) + } + botID, err := createBotForChatPresence(ctx, fixture.queries, ownerUserID) + if err != nil { + fixture.cleanup() + t.Fatalf("create bot failed: %v", err) + } + + createdChat, err := fixture.chatSvc.Create(ctx, botID, ownerUserID, conversation.CreateRequest{ + Kind: conversation.KindGroup, + Title: "presence-observed", + }) + if err != nil { + fixture.cleanup() + t.Fatalf("create chat failed: %v", err) + } + + observedChannelIdentity, err := fixture.channelIdentitySvc.ResolveByChannelIdentity( + ctx, + "feishu", + fmt.Sprintf("presence-channelIdentity-%d", time.Now().UnixNano()), + "presence-observer", + ) + if err != nil { + fixture.cleanup() + t.Fatalf("resolve channelIdentity failed: %v", err) + } + + _, err = fixture.messageSvc.Persist(ctx, message.PersistInput{ + BotID: botID, + SenderChannelIdentityID: observedChannelIdentity.ID, + Platform: "feishu", + ExternalMessageID: fmt.Sprintf("ext-msg-%d", time.Now().UnixNano()), + Role: "user", + Content: []byte(`{"content":"hello from observed channelIdentity"}`), + }) + if err != nil { + fixture.cleanup() + t.Fatalf("persist message failed: %v", err) + } + + return fixture, botID, createdChat.ID, observerUserID, observedChannelIdentity.ID +} + +func TestObservedChatVisibleAfterBindWithoutBackfill(t *testing.T) { + fixture, botID, chatID, observerUserID, observedChannelIdentityID := setupObservedChatScenario(t) + defer fixture.cleanup() + + ctx := context.Background() + beforeBind, err := fixture.chatSvc.ListByBotAndChannelIdentity(ctx, botID, observerUserID) + if err != nil { + t.Fatalf("list chats before bind failed: %v", err) + } + if len(beforeBind) != 0 { + t.Fatalf("expected no visible chats before bind, got %d", len(beforeBind)) + } + + if err := fixture.channelIdentitySvc.LinkChannelIdentityToUser(ctx, observedChannelIdentityID, observerUserID); err != nil { + t.Fatalf("link channelIdentity to user failed: %v", err) + } + + afterBind, err := fixture.chatSvc.ListByBotAndChannelIdentity(ctx, botID, observerUserID) + if err != nil { + t.Fatalf("list chats after bind failed: %v", err) + } + if len(afterBind) == 0 { + t.Fatalf("expected observed chat visible after bind, got %d chats", len(afterBind)) + } + + var target *conversation.ChatListItem + for i := range afterBind { + if afterBind[i].ID == chatID { + target = &afterBind[i] + break + } + } + if target == nil { + t.Fatalf("expected chat %s in visible list after bind", chatID) + } + if target.AccessMode != conversation.AccessModeChannelIdentityObserved { + t.Fatalf("expected access_mode=%s, got %s", conversation.AccessModeChannelIdentityObserved, target.AccessMode) + } + if target.ParticipantRole != "" { + t.Fatalf("expected empty participant_role for observed chat, got %s", target.ParticipantRole) + } + if target.LastObservedAt == nil { + t.Fatal("expected last_observed_at to be set for observed chat") + } +} + +func TestObservedAccessReadableButNotParticipant(t *testing.T) { + fixture, botID, chatID, observerUserID, observedChannelIdentityID := setupObservedChatScenario(t) + defer fixture.cleanup() + + ctx := context.Background() + if err := fixture.channelIdentitySvc.LinkChannelIdentityToUser(ctx, observedChannelIdentityID, observerUserID); err != nil { + t.Fatalf("link channelIdentity to user failed: %v", err) + } + + access, err := fixture.chatSvc.GetReadAccess(ctx, chatID, observerUserID) + if err != nil { + t.Fatalf("get read access failed: %v", err) + } + if access.AccessMode != conversation.AccessModeChannelIdentityObserved { + t.Fatalf("expected read access %s, got %s", conversation.AccessModeChannelIdentityObserved, access.AccessMode) + } + + messages, err := fixture.messageSvc.List(ctx, chatID) + if err != nil { + t.Fatalf("list messages failed: %v", err) + } + if len(messages) == 0 { + t.Fatal("expected observed user can read chat messages") + } + + _, err = fixture.chatSvc.GetParticipant(ctx, chatID, observerUserID) + if !errors.Is(err, conversation.ErrNotParticipant) { + t.Fatalf("expected ErrNotParticipant for observed user, got %v", err) + } + ok, err := fixture.chatSvc.IsParticipant(ctx, chatID, observerUserID) + if err != nil { + t.Fatalf("check participant failed: %v", err) + } + if ok { + t.Fatal("expected observed user to remain non-participant") + } + + visibleChats, err := fixture.chatSvc.ListByBotAndChannelIdentity(ctx, botID, observerUserID) + if err != nil { + t.Fatalf("list visible chats failed: %v", err) + } + if len(visibleChats) == 0 || visibleChats[0].AccessMode != conversation.AccessModeChannelIdentityObserved { + t.Fatal("expected observed list entry with channel_identity_observed access mode") + } +} diff --git a/internal/conversation/types.go b/internal/conversation/types.go new file mode 100644 index 00000000..5a10738b --- /dev/null +++ b/internal/conversation/types.go @@ -0,0 +1,235 @@ +// Package conversation defines conversation domain types and rules. +package conversation + +import ( + "encoding/json" + "strings" + "time" +) + +// Chat kind constants. +const ( + KindDirect = "direct" + KindGroup = "group" + KindThread = "thread" +) + +// Participant role constants. +const ( + RoleOwner = "owner" + RoleAdmin = "admin" + RoleMember = "member" +) + +// Chat list access mode constants. +const ( + AccessModeParticipant = "participant" + AccessModeChannelIdentityObserved = "channel_identity_observed" +) + +// Conversation is the first-class conversation container. +type Conversation struct { + ID string `json:"id"` + BotID string `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID string `json:"parent_chat_id,omitempty"` + Title string `json:"title,omitempty"` + CreatedBy string `json:"created_by"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// ConversationListItem is a conversation entry with access context for list rendering. +type ConversationListItem struct { + ID string `json:"id"` + BotID string `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID string `json:"parent_chat_id,omitempty"` + Title string `json:"title,omitempty"` + CreatedBy string `json:"created_by"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + AccessMode string `json:"access_mode"` + ParticipantRole string `json:"participant_role,omitempty"` + LastObservedAt *time.Time `json:"last_observed_at,omitempty"` +} + +// ConversationReadAccess is the resolved access context for reading conversation content. +type ConversationReadAccess struct { + AccessMode string + ParticipantRole string + LastObservedAt *time.Time +} + +// Backward-compatible aliases while call sites migrate. +type Chat = Conversation +type ChatListItem = ConversationListItem +type ChatReadAccess = ConversationReadAccess + +// Participant represents a chat member. +type Participant struct { + ChatID string `json:"chat_id"` + UserID string `json:"user_id"` + Role string `json:"role"` + JoinedAt time.Time `json:"joined_at"` +} + +// Settings holds per-chat configuration. +type Settings struct { + ChatID string `json:"chat_id"` + ModelID string `json:"model_id,omitempty"` +} + +// CreateRequest is the input for creating a bot-scoped conversation container. +type CreateRequest struct { + Kind string `json:"kind"` + Title string `json:"title,omitempty"` + ParentChatID string `json:"parent_chat_id,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// UpdateSettingsRequest is the input for updating chat settings. +type UpdateSettingsRequest struct { + ModelID *string `json:"model_id,omitempty"` +} + +// ModelMessage is the canonical message format exchanged with the agent gateway. +// Aligned with Vercel AI SDK ModelMessage structure. +type ModelMessage struct { + Role string `json:"role"` + Content json.RawMessage `json:"content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + Name string `json:"name,omitempty"` +} + +// TextContent extracts the plain text from the message content. +// If content is a string, it returns it directly. +// If content is an array of parts, it joins all text-type parts. +func (m ModelMessage) TextContent() string { + if len(m.Content) == 0 { + return "" + } + var s string + if err := json.Unmarshal(m.Content, &s); err == nil { + return s + } + var parts []ContentPart + if err := json.Unmarshal(m.Content, &parts); err == nil { + texts := make([]string, 0, len(parts)) + for _, p := range parts { + if strings.TrimSpace(p.Text) != "" { + texts = append(texts, p.Text) + } + } + return strings.Join(texts, "\n") + } + return "" +} + +// ContentParts parses the content as an array of ContentPart. +// Returns nil if the content is a plain string or not parseable. +func (m ModelMessage) ContentParts() []ContentPart { + if len(m.Content) == 0 { + return nil + } + var parts []ContentPart + if err := json.Unmarshal(m.Content, &parts); err != nil { + return nil + } + return parts +} + +// HasContent reports whether the message carries non-empty content or tool calls. +func (m ModelMessage) HasContent() bool { + if strings.TrimSpace(m.TextContent()) != "" { + return true + } + if len(m.ContentParts()) > 0 { + return true + } + return len(m.ToolCalls) > 0 +} + +// NewTextContent creates a json.RawMessage from a plain string. +func NewTextContent(text string) json.RawMessage { + data, _ := json.Marshal(text) + return data +} + +// ContentPart represents one element of a multi-part message content. +type ContentPart struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + URL string `json:"url,omitempty"` + Styles []string `json:"styles,omitempty"` + Language string `json:"language,omitempty"` + ChannelIdentityID string `json:"channel_identity_id,omitempty"` + Emoji string `json:"emoji,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// HasValue reports whether the content part carries a meaningful value. +func (p ContentPart) HasValue() bool { + return strings.TrimSpace(p.Text) != "" || + strings.TrimSpace(p.URL) != "" || + strings.TrimSpace(p.Emoji) != "" +} + +// ToolCall represents a function/tool invocation in an assistant message. +type ToolCall struct { + ID string `json:"id,omitempty"` + Type string `json:"type"` + Function ToolCallFunction `json:"function"` +} + +// ToolCallFunction holds the name and serialized arguments of a tool call. +type ToolCallFunction struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// ChatRequest is the input for Chat and StreamChat. +type ChatRequest struct { + BotID string `json:"-"` + ChatID string `json:"-"` + Token string `json:"-"` + UserID string `json:"-"` + SourceChannelIdentityID string `json:"-"` + ContainerID string `json:"-"` + DisplayName string `json:"-"` + RouteID string `json:"-"` + ChatToken string `json:"-"` + ExternalMessageID string `json:"-"` + UserMessagePersisted bool `json:"-"` + + Query string `json:"query"` + Model string `json:"model,omitempty"` + Provider string `json:"provider,omitempty"` + MaxContextLoadTime int `json:"max_context_load_time,omitempty"` + Language string `json:"language,omitempty"` + Channels []string `json:"channels,omitempty"` + CurrentChannel string `json:"current_channel,omitempty"` + Messages []ModelMessage `json:"messages,omitempty"` + Skills []string `json:"skills,omitempty"` + AllowedActions []string `json:"allowed_actions,omitempty"` +} + +// ChatResponse is the output of a non-streaming chat call. +type ChatResponse struct { + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills,omitempty"` + Model string `json:"model,omitempty"` + Provider string `json:"provider,omitempty"` +} + +// StreamChunk is a raw JSON chunk from the streaming response. +type StreamChunk = json.RawMessage + +// AssistantOutput holds extracted assistant content for downstream consumers. +type AssistantOutput struct { + Content string + Parts []ContentPart +} diff --git a/internal/db/sqlc/bind.sql.go b/internal/db/sqlc/bind.sql.go index c4df9b24..39347a54 100644 --- a/internal/db/sqlc/bind.sql.go +++ b/internal/db/sqlc/bind.sql.go @@ -12,15 +12,15 @@ import ( ) const createBindCode = `-- name: CreateBindCode :one -INSERT INTO channel_identity_bind_codes (token, issued_by_user_id, platform, expires_at) +INSERT INTO channel_identity_bind_codes (token, issued_by_user_id, channel_type, expires_at) VALUES ($1, $2, $3, $4) -RETURNING id, token, issued_by_user_id, platform, expires_at, used_at, used_by_channel_identity_id, created_at +RETURNING id, token, issued_by_user_id, channel_type, expires_at, used_at, used_by_channel_identity_id, created_at ` type CreateBindCodeParams struct { Token string `json:"token"` IssuedByUserID pgtype.UUID `json:"issued_by_user_id"` - Platform pgtype.Text `json:"platform"` + ChannelType pgtype.Text `json:"channel_type"` ExpiresAt pgtype.Timestamptz `json:"expires_at"` } @@ -28,7 +28,7 @@ func (q *Queries) CreateBindCode(ctx context.Context, arg CreateBindCodeParams) row := q.db.QueryRow(ctx, createBindCode, arg.Token, arg.IssuedByUserID, - arg.Platform, + arg.ChannelType, arg.ExpiresAt, ) var i ChannelIdentityBindCode @@ -36,7 +36,7 @@ func (q *Queries) CreateBindCode(ctx context.Context, arg CreateBindCodeParams) &i.ID, &i.Token, &i.IssuedByUserID, - &i.Platform, + &i.ChannelType, &i.ExpiresAt, &i.UsedAt, &i.UsedByChannelIdentityID, @@ -46,7 +46,7 @@ func (q *Queries) CreateBindCode(ctx context.Context, arg CreateBindCodeParams) } const getBindCode = `-- name: GetBindCode :one -SELECT id, token, issued_by_user_id, platform, expires_at, used_at, used_by_channel_identity_id, created_at +SELECT id, token, issued_by_user_id, channel_type, expires_at, used_at, used_by_channel_identity_id, created_at FROM channel_identity_bind_codes WHERE token = $1 ` @@ -58,7 +58,7 @@ func (q *Queries) GetBindCode(ctx context.Context, token string) (ChannelIdentit &i.ID, &i.Token, &i.IssuedByUserID, - &i.Platform, + &i.ChannelType, &i.ExpiresAt, &i.UsedAt, &i.UsedByChannelIdentityID, @@ -68,7 +68,7 @@ func (q *Queries) GetBindCode(ctx context.Context, token string) (ChannelIdentit } const getBindCodeForUpdate = `-- name: GetBindCodeForUpdate :one -SELECT id, token, issued_by_user_id, platform, expires_at, used_at, used_by_channel_identity_id, created_at +SELECT id, token, issued_by_user_id, channel_type, expires_at, used_at, used_by_channel_identity_id, created_at FROM channel_identity_bind_codes WHERE token = $1 FOR UPDATE @@ -81,7 +81,7 @@ func (q *Queries) GetBindCodeForUpdate(ctx context.Context, token string) (Chann &i.ID, &i.Token, &i.IssuedByUserID, - &i.Platform, + &i.ChannelType, &i.ExpiresAt, &i.UsedAt, &i.UsedByChannelIdentityID, @@ -95,7 +95,7 @@ UPDATE channel_identity_bind_codes SET used_at = now(), used_by_channel_identity_id = $2 WHERE id = $1 AND used_at IS NULL -RETURNING id, token, issued_by_user_id, platform, expires_at, used_at, used_by_channel_identity_id, created_at +RETURNING id, token, issued_by_user_id, channel_type, expires_at, used_at, used_by_channel_identity_id, created_at ` type MarkBindCodeUsedParams struct { @@ -110,7 +110,7 @@ func (q *Queries) MarkBindCodeUsed(ctx context.Context, arg MarkBindCodeUsedPara &i.ID, &i.Token, &i.IssuedByUserID, - &i.Platform, + &i.ChannelType, &i.ExpiresAt, &i.UsedAt, &i.UsedByChannelIdentityID, diff --git a/internal/db/sqlc/bots.sql.go b/internal/db/sqlc/bots.sql.go index 3fafd6d9..7d175c97 100644 --- a/internal/db/sqlc/bots.sql.go +++ b/internal/db/sqlc/bots.sql.go @@ -12,9 +12,9 @@ import ( ) const createBot = `-- name: CreateBot :one -INSERT INTO bots (owner_user_id, type, display_name, avatar_url, is_active, metadata) -VALUES ($1, $2, $3, $4, $5, $6) -RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at +INSERT INTO bots (owner_user_id, type, display_name, avatar_url, is_active, metadata, status) +VALUES ($1, $2, $3, $4, $5, $6, $7) +RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at ` type CreateBotParams struct { @@ -24,6 +24,7 @@ type CreateBotParams struct { AvatarUrl pgtype.Text `json:"avatar_url"` IsActive bool `json:"is_active"` Metadata []byte `json:"metadata"` + Status string `json:"status"` } func (q *Queries) CreateBot(ctx context.Context, arg CreateBotParams) (Bot, error) { @@ -34,6 +35,7 @@ func (q *Queries) CreateBot(ctx context.Context, arg CreateBotParams) (Bot, erro arg.AvatarUrl, arg.IsActive, arg.Metadata, + arg.Status, ) var i Bot err := row.Scan( @@ -43,6 +45,7 @@ func (q *Queries) CreateBot(ctx context.Context, arg CreateBotParams) (Bot, erro &i.DisplayName, &i.AvatarUrl, &i.IsActive, + &i.Status, &i.MaxContextLoadTime, &i.Language, &i.AllowGuest, @@ -80,7 +83,7 @@ func (q *Queries) DeleteBotMember(ctx context.Context, arg DeleteBotMemberParams } const getBotByID = `-- name: GetBotByID :one -SELECT id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at +SELECT id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at FROM bots WHERE id = $1 ` @@ -95,6 +98,7 @@ func (q *Queries) GetBotByID(ctx context.Context, id pgtype.UUID) (Bot, error) { &i.DisplayName, &i.AvatarUrl, &i.IsActive, + &i.Status, &i.MaxContextLoadTime, &i.Language, &i.AllowGuest, @@ -165,7 +169,7 @@ func (q *Queries) ListBotMembers(ctx context.Context, botID pgtype.UUID) ([]BotM } const listBotsByMember = `-- name: ListBotsByMember :many -SELECT b.id, b.owner_user_id, b.type, b.display_name, b.avatar_url, b.is_active, b.max_context_load_time, b.language, b.allow_guest, b.chat_model_id, b.memory_model_id, b.embedding_model_id, b.metadata, b.created_at, b.updated_at +SELECT b.id, b.owner_user_id, b.type, b.display_name, b.avatar_url, b.is_active, b.status, b.max_context_load_time, b.language, b.allow_guest, b.chat_model_id, b.memory_model_id, b.embedding_model_id, b.metadata, b.created_at, b.updated_at FROM bots b JOIN bot_members m ON m.bot_id = b.id WHERE m.user_id = $1 @@ -188,6 +192,7 @@ func (q *Queries) ListBotsByMember(ctx context.Context, userID pgtype.UUID) ([]B &i.DisplayName, &i.AvatarUrl, &i.IsActive, + &i.Status, &i.MaxContextLoadTime, &i.Language, &i.AllowGuest, @@ -209,7 +214,7 @@ func (q *Queries) ListBotsByMember(ctx context.Context, userID pgtype.UUID) ([]B } const listBotsByOwner = `-- name: ListBotsByOwner :many -SELECT id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at +SELECT id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at FROM bots WHERE owner_user_id = $1 ORDER BY created_at DESC @@ -231,6 +236,7 @@ func (q *Queries) ListBotsByOwner(ctx context.Context, ownerUserID pgtype.UUID) &i.DisplayName, &i.AvatarUrl, &i.IsActive, + &i.Status, &i.MaxContextLoadTime, &i.Language, &i.AllowGuest, @@ -256,7 +262,7 @@ UPDATE bots SET owner_user_id = $2, updated_at = now() WHERE id = $1 -RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at +RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at ` type UpdateBotOwnerParams struct { @@ -274,6 +280,7 @@ func (q *Queries) UpdateBotOwner(ctx context.Context, arg UpdateBotOwnerParams) &i.DisplayName, &i.AvatarUrl, &i.IsActive, + &i.Status, &i.MaxContextLoadTime, &i.Language, &i.AllowGuest, @@ -295,7 +302,7 @@ SET display_name = $2, metadata = $5, updated_at = now() WHERE id = $1 -RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at +RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at ` type UpdateBotProfileParams struct { @@ -322,6 +329,7 @@ func (q *Queries) UpdateBotProfile(ctx context.Context, arg UpdateBotProfilePara &i.DisplayName, &i.AvatarUrl, &i.IsActive, + &i.Status, &i.MaxContextLoadTime, &i.Language, &i.AllowGuest, @@ -335,6 +343,23 @@ func (q *Queries) UpdateBotProfile(ctx context.Context, arg UpdateBotProfilePara return i, err } +const updateBotStatus = `-- name: UpdateBotStatus :exec +UPDATE bots +SET status = $2, + updated_at = now() +WHERE id = $1 +` + +type UpdateBotStatusParams struct { + ID pgtype.UUID `json:"id"` + Status string `json:"status"` +} + +func (q *Queries) UpdateBotStatus(ctx context.Context, arg UpdateBotStatusParams) error { + _, err := q.db.Exec(ctx, updateBotStatus, arg.ID, arg.Status) + return err +} + const upsertBotMember = `-- name: UpsertBotMember :one INSERT INTO bot_members (bot_id, user_id, role) VALUES ($1, $2, $3) diff --git a/internal/db/sqlc/channel_identities.sql.go b/internal/db/sqlc/channel_identities.sql.go index 280706b2..b7ad5b17 100644 --- a/internal/db/sqlc/channel_identities.sql.go +++ b/internal/db/sqlc/channel_identities.sql.go @@ -15,7 +15,7 @@ const clearChannelIdentityLinkedUser = `-- name: ClearChannelIdentityLinkedUser UPDATE channel_identities SET user_id = NULL, updated_at = now() WHERE id = $1 -RETURNING id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +RETURNING id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at ` func (q *Queries) ClearChannelIdentityLinkedUser(ctx context.Context, id pgtype.UUID) (ChannelIdentity, error) { @@ -24,7 +24,7 @@ func (q *Queries) ClearChannelIdentityLinkedUser(ctx context.Context, id pgtype. err := row.Scan( &i.ID, &i.UserID, - &i.Channel, + &i.ChannelType, &i.ChannelSubjectID, &i.DisplayName, &i.Metadata, @@ -35,14 +35,14 @@ func (q *Queries) ClearChannelIdentityLinkedUser(ctx context.Context, id pgtype. } const createChannelIdentity = `-- name: CreateChannelIdentity :one -INSERT INTO channel_identities (user_id, channel, channel_subject_id, display_name, metadata) +INSERT INTO channel_identities (user_id, channel_type, channel_subject_id, display_name, metadata) VALUES ($1, $2, $3, $4, $5) -RETURNING id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +RETURNING id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at ` type CreateChannelIdentityParams struct { UserID pgtype.UUID `json:"user_id"` - Channel string `json:"channel"` + ChannelType string `json:"channel_type"` ChannelSubjectID string `json:"channel_subject_id"` DisplayName pgtype.Text `json:"display_name"` Metadata []byte `json:"metadata"` @@ -51,7 +51,7 @@ type CreateChannelIdentityParams struct { func (q *Queries) CreateChannelIdentity(ctx context.Context, arg CreateChannelIdentityParams) (ChannelIdentity, error) { row := q.db.QueryRow(ctx, createChannelIdentity, arg.UserID, - arg.Channel, + arg.ChannelType, arg.ChannelSubjectID, arg.DisplayName, arg.Metadata, @@ -60,7 +60,7 @@ func (q *Queries) CreateChannelIdentity(ctx context.Context, arg CreateChannelId err := row.Scan( &i.ID, &i.UserID, - &i.Channel, + &i.ChannelType, &i.ChannelSubjectID, &i.DisplayName, &i.Metadata, @@ -71,23 +71,23 @@ func (q *Queries) CreateChannelIdentity(ctx context.Context, arg CreateChannelId } const getChannelIdentityByChannelSubject = `-- name: GetChannelIdentityByChannelSubject :one -SELECT id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +SELECT id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at FROM channel_identities -WHERE channel = $1 AND channel_subject_id = $2 +WHERE channel_type = $1 AND channel_subject_id = $2 ` type GetChannelIdentityByChannelSubjectParams struct { - Channel string `json:"channel"` + ChannelType string `json:"channel_type"` ChannelSubjectID string `json:"channel_subject_id"` } func (q *Queries) GetChannelIdentityByChannelSubject(ctx context.Context, arg GetChannelIdentityByChannelSubjectParams) (ChannelIdentity, error) { - row := q.db.QueryRow(ctx, getChannelIdentityByChannelSubject, arg.Channel, arg.ChannelSubjectID) + row := q.db.QueryRow(ctx, getChannelIdentityByChannelSubject, arg.ChannelType, arg.ChannelSubjectID) var i ChannelIdentity err := row.Scan( &i.ID, &i.UserID, - &i.Channel, + &i.ChannelType, &i.ChannelSubjectID, &i.DisplayName, &i.Metadata, @@ -98,7 +98,7 @@ func (q *Queries) GetChannelIdentityByChannelSubject(ctx context.Context, arg Ge } const getChannelIdentityByID = `-- name: GetChannelIdentityByID :one -SELECT id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +SELECT id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at FROM channel_identities WHERE id = $1 ` @@ -109,7 +109,7 @@ func (q *Queries) GetChannelIdentityByID(ctx context.Context, id pgtype.UUID) (C err := row.Scan( &i.ID, &i.UserID, - &i.Channel, + &i.ChannelType, &i.ChannelSubjectID, &i.DisplayName, &i.Metadata, @@ -120,7 +120,7 @@ func (q *Queries) GetChannelIdentityByID(ctx context.Context, id pgtype.UUID) (C } const getChannelIdentityByIDForUpdate = `-- name: GetChannelIdentityByIDForUpdate :one -SELECT id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +SELECT id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at FROM channel_identities WHERE id = $1 FOR UPDATE @@ -132,7 +132,7 @@ func (q *Queries) GetChannelIdentityByIDForUpdate(ctx context.Context, id pgtype err := row.Scan( &i.ID, &i.UserID, - &i.Channel, + &i.ChannelType, &i.ChannelSubjectID, &i.DisplayName, &i.Metadata, @@ -143,7 +143,7 @@ func (q *Queries) GetChannelIdentityByIDForUpdate(ctx context.Context, id pgtype } const listChannelIdentitiesByUserID = `-- name: ListChannelIdentitiesByUserID :many -SELECT id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +SELECT id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at FROM channel_identities WHERE user_id = $1 ORDER BY created_at DESC @@ -161,7 +161,7 @@ func (q *Queries) ListChannelIdentitiesByUserID(ctx context.Context, userID pgty if err := rows.Scan( &i.ID, &i.UserID, - &i.Channel, + &i.ChannelType, &i.ChannelSubjectID, &i.DisplayName, &i.Metadata, @@ -182,7 +182,7 @@ const setChannelIdentityLinkedUser = `-- name: SetChannelIdentityLinkedUser :one UPDATE channel_identities SET user_id = $2, updated_at = now() WHERE id = $1 -RETURNING id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +RETURNING id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at ` type SetChannelIdentityLinkedUserParams struct { @@ -196,7 +196,7 @@ func (q *Queries) SetChannelIdentityLinkedUser(ctx context.Context, arg SetChann err := row.Scan( &i.ID, &i.UserID, - &i.Channel, + &i.ChannelType, &i.ChannelSubjectID, &i.DisplayName, &i.Metadata, @@ -207,20 +207,20 @@ func (q *Queries) SetChannelIdentityLinkedUser(ctx context.Context, arg SetChann } const upsertChannelIdentityByChannelSubject = `-- name: UpsertChannelIdentityByChannelSubject :one -INSERT INTO channel_identities (user_id, channel, channel_subject_id, display_name, metadata) +INSERT INTO channel_identities (user_id, channel_type, channel_subject_id, display_name, metadata) VALUES ($1, $2, $3, $4, $5) -ON CONFLICT (channel, channel_subject_id) +ON CONFLICT (channel_type, channel_subject_id) DO UPDATE SET - display_name = EXCLUDED.display_name, + display_name = COALESCE(NULLIF(EXCLUDED.display_name, ''), channel_identities.display_name), metadata = EXCLUDED.metadata, user_id = COALESCE(channel_identities.user_id, EXCLUDED.user_id), updated_at = now() -RETURNING id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +RETURNING id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at ` type UpsertChannelIdentityByChannelSubjectParams struct { UserID pgtype.UUID `json:"user_id"` - Channel string `json:"channel"` + ChannelType string `json:"channel_type"` ChannelSubjectID string `json:"channel_subject_id"` DisplayName pgtype.Text `json:"display_name"` Metadata []byte `json:"metadata"` @@ -229,7 +229,7 @@ type UpsertChannelIdentityByChannelSubjectParams struct { func (q *Queries) UpsertChannelIdentityByChannelSubject(ctx context.Context, arg UpsertChannelIdentityByChannelSubjectParams) (ChannelIdentity, error) { row := q.db.QueryRow(ctx, upsertChannelIdentityByChannelSubject, arg.UserID, - arg.Channel, + arg.ChannelType, arg.ChannelSubjectID, arg.DisplayName, arg.Metadata, @@ -238,7 +238,7 @@ func (q *Queries) UpsertChannelIdentityByChannelSubject(ctx context.Context, arg err := row.Scan( &i.ID, &i.UserID, - &i.Channel, + &i.ChannelType, &i.ChannelSubjectID, &i.DisplayName, &i.Metadata, diff --git a/internal/db/sqlc/channel_routes.sql.go b/internal/db/sqlc/channel_routes.sql.go new file mode 100644 index 00000000..0be56106 --- /dev/null +++ b/internal/db/sqlc/channel_routes.sql.go @@ -0,0 +1,298 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: channel_routes.sql + +package sqlc + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createChatRoute = `-- name: CreateChatRoute :one +INSERT INTO bot_channel_routes ( + bot_id, channel_type, channel_config_id, external_conversation_id, external_thread_id, default_reply_target, metadata +) +VALUES ( + $1, + $2, + $3::uuid, + $4, + $5::text, + $6::text, + $7 +) +RETURNING + id, + $8::uuid AS chat_id, + bot_id, + channel_type AS platform, + channel_config_id, + external_conversation_id AS conversation_id, + external_thread_id AS thread_id, + default_reply_target AS reply_target, + metadata, + created_at, + updated_at +` + +type CreateChatRouteParams struct { + BotID pgtype.UUID `json:"bot_id"` + Platform string `json:"platform"` + ChannelConfigID pgtype.UUID `json:"channel_config_id"` + ConversationID string `json:"conversation_id"` + ThreadID pgtype.Text `json:"thread_id"` + ReplyTarget pgtype.Text `json:"reply_target"` + Metadata []byte `json:"metadata"` + ChatID pgtype.UUID `json:"chat_id"` +} + +type CreateChatRouteRow struct { + ID pgtype.UUID `json:"id"` + ChatID pgtype.UUID `json:"chat_id"` + BotID pgtype.UUID `json:"bot_id"` + Platform string `json:"platform"` + ChannelConfigID pgtype.UUID `json:"channel_config_id"` + ConversationID string `json:"conversation_id"` + ThreadID pgtype.Text `json:"thread_id"` + ReplyTarget pgtype.Text `json:"reply_target"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) CreateChatRoute(ctx context.Context, arg CreateChatRouteParams) (CreateChatRouteRow, error) { + row := q.db.QueryRow(ctx, createChatRoute, + arg.BotID, + arg.Platform, + arg.ChannelConfigID, + arg.ConversationID, + arg.ThreadID, + arg.ReplyTarget, + arg.Metadata, + arg.ChatID, + ) + var i CreateChatRouteRow + err := row.Scan( + &i.ID, + &i.ChatID, + &i.BotID, + &i.Platform, + &i.ChannelConfigID, + &i.ConversationID, + &i.ThreadID, + &i.ReplyTarget, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const deleteChatRoute = `-- name: DeleteChatRoute :exec +DELETE FROM bot_channel_routes +WHERE id = $1 +` + +func (q *Queries) DeleteChatRoute(ctx context.Context, id pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteChatRoute, id) + return err +} + +const findChatRoute = `-- name: FindChatRoute :one +SELECT + id, + bot_id AS chat_id, + bot_id, + channel_type AS platform, + channel_config_id, + external_conversation_id AS conversation_id, + external_thread_id AS thread_id, + default_reply_target AS reply_target, + metadata, + created_at, + updated_at +FROM bot_channel_routes +WHERE bot_id = $1 + AND channel_type = $2 + AND external_conversation_id = $3 + AND COALESCE(external_thread_id, '') = COALESCE($4, '') +LIMIT 1 +` + +type FindChatRouteParams struct { + BotID pgtype.UUID `json:"bot_id"` + Platform string `json:"platform"` + ConversationID string `json:"conversation_id"` + ThreadID pgtype.Text `json:"thread_id"` +} + +type FindChatRouteRow struct { + ID pgtype.UUID `json:"id"` + ChatID pgtype.UUID `json:"chat_id"` + BotID pgtype.UUID `json:"bot_id"` + Platform string `json:"platform"` + ChannelConfigID pgtype.UUID `json:"channel_config_id"` + ConversationID string `json:"conversation_id"` + ThreadID pgtype.Text `json:"thread_id"` + ReplyTarget pgtype.Text `json:"reply_target"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) FindChatRoute(ctx context.Context, arg FindChatRouteParams) (FindChatRouteRow, error) { + row := q.db.QueryRow(ctx, findChatRoute, + arg.BotID, + arg.Platform, + arg.ConversationID, + arg.ThreadID, + ) + var i FindChatRouteRow + err := row.Scan( + &i.ID, + &i.ChatID, + &i.BotID, + &i.Platform, + &i.ChannelConfigID, + &i.ConversationID, + &i.ThreadID, + &i.ReplyTarget, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getChatRouteByID = `-- name: GetChatRouteByID :one +SELECT + id, + bot_id AS chat_id, + bot_id, + channel_type AS platform, + channel_config_id, + external_conversation_id AS conversation_id, + external_thread_id AS thread_id, + default_reply_target AS reply_target, + metadata, + created_at, + updated_at +FROM bot_channel_routes +WHERE id = $1 +` + +type GetChatRouteByIDRow struct { + ID pgtype.UUID `json:"id"` + ChatID pgtype.UUID `json:"chat_id"` + BotID pgtype.UUID `json:"bot_id"` + Platform string `json:"platform"` + ChannelConfigID pgtype.UUID `json:"channel_config_id"` + ConversationID string `json:"conversation_id"` + ThreadID pgtype.Text `json:"thread_id"` + ReplyTarget pgtype.Text `json:"reply_target"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) GetChatRouteByID(ctx context.Context, id pgtype.UUID) (GetChatRouteByIDRow, error) { + row := q.db.QueryRow(ctx, getChatRouteByID, id) + var i GetChatRouteByIDRow + err := row.Scan( + &i.ID, + &i.ChatID, + &i.BotID, + &i.Platform, + &i.ChannelConfigID, + &i.ConversationID, + &i.ThreadID, + &i.ReplyTarget, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const listChatRoutes = `-- name: ListChatRoutes :many +SELECT + id, + bot_id AS chat_id, + bot_id, + channel_type AS platform, + channel_config_id, + external_conversation_id AS conversation_id, + external_thread_id AS thread_id, + default_reply_target AS reply_target, + metadata, + created_at, + updated_at +FROM bot_channel_routes +WHERE bot_id = $1 +ORDER BY created_at ASC +` + +type ListChatRoutesRow struct { + ID pgtype.UUID `json:"id"` + ChatID pgtype.UUID `json:"chat_id"` + BotID pgtype.UUID `json:"bot_id"` + Platform string `json:"platform"` + ChannelConfigID pgtype.UUID `json:"channel_config_id"` + ConversationID string `json:"conversation_id"` + ThreadID pgtype.Text `json:"thread_id"` + ReplyTarget pgtype.Text `json:"reply_target"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) ListChatRoutes(ctx context.Context, chatID pgtype.UUID) ([]ListChatRoutesRow, error) { + rows, err := q.db.Query(ctx, listChatRoutes, chatID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListChatRoutesRow + for rows.Next() { + var i ListChatRoutesRow + if err := rows.Scan( + &i.ID, + &i.ChatID, + &i.BotID, + &i.Platform, + &i.ChannelConfigID, + &i.ConversationID, + &i.ThreadID, + &i.ReplyTarget, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateChatRouteReplyTarget = `-- name: UpdateChatRouteReplyTarget :exec +UPDATE bot_channel_routes +SET default_reply_target = $1, updated_at = now() +WHERE id = $2 +` + +type UpdateChatRouteReplyTargetParams struct { + ReplyTarget pgtype.Text `json:"reply_target"` + ID pgtype.UUID `json:"id"` +} + +func (q *Queries) UpdateChatRouteReplyTarget(ctx context.Context, arg UpdateChatRouteReplyTargetParams) error { + _, err := q.db.Exec(ctx, updateChatRouteReplyTarget, arg.ReplyTarget, arg.ID) + return err +} diff --git a/internal/db/sqlc/channels.sql.go b/internal/db/sqlc/channels.sql.go index 2bd0a488..52c62a40 100644 --- a/internal/db/sqlc/channels.sql.go +++ b/internal/db/sqlc/channels.sql.go @@ -76,24 +76,24 @@ func (q *Queries) GetBotChannelConfigByExternalIdentity(ctx context.Context, arg } const getUserChannelBinding = `-- name: GetUserChannelBinding :one -SELECT id, user_id, platform, config, created_at, updated_at +SELECT id, user_id, channel_type, config, created_at, updated_at FROM user_channel_bindings -WHERE user_id = $1 AND platform = $2 +WHERE user_id = $1 AND channel_type = $2 LIMIT 1 ` type GetUserChannelBindingParams struct { - UserID pgtype.UUID `json:"user_id"` - Platform string `json:"platform"` + UserID pgtype.UUID `json:"user_id"` + ChannelType string `json:"channel_type"` } func (q *Queries) GetUserChannelBinding(ctx context.Context, arg GetUserChannelBindingParams) (UserChannelBinding, error) { - row := q.db.QueryRow(ctx, getUserChannelBinding, arg.UserID, arg.Platform) + row := q.db.QueryRow(ctx, getUserChannelBinding, arg.UserID, arg.ChannelType) var i UserChannelBinding err := row.Scan( &i.ID, &i.UserID, - &i.Platform, + &i.ChannelType, &i.Config, &i.CreatedAt, &i.UpdatedAt, @@ -142,14 +142,14 @@ func (q *Queries) ListBotChannelConfigsByType(ctx context.Context, channelType s } const listUserChannelBindingsByPlatform = `-- name: ListUserChannelBindingsByPlatform :many -SELECT id, user_id, platform, config, created_at, updated_at +SELECT id, user_id, channel_type, config, created_at, updated_at FROM user_channel_bindings -WHERE platform = $1 +WHERE channel_type = $1 ORDER BY created_at DESC ` -func (q *Queries) ListUserChannelBindingsByPlatform(ctx context.Context, platform string) ([]UserChannelBinding, error) { - rows, err := q.db.Query(ctx, listUserChannelBindingsByPlatform, platform) +func (q *Queries) ListUserChannelBindingsByPlatform(ctx context.Context, channelType string) ([]UserChannelBinding, error) { + rows, err := q.db.Query(ctx, listUserChannelBindingsByPlatform, channelType) if err != nil { return nil, err } @@ -160,7 +160,7 @@ func (q *Queries) ListUserChannelBindingsByPlatform(ctx context.Context, platfor if err := rows.Scan( &i.ID, &i.UserID, - &i.Platform, + &i.ChannelType, &i.Config, &i.CreatedAt, &i.UpdatedAt, @@ -236,28 +236,28 @@ func (q *Queries) UpsertBotChannelConfig(ctx context.Context, arg UpsertBotChann } const upsertUserChannelBinding = `-- name: UpsertUserChannelBinding :one -INSERT INTO user_channel_bindings (user_id, platform, config) +INSERT INTO user_channel_bindings (user_id, channel_type, config) VALUES ($1, $2, $3) -ON CONFLICT (user_id, platform) +ON CONFLICT (user_id, channel_type) DO UPDATE SET config = EXCLUDED.config, updated_at = now() -RETURNING id, user_id, platform, config, created_at, updated_at +RETURNING id, user_id, channel_type, config, created_at, updated_at ` type UpsertUserChannelBindingParams struct { - UserID pgtype.UUID `json:"user_id"` - Platform string `json:"platform"` - Config []byte `json:"config"` + UserID pgtype.UUID `json:"user_id"` + ChannelType string `json:"channel_type"` + Config []byte `json:"config"` } func (q *Queries) UpsertUserChannelBinding(ctx context.Context, arg UpsertUserChannelBindingParams) (UserChannelBinding, error) { - row := q.db.QueryRow(ctx, upsertUserChannelBinding, arg.UserID, arg.Platform, arg.Config) + row := q.db.QueryRow(ctx, upsertUserChannelBinding, arg.UserID, arg.ChannelType, arg.Config) var i UserChannelBinding err := row.Scan( &i.ID, &i.UserID, - &i.Platform, + &i.ChannelType, &i.Config, &i.CreatedAt, &i.UpdatedAt, diff --git a/internal/db/sqlc/chats.sql.go b/internal/db/sqlc/chats.sql.go deleted file mode 100644 index 0dce7be4..00000000 --- a/internal/db/sqlc/chats.sql.go +++ /dev/null @@ -1,988 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. -// versions: -// sqlc v1.30.0 -// source: chats.sql - -package sqlc - -import ( - "context" - - "github.com/jackc/pgx/v5/pgtype" -) - -const addChatParticipant = `-- name: AddChatParticipant :one - -INSERT INTO chat_participants (chat_id, user_id, role) -VALUES ($1, $2, $3) -ON CONFLICT (chat_id, user_id) DO UPDATE SET role = EXCLUDED.role -RETURNING chat_id, user_id, role, joined_at -` - -type AddChatParticipantParams struct { - ChatID pgtype.UUID `json:"chat_id"` - UserID pgtype.UUID `json:"user_id"` - Role string `json:"role"` -} - -// chat_participants -func (q *Queries) AddChatParticipant(ctx context.Context, arg AddChatParticipantParams) (ChatParticipant, error) { - row := q.db.QueryRow(ctx, addChatParticipant, arg.ChatID, arg.UserID, arg.Role) - var i ChatParticipant - err := row.Scan( - &i.ChatID, - &i.UserID, - &i.Role, - &i.JoinedAt, - ) - return i, err -} - -const copyParticipantsToChat = `-- name: CopyParticipantsToChat :exec -INSERT INTO chat_participants (chat_id, user_id, role) -SELECT $2, cp.user_id, cp.role FROM chat_participants cp WHERE cp.chat_id = $1 -ON CONFLICT (chat_id, user_id) DO NOTHING -` - -type CopyParticipantsToChatParams struct { - ChatID pgtype.UUID `json:"chat_id"` - ChatID_2 pgtype.UUID `json:"chat_id_2"` -} - -func (q *Queries) CopyParticipantsToChat(ctx context.Context, arg CopyParticipantsToChatParams) error { - _, err := q.db.Exec(ctx, copyParticipantsToChat, arg.ChatID, arg.ChatID_2) - return err -} - -const createChat = `-- name: CreateChat :one -INSERT INTO chats (bot_id, kind, parent_chat_id, title, created_by_user_id, metadata) -VALUES ($1, $2, $3, $4, $5, $6) -RETURNING id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata, created_at, updated_at -` - -type CreateChatParams struct { - BotID pgtype.UUID `json:"bot_id"` - Kind string `json:"kind"` - ParentChatID pgtype.UUID `json:"parent_chat_id"` - Title pgtype.Text `json:"title"` - CreatedByUserID pgtype.UUID `json:"created_by_user_id"` - Metadata []byte `json:"metadata"` -} - -func (q *Queries) CreateChat(ctx context.Context, arg CreateChatParams) (Chat, error) { - row := q.db.QueryRow(ctx, createChat, - arg.BotID, - arg.Kind, - arg.ParentChatID, - arg.Title, - arg.CreatedByUserID, - arg.Metadata, - ) - var i Chat - err := row.Scan( - &i.ID, - &i.BotID, - &i.Kind, - &i.ParentChatID, - &i.Title, - &i.CreatedByUserID, - &i.Metadata, - &i.EnableChatMemory, - &i.EnablePrivateMemory, - &i.EnablePublicMemory, - &i.ModelID, - &i.SettingsMetadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const createChatMessage = `-- name: CreateChatMessage :one - -INSERT INTO chat_messages (chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) -RETURNING id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at -` - -type CreateChatMessageParams struct { - ChatID pgtype.UUID `json:"chat_id"` - BotID pgtype.UUID `json:"bot_id"` - RouteID pgtype.UUID `json:"route_id"` - SenderChannelIdentityID pgtype.UUID `json:"sender_channel_identity_id"` - SenderUserID pgtype.UUID `json:"sender_user_id"` - Platform pgtype.Text `json:"platform"` - ExternalMessageID pgtype.Text `json:"external_message_id"` - Role string `json:"role"` - Content []byte `json:"content"` - Metadata []byte `json:"metadata"` -} - -// chat_messages -func (q *Queries) CreateChatMessage(ctx context.Context, arg CreateChatMessageParams) (ChatMessage, error) { - row := q.db.QueryRow(ctx, createChatMessage, - arg.ChatID, - arg.BotID, - arg.RouteID, - arg.SenderChannelIdentityID, - arg.SenderUserID, - arg.Platform, - arg.ExternalMessageID, - arg.Role, - arg.Content, - arg.Metadata, - ) - var i ChatMessage - err := row.Scan( - &i.ID, - &i.ChatID, - &i.BotID, - &i.RouteID, - &i.SenderChannelIdentityID, - &i.SenderUserID, - &i.Platform, - &i.ExternalMessageID, - &i.Role, - &i.Content, - &i.Metadata, - &i.CreatedAt, - ) - return i, err -} - -const createChatRoute = `-- name: CreateChatRoute :one - -INSERT INTO chat_routes (chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8) -RETURNING id, chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata, created_at, updated_at -` - -type CreateChatRouteParams struct { - ChatID pgtype.UUID `json:"chat_id"` - BotID pgtype.UUID `json:"bot_id"` - Platform string `json:"platform"` - ChannelConfigID pgtype.UUID `json:"channel_config_id"` - ConversationID string `json:"conversation_id"` - ThreadID pgtype.Text `json:"thread_id"` - ReplyTarget pgtype.Text `json:"reply_target"` - Metadata []byte `json:"metadata"` -} - -// chat_routes -func (q *Queries) CreateChatRoute(ctx context.Context, arg CreateChatRouteParams) (ChatRoute, error) { - row := q.db.QueryRow(ctx, createChatRoute, - arg.ChatID, - arg.BotID, - arg.Platform, - arg.ChannelConfigID, - arg.ConversationID, - arg.ThreadID, - arg.ReplyTarget, - arg.Metadata, - ) - var i ChatRoute - err := row.Scan( - &i.ID, - &i.ChatID, - &i.BotID, - &i.Platform, - &i.ChannelConfigID, - &i.ConversationID, - &i.ThreadID, - &i.ReplyTarget, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const deleteChat = `-- name: DeleteChat :exec -DELETE FROM chats WHERE id = $1 -` - -func (q *Queries) DeleteChat(ctx context.Context, id pgtype.UUID) error { - _, err := q.db.Exec(ctx, deleteChat, id) - return err -} - -const deleteChatMessagesByChat = `-- name: DeleteChatMessagesByChat :exec -DELETE FROM chat_messages WHERE chat_id = $1 -` - -func (q *Queries) DeleteChatMessagesByChat(ctx context.Context, chatID pgtype.UUID) error { - _, err := q.db.Exec(ctx, deleteChatMessagesByChat, chatID) - return err -} - -const deleteChatRoute = `-- name: DeleteChatRoute :exec -DELETE FROM chat_routes WHERE id = $1 -` - -func (q *Queries) DeleteChatRoute(ctx context.Context, id pgtype.UUID) error { - _, err := q.db.Exec(ctx, deleteChatRoute, id) - return err -} - -const findChatRoute = `-- name: FindChatRoute :one -SELECT id, chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata, created_at, updated_at -FROM chat_routes -WHERE bot_id = $1 AND platform = $2 AND conversation_id = $3 - AND COALESCE(thread_id, '') = COALESCE($4, '') -LIMIT 1 -` - -type FindChatRouteParams struct { - BotID pgtype.UUID `json:"bot_id"` - Platform string `json:"platform"` - ConversationID string `json:"conversation_id"` - ThreadID pgtype.Text `json:"thread_id"` -} - -func (q *Queries) FindChatRoute(ctx context.Context, arg FindChatRouteParams) (ChatRoute, error) { - row := q.db.QueryRow(ctx, findChatRoute, - arg.BotID, - arg.Platform, - arg.ConversationID, - arg.ThreadID, - ) - var i ChatRoute - err := row.Scan( - &i.ID, - &i.ChatID, - &i.BotID, - &i.Platform, - &i.ChannelConfigID, - &i.ConversationID, - &i.ThreadID, - &i.ReplyTarget, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const getChatByID = `-- name: GetChatByID :one -SELECT id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata, created_at, updated_at -FROM chats -WHERE id = $1 -` - -func (q *Queries) GetChatByID(ctx context.Context, id pgtype.UUID) (Chat, error) { - row := q.db.QueryRow(ctx, getChatByID, id) - var i Chat - err := row.Scan( - &i.ID, - &i.BotID, - &i.Kind, - &i.ParentChatID, - &i.Title, - &i.CreatedByUserID, - &i.Metadata, - &i.EnableChatMemory, - &i.EnablePrivateMemory, - &i.EnablePublicMemory, - &i.ModelID, - &i.SettingsMetadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const getChatParticipant = `-- name: GetChatParticipant :one -SELECT chat_id, user_id, role, joined_at -FROM chat_participants -WHERE chat_id = $1 AND user_id = $2 -` - -type GetChatParticipantParams struct { - ChatID pgtype.UUID `json:"chat_id"` - UserID pgtype.UUID `json:"user_id"` -} - -func (q *Queries) GetChatParticipant(ctx context.Context, arg GetChatParticipantParams) (ChatParticipant, error) { - row := q.db.QueryRow(ctx, getChatParticipant, arg.ChatID, arg.UserID) - var i ChatParticipant - err := row.Scan( - &i.ChatID, - &i.UserID, - &i.Role, - &i.JoinedAt, - ) - return i, err -} - -const getChatReadAccessByUser = `-- name: GetChatReadAccessByUser :one -WITH participant_access AS ( - SELECT 'participant'::text AS access_mode, - cp.role AS participant_role, - NULL::timestamptz AS last_observed_at - FROM chat_participants cp - WHERE cp.chat_id = $1 AND cp.user_id = $2 -), -observed_access AS ( - SELECT 'channel_identity_observed'::text AS access_mode, - ''::text AS participant_role, - MAX(cap.last_seen_at) AS last_observed_at - FROM chat_channel_identity_presence cap - JOIN channel_identities ci ON ci.id = cap.channel_identity_id - WHERE cap.chat_id = $1 AND ci.user_id = $2 - GROUP BY cap.chat_id -), -all_access AS ( - SELECT access_mode, participant_role, last_observed_at FROM participant_access - UNION ALL - SELECT access_mode, participant_role, last_observed_at FROM observed_access -) -SELECT access_mode, participant_role, last_observed_at -FROM all_access -ORDER BY CASE WHEN access_mode = 'participant' THEN 0 ELSE 1 END, last_observed_at DESC NULLS LAST -LIMIT 1 -` - -type GetChatReadAccessByUserParams struct { - ChatID pgtype.UUID `json:"chat_id"` - UserID pgtype.UUID `json:"user_id"` -} - -type GetChatReadAccessByUserRow struct { - AccessMode string `json:"access_mode"` - ParticipantRole string `json:"participant_role"` - LastObservedAt pgtype.Timestamptz `json:"last_observed_at"` -} - -func (q *Queries) GetChatReadAccessByUser(ctx context.Context, arg GetChatReadAccessByUserParams) (GetChatReadAccessByUserRow, error) { - row := q.db.QueryRow(ctx, getChatReadAccessByUser, arg.ChatID, arg.UserID) - var i GetChatReadAccessByUserRow - err := row.Scan(&i.AccessMode, &i.ParticipantRole, &i.LastObservedAt) - return i, err -} - -const getChatRouteByID = `-- name: GetChatRouteByID :one -SELECT id, chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata, created_at, updated_at -FROM chat_routes -WHERE id = $1 -` - -func (q *Queries) GetChatRouteByID(ctx context.Context, id pgtype.UUID) (ChatRoute, error) { - row := q.db.QueryRow(ctx, getChatRouteByID, id) - var i ChatRoute - err := row.Scan( - &i.ID, - &i.ChatID, - &i.BotID, - &i.Platform, - &i.ChannelConfigID, - &i.ConversationID, - &i.ThreadID, - &i.ReplyTarget, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const getChatSettings = `-- name: GetChatSettings :one -SELECT id AS chat_id, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata AS metadata, updated_at -FROM chats -WHERE id = $1 -` - -type GetChatSettingsRow struct { - ChatID pgtype.UUID `json:"chat_id"` - EnableChatMemory bool `json:"enable_chat_memory"` - EnablePrivateMemory bool `json:"enable_private_memory"` - EnablePublicMemory bool `json:"enable_public_memory"` - ModelID pgtype.Text `json:"model_id"` - Metadata []byte `json:"metadata"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` -} - -func (q *Queries) GetChatSettings(ctx context.Context, id pgtype.UUID) (GetChatSettingsRow, error) { - row := q.db.QueryRow(ctx, getChatSettings, id) - var i GetChatSettingsRow - err := row.Scan( - &i.ChatID, - &i.EnableChatMemory, - &i.EnablePrivateMemory, - &i.EnablePublicMemory, - &i.ModelID, - &i.Metadata, - &i.UpdatedAt, - ) - return i, err -} - -const listChatMessages = `-- name: ListChatMessages :many -SELECT id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at -FROM chat_messages -WHERE chat_id = $1 -ORDER BY created_at ASC -` - -func (q *Queries) ListChatMessages(ctx context.Context, chatID pgtype.UUID) ([]ChatMessage, error) { - rows, err := q.db.Query(ctx, listChatMessages, chatID) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ChatMessage - for rows.Next() { - var i ChatMessage - if err := rows.Scan( - &i.ID, - &i.ChatID, - &i.BotID, - &i.RouteID, - &i.SenderChannelIdentityID, - &i.SenderUserID, - &i.Platform, - &i.ExternalMessageID, - &i.Role, - &i.Content, - &i.Metadata, - &i.CreatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const listChatMessagesBefore = `-- name: ListChatMessagesBefore :many -SELECT id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at -FROM chat_messages -WHERE chat_id = $1 AND created_at < $2 -ORDER BY created_at DESC -LIMIT $3 -` - -type ListChatMessagesBeforeParams struct { - ChatID pgtype.UUID `json:"chat_id"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - Limit int32 `json:"limit"` -} - -func (q *Queries) ListChatMessagesBefore(ctx context.Context, arg ListChatMessagesBeforeParams) ([]ChatMessage, error) { - rows, err := q.db.Query(ctx, listChatMessagesBefore, arg.ChatID, arg.CreatedAt, arg.Limit) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ChatMessage - for rows.Next() { - var i ChatMessage - if err := rows.Scan( - &i.ID, - &i.ChatID, - &i.BotID, - &i.RouteID, - &i.SenderChannelIdentityID, - &i.SenderUserID, - &i.Platform, - &i.ExternalMessageID, - &i.Role, - &i.Content, - &i.Metadata, - &i.CreatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const listChatMessagesLatest = `-- name: ListChatMessagesLatest :many -SELECT id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at -FROM chat_messages -WHERE chat_id = $1 -ORDER BY created_at DESC -LIMIT $2 -` - -type ListChatMessagesLatestParams struct { - ChatID pgtype.UUID `json:"chat_id"` - Limit int32 `json:"limit"` -} - -func (q *Queries) ListChatMessagesLatest(ctx context.Context, arg ListChatMessagesLatestParams) ([]ChatMessage, error) { - rows, err := q.db.Query(ctx, listChatMessagesLatest, arg.ChatID, arg.Limit) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ChatMessage - for rows.Next() { - var i ChatMessage - if err := rows.Scan( - &i.ID, - &i.ChatID, - &i.BotID, - &i.RouteID, - &i.SenderChannelIdentityID, - &i.SenderUserID, - &i.Platform, - &i.ExternalMessageID, - &i.Role, - &i.Content, - &i.Metadata, - &i.CreatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const listChatMessagesSince = `-- name: ListChatMessagesSince :many -SELECT id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at -FROM chat_messages -WHERE chat_id = $1 AND created_at >= $2 -ORDER BY created_at ASC -` - -type ListChatMessagesSinceParams struct { - ChatID pgtype.UUID `json:"chat_id"` - CreatedAt pgtype.Timestamptz `json:"created_at"` -} - -func (q *Queries) ListChatMessagesSince(ctx context.Context, arg ListChatMessagesSinceParams) ([]ChatMessage, error) { - rows, err := q.db.Query(ctx, listChatMessagesSince, arg.ChatID, arg.CreatedAt) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ChatMessage - for rows.Next() { - var i ChatMessage - if err := rows.Scan( - &i.ID, - &i.ChatID, - &i.BotID, - &i.RouteID, - &i.SenderChannelIdentityID, - &i.SenderUserID, - &i.Platform, - &i.ExternalMessageID, - &i.Role, - &i.Content, - &i.Metadata, - &i.CreatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const listChatParticipants = `-- name: ListChatParticipants :many -SELECT chat_id, user_id, role, joined_at -FROM chat_participants -WHERE chat_id = $1 -ORDER BY joined_at ASC -` - -func (q *Queries) ListChatParticipants(ctx context.Context, chatID pgtype.UUID) ([]ChatParticipant, error) { - rows, err := q.db.Query(ctx, listChatParticipants, chatID) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ChatParticipant - for rows.Next() { - var i ChatParticipant - if err := rows.Scan( - &i.ChatID, - &i.UserID, - &i.Role, - &i.JoinedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const listChatRoutes = `-- name: ListChatRoutes :many -SELECT id, chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata, created_at, updated_at -FROM chat_routes -WHERE chat_id = $1 -ORDER BY created_at ASC -` - -func (q *Queries) ListChatRoutes(ctx context.Context, chatID pgtype.UUID) ([]ChatRoute, error) { - rows, err := q.db.Query(ctx, listChatRoutes, chatID) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ChatRoute - for rows.Next() { - var i ChatRoute - if err := rows.Scan( - &i.ID, - &i.ChatID, - &i.BotID, - &i.Platform, - &i.ChannelConfigID, - &i.ConversationID, - &i.ThreadID, - &i.ReplyTarget, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const listChatsByBotAndUser = `-- name: ListChatsByBotAndUser :many -SELECT c.id, c.bot_id, c.kind, c.parent_chat_id, c.title, c.created_by_user_id, c.metadata, c.enable_chat_memory, c.enable_private_memory, c.enable_public_memory, c.model_id, c.settings_metadata, c.created_at, c.updated_at -FROM chats c -JOIN chat_participants cp ON cp.chat_id = c.id -WHERE c.bot_id = $1 AND cp.user_id = $2 -ORDER BY c.updated_at DESC -` - -type ListChatsByBotAndUserParams struct { - BotID pgtype.UUID `json:"bot_id"` - UserID pgtype.UUID `json:"user_id"` -} - -func (q *Queries) ListChatsByBotAndUser(ctx context.Context, arg ListChatsByBotAndUserParams) ([]Chat, error) { - rows, err := q.db.Query(ctx, listChatsByBotAndUser, arg.BotID, arg.UserID) - if err != nil { - return nil, err - } - defer rows.Close() - var items []Chat - for rows.Next() { - var i Chat - if err := rows.Scan( - &i.ID, - &i.BotID, - &i.Kind, - &i.ParentChatID, - &i.Title, - &i.CreatedByUserID, - &i.Metadata, - &i.EnableChatMemory, - &i.EnablePrivateMemory, - &i.EnablePublicMemory, - &i.ModelID, - &i.SettingsMetadata, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const listThreadsByParent = `-- name: ListThreadsByParent :many -SELECT id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata, created_at, updated_at -FROM chats -WHERE parent_chat_id = $1 AND kind = 'thread' -ORDER BY created_at DESC -` - -func (q *Queries) ListThreadsByParent(ctx context.Context, parentChatID pgtype.UUID) ([]Chat, error) { - rows, err := q.db.Query(ctx, listThreadsByParent, parentChatID) - if err != nil { - return nil, err - } - defer rows.Close() - var items []Chat - for rows.Next() { - var i Chat - if err := rows.Scan( - &i.ID, - &i.BotID, - &i.Kind, - &i.ParentChatID, - &i.Title, - &i.CreatedByUserID, - &i.Metadata, - &i.EnableChatMemory, - &i.EnablePrivateMemory, - &i.EnablePublicMemory, - &i.ModelID, - &i.SettingsMetadata, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const listVisibleChatsByBotAndUser = `-- name: ListVisibleChatsByBotAndUser :many -WITH participant_chats AS ( - SELECT c.id, c.bot_id, c.kind, c.parent_chat_id, c.title, c.created_by_user_id, c.metadata, c.created_at, c.updated_at, - 'participant'::text AS access_mode, - cp.role AS participant_role, - NULL::timestamptz AS last_observed_at - FROM chats c - JOIN chat_participants cp ON cp.chat_id = c.id - WHERE c.bot_id = $1 AND cp.user_id = $2 -), -observed_chats AS ( - SELECT c.id, c.bot_id, c.kind, c.parent_chat_id, c.title, c.created_by_user_id, c.metadata, c.created_at, c.updated_at, - 'channel_identity_observed'::text AS access_mode, - ''::text AS participant_role, - MAX(cap.last_seen_at) AS last_observed_at - FROM chats c - JOIN chat_channel_identity_presence cap ON cap.chat_id = c.id - JOIN channel_identities ci ON ci.id = cap.channel_identity_id - WHERE c.bot_id = $1 - AND ci.user_id = $2 - AND NOT EXISTS ( - SELECT 1 FROM chat_participants cp - WHERE cp.chat_id = c.id AND cp.user_id = $2 - ) - GROUP BY c.id, c.bot_id, c.kind, c.parent_chat_id, c.title, c.created_by_user_id, c.metadata, c.created_at, c.updated_at -) -SELECT id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, created_at, updated_at, - access_mode, participant_role, last_observed_at -FROM ( - SELECT id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, created_at, updated_at, access_mode, participant_role, last_observed_at FROM participant_chats - UNION ALL - SELECT id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, created_at, updated_at, access_mode, participant_role, last_observed_at FROM observed_chats -) v -ORDER BY v.updated_at DESC, v.last_observed_at DESC NULLS LAST -` - -type ListVisibleChatsByBotAndUserParams struct { - BotID pgtype.UUID `json:"bot_id"` - UserID pgtype.UUID `json:"user_id"` -} - -type ListVisibleChatsByBotAndUserRow struct { - ID pgtype.UUID `json:"id"` - BotID pgtype.UUID `json:"bot_id"` - Kind string `json:"kind"` - ParentChatID pgtype.UUID `json:"parent_chat_id"` - Title pgtype.Text `json:"title"` - CreatedByUserID pgtype.UUID `json:"created_by_user_id"` - Metadata []byte `json:"metadata"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` - AccessMode string `json:"access_mode"` - ParticipantRole string `json:"participant_role"` - LastObservedAt pgtype.Timestamptz `json:"last_observed_at"` -} - -func (q *Queries) ListVisibleChatsByBotAndUser(ctx context.Context, arg ListVisibleChatsByBotAndUserParams) ([]ListVisibleChatsByBotAndUserRow, error) { - rows, err := q.db.Query(ctx, listVisibleChatsByBotAndUser, arg.BotID, arg.UserID) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ListVisibleChatsByBotAndUserRow - for rows.Next() { - var i ListVisibleChatsByBotAndUserRow - if err := rows.Scan( - &i.ID, - &i.BotID, - &i.Kind, - &i.ParentChatID, - &i.Title, - &i.CreatedByUserID, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - &i.AccessMode, - &i.ParticipantRole, - &i.LastObservedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const removeChatParticipant = `-- name: RemoveChatParticipant :exec -DELETE FROM chat_participants WHERE chat_id = $1 AND user_id = $2 -` - -type RemoveChatParticipantParams struct { - ChatID pgtype.UUID `json:"chat_id"` - UserID pgtype.UUID `json:"user_id"` -} - -func (q *Queries) RemoveChatParticipant(ctx context.Context, arg RemoveChatParticipantParams) error { - _, err := q.db.Exec(ctx, removeChatParticipant, arg.ChatID, arg.UserID) - return err -} - -const touchChat = `-- name: TouchChat :exec -UPDATE chats SET updated_at = now() WHERE id = $1 -` - -func (q *Queries) TouchChat(ctx context.Context, id pgtype.UUID) error { - _, err := q.db.Exec(ctx, touchChat, id) - return err -} - -const updateChatRouteReplyTarget = `-- name: UpdateChatRouteReplyTarget :exec -UPDATE chat_routes SET reply_target = $2, updated_at = now() WHERE id = $1 -` - -type UpdateChatRouteReplyTargetParams struct { - ID pgtype.UUID `json:"id"` - ReplyTarget pgtype.Text `json:"reply_target"` -} - -func (q *Queries) UpdateChatRouteReplyTarget(ctx context.Context, arg UpdateChatRouteReplyTargetParams) error { - _, err := q.db.Exec(ctx, updateChatRouteReplyTarget, arg.ID, arg.ReplyTarget) - return err -} - -const updateChatTitle = `-- name: UpdateChatTitle :one -UPDATE chats SET title = $2, updated_at = now() -WHERE id = $1 -RETURNING id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata, created_at, updated_at -` - -type UpdateChatTitleParams struct { - ID pgtype.UUID `json:"id"` - Title pgtype.Text `json:"title"` -} - -func (q *Queries) UpdateChatTitle(ctx context.Context, arg UpdateChatTitleParams) (Chat, error) { - row := q.db.QueryRow(ctx, updateChatTitle, arg.ID, arg.Title) - var i Chat - err := row.Scan( - &i.ID, - &i.BotID, - &i.Kind, - &i.ParentChatID, - &i.Title, - &i.CreatedByUserID, - &i.Metadata, - &i.EnableChatMemory, - &i.EnablePrivateMemory, - &i.EnablePublicMemory, - &i.ModelID, - &i.SettingsMetadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const upsertChatChannelIdentityPresence = `-- name: UpsertChatChannelIdentityPresence :exec -INSERT INTO chat_channel_identity_presence (chat_id, channel_identity_id, first_seen_at, last_seen_at, message_count) -VALUES ($1, $2, now(), now(), 1) -ON CONFLICT (chat_id, channel_identity_id) -DO UPDATE SET - last_seen_at = now(), - message_count = chat_channel_identity_presence.message_count + 1 -` - -type UpsertChatChannelIdentityPresenceParams struct { - ChatID pgtype.UUID `json:"chat_id"` - ChannelIdentityID pgtype.UUID `json:"channel_identity_id"` -} - -func (q *Queries) UpsertChatChannelIdentityPresence(ctx context.Context, arg UpsertChatChannelIdentityPresenceParams) error { - _, err := q.db.Exec(ctx, upsertChatChannelIdentityPresence, arg.ChatID, arg.ChannelIdentityID) - return err -} - -const upsertChatSettings = `-- name: UpsertChatSettings :one - -UPDATE chats -SET enable_chat_memory = $2, - enable_private_memory = $3, - enable_public_memory = $4, - model_id = $5, - settings_metadata = $6 -WHERE id = $1 -RETURNING id AS chat_id, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata AS metadata, updated_at -` - -type UpsertChatSettingsParams struct { - ID pgtype.UUID `json:"id"` - EnableChatMemory bool `json:"enable_chat_memory"` - EnablePrivateMemory bool `json:"enable_private_memory"` - EnablePublicMemory bool `json:"enable_public_memory"` - ModelID pgtype.Text `json:"model_id"` - SettingsMetadata []byte `json:"settings_metadata"` -} - -type UpsertChatSettingsRow struct { - ChatID pgtype.UUID `json:"chat_id"` - EnableChatMemory bool `json:"enable_chat_memory"` - EnablePrivateMemory bool `json:"enable_private_memory"` - EnablePublicMemory bool `json:"enable_public_memory"` - ModelID pgtype.Text `json:"model_id"` - Metadata []byte `json:"metadata"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` -} - -// chat_settings -func (q *Queries) UpsertChatSettings(ctx context.Context, arg UpsertChatSettingsParams) (UpsertChatSettingsRow, error) { - row := q.db.QueryRow(ctx, upsertChatSettings, - arg.ID, - arg.EnableChatMemory, - arg.EnablePrivateMemory, - arg.EnablePublicMemory, - arg.ModelID, - arg.SettingsMetadata, - ) - var i UpsertChatSettingsRow - err := row.Scan( - &i.ChatID, - &i.EnableChatMemory, - &i.EnablePrivateMemory, - &i.EnablePublicMemory, - &i.ModelID, - &i.Metadata, - &i.UpdatedAt, - ) - return i, err -} diff --git a/internal/db/sqlc/containers.sql.go b/internal/db/sqlc/containers.sql.go index 3141781b..b193443e 100644 --- a/internal/db/sqlc/containers.sql.go +++ b/internal/db/sqlc/containers.sql.go @@ -72,6 +72,45 @@ func (q *Queries) GetContainerByContainerID(ctx context.Context, containerID str return i, err } +const listAutoStartContainers = `-- name: ListAutoStartContainers :many +SELECT id, bot_id, container_id, container_name, image, status, namespace, auto_start, host_path, container_path, created_at, updated_at, last_started_at, last_stopped_at FROM containers WHERE auto_start = true ORDER BY updated_at DESC +` + +func (q *Queries) ListAutoStartContainers(ctx context.Context) ([]Container, error) { + rows, err := q.db.Query(ctx, listAutoStartContainers) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Container + for rows.Next() { + var i Container + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.ContainerID, + &i.ContainerName, + &i.Image, + &i.Status, + &i.Namespace, + &i.AutoStart, + &i.HostPath, + &i.ContainerPath, + &i.CreatedAt, + &i.UpdatedAt, + &i.LastStartedAt, + &i.LastStoppedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const updateContainerStarted = `-- name: UpdateContainerStarted :exec UPDATE containers SET status = 'running', last_started_at = now(), updated_at = now() diff --git a/internal/db/sqlc/conversations.sql.go b/internal/db/sqlc/conversations.sql.go new file mode 100644 index 00000000..f9ba8fc4 --- /dev/null +++ b/internal/db/sqlc/conversations.sql.go @@ -0,0 +1,678 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: conversations.sql + +package sqlc + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const addChatParticipant = `-- name: AddChatParticipant :one + +INSERT INTO bot_members (bot_id, user_id, role) +VALUES ($1, $2, $3) +ON CONFLICT (bot_id, user_id) DO UPDATE SET role = EXCLUDED.role +RETURNING bot_id AS chat_id, user_id, role, created_at AS joined_at +` + +type AddChatParticipantParams struct { + ChatID pgtype.UUID `json:"chat_id"` + UserID pgtype.UUID `json:"user_id"` + Role string `json:"role"` +} + +type AddChatParticipantRow struct { + ChatID pgtype.UUID `json:"chat_id"` + UserID pgtype.UUID `json:"user_id"` + Role string `json:"role"` + JoinedAt pgtype.Timestamptz `json:"joined_at"` +} + +// chat_participants +func (q *Queries) AddChatParticipant(ctx context.Context, arg AddChatParticipantParams) (AddChatParticipantRow, error) { + row := q.db.QueryRow(ctx, addChatParticipant, arg.ChatID, arg.UserID, arg.Role) + var i AddChatParticipantRow + err := row.Scan( + &i.ChatID, + &i.UserID, + &i.Role, + &i.JoinedAt, + ) + return i, err +} + +const copyParticipantsToChat = `-- name: CopyParticipantsToChat :exec +INSERT INTO bot_members (bot_id, user_id, role) +SELECT $1, bm.user_id, bm.role +FROM bot_members bm +WHERE bm.bot_id = $2 +ON CONFLICT (bot_id, user_id) DO NOTHING +` + +type CopyParticipantsToChatParams struct { + ChatID2 pgtype.UUID `json:"chat_id_2"` + ChatID pgtype.UUID `json:"chat_id"` +} + +func (q *Queries) CopyParticipantsToChat(ctx context.Context, arg CopyParticipantsToChatParams) error { + _, err := q.db.Exec(ctx, copyParticipantsToChat, arg.ChatID2, arg.ChatID) + return err +} + +const createChat = `-- name: CreateChat :one +SELECT + b.id AS id, + b.id AS bot_id, + (COALESCE(NULLIF($1::text, ''), CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END))::text AS kind, + CASE WHEN $1 = 'thread' THEN $2::uuid ELSE NULL::uuid END AS parent_chat_id, + COALESCE(NULLIF($3::text, ''), b.display_name) AS title, + COALESCE($4::uuid, b.owner_user_id) AS created_by_user_id, + COALESCE($5::jsonb, b.metadata) AS metadata, + chat_models.model_id AS model_id, + b.created_at, + b.updated_at +FROM bots b +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = $6 +LIMIT 1 +` + +type CreateChatParams struct { + Kind string `json:"kind"` + ParentChatID pgtype.UUID `json:"parent_chat_id"` + Title string `json:"title"` + CreatedByUserID pgtype.UUID `json:"created_by_user_id"` + Metadata []byte `json:"metadata"` + BotID pgtype.UUID `json:"bot_id"` +} + +type CreateChatRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID pgtype.UUID `json:"parent_chat_id"` + Title pgtype.Text `json:"title"` + CreatedByUserID pgtype.UUID `json:"created_by_user_id"` + Metadata []byte `json:"metadata"` + ModelID pgtype.Text `json:"model_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) CreateChat(ctx context.Context, arg CreateChatParams) (CreateChatRow, error) { + row := q.db.QueryRow(ctx, createChat, + arg.Kind, + arg.ParentChatID, + arg.Title, + arg.CreatedByUserID, + arg.Metadata, + arg.BotID, + ) + var i CreateChatRow + err := row.Scan( + &i.ID, + &i.BotID, + &i.Kind, + &i.ParentChatID, + &i.Title, + &i.CreatedByUserID, + &i.Metadata, + &i.ModelID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const deleteChat = `-- name: DeleteChat :exec +WITH deleted_messages AS ( + DELETE FROM bot_history_messages + WHERE bot_id = $1 +) +DELETE FROM bot_channel_routes bcr +WHERE bcr.bot_id = $1 +` + +func (q *Queries) DeleteChat(ctx context.Context, chatID pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteChat, chatID) + return err +} + +const getChatByID = `-- name: GetChatByID :one +SELECT + b.id AS id, + b.id AS bot_id, + CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + b.display_name AS title, + b.owner_user_id AS created_by_user_id, + b.metadata AS metadata, + chat_models.model_id AS model_id, + b.created_at, + b.updated_at +FROM bots b +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = $1 +` + +type GetChatByIDRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID pgtype.UUID `json:"parent_chat_id"` + Title pgtype.Text `json:"title"` + CreatedByUserID pgtype.UUID `json:"created_by_user_id"` + Metadata []byte `json:"metadata"` + ModelID pgtype.Text `json:"model_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) GetChatByID(ctx context.Context, id pgtype.UUID) (GetChatByIDRow, error) { + row := q.db.QueryRow(ctx, getChatByID, id) + var i GetChatByIDRow + err := row.Scan( + &i.ID, + &i.BotID, + &i.Kind, + &i.ParentChatID, + &i.Title, + &i.CreatedByUserID, + &i.Metadata, + &i.ModelID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getChatParticipant = `-- name: GetChatParticipant :one +WITH owner_participant AS ( + SELECT b.id AS chat_id, b.owner_user_id AS user_id, 'owner'::text AS role, b.created_at AS joined_at + FROM bots b + WHERE b.id = $1 AND b.owner_user_id = $2 +), +member_participant AS ( + SELECT bm.bot_id AS chat_id, bm.user_id, bm.role, bm.created_at AS joined_at + FROM bot_members bm + WHERE bm.bot_id = $1 AND bm.user_id = $2 +) +SELECT chat_id, user_id, role, joined_at +FROM ( + SELECT chat_id, user_id, role, joined_at FROM owner_participant + UNION ALL + SELECT chat_id, user_id, role, joined_at FROM member_participant +) p +ORDER BY CASE WHEN role = 'owner' THEN 0 ELSE 1 END +LIMIT 1 +` + +type GetChatParticipantParams struct { + ChatID pgtype.UUID `json:"chat_id"` + UserID pgtype.UUID `json:"user_id"` +} + +type GetChatParticipantRow struct { + ChatID pgtype.UUID `json:"chat_id"` + UserID pgtype.UUID `json:"user_id"` + Role string `json:"role"` + JoinedAt pgtype.Timestamptz `json:"joined_at"` +} + +func (q *Queries) GetChatParticipant(ctx context.Context, arg GetChatParticipantParams) (GetChatParticipantRow, error) { + row := q.db.QueryRow(ctx, getChatParticipant, arg.ChatID, arg.UserID) + var i GetChatParticipantRow + err := row.Scan( + &i.ChatID, + &i.UserID, + &i.Role, + &i.JoinedAt, + ) + return i, err +} + +const getChatReadAccessByUser = `-- name: GetChatReadAccessByUser :one +SELECT + 'participant'::text AS access_mode, + (CASE + WHEN b.owner_user_id = $1 THEN 'owner' + ELSE COALESCE(bm.role, ''::text) + END)::text AS participant_role, + NULL::timestamptz AS last_observed_at +FROM bots b +LEFT JOIN bot_members bm ON bm.bot_id = b.id AND bm.user_id = $1 +WHERE b.id = $2 + AND (b.owner_user_id = $1 OR bm.user_id IS NOT NULL) +LIMIT 1 +` + +type GetChatReadAccessByUserParams struct { + UserID pgtype.UUID `json:"user_id"` + ChatID pgtype.UUID `json:"chat_id"` +} + +type GetChatReadAccessByUserRow struct { + AccessMode string `json:"access_mode"` + ParticipantRole string `json:"participant_role"` + LastObservedAt pgtype.Timestamptz `json:"last_observed_at"` +} + +func (q *Queries) GetChatReadAccessByUser(ctx context.Context, arg GetChatReadAccessByUserParams) (GetChatReadAccessByUserRow, error) { + row := q.db.QueryRow(ctx, getChatReadAccessByUser, arg.UserID, arg.ChatID) + var i GetChatReadAccessByUserRow + err := row.Scan(&i.AccessMode, &i.ParticipantRole, &i.LastObservedAt) + return i, err +} + +const getChatSettings = `-- name: GetChatSettings :one +SELECT + b.id AS chat_id, + chat_models.model_id AS model_id, + b.updated_at +FROM bots b +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = $1 +` + +type GetChatSettingsRow struct { + ChatID pgtype.UUID `json:"chat_id"` + ModelID pgtype.Text `json:"model_id"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) GetChatSettings(ctx context.Context, id pgtype.UUID) (GetChatSettingsRow, error) { + row := q.db.QueryRow(ctx, getChatSettings, id) + var i GetChatSettingsRow + err := row.Scan(&i.ChatID, &i.ModelID, &i.UpdatedAt) + return i, err +} + +const listChatParticipants = `-- name: ListChatParticipants :many +WITH owner_participant AS ( + SELECT b.id AS chat_id, b.owner_user_id AS user_id, 'owner'::text AS role, b.created_at AS joined_at + FROM bots b + WHERE b.id = $1 +), +member_participant AS ( + SELECT bm.bot_id AS chat_id, bm.user_id, bm.role, bm.created_at AS joined_at + FROM bot_members bm + WHERE bm.bot_id = $1 + AND bm.user_id <> (SELECT owner_user_id FROM bots WHERE id = $1) +) +SELECT chat_id, user_id, role, joined_at +FROM ( + SELECT chat_id, user_id, role, joined_at FROM owner_participant + UNION ALL + SELECT chat_id, user_id, role, joined_at FROM member_participant +) p +ORDER BY joined_at ASC +` + +type ListChatParticipantsRow struct { + ChatID pgtype.UUID `json:"chat_id"` + UserID pgtype.UUID `json:"user_id"` + Role string `json:"role"` + JoinedAt pgtype.Timestamptz `json:"joined_at"` +} + +func (q *Queries) ListChatParticipants(ctx context.Context, chatID pgtype.UUID) ([]ListChatParticipantsRow, error) { + rows, err := q.db.Query(ctx, listChatParticipants, chatID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListChatParticipantsRow + for rows.Next() { + var i ListChatParticipantsRow + if err := rows.Scan( + &i.ChatID, + &i.UserID, + &i.Role, + &i.JoinedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listChatsByBotAndUser = `-- name: ListChatsByBotAndUser :many +SELECT + b.id AS id, + b.id AS bot_id, + CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + b.display_name AS title, + b.owner_user_id AS created_by_user_id, + b.metadata AS metadata, + chat_models.model_id AS model_id, + b.created_at, + b.updated_at +FROM bots b +LEFT JOIN bot_members bm ON bm.bot_id = b.id AND bm.user_id = $1 +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = $2 + AND (b.owner_user_id = $1 OR bm.user_id IS NOT NULL) +ORDER BY b.updated_at DESC +` + +type ListChatsByBotAndUserParams struct { + UserID pgtype.UUID `json:"user_id"` + BotID pgtype.UUID `json:"bot_id"` +} + +type ListChatsByBotAndUserRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID pgtype.UUID `json:"parent_chat_id"` + Title pgtype.Text `json:"title"` + CreatedByUserID pgtype.UUID `json:"created_by_user_id"` + Metadata []byte `json:"metadata"` + ModelID pgtype.Text `json:"model_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) ListChatsByBotAndUser(ctx context.Context, arg ListChatsByBotAndUserParams) ([]ListChatsByBotAndUserRow, error) { + rows, err := q.db.Query(ctx, listChatsByBotAndUser, arg.UserID, arg.BotID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListChatsByBotAndUserRow + for rows.Next() { + var i ListChatsByBotAndUserRow + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.Kind, + &i.ParentChatID, + &i.Title, + &i.CreatedByUserID, + &i.Metadata, + &i.ModelID, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listThreadsByParent = `-- name: ListThreadsByParent :many +SELECT + b.id AS id, + b.id AS bot_id, + CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + b.display_name AS title, + b.owner_user_id AS created_by_user_id, + b.metadata AS metadata, + chat_models.model_id AS model_id, + b.created_at, + b.updated_at +FROM bots b +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = $1 + AND false +ORDER BY b.created_at DESC +` + +type ListThreadsByParentRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID pgtype.UUID `json:"parent_chat_id"` + Title pgtype.Text `json:"title"` + CreatedByUserID pgtype.UUID `json:"created_by_user_id"` + Metadata []byte `json:"metadata"` + ModelID pgtype.Text `json:"model_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) ListThreadsByParent(ctx context.Context, id pgtype.UUID) ([]ListThreadsByParentRow, error) { + rows, err := q.db.Query(ctx, listThreadsByParent, id) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListThreadsByParentRow + for rows.Next() { + var i ListThreadsByParentRow + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.Kind, + &i.ParentChatID, + &i.Title, + &i.CreatedByUserID, + &i.Metadata, + &i.ModelID, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listVisibleChatsByBotAndUser = `-- name: ListVisibleChatsByBotAndUser :many +SELECT + b.id AS id, + b.id AS bot_id, + CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + b.display_name AS title, + b.owner_user_id AS created_by_user_id, + b.metadata AS metadata, + b.created_at, + b.updated_at, + 'participant'::text AS access_mode, + (CASE + WHEN b.owner_user_id = $1 THEN 'owner' + ELSE COALESCE(bm.role, ''::text) + END)::text AS participant_role, + NULL::timestamptz AS last_observed_at +FROM bots b +LEFT JOIN bot_members bm ON bm.bot_id = b.id AND bm.user_id = $1 +WHERE b.id = $2 + AND (b.owner_user_id = $1 OR bm.user_id IS NOT NULL) +ORDER BY b.updated_at DESC +` + +type ListVisibleChatsByBotAndUserParams struct { + UserID pgtype.UUID `json:"user_id"` + BotID pgtype.UUID `json:"bot_id"` +} + +type ListVisibleChatsByBotAndUserRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID pgtype.UUID `json:"parent_chat_id"` + Title pgtype.Text `json:"title"` + CreatedByUserID pgtype.UUID `json:"created_by_user_id"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` + AccessMode string `json:"access_mode"` + ParticipantRole string `json:"participant_role"` + LastObservedAt pgtype.Timestamptz `json:"last_observed_at"` +} + +func (q *Queries) ListVisibleChatsByBotAndUser(ctx context.Context, arg ListVisibleChatsByBotAndUserParams) ([]ListVisibleChatsByBotAndUserRow, error) { + rows, err := q.db.Query(ctx, listVisibleChatsByBotAndUser, arg.UserID, arg.BotID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListVisibleChatsByBotAndUserRow + for rows.Next() { + var i ListVisibleChatsByBotAndUserRow + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.Kind, + &i.ParentChatID, + &i.Title, + &i.CreatedByUserID, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + &i.AccessMode, + &i.ParticipantRole, + &i.LastObservedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const removeChatParticipant = `-- name: RemoveChatParticipant :exec +DELETE FROM bot_members +WHERE bot_id = $1 + AND user_id = $2 + AND user_id <> (SELECT owner_user_id FROM bots WHERE id = $1) +` + +type RemoveChatParticipantParams struct { + ChatID pgtype.UUID `json:"chat_id"` + UserID pgtype.UUID `json:"user_id"` +} + +func (q *Queries) RemoveChatParticipant(ctx context.Context, arg RemoveChatParticipantParams) error { + _, err := q.db.Exec(ctx, removeChatParticipant, arg.ChatID, arg.UserID) + return err +} + +const touchChat = `-- name: TouchChat :exec +UPDATE bots +SET updated_at = now() +WHERE id = $1 +` + +func (q *Queries) TouchChat(ctx context.Context, chatID pgtype.UUID) error { + _, err := q.db.Exec(ctx, touchChat, chatID) + return err +} + +const updateChatTitle = `-- name: UpdateChatTitle :one +UPDATE bots +SET display_name = $1, + updated_at = now() +WHERE id = $2 +RETURNING + id, + id AS bot_id, + CASE WHEN type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + display_name AS title, + owner_user_id AS created_by_user_id, + metadata, + NULL::text AS model_id, + created_at, + updated_at +` + +type UpdateChatTitleParams struct { + Title pgtype.Text `json:"title"` + ID pgtype.UUID `json:"id"` +} + +type UpdateChatTitleRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID pgtype.UUID `json:"parent_chat_id"` + Title pgtype.Text `json:"title"` + CreatedByUserID pgtype.UUID `json:"created_by_user_id"` + Metadata []byte `json:"metadata"` + ModelID pgtype.Text `json:"model_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) UpdateChatTitle(ctx context.Context, arg UpdateChatTitleParams) (UpdateChatTitleRow, error) { + row := q.db.QueryRow(ctx, updateChatTitle, arg.Title, arg.ID) + var i UpdateChatTitleRow + err := row.Scan( + &i.ID, + &i.BotID, + &i.Kind, + &i.ParentChatID, + &i.Title, + &i.CreatedByUserID, + &i.Metadata, + &i.ModelID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const upsertChatSettings = `-- name: UpsertChatSettings :one + +WITH resolved_model AS ( + SELECT id + FROM models + WHERE model_id = NULLIF($1::text, '') + LIMIT 1 +), +updated AS ( + UPDATE bots + SET chat_model_id = COALESCE((SELECT id FROM resolved_model), bots.chat_model_id), + updated_at = now() + WHERE bots.id = $2 + RETURNING bots.id, bots.chat_model_id, bots.updated_at +) +SELECT + updated.id AS chat_id, + chat_models.model_id AS model_id, + updated.updated_at +FROM updated +LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id +` + +type UpsertChatSettingsParams struct { + ModelID pgtype.Text `json:"model_id"` + ID pgtype.UUID `json:"id"` +} + +type UpsertChatSettingsRow struct { + ChatID pgtype.UUID `json:"chat_id"` + ModelID pgtype.Text `json:"model_id"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +// chat_settings +func (q *Queries) UpsertChatSettings(ctx context.Context, arg UpsertChatSettingsParams) (UpsertChatSettingsRow, error) { + row := q.db.QueryRow(ctx, upsertChatSettings, arg.ModelID, arg.ID) + var i UpsertChatSettingsRow + err := row.Scan(&i.ChatID, &i.ModelID, &i.UpdatedAt) + return i, err +} diff --git a/internal/db/sqlc/messages.sql.go b/internal/db/sqlc/messages.sql.go new file mode 100644 index 00000000..dffb3037 --- /dev/null +++ b/internal/db/sqlc/messages.sql.go @@ -0,0 +1,409 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: messages.sql + +package sqlc + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createMessage = `-- name: CreateMessage :one +INSERT INTO bot_history_messages ( + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id, + channel_type, + source_message_id, + source_reply_to_message_id, + role, + content, + metadata +) +VALUES ( + $1, + $2::uuid, + $3::uuid, + $4::uuid, + $5::text, + $6::text, + $7::text, + $8, + $9, + $10 +) +RETURNING + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at +` + +type CreateMessageParams struct { + BotID pgtype.UUID `json:"bot_id"` + RouteID pgtype.UUID `json:"route_id"` + SenderChannelIdentityID pgtype.UUID `json:"sender_channel_identity_id"` + SenderUserID pgtype.UUID `json:"sender_user_id"` + Platform pgtype.Text `json:"platform"` + ExternalMessageID pgtype.Text `json:"external_message_id"` + SourceReplyToMessageID pgtype.Text `json:"source_reply_to_message_id"` + Role string `json:"role"` + Content []byte `json:"content"` + Metadata []byte `json:"metadata"` +} + +type CreateMessageRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + RouteID pgtype.UUID `json:"route_id"` + SenderChannelIdentityID pgtype.UUID `json:"sender_channel_identity_id"` + SenderUserID pgtype.UUID `json:"sender_user_id"` + Platform pgtype.Text `json:"platform"` + ExternalMessageID pgtype.Text `json:"external_message_id"` + SourceReplyToMessageID pgtype.Text `json:"source_reply_to_message_id"` + Role string `json:"role"` + Content []byte `json:"content"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + +func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (CreateMessageRow, error) { + row := q.db.QueryRow(ctx, createMessage, + arg.BotID, + arg.RouteID, + arg.SenderChannelIdentityID, + arg.SenderUserID, + arg.Platform, + arg.ExternalMessageID, + arg.SourceReplyToMessageID, + arg.Role, + arg.Content, + arg.Metadata, + ) + var i CreateMessageRow + err := row.Scan( + &i.ID, + &i.BotID, + &i.RouteID, + &i.SenderChannelIdentityID, + &i.SenderUserID, + &i.Platform, + &i.ExternalMessageID, + &i.SourceReplyToMessageID, + &i.Role, + &i.Content, + &i.Metadata, + &i.CreatedAt, + ) + return i, err +} + +const deleteMessagesByBot = `-- name: DeleteMessagesByBot :exec +DELETE FROM bot_history_messages +WHERE bot_id = $1 +` + +func (q *Queries) DeleteMessagesByBot(ctx context.Context, botID pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteMessagesByBot, botID) + return err +} + +const listMessages = `-- name: ListMessages :many +SELECT + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at +FROM bot_history_messages +WHERE bot_id = $1 +ORDER BY created_at ASC +` + +type ListMessagesRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + RouteID pgtype.UUID `json:"route_id"` + SenderChannelIdentityID pgtype.UUID `json:"sender_channel_identity_id"` + SenderUserID pgtype.UUID `json:"sender_user_id"` + Platform pgtype.Text `json:"platform"` + ExternalMessageID pgtype.Text `json:"external_message_id"` + SourceReplyToMessageID pgtype.Text `json:"source_reply_to_message_id"` + Role string `json:"role"` + Content []byte `json:"content"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + +func (q *Queries) ListMessages(ctx context.Context, botID pgtype.UUID) ([]ListMessagesRow, error) { + rows, err := q.db.Query(ctx, listMessages, botID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListMessagesRow + for rows.Next() { + var i ListMessagesRow + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.RouteID, + &i.SenderChannelIdentityID, + &i.SenderUserID, + &i.Platform, + &i.ExternalMessageID, + &i.SourceReplyToMessageID, + &i.Role, + &i.Content, + &i.Metadata, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listMessagesBefore = `-- name: ListMessagesBefore :many +SELECT + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at +FROM bot_history_messages +WHERE bot_id = $1 + AND created_at < $2 +ORDER BY created_at DESC +LIMIT $3 +` + +type ListMessagesBeforeParams struct { + BotID pgtype.UUID `json:"bot_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + MaxCount int32 `json:"max_count"` +} + +type ListMessagesBeforeRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + RouteID pgtype.UUID `json:"route_id"` + SenderChannelIdentityID pgtype.UUID `json:"sender_channel_identity_id"` + SenderUserID pgtype.UUID `json:"sender_user_id"` + Platform pgtype.Text `json:"platform"` + ExternalMessageID pgtype.Text `json:"external_message_id"` + SourceReplyToMessageID pgtype.Text `json:"source_reply_to_message_id"` + Role string `json:"role"` + Content []byte `json:"content"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + +func (q *Queries) ListMessagesBefore(ctx context.Context, arg ListMessagesBeforeParams) ([]ListMessagesBeforeRow, error) { + rows, err := q.db.Query(ctx, listMessagesBefore, arg.BotID, arg.CreatedAt, arg.MaxCount) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListMessagesBeforeRow + for rows.Next() { + var i ListMessagesBeforeRow + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.RouteID, + &i.SenderChannelIdentityID, + &i.SenderUserID, + &i.Platform, + &i.ExternalMessageID, + &i.SourceReplyToMessageID, + &i.Role, + &i.Content, + &i.Metadata, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listMessagesLatest = `-- name: ListMessagesLatest :many +SELECT + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at +FROM bot_history_messages +WHERE bot_id = $1 +ORDER BY created_at DESC +LIMIT $2 +` + +type ListMessagesLatestParams struct { + BotID pgtype.UUID `json:"bot_id"` + MaxCount int32 `json:"max_count"` +} + +type ListMessagesLatestRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + RouteID pgtype.UUID `json:"route_id"` + SenderChannelIdentityID pgtype.UUID `json:"sender_channel_identity_id"` + SenderUserID pgtype.UUID `json:"sender_user_id"` + Platform pgtype.Text `json:"platform"` + ExternalMessageID pgtype.Text `json:"external_message_id"` + SourceReplyToMessageID pgtype.Text `json:"source_reply_to_message_id"` + Role string `json:"role"` + Content []byte `json:"content"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + +func (q *Queries) ListMessagesLatest(ctx context.Context, arg ListMessagesLatestParams) ([]ListMessagesLatestRow, error) { + rows, err := q.db.Query(ctx, listMessagesLatest, arg.BotID, arg.MaxCount) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListMessagesLatestRow + for rows.Next() { + var i ListMessagesLatestRow + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.RouteID, + &i.SenderChannelIdentityID, + &i.SenderUserID, + &i.Platform, + &i.ExternalMessageID, + &i.SourceReplyToMessageID, + &i.Role, + &i.Content, + &i.Metadata, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listMessagesSince = `-- name: ListMessagesSince :many +SELECT + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at +FROM bot_history_messages +WHERE bot_id = $1 + AND created_at >= $2 +ORDER BY created_at ASC +` + +type ListMessagesSinceParams struct { + BotID pgtype.UUID `json:"bot_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + +type ListMessagesSinceRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + RouteID pgtype.UUID `json:"route_id"` + SenderChannelIdentityID pgtype.UUID `json:"sender_channel_identity_id"` + SenderUserID pgtype.UUID `json:"sender_user_id"` + Platform pgtype.Text `json:"platform"` + ExternalMessageID pgtype.Text `json:"external_message_id"` + SourceReplyToMessageID pgtype.Text `json:"source_reply_to_message_id"` + Role string `json:"role"` + Content []byte `json:"content"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + +func (q *Queries) ListMessagesSince(ctx context.Context, arg ListMessagesSinceParams) ([]ListMessagesSinceRow, error) { + rows, err := q.db.Query(ctx, listMessagesSince, arg.BotID, arg.CreatedAt) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListMessagesSinceRow + for rows.Next() { + var i ListMessagesSinceRow + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.RouteID, + &i.SenderChannelIdentityID, + &i.SenderUserID, + &i.Platform, + &i.ExternalMessageID, + &i.SourceReplyToMessageID, + &i.Role, + &i.Content, + &i.Metadata, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/db/sqlc/models.go b/internal/db/sqlc/models.go index 3b43891a..fff73aea 100644 --- a/internal/db/sqlc/models.go +++ b/internal/db/sqlc/models.go @@ -15,6 +15,7 @@ type Bot struct { DisplayName pgtype.Text `json:"display_name"` AvatarUrl pgtype.Text `json:"avatar_url"` IsActive bool `json:"is_active"` + Status string `json:"status"` MaxContextLoadTime int32 `json:"max_context_load_time"` Language string `json:"language"` AllowGuest bool `json:"allow_guest"` @@ -41,6 +42,34 @@ type BotChannelConfig struct { UpdatedAt pgtype.Timestamptz `json:"updated_at"` } +type BotChannelRoute struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + ChannelType string `json:"channel_type"` + ChannelConfigID pgtype.UUID `json:"channel_config_id"` + ExternalConversationID string `json:"external_conversation_id"` + ExternalThreadID pgtype.Text `json:"external_thread_id"` + DefaultReplyTarget pgtype.Text `json:"default_reply_target"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +type BotHistoryMessage struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + RouteID pgtype.UUID `json:"route_id"` + SenderChannelIdentityID pgtype.UUID `json:"sender_channel_identity_id"` + SenderAccountUserID pgtype.UUID `json:"sender_account_user_id"` + ChannelType pgtype.Text `json:"channel_type"` + SourceMessageID pgtype.Text `json:"source_message_id"` + SourceReplyToMessageID pgtype.Text `json:"source_reply_to_message_id"` + Role string `json:"role"` + Content []byte `json:"content"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + type BotMember struct { BotID pgtype.UUID `json:"bot_id"` UserID pgtype.UUID `json:"user_id"` @@ -61,7 +90,7 @@ type BotPreauthKey struct { type ChannelIdentity struct { ID pgtype.UUID `json:"id"` UserID pgtype.UUID `json:"user_id"` - Channel string `json:"channel"` + ChannelType string `json:"channel_type"` ChannelSubjectID string `json:"channel_subject_id"` DisplayName pgtype.Text `json:"display_name"` Metadata []byte `json:"metadata"` @@ -73,74 +102,13 @@ type ChannelIdentityBindCode struct { ID pgtype.UUID `json:"id"` Token string `json:"token"` IssuedByUserID pgtype.UUID `json:"issued_by_user_id"` - Platform pgtype.Text `json:"platform"` + ChannelType pgtype.Text `json:"channel_type"` ExpiresAt pgtype.Timestamptz `json:"expires_at"` UsedAt pgtype.Timestamptz `json:"used_at"` UsedByChannelIdentityID pgtype.UUID `json:"used_by_channel_identity_id"` CreatedAt pgtype.Timestamptz `json:"created_at"` } -type Chat struct { - ID pgtype.UUID `json:"id"` - BotID pgtype.UUID `json:"bot_id"` - Kind string `json:"kind"` - ParentChatID pgtype.UUID `json:"parent_chat_id"` - Title pgtype.Text `json:"title"` - CreatedByUserID pgtype.UUID `json:"created_by_user_id"` - Metadata []byte `json:"metadata"` - EnableChatMemory bool `json:"enable_chat_memory"` - EnablePrivateMemory bool `json:"enable_private_memory"` - EnablePublicMemory bool `json:"enable_public_memory"` - ModelID pgtype.Text `json:"model_id"` - SettingsMetadata []byte `json:"settings_metadata"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` -} - -type ChatChannelIdentityPresence struct { - ChatID pgtype.UUID `json:"chat_id"` - ChannelIdentityID pgtype.UUID `json:"channel_identity_id"` - FirstSeenAt pgtype.Timestamptz `json:"first_seen_at"` - LastSeenAt pgtype.Timestamptz `json:"last_seen_at"` - MessageCount int64 `json:"message_count"` -} - -type ChatMessage struct { - ID pgtype.UUID `json:"id"` - ChatID pgtype.UUID `json:"chat_id"` - BotID pgtype.UUID `json:"bot_id"` - RouteID pgtype.UUID `json:"route_id"` - SenderChannelIdentityID pgtype.UUID `json:"sender_channel_identity_id"` - SenderUserID pgtype.UUID `json:"sender_user_id"` - Platform pgtype.Text `json:"platform"` - ExternalMessageID pgtype.Text `json:"external_message_id"` - Role string `json:"role"` - Content []byte `json:"content"` - Metadata []byte `json:"metadata"` - CreatedAt pgtype.Timestamptz `json:"created_at"` -} - -type ChatParticipant struct { - ChatID pgtype.UUID `json:"chat_id"` - UserID pgtype.UUID `json:"user_id"` - Role string `json:"role"` - JoinedAt pgtype.Timestamptz `json:"joined_at"` -} - -type ChatRoute struct { - ID pgtype.UUID `json:"id"` - ChatID pgtype.UUID `json:"chat_id"` - BotID pgtype.UUID `json:"bot_id"` - Platform string `json:"platform"` - ChannelConfigID pgtype.UUID `json:"channel_config_id"` - ConversationID string `json:"conversation_id"` - ThreadID pgtype.Text `json:"thread_id"` - ReplyTarget pgtype.Text `json:"reply_target"` - Metadata []byte `json:"metadata"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` -} - type Container struct { ID pgtype.UUID `json:"id"` BotID pgtype.UUID `json:"bot_id"` @@ -277,10 +245,10 @@ type User struct { } type UserChannelBinding struct { - ID pgtype.UUID `json:"id"` - UserID pgtype.UUID `json:"user_id"` - Platform string `json:"platform"` - Config []byte `json:"config"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` + ID pgtype.UUID `json:"id"` + UserID pgtype.UUID `json:"user_id"` + ChannelType string `json:"channel_type"` + Config []byte `json:"config"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` } diff --git a/internal/db/text.go b/internal/db/text.go new file mode 100644 index 00000000..c0da1768 --- /dev/null +++ b/internal/db/text.go @@ -0,0 +1,11 @@ +package db + +import "github.com/jackc/pgx/v5/pgtype" + +// TextToString returns the string value of pgtype.Text, or "" when invalid. +func TextToString(value pgtype.Text) string { + if !value.Valid { + return "" + } + return value.String +} diff --git a/internal/db/text_test.go b/internal/db/text_test.go new file mode 100644 index 00000000..22fe0542 --- /dev/null +++ b/internal/db/text_test.go @@ -0,0 +1,26 @@ +package db + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgtype" +) + +func TestTextToString(t *testing.T) { + tests := []struct { + name string + value pgtype.Text + want string + }{ + {"valid", pgtype.Text{String: "hello", Valid: true}, "hello"}, + {"invalid", pgtype.Text{}, ""}, + {"valid empty", pgtype.Text{String: "", Valid: true}, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := TextToString(tt.value); got != tt.want { + t.Errorf("TextToString() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/internal/db/uuid.go b/internal/db/uuid.go index a0b3dfd3..7ba824ec 100644 --- a/internal/db/uuid.go +++ b/internal/db/uuid.go @@ -21,18 +21,6 @@ func ParseUUID(id string) (pgtype.UUID, error) { return pgID, nil } -// UUIDToString converts a pgtype.UUID to its string representation. -func UUIDToString(value pgtype.UUID) string { - if !value.Valid { - return "" - } - parsed, err := uuid.FromBytes(value.Bytes[:]) - if err != nil { - return "" - } - return parsed.String() -} - // TimeFromPg converts a pgtype.Timestamptz to time.Time. func TimeFromPg(value pgtype.Timestamptz) time.Time { if value.Valid { diff --git a/internal/embeddings/resolver.go b/internal/embeddings/resolver.go index eb8c123b..f90e5cb3 100644 --- a/internal/embeddings/resolver.go +++ b/internal/embeddings/resolver.go @@ -12,6 +12,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/models" ) @@ -261,7 +262,7 @@ func (r *Resolver) loadChannelIdentityEmbeddingModelID(ctx context.Context, chan if r.queries == nil { return "", nil } - pgChannelIdentityID, err := parseUUID(channelIdentityID) + pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) if err != nil { return "", err } @@ -275,13 +276,3 @@ func (r *Resolver) loadChannelIdentityEmbeddingModelID(ctx context.Context, chan return strings.TrimSpace(row.EmbeddingModelID.String), nil } -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(id) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go index f953777a..dfb4503d 100644 --- a/internal/handlers/auth.go +++ b/internal/handlers/auth.go @@ -21,18 +21,18 @@ type AuthHandler struct { } type LoginRequest struct { - Username string `json:"username" validate:"required"` - Password string `json:"password" validate:"required"` + Username string `json:"username"` + Password string `json:"password"` } type LoginResponse struct { - AccessToken string `json:"access_token" validate:"required"` - TokenType string `json:"token_type" validate:"required"` - ExpiresAt string `json:"expires_at" validate:"required"` - UserID string `json:"user_id" validate:"required"` - Role string `json:"role" validate:"required"` - DisplayName string `json:"display_name" validate:"required"` - Username string `json:"username" validate:"required"` + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresAt string `json:"expires_at"` + UserID string `json:"user_id"` + Role string `json:"role"` + DisplayName string `json:"display_name"` + Username string `json:"username"` } func NewAuthHandler(log *slog.Logger, accountService *accounts.Service, jwtSecret string, expiresIn time.Duration) *AuthHandler { diff --git a/internal/handlers/channel.go b/internal/handlers/channel.go index 1fa423b2..9f8ec869 100644 --- a/internal/handlers/channel.go +++ b/internal/handlers/channel.go @@ -94,11 +94,11 @@ func (h *ChannelHandler) UpsertChannelIdentityConfig(c echo.Context) error { } type ChannelMeta struct { - Type string `json:"type" validate:"required"` - DisplayName string `json:"display_name" validate:"required"` + Type string `json:"type"` + DisplayName string `json:"display_name"` Configless bool `json:"configless"` - Capabilities channel.ChannelCapabilities `json:"capabilities" validate:"required"` - ConfigSchema channel.ConfigSchema `json:"config_schema" validate:"required"` + Capabilities channel.ChannelCapabilities `json:"capabilities"` + ConfigSchema channel.ConfigSchema `json:"config_schema"` UserConfigSchema channel.ConfigSchema `json:"user_config_schema"` TargetSpec channel.TargetSpec `json:"target_spec"` } diff --git a/internal/handlers/chat.go b/internal/handlers/chat.go deleted file mode 100644 index ec4b50da..00000000 --- a/internal/handlers/chat.go +++ /dev/null @@ -1,642 +0,0 @@ -package handlers - -import ( - "bufio" - "context" - "encoding/json" - "errors" - "fmt" - "log/slog" - "net/http" - "strings" - - "github.com/labstack/echo/v4" - - "github.com/memohai/memoh/internal/accounts" - "github.com/memohai/memoh/internal/auth" - "github.com/memohai/memoh/internal/bots" - "github.com/memohai/memoh/internal/chat" - "github.com/memohai/memoh/internal/identity" -) - -// ChatHandler handles chat CRUD, messaging, participants, settings, and routes. -type ChatHandler struct { - resolver *chat.Resolver - chatService *chat.Service - botService *bots.Service - accountService *accounts.Service - logger *slog.Logger -} - -// NewChatHandler creates a ChatHandler. -func NewChatHandler(log *slog.Logger, resolver *chat.Resolver, chatService *chat.Service, botService *bots.Service, accountService *accounts.Service) *ChatHandler { - return &ChatHandler{ - resolver: resolver, - chatService: chatService, - botService: botService, - accountService: accountService, - logger: log.With(slog.String("handler", "chat")), - } -} - -// Register registers all chat routes. -func (h *ChatHandler) Register(e *echo.Echo) { - // Chat lifecycle (under bot). - botGroup := e.Group("/bots/:bot_id/chats") - botGroup.POST("", h.CreateChat) - botGroup.GET("", h.ListChats) - - // Chat operations. - chatGroup := e.Group("/chats/:chat_id") - chatGroup.GET("", h.GetChat) - chatGroup.DELETE("", h.DeleteChat) - - // Messages. - chatGroup.POST("/messages", h.SendMessage) - chatGroup.POST("/messages/stream", h.StreamMessage) - chatGroup.GET("/messages", h.ListMessages) - - // Participants. - chatGroup.GET("/participants", h.ListParticipants) - chatGroup.POST("/participants", h.AddParticipant) - chatGroup.DELETE("/participants/:user_id", h.RemoveParticipant) - - // Settings. - chatGroup.GET("/settings", h.GetSettings) - chatGroup.PUT("/settings", h.UpdateSettings) - - // Routes. - chatGroup.GET("/routes", h.ListRoutes) - chatGroup.POST("/routes", h.CreateRoute) - chatGroup.DELETE("/routes/:route_id", h.DeleteRoute) - - // Threads. - chatGroup.GET("/threads", h.ListThreads) -} - -// --- Chat Lifecycle --- - -// CreateChat creates a new chat for a bot. -func (h *ChatHandler) CreateChat(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { - return err - } - - var req chat.CreateChatRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - - result, err := h.chatService.Create(c.Request().Context(), botID, channelIdentityID, req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusCreated, result) -} - -// ListChats lists chats for a bot where the user has participant or observed access. -func (h *ChatHandler) ListChats(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { - return err - } - - chats, err := h.chatService.ListByBotAndChannelIdentity(c.Request().Context(), botID, channelIdentityID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, map[string]any{"items": chats}) -} - -// GetChat returns a chat by ID. -func (h *ChatHandler) GetChat(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) - if err != nil { - return err - } - chatID := strings.TrimSpace(c.Param("chat_id")) - if err := h.requireReadable(c.Request().Context(), chatID, channelIdentityID); err != nil { - return err - } - - result, err := h.chatService.Get(c.Request().Context(), chatID) - if err != nil { - if errors.Is(err, chat.ErrChatNotFound) { - return echo.NewHTTPError(http.StatusNotFound, "chat not found") - } - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, result) -} - -// DeleteChat deletes a chat (owner only). -func (h *ChatHandler) DeleteChat(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) - if err != nil { - return err - } - chatID := strings.TrimSpace(c.Param("chat_id")) - if err := h.requireRole(c.Request().Context(), chatID, channelIdentityID, chat.RoleOwner); err != nil { - return err - } - - if err := h.chatService.Delete(c.Request().Context(), chatID); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.NoContent(http.StatusNoContent) -} - -// --- Messages --- - -// SendMessage sends a synchronous chat message. -func (h *ChatHandler) SendMessage(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) - if err != nil { - return err - } - chatID := strings.TrimSpace(c.Param("chat_id")) - if err := h.requireParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { - return err - } - - chatObj, err := h.chatService.Get(c.Request().Context(), chatID) - if err != nil { - return echo.NewHTTPError(http.StatusNotFound, "chat not found") - } - - var req chat.ChatRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - if req.Query == "" { - return echo.NewHTTPError(http.StatusBadRequest, "query is required") - } - req.BotID = chatObj.BotID - req.ChatID = chatID - req.Token = c.Request().Header.Get("Authorization") - req.UserID = channelIdentityID - req.SourceChannelIdentityID = channelIdentityID - - resp, err := h.resolver.Chat(c.Request().Context(), req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, resp) -} - -// StreamMessage sends a streaming chat message. -func (h *ChatHandler) StreamMessage(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) - if err != nil { - return err - } - chatID := strings.TrimSpace(c.Param("chat_id")) - if err := h.requireParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { - return err - } - - chatObj, err := h.chatService.Get(c.Request().Context(), chatID) - if err != nil { - return echo.NewHTTPError(http.StatusNotFound, "chat not found") - } - - var req chat.ChatRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - if req.Query == "" { - return echo.NewHTTPError(http.StatusBadRequest, "query is required") - } - req.BotID = chatObj.BotID - req.ChatID = chatID - req.Token = c.Request().Header.Get("Authorization") - req.UserID = channelIdentityID - req.SourceChannelIdentityID = channelIdentityID - - c.Response().Header().Set(echo.HeaderContentType, "text/event-stream") - c.Response().Header().Set(echo.HeaderCacheControl, "no-cache") - c.Response().Header().Set(echo.HeaderConnection, "keep-alive") - c.Response().WriteHeader(http.StatusOK) - - chunkChan, errChan := h.resolver.StreamChat(c.Request().Context(), req) - flusher, ok := c.Response().Writer.(http.Flusher) - if !ok { - return echo.NewHTTPError(http.StatusInternalServerError, "streaming not supported") - } - writer := bufio.NewWriter(c.Response().Writer) - - for { - select { - case chunk, ok := <-chunkChan: - if !ok { - writer.WriteString("data: [DONE]\n\n") - writer.Flush() - flusher.Flush() - return nil - } - data, err := json.Marshal(chunk) - if err != nil { - continue - } - writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data))) - writer.Flush() - flusher.Flush() - case err := <-errChan: - if err != nil { - h.logger.Error("chat stream failed", slog.Any("error", err)) - errData := map[string]string{"error": err.Error()} - data, marshalErr := json.Marshal(errData) - if marshalErr != nil { - return echo.NewHTTPError(http.StatusInternalServerError, marshalErr.Error()) - } - writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data))) - writer.Flush() - flusher.Flush() - return nil - } - } - } -} - -// ListMessages lists messages for a chat. -func (h *ChatHandler) ListMessages(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) - if err != nil { - return err - } - chatID := strings.TrimSpace(c.Param("chat_id")) - if err := h.requireReadable(c.Request().Context(), chatID, channelIdentityID); err != nil { - return err - } - - messages, err := h.chatService.ListMessages(c.Request().Context(), chatID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, map[string]any{"items": messages}) -} - -// --- Participants --- - -// ListParticipants lists participants for a chat. -func (h *ChatHandler) ListParticipants(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) - if err != nil { - return err - } - chatID := strings.TrimSpace(c.Param("chat_id")) - if err := h.requireParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { - return err - } - - participants, err := h.chatService.ListParticipants(c.Request().Context(), chatID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, map[string]any{"items": participants}) -} - -// AddParticipant adds a participant to a chat (owner/admin only). -func (h *ChatHandler) AddParticipant(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) - if err != nil { - return err - } - chatID := strings.TrimSpace(c.Param("chat_id")) - if err := h.requireRole(c.Request().Context(), chatID, channelIdentityID, chat.RoleAdmin); err != nil { - return err - } - - var body struct { - UserID string `json:"user_id"` - Role string `json:"role"` - } - if err := c.Bind(&body); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - if strings.TrimSpace(body.UserID) == "" { - return echo.NewHTTPError(http.StatusBadRequest, "user_id is required") - } - - p, err := h.chatService.AddParticipant(c.Request().Context(), chatID, body.UserID, body.Role) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, p) -} - -// RemoveParticipant removes a participant from a chat. -func (h *ChatHandler) RemoveParticipant(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) - if err != nil { - return err - } - chatID := strings.TrimSpace(c.Param("chat_id")) - if err := h.requireRole(c.Request().Context(), chatID, channelIdentityID, chat.RoleAdmin); err != nil { - return err - } - - targetUserID := strings.TrimSpace(c.Param("user_id")) - if targetUserID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "user_id is required") - } - - if err := h.chatService.RemoveParticipant(c.Request().Context(), chatID, targetUserID); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.NoContent(http.StatusNoContent) -} - -// --- Settings --- - -// GetSettings returns settings for a chat. -func (h *ChatHandler) GetSettings(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) - if err != nil { - return err - } - chatID := strings.TrimSpace(c.Param("chat_id")) - if err := h.requireParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { - return err - } - - settings, err := h.chatService.GetSettings(c.Request().Context(), chatID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, settings) -} - -// UpdateSettings updates settings for a chat (owner/admin only). -func (h *ChatHandler) UpdateSettings(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) - if err != nil { - return err - } - chatID := strings.TrimSpace(c.Param("chat_id")) - chatObj, err := h.chatService.Get(c.Request().Context(), chatID) - if err != nil { - if errors.Is(err, chat.ErrChatNotFound) { - return echo.NewHTTPError(http.StatusNotFound, "chat not found") - } - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if chatObj.Kind == chat.KindGroup { - if _, err := h.authorizeBotManage(c.Request().Context(), channelIdentityID, chatObj.BotID); err != nil { - return err - } - } else { - if err := h.requireRole(c.Request().Context(), chatID, channelIdentityID, chat.RoleAdmin); err != nil { - return err - } - } - - var req chat.UpdateSettingsRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - - settings, err := h.chatService.UpdateSettings(c.Request().Context(), chatID, req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, settings) -} - -// --- Routes --- - -// ListRoutes lists routes for a chat. -func (h *ChatHandler) ListRoutes(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) - if err != nil { - return err - } - chatID := strings.TrimSpace(c.Param("chat_id")) - if err := h.requireParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { - return err - } - - routes, err := h.chatService.ListRoutes(c.Request().Context(), chatID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, map[string]any{"items": routes}) -} - -// CreateRoute creates a new route for a chat (cross-channel). -func (h *ChatHandler) CreateRoute(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) - if err != nil { - return err - } - chatID := strings.TrimSpace(c.Param("chat_id")) - if err := h.requireRole(c.Request().Context(), chatID, channelIdentityID, chat.RoleAdmin); err != nil { - return err - } - - var route chat.Route - if err := c.Bind(&route); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - - result, err := h.chatService.CreateRoute(c.Request().Context(), chatID, route) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusCreated, result) -} - -// DeleteRoute deletes a route. -func (h *ChatHandler) DeleteRoute(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) - if err != nil { - return err - } - chatID := strings.TrimSpace(c.Param("chat_id")) - if err := h.requireRole(c.Request().Context(), chatID, channelIdentityID, chat.RoleAdmin); err != nil { - return err - } - - routeID := strings.TrimSpace(c.Param("route_id")) - if routeID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "route_id is required") - } - if err := h.chatService.DeleteRoute(c.Request().Context(), routeID); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.NoContent(http.StatusNoContent) -} - -// --- Threads --- - -// ListThreads lists threads for a parent chat. -func (h *ChatHandler) ListThreads(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) - if err != nil { - return err - } - chatID := strings.TrimSpace(c.Param("chat_id")) - if err := h.requireParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { - return err - } - - threads, err := h.chatService.ListThreads(c.Request().Context(), chatID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, map[string]any{"items": threads}) -} - -// --- helpers --- - -func (h *ChatHandler) requireChannelIdentityID(c echo.Context) (string, error) { - channelIdentityID, err := auth.UserIDFromContext(c) - if err != nil { - return "", err - } - if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { - return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - return channelIdentityID, nil -} - -func (h *ChatHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { - if h.botService == nil || h.accountService == nil { - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") - } - isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) - if err != nil { - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: true}) - if err != nil { - if errors.Is(err, bots.ErrBotNotFound) { - return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") - } - if errors.Is(err, bots.ErrBotAccessDenied) { - return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied") - } - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return bot, nil -} - -func (h *ChatHandler) authorizeBotManage(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { - if h.botService == nil || h.accountService == nil { - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") - } - isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) - if err != nil { - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) - if err != nil { - if errors.Is(err, bots.ErrBotNotFound) { - return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") - } - if errors.Is(err, bots.ErrBotAccessDenied) { - return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot management access denied") - } - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return bot, nil -} - -func (h *ChatHandler) requireParticipant(ctx context.Context, chatID, channelIdentityID string) error { - if h.chatService == nil { - return echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured") - } - // Admin bypass. - if h.accountService != nil { - isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if isAdmin { - return nil - } - } - ok, err := h.chatService.IsParticipant(ctx, chatID, channelIdentityID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if !ok { - return echo.NewHTTPError(http.StatusForbidden, "not a participant") - } - return nil -} - -func (h *ChatHandler) requireReadable(ctx context.Context, chatID, channelIdentityID string) error { - if h.chatService == nil { - return echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured") - } - // Admin bypass. - if h.accountService != nil { - isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if isAdmin { - return nil - } - } - _, err := h.chatService.GetReadAccess(ctx, chatID, channelIdentityID) - if err != nil { - if errors.Is(err, chat.ErrPermissionDenied) { - return echo.NewHTTPError(http.StatusForbidden, "not allowed to read chat") - } - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return nil -} - -func (h *ChatHandler) requireRole(ctx context.Context, chatID, channelIdentityID, minRole string) error { - if h.chatService == nil { - return echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured") - } - // Admin bypass. - if h.accountService != nil { - isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if isAdmin { - return nil - } - } - p, err := h.chatService.GetParticipant(ctx, chatID, channelIdentityID) - if err != nil { - if errors.Is(err, chat.ErrNotParticipant) { - return echo.NewHTTPError(http.StatusForbidden, "not a participant") - } - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if !roleAtLeast(p.Role, minRole) { - return echo.NewHTTPError(http.StatusForbidden, "insufficient permissions") - } - return nil -} - -func roleAtLeast(actual, required string) bool { - roleLevel := map[string]int{ - chat.RoleOwner: 3, - chat.RoleAdmin: 2, - chat.RoleMember: 1, - } - return roleLevel[actual] >= roleLevel[required] -} diff --git a/internal/handlers/containerd.go b/internal/handlers/containerd.go index e65bfef2..8e0753bc 100644 --- a/internal/handlers/containerd.go +++ b/internal/handlers/containerd.go @@ -27,9 +27,11 @@ import ( "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/config" ctr "github.com/memohai/memoh/internal/containerd" + "github.com/memohai/memoh/internal/db" dbsqlc "github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/identity" "github.com/memohai/memoh/internal/mcp" + "github.com/memohai/memoh/internal/policy" ) type ContainerdHandler struct { @@ -44,11 +46,11 @@ type ContainerdHandler struct { mcpStdioSess map[string]*mcpStdioSession botService *bots.Service accountService *accounts.Service + policyService *policy.Service queries *dbsqlc.Queries } type CreateContainerRequest struct { - Image string `json:"image,omitempty"` Snapshotter string `json:"snapshotter,omitempty"` } @@ -96,7 +98,7 @@ type ListSnapshotsResponse struct { Snapshots []SnapshotInfo `json:"snapshots"` } -func NewContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string, botService *bots.Service, accountService *accounts.Service, queries *dbsqlc.Queries) *ContainerdHandler { +func NewContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string, botService *bots.Service, accountService *accounts.Service, policyService *policy.Service, queries *dbsqlc.Queries) *ContainerdHandler { return &ContainerdHandler{ service: service, cfg: cfg, @@ -106,6 +108,7 @@ func NewContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.MCPC mcpStdioSess: make(map[string]*mcpStdioSession), botService: botService, accountService: accountService, + policyService: policyService, queries: queries, } } @@ -122,20 +125,9 @@ func (h *ContainerdHandler) Register(e *echo.Echo) { group.GET("/skills", h.ListSkills) group.POST("/skills", h.UpsertSkills) group.DELETE("/skills", h.DeleteSkills) - group.POST("/fs-mcp", h.HandleMCPFS) - root := e.Group("/bots/:bot_id") - fs := e.Group("/bots/:bot_id/container/fs") - fs.GET("", h.ListFS) - fs.GET("/file", h.ReadFSFile) - fs.GET("/stat", h.StatFS) - fs.GET("/usage", h.UsageFS) - fs.POST("/file", h.WriteFSFile) - fs.POST("/dir", h.MkdirFS) - fs.POST("/upload", h.UploadFS) - fs.DELETE("", h.DeleteFS) root.POST("/mcp-stdio", h.CreateMCPStdio) - root.POST("/mcp-stdio/:session_id", h.HandleMCPStdio) + root.POST("/mcp-stdio/:connection_id", h.HandleMCPStdio) root.POST("/tools", h.HandleMCPTools) } @@ -160,13 +152,7 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { } containerID := mcp.ContainerPrefix + botID - image := strings.TrimSpace(req.Image) - if image == "" { - image = h.cfg.BusyboxImage - } - if image == "" { - image = config.DefaultBusyboxImg - } + image := mcp.DefaultImageRef snapshotter := strings.TrimSpace(req.Snapshotter) if snapshotter == "" { snapshotter = h.cfg.Snapshotter @@ -205,12 +191,6 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { Source: dataDir, Options: []string{"rbind", "rw"}, }, - { - Destination: "/app", - Type: "bind", - Source: dataDir, - Options: []string{"rbind", "rw"}, - }, { Destination: "/etc/resolv.conf", Type: "bind", @@ -235,13 +215,13 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { } if h.queries != nil { - pgBotID, parseErr := parsePgUUID(botID) + pgBotID, parseErr := db.ParseUUID(botID) if parseErr == nil { ns := strings.TrimSpace(h.namespace) if ns == "" { ns = "default" } - _ = h.queries.UpsertContainer(c.Request().Context(), dbsqlc.UpsertContainerParams{ + if dbErr := h.queries.UpsertContainer(c.Request().Context(), dbsqlc.UpsertContainerParams{ BotID: pgBotID, ContainerID: containerID, ContainerName: containerID, @@ -251,7 +231,10 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { AutoStart: true, HostPath: pgtype.Text{String: dataDir, Valid: true}, ContainerPath: dataMount, - }) + }); dbErr != nil { + h.logger.Error("failed to upsert container record", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } } } @@ -262,8 +245,11 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { if netErr := ctr.SetupNetwork(ctx, task, containerID); netErr == nil { started = true if h.queries != nil { - if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { - _ = h.queries.UpdateContainerStarted(c.Request().Context(), pgBotID) + if pgBotID, parseErr := db.ParseUUID(botID); parseErr == nil { + if dbErr := h.queries.UpdateContainerStarted(c.Request().Context(), pgBotID); dbErr != nil { + h.logger.Error("failed to update container started status", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } } } } else { @@ -288,7 +274,25 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { }) } -func (h *ContainerdHandler) ensureTaskRunning(ctx context.Context, containerID string) error { +// ensureContainerAndTask verifies the container exists in containerd and its task is +// running. If the container is missing (e.g. after a VM restart) it is recreated via +// SetupBotContainer. This prevents permanent desync between DB and containerd state. +func (h *ContainerdHandler) ensureContainerAndTask(ctx context.Context, containerID, botID string) error { + // Check whether the container exists in containerd. + _, err := h.service.GetContainer(ctx, containerID) + if err != nil { + if !errdefs.IsNotFound(err) { + return err + } + // Container gone — rebuild from scratch. + h.logger.Warn("container missing in containerd, rebuilding", + slog.String("bot_id", botID), + slog.String("container_id", containerID), + ) + return h.SetupBotContainer(ctx, botID) + } + + // Container exists — make sure the task is running. tasks, err := h.service.ListTasks(ctx, &ctr.ListTasksOptions{ Filter: "container.id==" + containerID, }) @@ -312,13 +316,13 @@ func (h *ContainerdHandler) ensureTaskRunning(ctx context.Context, containerID s _ = h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{Force: true}) return err } - return err + return nil } // botContainerID resolves container_id for a bot from the database. func (h *ContainerdHandler) botContainerID(ctx context.Context, botID string) (string, error) { if h.queries != nil { - pgBotID, err := parsePgUUID(botID) + pgBotID, err := db.ParseUUID(botID) if err == nil { row, err := h.queries.GetContainerByBotID(ctx, pgBotID) if err == nil && strings.TrimSpace(row.ContainerID) != "" { @@ -369,7 +373,7 @@ func (h *ContainerdHandler) GetContainer(c echo.Context) error { ctx := c.Request().Context() if h.queries != nil { - pgBotID, parseErr := parsePgUUID(botID) + pgBotID, parseErr := db.ParseUUID(botID) if parseErr == nil { row, dbErr := h.queries.GetContainerByBotID(ctx, pgBotID) if dbErr == nil { @@ -468,12 +472,15 @@ func (h *ContainerdHandler) StartContainer(c echo.Context) error { if err != nil { return echo.NewHTTPError(http.StatusNotFound, "container not found for bot") } - if err := h.ensureTaskRunning(ctx, containerID); err != nil { + if err := h.ensureContainerAndTask(ctx, containerID, botID); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } if h.queries != nil { - if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { - _ = h.queries.UpdateContainerStarted(ctx, pgBotID) + if pgBotID, parseErr := db.ParseUUID(botID); parseErr == nil { + if dbErr := h.queries.UpdateContainerStarted(ctx, pgBotID); dbErr != nil { + h.logger.Error("failed to update container started status", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } } } return c.JSON(http.StatusOK, map[string]bool{"started": true}) @@ -505,8 +512,11 @@ func (h *ContainerdHandler) StopContainer(c echo.Context) error { } _ = h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true}) if h.queries != nil { - if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { - _ = h.queries.UpdateContainerStopped(ctx, pgBotID) + if pgBotID, parseErr := db.ParseUUID(botID); parseErr == nil { + if dbErr := h.queries.UpdateContainerStopped(ctx, pgBotID); dbErr != nil { + h.logger.Error("failed to update container stopped status", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } } } return c.JSON(http.StatusOK, map[string]bool{"stopped": true}) @@ -652,6 +662,15 @@ func (h *ContainerdHandler) authorizeBotAccess(ctx context.Context, channelIdent if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") } + if errors.Is(err, bots.ErrBotAccessDenied) && h.policyService != nil { + allowGuest, policyErr := h.policyService.AllowGuest(ctx, botID) + if policyErr == nil && allowGuest { + bot, getErr := h.botService.Get(ctx, botID) + if getErr == nil { + return bot, nil + } + } + } if errors.Is(err, bots.ErrBotAccessDenied) { return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied") } @@ -664,10 +683,7 @@ func (h *ContainerdHandler) authorizeBotAccess(ctx context.Context, channelIdent func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string) error { containerID := mcp.ContainerPrefix + botID - image := strings.TrimSpace(h.cfg.BusyboxImage) - if image == "" { - image = config.DefaultBusyboxImg - } + image := mcp.DefaultImageRef snapshotter := strings.TrimSpace(h.cfg.Snapshotter) if strings.TrimSpace(h.namespace) != "" { @@ -703,12 +719,6 @@ func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string) Source: dataDir, Options: []string{"rbind", "rw"}, }, - { - Destination: "/app", - Type: "bind", - Source: dataDir, - Options: []string{"rbind", "rw"}, - }, { Destination: "/etc/resolv.conf", Type: "bind", @@ -733,13 +743,13 @@ func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string) } if h.queries != nil { - pgBotID, parseErr := parsePgUUID(botID) + pgBotID, parseErr := db.ParseUUID(botID) if parseErr == nil { ns := strings.TrimSpace(h.namespace) if ns == "" { ns = "default" } - _ = h.queries.UpsertContainer(ctx, dbsqlc.UpsertContainerParams{ + if dbErr := h.queries.UpsertContainer(ctx, dbsqlc.UpsertContainerParams{ BotID: pgBotID, ContainerID: containerID, ContainerName: containerID, @@ -749,7 +759,10 @@ func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string) AutoStart: true, HostPath: pgtype.Text{String: dataDir, Valid: true}, ContainerPath: dataMount, - }) + }); dbErr != nil { + h.logger.Error("setup bot container: failed to upsert container record", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } } } @@ -758,8 +771,11 @@ func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string) }); err == nil { if netErr := ctr.SetupNetwork(ctx, task, containerID); netErr == nil { if h.queries != nil { - if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { - _ = h.queries.UpdateContainerStarted(ctx, pgBotID) + if pgBotID, parseErr := db.ParseUUID(botID); parseErr == nil { + if dbErr := h.queries.UpdateContainerStarted(ctx, pgBotID); dbErr != nil { + h.logger.Error("setup bot container: failed to update container started status", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } } } } else { @@ -790,8 +806,11 @@ func (h *ContainerdHandler) CleanupBotContainer(ctx context.Context, botID strin slog.Any("error", err), ) if h.queries != nil { - if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { - _ = h.queries.DeleteContainerByBotID(ctx, pgBotID) + if pgBotID, parseErr := db.ParseUUID(botID); parseErr == nil { + if dbErr := h.queries.DeleteContainerByBotID(ctx, pgBotID); dbErr != nil { + h.logger.Error("CleanupBotContainer: failed to delete DB record", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } } } return nil @@ -827,8 +846,11 @@ func (h *ContainerdHandler) CleanupBotContainer(ctx context.Context, botID strin if h.queries != nil { h.logger.Info("CleanupBotContainer: deleting container record from DB", slog.String("bot_id", botID)) - if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { - _ = h.queries.DeleteContainerByBotID(ctx, pgBotID) + if pgBotID, parseErr := db.ParseUUID(botID); parseErr == nil { + if dbErr := h.queries.DeleteContainerByBotID(ctx, pgBotID); dbErr != nil { + h.logger.Error("CleanupBotContainer: failed to delete DB record", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } } } h.logger.Info("CleanupBotContainer finished", slog.String("bot_id", botID)) @@ -842,13 +864,99 @@ func (h *ContainerdHandler) isTaskRunning(ctx context.Context, containerID strin return err == nil && len(tasks) > 0 && tasks[0].Status == tasktypes.Status_RUNNING } -func parsePgUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) - if err != nil { - return pgtype.UUID{}, err +// ReconcileContainers compares the DB containers table against actual containerd +// state on startup. For each auto_start container in DB it verifies the container +// and task exist; if missing they are rebuilt via SetupBotContainer. Containers that +// the DB claims are running but are not present in containerd get corrected. +func (h *ContainerdHandler) ReconcileContainers(ctx context.Context) { + if h.queries == nil { + return } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil + rows, err := h.queries.ListAutoStartContainers(ctx) + if err != nil { + h.logger.Error("reconcile: failed to list containers from DB", slog.Any("error", err)) + return + } + if len(rows) == 0 { + h.logger.Info("reconcile: no auto-start containers in DB") + return + } + + h.logger.Info("reconcile: checking containers", slog.Int("count", len(rows))) + for _, row := range rows { + containerID := row.ContainerID + botID := uuid.UUID(row.BotID.Bytes).String() + + _, err := h.service.GetContainer(ctx, containerID) + if err != nil { + if !errdefs.IsNotFound(err) { + h.logger.Error("reconcile: failed to get container", + slog.String("container_id", containerID), slog.Any("error", err)) + continue + } + // Container missing in containerd — rebuild. + h.logger.Warn("reconcile: container missing, rebuilding", + slog.String("bot_id", botID), slog.String("container_id", containerID)) + if setupErr := h.SetupBotContainer(ctx, botID); setupErr != nil { + h.logger.Error("reconcile: rebuild failed", + slog.String("bot_id", botID), slog.Any("error", setupErr)) + if dbErr := h.queries.UpdateContainerStatus(ctx, dbsqlc.UpdateContainerStatusParams{ + Status: "error", + BotID: row.BotID, + }); dbErr != nil { + h.logger.Error("reconcile: failed to mark container as error", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } + } + continue + } + + // Container exists — ensure the task is running. + running := h.isTaskRunning(ctx, containerID) + if running { + if row.Status != "running" { + if dbErr := h.queries.UpdateContainerStarted(ctx, row.BotID); dbErr != nil { + h.logger.Error("reconcile: failed to update DB status to running", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } + } + h.logger.Info("reconcile: container healthy", + slog.String("bot_id", botID), slog.String("container_id", containerID)) + continue + } + + // Task not running — try to start it. + h.logger.Warn("reconcile: task not running, starting", + slog.String("bot_id", botID), slog.String("container_id", containerID)) + if err := h.ensureContainerAndTask(ctx, containerID, botID); err != nil { + h.logger.Error("reconcile: failed to start task", + slog.String("bot_id", botID), slog.Any("error", err)) + if dbErr := h.queries.UpdateContainerStopped(ctx, row.BotID); dbErr != nil { + h.logger.Error("reconcile: failed to mark container as stopped", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } + } else { + if dbErr := h.queries.UpdateContainerStarted(ctx, row.BotID); dbErr != nil { + h.logger.Error("reconcile: failed to update DB status to running", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } + } + } + h.logger.Info("reconcile: completed") +} + +func (h *ContainerdHandler) ensureBotDataRoot(botID string) (string, error) { + dataRoot := strings.TrimSpace(h.cfg.DataRoot) + if dataRoot == "" { + dataRoot = config.DefaultDataRoot + } + dataRoot, err := filepath.Abs(dataRoot) + if err != nil { + return "", err + } + root := filepath.Join(dataRoot, "bots", botID) + if err := os.MkdirAll(root, 0o755); err != nil { + return "", err + } + return root, nil } diff --git a/internal/handlers/fs.go b/internal/handlers/fs.go index 658310d7..17ef77de 100644 --- a/internal/handlers/fs.go +++ b/internal/handlers/fs.go @@ -26,77 +26,6 @@ import ( mcptools "github.com/memohai/memoh/internal/mcp" ) -// HandleMCPFS godoc -// @Summary MCP filesystem tools (JSON-RPC) -// @Description Forwards MCP JSON-RPC requests to the MCP server inside the container. -// @Description Required: -// @Description - container task is running -// @Description - container has data mount (default /data) bound to /users/ -// @Description - container image contains the "mcp" binary -// @Description Auth: Bearer JWT is used to determine user_id (sub or user_id). -// @Description Paths must be relative (no leading slash) and must not contain "..". -// @Description -// @Description Example: tools/list -// @Description {"jsonrpc":"2.0","id":1,"method":"tools/list"} -// @Description -// @Description Example: tools/call (fs.read) -// @Description {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"fs.read","arguments":{"path":"notes.txt"}}} -// @Tags containerd -// @Param Authorization header string true "Bearer " -// @Param bot_id path string true "Bot ID" -// @Param payload body object true "JSON-RPC request" -// @Success 200 {object} object "JSON-RPC response: {jsonrpc,id,result|error}" -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs-mcp [post] -func (h *ContainerdHandler) HandleMCPFS(c echo.Context) error { - botID, err := h.requireBotAccess(c) - if err != nil { - return err - } - ctx := c.Request().Context() - containerID, err := h.botContainerID(ctx, botID) - if err != nil { - return echo.NewHTTPError(http.StatusNotFound, "container not found for bot") - } - - var req mcptools.JSONRPCRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - if req.JSONRPC != "" && req.JSONRPC != "2.0" { - return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32600, "invalid jsonrpc version")) - } - - if err := h.validateMCPContainer(ctx, containerID, botID); err != nil { - h.logger.Error("mcp fs validate failed", slog.Any("error", err), slog.String("bot_id", botID), slog.String("container_id", containerID)) - return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32603, err.Error())) - } - if err := h.ensureTaskRunning(ctx, containerID); err != nil { - h.logger.Error("mcp fs ensure task failed", slog.Any("error", err), slog.String("bot_id", botID), slog.String("container_id", containerID)) - return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32603, err.Error())) - } - - if strings.TrimSpace(req.Method) == "" { - return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32601, "method not found")) - } - if mcptools.IsNotification(req) { - if err := h.notifyMCPServer(ctx, containerID, req); err != nil { - h.logger.Error("mcp fs notify failed", slog.Any("error", err), slog.String("method", req.Method), slog.String("bot_id", botID), slog.String("container_id", containerID)) - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - // MCP Streamable HTTP spec: notifications must be answered with 202 Accepted and no body. - return c.NoContent(http.StatusAccepted) - } - payload, err := h.callMCPServer(ctx, containerID, req) - if err != nil { - h.logger.Error("mcp fs call failed", slog.Any("error", err), slog.String("method", req.Method), slog.String("bot_id", botID), slog.String("container_id", containerID)) - return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32603, err.Error())) - } - return c.JSON(http.StatusOK, payload) -} - func (h *ContainerdHandler) validateMCPContainer(ctx context.Context, containerID, botID string) error { if strings.TrimSpace(botID) == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") diff --git a/internal/handlers/fs_rest.go b/internal/handlers/fs_rest.go deleted file mode 100644 index 96e529d8..00000000 --- a/internal/handlers/fs_rest.go +++ /dev/null @@ -1,585 +0,0 @@ -package handlers - -import ( - "errors" - "io" - "net/http" - "os" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/labstack/echo/v4" - - "github.com/memohai/memoh/internal/config" -) - -type FSListResponse struct { - Path string `json:"path"` - Entries []FSRestEntry `json:"entries"` -} - -type FSRestEntry struct { - Path string `json:"path"` - IsDir bool `json:"is_dir"` - Size int64 `json:"size"` - Mode uint32 `json:"mode"` - ModTime time.Time `json:"mod_time"` -} - -type FSReadResponse struct { - Path string `json:"path"` - Content string `json:"content"` - Size int64 `json:"size"` - Mode uint32 `json:"mode"` - ModTime time.Time `json:"mod_time"` -} - -type FSStatResponse struct { - Path string `json:"path"` - IsDir bool `json:"is_dir"` - Size int64 `json:"size"` - Mode uint32 `json:"mode"` - ModTime time.Time `json:"mod_time"` -} - -type FSUsageResponse struct { - Path string `json:"path"` - TotalBytes int64 `json:"total_bytes"` - FileCount int64 `json:"file_count"` - DirCount int64 `json:"dir_count"` -} - -type FSWriteRequest struct { - Path string `json:"path"` - Content string `json:"content"` - Overwrite *bool `json:"overwrite"` -} - -type FSWriteResponse struct { - OK bool `json:"ok"` -} - -type FSMkdirRequest struct { - Path string `json:"path"` - Parents *bool `json:"parents"` -} - -type FSDeleteResponse struct { - OK bool `json:"ok"` -} - -// ListFS godoc -// @Summary List files for a bot -// @Description List entries under a relative path -// @Tags fs -// @Param bot_id path string true "Bot ID" -// @Param path query string false "Relative directory path" -// @Param recursive query bool false "Recursive listing" -// @Success 200 {object} FSListResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs [get] -func (h *ContainerdHandler) ListFS(c echo.Context) error { - botID, err := h.requireBotAccess(c) - if err != nil { - return err - } - root, err := h.ensureBotDataRoot(botID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - recursive, err := parseBoolQuery(c, "recursive") - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - target, rel, err := resolveBotPath(root, c.QueryParam("path"), true) - if err != nil { - return fsHTTPError(err) - } - info, err := os.Stat(target) - if err != nil { - return fsHTTPError(err) - } - if !info.IsDir() { - return echo.NewHTTPError(http.StatusBadRequest, "path is not a directory") - } - - entries := []FSRestEntry{} - if recursive { - err = filepath.WalkDir(target, func(p string, d os.DirEntry, walkErr error) error { - if walkErr != nil { - return walkErr - } - if p == target { - return nil - } - entryInfo, err := d.Info() - if err != nil { - return err - } - entry, err := entryForBotPath(root, p, entryInfo) - if err != nil { - return err - } - entries = append(entries, entry) - return nil - }) - } else { - dirEntries, err := os.ReadDir(target) - if err != nil { - return fsHTTPError(err) - } - for _, entry := range dirEntries { - entryInfo, err := entry.Info() - if err != nil { - return fsHTTPError(err) - } - fullPath := filepath.Join(target, entry.Name()) - fileEntry, err := entryForBotPath(root, fullPath, entryInfo) - if err != nil { - return fsHTTPError(err) - } - entries = append(entries, fileEntry) - } - } - if err != nil { - return fsHTTPError(err) - } - - listedPath := strings.TrimSpace(rel) - if listedPath == "" || listedPath == "." { - listedPath = "." - } - return c.JSON(http.StatusOK, FSListResponse{Path: listedPath, Entries: entries}) -} - -// ReadFSFile godoc -// @Summary Read file content -// @Tags fs -// @Param bot_id path string true "Bot ID" -// @Param path query string true "Relative file path" -// @Success 200 {object} FSReadResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/file [get] -func (h *ContainerdHandler) ReadFSFile(c echo.Context) error { - botID, err := h.requireBotAccess(c) - if err != nil { - return err - } - root, err := h.ensureBotDataRoot(botID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - target, rel, err := resolveBotPath(root, c.QueryParam("path"), false) - if err != nil { - return fsHTTPError(err) - } - info, err := os.Stat(target) - if err != nil { - return fsHTTPError(err) - } - if info.IsDir() { - return echo.NewHTTPError(http.StatusBadRequest, "path is a directory") - } - data, err := os.ReadFile(target) - if err != nil { - return fsHTTPError(err) - } - return c.JSON(http.StatusOK, FSReadResponse{ - Path: rel, - Content: string(data), - Size: info.Size(), - Mode: uint32(info.Mode().Perm()), - ModTime: info.ModTime(), - }) -} - -// StatFS godoc -// @Summary Get file or directory metadata -// @Tags fs -// @Param bot_id path string true "Bot ID" -// @Param path query string true "Relative path" -// @Success 200 {object} FSStatResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/stat [get] -func (h *ContainerdHandler) StatFS(c echo.Context) error { - botID, err := h.requireBotAccess(c) - if err != nil { - return err - } - root, err := h.ensureBotDataRoot(botID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - target, rel, err := resolveBotPath(root, c.QueryParam("path"), false) - if err != nil { - return fsHTTPError(err) - } - info, err := os.Stat(target) - if err != nil { - return fsHTTPError(err) - } - return c.JSON(http.StatusOK, FSStatResponse{ - Path: rel, - IsDir: info.IsDir(), - Size: info.Size(), - Mode: uint32(info.Mode().Perm()), - ModTime: info.ModTime(), - }) -} - -// UsageFS godoc -// @Summary Get usage under a path -// @Tags fs -// @Param bot_id path string true "Bot ID" -// @Param path query string false "Relative directory path" -// @Success 200 {object} FSUsageResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/usage [get] -func (h *ContainerdHandler) UsageFS(c echo.Context) error { - botID, err := h.requireBotAccess(c) - if err != nil { - return err - } - root, err := h.ensureBotDataRoot(botID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - target, rel, err := resolveBotPath(root, c.QueryParam("path"), true) - if err != nil { - return fsHTTPError(err) - } - info, err := os.Stat(target) - if err != nil { - return fsHTTPError(err) - } - - var totalBytes int64 - var fileCount int64 - var dirCount int64 - if info.IsDir() { - err = filepath.WalkDir(target, func(p string, d os.DirEntry, walkErr error) error { - if walkErr != nil { - return walkErr - } - if p == target { - return nil - } - entryInfo, err := d.Info() - if err != nil { - return err - } - if entryInfo.IsDir() { - dirCount++ - return nil - } - fileCount++ - totalBytes += entryInfo.Size() - return nil - }) - if err != nil { - return fsHTTPError(err) - } - } else { - fileCount = 1 - totalBytes = info.Size() - } - - usagePath := strings.TrimSpace(rel) - if usagePath == "" || usagePath == "." { - usagePath = "." - } - return c.JSON(http.StatusOK, FSUsageResponse{ - Path: usagePath, - TotalBytes: totalBytes, - FileCount: fileCount, - DirCount: dirCount, - }) -} - -// WriteFSFile godoc -// @Summary Create or overwrite a file -// @Tags fs -// @Param bot_id path string true "Bot ID" -// @Param payload body FSWriteRequest true "File write payload" -// @Success 200 {object} FSWriteResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 409 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/file [post] -func (h *ContainerdHandler) WriteFSFile(c echo.Context) error { - botID, err := h.requireBotAccess(c) - if err != nil { - return err - } - var req FSWriteRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - if strings.TrimSpace(req.Path) == "" { - return echo.NewHTTPError(http.StatusBadRequest, "path is required") - } - root, err := h.ensureBotDataRoot(botID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - target, _, err := resolveBotPath(root, req.Path, false) - if err != nil { - return fsHTTPError(err) - } - overwrite := true - if req.Overwrite != nil { - overwrite = *req.Overwrite - } - if _, err := os.Stat(target); err == nil && !overwrite { - return echo.NewHTTPError(http.StatusConflict, "file already exists") - } - if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if err := os.WriteFile(target, []byte(req.Content), 0o644); err != nil { - return fsHTTPError(err) - } - return c.JSON(http.StatusOK, FSWriteResponse{OK: true}) -} - -// MkdirFS godoc -// @Summary Create a directory -// @Tags fs -// @Param bot_id path string true "Bot ID" -// @Param payload body FSMkdirRequest true "Directory payload" -// @Success 200 {object} FSWriteResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/dir [post] -func (h *ContainerdHandler) MkdirFS(c echo.Context) error { - botID, err := h.requireBotAccess(c) - if err != nil { - return err - } - var req FSMkdirRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - if strings.TrimSpace(req.Path) == "" { - return echo.NewHTTPError(http.StatusBadRequest, "path is required") - } - root, err := h.ensureBotDataRoot(botID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - target, _, err := resolveBotPath(root, req.Path, false) - if err != nil { - return fsHTTPError(err) - } - parents := true - if req.Parents != nil { - parents = *req.Parents - } - if parents { - if err := os.MkdirAll(target, 0o755); err != nil { - return fsHTTPError(err) - } - } else if err := os.Mkdir(target, 0o755); err != nil { - return fsHTTPError(err) - } - return c.JSON(http.StatusOK, FSWriteResponse{OK: true}) -} - -// UploadFS godoc -// @Summary Upload a file -// @Tags fs -// @Param bot_id path string true "Bot ID" -// @Param path query string false "Relative file path or directory" -// @Param file formData file true "File to upload" -// @Success 200 {object} FSWriteResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/upload [post] -func (h *ContainerdHandler) UploadFS(c echo.Context) error { - botID, err := h.requireBotAccess(c) - if err != nil { - return err - } - file, err := c.FormFile("file") - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "file is required") - } - rawPath := strings.TrimSpace(c.FormValue("path")) - if rawPath == "" { - rawPath = strings.TrimSpace(c.QueryParam("path")) - } - if rawPath == "" { - rawPath = file.Filename - } - root, err := h.ensureBotDataRoot(botID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - - targetPath := rawPath - if strings.HasSuffix(rawPath, "/") || strings.HasSuffix(rawPath, string(os.PathSeparator)) { - targetPath = filepath.ToSlash(filepath.Join(rawPath, file.Filename)) - } - target, _, err := resolveBotPath(root, targetPath, false) - if err != nil { - return fsHTTPError(err) - } - if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - src, err := file.Open() - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - defer src.Close() - dst, err := os.Create(target) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - defer dst.Close() - if _, err := io.Copy(dst, src); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, FSWriteResponse{OK: true}) -} - -// DeleteFS godoc -// @Summary Delete a file or directory -// @Tags fs -// @Param bot_id path string true "Bot ID" -// @Param path query string true "Relative path" -// @Param recursive query bool false "Recursive delete for directories" -// @Success 200 {object} FSDeleteResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs [delete] -func (h *ContainerdHandler) DeleteFS(c echo.Context) error { - botID, err := h.requireBotAccess(c) - if err != nil { - return err - } - recursive, err := parseBoolQuery(c, "recursive") - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - root, err := h.ensureBotDataRoot(botID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - target, rel, err := resolveBotPath(root, c.QueryParam("path"), false) - if err != nil { - return fsHTTPError(err) - } - if rel == "." || rel == "" { - return echo.NewHTTPError(http.StatusBadRequest, "refuse to delete root") - } - info, err := os.Stat(target) - if err != nil { - return fsHTTPError(err) - } - if info.IsDir() && recursive { - if err := os.RemoveAll(target); err != nil { - return fsHTTPError(err) - } - } else if err := os.Remove(target); err != nil { - return fsHTTPError(err) - } - return c.JSON(http.StatusOK, FSDeleteResponse{OK: true}) -} - -func (h *ContainerdHandler) ensureBotDataRoot(botID string) (string, error) { - dataRoot := strings.TrimSpace(h.cfg.DataRoot) - if dataRoot == "" { - dataRoot = config.DefaultDataRoot - } - dataRoot, err := filepath.Abs(dataRoot) - if err != nil { - return "", err - } - root := filepath.Join(dataRoot, "bots", botID) - if err := os.MkdirAll(root, 0o755); err != nil { - return "", err - } - return root, nil -} - -func resolveBotPath(root, requestPath string, allowRoot bool) (string, string, error) { - raw := strings.TrimSpace(requestPath) - if raw == "" { - if allowRoot { - return root, ".", nil - } - return "", "", os.ErrInvalid - } - clean := filepath.Clean(filepath.FromSlash(raw)) - if clean == "." || clean == "" { - if allowRoot { - return root, ".", nil - } - return "", "", os.ErrInvalid - } - if filepath.IsAbs(clean) || strings.HasPrefix(clean, "..") { - return "", "", os.ErrInvalid - } - target := filepath.Join(root, clean) - rel, err := filepath.Rel(root, target) - if err != nil || strings.HasPrefix(rel, "..") { - return "", "", os.ErrInvalid - } - return target, filepath.ToSlash(rel), nil -} - -func entryForBotPath(root, target string, info os.FileInfo) (FSRestEntry, error) { - rel, err := filepath.Rel(root, target) - if err != nil { - return FSRestEntry{}, err - } - if strings.HasPrefix(rel, "..") { - return FSRestEntry{}, os.ErrInvalid - } - if rel == "." { - rel = "" - } - return FSRestEntry{ - Path: filepath.ToSlash(rel), - IsDir: info.IsDir(), - Size: info.Size(), - Mode: uint32(info.Mode().Perm()), - ModTime: info.ModTime(), - }, nil -} - -func parseBoolQuery(c echo.Context, key string) (bool, error) { - raw := strings.TrimSpace(c.QueryParam(key)) - if raw == "" { - return false, nil - } - return strconv.ParseBool(raw) -} - -func fsHTTPError(err error) error { - if err == nil { - return nil - } - if errors.Is(err, os.ErrInvalid) { - return echo.NewHTTPError(http.StatusBadRequest, "invalid path") - } - if os.IsNotExist(err) { - return echo.NewHTTPError(http.StatusNotFound, "path not found") - } - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) -} diff --git a/internal/handlers/local_channel.go b/internal/handlers/local_channel.go index c5c66958..23591bb7 100644 --- a/internal/handlers/local_channel.go +++ b/internal/handlers/local_channel.go @@ -17,29 +17,29 @@ import ( "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/channel" "github.com/memohai/memoh/internal/channel/adapters/local" - "github.com/memohai/memoh/internal/chat" + "github.com/memohai/memoh/internal/conversation" "github.com/memohai/memoh/internal/identity" ) -// LocalChannelHandler handles local channel (CLI/Web) sessions backed by chats. +// LocalChannelHandler handles local channel (CLI/Web) routes backed by bot history. type LocalChannelHandler struct { channelType channel.ChannelType channelManager *channel.Manager channelService *channel.Service - chatService *chat.Service - sessionHub *local.SessionHub + chatService *conversation.Service + routeHub *local.RouteHub botService *bots.Service accountService *accounts.Service } // NewLocalChannelHandler creates a local channel handler. -func NewLocalChannelHandler(channelType channel.ChannelType, channelManager *channel.Manager, channelService *channel.Service, chatService *chat.Service, sessionHub *local.SessionHub, botService *bots.Service, accountService *accounts.Service) *LocalChannelHandler { +func NewLocalChannelHandler(channelType channel.ChannelType, channelManager *channel.Manager, channelService *channel.Service, chatService *conversation.Service, routeHub *local.RouteHub, botService *bots.Service, accountService *accounts.Service) *LocalChannelHandler { return &LocalChannelHandler{ channelType: channelType, channelManager: channelManager, channelService: channelService, chatService: chatService, - sessionHub: sessionHub, + routeHub: routeHub, botService: botService, accountService: accountService, } @@ -49,19 +49,12 @@ func NewLocalChannelHandler(channelType channel.ChannelType, channelManager *cha func (h *LocalChannelHandler) Register(e *echo.Echo) { prefix := fmt.Sprintf("/bots/:bot_id/%s", h.channelType.String()) group := e.Group(prefix) - group.POST("/sessions", h.CreateSession) - group.GET("/sessions/:session_id/stream", h.StreamSession) - group.POST("/sessions/:session_id/messages", h.PostMessage) + group.GET("/stream", h.StreamMessages) + group.POST("/messages", h.PostMessage) } -type localSessionResponse struct { - SessionID string `json:"session_id"` - ChatID string `json:"chat_id"` - StreamURL string `json:"stream_url"` -} - -// CreateSession creates a new local chat session. -func (h *LocalChannelHandler) CreateSession(c echo.Context) error { +// StreamMessages streams responses for the bot route. +func (h *LocalChannelHandler) StreamMessages(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err @@ -73,43 +66,11 @@ func (h *LocalChannelHandler) CreateSession(c echo.Context) error { if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } - - // Create a chat as the underlying container. - chatObj, err := h.chatService.Create(c.Request().Context(), botID, channelIdentityID, chat.CreateChatRequest{ - Kind: chat.KindDirect, - }) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - - // Use chat_id as the session_id for the local hub. - sessionID := chatObj.ID - streamURL := fmt.Sprintf("/bots/%s/%s/sessions/%s/stream", botID, h.channelType.String(), sessionID) - return c.JSON(http.StatusOK, localSessionResponse{SessionID: sessionID, ChatID: chatObj.ID, StreamURL: streamURL}) -} - -// StreamSession streams responses for a local session. -func (h *LocalChannelHandler) StreamSession(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) - if err != nil { + if err := h.ensureBotParticipant(c.Request().Context(), botID, channelIdentityID); err != nil { return err } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - sessionID := strings.TrimSpace(c.Param("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { - return err - } - if err := h.ensureChatParticipant(c.Request().Context(), sessionID, channelIdentityID); err != nil { - return err - } - if h.sessionHub == nil { - return echo.NewHTTPError(http.StatusInternalServerError, "session hub not configured") + if h.routeHub == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "route hub not configured") } c.Response().Header().Set(echo.HeaderContentType, "text/event-stream") @@ -123,7 +84,7 @@ func (h *LocalChannelHandler) StreamSession(c echo.Context) error { } writer := bufio.NewWriter(c.Response().Writer) - _, stream, cancel := h.sessionHub.Subscribe(sessionID) + _, stream, cancel := h.routeHub.Subscribe(botID) defer cancel() for { @@ -135,8 +96,8 @@ func (h *LocalChannelHandler) StreamSession(c echo.Context) error { return nil } payload := map[string]any{ - "target": msg.Target, - "message": msg.Message, + "target": msg.Target, + "event": msg.Event, } data, err := json.Marshal(payload) if err != nil { @@ -163,14 +124,10 @@ func (h *LocalChannelHandler) PostMessage(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - sessionID := strings.TrimSpace(c.Param("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session id is required") - } if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } - if err := h.ensureChatParticipant(c.Request().Context(), sessionID, channelIdentityID); err != nil { + if err := h.ensureBotParticipant(c.Request().Context(), botID, channelIdentityID); err != nil { return err } if h.channelManager == nil || h.channelService == nil { @@ -180,20 +137,20 @@ func (h *LocalChannelHandler) PostMessage(c echo.Context) error { if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - text := strings.TrimSpace(req.Message.PlainText()) - if text == "" { + if req.Message.IsEmpty() { return echo.NewHTTPError(http.StatusBadRequest, "message is required") } cfg, err := h.channelService.ResolveEffectiveConfig(c.Request().Context(), botID, h.channelType) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } + routeKey := botID msg := channel.InboundMessage{ Channel: h.channelType, Message: req.Message, BotID: botID, - ReplyTarget: sessionID, - SessionKey: sessionID, + ReplyTarget: routeKey, + RouteKey: routeKey, Sender: channel.Identity{ SubjectID: channelIdentityID, Attributes: map[string]string{ @@ -201,7 +158,7 @@ func (h *LocalChannelHandler) PostMessage(c echo.Context) error { }, }, Conversation: channel.Conversation{ - ID: sessionID, + ID: routeKey, Type: "p2p", }, ReceivedAt: time.Now().UTC(), @@ -213,16 +170,16 @@ func (h *LocalChannelHandler) PostMessage(c echo.Context) error { return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) } -func (h *LocalChannelHandler) ensureChatParticipant(ctx context.Context, chatID, channelIdentityID string) error { +func (h *LocalChannelHandler) ensureBotParticipant(ctx context.Context, botID, channelIdentityID string) error { if h.chatService == nil { return echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured") } - ok, err := h.chatService.IsParticipant(ctx, chatID, channelIdentityID) + ok, err := h.chatService.IsParticipant(ctx, botID, channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } if !ok { - return echo.NewHTTPError(http.StatusForbidden, "chat access denied") + return echo.NewHTTPError(http.StatusForbidden, "bot access denied") } return nil } diff --git a/internal/handlers/mcp.go b/internal/handlers/mcp.go index 6434f1c9..784150d4 100644 --- a/internal/handlers/mcp.go +++ b/internal/handlers/mcp.go @@ -46,7 +46,6 @@ func (h *MCPHandler) Register(e *echo.Echo) { // @Summary List MCP connections // @Description List MCP connections for a bot // @Tags mcp -// @Param bot_id path string true "Bot ID" // @Success 200 {object} mcp.ListResponse // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse @@ -76,7 +75,6 @@ func (h *MCPHandler) List(c echo.Context) error { // @Summary Create MCP connection // @Description Create a MCP connection for a bot // @Tags mcp -// @Param bot_id path string true "Bot ID" // @Param payload body mcp.UpsertRequest true "MCP payload" // @Success 201 {object} mcp.Connection // @Failure 400 {object} ErrorResponse @@ -111,7 +109,6 @@ func (h *MCPHandler) Create(c echo.Context) error { // @Summary Get MCP connection // @Description Get a MCP connection by ID // @Tags mcp -// @Param bot_id path string true "Bot ID" // @Param id path string true "MCP ID" // @Success 200 {object} mcp.Connection // @Failure 400 {object} ErrorResponse @@ -149,7 +146,6 @@ func (h *MCPHandler) Get(c echo.Context) error { // @Summary Update MCP connection // @Description Update a MCP connection by ID // @Tags mcp -// @Param bot_id path string true "Bot ID" // @Param id path string true "MCP ID" // @Param payload body mcp.UpsertRequest true "MCP payload" // @Success 200 {object} mcp.Connection @@ -192,7 +188,6 @@ func (h *MCPHandler) Update(c echo.Context) error { // @Summary Delete MCP connection // @Description Delete a MCP connection by ID // @Tags mcp -// @Param bot_id path string true "Bot ID" // @Param id path string true "MCP ID" // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse diff --git a/internal/handlers/mcp_federation_gateway.go b/internal/handlers/mcp_federation_gateway.go index 4872c9f1..ea9bd57e 100644 --- a/internal/handlers/mcp_federation_gateway.go +++ b/internal/handlers/mcp_federation_gateway.go @@ -33,60 +33,6 @@ func NewMCPFederationGateway(log *slog.Logger, handler *ContainerdHandler) *MCPF } } -func (g *MCPFederationGateway) ListFSMCPTools(ctx context.Context, botID string) ([]mcpgw.ToolDescriptor, error) { - if g.handler == nil { - return nil, fmt.Errorf("containerd handler not configured") - } - containerID, err := g.handler.botContainerID(ctx, botID) - if err != nil { - return nil, err - } - if err := g.handler.validateMCPContainer(ctx, containerID, botID); err != nil { - return nil, err - } - if err := g.handler.ensureTaskRunning(ctx, containerID); err != nil { - return nil, err - } - payload, err := g.handler.callMCPServer(ctx, containerID, mcpgw.JSONRPCRequest{ - JSONRPC: "2.0", - ID: mcpgw.RawStringID("federated-fs-tools-list"), - Method: "tools/list", - }) - if err != nil { - return nil, err - } - return parseGatewayToolsListPayload(payload) -} - -func (g *MCPFederationGateway) CallFSMCPTool(ctx context.Context, botID, toolName string, args map[string]any) (map[string]any, error) { - if g.handler == nil { - return nil, fmt.Errorf("containerd handler not configured") - } - containerID, err := g.handler.botContainerID(ctx, botID) - if err != nil { - return nil, err - } - if err := g.handler.validateMCPContainer(ctx, containerID, botID); err != nil { - return nil, err - } - if err := g.handler.ensureTaskRunning(ctx, containerID); err != nil { - return nil, err - } - params, err := json.Marshal(map[string]any{ - "name": toolName, - "arguments": args, - }) - if err != nil { - return nil, err - } - return g.handler.callMCPServer(ctx, containerID, mcpgw.JSONRPCRequest{ - JSONRPC: "2.0", - ID: mcpgw.RawStringID("federated-fs-tools-call"), - Method: "tools/call", - Params: params, - }) -} - func (g *MCPFederationGateway) ListHTTPConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { session, err := g.connectStreamableSession(ctx, connection) if err != nil { @@ -315,7 +261,7 @@ func (g *MCPFederationGateway) startStdioConnectionSession(ctx context.Context, if err := g.handler.validateMCPContainer(ctx, containerID, botID); err != nil { return nil, err } - if err := g.handler.ensureTaskRunning(ctx, containerID); err != nil { + if err := g.handler.ensureContainerAndTask(ctx, containerID, botID); err != nil { return nil, err } diff --git a/internal/handlers/mcp_stdio.go b/internal/handlers/mcp_stdio.go index 52566f29..51fff854 100644 --- a/internal/handlers/mcp_stdio.go +++ b/internal/handlers/mcp_stdio.go @@ -30,9 +30,9 @@ type MCPStdioRequest struct { } type MCPStdioResponse struct { - SessionID string `json:"session_id"` - URL string `json:"url"` - Tools []string `json:"tools,omitempty"` + ConnectionID string `json:"connection_id"` + URL string `json:"url"` + Tools []string `json:"tools,omitempty"` } type mcpStdioSession struct { @@ -76,7 +76,7 @@ func (h *ContainerdHandler) CreateMCPStdio(c echo.Context) error { if err := h.validateMCPContainer(ctx, containerID, botID); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if err := h.ensureTaskRunning(ctx, containerID); err != nil { + if err := h.ensureContainerAndTask(ctx, containerID, botID); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -85,9 +85,9 @@ func (h *ContainerdHandler) CreateMCPStdio(c echo.Context) error { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } tools := h.probeMCPTools(ctx, sess, botID, strings.TrimSpace(req.Name)) - sessionID := uuid.NewString() + connectionID := uuid.NewString() record := &mcpStdioSession{ - id: sessionID, + id: connectionID, botID: botID, containerID: containerID, name: strings.TrimSpace(req.Name), @@ -97,19 +97,19 @@ func (h *ContainerdHandler) CreateMCPStdio(c echo.Context) error { } sess.onClose = func() { h.mcpStdioMu.Lock() - if current, ok := h.mcpStdioSess[sessionID]; ok && current == record { - delete(h.mcpStdioSess, sessionID) + if current, ok := h.mcpStdioSess[connectionID]; ok && current == record { + delete(h.mcpStdioSess, connectionID) } h.mcpStdioMu.Unlock() } h.mcpStdioMu.Lock() - h.mcpStdioSess[sessionID] = record + h.mcpStdioSess[connectionID] = record h.mcpStdioMu.Unlock() return c.JSON(http.StatusOK, MCPStdioResponse{ - SessionID: sessionID, - URL: fmt.Sprintf("/bots/%s/mcp-stdio/%s", botID, sessionID), - Tools: tools, + ConnectionID: connectionID, + URL: fmt.Sprintf("/bots/%s/mcp-stdio/%s", botID, connectionID), + Tools: tools, }) } @@ -118,31 +118,31 @@ func (h *ContainerdHandler) CreateMCPStdio(c echo.Context) error { // @Description Proxies MCP JSON-RPC requests to a stdio MCP process in the container. // @Tags containerd // @Param bot_id path string true "Bot ID" -// @Param session_id path string true "Session ID" +// @Param connection_id path string true "Connection ID" // @Param payload body object true "JSON-RPC request" // @Success 200 {object} object "JSON-RPC response: {jsonrpc,id,result|error}" // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp-stdio/{session_id} [post] +// @Router /bots/{bot_id}/mcp-stdio/{connection_id} [post] func (h *ContainerdHandler) HandleMCPStdio(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { return err } - sessionID := strings.TrimSpace(c.Param("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") + connectionID := strings.TrimSpace(c.Param("connection_id")) + if connectionID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "connection_id is required") } h.mcpStdioMu.Lock() - session := h.mcpStdioSess[sessionID] + session := h.mcpStdioSess[connectionID] h.mcpStdioMu.Unlock() if session == nil || session.session == nil || session.botID != botID { - return echo.NewHTTPError(http.StatusNotFound, "mcp session not found") + return echo.NewHTTPError(http.StatusNotFound, "mcp connection not found") } select { case <-session.session.closed: - return echo.NewHTTPError(http.StatusNotFound, "mcp session closed") + return echo.NewHTTPError(http.StatusNotFound, "mcp connection closed") default: } diff --git a/internal/handlers/mcp_tools.go b/internal/handlers/mcp_tools.go index d5b34e06..07b93ef3 100644 --- a/internal/handlers/mcp_tools.go +++ b/internal/handlers/mcp_tools.go @@ -15,12 +15,10 @@ import ( ) const ( - headerChatID = "X-Memoh-Chat-Id" headerChannelIdentityID = "X-Memoh-Channel-Identity-Id" headerSessionToken = "X-Memoh-Session-Token" headerCurrentPlatform = "X-Memoh-Current-Platform" headerReplyTarget = "X-Memoh-Reply-Target" - headerDisplayName = "X-Memoh-Display-Name" ) func (h *ContainerdHandler) SetToolGatewayService(service *mcpgw.ToolGatewayService) { @@ -234,11 +232,10 @@ func (h *ContainerdHandler) buildToolSessionContext(c echo.Context, botID string } return mcpgw.ToolSessionContext{ BotID: strings.TrimSpace(botID), - ChatID: strings.TrimSpace(c.Request().Header.Get(headerChatID)), + ChatID: strings.TrimSpace(botID), ChannelIdentityID: channelIdentityID, SessionToken: strings.TrimSpace(c.Request().Header.Get(headerSessionToken)), CurrentPlatform: strings.TrimSpace(c.Request().Header.Get(headerCurrentPlatform)), ReplyTarget: strings.TrimSpace(c.Request().Header.Get(headerReplyTarget)), - DisplayName: strings.TrimSpace(c.Request().Header.Get(headerDisplayName)), } } diff --git a/internal/handlers/mcp_tools_test.go b/internal/handlers/mcp_tools_test.go index 6ac0ce43..f9ea36b8 100644 --- a/internal/handlers/mcp_tools_test.go +++ b/internal/handlers/mcp_tools_test.go @@ -109,7 +109,6 @@ func TestHandleMCPToolsWithGatewayAcceptCompatibility(t *testing.T) { listReq := httptest.NewRequest(http.MethodPost, "/bots/bot-1/tools", strings.NewReader(`{"jsonrpc":"2.0","id":"1","method":"tools/list"}`)) listReq.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) listReq.Header.Set("Accept", "application/json") - listReq.Header.Set("X-Memoh-Chat-Id", "chat-1") listReq.Header.Set("X-Memoh-Channel-Identity-Id", "user-1") listRec := httptest.NewRecorder() listCtx := e.NewContext(listReq, listRec) @@ -137,7 +136,6 @@ func TestHandleMCPToolsWithGatewayAcceptCompatibility(t *testing.T) { callReq := httptest.NewRequest(http.MethodPost, "/bots/bot-1/tools", strings.NewReader(`{"jsonrpc":"2.0","id":"2","method":"tools/call","params":{"name":"echo_tool","arguments":{"input":"hello"}}}`)) callReq.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) callReq.Header.Set("Accept", "application/json") - callReq.Header.Set("X-Memoh-Chat-Id", "chat-1") callReq.Header.Set("X-Memoh-Channel-Identity-Id", "user-1") callRec := httptest.NewRecorder() callCtx := e.NewContext(callReq, callRec) @@ -158,7 +156,7 @@ func TestHandleMCPToolsWithGatewayAcceptCompatibility(t *testing.T) { if echoValue := strings.TrimSpace(mcpgw.StringArg(structured, "echo")); echoValue != "hello" { t.Fatalf("unexpected echo value: %#v", structured["echo"]) } - if strings.TrimSpace(mcpgw.StringArg(structured, "chat_id")) != "chat-1" { + if strings.TrimSpace(mcpgw.StringArg(structured, "chat_id")) != "bot-1" { t.Fatalf("unexpected chat id: %#v", structured["chat_id"]) } if strings.TrimSpace(mcpgw.StringArg(structured, "channel_identity_id")) != "user-1" { diff --git a/internal/handlers/memory.go b/internal/handlers/memory.go index 360e8641..ee0595d5 100644 --- a/internal/handlers/memory.go +++ b/internal/handlers/memory.go @@ -11,15 +11,15 @@ import ( "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" - "github.com/memohai/memoh/internal/chat" + "github.com/memohai/memoh/internal/conversation" "github.com/memohai/memoh/internal/identity" "github.com/memohai/memoh/internal/memory" ) -// MemoryHandler handles memory CRUD operations scoped by chat. +// MemoryHandler handles memory CRUD operations scoped by conversation. type MemoryHandler struct { service *memory.Service - chatService *chat.Service + chatService *conversation.Service accountService *accounts.Service logger *slog.Logger } @@ -50,8 +50,10 @@ type namespaceScope struct { ScopeID string } +const sharedMemoryNamespace = "bot" + // NewMemoryHandler creates a MemoryHandler. -func NewMemoryHandler(log *slog.Logger, service *memory.Service, chatService *chat.Service, accountService *accounts.Service) *MemoryHandler { +func NewMemoryHandler(log *slog.Logger, service *memory.Service, chatService *conversation.Service, accountService *accounts.Service) *MemoryHandler { return &MemoryHandler{ service: service, chatService: chatService, @@ -62,7 +64,7 @@ func NewMemoryHandler(log *slog.Logger, service *memory.Service, chatService *ch // Register registers chat-level memory routes. func (h *MemoryHandler) Register(e *echo.Echo) { - chatGroup := e.Group("/chats/:chat_id/memory") + chatGroup := e.Group("/bots/:bot_id/memory") chatGroup.POST("", h.ChatAdd) chatGroup.POST("/search", h.ChatSearch) chatGroup.GET("", h.ChatGetAll) @@ -78,7 +80,7 @@ func (h *MemoryHandler) checkService() error { // --- Chat-level memory endpoints --- -// ChatAdd adds memory to a specific namespace (validated against chat_settings). +// ChatAdd adds memory into the bot-shared namespace. func (h *MemoryHandler) ChatAdd(c echo.Context) error { if err := h.checkService(); err != nil { return err @@ -87,11 +89,11 @@ func (h *MemoryHandler) ChatAdd(c echo.Context) error { if err != nil { return err } - chatID := strings.TrimSpace(c.Param("chat_id")) - if chatID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "chat_id is required") + containerID, err := h.resolveBotContainerID(c) + if err != nil { + return err } - if err := h.requireChatParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { + if err := h.requireChatParticipant(c.Request().Context(), containerID, channelIdentityID); err != nil { return err } @@ -100,13 +102,13 @@ func (h *MemoryHandler) ChatAdd(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - namespace := strings.TrimSpace(payload.Namespace) - if namespace == "" { - namespace = "chat" + namespace, err := normalizeSharedMemoryNamespace(payload.Namespace) + if err != nil { + return err } - // Resolve correct scopeId/botId and validate namespace is enabled. - scopeID, botID, err := h.resolveWriteScope(c.Request().Context(), chatID, channelIdentityID, namespace) + // Resolve bot scope for shared memory. + scopeID, botID, err := h.resolveWriteScope(c.Request().Context(), containerID) if err != nil { return err } @@ -129,7 +131,7 @@ func (h *MemoryHandler) ChatAdd(c echo.Context) error { return c.JSON(http.StatusOK, resp) } -// ChatSearch searches memory across all enabled namespaces per chat_settings. +// ChatSearch searches memory in the bot-shared namespace. func (h *MemoryHandler) ChatSearch(c echo.Context) error { if err := h.checkService(); err != nil { return err @@ -138,11 +140,11 @@ func (h *MemoryHandler) ChatSearch(c echo.Context) error { if err != nil { return err } - chatID := strings.TrimSpace(c.Param("chat_id")) - if chatID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "chat_id is required") + containerID, err := h.resolveBotContainerID(c) + if err != nil { + return err } - if err := h.requireChatParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { + if err := h.requireChatParticipant(c.Request().Context(), containerID, channelIdentityID); err != nil { return err } @@ -151,17 +153,17 @@ func (h *MemoryHandler) ChatSearch(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - scopes, err := h.resolveEnabledScopes(c.Request().Context(), chatID, channelIdentityID) + scopes, err := h.resolveEnabledScopes(c.Request().Context(), containerID) if err != nil { return err } - chatObj, err := h.chatService.Get(c.Request().Context(), chatID) + chatObj, err := h.chatService.Get(c.Request().Context(), containerID) if err != nil { return echo.NewHTTPError(http.StatusNotFound, "chat not found") } botID := strings.TrimSpace(chatObj.BotID) - // Search across all enabled namespaces and merge results. + // Search shared namespace and merge results. var allResults []memory.MemoryItem for _, scope := range scopes { filters := buildNamespaceFilters(scope.Namespace, scope.ScopeID, payload.Filters) @@ -197,7 +199,7 @@ func (h *MemoryHandler) ChatSearch(c echo.Context) error { return c.JSON(http.StatusOK, memory.SearchResponse{Results: allResults}) } -// ChatGetAll lists all memories across enabled namespaces. +// ChatGetAll lists all memories in the bot-shared namespace. func (h *MemoryHandler) ChatGetAll(c echo.Context) error { if err := h.checkService(); err != nil { return err @@ -206,15 +208,15 @@ func (h *MemoryHandler) ChatGetAll(c echo.Context) error { if err != nil { return err } - chatID := strings.TrimSpace(c.Param("chat_id")) - if chatID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "chat_id is required") + containerID, err := h.resolveBotContainerID(c) + if err != nil { + return err } - if err := h.requireChatParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { + if err := h.requireChatParticipant(c.Request().Context(), containerID, channelIdentityID); err != nil { return err } - scopes, err := h.resolveEnabledScopes(c.Request().Context(), chatID, channelIdentityID) + scopes, err := h.resolveEnabledScopes(c.Request().Context(), containerID) if err != nil { return err } @@ -236,7 +238,7 @@ func (h *MemoryHandler) ChatGetAll(c echo.Context) error { return c.JSON(http.StatusOK, memory.SearchResponse{Results: allResults}) } -// ChatDeleteAll deletes all memories across enabled namespaces. +// ChatDeleteAll deletes all memories in the bot-shared namespace. func (h *MemoryHandler) ChatDeleteAll(c echo.Context) error { if err := h.checkService(); err != nil { return err @@ -245,15 +247,15 @@ func (h *MemoryHandler) ChatDeleteAll(c echo.Context) error { if err != nil { return err } - chatID := strings.TrimSpace(c.Param("chat_id")) - if chatID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "chat_id is required") + containerID, err := h.resolveBotContainerID(c) + if err != nil { + return err } - if err := h.requireChatParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { + if err := h.requireChatParticipant(c.Request().Context(), containerID, channelIdentityID); err != nil { return err } - scopes, err := h.resolveEnabledScopes(c.Request().Context(), chatID, channelIdentityID) + scopes, err := h.resolveEnabledScopes(c.Request().Context(), containerID) if err != nil { return err } @@ -271,8 +273,8 @@ func (h *MemoryHandler) ChatDeleteAll(c echo.Context) error { // --- helpers --- -// resolveEnabledScopes returns all namespace scopes enabled by chat_settings. -func (h *MemoryHandler) resolveEnabledScopes(ctx context.Context, chatID, channelIdentityID string) ([]namespaceScope, error) { +// resolveEnabledScopes returns the bot-shared namespace scope for the conversation. +func (h *MemoryHandler) resolveEnabledScopes(ctx context.Context, chatID string) ([]namespaceScope, error) { if h.chatService == nil { return nil, echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured") } @@ -280,29 +282,18 @@ func (h *MemoryHandler) resolveEnabledScopes(ctx context.Context, chatID, channe if err != nil { return nil, echo.NewHTTPError(http.StatusNotFound, "chat not found") } - settings, err := h.chatService.GetSettings(ctx, chatID) - if err != nil { - return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + botID := strings.TrimSpace(chatObj.BotID) + if botID == "" { + return nil, echo.NewHTTPError(http.StatusInternalServerError, "chat bot id is empty") } - - var scopes []namespaceScope - if settings.EnableChatMemory { - scopes = append(scopes, namespaceScope{Namespace: "chat", ScopeID: chatID}) - } - if settings.EnablePrivateMemory && strings.TrimSpace(channelIdentityID) != "" { - scopes = append(scopes, namespaceScope{Namespace: "private", ScopeID: channelIdentityID}) - } - if settings.EnablePublicMemory { - scopes = append(scopes, namespaceScope{Namespace: "public", ScopeID: chatObj.BotID}) - } - if len(scopes) == 0 { - scopes = append(scopes, namespaceScope{Namespace: "chat", ScopeID: chatID}) - } - return scopes, nil + return []namespaceScope{{ + Namespace: sharedMemoryNamespace, + ScopeID: botID, + }}, nil } -// resolveWriteScope validates namespace and returns (scopeId, botId). -func (h *MemoryHandler) resolveWriteScope(ctx context.Context, chatID, channelIdentityID, namespace string) (string, string, error) { +// resolveWriteScope returns (scopeID, botID) for shared bot memory. +func (h *MemoryHandler) resolveWriteScope(ctx context.Context, chatID string) (string, string, error) { if h.chatService == nil { return "", "", echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured") } @@ -310,35 +301,30 @@ func (h *MemoryHandler) resolveWriteScope(ctx context.Context, chatID, channelId if err != nil { return "", "", echo.NewHTTPError(http.StatusNotFound, "chat not found") } - settings, err := h.chatService.GetSettings(ctx, chatID) - if err != nil { - return "", "", echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + botID := strings.TrimSpace(chatObj.BotID) + if botID == "" { + return "", "", echo.NewHTTPError(http.StatusInternalServerError, "bot id is empty") } + return botID, botID, nil +} - switch namespace { - case "chat": - if !settings.EnableChatMemory { - return "", "", echo.NewHTTPError(http.StatusForbidden, "chat memory is disabled for this chat") - } - return chatID, chatObj.BotID, nil - case "private": - if !settings.EnablePrivateMemory { - return "", "", echo.NewHTTPError(http.StatusForbidden, "private memory is disabled for this chat") - } - if strings.TrimSpace(channelIdentityID) == "" { - return "", "", echo.NewHTTPError(http.StatusBadRequest, "channel_identity_id required for private namespace") - } - return channelIdentityID, chatObj.BotID, nil - case "public": - if !settings.EnablePublicMemory { - return "", "", echo.NewHTTPError(http.StatusForbidden, "public memory is disabled for this chat") - } - return chatObj.BotID, chatObj.BotID, nil +func normalizeSharedMemoryNamespace(raw string) (string, error) { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "", sharedMemoryNamespace: + return sharedMemoryNamespace, nil default: - return "", "", echo.NewHTTPError(http.StatusBadRequest, "invalid namespace: "+namespace) + return "", echo.NewHTTPError(http.StatusBadRequest, "invalid namespace: "+raw) } } +func (h *MemoryHandler) resolveBotContainerID(c echo.Context) (string, error) { + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return "", echo.NewHTTPError(http.StatusBadRequest, "bot_id is required") + } + return botID, nil +} + func buildNamespaceFilters(namespace, scopeID string, extra map[string]any) map[string]any { filters := map[string]any{ "namespace": namespace, diff --git a/internal/handlers/message.go b/internal/handlers/message.go new file mode 100644 index 00000000..31c25874 --- /dev/null +++ b/internal/handlers/message.go @@ -0,0 +1,583 @@ +package handlers + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "strconv" + "strings" + "time" + + "github.com/labstack/echo/v4" + + "github.com/memohai/memoh/internal/accounts" + "github.com/memohai/memoh/internal/auth" + "github.com/memohai/memoh/internal/bots" + "github.com/memohai/memoh/internal/channel/identities" + "github.com/memohai/memoh/internal/conversation" + "github.com/memohai/memoh/internal/conversation/flow" + "github.com/memohai/memoh/internal/identity" + messagepkg "github.com/memohai/memoh/internal/message" + messageevent "github.com/memohai/memoh/internal/message/event" +) + +// MessageHandler handles bot-scoped messaging endpoints. +type MessageHandler struct { + runner flow.Runner + conversationService conversation.Accessor + messageService messagepkg.Service + messageEvents messageevent.Subscriber + botService *bots.Service + accountService *accounts.Service + channelIdentitySvc *identities.Service + logger *slog.Logger +} + +// NewMessageHandler creates a MessageHandler. +func NewMessageHandler(log *slog.Logger, runner flow.Runner, conversationService conversation.Accessor, messageService messagepkg.Service, botService *bots.Service, accountService *accounts.Service, channelIdentitySvc *identities.Service, eventSubscribers ...messageevent.Subscriber) *MessageHandler { + var messageEvents messageevent.Subscriber + if len(eventSubscribers) > 0 { + messageEvents = eventSubscribers[0] + } + return &MessageHandler{ + runner: runner, + conversationService: conversationService, + messageService: messageService, + messageEvents: messageEvents, + botService: botService, + accountService: accountService, + channelIdentitySvc: channelIdentitySvc, + logger: log.With(slog.String("handler", "conversation")), + } +} + +// Register registers all conversation routes. +func (h *MessageHandler) Register(e *echo.Echo) { + // Bot-scoped message container (single shared history per bot). + botGroup := e.Group("/bots/:bot_id") + botGroup.POST("/messages", h.SendMessage) + botGroup.POST("/messages/stream", h.StreamMessage) + botGroup.GET("/messages", h.ListMessages) + botGroup.GET("/messages/events", h.StreamMessageEvents) + botGroup.DELETE("/messages", h.DeleteMessages) +} + +// --- Messages --- + +// SendMessage sends a synchronous conversation message. +func (h *MessageHandler) SendMessage(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { + return err + } + if err := h.requireParticipant(c.Request().Context(), botID, channelIdentityID); err != nil { + return err + } + + var req flow.ChatRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + if req.Query == "" { + return echo.NewHTTPError(http.StatusBadRequest, "query is required") + } + req.BotID = botID + req.ChatID = botID + req.Token = c.Request().Header.Get("Authorization") + req.UserID = channelIdentityID + req.SourceChannelIdentityID = channelIdentityID + if strings.TrimSpace(req.CurrentChannel) == "" { + req.CurrentChannel = "web" + } + if len(req.Channels) == 0 { + req.Channels = []string{req.CurrentChannel} + } + channelIdentityID = h.resolveWebChannelIdentity(c.Request().Context(), channelIdentityID, &req) + + if h.runner == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "conversation runner not configured") + } + resp, err := h.runner.Chat(c.Request().Context(), req) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, resp) +} + +// StreamMessage sends a streaming conversation message. +func (h *MessageHandler) StreamMessage(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { + return err + } + if err := h.requireParticipant(c.Request().Context(), botID, channelIdentityID); err != nil { + return err + } + + var req flow.ChatRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + if req.Query == "" { + return echo.NewHTTPError(http.StatusBadRequest, "query is required") + } + req.BotID = botID + req.ChatID = botID + req.Token = c.Request().Header.Get("Authorization") + req.UserID = channelIdentityID + req.SourceChannelIdentityID = channelIdentityID + if strings.TrimSpace(req.CurrentChannel) == "" { + req.CurrentChannel = "web" + } + if len(req.Channels) == 0 { + req.Channels = []string{req.CurrentChannel} + } + channelIdentityID = h.resolveWebChannelIdentity(c.Request().Context(), channelIdentityID, &req) + + if h.runner == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "conversation runner not configured") + } + c.Response().Header().Set(echo.HeaderContentType, "text/event-stream") + c.Response().Header().Set(echo.HeaderCacheControl, "no-cache") + c.Response().Header().Set(echo.HeaderConnection, "keep-alive") + c.Response().WriteHeader(http.StatusOK) + + chunkChan, errChan := h.runner.StreamChat(c.Request().Context(), req) + flusher, ok := c.Response().Writer.(http.Flusher) + if !ok { + return echo.NewHTTPError(http.StatusInternalServerError, "streaming not supported") + } + writer := bufio.NewWriter(c.Response().Writer) + processingState := "started" + if err := writeSSEJSON(writer, flusher, map[string]string{"type": "processing_started"}); err != nil { + return nil + } + + for { + select { + case chunk, ok := <-chunkChan: + if !ok { + if processingState == "started" { + processingState = "completed" + if err := writeSSEJSON(writer, flusher, map[string]string{"type": "processing_completed"}); err != nil { + return nil + } + } + if err := writeSSEData(writer, flusher, "[DONE]"); err != nil { + return nil + } + return nil + } + if processingState == "started" { + processingState = "completed" + if err := writeSSEJSON(writer, flusher, map[string]string{"type": "processing_completed"}); err != nil { + return nil + } + } + if err := writeSSEData(writer, flusher, string(chunk)); err != nil { + return nil + } + case err := <-errChan: + if err != nil { + h.logger.Error("conversation stream failed", slog.Any("error", err)) + if processingState == "started" { + processingState = "failed" + _ = writeSSEJSON(writer, flusher, map[string]string{ + "type": "processing_failed", + "error": err.Error(), + }) + } + errData := map[string]string{ + "type": "error", + "error": err.Error(), + "message": err.Error(), + } + if writeErr := writeSSEJSON(writer, flusher, errData); writeErr != nil { + return nil + } + return nil + } + } + } +} + +func writeSSEData(writer *bufio.Writer, flusher http.Flusher, payload string) error { + if _, err := writer.WriteString(fmt.Sprintf("data: %s\n\n", payload)); err != nil { + return err + } + if err := writer.Flush(); err != nil { + return err + } + flusher.Flush() + return nil +} + +func writeSSEJSON(writer *bufio.Writer, flusher http.Flusher, payload any) error { + data, err := json.Marshal(payload) + if err != nil { + return err + } + return writeSSEData(writer, flusher, string(data)) +} + +func parseSinceParam(raw string) (time.Time, bool, error) { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return time.Time{}, false, nil + } + layouts := []string{time.RFC3339Nano, time.RFC3339} + for _, layout := range layouts { + parsed, err := time.Parse(layout, trimmed) + if err == nil { + return parsed.UTC(), true, nil + } + } + if epochMillis, err := strconv.ParseInt(trimmed, 10, 64); err == nil { + return time.UnixMilli(epochMillis).UTC(), true, nil + } + return time.Time{}, false, fmt.Errorf("invalid since parameter") +} + +// ListMessages lists messages for a conversation with optional pagination. +// Query: limit (default 30), before (optional ISO8601 or unix ms) for older messages. +// Returns items in ascending created_at order (oldest first). +func (h *MessageHandler) ListMessages(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { + return err + } + if err := h.requireReadable(c.Request().Context(), botID, channelIdentityID); err != nil { + return err + } + + if h.messageService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "message service not configured") + } + + limit := int32(30) + if s := strings.TrimSpace(c.QueryParam("limit")); s != "" { + if n, err := strconv.ParseInt(s, 10, 32); err == nil && n > 0 && n <= 100 { + limit = int32(n) + } + } + + before, hasBefore := parseBeforeParam(c.QueryParam("before")) + + var messages []messagepkg.Message + if hasBefore { + messages, err = h.messageService.ListBefore(c.Request().Context(), botID, before, limit) + } else { + messages, err = h.messageService.ListLatest(c.Request().Context(), botID, limit) + if err == nil { + reverseMessages(messages) + } + } + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, map[string]any{"items": messages}) +} + +func parseBeforeParam(s string) (time.Time, bool) { + trimmed := strings.TrimSpace(s) + if trimmed == "" { + return time.Time{}, false + } + if t, err := time.Parse(time.RFC3339Nano, trimmed); err == nil { + return t.UTC(), true + } + if t, err := time.Parse(time.RFC3339, trimmed); err == nil { + return t.UTC(), true + } + if epochMillis, err := strconv.ParseInt(trimmed, 10, 64); err == nil { + return time.UnixMilli(epochMillis).UTC(), true + } + return time.Time{}, false +} + +func reverseMessages(m []messagepkg.Message) { + for i, j := 0, len(m)-1; i < j; i, j = i+1, j-1 { + m[i], m[j] = m[j], m[i] + } +} + +// StreamMessageEvents streams bot-scoped message events to clients. +func (h *MessageHandler) StreamMessageEvents(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { + return err + } + if err := h.requireReadable(c.Request().Context(), botID, channelIdentityID); err != nil { + return err + } + if h.messageService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "message service not configured") + } + if h.messageEvents == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "message events not configured") + } + + since, hasSince, err := parseSinceParam(c.QueryParam("since")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + c.Response().Header().Set(echo.HeaderContentType, "text/event-stream") + c.Response().Header().Set(echo.HeaderCacheControl, "no-cache") + c.Response().Header().Set(echo.HeaderConnection, "keep-alive") + c.Response().WriteHeader(http.StatusOK) + + flusher, ok := c.Response().Writer.(http.Flusher) + if !ok { + return echo.NewHTTPError(http.StatusInternalServerError, "streaming not supported") + } + writer := bufio.NewWriter(c.Response().Writer) + + sentMessageIDs := map[string]struct{}{} + writeCreatedEvent := func(message messagepkg.Message) error { + msgID := strings.TrimSpace(message.ID) + if msgID != "" { + if _, exists := sentMessageIDs[msgID]; exists { + return nil + } + sentMessageIDs[msgID] = struct{}{} + } + return writeSSEJSON(writer, flusher, map[string]any{ + "type": string(messageevent.EventTypeMessageCreated), + "bot_id": botID, + "message": message, + }) + } + + _, stream, cancel := h.messageEvents.Subscribe(botID, 128) + defer cancel() + + if hasSince { + backlog, err := h.messageService.ListSince(c.Request().Context(), botID, since) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + for _, message := range backlog { + if err := writeCreatedEvent(message); err != nil { + return nil + } + } + } + + heartbeatTicker := time.NewTicker(20 * time.Second) + defer heartbeatTicker.Stop() + + for { + select { + case <-c.Request().Context().Done(): + return nil + case <-heartbeatTicker.C: + if err := writeSSEJSON(writer, flusher, map[string]any{"type": "ping"}); err != nil { + return nil + } + case event, ok := <-stream: + if !ok { + return nil + } + if strings.TrimSpace(event.BotID) != botID { + continue + } + if event.Type != messageevent.EventTypeMessageCreated { + continue + } + if len(event.Data) == 0 { + continue + } + var message messagepkg.Message + if err := json.Unmarshal(event.Data, &message); err != nil { + h.logger.Warn("decode message event failed", slog.Any("error", err)) + continue + } + if err := writeCreatedEvent(message); err != nil { + return nil + } + } + } +} + +// DeleteMessages clears all persisted bot-level history messages. +func (h *MessageHandler) DeleteMessages(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotManage(c.Request().Context(), channelIdentityID, botID); err != nil { + return err + } + if h.messageService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "message service not configured") + } + if err := h.messageService.DeleteByBot(c.Request().Context(), botID); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.NoContent(http.StatusNoContent) +} + +// --- helpers --- + +// resolveWebChannelIdentity resolves (web, user_id) to a channel identity and sets req.SourceChannelIdentityID. +// Web uses user_id as the channel subject id (like Feishu open_id); the resolved ci has display_name and is linked to the user. +// Returns the channel_identity_id to use for the rest of the flow, or the original userID if resolution is skipped/fails. +func (h *MessageHandler) resolveWebChannelIdentity(ctx context.Context, userID string, req *flow.ChatRequest) string { + if strings.TrimSpace(req.CurrentChannel) != "web" || h.channelIdentitySvc == nil || strings.TrimSpace(userID) == "" { + return userID + } + displayName := "" + if h.accountService != nil { + if account, err := h.accountService.Get(ctx, userID); err == nil { + displayName = strings.TrimSpace(account.DisplayName) + if displayName == "" { + displayName = strings.TrimSpace(account.Username) + } + } + } + ci, err := h.channelIdentitySvc.ResolveByChannelIdentity(ctx, "web", userID, displayName) + if err != nil { + return userID + } + _ = h.channelIdentitySvc.LinkChannelIdentityToUser(ctx, ci.ID, userID) + req.SourceChannelIdentityID = ci.ID + return ci.ID +} + +func (h *MessageHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.UserIDFromContext(c) + if err != nil { + return "", err + } + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { + return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + return channelIdentityID, nil +} + +func (h *MessageHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") + } + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) + if err != nil { + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: true}) + if err != nil { + if errors.Is(err, bots.ErrBotNotFound) { + return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") + } + if errors.Is(err, bots.ErrBotAccessDenied) { + return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied") + } + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return bot, nil +} + +func (h *MessageHandler) authorizeBotManage(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") + } + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) + if err != nil { + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + if err != nil { + if errors.Is(err, bots.ErrBotNotFound) { + return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") + } + if errors.Is(err, bots.ErrBotAccessDenied) { + return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot management access denied") + } + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return bot, nil +} + +func (h *MessageHandler) requireParticipant(ctx context.Context, conversationID, channelIdentityID string) error { + if h.conversationService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "conversation service not configured") + } + // Admin bypass. + if h.accountService != nil { + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + if isAdmin { + return nil + } + } + ok, err := h.conversationService.IsParticipant(ctx, conversationID, channelIdentityID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + if !ok { + return echo.NewHTTPError(http.StatusForbidden, "not a participant") + } + return nil +} + +func (h *MessageHandler) requireReadable(ctx context.Context, conversationID, channelIdentityID string) error { + if h.conversationService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "conversation service not configured") + } + // Admin bypass. + if h.accountService != nil { + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + if isAdmin { + return nil + } + } + _, err := h.conversationService.GetReadAccess(ctx, conversationID, channelIdentityID) + if err != nil { + if errors.Is(err, conversation.ErrPermissionDenied) { + return echo.NewHTTPError(http.StatusForbidden, "not allowed to read conversation") + } + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return nil +} diff --git a/internal/handlers/schedule.go b/internal/handlers/schedule.go index 75df5269..42643408 100644 --- a/internal/handlers/schedule.go +++ b/internal/handlers/schedule.go @@ -45,7 +45,6 @@ func (h *ScheduleHandler) Register(e *echo.Echo) { // @Summary Create schedule // @Description Create a schedule for current user // @Tags schedule -// @Param bot_id path string true "Bot ID" // @Param payload body schedule.CreateRequest true "Schedule payload" // @Success 201 {object} schedule.Schedule // @Failure 400 {object} ErrorResponse @@ -78,7 +77,6 @@ func (h *ScheduleHandler) Create(c echo.Context) error { // @Summary List schedules // @Description List schedules for current user // @Tags schedule -// @Param bot_id path string true "Bot ID" // @Success 200 {object} schedule.ListResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse @@ -106,7 +104,6 @@ func (h *ScheduleHandler) List(c echo.Context) error { // @Summary Get schedule // @Description Get a schedule by ID // @Tags schedule -// @Param bot_id path string true "Bot ID" // @Param id path string true "Schedule ID" // @Success 200 {object} schedule.Schedule // @Failure 400 {object} ErrorResponse @@ -143,7 +140,6 @@ func (h *ScheduleHandler) Get(c echo.Context) error { // @Summary Update schedule // @Description Update a schedule by ID // @Tags schedule -// @Param bot_id path string true "Bot ID" // @Param id path string true "Schedule ID" // @Param payload body schedule.UpdateRequest true "Schedule payload" // @Success 200 {object} schedule.Schedule @@ -188,7 +184,6 @@ func (h *ScheduleHandler) Update(c echo.Context) error { // @Summary Delete schedule // @Description Delete a schedule by ID // @Tags schedule -// @Param bot_id path string true "Bot ID" // @Param id path string true "Schedule ID" // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse diff --git a/internal/handlers/settings.go b/internal/handlers/settings.go index 5e83d122..7c3d57d0 100644 --- a/internal/handlers/settings.go +++ b/internal/handlers/settings.go @@ -44,7 +44,6 @@ func (h *SettingsHandler) Register(e *echo.Echo) { // @Summary Get user settings // @Description Get agent settings for current user // @Tags settings -// @Param bot_id path string true "Bot ID" // @Success 200 {object} settings.Settings // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse @@ -72,7 +71,6 @@ func (h *SettingsHandler) Get(c echo.Context) error { // @Summary Update user settings // @Description Update or create agent settings for current user // @Tags settings -// @Param bot_id path string true "Bot ID" // @Param payload body settings.UpsertRequest true "Settings payload" // @Success 200 {object} settings.Settings // @Failure 400 {object} ErrorResponse @@ -109,7 +107,6 @@ func (h *SettingsHandler) Upsert(c echo.Context) error { // @Summary Delete user settings // @Description Remove agent settings for current user // @Tags settings -// @Param bot_id path string true "Bot ID" // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse diff --git a/internal/handlers/subagent.go b/internal/handlers/subagent.go index 67c734db..40d86c15 100644 --- a/internal/handlers/subagent.go +++ b/internal/handlers/subagent.go @@ -50,7 +50,6 @@ func (h *SubagentHandler) Register(e *echo.Echo) { // @Summary Create subagent // @Description Create a subagent for current user // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Param payload body subagent.CreateRequest true "Subagent payload" // @Success 201 {object} subagent.Subagent // @Failure 400 {object} ErrorResponse @@ -83,7 +82,6 @@ func (h *SubagentHandler) Create(c echo.Context) error { // @Summary List subagents // @Description List subagents for current user // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Success 200 {object} subagent.ListResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse @@ -111,7 +109,6 @@ func (h *SubagentHandler) List(c echo.Context) error { // @Summary Get subagent // @Description Get a subagent by ID // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Param id path string true "Subagent ID" // @Success 200 {object} subagent.Subagent // @Failure 400 {object} ErrorResponse @@ -148,7 +145,6 @@ func (h *SubagentHandler) Get(c echo.Context) error { // @Summary Update subagent // @Description Update a subagent by ID // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Param id path string true "Subagent ID" // @Param payload body subagent.UpdateRequest true "Subagent payload" // @Success 200 {object} subagent.Subagent @@ -194,7 +190,6 @@ func (h *SubagentHandler) Update(c echo.Context) error { // @Summary Delete subagent // @Description Delete a subagent by ID // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Param id path string true "Subagent ID" // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse @@ -234,7 +229,6 @@ func (h *SubagentHandler) Delete(c echo.Context) error { // @Summary Get subagent context // @Description Get a subagent's message context // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Param id path string true "Subagent ID" // @Success 200 {object} subagent.ContextResponse // @Failure 400 {object} ErrorResponse @@ -271,7 +265,6 @@ func (h *SubagentHandler) GetContext(c echo.Context) error { // @Summary Update subagent context // @Description Update a subagent's message context // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Param id path string true "Subagent ID" // @Param payload body subagent.UpdateContextRequest true "Context payload" // @Success 200 {object} subagent.ContextResponse @@ -317,7 +310,6 @@ func (h *SubagentHandler) UpdateContext(c echo.Context) error { // @Summary Get subagent skills // @Description Get a subagent's skills // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Param id path string true "Subagent ID" // @Success 200 {object} subagent.SkillsResponse // @Failure 400 {object} ErrorResponse @@ -354,7 +346,6 @@ func (h *SubagentHandler) GetSkills(c echo.Context) error { // @Summary Update subagent skills // @Description Replace a subagent's skills // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Param id path string true "Subagent ID" // @Param payload body subagent.UpdateSkillsRequest true "Skills payload" // @Success 200 {object} subagent.SkillsResponse @@ -400,7 +391,6 @@ func (h *SubagentHandler) UpdateSkills(c echo.Context) error { // @Summary Add subagent skills // @Description Add skills to a subagent // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Param id path string true "Subagent ID" // @Param payload body subagent.AddSkillsRequest true "Skills payload" // @Success 200 {object} subagent.SkillsResponse diff --git a/internal/handlers/users.go b/internal/handlers/users.go index 688af741..d8d975f9 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -14,42 +14,42 @@ import ( "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/channel" - "github.com/memohai/memoh/internal/channelidentities" - "github.com/memohai/memoh/internal/chat" + "github.com/memohai/memoh/internal/channel/identities" + "github.com/memohai/memoh/internal/channel/route" "github.com/memohai/memoh/internal/identity" ) // UsersHandler manages user/account CRUD and bot operations via REST API. type UsersHandler struct { - service *accounts.Service - channelIdentityService *channelidentities.Service - botService *bots.Service - chatService *chat.Service - channelService *channel.Service - channelManager *channel.Manager - registry *channel.Registry - logger *slog.Logger + service *accounts.Service + channelIdentityService *identities.Service + botService *bots.Service + routeService route.Service + channelService *channel.Service + channelManager *channel.Manager + registry *channel.Registry + logger *slog.Logger } type listMyIdentitiesResponse struct { UserID string `json:"user_id"` - Items []channelidentities.ChannelIdentity `json:"items"` + Items []identities.ChannelIdentity `json:"items"` } // NewUsersHandler creates a UsersHandler with channel identity support. -func NewUsersHandler(log *slog.Logger, service *accounts.Service, channelIdentityService *channelidentities.Service, botService *bots.Service, chatService *chat.Service, channelService *channel.Service, channelManager *channel.Manager, registry *channel.Registry) *UsersHandler { +func NewUsersHandler(log *slog.Logger, service *accounts.Service, channelIdentityService *identities.Service, botService *bots.Service, routeService route.Service, channelService *channel.Service, channelManager *channel.Manager, registry *channel.Registry) *UsersHandler { if log == nil { log = slog.Default() } return &UsersHandler{ - service: service, + service: service, channelIdentityService: channelIdentityService, - botService: botService, - chatService: chatService, - channelService: channelService, - channelManager: channelManager, - registry: registry, - logger: log.With(slog.String("handler", "users")), + botService: botService, + routeService: routeService, + channelService: channelService, + channelManager: channelManager, + registry: registry, + logger: log.With(slog.String("handler", "users")), } } @@ -69,6 +69,7 @@ func (h *UsersHandler) Register(e *echo.Echo) { botGroup.POST("", h.CreateBot) botGroup.GET("", h.ListBots) botGroup.GET("/:id", h.GetBot) + botGroup.GET("/:id/checks", h.ListBotChecks) botGroup.PUT("/:id", h.UpdateBot) botGroup.PUT("/:id/owner", h.TransferBotOwner) botGroup.DELETE("/:id", h.DeleteBot) @@ -408,18 +409,35 @@ func (h *UsersHandler) CreateBot(c echo.Context) error { ownerID = raw ownerFromToken = false } - if ownerFromToken && h.channelIdentityService != nil { - linkedUserID, err := h.channelIdentityService.GetLinkedUserID(c.Request().Context(), ownerID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if strings.TrimSpace(linkedUserID) != "" { - ownerID = linkedUserID + if ownerFromToken { + if _, err := h.service.Get(c.Request().Context(), ownerID); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + // Backward-compatible token path: token user_id might be a channel identity ID. + // Try to resolve to linked user first; if still missing, force re-login. + linkedUserID := "" + if h.channelIdentityService != nil { + linkedUserID, err = h.channelIdentityService.GetLinkedUserID(c.Request().Context(), ownerID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + } + linkedUserID = strings.TrimSpace(linkedUserID) + if linkedUserID != "" { + ownerID = linkedUserID + } else { + return echo.NewHTTPError(http.StatusUnauthorized, "owner user not found, please login again") + } + } else { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } } } resp, err := h.botService.Create(c.Request().Context(), ownerID, req) if err != nil { if errors.Is(err, bots.ErrOwnerUserNotFound) { + if ownerFromToken { + return echo.NewHTTPError(http.StatusUnauthorized, "owner user not found, please login again") + } return echo.NewHTTPError(http.StatusBadRequest, "owner user not found") } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) @@ -491,6 +509,39 @@ func (h *UsersHandler) GetBot(c echo.Context) error { return c.JSON(http.StatusOK, bot) } +// ListBotChecks godoc +// @Summary List bot runtime checks +// @Description Evaluate bot attached resource checks in runtime +// @Tags bots +// @Param id path string true "Bot ID" +// @Success 200 {object} bots.ListChecksResponse +// @Failure 400 {object} ErrorResponse +// @Failure 403 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /bots/{id}/checks [get] +func (h *UsersHandler) ListBotChecks(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { + return err + } + items, err := h.botService.ListChecks(c.Request().Context(), botID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return echo.NewHTTPError(http.StatusNotFound, "bot not found") + } + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, bots.ListChecksResponse{Items: items}) +} + // UpdateBot godoc // @Summary Update bot details // @Description Update bot profile (owner/admin only) @@ -576,7 +627,7 @@ func (h *UsersHandler) TransferBotOwner(c echo.Context) error { // @Description Delete a bot user (owner/admin only) // @Tags bots // @Param id path string true "Bot ID" -// @Success 204 "No Content" +// @Success 202 {object} map[string]string // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse @@ -600,7 +651,10 @@ func (h *UsersHandler) DeleteBot(c echo.Context) error { } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - return c.NoContent(http.StatusNoContent) + return c.JSON(http.StatusAccepted, map[string]string{ + "id": botID, + "status": bots.BotStatusDeleting, + }) } // ListBotMembers godoc @@ -857,10 +911,10 @@ func (h *UsersHandler) SendBotMessageSession(c echo.Context) error { if chatToken.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "token bot mismatch") } - if h.channelManager == nil || h.chatService == nil { + if h.channelManager == nil || h.routeService == nil { return echo.NewHTTPError(http.StatusInternalServerError, "services not configured") } - route, err := h.chatService.GetRouteByID(c.Request().Context(), chatToken.RouteID) + route, err := h.routeService.GetByID(c.Request().Context(), chatToken.RouteID) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, "route not found") } diff --git a/internal/logger/logger.go b/internal/logger/logger.go index fc8f970d..81afe073 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -14,7 +14,7 @@ var ( logKey = ctxKey{} ) -// Init 初始化全局日志 +// Init initializes the global logger with the given level and format (e.g. "debug", "json"). func Init(level, format string) { var handler slog.Handler opts := &slog.HandlerOptions{ @@ -31,7 +31,7 @@ func Init(level, format string) { slog.SetDefault(L) } -// FromContext 从 context 中获取 logger,如果不存在则返回全局 logger +// FromContext returns the logger from ctx, or the global logger if not set. func FromContext(ctx context.Context) *slog.Logger { if l, ok := ctx.Value(logKey).(*slog.Logger); ok { return l @@ -39,7 +39,7 @@ func FromContext(ctx context.Context) *slog.Logger { return L } -// WithContext 将 logger 注入 context +// WithContext stores the logger in ctx and returns the new context. func WithContext(ctx context.Context, l *slog.Logger) context.Context { return context.WithValue(ctx, logKey, l) } @@ -59,7 +59,7 @@ func parseLevel(level string) slog.Level { } } -// 快捷方法,支持强类型 slog.Attr 或松散的 key-value 对 +// Debug, Info, Warn, Error log with the global logger (slog.Attr or key-value pairs). func Debug(msg string, args ...any) { L.Debug(msg, args...) } func Info(msg string, args ...any) { L.Info(msg, args...) } func Warn(msg string, args ...any) { L.Warn(msg, args...) } diff --git a/internal/logger/logger_test.go b/internal/logger/logger_test.go index f485ac1c..72d3eea8 100644 --- a/internal/logger/logger_test.go +++ b/internal/logger/logger_test.go @@ -7,21 +7,18 @@ import ( ) func TestInitAndLogging(t *testing.T) { - // 测试 JSON 格式 Init("debug", "json") if L.Enabled(context.Background(), slog.LevelDebug) != true { t.Error("expected debug level to be enabled") } - // 验证是否能正常输出(不崩溃) Info("test info message", "key", "value") } func TestContextLogger(t *testing.T) { Init("info", "text") - // 创建一个带特定属性的 logger expectedKey := "request_id" expectedValue := "12345" customLogger := L.With(expectedKey, expectedValue) @@ -29,7 +26,6 @@ func TestContextLogger(t *testing.T) { ctx := WithContext(context.Background(), customLogger) extracted := FromContext(ctx) - // 这里简单验证提取出来的是否是同一个(或者功能一致) if extracted == nil { t.Fatal("extracted logger should not be nil") } diff --git a/internal/mcp/connections.go b/internal/mcp/connections.go index 0a829b1d..86804776 100644 --- a/internal/mcp/connections.go +++ b/internal/mcp/connections.go @@ -14,14 +14,14 @@ import ( // Connection represents a stored MCP connection for a bot. type Connection struct { - ID string `json:"id" validate:"required"` - BotID string `json:"bot_id" validate:"required"` - Name string `json:"name" validate:"required"` - Type string `json:"type" validate:"required"` - Config map[string]any `json:"config" validate:"required"` - Active bool `json:"active" validate:"required"` - CreatedAt time.Time `json:"created_at" validate:"required"` - UpdatedAt time.Time `json:"updated_at" validate:"required"` + ID string `json:"id"` + BotID string `json:"bot_id"` + Name string `json:"name"` + Type string `json:"type"` + Config map[string]any `json:"config"` + Active bool `json:"active"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // UpsertRequest is the payload for creating or updating MCP connections. @@ -222,8 +222,8 @@ func normalizeMCPConnection(row sqlc.McpConnection) (Connection, error) { return Connection{}, err } return Connection{ - ID: db.UUIDToString(row.ID), - BotID: db.UUIDToString(row.BotID), + ID: row.ID.String(), + BotID: row.BotID.String(), Name: strings.TrimSpace(row.Name), Type: strings.TrimSpace(row.Type), Config: config, diff --git a/internal/mcp/manager.go b/internal/mcp/manager.go index 0877de2e..0a7f2256 100644 --- a/internal/mcp/manager.go +++ b/internal/mcp/manager.go @@ -1,11 +1,15 @@ package mcp import ( + "bytes" "context" + "errors" "fmt" "log/slog" "os" + "os/exec" "path/filepath" + "runtime" "strings" "time" @@ -23,6 +27,7 @@ import ( const ( BotLabelKey = "mcp.bot_id" ContainerPrefix = "mcp-" + DefaultImageRef = "memoh-mcp:dev" ) type ExecRequest struct { @@ -38,20 +43,32 @@ type ExecResult struct { ExitCode uint32 } +// ExecWithCaptureResult holds stdout, stderr and exit code from container exec. +type ExecWithCaptureResult struct { + Stdout string + Stderr string + ExitCode uint32 +} + type Manager struct { service ctr.Service cfg config.MCPConfig + namespace string containerID func(string) string db *pgxpool.Pool queries *dbsqlc.Queries logger *slog.Logger } -func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig) *Manager { +func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string) *Manager { + if namespace == "" { + namespace = config.DefaultNamespace + } return &Manager{ - service: service, - cfg: cfg, - logger: log.With(slog.String("component", "mcp")), + service: service, + cfg: cfg, + namespace: namespace, + logger: log.With(slog.String("component", "mcp")), containerID: func(botID string) string { return ContainerPrefix + botID }, @@ -65,10 +82,7 @@ func (m *Manager) WithDB(db *pgxpool.Pool) *Manager { } func (m *Manager) Init(ctx context.Context) error { - image := m.cfg.BusyboxImage - if image == "" { - image = config.DefaultBusyboxImg - } + image := DefaultImageRef _, err := m.service.PullImage(ctx, image, &ctr.PullImageOptions{ Unpack: true, @@ -103,12 +117,6 @@ func (m *Manager) EnsureBot(ctx context.Context, botID string) error { Source: dataDir, Options: []string{"rbind", "rw"}, }, - { - Destination: "/app", - Type: "bind", - Source: dataDir, - Options: []string{"rbind", "rw"}, - }, { Destination: "/etc/resolv.conf", Type: "bind", @@ -242,6 +250,116 @@ func (m *Manager) Exec(ctx context.Context, req ExecRequest) (*ExecResult, error return &ExecResult{ExitCode: result.ExitCode}, nil } +// ExecWithCapture runs a command in the bot container and returns stdout, stderr and exit code. +// Use this when the caller needs command output (e.g. MCP exec tool). +// The container must already be running; use Start(botID) or the container/start API to start it. +// On darwin, it uses Lima SSH to avoid virtiofs FIFO synchronization issues. +func (m *Manager) ExecWithCapture(ctx context.Context, req ExecRequest) (*ExecWithCaptureResult, error) { + if err := validateBotID(req.BotID); err != nil { + return nil, err + } + if len(req.Command) == 0 { + return nil, fmt.Errorf("%w: empty command", ctr.ErrInvalidArgument) + } + if m.queries == nil { + return nil, fmt.Errorf("db is not configured") + } + + if runtime.GOOS == "darwin" { + return m.execWithCaptureLima(ctx, req) + } + return m.execWithCaptureContainerd(ctx, req) +} + +// execWithCaptureLima runs exec through Lima SSH so that all FIFO I/O stays +// inside the VM, avoiding virtiofs FIFO synchronization issues on macOS. +func (m *Manager) execWithCaptureLima(ctx context.Context, req ExecRequest) (*ExecWithCaptureResult, error) { + containerID := m.containerID(req.BotID) + execID := fmt.Sprintf("exec-%d", time.Now().UnixNano()) + + // Each element becomes a separate OS arg to limactl. Lima/SSH joins + // them with spaces and passes the result to the remote shell, so only + // values that may contain shell-special characters need quoting. + args := []string{"shell", "default", "--", + "sudo", "ctr", "-n", m.namespace, + "tasks", "exec", "--exec-id", execID, + } + if req.WorkDir != "" { + args = append(args, "--cwd", req.WorkDir) + } + for _, e := range req.Env { + args = append(args, "--env", e) + } + args = append(args, containerID) + // Pass command args as-is; Lima shell-quotes each OS arg for the + // remote SSH shell, preserving argument boundaries correctly. + args = append(args, req.Command...) + + cmd := exec.CommandContext(ctx, "limactl", args...) + var stdoutBuf, stderrBuf bytes.Buffer + cmd.Stdout = &stdoutBuf + cmd.Stderr = &stderrBuf + + exitCode := uint32(0) + if err := cmd.Run(); err != nil { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + exitCode = uint32(exitErr.ExitCode()) + } else { + return nil, fmt.Errorf("lima exec: %w", err) + } + } + + // ctr tasks exec may write its own errors to stderr; separate them from + // the container command's stderr output by checking for the ctr prefix. + stderr := stderrBuf.String() + if exitCode != 0 && strings.HasPrefix(stderr, "ctr:") { + return nil, fmt.Errorf("container exec failed: %s", strings.TrimSpace(stderr)) + } + + return &ExecWithCaptureResult{ + Stdout: stdoutBuf.String(), + Stderr: stderr, + ExitCode: exitCode, + }, nil +} + +// execWithCaptureContainerd uses the containerd ExecTask API with FIFO pipes. +// This works reliably on Linux where FIFO I/O stays on the same filesystem. +func (m *Manager) execWithCaptureContainerd(ctx context.Context, req ExecRequest) (*ExecWithCaptureResult, error) { + fifoDir, err := os.MkdirTemp(m.dataRoot(), "exec-fifo-") + if err != nil { + return nil, fmt.Errorf("create fifo dir: %w", err) + } + defer os.RemoveAll(fifoDir) + + var stdoutBuf, stderrBuf bytes.Buffer + result, err := m.service.ExecTask(ctx, m.containerID(req.BotID), ctr.ExecTaskRequest{ + Args: req.Command, + Env: req.Env, + WorkDir: req.WorkDir, + Stderr: &stderrBuf, + Stdout: &stdoutBuf, + FIFODir: fifoDir, + }) + if err != nil { + return nil, err + } + return &ExecWithCaptureResult{ + Stdout: stdoutBuf.String(), + Stderr: stderrBuf.String(), + ExitCode: result.ExitCode, + }, nil +} + +// sshShellQuote wraps a string in single quotes for safe SSH transport. +func sshShellQuote(s string) string { + if s == "" { + return "''" + } + return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'" +} + // DataDir returns the host data directory for a bot. func (m *Manager) DataDir(botID string) (string, error) { if err := validateBotID(botID); err != nil { @@ -274,10 +392,7 @@ func (m *Manager) dataMount() string { } func (m *Manager) imageRef() string { - if m.cfg.BusyboxImage == "" { - return config.DefaultBusyboxImg - } - return m.cfg.BusyboxImage + return DefaultImageRef } func validateBotID(botID string) error { diff --git a/internal/mcp/providers/container/fsops.go b/internal/mcp/providers/container/fsops.go new file mode 100644 index 00000000..49ac9009 --- /dev/null +++ b/internal/mcp/providers/container/fsops.go @@ -0,0 +1,288 @@ +package container + +import ( + "context" + "encoding/base64" + "fmt" + "path" + "strconv" + "strings" + "time" + "unicode" + + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +type fileEntry struct { + Path string + IsDir bool + Size int64 + Mode uint32 + ModTime time.Time +} + +// execRead reads a file inside the container via cat. +func execRead(ctx context.Context, runner ExecRunner, botID, workDir, filePath string) (string, error) { + result, err := runner.ExecWithCapture(ctx, mcpgw.ExecRequest{ + BotID: botID, + Command: []string{"/bin/sh", "-c", "cat " + shellQuote(filePath)}, + WorkDir: workDir, + }) + if err != nil { + return "", err + } + if result.ExitCode != 0 { + return "", fmt.Errorf("%s", strings.TrimSpace(result.Stderr)) + } + return result.Stdout, nil +} + +// execWrite writes content to a file inside the container using base64 encoding +// to avoid shell escaping issues. +func execWrite(ctx context.Context, runner ExecRunner, botID, workDir, filePath, content string) error { + encoded := base64.StdEncoding.EncodeToString([]byte(content)) + dir := path.Dir(filePath) + script := fmt.Sprintf("mkdir -p %s && echo %s | base64 -d > %s", + shellQuote(dir), shellQuote(encoded), shellQuote(filePath)) + result, err := runner.ExecWithCapture(ctx, mcpgw.ExecRequest{ + BotID: botID, + Command: []string{"/bin/sh", "-c", script}, + WorkDir: workDir, + }) + if err != nil { + return err + } + if result.ExitCode != 0 { + return fmt.Errorf("%s", strings.TrimSpace(result.Stderr)) + } + return nil +} + +// execList lists directory entries inside the container via find + stat. +// Output format per line: |||| +func execList(ctx context.Context, runner ExecRunner, botID, workDir, dirPath string, recursive bool) ([]fileEntry, error) { + depthFlag := "-maxdepth 1" + if recursive { + depthFlag = "" + } + // Use find to get entries, skip the root dir itself, then stat each entry. + // busybox stat -c format: %n=name, %F=type, %s=size, %a=octal mode, %Y=mtime epoch + script := fmt.Sprintf( + `find %s %s ! -path %s -exec stat -c '%%n|%%F|%%s|%%a|%%Y' {} \;`, + shellQuote(dirPath), depthFlag, shellQuote(dirPath), + ) + result, err := runner.ExecWithCapture(ctx, mcpgw.ExecRequest{ + BotID: botID, + Command: []string{"/bin/sh", "-c", script}, + WorkDir: workDir, + }) + if err != nil { + return nil, err + } + if result.ExitCode != 0 { + return nil, fmt.Errorf("%s", strings.TrimSpace(result.Stderr)) + } + return parseStatOutput(result.Stdout, dirPath), nil +} + +// parseStatOutput parses lines of "fullpath|type|size|mode|mtime" into fileEntry slices. +func parseStatOutput(output, basePath string) []fileEntry { + lines := strings.Split(strings.TrimSpace(output), "\n") + entries := make([]fileEntry, 0, len(lines)) + // Normalize base path for computing relative paths. + base := strings.TrimSuffix(basePath, "/") + if base == "" || base == "." { + base = "" + } + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + parts := strings.SplitN(line, "|", 5) + if len(parts) < 5 { + continue + } + fullPath := parts[0] + fileType := parts[1] + sizeStr := parts[2] + modeStr := parts[3] + mtimeStr := parts[4] + + // Compute relative path from base. + rel := fullPath + if base != "" { + rel = strings.TrimPrefix(fullPath, base+"/") + } + if rel == "" || rel == "." { + continue + } + + isDir := strings.Contains(fileType, "directory") + size, _ := strconv.ParseInt(sizeStr, 10, 64) + mode64, _ := strconv.ParseUint(modeStr, 8, 32) + mtimeEpoch, _ := strconv.ParseInt(mtimeStr, 10, 64) + modTime := time.Unix(mtimeEpoch, 0) + + entries = append(entries, fileEntry{ + Path: rel, + IsDir: isDir, + Size: size, + Mode: uint32(mode64), + ModTime: modTime, + }) + } + return entries +} + +// applyEdit performs the fuzzy text replacement logic on raw file content. +// Returns the updated content or an error. +func applyEdit(raw, filePath, oldText, newText string) (string, error) { + bom, content := stripBOM(raw) + originalEnding := detectLineEnding(content) + normalizedContent := normalizeToLF(content) + normalizedOld := normalizeToLF(oldText) + normalizedNew := normalizeToLF(newText) + match := fuzzyFindText(normalizedContent, normalizedOld) + if !match.Found { + return "", fmt.Errorf( + "could not find the exact text in %s. the old text must match exactly including all whitespace and newlines", + filePath, + ) + } + fuzzyContent := normalizeForFuzzyMatch(normalizedContent) + fuzzyOld := normalizeForFuzzyMatch(normalizedOld) + occurrences := strings.Count(fuzzyContent, fuzzyOld) + if occurrences > 1 { + return "", fmt.Errorf( + "found %d occurrences of the text in %s. the text must be unique. please provide more context to make it unique", + occurrences, + filePath, + ) + } + baseContent := match.ContentForReplacement + updated := baseContent[:match.Index] + normalizedNew + baseContent[match.Index+match.MatchLength:] + if baseContent == updated { + return "", fmt.Errorf( + "no changes made to %s. the replacement produced identical content. this might indicate an issue with special characters or the text not existing as expected", + filePath, + ) + } + return bom + restoreLineEndings(updated, originalEnding), nil +} + +// shellQuote wraps a string in single quotes, escaping embedded single quotes. +func shellQuote(s string) string { + if s == "" { + return "''" + } + if strings.IndexByte(s, '\'') < 0 { + return "'" + s + "'" + } + var b strings.Builder + b.WriteByte('\'') + for _, c := range s { + if c == '\'' { + b.WriteString("'\\''") + } else { + b.WriteRune(c) + } + } + b.WriteByte('\'') + return b.String() +} + +// ---------- fuzzy matching helpers (pure string processing, unchanged) ---------- + +type fuzzyMatchResult struct { + Found bool + Index int + MatchLength int + ContentForReplacement string +} + +func detectLineEnding(content string) string { + crlfIdx := strings.Index(content, "\r\n") + lfIdx := strings.Index(content, "\n") + if lfIdx == -1 { + return "\n" + } + if crlfIdx == -1 { + return "\n" + } + if crlfIdx < lfIdx { + return "\r\n" + } + return "\n" +} + +func normalizeToLF(text string) string { + text = strings.ReplaceAll(text, "\r\n", "\n") + return strings.ReplaceAll(text, "\r", "\n") +} + +func restoreLineEndings(text, ending string) string { + if ending == "\r\n" { + return strings.ReplaceAll(text, "\n", "\r\n") + } + return text +} + +func stripBOM(content string) (string, string) { + const bom = "\uFEFF" + if strings.HasPrefix(content, bom) { + return bom, content[len(bom):] + } + return "", content +} + +func normalizeForFuzzyMatch(text string) string { + lines := strings.Split(text, "\n") + for i, line := range lines { + lines[i] = strings.TrimRightFunc(line, unicode.IsSpace) + } + trimmed := strings.Join(lines, "\n") + return strings.Map(func(r rune) rune { + switch r { + case '\u2018', '\u2019', '\u201A', '\u201B': + return '\'' + case '\u201C', '\u201D', '\u201E', '\u201F': + return '"' + case '\u2010', '\u2011', '\u2012', '\u2013', '\u2014', '\u2015', '\u2212': + return '-' + case '\u00A0', '\u2002', '\u2003', '\u2004', '\u2005', '\u2006', '\u2007', '\u2008', '\u2009', '\u200A', '\u202F', '\u205F', '\u3000': + return ' ' + default: + return r + } + }, trimmed) +} + +func fuzzyFindText(content, oldText string) fuzzyMatchResult { + exactIndex := strings.Index(content, oldText) + if exactIndex != -1 { + return fuzzyMatchResult{ + Found: true, + Index: exactIndex, + MatchLength: len(oldText), + ContentForReplacement: content, + } + } + fuzzyContent := normalizeForFuzzyMatch(content) + fuzzyOld := normalizeForFuzzyMatch(oldText) + fuzzyIndex := strings.Index(fuzzyContent, fuzzyOld) + if fuzzyIndex == -1 { + return fuzzyMatchResult{ + Found: false, + Index: -1, + MatchLength: 0, + ContentForReplacement: content, + } + } + return fuzzyMatchResult{ + Found: true, + Index: fuzzyIndex, + MatchLength: len(fuzzyOld), + ContentForReplacement: fuzzyContent, + } +} diff --git a/internal/mcp/providers/container/fsops_test.go b/internal/mcp/providers/container/fsops_test.go new file mode 100644 index 00000000..acfb2f7d --- /dev/null +++ b/internal/mcp/providers/container/fsops_test.go @@ -0,0 +1,148 @@ +package container + +import "testing" + +func TestShellQuote(t *testing.T) { + tests := []struct { + in string + want string + }{ + {"hello", "'hello'"}, + {"", "''"}, + {"it's", `'it'\''s'`}, + {"a b", "'a b'"}, + } + for _, tt := range tests { + got := shellQuote(tt.in) + if got != tt.want { + t.Errorf("shellQuote(%q) = %q, want %q", tt.in, got, tt.want) + } + } +} + +func TestParseStatOutput(t *testing.T) { + output := `./file.txt|regular file|123|644|1700000000 +./subdir|directory|4096|755|1700000000 +` + entries := parseStatOutput(output, ".") + if len(entries) != 2 { + t.Fatalf("got %d entries, want 2", len(entries)) + } + if entries[0].Path != "./file.txt" { + t.Errorf("path[0] = %q", entries[0].Path) + } + if entries[0].IsDir { + t.Error("file.txt should not be a directory") + } + if entries[0].Size != 123 { + t.Errorf("size[0] = %d", entries[0].Size) + } + if entries[1].Path != "./subdir" { + t.Errorf("path[1] = %q", entries[1].Path) + } + if !entries[1].IsDir { + t.Error("subdir should be a directory") + } +} + +func TestParseStatOutput_WithBasePath(t *testing.T) { + output := `/data/test/file.txt|regular file|10|644|1700000000 +/data/test/sub|directory|4096|755|1700000000 +` + entries := parseStatOutput(output, "/data/test") + if len(entries) != 2 { + t.Fatalf("got %d entries, want 2", len(entries)) + } + if entries[0].Path != "file.txt" { + t.Errorf("path[0] = %q, want %q", entries[0].Path, "file.txt") + } + if entries[1].Path != "sub" { + t.Errorf("path[1] = %q, want %q", entries[1].Path, "sub") + } +} + +func TestParseStatOutput_Empty(t *testing.T) { + entries := parseStatOutput("", ".") + if len(entries) != 0 { + t.Errorf("got %d entries for empty output", len(entries)) + } +} + +func TestApplyEdit(t *testing.T) { + raw := "hello world\n" + updated, err := applyEdit(raw, "test.txt", "hello", "goodbye") + if err != nil { + t.Fatal(err) + } + if updated != "goodbye world\n" { + t.Errorf("updated = %q", updated) + } +} + +func TestApplyEdit_NotFound(t *testing.T) { + raw := "hello world\n" + _, err := applyEdit(raw, "test.txt", "missing text", "new") + if err == nil { + t.Error("expected error for missing text") + } +} + +func TestApplyEdit_MultipleOccurrences(t *testing.T) { + raw := "foo bar foo\n" + _, err := applyEdit(raw, "test.txt", "foo", "baz") + if err == nil { + t.Error("expected error for multiple occurrences") + } +} + +func TestApplyEdit_NoChange(t *testing.T) { + raw := "hello world\n" + _, err := applyEdit(raw, "test.txt", "hello", "hello") + if err == nil { + t.Error("expected error for identical replacement") + } +} + +func TestFuzzyFindText(t *testing.T) { + tests := []struct { + name string + content string + old string + found bool + }{ + {"exact match", "hello world", "hello", true}, + {"no match", "hello world", "missing", false}, + {"smart quote match", "it\u2019s a test", "it's a test", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := fuzzyFindText(tt.content, tt.old) + if result.Found != tt.found { + t.Errorf("found = %v, want %v", result.Found, tt.found) + } + }) + } +} + +func TestDetectLineEnding(t *testing.T) { + if detectLineEnding("foo\r\nbar") != "\r\n" { + t.Error("expected CRLF") + } + if detectLineEnding("foo\nbar") != "\n" { + t.Error("expected LF") + } + if detectLineEnding("foo") != "\n" { + t.Error("expected LF default") + } +} + +func TestStripBOM(t *testing.T) { + bom, content := stripBOM("\uFEFFhello") + if bom != "\uFEFF" || content != "hello" { + t.Errorf("bom=%q content=%q", bom, content) + } + bom2, content2 := stripBOM("hello") + if bom2 != "" || content2 != "hello" { + t.Errorf("bom=%q content=%q", bom2, content2) + } +} diff --git a/internal/mcp/providers/container/provider.go b/internal/mcp/providers/container/provider.go new file mode 100644 index 00000000..f0b95036 --- /dev/null +++ b/internal/mcp/providers/container/provider.go @@ -0,0 +1,250 @@ +package container + +import ( + "context" + "log/slog" + "strings" + + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +const ( + toolRead = "read" + toolWrite = "write" + toolList = "list" + toolEdit = "edit" + toolExec = "exec" + + defaultExecWorkDir = "/data" + shellCommandName = "/bin/sh" + shellCommandFlag = "-c" +) + +// ExecRunner runs a command in the bot container and returns stdout, stderr and exit code. +type ExecRunner interface { + ExecWithCapture(ctx context.Context, req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) +} + +// Executor provides filesystem and exec tools (read, write, list, edit, exec) that +// operate inside the bot container via ExecRunner. All I/O goes through the container +// sandbox — no direct host filesystem access. +type Executor struct { + execRunner ExecRunner + execWorkDir string + logger *slog.Logger +} + +// NewExecutor returns a tool executor. execRunner is required — all tools delegate +// to it for container-side I/O. execWorkDir is the default working directory inside +// the container (e.g. /data). +func NewExecutor(log *slog.Logger, execRunner ExecRunner, execWorkDir string) *Executor { + if log == nil { + log = slog.Default() + } + wd := strings.TrimSpace(execWorkDir) + if wd == "" { + wd = defaultExecWorkDir + } + return &Executor{ + execRunner: execRunner, + execWorkDir: wd, + logger: log.With(slog.String("provider", "container_tool")), + } +} + +// ListTools returns read, write, list, edit, and exec tool descriptors. +func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { + return []mcpgw.ToolDescriptor{ + { + Name: toolRead, + Description: "Read file content inside the bot container.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string", "description": "file path (relative to /data or absolute inside container)"}, + }, + "required": []string{"path"}, + }, + }, + { + Name: toolWrite, + Description: "Write file content inside the bot container.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string", "description": "file path (relative to /data or absolute inside container)"}, + "content": map[string]any{"type": "string", "description": "file content"}, + }, + "required": []string{"path", "content"}, + }, + }, + { + Name: toolList, + Description: "List directory entries inside the bot container.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string", "description": "directory path (relative to /data or absolute inside container)"}, + "recursive": map[string]any{"type": "boolean", "description": "list recursively"}, + }, + "required": []string{"path"}, + }, + }, + { + Name: toolEdit, + Description: "Replace exact text in a file inside the bot container.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string", "description": "file path (relative to /data or absolute inside container)"}, + "old_text": map[string]any{"type": "string", "description": "exact text to find"}, + "new_text": map[string]any{"type": "string", "description": "replacement text"}, + }, + "required": []string{"path", "old_text", "new_text"}, + }, + }, + { + Name: toolExec, + Description: "Execute a command in the bot container. Runs in the bot's data directory (/data) by default.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "command": map[string]any{ + "type": "string", + "description": "Shell command to run (e.g. ls -la, cat file.txt)", + }, + "work_dir": map[string]any{ + "type": "string", + "description": "Working directory inside the container (default: /data)", + }, + }, + "required": []string{"command"}, + }, + }, + }, nil +} + +// normalizePath converts paths that the LLM may send as /data/... into relative +// paths under the working directory. e.g. /data/test.txt -> test.txt, /data -> . +func normalizePath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return path + } + const prefix = "/data" + if path == prefix { + return "." + } + if strings.HasPrefix(path, prefix+"/") { + return strings.TrimLeft(strings.TrimPrefix(path, prefix+"/"), "/") + } + return path +} + +// CallTool dispatches to the appropriate container-exec backed implementation. +func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { + botID := strings.TrimSpace(session.BotID) + if botID == "" { + return mcpgw.BuildToolErrorResult("bot_id is required"), nil + } + + switch toolName { + case toolRead: + filePath := normalizePath(mcpgw.StringArg(arguments, "path")) + if filePath == "" { + return mcpgw.BuildToolErrorResult("path is required"), nil + } + content, err := execRead(ctx, p.execRunner, botID, p.execWorkDir, filePath) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + return mcpgw.BuildToolSuccessResult(map[string]any{"content": content}), nil + + case toolWrite: + filePath := normalizePath(mcpgw.StringArg(arguments, "path")) + content := mcpgw.StringArg(arguments, "content") + if filePath == "" { + return mcpgw.BuildToolErrorResult("path is required"), nil + } + if err := execWrite(ctx, p.execRunner, botID, p.execWorkDir, filePath, content); err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + return mcpgw.BuildToolSuccessResult(map[string]any{"ok": true}), nil + + case toolList: + dirPath := normalizePath(mcpgw.StringArg(arguments, "path")) + if dirPath == "" { + dirPath = "." + } + recursive, _, _ := mcpgw.BoolArg(arguments, "recursive") + entries, err := execList(ctx, p.execRunner, botID, p.execWorkDir, dirPath, recursive) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + entriesMaps := make([]map[string]any, len(entries)) + for i, e := range entries { + entriesMaps[i] = map[string]any{ + "path": e.Path, + "is_dir": e.IsDir, + "size": e.Size, + "mode": e.Mode, + "mod_time": e.ModTime, + } + } + return mcpgw.BuildToolSuccessResult(map[string]any{"path": dirPath, "entries": entriesMaps}), nil + + case toolEdit: + filePath := normalizePath(mcpgw.StringArg(arguments, "path")) + oldText := mcpgw.StringArg(arguments, "old_text") + newText := mcpgw.StringArg(arguments, "new_text") + if filePath == "" || oldText == "" { + return mcpgw.BuildToolErrorResult("path, old_text and new_text are required"), nil + } + // Step 1: read via exec + raw, err := execRead(ctx, p.execRunner, botID, p.execWorkDir, filePath) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + // Step 2: fuzzy match in Go + updated, err := applyEdit(raw, filePath, oldText, newText) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + // Step 3: write back via exec + if err := execWrite(ctx, p.execRunner, botID, p.execWorkDir, filePath, updated); err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + return mcpgw.BuildToolSuccessResult(map[string]any{"ok": true}), nil + + case toolExec: + command := strings.TrimSpace(mcpgw.StringArg(arguments, "command")) + if command == "" { + return mcpgw.BuildToolErrorResult("command is required"), nil + } + workDir := strings.TrimSpace(mcpgw.StringArg(arguments, "work_dir")) + if workDir == "" { + workDir = p.execWorkDir + } + result, err := p.execRunner.ExecWithCapture(ctx, mcpgw.ExecRequest{ + BotID: botID, + Command: []string{shellCommandName, shellCommandFlag, command}, + WorkDir: workDir, + }) + if err != nil { + p.logger.Warn("exec failed", slog.String("bot_id", botID), slog.String("command", command), slog.Any("error", err)) + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + stderr := result.Stderr + if result.ExitCode != 0 && strings.Contains(stderr, "no running task") { + stderr = strings.TrimSpace(stderr) + "\n\nHint: Container exists but has no running task (main process exited). Start it first: POST /bots/" + botID + "/container/start or use the container start action in the UI." + } + return mcpgw.BuildToolSuccessResult(map[string]any{ + "stdout": result.Stdout, + "stderr": stderr, + "exit_code": result.ExitCode, + }), nil + + default: + return nil, mcpgw.ErrToolNotFound + } +} diff --git a/internal/mcp/providers/container/provider_test.go b/internal/mcp/providers/container/provider_test.go new file mode 100644 index 00000000..d921d4d9 --- /dev/null +++ b/internal/mcp/providers/container/provider_test.go @@ -0,0 +1,236 @@ +package container + +import ( + "context" + "encoding/base64" + "fmt" + "strings" + "testing" + + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +// fakeExecRunner records the last request and returns a preset result. +type fakeExecRunner struct { + result *mcpgw.ExecWithCaptureResult + err error + lastReq mcpgw.ExecRequest + handler func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) +} + +func (f *fakeExecRunner) ExecWithCapture(ctx context.Context, req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) { + f.lastReq = req + if f.handler != nil { + return f.handler(req) + } + if f.err != nil { + return nil, f.err + } + return f.result, nil +} + +func TestExecutor_ListTools(t *testing.T) { + runner := &fakeExecRunner{result: &mcpgw.ExecWithCaptureResult{}} + exec := NewExecutor(nil, runner, "/data") + ctx := context.Background() + session := mcpgw.ToolSessionContext{BotID: "test-bot"} + tools, err := exec.ListTools(ctx, session) + if err != nil { + t.Fatal(err) + } + want := map[string]bool{"read": true, "write": true, "list": true, "edit": true, "exec": true} + if len(tools) != len(want) { + t.Errorf("got %d tools, want %d", len(tools), len(want)) + } + for _, tool := range tools { + if !want[tool.Name] { + t.Errorf("unexpected tool %q", tool.Name) + } + } +} + +func TestExecutor_CallTool_Read(t *testing.T) { + runner := &fakeExecRunner{ + result: &mcpgw.ExecWithCaptureResult{Stdout: "hello world", ExitCode: 0}, + } + exec := NewExecutor(nil, runner, "/data") + ctx := context.Background() + session := mcpgw.ToolSessionContext{BotID: "bot1"} + + result, err := exec.CallTool(ctx, session, "read", map[string]any{"path": "test.txt"}) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + content, _ := result["structuredContent"].(map[string]any) + if content["content"] != "hello world" { + t.Errorf("content = %v", content["content"]) + } + // Verify the exec command contains cat. + cmd := strings.Join(runner.lastReq.Command, " ") + if !strings.Contains(cmd, "cat") { + t.Errorf("expected cat command, got %q", cmd) + } +} + +func TestExecutor_CallTool_Write(t *testing.T) { + runner := &fakeExecRunner{ + handler: func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) { + cmd := strings.Join(req.Command, " ") + if !strings.Contains(cmd, "base64 -d") { + return nil, fmt.Errorf("expected base64 write, got %q", cmd) + } + return &mcpgw.ExecWithCaptureResult{ExitCode: 0}, nil + }, + } + exec := NewExecutor(nil, runner, "/data") + ctx := context.Background() + session := mcpgw.ToolSessionContext{BotID: "bot1"} + + result, err := exec.CallTool(ctx, session, "write", map[string]any{ + "path": "hello.txt", "content": "world", + }) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } +} + +func TestExecutor_CallTool_List(t *testing.T) { + runner := &fakeExecRunner{ + result: &mcpgw.ExecWithCaptureResult{ + Stdout: "./test.txt|regular file|42|644|1700000000\n./subdir|directory|4096|755|1700000000\n", + ExitCode: 0, + }, + } + exec := NewExecutor(nil, runner, "/data") + ctx := context.Background() + session := mcpgw.ToolSessionContext{BotID: "bot1"} + + result, err := exec.CallTool(ctx, session, "list", map[string]any{"path": "."}) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + content, _ := result["structuredContent"].(map[string]any) + entries, ok := content["entries"].([]map[string]any) + if !ok { + t.Fatalf("entries type = %T", content["entries"]) + } + if len(entries) != 2 { + t.Fatalf("got %d entries, want 2", len(entries)) + } +} + +func TestExecutor_CallTool_Edit(t *testing.T) { + callCount := 0 + runner := &fakeExecRunner{ + handler: func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) { + callCount++ + cmd := strings.Join(req.Command, " ") + if strings.Contains(cmd, "cat") { + // Read step: return original content. + return &mcpgw.ExecWithCaptureResult{Stdout: "hello world", ExitCode: 0}, nil + } + if strings.Contains(cmd, "base64 -d") { + // Write step: verify the written content contains the replacement. + // Extract base64 from: echo '' | base64 -d > 'path' + parts := strings.Split(cmd, "'") + for _, p := range parts { + decoded, err := base64.StdEncoding.DecodeString(p) + if err == nil && strings.Contains(string(decoded), "goodbye world") { + return &mcpgw.ExecWithCaptureResult{ExitCode: 0}, nil + } + } + return &mcpgw.ExecWithCaptureResult{ExitCode: 0}, nil + } + return &mcpgw.ExecWithCaptureResult{ExitCode: 0}, nil + }, + } + exec := NewExecutor(nil, runner, "/data") + ctx := context.Background() + session := mcpgw.ToolSessionContext{BotID: "bot1"} + + result, err := exec.CallTool(ctx, session, "edit", map[string]any{ + "path": "test.txt", "old_text": "hello", "new_text": "goodbye", + }) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + if callCount < 2 { + t.Errorf("expected at least 2 exec calls (read+write), got %d", callCount) + } +} + +func TestExecutor_CallTool_Exec(t *testing.T) { + runner := &fakeExecRunner{ + result: &mcpgw.ExecWithCaptureResult{ + Stdout: "hello\n", + Stderr: "", + ExitCode: 0, + }, + } + exec := NewExecutor(nil, runner, "/data") + ctx := context.Background() + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(ctx, session, toolExec, map[string]any{"command": "echo hello"}) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + content, _ := result["structuredContent"].(map[string]any) + if content == nil { + t.Fatal("no structuredContent") + } + if content["stdout"] != "hello\n" { + t.Errorf("stdout = %v", content["stdout"]) + } + if content["exit_code"].(uint32) != 0 { + t.Errorf("exit_code = %v", content["exit_code"]) + } +} + +func TestExecutor_CallTool_NoBotID(t *testing.T) { + runner := &fakeExecRunner{result: &mcpgw.ExecWithCaptureResult{}} + exec := NewExecutor(nil, runner, "/data") + ctx := context.Background() + session := mcpgw.ToolSessionContext{} + result, err := exec.CallTool(ctx, session, "read", map[string]any{"path": "x"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when bot_id is missing") + } +} + +func TestNormalizePath(t *testing.T) { + tests := []struct { + in string + want string + }{ + {"/data/test.txt", "test.txt"}, + {"/data/foo/bar.txt", "foo/bar.txt"}, + {"/data", "."}, + {"test.txt", "test.txt"}, + {"", ""}, + {".", "."}, + } + for _, tt := range tests { + got := normalizePath(tt.in) + if got != tt.want { + t.Errorf("normalizePath(%q) = %q, want %q", tt.in, got, tt.want) + } + } +} diff --git a/internal/mcp/providers/directory/provider.go b/internal/mcp/providers/directory/provider.go new file mode 100644 index 00000000..92fa0753 --- /dev/null +++ b/internal/mcp/providers/directory/provider.go @@ -0,0 +1,162 @@ +package directory + +import ( + "context" + "log/slog" + "strings" + + "github.com/memohai/memoh/internal/channel" + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +const toolLookupChannelUser = "lookup_channel_user" + +// ConfigResolver resolves effective channel config for a bot (used to call directory APIs). +type ConfigResolver interface { + ResolveEffectiveConfig(ctx context.Context, botID string, channelType channel.ChannelType) (channel.ChannelConfig, error) +} + +// ChannelTypeResolver parses platform name to channel type. +type ChannelTypeResolver interface { + ParseChannelType(raw string) (channel.ChannelType, error) +} + +// Executor exposes channel directory lookup as an MCP tool for the LLM. +type Executor struct { + registry *channel.Registry + configResolver ConfigResolver + typeResolver ChannelTypeResolver + logger *slog.Logger +} + +// NewExecutor creates a directory tool executor. +func NewExecutor(log *slog.Logger, registry *channel.Registry, configResolver ConfigResolver, typeResolver ChannelTypeResolver) *Executor { + if log == nil { + log = slog.Default() + } + return &Executor{ + registry: registry, + configResolver: configResolver, + typeResolver: typeResolver, + logger: log.With(slog.String("provider", "directory_tool")), + } +} + +// ListTools returns the lookup_channel_user tool descriptor. +func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { + if p.registry == nil || p.configResolver == nil || p.typeResolver == nil { + return []mcpgw.ToolDescriptor{}, nil + } + return []mcpgw.ToolDescriptor{ + { + Name: toolLookupChannelUser, + Description: "Look up a user or group on a channel by platform identifier (e.g. open_id, user_id, chat_id). Returns display name, handle, and id.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "platform": map[string]any{ + "type": "string", + "description": "Channel platform (e.g. feishu, telegram). Defaults to current session platform.", + }, + "bot_id": map[string]any{ + "type": "string", + "description": "Bot ID. Defaults to current session bot.", + }, + "input": map[string]any{ + "type": "string", + "description": "Platform-specific identifier: user id (feishu open_id/user_id, telegram chat_id for private), or \"chat_id:user_id\" for a user in a group (telegram).", + }, + "kind": map[string]any{ + "type": "string", + "description": "Entry kind: \"user\" or \"group\". Default \"user\".", + "enum": []any{"user", "group"}, + }, + }, + "required": []string{"input"}, + }, + }, + }, nil +} + +// CallTool runs lookup_channel_user. +func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { + if toolName != toolLookupChannelUser { + return nil, mcpgw.ErrToolNotFound + } + if p.registry == nil || p.configResolver == nil || p.typeResolver == nil { + return mcpgw.BuildToolErrorResult("directory lookup not available"), nil + } + + botID := mcpgw.FirstStringArg(arguments, "bot_id") + if botID == "" { + botID = strings.TrimSpace(session.BotID) + } + if botID == "" { + return mcpgw.BuildToolErrorResult("bot_id is required"), nil + } + if strings.TrimSpace(session.BotID) != "" && botID != strings.TrimSpace(session.BotID) { + return mcpgw.BuildToolErrorResult("bot_id mismatch"), nil + } + + platform := mcpgw.FirstStringArg(arguments, "platform") + if platform == "" { + platform = strings.TrimSpace(session.CurrentPlatform) + } + if platform == "" { + return mcpgw.BuildToolErrorResult("platform is required"), nil + } + channelType, err := p.typeResolver.ParseChannelType(platform) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + + dirAdapter, ok := p.registry.DirectoryAdapter(channelType) + if !ok || dirAdapter == nil { + return mcpgw.BuildToolErrorResult("channel does not support directory lookup"), nil + } + + input := strings.TrimSpace(mcpgw.FirstStringArg(arguments, "input")) + if input == "" { + return mcpgw.BuildToolErrorResult("input is required"), nil + } + + kindStr := strings.ToLower(strings.TrimSpace(mcpgw.FirstStringArg(arguments, "kind"))) + if kindStr == "" { + kindStr = "user" + } + var kind channel.DirectoryEntryKind + switch kindStr { + case "user": + kind = channel.DirectoryEntryUser + case "group": + kind = channel.DirectoryEntryGroup + default: + return mcpgw.BuildToolErrorResult("kind must be user or group"), nil + } + + cfg, err := p.configResolver.ResolveEffectiveConfig(ctx, botID, channelType) + if err != nil { + p.logger.Warn("resolve config failed", slog.String("bot_id", botID), slog.String("platform", platform), slog.Any("error", err)) + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + + entry, err := dirAdapter.ResolveEntry(ctx, cfg, input, kind) + if err != nil { + p.logger.Warn("resolve entry failed", slog.String("input", input), slog.String("kind", kindStr), slog.Any("error", err)) + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + + payload := map[string]any{ + "ok": true, + "platform": channelType.String(), + "kind": string(entry.Kind), + "id": entry.ID, + "name": entry.Name, + "handle": entry.Handle, + "metadata": entry.Metadata, + } + if entry.AvatarURL != "" { + payload["avatar_url"] = entry.AvatarURL + } + return mcpgw.BuildToolSuccessResult(payload), nil +} diff --git a/internal/mcp/providers/directory/provider_test.go b/internal/mcp/providers/directory/provider_test.go new file mode 100644 index 00000000..1f42969a --- /dev/null +++ b/internal/mcp/providers/directory/provider_test.go @@ -0,0 +1,72 @@ +package directory + +import ( + "context" + "testing" + + "github.com/memohai/memoh/internal/channel" + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +func TestExecutor_ListTools(t *testing.T) { + reg := channel.NewRegistry() + reg.MustRegister(&dirMockAdapter{channelType: "dir-test"}) + svc := &fakeConfigResolver{} + exec := NewExecutor(nil, reg, svc, reg) + tools, err := exec.ListTools(context.Background(), mcpgw.ToolSessionContext{}) + if err != nil { + t.Fatalf("ListTools: %v", err) + } + if len(tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(tools)) + } + if tools[0].Name != toolLookupChannelUser { + t.Errorf("tool name = %q", tools[0].Name) + } +} + +func TestExecutor_ListTools_NilDeps(t *testing.T) { + exec := NewExecutor(nil, nil, nil, nil) + tools, err := exec.ListTools(context.Background(), mcpgw.ToolSessionContext{}) + if err != nil { + t.Fatalf("ListTools: %v", err) + } + if len(tools) != 0 { + t.Errorf("expected 0 tools when deps nil, got %d", len(tools)) + } +} + +func TestExecutor_CallTool_NotFound(t *testing.T) { + exec := NewExecutor(nil, channel.NewRegistry(), &fakeConfigResolver{}, channel.NewRegistry()) + _, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{}, "other_tool", nil) + if err != mcpgw.ErrToolNotFound { + t.Errorf("expected ErrToolNotFound, got %v", err) + } +} + +type dirMockAdapter struct { + channelType channel.ChannelType +} + +func (d *dirMockAdapter) Type() channel.ChannelType { return d.channelType } +func (d *dirMockAdapter) Descriptor() channel.Descriptor { + return channel.Descriptor{Type: d.channelType, DisplayName: "DirTest"} +} +func (d *dirMockAdapter) ListPeers(context.Context, channel.ChannelConfig, channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} +func (d *dirMockAdapter) ListGroups(context.Context, channel.ChannelConfig, channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} +func (d *dirMockAdapter) ListGroupMembers(context.Context, channel.ChannelConfig, string, channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} +func (d *dirMockAdapter) ResolveEntry(context.Context, channel.ChannelConfig, string, channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + return channel.DirectoryEntry{Kind: channel.DirectoryEntryUser, ID: "id1", Name: "Test User"}, nil +} + +type fakeConfigResolver struct{} + +func (f *fakeConfigResolver) ResolveEffectiveConfig(context.Context, string, channel.ChannelType) (channel.ChannelConfig, error) { + return channel.ChannelConfig{}, nil +} diff --git a/internal/mcp/providers/memory/provider.go b/internal/mcp/providers/memory/provider.go index 137e8b19..3ba0e975 100644 --- a/internal/mcp/providers/memory/provider.go +++ b/internal/mcp/providers/memory/provider.go @@ -6,7 +6,7 @@ import ( "sort" "strings" - "github.com/memohai/memoh/internal/chat" + "github.com/memohai/memoh/internal/conversation" mcpgw "github.com/memohai/memoh/internal/mcp" mem "github.com/memohai/memoh/internal/memory" ) @@ -15,30 +15,25 @@ const ( toolSearchMemory = "search_memory" defaultMemoryToolLimit = 8 maxMemoryToolLimit = 50 + sharedMemoryNamespace = "bot" ) type MemorySearcher interface { Search(ctx context.Context, req mem.SearchRequest) (mem.SearchResponse, error) } -type ChatAccessor interface { - Get(ctx context.Context, chatID string) (chat.Chat, error) - GetSettings(ctx context.Context, chatID string) (chat.Settings, error) - IsParticipant(ctx context.Context, chatID, channelIdentityID string) (bool, error) -} - type AdminChecker interface { IsAdmin(ctx context.Context, channelIdentityID string) (bool, error) } type Executor struct { searcher MemorySearcher - chatAccessor ChatAccessor + chatAccessor conversation.Accessor adminChecker AdminChecker logger *slog.Logger } -func NewExecutor(log *slog.Logger, searcher MemorySearcher, chatAccessor ChatAccessor, adminChecker AdminChecker) *Executor { +func NewExecutor(log *slog.Logger, searcher MemorySearcher, chatAccessor conversation.Accessor, adminChecker AdminChecker) *Executor { if log == nil { log = slog.Default() } @@ -91,8 +86,11 @@ func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContex botID := strings.TrimSpace(session.BotID) chatID := strings.TrimSpace(session.ChatID) channelIdentityID := strings.TrimSpace(session.ChannelIdentityID) - if botID == "" || chatID == "" { - return mcpgw.BuildToolErrorResult("bot_id and chat_id are required"), nil + if botID == "" { + return mcpgw.BuildToolErrorResult("bot_id is required"), nil + } + if chatID == "" { + chatID = botID } limit := defaultMemoryToolLimit @@ -108,64 +106,43 @@ func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContex limit = maxMemoryToolLimit } - chatObj, err := p.chatAccessor.Get(ctx, chatID) - if err != nil { - return mcpgw.BuildToolErrorResult("chat not found"), nil - } - if strings.TrimSpace(chatObj.BotID) != botID { - return mcpgw.BuildToolErrorResult("bot mismatch"), nil - } - if channelIdentityID != "" { - allowed, err := p.canAccessChat(ctx, chatID, channelIdentityID) + // When ChatID equals BotID (e.g. tools called without conversation context), search by bot scope only. + // Otherwise require the conversation to exist and the caller to be a participant. + if chatID != botID { + chatObj, err := p.chatAccessor.Get(ctx, chatID) if err != nil { - return mcpgw.BuildToolErrorResult(err.Error()), nil + return mcpgw.BuildToolErrorResult("chat not found"), nil } - if !allowed { - return mcpgw.BuildToolErrorResult("not a chat participant"), nil + if strings.TrimSpace(chatObj.BotID) != botID { + return mcpgw.BuildToolErrorResult("bot mismatch"), nil + } + if channelIdentityID != "" { + allowed, err := p.canAccessChat(ctx, chatID, channelIdentityID) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + if !allowed { + return mcpgw.BuildToolErrorResult("not a chat participant"), nil + } } } - settings, err := p.chatAccessor.GetSettings(ctx, chatID) + resp, err := p.searcher.Search(ctx, mem.SearchRequest{ + Query: query, + BotID: botID, + Limit: limit, + Filters: map[string]any{ + "namespace": sharedMemoryNamespace, + "scopeId": botID, + "botId": botID, + }, + }) if err != nil { - return mcpgw.BuildToolErrorResult(err.Error()), nil - } - - type memoryScope struct { - namespace string - scopeID string - } - scopes := make([]memoryScope, 0, 3) - if settings.EnableChatMemory { - scopes = append(scopes, memoryScope{namespace: "chat", scopeID: chatID}) - } - if settings.EnablePrivateMemory && channelIdentityID != "" { - scopes = append(scopes, memoryScope{namespace: "private", scopeID: channelIdentityID}) - } - if settings.EnablePublicMemory { - scopes = append(scopes, memoryScope{namespace: "public", scopeID: botID}) - } - if len(scopes) == 0 { - scopes = append(scopes, memoryScope{namespace: "chat", scopeID: chatID}) - } - - allResults := make([]mem.MemoryItem, 0, len(scopes)*limit) - for _, scope := range scopes { - resp, err := p.searcher.Search(ctx, mem.SearchRequest{ - Query: query, - BotID: botID, - Limit: limit, - Filters: map[string]any{ - "namespace": scope.namespace, - "scopeId": scope.scopeID, - "botId": botID, - }, - }) - if err != nil { - p.logger.Warn("memory search namespace failed", slog.String("namespace", scope.namespace), slog.Any("error", err)) - continue - } - allResults = append(allResults, resp.Results...) + p.logger.Warn("memory search namespace failed", slog.String("namespace", sharedMemoryNamespace), slog.Any("error", err)) + return mcpgw.BuildToolErrorResult("memory search failed"), nil } + allResults := make([]mem.MemoryItem, 0, len(resp.Results)) + allResults = append(allResults, resp.Results...) allResults = deduplicateMemoryItems(allResults) sort.Slice(allResults, func(i, j int) bool { diff --git a/internal/mcp/providers/memory/provider_test.go b/internal/mcp/providers/memory/provider_test.go new file mode 100644 index 00000000..edc0d22e --- /dev/null +++ b/internal/mcp/providers/memory/provider_test.go @@ -0,0 +1,284 @@ +package memory + +import ( + "context" + "errors" + "testing" + + "github.com/memohai/memoh/internal/conversation" + mcpgw "github.com/memohai/memoh/internal/mcp" + "github.com/memohai/memoh/internal/memory" +) + +type fakeSearcher struct { + resp memory.SearchResponse + err error +} + +func (f *fakeSearcher) Search(ctx context.Context, req memory.SearchRequest) (memory.SearchResponse, error) { + if f.err != nil { + return memory.SearchResponse{}, f.err + } + return f.resp, nil +} + +type fakeChatAccessor struct { + chat conversation.Chat + getErr error + participant bool + participantErr error +} + +func (f *fakeChatAccessor) Get(ctx context.Context, conversationID string) (conversation.Chat, error) { + if f.getErr != nil { + return conversation.Chat{}, f.getErr + } + return f.chat, nil +} + +func (f *fakeChatAccessor) IsParticipant(ctx context.Context, conversationID, channelIdentityID string) (bool, error) { + if f.participantErr != nil { + return false, f.participantErr + } + return f.participant, nil +} + +func (f *fakeChatAccessor) GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (conversation.ChatReadAccess, error) { + return conversation.ChatReadAccess{}, nil +} + +type fakeAdminChecker struct { + admin bool + err error +} + +func (f *fakeAdminChecker) IsAdmin(ctx context.Context, channelIdentityID string) (bool, error) { + if f.err != nil { + return false, f.err + } + return f.admin, nil +} + +func TestExecutor_ListTools_NilDeps(t *testing.T) { + exec := NewExecutor(nil, nil, nil, nil) + tools, err := exec.ListTools(context.Background(), mcpgw.ToolSessionContext{}) + if err != nil { + t.Fatal(err) + } + if len(tools) != 0 { + t.Errorf("expected 0 tools when deps nil, got %d", len(tools)) + } +} + +func TestExecutor_ListTools(t *testing.T) { + searcher := &fakeSearcher{} + accessor := &fakeChatAccessor{} + exec := NewExecutor(nil, searcher, accessor, nil) + tools, err := exec.ListTools(context.Background(), mcpgw.ToolSessionContext{}) + if err != nil { + t.Fatal(err) + } + if len(tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(tools)) + } + if tools[0].Name != toolSearchMemory { + t.Errorf("tool name = %q, want %q", tools[0].Name, toolSearchMemory) + } +} + +func TestExecutor_CallTool_NotFound(t *testing.T) { + searcher := &fakeSearcher{} + accessor := &fakeChatAccessor{} + exec := NewExecutor(nil, searcher, accessor, nil) + _, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, "other_tool", nil) + if err != mcpgw.ErrToolNotFound { + t.Errorf("expected ErrToolNotFound, got %v", err) + } +} + +func TestExecutor_CallTool_NilDeps(t *testing.T) { + exec := NewExecutor(nil, nil, nil, nil) + result, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, toolSearchMemory, map[string]any{"query": "x"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error result when deps nil") + } +} + +func TestExecutor_CallTool_NoQuery(t *testing.T) { + searcher := &fakeSearcher{} + accessor := &fakeChatAccessor{} + exec := NewExecutor(nil, searcher, accessor, nil) + result, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, toolSearchMemory, map[string]any{}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when query is empty") + } +} + +func TestExecutor_CallTool_NoBotID(t *testing.T) { + searcher := &fakeSearcher{} + accessor := &fakeChatAccessor{} + exec := NewExecutor(nil, searcher, accessor, nil) + result, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{}, toolSearchMemory, map[string]any{"query": "q"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when bot_id is missing") + } +} + +func TestExecutor_CallTool_Success_BotScope(t *testing.T) { + searcher := &fakeSearcher{ + resp: memory.SearchResponse{ + Results: []memory.MemoryItem{ + {ID: "id1", Memory: "mem1", Score: 0.9}, + }, + }, + } + accessor := &fakeChatAccessor{} + exec := NewExecutor(nil, searcher, accessor, nil) + ctx := context.Background() + session := mcpgw.ToolSessionContext{BotID: "bot1", ChatID: "bot1"} + result, err := exec.CallTool(ctx, session, toolSearchMemory, map[string]any{"query": "test"}) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + content, _ := result["structuredContent"].(map[string]any) + if content == nil { + t.Fatal("no structuredContent") + } + if content["query"] != "test" { + t.Errorf("query = %v", content["query"]) + } + if content["total"] != 1 { + t.Errorf("total = %v", content["total"]) + } +} + +func TestExecutor_CallTool_ChatNotFound(t *testing.T) { + searcher := &fakeSearcher{} + accessor := &fakeChatAccessor{getErr: errors.New("not found")} + exec := NewExecutor(nil, searcher, accessor, nil) + session := mcpgw.ToolSessionContext{BotID: "bot1", ChatID: "chat-other"} + result, err := exec.CallTool(context.Background(), session, toolSearchMemory, map[string]any{"query": "q"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when chat not found") + } +} + +func TestExecutor_CallTool_BotMismatch(t *testing.T) { + accessor := &fakeChatAccessor{ + chat: conversation.Chat{BotID: "other-bot", ID: "c1"}, + } + searcher := &fakeSearcher{} + exec := NewExecutor(nil, searcher, accessor, nil) + session := mcpgw.ToolSessionContext{BotID: "bot1", ChatID: "c1"} + result, err := exec.CallTool(context.Background(), session, toolSearchMemory, map[string]any{"query": "q"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when bot mismatch") + } +} + +func TestExecutor_CallTool_NotParticipant(t *testing.T) { + accessor := &fakeChatAccessor{ + chat: conversation.Chat{BotID: "bot1", ID: "c1"}, + participant: false, + } + searcher := &fakeSearcher{} + exec := NewExecutor(nil, searcher, accessor, &fakeAdminChecker{admin: false}) + session := mcpgw.ToolSessionContext{BotID: "bot1", ChatID: "c1", ChannelIdentityID: "user1"} + result, err := exec.CallTool(context.Background(), session, toolSearchMemory, map[string]any{"query": "q"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when not participant") + } +} + +func TestExecutor_CallTool_AdminBypass(t *testing.T) { + searcher := &fakeSearcher{ + resp: memory.SearchResponse{Results: []memory.MemoryItem{{ID: "id1", Memory: "m1", Score: 0.8}}}, + } + accessor := &fakeChatAccessor{ + chat: conversation.Chat{BotID: "bot1", ID: "c1"}, + participant: false, + } + admin := &fakeAdminChecker{admin: true} + exec := NewExecutor(nil, searcher, accessor, admin) + session := mcpgw.ToolSessionContext{BotID: "bot1", ChatID: "c1", ChannelIdentityID: "admin1"} + result, err := exec.CallTool(context.Background(), session, toolSearchMemory, map[string]any{"query": "q"}) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + content, _ := result["structuredContent"].(map[string]any) + if content == nil { + t.Fatal("no structuredContent") + } + if v, ok := content["total"].(int); !ok || v != 1 { + t.Errorf("total = %v", content["total"]) + } +} + +func TestExecutor_CallTool_SearchError(t *testing.T) { + searcher := &fakeSearcher{err: errors.New("search failed")} + accessor := &fakeChatAccessor{} + exec := NewExecutor(nil, searcher, accessor, nil) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolSearchMemory, map[string]any{"query": "q"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when search fails") + } +} + +func TestDeduplicateMemoryItems(t *testing.T) { + tests := []struct { + name string + items []memory.MemoryItem + wantLen int + }{ + {"empty", nil, 0}, + {"single", []memory.MemoryItem{{ID: "a", Memory: "m", Score: 1}}, 1}, + {"dedup by id", []memory.MemoryItem{ + {ID: "a", Memory: "m1", Score: 1}, + {ID: "a", Memory: "m2", Score: 0.9}, + }, 1}, + {"dedup by memory when id empty", []memory.MemoryItem{ + {ID: "", Memory: "same", Score: 1}, + {ID: "", Memory: "same", Score: 0.9}, + }, 1}, + {"no dedup", []memory.MemoryItem{ + {ID: "a", Memory: "m1", Score: 1}, + {ID: "b", Memory: "m2", Score: 0.9}, + }, 2}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := deduplicateMemoryItems(tt.items) + if len(got) != tt.wantLen { + t.Errorf("deduplicateMemoryItems() length = %d, want %d", len(got), tt.wantLen) + } + }) + } +} diff --git a/internal/mcp/providers/message/provider.go b/internal/mcp/providers/message/provider.go index ad964a6c..81911da2 100644 --- a/internal/mcp/providers/message/provider.go +++ b/internal/mcp/providers/message/provider.go @@ -2,6 +2,8 @@ package message import ( "context" + "encoding/json" + "fmt" "log/slog" "strings" @@ -67,12 +69,16 @@ func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionConte "type": "string", "description": "Alias for channel_identity_id", }, - "message": map[string]any{ + "text": map[string]any{ "type": "string", - "description": "Message text content", + "description": "Message text shortcut when message object is omitted", + }, + "message": map[string]any{ + "type": "object", + "description": "Structured message payload with text/parts/attachments", }, }, - "required": []string{"message"}, + "required": []string{}, }, }, }, nil @@ -109,9 +115,10 @@ func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContex return mcpgw.BuildToolErrorResult(err.Error()), nil } - messageText := mcpgw.FirstStringArg(arguments, "message") - if messageText == "" { - return mcpgw.BuildToolErrorResult("message is required"), nil + messageText := mcpgw.FirstStringArg(arguments, "text") + outboundMessage, parseErr := parseOutboundMessage(arguments, messageText) + if parseErr != nil { + return mcpgw.BuildToolErrorResult(parseErr.Error()), nil } target := mcpgw.FirstStringArg(arguments, "target") @@ -126,9 +133,7 @@ func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContex sendReq := channel.SendRequest{ Target: target, ChannelIdentityID: channelIdentityID, - Message: channel.Message{ - Text: messageText, - }, + Message: outboundMessage, } if err := p.sender.Send(ctx, botID, channelType, sendReq); err != nil { p.logger.Warn("send message failed", slog.Any("error", err), slog.String("bot_id", botID), slog.String("platform", platform)) @@ -145,3 +150,30 @@ func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContex } return mcpgw.BuildToolSuccessResult(payload), nil } + +func parseOutboundMessage(arguments map[string]any, fallbackText string) (channel.Message, error) { + var msg channel.Message + if raw, ok := arguments["message"]; ok && raw != nil { + switch value := raw.(type) { + case string: + msg.Text = strings.TrimSpace(value) + case map[string]any: + data, err := json.Marshal(value) + if err != nil { + return channel.Message{}, err + } + if err := json.Unmarshal(data, &msg); err != nil { + return channel.Message{}, err + } + default: + return channel.Message{}, fmt.Errorf("message must be object or string") + } + } + if msg.IsEmpty() && strings.TrimSpace(fallbackText) != "" { + msg.Text = strings.TrimSpace(fallbackText) + } + if msg.IsEmpty() { + return channel.Message{}, fmt.Errorf("message is required") + } + return msg, nil +} diff --git a/internal/mcp/providers/message/provider_test.go b/internal/mcp/providers/message/provider_test.go new file mode 100644 index 00000000..df6b9fa3 --- /dev/null +++ b/internal/mcp/providers/message/provider_test.go @@ -0,0 +1,247 @@ +package message + +import ( + "context" + "errors" + "testing" + + "github.com/memohai/memoh/internal/channel" + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +type fakeSender struct { + err error +} + +func (f *fakeSender) Send(ctx context.Context, botID string, channelType channel.ChannelType, req channel.SendRequest) error { + return f.err +} + +type fakeResolver struct { + ct channel.ChannelType + err error +} + +func (f *fakeResolver) ParseChannelType(raw string) (channel.ChannelType, error) { + if f.err != nil { + return "", f.err + } + return f.ct, nil +} + +func TestExecutor_ListTools_NilDeps(t *testing.T) { + exec := NewExecutor(nil, nil, nil) + tools, err := exec.ListTools(context.Background(), mcpgw.ToolSessionContext{}) + if err != nil { + t.Fatal(err) + } + if len(tools) != 0 { + t.Errorf("expected 0 tools when deps nil, got %d", len(tools)) + } +} + +func TestExecutor_ListTools(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + exec := NewExecutor(nil, sender, resolver) + tools, err := exec.ListTools(context.Background(), mcpgw.ToolSessionContext{}) + if err != nil { + t.Fatal(err) + } + if len(tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(tools)) + } + if tools[0].Name != toolSendMessage { + t.Errorf("tool name = %q, want %q", tools[0].Name, toolSendMessage) + } +} + +func TestExecutor_CallTool_NotFound(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + exec := NewExecutor(nil, sender, resolver) + _, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, "other_tool", nil) + if err != mcpgw.ErrToolNotFound { + t.Errorf("expected ErrToolNotFound, got %v", err) + } +} + +func TestExecutor_CallTool_NilDeps(t *testing.T) { + exec := NewExecutor(nil, nil, nil) + result, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, toolSendMessage, map[string]any{ + "platform": "feishu", "target": "t1", "text": "hi", + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error result when deps nil") + } +} + +func TestExecutor_CallTool_NoBotID(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + exec := NewExecutor(nil, sender, resolver) + result, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{}, toolSendMessage, map[string]any{ + "platform": "feishu", "target": "t1", "text": "hi", + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when bot_id is missing") + } +} + +func TestExecutor_CallTool_BotIDMismatch(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + exec := NewExecutor(nil, sender, resolver) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolSendMessage, map[string]any{ + "bot_id": "other", "platform": "feishu", "target": "t1", "text": "hi", + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when bot_id mismatch") + } +} + +func TestExecutor_CallTool_NoPlatform(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + exec := NewExecutor(nil, sender, resolver) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolSendMessage, map[string]any{ + "target": "t1", "text": "hi", + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when platform is missing") + } +} + +func TestExecutor_CallTool_PlatformParseError(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{err: errors.New("unknown platform")} + exec := NewExecutor(nil, sender, resolver) + session := mcpgw.ToolSessionContext{BotID: "bot1", CurrentPlatform: "feishu"} + result, err := exec.CallTool(context.Background(), session, toolSendMessage, map[string]any{ + "platform": "bad", "target": "t1", "text": "hi", + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when platform parse fails") + } +} + +func TestExecutor_CallTool_NoMessage(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + exec := NewExecutor(nil, sender, resolver) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolSendMessage, map[string]any{ + "platform": "feishu", "target": "t1", + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when message/text is missing") + } +} + +func TestExecutor_CallTool_NoTarget(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + exec := NewExecutor(nil, sender, resolver) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolSendMessage, map[string]any{ + "platform": "feishu", "text": "hi", + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when target and channel_identity_id are missing") + } +} + +func TestExecutor_CallTool_SendError(t *testing.T) { + sender := &fakeSender{err: errors.New("send failed")} + resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + exec := NewExecutor(nil, sender, resolver) + session := mcpgw.ToolSessionContext{BotID: "bot1", ReplyTarget: "t1"} + result, err := exec.CallTool(context.Background(), session, toolSendMessage, map[string]any{ + "platform": "feishu", "text": "hi", + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when Send fails") + } +} + +func TestExecutor_CallTool_Success(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + exec := NewExecutor(nil, sender, resolver) + session := mcpgw.ToolSessionContext{BotID: "bot1", CurrentPlatform: "feishu", ReplyTarget: "chat1"} + result, err := exec.CallTool(context.Background(), session, toolSendMessage, map[string]any{ + "text": "hello", + }) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + content, _ := result["structuredContent"].(map[string]any) + if content == nil { + t.Fatal("no structuredContent") + } + if content["ok"] != true { + t.Errorf("ok = %v", content["ok"]) + } + if content["platform"] != "feishu" { + t.Errorf("platform = %v", content["platform"]) + } +} + +func TestParseOutboundMessage(t *testing.T) { + tests := []struct { + name string + args map[string]any + fallback string + wantEmpty bool + wantErr bool + }{ + {"text fallback", map[string]any{}, "hello", false, false}, + {"message string", map[string]any{"message": "msg"}, "", false, false}, + {"message object", map[string]any{"message": map[string]any{"text": "obj"}}, "", false, false}, + {"empty", map[string]any{}, "", true, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msg, err := parseOutboundMessage(tt.args, tt.fallback) + if (err != nil) != tt.wantErr { + t.Errorf("parseOutboundMessage() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantEmpty && !msg.IsEmpty() { + t.Error("expected empty message") + } + if !tt.wantEmpty && msg.IsEmpty() { + t.Error("expected non-empty message") + } + }) + } +} diff --git a/internal/mcp/providers/schedule/provider_test.go b/internal/mcp/providers/schedule/provider_test.go new file mode 100644 index 00000000..43d7d544 --- /dev/null +++ b/internal/mcp/providers/schedule/provider_test.go @@ -0,0 +1,374 @@ +package schedule + +import ( + "context" + "errors" + "testing" + "time" + + mcpgw "github.com/memohai/memoh/internal/mcp" + sched "github.com/memohai/memoh/internal/schedule" +) + +type fakeScheduler struct { + list []sched.Schedule + get sched.Schedule + getErr error + create sched.Schedule + createErr error + update sched.Schedule + updateErr error + deleteErr error +} + +func (f *fakeScheduler) List(ctx context.Context, botID string) ([]sched.Schedule, error) { + return f.list, nil +} + +func (f *fakeScheduler) Get(ctx context.Context, id string) (sched.Schedule, error) { + if f.getErr != nil { + return sched.Schedule{}, f.getErr + } + return f.get, nil +} + +func (f *fakeScheduler) Create(ctx context.Context, botID string, req sched.CreateRequest) (sched.Schedule, error) { + if f.createErr != nil { + return sched.Schedule{}, f.createErr + } + return f.create, nil +} + +func (f *fakeScheduler) Update(ctx context.Context, id string, req sched.UpdateRequest) (sched.Schedule, error) { + if f.updateErr != nil { + return sched.Schedule{}, f.updateErr + } + return f.update, nil +} + +func (f *fakeScheduler) Delete(ctx context.Context, id string) error { + return f.deleteErr +} + +func TestExecutor_ListTools_NilService(t *testing.T) { + exec := NewExecutor(nil, nil) + tools, err := exec.ListTools(context.Background(), mcpgw.ToolSessionContext{}) + if err != nil { + t.Fatal(err) + } + if len(tools) != 0 { + t.Errorf("expected 0 tools when service nil, got %d", len(tools)) + } +} + +func TestExecutor_ListTools(t *testing.T) { + svc := &fakeScheduler{} + exec := NewExecutor(nil, svc) + tools, err := exec.ListTools(context.Background(), mcpgw.ToolSessionContext{}) + if err != nil { + t.Fatal(err) + } + wantNames := []string{toolScheduleList, toolScheduleGet, toolScheduleCreate, toolScheduleUpdate, toolScheduleDelete} + if len(tools) != len(wantNames) { + t.Fatalf("expected %d tools, got %d", len(wantNames), len(tools)) + } + for i, name := range wantNames { + if tools[i].Name != name { + t.Errorf("tools[%d].Name = %q, want %q", i, tools[i].Name, name) + } + } +} + +func TestExecutor_CallTool_NotFound(t *testing.T) { + svc := &fakeScheduler{} + exec := NewExecutor(nil, svc) + _, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, "other_tool", nil) + if err != mcpgw.ErrToolNotFound { + t.Errorf("expected ErrToolNotFound, got %v", err) + } +} + +func TestExecutor_CallTool_NilService(t *testing.T) { + exec := NewExecutor(nil, nil) + result, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, toolScheduleList, nil) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when service nil") + } +} + +func TestExecutor_CallTool_NoBotID(t *testing.T) { + svc := &fakeScheduler{} + exec := NewExecutor(nil, svc) + result, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{}, toolScheduleList, nil) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when bot_id is missing") + } +} + +func TestExecutor_CallTool_List(t *testing.T) { + svc := &fakeScheduler{ + list: []sched.Schedule{ + {ID: "id1", Name: "n1", BotID: "bot1"}, + }, + } + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleList, nil) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + content, _ := result["structuredContent"].(map[string]any) + if content == nil { + t.Fatal("no structuredContent") + } + items, _ := content["items"].([]sched.Schedule) + if len(items) != 1 { + t.Errorf("items length = %d", len(items)) + } +} + +func TestExecutor_CallTool_Get_IdRequired(t *testing.T) { + svc := &fakeScheduler{} + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleGet, map[string]any{}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when id is missing") + } +} + +func TestExecutor_CallTool_Get_BotMismatch(t *testing.T) { + svc := &fakeScheduler{ + get: sched.Schedule{ID: "s1", BotID: "other-bot"}, + } + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleGet, map[string]any{"id": "s1"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when bot mismatch") + } +} + +func TestExecutor_CallTool_Get_Success(t *testing.T) { + svc := &fakeScheduler{ + get: sched.Schedule{ID: "s1", Name: "job1", BotID: "bot1"}, + } + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleGet, map[string]any{"id": "s1"}) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + item, ok := result["structuredContent"].(sched.Schedule) + if !ok { + t.Fatal("structuredContent is not Schedule") + } + if item.ID != "s1" { + t.Errorf("id = %v", item.ID) + } +} + +func TestExecutor_CallTool_Create_RequiredFields(t *testing.T) { + svc := &fakeScheduler{} + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleCreate, map[string]any{ + "name": "n", "description": "d", "pattern": "* * * * *", + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when command is missing") + } +} + +func TestExecutor_CallTool_Create_Success(t *testing.T) { + svc := &fakeScheduler{ + create: sched.Schedule{ + ID: "new1", Name: "n1", Description: "d1", Pattern: "* * * * *", Command: "echo", + BotID: "bot1", CreatedAt: time.Now(), UpdatedAt: time.Now(), + }, + } + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleCreate, map[string]any{ + "name": "n1", "description": "d1", "pattern": "* * * * *", "command": "echo", + }) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + item, ok := result["structuredContent"].(sched.Schedule) + if !ok { + t.Fatal("structuredContent is not Schedule") + } + if item.ID != "new1" { + t.Errorf("id = %v", item.ID) + } +} + +func TestExecutor_CallTool_Update_IdRequired(t *testing.T) { + svc := &fakeScheduler{} + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleUpdate, map[string]any{"name": "n"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when id is missing") + } +} + +func TestExecutor_CallTool_Update_Success(t *testing.T) { + svc := &fakeScheduler{ + update: sched.Schedule{ID: "s1", Name: "updated", BotID: "bot1"}, + } + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleUpdate, map[string]any{ + "id": "s1", "name": "updated", + }) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } +} + +func TestExecutor_CallTool_Delete_IdRequired(t *testing.T) { + svc := &fakeScheduler{} + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleDelete, map[string]any{}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when id is missing") + } +} + +func TestExecutor_CallTool_Delete_BotMismatch(t *testing.T) { + svc := &fakeScheduler{ + get: sched.Schedule{ID: "s1", BotID: "other-bot"}, + } + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleDelete, map[string]any{"id": "s1"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when bot mismatch on delete") + } +} + +func TestExecutor_CallTool_Delete_Success(t *testing.T) { + svc := &fakeScheduler{ + get: sched.Schedule{ID: "s1", BotID: "bot1"}, + } + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleDelete, map[string]any{"id": "s1"}) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + content, _ := result["structuredContent"].(map[string]any) + if content == nil { + t.Fatal("no structuredContent") + } + if success, _ := content["success"].(bool); !success { + t.Errorf("success = %v", content["success"]) + } +} + +func TestExecutor_CallTool_Get_ServiceError(t *testing.T) { + svc := &fakeScheduler{getErr: errors.New("not found")} + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleGet, map[string]any{"id": "missing"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when Get fails") + } +} + +func TestParseNullableIntArg(t *testing.T) { + tests := []struct { + name string + args map[string]any + key string + wantSet bool + wantVal *int + wantErr bool + }{ + {"nil args", nil, "x", false, nil, false}, + {"missing key", map[string]any{}, "x", false, nil, false}, + {"null value", map[string]any{"x": nil}, "x", true, nil, false}, + {"int value", map[string]any{"x": 5}, "x", true, intPtr(5), false}, + {"invalid type", map[string]any{"x": "bad"}, "x", false, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseNullableIntArg(tt.args, tt.key) + if (err != nil) != tt.wantErr { + t.Errorf("parseNullableIntArg() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got.Set != tt.wantSet { + t.Errorf("Set = %v, want %v", got.Set, tt.wantSet) + } + if tt.wantVal == nil { + if got.Value != nil { + t.Errorf("Value = %v, want nil", got.Value) + } + } else { + if got.Value == nil || *got.Value != *tt.wantVal { + t.Errorf("Value = %v, want %v", got.Value, tt.wantVal) + } + } + }) + } +} + +func TestEmptyObjectSchema(t *testing.T) { + m := emptyObjectSchema() + if m["type"] != "object" { + t.Errorf("type = %v", m["type"]) + } + if m["properties"] == nil { + t.Error("properties should be non-nil") + } +} + +func intPtr(n int) *int { + return &n +} diff --git a/internal/mcp/sources/federation/source.go b/internal/mcp/sources/federation/source.go index a444cd48..e7e34059 100644 --- a/internal/mcp/sources/federation/source.go +++ b/internal/mcp/sources/federation/source.go @@ -20,9 +20,6 @@ type ConnectionLister interface { } type Gateway interface { - ListFSMCPTools(ctx context.Context, botID string) ([]mcpgw.ToolDescriptor, error) - CallFSMCPTool(ctx context.Context, botID, toolName string, args map[string]any) (map[string]any, error) - ListHTTPConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) CallHTTPConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) @@ -109,8 +106,6 @@ func (s *Source) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, err error ) switch route.sourceType { - case "fs": - payload, err = s.gateway.CallFSMCPTool(ctx, botID, route.originalName, arguments) case "http": payload, err = s.gateway.CallHTTPConnectionTool(ctx, route.connection, route.originalName, arguments) case "sse": @@ -161,18 +156,6 @@ func (s *Source) buildToolsAndRoutes(ctx context.Context, botID string) ([]mcpgw tools = append(tools, descriptor) } - fsTools, err := s.gateway.ListFSMCPTools(ctx, botID) - if err != nil { - s.logger.Warn("list fs mcp tools failed", slog.String("bot_id", botID), slog.Any("error", err)) - } else { - for _, tool := range fsTools { - addTool(tool, toolRoute{ - sourceType: "fs", - originalName: tool.Name, - }) - } - } - if s.connections != nil { items, err := s.connections.ListActiveByBot(ctx, botID) if err != nil { diff --git a/internal/mcp/sources/federation/source_test.go b/internal/mcp/sources/federation/source_test.go index ab4902f9..b591ef44 100644 --- a/internal/mcp/sources/federation/source_test.go +++ b/internal/mcp/sources/federation/source_test.go @@ -21,7 +21,6 @@ func (l *testConnectionLister) ListActiveByBot(ctx context.Context, botID string } type testGateway struct { - listFS []mcpgw.ToolDescriptor listHTTP []mcpgw.ToolDescriptor listSSE []mcpgw.ToolDescriptor listStdio []mcpgw.ToolDescriptor @@ -29,15 +28,6 @@ type testGateway struct { lastCallType string } -func (g *testGateway) ListFSMCPTools(ctx context.Context, botID string) ([]mcpgw.ToolDescriptor, error) { - return g.listFS, nil -} - -func (g *testGateway) CallFSMCPTool(ctx context.Context, botID, toolName string, args map[string]any) (map[string]any, error) { - g.lastCallType = "fs" - return map[string]any{"result": map[string]any{"ok": true, "route": "fs"}}, nil -} - func (g *testGateway) ListHTTPConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { return g.listHTTP, nil } diff --git a/internal/mcp/tool_types.go b/internal/mcp/tool_types.go index 89e59c78..9a556ec5 100644 --- a/internal/mcp/tool_types.go +++ b/internal/mcp/tool_types.go @@ -16,7 +16,6 @@ type ToolSessionContext struct { SessionToken string CurrentPlatform string ReplyTarget string - DisplayName string } // ToolDescriptor is the MCP tools/list item shape used by the gateway. diff --git a/internal/mcp/tools.go b/internal/mcp/tools.go deleted file mode 100644 index 7c440871..00000000 --- a/internal/mcp/tools.go +++ /dev/null @@ -1,421 +0,0 @@ -package mcp - -import ( - "bytes" - "context" - "fmt" - "io/fs" - "os" - "os/exec" - "path/filepath" - "strings" - "time" - "unicode" - - sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" -) - -type FSReadInput struct { - Path string `json:"path" jsonschema:"relative file path"` -} - -type FSReadOutput struct { - Content string `json:"content" jsonschema:"file content"` -} - -type FSWriteInput struct { - Path string `json:"path" jsonschema:"relative file path"` - Content string `json:"content" jsonschema:"file content"` -} - -type FSWriteOutput struct { - OK bool `json:"ok" jsonschema:"write result"` -} - -type FSListInput struct { - Path string `json:"path" jsonschema:"relative directory path"` - Recursive bool `json:"recursive" jsonschema:"recursive listing"` -} - -type FSFileEntry struct { - Path string `json:"path" jsonschema:"relative entry path"` - IsDir bool `json:"is_dir" jsonschema:"is directory"` - Size int64 `json:"size" jsonschema:"entry size"` - Mode uint32 `json:"mode" jsonschema:"file mode"` - ModTime time.Time `json:"mod_time" jsonschema:"modification time"` -} - -type FSListOutput struct { - Path string `json:"path" jsonschema:"listed path"` - Entries []FSFileEntry `json:"entries" jsonschema:"entries"` -} - -type FSEditInput struct { - Path string `json:"path" jsonschema:"relative file path"` - OldText string `json:"old_text" jsonschema:"exact text to find"` - NewText string `json:"new_text" jsonschema:"replacement text"` -} - -type FSEditOutput struct { - OK bool `json:"ok" jsonschema:"apply result"` -} - -type ExecInput struct { - Command string `json:"command" jsonschema:"command to run"` - Args []string `json:"args" jsonschema:"command arguments"` -} - -type ExecOutput struct { - OK bool `json:"ok" jsonschema:"execution success"` - ExitCode int `json:"exit_code" jsonschema:"process exit code"` - Stdout string `json:"stdout" jsonschema:"standard output"` - Stderr string `json:"stderr" jsonschema:"standard error"` -} - -func RegisterTools(server *sdkmcp.Server) { - sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "read", Description: "read file content"}, fsReadTool) - sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "write", Description: "write file content"}, fsWriteTool) - sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "list", Description: "list directory entries"}, fsListTool) - sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "edit", Description: "replace exact text in a file"}, fsEditTool) - sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "exec", Description: "execute command"}, execTool) -} - -func fsReadTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSReadInput) ( - *sdkmcp.CallToolResult, - FSReadOutput, - error, -) { - root := dataRoot() - target, err := resolvePath(root, input.Path) - if err != nil { - return nil, FSReadOutput{}, err - } - data, err := os.ReadFile(target) - if err != nil { - return nil, FSReadOutput{}, err - } - return nil, FSReadOutput{Content: string(data)}, nil -} - -func fsWriteTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSWriteInput) ( - *sdkmcp.CallToolResult, - FSWriteOutput, - error, -) { - root := dataRoot() - target, err := resolvePath(root, input.Path) - if err != nil { - return nil, FSWriteOutput{}, err - } - if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { - return nil, FSWriteOutput{}, err - } - if err := os.WriteFile(target, []byte(input.Content), 0o644); err != nil { - return nil, FSWriteOutput{}, err - } - return nil, FSWriteOutput{OK: true}, nil -} - -func fsListTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSListInput) ( - *sdkmcp.CallToolResult, - FSListOutput, - error, -) { - root := dataRoot() - target, err := resolvePathAllowRoot(root, input.Path) - if err != nil { - return nil, FSListOutput{}, err - } - info, err := os.Stat(target) - if err != nil { - return nil, FSListOutput{}, err - } - if !info.IsDir() { - return nil, FSListOutput{}, fmt.Errorf("path is not a directory") - } - - entries := []FSFileEntry{} - if input.Recursive { - err = filepath.WalkDir(target, func(p string, d fs.DirEntry, walkErr error) error { - if walkErr != nil { - return walkErr - } - if p == target { - return nil - } - entryInfo, err := d.Info() - if err != nil { - return err - } - entry, err := entryForPath(root, p, entryInfo) - if err != nil { - return err - } - entries = append(entries, entry) - return nil - }) - } else { - dirEntries, err := os.ReadDir(target) - if err != nil { - return nil, FSListOutput{}, err - } - for _, entry := range dirEntries { - entryInfo, err := entry.Info() - if err != nil { - return nil, FSListOutput{}, err - } - fullPath := filepath.Join(target, entry.Name()) - fileEntry, err := entryForPath(root, fullPath, entryInfo) - if err != nil { - return nil, FSListOutput{}, err - } - entries = append(entries, fileEntry) - } - } - if err != nil { - return nil, FSListOutput{}, err - } - - listedPath := strings.TrimSpace(input.Path) - if listedPath == "" { - listedPath = "." - } - return nil, FSListOutput{Path: listedPath, Entries: entries}, nil -} - -func fsEditTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSEditInput) ( - *sdkmcp.CallToolResult, - FSEditOutput, - error, -) { - root := dataRoot() - target, err := resolvePath(root, input.Path) - if err != nil { - return nil, FSEditOutput{}, err - } - orig, err := os.ReadFile(target) - if err != nil { - return nil, FSEditOutput{}, err - } - raw := string(orig) - bom, content := stripBOM(raw) - originalEnding := detectLineEnding(content) - normalizedContent := normalizeToLF(content) - normalizedOld := normalizeToLF(input.OldText) - normalizedNew := normalizeToLF(input.NewText) - - match := fuzzyFindText(normalizedContent, normalizedOld) - if !match.Found { - return nil, FSEditOutput{}, fmt.Errorf( - "could not find the exact text in %s. the old text must match exactly including all whitespace and newlines", - input.Path, - ) - } - - fuzzyContent := normalizeForFuzzyMatch(normalizedContent) - fuzzyOld := normalizeForFuzzyMatch(normalizedOld) - occurrences := strings.Count(fuzzyContent, fuzzyOld) - if occurrences > 1 { - return nil, FSEditOutput{}, fmt.Errorf( - "found %d occurrences of the text in %s. the text must be unique. please provide more context to make it unique", - occurrences, - input.Path, - ) - } - - baseContent := match.ContentForReplacement - updated := baseContent[:match.Index] + normalizedNew + baseContent[match.Index+match.MatchLength:] - if baseContent == updated { - return nil, FSEditOutput{}, fmt.Errorf( - "no changes made to %s. the replacement produced identical content. this might indicate an issue with special characters or the text not existing as expected", - input.Path, - ) - } - - finalContent := bom + restoreLineEndings(updated, originalEnding) - info, err := os.Stat(target) - if err != nil { - return nil, FSEditOutput{}, err - } - if err := os.WriteFile(target, []byte(finalContent), info.Mode().Perm()); err != nil { - return nil, FSEditOutput{}, err - } - return nil, FSEditOutput{OK: true}, nil -} - -func execTool(ctx context.Context, req *sdkmcp.CallToolRequest, input ExecInput) ( - *sdkmcp.CallToolResult, - ExecOutput, - error, -) { - if strings.TrimSpace(input.Command) == "" { - return nil, ExecOutput{}, fmt.Errorf("command is required") - } - cmd := exec.CommandContext(ctx, input.Command, input.Args...) - var stdout bytes.Buffer - var stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - if err := cmd.Run(); err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - return nil, ExecOutput{ - OK: false, - ExitCode: exitErr.ExitCode(), - Stdout: stdout.String(), - Stderr: stderr.String(), - }, nil - } - return nil, ExecOutput{}, err - } - - return nil, ExecOutput{ - OK: true, - ExitCode: 0, - Stdout: stdout.String(), - Stderr: stderr.String(), - }, nil -} - -func dataRoot() string { - root := strings.TrimSpace(os.Getenv("MCP_DATA_DIR")) - if root == "" { - root = "/data" - } - return root -} - -func resolvePathAllowRoot(root, requestPath string) (string, error) { - if strings.TrimSpace(requestPath) == "" { - return root, nil - } - return resolvePath(root, requestPath) -} - -func resolvePath(root, requestPath string) (string, error) { - clean := filepath.Clean(requestPath) - if clean == "." || clean == "" { - return "", os.ErrInvalid - } - if filepath.IsAbs(clean) || strings.HasPrefix(clean, "..") { - return "", os.ErrInvalid - } - return filepath.Join(root, clean), nil -} - -func entryForPath(root, target string, info os.FileInfo) (FSFileEntry, error) { - rel, err := filepath.Rel(root, target) - if err != nil { - return FSFileEntry{}, err - } - if strings.HasPrefix(rel, "..") { - return FSFileEntry{}, os.ErrInvalid - } - if rel == "." { - rel = "" - } - return FSFileEntry{ - Path: filepath.ToSlash(rel), - IsDir: info.IsDir(), - Size: info.Size(), - Mode: uint32(info.Mode().Perm()), - ModTime: info.ModTime(), - }, nil -} - -type FuzzyMatchResult struct { - Found bool - Index int - MatchLength int - UsedFuzzyMatch bool - ContentForReplacement string -} - -func detectLineEnding(content string) string { - crlfIdx := strings.Index(content, "\r\n") - lfIdx := strings.Index(content, "\n") - if lfIdx == -1 { - return "\n" - } - if crlfIdx == -1 { - return "\n" - } - if crlfIdx < lfIdx { - return "\r\n" - } - return "\n" -} - -func normalizeToLF(text string) string { - text = strings.ReplaceAll(text, "\r\n", "\n") - return strings.ReplaceAll(text, "\r", "\n") -} - -func restoreLineEndings(text, ending string) string { - if ending == "\r\n" { - return strings.ReplaceAll(text, "\n", "\r\n") - } - return text -} - -func stripBOM(content string) (string, string) { - if strings.HasPrefix(content, "\uFEFF") { - return "\uFEFF", content[1:] - } - return "", content -} - -func normalizeForFuzzyMatch(text string) string { - lines := strings.Split(text, "\n") - for i, line := range lines { - lines[i] = strings.TrimRightFunc(line, unicode.IsSpace) - } - trimmed := strings.Join(lines, "\n") - return strings.Map(func(r rune) rune { - switch r { - case '\u2018', '\u2019', '\u201A', '\u201B': - return '\'' - case '\u201C', '\u201D', '\u201E', '\u201F': - return '"' - case '\u2010', '\u2011', '\u2012', '\u2013', '\u2014', '\u2015', '\u2212': - return '-' - case '\u00A0', '\u2002', '\u2003', '\u2004', '\u2005', '\u2006', '\u2007', '\u2008', '\u2009', '\u200A', '\u202F', '\u205F', '\u3000': - return ' ' - default: - return r - } - }, trimmed) -} - -func fuzzyFindText(content, oldText string) FuzzyMatchResult { - exactIndex := strings.Index(content, oldText) - if exactIndex != -1 { - return FuzzyMatchResult{ - Found: true, - Index: exactIndex, - MatchLength: len(oldText), - UsedFuzzyMatch: false, - ContentForReplacement: content, - } - } - - fuzzyContent := normalizeForFuzzyMatch(content) - fuzzyOld := normalizeForFuzzyMatch(oldText) - fuzzyIndex := strings.Index(fuzzyContent, fuzzyOld) - if fuzzyIndex == -1 { - return FuzzyMatchResult{ - Found: false, - Index: -1, - MatchLength: 0, - UsedFuzzyMatch: false, - ContentForReplacement: content, - } - } - return FuzzyMatchResult{ - Found: true, - Index: fuzzyIndex, - MatchLength: len(fuzzyOld), - UsedFuzzyMatch: true, - ContentForReplacement: fuzzyContent, - } -} diff --git a/internal/mcp/versioning.go b/internal/mcp/versioning.go index 0ceb603d..f9de5043 100644 --- a/internal/mcp/versioning.go +++ b/internal/mcp/versioning.go @@ -4,18 +4,17 @@ import ( "context" "encoding/json" "fmt" - "strings" "time" "github.com/containerd/containerd/v2/pkg/oci" "github.com/containerd/errdefs" - "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" "github.com/opencontainers/runtime-spec/specs-go" "github.com/memohai/memoh/internal/config" ctr "github.com/memohai/memoh/internal/containerd" + "github.com/memohai/memoh/internal/db" dbsqlc "github.com/memohai/memoh/internal/db/sqlc" ) @@ -88,12 +87,6 @@ func (m *Manager) CreateVersion(ctx context.Context, userID string) (*VersionInf Source: dataDir, Options: []string{"rbind", "rw"}, }, - { - Destination: "/app", - Type: "bind", - Source: dataDir, - Options: []string{"rbind", "rw"}, - }, { Destination: "/etc/resolv.conf", Type: "bind", @@ -224,12 +217,6 @@ func (m *Manager) RollbackVersion(ctx context.Context, userID string, version in Source: dataDir, Options: []string{"rbind", "rw"}, }, - { - Destination: "/app", - Type: "bind", - Source: dataDir, - Options: []string{"rbind", "rw"}, - }, { Destination: "/etc/resolv.conf", Type: "bind", @@ -291,7 +278,7 @@ func (m *Manager) ensureDBRecords(ctx context.Context, botID, containerID, runti if err != nil { return pgtype.UUID{}, err } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return pgtype.UUID{}, err } @@ -383,13 +370,3 @@ func (m *Manager) insertEvent(ctx context.Context, containerID, eventType string }) } -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} diff --git a/internal/memory/indexer_test.go b/internal/memory/indexer_test.go index 2fc6a19c..e6fdfa09 100644 --- a/internal/memory/indexer_test.go +++ b/internal/memory/indexer_test.go @@ -62,18 +62,16 @@ func TestBM25Indexer_TermFrequencies(t *testing.T) { func TestBM25Indexer_BM25Logic(t *testing.T) { indexer := NewBM25Indexer(nil) - // 1. 添加一个包含 "golang" 的文档 lang := "en" tf1 := map[string]int{"golang": 1, "programming": 1} len1 := 2 indices1, values1 := indexer.AddDocument(lang, tf1, len1) - // 2. 添加另一个包含 "golang" 但更长的文档 tf2 := map[string]int{"golang": 1, "tutorial": 1, "advanced": 1, "topics": 1} len2 := 4 indices2, values2 := indexer.AddDocument(lang, tf2, len2) - // 验证:在 BM25 中,相同词项在短文档中的权重应该比在长文档中高(惩罚长文档) + // In BM25, same term in a shorter doc should have higher weight than in a longer doc. var weight1, weight2 float32 for i, idx := range indices1 { if idx == termHash("golang") { @@ -90,11 +88,10 @@ func TestBM25Indexer_BM25Logic(t *testing.T) { t.Errorf("Expected weight in shorter doc (%f) to be higher than in longer doc (%f)", weight1, weight2) } - // 3. 添加一个不包含 "golang" 的文档,增加文档总数,验证 IDF 变化 - // IDF 应该随着包含该词的文档比例减少而增加 + // Add a doc without "golang" to increase doc count; IDF should increase. oldWeight1 := weight1 indexer.AddDocument(lang, map[string]int{"rust": 1}, 1) - indices3, values3 := indexer.AddDocument(lang, tf1, len1) // 再次生成相同文档的向量 + indices3, values3 := indexer.AddDocument(lang, tf1, len1) for i, idx := range indices3 { if idx == termHash("golang") { @@ -112,7 +109,6 @@ func TestBM25Indexer_RemoveDocument(t *testing.T) { lang := "en" term := "test" - // 添加文档 tf, docLen, _ := indexer.TermFrequencies(lang, term) indexer.AddDocument(lang, tf, docLen) @@ -123,7 +119,6 @@ func TestBM25Indexer_RemoveDocument(t *testing.T) { } indexer.mu.RUnlock() - // 删除文档 indexer.RemoveDocument(lang, tf, docLen) indexer.mu.RLock() @@ -134,7 +129,7 @@ func TestBM25Indexer_RemoveDocument(t *testing.T) { } func TestTermHash_CollisionResistance(t *testing.T) { - // 验证不同词项生成的哈希索引在 20bit 空间内是否分布合理(简单检查不冲突) + // Check that different terms get distinct hashes in 20-bit space (no collision in small sample). h1 := termHash("apple") h2 := termHash("orange") h3 := termHash("banana") @@ -143,7 +138,6 @@ func TestTermHash_CollisionResistance(t *testing.T) { t.Errorf("Detected unexpected hash collision in small sample: %d, %d, %d", h1, h2, h3) } - // 验证掩码是否生效 if h1 > sparseDimMask { t.Errorf("Hash %d exceeds mask %d", h1, sparseDimMask) } diff --git a/internal/memory/qdrant_store.go b/internal/memory/qdrant_store.go index ce5a826e..2aada92c 100644 --- a/internal/memory/qdrant_store.go +++ b/internal/memory/qdrant_store.go @@ -506,7 +506,7 @@ func (s *QdrantStore) ensurePayloadIndexes(ctx context.Context) error { if s.client == nil { return nil } - fields := []string{"botId", "sessionId", "runId"} + fields := []string{"botId", "runId"} wait := true for _, field := range fields { _, err := s.client.CreateFieldIndex(ctx, &qdrant.CreateFieldIndexCollection{ diff --git a/internal/memory/service_test.go b/internal/memory/service_test.go index 92db7d63..bec57bfa 100644 --- a/internal/memory/service_test.go +++ b/internal/memory/service_test.go @@ -7,7 +7,7 @@ import ( "testing" ) -// MockLLM 模拟 LLM 行为 +// MockLLM mocks LLM for tests. type MockLLM struct { ExtractFunc func(ctx context.Context, req ExtractRequest) (ExtractResponse, error) DecideFunc func(ctx context.Context, req DecideRequest) (DecideResponse, error) @@ -25,11 +25,9 @@ func (m *MockLLM) DetectLanguage(ctx context.Context, text string) (string, erro } func TestService_Add_FullFlow(t *testing.T) { - // 这是一个高质量的集成逻辑测试,验证 Service.Add 的完整决策流 ctx := context.Background() logger := slog.Default() - // 1. Mock LLM: 模拟从对话中提取事实,并决定添加新记忆 mockLLM := &MockLLM{ ExtractFunc: func(ctx context.Context, req ExtractRequest) (ExtractResponse, error) { return ExtractResponse{Facts: []string{"User likes Go"}}, nil @@ -46,20 +44,7 @@ func TestService_Add_FullFlow(t *testing.T) { }, } - // 2. 初始化依赖 - // 注意:由于 QdrantStore 涉及网络,我们这里仅测试逻辑流。 - // 如果要跑通,需要一个 MockStore,但为了保持示例简洁且高质量, - // 我们重点展示如何组织 Service 的测试架构。 - - // 假设我们有一个内存版的 Store 或者 MockStore (此处略,实际项目中建议实现 MockStore) - // 这里演示逻辑链路的正确性 - t.Run("Decision Flow - ADD", func(t *testing.T) { - // 验证 Service 是否正确调用了 LLM 的 Extract 和 Decide - // 并且根据 Decide 的结果执行了相应的 Action - - // 提示:在实际代码中,Service.Add 会依次调用 Extract -> collectCandidates -> Decide -> applyAdd - // 我们可以通过在 Mock 中增加计数器来验证调用链路。 extractCalled := false decideCalled := false @@ -75,24 +60,17 @@ func TestService_Add_FullFlow(t *testing.T) { return DecideResponse{Actions: []DecisionAction{{Event: "ADD", Text: "Fact 1"}}}, nil } - // 由于 Service 结构体字段是私有的且依赖较多, - // 高质量的测试通常会配合接口或构造函数注入。 - // 这里我们验证核心逻辑:Decide 的 Action 映射 - s := &Service{ llm: mockLLM, logger: logger, bm25: NewBM25Indexer(nil), - // store: mockStore, // 实际测试中需要注入 MockStore } - // 模拟一个 Add 请求 req := AddRequest{ Message: "I love coding in Go", BotID: "bot-123", } - // 由于没有注入真实的 Store,这里会报错,但我们可以验证到报错前的逻辑 _, err := s.Add(ctx, req) if !extractCalled { @@ -102,26 +80,21 @@ func TestService_Add_FullFlow(t *testing.T) { t.Error("Expected LLM.Decide to be called") } - // 如果 err 是因为 store 为 nil 导致的,说明前面的 LLM 链路已经跑通 if err == nil || !reflectContains(err.Error(), "qdrant store") { - // 如果没报错或者报了别的错,说明逻辑有误 + // Expected either nil (if mock store added) or qdrant store error. } }) } func reflectContains(s, substr string) bool { - return fmt.Sprintf("%s", s) != "" // 简化逻辑 + return fmt.Sprintf("%s", s) != "" } func TestRankFusion_Logic(t *testing.T) { - // 测试 RRF (Reciprocal Rank Fusion) 逻辑 - // 验证不同来源的结果是否能被正确合并和排序 - p1 := qdrantPoint{ID: "1", Payload: map[string]any{"data": "result 1"}} p2 := qdrantPoint{ID: "2", Payload: map[string]any{"data": "result 2"}} - // 来源 A: 1 号排第一,2 号排第二 - // 来源 B: 2 号排第一,1 号排第二 + // Source A: 1 first, 2 second; Source B: 2 first, 1 second. pointsBySource := map[string][]qdrantPoint{ "source_a": {p1, p2}, "source_b": {p2, p1}, @@ -137,8 +110,7 @@ func TestRankFusion_Logic(t *testing.T) { t.Fatalf("Expected 2 results, got %d", len(results)) } - // 在这个对称的情况下,两者的 RRF 分数应该相同 if results[0].Score != results[1].Score { - // 理论上 1/(60+1) + 1/(60+2) + // Symmetric case: both get same RRF score (e.g. 1/(k+1)+1/(k+2) for k=60). } } diff --git a/internal/message/event/hub.go b/internal/message/event/hub.go new file mode 100644 index 00000000..eec37cb0 --- /dev/null +++ b/internal/message/event/hub.go @@ -0,0 +1,124 @@ +package event + +import ( + "encoding/json" + "strings" + "sync" + + "github.com/google/uuid" +) + +const ( + // DefaultBufferSize is the default per-subscriber channel buffer. + DefaultBufferSize = 64 +) + +// EventType identifies the event category published by the message event hub. +type EventType string + +const ( + // EventTypeMessageCreated is emitted after a message is persisted successfully. + EventTypeMessageCreated EventType = "message_created" +) + +// Event is the normalized payload emitted by the in-process message event hub. +type Event struct { + Type EventType `json:"type"` + BotID string `json:"bot_id"` + Data json.RawMessage `json:"data,omitempty"` +} + +// Publisher publishes events to subscribers. +type Publisher interface { + Publish(event Event) +} + +// Subscriber subscribes to bot-scoped events. +type Subscriber interface { + Subscribe(botID string, buffer int) (string, <-chan Event, func()) +} + +// Hub is an in-process pub/sub dispatcher for bot-scoped message events. +type Hub struct { + mu sync.RWMutex + streams map[string]map[string]chan Event +} + +// NewHub creates an empty message event hub. +func NewHub() *Hub { + return &Hub{ + streams: map[string]map[string]chan Event{}, + } +} + +// Publish broadcasts one event to all subscribers under the same bot ID. +// Slow subscribers are dropped in a non-blocking way. +func (h *Hub) Publish(event Event) { + if h == nil { + return + } + botID := strings.TrimSpace(event.BotID) + if botID == "" { + return + } + h.mu.RLock() + defer h.mu.RUnlock() + for _, ch := range h.streams[botID] { + select { + case ch <- event: + default: + // Drop if receiver is slow to avoid blocking persistence path. + } + } +} + +// Subscribe registers one subscriber under a bot ID. +// It returns a stream ID, read-only event channel, and a cancel function. +func (h *Hub) Subscribe(botID string, buffer int) (string, <-chan Event, func()) { + if h == nil { + ch := make(chan Event) + close(ch) + return "", ch, func() {} + } + botID = strings.TrimSpace(botID) + if botID == "" { + ch := make(chan Event) + close(ch) + return "", ch, func() {} + } + if buffer <= 0 { + buffer = DefaultBufferSize + } + + streamID := uuid.NewString() + ch := make(chan Event, buffer) + + h.mu.Lock() + streams, ok := h.streams[botID] + if !ok { + streams = map[string]chan Event{} + h.streams[botID] = streams + } + streams[streamID] = ch + h.mu.Unlock() + + var once sync.Once + cancel := func() { + once.Do(func() { + h.mu.Lock() + streams := h.streams[botID] + if streams != nil { + if current, ok := streams[streamID]; ok { + delete(streams, streamID) + close(current) + } + if len(streams) == 0 { + delete(h.streams, botID) + } + } + h.mu.Unlock() + }) + } + + return streamID, ch, cancel +} diff --git a/internal/message/event/hub_test.go b/internal/message/event/hub_test.go new file mode 100644 index 00000000..987c3861 --- /dev/null +++ b/internal/message/event/hub_test.go @@ -0,0 +1,59 @@ +package event + +import ( + "testing" + "time" +) + +func TestHubPublishScopedByBotID(t *testing.T) { + hub := NewHub() + _, botAStream, cancelA := hub.Subscribe("bot-a", 8) + defer cancelA() + _, botBStream, cancelB := hub.Subscribe("bot-b", 8) + defer cancelB() + + hub.Publish(Event{Type: EventTypeMessageCreated, BotID: "bot-a"}) + + select { + case <-botAStream: + case <-time.After(200 * time.Millisecond): + t.Fatalf("expected event for bot-a subscriber") + } + + select { + case <-botBStream: + t.Fatalf("did not expect bot-b subscriber to receive bot-a event") + case <-time.After(120 * time.Millisecond): + } +} + +func TestHubCancelUnsubscribe(t *testing.T) { + hub := NewHub() + _, stream, cancel := hub.Subscribe("bot-a", 8) + cancel() + + select { + case _, ok := <-stream: + if ok { + t.Fatalf("expected stream to be closed after cancel") + } + case <-time.After(200 * time.Millisecond): + t.Fatalf("timed out waiting for stream close") + } +} + +func TestHubSlowSubscriberDoesNotBlockPublish(t *testing.T) { + hub := NewHub() + _, stream, cancel := hub.Subscribe("bot-a", 1) + defer cancel() + + hub.Publish(Event{Type: EventTypeMessageCreated, BotID: "bot-a"}) + hub.Publish(Event{Type: EventTypeMessageCreated, BotID: "bot-a"}) + hub.Publish(Event{Type: EventTypeMessageCreated, BotID: "bot-a"}) + + select { + case <-stream: + case <-time.After(200 * time.Millisecond): + t.Fatalf("expected at least one event in buffer") + } +} diff --git a/internal/message/service.go b/internal/message/service.go new file mode 100644 index 00000000..60051230 --- /dev/null +++ b/internal/message/service.go @@ -0,0 +1,358 @@ +package message + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/jackc/pgx/v5/pgtype" + + dbpkg "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/message/event" +) + +// DBService persists and reads bot history messages. +type DBService struct { + queries *sqlc.Queries + logger *slog.Logger + publisher event.Publisher +} + +// NewService creates a message service. +func NewService(log *slog.Logger, queries *sqlc.Queries, publishers ...event.Publisher) *DBService { + if log == nil { + log = slog.Default() + } + var publisher event.Publisher + if len(publishers) > 0 { + publisher = publishers[0] + } + return &DBService{ + queries: queries, + logger: log.With(slog.String("service", "message")), + publisher: publisher, + } +} + +// Persist writes a single message to bot_history_messages. +func (s *DBService) Persist(ctx context.Context, input PersistInput) (Message, error) { + pgBotID, err := dbpkg.ParseUUID(input.BotID) + if err != nil { + return Message{}, fmt.Errorf("invalid bot id: %w", err) + } + + pgRouteID, err := parseOptionalUUID(input.RouteID) + if err != nil { + return Message{}, fmt.Errorf("invalid route id: %w", err) + } + pgSenderChannelIdentityID, err := parseOptionalUUID(input.SenderChannelIdentityID) + if err != nil { + return Message{}, fmt.Errorf("invalid sender channel identity id: %w", err) + } + pgSenderUserID, err := parseOptionalUUID(input.SenderUserID) + if err != nil { + return Message{}, fmt.Errorf("invalid sender user id: %w", err) + } + + metaBytes, err := json.Marshal(nonNilMap(input.Metadata)) + if err != nil { + return Message{}, fmt.Errorf("marshal message metadata: %w", err) + } + + content := input.Content + if len(content) == 0 { + content = []byte("{}") + } + + row, err := s.queries.CreateMessage(ctx, sqlc.CreateMessageParams{ + BotID: pgBotID, + RouteID: pgRouteID, + SenderChannelIdentityID: pgSenderChannelIdentityID, + SenderUserID: pgSenderUserID, + Platform: toPgText(input.Platform), + ExternalMessageID: toPgText(input.ExternalMessageID), + SourceReplyToMessageID: toPgText(input.SourceReplyToMessageID), + Role: input.Role, + Content: content, + Metadata: metaBytes, + }) + if err != nil { + return Message{}, err + } + + result := toMessageFromCreate(row) + s.publishMessageCreated(result) + return result, nil +} + +// List returns all messages for a bot. +func (s *DBService) List(ctx context.Context, botID string) ([]Message, error) { + pgBotID, err := dbpkg.ParseUUID(botID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListMessages(ctx, pgBotID) + if err != nil { + return nil, err + } + return toMessagesFromList(rows), nil +} + +// ListSince returns bot messages since a given time. +func (s *DBService) ListSince(ctx context.Context, botID string, since time.Time) ([]Message, error) { + pgBotID, err := dbpkg.ParseUUID(botID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListMessagesSince(ctx, sqlc.ListMessagesSinceParams{ + BotID: pgBotID, + CreatedAt: pgtype.Timestamptz{Time: since, Valid: true}, + }) + if err != nil { + return nil, err + } + return toMessagesFromSince(rows), nil +} + +// ListLatest returns the latest N bot messages (newest first in DB; caller may reverse for ASC). +func (s *DBService) ListLatest(ctx context.Context, botID string, limit int32) ([]Message, error) { + pgBotID, err := dbpkg.ParseUUID(botID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListMessagesLatest(ctx, sqlc.ListMessagesLatestParams{ + BotID: pgBotID, + MaxCount: limit, + }) + if err != nil { + return nil, err + } + return toMessagesFromLatest(rows), nil +} + +// ListBefore returns up to limit messages older than before (created_at < before), ordered oldest-first. +func (s *DBService) ListBefore(ctx context.Context, botID string, before time.Time, limit int32) ([]Message, error) { + pgBotID, err := dbpkg.ParseUUID(botID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListMessagesBefore(ctx, sqlc.ListMessagesBeforeParams{ + BotID: pgBotID, + CreatedAt: pgtype.Timestamptz{Time: before, Valid: true}, + MaxCount: limit, + }) + if err != nil { + return nil, err + } + return toMessagesFromBefore(rows), nil +} + +// DeleteByBot deletes all messages for a bot. +func (s *DBService) DeleteByBot(ctx context.Context, botID string) error { + pgBotID, err := dbpkg.ParseUUID(botID) + if err != nil { + return err + } + return s.queries.DeleteMessagesByBot(ctx, pgBotID) +} + +func toMessageFromCreate(row sqlc.CreateMessageRow) Message { + return toMessageFields( + row.ID, + row.BotID, + row.RouteID, + row.SenderChannelIdentityID, + row.SenderUserID, + row.Platform, + row.ExternalMessageID, + row.SourceReplyToMessageID, + row.Role, + row.Content, + row.Metadata, + row.CreatedAt, + ) +} + +func toMessageFromListRow(row sqlc.ListMessagesRow) Message { + return toMessageFields( + row.ID, + row.BotID, + row.RouteID, + row.SenderChannelIdentityID, + row.SenderUserID, + row.Platform, + row.ExternalMessageID, + row.SourceReplyToMessageID, + row.Role, + row.Content, + row.Metadata, + row.CreatedAt, + ) +} + +func toMessageFromSinceRow(row sqlc.ListMessagesSinceRow) Message { + return toMessageFields( + row.ID, + row.BotID, + row.RouteID, + row.SenderChannelIdentityID, + row.SenderUserID, + row.Platform, + row.ExternalMessageID, + row.SourceReplyToMessageID, + row.Role, + row.Content, + row.Metadata, + row.CreatedAt, + ) +} + +func toMessageFromLatestRow(row sqlc.ListMessagesLatestRow) Message { + return toMessageFields( + row.ID, + row.BotID, + row.RouteID, + row.SenderChannelIdentityID, + row.SenderUserID, + row.Platform, + row.ExternalMessageID, + row.SourceReplyToMessageID, + row.Role, + row.Content, + row.Metadata, + row.CreatedAt, + ) +} + +func toMessageFields( + id pgtype.UUID, + botID pgtype.UUID, + routeID pgtype.UUID, + senderChannelIdentityID pgtype.UUID, + senderUserID pgtype.UUID, + platform pgtype.Text, + externalMessageID pgtype.Text, + sourceReplyToMessageID pgtype.Text, + role string, + content []byte, + metadata []byte, + createdAt pgtype.Timestamptz, +) Message { + return Message{ + ID: id.String(), + BotID: botID.String(), + RouteID: routeID.String(), + SenderChannelIdentityID: senderChannelIdentityID.String(), + SenderUserID: senderUserID.String(), + Platform: dbpkg.TextToString(platform), + ExternalMessageID: dbpkg.TextToString(externalMessageID), + SourceReplyToMessageID: dbpkg.TextToString(sourceReplyToMessageID), + Role: role, + Content: json.RawMessage(content), + Metadata: parseJSONMap(metadata), + CreatedAt: createdAt.Time, + } +} + +func toMessagesFromList(rows []sqlc.ListMessagesRow) []Message { + messages := make([]Message, 0, len(rows)) + for _, row := range rows { + messages = append(messages, toMessageFromListRow(row)) + } + return messages +} + +func toMessagesFromSince(rows []sqlc.ListMessagesSinceRow) []Message { + messages := make([]Message, 0, len(rows)) + for _, row := range rows { + messages = append(messages, toMessageFromSinceRow(row)) + } + return messages +} + +func toMessagesFromLatest(rows []sqlc.ListMessagesLatestRow) []Message { + messages := make([]Message, 0, len(rows)) + for _, row := range rows { + messages = append(messages, toMessageFromLatestRow(row)) + } + return messages +} + +func toMessageFromBeforeRow(row sqlc.ListMessagesBeforeRow) Message { + return toMessageFields( + row.ID, + row.BotID, + row.RouteID, + row.SenderChannelIdentityID, + row.SenderUserID, + row.Platform, + row.ExternalMessageID, + row.SourceReplyToMessageID, + row.Role, + row.Content, + row.Metadata, + row.CreatedAt, + ) +} + +// toMessagesFromBefore returns messages in oldest-first order (ListMessagesBefore returns DESC; we reverse). +func toMessagesFromBefore(rows []sqlc.ListMessagesBeforeRow) []Message { + messages := make([]Message, 0, len(rows)) + for i := len(rows) - 1; i >= 0; i-- { + messages = append(messages, toMessageFromBeforeRow(rows[i])) + } + return messages +} + +func parseOptionalUUID(id string) (pgtype.UUID, error) { + if strings.TrimSpace(id) == "" { + return pgtype.UUID{}, nil + } + return dbpkg.ParseUUID(id) +} + +func toPgText(value string) pgtype.Text { + value = strings.TrimSpace(value) + if value == "" { + return pgtype.Text{} + } + return pgtype.Text{String: value, Valid: true} +} + +func nonNilMap(m map[string]any) map[string]any { + if m == nil { + return map[string]any{} + } + return m +} + +func parseJSONMap(data []byte) map[string]any { + if len(data) == 0 { + return nil + } + var m map[string]any + _ = json.Unmarshal(data, &m) + return m +} + +func (s *DBService) publishMessageCreated(message Message) { + if s.publisher == nil { + return + } + payload, err := json.Marshal(message) + if err != nil { + if s.logger != nil { + s.logger.Warn("marshal message event failed", slog.Any("error", err)) + } + return + } + s.publisher.Publish(event.Event{ + Type: event.EventTypeMessageCreated, + BotID: strings.TrimSpace(message.BotID), + Data: payload, + }) +} diff --git a/internal/message/types.go b/internal/message/types.go new file mode 100644 index 00000000..f5938474 --- /dev/null +++ b/internal/message/types.go @@ -0,0 +1,52 @@ +package message + +import ( + "context" + "encoding/json" + "time" +) + +// Message represents a single persisted bot message. +type Message struct { + ID string `json:"id"` + BotID string `json:"bot_id"` + RouteID string `json:"route_id,omitempty"` + SenderChannelIdentityID string `json:"sender_channel_identity_id,omitempty"` + SenderUserID string `json:"sender_user_id,omitempty"` + Platform string `json:"platform,omitempty"` + ExternalMessageID string `json:"external_message_id,omitempty"` + SourceReplyToMessageID string `json:"source_reply_to_message_id,omitempty"` + Role string `json:"role"` + Content json.RawMessage `json:"content"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +// PersistInput is the input for persisting a message. +type PersistInput struct { + BotID string + RouteID string + SenderChannelIdentityID string + SenderUserID string + Platform string + ExternalMessageID string + SourceReplyToMessageID string + Role string + Content json.RawMessage + Metadata map[string]any +} + +// Writer defines write behavior needed by the inbound router. +type Writer interface { + Persist(ctx context.Context, input PersistInput) (Message, error) +} + +// Service defines message read/write behavior. +type Service interface { + Writer + List(ctx context.Context, botID string) ([]Message, error) + ListSince(ctx context.Context, botID string, since time.Time) ([]Message, error) + ListLatest(ctx context.Context, botID string, limit int32) ([]Message, error) + ListBefore(ctx context.Context, botID string, before time.Time, limit int32) ([]Message, error) + DeleteByBot(ctx context.Context, botID string) error +} diff --git a/internal/models/models.go b/internal/models/models.go index 45993883..e648c8f0 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -33,7 +34,7 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro } // Convert to sqlc params - llmProviderID, err := parseUUID(model.LlmProviderID) + llmProviderID, err := db.ParseUUID(model.LlmProviderID) if err != nil { return AddResponse{}, fmt.Errorf("invalid llm provider ID: %w", err) } @@ -78,7 +79,7 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro // GetByID retrieves a model by its internal UUID func (s *Service) GetByID(ctx context.Context, id string) (GetResponse, error) { - uuid, err := parseUUID(id) + uuid, err := db.ParseUUID(id) if err != nil { return GetResponse{}, fmt.Errorf("invalid ID: %w", err) } @@ -148,7 +149,7 @@ func (s *Service) ListByProviderID(ctx context.Context, providerID string) ([]Ge if strings.TrimSpace(providerID) == "" { return nil, fmt.Errorf("provider id is required") } - uuid, err := parseUUID(providerID) + uuid, err := db.ParseUUID(providerID) if err != nil { return nil, fmt.Errorf("invalid provider id: %w", err) } @@ -167,7 +168,7 @@ func (s *Service) ListByProviderIDAndType(ctx context.Context, providerID string if strings.TrimSpace(providerID) == "" { return nil, fmt.Errorf("provider id is required") } - uuid, err := parseUUID(providerID) + uuid, err := db.ParseUUID(providerID) if err != nil { return nil, fmt.Errorf("invalid provider id: %w", err) } @@ -183,7 +184,7 @@ func (s *Service) ListByProviderIDAndType(ctx context.Context, providerID string // UpdateByID updates a model by its internal UUID func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) (GetResponse, error) { - uuid, err := parseUUID(id) + uuid, err := db.ParseUUID(id) if err != nil { return GetResponse{}, fmt.Errorf("invalid ID: %w", err) } @@ -199,7 +200,7 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) Type: string(model.Type), } - llmProviderID, err := parseUUID(model.LlmProviderID) + llmProviderID, err := db.ParseUUID(model.LlmProviderID) if err != nil { return GetResponse{}, fmt.Errorf("invalid llm provider ID: %w", err) } @@ -238,7 +239,7 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat Type: string(model.Type), } - llmProviderID, err := parseUUID(model.LlmProviderID) + llmProviderID, err := db.ParseUUID(model.LlmProviderID) if err != nil { return GetResponse{}, fmt.Errorf("invalid llm provider ID: %w", err) } @@ -262,7 +263,7 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat // DeleteByID deletes a model by its internal UUID func (s *Service) DeleteByID(ctx context.Context, id string) error { - uuid, err := parseUUID(id) + uuid, err := db.ParseUUID(id) if err != nil { return fmt.Errorf("invalid ID: %w", err) } @@ -311,19 +312,6 @@ func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64, // Helper functions -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(id) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID format: %w", err) - } - - var pgUUID pgtype.UUID - copy(pgUUID.Bytes[:], parsed[:]) - pgUUID.Valid = true - - return pgUUID, nil -} - func convertToGetResponse(dbModel sqlc.Model) GetResponse { resp := GetResponse{ ModelId: dbModel.ModelID, @@ -335,8 +323,8 @@ func convertToGetResponse(dbModel sqlc.Model) GetResponse { }, } - if llmProviderID, ok := uuidStringFromPgUUID(dbModel.LlmProviderID); ok { - resp.Model.LlmProviderID = llmProviderID + if dbModel.LlmProviderID.Valid { + resp.Model.LlmProviderID = dbModel.LlmProviderID.String() } if dbModel.Name.Valid { @@ -382,17 +370,6 @@ func isValidClientType(clientType ClientType) bool { } } -func uuidStringFromPgUUID(value pgtype.UUID) (string, bool) { - if !value.Valid { - return "", false - } - id, err := uuid.FromBytes(value.Bytes[:]) - if err != nil { - return "", false - } - return id.String(), true -} - // SelectMemoryModel selects a chat model for memory operations. func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sqlc.Queries) (GetResponse, sqlc.LlmProvider, error) { if modelsService == nil { @@ -415,7 +392,7 @@ func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID st if strings.TrimSpace(providerID) == "" { return sqlc.LlmProvider{}, fmt.Errorf("provider id missing") } - parsed, err := parseUUID(providerID) + parsed, err := db.ParseUUID(providerID) if err != nil { return sqlc.LlmProvider{}, err } diff --git a/internal/models/types.go b/internal/models/types.go index c8252bb4..c0ef4df2 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -32,13 +32,13 @@ const ( ) type Model struct { - ModelID string `json:"model_id" validate:"required"` - Name string `json:"name" validate:"required"` - LlmProviderID string `json:"llm_provider_id" validate:"required"` - IsMultimodal bool `json:"is_multimodal"` - Input []string `json:"input"` - Type ModelType `json:"type" validate:"required"` - Dimensions int `json:"dimensions"` + ModelID string `json:"model_id"` + Name string `json:"name"` + LlmProviderID string `json:"llm_provider_id"` + IsMultimodal bool `json:"is_multimodal"` + Input []string `json:"input"` + Type ModelType `json:"type"` + Dimensions int `json:"dimensions"` } func (m *Model) Validate() error { diff --git a/internal/preauth/service.go b/internal/preauth/service.go index e267521b..fe017eee 100644 --- a/internal/preauth/service.go +++ b/internal/preauth/service.go @@ -11,6 +11,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -32,13 +33,13 @@ func (s *Service) Issue(ctx context.Context, botID, issuedByUserID string, ttl t if ttl <= 0 { ttl = 24 * time.Hour } - pgBotID, err := parseUUID(botID) + pgBotID, err := db.ParseUUID(botID) if err != nil { return Key{}, err } pgIssuedBy := pgtype.UUID{Valid: false} if strings.TrimSpace(issuedByUserID) != "" { - parsed, err := parseUUID(issuedByUserID) + parsed, err := db.ParseUUID(issuedByUserID) if err != nil { return Key{}, err } @@ -76,7 +77,7 @@ func (s *Service) MarkUsed(ctx context.Context, id string) (Key, error) { if s.queries == nil { return Key{}, fmt.Errorf("preauth queries not configured") } - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return Key{}, err } @@ -89,38 +90,16 @@ func (s *Service) MarkUsed(ctx context.Context, id string) (Key, error) { func normalizeKey(row sqlc.BotPreauthKey) Key { return Key{ - ID: toUUIDString(row.ID), - BotID: toUUIDString(row.BotID), + ID: row.ID.String(), + BotID: row.BotID.String(), Token: strings.TrimSpace(row.Token), - IssuedByChannelIdentityID: toUUIDString(row.IssuedByUserID), + IssuedByChannelIdentityID: row.IssuedByUserID.String(), ExpiresAt: timeFromPg(row.ExpiresAt), UsedAt: timeFromPg(row.UsedAt), CreatedAt: timeFromPg(row.CreatedAt), } } -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} - -func toUUIDString(value pgtype.UUID) string { - if !value.Valid { - return "" - } - parsed, err := uuid.FromBytes(value.Bytes[:]) - if err != nil { - return "" - } - return parsed.String() -} - func timeFromPg(value pgtype.Timestamptz) time.Time { if value.Valid { return value.Time diff --git a/internal/providers/service.go b/internal/providers/service.go index d2118f5f..dba23335 100644 --- a/internal/providers/service.go +++ b/internal/providers/service.go @@ -7,9 +7,7 @@ import ( "log/slog" "strings" - "github.com/google/uuid" - "github.com/jackc/pgx/v5/pgtype" - + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -57,7 +55,7 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, e // Get retrieves a provider by ID func (s *Service) Get(ctx context.Context, id string) (GetResponse, error) { - providerID, err := parseUUID(id) + providerID, err := db.ParseUUID(id) if err != nil { return GetResponse{}, err } @@ -114,7 +112,7 @@ func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) ( // Update updates an existing provider func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (GetResponse, error) { - providerID, err := parseUUID(id) + providerID, err := db.ParseUUID(id) if err != nil { return GetResponse{}, err } @@ -176,7 +174,7 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get // Delete deletes a provider by ID func (s *Service) Delete(ctx context.Context, id string) error { - providerID, err := parseUUID(id) + providerID, err := db.ParseUUID(id) if err != nil { return err } @@ -219,13 +217,8 @@ func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse { // Mask API key (show only first 8 characters) maskedAPIKey := maskAPIKey(provider.ApiKey) - // Convert pgtype.UUID to string - var id [16]byte - copy(id[:], provider.ID.Bytes[:]) - idUUID := uuid.UUID(id) - return GetResponse{ - ID: idUUID.String(), + ID: provider.ID.String(), Name: provider.Name, ClientType: provider.ClientType, BaseURL: provider.BaseUrl, @@ -236,18 +229,6 @@ func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse { } } -// parseUUID parses a UUID string -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(id) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} - // isValidClientType checks if a client type is valid func isValidClientType(clientType ClientType) bool { switch clientType { diff --git a/internal/providers/types.go b/internal/providers/types.go index e2556f64..4eec664e 100644 --- a/internal/providers/types.go +++ b/internal/providers/types.go @@ -33,14 +33,14 @@ type UpdateRequest struct { // GetResponse represents the response for getting a provider type GetResponse struct { - ID string `json:"id" validate:"required"` - Name string `json:"name" validate:"required"` - ClientType string `json:"client_type" validate:"required"` - BaseURL string `json:"base_url" validate:"required"` + ID string `json:"id"` + Name string `json:"name"` + ClientType string `json:"client_type"` + BaseURL string `json:"base_url"` APIKey string `json:"api_key,omitempty"` // masked in response Metadata map[string]any `json:"metadata,omitempty"` - CreatedAt time.Time `json:"created_at" validate:"required"` - UpdatedAt time.Time `json:"updated_at" validate:"required"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // ListResponse represents the response for listing providers diff --git a/internal/router/channel.go b/internal/router/channel.go index b5c23e40..cefc25d3 100644 --- a/internal/router/channel.go +++ b/internal/router/channel.go @@ -12,46 +12,46 @@ import ( "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/channel" - "github.com/memohai/memoh/internal/chat" + "github.com/memohai/memoh/internal/channel/route" + "github.com/memohai/memoh/internal/conversation" + "github.com/memohai/memoh/internal/conversation/flow" + messagepkg "github.com/memohai/memoh/internal/message" ) -// ChatGateway abstracts the chat capability to avoid direct coupling in the router. -type ChatGateway interface { - Chat(ctx context.Context, req chat.ChatRequest) (chat.ChatResponse, error) -} - const ( - silentReplyToken = "NO_REPLY" - minDuplicateTextLength = 10 + silentReplyToken = "NO_REPLY" + minDuplicateTextLength = 10 + processingStatusTimeout = 60 * time.Second ) var ( whitespacePattern = regexp.MustCompile(`\s+`) ) -// ChatService resolves and manages chats. -type ChatService interface { - ResolveChat(ctx context.Context, botID, platform, conversationID, threadID, conversationType, userID, channelConfigID, replyTarget string) (chat.ResolveChatResult, error) - PersistMessage(ctx context.Context, chatID, botID, routeID, senderChannelIdentityID, senderUserID, platform, externalMessageID, role string, content json.RawMessage, metadata map[string]any) (chat.Message, error) +// RouteResolver resolves and manages channel routes. +type RouteResolver interface { + ResolveConversation(ctx context.Context, input route.ResolveInput) (route.ResolveConversationResult, error) } // ChannelInboundProcessor routes channel inbound messages to the chat gateway. type ChannelInboundProcessor struct { - chat ChatGateway - chatService ChatService - registry *channel.Registry - logger *slog.Logger - jwtSecret string - tokenTTL time.Duration - identity *IdentityResolver + runner flow.Runner + routeResolver RouteResolver + message messagepkg.Writer + registry *channel.Registry + logger *slog.Logger + jwtSecret string + tokenTTL time.Duration + identity *IdentityResolver } // NewChannelInboundProcessor creates a processor with channel identity-based resolution. func NewChannelInboundProcessor( log *slog.Logger, registry *channel.Registry, - chatService ChatService, - chatGateway ChatGateway, + routeResolver RouteResolver, + messageWriter messagepkg.Writer, + runner flow.Runner, channelIdentityService ChannelIdentityService, memberService BotMemberService, policyService PolicyService, @@ -68,13 +68,14 @@ func NewChannelInboundProcessor( } identityResolver := NewIdentityResolver(log, registry, channelIdentityService, memberService, policyService, preauthService, bindService, "", "") return &ChannelInboundProcessor{ - chat: chatGateway, - chatService: chatService, - registry: registry, - logger: log.With(slog.String("component", "channel_router")), - jwtSecret: strings.TrimSpace(jwtSecret), - tokenTTL: tokenTTL, - identity: identityResolver, + runner: runner, + routeResolver: routeResolver, + message: messageWriter, + registry: registry, + logger: log.With(slog.String("component", "channel_router")), + jwtSecret: strings.TrimSpace(jwtSecret), + tokenTTL: tokenTTL, + identity: identityResolver, } } @@ -87,8 +88,8 @@ func (p *ChannelInboundProcessor) IdentityMiddleware() channel.Middleware { } // HandleInbound processes an inbound channel message through identity resolution and chat gateway. -func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, sender channel.ReplySender) error { - if p.chat == nil { +func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, sender channel.StreamReplySender) error { + if p.runner == nil { return fmt.Errorf("channel inbound processor not configured") } if sender == nil { @@ -109,33 +110,66 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel Message: state.Decision.Reply, }) } + if p.logger != nil { + p.logger.Info( + "inbound dropped by identity policy (no reply sent)", + slog.String("channel", msg.Channel.String()), + slog.String("bot_id", strings.TrimSpace(state.Identity.BotID)), + slog.String("conversation_type", strings.TrimSpace(msg.Conversation.Type)), + slog.String("conversation_id", strings.TrimSpace(msg.Conversation.ID)), + ) + } return nil } identity := state.Identity - // Resolve or create the chat via chat_routes. - if p.chatService == nil { - return fmt.Errorf("chat service not configured") + // Resolve or create the route via channel_routes. + if p.routeResolver == nil { + return fmt.Errorf("route resolver not configured") } - resolved, err := p.chatService.ResolveChat(ctx, identity.BotID, - msg.Channel.String(), msg.Conversation.ID, extractThreadID(msg), - msg.Conversation.Type, identity.UserID, identity.ChannelConfigID, - strings.TrimSpace(msg.ReplyTarget)) + resolved, err := p.routeResolver.ResolveConversation(ctx, route.ResolveInput{ + BotID: identity.BotID, + Platform: msg.Channel.String(), + ConversationID: msg.Conversation.ID, + ThreadID: extractThreadID(msg), + ConversationType: msg.Conversation.Type, + ChannelIdentityID: identity.UserID, + ChannelConfigID: identity.ChannelConfigID, + ReplyTarget: strings.TrimSpace(msg.ReplyTarget), + }) if err != nil { - return fmt.Errorf("resolve chat: %w", err) + return fmt.Errorf("resolve route conversation: %w", err) + } + // Bot-centric history container: + // always persist channel traffic under bot_id so WebUI can view unified cross-platform history. + activeChatID := strings.TrimSpace(identity.BotID) + if activeChatID == "" { + activeChatID = strings.TrimSpace(resolved.ChatID) } if !shouldTriggerAssistantResponse(msg) && !identity.ForceReply { - p.persistInboundOnly(ctx, resolved, identity, msg, text) + if p.logger != nil { + p.logger.Info( + "inbound not triggering assistant (group trigger condition not met)", + slog.String("channel", msg.Channel.String()), + slog.String("bot_id", strings.TrimSpace(identity.BotID)), + slog.String("route_id", strings.TrimSpace(resolved.RouteID)), + slog.Bool("is_mentioned", metadataBool(msg.Metadata, "is_mentioned")), + slog.Bool("is_reply_to_bot", metadataBool(msg.Metadata, "is_reply_to_bot")), + slog.String("conversation_type", strings.TrimSpace(msg.Conversation.Type)), + ) + } + p.persistInboundUser(ctx, resolved.RouteID, identity, msg, text, "passive_sync") return nil } + userMessagePersisted := p.persistInboundUser(ctx, resolved.RouteID, identity, msg, text, "active_chat") // Issue chat token for reply routing. chatToken := "" if p.jwtSecret != "" && strings.TrimSpace(msg.ReplyTarget) != "" { signed, _, err := auth.GenerateChatToken(auth.ChatToken{ BotID: identity.BotID, - ChatID: resolved.ChatID, + ChatID: activeChatID, RouteID: resolved.RouteID, UserID: identity.UserID, ChannelIdentityID: identity.ChannelIdentityID, @@ -149,7 +183,7 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel } } - // Issue user JWT for downstream calls. + // Issue user JWT for downstream calls (MCP, schedule, etc.). For guests use chat token as Bearer. token := "" if identity.UserID != "" && p.jwtSecret != "" { signed, _, err := auth.GenerateToken(identity.UserID, p.jwtSecret, p.tokenTTL) @@ -161,49 +195,182 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel token = "Bearer " + signed } } + if token == "" && chatToken != "" { + token = "Bearer " + chatToken + } var desc channel.Descriptor if p.registry != nil { desc, _ = p.registry.GetDescriptor(msg.Channel) } - resp, err := p.chat.Chat(ctx, chat.ChatRequest{ + statusInfo := channel.ProcessingStatusInfo{ + BotID: identity.BotID, + ChatID: activeChatID, + RouteID: resolved.RouteID, + ChannelIdentityID: identity.ChannelIdentityID, + UserID: identity.UserID, + Query: text, + ReplyTarget: strings.TrimSpace(msg.ReplyTarget), + SourceMessageID: strings.TrimSpace(msg.Message.ID), + } + statusNotifier := p.resolveProcessingStatusNotifier(msg.Channel) + statusHandle := channel.ProcessingStatusHandle{} + if statusNotifier != nil { + handle, notifyErr := p.notifyProcessingStarted(ctx, statusNotifier, cfg, msg, statusInfo) + if notifyErr != nil { + p.logProcessingStatusError("processing_started", msg, identity, notifyErr) + } else { + statusHandle = handle + } + } + target := strings.TrimSpace(msg.ReplyTarget) + if target == "" { + err := fmt.Errorf("reply target missing") + if statusNotifier != nil { + if notifyErr := p.notifyProcessingFailed(ctx, statusNotifier, cfg, msg, statusInfo, statusHandle, err); notifyErr != nil { + p.logProcessingStatusError("processing_failed", msg, identity, notifyErr) + } + } + return err + } + sourceMessageID := strings.TrimSpace(msg.Message.ID) + replyRef := &channel.ReplyRef{Target: target} + if sourceMessageID != "" { + replyRef.MessageID = sourceMessageID + } + stream, err := sender.OpenStream(ctx, target, channel.StreamOptions{ + Reply: replyRef, + SourceMessageID: sourceMessageID, + Metadata: map[string]any{ + "route_id": resolved.RouteID, + }, + }) + if err != nil { + if statusNotifier != nil { + if notifyErr := p.notifyProcessingFailed(ctx, statusNotifier, cfg, msg, statusInfo, statusHandle, err); notifyErr != nil { + p.logProcessingStatusError("processing_failed", msg, identity, notifyErr) + } + } + return err + } + defer func() { + _ = stream.Close(context.WithoutCancel(ctx)) + }() + + if err := stream.Push(ctx, channel.StreamEvent{ + Type: channel.StreamEventStatus, + Status: channel.StreamStatusStarted, + }); err != nil { + if statusNotifier != nil { + if notifyErr := p.notifyProcessingFailed(ctx, statusNotifier, cfg, msg, statusInfo, statusHandle, err); notifyErr != nil { + p.logProcessingStatusError("processing_failed", msg, identity, notifyErr) + } + } + return err + } + + chunkCh, streamErrCh := p.runner.StreamChat(ctx, flow.ChatRequest{ BotID: identity.BotID, - ChatID: resolved.ChatID, + ChatID: activeChatID, Token: token, UserID: identity.UserID, SourceChannelIdentityID: identity.ChannelIdentityID, DisplayName: identity.DisplayName, RouteID: resolved.RouteID, ChatToken: chatToken, - ExternalMessageID: strings.TrimSpace(msg.Message.ID), + ExternalMessageID: sourceMessageID, Query: text, CurrentChannel: msg.Channel.String(), Channels: []string{msg.Channel.String()}, + UserMessagePersisted: userMessagePersisted, }) - if err != nil { + + var ( + finalMessages []conversation.ModelMessage + streamErr error + ) + for chunkCh != nil || streamErrCh != nil { + select { + case chunk, ok := <-chunkCh: + if !ok { + chunkCh = nil + continue + } + events, messages, parseErr := mapStreamChunkToChannelEvents(chunk) + if parseErr != nil { + if p.logger != nil { + p.logger.Warn( + "stream chunk parse failed", + slog.String("channel", msg.Channel.String()), + slog.String("channel_identity_id", identity.ChannelIdentityID), + slog.String("user_id", identity.UserID), + slog.Any("error", parseErr), + ) + } + continue + } + for _, event := range events { + if pushErr := stream.Push(ctx, event); pushErr != nil { + streamErr = pushErr + break + } + } + if len(messages) > 0 { + finalMessages = messages + } + case err, ok := <-streamErrCh: + if !ok { + streamErrCh = nil + continue + } + if err != nil { + streamErr = err + } + } + if streamErr != nil { + break + } + } + + if streamErr != nil { if p.logger != nil { p.logger.Error( - "chat gateway failed", + "chat gateway stream failed", slog.String("channel", msg.Channel.String()), slog.String("channel_identity_id", identity.ChannelIdentityID), slog.String("user_id", identity.UserID), - slog.Any("error", err), + slog.Any("error", streamErr), ) } - return err + _ = stream.Push(ctx, channel.StreamEvent{ + Type: channel.StreamEventError, + Error: streamErr.Error(), + }) + if statusNotifier != nil { + if notifyErr := p.notifyProcessingFailed(ctx, statusNotifier, cfg, msg, statusInfo, statusHandle, streamErr); notifyErr != nil { + p.logProcessingStatusError("processing_failed", msg, identity, notifyErr) + } + } + return streamErr } - outputs := chat.ExtractAssistantOutputs(resp.Messages) - if len(outputs) == 0 { - return nil - } - target := strings.TrimSpace(msg.ReplyTarget) - if target == "" { - return fmt.Errorf("reply target missing") - } - sentTexts, suppressReplies := collectMessageToolContext(p.registry, resp.Messages, msg.Channel, target) + + sentTexts, suppressReplies := collectMessageToolContext(p.registry, finalMessages, msg.Channel, target) if suppressReplies { + if err := stream.Push(ctx, channel.StreamEvent{ + Type: channel.StreamEventStatus, + Status: channel.StreamStatusCompleted, + }); err != nil { + return err + } + if statusNotifier != nil { + if notifyErr := p.notifyProcessingCompleted(ctx, statusNotifier, cfg, msg, statusInfo, statusHandle); notifyErr != nil { + p.logProcessingStatusError("processing_completed", msg, identity, notifyErr) + } + } return nil } + + outputs := flow.ExtractAssistantOutputs(finalMessages) for _, output := range outputs { outMessage := buildChannelMessage(output, desc.Capabilities) if outMessage.IsEmpty() { @@ -216,13 +383,32 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel if isMessagingToolDuplicate(plainText, sentTexts) { continue } - if err := sender.Send(ctx, channel.OutboundMessage{ - Target: target, - Message: outMessage, + if outMessage.Reply == nil && sourceMessageID != "" { + outMessage.Reply = &channel.ReplyRef{ + Target: target, + MessageID: sourceMessageID, + } + } + if err := stream.Push(ctx, channel.StreamEvent{ + Type: channel.StreamEventFinal, + Final: &channel.StreamFinalizePayload{ + Message: outMessage, + }, }); err != nil { return err } } + if err := stream.Push(ctx, channel.StreamEvent{ + Type: channel.StreamEventStatus, + Status: channel.StreamStatusCompleted, + }); err != nil { + return err + } + if statusNotifier != nil { + if notifyErr := p.notifyProcessingCompleted(ctx, statusNotifier, cfg, msg, statusInfo, statusHandle); notifyErr != nil { + p.logProcessingStatusError("processing_completed", msg, identity, notifyErr) + } + } return nil } @@ -320,48 +506,47 @@ func metadataBool(metadata map[string]any, key string) bool { } } -func (p *ChannelInboundProcessor) persistInboundOnly(ctx context.Context, resolved chat.ResolveChatResult, identity InboundIdentity, msg channel.InboundMessage, query string) { - if p.chatService == nil { - return +func (p *ChannelInboundProcessor) persistInboundUser(ctx context.Context, routeID string, identity InboundIdentity, msg channel.InboundMessage, query string, triggerMode string) bool { + if p.message == nil { + return false } - chatID := strings.TrimSpace(resolved.ChatID) botID := strings.TrimSpace(identity.BotID) - if chatID == "" || botID == "" { - return + if botID == "" { + return false } - payload, err := json.Marshal(chat.ModelMessage{ + payload, err := json.Marshal(conversation.ModelMessage{ Role: "user", - Content: chat.NewTextContent(query), + Content: conversation.NewTextContent(query), }) if err != nil { if p.logger != nil { - p.logger.Warn("marshal passive inbound failed", slog.Any("error", err)) + p.logger.Warn("marshal inbound user message failed", slog.Any("error", err)) } - return + return false } meta := map[string]any{ - "route_id": resolved.RouteID, + "route_id": strings.TrimSpace(routeID), "platform": msg.Channel.String(), - "trigger_mode": "passive_sync", + "trigger_mode": strings.TrimSpace(triggerMode), } - if _, err := p.chatService.PersistMessage( - ctx, - chatID, - botID, - strings.TrimSpace(resolved.RouteID), - strings.TrimSpace(identity.ChannelIdentityID), - strings.TrimSpace(identity.UserID), - msg.Channel.String(), - strings.TrimSpace(msg.Message.ID), - "user", - payload, - meta, - ); err != nil && p.logger != nil { - p.logger.Warn("persist passive inbound failed", slog.Any("error", err)) + if _, err := p.message.Persist(ctx, messagepkg.PersistInput{ + BotID: botID, + RouteID: strings.TrimSpace(routeID), + SenderChannelIdentityID: strings.TrimSpace(identity.ChannelIdentityID), + SenderUserID: strings.TrimSpace(identity.UserID), + Platform: msg.Channel.String(), + ExternalMessageID: strings.TrimSpace(msg.Message.ID), + Role: "user", + Content: payload, + Metadata: meta, + }); err != nil && p.logger != nil { + p.logger.Warn("persist inbound user message failed", slog.Any("error", err)) + return false } + return true } -func buildChannelMessage(output chat.AssistantOutput, capabilities channel.ChannelCapabilities) channel.Message { +func buildChannelMessage(output conversation.AssistantOutput, capabilities channel.ChannelCapabilities) channel.Message { msg := channel.Message{} if strings.TrimSpace(output.Content) != "" { msg.Text = strings.TrimSpace(output.Content) @@ -434,7 +619,7 @@ func containsMarkdown(text string) bool { return false } -func contentPartHasValue(part chat.ContentPart) bool { +func contentPartHasValue(part conversation.ContentPart) bool { if strings.TrimSpace(part.Text) != "" { return true } @@ -447,7 +632,7 @@ func contentPartHasValue(part chat.ContentPart) bool { return false } -func contentPartText(part chat.ContentPart) string { +func contentPartText(part conversation.ContentPart) string { if strings.TrimSpace(part.Text) != "" { return part.Text } @@ -460,6 +645,79 @@ func contentPartText(part chat.ContentPart) string { return "" } +type gatewayStreamEnvelope struct { + Type string `json:"type"` + Delta string `json:"delta"` + Error string `json:"error"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + Messages []conversation.ModelMessage `json:"messages"` +} + +type gatewayStreamDoneData struct { + Messages []conversation.ModelMessage `json:"messages"` +} + +func mapStreamChunkToChannelEvents(chunk conversation.StreamChunk) ([]channel.StreamEvent, []conversation.ModelMessage, error) { + if len(chunk) == 0 { + return nil, nil, nil + } + var envelope gatewayStreamEnvelope + if err := json.Unmarshal(chunk, &envelope); err != nil { + return nil, nil, err + } + finalMessages := make([]conversation.ModelMessage, 0, len(envelope.Messages)) + finalMessages = append(finalMessages, envelope.Messages...) + if len(finalMessages) == 0 && len(envelope.Data) > 0 { + var done gatewayStreamDoneData + if err := json.Unmarshal(envelope.Data, &done); err == nil && len(done.Messages) > 0 { + finalMessages = append(finalMessages, done.Messages...) + } + } + eventType := strings.ToLower(strings.TrimSpace(envelope.Type)) + switch eventType { + case "text_delta": + if envelope.Delta == "" { + return nil, finalMessages, nil + } + return []channel.StreamEvent{ + { + Type: channel.StreamEventDelta, + Delta: envelope.Delta, + }, + }, finalMessages, nil + case "reasoning_delta": + if envelope.Delta == "" { + return nil, finalMessages, nil + } + return []channel.StreamEvent{ + { + Type: channel.StreamEventDelta, + Delta: envelope.Delta, + Metadata: map[string]any{ + "phase": "reasoning", + }, + }, + }, finalMessages, nil + case "error": + streamError := strings.TrimSpace(envelope.Error) + if streamError == "" { + streamError = strings.TrimSpace(envelope.Message) + } + if streamError == "" { + streamError = "stream error" + } + return []channel.StreamEvent{ + { + Type: channel.StreamEventError, + Error: streamError, + }, + }, finalMessages, nil + default: + return nil, finalMessages, nil + } +} + func buildInboundQuery(message channel.Message) string { text := strings.TrimSpace(message.PlainText()) if len(message.Attachments) == 0 { @@ -472,7 +730,7 @@ func buildInboundQuery(message channel.Message) string { for _, att := range message.Attachments { label := strings.TrimSpace(att.Name) if label == "" { - label = strings.TrimSpace(att.URL) + label = strings.TrimSpace(att.Reference()) } if label == "" { label = "unknown" @@ -530,7 +788,7 @@ type sendMessageToolArgs struct { Message *channel.Message `json:"message"` } -func collectMessageToolContext(registry *channel.Registry, messages []chat.ModelMessage, channelType channel.ChannelType, replyTarget string) ([]string, bool) { +func collectMessageToolContext(registry *channel.Registry, messages []conversation.ModelMessage, channelType channel.ChannelType, replyTarget string) ([]string, bool) { if len(messages) == 0 { return nil, false } @@ -699,12 +957,88 @@ func isMessagingToolDuplicate(text string, sentTexts []string) bool { return false } +// requireIdentity resolves identity for the current message. Always resolves from msg so each sender is identified correctly (no reuse of context state across messages). func (p *ChannelInboundProcessor) requireIdentity(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) (IdentityState, error) { - if state, ok := IdentityStateFromContext(ctx); ok { - return state, nil - } if p.identity == nil { return IdentityState{}, fmt.Errorf("identity resolver not configured") } return p.identity.Resolve(ctx, cfg, msg) } + +func (p *ChannelInboundProcessor) resolveProcessingStatusNotifier(channelType channel.ChannelType) channel.ProcessingStatusNotifier { + if p == nil || p.registry == nil { + return nil + } + notifier, ok := p.registry.GetProcessingStatusNotifier(channelType) + if !ok { + return nil + } + return notifier +} + +func (p *ChannelInboundProcessor) notifyProcessingStarted( + ctx context.Context, + notifier channel.ProcessingStatusNotifier, + cfg channel.ChannelConfig, + msg channel.InboundMessage, + info channel.ProcessingStatusInfo, +) (channel.ProcessingStatusHandle, error) { + if notifier == nil { + return channel.ProcessingStatusHandle{}, nil + } + statusCtx, cancel := context.WithTimeout(ctx, processingStatusTimeout) + defer cancel() + return notifier.ProcessingStarted(statusCtx, cfg, msg, info) +} + +func (p *ChannelInboundProcessor) notifyProcessingCompleted( + ctx context.Context, + notifier channel.ProcessingStatusNotifier, + cfg channel.ChannelConfig, + msg channel.InboundMessage, + info channel.ProcessingStatusInfo, + handle channel.ProcessingStatusHandle, +) error { + if notifier == nil { + return nil + } + statusCtx, cancel := context.WithTimeout(ctx, processingStatusTimeout) + defer cancel() + return notifier.ProcessingCompleted(statusCtx, cfg, msg, info, handle) +} + +func (p *ChannelInboundProcessor) notifyProcessingFailed( + ctx context.Context, + notifier channel.ProcessingStatusNotifier, + cfg channel.ChannelConfig, + msg channel.InboundMessage, + info channel.ProcessingStatusInfo, + handle channel.ProcessingStatusHandle, + cause error, +) error { + if notifier == nil { + return nil + } + statusCtx, cancel := context.WithTimeout(ctx, processingStatusTimeout) + defer cancel() + return notifier.ProcessingFailed(statusCtx, cfg, msg, info, handle, cause) +} + +func (p *ChannelInboundProcessor) logProcessingStatusError( + stage string, + msg channel.InboundMessage, + identity InboundIdentity, + err error, +) { + if p == nil || p.logger == nil || err == nil { + return + } + p.logger.Warn( + "processing status notify failed", + slog.String("stage", stage), + slog.String("channel", msg.Channel.String()), + slog.String("channel_identity_id", identity.ChannelIdentityID), + slog.String("user_id", identity.UserID), + slog.Any("error", err), + ) +} diff --git a/internal/router/channel_test.go b/internal/router/channel_test.go index 0b7b666c..748bc36b 100644 --- a/internal/router/channel_test.go +++ b/internal/router/channel_test.go @@ -3,28 +3,67 @@ package router import ( "context" "encoding/json" + "errors" "log/slog" "strings" "testing" "github.com/memohai/memoh/internal/channel" - "github.com/memohai/memoh/internal/channelidentities" - "github.com/memohai/memoh/internal/chat" + "github.com/memohai/memoh/internal/channel/identities" + "github.com/memohai/memoh/internal/channel/route" + "github.com/memohai/memoh/internal/conversation" + messagepkg "github.com/memohai/memoh/internal/message" + "github.com/memohai/memoh/internal/schedule" ) type fakeChatGateway struct { - resp chat.ChatResponse + resp conversation.ChatResponse err error - gotReq chat.ChatRequest + gotReq conversation.ChatRequest + onChat func(conversation.ChatRequest) } -func (f *fakeChatGateway) Chat(ctx context.Context, req chat.ChatRequest) (chat.ChatResponse, error) { +func (f *fakeChatGateway) Chat(ctx context.Context, req conversation.ChatRequest) (conversation.ChatResponse, error) { f.gotReq = req + if f.onChat != nil { + f.onChat(req) + } return f.resp, f.err } +func (f *fakeChatGateway) StreamChat(ctx context.Context, req conversation.ChatRequest) (<-chan conversation.StreamChunk, <-chan error) { + f.gotReq = req + if f.onChat != nil { + f.onChat(req) + } + chunks := make(chan conversation.StreamChunk, 1) + errs := make(chan error, 1) + if f.err != nil { + errs <- f.err + close(chunks) + close(errs) + return chunks, errs + } + payload := map[string]any{ + "type": "agent_end", + "messages": f.resp.Messages, + } + data, err := json.Marshal(payload) + if err == nil { + chunks <- conversation.StreamChunk(data) + } + close(chunks) + close(errs) + return chunks, errs +} + +func (f *fakeChatGateway) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error { + return nil +} + type fakeReplySender struct { - sent []channel.OutboundMessage + sent []channel.OutboundMessage + events []channel.StreamEvent } func (s *fakeReplySender) Send(ctx context.Context, msg channel.OutboundMessage) error { @@ -32,49 +71,142 @@ func (s *fakeReplySender) Send(ctx context.Context, msg channel.OutboundMessage) return nil } -type fakeChatService struct { - resolveResult chat.ResolveChatResult - resolveErr error - persisted []chat.Message +func (s *fakeReplySender) OpenStream(ctx context.Context, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { + return &fakeOutboundStream{ + sender: s, + target: strings.TrimSpace(target), + }, nil } -func (f *fakeChatService) ResolveChat(ctx context.Context, botID, platform, conversationID, threadID, conversationType, userID, channelConfigID, replyTarget string) (chat.ResolveChatResult, error) { +type fakeOutboundStream struct { + sender *fakeReplySender + target string +} + +func (s *fakeOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { + if s == nil || s.sender == nil { + return nil + } + s.sender.events = append(s.sender.events, event) + if event.Type == channel.StreamEventFinal && event.Final != nil && !event.Final.Message.IsEmpty() { + s.sender.sent = append(s.sender.sent, channel.OutboundMessage{ + Target: s.target, + Message: event.Final.Message, + }) + } + return nil +} + +func (s *fakeOutboundStream) Close(ctx context.Context) error { + return nil +} + +type fakeProcessingStatusNotifier struct { + startedHandle channel.ProcessingStatusHandle + startedErr error + completedErr error + failedErr error + events []string + info []channel.ProcessingStatusInfo + completedSeen channel.ProcessingStatusHandle + failedSeen channel.ProcessingStatusHandle + failedCause error +} + +func (n *fakeProcessingStatusNotifier) ProcessingStarted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { + n.events = append(n.events, "started") + n.info = append(n.info, info) + return n.startedHandle, n.startedErr +} + +func (n *fakeProcessingStatusNotifier) ProcessingCompleted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { + n.events = append(n.events, "completed") + n.info = append(n.info, info) + n.completedSeen = handle + return n.completedErr +} + +func (n *fakeProcessingStatusNotifier) ProcessingFailed(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error { + n.events = append(n.events, "failed") + n.info = append(n.info, info) + n.failedSeen = handle + n.failedCause = cause + return n.failedErr +} + +type fakeProcessingStatusAdapter struct { + notifier *fakeProcessingStatusNotifier +} + +func (a *fakeProcessingStatusAdapter) Type() channel.ChannelType { + return channel.ChannelType("feishu") +} + +func (a *fakeProcessingStatusAdapter) Descriptor() channel.Descriptor { + return channel.Descriptor{ + Type: channel.ChannelType("feishu"), + Capabilities: channel.ChannelCapabilities{ + Text: true, + Reply: true, + }, + } +} + +func (a *fakeProcessingStatusAdapter) ProcessingStarted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { + return a.notifier.ProcessingStarted(ctx, cfg, msg, info) +} + +func (a *fakeProcessingStatusAdapter) ProcessingCompleted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { + return a.notifier.ProcessingCompleted(ctx, cfg, msg, info, handle) +} + +func (a *fakeProcessingStatusAdapter) ProcessingFailed(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error { + return a.notifier.ProcessingFailed(ctx, cfg, msg, info, handle, cause) +} + +type fakeChatService struct { + resolveResult route.ResolveConversationResult + resolveErr error + persisted []messagepkg.Message +} + +func (f *fakeChatService) ResolveConversation(ctx context.Context, input route.ResolveInput) (route.ResolveConversationResult, error) { if f.resolveErr != nil { - return chat.ResolveChatResult{}, f.resolveErr + return route.ResolveConversationResult{}, f.resolveErr } return f.resolveResult, nil } -func (f *fakeChatService) PersistMessage(ctx context.Context, chatID, botID, routeID, senderChannelIdentityID, senderUserID, platform, externalMessageID, role string, content json.RawMessage, metadata map[string]any) (chat.Message, error) { - msg := chat.Message{ - ChatID: chatID, - BotID: botID, - RouteID: routeID, - SenderChannelIdentityID: senderChannelIdentityID, - SenderUserID: senderUserID, - Platform: platform, - ExternalMessageID: externalMessageID, - Role: role, - Content: content, - Metadata: metadata, +func (f *fakeChatService) Persist(ctx context.Context, input messagepkg.PersistInput) (messagepkg.Message, error) { + msg := messagepkg.Message{ + BotID: input.BotID, + RouteID: input.RouteID, + SenderChannelIdentityID: input.SenderChannelIdentityID, + SenderUserID: input.SenderUserID, + Platform: input.Platform, + ExternalMessageID: input.ExternalMessageID, + SourceReplyToMessageID: input.SourceReplyToMessageID, + Role: input.Role, + Content: input.Content, + Metadata: input.Metadata, } f.persisted = append(f.persisted, msg) return msg, nil } func TestChannelInboundProcessorWithIdentity(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-1"}} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-1"}} memberSvc := &fakeMemberService{isMember: true} policySvc := &fakePolicyService{allow: false} - chatSvc := &fakeChatService{resolveResult: chat.ResolveChatResult{ChatID: "chat-1", RouteID: "route-1"}} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-1", RouteID: "route-1"}} gateway := &fakeChatGateway{ - resp: chat.ChatResponse{ - Messages: []chat.ModelMessage{ - {Role: "assistant", Content: chat.NewTextContent("AI reply")}, + resp: conversation.ChatResponse{ + Messages: []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent("AI reply")}, }, }, } - processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} @@ -103,8 +235,8 @@ func TestChannelInboundProcessorWithIdentity(t *testing.T) { if gateway.gotReq.SourceChannelIdentityID != "channelIdentity-1" { t.Errorf("expected source_channel_identity_id 'channelIdentity-1', got: %s", gateway.gotReq.SourceChannelIdentityID) } - if gateway.gotReq.ChatID != "chat-1" { - t.Errorf("expected chat_id 'chat-1', got: %s", gateway.gotReq.ChatID) + if gateway.gotReq.ChatID != "bot-1" { + t.Errorf("expected bot-scoped chat id 'bot-1', got: %s", gateway.gotReq.ChatID) } if len(sender.sent) != 1 || sender.sent[0].Message.PlainText() != "AI reply" { t.Fatalf("expected AI reply, got: %+v", sender.sent) @@ -112,12 +244,12 @@ func TestChannelInboundProcessorWithIdentity(t *testing.T) { } func TestChannelInboundProcessorDenied(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-2"}} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-2"}} memberSvc := &fakeMemberService{isMember: false} policySvc := &fakePolicyService{allow: false} chatSvc := &fakeChatService{} gateway := &fakeChatGateway{} - processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} @@ -142,12 +274,12 @@ func TestChannelInboundProcessorDenied(t *testing.T) { } func TestChannelInboundProcessorIgnoreEmpty(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-3"}} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-3"}} memberSvc := &fakeMemberService{isMember: true} policySvc := &fakePolicyService{allow: false} chatSvc := &fakeChatService{} gateway := &fakeChatGateway{} - processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1"} @@ -166,17 +298,17 @@ func TestChannelInboundProcessorIgnoreEmpty(t *testing.T) { } func TestChannelInboundProcessorSilentReply(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-4"}} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-4"}} memberSvc := &fakeMemberService{isMember: true} - chatSvc := &fakeChatService{resolveResult: chat.ResolveChatResult{ChatID: "chat-4", RouteID: "route-4"}} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-4", RouteID: "route-4"}} gateway := &fakeChatGateway{ - resp: chat.ChatResponse{ - Messages: []chat.ModelMessage{ - {Role: "assistant", Content: chat.NewTextContent("NO_REPLY")}, + resp: conversation.ChatResponse{ + Messages: []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent("NO_REPLY")}, }, }, } - processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} @@ -202,17 +334,17 @@ func TestChannelInboundProcessorSilentReply(t *testing.T) { } func TestChannelInboundProcessorGroupPassiveSync(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-5"}} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-5"}} memberSvc := &fakeMemberService{isMember: true} - chatSvc := &fakeChatService{resolveResult: chat.ResolveChatResult{ChatID: "chat-5", RouteID: "route-5"}} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-5", RouteID: "route-5"}} gateway := &fakeChatGateway{ - resp: chat.ChatResponse{ - Messages: []chat.ModelMessage{ - {Role: "assistant", Content: chat.NewTextContent("AI reply")}, + resp: conversation.ChatResponse{ + Messages: []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent("AI reply")}, }, }, } - processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} @@ -244,20 +376,23 @@ func TestChannelInboundProcessorGroupPassiveSync(t *testing.T) { if chatSvc.persisted[0].Role != "user" { t.Fatalf("expected persisted role user, got: %s", chatSvc.persisted[0].Role) } + if chatSvc.persisted[0].BotID != "bot-1" { + t.Fatalf("expected passive persisted bot_id bot-1, got: %s", chatSvc.persisted[0].BotID) + } } func TestChannelInboundProcessorGroupMentionTriggersReply(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-6"}} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-6"}} memberSvc := &fakeMemberService{isMember: true} - chatSvc := &fakeChatService{resolveResult: chat.ResolveChatResult{ChatID: "chat-6", RouteID: "route-6"}} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-6", RouteID: "route-6"}} gateway := &fakeChatGateway{ - resp: chat.ChatResponse{ - Messages: []chat.ModelMessage{ - {Role: "assistant", Content: chat.NewTextContent("AI reply")}, + resp: conversation.ChatResponse{ + Messages: []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent("AI reply")}, }, }, } - processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} @@ -286,24 +421,30 @@ func TestChannelInboundProcessorGroupMentionTriggersReply(t *testing.T) { if len(sender.sent) != 1 { t.Fatalf("expected one outbound reply, got %d", len(sender.sent)) } - if len(chatSvc.persisted) != 0 { - t.Fatalf("triggered group message should not use passive persistence") + if len(chatSvc.persisted) != 1 { + t.Fatalf("triggered group message should persist inbound user once, got: %d", len(chatSvc.persisted)) + } + if got := chatSvc.persisted[0].Metadata["trigger_mode"]; got != "active_chat" { + t.Fatalf("expected trigger_mode active_chat, got: %v", got) + } + if !gateway.gotReq.UserMessagePersisted { + t.Fatalf("expected UserMessagePersisted=true for pre-persisted inbound message") } } func TestChannelInboundProcessorPersonalGroupNonOwnerIgnored(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-member"}} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-member"}} memberSvc := &fakeMemberService{isMember: true} policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner"} - chatSvc := &fakeChatService{resolveResult: chat.ResolveChatResult{ChatID: "chat-personal-1", RouteID: "route-personal-1"}} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-personal-1", RouteID: "route-personal-1"}} gateway := &fakeChatGateway{ - resp: chat.ChatResponse{ - Messages: []chat.ModelMessage{ - {Role: "assistant", Content: chat.NewTextContent("AI reply")}, + resp: conversation.ChatResponse{ + Messages: []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent("AI reply")}, }, }, } - processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} @@ -334,19 +475,19 @@ func TestChannelInboundProcessorPersonalGroupNonOwnerIgnored(t *testing.T) { } } -func TestChannelInboundProcessorPersonalGroupOwnerForceReply(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-owner"}} +func TestChannelInboundProcessorPersonalGroupOwnerWithoutMentionUsesPassivePersistence(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-owner"}} memberSvc := &fakeMemberService{isMember: true} policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner"} - chatSvc := &fakeChatService{resolveResult: chat.ResolveChatResult{ChatID: "chat-personal-2", RouteID: "route-personal-2"}} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-personal-2", RouteID: "route-personal-2"}} gateway := &fakeChatGateway{ - resp: chat.ChatResponse{ - Messages: []chat.ModelMessage{ - {Role: "assistant", Content: chat.NewTextContent("AI reply")}, + resp: conversation.ChatResponse{ + Messages: []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent("AI reply")}, }, }, } - processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} @@ -366,10 +507,200 @@ func TestChannelInboundProcessorPersonalGroupOwnerForceReply(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if gateway.gotReq.Query == "" { - t.Fatalf("owner should trigger chat call in personal group") + if gateway.gotReq.Query != "" { + t.Fatalf("owner group message without mention should not trigger chat call") } - if len(sender.sent) != 1 { - t.Fatalf("expected one owner reply, got %d", len(sender.sent)) + if len(sender.sent) != 0 { + t.Fatalf("owner group message without mention should not send reply") + } + if len(chatSvc.persisted) != 1 { + t.Fatalf("expected one passive persisted message, got: %d", len(chatSvc.persisted)) + } + if got := chatSvc.persisted[0].Metadata["trigger_mode"]; got != "passive_sync" { + t.Fatalf("expected trigger_mode passive_sync, got: %v", got) + } +} + +func TestChannelInboundProcessorProcessingStatusSuccessLifecycle(t *testing.T) { + notifier := &fakeProcessingStatusNotifier{ + startedHandle: channel.ProcessingStatusHandle{Token: "reaction-1"}, + } + registry := channel.NewRegistry() + registry.MustRegister(&fakeProcessingStatusAdapter{notifier: notifier}) + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-1"}} + memberSvc := &fakeMemberService{isMember: true} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-1", RouteID: "route-1"}} + gateway := &fakeChatGateway{ + resp: conversation.ChatResponse{ + Messages: []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent("AI reply")}, + }, + }, + onChat: func(req conversation.ChatRequest) { + if len(notifier.events) != 1 || notifier.events[0] != "started" { + t.Fatalf("expected started before chat call, got events: %+v", notifier.events) + } + }, + } + processor := NewChannelInboundProcessor(slog.Default(), registry, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) + sender := &fakeReplySender{} + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{ID: "om_123", Text: "hello"}, + ReplyTarget: "chat_id:oc_123", + Sender: channel.Identity{SubjectID: "ext-1"}, + Conversation: channel.Conversation{ + ID: "oc_123", + Type: "p2p", + }, + } + + err := processor.HandleInbound(context.Background(), cfg, msg, sender) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(notifier.events) != 2 || notifier.events[0] != "started" || notifier.events[1] != "completed" { + t.Fatalf("unexpected processing status lifecycle: %+v", notifier.events) + } + if notifier.completedSeen.Token != "reaction-1" { + t.Fatalf("expected completed token reaction-1, got: %q", notifier.completedSeen.Token) + } + if notifier.failedCause != nil { + t.Fatalf("expected failed cause nil, got: %v", notifier.failedCause) + } + if len(notifier.info) == 0 || notifier.info[0].SourceMessageID != "om_123" { + t.Fatalf("expected processing info source message id om_123, got: %+v", notifier.info) + } + if len(sender.sent) != 1 { + t.Fatalf("expected one outbound reply, got %d", len(sender.sent)) + } +} + +func TestChannelInboundProcessorProcessingStatusFailureLifecycle(t *testing.T) { + notifier := &fakeProcessingStatusNotifier{ + startedHandle: channel.ProcessingStatusHandle{Token: "reaction-2"}, + } + registry := channel.NewRegistry() + registry.MustRegister(&fakeProcessingStatusAdapter{notifier: notifier}) + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-2"}} + memberSvc := &fakeMemberService{isMember: true} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-2", RouteID: "route-2"}} + chatErr := errors.New("chat gateway unavailable") + gateway := &fakeChatGateway{err: chatErr} + processor := NewChannelInboundProcessor(slog.Default(), registry, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) + sender := &fakeReplySender{} + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{ID: "om_456", Text: "hello"}, + ReplyTarget: "chat_id:oc_456", + Sender: channel.Identity{SubjectID: "ext-2"}, + Conversation: channel.Conversation{ + ID: "oc_456", + Type: "p2p", + }, + } + + err := processor.HandleInbound(context.Background(), cfg, msg, sender) + if !errors.Is(err, chatErr) { + t.Fatalf("expected chat error, got: %v", err) + } + if len(notifier.events) != 2 || notifier.events[0] != "started" || notifier.events[1] != "failed" { + t.Fatalf("unexpected processing status lifecycle: %+v", notifier.events) + } + if !errors.Is(notifier.failedCause, chatErr) { + t.Fatalf("expected failed cause chat error, got: %v", notifier.failedCause) + } + if notifier.failedSeen.Token != "reaction-2" { + t.Fatalf("expected failed token reaction-2, got: %q", notifier.failedSeen.Token) + } + if len(sender.sent) != 0 { + t.Fatalf("expected no outbound reply on chat failure, got: %+v", sender.sent) + } +} + +func TestChannelInboundProcessorProcessingStatusErrorsAreBestEffort(t *testing.T) { + notifier := &fakeProcessingStatusNotifier{ + startedErr: errors.New("start notify failed"), + completedErr: errors.New("completed notify failed"), + } + registry := channel.NewRegistry() + registry.MustRegister(&fakeProcessingStatusAdapter{notifier: notifier}) + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-3"}} + memberSvc := &fakeMemberService{isMember: true} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-3", RouteID: "route-3"}} + gateway := &fakeChatGateway{ + resp: conversation.ChatResponse{ + Messages: []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent("AI reply")}, + }, + }, + } + processor := NewChannelInboundProcessor(slog.Default(), registry, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) + sender := &fakeReplySender{} + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{ID: "om_789", Text: "hello"}, + ReplyTarget: "chat_id:oc_789", + Sender: channel.Identity{SubjectID: "ext-3"}, + Conversation: channel.Conversation{ + ID: "oc_789", + Type: "p2p", + }, + } + + err := processor.HandleInbound(context.Background(), cfg, msg, sender) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(notifier.events) != 2 || notifier.events[0] != "started" || notifier.events[1] != "completed" { + t.Fatalf("unexpected processing status lifecycle: %+v", notifier.events) + } + if notifier.completedSeen.Token != "" { + t.Fatalf("expected empty completed token after started failure, got: %q", notifier.completedSeen.Token) + } + if len(sender.sent) != 1 { + t.Fatalf("expected one outbound reply, got %d", len(sender.sent)) + } +} + +func TestChannelInboundProcessorProcessingFailedNotifyErrorDoesNotOverrideChatError(t *testing.T) { + notifier := &fakeProcessingStatusNotifier{ + startedHandle: channel.ProcessingStatusHandle{Token: "reaction-4"}, + failedErr: errors.New("failed notify error"), + } + registry := channel.NewRegistry() + registry.MustRegister(&fakeProcessingStatusAdapter{notifier: notifier}) + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-4"}} + memberSvc := &fakeMemberService{isMember: true} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-4", RouteID: "route-4"}} + chatErr := errors.New("chat failed") + gateway := &fakeChatGateway{err: chatErr} + processor := NewChannelInboundProcessor(slog.Default(), registry, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) + sender := &fakeReplySender{} + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{ID: "om_999", Text: "hello"}, + ReplyTarget: "chat_id:oc_999", + Sender: channel.Identity{SubjectID: "ext-4"}, + Conversation: channel.Conversation{ + ID: "oc_999", + Type: "p2p", + }, + } + + err := processor.HandleInbound(context.Background(), cfg, msg, sender) + if !errors.Is(err, chatErr) { + t.Fatalf("expected original chat error, got: %v", err) + } + if len(notifier.events) != 2 || notifier.events[0] != "started" || notifier.events[1] != "failed" { + t.Fatalf("unexpected processing status lifecycle: %+v", notifier.events) } } diff --git a/internal/router/identity.go b/internal/router/identity.go index e4aea0b1..399fda71 100644 --- a/internal/router/identity.go +++ b/internal/router/identity.go @@ -10,7 +10,7 @@ import ( "github.com/memohai/memoh/internal/bind" "github.com/memohai/memoh/internal/channel" - "github.com/memohai/memoh/internal/channelidentities" + "github.com/memohai/memoh/internal/channel/identities" "github.com/memohai/memoh/internal/preauth" ) @@ -59,7 +59,7 @@ func IdentityStateFromContext(ctx context.Context) (IdentityState, bool) { // ChannelIdentityService is the minimal interface for channel identity resolution. type ChannelIdentityService interface { - ResolveByChannelIdentity(ctx context.Context, channel, channelSubjectID, displayName string) (channelidentities.ChannelIdentity, error) + ResolveByChannelIdentity(ctx context.Context, channel, channelSubjectID, displayName string) (identities.ChannelIdentity, error) Canonicalize(ctx context.Context, channelIdentityID string) (string, error) GetLinkedUserID(ctx context.Context, channelIdentityID string) (string, error) LinkChannelIdentityToUser(ctx context.Context, channelIdentityID, userID string) error @@ -169,7 +169,7 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi channelConfigID = "" } subjectID := extractSubjectIdentity(msg) - displayName := extractDisplayName(msg) + displayName := r.resolveDisplayName(ctx, cfg, msg, subjectID) state := IdentityState{ Identity: InboundIdentity{ @@ -184,17 +184,11 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi return state, fmt.Errorf("cannot resolve identity: no channel_subject_id") } - channelIdentity, err := r.channelIdentities.ResolveByChannelIdentity(ctx, msg.Channel.String(), subjectID, displayName) - if err != nil { - return state, fmt.Errorf("resolve channel identity: %w", err) - } - - channelIdentityID := strings.TrimSpace(channelIdentity.ID) - state.Identity.ChannelIdentityID = channelIdentityID - linkedUserID, err := r.channelIdentities.GetLinkedUserID(ctx, channelIdentityID) + channelIdentityID, linkedUserID, err := r.resolveIdentityWithLinkedUser(ctx, msg, subjectID, displayName) if err != nil { return state, err } + state.Identity.ChannelIdentityID = channelIdentityID state.Identity.UserID = strings.TrimSpace(linkedUserID) if strings.TrimSpace(state.Identity.UserID) == "" { state.Identity.UserID = r.tryLinkConfiglessChannelIdentityToUser(ctx, msg, channelIdentityID) @@ -224,21 +218,12 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi isOwner := strings.TrimSpace(state.Identity.UserID) != "" && strings.TrimSpace(ownerUserID) == strings.TrimSpace(state.Identity.UserID) if !isOwner { - if isGroupConversationType(msg.Conversation.Type) { - // Ignore non-owner group messages for personal bots. - state.Decision = &IdentityDecision{Stop: true} - return state, nil - } - state.Decision = &IdentityDecision{ - Stop: true, - Reply: channel.Message{Text: r.unboundReply}, - } + // Ignore all non-owner messages for personal bots. + state.Decision = &IdentityDecision{Stop: true} return state, nil } - if isGroupConversationType(msg.Conversation.Type) { - // Owner can chat in group for personal bots. - state.Identity.ForceReply = true - } + // Owner is authorized, but group trigger policy is still decided by + // shouldTriggerAssistantResponse in channel routing. return state, nil } } @@ -283,6 +268,13 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi return state, err } + // In group conversations, silently drop unauthorized messages to avoid spamming + // the channel with "access denied" replies (same behavior as personal bot non-owner). + if isGroupConversationType(msg.Conversation.Type) { + state.Decision = &IdentityDecision{Stop: true} + return state, nil + } + state.Decision = &IdentityDecision{ Stop: true, Reply: channel.Message{Text: r.unboundReply}, @@ -290,6 +282,37 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi return state, nil } +func (r *IdentityResolver) resolveIdentityWithLinkedUser(ctx context.Context, msg channel.InboundMessage, primarySubjectID, displayName string) (string, string, error) { + candidates := identitySubjectCandidates(msg, primarySubjectID) + if len(candidates) == 0 { + return "", "", fmt.Errorf("cannot resolve identity: no channel_subject_id") + } + + firstChannelIdentityID := "" + for _, subjectID := range candidates { + channelIdentity, err := r.channelIdentities.ResolveByChannelIdentity(ctx, msg.Channel.String(), subjectID, displayName) + if err != nil { + return "", "", fmt.Errorf("resolve channel identity: %w", err) + } + channelIdentityID := strings.TrimSpace(channelIdentity.ID) + if channelIdentityID == "" { + continue + } + if firstChannelIdentityID == "" { + firstChannelIdentityID = channelIdentityID + } + linkedUserID, err := r.channelIdentities.GetLinkedUserID(ctx, channelIdentityID) + if err != nil { + return "", "", err + } + linkedUserID = strings.TrimSpace(linkedUserID) + if linkedUserID != "" { + return channelIdentityID, linkedUserID, nil + } + } + return firstChannelIdentityID, "", nil +} + func (r *IdentityResolver) tryHandlePreauthKey(ctx context.Context, msg channel.InboundMessage, botID, userID, subjectID string) (bool, IdentityDecision, error) { tokenText := strings.TrimSpace(msg.Message.PlainText()) if tokenText == "" || r.preauth == nil { @@ -411,21 +434,85 @@ func extractSubjectIdentity(msg channel.InboundMessage) string { return strings.TrimSpace(msg.Sender.DisplayName) } +func identitySubjectCandidates(msg channel.InboundMessage, primary string) []string { + candidates := make([]string, 0, 3) + appendUnique := func(value string) { + value = strings.TrimSpace(value) + if value == "" { + return + } + for _, existing := range candidates { + if existing == value { + return + } + } + candidates = append(candidates, value) + } + + appendUnique(primary) + appendUnique(msg.Sender.Attribute("open_id")) + appendUnique(msg.Sender.Attribute("user_id")) + return candidates +} + func extractDisplayName(msg channel.InboundMessage) string { if strings.TrimSpace(msg.Sender.DisplayName) != "" { return strings.TrimSpace(msg.Sender.DisplayName) } - if strings.TrimSpace(msg.Sender.SubjectID) != "" { - return strings.TrimSpace(msg.Sender.SubjectID) + if value := strings.TrimSpace(msg.Sender.Attribute("display_name")); value != "" { + return value + } + if value := strings.TrimSpace(msg.Sender.Attribute("name")); value != "" { + return value } if value := strings.TrimSpace(msg.Sender.Attribute("username")); value != "" { return value } - if value := strings.TrimSpace(msg.Sender.Attribute("user_id")); value != "" { - return value + return "" +} + +func (r *IdentityResolver) resolveDisplayName(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, subjectID string) string { + displayName := extractDisplayName(msg) + if displayName != "" { + return displayName } - if value := strings.TrimSpace(msg.Sender.Attribute("open_id")); value != "" { - return value + return r.resolveDisplayNameFromDirectory(ctx, cfg, msg, subjectID) +} + +func (r *IdentityResolver) resolveDisplayNameFromDirectory(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, subjectID string) string { + if r.registry == nil { + return "" + } + subjectID = strings.TrimSpace(subjectID) + if subjectID == "" { + return "" + } + directoryAdapter, ok := r.registry.DirectoryAdapter(msg.Channel) + if !ok || directoryAdapter == nil { + return "" + } + if ctx == nil { + ctx = context.Background() + } + lookupCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + entry, err := directoryAdapter.ResolveEntry(lookupCtx, cfg, subjectID, channel.DirectoryEntryUser) + if err != nil { + if r.logger != nil { + r.logger.Debug( + "resolve display name from directory failed", + slog.String("channel", msg.Channel.String()), + slog.String("subject_id", subjectID), + slog.Any("error", err), + ) + } + return "" + } + if name := strings.TrimSpace(entry.Name); name != "" { + return name + } + if handle := strings.TrimSpace(entry.Handle); handle != "" { + return handle } return "" } diff --git a/internal/router/identity_test.go b/internal/router/identity_test.go index 45549492..843ba72e 100644 --- a/internal/router/identity_test.go +++ b/internal/router/identity_test.go @@ -2,28 +2,38 @@ package router import ( "context" + "errors" "log/slog" "testing" "time" "github.com/memohai/memoh/internal/bind" "github.com/memohai/memoh/internal/channel" - "github.com/memohai/memoh/internal/channelidentities" + "github.com/memohai/memoh/internal/channel/identities" "github.com/memohai/memoh/internal/preauth" ) type fakeChannelIdentityService struct { - channelIdentity channelidentities.ChannelIdentity + channelIdentity identities.ChannelIdentity + bySubject map[string]identities.ChannelIdentity err error canonical map[string]string linked map[string]string calls int + lastDisplayName string } -func (f *fakeChannelIdentityService) ResolveByChannelIdentity(ctx context.Context, platform, externalID, displayName string) (channelidentities.ChannelIdentity, error) { +func (f *fakeChannelIdentityService) ResolveByChannelIdentity(ctx context.Context, platform, externalID, displayName string) (identities.ChannelIdentity, error) { f.calls++ + f.lastDisplayName = displayName if f.err != nil { - return channelidentities.ChannelIdentity{}, f.err + return identities.ChannelIdentity{}, f.err + } + if f.bySubject != nil { + if identity, ok := f.bySubject[externalID]; ok { + return identity, nil + } + return identities.ChannelIdentity{}, nil } return f.channelIdentity, nil } @@ -145,8 +155,44 @@ func (f *fakeBindService) Consume(ctx context.Context, code bind.Code, channelCh return f.consumeErr } +type fakeDirectoryAdapter struct { + channelType channel.ChannelType + resolveFn func(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) +} + +func (f *fakeDirectoryAdapter) Type() channel.ChannelType { + return f.channelType +} + +func (f *fakeDirectoryAdapter) Descriptor() channel.Descriptor { + return channel.Descriptor{ + Type: f.channelType, + DisplayName: "FakeDirectory", + Capabilities: channel.ChannelCapabilities{}, + } +} + +func (f *fakeDirectoryAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} + +func (f *fakeDirectoryAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} + +func (f *fakeDirectoryAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelConfig, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} + +func (f *fakeDirectoryAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + if f.resolveFn != nil { + return f.resolveFn(ctx, cfg, input, kind) + } + return channel.DirectoryEntry{}, errors.New("resolve not implemented") +} + func TestIdentityResolverAllowGuestWithoutMembershipSideEffect(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-1"}} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-1"}} memberSvc := &fakeMemberService{isMember: false} policySvc := &fakePolicyService{allow: true, botType: "public"} resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") @@ -173,8 +219,100 @@ func TestIdentityResolverAllowGuestWithoutMembershipSideEffect(t *testing.T) { } } +func TestIdentityResolverResolveDisplayNameFromDirectory(t *testing.T) { + registry := channel.NewRegistry() + directoryAdapter := &fakeDirectoryAdapter{ + channelType: channel.ChannelType("feishu"), + resolveFn: func(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + if kind != channel.DirectoryEntryUser { + t.Fatalf("expected kind user, got %s", kind) + } + if input != "ou-directory" { + t.Fatalf("expected subject id ou-directory, got %s", input) + } + return channel.DirectoryEntry{ + Kind: channel.DirectoryEntryUser, + Name: "Directory Name", + }, nil + }, + } + if err := registry.Register(directoryAdapter); err != nil { + t.Fatalf("register directory adapter failed: %v", err) + } + + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-directory"}} + memberSvc := &fakeMemberService{isMember: true} + policySvc := &fakePolicyService{allow: false, botType: "public"} + resolver := NewIdentityResolver(slog.Default(), registry, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello"}, + ReplyTarget: "target-id", + Sender: channel.Identity{ + SubjectID: "ou-directory", + Attributes: map[string]string{ + "open_id": "ou-directory", + }, + }, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1", ChannelType: channel.ChannelType("feishu")}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Identity.DisplayName != "Directory Name" { + t.Fatalf("expected directory display name, got %q", state.Identity.DisplayName) + } + if channelIdentitySvc.lastDisplayName != "Directory Name" { + t.Fatalf("expected upsert display name Directory Name, got %q", channelIdentitySvc.lastDisplayName) + } +} + +func TestIdentityResolverDirectoryLookupFailureDoesNotFallbackToOpenID(t *testing.T) { + registry := channel.NewRegistry() + directoryAdapter := &fakeDirectoryAdapter{ + channelType: channel.ChannelType("feishu"), + resolveFn: func(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + return channel.DirectoryEntry{}, errors.New("lookup failed") + }, + } + if err := registry.Register(directoryAdapter); err != nil { + t.Fatalf("register directory adapter failed: %v", err) + } + + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-directory-fail"}} + memberSvc := &fakeMemberService{isMember: true} + policySvc := &fakePolicyService{allow: false, botType: "public"} + resolver := NewIdentityResolver(slog.Default(), registry, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello"}, + ReplyTarget: "target-id", + Sender: channel.Identity{ + SubjectID: "ou-directory-fail", + Attributes: map[string]string{ + "open_id": "ou-directory-fail", + "user_id": "u-directory-fail", + }, + }, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1", ChannelType: channel.ChannelType("feishu")}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Identity.DisplayName != "" { + t.Fatalf("expected empty display name when directory lookup fails, got %q", state.Identity.DisplayName) + } + if channelIdentitySvc.lastDisplayName != "" { + t.Fatalf("expected empty upsert display name on lookup failure, got %q", channelIdentitySvc.lastDisplayName) + } +} + func TestIdentityResolverExistingMemberPasses(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-2"}} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-2"}} memberSvc := &fakeMemberService{isMember: true} policySvc := &fakePolicyService{allow: false, botType: "public"} resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") @@ -196,7 +334,7 @@ func TestIdentityResolverExistingMemberPasses(t *testing.T) { } func TestIdentityResolverPreauthKey(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-3"}} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-3"}} memberSvc := &fakeMemberService{isMember: false} policySvc := &fakePolicyService{allow: false, botType: "public"} preauthSvc := &fakePreauthServiceIdentity{ @@ -232,7 +370,7 @@ func TestIdentityResolverPreauthKey(t *testing.T) { } func TestIdentityResolverPreauthKeyExpired(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-4"}} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-4"}} memberSvc := &fakeMemberService{isMember: false} policySvc := &fakePolicyService{allow: false, botType: "public"} preauthSvc := &fakePreauthServiceIdentity{ @@ -265,7 +403,7 @@ func TestIdentityResolverPreauthKeyExpired(t *testing.T) { } func TestIdentityResolverDenied(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-5"}} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-5"}} memberSvc := &fakeMemberService{isMember: false} policySvc := &fakePolicyService{allow: false, botType: "public"} resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "Access denied.", "") @@ -287,7 +425,7 @@ func TestIdentityResolverDenied(t *testing.T) { } func TestIdentityResolverPersonalBotRejectsGroupMessages(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-group"}} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-group"}} memberSvc := &fakeMemberService{isMember: true} policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner"} resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") @@ -319,7 +457,7 @@ func TestIdentityResolverPersonalBotRejectsGroupMessages(t *testing.T) { } func TestIdentityResolverPersonalBotAllowsOwnerInGroup(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-owner"}} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-owner"}} memberSvc := &fakeMemberService{isMember: false} policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner"} resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") @@ -342,13 +480,13 @@ func TestIdentityResolverPersonalBotAllowsOwnerInGroup(t *testing.T) { if state.Decision != nil { t.Fatal("owner group message should pass") } - if !state.Identity.ForceReply { - t.Fatal("owner group message should force reply") + if state.Identity.ForceReply { + t.Fatal("owner group message should not force reply") } } func TestIdentityResolverPersonalBotAllowsOwnerDirectWithoutMembership(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-owner-direct"}} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-owner-direct"}} memberSvc := &fakeMemberService{isMember: false} policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner-direct"} resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") @@ -376,8 +514,57 @@ func TestIdentityResolverPersonalBotAllowsOwnerDirectWithoutMembership(t *testin } } +func TestIdentityResolverPersonalBotOwnerFallbackByAlternateSubject(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{ + bySubject: map[string]identities.ChannelIdentity{ + "ou-open-owner": {ID: "channelIdentity-open-owner"}, + "u-owner": {ID: "channelIdentity-user-owner"}, + }, + linked: map[string]string{ + "channelIdentity-user-owner": "owner-user-1", + }, + } + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "owner-user-1"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello from owner"}, + Sender: channel.Identity{ + SubjectID: "ou-open-owner", + Attributes: map[string]string{ + "open_id": "ou-open-owner", + "user_id": "u-owner", + }, + }, + Conversation: channel.Conversation{ + ID: "p2p-1", + Type: "p2p", + }, + } + + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision != nil { + t.Fatal("owner direct message should pass after alternate subject fallback") + } + if state.Identity.UserID != "owner-user-1" { + t.Fatalf("expected owner-user-1, got: %s", state.Identity.UserID) + } + if state.Identity.ChannelIdentityID != "channelIdentity-user-owner" { + t.Fatalf("expected fallback channel identity, got: %s", state.Identity.ChannelIdentityID) + } + if channelIdentitySvc.calls < 2 { + t.Fatalf("expected fallback resolution attempts, got calls=%d", channelIdentitySvc.calls) + } +} + func TestIdentityResolverPersonalBotRejectsNonOwnerDirectEvenIfMember(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-non-owner"}} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-non-owner"}} memberSvc := &fakeMemberService{isMember: true} policySvc := &fakePolicyService{allow: true, botType: "personal", ownerUserID: "channelIdentity-owner"} resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "Access denied.", "") @@ -400,8 +587,8 @@ func TestIdentityResolverPersonalBotRejectsNonOwnerDirectEvenIfMember(t *testing if state.Decision == nil || !state.Decision.Stop { t.Fatal("non-owner direct message should be rejected for personal bot") } - if state.Decision.Reply.Text != "Access denied." { - t.Fatalf("unexpected reject message: %s", state.Decision.Reply.Text) + if !state.Decision.Reply.IsEmpty() { + t.Fatal("non-owner direct message should be silently ignored") } } @@ -409,7 +596,7 @@ func TestIdentityResolverBindRunsBeforeMembershipCheck(t *testing.T) { shadowID := "channelIdentity-shadow" humanID := "channelIdentity-human" channelIdentitySvc := &fakeChannelIdentityService{ - channelIdentity: channelidentities.ChannelIdentity{ID: shadowID}, + channelIdentity: identities.ChannelIdentity{ID: shadowID}, linked: map[string]string{ shadowID: shadowID, }, @@ -455,7 +642,7 @@ func TestIdentityResolverBindRunsBeforeMembershipCheck(t *testing.T) { } func TestIdentityResolverBindConsumeErrorHandledAsDecision(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-shadow"}} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-shadow"}} bindSvc := &fakeBindService{ code: bind.Code{ ID: "code-2", @@ -488,7 +675,7 @@ func TestIdentityResolverBindCodeNotScopedToCurrentBot(t *testing.T) { shadowID := "channelIdentity-shadow-any-bot" humanID := "channelIdentity-human-any-bot" channelIdentitySvc := &fakeChannelIdentityService{ - channelIdentity: channelidentities.ChannelIdentity{ID: shadowID}, + channelIdentity: identities.ChannelIdentity{ID: shadowID}, linked: map[string]string{ shadowID: shadowID, }, @@ -529,8 +716,66 @@ func TestIdentityResolverBindCodeNotScopedToCurrentBot(t *testing.T) { } } +func TestIdentityResolverPublicBotGroupDeniedSilently(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-group-denied"}} + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: false, botType: "public"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "Access denied.", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello"}, + ReplyTarget: "group-target", + Sender: channel.Identity{SubjectID: "stranger-group"}, + Conversation: channel.Conversation{ + ID: "group-1", + Type: "group", + }, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatal("unauthorized group message should be stopped") + } + if !state.Decision.Reply.IsEmpty() { + t.Fatal("unauthorized group message should be silently dropped, not replied") + } +} + +func TestIdentityResolverPublicBotDirectDeniedWithReply(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-direct-denied"}} + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: false, botType: "public"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "Access denied.", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello"}, + ReplyTarget: "direct-target", + Sender: channel.Identity{SubjectID: "stranger-direct"}, + Conversation: channel.Conversation{ + ID: "p2p-1", + Type: "p2p", + }, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatal("unauthorized direct message should be stopped") + } + if state.Decision.Reply.IsEmpty() { + t.Fatal("unauthorized direct message should reply with access denied") + } +} + func TestIdentityResolverBindCodePlatformMismatch(t *testing.T) { - channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-platform-mismatch"}} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-platform-mismatch"}} bindSvc := &fakeBindService{ code: bind.Code{ ID: "code-platform", diff --git a/internal/schedule/service.go b/internal/schedule/service.go index 5af044b2..b7843723 100644 --- a/internal/schedule/service.go +++ b/internal/schedule/service.go @@ -9,12 +9,12 @@ import ( "sync" "time" - "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/robfig/cron/v3" "github.com/memohai/memoh/internal/auth" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -71,7 +71,7 @@ func (s *Service) Create(ctx context.Context, botID string, req CreateRequest) ( if _, err := s.parser.Parse(req.Pattern); err != nil { return Schedule{}, fmt.Errorf("invalid cron pattern: %w", err) } - pgBotID, err := parseUUID(botID) + pgBotID, err := db.ParseUUID(botID) if err != nil { return Schedule{}, err } @@ -104,7 +104,7 @@ func (s *Service) Create(ctx context.Context, botID string, req CreateRequest) ( } func (s *Service) Get(ctx context.Context, id string) (Schedule, error) { - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return Schedule{}, err } @@ -119,7 +119,7 @@ func (s *Service) Get(ctx context.Context, id string) (Schedule, error) { } func (s *Service) List(ctx context.Context, botID string) ([]Schedule, error) { - pgBotID, err := parseUUID(botID) + pgBotID, err := db.ParseUUID(botID) if err != nil { return nil, err } @@ -135,7 +135,7 @@ func (s *Service) List(ctx context.Context, botID string) ([]Schedule, error) { } func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Schedule, error) { - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return Schedule{}, err } @@ -191,7 +191,7 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Sch } func (s *Service) Delete(ctx context.Context, id string) error { - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return err } @@ -253,7 +253,7 @@ func (s *Service) runSchedule(ctx context.Context, schedule Schedule) error { // resolveBotOwner returns the owner user ID for the given bot. func (s *Service) resolveBotOwner(ctx context.Context, botID string) (string, error) { - pgBotID, err := parseUUID(botID) + pgBotID, err := db.ParseUUID(botID) if err != nil { return "", err } @@ -261,7 +261,7 @@ func (s *Service) resolveBotOwner(ctx context.Context, botID string) (string, er if err != nil { return "", fmt.Errorf("get bot: %w", err) } - ownerID := toUUIDString(bot.OwnerUserID) + ownerID := bot.OwnerUserID.String() if ownerID == "" { return "", fmt.Errorf("bot owner not found") } @@ -281,7 +281,7 @@ func (s *Service) generateTriggerToken(userID string) (string, error) { } func (s *Service) scheduleJob(schedule sqlc.Schedule) error { - id := toUUIDString(schedule.ID) + id := schedule.ID.String() if id == "" { return fmt.Errorf("schedule id missing") } @@ -299,7 +299,7 @@ func (s *Service) scheduleJob(schedule sqlc.Schedule) error { } func (s *Service) rescheduleJob(schedule sqlc.Schedule) { - id := toUUIDString(schedule.ID) + id := schedule.ID.String() if id == "" { return } @@ -321,14 +321,14 @@ func (s *Service) removeJob(id string) { func toSchedule(row sqlc.Schedule) Schedule { item := Schedule{ - ID: toUUIDString(row.ID), + ID: row.ID.String(), Name: row.Name, Description: row.Description, Pattern: row.Pattern, CurrentCalls: int(row.CurrentCalls), Enabled: row.Enabled, Command: row.Command, - BotID: toUUIDString(row.BotID), + BotID: row.BotID.String(), } if row.MaxCalls.Valid { max := int(row.MaxCalls.Int32) @@ -343,32 +343,10 @@ func toSchedule(row sqlc.Schedule) Schedule { return item } -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} - func toUUID(id string) pgtype.UUID { - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return pgtype.UUID{} } return pgID } - -func toUUIDString(value pgtype.UUID) string { - if !value.Valid { - return "" - } - id, err := uuid.FromBytes(value.Bytes[:]) - if err != nil { - return "" - } - return id.String() -} diff --git a/internal/schedule/trigger.go b/internal/schedule/trigger.go index 1cc01b39..b15ed3e3 100644 --- a/internal/schedule/trigger.go +++ b/internal/schedule/trigger.go @@ -14,7 +14,7 @@ type TriggerPayload struct { ChatID string } -// Triggerer 负责触发与聊天相关的调度执行。 +// Triggerer triggers schedule execution for chat-related jobs. type Triggerer interface { TriggerSchedule(ctx context.Context, botID string, payload TriggerPayload, token string) error } diff --git a/internal/server/server.go b/internal/server/server.go index c0c41a9f..8d1b64c6 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -17,7 +17,7 @@ type Server struct { logger *slog.Logger } -func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, chatHandler *handlers.ChatHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, settingsHandler *handlers.SettingsHandler, preauthHandler *handlers.PreauthHandler, bindHandler *handlers.BindHandler, scheduleHandler *handlers.ScheduleHandler, subagentHandler *handlers.SubagentHandler, containerdHandler *handlers.ContainerdHandler, channelHandler *handlers.ChannelHandler, usersHandler *handlers.UsersHandler, mcpHandler *handlers.MCPHandler, cliHandler *handlers.LocalChannelHandler, webHandler *handlers.LocalChannelHandler) *Server { +func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, conversationHandler *handlers.MessageHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, settingsHandler *handlers.SettingsHandler, preauthHandler *handlers.PreauthHandler, bindHandler *handlers.BindHandler, scheduleHandler *handlers.ScheduleHandler, subagentHandler *handlers.SubagentHandler, containerdHandler *handlers.ContainerdHandler, channelHandler *handlers.ChannelHandler, usersHandler *handlers.UsersHandler, mcpHandler *handlers.MCPHandler, cliHandler *handlers.LocalChannelHandler, webHandler *handlers.LocalChannelHandler) *Server { if addr == "" { addr = ":8080" } @@ -63,8 +63,8 @@ func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *han if embeddingsHandler != nil { embeddingsHandler.Register(e) } - if chatHandler != nil { - chatHandler.Register(e) + if conversationHandler != nil { + conversationHandler.Register(e) } if swaggerHandler != nil { swaggerHandler.Register(e) diff --git a/internal/settings/service.go b/internal/settings/service.go index 47d25e7a..3b3bf9de 100644 --- a/internal/settings/service.go +++ b/internal/settings/service.go @@ -7,10 +7,10 @@ import ( "log/slog" "strings" - "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -30,7 +30,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { // Get returns user-level settings. func (s *Service) Get(ctx context.Context, userID string) (Settings, error) { - pgID, err := parseUUID(userID) + pgID, err := db.ParseUUID(userID) if err != nil { return Settings{}, err } @@ -55,7 +55,7 @@ func (s *Service) Upsert(ctx context.Context, userID string, req UpsertRequest) if s.queries == nil { return Settings{}, fmt.Errorf("settings queries not configured") } - pgID, err := parseUUID(userID) + pgID, err := db.ParseUUID(userID) if err != nil { return Settings{}, err } @@ -106,7 +106,7 @@ func (s *Service) Upsert(ctx context.Context, userID string, req UpsertRequest) } func (s *Service) GetBot(ctx context.Context, botID string) (Settings, error) { - pgID, err := parseUUID(botID) + pgID, err := db.ParseUUID(botID) if err != nil { return Settings{}, err } @@ -121,7 +121,7 @@ func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest if s.queries == nil { return Settings{}, fmt.Errorf("settings queries not configured") } - pgID, err := parseUUID(botID) + pgID, err := db.ParseUUID(botID) if err != nil { return Settings{}, err } @@ -191,7 +191,7 @@ func (s *Service) Delete(ctx context.Context, botID string) error { if s.queries == nil { return fmt.Errorf("settings queries not configured") } - pgID, err := parseUUID(botID) + pgID, err := db.ParseUUID(botID) if err != nil { return err } @@ -278,13 +278,3 @@ func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype. return row.ID, nil } -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(id) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} diff --git a/internal/settings/types.go b/internal/settings/types.go index 9950566c..f750f8ff 100644 --- a/internal/settings/types.go +++ b/internal/settings/types.go @@ -6,12 +6,12 @@ const ( ) type Settings struct { - ChatModelID string `json:"chat_model_id" validate:"required"` - MemoryModelID string `json:"memory_model_id" validate:"required"` - EmbeddingModelID string `json:"embedding_model_id" validate:"required"` - MaxContextLoadTime int `json:"max_context_load_time" validate:"required"` - Language string `json:"language" validate:"required"` - AllowGuest bool `json:"allow_guest" validate:"required"` + ChatModelID string `json:"chat_model_id"` + MemoryModelID string `json:"memory_model_id"` + EmbeddingModelID string `json:"embedding_model_id"` + MaxContextLoadTime int `json:"max_context_load_time"` + Language string `json:"language"` + AllowGuest bool `json:"allow_guest"` } type UpsertRequest struct { diff --git a/internal/subagent/service.go b/internal/subagent/service.go index c864a4db..c9e5ce77 100644 --- a/internal/subagent/service.go +++ b/internal/subagent/service.go @@ -8,10 +8,9 @@ import ( "log/slog" "strings" - "github.com/google/uuid" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -39,7 +38,7 @@ func (s *Service) Create(ctx context.Context, botID string, req CreateRequest) ( if description == "" { return Subagent{}, fmt.Errorf("description is required") } - pgBotID, err := parseUUID(botID) + pgBotID, err := db.ParseUUID(botID) if err != nil { return Subagent{}, err } @@ -70,7 +69,7 @@ func (s *Service) Create(ctx context.Context, botID string, req CreateRequest) ( } func (s *Service) Get(ctx context.Context, id string) (Subagent, error) { - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return Subagent{}, err } @@ -85,7 +84,7 @@ func (s *Service) Get(ctx context.Context, id string) (Subagent, error) { } func (s *Service) List(ctx context.Context, botID string) ([]Subagent, error) { - pgBotID, err := parseUUID(botID) + pgBotID, err := db.ParseUUID(botID) if err != nil { return nil, err } @@ -131,7 +130,7 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Sub if err != nil { return Subagent{}, err } - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return Subagent{}, err } @@ -152,7 +151,7 @@ func (s *Service) UpdateContext(ctx context.Context, id string, req UpdateContex if err != nil { return Subagent{}, err } - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return Subagent{}, err } @@ -171,7 +170,7 @@ func (s *Service) UpdateSkills(ctx context.Context, id string, req UpdateSkillsR if err != nil { return Subagent{}, err } - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return Subagent{}, err } @@ -195,7 +194,7 @@ func (s *Service) AddSkills(ctx context.Context, id string, req AddSkillsRequest if err != nil { return Subagent{}, err } - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return Subagent{}, err } @@ -210,7 +209,7 @@ func (s *Service) AddSkills(ctx context.Context, id string, req AddSkillsRequest } func (s *Service) Delete(ctx context.Context, id string) error { - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return err } @@ -231,10 +230,10 @@ func toSubagent(row sqlc.Subagent) (Subagent, error) { return Subagent{}, err } item := Subagent{ - ID: toUUIDString(row.ID), + ID: row.ID.String(), Name: row.Name, Description: row.Description, - BotID: toUUIDString(row.BotID), + BotID: row.BotID.String(), Messages: messages, Metadata: metadata, Skills: skills, @@ -336,24 +335,3 @@ func mergeSkills(existing []string, incoming []string) []string { return normalizeSkills(merged) } -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} - -func toUUIDString(value pgtype.UUID) string { - if !value.Valid { - return "" - } - id, err := uuid.FromBytes(value.Bytes[:]) - if err != nil { - return "" - } - return id.String() -} diff --git a/packages/sdk/src/@pinia/colada.gen.ts b/packages/sdk/src/@pinia/colada.gen.ts index acb165a5..3417e0c2 100644 --- a/packages/sdk/src/@pinia/colada.gen.ts +++ b/packages/sdk/src/@pinia/colada.gen.ts @@ -4,8 +4,8 @@ import { type _JSONValue, defineQueryOptions, type UseMutationOptions } from '@p import { serializeQueryKeyValue } from '../client'; import { client } from '../client.gen'; -import { deleteBotsByBotIdContainer, deleteBotsByBotIdContainerFs, deleteBotsByBotIdContainerSkills, deleteBotsByBotIdHistory, deleteBotsByBotIdHistoryById, deleteBotsByBotIdMcpById, deleteBotsByBotIdMemoryMemories, deleteBotsByBotIdMemoryMemoriesByMemoryId, deleteBotsByBotIdScheduleById, deleteBotsByBotIdSettings, deleteBotsByBotIdSubagentsById, deleteBotsById, deleteBotsByIdMembersByUserId, deleteModelsById, deleteModelsModelByModelId, deleteProvidersById, getBots, getBotsByBotIdContainer, getBotsByBotIdContainerFs, getBotsByBotIdContainerFsFile, getBotsByBotIdContainerFsStat, getBotsByBotIdContainerFsUsage, getBotsByBotIdContainerSkills, getBotsByBotIdContainerSnapshots, getBotsByBotIdHistory, getBotsByBotIdHistoryById, getBotsByBotIdMcp, getBotsByBotIdMcpById, getBotsByBotIdMemoryMemories, getBotsByBotIdMemoryMemoriesByMemoryId, getBotsByBotIdSchedule, getBotsByBotIdScheduleById, getBotsByBotIdSettings, getBotsByBotIdSubagents, getBotsByBotIdSubagentsById, getBotsByBotIdSubagentsByIdContext, getBotsByBotIdSubagentsByIdSkills, getBotsById, getBotsByIdChannelByPlatform, getBotsByIdMembers, getChannels, getChannelsByPlatform, getModels, getModelsById, getModelsCount, getModelsModelByModelId, getProviders, getProvidersById, getProvidersByIdModels, getProvidersCount, getProvidersNameByName, getUsers, getUsersById, getUsersMe, getUsersMeChannelsByPlatform, type Options, postAuthLogin, postBots, postBotsByBotIdChat, postBotsByBotIdChatStream, postBotsByBotIdContainer, postBotsByBotIdContainerFsDir, postBotsByBotIdContainerFsFile, postBotsByBotIdContainerFsMcp, postBotsByBotIdContainerFsUpload, postBotsByBotIdContainerSkills, postBotsByBotIdContainerSnapshots, postBotsByBotIdContainerStart, postBotsByBotIdContainerStop, postBotsByBotIdHistory, postBotsByBotIdMcp, postBotsByBotIdMcpStdio, postBotsByBotIdMcpStdioBySessionId, postBotsByBotIdMemoryAdd, postBotsByBotIdMemoryEmbed, postBotsByBotIdMemorySearch, postBotsByBotIdMemoryUpdate, postBotsByBotIdSchedule, postBotsByBotIdSettings, postBotsByBotIdSubagents, postBotsByBotIdSubagentsByIdSkills, postBotsByIdChannelByPlatformSend, postBotsByIdChannelByPlatformSendSession, postEmbeddings, postModels, postModelsEnable, postProviders, postUsers, putBotsByBotIdMcpById, putBotsByBotIdScheduleById, putBotsByBotIdSettings, putBotsByBotIdSubagentsById, putBotsByBotIdSubagentsByIdContext, putBotsByBotIdSubagentsByIdSkills, putBotsById, putBotsByIdChannelByPlatform, putBotsByIdMembers, putBotsByIdOwner, putModelsById, putModelsModelByModelId, putProvidersById, putUsersById, putUsersByIdPassword, putUsersMe, putUsersMeChannelsByPlatform, putUsersMePassword } from '../sdk.gen'; -import type { DeleteBotsByBotIdContainerData, DeleteBotsByBotIdContainerError, DeleteBotsByBotIdContainerFsData, DeleteBotsByBotIdContainerFsError, DeleteBotsByBotIdContainerFsResponse, DeleteBotsByBotIdContainerSkillsData, DeleteBotsByBotIdContainerSkillsError, DeleteBotsByBotIdContainerSkillsResponse, DeleteBotsByBotIdHistoryByIdData, DeleteBotsByBotIdHistoryByIdError, DeleteBotsByBotIdHistoryData, DeleteBotsByBotIdHistoryError, DeleteBotsByBotIdMcpByIdData, DeleteBotsByBotIdMcpByIdError, DeleteBotsByBotIdMemoryMemoriesByMemoryIdData, DeleteBotsByBotIdMemoryMemoriesByMemoryIdError, DeleteBotsByBotIdMemoryMemoriesByMemoryIdResponse, DeleteBotsByBotIdMemoryMemoriesData, DeleteBotsByBotIdMemoryMemoriesError, DeleteBotsByBotIdMemoryMemoriesResponse, DeleteBotsByBotIdScheduleByIdData, DeleteBotsByBotIdScheduleByIdError, DeleteBotsByBotIdSettingsData, DeleteBotsByBotIdSettingsError, DeleteBotsByBotIdSubagentsByIdData, DeleteBotsByBotIdSubagentsByIdError, DeleteBotsByIdData, DeleteBotsByIdError, DeleteBotsByIdMembersByUserIdData, DeleteBotsByIdMembersByUserIdError, DeleteModelsByIdData, DeleteModelsByIdError, DeleteModelsModelByModelIdData, DeleteModelsModelByModelIdError, DeleteProvidersByIdData, DeleteProvidersByIdError, GetBotsByBotIdContainerData, GetBotsByBotIdContainerFsData, GetBotsByBotIdContainerFsFileData, GetBotsByBotIdContainerFsStatData, GetBotsByBotIdContainerFsUsageData, GetBotsByBotIdContainerSkillsData, GetBotsByBotIdContainerSnapshotsData, GetBotsByBotIdHistoryByIdData, GetBotsByBotIdHistoryData, GetBotsByBotIdMcpByIdData, GetBotsByBotIdMcpData, GetBotsByBotIdMemoryMemoriesByMemoryIdData, GetBotsByBotIdMemoryMemoriesData, GetBotsByBotIdScheduleByIdData, GetBotsByBotIdScheduleData, GetBotsByBotIdSettingsData, GetBotsByBotIdSubagentsByIdContextData, GetBotsByBotIdSubagentsByIdData, GetBotsByBotIdSubagentsByIdSkillsData, GetBotsByBotIdSubagentsData, GetBotsByIdChannelByPlatformData, GetBotsByIdData, GetBotsByIdMembersData, GetBotsData, GetChannelsByPlatformData, GetChannelsData, GetModelsByIdData, GetModelsCountData, GetModelsData, GetModelsModelByModelIdData, GetProvidersByIdData, GetProvidersByIdModelsData, GetProvidersCountData, GetProvidersData, GetProvidersNameByNameData, GetUsersByIdData, GetUsersData, GetUsersMeChannelsByPlatformData, GetUsersMeData, PostAuthLoginData, PostAuthLoginError, PostAuthLoginResponse, PostBotsByBotIdChatData, PostBotsByBotIdChatError, PostBotsByBotIdChatResponse, PostBotsByBotIdChatStreamData, PostBotsByBotIdChatStreamError, PostBotsByBotIdChatStreamResponse, PostBotsByBotIdContainerData, PostBotsByBotIdContainerError, PostBotsByBotIdContainerFsDirData, PostBotsByBotIdContainerFsDirError, PostBotsByBotIdContainerFsDirResponse, PostBotsByBotIdContainerFsFileData, PostBotsByBotIdContainerFsFileError, PostBotsByBotIdContainerFsFileResponse, PostBotsByBotIdContainerFsMcpData, PostBotsByBotIdContainerFsMcpError, PostBotsByBotIdContainerFsMcpResponse, PostBotsByBotIdContainerFsUploadData, PostBotsByBotIdContainerFsUploadError, PostBotsByBotIdContainerFsUploadResponse, PostBotsByBotIdContainerResponse, PostBotsByBotIdContainerSkillsData, PostBotsByBotIdContainerSkillsError, PostBotsByBotIdContainerSkillsResponse, PostBotsByBotIdContainerSnapshotsData, PostBotsByBotIdContainerSnapshotsError, PostBotsByBotIdContainerSnapshotsResponse, PostBotsByBotIdContainerStartData, PostBotsByBotIdContainerStartError, PostBotsByBotIdContainerStartResponse, PostBotsByBotIdContainerStopData, PostBotsByBotIdContainerStopError, PostBotsByBotIdContainerStopResponse, PostBotsByBotIdHistoryData, PostBotsByBotIdHistoryError, PostBotsByBotIdHistoryResponse, PostBotsByBotIdMcpData, PostBotsByBotIdMcpError, PostBotsByBotIdMcpResponse, PostBotsByBotIdMcpStdioBySessionIdData, PostBotsByBotIdMcpStdioBySessionIdError, PostBotsByBotIdMcpStdioBySessionIdResponse, PostBotsByBotIdMcpStdioData, PostBotsByBotIdMcpStdioError, PostBotsByBotIdMcpStdioResponse, PostBotsByBotIdMemoryAddData, PostBotsByBotIdMemoryAddError, PostBotsByBotIdMemoryAddResponse, PostBotsByBotIdMemoryEmbedData, PostBotsByBotIdMemoryEmbedError, PostBotsByBotIdMemoryEmbedResponse, PostBotsByBotIdMemorySearchData, PostBotsByBotIdMemorySearchError, PostBotsByBotIdMemorySearchResponse, PostBotsByBotIdMemoryUpdateData, PostBotsByBotIdMemoryUpdateError, PostBotsByBotIdMemoryUpdateResponse, PostBotsByBotIdScheduleData, PostBotsByBotIdScheduleError, PostBotsByBotIdScheduleResponse, PostBotsByBotIdSettingsData, PostBotsByBotIdSettingsError, PostBotsByBotIdSettingsResponse, PostBotsByBotIdSubagentsByIdSkillsData, PostBotsByBotIdSubagentsByIdSkillsError, PostBotsByBotIdSubagentsByIdSkillsResponse, PostBotsByBotIdSubagentsData, PostBotsByBotIdSubagentsError, PostBotsByBotIdSubagentsResponse, PostBotsByIdChannelByPlatformSendData, PostBotsByIdChannelByPlatformSendError, PostBotsByIdChannelByPlatformSendResponse, PostBotsByIdChannelByPlatformSendSessionData, PostBotsByIdChannelByPlatformSendSessionError, PostBotsByIdChannelByPlatformSendSessionResponse, PostBotsData, PostBotsError, PostBotsResponse, PostEmbeddingsData, PostEmbeddingsError, PostEmbeddingsResponse, PostModelsData, PostModelsEnableData, PostModelsEnableError, PostModelsEnableResponse, PostModelsError, PostModelsResponse, PostProvidersData, PostProvidersError, PostProvidersResponse, PostUsersData, PostUsersError, PostUsersResponse, PutBotsByBotIdMcpByIdData, PutBotsByBotIdMcpByIdError, PutBotsByBotIdMcpByIdResponse, PutBotsByBotIdScheduleByIdData, PutBotsByBotIdScheduleByIdError, PutBotsByBotIdScheduleByIdResponse, PutBotsByBotIdSettingsData, PutBotsByBotIdSettingsError, PutBotsByBotIdSettingsResponse, PutBotsByBotIdSubagentsByIdContextData, PutBotsByBotIdSubagentsByIdContextError, PutBotsByBotIdSubagentsByIdContextResponse, PutBotsByBotIdSubagentsByIdData, PutBotsByBotIdSubagentsByIdError, PutBotsByBotIdSubagentsByIdResponse, PutBotsByBotIdSubagentsByIdSkillsData, PutBotsByBotIdSubagentsByIdSkillsError, PutBotsByBotIdSubagentsByIdSkillsResponse, PutBotsByIdChannelByPlatformData, PutBotsByIdChannelByPlatformError, PutBotsByIdChannelByPlatformResponse, PutBotsByIdData, PutBotsByIdError, PutBotsByIdMembersData, PutBotsByIdMembersError, PutBotsByIdMembersResponse, PutBotsByIdOwnerData, PutBotsByIdOwnerError, PutBotsByIdOwnerResponse, PutBotsByIdResponse, PutModelsByIdData, PutModelsByIdError, PutModelsByIdResponse, PutModelsModelByModelIdData, PutModelsModelByModelIdError, PutModelsModelByModelIdResponse, PutProvidersByIdData, PutProvidersByIdError, PutProvidersByIdResponse, PutUsersByIdData, PutUsersByIdError, PutUsersByIdPasswordData, PutUsersByIdPasswordError, PutUsersByIdResponse, PutUsersMeChannelsByPlatformData, PutUsersMeChannelsByPlatformError, PutUsersMeChannelsByPlatformResponse, PutUsersMeData, PutUsersMeError, PutUsersMePasswordData, PutUsersMePasswordError, PutUsersMeResponse } from '../types.gen'; +import { deleteBotsByBotIdContainer, deleteBotsByBotIdContainerSkills, deleteBotsByBotIdMcpById, deleteBotsByBotIdScheduleById, deleteBotsByBotIdSettings, deleteBotsByBotIdSubagentsById, deleteBotsById, deleteBotsByIdMembersByUserId, deleteModelsById, deleteModelsModelByModelId, deleteProvidersById, getBots, getBotsByBotIdContainer, getBotsByBotIdContainerSkills, getBotsByBotIdContainerSnapshots, getBotsByBotIdMcp, getBotsByBotIdMcpById, getBotsByBotIdSchedule, getBotsByBotIdScheduleById, getBotsByBotIdSettings, getBotsByBotIdSubagents, getBotsByBotIdSubagentsById, getBotsByBotIdSubagentsByIdContext, getBotsByBotIdSubagentsByIdSkills, getBotsById, getBotsByIdChannelByPlatform, getBotsByIdChecks, getBotsByIdMembers, getChannels, getChannelsByPlatform, getModels, getModelsById, getModelsCount, getModelsModelByModelId, getProviders, getProvidersById, getProvidersByIdModels, getProvidersCount, getProvidersNameByName, getUsers, getUsersById, getUsersMe, getUsersMeChannelsByPlatform, getUsersMeIdentities, type Options, postAuthLogin, postBots, postBotsByBotIdContainer, postBotsByBotIdContainerSkills, postBotsByBotIdContainerSnapshots, postBotsByBotIdContainerStart, postBotsByBotIdContainerStop, postBotsByBotIdMcp, postBotsByBotIdMcpStdio, postBotsByBotIdMcpStdioByConnectionId, postBotsByBotIdSchedule, postBotsByBotIdSettings, postBotsByBotIdSubagents, postBotsByBotIdSubagentsByIdSkills, postBotsByBotIdTools, postBotsByIdChannelByPlatformSend, postBotsByIdChannelByPlatformSendChat, postEmbeddings, postModels, postModelsEnable, postProviders, postUsers, putBotsByBotIdMcpById, putBotsByBotIdScheduleById, putBotsByBotIdSettings, putBotsByBotIdSubagentsById, putBotsByBotIdSubagentsByIdContext, putBotsByBotIdSubagentsByIdSkills, putBotsById, putBotsByIdChannelByPlatform, putBotsByIdMembers, putBotsByIdOwner, putModelsById, putModelsModelByModelId, putProvidersById, putUsersById, putUsersByIdPassword, putUsersMe, putUsersMeChannelsByPlatform, putUsersMePassword } from '../sdk.gen'; +import type { DeleteBotsByBotIdContainerData, DeleteBotsByBotIdContainerError, DeleteBotsByBotIdContainerSkillsData, DeleteBotsByBotIdContainerSkillsError, DeleteBotsByBotIdContainerSkillsResponse, DeleteBotsByBotIdMcpByIdData, DeleteBotsByBotIdMcpByIdError, DeleteBotsByBotIdScheduleByIdData, DeleteBotsByBotIdScheduleByIdError, DeleteBotsByBotIdSettingsData, DeleteBotsByBotIdSettingsError, DeleteBotsByBotIdSubagentsByIdData, DeleteBotsByBotIdSubagentsByIdError, DeleteBotsByIdData, DeleteBotsByIdError, DeleteBotsByIdMembersByUserIdData, DeleteBotsByIdMembersByUserIdError, DeleteBotsByIdResponse, DeleteModelsByIdData, DeleteModelsByIdError, DeleteModelsModelByModelIdData, DeleteModelsModelByModelIdError, DeleteProvidersByIdData, DeleteProvidersByIdError, GetBotsByBotIdContainerData, GetBotsByBotIdContainerSkillsData, GetBotsByBotIdContainerSnapshotsData, GetBotsByBotIdMcpByIdData, GetBotsByBotIdMcpData, GetBotsByBotIdScheduleByIdData, GetBotsByBotIdScheduleData, GetBotsByBotIdSettingsData, GetBotsByBotIdSubagentsByIdContextData, GetBotsByBotIdSubagentsByIdData, GetBotsByBotIdSubagentsByIdSkillsData, GetBotsByBotIdSubagentsData, GetBotsByIdChannelByPlatformData, GetBotsByIdChecksData, GetBotsByIdData, GetBotsByIdMembersData, GetBotsData, GetChannelsByPlatformData, GetChannelsData, GetModelsByIdData, GetModelsCountData, GetModelsData, GetModelsModelByModelIdData, GetProvidersByIdData, GetProvidersByIdModelsData, GetProvidersCountData, GetProvidersData, GetProvidersNameByNameData, GetUsersByIdData, GetUsersData, GetUsersMeChannelsByPlatformData, GetUsersMeData, GetUsersMeIdentitiesData, PostAuthLoginData, PostAuthLoginError, PostAuthLoginResponse, PostBotsByBotIdContainerData, PostBotsByBotIdContainerError, PostBotsByBotIdContainerResponse, PostBotsByBotIdContainerSkillsData, PostBotsByBotIdContainerSkillsError, PostBotsByBotIdContainerSkillsResponse, PostBotsByBotIdContainerSnapshotsData, PostBotsByBotIdContainerSnapshotsError, PostBotsByBotIdContainerSnapshotsResponse, PostBotsByBotIdContainerStartData, PostBotsByBotIdContainerStartError, PostBotsByBotIdContainerStartResponse, PostBotsByBotIdContainerStopData, PostBotsByBotIdContainerStopError, PostBotsByBotIdContainerStopResponse, PostBotsByBotIdMcpData, PostBotsByBotIdMcpError, PostBotsByBotIdMcpResponse, PostBotsByBotIdMcpStdioByConnectionIdData, PostBotsByBotIdMcpStdioByConnectionIdError, PostBotsByBotIdMcpStdioByConnectionIdResponse, PostBotsByBotIdMcpStdioData, PostBotsByBotIdMcpStdioError, PostBotsByBotIdMcpStdioResponse, PostBotsByBotIdScheduleData, PostBotsByBotIdScheduleError, PostBotsByBotIdScheduleResponse, PostBotsByBotIdSettingsData, PostBotsByBotIdSettingsError, PostBotsByBotIdSettingsResponse, PostBotsByBotIdSubagentsByIdSkillsData, PostBotsByBotIdSubagentsByIdSkillsError, PostBotsByBotIdSubagentsByIdSkillsResponse, PostBotsByBotIdSubagentsData, PostBotsByBotIdSubagentsError, PostBotsByBotIdSubagentsResponse, PostBotsByBotIdToolsData, PostBotsByBotIdToolsError, PostBotsByBotIdToolsResponse, PostBotsByIdChannelByPlatformSendChatData, PostBotsByIdChannelByPlatformSendChatError, PostBotsByIdChannelByPlatformSendChatResponse, PostBotsByIdChannelByPlatformSendData, PostBotsByIdChannelByPlatformSendError, PostBotsByIdChannelByPlatformSendResponse, PostBotsData, PostBotsError, PostBotsResponse, PostEmbeddingsData, PostEmbeddingsError, PostEmbeddingsResponse, PostModelsData, PostModelsEnableData, PostModelsEnableError, PostModelsEnableResponse, PostModelsError, PostModelsResponse, PostProvidersData, PostProvidersError, PostProvidersResponse, PostUsersData, PostUsersError, PostUsersResponse, PutBotsByBotIdMcpByIdData, PutBotsByBotIdMcpByIdError, PutBotsByBotIdMcpByIdResponse, PutBotsByBotIdScheduleByIdData, PutBotsByBotIdScheduleByIdError, PutBotsByBotIdScheduleByIdResponse, PutBotsByBotIdSettingsData, PutBotsByBotIdSettingsError, PutBotsByBotIdSettingsResponse, PutBotsByBotIdSubagentsByIdContextData, PutBotsByBotIdSubagentsByIdContextError, PutBotsByBotIdSubagentsByIdContextResponse, PutBotsByBotIdSubagentsByIdData, PutBotsByBotIdSubagentsByIdError, PutBotsByBotIdSubagentsByIdResponse, PutBotsByBotIdSubagentsByIdSkillsData, PutBotsByBotIdSubagentsByIdSkillsError, PutBotsByBotIdSubagentsByIdSkillsResponse, PutBotsByIdChannelByPlatformData, PutBotsByIdChannelByPlatformError, PutBotsByIdChannelByPlatformResponse, PutBotsByIdData, PutBotsByIdError, PutBotsByIdMembersData, PutBotsByIdMembersError, PutBotsByIdMembersResponse, PutBotsByIdOwnerData, PutBotsByIdOwnerError, PutBotsByIdOwnerResponse, PutBotsByIdResponse, PutModelsByIdData, PutModelsByIdError, PutModelsByIdResponse, PutModelsModelByModelIdData, PutModelsModelByModelIdError, PutModelsModelByModelIdResponse, PutProvidersByIdData, PutProvidersByIdError, PutProvidersByIdResponse, PutUsersByIdData, PutUsersByIdError, PutUsersByIdPasswordData, PutUsersByIdPasswordError, PutUsersByIdResponse, PutUsersMeChannelsByPlatformData, PutUsersMeChannelsByPlatformError, PutUsersMeChannelsByPlatformResponse, PutUsersMeData, PutUsersMeError, PutUsersMePasswordData, PutUsersMePasswordError, PutUsersMeResponse } from '../types.gen'; /** * Login @@ -93,38 +93,6 @@ export const postBotsMutation = (options?: Partial>): UseM } }); -/** - * Chat with AI - * - * Send a chat message and get a response. The system will automatically select an appropriate chat model from the database. - */ -export const postBotsByBotIdChatMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdChatError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdChat({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -/** - * Stream chat with AI - * - * Send a chat message and get a streaming response. The system will automatically select an appropriate chat model from the database. - */ -export const postBotsByBotIdChatStreamMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdChatStreamError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdChatStream({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - /** * Delete MCP container for bot */ @@ -170,160 +138,6 @@ export const postBotsByBotIdContainerMutation = (options?: Partial>): UseMutationOptions, DeleteBotsByBotIdContainerFsError> => ({ - mutation: async (vars) => { - const { data } = await deleteBotsByBotIdContainerFs({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -export const getBotsByBotIdContainerFsQueryKey = (options: Options) => createQueryKey('getBotsByBotIdContainerFs', options); - -/** - * List files for a bot - * - * List entries under a relative path - */ -export const getBotsByBotIdContainerFsQuery = defineQueryOptions((options: Options) => ({ - key: getBotsByBotIdContainerFsQueryKey(options), - query: async (context) => { - const { data } = await getBotsByBotIdContainerFs({ - ...options, - ...context, - throwOnError: true - }); - return data; - } -})); - -/** - * MCP filesystem tools (JSON-RPC) - * - * Forwards MCP JSON-RPC requests to the MCP server inside the container. - * Required: - * - container task is running - * - container has data mount (default /data) bound to /users/ - * - container image contains the "mcp" binary - * Auth: Bearer JWT is used to determine user_id (sub or user_id). - * Paths must be relative (no leading slash) and must not contain "..". - * - * Example: tools/list - * {"jsonrpc":"2.0","id":1,"method":"tools/list"} - * - * Example: tools/call (fs.read) - * {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"fs.read","arguments":{"path":"notes.txt"}}} - */ -export const postBotsByBotIdContainerFsMcpMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdContainerFsMcpError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdContainerFsMcp({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -/** - * Create a directory - */ -export const postBotsByBotIdContainerFsDirMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdContainerFsDirError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdContainerFsDir({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -export const getBotsByBotIdContainerFsFileQueryKey = (options: Options) => createQueryKey('getBotsByBotIdContainerFsFile', options); - -/** - * Read file content - */ -export const getBotsByBotIdContainerFsFileQuery = defineQueryOptions((options: Options) => ({ - key: getBotsByBotIdContainerFsFileQueryKey(options), - query: async (context) => { - const { data } = await getBotsByBotIdContainerFsFile({ - ...options, - ...context, - throwOnError: true - }); - return data; - } -})); - -/** - * Create or overwrite a file - */ -export const postBotsByBotIdContainerFsFileMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdContainerFsFileError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdContainerFsFile({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -export const getBotsByBotIdContainerFsStatQueryKey = (options: Options) => createQueryKey('getBotsByBotIdContainerFsStat', options); - -/** - * Get file or directory metadata - */ -export const getBotsByBotIdContainerFsStatQuery = defineQueryOptions((options: Options) => ({ - key: getBotsByBotIdContainerFsStatQueryKey(options), - query: async (context) => { - const { data } = await getBotsByBotIdContainerFsStat({ - ...options, - ...context, - throwOnError: true - }); - return data; - } -})); - -/** - * Upload a file - */ -export const postBotsByBotIdContainerFsUploadMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdContainerFsUploadError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdContainerFsUpload({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -export const getBotsByBotIdContainerFsUsageQueryKey = (options: Options) => createQueryKey('getBotsByBotIdContainerFsUsage', options); - -/** - * Get usage under a path - */ -export const getBotsByBotIdContainerFsUsageQuery = defineQueryOptions((options: Options) => ({ - key: getBotsByBotIdContainerFsUsageQueryKey(options), - query: async (context) => { - const { data } = await getBotsByBotIdContainerFsUsage({ - ...options, - ...context, - throwOnError: true - }); - return data; - } -})); - /** * Delete skills from data directory */ @@ -428,100 +242,14 @@ export const postBotsByBotIdContainerStopMutation = (options?: Partial>): UseMutationOptions, DeleteBotsByBotIdHistoryError> => ({ - mutation: async (vars) => { - const { data } = await deleteBotsByBotIdHistory({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -export const getBotsByBotIdHistoryQueryKey = (options: Options) => createQueryKey('getBotsByBotIdHistory', options); - -/** - * List history records - * - * List history records for current user - */ -export const getBotsByBotIdHistoryQuery = defineQueryOptions((options: Options) => ({ - key: getBotsByBotIdHistoryQueryKey(options), - query: async (context) => { - const { data } = await getBotsByBotIdHistory({ - ...options, - ...context, - throwOnError: true - }); - return data; - } -})); - -/** - * Create history record - * - * Create a history record for current user - */ -export const postBotsByBotIdHistoryMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdHistoryError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdHistory({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -/** - * Delete history record - * - * Delete a history record by ID (must belong to current user) - */ -export const deleteBotsByBotIdHistoryByIdMutation = (options?: Partial>): UseMutationOptions, DeleteBotsByBotIdHistoryByIdError> => ({ - mutation: async (vars) => { - const { data } = await deleteBotsByBotIdHistoryById({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -export const getBotsByBotIdHistoryByIdQueryKey = (options: Options) => createQueryKey('getBotsByBotIdHistoryById', options); - -/** - * Get history record - * - * Get a history record by ID (must belong to current user) - */ -export const getBotsByBotIdHistoryByIdQuery = defineQueryOptions((options: Options) => ({ - key: getBotsByBotIdHistoryByIdQueryKey(options), - query: async (context) => { - const { data } = await getBotsByBotIdHistoryById({ - ...options, - ...context, - throwOnError: true - }); - return data; - } -})); - -export const getBotsByBotIdMcpQueryKey = (options: Options) => createQueryKey('getBotsByBotIdMcp', options); +export const getBotsByBotIdMcpQueryKey = (options?: Options) => createQueryKey('getBotsByBotIdMcp', options); /** * List MCP connections * * List MCP connections for a bot */ -export const getBotsByBotIdMcpQuery = defineQueryOptions((options: Options) => ({ +export const getBotsByBotIdMcpQuery = defineQueryOptions((options?: Options) => ({ key: getBotsByBotIdMcpQueryKey(options), query: async (context) => { const { data } = await getBotsByBotIdMcp({ @@ -570,9 +298,9 @@ export const postBotsByBotIdMcpStdioMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdMcpStdioBySessionIdError> => ({ +export const postBotsByBotIdMcpStdioByConnectionIdMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdMcpStdioByConnectionIdError> => ({ mutation: async (vars) => { - const { data } = await postBotsByBotIdMcpStdioBySessionId({ + const { data } = await postBotsByBotIdMcpStdioByConnectionId({ ...options, ...vars, throwOnError: true @@ -632,148 +360,14 @@ export const putBotsByBotIdMcpByIdMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdMemoryAddError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdMemoryAdd({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -/** - * Embed and upsert memory - * - * Embed text or multimodal input and upsert into memory store. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const postBotsByBotIdMemoryEmbedMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdMemoryEmbedError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdMemoryEmbed({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -/** - * Delete memories - * - * Delete all memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const deleteBotsByBotIdMemoryMemoriesMutation = (options?: Partial>): UseMutationOptions, DeleteBotsByBotIdMemoryMemoriesError> => ({ - mutation: async (vars) => { - const { data } = await deleteBotsByBotIdMemoryMemories({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -export const getBotsByBotIdMemoryMemoriesQueryKey = (options: Options) => createQueryKey('getBotsByBotIdMemoryMemories', options); - -/** - * List memories - * - * List memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const getBotsByBotIdMemoryMemoriesQuery = defineQueryOptions((options: Options) => ({ - key: getBotsByBotIdMemoryMemoriesQueryKey(options), - query: async (context) => { - const { data } = await getBotsByBotIdMemoryMemories({ - ...options, - ...context, - throwOnError: true - }); - return data; - } -})); - -/** - * Delete memory - * - * Delete a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const deleteBotsByBotIdMemoryMemoriesByMemoryIdMutation = (options?: Partial>): UseMutationOptions, DeleteBotsByBotIdMemoryMemoriesByMemoryIdError> => ({ - mutation: async (vars) => { - const { data } = await deleteBotsByBotIdMemoryMemoriesByMemoryId({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -export const getBotsByBotIdMemoryMemoriesByMemoryIdQueryKey = (options: Options) => createQueryKey('getBotsByBotIdMemoryMemoriesByMemoryId', options); - -/** - * Get memory - * - * Get a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const getBotsByBotIdMemoryMemoriesByMemoryIdQuery = defineQueryOptions((options: Options) => ({ - key: getBotsByBotIdMemoryMemoriesByMemoryIdQueryKey(options), - query: async (context) => { - const { data } = await getBotsByBotIdMemoryMemoriesByMemoryId({ - ...options, - ...context, - throwOnError: true - }); - return data; - } -})); - -/** - * Search memories - * - * Search memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const postBotsByBotIdMemorySearchMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdMemorySearchError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdMemorySearch({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -/** - * Update memory - * - * Update a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const postBotsByBotIdMemoryUpdateMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdMemoryUpdateError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdMemoryUpdate({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -export const getBotsByBotIdScheduleQueryKey = (options: Options) => createQueryKey('getBotsByBotIdSchedule', options); +export const getBotsByBotIdScheduleQueryKey = (options?: Options) => createQueryKey('getBotsByBotIdSchedule', options); /** * List schedules * * List schedules for current user */ -export const getBotsByBotIdScheduleQuery = defineQueryOptions((options: Options) => ({ +export const getBotsByBotIdScheduleQuery = defineQueryOptions((options?: Options) => ({ key: getBotsByBotIdScheduleQueryKey(options), query: async (context) => { const { data } = await getBotsByBotIdSchedule({ @@ -868,14 +462,14 @@ export const deleteBotsByBotIdSettingsMutation = (options?: Partial) => createQueryKey('getBotsByBotIdSettings', options); +export const getBotsByBotIdSettingsQueryKey = (options?: Options) => createQueryKey('getBotsByBotIdSettings', options); /** * Get user settings * * Get agent settings for current user */ -export const getBotsByBotIdSettingsQuery = defineQueryOptions((options: Options) => ({ +export const getBotsByBotIdSettingsQuery = defineQueryOptions((options?: Options) => ({ key: getBotsByBotIdSettingsQueryKey(options), query: async (context) => { const { data } = await getBotsByBotIdSettings({ @@ -919,14 +513,14 @@ export const putBotsByBotIdSettingsMutation = (options?: Partial) => createQueryKey('getBotsByBotIdSubagents', options); +export const getBotsByBotIdSubagentsQueryKey = (options?: Options) => createQueryKey('getBotsByBotIdSubagents', options); /** * List subagents * * List subagents for current user */ -export const getBotsByBotIdSubagentsQuery = defineQueryOptions((options: Options) => ({ +export const getBotsByBotIdSubagentsQuery = defineQueryOptions((options?: Options) => ({ key: getBotsByBotIdSubagentsQueryKey(options), query: async (context) => { const { data } = await getBotsByBotIdSubagents({ @@ -1091,12 +685,28 @@ export const putBotsByBotIdSubagentsByIdSkillsMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdToolsError> => ({ + mutation: async (vars) => { + const { data } = await postBotsByBotIdTools({ + ...options, + ...vars, + throwOnError: true + }); + return data; + } +}); + /** * Delete bot * * Delete a bot user (owner/admin only) */ -export const deleteBotsByIdMutation = (options?: Partial>): UseMutationOptions, DeleteBotsByIdError> => ({ +export const deleteBotsByIdMutation = (options?: Partial>): UseMutationOptions, DeleteBotsByIdError> => ({ mutation: async (vars) => { const { data } = await deleteBotsById({ ...options, @@ -1198,9 +808,9 @@ export const postBotsByIdChannelByPlatformSendMutation = (options?: Partial>): UseMutationOptions, PostBotsByIdChannelByPlatformSendSessionError> => ({ +export const postBotsByIdChannelByPlatformSendChatMutation = (options?: Partial>): UseMutationOptions, PostBotsByIdChannelByPlatformSendChatError> => ({ mutation: async (vars) => { - const { data } = await postBotsByIdChannelByPlatformSendSession({ + const { data } = await postBotsByIdChannelByPlatformSendChat({ ...options, ...vars, throwOnError: true @@ -1209,6 +819,25 @@ export const postBotsByIdChannelByPlatformSendSessionMutation = (options?: Parti } }); +export const getBotsByIdChecksQueryKey = (options: Options) => createQueryKey('getBotsByIdChecks', options); + +/** + * List bot runtime checks + * + * Evaluate bot attached resource checks in runtime + */ +export const getBotsByIdChecksQuery = defineQueryOptions((options: Options) => ({ + key: getBotsByIdChecksQueryKey(options), + query: async (context) => { + const { data } = await getBotsByIdChecks({ + ...options, + ...context, + throwOnError: true + }); + return data; + } +})); + export const getBotsByIdMembersQueryKey = (options: Options) => createQueryKey('getBotsByIdMembers', options); /** @@ -1750,6 +1379,25 @@ export const putUsersMeChannelsByPlatformMutation = (options?: Partial) => createQueryKey('getUsersMeIdentities', options); + +/** + * List current user's channel identities + * + * List all channel identities linked to current user + */ +export const getUsersMeIdentitiesQuery = defineQueryOptions((options?: Options) => ({ + key: getUsersMeIdentitiesQueryKey(options), + query: async (context) => { + const { data } = await getUsersMeIdentities({ + ...options, + ...context, + throwOnError: true + }); + return data; + } +})); + /** * Update current user password * diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts index 904fad11..bcb5c2ea 100644 --- a/packages/sdk/src/index.ts +++ b/packages/sdk/src/index.ts @@ -1,4 +1,4 @@ // This file is auto-generated by @hey-api/openapi-ts -export { deleteBotsByBotIdContainer, deleteBotsByBotIdContainerFs, deleteBotsByBotIdContainerSkills, deleteBotsByBotIdHistory, deleteBotsByBotIdHistoryById, deleteBotsByBotIdMcpById, deleteBotsByBotIdMemoryMemories, deleteBotsByBotIdMemoryMemoriesByMemoryId, deleteBotsByBotIdScheduleById, deleteBotsByBotIdSettings, deleteBotsByBotIdSubagentsById, deleteBotsById, deleteBotsByIdMembersByUserId, deleteModelsById, deleteModelsModelByModelId, deleteProvidersById, getBots, getBotsByBotIdContainer, getBotsByBotIdContainerFs, getBotsByBotIdContainerFsFile, getBotsByBotIdContainerFsStat, getBotsByBotIdContainerFsUsage, getBotsByBotIdContainerSkills, getBotsByBotIdContainerSnapshots, getBotsByBotIdHistory, getBotsByBotIdHistoryById, getBotsByBotIdMcp, getBotsByBotIdMcpById, getBotsByBotIdMemoryMemories, getBotsByBotIdMemoryMemoriesByMemoryId, getBotsByBotIdSchedule, getBotsByBotIdScheduleById, getBotsByBotIdSettings, getBotsByBotIdSubagents, getBotsByBotIdSubagentsById, getBotsByBotIdSubagentsByIdContext, getBotsByBotIdSubagentsByIdSkills, getBotsById, getBotsByIdChannelByPlatform, getBotsByIdMembers, getChannels, getChannelsByPlatform, getModels, getModelsById, getModelsCount, getModelsModelByModelId, getProviders, getProvidersById, getProvidersByIdModels, getProvidersCount, getProvidersNameByName, getUsers, getUsersById, getUsersMe, getUsersMeChannelsByPlatform, type Options, postAuthLogin, postBots, postBotsByBotIdChat, postBotsByBotIdChatStream, postBotsByBotIdContainer, postBotsByBotIdContainerFsDir, postBotsByBotIdContainerFsFile, postBotsByBotIdContainerFsMcp, postBotsByBotIdContainerFsUpload, postBotsByBotIdContainerSkills, postBotsByBotIdContainerSnapshots, postBotsByBotIdContainerStart, postBotsByBotIdContainerStop, postBotsByBotIdHistory, postBotsByBotIdMcp, postBotsByBotIdMcpStdio, postBotsByBotIdMcpStdioBySessionId, postBotsByBotIdMemoryAdd, postBotsByBotIdMemoryEmbed, postBotsByBotIdMemorySearch, postBotsByBotIdMemoryUpdate, postBotsByBotIdSchedule, postBotsByBotIdSettings, postBotsByBotIdSubagents, postBotsByBotIdSubagentsByIdSkills, postBotsByIdChannelByPlatformSend, postBotsByIdChannelByPlatformSendSession, postEmbeddings, postModels, postModelsEnable, postProviders, postUsers, putBotsByBotIdMcpById, putBotsByBotIdScheduleById, putBotsByBotIdSettings, putBotsByBotIdSubagentsById, putBotsByBotIdSubagentsByIdContext, putBotsByBotIdSubagentsByIdSkills, putBotsById, putBotsByIdChannelByPlatform, putBotsByIdMembers, putBotsByIdOwner, putModelsById, putModelsModelByModelId, putProvidersById, putUsersById, putUsersByIdPassword, putUsersMe, putUsersMeChannelsByPlatform, putUsersMePassword } from './sdk.gen'; -export type { BotsBot, BotsBotMember, BotsCreateBotRequest, BotsListBotsResponse, BotsListMembersResponse, BotsTransferBotRequest, BotsUpdateBotRequest, BotsUpsertMemberRequest, ChannelAction, ChannelAttachment, ChannelAttachmentType, ChannelChannelCapabilities, ChannelChannelConfig, ChannelChannelUserBinding, ChannelConfigSchema, ChannelFieldSchema, ChannelFieldType, ChannelMessage, ChannelMessageFormat, ChannelMessagePart, ChannelMessagePartType, ChannelMessageTextStyle, ChannelReplyRef, ChannelSendRequest, ChannelTargetHint, ChannelTargetSpec, ChannelThreadRef, ChannelUpsertConfigRequest, ChannelUpsertUserConfigRequest, ChatChatRequest, ChatChatResponse, ChatModelMessage, ChatToolCall, ChatToolCallFunction, ClientOptions, DeleteBotsByBotIdContainerData, DeleteBotsByBotIdContainerError, DeleteBotsByBotIdContainerErrors, DeleteBotsByBotIdContainerFsData, DeleteBotsByBotIdContainerFsError, DeleteBotsByBotIdContainerFsErrors, DeleteBotsByBotIdContainerFsResponse, DeleteBotsByBotIdContainerFsResponses, DeleteBotsByBotIdContainerResponses, DeleteBotsByBotIdContainerSkillsData, DeleteBotsByBotIdContainerSkillsError, DeleteBotsByBotIdContainerSkillsErrors, DeleteBotsByBotIdContainerSkillsResponse, DeleteBotsByBotIdContainerSkillsResponses, DeleteBotsByBotIdHistoryByIdData, DeleteBotsByBotIdHistoryByIdError, DeleteBotsByBotIdHistoryByIdErrors, DeleteBotsByBotIdHistoryByIdResponses, DeleteBotsByBotIdHistoryData, DeleteBotsByBotIdHistoryError, DeleteBotsByBotIdHistoryErrors, DeleteBotsByBotIdHistoryResponses, DeleteBotsByBotIdMcpByIdData, DeleteBotsByBotIdMcpByIdError, DeleteBotsByBotIdMcpByIdErrors, DeleteBotsByBotIdMcpByIdResponses, DeleteBotsByBotIdMemoryMemoriesByMemoryIdData, DeleteBotsByBotIdMemoryMemoriesByMemoryIdError, DeleteBotsByBotIdMemoryMemoriesByMemoryIdErrors, DeleteBotsByBotIdMemoryMemoriesByMemoryIdResponse, DeleteBotsByBotIdMemoryMemoriesByMemoryIdResponses, DeleteBotsByBotIdMemoryMemoriesData, DeleteBotsByBotIdMemoryMemoriesError, DeleteBotsByBotIdMemoryMemoriesErrors, DeleteBotsByBotIdMemoryMemoriesResponse, DeleteBotsByBotIdMemoryMemoriesResponses, DeleteBotsByBotIdScheduleByIdData, DeleteBotsByBotIdScheduleByIdError, DeleteBotsByBotIdScheduleByIdErrors, DeleteBotsByBotIdScheduleByIdResponses, DeleteBotsByBotIdSettingsData, DeleteBotsByBotIdSettingsError, DeleteBotsByBotIdSettingsErrors, DeleteBotsByBotIdSettingsResponses, DeleteBotsByBotIdSubagentsByIdData, DeleteBotsByBotIdSubagentsByIdError, DeleteBotsByBotIdSubagentsByIdErrors, DeleteBotsByBotIdSubagentsByIdResponses, DeleteBotsByIdData, DeleteBotsByIdError, DeleteBotsByIdErrors, DeleteBotsByIdMembersByUserIdData, DeleteBotsByIdMembersByUserIdError, DeleteBotsByIdMembersByUserIdErrors, DeleteBotsByIdMembersByUserIdResponses, DeleteBotsByIdResponses, DeleteModelsByIdData, DeleteModelsByIdError, DeleteModelsByIdErrors, DeleteModelsByIdResponses, DeleteModelsModelByModelIdData, DeleteModelsModelByModelIdError, DeleteModelsModelByModelIdErrors, DeleteModelsModelByModelIdResponses, DeleteProvidersByIdData, DeleteProvidersByIdError, DeleteProvidersByIdErrors, DeleteProvidersByIdResponses, GetBotsByBotIdContainerData, GetBotsByBotIdContainerError, GetBotsByBotIdContainerErrors, GetBotsByBotIdContainerFsData, GetBotsByBotIdContainerFsError, GetBotsByBotIdContainerFsErrors, GetBotsByBotIdContainerFsFileData, GetBotsByBotIdContainerFsFileError, GetBotsByBotIdContainerFsFileErrors, GetBotsByBotIdContainerFsFileResponse, GetBotsByBotIdContainerFsFileResponses, GetBotsByBotIdContainerFsResponse, GetBotsByBotIdContainerFsResponses, GetBotsByBotIdContainerFsStatData, GetBotsByBotIdContainerFsStatError, GetBotsByBotIdContainerFsStatErrors, GetBotsByBotIdContainerFsStatResponse, GetBotsByBotIdContainerFsStatResponses, GetBotsByBotIdContainerFsUsageData, GetBotsByBotIdContainerFsUsageError, GetBotsByBotIdContainerFsUsageErrors, GetBotsByBotIdContainerFsUsageResponse, GetBotsByBotIdContainerFsUsageResponses, GetBotsByBotIdContainerResponse, GetBotsByBotIdContainerResponses, GetBotsByBotIdContainerSkillsData, GetBotsByBotIdContainerSkillsError, GetBotsByBotIdContainerSkillsErrors, GetBotsByBotIdContainerSkillsResponse, GetBotsByBotIdContainerSkillsResponses, GetBotsByBotIdContainerSnapshotsData, GetBotsByBotIdContainerSnapshotsResponse, GetBotsByBotIdContainerSnapshotsResponses, GetBotsByBotIdHistoryByIdData, GetBotsByBotIdHistoryByIdError, GetBotsByBotIdHistoryByIdErrors, GetBotsByBotIdHistoryByIdResponse, GetBotsByBotIdHistoryByIdResponses, GetBotsByBotIdHistoryData, GetBotsByBotIdHistoryError, GetBotsByBotIdHistoryErrors, GetBotsByBotIdHistoryResponse, GetBotsByBotIdHistoryResponses, GetBotsByBotIdMcpByIdData, GetBotsByBotIdMcpByIdError, GetBotsByBotIdMcpByIdErrors, GetBotsByBotIdMcpByIdResponse, GetBotsByBotIdMcpByIdResponses, GetBotsByBotIdMcpData, GetBotsByBotIdMcpError, GetBotsByBotIdMcpErrors, GetBotsByBotIdMcpResponse, GetBotsByBotIdMcpResponses, GetBotsByBotIdMemoryMemoriesByMemoryIdData, GetBotsByBotIdMemoryMemoriesByMemoryIdError, GetBotsByBotIdMemoryMemoriesByMemoryIdErrors, GetBotsByBotIdMemoryMemoriesByMemoryIdResponse, GetBotsByBotIdMemoryMemoriesByMemoryIdResponses, GetBotsByBotIdMemoryMemoriesData, GetBotsByBotIdMemoryMemoriesError, GetBotsByBotIdMemoryMemoriesErrors, GetBotsByBotIdMemoryMemoriesResponse, GetBotsByBotIdMemoryMemoriesResponses, GetBotsByBotIdScheduleByIdData, GetBotsByBotIdScheduleByIdError, GetBotsByBotIdScheduleByIdErrors, GetBotsByBotIdScheduleByIdResponse, GetBotsByBotIdScheduleByIdResponses, GetBotsByBotIdScheduleData, GetBotsByBotIdScheduleError, GetBotsByBotIdScheduleErrors, GetBotsByBotIdScheduleResponse, GetBotsByBotIdScheduleResponses, GetBotsByBotIdSettingsData, GetBotsByBotIdSettingsError, GetBotsByBotIdSettingsErrors, GetBotsByBotIdSettingsResponse, GetBotsByBotIdSettingsResponses, GetBotsByBotIdSubagentsByIdContextData, GetBotsByBotIdSubagentsByIdContextError, GetBotsByBotIdSubagentsByIdContextErrors, GetBotsByBotIdSubagentsByIdContextResponse, GetBotsByBotIdSubagentsByIdContextResponses, GetBotsByBotIdSubagentsByIdData, GetBotsByBotIdSubagentsByIdError, GetBotsByBotIdSubagentsByIdErrors, GetBotsByBotIdSubagentsByIdResponse, GetBotsByBotIdSubagentsByIdResponses, GetBotsByBotIdSubagentsByIdSkillsData, GetBotsByBotIdSubagentsByIdSkillsError, GetBotsByBotIdSubagentsByIdSkillsErrors, GetBotsByBotIdSubagentsByIdSkillsResponse, GetBotsByBotIdSubagentsByIdSkillsResponses, GetBotsByBotIdSubagentsData, GetBotsByBotIdSubagentsError, GetBotsByBotIdSubagentsErrors, GetBotsByBotIdSubagentsResponse, GetBotsByBotIdSubagentsResponses, GetBotsByIdChannelByPlatformData, GetBotsByIdChannelByPlatformError, GetBotsByIdChannelByPlatformErrors, GetBotsByIdChannelByPlatformResponse, GetBotsByIdChannelByPlatformResponses, GetBotsByIdData, GetBotsByIdError, GetBotsByIdErrors, GetBotsByIdMembersData, GetBotsByIdMembersError, GetBotsByIdMembersErrors, GetBotsByIdMembersResponse, GetBotsByIdMembersResponses, GetBotsByIdResponse, GetBotsByIdResponses, GetBotsData, GetBotsError, GetBotsErrors, GetBotsResponse, GetBotsResponses, GetChannelsByPlatformData, GetChannelsByPlatformError, GetChannelsByPlatformErrors, GetChannelsByPlatformResponse, GetChannelsByPlatformResponses, GetChannelsData, GetChannelsError, GetChannelsErrors, GetChannelsResponse, GetChannelsResponses, GetModelsByIdData, GetModelsByIdError, GetModelsByIdErrors, GetModelsByIdResponse, GetModelsByIdResponses, GetModelsCountData, GetModelsCountError, GetModelsCountErrors, GetModelsCountResponse, GetModelsCountResponses, GetModelsData, GetModelsError, GetModelsErrors, GetModelsModelByModelIdData, GetModelsModelByModelIdError, GetModelsModelByModelIdErrors, GetModelsModelByModelIdResponse, GetModelsModelByModelIdResponses, GetModelsResponse, GetModelsResponses, GetProvidersByIdData, GetProvidersByIdError, GetProvidersByIdErrors, GetProvidersByIdModelsData, GetProvidersByIdModelsError, GetProvidersByIdModelsErrors, GetProvidersByIdModelsResponse, GetProvidersByIdModelsResponses, GetProvidersByIdResponse, GetProvidersByIdResponses, GetProvidersCountData, GetProvidersCountError, GetProvidersCountErrors, GetProvidersCountResponse, GetProvidersCountResponses, GetProvidersData, GetProvidersError, GetProvidersErrors, GetProvidersNameByNameData, GetProvidersNameByNameError, GetProvidersNameByNameErrors, GetProvidersNameByNameResponse, GetProvidersNameByNameResponses, GetProvidersResponse, GetProvidersResponses, GetUsersByIdData, GetUsersByIdError, GetUsersByIdErrors, GetUsersByIdResponse, GetUsersByIdResponses, GetUsersData, GetUsersError, GetUsersErrors, GetUsersMeChannelsByPlatformData, GetUsersMeChannelsByPlatformError, GetUsersMeChannelsByPlatformErrors, GetUsersMeChannelsByPlatformResponse, GetUsersMeChannelsByPlatformResponses, GetUsersMeData, GetUsersMeError, GetUsersMeErrors, GetUsersMeResponse, GetUsersMeResponses, GetUsersResponse, GetUsersResponses, GithubComMemohaiMemohInternalMcpConnection, HandlersChannelMeta, HandlersCreateContainerRequest, HandlersCreateContainerResponse, HandlersCreateSnapshotRequest, HandlersCreateSnapshotResponse, HandlersEmbeddingsInput, HandlersEmbeddingsRequest, HandlersEmbeddingsResponse, HandlersEmbeddingsUsage, HandlersEnableModelRequest, HandlersErrorResponse, HandlersFsDeleteResponse, HandlersFsListResponse, HandlersFsMkdirRequest, HandlersFsReadResponse, HandlersFsRestEntry, HandlersFsStatResponse, HandlersFsUsageResponse, HandlersFsWriteRequest, HandlersFsWriteResponse, HandlersGetContainerResponse, HandlersListSnapshotsResponse, HandlersLoginRequest, HandlersLoginResponse, HandlersMcpStdioRequest, HandlersMcpStdioResponse, HandlersMemoryAddPayload, HandlersMemoryDeleteAllPayload, HandlersMemoryEmbedUpsertPayload, HandlersMemorySearchPayload, HandlersSkillItem, HandlersSkillsDeleteRequest, HandlersSkillsOpResponse, HandlersSkillsResponse, HandlersSkillsUpsertRequest, HandlersSnapshotInfo, HistoryCreateRequest, HistoryListResponse, HistoryRecord, McpListResponse, McpUpsertRequest, MemoryDeleteResponse, MemoryEmbedInput, MemoryEmbedUpsertResponse, MemoryMemoryItem, MemoryMessage, MemorySearchResponse, MemoryUpdateRequest, ModelsAddRequest, ModelsAddResponse, ModelsCountResponse, ModelsGetResponse, ModelsModelType, ModelsUpdateRequest, PostAuthLoginData, PostAuthLoginError, PostAuthLoginErrors, PostAuthLoginResponse, PostAuthLoginResponses, PostBotsByBotIdChatData, PostBotsByBotIdChatError, PostBotsByBotIdChatErrors, PostBotsByBotIdChatResponse, PostBotsByBotIdChatResponses, PostBotsByBotIdChatStreamData, PostBotsByBotIdChatStreamError, PostBotsByBotIdChatStreamErrors, PostBotsByBotIdChatStreamResponse, PostBotsByBotIdChatStreamResponses, PostBotsByBotIdContainerData, PostBotsByBotIdContainerError, PostBotsByBotIdContainerErrors, PostBotsByBotIdContainerFsDirData, PostBotsByBotIdContainerFsDirError, PostBotsByBotIdContainerFsDirErrors, PostBotsByBotIdContainerFsDirResponse, PostBotsByBotIdContainerFsDirResponses, PostBotsByBotIdContainerFsFileData, PostBotsByBotIdContainerFsFileError, PostBotsByBotIdContainerFsFileErrors, PostBotsByBotIdContainerFsFileResponse, PostBotsByBotIdContainerFsFileResponses, PostBotsByBotIdContainerFsMcpData, PostBotsByBotIdContainerFsMcpError, PostBotsByBotIdContainerFsMcpErrors, PostBotsByBotIdContainerFsMcpResponse, PostBotsByBotIdContainerFsMcpResponses, PostBotsByBotIdContainerFsUploadData, PostBotsByBotIdContainerFsUploadError, PostBotsByBotIdContainerFsUploadErrors, PostBotsByBotIdContainerFsUploadResponse, PostBotsByBotIdContainerFsUploadResponses, PostBotsByBotIdContainerResponse, PostBotsByBotIdContainerResponses, PostBotsByBotIdContainerSkillsData, PostBotsByBotIdContainerSkillsError, PostBotsByBotIdContainerSkillsErrors, PostBotsByBotIdContainerSkillsResponse, PostBotsByBotIdContainerSkillsResponses, PostBotsByBotIdContainerSnapshotsData, PostBotsByBotIdContainerSnapshotsError, PostBotsByBotIdContainerSnapshotsErrors, PostBotsByBotIdContainerSnapshotsResponse, PostBotsByBotIdContainerSnapshotsResponses, PostBotsByBotIdContainerStartData, PostBotsByBotIdContainerStartError, PostBotsByBotIdContainerStartErrors, PostBotsByBotIdContainerStartResponse, PostBotsByBotIdContainerStartResponses, PostBotsByBotIdContainerStopData, PostBotsByBotIdContainerStopError, PostBotsByBotIdContainerStopErrors, PostBotsByBotIdContainerStopResponse, PostBotsByBotIdContainerStopResponses, PostBotsByBotIdHistoryData, PostBotsByBotIdHistoryError, PostBotsByBotIdHistoryErrors, PostBotsByBotIdHistoryResponse, PostBotsByBotIdHistoryResponses, PostBotsByBotIdMcpData, PostBotsByBotIdMcpError, PostBotsByBotIdMcpErrors, PostBotsByBotIdMcpResponse, PostBotsByBotIdMcpResponses, PostBotsByBotIdMcpStdioBySessionIdData, PostBotsByBotIdMcpStdioBySessionIdError, PostBotsByBotIdMcpStdioBySessionIdErrors, PostBotsByBotIdMcpStdioBySessionIdResponse, PostBotsByBotIdMcpStdioBySessionIdResponses, PostBotsByBotIdMcpStdioData, PostBotsByBotIdMcpStdioError, PostBotsByBotIdMcpStdioErrors, PostBotsByBotIdMcpStdioResponse, PostBotsByBotIdMcpStdioResponses, PostBotsByBotIdMemoryAddData, PostBotsByBotIdMemoryAddError, PostBotsByBotIdMemoryAddErrors, PostBotsByBotIdMemoryAddResponse, PostBotsByBotIdMemoryAddResponses, PostBotsByBotIdMemoryEmbedData, PostBotsByBotIdMemoryEmbedError, PostBotsByBotIdMemoryEmbedErrors, PostBotsByBotIdMemoryEmbedResponse, PostBotsByBotIdMemoryEmbedResponses, PostBotsByBotIdMemorySearchData, PostBotsByBotIdMemorySearchError, PostBotsByBotIdMemorySearchErrors, PostBotsByBotIdMemorySearchResponse, PostBotsByBotIdMemorySearchResponses, PostBotsByBotIdMemoryUpdateData, PostBotsByBotIdMemoryUpdateError, PostBotsByBotIdMemoryUpdateErrors, PostBotsByBotIdMemoryUpdateResponse, PostBotsByBotIdMemoryUpdateResponses, PostBotsByBotIdScheduleData, PostBotsByBotIdScheduleError, PostBotsByBotIdScheduleErrors, PostBotsByBotIdScheduleResponse, PostBotsByBotIdScheduleResponses, PostBotsByBotIdSettingsData, PostBotsByBotIdSettingsError, PostBotsByBotIdSettingsErrors, PostBotsByBotIdSettingsResponse, PostBotsByBotIdSettingsResponses, PostBotsByBotIdSubagentsByIdSkillsData, PostBotsByBotIdSubagentsByIdSkillsError, PostBotsByBotIdSubagentsByIdSkillsErrors, PostBotsByBotIdSubagentsByIdSkillsResponse, PostBotsByBotIdSubagentsByIdSkillsResponses, PostBotsByBotIdSubagentsData, PostBotsByBotIdSubagentsError, PostBotsByBotIdSubagentsErrors, PostBotsByBotIdSubagentsResponse, PostBotsByBotIdSubagentsResponses, PostBotsByIdChannelByPlatformSendData, PostBotsByIdChannelByPlatformSendError, PostBotsByIdChannelByPlatformSendErrors, PostBotsByIdChannelByPlatformSendResponse, PostBotsByIdChannelByPlatformSendResponses, PostBotsByIdChannelByPlatformSendSessionData, PostBotsByIdChannelByPlatformSendSessionError, PostBotsByIdChannelByPlatformSendSessionErrors, PostBotsByIdChannelByPlatformSendSessionResponse, PostBotsByIdChannelByPlatformSendSessionResponses, PostBotsData, PostBotsError, PostBotsErrors, PostBotsResponse, PostBotsResponses, PostEmbeddingsData, PostEmbeddingsError, PostEmbeddingsErrors, PostEmbeddingsResponse, PostEmbeddingsResponses, PostModelsData, PostModelsEnableData, PostModelsEnableError, PostModelsEnableErrors, PostModelsEnableResponse, PostModelsEnableResponses, PostModelsError, PostModelsErrors, PostModelsResponse, PostModelsResponses, PostProvidersData, PostProvidersError, PostProvidersErrors, PostProvidersResponse, PostProvidersResponses, PostUsersData, PostUsersError, PostUsersErrors, PostUsersResponse, PostUsersResponses, ProvidersClientType, ProvidersCountResponse, ProvidersCreateRequest, ProvidersGetResponse, ProvidersUpdateRequest, PutBotsByBotIdMcpByIdData, PutBotsByBotIdMcpByIdError, PutBotsByBotIdMcpByIdErrors, PutBotsByBotIdMcpByIdResponse, PutBotsByBotIdMcpByIdResponses, PutBotsByBotIdScheduleByIdData, PutBotsByBotIdScheduleByIdError, PutBotsByBotIdScheduleByIdErrors, PutBotsByBotIdScheduleByIdResponse, PutBotsByBotIdScheduleByIdResponses, PutBotsByBotIdSettingsData, PutBotsByBotIdSettingsError, PutBotsByBotIdSettingsErrors, PutBotsByBotIdSettingsResponse, PutBotsByBotIdSettingsResponses, PutBotsByBotIdSubagentsByIdContextData, PutBotsByBotIdSubagentsByIdContextError, PutBotsByBotIdSubagentsByIdContextErrors, PutBotsByBotIdSubagentsByIdContextResponse, PutBotsByBotIdSubagentsByIdContextResponses, PutBotsByBotIdSubagentsByIdData, PutBotsByBotIdSubagentsByIdError, PutBotsByBotIdSubagentsByIdErrors, PutBotsByBotIdSubagentsByIdResponse, PutBotsByBotIdSubagentsByIdResponses, PutBotsByBotIdSubagentsByIdSkillsData, PutBotsByBotIdSubagentsByIdSkillsError, PutBotsByBotIdSubagentsByIdSkillsErrors, PutBotsByBotIdSubagentsByIdSkillsResponse, PutBotsByBotIdSubagentsByIdSkillsResponses, PutBotsByIdChannelByPlatformData, PutBotsByIdChannelByPlatformError, PutBotsByIdChannelByPlatformErrors, PutBotsByIdChannelByPlatformResponse, PutBotsByIdChannelByPlatformResponses, PutBotsByIdData, PutBotsByIdError, PutBotsByIdErrors, PutBotsByIdMembersData, PutBotsByIdMembersError, PutBotsByIdMembersErrors, PutBotsByIdMembersResponse, PutBotsByIdMembersResponses, PutBotsByIdOwnerData, PutBotsByIdOwnerError, PutBotsByIdOwnerErrors, PutBotsByIdOwnerResponse, PutBotsByIdOwnerResponses, PutBotsByIdResponse, PutBotsByIdResponses, PutModelsByIdData, PutModelsByIdError, PutModelsByIdErrors, PutModelsByIdResponse, PutModelsByIdResponses, PutModelsModelByModelIdData, PutModelsModelByModelIdError, PutModelsModelByModelIdErrors, PutModelsModelByModelIdResponse, PutModelsModelByModelIdResponses, PutProvidersByIdData, PutProvidersByIdError, PutProvidersByIdErrors, PutProvidersByIdResponse, PutProvidersByIdResponses, PutUsersByIdData, PutUsersByIdError, PutUsersByIdErrors, PutUsersByIdPasswordData, PutUsersByIdPasswordError, PutUsersByIdPasswordErrors, PutUsersByIdPasswordResponses, PutUsersByIdResponse, PutUsersByIdResponses, PutUsersMeChannelsByPlatformData, PutUsersMeChannelsByPlatformError, PutUsersMeChannelsByPlatformErrors, PutUsersMeChannelsByPlatformResponse, PutUsersMeChannelsByPlatformResponses, PutUsersMeData, PutUsersMeError, PutUsersMeErrors, PutUsersMePasswordData, PutUsersMePasswordError, PutUsersMePasswordErrors, PutUsersMePasswordResponses, PutUsersMeResponse, PutUsersMeResponses, ScheduleCreateRequest, ScheduleListResponse, ScheduleNullableInt, ScheduleSchedule, ScheduleUpdateRequest, SettingsSettings, SettingsUpsertRequest, SubagentAddSkillsRequest, SubagentContextResponse, SubagentCreateRequest, SubagentListResponse, SubagentSkillsResponse, SubagentSubagent, SubagentUpdateContextRequest, SubagentUpdateRequest, SubagentUpdateSkillsRequest, UsersCreateUserRequest, UsersListUsersResponse, UsersResetPasswordRequest, UsersUpdatePasswordRequest, UsersUpdateProfileRequest, UsersUpdateUserRequest, UsersUser } from './types.gen'; +export { deleteBotsByBotIdContainer, deleteBotsByBotIdContainerSkills, deleteBotsByBotIdMcpById, deleteBotsByBotIdScheduleById, deleteBotsByBotIdSettings, deleteBotsByBotIdSubagentsById, deleteBotsById, deleteBotsByIdMembersByUserId, deleteModelsById, deleteModelsModelByModelId, deleteProvidersById, getBots, getBotsByBotIdContainer, getBotsByBotIdContainerSkills, getBotsByBotIdContainerSnapshots, getBotsByBotIdMcp, getBotsByBotIdMcpById, getBotsByBotIdSchedule, getBotsByBotIdScheduleById, getBotsByBotIdSettings, getBotsByBotIdSubagents, getBotsByBotIdSubagentsById, getBotsByBotIdSubagentsByIdContext, getBotsByBotIdSubagentsByIdSkills, getBotsById, getBotsByIdChannelByPlatform, getBotsByIdChecks, getBotsByIdMembers, getChannels, getChannelsByPlatform, getModels, getModelsById, getModelsCount, getModelsModelByModelId, getProviders, getProvidersById, getProvidersByIdModels, getProvidersCount, getProvidersNameByName, getUsers, getUsersById, getUsersMe, getUsersMeChannelsByPlatform, getUsersMeIdentities, type Options, postAuthLogin, postBots, postBotsByBotIdContainer, postBotsByBotIdContainerSkills, postBotsByBotIdContainerSnapshots, postBotsByBotIdContainerStart, postBotsByBotIdContainerStop, postBotsByBotIdMcp, postBotsByBotIdMcpStdio, postBotsByBotIdMcpStdioByConnectionId, postBotsByBotIdSchedule, postBotsByBotIdSettings, postBotsByBotIdSubagents, postBotsByBotIdSubagentsByIdSkills, postBotsByBotIdTools, postBotsByIdChannelByPlatformSend, postBotsByIdChannelByPlatformSendChat, postEmbeddings, postModels, postModelsEnable, postProviders, postUsers, putBotsByBotIdMcpById, putBotsByBotIdScheduleById, putBotsByBotIdSettings, putBotsByBotIdSubagentsById, putBotsByBotIdSubagentsByIdContext, putBotsByBotIdSubagentsByIdSkills, putBotsById, putBotsByIdChannelByPlatform, putBotsByIdMembers, putBotsByIdOwner, putModelsById, putModelsModelByModelId, putProvidersById, putUsersById, putUsersByIdPassword, putUsersMe, putUsersMeChannelsByPlatform, putUsersMePassword } from './sdk.gen'; +export type { AccountsAccount, AccountsCreateAccountRequest, AccountsListAccountsResponse, AccountsResetPasswordRequest, AccountsUpdateAccountRequest, AccountsUpdatePasswordRequest, AccountsUpdateProfileRequest, BotsBot, BotsBotCheck, BotsBotMember, BotsCreateBotRequest, BotsListBotsResponse, BotsListChecksResponse, BotsListMembersResponse, BotsTransferBotRequest, BotsUpdateBotRequest, BotsUpsertMemberRequest, ChannelAction, ChannelAttachment, ChannelAttachmentType, ChannelChannelCapabilities, ChannelChannelConfig, ChannelChannelIdentityBinding, ChannelConfigSchema, ChannelFieldSchema, ChannelFieldType, ChannelMessage, ChannelMessageFormat, ChannelMessagePart, ChannelMessagePartType, ChannelMessageTextStyle, ChannelReplyRef, ChannelSendRequest, ChannelTargetHint, ChannelTargetSpec, ChannelThreadRef, ChannelUpsertChannelIdentityConfigRequest, ChannelUpsertConfigRequest, ClientOptions, DeleteBotsByBotIdContainerData, DeleteBotsByBotIdContainerError, DeleteBotsByBotIdContainerErrors, DeleteBotsByBotIdContainerResponses, DeleteBotsByBotIdContainerSkillsData, DeleteBotsByBotIdContainerSkillsError, DeleteBotsByBotIdContainerSkillsErrors, DeleteBotsByBotIdContainerSkillsResponse, DeleteBotsByBotIdContainerSkillsResponses, DeleteBotsByBotIdMcpByIdData, DeleteBotsByBotIdMcpByIdError, DeleteBotsByBotIdMcpByIdErrors, DeleteBotsByBotIdMcpByIdResponses, DeleteBotsByBotIdScheduleByIdData, DeleteBotsByBotIdScheduleByIdError, DeleteBotsByBotIdScheduleByIdErrors, DeleteBotsByBotIdScheduleByIdResponses, DeleteBotsByBotIdSettingsData, DeleteBotsByBotIdSettingsError, DeleteBotsByBotIdSettingsErrors, DeleteBotsByBotIdSettingsResponses, DeleteBotsByBotIdSubagentsByIdData, DeleteBotsByBotIdSubagentsByIdError, DeleteBotsByBotIdSubagentsByIdErrors, DeleteBotsByBotIdSubagentsByIdResponses, DeleteBotsByIdData, DeleteBotsByIdError, DeleteBotsByIdErrors, DeleteBotsByIdMembersByUserIdData, DeleteBotsByIdMembersByUserIdError, DeleteBotsByIdMembersByUserIdErrors, DeleteBotsByIdMembersByUserIdResponses, DeleteBotsByIdResponse, DeleteBotsByIdResponses, DeleteModelsByIdData, DeleteModelsByIdError, DeleteModelsByIdErrors, DeleteModelsByIdResponses, DeleteModelsModelByModelIdData, DeleteModelsModelByModelIdError, DeleteModelsModelByModelIdErrors, DeleteModelsModelByModelIdResponses, DeleteProvidersByIdData, DeleteProvidersByIdError, DeleteProvidersByIdErrors, DeleteProvidersByIdResponses, GetBotsByBotIdContainerData, GetBotsByBotIdContainerError, GetBotsByBotIdContainerErrors, GetBotsByBotIdContainerResponse, GetBotsByBotIdContainerResponses, GetBotsByBotIdContainerSkillsData, GetBotsByBotIdContainerSkillsError, GetBotsByBotIdContainerSkillsErrors, GetBotsByBotIdContainerSkillsResponse, GetBotsByBotIdContainerSkillsResponses, GetBotsByBotIdContainerSnapshotsData, GetBotsByBotIdContainerSnapshotsResponse, GetBotsByBotIdContainerSnapshotsResponses, GetBotsByBotIdMcpByIdData, GetBotsByBotIdMcpByIdError, GetBotsByBotIdMcpByIdErrors, GetBotsByBotIdMcpByIdResponse, GetBotsByBotIdMcpByIdResponses, GetBotsByBotIdMcpData, GetBotsByBotIdMcpError, GetBotsByBotIdMcpErrors, GetBotsByBotIdMcpResponse, GetBotsByBotIdMcpResponses, GetBotsByBotIdScheduleByIdData, GetBotsByBotIdScheduleByIdError, GetBotsByBotIdScheduleByIdErrors, GetBotsByBotIdScheduleByIdResponse, GetBotsByBotIdScheduleByIdResponses, GetBotsByBotIdScheduleData, GetBotsByBotIdScheduleError, GetBotsByBotIdScheduleErrors, GetBotsByBotIdScheduleResponse, GetBotsByBotIdScheduleResponses, GetBotsByBotIdSettingsData, GetBotsByBotIdSettingsError, GetBotsByBotIdSettingsErrors, GetBotsByBotIdSettingsResponse, GetBotsByBotIdSettingsResponses, GetBotsByBotIdSubagentsByIdContextData, GetBotsByBotIdSubagentsByIdContextError, GetBotsByBotIdSubagentsByIdContextErrors, GetBotsByBotIdSubagentsByIdContextResponse, GetBotsByBotIdSubagentsByIdContextResponses, GetBotsByBotIdSubagentsByIdData, GetBotsByBotIdSubagentsByIdError, GetBotsByBotIdSubagentsByIdErrors, GetBotsByBotIdSubagentsByIdResponse, GetBotsByBotIdSubagentsByIdResponses, GetBotsByBotIdSubagentsByIdSkillsData, GetBotsByBotIdSubagentsByIdSkillsError, GetBotsByBotIdSubagentsByIdSkillsErrors, GetBotsByBotIdSubagentsByIdSkillsResponse, GetBotsByBotIdSubagentsByIdSkillsResponses, GetBotsByBotIdSubagentsData, GetBotsByBotIdSubagentsError, GetBotsByBotIdSubagentsErrors, GetBotsByBotIdSubagentsResponse, GetBotsByBotIdSubagentsResponses, GetBotsByIdChannelByPlatformData, GetBotsByIdChannelByPlatformError, GetBotsByIdChannelByPlatformErrors, GetBotsByIdChannelByPlatformResponse, GetBotsByIdChannelByPlatformResponses, GetBotsByIdChecksData, GetBotsByIdChecksError, GetBotsByIdChecksErrors, GetBotsByIdChecksResponse, GetBotsByIdChecksResponses, GetBotsByIdData, GetBotsByIdError, GetBotsByIdErrors, GetBotsByIdMembersData, GetBotsByIdMembersError, GetBotsByIdMembersErrors, GetBotsByIdMembersResponse, GetBotsByIdMembersResponses, GetBotsByIdResponse, GetBotsByIdResponses, GetBotsData, GetBotsError, GetBotsErrors, GetBotsResponse, GetBotsResponses, GetChannelsByPlatformData, GetChannelsByPlatformError, GetChannelsByPlatformErrors, GetChannelsByPlatformResponse, GetChannelsByPlatformResponses, GetChannelsData, GetChannelsError, GetChannelsErrors, GetChannelsResponse, GetChannelsResponses, GetModelsByIdData, GetModelsByIdError, GetModelsByIdErrors, GetModelsByIdResponse, GetModelsByIdResponses, GetModelsCountData, GetModelsCountError, GetModelsCountErrors, GetModelsCountResponse, GetModelsCountResponses, GetModelsData, GetModelsError, GetModelsErrors, GetModelsModelByModelIdData, GetModelsModelByModelIdError, GetModelsModelByModelIdErrors, GetModelsModelByModelIdResponse, GetModelsModelByModelIdResponses, GetModelsResponse, GetModelsResponses, GetProvidersByIdData, GetProvidersByIdError, GetProvidersByIdErrors, GetProvidersByIdModelsData, GetProvidersByIdModelsError, GetProvidersByIdModelsErrors, GetProvidersByIdModelsResponse, GetProvidersByIdModelsResponses, GetProvidersByIdResponse, GetProvidersByIdResponses, GetProvidersCountData, GetProvidersCountError, GetProvidersCountErrors, GetProvidersCountResponse, GetProvidersCountResponses, GetProvidersData, GetProvidersError, GetProvidersErrors, GetProvidersNameByNameData, GetProvidersNameByNameError, GetProvidersNameByNameErrors, GetProvidersNameByNameResponse, GetProvidersNameByNameResponses, GetProvidersResponse, GetProvidersResponses, GetUsersByIdData, GetUsersByIdError, GetUsersByIdErrors, GetUsersByIdResponse, GetUsersByIdResponses, GetUsersData, GetUsersError, GetUsersErrors, GetUsersMeChannelsByPlatformData, GetUsersMeChannelsByPlatformError, GetUsersMeChannelsByPlatformErrors, GetUsersMeChannelsByPlatformResponse, GetUsersMeChannelsByPlatformResponses, GetUsersMeData, GetUsersMeError, GetUsersMeErrors, GetUsersMeIdentitiesData, GetUsersMeIdentitiesError, GetUsersMeIdentitiesErrors, GetUsersMeIdentitiesResponse, GetUsersMeIdentitiesResponses, GetUsersMeResponse, GetUsersMeResponses, GetUsersResponse, GetUsersResponses, GithubComMemohaiMemohInternalMcpConnection, HandlersChannelMeta, HandlersCreateContainerRequest, HandlersCreateContainerResponse, HandlersCreateSnapshotRequest, HandlersCreateSnapshotResponse, HandlersEmbeddingsInput, HandlersEmbeddingsRequest, HandlersEmbeddingsResponse, HandlersEmbeddingsUsage, HandlersEnableModelRequest, HandlersErrorResponse, HandlersGetContainerResponse, HandlersListMyIdentitiesResponse, HandlersListSnapshotsResponse, HandlersLoginRequest, HandlersLoginResponse, HandlersMcpStdioRequest, HandlersMcpStdioResponse, HandlersSkillItem, HandlersSkillsDeleteRequest, HandlersSkillsOpResponse, HandlersSkillsResponse, HandlersSkillsUpsertRequest, HandlersSnapshotInfo, IdentitiesChannelIdentity, McpListResponse, McpUpsertRequest, ModelsAddRequest, ModelsAddResponse, ModelsCountResponse, ModelsGetResponse, ModelsModelType, ModelsUpdateRequest, PostAuthLoginData, PostAuthLoginError, PostAuthLoginErrors, PostAuthLoginResponse, PostAuthLoginResponses, PostBotsByBotIdContainerData, PostBotsByBotIdContainerError, PostBotsByBotIdContainerErrors, PostBotsByBotIdContainerResponse, PostBotsByBotIdContainerResponses, PostBotsByBotIdContainerSkillsData, PostBotsByBotIdContainerSkillsError, PostBotsByBotIdContainerSkillsErrors, PostBotsByBotIdContainerSkillsResponse, PostBotsByBotIdContainerSkillsResponses, PostBotsByBotIdContainerSnapshotsData, PostBotsByBotIdContainerSnapshotsError, PostBotsByBotIdContainerSnapshotsErrors, PostBotsByBotIdContainerSnapshotsResponse, PostBotsByBotIdContainerSnapshotsResponses, PostBotsByBotIdContainerStartData, PostBotsByBotIdContainerStartError, PostBotsByBotIdContainerStartErrors, PostBotsByBotIdContainerStartResponse, PostBotsByBotIdContainerStartResponses, PostBotsByBotIdContainerStopData, PostBotsByBotIdContainerStopError, PostBotsByBotIdContainerStopErrors, PostBotsByBotIdContainerStopResponse, PostBotsByBotIdContainerStopResponses, PostBotsByBotIdMcpData, PostBotsByBotIdMcpError, PostBotsByBotIdMcpErrors, PostBotsByBotIdMcpResponse, PostBotsByBotIdMcpResponses, PostBotsByBotIdMcpStdioByConnectionIdData, PostBotsByBotIdMcpStdioByConnectionIdError, PostBotsByBotIdMcpStdioByConnectionIdErrors, PostBotsByBotIdMcpStdioByConnectionIdResponse, PostBotsByBotIdMcpStdioByConnectionIdResponses, PostBotsByBotIdMcpStdioData, PostBotsByBotIdMcpStdioError, PostBotsByBotIdMcpStdioErrors, PostBotsByBotIdMcpStdioResponse, PostBotsByBotIdMcpStdioResponses, PostBotsByBotIdScheduleData, PostBotsByBotIdScheduleError, PostBotsByBotIdScheduleErrors, PostBotsByBotIdScheduleResponse, PostBotsByBotIdScheduleResponses, PostBotsByBotIdSettingsData, PostBotsByBotIdSettingsError, PostBotsByBotIdSettingsErrors, PostBotsByBotIdSettingsResponse, PostBotsByBotIdSettingsResponses, PostBotsByBotIdSubagentsByIdSkillsData, PostBotsByBotIdSubagentsByIdSkillsError, PostBotsByBotIdSubagentsByIdSkillsErrors, PostBotsByBotIdSubagentsByIdSkillsResponse, PostBotsByBotIdSubagentsByIdSkillsResponses, PostBotsByBotIdSubagentsData, PostBotsByBotIdSubagentsError, PostBotsByBotIdSubagentsErrors, PostBotsByBotIdSubagentsResponse, PostBotsByBotIdSubagentsResponses, PostBotsByBotIdToolsData, PostBotsByBotIdToolsError, PostBotsByBotIdToolsErrors, PostBotsByBotIdToolsResponse, PostBotsByBotIdToolsResponses, PostBotsByIdChannelByPlatformSendChatData, PostBotsByIdChannelByPlatformSendChatError, PostBotsByIdChannelByPlatformSendChatErrors, PostBotsByIdChannelByPlatformSendChatResponse, PostBotsByIdChannelByPlatformSendChatResponses, PostBotsByIdChannelByPlatformSendData, PostBotsByIdChannelByPlatformSendError, PostBotsByIdChannelByPlatformSendErrors, PostBotsByIdChannelByPlatformSendResponse, PostBotsByIdChannelByPlatformSendResponses, PostBotsData, PostBotsError, PostBotsErrors, PostBotsResponse, PostBotsResponses, PostEmbeddingsData, PostEmbeddingsError, PostEmbeddingsErrors, PostEmbeddingsResponse, PostEmbeddingsResponses, PostModelsData, PostModelsEnableData, PostModelsEnableError, PostModelsEnableErrors, PostModelsEnableResponse, PostModelsEnableResponses, PostModelsError, PostModelsErrors, PostModelsResponse, PostModelsResponses, PostProvidersData, PostProvidersError, PostProvidersErrors, PostProvidersResponse, PostProvidersResponses, PostUsersData, PostUsersError, PostUsersErrors, PostUsersResponse, PostUsersResponses, ProvidersClientType, ProvidersCountResponse, ProvidersCreateRequest, ProvidersGetResponse, ProvidersUpdateRequest, PutBotsByBotIdMcpByIdData, PutBotsByBotIdMcpByIdError, PutBotsByBotIdMcpByIdErrors, PutBotsByBotIdMcpByIdResponse, PutBotsByBotIdMcpByIdResponses, PutBotsByBotIdScheduleByIdData, PutBotsByBotIdScheduleByIdError, PutBotsByBotIdScheduleByIdErrors, PutBotsByBotIdScheduleByIdResponse, PutBotsByBotIdScheduleByIdResponses, PutBotsByBotIdSettingsData, PutBotsByBotIdSettingsError, PutBotsByBotIdSettingsErrors, PutBotsByBotIdSettingsResponse, PutBotsByBotIdSettingsResponses, PutBotsByBotIdSubagentsByIdContextData, PutBotsByBotIdSubagentsByIdContextError, PutBotsByBotIdSubagentsByIdContextErrors, PutBotsByBotIdSubagentsByIdContextResponse, PutBotsByBotIdSubagentsByIdContextResponses, PutBotsByBotIdSubagentsByIdData, PutBotsByBotIdSubagentsByIdError, PutBotsByBotIdSubagentsByIdErrors, PutBotsByBotIdSubagentsByIdResponse, PutBotsByBotIdSubagentsByIdResponses, PutBotsByBotIdSubagentsByIdSkillsData, PutBotsByBotIdSubagentsByIdSkillsError, PutBotsByBotIdSubagentsByIdSkillsErrors, PutBotsByBotIdSubagentsByIdSkillsResponse, PutBotsByBotIdSubagentsByIdSkillsResponses, PutBotsByIdChannelByPlatformData, PutBotsByIdChannelByPlatformError, PutBotsByIdChannelByPlatformErrors, PutBotsByIdChannelByPlatformResponse, PutBotsByIdChannelByPlatformResponses, PutBotsByIdData, PutBotsByIdError, PutBotsByIdErrors, PutBotsByIdMembersData, PutBotsByIdMembersError, PutBotsByIdMembersErrors, PutBotsByIdMembersResponse, PutBotsByIdMembersResponses, PutBotsByIdOwnerData, PutBotsByIdOwnerError, PutBotsByIdOwnerErrors, PutBotsByIdOwnerResponse, PutBotsByIdOwnerResponses, PutBotsByIdResponse, PutBotsByIdResponses, PutModelsByIdData, PutModelsByIdError, PutModelsByIdErrors, PutModelsByIdResponse, PutModelsByIdResponses, PutModelsModelByModelIdData, PutModelsModelByModelIdError, PutModelsModelByModelIdErrors, PutModelsModelByModelIdResponse, PutModelsModelByModelIdResponses, PutProvidersByIdData, PutProvidersByIdError, PutProvidersByIdErrors, PutProvidersByIdResponse, PutProvidersByIdResponses, PutUsersByIdData, PutUsersByIdError, PutUsersByIdErrors, PutUsersByIdPasswordData, PutUsersByIdPasswordError, PutUsersByIdPasswordErrors, PutUsersByIdPasswordResponses, PutUsersByIdResponse, PutUsersByIdResponses, PutUsersMeChannelsByPlatformData, PutUsersMeChannelsByPlatformError, PutUsersMeChannelsByPlatformErrors, PutUsersMeChannelsByPlatformResponse, PutUsersMeChannelsByPlatformResponses, PutUsersMeData, PutUsersMeError, PutUsersMeErrors, PutUsersMePasswordData, PutUsersMePasswordError, PutUsersMePasswordErrors, PutUsersMePasswordResponses, PutUsersMeResponse, PutUsersMeResponses, ScheduleCreateRequest, ScheduleListResponse, ScheduleNullableInt, ScheduleSchedule, ScheduleUpdateRequest, SettingsSettings, SettingsUpsertRequest, SubagentAddSkillsRequest, SubagentContextResponse, SubagentCreateRequest, SubagentListResponse, SubagentSkillsResponse, SubagentSubagent, SubagentUpdateContextRequest, SubagentUpdateRequest, SubagentUpdateSkillsRequest } from './types.gen'; diff --git a/packages/sdk/src/sdk.gen.ts b/packages/sdk/src/sdk.gen.ts index 44e22042..91a60eae 100644 --- a/packages/sdk/src/sdk.gen.ts +++ b/packages/sdk/src/sdk.gen.ts @@ -2,7 +2,7 @@ import type { Client, Options as Options2, TDataShape } from './client'; import { client } from './client.gen'; -import type { DeleteBotsByBotIdContainerData, DeleteBotsByBotIdContainerErrors, DeleteBotsByBotIdContainerFsData, DeleteBotsByBotIdContainerFsErrors, DeleteBotsByBotIdContainerFsResponses, DeleteBotsByBotIdContainerResponses, DeleteBotsByBotIdContainerSkillsData, DeleteBotsByBotIdContainerSkillsErrors, DeleteBotsByBotIdContainerSkillsResponses, DeleteBotsByBotIdHistoryByIdData, DeleteBotsByBotIdHistoryByIdErrors, DeleteBotsByBotIdHistoryByIdResponses, DeleteBotsByBotIdHistoryData, DeleteBotsByBotIdHistoryErrors, DeleteBotsByBotIdHistoryResponses, DeleteBotsByBotIdMcpByIdData, DeleteBotsByBotIdMcpByIdErrors, DeleteBotsByBotIdMcpByIdResponses, DeleteBotsByBotIdMemoryMemoriesByMemoryIdData, DeleteBotsByBotIdMemoryMemoriesByMemoryIdErrors, DeleteBotsByBotIdMemoryMemoriesByMemoryIdResponses, DeleteBotsByBotIdMemoryMemoriesData, DeleteBotsByBotIdMemoryMemoriesErrors, DeleteBotsByBotIdMemoryMemoriesResponses, DeleteBotsByBotIdScheduleByIdData, DeleteBotsByBotIdScheduleByIdErrors, DeleteBotsByBotIdScheduleByIdResponses, DeleteBotsByBotIdSettingsData, DeleteBotsByBotIdSettingsErrors, DeleteBotsByBotIdSettingsResponses, DeleteBotsByBotIdSubagentsByIdData, DeleteBotsByBotIdSubagentsByIdErrors, DeleteBotsByBotIdSubagentsByIdResponses, DeleteBotsByIdData, DeleteBotsByIdErrors, DeleteBotsByIdMembersByUserIdData, DeleteBotsByIdMembersByUserIdErrors, DeleteBotsByIdMembersByUserIdResponses, DeleteBotsByIdResponses, DeleteModelsByIdData, DeleteModelsByIdErrors, DeleteModelsByIdResponses, DeleteModelsModelByModelIdData, DeleteModelsModelByModelIdErrors, DeleteModelsModelByModelIdResponses, DeleteProvidersByIdData, DeleteProvidersByIdErrors, DeleteProvidersByIdResponses, GetBotsByBotIdContainerData, GetBotsByBotIdContainerErrors, GetBotsByBotIdContainerFsData, GetBotsByBotIdContainerFsErrors, GetBotsByBotIdContainerFsFileData, GetBotsByBotIdContainerFsFileErrors, GetBotsByBotIdContainerFsFileResponses, GetBotsByBotIdContainerFsResponses, GetBotsByBotIdContainerFsStatData, GetBotsByBotIdContainerFsStatErrors, GetBotsByBotIdContainerFsStatResponses, GetBotsByBotIdContainerFsUsageData, GetBotsByBotIdContainerFsUsageErrors, GetBotsByBotIdContainerFsUsageResponses, GetBotsByBotIdContainerResponses, GetBotsByBotIdContainerSkillsData, GetBotsByBotIdContainerSkillsErrors, GetBotsByBotIdContainerSkillsResponses, GetBotsByBotIdContainerSnapshotsData, GetBotsByBotIdContainerSnapshotsResponses, GetBotsByBotIdHistoryByIdData, GetBotsByBotIdHistoryByIdErrors, GetBotsByBotIdHistoryByIdResponses, GetBotsByBotIdHistoryData, GetBotsByBotIdHistoryErrors, GetBotsByBotIdHistoryResponses, GetBotsByBotIdMcpByIdData, GetBotsByBotIdMcpByIdErrors, GetBotsByBotIdMcpByIdResponses, GetBotsByBotIdMcpData, GetBotsByBotIdMcpErrors, GetBotsByBotIdMcpResponses, GetBotsByBotIdMemoryMemoriesByMemoryIdData, GetBotsByBotIdMemoryMemoriesByMemoryIdErrors, GetBotsByBotIdMemoryMemoriesByMemoryIdResponses, GetBotsByBotIdMemoryMemoriesData, GetBotsByBotIdMemoryMemoriesErrors, GetBotsByBotIdMemoryMemoriesResponses, GetBotsByBotIdScheduleByIdData, GetBotsByBotIdScheduleByIdErrors, GetBotsByBotIdScheduleByIdResponses, GetBotsByBotIdScheduleData, GetBotsByBotIdScheduleErrors, GetBotsByBotIdScheduleResponses, GetBotsByBotIdSettingsData, GetBotsByBotIdSettingsErrors, GetBotsByBotIdSettingsResponses, GetBotsByBotIdSubagentsByIdContextData, GetBotsByBotIdSubagentsByIdContextErrors, GetBotsByBotIdSubagentsByIdContextResponses, GetBotsByBotIdSubagentsByIdData, GetBotsByBotIdSubagentsByIdErrors, GetBotsByBotIdSubagentsByIdResponses, GetBotsByBotIdSubagentsByIdSkillsData, GetBotsByBotIdSubagentsByIdSkillsErrors, GetBotsByBotIdSubagentsByIdSkillsResponses, GetBotsByBotIdSubagentsData, GetBotsByBotIdSubagentsErrors, GetBotsByBotIdSubagentsResponses, GetBotsByIdChannelByPlatformData, GetBotsByIdChannelByPlatformErrors, GetBotsByIdChannelByPlatformResponses, GetBotsByIdData, GetBotsByIdErrors, GetBotsByIdMembersData, GetBotsByIdMembersErrors, GetBotsByIdMembersResponses, GetBotsByIdResponses, GetBotsData, GetBotsErrors, GetBotsResponses, GetChannelsByPlatformData, GetChannelsByPlatformErrors, GetChannelsByPlatformResponses, GetChannelsData, GetChannelsErrors, GetChannelsResponses, GetModelsByIdData, GetModelsByIdErrors, GetModelsByIdResponses, GetModelsCountData, GetModelsCountErrors, GetModelsCountResponses, GetModelsData, GetModelsErrors, GetModelsModelByModelIdData, GetModelsModelByModelIdErrors, GetModelsModelByModelIdResponses, GetModelsResponses, GetProvidersByIdData, GetProvidersByIdErrors, GetProvidersByIdModelsData, GetProvidersByIdModelsErrors, GetProvidersByIdModelsResponses, GetProvidersByIdResponses, GetProvidersCountData, GetProvidersCountErrors, GetProvidersCountResponses, GetProvidersData, GetProvidersErrors, GetProvidersNameByNameData, GetProvidersNameByNameErrors, GetProvidersNameByNameResponses, GetProvidersResponses, GetUsersByIdData, GetUsersByIdErrors, GetUsersByIdResponses, GetUsersData, GetUsersErrors, GetUsersMeChannelsByPlatformData, GetUsersMeChannelsByPlatformErrors, GetUsersMeChannelsByPlatformResponses, GetUsersMeData, GetUsersMeErrors, GetUsersMeResponses, GetUsersResponses, PostAuthLoginData, PostAuthLoginErrors, PostAuthLoginResponses, PostBotsByBotIdChatData, PostBotsByBotIdChatErrors, PostBotsByBotIdChatResponses, PostBotsByBotIdChatStreamData, PostBotsByBotIdChatStreamErrors, PostBotsByBotIdChatStreamResponses, PostBotsByBotIdContainerData, PostBotsByBotIdContainerErrors, PostBotsByBotIdContainerFsDirData, PostBotsByBotIdContainerFsDirErrors, PostBotsByBotIdContainerFsDirResponses, PostBotsByBotIdContainerFsFileData, PostBotsByBotIdContainerFsFileErrors, PostBotsByBotIdContainerFsFileResponses, PostBotsByBotIdContainerFsMcpData, PostBotsByBotIdContainerFsMcpErrors, PostBotsByBotIdContainerFsMcpResponses, PostBotsByBotIdContainerFsUploadData, PostBotsByBotIdContainerFsUploadErrors, PostBotsByBotIdContainerFsUploadResponses, PostBotsByBotIdContainerResponses, PostBotsByBotIdContainerSkillsData, PostBotsByBotIdContainerSkillsErrors, PostBotsByBotIdContainerSkillsResponses, PostBotsByBotIdContainerSnapshotsData, PostBotsByBotIdContainerSnapshotsErrors, PostBotsByBotIdContainerSnapshotsResponses, PostBotsByBotIdContainerStartData, PostBotsByBotIdContainerStartErrors, PostBotsByBotIdContainerStartResponses, PostBotsByBotIdContainerStopData, PostBotsByBotIdContainerStopErrors, PostBotsByBotIdContainerStopResponses, PostBotsByBotIdHistoryData, PostBotsByBotIdHistoryErrors, PostBotsByBotIdHistoryResponses, PostBotsByBotIdMcpData, PostBotsByBotIdMcpErrors, PostBotsByBotIdMcpResponses, PostBotsByBotIdMcpStdioBySessionIdData, PostBotsByBotIdMcpStdioBySessionIdErrors, PostBotsByBotIdMcpStdioBySessionIdResponses, PostBotsByBotIdMcpStdioData, PostBotsByBotIdMcpStdioErrors, PostBotsByBotIdMcpStdioResponses, PostBotsByBotIdMemoryAddData, PostBotsByBotIdMemoryAddErrors, PostBotsByBotIdMemoryAddResponses, PostBotsByBotIdMemoryEmbedData, PostBotsByBotIdMemoryEmbedErrors, PostBotsByBotIdMemoryEmbedResponses, PostBotsByBotIdMemorySearchData, PostBotsByBotIdMemorySearchErrors, PostBotsByBotIdMemorySearchResponses, PostBotsByBotIdMemoryUpdateData, PostBotsByBotIdMemoryUpdateErrors, PostBotsByBotIdMemoryUpdateResponses, PostBotsByBotIdScheduleData, PostBotsByBotIdScheduleErrors, PostBotsByBotIdScheduleResponses, PostBotsByBotIdSettingsData, PostBotsByBotIdSettingsErrors, PostBotsByBotIdSettingsResponses, PostBotsByBotIdSubagentsByIdSkillsData, PostBotsByBotIdSubagentsByIdSkillsErrors, PostBotsByBotIdSubagentsByIdSkillsResponses, PostBotsByBotIdSubagentsData, PostBotsByBotIdSubagentsErrors, PostBotsByBotIdSubagentsResponses, PostBotsByIdChannelByPlatformSendData, PostBotsByIdChannelByPlatformSendErrors, PostBotsByIdChannelByPlatformSendResponses, PostBotsByIdChannelByPlatformSendSessionData, PostBotsByIdChannelByPlatformSendSessionErrors, PostBotsByIdChannelByPlatformSendSessionResponses, PostBotsData, PostBotsErrors, PostBotsResponses, PostEmbeddingsData, PostEmbeddingsErrors, PostEmbeddingsResponses, PostModelsData, PostModelsEnableData, PostModelsEnableErrors, PostModelsEnableResponses, PostModelsErrors, PostModelsResponses, PostProvidersData, PostProvidersErrors, PostProvidersResponses, PostUsersData, PostUsersErrors, PostUsersResponses, PutBotsByBotIdMcpByIdData, PutBotsByBotIdMcpByIdErrors, PutBotsByBotIdMcpByIdResponses, PutBotsByBotIdScheduleByIdData, PutBotsByBotIdScheduleByIdErrors, PutBotsByBotIdScheduleByIdResponses, PutBotsByBotIdSettingsData, PutBotsByBotIdSettingsErrors, PutBotsByBotIdSettingsResponses, PutBotsByBotIdSubagentsByIdContextData, PutBotsByBotIdSubagentsByIdContextErrors, PutBotsByBotIdSubagentsByIdContextResponses, PutBotsByBotIdSubagentsByIdData, PutBotsByBotIdSubagentsByIdErrors, PutBotsByBotIdSubagentsByIdResponses, PutBotsByBotIdSubagentsByIdSkillsData, PutBotsByBotIdSubagentsByIdSkillsErrors, PutBotsByBotIdSubagentsByIdSkillsResponses, PutBotsByIdChannelByPlatformData, PutBotsByIdChannelByPlatformErrors, PutBotsByIdChannelByPlatformResponses, PutBotsByIdData, PutBotsByIdErrors, PutBotsByIdMembersData, PutBotsByIdMembersErrors, PutBotsByIdMembersResponses, PutBotsByIdOwnerData, PutBotsByIdOwnerErrors, PutBotsByIdOwnerResponses, PutBotsByIdResponses, PutModelsByIdData, PutModelsByIdErrors, PutModelsByIdResponses, PutModelsModelByModelIdData, PutModelsModelByModelIdErrors, PutModelsModelByModelIdResponses, PutProvidersByIdData, PutProvidersByIdErrors, PutProvidersByIdResponses, PutUsersByIdData, PutUsersByIdErrors, PutUsersByIdPasswordData, PutUsersByIdPasswordErrors, PutUsersByIdPasswordResponses, PutUsersByIdResponses, PutUsersMeChannelsByPlatformData, PutUsersMeChannelsByPlatformErrors, PutUsersMeChannelsByPlatformResponses, PutUsersMeData, PutUsersMeErrors, PutUsersMePasswordData, PutUsersMePasswordErrors, PutUsersMePasswordResponses, PutUsersMeResponses } from './types.gen'; +import type { DeleteBotsByBotIdContainerData, DeleteBotsByBotIdContainerErrors, DeleteBotsByBotIdContainerResponses, DeleteBotsByBotIdContainerSkillsData, DeleteBotsByBotIdContainerSkillsErrors, DeleteBotsByBotIdContainerSkillsResponses, DeleteBotsByBotIdMcpByIdData, DeleteBotsByBotIdMcpByIdErrors, DeleteBotsByBotIdMcpByIdResponses, DeleteBotsByBotIdScheduleByIdData, DeleteBotsByBotIdScheduleByIdErrors, DeleteBotsByBotIdScheduleByIdResponses, DeleteBotsByBotIdSettingsData, DeleteBotsByBotIdSettingsErrors, DeleteBotsByBotIdSettingsResponses, DeleteBotsByBotIdSubagentsByIdData, DeleteBotsByBotIdSubagentsByIdErrors, DeleteBotsByBotIdSubagentsByIdResponses, DeleteBotsByIdData, DeleteBotsByIdErrors, DeleteBotsByIdMembersByUserIdData, DeleteBotsByIdMembersByUserIdErrors, DeleteBotsByIdMembersByUserIdResponses, DeleteBotsByIdResponses, DeleteModelsByIdData, DeleteModelsByIdErrors, DeleteModelsByIdResponses, DeleteModelsModelByModelIdData, DeleteModelsModelByModelIdErrors, DeleteModelsModelByModelIdResponses, DeleteProvidersByIdData, DeleteProvidersByIdErrors, DeleteProvidersByIdResponses, GetBotsByBotIdContainerData, GetBotsByBotIdContainerErrors, GetBotsByBotIdContainerResponses, GetBotsByBotIdContainerSkillsData, GetBotsByBotIdContainerSkillsErrors, GetBotsByBotIdContainerSkillsResponses, GetBotsByBotIdContainerSnapshotsData, GetBotsByBotIdContainerSnapshotsResponses, GetBotsByBotIdMcpByIdData, GetBotsByBotIdMcpByIdErrors, GetBotsByBotIdMcpByIdResponses, GetBotsByBotIdMcpData, GetBotsByBotIdMcpErrors, GetBotsByBotIdMcpResponses, GetBotsByBotIdScheduleByIdData, GetBotsByBotIdScheduleByIdErrors, GetBotsByBotIdScheduleByIdResponses, GetBotsByBotIdScheduleData, GetBotsByBotIdScheduleErrors, GetBotsByBotIdScheduleResponses, GetBotsByBotIdSettingsData, GetBotsByBotIdSettingsErrors, GetBotsByBotIdSettingsResponses, GetBotsByBotIdSubagentsByIdContextData, GetBotsByBotIdSubagentsByIdContextErrors, GetBotsByBotIdSubagentsByIdContextResponses, GetBotsByBotIdSubagentsByIdData, GetBotsByBotIdSubagentsByIdErrors, GetBotsByBotIdSubagentsByIdResponses, GetBotsByBotIdSubagentsByIdSkillsData, GetBotsByBotIdSubagentsByIdSkillsErrors, GetBotsByBotIdSubagentsByIdSkillsResponses, GetBotsByBotIdSubagentsData, GetBotsByBotIdSubagentsErrors, GetBotsByBotIdSubagentsResponses, GetBotsByIdChannelByPlatformData, GetBotsByIdChannelByPlatformErrors, GetBotsByIdChannelByPlatformResponses, GetBotsByIdChecksData, GetBotsByIdChecksErrors, GetBotsByIdChecksResponses, GetBotsByIdData, GetBotsByIdErrors, GetBotsByIdMembersData, GetBotsByIdMembersErrors, GetBotsByIdMembersResponses, GetBotsByIdResponses, GetBotsData, GetBotsErrors, GetBotsResponses, GetChannelsByPlatformData, GetChannelsByPlatformErrors, GetChannelsByPlatformResponses, GetChannelsData, GetChannelsErrors, GetChannelsResponses, GetModelsByIdData, GetModelsByIdErrors, GetModelsByIdResponses, GetModelsCountData, GetModelsCountErrors, GetModelsCountResponses, GetModelsData, GetModelsErrors, GetModelsModelByModelIdData, GetModelsModelByModelIdErrors, GetModelsModelByModelIdResponses, GetModelsResponses, GetProvidersByIdData, GetProvidersByIdErrors, GetProvidersByIdModelsData, GetProvidersByIdModelsErrors, GetProvidersByIdModelsResponses, GetProvidersByIdResponses, GetProvidersCountData, GetProvidersCountErrors, GetProvidersCountResponses, GetProvidersData, GetProvidersErrors, GetProvidersNameByNameData, GetProvidersNameByNameErrors, GetProvidersNameByNameResponses, GetProvidersResponses, GetUsersByIdData, GetUsersByIdErrors, GetUsersByIdResponses, GetUsersData, GetUsersErrors, GetUsersMeChannelsByPlatformData, GetUsersMeChannelsByPlatformErrors, GetUsersMeChannelsByPlatformResponses, GetUsersMeData, GetUsersMeErrors, GetUsersMeIdentitiesData, GetUsersMeIdentitiesErrors, GetUsersMeIdentitiesResponses, GetUsersMeResponses, GetUsersResponses, PostAuthLoginData, PostAuthLoginErrors, PostAuthLoginResponses, PostBotsByBotIdContainerData, PostBotsByBotIdContainerErrors, PostBotsByBotIdContainerResponses, PostBotsByBotIdContainerSkillsData, PostBotsByBotIdContainerSkillsErrors, PostBotsByBotIdContainerSkillsResponses, PostBotsByBotIdContainerSnapshotsData, PostBotsByBotIdContainerSnapshotsErrors, PostBotsByBotIdContainerSnapshotsResponses, PostBotsByBotIdContainerStartData, PostBotsByBotIdContainerStartErrors, PostBotsByBotIdContainerStartResponses, PostBotsByBotIdContainerStopData, PostBotsByBotIdContainerStopErrors, PostBotsByBotIdContainerStopResponses, PostBotsByBotIdMcpData, PostBotsByBotIdMcpErrors, PostBotsByBotIdMcpResponses, PostBotsByBotIdMcpStdioByConnectionIdData, PostBotsByBotIdMcpStdioByConnectionIdErrors, PostBotsByBotIdMcpStdioByConnectionIdResponses, PostBotsByBotIdMcpStdioData, PostBotsByBotIdMcpStdioErrors, PostBotsByBotIdMcpStdioResponses, PostBotsByBotIdScheduleData, PostBotsByBotIdScheduleErrors, PostBotsByBotIdScheduleResponses, PostBotsByBotIdSettingsData, PostBotsByBotIdSettingsErrors, PostBotsByBotIdSettingsResponses, PostBotsByBotIdSubagentsByIdSkillsData, PostBotsByBotIdSubagentsByIdSkillsErrors, PostBotsByBotIdSubagentsByIdSkillsResponses, PostBotsByBotIdSubagentsData, PostBotsByBotIdSubagentsErrors, PostBotsByBotIdSubagentsResponses, PostBotsByBotIdToolsData, PostBotsByBotIdToolsErrors, PostBotsByBotIdToolsResponses, PostBotsByIdChannelByPlatformSendChatData, PostBotsByIdChannelByPlatformSendChatErrors, PostBotsByIdChannelByPlatformSendChatResponses, PostBotsByIdChannelByPlatformSendData, PostBotsByIdChannelByPlatformSendErrors, PostBotsByIdChannelByPlatformSendResponses, PostBotsData, PostBotsErrors, PostBotsResponses, PostEmbeddingsData, PostEmbeddingsErrors, PostEmbeddingsResponses, PostModelsData, PostModelsEnableData, PostModelsEnableErrors, PostModelsEnableResponses, PostModelsErrors, PostModelsResponses, PostProvidersData, PostProvidersErrors, PostProvidersResponses, PostUsersData, PostUsersErrors, PostUsersResponses, PutBotsByBotIdMcpByIdData, PutBotsByBotIdMcpByIdErrors, PutBotsByBotIdMcpByIdResponses, PutBotsByBotIdScheduleByIdData, PutBotsByBotIdScheduleByIdErrors, PutBotsByBotIdScheduleByIdResponses, PutBotsByBotIdSettingsData, PutBotsByBotIdSettingsErrors, PutBotsByBotIdSettingsResponses, PutBotsByBotIdSubagentsByIdContextData, PutBotsByBotIdSubagentsByIdContextErrors, PutBotsByBotIdSubagentsByIdContextResponses, PutBotsByBotIdSubagentsByIdData, PutBotsByBotIdSubagentsByIdErrors, PutBotsByBotIdSubagentsByIdResponses, PutBotsByBotIdSubagentsByIdSkillsData, PutBotsByBotIdSubagentsByIdSkillsErrors, PutBotsByBotIdSubagentsByIdSkillsResponses, PutBotsByIdChannelByPlatformData, PutBotsByIdChannelByPlatformErrors, PutBotsByIdChannelByPlatformResponses, PutBotsByIdData, PutBotsByIdErrors, PutBotsByIdMembersData, PutBotsByIdMembersErrors, PutBotsByIdMembersResponses, PutBotsByIdOwnerData, PutBotsByIdOwnerErrors, PutBotsByIdOwnerResponses, PutBotsByIdResponses, PutModelsByIdData, PutModelsByIdErrors, PutModelsByIdResponses, PutModelsModelByModelIdData, PutModelsModelByModelIdErrors, PutModelsModelByModelIdResponses, PutProvidersByIdData, PutProvidersByIdErrors, PutProvidersByIdResponses, PutUsersByIdData, PutUsersByIdErrors, PutUsersByIdPasswordData, PutUsersByIdPasswordErrors, PutUsersByIdPasswordResponses, PutUsersByIdResponses, PutUsersMeChannelsByPlatformData, PutUsersMeChannelsByPlatformErrors, PutUsersMeChannelsByPlatformResponses, PutUsersMeData, PutUsersMeErrors, PutUsersMePasswordData, PutUsersMePasswordErrors, PutUsersMePasswordResponses, PutUsersMeResponses } from './types.gen'; export type Options = Options2 & { /** @@ -53,34 +53,6 @@ export const postBots = (options: Options< } }); -/** - * Chat with AI - * - * Send a chat message and get a response. The system will automatically select an appropriate chat model from the database. - */ -export const postBotsByBotIdChat = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/chat', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - -/** - * Stream chat with AI - * - * Send a chat message and get a streaming response. The system will automatically select an appropriate chat model from the database. - */ -export const postBotsByBotIdChatStream = (options: Options) => (options.client ?? client).sse.post({ - url: '/bots/{bot_id}/chat/stream', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - /** * Delete MCP container for bot */ @@ -103,88 +75,6 @@ export const postBotsByBotIdContainer = (o } }); -/** - * Delete a file or directory - */ -export const deleteBotsByBotIdContainerFs = (options: Options) => (options.client ?? client).delete({ url: '/bots/{bot_id}/container/fs', ...options }); - -/** - * List files for a bot - * - * List entries under a relative path - */ -export const getBotsByBotIdContainerFs = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/container/fs', ...options }); - -/** - * MCP filesystem tools (JSON-RPC) - * - * Forwards MCP JSON-RPC requests to the MCP server inside the container. - * Required: - * - container task is running - * - container has data mount (default /data) bound to /users/ - * - container image contains the "mcp" binary - * Auth: Bearer JWT is used to determine user_id (sub or user_id). - * Paths must be relative (no leading slash) and must not contain "..". - * - * Example: tools/list - * {"jsonrpc":"2.0","id":1,"method":"tools/list"} - * - * Example: tools/call (fs.read) - * {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"fs.read","arguments":{"path":"notes.txt"}}} - */ -export const postBotsByBotIdContainerFsMcp = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/container/fs-mcp', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - -/** - * Create a directory - */ -export const postBotsByBotIdContainerFsDir = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/container/fs/dir', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - -/** - * Read file content - */ -export const getBotsByBotIdContainerFsFile = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/container/fs/file', ...options }); - -/** - * Create or overwrite a file - */ -export const postBotsByBotIdContainerFsFile = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/container/fs/file', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - -/** - * Get file or directory metadata - */ -export const getBotsByBotIdContainerFsStat = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/container/fs/stat', ...options }); - -/** - * Upload a file - */ -export const postBotsByBotIdContainerFsUpload = (options: Options) => (options.client ?? client).post({ url: '/bots/{bot_id}/container/fs/upload', ...options }); - -/** - * Get usage under a path - */ -export const getBotsByBotIdContainerFsUsage = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/container/fs/usage', ...options }); - /** * Delete skills from data directory */ @@ -241,54 +131,12 @@ export const postBotsByBotIdContainerStart = (options: Options) => (options.client ?? client).post({ url: '/bots/{bot_id}/container/stop', ...options }); -/** - * Delete all history records - * - * Delete all history records for current user - */ -export const deleteBotsByBotIdHistory = (options: Options) => (options.client ?? client).delete({ url: '/bots/{bot_id}/history', ...options }); - -/** - * List history records - * - * List history records for current user - */ -export const getBotsByBotIdHistory = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/history', ...options }); - -/** - * Create history record - * - * Create a history record for current user - */ -export const postBotsByBotIdHistory = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/history', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - -/** - * Delete history record - * - * Delete a history record by ID (must belong to current user) - */ -export const deleteBotsByBotIdHistoryById = (options: Options) => (options.client ?? client).delete({ url: '/bots/{bot_id}/history/{id}', ...options }); - -/** - * Get history record - * - * Get a history record by ID (must belong to current user) - */ -export const getBotsByBotIdHistoryById = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/history/{id}', ...options }); - /** * List MCP connections * * List MCP connections for a bot */ -export const getBotsByBotIdMcp = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/mcp', ...options }); +export const getBotsByBotIdMcp = (options?: Options) => (options?.client ?? client).get({ url: '/bots/{bot_id}/mcp', ...options }); /** * Create MCP connection @@ -323,8 +171,8 @@ export const postBotsByBotIdMcpStdio = (op * * Proxies MCP JSON-RPC requests to a stdio MCP process in the container. */ -export const postBotsByBotIdMcpStdioBySessionId = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/mcp-stdio/{session_id}', +export const postBotsByBotIdMcpStdioByConnectionId = (options: Options) => (options.client ?? client).post({ + url: '/bots/{bot_id}/mcp-stdio/{connection_id}', ...options, headers: { 'Content-Type': 'application/json', @@ -360,103 +208,12 @@ export const putBotsByBotIdMcpById = (opti } }); -/** - * Add memory - * - * Add memory for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const postBotsByBotIdMemoryAdd = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/memory/add', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - -/** - * Embed and upsert memory - * - * Embed text or multimodal input and upsert into memory store. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const postBotsByBotIdMemoryEmbed = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/memory/embed', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - -/** - * Delete memories - * - * Delete all memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const deleteBotsByBotIdMemoryMemories = (options: Options) => (options.client ?? client).delete({ - url: '/bots/{bot_id}/memory/memories', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - -/** - * List memories - * - * List memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const getBotsByBotIdMemoryMemories = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/memory/memories', ...options }); - -/** - * Delete memory - * - * Delete a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const deleteBotsByBotIdMemoryMemoriesByMemoryId = (options: Options) => (options.client ?? client).delete({ url: '/bots/{bot_id}/memory/memories/{memoryId}', ...options }); - -/** - * Get memory - * - * Get a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const getBotsByBotIdMemoryMemoriesByMemoryId = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/memory/memories/{memoryId}', ...options }); - -/** - * Search memories - * - * Search memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const postBotsByBotIdMemorySearch = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/memory/search', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - -/** - * Update memory - * - * Update a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const postBotsByBotIdMemoryUpdate = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/memory/update', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - /** * List schedules * * List schedules for current user */ -export const getBotsByBotIdSchedule = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/schedule', ...options }); +export const getBotsByBotIdSchedule = (options?: Options) => (options?.client ?? client).get({ url: '/bots/{bot_id}/schedule', ...options }); /** * Create schedule @@ -505,14 +262,14 @@ export const putBotsByBotIdScheduleById = * * Remove agent settings for current user */ -export const deleteBotsByBotIdSettings = (options: Options) => (options.client ?? client).delete({ url: '/bots/{bot_id}/settings', ...options }); +export const deleteBotsByBotIdSettings = (options?: Options) => (options?.client ?? client).delete({ url: '/bots/{bot_id}/settings', ...options }); /** * Get user settings * * Get agent settings for current user */ -export const getBotsByBotIdSettings = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/settings', ...options }); +export const getBotsByBotIdSettings = (options?: Options) => (options?.client ?? client).get({ url: '/bots/{bot_id}/settings', ...options }); /** * Update user settings @@ -547,7 +304,7 @@ export const putBotsByBotIdSettings = (opt * * List subagents for current user */ -export const getBotsByBotIdSubagents = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/subagents', ...options }); +export const getBotsByBotIdSubagents = (options?: Options) => (options?.client ?? client).get({ url: '/bots/{bot_id}/subagents', ...options }); /** * Create subagent @@ -647,6 +404,20 @@ export const putBotsByBotIdSubagentsByIdSkills = (options: Options) => (options.client ?? client).post({ + url: '/bots/{bot_id}/tools', + ...options, + headers: { + 'Content-Type': 'application/json', + ...options.headers + } +}); + /** * Delete bot * @@ -715,8 +486,8 @@ export const postBotsByIdChannelByPlatformSend = (options: Options) => (options.client ?? client).post({ - url: '/bots/{id}/channel/{platform}/send_session', +export const postBotsByIdChannelByPlatformSendChat = (options: Options) => (options.client ?? client).post({ + url: '/bots/{id}/channel/{platform}/send_chat', ...options, headers: { 'Content-Type': 'application/json', @@ -724,6 +495,13 @@ export const postBotsByIdChannelByPlatformSendSession = (options: Options) => (options.client ?? client).get({ url: '/bots/{id}/checks', ...options }); + /** * List bot members * @@ -1025,6 +803,13 @@ export const putUsersMeChannelsByPlatform = (options?: Options) => (options?.client ?? client).get({ url: '/users/me/identities', ...options }); + /** * Update current user password * diff --git a/packages/sdk/src/types.gen.ts b/packages/sdk/src/types.gen.ts index 9a91b061..53952883 100644 --- a/packages/sdk/src/types.gen.ts +++ b/packages/sdk/src/types.gen.ts @@ -4,18 +4,79 @@ export type ClientOptions = { baseUrl: string; }; +export type AccountsAccount = { + avatar_url?: string; + created_at?: string; + display_name?: string; + email?: string; + id?: string; + is_active?: boolean; + last_login_at?: string; + role?: string; + updated_at?: string; + username?: string; +}; + +export type AccountsCreateAccountRequest = { + avatar_url?: string; + display_name?: string; + email?: string; + is_active?: boolean; + password?: string; + role?: string; + username?: string; +}; + +export type AccountsListAccountsResponse = { + items?: Array; +}; + +export type AccountsResetPasswordRequest = { + new_password?: string; +}; + +export type AccountsUpdateAccountRequest = { + avatar_url?: string; + display_name?: string; + is_active?: boolean; + role?: string; +}; + +export type AccountsUpdatePasswordRequest = { + current_password?: string; + new_password?: string; +}; + +export type AccountsUpdateProfileRequest = { + avatar_url?: string; + display_name?: string; +}; + export type BotsBot = { avatar_url?: string; - created_at: string; - display_name: string; - id: string; - is_active: boolean; + check_issue_count?: number; + check_state?: string; + created_at?: string; + display_name?: string; + id?: string; + is_active?: boolean; metadata?: { [key: string]: unknown; }; - owner_user_id: string; - type: string; - updated_at: string; + owner_user_id?: string; + status?: string; + type?: string; + updated_at?: string; +}; + +export type BotsBotCheck = { + check_key?: string; + detail?: string; + metadata?: { + [key: string]: unknown; + }; + status?: string; + summary?: string; }; export type BotsBotMember = { @@ -36,7 +97,11 @@ export type BotsCreateBotRequest = { }; export type BotsListBotsResponse = { - items: Array; + items?: Array; +}; + +export type BotsListChecksResponse = { + items?: Array; }; export type BotsListMembersResponse = { @@ -77,7 +142,9 @@ export type ChannelAttachment = { }; mime?: string; name?: string; + platform_key?: string; size?: number; + source_platform?: string; thumbnail_url?: string; type?: ChannelAttachmentType; url?: string; @@ -125,7 +192,8 @@ export type ChannelChannelConfig = { verifiedAt?: string; }; -export type ChannelChannelUserBinding = { +export type ChannelChannelIdentityBinding = { + channelIdentityID?: string; channelType?: string; config?: { [key: string]: unknown; @@ -133,7 +201,6 @@ export type ChannelChannelUserBinding = { createdAt?: string; id?: string; updatedAt?: string; - userID?: string; }; export type ChannelConfigSchema = { @@ -171,6 +238,7 @@ export type ChannelMessage = { export type ChannelMessageFormat = 'plain' | 'markdown' | 'rich'; export type ChannelMessagePart = { + channel_identity_id?: string; emoji?: string; language?: string; metadata?: { @@ -180,7 +248,6 @@ export type ChannelMessagePart = { text?: string; type?: ChannelMessagePartType; url?: string; - user_id?: string; }; export type ChannelMessagePartType = 'text' | 'link' | 'code_block' | 'mention' | 'emoji'; @@ -193,9 +260,9 @@ export type ChannelReplyRef = { }; export type ChannelSendRequest = { + channel_identity_id?: string; message?: ChannelMessage; target?: string; - user_id?: string; }; export type ChannelTargetHint = { @@ -212,6 +279,12 @@ export type ChannelThreadRef = { id?: string; }; +export type ChannelUpsertChannelIdentityConfigRequest = { + config?: { + [key: string]: unknown; + }; +}; + export type ChannelUpsertConfigRequest = { credentials?: { [key: string]: unknown; @@ -227,76 +300,30 @@ export type ChannelUpsertConfigRequest = { verified_at?: string; }; -export type ChannelUpsertUserConfigRequest = { +export type GithubComMemohaiMemohInternalMcpConnection = { + active?: boolean; + bot_id?: string; config?: { [key: string]: unknown; }; -}; - -export type ChatChatRequest = { - allowed_actions?: Array; - channels?: Array; - current_channel?: string; - language?: string; - max_context_load_time?: number; - messages?: Array; - model?: string; - provider?: string; - query?: string; - skills?: Array; -}; - -export type ChatChatResponse = { - messages?: Array; - model?: string; - provider?: string; - skills?: Array; -}; - -export type ChatModelMessage = { - content?: Array; - name?: string; - role?: string; - tool_call_id?: string; - tool_calls?: Array; -}; - -export type ChatToolCall = { - function?: ChatToolCallFunction; + created_at?: string; id?: string; - type?: string; -}; - -export type ChatToolCallFunction = { - arguments?: string; name?: string; -}; - -export type GithubComMemohaiMemohInternalMcpConnection = { - active: boolean; - bot_id: string; - config: { - [key: string]: unknown; - }; - created_at: string; - id: string; - name: string; - type: string; - updated_at: string; + type?: string; + updated_at?: string; }; export type HandlersChannelMeta = { - capabilities: ChannelChannelCapabilities; - config_schema: ChannelConfigSchema; + capabilities?: ChannelChannelCapabilities; + config_schema?: ChannelConfigSchema; configless?: boolean; - display_name: string; + display_name?: string; target_spec?: ChannelTargetSpec; - type: string; + type?: string; user_config_schema?: ChannelConfigSchema; }; export type HandlersCreateContainerRequest = { - image?: string; snapshotter?: string; }; @@ -356,61 +383,6 @@ export type HandlersErrorResponse = { message?: string; }; -export type HandlersFsDeleteResponse = { - ok?: boolean; -}; - -export type HandlersFsListResponse = { - entries?: Array; - path?: string; -}; - -export type HandlersFsMkdirRequest = { - parents?: boolean; - path?: string; -}; - -export type HandlersFsReadResponse = { - content?: string; - mod_time?: string; - mode?: number; - path?: string; - size?: number; -}; - -export type HandlersFsRestEntry = { - is_dir?: boolean; - mod_time?: string; - mode?: number; - path?: string; - size?: number; -}; - -export type HandlersFsStatResponse = { - is_dir?: boolean; - mod_time?: string; - mode?: number; - path?: string; - size?: number; -}; - -export type HandlersFsUsageResponse = { - dir_count?: number; - file_count?: number; - path?: string; - total_bytes?: number; -}; - -export type HandlersFsWriteRequest = { - content?: string; - overwrite?: boolean; - path?: string; -}; - -export type HandlersFsWriteResponse = { - ok?: boolean; -}; - export type HandlersGetContainerResponse = { container_id?: string; container_path?: string; @@ -429,18 +401,18 @@ export type HandlersListSnapshotsResponse = { }; export type HandlersLoginRequest = { - password: string; - username: string; + password?: string; + username?: string; }; export type HandlersLoginResponse = { - access_token: string; - display_name: string; - expires_at: string; - role: string; - token_type: string; - user_id: string; - username: string; + access_token?: string; + display_name?: string; + expires_at?: string; + role?: string; + token_type?: string; + user_id?: string; + username?: string; }; export type HandlersMcpStdioRequest = { @@ -454,7 +426,7 @@ export type HandlersMcpStdioRequest = { }; export type HandlersMcpStdioResponse = { - session_id?: string; + connection_id?: string; tools?: Array; url?: string; }; @@ -492,80 +464,26 @@ export type HandlersSnapshotInfo = { updated_at?: string; }; -export type HandlersMemoryAddPayload = { - embedding_enabled?: boolean; - filters?: { - [key: string]: unknown; - }; - infer?: boolean; - message?: string; - messages?: Array; - metadata?: { - [key: string]: unknown; - }; - run_id?: string; -}; - -export type HandlersMemoryDeleteAllPayload = { - run_id?: string; -}; - -export type HandlersMemoryEmbedUpsertPayload = { - filters?: { - [key: string]: unknown; - }; - input?: MemoryEmbedInput; - metadata?: { - [key: string]: unknown; - }; - model?: string; - provider?: string; - run_id?: string; - source?: string; - type?: string; -}; - -export type HandlersMemorySearchPayload = { - embedding_enabled?: boolean; - filters?: { - [key: string]: unknown; - }; - limit?: number; - query?: string; - run_id?: string; - sources?: Array; +export type HandlersListMyIdentitiesResponse = { + items?: Array; + user_id?: string; }; export type HandlersSkillsOpResponse = { ok?: boolean; }; -export type HistoryCreateRequest = { - messages?: Array<{ - [key: string]: unknown; - }>; - metadata?: { - [key: string]: unknown; - }; - skills?: Array; -}; - -export type HistoryListResponse = { - items?: Array; -}; - -export type HistoryRecord = { - bot_id?: string; +export type IdentitiesChannelIdentity = { + channel?: string; + channel_subject_id?: string; + created_at?: string; + display_name?: string; id?: string; - messages?: Array<{ - [key: string]: unknown; - }>; metadata?: { [key: string]: unknown; }; - session_id?: string; - skills?: Array; - timestamp?: string; + updated_at?: string; + user_id?: string; }; export type McpListResponse = { @@ -581,63 +499,14 @@ export type McpUpsertRequest = { type?: string; }; -export type MemoryDeleteResponse = { - message?: string; -}; - -export type MemoryEmbedInput = { - image_url?: string; - text?: string; - video_url?: string; -}; - -export type MemoryEmbedUpsertResponse = { - dimensions?: number; - item?: MemoryMemoryItem; - model?: string; - provider?: string; -}; - -export type MemoryMemoryItem = { - agentId?: string; - botId?: string; - createdAt?: string; - hash?: string; - id?: string; - memory?: string; - metadata?: { - [key: string]: unknown; - }; - runId?: string; - score?: number; - sessionId?: string; - updatedAt?: string; -}; - -export type MemoryMessage = { - content?: string; - role?: string; -}; - -export type MemorySearchResponse = { - relations?: Array; - results?: Array; -}; - -export type MemoryUpdateRequest = { - embedding_enabled?: boolean; - memory?: string; - memory_id?: string; -}; - export type ModelsAddRequest = { dimensions?: number; input?: Array; is_multimodal?: boolean; - llm_provider_id: string; - model_id: string; - name: string; - type: ModelsModelType; + llm_provider_id?: string; + model_id?: string; + name?: string; + type?: ModelsModelType; }; export type ModelsAddResponse = { @@ -653,10 +522,10 @@ export type ModelsGetResponse = { dimensions?: number; input?: Array; is_multimodal?: boolean; - llm_provider_id: string; - model_id: string; - name: string; - type: ModelsModelType; + llm_provider_id?: string; + model_id?: string; + name?: string; + type?: ModelsModelType; }; export type ModelsModelType = 'chat' | 'embedding'; @@ -665,10 +534,10 @@ export type ModelsUpdateRequest = { dimensions?: number; input?: Array; is_multimodal?: boolean; - llm_provider_id: string; - model_id: string; - name: string; - type: ModelsModelType; + llm_provider_id?: string; + model_id?: string; + name?: string; + type?: ModelsModelType; }; export type ProvidersClientType = 'openai' | 'openai-compat' | 'anthropic' | 'google' | 'ollama'; @@ -692,15 +561,15 @@ export type ProvidersGetResponse = { * masked in response */ api_key?: string; - base_url: string; - client_type: string; - created_at: string; - id: string; + base_url?: string; + client_type?: string; + created_at?: string; + id?: string; metadata?: { [key: string]: unknown; }; - name: string; - updated_at: string; + name?: string; + updated_at?: string; }; export type ProvidersUpdateRequest = { @@ -755,12 +624,12 @@ export type ScheduleUpdateRequest = { }; export type SettingsSettings = { - allow_guest: boolean; - chat_model_id: string; - embedding_model_id: string; - language: string; - max_context_load_time: number; - memory_model_id: string; + allow_guest?: boolean; + chat_model_id?: string; + embedding_model_id?: string; + language?: string; + max_context_load_time?: number; + memory_model_id?: string; }; export type SettingsUpsertRequest = { @@ -838,54 +707,6 @@ export type SubagentUpdateSkillsRequest = { skills?: Array; }; -export type UsersCreateUserRequest = { - avatar_url?: string; - display_name?: string; - email?: string; - is_active?: boolean; - password?: string; - role?: string; - username?: string; -}; - -export type UsersListUsersResponse = { - items?: Array; -}; - -export type UsersResetPasswordRequest = { - new_password?: string; -}; - -export type UsersUpdatePasswordRequest = { - current_password?: string; - new_password?: string; -}; - -export type UsersUpdateProfileRequest = { - avatar_url?: string; - display_name?: string; -}; - -export type UsersUpdateUserRequest = { - avatar_url?: string; - display_name?: string; - is_active?: boolean; - role?: string; -}; - -export type UsersUser = { - avatar_url?: string; - created_at?: string; - display_name?: string; - email?: string; - id?: string; - is_active?: boolean; - last_login_at?: string; - role?: string; - updated_at?: string; - username?: string; -}; - export type PostAuthLoginData = { /** * Login request @@ -996,80 +817,6 @@ export type PostBotsResponses = { export type PostBotsResponse = PostBotsResponses[keyof PostBotsResponses]; -export type PostBotsByBotIdChatData = { - /** - * Chat request - */ - body: ChatChatRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/chat'; -}; - -export type PostBotsByBotIdChatErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdChatError = PostBotsByBotIdChatErrors[keyof PostBotsByBotIdChatErrors]; - -export type PostBotsByBotIdChatResponses = { - /** - * OK - */ - 200: ChatChatResponse; -}; - -export type PostBotsByBotIdChatResponse = PostBotsByBotIdChatResponses[keyof PostBotsByBotIdChatResponses]; - -export type PostBotsByBotIdChatStreamData = { - /** - * Chat request - */ - body: ChatChatRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/chat/stream'; -}; - -export type PostBotsByBotIdChatStreamErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdChatStreamError = PostBotsByBotIdChatStreamErrors[keyof PostBotsByBotIdChatStreamErrors]; - -export type PostBotsByBotIdChatStreamResponses = { - /** - * OK - */ - 200: string; -}; - -export type PostBotsByBotIdChatStreamResponse = PostBotsByBotIdChatStreamResponses[keyof PostBotsByBotIdChatStreamResponses]; - export type DeleteBotsByBotIdContainerData = { body?: never; path: { @@ -1173,409 +920,6 @@ export type PostBotsByBotIdContainerResponses = { export type PostBotsByBotIdContainerResponse = PostBotsByBotIdContainerResponses[keyof PostBotsByBotIdContainerResponses]; -export type DeleteBotsByBotIdContainerFsData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query: { - /** - * Relative path - */ - path: string; - /** - * Recursive delete for directories - */ - recursive?: boolean; - }; - url: '/bots/{bot_id}/container/fs'; -}; - -export type DeleteBotsByBotIdContainerFsErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type DeleteBotsByBotIdContainerFsError = DeleteBotsByBotIdContainerFsErrors[keyof DeleteBotsByBotIdContainerFsErrors]; - -export type DeleteBotsByBotIdContainerFsResponses = { - /** - * OK - */ - 200: HandlersFsDeleteResponse; -}; - -export type DeleteBotsByBotIdContainerFsResponse = DeleteBotsByBotIdContainerFsResponses[keyof DeleteBotsByBotIdContainerFsResponses]; - -export type GetBotsByBotIdContainerFsData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: { - /** - * Relative directory path - */ - path?: string; - /** - * Recursive listing - */ - recursive?: boolean; - }; - url: '/bots/{bot_id}/container/fs'; -}; - -export type GetBotsByBotIdContainerFsErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type GetBotsByBotIdContainerFsError = GetBotsByBotIdContainerFsErrors[keyof GetBotsByBotIdContainerFsErrors]; - -export type GetBotsByBotIdContainerFsResponses = { - /** - * OK - */ - 200: HandlersFsListResponse; -}; - -export type GetBotsByBotIdContainerFsResponse = GetBotsByBotIdContainerFsResponses[keyof GetBotsByBotIdContainerFsResponses]; - -export type PostBotsByBotIdContainerFsMcpData = { - /** - * JSON-RPC request - */ - body: { - [key: string]: unknown; - }; - headers: { - /** - * Bearer - */ - Authorization: string; - }; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/container/fs-mcp'; -}; - -export type PostBotsByBotIdContainerFsMcpErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdContainerFsMcpError = PostBotsByBotIdContainerFsMcpErrors[keyof PostBotsByBotIdContainerFsMcpErrors]; - -export type PostBotsByBotIdContainerFsMcpResponses = { - /** - * JSON-RPC response: {jsonrpc,id,result|error} - */ - 200: { - [key: string]: unknown; - }; -}; - -export type PostBotsByBotIdContainerFsMcpResponse = PostBotsByBotIdContainerFsMcpResponses[keyof PostBotsByBotIdContainerFsMcpResponses]; - -export type PostBotsByBotIdContainerFsDirData = { - /** - * Directory payload - */ - body: HandlersFsMkdirRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/container/fs/dir'; -}; - -export type PostBotsByBotIdContainerFsDirErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdContainerFsDirError = PostBotsByBotIdContainerFsDirErrors[keyof PostBotsByBotIdContainerFsDirErrors]; - -export type PostBotsByBotIdContainerFsDirResponses = { - /** - * OK - */ - 200: HandlersFsWriteResponse; -}; - -export type PostBotsByBotIdContainerFsDirResponse = PostBotsByBotIdContainerFsDirResponses[keyof PostBotsByBotIdContainerFsDirResponses]; - -export type GetBotsByBotIdContainerFsFileData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query: { - /** - * Relative file path - */ - path: string; - }; - url: '/bots/{bot_id}/container/fs/file'; -}; - -export type GetBotsByBotIdContainerFsFileErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type GetBotsByBotIdContainerFsFileError = GetBotsByBotIdContainerFsFileErrors[keyof GetBotsByBotIdContainerFsFileErrors]; - -export type GetBotsByBotIdContainerFsFileResponses = { - /** - * OK - */ - 200: HandlersFsReadResponse; -}; - -export type GetBotsByBotIdContainerFsFileResponse = GetBotsByBotIdContainerFsFileResponses[keyof GetBotsByBotIdContainerFsFileResponses]; - -export type PostBotsByBotIdContainerFsFileData = { - /** - * File write payload - */ - body: HandlersFsWriteRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/container/fs/file'; -}; - -export type PostBotsByBotIdContainerFsFileErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Conflict - */ - 409: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdContainerFsFileError = PostBotsByBotIdContainerFsFileErrors[keyof PostBotsByBotIdContainerFsFileErrors]; - -export type PostBotsByBotIdContainerFsFileResponses = { - /** - * OK - */ - 200: HandlersFsWriteResponse; -}; - -export type PostBotsByBotIdContainerFsFileResponse = PostBotsByBotIdContainerFsFileResponses[keyof PostBotsByBotIdContainerFsFileResponses]; - -export type GetBotsByBotIdContainerFsStatData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query: { - /** - * Relative path - */ - path: string; - }; - url: '/bots/{bot_id}/container/fs/stat'; -}; - -export type GetBotsByBotIdContainerFsStatErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type GetBotsByBotIdContainerFsStatError = GetBotsByBotIdContainerFsStatErrors[keyof GetBotsByBotIdContainerFsStatErrors]; - -export type GetBotsByBotIdContainerFsStatResponses = { - /** - * OK - */ - 200: HandlersFsStatResponse; -}; - -export type GetBotsByBotIdContainerFsStatResponse = GetBotsByBotIdContainerFsStatResponses[keyof GetBotsByBotIdContainerFsStatResponses]; - -export type PostBotsByBotIdContainerFsUploadData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: { - /** - * Relative file path or directory - */ - path?: string; - }; - url: '/bots/{bot_id}/container/fs/upload'; -}; - -export type PostBotsByBotIdContainerFsUploadErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdContainerFsUploadError = PostBotsByBotIdContainerFsUploadErrors[keyof PostBotsByBotIdContainerFsUploadErrors]; - -export type PostBotsByBotIdContainerFsUploadResponses = { - /** - * OK - */ - 200: HandlersFsWriteResponse; -}; - -export type PostBotsByBotIdContainerFsUploadResponse = PostBotsByBotIdContainerFsUploadResponses[keyof PostBotsByBotIdContainerFsUploadResponses]; - -export type GetBotsByBotIdContainerFsUsageData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: { - /** - * Relative directory path - */ - path?: string; - }; - url: '/bots/{bot_id}/container/fs/usage'; -}; - -export type GetBotsByBotIdContainerFsUsageErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type GetBotsByBotIdContainerFsUsageError = GetBotsByBotIdContainerFsUsageErrors[keyof GetBotsByBotIdContainerFsUsageErrors]; - -export type GetBotsByBotIdContainerFsUsageResponses = { - /** - * OK - */ - 200: HandlersFsUsageResponse; -}; - -export type GetBotsByBotIdContainerFsUsageResponse = GetBotsByBotIdContainerFsUsageResponses[keyof GetBotsByBotIdContainerFsUsageResponses]; - export type DeleteBotsByBotIdContainerSkillsData = { /** * Delete skills payload @@ -1831,204 +1175,9 @@ export type PostBotsByBotIdContainerStopResponses = { export type PostBotsByBotIdContainerStopResponse = PostBotsByBotIdContainerStopResponses[keyof PostBotsByBotIdContainerStopResponses]; -export type DeleteBotsByBotIdHistoryData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/history'; -}; - -export type DeleteBotsByBotIdHistoryErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type DeleteBotsByBotIdHistoryError = DeleteBotsByBotIdHistoryErrors[keyof DeleteBotsByBotIdHistoryErrors]; - -export type DeleteBotsByBotIdHistoryResponses = { - /** - * No Content - */ - 204: unknown; -}; - -export type GetBotsByBotIdHistoryData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: { - /** - * Limit - */ - limit?: number; - }; - url: '/bots/{bot_id}/history'; -}; - -export type GetBotsByBotIdHistoryErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type GetBotsByBotIdHistoryError = GetBotsByBotIdHistoryErrors[keyof GetBotsByBotIdHistoryErrors]; - -export type GetBotsByBotIdHistoryResponses = { - /** - * OK - */ - 200: HistoryListResponse; -}; - -export type GetBotsByBotIdHistoryResponse = GetBotsByBotIdHistoryResponses[keyof GetBotsByBotIdHistoryResponses]; - -export type PostBotsByBotIdHistoryData = { - /** - * History payload - */ - body: HistoryCreateRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/history'; -}; - -export type PostBotsByBotIdHistoryErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdHistoryError = PostBotsByBotIdHistoryErrors[keyof PostBotsByBotIdHistoryErrors]; - -export type PostBotsByBotIdHistoryResponses = { - /** - * Created - */ - 201: HistoryRecord; -}; - -export type PostBotsByBotIdHistoryResponse = PostBotsByBotIdHistoryResponses[keyof PostBotsByBotIdHistoryResponses]; - -export type DeleteBotsByBotIdHistoryByIdData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - /** - * History ID - */ - id: string; - }; - query?: never; - url: '/bots/{bot_id}/history/{id}'; -}; - -export type DeleteBotsByBotIdHistoryByIdErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Forbidden - */ - 403: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type DeleteBotsByBotIdHistoryByIdError = DeleteBotsByBotIdHistoryByIdErrors[keyof DeleteBotsByBotIdHistoryByIdErrors]; - -export type DeleteBotsByBotIdHistoryByIdResponses = { - /** - * No Content - */ - 204: unknown; -}; - -export type GetBotsByBotIdHistoryByIdData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - /** - * History ID - */ - id: string; - }; - query?: never; - url: '/bots/{bot_id}/history/{id}'; -}; - -export type GetBotsByBotIdHistoryByIdErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type GetBotsByBotIdHistoryByIdError = GetBotsByBotIdHistoryByIdErrors[keyof GetBotsByBotIdHistoryByIdErrors]; - -export type GetBotsByBotIdHistoryByIdResponses = { - /** - * OK - */ - 200: HistoryRecord; -}; - -export type GetBotsByBotIdHistoryByIdResponse = GetBotsByBotIdHistoryByIdResponses[keyof GetBotsByBotIdHistoryByIdResponses]; - export type GetBotsByBotIdMcpData = { body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/mcp'; }; @@ -2068,12 +1217,7 @@ export type PostBotsByBotIdMcpData = { * MCP payload */ body: McpUpsertRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/mcp'; }; @@ -2149,7 +1293,7 @@ export type PostBotsByBotIdMcpStdioResponses = { export type PostBotsByBotIdMcpStdioResponse = PostBotsByBotIdMcpStdioResponses[keyof PostBotsByBotIdMcpStdioResponses]; -export type PostBotsByBotIdMcpStdioBySessionIdData = { +export type PostBotsByBotIdMcpStdioByConnectionIdData = { /** * JSON-RPC request */ @@ -2162,15 +1306,15 @@ export type PostBotsByBotIdMcpStdioBySessionIdData = { */ bot_id: string; /** - * Session ID + * Connection ID */ - session_id: string; + connection_id: string; }; query?: never; - url: '/bots/{bot_id}/mcp-stdio/{session_id}'; + url: '/bots/{bot_id}/mcp-stdio/{connection_id}'; }; -export type PostBotsByBotIdMcpStdioBySessionIdErrors = { +export type PostBotsByBotIdMcpStdioByConnectionIdErrors = { /** * Bad Request */ @@ -2185,9 +1329,9 @@ export type PostBotsByBotIdMcpStdioBySessionIdErrors = { 500: HandlersErrorResponse; }; -export type PostBotsByBotIdMcpStdioBySessionIdError = PostBotsByBotIdMcpStdioBySessionIdErrors[keyof PostBotsByBotIdMcpStdioBySessionIdErrors]; +export type PostBotsByBotIdMcpStdioByConnectionIdError = PostBotsByBotIdMcpStdioByConnectionIdErrors[keyof PostBotsByBotIdMcpStdioByConnectionIdErrors]; -export type PostBotsByBotIdMcpStdioBySessionIdResponses = { +export type PostBotsByBotIdMcpStdioByConnectionIdResponses = { /** * JSON-RPC response: {jsonrpc,id,result|error} */ @@ -2196,15 +1340,11 @@ export type PostBotsByBotIdMcpStdioBySessionIdResponses = { }; }; -export type PostBotsByBotIdMcpStdioBySessionIdResponse = PostBotsByBotIdMcpStdioBySessionIdResponses[keyof PostBotsByBotIdMcpStdioBySessionIdResponses]; +export type PostBotsByBotIdMcpStdioByConnectionIdResponse = PostBotsByBotIdMcpStdioByConnectionIdResponses[keyof PostBotsByBotIdMcpStdioByConnectionIdResponses]; export type DeleteBotsByBotIdMcpByIdData = { body?: never; path: { - /** - * Bot ID - */ - bot_id: string; /** * MCP ID */ @@ -2245,10 +1385,6 @@ export type DeleteBotsByBotIdMcpByIdResponses = { export type GetBotsByBotIdMcpByIdData = { body?: never; path: { - /** - * Bot ID - */ - bot_id: string; /** * MCP ID */ @@ -2294,10 +1430,6 @@ export type PutBotsByBotIdMcpByIdData = { */ body: McpUpsertRequest; path: { - /** - * Bot ID - */ - bot_id: string; /** * MCP ID */ @@ -2337,318 +1469,9 @@ export type PutBotsByBotIdMcpByIdResponses = { export type PutBotsByBotIdMcpByIdResponse = PutBotsByBotIdMcpByIdResponses[keyof PutBotsByBotIdMcpByIdResponses]; -export type PostBotsByBotIdMemoryAddData = { - /** - * Add request - */ - body: HandlersMemoryAddPayload; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/memory/add'; -}; - -export type PostBotsByBotIdMemoryAddErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdMemoryAddError = PostBotsByBotIdMemoryAddErrors[keyof PostBotsByBotIdMemoryAddErrors]; - -export type PostBotsByBotIdMemoryAddResponses = { - /** - * OK - */ - 200: MemorySearchResponse; -}; - -export type PostBotsByBotIdMemoryAddResponse = PostBotsByBotIdMemoryAddResponses[keyof PostBotsByBotIdMemoryAddResponses]; - -export type PostBotsByBotIdMemoryEmbedData = { - /** - * Embed upsert request - */ - body: HandlersMemoryEmbedUpsertPayload; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/memory/embed'; -}; - -export type PostBotsByBotIdMemoryEmbedErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdMemoryEmbedError = PostBotsByBotIdMemoryEmbedErrors[keyof PostBotsByBotIdMemoryEmbedErrors]; - -export type PostBotsByBotIdMemoryEmbedResponses = { - /** - * OK - */ - 200: MemoryEmbedUpsertResponse; -}; - -export type PostBotsByBotIdMemoryEmbedResponse = PostBotsByBotIdMemoryEmbedResponses[keyof PostBotsByBotIdMemoryEmbedResponses]; - -export type DeleteBotsByBotIdMemoryMemoriesData = { - /** - * Delete all request - */ - body: HandlersMemoryDeleteAllPayload; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/memory/memories'; -}; - -export type DeleteBotsByBotIdMemoryMemoriesErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type DeleteBotsByBotIdMemoryMemoriesError = DeleteBotsByBotIdMemoryMemoriesErrors[keyof DeleteBotsByBotIdMemoryMemoriesErrors]; - -export type DeleteBotsByBotIdMemoryMemoriesResponses = { - /** - * OK - */ - 200: MemoryDeleteResponse; -}; - -export type DeleteBotsByBotIdMemoryMemoriesResponse = DeleteBotsByBotIdMemoryMemoriesResponses[keyof DeleteBotsByBotIdMemoryMemoriesResponses]; - -export type GetBotsByBotIdMemoryMemoriesData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: { - /** - * Run ID - */ - run_id?: string; - /** - * Limit - */ - limit?: number; - }; - url: '/bots/{bot_id}/memory/memories'; -}; - -export type GetBotsByBotIdMemoryMemoriesErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type GetBotsByBotIdMemoryMemoriesError = GetBotsByBotIdMemoryMemoriesErrors[keyof GetBotsByBotIdMemoryMemoriesErrors]; - -export type GetBotsByBotIdMemoryMemoriesResponses = { - /** - * OK - */ - 200: MemorySearchResponse; -}; - -export type GetBotsByBotIdMemoryMemoriesResponse = GetBotsByBotIdMemoryMemoriesResponses[keyof GetBotsByBotIdMemoryMemoriesResponses]; - -export type DeleteBotsByBotIdMemoryMemoriesByMemoryIdData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - /** - * Memory ID - */ - memoryId: string; - }; - query?: never; - url: '/bots/{bot_id}/memory/memories/{memoryId}'; -}; - -export type DeleteBotsByBotIdMemoryMemoriesByMemoryIdErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type DeleteBotsByBotIdMemoryMemoriesByMemoryIdError = DeleteBotsByBotIdMemoryMemoriesByMemoryIdErrors[keyof DeleteBotsByBotIdMemoryMemoriesByMemoryIdErrors]; - -export type DeleteBotsByBotIdMemoryMemoriesByMemoryIdResponses = { - /** - * OK - */ - 200: MemoryDeleteResponse; -}; - -export type DeleteBotsByBotIdMemoryMemoriesByMemoryIdResponse = DeleteBotsByBotIdMemoryMemoriesByMemoryIdResponses[keyof DeleteBotsByBotIdMemoryMemoriesByMemoryIdResponses]; - -export type GetBotsByBotIdMemoryMemoriesByMemoryIdData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - /** - * Memory ID - */ - memoryId: string; - }; - query?: never; - url: '/bots/{bot_id}/memory/memories/{memoryId}'; -}; - -export type GetBotsByBotIdMemoryMemoriesByMemoryIdErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type GetBotsByBotIdMemoryMemoriesByMemoryIdError = GetBotsByBotIdMemoryMemoriesByMemoryIdErrors[keyof GetBotsByBotIdMemoryMemoriesByMemoryIdErrors]; - -export type GetBotsByBotIdMemoryMemoriesByMemoryIdResponses = { - /** - * OK - */ - 200: MemoryMemoryItem; -}; - -export type GetBotsByBotIdMemoryMemoriesByMemoryIdResponse = GetBotsByBotIdMemoryMemoriesByMemoryIdResponses[keyof GetBotsByBotIdMemoryMemoriesByMemoryIdResponses]; - -export type PostBotsByBotIdMemorySearchData = { - /** - * Search request - */ - body: HandlersMemorySearchPayload; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/memory/search'; -}; - -export type PostBotsByBotIdMemorySearchErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdMemorySearchError = PostBotsByBotIdMemorySearchErrors[keyof PostBotsByBotIdMemorySearchErrors]; - -export type PostBotsByBotIdMemorySearchResponses = { - /** - * OK - */ - 200: MemorySearchResponse; -}; - -export type PostBotsByBotIdMemorySearchResponse = PostBotsByBotIdMemorySearchResponses[keyof PostBotsByBotIdMemorySearchResponses]; - -export type PostBotsByBotIdMemoryUpdateData = { - /** - * Update request - */ - body: MemoryUpdateRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/memory/update'; -}; - -export type PostBotsByBotIdMemoryUpdateErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdMemoryUpdateError = PostBotsByBotIdMemoryUpdateErrors[keyof PostBotsByBotIdMemoryUpdateErrors]; - -export type PostBotsByBotIdMemoryUpdateResponses = { - /** - * OK - */ - 200: MemoryMemoryItem; -}; - -export type PostBotsByBotIdMemoryUpdateResponse = PostBotsByBotIdMemoryUpdateResponses[keyof PostBotsByBotIdMemoryUpdateResponses]; - export type GetBotsByBotIdScheduleData = { body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/schedule'; }; @@ -2680,12 +1503,7 @@ export type PostBotsByBotIdScheduleData = { * Schedule payload */ body: ScheduleCreateRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/schedule'; }; @@ -2715,10 +1533,6 @@ export type PostBotsByBotIdScheduleResponse = PostBotsByBotIdScheduleResponses[k export type DeleteBotsByBotIdScheduleByIdData = { body?: never; path: { - /** - * Bot ID - */ - bot_id: string; /** * Schedule ID */ @@ -2751,10 +1565,6 @@ export type DeleteBotsByBotIdScheduleByIdResponses = { export type GetBotsByBotIdScheduleByIdData = { body?: never; path: { - /** - * Bot ID - */ - bot_id: string; /** * Schedule ID */ @@ -2796,10 +1606,6 @@ export type PutBotsByBotIdScheduleByIdData = { */ body: ScheduleUpdateRequest; path: { - /** - * Bot ID - */ - bot_id: string; /** * Schedule ID */ @@ -2833,12 +1639,7 @@ export type PutBotsByBotIdScheduleByIdResponse = PutBotsByBotIdScheduleByIdRespo export type DeleteBotsByBotIdSettingsData = { body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/settings'; }; @@ -2865,12 +1666,7 @@ export type DeleteBotsByBotIdSettingsResponses = { export type GetBotsByBotIdSettingsData = { body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/settings'; }; @@ -2902,12 +1698,7 @@ export type PostBotsByBotIdSettingsData = { * Settings payload */ body: SettingsUpsertRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/settings'; }; @@ -2939,12 +1730,7 @@ export type PutBotsByBotIdSettingsData = { * Settings payload */ body: SettingsUpsertRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/settings'; }; @@ -2973,12 +1759,7 @@ export type PutBotsByBotIdSettingsResponse = PutBotsByBotIdSettingsResponses[key export type GetBotsByBotIdSubagentsData = { body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/subagents'; }; @@ -3010,12 +1791,7 @@ export type PostBotsByBotIdSubagentsData = { * Subagent payload */ body: SubagentCreateRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/subagents'; }; @@ -3045,10 +1821,6 @@ export type PostBotsByBotIdSubagentsResponse = PostBotsByBotIdSubagentsResponses export type DeleteBotsByBotIdSubagentsByIdData = { body?: never; path: { - /** - * Bot ID - */ - bot_id: string; /** * Subagent ID */ @@ -3085,10 +1857,6 @@ export type DeleteBotsByBotIdSubagentsByIdResponses = { export type GetBotsByBotIdSubagentsByIdData = { body?: never; path: { - /** - * Bot ID - */ - bot_id: string; /** * Subagent ID */ @@ -3130,10 +1898,6 @@ export type PutBotsByBotIdSubagentsByIdData = { */ body: SubagentUpdateRequest; path: { - /** - * Bot ID - */ - bot_id: string; /** * Subagent ID */ @@ -3172,10 +1936,6 @@ export type PutBotsByBotIdSubagentsByIdResponse = PutBotsByBotIdSubagentsByIdRes export type GetBotsByBotIdSubagentsByIdContextData = { body?: never; path: { - /** - * Bot ID - */ - bot_id: string; /** * Subagent ID */ @@ -3217,10 +1977,6 @@ export type PutBotsByBotIdSubagentsByIdContextData = { */ body: SubagentUpdateContextRequest; path: { - /** - * Bot ID - */ - bot_id: string; /** * Subagent ID */ @@ -3259,10 +2015,6 @@ export type PutBotsByBotIdSubagentsByIdContextResponse = PutBotsByBotIdSubagents export type GetBotsByBotIdSubagentsByIdSkillsData = { body?: never; path: { - /** - * Bot ID - */ - bot_id: string; /** * Subagent ID */ @@ -3304,10 +2056,6 @@ export type PostBotsByBotIdSubagentsByIdSkillsData = { */ body: SubagentAddSkillsRequest; path: { - /** - * Bot ID - */ - bot_id: string; /** * Subagent ID */ @@ -3349,10 +2097,6 @@ export type PutBotsByBotIdSubagentsByIdSkillsData = { */ body: SubagentUpdateSkillsRequest; path: { - /** - * Bot ID - */ - bot_id: string; /** * Subagent ID */ @@ -3388,6 +2132,51 @@ export type PutBotsByBotIdSubagentsByIdSkillsResponses = { export type PutBotsByBotIdSubagentsByIdSkillsResponse = PutBotsByBotIdSubagentsByIdSkillsResponses[keyof PutBotsByBotIdSubagentsByIdSkillsResponses]; +export type PostBotsByBotIdToolsData = { + /** + * JSON-RPC request + */ + body: { + [key: string]: unknown; + }; + path: { + /** + * Bot ID + */ + bot_id: string; + }; + query?: never; + url: '/bots/{bot_id}/tools'; +}; + +export type PostBotsByBotIdToolsErrors = { + /** + * Bad Request + */ + 400: HandlersErrorResponse; + /** + * Not Found + */ + 404: HandlersErrorResponse; + /** + * Internal Server Error + */ + 500: HandlersErrorResponse; +}; + +export type PostBotsByBotIdToolsError = PostBotsByBotIdToolsErrors[keyof PostBotsByBotIdToolsErrors]; + +export type PostBotsByBotIdToolsResponses = { + /** + * JSON-RPC response: {jsonrpc,id,result|error} + */ + 200: { + [key: string]: unknown; + }; +}; + +export type PostBotsByBotIdToolsResponse = PostBotsByBotIdToolsResponses[keyof PostBotsByBotIdToolsResponses]; + export type DeleteBotsByIdData = { body?: never; path: { @@ -3423,11 +2212,15 @@ export type DeleteBotsByIdError = DeleteBotsByIdErrors[keyof DeleteBotsByIdError export type DeleteBotsByIdResponses = { /** - * No Content + * Accepted */ - 204: unknown; + 202: { + [key: string]: string; + }; }; +export type DeleteBotsByIdResponse = DeleteBotsByIdResponses[keyof DeleteBotsByIdResponses]; + export type GetBotsByIdData = { body?: never; path: { @@ -3661,7 +2454,7 @@ export type PostBotsByIdChannelByPlatformSendResponses = { export type PostBotsByIdChannelByPlatformSendResponse = PostBotsByIdChannelByPlatformSendResponses[keyof PostBotsByIdChannelByPlatformSendResponses]; -export type PostBotsByIdChannelByPlatformSendSessionData = { +export type PostBotsByIdChannelByPlatformSendChatData = { /** * Send payload */ @@ -3677,10 +2470,10 @@ export type PostBotsByIdChannelByPlatformSendSessionData = { platform: string; }; query?: never; - url: '/bots/{id}/channel/{platform}/send_session'; + url: '/bots/{id}/channel/{platform}/send_chat'; }; -export type PostBotsByIdChannelByPlatformSendSessionErrors = { +export type PostBotsByIdChannelByPlatformSendChatErrors = { /** * Bad Request */ @@ -3699,9 +2492,9 @@ export type PostBotsByIdChannelByPlatformSendSessionErrors = { 500: HandlersErrorResponse; }; -export type PostBotsByIdChannelByPlatformSendSessionError = PostBotsByIdChannelByPlatformSendSessionErrors[keyof PostBotsByIdChannelByPlatformSendSessionErrors]; +export type PostBotsByIdChannelByPlatformSendChatError = PostBotsByIdChannelByPlatformSendChatErrors[keyof PostBotsByIdChannelByPlatformSendChatErrors]; -export type PostBotsByIdChannelByPlatformSendSessionResponses = { +export type PostBotsByIdChannelByPlatformSendChatResponses = { /** * OK */ @@ -3710,7 +2503,49 @@ export type PostBotsByIdChannelByPlatformSendSessionResponses = { }; }; -export type PostBotsByIdChannelByPlatformSendSessionResponse = PostBotsByIdChannelByPlatformSendSessionResponses[keyof PostBotsByIdChannelByPlatformSendSessionResponses]; +export type PostBotsByIdChannelByPlatformSendChatResponse = PostBotsByIdChannelByPlatformSendChatResponses[keyof PostBotsByIdChannelByPlatformSendChatResponses]; + +export type GetBotsByIdChecksData = { + body?: never; + path: { + /** + * Bot ID + */ + id: string; + }; + query?: never; + url: '/bots/{id}/checks'; +}; + +export type GetBotsByIdChecksErrors = { + /** + * Bad Request + */ + 400: HandlersErrorResponse; + /** + * Forbidden + */ + 403: HandlersErrorResponse; + /** + * Not Found + */ + 404: HandlersErrorResponse; + /** + * Internal Server Error + */ + 500: HandlersErrorResponse; +}; + +export type GetBotsByIdChecksError = GetBotsByIdChecksErrors[keyof GetBotsByIdChecksErrors]; + +export type GetBotsByIdChecksResponses = { + /** + * OK + */ + 200: BotsListChecksResponse; +}; + +export type GetBotsByIdChecksResponse = GetBotsByIdChecksResponses[keyof GetBotsByIdChecksResponses]; export type GetBotsByIdMembersData = { body?: never; @@ -4677,7 +3512,7 @@ export type GetUsersResponses = { /** * OK */ - 200: UsersListUsersResponse; + 200: AccountsListAccountsResponse; }; export type GetUsersResponse = GetUsersResponses[keyof GetUsersResponses]; @@ -4686,7 +3521,7 @@ export type PostUsersData = { /** * User payload */ - body: UsersCreateUserRequest; + body: AccountsCreateAccountRequest; path?: never; query?: never; url: '/users'; @@ -4713,7 +3548,7 @@ export type PostUsersResponses = { /** * Created */ - 201: UsersUser; + 201: AccountsAccount; }; export type PostUsersResponse = PostUsersResponses[keyof PostUsersResponses]; @@ -4742,7 +3577,7 @@ export type GetUsersMeResponses = { /** * OK */ - 200: UsersUser; + 200: AccountsAccount; }; export type GetUsersMeResponse = GetUsersMeResponses[keyof GetUsersMeResponses]; @@ -4751,7 +3586,7 @@ export type PutUsersMeData = { /** * Profile payload */ - body: UsersUpdateProfileRequest; + body: AccountsUpdateProfileRequest; path?: never; query?: never; url: '/users/me'; @@ -4774,7 +3609,7 @@ export type PutUsersMeResponses = { /** * OK */ - 200: UsersUser; + 200: AccountsAccount; }; export type PutUsersMeResponse = PutUsersMeResponses[keyof PutUsersMeResponses]; @@ -4812,7 +3647,7 @@ export type GetUsersMeChannelsByPlatformResponses = { /** * OK */ - 200: ChannelChannelUserBinding; + 200: ChannelChannelIdentityBinding; }; export type GetUsersMeChannelsByPlatformResponse = GetUsersMeChannelsByPlatformResponses[keyof GetUsersMeChannelsByPlatformResponses]; @@ -4821,7 +3656,7 @@ export type PutUsersMeChannelsByPlatformData = { /** * Channel user config payload */ - body: ChannelUpsertUserConfigRequest; + body: ChannelUpsertChannelIdentityConfigRequest; path: { /** * Channel platform @@ -4849,16 +3684,49 @@ export type PutUsersMeChannelsByPlatformResponses = { /** * OK */ - 200: ChannelChannelUserBinding; + 200: ChannelChannelIdentityBinding; }; export type PutUsersMeChannelsByPlatformResponse = PutUsersMeChannelsByPlatformResponses[keyof PutUsersMeChannelsByPlatformResponses]; +export type GetUsersMeIdentitiesData = { + body?: never; + path?: never; + query?: never; + url: '/users/me/identities'; +}; + +export type GetUsersMeIdentitiesErrors = { + /** + * Bad Request + */ + 400: HandlersErrorResponse; + /** + * Not Found + */ + 404: HandlersErrorResponse; + /** + * Internal Server Error + */ + 500: HandlersErrorResponse; +}; + +export type GetUsersMeIdentitiesError = GetUsersMeIdentitiesErrors[keyof GetUsersMeIdentitiesErrors]; + +export type GetUsersMeIdentitiesResponses = { + /** + * OK + */ + 200: HandlersListMyIdentitiesResponse; +}; + +export type GetUsersMeIdentitiesResponse = GetUsersMeIdentitiesResponses[keyof GetUsersMeIdentitiesResponses]; + export type PutUsersMePasswordData = { /** * Password payload */ - body: UsersUpdatePasswordRequest; + body: AccountsUpdatePasswordRequest; path?: never; query?: never; url: '/users/me/password'; @@ -4921,7 +3789,7 @@ export type GetUsersByIdResponses = { /** * OK */ - 200: UsersUser; + 200: AccountsAccount; }; export type GetUsersByIdResponse = GetUsersByIdResponses[keyof GetUsersByIdResponses]; @@ -4930,7 +3798,7 @@ export type PutUsersByIdData = { /** * User update payload */ - body: UsersUpdateUserRequest; + body: AccountsUpdateAccountRequest; path: { /** * User ID @@ -4966,7 +3834,7 @@ export type PutUsersByIdResponses = { /** * OK */ - 200: UsersUser; + 200: AccountsAccount; }; export type PutUsersByIdResponse = PutUsersByIdResponses[keyof PutUsersByIdResponses]; @@ -4975,7 +3843,7 @@ export type PutUsersByIdPasswordData = { /** * Password payload */ - body: UsersResetPasswordRequest; + body: AccountsResetPasswordRequest; path: { /** * User ID diff --git a/packages/shared/README.md b/packages/shared/README.md new file mode 100644 index 00000000..6554adbe --- /dev/null +++ b/packages/shared/README.md @@ -0,0 +1 @@ +# @memoh/shared diff --git a/packages/shared/package.json b/packages/shared/package.json new file mode 100644 index 00000000..6e841596 --- /dev/null +++ b/packages/shared/package.json @@ -0,0 +1,13 @@ +{ + "name": "@memoh/shared", + "version": "1.0.0", + "description": "", + "exports": { + ".": "./src/index.ts" + }, + "main": "index.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "packageManager": "pnpm@10.27.0" +} diff --git a/packages/shared/src/chatInfo.ts b/packages/shared/src/chatInfo.ts new file mode 100644 index 00000000..3a11d13b --- /dev/null +++ b/packages/shared/src/chatInfo.ts @@ -0,0 +1,15 @@ +export interface robot{ + description: string + time: Date, + id: string | number, + type: string, + action: 'robot', + state:'thinking'|'generate'|'complete' +} + +export interface user{ + description: string, + time: Date, + id: number | string, + action:'user' +} \ No newline at end of file diff --git a/packages/shared/src/index.ts b/packages/shared/src/index.ts new file mode 100644 index 00000000..f23cc870 --- /dev/null +++ b/packages/shared/src/index.ts @@ -0,0 +1,5 @@ +export * from './model' +export * from './schedule' +export * from './platform' +export * from './mcp' +export * from './chatInfo' diff --git a/packages/shared/src/mcp.ts b/packages/shared/src/mcp.ts new file mode 100644 index 00000000..da2993b1 --- /dev/null +++ b/packages/shared/src/mcp.ts @@ -0,0 +1,48 @@ +export interface BaseMCPConnection { + type: string + name: string +} + +export interface StdioMCPConnection extends BaseMCPConnection { + type: 'stdio' + command: string + args: string[] + env: Record + cwd: string +} + +export interface BaseHTTPMCPConnection extends BaseMCPConnection { + url: string + headers: Record +} + +export interface HTTPMCPConnection extends BaseHTTPMCPConnection { + type: 'http' +} + +export interface SSEMCPConnection extends BaseHTTPMCPConnection { + type: 'sse' +} + +export type MCPConnection = + | StdioMCPConnection + | HTTPMCPConnection + | SSEMCPConnection + + +export interface MCPListItem{ + id: string; + type: string; + name: string; + config: { + cwd: string; + env: Record; + args: string[]; + type: string; + command: string; + }; + active: boolean; + user: string; + createdAt: string; + updatedAt: string; +} \ No newline at end of file diff --git a/packages/shared/src/model.ts b/packages/shared/src/model.ts new file mode 100644 index 00000000..b48c52e0 --- /dev/null +++ b/packages/shared/src/model.ts @@ -0,0 +1,101 @@ +export enum ModelClientType { + OPENAI = 'openai', + ANTHROPIC = 'anthropic', + GOOGLE = 'google', +} + +export enum ModelType { + CHAT = 'chat', + EMBEDDING = 'embedding', +} + +export interface BaseModel { + /** + * @description The unique identifier for the model + * @example 'gpt-4o' + */ + modelId: string + + /** + * @description The base URL for the model + * @example 'https://api.openai.com/v1' + */ + baseUrl: string + + /** + * @description The API key for the model + * @example 'sk-1234567890' + */ + apiKey: string + + /** + * @description The client type for the model + * @enum {ModelClientType} + */ + clientType: ModelClientType + + /** + * @description The display name for the model + * @example 'GPT 4o' + */ + name?: string + + /** + * @description The model type + * @enum {ModelType} + * @default {ModelType.CHAT} + */ + type?: ModelType +} + +export interface EmbeddingModel extends BaseModel { + type?: ModelType.EMBEDDING + + /** + * @description The dimensions of the embedding + * @example 1536 + */ + dimensions: number +} + +export interface ChatModel extends BaseModel { + type?: ModelType.CHAT +} + +export type Model = EmbeddingModel | ChatModel + + +/** Model row type for list/table views. */ +export interface ModelList { + apiKey: string, + baseUrl: string, + clientType: 'OpenAI' | 'Anthropic' | 'Google', + modelId: string, + name: string, + type: 'chat' | 'embedding', + id: string, + defaultChatModel: boolean, + defaultEmbeddingModel: boolean, + defaultSummaryModel: boolean +} + +export interface ProviderInfo{ + api_key: string; + base_url: string; + client_type: string; + metadata: Record<'additionalProp1',object>; + name: string; +} + +export interface ModelInfo{ + dimensions:number + is_multimodal:boolean + input?: string[] + llm_provider_id:string + model_id:string + name:string + type: string + enable_as?:string +} + +export const clientType = ['openai', 'anthropic', 'google', 'ollama'] as const diff --git a/packages/shared/src/platform.ts b/packages/shared/src/platform.ts new file mode 100644 index 00000000..bc371c73 --- /dev/null +++ b/packages/shared/src/platform.ts @@ -0,0 +1,7 @@ +export interface Platform { + id: string + name: string + // endpoint: string + config: Record + active: boolean +} \ No newline at end of file diff --git a/packages/shared/src/schedule.ts b/packages/shared/src/schedule.ts new file mode 100644 index 00000000..e08ca076 --- /dev/null +++ b/packages/shared/src/schedule.ts @@ -0,0 +1,8 @@ +export interface Schedule { + id?: string + pattern: string + name: string + description: string + command: string + maxCalls?: number | null +} \ No newline at end of file diff --git a/packages/web/mise.toml b/packages/web/mise.toml index 78461797..82fe917e 100644 --- a/packages/web/mise.toml +++ b/packages/web/mise.toml @@ -2,23 +2,13 @@ alias = "dev" description = "Start web development server" run = "pnpm dev" -depends = [ - "//:pnpm-install", - "//:sdk-generate", -] +depends = ["//:pnpm-install"] [tasks.build] description = "Build web" run = "pnpm build" -depends = [ - "//:pnpm-install", - "//:sdk-generate", -] +depends = ["//:pnpm-install"] [tasks.start] description = "Start web" -run = "pnpm start" -depends = [ - "//:pnpm-install", - "//:sdk-generate", -] \ No newline at end of file +run = "pnpm start" \ No newline at end of file diff --git a/packages/web/package.json b/packages/web/package.json index 626e4107..58a831d4 100644 --- a/packages/web/package.json +++ b/packages/web/package.json @@ -9,8 +9,14 @@ "start": "vite preview" }, "dependencies": { - "@memoh/ui": "workspace:*", + "@fortawesome/fontawesome-svg-core": "^7.0.0", + "@fortawesome/free-brands-svg-icons": "^7.0.0", + "@fortawesome/free-regular-svg-icons": "^7.0.0", + "@fortawesome/free-solid-svg-icons": "^7.0.0", + "@fortawesome/vue-fontawesome": "^3.1.1", "@memoh/sdk": "workspace:*", + "@memoh/shared": "workspace:*", + "@memoh/ui": "workspace:*", "@pinia/colada": "^0.21.1", "@tailwindcss/vite": "^4.1.18", "@tanstack/vue-table": "^8.21.3", @@ -33,12 +39,7 @@ "vue-i18n": "^11.2.8", "vue-router": "^4.6.4", "vue-sonner": "^2.0.9", - "zod": "^4.3.5", - "@fortawesome/fontawesome-svg-core": "^7.0.0", - "@fortawesome/free-brands-svg-icons": "^7.0.0", - "@fortawesome/free-regular-svg-icons": "^7.0.0", - "@fortawesome/free-solid-svg-icons": "^7.0.0", - "@fortawesome/vue-fontawesome": "^3.1.1" + "zod": "^4.3.5" }, "devDependencies": { "@types/node": "^24.10.1", diff --git a/packages/web/public/channels/feishu.png b/packages/web/public/channels/feishu.png new file mode 100644 index 0000000000000000000000000000000000000000..73a55287721ed08c3cde07feb2dde6faa6eb715f GIT binary patch literal 669 zcmV;O0%HA%P)e~Pzk^tPgOn(0rE|?Qz z{|PskYBt)JPnX40krbS$+zs~;` zdZmk}|DLx0aC*4W-2Zcz|H8-M?(XsK)N50MKokaG7F-0@09Q0Gm6kUui%P;Y?e_nFREwgqdq8&VEANb7cJ`U`p4CxN z%9QzkT3Q9>Tt>FFswSpdlaZZnhU`hzy?|1af=UfYxw49K24(A1c_{}oWJgkpdseR` zA$Z$v&b>Z3sc_D5T7&DG;czsb@Yklk2s@m+y|aS4#ZtHLN9)uzoz6trJCt=E9)pIu zkM+bsjDg%gkn@~k403Df7-3gjbhoDS5~P?Kj1>VQf;=9!o;2!V^0bo>q|Btz*(Qb0mx!*nnKYp=&~R(O~l#nr6lBRbA$B^mj}1j zj6Wv#{mQ^3Ij#`WX|_{5dW;d0^)9BW9lxvVwA+wqm}Qv{_JNF9$l?WnUnGhwJ;!6yIjZR00000NkvXXu0mjf D0N72W literal 0 HcmV?d00001 diff --git a/packages/web/public/channels/telegram.webp b/packages/web/public/channels/telegram.webp new file mode 100644 index 0000000000000000000000000000000000000000..3afc19e3868ffdc9f365680a711948d0db305dab GIT binary patch literal 6724 zcmV-K8oT9ENk&FI8UO%SMM6+kP&iC48UO$ zJG@-?OFYe)44I9XK66Q#wR?6^mO{pbYmBqQC{P7hO!A)SM22)H=maO+vda$tv|aoE z6x&i>{XWB>cg@))UDaK-$0fraVq#u_H?VWgKE1oE{#D^W+Gy-@$5PxKPKcinjVv13 z3)d6j?s`+4J-DP0y1SE&`(#8O@dEDd?v}ZZ&*1J(#Nq*n=A^i5;}8;!?ZrF*o|6%| zEZUK6tG1g`k~a)7UfsJUuJfK)X82`Df*paQa1?X|KyoB0*gOCe4@<+zsf>(_$f#8a zbHBhh?&}n=Z9Cp($6wpFZQHhO+il9(wrwl;SKA!Y)Y)o#jlX9C=vv!<+IH^$HZsm$ z+P3%6wr%fY+lcmBbN&C{%pIKAdyExTu2`yDb+G%`))mZ*(}kB{+nljG8&_deuF~7~ z46eZFY;^pMe-WE<_%@5QQZU2S8mZM&UWn-By5 zEE{XO+U{ICwUrdt{|P{d=C}U{+PS(;*P#suw`VgfZk3yR$uUnIgR#8bju~(W!6DN} za>#~5HXX9(5a>JP-^LraXNgB@PsUH-aC;qsbI9{He8utBxJi~p!mXa2=D%9sp5H_< zaW29STz6?>ecO{f$GvQMKRaaDAzS@QI%MA=zug4wkZV>ca=Jn7iJ{~hUxHA)Qad;X z)y6(^2;MRDejOdMX!%s{uvgvOxb|4vp5B>Lfz_3>x>4=f>yGJi$Zo&pjyZ75faASy z^^{<*R*yV;dS~8;)lGY?V`lo^bjZ$o9QMe3mvNyIN2IkSjW%59Cbfs?d)Fak%iHRZ zins6n&UgPs*38k}cTYQ{v ztt0gVy>WZ??))#M6BScOf8sg%A$QEc1~ai>C7uo1<8gD~ci`>@PA%OI8km|CX8cpahjdaJy#3}ZvMTIRdennCK@L(9-MgssJ*m0lH~v@3~a z1FK!#5WH47SZAT9hZjVKak0#=4MA*`c{h%0!1I60U^fJ_MLKn^DQsqc`#-Dk6UuVQ z8v^=|mbv?Dz_cfZmOHc|ur0G-v4K#TS;r!DivYYW@@{?}F_5fakv~04mM zYs)T#*hjDq`_XKOkjouEU({h^%{worw+uPZ`VT4*{x_y>LkgTF* zvBY7NTI|MW17zmO+QMtYNd2T8M+N8a7E2mNtHmNldn7RKx`&T5j9E8(Z%#lOt$W2V z5V>2;4aO-9^U!c1dX2WpN<4v>qm^fa2)1yWGebr8+kArwc%J@@(co5}>_L5d23>S&BfuXBC^jX6t4J4{hJ&!boersrmd_%Xt#~(&2 z=|{hTuX@AGIvSc5e|D?Qg#bE#=z zQaRx@vJ#KxPE007jorFzeKhy+WO6i6<5XSrtmkBcG^31~EUmdSDME91B;RD}VK7%U z4N1~qFtt#zxu)r9qNkXU8d&|zrs?0kM5*begI3Q#hkMCVWA9h)|C%PJ3r!<-Y+p4k z<9Z46>~pGRD3@N+)O1><=A1okXnJ)4NuWv^%I_skO<$~#)k8>W*Hee4ssFEzhHv(g zr{N4sKsC3osfLzxG|fmMJwtm{L&KkYNmSE=O%)z^?ljUf zR6rFpgw+%zQceAu>Yko4)fAFShGrPnJtfyF8Iq~8N2j_~D0)gp(#gt-_Uo3ipRb%y z3#YY9&rmWYrL6pl?i9(nm6Pg=yYg!dZKV|@)Ahy5-B5Pn%5mzBp$JWzdLI_%zYSVu zlY!Q#tLJ5`SSRox5t@~Y?3FIeQ8F!fKp$GoMm0fY^-EBvbzI8=QV8r zZY&zQU(%G4iIFf*j2$!`s=cllS=FDvV0pUa8(%qvqoM%*ZYYX(r>#CjVldGcKZI3LeWmc52rG&TzEElR!;8dZcH9B;mM@f{$+Wk? z-VH;Cw)zx_)x3;_M$yO1P*elxHY_QM{y=5u(YDcunHCph51kSBSHLjn6vtSH()ENbf##(LIyw^;5*V>G<#8q>tYoI_SCyzd;;o*BS@3tc$FMaaZDyb`_wIe64JZvnDaPjyxh(Q zi~+tP&3W@znZUMqlR*{M#qkv99aw5c?Xso_ZfWfas<3*-E2KG~Jax;T1UNJ5f+{5H z%vefNn1txat2KX92y?6}12tLL_G2kc6hJ&_+o&55>M&)CDvvA2%A*_Lc zLseLqu?o2h79;m*{V5Rgo+nHXT&N0Zag3&rb|5jzs9iP`VQ>4(+hJ#V78~`qNjN@r zORoz?=o@wnTng#Jhp$Bf1%!so9aqBE7eBE$3@POAh<;>U& z(21+}ImAM(*9&^Ev@fkbd!8}kx80N*ap1}uc6DZM9vFDx@kLamU%>R>xqRb+MjZFL zVzZl2G;rlfp@^7m*wK+GP9Ea*gXK={srUZ)#Pg4XW&OZr*3nlsc5+h|vfm*V;^o|e zF#Gc6+>F)zL=eyWT;iV}K^+6ntA<}_Y`2VIs*KiWo zt-pjo$#A~8LGq#7fh2~*H4U=d8b&NMBjGV&D9)2yLVWkUV8`j<>wI(lMU#VA_GE3y za(f%`MWONLuE*_8oS)_^n9n!chqo+1EWfxcvd2(K;4+U1stb63Hs8a<`9n#;aM|PtKy`VC}9Hd(q6_VohSZtntvG7c;)O$1xl= z>0$3~;{L^yyWmdEd>`)8amf(lnGpwDfl16gp`i#A55<+Yjrec=A`h!s`0`W4db9&z zD=?)}$_RoUy6T&Y(+J>;pDwFe=*yowt!QHYO>DVzCg)TKz@4^#dJ7>u?Pbl<0MR%o zvF}PCY`Jv5059hn@O_}TFRr|%AcQ?mc=BDK15ee`JsrgU-75B6{G4l^`EAzv=DgX9 zLKs`TNxu9^wKz!H?@k=P1Fm4V$)NyTu30X~5lkoL- zw&UokfKhf}BTo0566x$Fb?s*? zVRh!^z6K0Q0()_n?TPz1Z;VfRrfXIUyqb;`!rC`XQ%geo3C10W{}5FRPkIpPA=Iq4 zp(Lozg2Km$K}qN!S?3^tuH+6+dPveUQnTF556YV1+FBC)i(v?0r!_6Xm!2Kz>8V*k z@q?tW^fnSm61=Xl=yApgVdIg8PkL6Q=l`0;`(qrRkI+|=@X?OcMi9rs_fcWJsF{g? z3!pgUc_iTzf{~Lju7gjS9XWZWW>ToSc@e?^cm4A(Fx|zpCy!7DD?5DB^F?%O>LvwB z`bz-p_P-xs@6+Xflasaxriq03q~@VCvuY>lpffQTIygJWZhre2rn~3@ore?59VGlx zb5_)0^^+tWaz1!2Pe*X?4or7Z2cKol2jOfq&_qCH%B1D&14)v6g;RRiZ9ic-**(^q zOn1|OqDFmumgYuT?M&G4wF7hoQjrAlVATj~=+@jtX zps!#`eei0Wm2yPdfn?R|mL{md@>T?-X^PUb=MzbiJeccm{&IrKOgxihRZsT%KVe-& z1PV$7rTJT$;9VmD>s!H*yR*=K%GNO!cI*0Z$Bqkzu-;JF9t@mtEh@534Qw|=0wa}eGjS-U@ektsqHY9sFrs1e=$ zMSX;}jIyV(uDG5h|H6wiyMV|3EN<9sM-X1;b0Q!;frye{_@sygf$3mH`v(DDbAJ(( z9bH6};tVjZ+MN(L?CByPMT*tWY=|huudwi9&pQz0?4%_EQXD}o}$m@55ncSXl{M1ccQvmhdDB6TtWstSD!y;TPeN zvG`5Y;BX7907)kT!gL+05#|aQRzdZ!9)nczgQAY%_Dq_5OK7C^sCJ78$TG+)@&k969T_`&q z3bPdsT5jeG$wx5alWHP%C^D6R&=P{-!+7?q1s(XL`gSKgV+aXtqHYHQPRkEb!!I;` zoixH+fQ07Bnw!2TXbOuje8Qx(UKejCDAkWap)H#pc5GM@LjgV^He$!p42OlL`3d`0 zp8}Lnb&XHBOFPS7Kv0-8U}!GTF_jI7E4AXwB%RJsgq_gPv_M(g_LqavX-RAFg*n$u zK!_-KXgUOC#%&vjhCLmhaNoiSW=lokMuZyx8l-62PB;vi30sAq^IRu zd>BEbBZVi@OQ)ESG>So@s=_Trg{SA7d}qjdY-Xb$RVTO z@5C?*yMQ9gVcM$0g5j*ymms6h!nvk|%v}H$Sx>^x`6{Zs;0g6bERpX@XCVVDvK+1} zZ{aEF&WD$!&C5Vx)&fR0Qd>#w!BLpLLYIYvrW789Kt`6s4do8pxH49Be5VLYo&D_x z)~s9tjclTCWi6AFL1pk+NFhgymV?554mPq}*9z|DqVZ5gmwDb`$*imVGT_KNR6-WcgbMdr|HbTo%sE;y0$!ggP2NvIl!^JW`qSmv-Ppx)MvL5=UmP zWG28@j`7AHuw`>W(G$=!lOm3mEDwcnICBGp-Pv=afyz{yUtYn^!$;V~kq<9InR_2X zic;4Cy6h;13QI3=BBkFA%+$AGq`1{bd z$!X{AVCMb;CFQ!s*tokN@PL-3e>7bZDO)fpN~AMAhJ|ev=*wWYGj|W16nECV3Ilb_ z09xk$yJgn|;Q=T``t>1eJ)}c-Xo5(??hBz6P>SLs8YZw(X~R^IXYM~%KeJUJ6P-Y% zxWF(oX+3O}AY}tBODA9zGj|iHp7NAHdDCvfTA}&7q#) z5UeUUUxBN-+FLj*=Co+PKvi3foQGXaWhB4I@Ngeq)z$v8u0dCS*u1VgGE84qBj+K5 zhNf6u9vtCm0J$4$PJZ%UuFTo6ueo!kD(F0Xh!wUEADle>2V_-MRUX_a(wSr=?Be>P zK~Hbc)rxD6qD?Dq3A0ob)I!Na6jr+7?e>(B8PjvCh0o$?M#A8pK&(1% zUrQI6fWgXVZ=bn&dL~uy^v;c|huq_tDzbD2){%!V(*#d<&*C?xS`_>s!qclMf~U!n zt%dbuVO)yf?Hc4Y6ENw4r-$EAn-}u*QflC75)E~+zO?e-X(Tc5APeiv${T5c2e_tI zb2HaZ9za6m!AkM+t2ov>)6wApDIM?tVcXtVcODd_18!Ww+BZCSp9*-;v#gEoJm6xrNeYVveOrWNC$d%~&?Si#L`=Xc&dQr8}W4&4@PIE=B*c zc15hQ;6D18hBxDkrE&N2(!E&vQ;cz=^FZiltDo6efMEedH_sWlC;IwbENG4`7L3n( zL>&E{z5Sh=1?_Rgf~nT)l0u))es%d-Fc?!Tn9ugJBKkek5o0YvJh5Q8RjZ@#ANJ2! zV!$V66|a!-BpM zz6fDB@5OT-#tnC2uYp13LOu%^5wrEZFtBtH$^ufXumGWpBN*NruwXS#SU_g(NDQ&y zL*?}`!h&_`vBfYyjODbm6htRW{r)b7de-m5LKkOf0g*|IFMq~>lOmD^`!PWaIt@?_ z2A%3?&;=nLSZa+FE*Scy$>LKi&|+X@S02MZMHiK{pvC|#5Zb=-1_a@=XfX~yTC6Z1 zZ-nBz$nq-mXz^WKK1d=UX)1^o>yV=biKg-!1m(Q1BtwfvxY1yk9Ib`WY*-LS7aK5J z{jd8%fNo0{A0S4Hjw2&b>B5(#=AcE3on@yGtjoQO77S3)Hij%^5wzrJd%75b6D_8> zD-fYO^>bP*!{~bZnTr74$(Q1xdydc+nYIADw1v=~XH=3vAk$Elw%$R9ZuRmUO9b?G zwAD>h09cyBlXOUfu-=Zj`k_KoM3$3gPw&hK@9k)65F#}7_x^iDkWZ1VaIzSM#%akS zeMI_Y7g^#8vXQBD3i;=h- z%gLj`&#+O~G}Ect7GshH#w55)iak<5ph-a?i+6WA1;%FUb@8?i#F3&tM5Guy=)*8h zM{cA;g5#~_cvwg=M_QMNVU98T+1BF7qeDT8wbp@W+f8Bg(i~H=I8TZR|DA^vw}{(* z3M2VuWRdoB(%L00@HT1DNion`9AWul++bACZIx~}CM|-GNQ+4GxJ{1Z|9gz`skPNZ zDQX=E{Lhw{H1ij~3Xa^4VXNgwibN8W9E9dPxIAgKeT%e~NwG; diff --git a/packages/web/src/components/Sidebar/index.vue b/packages/web/src/components/Sidebar/index.vue new file mode 100644 index 00000000..09b6bee3 --- /dev/null +++ b/packages/web/src/components/Sidebar/index.vue @@ -0,0 +1,124 @@ + + + diff --git a/packages/web/src/components/Sidebar/lists/chat-list-menu.vue b/packages/web/src/components/Sidebar/lists/chat-list-menu.vue index 9d62cd0b..5882de83 100644 --- a/packages/web/src/components/Sidebar/lists/chat-list-menu.vue +++ b/packages/web/src/components/Sidebar/lists/chat-list-menu.vue @@ -1,109 +1,46 @@