From 6acdd191c7b26eb9e7d82d5a6694bcd2f1f197c3 Mon Sep 17 00:00:00 2001 From: Ran <16112591+chen-ran@users.noreply.github.com> Date: Thu, 12 Feb 2026 17:13:03 +0800 Subject: [PATCH] Squashed commit of the following: commit bcdb026ae43e4f95d0b2c4f9bd440a2df9d6b514 Author: Ran <16112591+chen-ran@users.noreply.github.com> Date: Thu Feb 12 17:10:32 2026 +0800 chore: update DEVELOPMENT.md commit 30281742ef911884d3edce6ec7c317b81be6e281 Merge: ca5c6a1 5b05f13 Author: BBQ Date: Thu Feb 12 15:49:17 2026 +0800 merge(github/main): integrate fx dependency injection framework Merge upstream fx refactor and adapt all services to use go.uber.org/fx for dependency injection. Resolve conflicts in main.go, server.go, and service constructors while preserving our domain model changes. - Fix telegram adapter panic on shutdown (double close channel) - Fix feishu adapter processing messages after stop - Increase directory lookup timeout from 2s to 5s commit ca5c6a1866dabd8356fc4ba0c91a0a85e134fcea Author: BBQ Date: Thu Feb 12 15:33:09 2026 +0800 refactor(core): restructure conversation, channel and message domains - Rename chat module to conversation with flow-based architecture - Move channelidentities into channel/identities subpackage - Add channel/route for routing logic - Add message service with event hub - Add MCP providers: container, directory, schedule - Refactor Feishu/Telegram adapters with directory and stream support - Add platform management page and channel badges in web UI - Update database schema for conversations, messages and channel routes - Add @memoh/shared package for cross-package type definitions commit 75e2ef0467db19f1e4088b8669e9e4069f139bfd Merge: d99ba38 01cb6c8 Author: BBQ Date: Thu Feb 12 14:45:49 2026 +0800 merge(github): merge github/main, resolve index.ts URL conflict Keep our defensive absolute-URL check in createAuthFetcher. commit d99ba38b7d9d49b817309c9bbb9a9d57678f04be Merge: 860e20f 35ce7d1 Author: BBQ Date: Thu Feb 12 05:20:18 2026 +0800 merge(github): merge github/main, keep our code and docs/spec commit 860e20fe7045e361b8cf491d414852dbc036f22e Author: BBQ Date: Wed Feb 11 22:13:27 2026 +0800 docs(docs): add concepts and style guides for VitePress site - Add concepts: identity-and-binding, index (en/zh) - Add style: terminology (en/zh) - Update index and zh/index - Update .vitepress/config.ts commit a75fdb804015a46f3c131dc153e2283dd1b52f24 Author: BBQ Date: Wed Feb 11 17:37:16 2026 +0800 refactor(mcp): standardize unified tool gateway on go-sdk Split business executors from federation sources and migrate unified tool/federation transports to the official go-sdk for stricter MCP compliance and safer session lifecycle handling. Add targeted regression tests for accept compatibility, initialization retries, pending cleanup, and include updated swagger artifacts. commit 02b33c8e85bfc4da6859627291017f202912fdd1 Author: BBQ Date: Wed Feb 11 15:42:21 2026 +0800 refactor(core): finalize user-centric identity and policy cleanup Unify auth and chat identity semantics around user_id, enforce personal-bot owner-only authorization, and remove legacy compatibility branches in integration tests. commit 06e8619a37d6918542a3e6b9daa7b61bf1dcc2bd Author: BBQ Date: Wed Feb 11 14:47:03 2026 +0800 refactor(core): migrate channel identity and binding across app Align channel identity and bind flow across backend and app-facing layers, including generated swagger artifacts and package lock updates while excluding docs content changes. --- DEPLOYMENT.md | 9 +- agent/src/agent.ts | 141 +- agent/src/config.ts | 2 +- agent/src/index.ts | 21 +- agent/src/models.ts | 7 +- agent/src/modules/chat.ts | 11 +- agent/src/prompts/system.ts | 10 +- agent/src/prompts/user.ts | 10 +- agent/src/test/unified_mcp_tools.test.ts | 89 + agent/src/tools/contact.ts | 129 +- agent/src/tools/index.ts | 9 +- agent/src/tools/mcp.ts | 100 +- agent/src/tools/memory.ts | 14 +- agent/src/tools/message.ts | 34 +- agent/src/tools/subagent.ts | 33 +- agent/src/types/agent.ts | 10 +- cmd/agent/main.go | 465 ++-- cmd/feishu-echo/main.go | 120 + cmd/feishu-echo/main_test.go | 20 + cmd/mcp/Dockerfile | 18 + cmd/mcp/main.go | 3 +- db/migrations/0001_init.down.sql | 21 +- db/migrations/0001_init.up.sql | 197 +- db/queries/bind.sql | 22 + db/queries/bots.sql | 22 +- db/queries/channel_identities.sql | 49 + db/queries/channel_routes.sql | 87 + db/queries/channels.sql | 33 +- db/queries/contacts.sql | 76 - db/queries/containers.sql | 3 + db/queries/conversations.sql | 229 ++ db/queries/history.sql | 31 - db/queries/messages.sql | 118 + db/queries/settings.sql | 93 +- db/queries/users.sql | 92 +- docker-compose.yml | 2 +- docker/Dockerfile.agent | 2 +- docs/docs/.vitepress/config.ts | 48 +- docs/docs/concepts/identity-and-binding.md | 41 + docs/docs/concepts/index.md | 21 + docs/docs/index.md | 23 +- docs/docs/style/terminology.md | 40 + docs/docs/zh/concepts/identity-and-binding.md | 41 + docs/docs/zh/concepts/index.md | 21 + docs/docs/zh/index.md | 23 +- docs/docs/zh/style/terminology.md | 40 + internal/{users => accounts}/service.go | 233 +- internal/{users => accounts}/types.go | 19 +- internal/auth/jwt.go | 91 +- internal/bind/service.go | 242 ++ internal/bind/service_integration_test.go | 240 ++ internal/bind/service_test.go | 207 ++ internal/bind/types.go | 26 + internal/bots/service.go | 500 +++- internal/bots/types.go | 72 +- internal/channel/adapter.go | 48 +- internal/channel/adapters/feishu/config.go | 4 +- internal/channel/adapters/feishu/directory.go | 298 ++ .../channel/adapters/feishu/directory_test.go | 118 + internal/channel/adapters/feishu/feishu.go | 636 +++-- .../feishu/feishu_integration_test.go | 39 +- .../channel/adapters/feishu/feishu_test.go | 380 +++ internal/channel/adapters/feishu/inbound.go | 269 ++ .../feishu/{feishu_logger.go => logger.go} | 0 internal/channel/adapters/feishu/stream.go | 286 ++ internal/channel/adapters/local/cli.go | 37 +- internal/channel/adapters/local/hub.go | 106 +- internal/channel/adapters/local/hub_test.go | 50 + internal/channel/adapters/local/web.go | 37 +- internal/channel/adapters/telegram/config.go | 4 +- .../channel/adapters/telegram/directory.go | 239 ++ .../adapters/telegram/directory_test.go | 99 + internal/channel/adapters/telegram/logger.go | 19 + .../channel/adapters/telegram/logger_test.go | 46 + .../channel/adapters/telegram/markdown.go | 210 ++ .../adapters/telegram/markdown_test.go | 209 ++ internal/channel/adapters/telegram/stream.go | 211 ++ .../channel/adapters/telegram/stream_test.go | 119 + .../channel/adapters/telegram/telegram.go | 169 +- .../adapters/telegram/telegram_test.go | 280 ++ internal/channel/capabilities.go | 32 +- internal/channel/config_test.go | 2 +- internal/channel/connection.go | 50 +- internal/channel/directory.go | 2 +- internal/channel/helpers_test.go | 45 +- internal/channel/identities/service.go | 291 ++ .../service_identity_integration_test.go | 89 + .../identities/service_integration_test.go | 96 + internal/channel/identities/service_test.go | 37 + internal/channel/identities/types.go | 15 + internal/channel/inbound_test.go | 217 ++ internal/channel/manager.go | 25 +- internal/channel/manager_core_test.go | 128 - internal/channel/manager_integration_test.go | 60 +- internal/channel/outbound.go | 157 +- internal/channel/processor.go | 2 +- internal/channel/registry.go | 40 + internal/channel/registry_test.go | 63 + internal/channel/route/service.go | 362 +++ internal/channel/route/types.go | 68 + internal/channel/service.go | 256 +- internal/channel/types.go | 160 +- internal/chat/assistant_output.go | 2 +- internal/chat/resolver.go | 578 ++-- internal/chat/resolver_memory_context_test.go | 55 + internal/chat/resolver_test.go | 12 +- internal/chat/schedule_gateway.go | 2 +- internal/chat/service.go | 1000 +++++++ .../chat/service_presence_integration_test.go | 244 ++ internal/chat/types.go | 165 +- internal/contacts/service.go | 410 --- internal/contacts/types.go | 45 - .../conversation/flow/assistant_output.go | 36 + internal/conversation/flow/resolver.go | 1226 ++++++++ .../flow/resolver_memory_context_test.go | 55 + internal/conversation/flow/resolver_test.go | 158 ++ .../conversation/flow/schedule_gateway.go | 26 + internal/conversation/flow/types.go | 14 + internal/conversation/flow/types_alias.go | 15 + internal/conversation/interfaces.go | 20 + internal/conversation/resolver.go | 1179 ++++++++ internal/conversation/service_domain.go | 483 ++++ .../service_presence_integration_test.go | 242 ++ internal/conversation/types.go | 235 ++ internal/db/sqlc/bind.sql.go | 120 + internal/db/sqlc/bots.sql.go | 77 +- internal/db/sqlc/channel_identities.sql.go | 249 ++ internal/db/sqlc/channel_routes.sql.go | 298 ++ internal/db/sqlc/channels.sql.go | 144 +- internal/db/sqlc/contacts.sql.go | 380 --- internal/db/sqlc/containers.sql.go | 39 + internal/db/sqlc/conversations.sql.go | 678 +++++ internal/db/sqlc/history.sql.go | 178 -- internal/db/sqlc/messages.sql.go | 409 +++ internal/db/sqlc/models.go | 179 +- internal/db/sqlc/settings.sql.go | 312 ++- internal/db/sqlc/users.sql.go | 442 +-- internal/db/text.go | 11 + internal/db/text_test.go | 26 + internal/db/uuid.go | 12 - internal/directory/service.go | 226 -- internal/embeddings/dashscope.go | 2 +- internal/embeddings/resolver.go | 33 +- internal/handlers/auth.go | 57 +- internal/handlers/bind.go | 91 + internal/handlers/channel.go | 44 +- internal/handlers/chat.go | 238 -- internal/handlers/contacts.go | 183 -- internal/handlers/containerd.go | 328 ++- internal/handlers/embeddings.go | 2 +- internal/handlers/fs.go | 543 +++- internal/handlers/fs_mcp_session_test.go | 255 ++ internal/handlers/fs_rest.go | 585 ---- internal/handlers/history.go | 259 -- internal/handlers/local_channel.go | 134 +- internal/handlers/mcp.go | 47 +- internal/handlers/mcp_federation_gateway.go | 480 ++++ .../handlers/mcp_federation_gateway_test.go | 188 ++ internal/handlers/mcp_stdio.go | 66 +- internal/handlers/mcp_tools.go | 241 ++ internal/handlers/mcp_tools_test.go | 165 ++ internal/handlers/memory.go | 643 ++--- internal/handlers/message.go | 583 ++++ internal/handlers/preauth.go | 26 +- internal/handlers/schedule.go | 37 +- internal/handlers/settings.go | 56 +- internal/handlers/subagent.go | 88 +- internal/handlers/users.go | 328 ++- internal/history/service.go | 237 -- internal/history/types.go | 23 - internal/identity/types.go | 9 +- internal/identity/user.go | 12 +- internal/logger/logger.go | 8 +- internal/logger/logger_test.go | 12 +- internal/mcp/connections.go | 20 +- internal/mcp/jsonrpc.go | 77 - internal/mcp/manager.go | 157 +- internal/mcp/providers/container/fsops.go | 288 ++ .../mcp/providers/container/fsops_test.go | 148 + internal/mcp/providers/container/provider.go | 250 ++ .../mcp/providers/container/provider_test.go | 236 ++ internal/mcp/providers/directory/provider.go | 162 ++ .../mcp/providers/directory/provider_test.go | 72 + internal/mcp/providers/memory/provider.go | 205 ++ .../mcp/providers/memory/provider_test.go | 284 ++ internal/mcp/providers/message/provider.go | 179 ++ .../mcp/providers/message/provider_test.go | 247 ++ internal/mcp/providers/schedule/provider.go | 259 ++ .../mcp/providers/schedule/provider_test.go | 374 +++ internal/mcp/sources/federation/source.go | 276 ++ .../mcp/sources/federation/source_test.go | 126 + internal/mcp/tool_gateway_service.go | 168 ++ internal/mcp/tool_gateway_service_test.go | 126 + internal/mcp/tool_registry.go | 72 + internal/mcp/tool_registry_test.go | 83 + internal/mcp/tool_types.go | 197 ++ internal/mcp/tools.go | 421 --- internal/mcp/versioning.go | 27 +- internal/memory/indexer_test.go | 14 +- internal/memory/qdrant_store.go | 2 +- internal/memory/service.go | 24 +- internal/memory/service_test.go | 38 +- internal/memory/types.go | 102 +- internal/message/event/hub.go | 124 + internal/message/event/hub_test.go | 59 + internal/message/service.go | 358 +++ internal/message/types.go | 52 + internal/models/models.go | 47 +- internal/models/models_test.go | 19 +- internal/models/types.go | 14 +- internal/policy/service.go | 31 + internal/preauth/service.go | 44 +- internal/preauth/types.go | 15 +- internal/providers/service.go | 29 +- internal/providers/types.go | 12 +- internal/router/channel.go | 693 ++++- internal/router/channel_test.go | 749 +++-- internal/router/identity.go | 501 +++- internal/router/identity_test.go | 797 +++++- internal/schedule/service.go | 48 +- internal/schedule/trigger.go | 3 +- internal/schedule/types.go | 20 +- internal/settings/service.go | 203 +- internal/settings/types.go | 12 +- internal/subagent/service.go | 44 +- internal/subagent/types.go | 28 +- packages/sdk/src/@pinia/colada.gen.ts | 490 +--- packages/sdk/src/index.ts | 4 +- packages/sdk/src/sdk.gen.ts | 291 +- packages/sdk/src/types.gen.ts | 1738 ++---------- packages/shared/README.md | 1 + packages/shared/package.json | 13 + packages/shared/src/chatInfo.ts | 15 + packages/shared/src/index.ts | 5 + packages/shared/src/mcp.ts | 48 + packages/shared/src/model.ts | 101 + packages/shared/src/platform.ts | 7 + packages/shared/src/schedule.ts | 8 + packages/web/mise.toml | 16 +- packages/web/package.json | 15 +- packages/web/public/channels/feishu.png | Bin 0 -> 669 bytes packages/web/public/channels/telegram.webp | Bin 0 -> 6724 bytes packages/web/src/App.vue | 1 - packages/web/src/components/Sidebar/index.vue | 124 + .../Sidebar/lists/chat-list-menu.vue | 114 + .../Sidebar/lists/settings-list-menu.vue | 87 + .../web/src/components/Sidebar/lists/types.ts | 4 + .../web/src/components/add-provider/index.vue | 18 +- .../chat-list/channel-badge/index.vue | 46 + .../web/src/components/chat-list/index.vue | 104 +- .../components/chat-list/robot-chat/index.vue | 71 + .../components/chat-list/user-chat/index.vue | 40 +- .../web/src/components/create-mcp/index.vue | 8 +- .../web/src/components/create-model/index.vue | 37 +- .../src/components/main-container/index.vue | 2 +- packages/web/src/components/sidebar/index.vue | 100 - packages/web/src/composables/api/useAuth.ts | 13 + .../web/src/composables/api/useBotSettings.ts | 42 + packages/web/src/composables/api/useBots.ts | 196 ++ .../web/src/composables/api/useChannels.ts | 88 + packages/web/src/composables/api/useChat.ts | 528 +++- packages/web/src/composables/api/useMcp.ts | 21 +- packages/web/src/composables/api/useModels.ts | 87 + .../web/src/composables/api/usePlatform.ts | 2 +- .../web/src/composables/api/useProviders.ts | 80 + packages/web/src/composables/api/useUsers.ts | 60 + packages/web/src/composables/useAutoScroll.ts | 60 +- .../web/src/composables/useKeyValueTags.ts | 10 +- packages/web/src/i18n/locales/en.json | 211 +- packages/web/src/i18n/locales/zh.json | 211 +- packages/web/src/main.ts | 33 +- .../src/pages/bots/components/bot-card.vue | 125 +- .../pages/bots/components/bot-channels.vue | 72 +- .../pages/bots/components/bot-settings.vue | 132 +- .../components/channel-settings-panel.vue | 61 +- .../src/pages/bots/components/create-bot.vue | 98 +- .../pages/bots/components/model-select.vue | 12 +- packages/web/src/pages/bots/detail.vue | 888 +++++- packages/web/src/pages/bots/index.vue | 59 +- .../src/pages/chat/components/bot-sidebar.vue | 23 +- .../src/pages/chat/components/chat-area.vue | 11 +- .../pages/chat/components/thinking-block.vue | 1 - packages/web/src/pages/chat/index.vue | 162 +- packages/web/src/pages/login/index.vue | 9 +- packages/web/src/pages/main-section/index.vue | 2 +- packages/web/src/pages/mcp/index.vue | 24 +- .../pages/models/components/model-item.vue | 6 +- .../pages/models/components/model-list.vue | 6 +- .../pages/models/components/provider-form.vue | 4 +- packages/web/src/pages/models/index.vue | 51 +- .../web/src/pages/models/model-setting.vue | 69 +- packages/web/src/pages/settings/index.vue | 157 +- packages/web/src/pages/settings/user.vue | 499 ++++ packages/web/src/router.ts | 162 +- packages/web/src/store/User.ts | 72 + packages/web/src/store/chat-list.ts | 838 ++++-- packages/web/src/store/settings.ts | 1 - packages/web/src/store/user.ts | 12 + packages/web/src/types/index.ts | 0 packages/web/src/utils/channel-icons.ts | 30 + packages/web/src/utils/request.ts | 11 +- pnpm-lock.yaml | 5 + spec/docs.go | 2467 +++-------------- spec/swagger.json | 2467 +++-------------- spec/swagger.yaml | 1693 ++--------- 305 files changed, 33212 insertions(+), 17058 deletions(-) create mode 100644 agent/src/test/unified_mcp_tools.test.ts create mode 100644 cmd/feishu-echo/main.go create mode 100644 cmd/feishu-echo/main_test.go create mode 100644 cmd/mcp/Dockerfile create mode 100644 db/queries/bind.sql create mode 100644 db/queries/channel_identities.sql create mode 100644 db/queries/channel_routes.sql delete mode 100644 db/queries/contacts.sql create mode 100644 db/queries/conversations.sql delete mode 100644 db/queries/history.sql create mode 100644 db/queries/messages.sql create mode 100644 docs/docs/concepts/identity-and-binding.md create mode 100644 docs/docs/concepts/index.md create mode 100644 docs/docs/style/terminology.md create mode 100644 docs/docs/zh/concepts/identity-and-binding.md create mode 100644 docs/docs/zh/concepts/index.md create mode 100644 docs/docs/zh/style/terminology.md rename internal/{users => accounts}/service.go (54%) rename internal/{users => accounts}/types.go (68%) create mode 100644 internal/bind/service.go create mode 100644 internal/bind/service_integration_test.go create mode 100644 internal/bind/service_test.go create mode 100644 internal/bind/types.go create mode 100644 internal/channel/adapters/feishu/directory.go create mode 100644 internal/channel/adapters/feishu/directory_test.go create mode 100644 internal/channel/adapters/feishu/inbound.go rename internal/channel/adapters/feishu/{feishu_logger.go => logger.go} (100%) create mode 100644 internal/channel/adapters/feishu/stream.go create mode 100644 internal/channel/adapters/local/hub_test.go create mode 100644 internal/channel/adapters/telegram/directory.go create mode 100644 internal/channel/adapters/telegram/directory_test.go create mode 100644 internal/channel/adapters/telegram/logger.go create mode 100644 internal/channel/adapters/telegram/logger_test.go create mode 100644 internal/channel/adapters/telegram/markdown.go create mode 100644 internal/channel/adapters/telegram/markdown_test.go create mode 100644 internal/channel/adapters/telegram/stream.go create mode 100644 internal/channel/adapters/telegram/stream_test.go create mode 100644 internal/channel/identities/service.go create mode 100644 internal/channel/identities/service_identity_integration_test.go create mode 100644 internal/channel/identities/service_integration_test.go create mode 100644 internal/channel/identities/service_test.go create mode 100644 internal/channel/identities/types.go create mode 100644 internal/channel/inbound_test.go delete mode 100644 internal/channel/manager_core_test.go create mode 100644 internal/channel/registry_test.go create mode 100644 internal/channel/route/service.go create mode 100644 internal/channel/route/types.go create mode 100644 internal/chat/resolver_memory_context_test.go create mode 100644 internal/chat/service.go create mode 100644 internal/chat/service_presence_integration_test.go delete mode 100644 internal/contacts/service.go delete mode 100644 internal/contacts/types.go create mode 100644 internal/conversation/flow/assistant_output.go create mode 100644 internal/conversation/flow/resolver.go create mode 100644 internal/conversation/flow/resolver_memory_context_test.go create mode 100644 internal/conversation/flow/resolver_test.go create mode 100644 internal/conversation/flow/schedule_gateway.go create mode 100644 internal/conversation/flow/types.go create mode 100644 internal/conversation/flow/types_alias.go create mode 100644 internal/conversation/interfaces.go create mode 100644 internal/conversation/resolver.go create mode 100644 internal/conversation/service_domain.go create mode 100644 internal/conversation/service_presence_integration_test.go create mode 100644 internal/conversation/types.go create mode 100644 internal/db/sqlc/bind.sql.go create mode 100644 internal/db/sqlc/channel_identities.sql.go create mode 100644 internal/db/sqlc/channel_routes.sql.go delete mode 100644 internal/db/sqlc/contacts.sql.go create mode 100644 internal/db/sqlc/conversations.sql.go delete mode 100644 internal/db/sqlc/history.sql.go create mode 100644 internal/db/sqlc/messages.sql.go create mode 100644 internal/db/text.go create mode 100644 internal/db/text_test.go delete mode 100644 internal/directory/service.go create mode 100644 internal/handlers/bind.go delete mode 100644 internal/handlers/chat.go delete mode 100644 internal/handlers/contacts.go create mode 100644 internal/handlers/fs_mcp_session_test.go delete mode 100644 internal/handlers/fs_rest.go delete mode 100644 internal/handlers/history.go create mode 100644 internal/handlers/mcp_federation_gateway.go create mode 100644 internal/handlers/mcp_federation_gateway_test.go create mode 100644 internal/handlers/mcp_tools.go create mode 100644 internal/handlers/mcp_tools_test.go create mode 100644 internal/handlers/message.go delete mode 100644 internal/history/service.go delete mode 100644 internal/history/types.go create mode 100644 internal/mcp/providers/container/fsops.go create mode 100644 internal/mcp/providers/container/fsops_test.go create mode 100644 internal/mcp/providers/container/provider.go create mode 100644 internal/mcp/providers/container/provider_test.go create mode 100644 internal/mcp/providers/directory/provider.go create mode 100644 internal/mcp/providers/directory/provider_test.go create mode 100644 internal/mcp/providers/memory/provider.go create mode 100644 internal/mcp/providers/memory/provider_test.go create mode 100644 internal/mcp/providers/message/provider.go create mode 100644 internal/mcp/providers/message/provider_test.go create mode 100644 internal/mcp/providers/schedule/provider.go create mode 100644 internal/mcp/providers/schedule/provider_test.go create mode 100644 internal/mcp/sources/federation/source.go create mode 100644 internal/mcp/sources/federation/source_test.go create mode 100644 internal/mcp/tool_gateway_service.go create mode 100644 internal/mcp/tool_gateway_service_test.go create mode 100644 internal/mcp/tool_registry.go create mode 100644 internal/mcp/tool_registry_test.go create mode 100644 internal/mcp/tool_types.go delete mode 100644 internal/mcp/tools.go create mode 100644 internal/message/event/hub.go create mode 100644 internal/message/event/hub_test.go create mode 100644 internal/message/service.go create mode 100644 internal/message/types.go create mode 100644 packages/shared/README.md create mode 100644 packages/shared/package.json create mode 100644 packages/shared/src/chatInfo.ts create mode 100644 packages/shared/src/index.ts create mode 100644 packages/shared/src/mcp.ts create mode 100644 packages/shared/src/model.ts create mode 100644 packages/shared/src/platform.ts create mode 100644 packages/shared/src/schedule.ts create mode 100644 packages/web/public/channels/feishu.png create mode 100644 packages/web/public/channels/telegram.webp create mode 100644 packages/web/src/components/Sidebar/index.vue create mode 100644 packages/web/src/components/Sidebar/lists/chat-list-menu.vue create mode 100644 packages/web/src/components/Sidebar/lists/settings-list-menu.vue create mode 100644 packages/web/src/components/Sidebar/lists/types.ts create mode 100644 packages/web/src/components/chat-list/channel-badge/index.vue create mode 100644 packages/web/src/components/chat-list/robot-chat/index.vue delete mode 100644 packages/web/src/components/sidebar/index.vue create mode 100644 packages/web/src/composables/api/useAuth.ts create mode 100644 packages/web/src/composables/api/useBotSettings.ts create mode 100644 packages/web/src/composables/api/useBots.ts create mode 100644 packages/web/src/composables/api/useChannels.ts create mode 100644 packages/web/src/composables/api/useModels.ts create mode 100644 packages/web/src/composables/api/useProviders.ts create mode 100644 packages/web/src/composables/api/useUsers.ts create mode 100644 packages/web/src/pages/settings/user.vue create mode 100644 packages/web/src/store/User.ts create mode 100644 packages/web/src/types/index.ts create mode 100644 packages/web/src/utils/channel-icons.ts diff --git a/DEPLOYMENT.md b/DEPLOYMENT.md index 6b57052b..a9801af8 100644 --- a/DEPLOYMENT.md +++ b/DEPLOYMENT.md @@ -75,11 +75,10 @@ Advantages: ### Using Docker Compose ```bash -docker compose up -d # Start services -docker compose down # Stop services -docker compose logs -f # View logs -docker compose ps # View status -docker compose restart # Restart services +docker compose up -d # Start +docker compose down # Stop +docker compose logs -f # View logs +nerdctl images # Ensure that memoh-mcp:latest exsits ``` ### Bot Container Management diff --git a/agent/src/agent.ts b/agent/src/agent.ts index 66ea9015..3866575d 100644 --- a/agent/src/agent.ts +++ b/agent/src/agent.ts @@ -1,10 +1,9 @@ import { generateText, ImagePart, LanguageModelUsage, ModelMessage, stepCountIs, streamText, UserModelMessage } from 'ai' -import { AgentInput, AgentParams, AgentSkill, allActions, HTTPMCPConnection, MCPConnection, Schedule, StdioMCPConnection } from './types' +import { AgentInput, AgentParams, AgentSkill, allActions, Schedule } from './types' import { system, schedule, user, subagentSystem } from './prompts' import { AuthFetcher } from './index' import { createModel } from './model' import { AgentAction } from './types/action' -import { getTools } from './tools' import { extractAttachmentsFromText, stripAttachmentsFromMessages, @@ -21,15 +20,13 @@ export const createAgent = ({ language = 'Same as the user input', allowedActions = allActions, channels = [], - mcpConnections = [], skills = [], currentChannel = 'Unknown Channel', identity = { botId: '', - sessionId: '', containerId: '', - contactId: '', - contactName: '', + channelIdentityId: '', + displayName: '', }, auth, }: AgentParams, fetch: AuthFetcher) => { @@ -47,18 +44,6 @@ export const createAgent = ({ return enabledSkills.map(skill => skill.name) } - const getDefaultMCPConnections = (): MCPConnection[] => { - const fs: HTTPMCPConnection = { - type: 'http', - name: 'fs', - url: `${auth.baseUrl}/bots/${identity.botId}/container/fs-mcp`, - headers: { - 'Authorization': `Bearer ${auth.bearer}`, - }, - } - return [fs] - } - const loadSystemFiles = async () => { if (!auth?.bearer || !identity.botId) { return { @@ -67,18 +52,43 @@ export const createAgent = ({ toolsContent: '', } } - const fetchFile = async (path: string) => { - const response = await fetch(`/bots/${identity.botId}/container/fs/file?path=${encodeURIComponent(path)}`) - if (!response.ok) { - return '' + const readViaMCP = async (path: string): Promise => { + const url = `${auth.baseUrl.replace(/\/$/, '')}/bots/${identity.botId}/tools` + const headers: Record = { + 'Content-Type': 'application/json', + 'Accept': 'application/json, text/event-stream', + 'Authorization': `Bearer ${auth.bearer}`, } - const data = await response.json().catch(() => ({} as { content?: string })) - return typeof data?.content === 'string' ? data.content : '' + if (identity.channelIdentityId) { + headers['X-Memoh-Channel-Identity-Id'] = identity.channelIdentityId + } + const body = JSON.stringify({ + jsonrpc: '2.0', + id: `read-${path}`, + method: 'tools/call', + params: { name: 'read', arguments: { path } }, + }) + const response = await fetch(url, { method: 'POST', headers, body }) + if (!response.ok) return '' + const data = await response.json().catch(() => ({} as any)) + const structured = data?.result?.structuredContent ?? data?.result?.content?.[0]?.text + if (typeof structured === 'string') { + try { + const parsed = JSON.parse(structured) + return typeof parsed?.content === 'string' ? parsed.content : '' + } catch { + return structured + } + } + if (typeof structured === 'object' && structured?.content) { + return typeof structured.content === 'string' ? structured.content : '' + } + return '' } const [identityContent, soulContent, toolsContent] = await Promise.all([ - fetchFile('IDENTITY.md'), - fetchFile('SOUL.md'), - fetchFile('TOOLS.md'), + readViaMCP('IDENTITY.md'), + readViaMCP('SOUL.md'), + readViaMCP('TOOLS.md'), ]) return { identityContent, @@ -94,6 +104,7 @@ export const createAgent = ({ language, maxContextLoadTime: activeContextTime, channels, + currentChannel, skills, enabledSkills, identityContent, @@ -103,25 +114,32 @@ export const createAgent = ({ } const getAgentTools = async () => { - const tools = getTools(allowedActions, { - fetch, - model: modelConfig, - brave, - identity, - enableSkill, - }) - const defaultMCPConnections = getDefaultMCPConnections() - const { tools: mcpTools, close: closeMCP } = await getMCPTools([ - ...defaultMCPConnections, - ...mcpConnections, - ], { - botId: identity.botId, - auth, - fetch, - }) - Object.assign(tools, mcpTools) + const baseUrl = auth.baseUrl.replace(/\/$/, '') + const botId = identity.botId.trim() + if (!baseUrl || !botId) { + return { + tools: {}, + close: async () => {}, + } + } + const headers: Record = { + 'Authorization': `Bearer ${auth.bearer}`, + } + if (identity.channelIdentityId) { + headers['X-Memoh-Channel-Identity-Id'] = identity.channelIdentityId + } + if (identity.sessionToken) { + headers['X-Memoh-Session-Token'] = identity.sessionToken + } + if (identity.currentPlatform) { + headers['X-Memoh-Current-Platform'] = identity.currentPlatform + } + if (identity.replyTarget) { + headers['X-Memoh-Reply-Target'] = identity.replyTarget + } + const { tools: mcpTools, close: closeMCP } = await getMCPTools(`${baseUrl}/bots/${botId}/tools`, headers) return { - tools, + tools: mcpTools, close: closeMCP, } } @@ -130,8 +148,8 @@ export const createAgent = ({ const images = input.attachments.filter(attachment => attachment.type === 'image') const files = input.attachments.filter((a): a is ContainerFileAttachment => a.type === 'file') const text = user(input.query, { - contactId: identity.contactId, - contactName: identity.contactName, + channelIdentityId: identity.channelIdentityId || identity.contactId || '', + displayName: identity.displayName || identity.contactName || 'User', channel: currentChannel, date: new Date(), attachments: files, @@ -171,7 +189,7 @@ export const createAgent = ({ const { messages: strippedMessages, attachments: messageAttachments } = stripAttachmentsFromMessages(response.messages) const allAttachments = dedupeAttachments([...textAttachments, ...messageAttachments]) return { - messages: [userPrompt, ...strippedMessages], + messages: strippedMessages, reasoning: reasoning.map(part => part.text), usage, text: cleanedText, @@ -258,6 +276,28 @@ export const createAgent = ({ } } + const resolveStreamErrorMessage = (raw: unknown): string => { + if (raw instanceof Error && raw.message.trim()) { + return raw.message + } + if (typeof raw === 'string' && raw.trim()) { + return raw + } + if (raw && typeof raw === 'object') { + const candidate = raw as { message?: unknown; error?: unknown } + if (typeof candidate.message === 'string' && candidate.message.trim()) { + return candidate.message + } + if (typeof candidate.error === 'string' && candidate.error.trim()) { + return candidate.error + } + if (candidate.error instanceof Error && candidate.error.message.trim()) { + return candidate.error.message + } + } + return 'Model stream failed' + } + async function* stream(input: AgentInput): AsyncGenerator { const userPrompt = generateUserPrompt(input) const messages = [...input.messages, userPrompt] @@ -297,6 +337,9 @@ export const createAgent = ({ input, } for await (const chunk of fullStream) { + if (chunk.type === 'error') { + throw new Error(resolveStreamErrorMessage((chunk as { error?: unknown }).error)) + } switch (chunk.type) { case 'reasoning-start': yield { type: 'reasoning_start', @@ -376,7 +419,7 @@ export const createAgent = ({ const { messages: strippedMessages } = stripAttachmentsFromMessages(result.messages) yield { type: 'agent_end', - messages: [userPrompt, ...strippedMessages], + messages: strippedMessages, reasoning: result.reasoning, usage: result.usage!, skills: getEnabledSkills(), diff --git a/agent/src/config.ts b/agent/src/config.ts index 3a23b5f6..9875a18e 100644 --- a/agent/src/config.ts +++ b/agent/src/config.ts @@ -10,7 +10,7 @@ type AgentGatewayConfig = { 'server': { addr?: string }, - 'brave': { + 'brave'?: { api_key?: string base_url?: string } diff --git a/agent/src/index.ts b/agent/src/index.ts index 1ebf9d07..eef30d5a 100644 --- a/agent/src/index.ts +++ b/agent/src/index.ts @@ -7,9 +7,14 @@ import { loadConfig } from './config' const config = loadConfig('../config.toml') export const getBraveConfig = () => { + const apiKey = config.brave?.api_key?.trim() ?? '' + if (!apiKey) { + return undefined + } + const baseUrl = config.brave?.base_url?.trim() || 'https://api.search.brave.com/res/v1/' return { - apiKey: config.brave.api_key ?? '', - baseUrl: config.brave.base_url ?? 'https://api.search.brave.com/res/v1/', + apiKey, + baseUrl, } } @@ -44,16 +49,16 @@ export const createAuthFetcher = (bearer: string | undefined): AuthFetcher => { return async (url: string, options?: RequestInit) => { const requestOptions = options ?? {} const headers = new Headers(requestOptions.headers || {}) - if (bearer) { + if (bearer && !headers.has('Authorization')) { headers.set('Authorization', `Bearer ${bearer}`) } - const requestUrl = new URL( - url, - `${getBaseUrl().replace(/\/+$/, '')}/`, - ).toString() + const baseURL = getBaseUrl() + const requestURL = /^https?:\/\//i.test(url) + ? url + : new URL(url, `${baseURL.replace(/\/$/, '')}/`).toString() - return await fetch(requestUrl, { + return await fetch(requestURL, { ...requestOptions, headers, }) diff --git a/agent/src/models.ts b/agent/src/models.ts index f4aca341..d7403c97 100644 --- a/agent/src/models.ts +++ b/agent/src/models.ts @@ -22,10 +22,11 @@ export const AllowedActionModel = z.enum(allActions) export const IdentityContextModel = z.object({ botId: z.string().min(1, 'Bot ID is required'), - sessionId: z.string().min(1, 'Session ID is required'), containerId: z.string().min(1, 'Container ID is required'), - contactId: z.string().min(1, 'Contact ID is required'), - contactName: z.string().min(1, 'Contact name is required'), + channelIdentityId: z.string().min(1, 'Channel identity ID is required'), + displayName: z.string().min(1, 'Display name is required'), + contactId: z.string().optional(), + contactName: z.string().optional(), contactAlias: z.string().optional(), userId: z.string().optional(), currentPlatform: z.string().optional(), diff --git a/agent/src/modules/chat.ts b/agent/src/modules/chat.ts index a67fac09..150d83df 100644 --- a/agent/src/modules/chat.ts +++ b/agent/src/modules/chat.ts @@ -4,7 +4,7 @@ import { createAgent } from '../agent' import { createAuthFetcher, getBaseUrl, getBraveConfig } from '../index' import { ModelConfig } from '../types' import { bearerMiddleware } from '../middlewares/bearer' -import { AgentSkillModel, AllowedActionModel, AttachmentModel, IdentityContextModel, MCPConnectionModel, ModelConfigModel, ScheduleModel } from '../models' +import { AgentSkillModel, AllowedActionModel, AttachmentModel, IdentityContextModel, ModelConfigModel, ScheduleModel } from '../models' import { allActions } from '../types' const AgentModel = z.object({ @@ -18,7 +18,6 @@ const AgentModel = z.object({ skills: z.array(z.string()), identity: IdentityContextModel, attachments: z.array(AttachmentModel).optional().default([]), - mcpConnections: z.array(MCPConnectionModel).optional().default([]), }) export const chatModule = new Elysia({ prefix: '/chat' }) @@ -33,7 +32,6 @@ export const chatModule = new Elysia({ prefix: '/chat' }) currentChannel: body.currentChannel, allowedActions: body.allowedActions, identity: body.identity, - mcpConnections: body.mcpConnections, auth: { bearer: bearer!, baseUrl: getBaseUrl(), @@ -63,7 +61,6 @@ export const chatModule = new Elysia({ prefix: '/chat' }) currentChannel: body.currentChannel, allowedActions: body.allowedActions, identity: body.identity, - mcpConnections: body.mcpConnections, auth: { bearer: bearer!, baseUrl: getBaseUrl(), @@ -81,9 +78,12 @@ export const chatModule = new Elysia({ prefix: '/chat' }) } } catch (error) { console.error(error) + const message = error instanceof Error && error.message.trim() + ? error.message + : 'Internal server error' yield sse(JSON.stringify({ type: 'error', - message: 'Internal server error', + message, })) } }, { @@ -99,7 +99,6 @@ export const chatModule = new Elysia({ prefix: '/chat' }) channels: body.channels, currentChannel: body.currentChannel, identity: body.identity, - mcpConnections: body.mcpConnections, auth: { bearer: bearer!, baseUrl: getBaseUrl(), diff --git a/agent/src/prompts/system.ts b/agent/src/prompts/system.ts index 7049b94b..11c8649e 100644 --- a/agent/src/prompts/system.ts +++ b/agent/src/prompts/system.ts @@ -6,6 +6,8 @@ export interface SystemParams { language: string maxContextLoadTime: number channels: string[] + /** Channel where the current session/message is from (e.g. telegram, feishu, web). */ + currentChannel: string skills: AgentSkill[] enabledSkills: AgentSkill[] identityContent?: string @@ -23,11 +25,12 @@ ${skill.content} `.trim() } -export const system = ({ +export const system = ({ date, language, maxContextLoadTime, channels, + currentChannel, skills, enabledSkills, identityContent, @@ -37,6 +40,7 @@ export const system = ({ const headers = { 'language': language, 'available-channels': channels.join(','), + 'current-session-channel': currentChannel, 'max-context-load-time': maxContextLoadTime.toString(), 'time-now': date.toISOString(), } @@ -97,8 +101,12 @@ You have a contacts book to record them that you do not need to worry about who ## Channels +The current session (and the latest user message) is from channel: ${quote(currentChannel)}. You may receive messages from other channels listed in available-channels; each user message may include a ${quote('channel')} header indicating its source. + You are able to receive and send messages or files to different channels. +When you need to resolve a user or group on a channel (e.g. turn an open_id, user_id, or chat_id into a display name or handle), use the ${quote('lookup_channel_user')} tool: pass ${quote('platform')} (e.g. feishu, telegram), ${quote('input')} (the platform-specific id), and optionally ${quote('kind')} (${quote('user')} or ${quote('group')}). It returns name, handle, and id for that entry. + ## Attachments ### Receive diff --git a/agent/src/prompts/user.ts b/agent/src/prompts/user.ts index ac46c23b..742b8a0b 100644 --- a/agent/src/prompts/user.ts +++ b/agent/src/prompts/user.ts @@ -1,8 +1,8 @@ import { ContainerFileAttachment } from '../types' export interface UserParams { - contactId: string - contactName: string + channelIdentityId: string + displayName: string channel: string date: Date attachments: ContainerFileAttachment[] @@ -10,11 +10,11 @@ export interface UserParams { export const user = ( query: string, - { contactId, contactName, channel, date, attachments }: UserParams + { channelIdentityId, displayName, channel, date, attachments }: UserParams ) => { const headers = { - 'contact-id': contactId, - 'contact-name': contactName, + 'channel-identity-id': channelIdentityId, + 'display-name': displayName, 'channel': channel, 'time': date.toISOString(), 'attachments': attachments.map(attachment => attachment.path), diff --git a/agent/src/test/unified_mcp_tools.test.ts b/agent/src/test/unified_mcp_tools.test.ts new file mode 100644 index 00000000..a3c7b319 --- /dev/null +++ b/agent/src/test/unified_mcp_tools.test.ts @@ -0,0 +1,89 @@ +import { describe, expect, test } from 'bun:test' +import { getMCPTools } from '../tools/mcp' + +describe('getMCPTools (unified endpoint)', () => { + test('loads tools from unified MCP HTTP endpoint', async () => { + const seenMethods: string[] = [] + const seenAuthHeaders: string[] = [] + + const server = Bun.serve({ + port: 0, + async fetch(request) { + seenAuthHeaders.push(request.headers.get('authorization') ?? '') + const body = await request.json().catch(() => ({} as Record)) + const method = typeof body?.method === 'string' ? body.method : '' + seenMethods.push(method) + + if (method === 'initialize') { + return Response.json({ + jsonrpc: '2.0', + id: body.id ?? null, + result: { + protocolVersion: '2025-06-18', + capabilities: { + tools: { + listChanged: false, + }, + }, + serverInfo: { + name: 'test-mcp', + version: '1.0.0', + }, + }, + }) + } + + if (method === 'notifications/initialized') { + return new Response(null, { status: 202 }) + } + + if (method === 'tools/list') { + return Response.json({ + jsonrpc: '2.0', + id: body.id ?? null, + result: { + tools: [ + { + name: 'search_memory', + description: 'Search memory', + inputSchema: { + type: 'object', + properties: { + query: { type: 'string' }, + }, + required: ['query'], + }, + }, + ], + }, + }) + } + + return Response.json({ + jsonrpc: '2.0', + id: body.id ?? null, + error: { + code: -32601, + message: 'method not found', + }, + }) + }, + }) + + try { + const endpoint = `http://127.0.0.1:${server.port}/bots/bot-1/tools` + const { tools, close } = await getMCPTools(endpoint, { + Authorization: 'Bearer test-token', + }) + + expect(Object.keys(tools)).toContain('search_memory') + expect(seenMethods).toContain('initialize') + expect(seenMethods).toContain('tools/list') + expect(seenAuthHeaders.some(value => value === 'Bearer test-token')).toBe(true) + + await close() + } finally { + server.stop(true) + } + }) +}) diff --git a/agent/src/tools/contact.ts b/agent/src/tools/contact.ts index 205eb280..de5e1c6a 100644 --- a/agent/src/tools/contact.ts +++ b/agent/src/tools/contact.ts @@ -11,100 +11,57 @@ export type ContactToolParams = { export const getContactTools = ({ fetch, identity }: ContactToolParams) => { const botId = identity.botId.trim() + const listMyIdentities = async () => { + const response = await fetch('/users/me/identities') + return response.json() + } + const contactSearch = tool({ - description: 'Search contacts by name or alias', + description: 'Search identity cards by platform, external id, or display name', inputSchema: z.object({ - query: z.string().describe('The query to search for contacts'), + query: z.string().describe('The query to search identities').optional().default(''), }), execute: async ({ query }) => { - const url = `/bots/${botId}/contacts?q=${encodeURIComponent(query)}` - const response = await fetch(url) - return response.json() + const payload = await listMyIdentities() + const keyword = query.trim().toLowerCase() + const items = Array.isArray(payload?.items) ? payload.items : [] + const filtered = keyword + ? items.filter((item: { platform?: string; external_id?: string; display_name?: string }) => { + const platform = String(item?.platform ?? '').toLowerCase() + const externalID = String(item?.external_id ?? '').toLowerCase() + const displayName = String(item?.display_name ?? '').toLowerCase() + return platform.includes(keyword) || externalID.includes(keyword) || displayName.includes(keyword) + }) + : items + return { + canonical_channel_identity_id: payload?.canonical_channel_identity_id ?? '', + total: filtered.length, + items: filtered, + } }, }) - const contactCreate = tool({ - description: 'Create a contact', + const contactCardMe = tool({ + description: 'Get my canonical identity card and all linked channel identities', + inputSchema: z.object({}), + execute: async () => { + return listMyIdentities() + }, + }) + + const contactIssueBindCode = tool({ + description: 'Issue a bind code for linking current channel identity to this account', inputSchema: z.object({ - name: z.string().describe('The display name of the contact'), - alias: z.string().describe('The alias of the contact').optional(), - tags: z.array(z.string()).describe('The tags of the contact').optional(), + ttl_seconds: z.number().int().positive().optional().describe('Bind code ttl in seconds'), }), - execute: async ({ name, alias, tags }) => { - const response = await fetch(`/bots/${botId}/contacts`, { + execute: async ({ ttl_seconds }) => { + if (!botId) { + throw new Error('bot_id is required') + } + const response = await fetch(`/bots/${botId}/bind_codes`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - display_name: name, - alias: alias, - tags: tags ?? [], - }), - }) - return response.json() - }, - }) - - const contactUpdate = tool({ - description: 'Update a contact', - inputSchema: z.object({ - contact_id: z.string().describe('The ID of the contact to update'), - name: z.string().describe('The display name of the contact').optional(), - alias: z.string().describe('The alias of the contact').optional(), - tags: z.array(z.string()).describe('The tags of the contact').optional(), - }), - execute: async ({ contact_id, name, alias, tags }) => { - const response = await fetch(`/bots/${botId}/contacts/${contact_id}`, { - method: 'PATCH', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - display_name: name, - alias: alias, - tags: tags ?? [], - }), - }) - return response.json() - }, - }) - - // const contactBindToken = tool({ - // description: 'Issue a one-time bind token for a contact', - // inputSchema: z.object({ - // contact_id: ContactID, - // target_platform: z.string().describe('The platform to bind the contact to'), - // target_external_id: z.string().describe('The external ID of the contact'), - // ttl_seconds: z.number().describe('The number of seconds the bind token is valid').optional(), - // }), - // execute: async ({ bot_id, contact_id, target_platform, target_external_id, ttl_seconds }) => { - // const response = await fetch(`/bots/${botId}/contacts/${contact_id}/bind_token`, { - // method: 'POST', - // headers: { 'Content-Type': 'application/json' }, - // body: JSON.stringify({ - // target_platform: target_platform, - // target_external_id: target_external_id, - // ttl_seconds: ttl_seconds, - // }), - // }) - // return response.json() - // }, - // }) - - const contactBind = tool({ - description: 'Bind a contact to a platform identity using a bind token', - inputSchema: z.object({ - contact_id: z.string().describe('The ID of the contact to bind'), - platform: z.string().describe('The platform to bind the contact to'), - external_id: z.string().describe('The external ID of the contact'), - bind_token: z.string().describe('The bind token to use'), - }), - execute: async ({ contact_id, platform, external_id, bind_token }) => { - const response = await fetch(`/bots/${botId}/contacts/${contact_id}/bind`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - platform: platform, - external_id: external_id, - bind_token: bind_token, - }), + body: JSON.stringify({ ttl_seconds }), }) return response.json() }, @@ -112,9 +69,7 @@ export const getContactTools = ({ fetch, identity }: ContactToolParams) => { return { 'contact_search': contactSearch, - 'contact_create': contactCreate, - 'contact_update': contactUpdate, - // 'contact_bind_token': contactBindToken, - 'contact_bind': contactBind, + 'contact_card_me': contactCardMe, + 'contact_issue_bind_code': contactIssueBindCode, } } diff --git a/agent/src/tools/index.ts b/agent/src/tools/index.ts index 34422aed..80456ac3 100644 --- a/agent/src/tools/index.ts +++ b/agent/src/tools/index.ts @@ -1,5 +1,5 @@ import { AuthFetcher } from '..' -import { AgentAction, BraveConfig, IdentityContext, ModelConfig } from '../types' +import { AgentAction, AgentAuthContext, BraveConfig, IdentityContext, ModelConfig } from '../types' import { ToolSet } from 'ai' import { getWebTools } from './web' import { getScheduleTools } from './schedule' @@ -14,12 +14,13 @@ export interface ToolsParams { model: ModelConfig brave?: BraveConfig identity: IdentityContext + auth: AgentAuthContext enableSkill: (skill: string) => void } export const getTools = ( actions: AgentAction[], - { fetch, model, brave, identity, enableSkill }: ToolsParams + { fetch, model, brave, identity, auth, enableSkill }: ToolsParams ) => { const tools: ToolSet = {} if (actions.includes(AgentAction.Web) && brave) { @@ -31,11 +32,11 @@ export const getTools = ( Object.assign(tools, scheduleTools) } if (actions.includes(AgentAction.Memory)) { - const memoryTools = getMemoryTools({ fetch }) + const memoryTools = getMemoryTools({ fetch, identity }) Object.assign(tools, memoryTools) } if (actions.includes(AgentAction.Subagent)) { - const subagentTools = getSubagentTools({ fetch, model, brave, identity }) + const subagentTools = getSubagentTools({ fetch, model, brave, identity, auth }) Object.assign(tools, subagentTools) } if (actions.includes(AgentAction.Contact)) { diff --git a/agent/src/tools/mcp.ts b/agent/src/tools/mcp.ts index 1bb4466f..a6a88447 100644 --- a/agent/src/tools/mcp.ts +++ b/agent/src/tools/mcp.ts @@ -1,101 +1,17 @@ -import { HTTPMCPConnection, MCPConnection, SSEMCPConnection, StdioMCPConnection } from '../types' import { createMCPClient } from '@ai-sdk/mcp' -import { AuthFetcher } from '../index' -import type { AgentAuthContext } from '../types/agent' - -type MCPToolOptions = { - botId?: string - auth?: AgentAuthContext - fetch?: AuthFetcher -} - -export const getMCPTools = async (connections: MCPConnection[], options: MCPToolOptions = {}) => { - const closeCallbacks: Array<() => Promise> = [] - - const getHTTPTools = async (connection: HTTPMCPConnection) => { - const client = await createMCPClient({ - transport: { - type: 'http', - url: connection.url, - headers: connection.headers, - } - }) - closeCallbacks.push(() => client.close()) - const tools = await client.tools() - return tools - } - - const getSSETools = async (connection: SSEMCPConnection) => { - const client = await createMCPClient({ - transport: { - type: 'sse', - url: connection.url, - headers: connection.headers, - } - }) - closeCallbacks.push(() => client.close()) - const tools = await client.tools() - return tools - } - - const getStdioTools = async (connection: StdioMCPConnection) => { - if (!options.fetch || !options.botId || !options.auth) { - throw new Error('stdio mcp requires auth fetcher and bot id') - } - const response = await options.fetch(`/bots/${options.botId}/mcp-stdio`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json' - }, - body: JSON.stringify({ - name: connection.name, - command: connection.command, - args: connection.args ?? [], - env: connection.env ?? {}, - cwd: connection.cwd ?? '' - }) - }) - if (!response.ok) { - const text = await response.text().catch(() => '') - throw new Error(`mcp-stdio failed: ${response.status} ${text}`) - } - const data = await response.json().catch(() => ({} as { url?: string })) - const rawUrl = typeof data?.url === 'string' ? data.url : '' - if (!rawUrl) { - throw new Error('mcp-stdio response missing url') - } - const baseUrl = options.auth.baseUrl ?? '' - const url = rawUrl.startsWith('http') - ? rawUrl - : `${baseUrl.replace(/\/$/, '')}/${rawUrl.replace(/^\//, '')}` - return await getHTTPTools({ +export const getMCPTools = async (url: string, headers: Record = {}) => { + const client = await createMCPClient({ + transport: { type: 'http', - name: connection.name, url, - headers: { - 'Authorization': `Bearer ${options.auth.bearer}` - } - }) - } - - const toolSets = await Promise.all(connections.map(async (connection) => { - switch (connection.type) { - case 'http': - return getHTTPTools(connection) - case 'sse': - return getSSETools(connection) - case 'stdio': - return getStdioTools(connection) - default: - console.warn('unknown mcp connection type', connection) - return {} + headers, } - })) - + }) + const tools = await client.tools() return { - tools: Object.assign({}, ...toolSets), + tools, close: async () => { - await Promise.all(closeCallbacks.map(callback => callback())) + await client.close() } } } diff --git a/agent/src/tools/memory.ts b/agent/src/tools/memory.ts index 3936fae9..e2ed7b78 100644 --- a/agent/src/tools/memory.ts +++ b/agent/src/tools/memory.ts @@ -1,9 +1,11 @@ import { tool } from 'ai' import { AuthFetcher } from '..' +import type { IdentityContext } from '../types' import { z } from 'zod' export type MemoryToolParams = { fetch: AuthFetcher + identity: IdentityContext } type MemorySearchItem = { @@ -16,20 +18,26 @@ type MemorySearchItem = { } } -export const getMemoryTools = ({ fetch }: MemoryToolParams) => { +export const getMemoryTools = ({ fetch, identity }: MemoryToolParams) => { const searchMemory = tool({ description: 'Search for memories', inputSchema: z.object({ query: z.string().describe('The query to search for memories'), + limit: z.number().int().positive().max(50).optional(), }), - execute: async ({ query }) => { - const response = await fetch('/memory/search', { + execute: async ({ query, limit }) => { + const botId = identity.botId.trim() + if (!botId) { + throw new Error('botId is required to search memory') + } + const response = await fetch(`/bots/${botId}/memory/search`, { method: 'POST', headers: { 'Content-Type': 'application/json', }, body: JSON.stringify({ query, + limit, }), }) const data = await response.json() diff --git a/agent/src/tools/message.ts b/agent/src/tools/message.ts index 77bbc906..c633c6a7 100644 --- a/agent/src/tools/message.ts +++ b/agent/src/tools/message.ts @@ -12,6 +12,7 @@ const SendMessageSchema = z.object({ bot_id: z.string().optional(), platform: z.string().optional(), target: z.string().optional(), + channel_identity_id: z.string().optional(), to_user_id: z.string().optional(), message: z.string(), }) @@ -25,38 +26,39 @@ export const getMessageTools = ({ fetch, identity }: MessageToolParams) => { const platform = (payload.platform ?? identity.currentPlatform ?? '').trim() const replyTarget = (identity.replyTarget ?? '').trim() const target = (payload.target ?? replyTarget).trim() - const toUserID = (payload.to_user_id ?? '').trim() + const channelIdentityID = (payload.channel_identity_id ?? payload.to_user_id ?? '').trim() if (!botId) { throw new Error('bot_id is required') } if (!platform) { throw new Error('platform is required') } - if (!target && !toUserID && !identity.sessionToken) { - throw new Error('target or to_user_id is required') + // Prefer chat token when there is no explicit target identity. + const useSessionToken = !!identity.sessionToken && !channelIdentityID + if (!target && !channelIdentityID && !useSessionToken) { + throw new Error('target or channel_identity_id is required') } - // Use session token if available and no explicit to_user_id specified - // This allows replying to current session without needing explicit auth - const useSessionToken = !!identity.sessionToken && !toUserID console.log('[Tool] send_message', { botId, platform, target: target || undefined, - toUserID: toUserID || undefined, + channelIdentityID: channelIdentityID || undefined, replyTarget, useSessionToken, }) - const body: Record = { message: payload.message } - if (!useSessionToken) { - if (target) { - body.to = target - } - if (toUserID) { - body.to_user_id = toUserID - } + const body: Record = { + message: { + text: payload.message, + }, + } + if (target) { + body.target = target + } + if (channelIdentityID) { + body.channel_identity_id = channelIdentityID } const url = useSessionToken - ? `/bots/${botId}/channel/${platform}/send_session` + ? `/bots/${botId}/channel/${platform}/send_chat` : `/bots/${botId}/channel/${platform}/send` const headers: Record = { 'Content-Type': 'application/json' } if (useSessionToken && identity.sessionToken) { diff --git a/agent/src/tools/subagent.ts b/agent/src/tools/subagent.ts index fe587e2a..65321660 100644 --- a/agent/src/tools/subagent.ts +++ b/agent/src/tools/subagent.ts @@ -1,7 +1,7 @@ import { tool } from 'ai' import { z } from 'zod' import { createAgent } from '../agent' -import { ModelConfig, BraveConfig } from '../types' +import { ModelConfig, BraveConfig, AgentAuthContext } from '../types' import { AuthFetcher } from '..' import { AgentAction, IdentityContext } from '../types/agent' @@ -10,14 +10,21 @@ export interface SubagentToolParams { model: ModelConfig brave?: BraveConfig identity: IdentityContext + auth: AgentAuthContext } -export const getSubagentTools = ({ fetch, model, brave, identity }: SubagentToolParams) => { +export const getSubagentTools = ({ fetch, model, brave, identity, auth }: SubagentToolParams) => { + const botId = identity.botId.trim() + const base = `/bots/${botId}/subagents` + const listSubagents = tool({ description: 'List subagents for current user', inputSchema: z.object({}), execute: async () => { - const response = await fetch('/subagents', { method: 'GET' }) + if (!botId) { + throw new Error('bot_id is required') + } + const response = await fetch(base, { method: 'GET' }) return response.json() }, }) @@ -29,7 +36,10 @@ export const getSubagentTools = ({ fetch, model, brave, identity }: SubagentTool description: z.string(), }), execute: async ({ name, description }) => { - const response = await fetch('/subagents', { + if (!botId) { + throw new Error('bot_id is required') + } + const response = await fetch(base, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ name, description }), @@ -44,7 +54,10 @@ export const getSubagentTools = ({ fetch, model, brave, identity }: SubagentTool id: z.string().describe('Subagent ID'), }), execute: async ({ id }) => { - const response = await fetch(`/subagents/${id}`, { method: 'DELETE' }) + if (!botId) { + throw new Error('bot_id is required') + } + const response = await fetch(`${base}/${id}`, { method: 'DELETE' }) return response.status === 204 ? { success: true } : response.json() }, }) @@ -56,14 +69,17 @@ export const getSubagentTools = ({ fetch, model, brave, identity }: SubagentTool query: z.string().describe('The prompt to ask the subagent to do.'), }), execute: async ({ name, query }) => { - const listResponse = await fetch('/subagents', { method: 'GET' }) + if (!botId) { + throw new Error('bot_id is required') + } + const listResponse = await fetch(base, { method: 'GET' }) const listPayload = await listResponse.json() const items = Array.isArray(listPayload?.items) ? listPayload.items : [] const target = items.find((item: { name?: string }) => item?.name === name) if (!target?.id) { throw new Error(`subagent not found: ${name}`) } - const contextResponse = await fetch(`/subagents/${target.id}/context`, { method: 'GET' }) + const contextResponse = await fetch(`${base}/${target.id}/context`, { method: 'GET' }) const contextPayload = await contextResponse.json() const contextMessages = Array.isArray(contextPayload?.messages) ? contextPayload.messages : [] const { askAsSubagent } = createAgent({ @@ -73,6 +89,7 @@ export const getSubagentTools = ({ fetch, model, brave, identity }: SubagentTool AgentAction.Web, ], identity, + auth, }, fetch) const result = await askAsSubagent({ messages: contextMessages, @@ -81,7 +98,7 @@ export const getSubagentTools = ({ fetch, model, brave, identity }: SubagentTool description: target.description, }) const updatedMessages = [...contextMessages, ...result.messages] - await fetch(`/subagents/${target.id}/context`, { + await fetch(`${base}/${target.id}/context`, { method: 'PUT', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ messages: updatedMessages }), diff --git a/agent/src/types/agent.ts b/agent/src/types/agent.ts index 430c3850..ec92ace4 100644 --- a/agent/src/types/agent.ts +++ b/agent/src/types/agent.ts @@ -1,15 +1,16 @@ import { ModelMessage } from 'ai' import { ModelConfig } from './model' import { AgentAttachment } from './attachment' -import { MCPConnection } from './mcp' export interface IdentityContext { botId: string - sessionId: string containerId: string - contactId: string - contactName: string + channelIdentityId: string + displayName: string + + contactId?: string + contactName?: string contactAlias?: string userId?: string @@ -48,7 +49,6 @@ export interface AgentParams { brave?: BraveConfig channels?: string[] currentChannel?: string - mcpConnections?: MCPConnection[] identity?: IdentityContext auth: AgentAuthContext skills?: AgentSkill[] diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 705daca4..1ec0c4d6 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -11,24 +11,41 @@ import ( "time" containerd "github.com/containerd/containerd/v2/client" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + "go.uber.org/fx" + "go.uber.org/fx/fxevent" + "golang.org/x/crypto/bcrypt" + + "github.com/memohai/memoh/internal/accounts" + "github.com/memohai/memoh/internal/bind" "github.com/memohai/memoh/internal/boot" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/channel" "github.com/memohai/memoh/internal/channel/adapters/feishu" "github.com/memohai/memoh/internal/channel/adapters/local" "github.com/memohai/memoh/internal/channel/adapters/telegram" - "github.com/memohai/memoh/internal/chat" + "github.com/memohai/memoh/internal/channel/identities" + "github.com/memohai/memoh/internal/channel/route" "github.com/memohai/memoh/internal/config" - "github.com/memohai/memoh/internal/contacts" ctr "github.com/memohai/memoh/internal/containerd" + "github.com/memohai/memoh/internal/conversation" + "github.com/memohai/memoh/internal/conversation/flow" "github.com/memohai/memoh/internal/db" dbsqlc "github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/embeddings" "github.com/memohai/memoh/internal/handlers" - "github.com/memohai/memoh/internal/history" "github.com/memohai/memoh/internal/logger" "github.com/memohai/memoh/internal/mcp" + mcpcontainer "github.com/memohai/memoh/internal/mcp/providers/container" + mcpdirectory "github.com/memohai/memoh/internal/mcp/providers/directory" + mcpmemory "github.com/memohai/memoh/internal/mcp/providers/memory" + mcpmessage "github.com/memohai/memoh/internal/mcp/providers/message" + mcpschedule "github.com/memohai/memoh/internal/mcp/providers/schedule" + mcpfederation "github.com/memohai/memoh/internal/mcp/sources/federation" "github.com/memohai/memoh/internal/memory" + "github.com/memohai/memoh/internal/message" + "github.com/memohai/memoh/internal/message/event" "github.com/memohai/memoh/internal/models" "github.com/memohai/memoh/internal/policy" "github.com/memohai/memoh/internal/preauth" @@ -38,114 +55,83 @@ import ( "github.com/memohai/memoh/internal/server" "github.com/memohai/memoh/internal/settings" "github.com/memohai/memoh/internal/subagent" - "github.com/memohai/memoh/internal/users" "github.com/memohai/memoh/internal/version" - "go.uber.org/fx" - "go.uber.org/fx/fxevent" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgxpool" - "golang.org/x/crypto/bcrypt" ) -func provideConfig() (config.Config, error) { - cfgPath := os.Getenv("CONFIG_PATH") - cfg, err := config.Load(cfgPath) - if err != nil { - return config.Config{}, fmt.Errorf("load config: %v\n", err) - } - return cfg, nil -} - -func provideLogger(cfg config.Config) *slog.Logger { - logger.Init(cfg.Log.Level, cfg.Log.Format) - return logger.L -} - -func provideContainerdClient(lc fx.Lifecycle, runtimeConfig *boot.RuntimeConfig) (*containerd.Client, error) { - factory := ctr.DefaultClientFactory{SocketPath: runtimeConfig.ContainerdSocketPath} - client, err := factory.New(context.Background()) - if err != nil { - return nil, fmt.Errorf("connect containerd: %w", err) - } - - lc.Append(fx.Hook{ - OnStop: func(ctx context.Context) error { - if err := client.Close(); err != nil { - return fmt.Errorf("close containerd client: %w", err) - } - return nil - }, - }) - return client, nil -} - func main() { fx.New( fx.Provide( provideConfig, boot.ProvideRuntimeConfig, provideLogger, - - // misc provideContainerdClient, provideDBConn, provideDBQueries, + // containerd & mcp infrastructure fx.Annotate(ctr.NewDefaultService, fx.As(new(ctr.Service))), - mcp.NewManager, + provideMCPManager, + // memory pipeline provideMemoryLLM, provideEmbeddingsResolver, provideEmbeddingSetup, provideTextEmbedderForMemory, provideQdrantStore, memory.NewBM25Indexer, - provideChatResolver, - local.NewSessionHub, - provideChannelRegistry, - - provideChannelRouter, - provideChannelManager, - - chat.NewScheduleGateway, - fx.Annotate(func(scheduleGateway *chat.ScheduleGateway) schedule.Triggerer { - return scheduleGateway - }, fx.As(new(schedule.Triggerer))), + provideMemoryService, + // domain services (auto-wired) models.NewService, bots.NewService, - users.NewService, - providers.NewService, + accounts.NewService, settings.NewService, - history.NewService, - contacts.NewService, + providers.NewService, + policy.NewService, preauth.NewService, mcp.NewConnectionService, subagent.NewService, - schedule.NewService, - channel.NewService, - policy.NewService, - provideMemoryService, + conversation.NewService, + identities.NewService, + bind.NewService, + event.NewHub, + // services requiring provide functions + provideRouteService, + provideMessageService, + + // channel infrastructure + local.NewRouteHub, + provideChannelRegistry, + channel.NewService, + provideChannelRouter, + provideChannelManager, + + // conversation flow + provideChatResolver, + provideScheduleTriggerer, + schedule.NewService, + + // containerd handler & tool gateway + provideContainerdHandler, + provideToolGatewayService, + + // http handlers (group:"server_handlers") provideServerHandler(handlers.NewPingHandler), - provideServerHandler(handlers.NewAuthHandler), + provideServerHandler(provideAuthHandler), provideServerHandler(handlers.NewMemoryHandler), provideServerHandler(handlers.NewEmbeddingsHandler), - provideServerHandler(handlers.NewChatHandler), + provideServerHandler(provideMessageHandler), provideServerHandler(handlers.NewSwaggerHandler), provideServerHandler(handlers.NewProvidersHandler), provideServerHandler(handlers.NewModelsHandler), provideServerHandler(handlers.NewSettingsHandler), - provideServerHandler(handlers.NewHistoryHandler), - provideServerHandler(handlers.NewContactsHandler), provideServerHandler(handlers.NewPreauthHandler), + provideServerHandler(handlers.NewBindHandler), provideServerHandler(handlers.NewScheduleHandler), provideServerHandler(handlers.NewSubagentHandler), - handlers.NewContainerdHandler, - provideServerHandler(handlers.NewContainerdHandler), provideServerHandler(handlers.NewChannelHandler), - provideServerHandler(handlers.NewUsersHandler), + provideServerHandler(provideUsersHandler), provideServerHandler(handlers.NewMCPHandler), provideServerHandler(provideCLIHandler), provideServerHandler(provideWebHandler), @@ -156,16 +142,19 @@ func main() { startMemoryWarmup, startScheduleService, startChannelManager, + startContainerReconciliation, startServer, ), fx.WithLogger(func(logger *slog.Logger) fxevent.Logger { - l := &fxevent.SlogLogger{Logger: logger.With(slog.String("component", "fx"))} - // l.UseLogLevel(slog.LevelInfo) - return l + return &fxevent.SlogLogger{Logger: logger.With(slog.String("component", "fx"))} }), ).Run() } +// --------------------------------------------------------------------------- +// fx helper +// --------------------------------------------------------------------------- + func provideServerHandler(fn any) any { return fx.Annotate( fn, @@ -174,10 +163,40 @@ func provideServerHandler(fn any) any { ) } -func provideDBConn(lc fx.Lifecycle, cfg config.Config) (*pgxpool.Pool, error) { - ctx := context.Background() // TODO: use timeout context +// --------------------------------------------------------------------------- +// infrastructure providers +// --------------------------------------------------------------------------- - conn, err := db.Open(ctx, cfg.Postgres) +func provideConfig() (config.Config, error) { + cfgPath := os.Getenv("CONFIG_PATH") + cfg, err := config.Load(cfgPath) + if err != nil { + return config.Config{}, fmt.Errorf("load config: %w", err) + } + return cfg, nil +} + +func provideLogger(cfg config.Config) *slog.Logger { + logger.Init(cfg.Log.Level, cfg.Log.Format) + return logger.L +} + +func provideContainerdClient(lc fx.Lifecycle, rc *boot.RuntimeConfig) (*containerd.Client, error) { + factory := ctr.DefaultClientFactory{SocketPath: rc.ContainerdSocketPath} + client, err := factory.New(context.Background()) + if err != nil { + return nil, fmt.Errorf("connect containerd: %w", err) + } + lc.Append(fx.Hook{ + OnStop: func(ctx context.Context) error { + return client.Close() + }, + }) + return client, nil +} + +func provideDBConn(lc fx.Lifecycle, cfg config.Config) (*pgxpool.Pool, error) { + conn, err := db.Open(context.Background(), cfg.Postgres) if err != nil { return nil, fmt.Errorf("db connect: %w", err) } @@ -194,6 +213,23 @@ func provideDBQueries(conn *pgxpool.Pool) *dbsqlc.Queries { return dbsqlc.New(conn) } +func provideMCPManager(log *slog.Logger, service ctr.Service, cfg config.Config, conn *pgxpool.Pool) *mcp.Manager { + return mcp.NewManager(log, service, cfg.MCP, cfg.Containerd.Namespace, conn) +} + +// --------------------------------------------------------------------------- +// memory providers +// --------------------------------------------------------------------------- + +func provideMemoryLLM(modelsService *models.Service, queries *dbsqlc.Queries, log *slog.Logger) memory.LLM { + return &lazyLLMClient{ + modelsService: modelsService, + queries: queries, + timeout: 30 * time.Second, + logger: log, + } +} + func provideEmbeddingsResolver(log *slog.Logger, modelsService *models.Service, queries *dbsqlc.Queries) *embeddings.Resolver { return embeddings.NewResolver(log, modelsService, queries, 10*time.Second) } @@ -228,58 +264,158 @@ func provideTextEmbedderForMemory(resolver *embeddings.Resolver, setup embedding return buildTextEmbedder(resolver, setup.TextModel, setup.HasEmbeddingModels, log) } -func provideMemoryService(log *slog.Logger, llm memory.LLM, embedder embeddings.Embedder, store *memory.QdrantStore, resolver *embeddings.Resolver, bm25Indexer *memory.BM25Indexer, setup embeddingSetup) *memory.Service { - return memory.NewService(log, llm, embedder, store, resolver, bm25Indexer, setup.TextModel.ModelID, setup.MultimodalModel.ModelID) +func provideQdrantStore(log *slog.Logger, cfg config.Config, setup embeddingSetup) (*memory.QdrantStore, error) { + qcfg := cfg.Qdrant + timeout := time.Duration(qcfg.TimeoutSeconds) * time.Second + if setup.HasEmbeddingModels && len(setup.Vectors) > 0 { + store, err := memory.NewQdrantStoreWithVectors(log, qcfg.BaseURL, qcfg.APIKey, qcfg.Collection, setup.Vectors, "sparse_hash", timeout) + if err != nil { + return nil, fmt.Errorf("qdrant named vectors init: %w", err) + } + return store, nil + } + store, err := memory.NewQdrantStore(log, qcfg.BaseURL, qcfg.APIKey, qcfg.Collection, setup.TextModel.Dimensions, "sparse_hash", timeout) + if err != nil { + return nil, fmt.Errorf("qdrant init: %w", err) + } + return store, nil } -func provideChatResolver(log *slog.Logger, cfg config.Config, modelsService *models.Service, queries *dbsqlc.Queries, memoryService *memory.Service, historyService *history.Service, settingsService *settings.Service, mcpConnectionsService *mcp.ConnectionService, containerdHandler *handlers.ContainerdHandler) *chat.Resolver { - chatResolver := chat.NewResolver(log, modelsService, queries, memoryService, historyService, settingsService, mcpConnectionsService, cfg.AgentGateway.BaseURL(), 120*time.Second) - chatResolver.SetSkillLoader(&skillLoaderAdapter{handler: containerdHandler}) - return chatResolver +func provideMemoryService(log *slog.Logger, llm memory.LLM, embedder embeddings.Embedder, store *memory.QdrantStore, resolver *embeddings.Resolver, bm25 *memory.BM25Indexer, setup embeddingSetup) *memory.Service { + return memory.NewService(log, llm, embedder, store, resolver, bm25, setup.TextModel.ModelID, setup.MultimodalModel.ModelID) } -func provideChannelRegistry(log *slog.Logger, sessionHub *local.SessionHub) *channel.Registry { +// --------------------------------------------------------------------------- +// domain service providers (interface adapters) +// --------------------------------------------------------------------------- + +func provideRouteService(log *slog.Logger, queries *dbsqlc.Queries, chatService *conversation.Service) *route.DBService { + return route.NewService(log, queries, chatService) +} + +func provideMessageService(log *slog.Logger, queries *dbsqlc.Queries, hub *event.Hub) *message.DBService { + return message.NewService(log, queries, hub) +} + +func provideScheduleTriggerer(resolver *flow.Resolver) schedule.Triggerer { + return flow.NewScheduleGateway(resolver) +} + +// --------------------------------------------------------------------------- +// conversation flow +// --------------------------------------------------------------------------- + +func provideChatResolver(log *slog.Logger, cfg config.Config, modelsService *models.Service, queries *dbsqlc.Queries, memoryService *memory.Service, chatService *conversation.Service, msgService *message.DBService, settingsService *settings.Service, mcpConnService *mcp.ConnectionService, containerdHandler *handlers.ContainerdHandler) *flow.Resolver { + resolver := flow.NewResolver(log, modelsService, queries, memoryService, chatService, msgService, settingsService, mcpConnService, cfg.AgentGateway.BaseURL(), 120*time.Second) + resolver.SetSkillLoader(&skillLoaderAdapter{handler: containerdHandler}) + return resolver +} + +// --------------------------------------------------------------------------- +// channel providers +// --------------------------------------------------------------------------- + +func provideChannelRegistry(log *slog.Logger, hub *local.RouteHub) *channel.Registry { registry := channel.NewRegistry() registry.MustRegister(telegram.NewTelegramAdapter(log)) registry.MustRegister(feishu.NewFeishuAdapter(log)) - registry.MustRegister(local.NewCLIAdapter(sessionHub)) - registry.MustRegister(local.NewWebAdapter(sessionHub)) + registry.MustRegister(local.NewCLIAdapter(hub)) + registry.MustRegister(local.NewWebAdapter(hub)) return registry } -func provideChannelRouter(log *slog.Logger, registry *channel.Registry, channelService *channel.Service, chatResolver *chat.Resolver, contactsService *contacts.Service, policyService *policy.Service, preauthService *preauth.Service, cfg config.Config) *router.ChannelInboundProcessor { - return router.NewChannelInboundProcessor(log, registry, channelService, chatResolver, contactsService, policyService, preauthService, cfg.Auth.JWTSecret, 5*time.Minute) +func provideChannelRouter(log *slog.Logger, registry *channel.Registry, routeService *route.DBService, msgService *message.DBService, resolver *flow.Resolver, identityService *identities.Service, botService *bots.Service, policyService *policy.Service, preauthService *preauth.Service, bindService *bind.Service, rc *boot.RuntimeConfig) *router.ChannelInboundProcessor { + return router.NewChannelInboundProcessor(log, registry, routeService, msgService, resolver, identityService, botService, policyService, preauthService, bindService, rc.JwtSecret, 5*time.Minute) } func provideChannelManager(log *slog.Logger, registry *channel.Registry, channelService *channel.Service, channelRouter *router.ChannelInboundProcessor) *channel.Manager { - channelManager := channel.NewManager(log, registry, channelService, channelRouter) + mgr := channel.NewManager(log, registry, channelService, channelRouter) if mw := channelRouter.IdentityMiddleware(); mw != nil { - channelManager.Use(mw) + mgr.Use(mw) } - return channelManager + return mgr } -func provideCLIHandler(channelManager *channel.Manager, channelService *channel.Service, sessionHub *local.SessionHub, botService *bots.Service, usersService *users.Service) *handlers.LocalChannelHandler { - return handlers.NewLocalChannelHandler(local.CLIType, channelManager, channelService, sessionHub, botService, usersService) +// --------------------------------------------------------------------------- +// containerd handler & tool gateway +// --------------------------------------------------------------------------- + +func provideContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.Config, botService *bots.Service, accountService *accounts.Service, policyService *policy.Service, queries *dbsqlc.Queries) *handlers.ContainerdHandler { + return handlers.NewContainerdHandler(log, service, cfg.MCP, cfg.Containerd.Namespace, botService, accountService, policyService, queries) } -func provideWebHandler(channelManager *channel.Manager, channelService *channel.Service, sessionHub *local.SessionHub, botService *bots.Service, usersService *users.Service) *handlers.LocalChannelHandler { - return handlers.NewLocalChannelHandler(local.WebType, channelManager, channelService, sessionHub, botService, usersService) +func provideToolGatewayService(log *slog.Logger, cfg config.Config, channelManager *channel.Manager, registry *channel.Registry, channelService *channel.Service, scheduleService *schedule.Service, memoryService *memory.Service, chatService *conversation.Service, accountService *accounts.Service, manager *mcp.Manager, containerdHandler *handlers.ContainerdHandler, mcpConnService *mcp.ConnectionService) *mcp.ToolGatewayService { + messageExec := mcpmessage.NewExecutor(log, channelManager, registry) + directoryExec := mcpdirectory.NewExecutor(log, registry, channelService, registry) + scheduleExec := mcpschedule.NewExecutor(log, scheduleService) + memoryExec := mcpmemory.NewExecutor(log, memoryService, chatService, accountService) + execWorkDir := cfg.MCP.DataMount + if strings.TrimSpace(execWorkDir) == "" { + execWorkDir = config.DefaultDataMount + } + fsExec := mcpcontainer.NewExecutor(log, manager, execWorkDir) + + fedGateway := handlers.NewMCPFederationGateway(log, containerdHandler) + fedSource := mcpfederation.NewSource(log, fedGateway, mcpConnService) + + svc := mcp.NewToolGatewayService( + log, + []mcp.ToolExecutor{messageExec, directoryExec, scheduleExec, memoryExec, fsExec}, + []mcp.ToolSource{fedSource}, + ) + containerdHandler.SetToolGatewayService(svc) + return svc } +// --------------------------------------------------------------------------- +// handler providers (interface adaptation / config extraction) +// --------------------------------------------------------------------------- + +func provideAuthHandler(log *slog.Logger, accountService *accounts.Service, rc *boot.RuntimeConfig) *handlers.AuthHandler { + return handlers.NewAuthHandler(log, accountService, rc.JwtSecret, rc.JwtExpiresIn) +} + +func provideMessageHandler(log *slog.Logger, resolver *flow.Resolver, chatService *conversation.Service, msgService *message.DBService, botService *bots.Service, accountService *accounts.Service, identityService *identities.Service, hub *event.Hub) *handlers.MessageHandler { + return handlers.NewMessageHandler(log, resolver, chatService, msgService, botService, accountService, identityService, hub) +} + +func provideUsersHandler(log *slog.Logger, accountService *accounts.Service, identityService *identities.Service, botService *bots.Service, routeService *route.DBService, channelService *channel.Service, channelManager *channel.Manager, registry *channel.Registry) *handlers.UsersHandler { + return handlers.NewUsersHandler(log, accountService, identityService, botService, routeService, channelService, channelManager, registry) +} + +func provideCLIHandler(channelManager *channel.Manager, channelService *channel.Service, chatService *conversation.Service, hub *local.RouteHub, botService *bots.Service, accountService *accounts.Service) *handlers.LocalChannelHandler { + return handlers.NewLocalChannelHandler(local.CLIType, channelManager, channelService, chatService, hub, botService, accountService) +} + +func provideWebHandler(channelManager *channel.Manager, channelService *channel.Service, chatService *conversation.Service, hub *local.RouteHub, botService *bots.Service, accountService *accounts.Service) *handlers.LocalChannelHandler { + return handlers.NewLocalChannelHandler(local.WebType, channelManager, channelService, chatService, hub, botService, accountService) +} + +// --------------------------------------------------------------------------- +// server +// --------------------------------------------------------------------------- + type serverParams struct { fx.In - Logger *slog.Logger - RuntimeConfig *boot.RuntimeConfig - Config config.Config - ServerHandlers []server.Handler `group:"server_handlers"` + Logger *slog.Logger + RuntimeConfig *boot.RuntimeConfig + Config config.Config + ServerHandlers []server.Handler `group:"server_handlers"` + ContainerdHandler *handlers.ContainerdHandler } func provideServer(params serverParams) *server.Server { - return server.NewServer(params.Logger, params.RuntimeConfig.ServerAddr, params.Config.Auth.JWTSecret, params.ServerHandlers...) + allHandlers := make([]server.Handler, 0, len(params.ServerHandlers)+1) + allHandlers = append(allHandlers, params.ServerHandlers...) + allHandlers = append(allHandlers, params.ContainerdHandler) + return server.NewServer(params.Logger, params.RuntimeConfig.ServerAddr, params.Config.Auth.JWTSecret, allHandlers...) } +// --------------------------------------------------------------------------- +// lifecycle hooks +// --------------------------------------------------------------------------- + func startMemoryWarmup(lc fx.Lifecycle, memoryService *memory.Service, logger *slog.Logger) { lc.Append(fx.Hook{ OnStart: func(ctx context.Context) error { @@ -291,19 +427,9 @@ func startMemoryWarmup(lc fx.Lifecycle, memoryService *memory.Service, logger *s return nil }, }) - } -func startChannelManager(lc fx.Lifecycle, channelManager *channel.Manager, logger *slog.Logger) { - lc.Append(fx.Hook{ - OnStart: func(ctx context.Context) error { - channelManager.Start(ctx) - return nil - }, - }) -} - -func startScheduleService(lc fx.Lifecycle, scheduleService *schedule.Service, logger *slog.Logger) { +func startScheduleService(lc fx.Lifecycle, scheduleService *schedule.Service) { lc.Append(fx.Hook{ OnStart: func(ctx context.Context) error { return scheduleService.Bootstrap(ctx) @@ -311,40 +437,48 @@ func startScheduleService(lc fx.Lifecycle, scheduleService *schedule.Service, lo }) } -func startServer( - lc fx.Lifecycle, - logger *slog.Logger, - srv *server.Server, - shutdowner fx.Shutdowner, - cfg config.Config, - queries *dbsqlc.Queries, - scheduleService *schedule.Service, - channelManager *channel.Manager, - botService *bots.Service, - containerdHandler *handlers.ContainerdHandler, -) { +func startChannelManager(lc fx.Lifecycle, channelManager *channel.Manager) { + ctx, cancel := context.WithCancel(context.Background()) + lc.Append(fx.Hook{ + OnStart: func(_ context.Context) error { + channelManager.Start(ctx) + return nil + }, + OnStop: func(stopCtx context.Context) error { + cancel() + return channelManager.Shutdown(stopCtx) + }, + }) +} + +func startContainerReconciliation(lc fx.Lifecycle, containerdHandler *handlers.ContainerdHandler, _ *mcp.ToolGatewayService) { + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + go containerdHandler.ReconcileContainers(ctx) + return nil + }, + }) +} + +func startServer(lc fx.Lifecycle, logger *slog.Logger, srv *server.Server, shutdowner fx.Shutdowner, cfg config.Config, queries *dbsqlc.Queries, botService *bots.Service, containerdHandler *handlers.ContainerdHandler) { fmt.Printf("Starting Memoh Agent %s\n", version.GetInfo()) lc.Append(fx.Hook{ OnStart: func(ctx context.Context) error { - if err := ensureAdminUser(ctx, logger, queries, cfg); err != nil { return err } - botService.SetContainerLifecycle(containerdHandler) go func() { - if err := srv.Start(); err != nil { // block until server is stopped + if err := srv.Start(); err != nil && !errors.Is(err, http.ErrServerClosed) { logger.Error("server failed", slog.Any("error", err)) - _ = shutdowner.Shutdown() // shutdown the application if the server fails to start + _ = shutdowner.Shutdown() } }() - return nil }, OnStop: func(ctx context.Context) error { - // graceful shutdown if err := srv.Stop(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { return fmt.Errorf("server stop: %w", err) } @@ -353,6 +487,10 @@ func startServer( }) } +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + func buildTextEmbedder(resolver *embeddings.Resolver, textModel models.GetResponse, hasModels bool, log *slog.Logger) embeddings.Embedder { if !hasModels { return nil @@ -368,44 +506,11 @@ func buildTextEmbedder(resolver *embeddings.Resolver, textModel models.GetRespon } } -func provideQdrantStore(log *slog.Logger, cfgAll config.Config, setup embeddingSetup) (*memory.QdrantStore, error) { - cfg := cfgAll.Qdrant - timeout := time.Duration(cfg.TimeoutSeconds) * time.Second - if setup.HasEmbeddingModels && len(setup.Vectors) > 0 { - store, err := memory.NewQdrantStoreWithVectors( - log, - cfg.BaseURL, - cfg.APIKey, - cfg.Collection, - setup.Vectors, - "sparse_hash", - timeout, - ) - if err != nil { - return nil, fmt.Errorf("qdrant named vectors init: %w", err) - } - return store, nil - } - store, err := memory.NewQdrantStore( - log, - cfg.BaseURL, - cfg.APIKey, - cfg.Collection, - setup.TextModel.Dimensions, - "sparse_hash", - timeout, - ) - if err != nil { - return nil, fmt.Errorf("qdrant init: %w", err) - } - return store, nil -} - func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) error { if queries == nil { return fmt.Errorf("db queries not configured") } - count, err := queries.CountUsers(ctx) + count, err := queries.CountAccounts(ctx) if err != nil { return err } @@ -428,6 +533,14 @@ func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Quer return err } + user, err := queries.CreateUser(ctx, dbsqlc.CreateUserParams{ + IsActive: true, + Metadata: []byte("{}"), + }) + if err != nil { + return fmt.Errorf("create admin user: %w", err) + } + emailValue := pgtype.Text{Valid: false} if email != "" { emailValue = pgtype.Text{String: email, Valid: true} @@ -435,10 +548,11 @@ func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Quer displayName := pgtype.Text{String: username, Valid: true} dataRoot := pgtype.Text{String: cfg.MCP.DataRoot, Valid: cfg.MCP.DataRoot != ""} - _, err = queries.CreateUser(ctx, dbsqlc.CreateUserParams{ - Username: username, + _, err = queries.CreateAccount(ctx, dbsqlc.CreateAccountParams{ + UserID: user.ID, + Username: pgtype.Text{String: username, Valid: true}, Email: emailValue, - PasswordHash: string(hashed), + PasswordHash: pgtype.Text{String: string(hashed), Valid: true}, Role: "admin", DisplayName: displayName, AvatarUrl: pgtype.Text{Valid: false}, @@ -452,14 +566,9 @@ func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Quer return nil } -func provideMemoryLLM(modelsService *models.Service, queries *dbsqlc.Queries, log *slog.Logger) memory.LLM { - return &lazyLLMClient{ - modelsService: modelsService, - queries: queries, - timeout: 30 * time.Second, - logger: log, - } -} +// --------------------------------------------------------------------------- +// lazy LLM client +// --------------------------------------------------------------------------- type lazyLLMClient struct { modelsService *models.Service @@ -507,19 +616,19 @@ func (c *lazyLLMClient) resolve(ctx context.Context) (memory.LLM, error) { return memory.NewLLMClient(c.logger, memoryProvider.BaseUrl, memoryProvider.ApiKey, memoryModel.ModelID, c.timeout) } -// skillLoaderAdapter bridges handlers.ContainerdHandler to chat.SkillLoader. +// skillLoaderAdapter bridges handlers.ContainerdHandler to flow.SkillLoader. type skillLoaderAdapter struct { handler *handlers.ContainerdHandler } -func (a *skillLoaderAdapter) LoadSkills(ctx context.Context, botID string) ([]chat.SkillEntry, error) { +func (a *skillLoaderAdapter) LoadSkills(ctx context.Context, botID string) ([]flow.SkillEntry, error) { items, err := a.handler.LoadSkills(ctx, botID) if err != nil { return nil, err } - entries := make([]chat.SkillEntry, len(items)) + entries := make([]flow.SkillEntry, len(items)) for i, item := range items { - entries[i] = chat.SkillEntry{ + entries[i] = flow.SkillEntry{ Name: item.Name, Description: item.Description, Content: item.Content, diff --git a/cmd/feishu-echo/main.go b/cmd/feishu-echo/main.go new file mode 100644 index 00000000..d092ca6a --- /dev/null +++ b/cmd/feishu-echo/main.go @@ -0,0 +1,120 @@ +// feishu-echo is a minimal Feishu bot that connects via WebSocket and counts received events. +// Used to verify whether message loss is due to our app logic or network/Feishu delivery. +// +// Usage: +// +// FEISHU_APP_ID=xxx FEISHU_APP_SECRET=xxx FEISHU_ENCRYPT=xxx FEISHU_VERIFY=xxx go run ./cmd/feishu-echo +package main + +import ( + "context" + "log" + "os" + "os/signal" + "strings" + "sync/atomic" + "time" + + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + larkws "github.com/larksuite/oapi-sdk-go/v3/ws" + + "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher" +) + +type eventCounts struct { + messageReceive atomic.Int64 + messageRead atomic.Int64 + reactionCreated atomic.Int64 + reactionDeleted atomic.Int64 +} + +func (c *eventCounts) log() { + log.Printf("[feishu-echo] counts: receive=%d read=%d reaction_created=%d reaction_deleted=%d", + c.messageReceive.Load(), c.messageRead.Load(), c.reactionCreated.Load(), c.reactionDeleted.Load()) +} + +func main() { + appID := strings.TrimSpace(os.Getenv("FEISHU_APP_ID")) + appSecret := strings.TrimSpace(os.Getenv("FEISHU_APP_SECRET")) + encryptKey := strings.TrimSpace(os.Getenv("FEISHU_ENCRYPT")) + verifyToken := strings.TrimSpace(os.Getenv("FEISHU_VERIFY")) + + if appID == "" || appSecret == "" { + log.Fatal("FEISHU_APP_ID and FEISHU_APP_SECRET are required") + } + + log.Printf("[feishu-echo] starting with app_id=%s (encrypt=%v, verify=%v)", appID, encryptKey != "", verifyToken != "") + + counts := new(eventCounts) + eventDispatcher := dispatcher.NewEventDispatcher(verifyToken, encryptKey) + + eventDispatcher.OnP2MessageReceiveV1(func(_ context.Context, _ *larkim.P2MessageReceiveV1) error { + counts.messageReceive.Add(1) + counts.log() + return nil + }) + + eventDispatcher.OnP2MessageReadV1(func(_ context.Context, _ *larkim.P2MessageReadV1) error { + counts.messageRead.Add(1) + counts.log() + return nil + }) + + eventDispatcher.OnP2MessageReactionCreatedV1(func(_ context.Context, _ *larkim.P2MessageReactionCreatedV1) error { + counts.reactionCreated.Add(1) + counts.log() + return nil + }) + + eventDispatcher.OnP2MessageReactionDeletedV1(func(_ context.Context, _ *larkim.P2MessageReactionDeletedV1) error { + counts.reactionDeleted.Add(1) + counts.log() + return nil + }) + + client := larkws.NewClient( + appID, + appSecret, + larkws.WithEventHandler(eventDispatcher), + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt) + <-sig + log.Println("[feishu-echo] interrupt, shutting down") + cancel() + counts.log() + os.Exit(0) + }() + + const reconnectDelay = 3 * time.Second +run: + for { + if ctx.Err() != nil { + break run + } + log.Println("[feishu-echo] connecting to Feishu WebSocket...") + err := client.Start(ctx) + if ctx.Err() != nil { + break run + } + if err != nil { + log.Printf("[feishu-echo] client error: %v; reconnecting in %v", err, reconnectDelay) + } else { + log.Printf("[feishu-echo] connection closed; reconnecting in %v", reconnectDelay) + } + timer := time.NewTimer(reconnectDelay) + select { + case <-ctx.Done(): + timer.Stop() + break run + case <-timer.C: + } + } + counts.log() + log.Println("[feishu-echo] stopped") +} diff --git a/cmd/feishu-echo/main_test.go b/cmd/feishu-echo/main_test.go new file mode 100644 index 00000000..9ed03377 --- /dev/null +++ b/cmd/feishu-echo/main_test.go @@ -0,0 +1,20 @@ +package main + +import ( + "testing" +) + +func TestEventCounts(t *testing.T) { + c := new(eventCounts) + c.log() + if c.messageReceive.Load() != 0 || c.messageRead.Load() != 0 { + t.Fatalf("initial counts should be 0") + } + c.messageReceive.Add(2) + c.messageRead.Add(1) + c.reactionCreated.Add(1) + if c.messageReceive.Load() != 2 || c.messageRead.Load() != 1 || c.reactionCreated.Load() != 1 { + t.Fatalf("counts after add: receive=2 read=1 reaction_created=1") + } + c.log() +} diff --git a/cmd/mcp/Dockerfile b/cmd/mcp/Dockerfile new file mode 100644 index 00000000..1daf0c51 --- /dev/null +++ b/cmd/mcp/Dockerfile @@ -0,0 +1,18 @@ +FROM golang:1.25-alpine AS build + +WORKDIR /src +COPY go.mod go.sum ./ +RUN go mod download + +COPY . . +ARG TARGETARCH +ARG COMMIT_HASH=unknown +RUN CGO_ENABLED=0 GOOS=linux GOARCH=${TARGETARCH:-amd64} \ + go build -trimpath -ldflags "-s -w -X github.com/memohai/memoh/internal/version.CommitHash=${COMMIT_HASH}" -o /out/mcp ./cmd/mcp + +FROM alpine:latest +RUN apk add --no-cache grep +WORKDIR /app +COPY --from=build /out/mcp /opt/mcp +COPY cmd/mcp/template /opt/mcp-template +ENTRYPOINT ["/bin/sh","-lc","bootstrap(){ [ -e /app/mcp ] || { mkdir -p /app; [ -f /opt/mcp ] && cp -a /opt/mcp /app/mcp 2>/dev/null || true; }; }; bootstrap; if [ -x /app/mcp ]; then exec /app/mcp \"$@\"; fi; exec /opt/mcp \"$@\"","--"] diff --git a/cmd/mcp/main.go b/cmd/mcp/main.go index ac5ed75f..9110eceb 100644 --- a/cmd/mcp/main.go +++ b/cmd/mcp/main.go @@ -10,7 +10,6 @@ import ( "syscall" "github.com/memohai/memoh/internal/logger" - "github.com/memohai/memoh/internal/mcp" "github.com/memohai/memoh/internal/version" gomcp "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -19,11 +18,11 @@ func main() { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() + // File tools (read/write/list/edit) are provided by the agent's MCP tool gateway, not this binary. server := gomcp.NewServer( &gomcp.Implementation{Name: "memoh-mcp", Version: version.GetInfo()}, nil, ) - mcp.RegisterTools(server) err := server.Run(ctx, &gomcp.StdioTransport{}) if ctx.Err() != nil { return diff --git a/db/migrations/0001_init.down.sql b/db/migrations/0001_init.down.sql index 28f5128d..8369a94b 100644 --- a/db/migrations/0001_init.down.sql +++ b/db/migrations/0001_init.down.sql @@ -1,26 +1,21 @@ -DROP TABLE IF EXISTS user_settings; DROP TABLE IF EXISTS subagents; DROP TABLE IF EXISTS schedule; DROP TABLE IF EXISTS lifecycle_events; DROP TABLE IF EXISTS container_versions; DROP TABLE IF EXISTS snapshots; DROP TABLE IF EXISTS containers; -DROP TABLE IF EXISTS channel_sessions; -DROP TABLE IF EXISTS contact_channels; +DROP TABLE IF EXISTS bot_history_messages; +DROP TABLE IF EXISTS bot_channel_routes; +DROP TABLE IF EXISTS channel_identity_bind_codes; DROP TABLE IF EXISTS bot_preauth_keys; DROP TABLE IF EXISTS bot_channel_configs; -DROP TABLE IF EXISTS user_channel_bindings; -DROP TABLE IF EXISTS history; -DROP TABLE IF EXISTS conversations; DROP TABLE IF EXISTS mcp_connections; -DROP TABLE IF EXISTS bot_model_configs; -DROP TABLE IF EXISTS bot_settings; DROP TABLE IF EXISTS bot_members; -DROP TABLE IF EXISTS contact_bind_tokens; DROP TABLE IF EXISTS bots; +DROP TABLE IF EXISTS model_variants; +DROP TABLE IF EXISTS models; +DROP TABLE IF EXISTS llm_providers; +DROP TABLE IF EXISTS user_channel_bindings; +DROP TABLE IF EXISTS channel_identities; DROP TABLE IF EXISTS users; -DROP TABLE IF EXISTS contacts; --- DROP TABLE IF EXISTS model_variants; --- DROP TABLE IF EXISTS models; --- DROP TABLE IF EXISTS llm_providers; DROP TYPE IF EXISTS user_role; diff --git a/db/migrations/0001_init.up.sql b/db/migrations/0001_init.up.sql index b9facd8b..812b7865 100644 --- a/db/migrations/0001_init.up.sql +++ b/db/migrations/0001_init.up.sql @@ -8,23 +8,58 @@ BEGIN END $$; +-- users: Memoh user principal CREATE TABLE IF NOT EXISTS users ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - username TEXT NOT NULL, + username TEXT, email TEXT, - password_hash TEXT NOT NULL, + password_hash TEXT, role user_role NOT NULL DEFAULT 'member', display_name TEXT, avatar_url TEXT, - is_active BOOLEAN NOT NULL DEFAULT true, data_root TEXT, + last_login_at TIMESTAMPTZ, + chat_model_id TEXT, + memory_model_id TEXT, + embedding_model_id TEXT, + max_context_load_time INTEGER NOT NULL DEFAULT 1440, + language TEXT NOT NULL DEFAULT 'auto', + is_active BOOLEAN NOT NULL DEFAULT true, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), - last_login_at TIMESTAMPTZ, CONSTRAINT users_email_unique UNIQUE (email), CONSTRAINT users_username_unique UNIQUE (username) ); +-- channel_identities: unified inbound identity subject +CREATE TABLE IF NOT EXISTS channel_identities ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID REFERENCES users(id) ON DELETE SET NULL, + channel_type TEXT NOT NULL, + channel_subject_id TEXT NOT NULL, + display_name TEXT, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT channel_identities_channel_type_subject_unique UNIQUE (channel_type, channel_subject_id) +); + +CREATE INDEX IF NOT EXISTS idx_channel_identities_user_id ON channel_identities(user_id); + +-- user_channel_bindings: outbound delivery config +CREATE TABLE IF NOT EXISTS user_channel_bindings ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + channel_type TEXT NOT NULL, + config JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT user_channel_bindings_unique UNIQUE (user_id, channel_type) +); + +CREATE INDEX IF NOT EXISTS idx_user_channel_bindings_user_id ON user_channel_bindings(user_id); + CREATE TABLE IF NOT EXISTS llm_providers ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), name TEXT NOT NULL, @@ -73,10 +108,18 @@ CREATE TABLE IF NOT EXISTS bots ( display_name TEXT, avatar_url TEXT, is_active BOOLEAN NOT NULL DEFAULT true, + status TEXT NOT NULL DEFAULT 'ready', + max_context_load_time INTEGER NOT NULL DEFAULT 1440, + language TEXT NOT NULL DEFAULT 'auto', + allow_guest BOOLEAN NOT NULL DEFAULT false, + chat_model_id UUID REFERENCES models(id) ON DELETE SET NULL, + memory_model_id UUID REFERENCES models(id) ON DELETE SET NULL, + embedding_model_id UUID REFERENCES models(id) ON DELETE SET NULL, metadata JSONB NOT NULL DEFAULT '{}'::jsonb, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), - CONSTRAINT bots_type_check CHECK (type IN ('personal', 'public')) + CONSTRAINT bots_type_check CHECK (type IN ('personal', 'public')), + CONSTRAINT bots_status_check CHECK (status IN ('creating', 'ready', 'deleting')) ); CREATE INDEX IF NOT EXISTS idx_bots_owner_user_id ON bots(owner_user_id); @@ -92,20 +135,6 @@ CREATE TABLE IF NOT EXISTS bot_members ( CREATE INDEX IF NOT EXISTS idx_bot_members_user_id ON bot_members(user_id); -CREATE TABLE IF NOT EXISTS bot_settings ( - bot_id UUID PRIMARY KEY REFERENCES bots(id) ON DELETE CASCADE, - max_context_load_time INTEGER NOT NULL DEFAULT 1440, - language TEXT NOT NULL DEFAULT 'auto', - allow_guest BOOLEAN NOT NULL DEFAULT false -); - -CREATE TABLE IF NOT EXISTS bot_model_configs ( - bot_id UUID PRIMARY KEY REFERENCES bots(id) ON DELETE CASCADE, - chat_model_id UUID REFERENCES models(id) ON DELETE SET NULL, - embedding_model_id UUID REFERENCES models(id) ON DELETE SET NULL, - memory_model_id UUID REFERENCES models(id) ON DELETE SET NULL -); - CREATE TABLE IF NOT EXISTS mcp_connections ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, @@ -121,45 +150,7 @@ CREATE TABLE IF NOT EXISTS mcp_connections ( CREATE INDEX IF NOT EXISTS idx_mcp_connections_bot_id ON mcp_connections(bot_id); -CREATE TABLE IF NOT EXISTS conversations ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, - session_id TEXT NOT NULL, - channel_type TEXT NOT NULL, - chat_id TEXT, - sender_id TEXT, - created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), - CONSTRAINT conversations_session_unique UNIQUE (bot_id, session_id) -); - -CREATE INDEX IF NOT EXISTS idx_conversations_bot_id ON conversations(bot_id); - -CREATE TABLE IF NOT EXISTS history ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, - session_id TEXT NOT NULL, - messages JSONB NOT NULL, - metadata JSONB NOT NULL DEFAULT '{}'::jsonb, - skills TEXT[] NOT NULL DEFAULT '{}'::text[], - timestamp TIMESTAMPTZ NOT NULL -); - -CREATE INDEX IF NOT EXISTS idx_history_bot ON history(bot_id); -CREATE INDEX IF NOT EXISTS idx_history_session ON history(session_id); -CREATE INDEX IF NOT EXISTS idx_history_timestamp ON history(timestamp); - -CREATE TABLE IF NOT EXISTS user_channel_bindings ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, - channel_type TEXT NOT NULL, - config JSONB NOT NULL DEFAULT '{}'::jsonb, - created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), - CONSTRAINT user_channel_bindings_unique UNIQUE (user_id, channel_type) -); - -CREATE INDEX IF NOT EXISTS idx_user_channel_bindings_user_id ON user_channel_bindings(user_id); +-- Bot history is bot-scoped (one history container per bot). CREATE TABLE IF NOT EXISTS bot_channel_configs ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), @@ -183,26 +174,6 @@ CREATE UNIQUE INDEX IF NOT EXISTS idx_bot_channel_external_identity CREATE INDEX IF NOT EXISTS idx_bot_channel_bot_id ON bot_channel_configs(bot_id); -CREATE TABLE IF NOT EXISTS contacts ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, - user_id UUID REFERENCES users(id) ON DELETE SET NULL, - display_name TEXT, - alias TEXT, - tags TEXT[] NOT NULL DEFAULT '{}'::text[], - status TEXT NOT NULL DEFAULT 'active', - metadata JSONB NOT NULL DEFAULT '{}'::jsonb, - created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), - CONSTRAINT contacts_status_check CHECK (status IN ('active', 'blocked', 'pending')) -); - -CREATE UNIQUE INDEX IF NOT EXISTS idx_contacts_bot_user_unique - ON contacts(bot_id, user_id) - WHERE user_id IS NOT NULL; - -CREATE INDEX IF NOT EXISTS idx_contacts_bot_id ON contacts(bot_id); - CREATE TABLE IF NOT EXISTS bot_preauth_keys ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, @@ -217,37 +188,61 @@ CREATE TABLE IF NOT EXISTS bot_preauth_keys ( CREATE INDEX IF NOT EXISTS idx_bot_preauth_keys_bot_id ON bot_preauth_keys(bot_id); CREATE INDEX IF NOT EXISTS idx_bot_preauth_keys_expires ON bot_preauth_keys(expires_at); -CREATE TABLE IF NOT EXISTS contact_channels ( +-- channel_identity_bind_codes: one-time codes for channel identity->user linking +CREATE TABLE IF NOT EXISTS channel_identity_bind_codes ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, - contact_id UUID NOT NULL REFERENCES contacts(id) ON DELETE CASCADE, - platform TEXT NOT NULL, - external_id TEXT NOT NULL, - metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + token TEXT NOT NULL, + issued_by_user_id UUID NOT NULL REFERENCES users(id), + channel_type TEXT, + expires_at TIMESTAMPTZ, + used_at TIMESTAMPTZ, + used_by_channel_identity_id UUID REFERENCES channel_identities(id), created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), - CONSTRAINT contact_channels_unique UNIQUE (bot_id, platform, external_id) + CONSTRAINT channel_identity_bind_codes_token_unique UNIQUE (token) ); -CREATE INDEX IF NOT EXISTS idx_contact_channels_contact_id ON contact_channels(contact_id); -CREATE INDEX IF NOT EXISTS idx_contact_channels_platform_external ON contact_channels(platform, external_id); +CREATE INDEX IF NOT EXISTS idx_channel_identity_bind_codes_channel_type ON channel_identity_bind_codes(channel_type); -CREATE TABLE IF NOT EXISTS channel_sessions ( - session_id TEXT PRIMARY KEY, +-- bot_channel_routes: route mapping for inbound channel threads to bot history. +CREATE TABLE IF NOT EXISTS bot_channel_routes ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, + channel_type TEXT NOT NULL, channel_config_id UUID REFERENCES bot_channel_configs(id) ON DELETE SET NULL, - user_id UUID REFERENCES users(id) ON DELETE CASCADE, - contact_id UUID REFERENCES contacts(id) ON DELETE SET NULL, - platform TEXT NOT NULL, - reply_target TEXT, - thread_id TEXT, + external_conversation_id TEXT NOT NULL, + external_thread_id TEXT, + default_reply_target TEXT, metadata JSONB NOT NULL DEFAULT '{}'::jsonb, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now() ); -CREATE INDEX IF NOT EXISTS idx_channel_sessions_bot_id ON channel_sessions(bot_id); -CREATE INDEX IF NOT EXISTS idx_channel_sessions_user_id ON channel_sessions(user_id); +CREATE UNIQUE INDEX IF NOT EXISTS idx_bot_channel_routes_unique + ON bot_channel_routes (bot_id, channel_type, external_conversation_id, COALESCE(external_thread_id, '')); +CREATE INDEX IF NOT EXISTS idx_bot_channel_routes_bot ON bot_channel_routes(bot_id); + +-- bot_history_messages: unified message history under bot scope. +CREATE TABLE IF NOT EXISTS bot_history_messages ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, + route_id UUID REFERENCES bot_channel_routes(id) ON DELETE SET NULL, + sender_channel_identity_id UUID REFERENCES channel_identities(id), + sender_account_user_id UUID REFERENCES users(id), + channel_type TEXT, + source_message_id TEXT, + source_reply_to_message_id TEXT, + role TEXT NOT NULL CHECK (role IN ('user', 'assistant', 'system', 'tool')), + content JSONB NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_bot_history_messages_bot_created ON bot_history_messages(bot_id, created_at); +CREATE INDEX IF NOT EXISTS idx_bot_history_messages_route ON bot_history_messages(route_id); +CREATE INDEX IF NOT EXISTS idx_bot_history_messages_source_lookup + ON bot_history_messages(channel_type, source_message_id); +CREATE INDEX IF NOT EXISTS idx_bot_history_messages_reply_lookup + ON bot_history_messages(channel_type, source_reply_to_message_id); CREATE TABLE IF NOT EXISTS containers ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), @@ -339,11 +334,3 @@ CREATE TABLE IF NOT EXISTS subagents ( CREATE INDEX IF NOT EXISTS idx_subagents_bot_id ON subagents(bot_id); CREATE INDEX IF NOT EXISTS idx_subagents_deleted ON subagents(deleted); -CREATE TABLE IF NOT EXISTS user_settings ( - user_id UUID PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE, - chat_model_id TEXT, - memory_model_id TEXT, - embedding_model_id TEXT, - max_context_load_time INTEGER NOT NULL DEFAULT 1440, - language TEXT NOT NULL DEFAULT 'auto' -); diff --git a/db/queries/bind.sql b/db/queries/bind.sql new file mode 100644 index 00000000..a2f30218 --- /dev/null +++ b/db/queries/bind.sql @@ -0,0 +1,22 @@ +-- name: CreateBindCode :one +INSERT INTO channel_identity_bind_codes (token, issued_by_user_id, channel_type, expires_at) +VALUES ($1, $2, $3, $4) +RETURNING id, token, issued_by_user_id, channel_type, expires_at, used_at, used_by_channel_identity_id, created_at; + +-- name: GetBindCode :one +SELECT id, token, issued_by_user_id, channel_type, expires_at, used_at, used_by_channel_identity_id, created_at +FROM channel_identity_bind_codes +WHERE token = $1; + +-- name: GetBindCodeForUpdate :one +SELECT id, token, issued_by_user_id, channel_type, expires_at, used_at, used_by_channel_identity_id, created_at +FROM channel_identity_bind_codes +WHERE token = $1 +FOR UPDATE; + +-- name: MarkBindCodeUsed :one +UPDATE channel_identity_bind_codes +SET used_at = now(), used_by_channel_identity_id = $2 +WHERE id = $1 + AND used_at IS NULL +RETURNING id, token, issued_by_user_id, channel_type, expires_at, used_at, used_by_channel_identity_id, created_at; diff --git a/db/queries/bots.sql b/db/queries/bots.sql index 5dd0e46a..6194f319 100644 --- a/db/queries/bots.sql +++ b/db/queries/bots.sql @@ -1,21 +1,21 @@ -- name: CreateBot :one -INSERT INTO bots (owner_user_id, type, display_name, avatar_url, is_active, metadata) -VALUES ($1, $2, $3, $4, $5, $6) -RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at; +INSERT INTO bots (owner_user_id, type, display_name, avatar_url, is_active, metadata, status) +VALUES ($1, $2, $3, $4, $5, $6, $7) +RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at; -- name: GetBotByID :one -SELECT id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at +SELECT id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at FROM bots WHERE id = $1; -- name: ListBotsByOwner :many -SELECT id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at +SELECT id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at FROM bots WHERE owner_user_id = $1 ORDER BY created_at DESC; -- name: ListBotsByMember :many -SELECT b.id, b.owner_user_id, b.type, b.display_name, b.avatar_url, b.is_active, b.metadata, b.created_at, b.updated_at +SELECT b.id, b.owner_user_id, b.type, b.display_name, b.avatar_url, b.is_active, b.status, b.max_context_load_time, b.language, b.allow_guest, b.chat_model_id, b.memory_model_id, b.embedding_model_id, b.metadata, b.created_at, b.updated_at FROM bots b JOIN bot_members m ON m.bot_id = b.id WHERE m.user_id = $1 @@ -29,14 +29,20 @@ SET display_name = $2, metadata = $5, updated_at = now() WHERE id = $1 -RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at; +RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at; -- name: UpdateBotOwner :one UPDATE bots SET owner_user_id = $2, updated_at = now() WHERE id = $1 -RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at; +RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at; + +-- name: UpdateBotStatus :exec +UPDATE bots +SET status = $2, + updated_at = now() +WHERE id = $1; -- name: DeleteBotByID :exec DELETE FROM bots WHERE id = $1; diff --git a/db/queries/channel_identities.sql b/db/queries/channel_identities.sql new file mode 100644 index 00000000..8a761998 --- /dev/null +++ b/db/queries/channel_identities.sql @@ -0,0 +1,49 @@ +-- name: CreateChannelIdentity :one +INSERT INTO channel_identities (user_id, channel_type, channel_subject_id, display_name, metadata) +VALUES ($1, $2, $3, $4, $5) +RETURNING id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at; + +-- name: GetChannelIdentityByID :one +SELECT id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at +FROM channel_identities +WHERE id = $1; + +-- name: GetChannelIdentityByIDForUpdate :one +SELECT id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at +FROM channel_identities +WHERE id = $1 +FOR UPDATE; + +-- name: GetChannelIdentityByChannelSubject :one +SELECT id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at +FROM channel_identities +WHERE channel_type = $1 AND channel_subject_id = $2; + +-- name: UpsertChannelIdentityByChannelSubject :one +INSERT INTO channel_identities (user_id, channel_type, channel_subject_id, display_name, metadata) +VALUES ($1, $2, $3, $4, $5) +ON CONFLICT (channel_type, channel_subject_id) +DO UPDATE SET + display_name = COALESCE(NULLIF(EXCLUDED.display_name, ''), channel_identities.display_name), + metadata = EXCLUDED.metadata, + user_id = COALESCE(channel_identities.user_id, EXCLUDED.user_id), + updated_at = now() +RETURNING id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at; + +-- name: ListChannelIdentitiesByUserID :many +SELECT id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at +FROM channel_identities +WHERE user_id = $1 +ORDER BY created_at DESC; + +-- name: SetChannelIdentityLinkedUser :one +UPDATE channel_identities +SET user_id = $2, updated_at = now() +WHERE id = $1 +RETURNING id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at; + +-- name: ClearChannelIdentityLinkedUser :one +UPDATE channel_identities +SET user_id = NULL, updated_at = now() +WHERE id = $1 +RETURNING id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at; diff --git a/db/queries/channel_routes.sql b/db/queries/channel_routes.sql new file mode 100644 index 00000000..be84a19f --- /dev/null +++ b/db/queries/channel_routes.sql @@ -0,0 +1,87 @@ +-- name: CreateChatRoute :one +INSERT INTO bot_channel_routes ( + bot_id, channel_type, channel_config_id, external_conversation_id, external_thread_id, default_reply_target, metadata +) +VALUES ( + sqlc.arg(bot_id), + sqlc.arg(platform), + sqlc.narg(channel_config_id)::uuid, + sqlc.arg(conversation_id), + sqlc.narg(thread_id)::text, + sqlc.narg(reply_target)::text, + sqlc.arg(metadata) +) +RETURNING + id, + sqlc.arg(chat_id)::uuid AS chat_id, + bot_id, + channel_type AS platform, + channel_config_id, + external_conversation_id AS conversation_id, + external_thread_id AS thread_id, + default_reply_target AS reply_target, + metadata, + created_at, + updated_at; + +-- name: FindChatRoute :one +SELECT + id, + bot_id AS chat_id, + bot_id, + channel_type AS platform, + channel_config_id, + external_conversation_id AS conversation_id, + external_thread_id AS thread_id, + default_reply_target AS reply_target, + metadata, + created_at, + updated_at +FROM bot_channel_routes +WHERE bot_id = $1 + AND channel_type = sqlc.arg(platform) + AND external_conversation_id = sqlc.arg(conversation_id) + AND COALESCE(external_thread_id, '') = COALESCE(sqlc.narg(thread_id), '') +LIMIT 1; + +-- name: GetChatRouteByID :one +SELECT + id, + bot_id AS chat_id, + bot_id, + channel_type AS platform, + channel_config_id, + external_conversation_id AS conversation_id, + external_thread_id AS thread_id, + default_reply_target AS reply_target, + metadata, + created_at, + updated_at +FROM bot_channel_routes +WHERE id = $1; + +-- name: ListChatRoutes :many +SELECT + id, + bot_id AS chat_id, + bot_id, + channel_type AS platform, + channel_config_id, + external_conversation_id AS conversation_id, + external_thread_id AS thread_id, + default_reply_target AS reply_target, + metadata, + created_at, + updated_at +FROM bot_channel_routes +WHERE bot_id = sqlc.arg(chat_id) +ORDER BY created_at ASC; + +-- name: UpdateChatRouteReplyTarget :exec +UPDATE bot_channel_routes +SET default_reply_target = sqlc.arg(reply_target), updated_at = now() +WHERE id = sqlc.arg(id); + +-- name: DeleteChatRoute :exec +DELETE FROM bot_channel_routes +WHERE id = $1; diff --git a/db/queries/channels.sql b/db/queries/channels.sql index 9323a26f..da22f1c2 100644 --- a/db/queries/channels.sql +++ b/db/queries/channels.sql @@ -48,40 +48,9 @@ DO UPDATE SET updated_at = now() RETURNING id, user_id, channel_type, config, created_at, updated_at; --- name: ListUserChannelBindingsByType :many +-- name: ListUserChannelBindingsByPlatform :many SELECT id, user_id, channel_type, config, created_at, updated_at FROM user_channel_bindings WHERE channel_type = $1 ORDER BY created_at DESC; --- name: GetChannelSessionByID :one -SELECT session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata, created_at, updated_at -FROM channel_sessions -WHERE session_id = $1 -LIMIT 1; - --- name: ListChannelSessionsByBotPlatform :many -SELECT session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata, created_at, updated_at -FROM channel_sessions -WHERE bot_id = $1 AND platform = $2 -ORDER BY updated_at DESC; - --- name: UpsertChannelSession :one -INSERT INTO channel_sessions (session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) -ON CONFLICT (session_id) -DO UPDATE SET - bot_id = EXCLUDED.bot_id, - channel_config_id = EXCLUDED.channel_config_id, - user_id = EXCLUDED.user_id, - contact_id = EXCLUDED.contact_id, - platform = EXCLUDED.platform, - reply_target = EXCLUDED.reply_target, - thread_id = EXCLUDED.thread_id, - metadata = EXCLUDED.metadata, - updated_at = now() -RETURNING session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata, created_at, updated_at; - --- name: DeleteChannelSession :exec -DELETE FROM channel_sessions -WHERE session_id = $1; diff --git a/db/queries/contacts.sql b/db/queries/contacts.sql deleted file mode 100644 index 7f5d9fe8..00000000 --- a/db/queries/contacts.sql +++ /dev/null @@ -1,76 +0,0 @@ --- name: CreateContact :one -INSERT INTO contacts (bot_id, user_id, display_name, alias, tags, status, metadata) -VALUES ($1, $2, $3, $4, $5, $6, $7) -RETURNING id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at; - --- name: GetContactByID :one -SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -FROM contacts -WHERE id = $1 -LIMIT 1; - --- name: GetContactByUserID :one -SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -FROM contacts -WHERE bot_id = $1 AND user_id = $2 -LIMIT 1; - --- name: ListContactsByBot :many -SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -FROM contacts -WHERE bot_id = $1 -ORDER BY created_at DESC; - --- name: SearchContacts :many -SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -FROM contacts -WHERE bot_id = $1 - AND ( - display_name ILIKE sqlc.arg(query) - OR alias ILIKE sqlc.arg(query) - OR EXISTS ( - SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE sqlc.arg(query) - ) - ) -ORDER BY created_at DESC; - --- name: UpdateContact :one -UPDATE contacts -SET display_name = COALESCE(sqlc.narg(display_name), display_name), - alias = COALESCE(sqlc.narg(alias), alias), - tags = COALESCE(sqlc.narg(tags), tags), - status = COALESCE(NULLIF(sqlc.arg(status)::text, ''), status), - metadata = COALESCE(sqlc.narg(metadata), metadata), - updated_at = now() -WHERE id = sqlc.arg(id) -RETURNING id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at; - --- name: UpdateContactUser :one -UPDATE contacts -SET user_id = $2, - updated_at = now() -WHERE id = $1 -RETURNING id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at; - --- name: UpsertContactChannel :one -INSERT INTO contact_channels (bot_id, contact_id, platform, external_id, metadata) -VALUES ($1, $2, $3, $4, $5) -ON CONFLICT (bot_id, platform, external_id) -DO UPDATE SET - contact_id = EXCLUDED.contact_id, - metadata = EXCLUDED.metadata, - updated_at = now() -RETURNING id, bot_id, contact_id, platform, external_id, metadata, created_at, updated_at; - --- name: GetContactChannelByIdentity :one -SELECT id, bot_id, contact_id, platform, external_id, metadata, created_at, updated_at -FROM contact_channels -WHERE bot_id = $1 AND platform = $2 AND external_id = $3 -LIMIT 1; - --- name: ListContactChannelsByContact :many -SELECT id, bot_id, contact_id, platform, external_id, metadata, created_at, updated_at -FROM contact_channels -WHERE contact_id = $1 -ORDER BY created_at DESC; - diff --git a/db/queries/containers.sql b/db/queries/containers.sql index 3e81c964..e3fba5ce 100644 --- a/db/queries/containers.sql +++ b/db/queries/containers.sql @@ -52,3 +52,6 @@ WHERE bot_id = sqlc.arg(bot_id); UPDATE containers SET status = 'stopped', last_stopped_at = now(), updated_at = now() WHERE bot_id = sqlc.arg(bot_id); + +-- name: ListAutoStartContainers :many +SELECT * FROM containers WHERE auto_start = true ORDER BY updated_at DESC; diff --git a/db/queries/conversations.sql b/db/queries/conversations.sql new file mode 100644 index 00000000..e3e25528 --- /dev/null +++ b/db/queries/conversations.sql @@ -0,0 +1,229 @@ +-- name: CreateChat :one +SELECT + b.id AS id, + b.id AS bot_id, + (COALESCE(NULLIF(sqlc.arg(kind)::text, ''), CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END))::text AS kind, + CASE WHEN sqlc.arg(kind) = 'thread' THEN sqlc.arg(parent_chat_id)::uuid ELSE NULL::uuid END AS parent_chat_id, + COALESCE(NULLIF(sqlc.arg(title)::text, ''), b.display_name) AS title, + COALESCE(sqlc.arg(created_by_user_id)::uuid, b.owner_user_id) AS created_by_user_id, + COALESCE(sqlc.arg(metadata)::jsonb, b.metadata) AS metadata, + chat_models.model_id AS model_id, + b.created_at, + b.updated_at +FROM bots b +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = sqlc.arg(bot_id) +LIMIT 1; + +-- name: GetChatByID :one +SELECT + b.id AS id, + b.id AS bot_id, + CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + b.display_name AS title, + b.owner_user_id AS created_by_user_id, + b.metadata AS metadata, + chat_models.model_id AS model_id, + b.created_at, + b.updated_at +FROM bots b +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = $1; + +-- name: ListChatsByBotAndUser :many +SELECT + b.id AS id, + b.id AS bot_id, + CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + b.display_name AS title, + b.owner_user_id AS created_by_user_id, + b.metadata AS metadata, + chat_models.model_id AS model_id, + b.created_at, + b.updated_at +FROM bots b +LEFT JOIN bot_members bm ON bm.bot_id = b.id AND bm.user_id = sqlc.arg(user_id) +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = sqlc.arg(bot_id) + AND (b.owner_user_id = sqlc.arg(user_id) OR bm.user_id IS NOT NULL) +ORDER BY b.updated_at DESC; + +-- name: ListVisibleChatsByBotAndUser :many +SELECT + b.id AS id, + b.id AS bot_id, + CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + b.display_name AS title, + b.owner_user_id AS created_by_user_id, + b.metadata AS metadata, + b.created_at, + b.updated_at, + 'participant'::text AS access_mode, + (CASE + WHEN b.owner_user_id = sqlc.arg(user_id) THEN 'owner' + ELSE COALESCE(bm.role, ''::text) + END)::text AS participant_role, + NULL::timestamptz AS last_observed_at +FROM bots b +LEFT JOIN bot_members bm ON bm.bot_id = b.id AND bm.user_id = sqlc.arg(user_id) +WHERE b.id = sqlc.arg(bot_id) + AND (b.owner_user_id = sqlc.arg(user_id) OR bm.user_id IS NOT NULL) +ORDER BY b.updated_at DESC; + +-- name: GetChatReadAccessByUser :one +SELECT + 'participant'::text AS access_mode, + (CASE + WHEN b.owner_user_id = sqlc.arg(user_id) THEN 'owner' + ELSE COALESCE(bm.role, ''::text) + END)::text AS participant_role, + NULL::timestamptz AS last_observed_at +FROM bots b +LEFT JOIN bot_members bm ON bm.bot_id = b.id AND bm.user_id = sqlc.arg(user_id) +WHERE b.id = sqlc.arg(chat_id) + AND (b.owner_user_id = sqlc.arg(user_id) OR bm.user_id IS NOT NULL) +LIMIT 1; + +-- name: ListThreadsByParent :many +SELECT + b.id AS id, + b.id AS bot_id, + CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + b.display_name AS title, + b.owner_user_id AS created_by_user_id, + b.metadata AS metadata, + chat_models.model_id AS model_id, + b.created_at, + b.updated_at +FROM bots b +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = $1 + AND false +ORDER BY b.created_at DESC; + +-- name: UpdateChatTitle :one +UPDATE bots +SET display_name = sqlc.arg(title), + updated_at = now() +WHERE id = sqlc.arg(id) +RETURNING + id, + id AS bot_id, + CASE WHEN type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + display_name AS title, + owner_user_id AS created_by_user_id, + metadata, + NULL::text AS model_id, + created_at, + updated_at; + +-- name: TouchChat :exec +UPDATE bots +SET updated_at = now() +WHERE id = sqlc.arg(chat_id); + +-- name: DeleteChat :exec +WITH deleted_messages AS ( + DELETE FROM bot_history_messages + WHERE bot_id = sqlc.arg(chat_id) +) +DELETE FROM bot_channel_routes bcr +WHERE bcr.bot_id = sqlc.arg(chat_id); + +-- chat_participants + +-- name: AddChatParticipant :one +INSERT INTO bot_members (bot_id, user_id, role) +VALUES (sqlc.arg(chat_id), sqlc.arg(user_id), sqlc.arg(role)) +ON CONFLICT (bot_id, user_id) DO UPDATE SET role = EXCLUDED.role +RETURNING bot_id AS chat_id, user_id, role, created_at AS joined_at; + +-- name: GetChatParticipant :one +WITH owner_participant AS ( + SELECT b.id AS chat_id, b.owner_user_id AS user_id, 'owner'::text AS role, b.created_at AS joined_at + FROM bots b + WHERE b.id = sqlc.arg(chat_id) AND b.owner_user_id = sqlc.arg(user_id) +), +member_participant AS ( + SELECT bm.bot_id AS chat_id, bm.user_id, bm.role, bm.created_at AS joined_at + FROM bot_members bm + WHERE bm.bot_id = sqlc.arg(chat_id) AND bm.user_id = sqlc.arg(user_id) +) +SELECT chat_id, user_id, role, joined_at +FROM ( + SELECT * FROM owner_participant + UNION ALL + SELECT * FROM member_participant +) p +ORDER BY CASE WHEN role = 'owner' THEN 0 ELSE 1 END +LIMIT 1; + +-- name: ListChatParticipants :many +WITH owner_participant AS ( + SELECT b.id AS chat_id, b.owner_user_id AS user_id, 'owner'::text AS role, b.created_at AS joined_at + FROM bots b + WHERE b.id = sqlc.arg(chat_id) +), +member_participant AS ( + SELECT bm.bot_id AS chat_id, bm.user_id, bm.role, bm.created_at AS joined_at + FROM bot_members bm + WHERE bm.bot_id = sqlc.arg(chat_id) + AND bm.user_id <> (SELECT owner_user_id FROM bots WHERE id = sqlc.arg(chat_id)) +) +SELECT chat_id, user_id, role, joined_at +FROM ( + SELECT * FROM owner_participant + UNION ALL + SELECT * FROM member_participant +) p +ORDER BY joined_at ASC; + +-- name: RemoveChatParticipant :exec +DELETE FROM bot_members +WHERE bot_id = sqlc.arg(chat_id) + AND user_id = sqlc.arg(user_id) + AND user_id <> (SELECT owner_user_id FROM bots WHERE id = sqlc.arg(chat_id)); + +-- name: CopyParticipantsToChat :exec +INSERT INTO bot_members (bot_id, user_id, role) +SELECT sqlc.arg(chat_id_2), bm.user_id, bm.role +FROM bot_members bm +WHERE bm.bot_id = sqlc.arg(chat_id) +ON CONFLICT (bot_id, user_id) DO NOTHING; + +-- chat_settings + +-- name: UpsertChatSettings :one +WITH resolved_model AS ( + SELECT id + FROM models + WHERE model_id = NULLIF(sqlc.narg(model_id)::text, '') + LIMIT 1 +), +updated AS ( + UPDATE bots + SET chat_model_id = COALESCE((SELECT id FROM resolved_model), bots.chat_model_id), + updated_at = now() + WHERE bots.id = sqlc.arg(id) + RETURNING bots.id, bots.chat_model_id, bots.updated_at +) +SELECT + updated.id AS chat_id, + chat_models.model_id AS model_id, + updated.updated_at +FROM updated +LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id; + +-- name: GetChatSettings :one +SELECT + b.id AS chat_id, + chat_models.model_id AS model_id, + b.updated_at +FROM bots b +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = $1; diff --git a/db/queries/history.sql b/db/queries/history.sql deleted file mode 100644 index 7e95576c..00000000 --- a/db/queries/history.sql +++ /dev/null @@ -1,31 +0,0 @@ --- name: CreateHistory :one -INSERT INTO history (bot_id, session_id, messages, metadata, skills, timestamp) -VALUES ($1, $2, $3, $4, $5, $6) -RETURNING id, bot_id, session_id, messages, metadata, skills, timestamp; - --- name: ListHistoryByBotSessionSince :many -SELECT id, bot_id, session_id, messages, metadata, skills, timestamp -FROM history -WHERE bot_id = $1 AND session_id = $2 AND timestamp >= $3 -ORDER BY timestamp ASC; - --- name: GetHistoryByID :one -SELECT id, bot_id, session_id, messages, metadata, skills, timestamp -FROM history -WHERE id = $1; - --- name: ListHistoryByBotSession :many -SELECT id, bot_id, session_id, messages, metadata, skills, timestamp -FROM history -WHERE bot_id = $1 AND session_id = $2 -ORDER BY timestamp DESC -LIMIT $3; - --- name: DeleteHistoryByID :exec -DELETE FROM history -WHERE id = $1; - --- name: DeleteHistoryByBotSession :exec -DELETE FROM history -WHERE bot_id = $1 AND session_id = $2; - diff --git a/db/queries/messages.sql b/db/queries/messages.sql new file mode 100644 index 00000000..eaa5a02b --- /dev/null +++ b/db/queries/messages.sql @@ -0,0 +1,118 @@ +-- name: CreateMessage :one +INSERT INTO bot_history_messages ( + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id, + channel_type, + source_message_id, + source_reply_to_message_id, + role, + content, + metadata +) +VALUES ( + sqlc.arg(bot_id), + sqlc.narg(route_id)::uuid, + sqlc.narg(sender_channel_identity_id)::uuid, + sqlc.narg(sender_user_id)::uuid, + sqlc.narg(platform)::text, + sqlc.narg(external_message_id)::text, + sqlc.narg(source_reply_to_message_id)::text, + sqlc.arg(role), + sqlc.arg(content), + sqlc.arg(metadata) +) +RETURNING + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at; + +-- name: ListMessages :many +SELECT + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at +FROM bot_history_messages +WHERE bot_id = sqlc.arg(bot_id) +ORDER BY created_at ASC; + +-- name: ListMessagesSince :many +SELECT + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at +FROM bot_history_messages +WHERE bot_id = sqlc.arg(bot_id) + AND created_at >= sqlc.arg(created_at) +ORDER BY created_at ASC; + +-- name: ListMessagesBefore :many +SELECT + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at +FROM bot_history_messages +WHERE bot_id = sqlc.arg(bot_id) + AND created_at < sqlc.arg(created_at) +ORDER BY created_at DESC +LIMIT sqlc.arg(max_count); + +-- name: ListMessagesLatest :many +SELECT + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at +FROM bot_history_messages +WHERE bot_id = sqlc.arg(bot_id) +ORDER BY created_at DESC +LIMIT sqlc.arg(max_count); + +-- name: DeleteMessagesByBot :exec +DELETE FROM bot_history_messages +WHERE bot_id = sqlc.arg(bot_id); diff --git a/db/queries/settings.sql b/db/queries/settings.sql index 4f35be30..926fc0b4 100644 --- a/db/queries/settings.sql +++ b/db/queries/settings.sql @@ -1,55 +1,64 @@ -- name: GetSettingsByUserID :one -SELECT user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language -FROM user_settings -WHERE user_id = $1; +SELECT id AS user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language +FROM users +WHERE id = $1; -- name: UpsertUserSettings :one -INSERT INTO user_settings (user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language) -VALUES ($1, $2, $3, $4, $5, $6) -ON CONFLICT (user_id) DO UPDATE SET - chat_model_id = EXCLUDED.chat_model_id, - memory_model_id = EXCLUDED.memory_model_id, - embedding_model_id = EXCLUDED.embedding_model_id, - max_context_load_time = EXCLUDED.max_context_load_time, - language = EXCLUDED.language -RETURNING user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language; +UPDATE users +SET chat_model_id = $2, + memory_model_id = $3, + embedding_model_id = $4, + max_context_load_time = $5, + language = $6, + updated_at = now() +WHERE id = $1 +RETURNING id AS user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language; -- name: GetSettingsByBotID :one -SELECT bot_id, max_context_load_time, language, allow_guest -FROM bot_settings -WHERE bot_id = $1; - --- name: GetBotModelConfigByBotID :one SELECT - bot_model_configs.bot_id, + bots.id AS bot_id, + bots.max_context_load_time, + bots.language, + bots.allow_guest, chat_models.model_id AS chat_model_id, memory_models.model_id AS memory_model_id, embedding_models.model_id AS embedding_model_id -FROM bot_model_configs -LEFT JOIN models AS chat_models ON chat_models.id = bot_model_configs.chat_model_id -LEFT JOIN models AS memory_models ON memory_models.id = bot_model_configs.memory_model_id -LEFT JOIN models AS embedding_models ON embedding_models.id = bot_model_configs.embedding_model_id -WHERE bot_model_configs.bot_id = $1; +FROM bots +LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id +LEFT JOIN models AS memory_models ON memory_models.id = bots.memory_model_id +LEFT JOIN models AS embedding_models ON embedding_models.id = bots.embedding_model_id +WHERE bots.id = $1; -- name: UpsertBotSettings :one -INSERT INTO bot_settings (bot_id, max_context_load_time, language, allow_guest) -VALUES ($1, $2, $3, $4) -ON CONFLICT (bot_id) DO UPDATE SET - max_context_load_time = EXCLUDED.max_context_load_time, - language = EXCLUDED.language, - allow_guest = EXCLUDED.allow_guest -RETURNING bot_id, max_context_load_time, language, allow_guest; - --- name: UpsertBotModelConfig :one -INSERT INTO bot_model_configs (bot_id, chat_model_id, memory_model_id, embedding_model_id) -VALUES ($1, $2, $3, $4) -ON CONFLICT (bot_id) DO UPDATE SET - chat_model_id = COALESCE(EXCLUDED.chat_model_id, bot_model_configs.chat_model_id), - memory_model_id = COALESCE(EXCLUDED.memory_model_id, bot_model_configs.memory_model_id), - embedding_model_id = COALESCE(EXCLUDED.embedding_model_id, bot_model_configs.embedding_model_id) -RETURNING bot_id, chat_model_id, memory_model_id, embedding_model_id; +WITH updated AS ( + UPDATE bots + SET max_context_load_time = sqlc.arg(max_context_load_time), + language = sqlc.arg(language), + allow_guest = sqlc.arg(allow_guest), + chat_model_id = COALESCE(sqlc.narg(chat_model_id)::uuid, bots.chat_model_id), + memory_model_id = COALESCE(sqlc.narg(memory_model_id)::uuid, bots.memory_model_id), + embedding_model_id = COALESCE(sqlc.narg(embedding_model_id)::uuid, bots.embedding_model_id), + updated_at = now() + WHERE bots.id = sqlc.arg(id) + RETURNING bots.id, bots.max_context_load_time, bots.language, bots.allow_guest, bots.chat_model_id, bots.memory_model_id, bots.embedding_model_id +) +SELECT + updated.id AS bot_id, + updated.max_context_load_time, + updated.language, + updated.allow_guest, + chat_models.model_id AS chat_model_id, + memory_models.model_id AS memory_model_id, + embedding_models.model_id AS embedding_model_id +FROM updated +LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id +LEFT JOIN models AS memory_models ON memory_models.id = updated.memory_model_id +LEFT JOIN models AS embedding_models ON embedding_models.id = updated.embedding_model_id; -- name: DeleteSettingsByBotID :exec -DELETE FROM bot_settings -WHERE bot_id = $1; - +UPDATE bots +SET max_context_load_time = 1440, + language = 'auto', + allow_guest = false, + updated_at = now() +WHERE id = $1; diff --git a/db/queries/users.sql b/db/queries/users.sql index 6506d935..87dc6e50 100644 --- a/db/queries/users.sql +++ b/db/queries/users.sql @@ -1,20 +1,38 @@ -- name: CreateUser :one -INSERT INTO users (username, email, password_hash, role, display_name, avatar_url, is_active, data_root) -VALUES ( - sqlc.arg(username), - sqlc.arg(email), - sqlc.arg(password_hash), - sqlc.arg(role)::user_role, - sqlc.arg(display_name), - sqlc.arg(avatar_url), - sqlc.arg(is_active), - sqlc.arg(data_root) -) +INSERT INTO users (is_active, metadata) +VALUES ($1, $2) RETURNING *; --- name: UpsertUserByUsername :one -INSERT INTO users (username, email, password_hash, role, display_name, avatar_url, is_active, data_root) +-- name: GetUserByID :one +SELECT * +FROM users +WHERE id = $1; + +-- name: UpdateUserStatus :one +UPDATE users +SET is_active = $2, + updated_at = now() +WHERE id = $1 +RETURNING *; + +-- name: CreateAccount :one +UPDATE users +SET username = sqlc.arg(username), + email = sqlc.arg(email), + password_hash = sqlc.arg(password_hash), + role = sqlc.arg(role)::user_role, + display_name = sqlc.arg(display_name), + avatar_url = sqlc.arg(avatar_url), + is_active = sqlc.arg(is_active), + data_root = sqlc.arg(data_root), + updated_at = now() +WHERE id = sqlc.arg(user_id) +RETURNING *; + +-- name: UpsertAccountByUsername :one +INSERT INTO users (id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, metadata) VALUES ( + sqlc.arg(user_id), sqlc.arg(username), sqlc.arg(email), sqlc.arg(password_hash), @@ -22,7 +40,8 @@ VALUES ( sqlc.arg(display_name), sqlc.arg(avatar_url), sqlc.arg(is_active), - sqlc.arg(data_root) + sqlc.arg(data_root), + '{}'::jsonb ) ON CONFLICT (username) DO UPDATE SET email = EXCLUDED.email, @@ -35,39 +54,27 @@ ON CONFLICT (username) DO UPDATE SET updated_at = now() RETURNING *; --- name: GetUserByUsername :one +-- name: GetAccountByUsername :one SELECT * FROM users WHERE username = sqlc.arg(username); --- name: GetUserByIdentity :one +-- name: GetAccountByIdentity :one SELECT * FROM users WHERE username = sqlc.arg(identity) OR email = sqlc.arg(identity); --- name: GetUserByID :one -SELECT * FROM users WHERE id = sqlc.arg(id); +-- name: GetAccountByUserID :one +SELECT * FROM users WHERE id = sqlc.arg(user_id); --- name: CreateUserWithID :one -INSERT INTO users (id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root) -VALUES ( - sqlc.arg(id), - sqlc.arg(username), - sqlc.arg(email), - sqlc.arg(password_hash), - sqlc.arg(role)::user_role, - sqlc.arg(display_name), - sqlc.arg(avatar_url), - sqlc.arg(is_active), - sqlc.arg(data_root) -) -RETURNING *; +-- name: CountAccounts :one +SELECT COUNT(*)::bigint AS count +FROM users +WHERE username IS NOT NULL + AND password_hash IS NOT NULL; --- name: CountUsers :one -SELECT COUNT(*)::bigint AS count FROM users; - --- name: ListUsers :many +-- name: ListAccounts :many SELECT * FROM users +WHERE username IS NOT NULL ORDER BY created_at DESC; - --- name: UpdateUserProfile :one +-- name: UpdateAccountProfile :one UPDATE users SET display_name = $2, avatar_url = $3, @@ -76,27 +83,26 @@ SET display_name = $2, WHERE id = $1 RETURNING *; --- name: UpdateUserAdmin :one +-- name: UpdateAccountAdmin :one UPDATE users SET role = sqlc.arg(role)::user_role, display_name = sqlc.arg(display_name), avatar_url = sqlc.arg(avatar_url), is_active = sqlc.arg(is_active), updated_at = now() -WHERE id = sqlc.arg(id) +WHERE id = sqlc.arg(user_id) RETURNING *; --- name: UpdateUserPassword :one +-- name: UpdateAccountPassword :one UPDATE users SET password_hash = $2, updated_at = now() WHERE id = $1 RETURNING *; --- name: UpdateUserLastLogin :one +-- name: UpdateAccountLastLogin :one UPDATE users SET last_login_at = now(), updated_at = now() WHERE id = $1 RETURNING *; - diff --git a/docker-compose.yml b/docker-compose.yml index 28f554ca..9b9088a5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -109,4 +109,4 @@ volumes: networks: memoh-network: - driver: bridge + driver: bridge \ No newline at end of file diff --git a/docker/Dockerfile.agent b/docker/Dockerfile.agent index da4c4cef..9f860353 100644 --- a/docker/Dockerfile.agent +++ b/docker/Dockerfile.agent @@ -23,6 +23,6 @@ COPY --from=builder /build/package.json /app/package.json EXPOSE 8081 HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ - CMD wget --no-verbose --tries=1 --spider http://localhost:8081/health || exit 1 + CMD wget --no-verbose --tries=1 --spider http://agent:8081/health || exit 1 CMD ["bun", "run", "dist/index.js"] diff --git a/docs/docs/.vitepress/config.ts b/docs/docs/.vitepress/config.ts index 7a3152c2..ef051e48 100644 --- a/docs/docs/.vitepress/config.ts +++ b/docs/docs/.vitepress/config.ts @@ -27,19 +27,63 @@ export default defineConfig({ sidebar: { '/': [ { - text: 'Hello Memoh', + text: 'Overview', link: '/index.md' }, { text: 'Getting Started', link: '/getting-started.md' + }, + { + text: 'Core Concepts', + items: [ + { + text: 'Concepts Overview', + link: '/concepts/index.md' + }, + { + text: 'Accounts and Linking', + link: '/concepts/identity-and-binding.md' + } + ] + }, + { + text: 'Documentation Style Guide', + items: [ + { + text: 'Terminology Rules', + link: '/style/terminology.md' + } + ] } ], '/zh/': [ { - text: 'Hello Memoh', + text: '文档总览', link: '/zh/index.md' }, + { + text: '核心概念', + items: [ + { + text: '概念总览', + link: '/zh/concepts/index.md' + }, + { + text: '账号模型与绑定', + link: '/zh/concepts/identity-and-binding.md' + } + ] + }, + { + text: '文档写作规范', + items: [ + { + text: '术语规范', + link: '/zh/style/terminology.md' + } + ] + } ] }, diff --git a/docs/docs/concepts/identity-and-binding.md b/docs/docs/concepts/identity-and-binding.md new file mode 100644 index 00000000..bb7982b6 --- /dev/null +++ b/docs/docs/concepts/identity-and-binding.md @@ -0,0 +1,41 @@ +# Accounts and Linking + +## Account Model + +Memoh treats platform accounts and system accounts as two different entities: + +- **Platform Account (`ChannelIdentity`)** is the user's account on an external access platform (for example, a TG account), not a Memoh internal account. +- **System Account (`User`)** is an internal account in Memoh. + +A platform account can exist before linking. +`bind` is the mechanism that links these two account types. + +## Access Platform and Bot + +- **Access Platform (`channel`)** is where inbound messages come from. +- **Bot** is an authorization and resource boundary inside Memoh. + +Bots are managed by system accounts, while inbound messages are produced by platform accounts. + +## Why Linking Is Account-Scoped + +Account linking exists to establish account ownership, not to grant bot resources directly: + +- It links platform accounts and system accounts independent of any single bot. +- It avoids coupling account linking with member management semantics. +- It keeps bot authorization and account linking decoupled. + +## Linking Flow (Current Consensus) + +1. A user requests a bind code under their own system account. +2. The platform account sends the code from a supported access-platform conversation. +3. Memoh validates the code and links platform account to system account. +4. Bot membership and authorization are handled by their own flows. + +## Bot Type Semantics + +- **Public bot**: supports member-based collaboration. +- **Personal bot**: conceptually single-owner, and should not rely on member semantics. + +> Note: The conceptual model is documented here as product semantics. +> Runtime behavior may still be in transition as implementations are tightened. diff --git a/docs/docs/concepts/index.md b/docs/docs/concepts/index.md new file mode 100644 index 00000000..e6d54718 --- /dev/null +++ b/docs/docs/concepts/index.md @@ -0,0 +1,21 @@ +# Core Concepts + +This section defines the core account and access concepts used by Memoh. + +## Concept Map + +- **System Account (`User`)**: an internal account in Memoh. +- **Platform Account (`ChannelIdentity`)**: a user's account on an external access platform, not a Memoh account (for example, the user's Telegram (TG) account). +- **Bot**: an access and resource boundary managed by a system account. +- **Account Linking (`bind`)**: the process that links a platform account to a system account. + +## Why This Matters + +Memoh receives messages from external access platforms, but manages permissions and resources inside the system. +To keep these concerns clear, the model separates platform accounts from system accounts, while keeping bot access control as an independent concern. + +Terminology note: "platform account" always means the user's account on that platform (such as TG), not an internal account created by this project. + +## In This Chapter + +- [Accounts and Linking](/concepts/identity-and-binding.md) diff --git a/docs/docs/index.md b/docs/docs/index.md index dbb041ff..076fe5a9 100644 --- a/docs/docs/index.md +++ b/docs/docs/index.md @@ -1 +1,22 @@ -# Hello Memoh +# Memoh Documentation + +Memoh is a multi-member, long-memory, containerized AI agent system. + +## Documentation Sections + +- [Getting Started](/getting-started.md) +- [Core Concepts](/concepts/index.md) + +## For Contributors + +- [Terminology Rules](/style/terminology.md) + +## Current Focus + +The current documentation iteration focuses on account semantics: + +- Distinguishing system accounts and platform accounts +- Explaining why account linking is account-scoped +- Clarifying the relationship between account linking and bot access + +Note: "platform account" means the user's account on external platforms (for example, TG), not a Memoh account. diff --git a/docs/docs/style/terminology.md b/docs/docs/style/terminology.md new file mode 100644 index 00000000..a510d624 --- /dev/null +++ b/docs/docs/style/terminology.md @@ -0,0 +1,40 @@ +# Terminology Rules + +> Audience: documentation contributors and maintainers. +> This page defines writing terms. It is not product user guidance. + +## Canonical Terms + +- **System Account (`User`)**: the account inside Memoh. +- **Platform Account (`ChannelIdentity`)**: the user's account on an external access platform, not a Memoh account. +- **Access Platform (`channel`)**: the external platform carrying inbound messages. +- **Account Linking (`bind`)**: linking a Platform Account to a System Account. +- **Bind Code**: one-time code used for account linking. +- **Bot**: resource and authorization boundary managed by a System Account. + +## Preferred Wording + +- Write **"platform account"** instead of "actor" in user-facing docs. +- Write **"access platform"** instead of "channel" when describing product behavior. +- Keep code aliases in parentheses on first mention: + - `Platform Account (ChannelIdentity)` + - `System Account (User)` + - `Account Linking (bind)` + +## Disallowed or Discouraged Terms + +- Avoid plain **actor** in conceptual docs (except when quoting code symbols). +- Avoid ambiguous **platform user** phrasing (it does not distinguish system vs platform account). +- Avoid wording that implies Platform Account is created inside Memoh. + +## Example Sentences + +- Correct: "A platform account is the user's TG account, not a Memoh account." +- Correct: "Account linking binds a platform account to a system account." +- Incorrect: "Actor is a user in Memoh." + +## Contributor Checklist + +- Is every "account" term clearly scoped (system vs platform)? +- Is "channel" replaced by "access platform" in prose? +- Are code aliases kept only as parenthetical references? diff --git a/docs/docs/zh/concepts/identity-and-binding.md b/docs/docs/zh/concepts/identity-and-binding.md new file mode 100644 index 00000000..c7d275d3 --- /dev/null +++ b/docs/docs/zh/concepts/identity-and-binding.md @@ -0,0 +1,41 @@ +# 账号模型与绑定 + +## 账号模型 + +Memoh 将平台账号与系统账号视为两类不同实体: + +- **平台账号(`ChannelIdentity`)** 是用户在外部接入平台上的账号(例如飞书账号),不是 Memoh 内部账号。 +- **系统账号(`User`)** 是 Memoh 系统内账号。 + +平台账号在初始阶段可以不绑定系统账号。 +`bind` 的职责是完成这两类账号的关联。 + +## 接入平台与 Bot + +- **接入平台(`channel`)** 是入站消息来源。 +- **Bot** 是系统内的授权与资源边界。 + +Bot 由系统账号管理,入站消息由平台账号产生。 + +## 为什么账号绑定是账号作用域 + +账号绑定的目标是建立账号归属关系,而不是直接发放 bot 资源权限: + +- 它只负责平台账号与系统账号的绑定; +- 不把账号绑定与成员管理语义耦合在一起; +- 让 bot 访问控制保持独立、可演进。 + +## 账号绑定流程(当前共识) + +1. 用户以自己的系统账号申请 bind code; +2. 平台账号在支持的接入平台会话中发送 code; +3. 系统校验 code,完成平台账号到系统账号的绑定; +4. bot 成员与授权由独立流程处理。 + +## Bot 类型语义 + +- **Public bot**:支持成员协作语义。 +- **Personal bot**:语义上应为单 owner,不应依赖成员机制。 + +> 注:本文档记录的是产品语义与共识方向。 +> 部分运行时细节仍可能处于收敛阶段。 diff --git a/docs/docs/zh/concepts/index.md b/docs/docs/zh/concepts/index.md new file mode 100644 index 00000000..fe931ca6 --- /dev/null +++ b/docs/docs/zh/concepts/index.md @@ -0,0 +1,21 @@ +# 核心概念 + +本章节用于定义 Memoh 的核心账号与访问概念。 + +## 概念图 + +- **系统账号(`User`)**:Memoh 系统内账号。 +- **平台账号(`ChannelIdentity`)**:用户在外部接入平台上的账号,不是 Memoh 系统内账号(例如用户的飞书账号)。 +- **Bot**:由系统账号管理的资源与访问边界。 +- **账号绑定(`bind`)**:把平台账号关联到系统账号的过程。 + +## 为什么重要 + +Memoh 需要同时处理外部接入平台消息与系统内权限控制。 +因此我们明确区分平台账号与系统账号,并将 bot 授权与账号绑定解耦。 + +术语说明:文档中的“平台账号”统一指用户在对应平台上的真实账号(如飞书账号),不指本项目内部账号。 + +## 本章内容 + +- [账号模型与绑定](/zh/concepts/identity-and-binding.md) diff --git a/docs/docs/zh/index.md b/docs/docs/zh/index.md index dbb041ff..5a068f92 100644 --- a/docs/docs/zh/index.md +++ b/docs/docs/zh/index.md @@ -1 +1,22 @@ -# Hello Memoh +# Memoh 文档 + +Memoh 是一个多成员、长记忆、容器化的 AI Agent 系统。 + +## 文档章节 + +- [快速开始](/getting-started.md) +- [核心概念](/zh/concepts/index.md) + +## 面向文档贡献者 + +- [术语规范](/zh/style/terminology.md) + +## 当前维护范围 + +当前文档先聚焦账号语义与访问控制: + +- 区分系统账号与平台账号 +- 解释为什么账号绑定是账号作用域 +- 说明账号绑定与 bot 访问控制之间的关系 + +说明:“平台账号”指用户在外部平台上的真实账号(例如飞书账号),不是 Memoh 系统账号。 diff --git a/docs/docs/zh/style/terminology.md b/docs/docs/zh/style/terminology.md new file mode 100644 index 00000000..5d76ec7c --- /dev/null +++ b/docs/docs/zh/style/terminology.md @@ -0,0 +1,40 @@ +# 术语规范 + +> 适用对象:文档编写者与维护者。 +> 本页用于统一写作语义,不是面向最终用户的功能说明。 + +## 规范术语 + +- **系统账号(`User`)**:Memoh 系统内账号。 +- **平台账号(`ChannelIdentity`)**:用户在外部接入平台上的账号,不是 Memoh 内账号。 +- **接入平台(`channel`)**:承载入站消息的外部平台。 +- **账号绑定(`bind`)**:把平台账号关联到系统账号的过程。 +- **绑定码(Bind Code)**:用于账号绑定的一次性代码。 +- **Bot**:由系统账号管理的资源与授权边界。 + +## 推荐写法 + +- 面向产品语义时,优先写 **“平台账号”**,不要直接写 actor。 +- 描述业务行为时,优先写 **“接入平台”**,不要直接写 channel。 +- 首次出现保留技术别名,后续可只用中文术语: + - 平台账号(`ChannelIdentity`) + - 系统账号(`User`) + - 账号绑定(`bind`) + +## 禁用或不推荐写法 + +- 在概念文档中直接使用 **actor**(除非明确引用代码符号)。 +- 使用含糊表述如 **“平台用户”**(未区分系统账号与平台账号)。 +- 写出“平台账号是 Memoh 内部账号”这类错误语义。 + +## 示例 + +- 正确:**“平台账号是用户在飞书上的账号,不是 Memoh 系统账号。”** +- 正确:**“账号绑定用于把平台账号关联到系统账号。”** +- 错误:**“Actor 是 Memoh 里的用户。”** + +## 自检清单 + +- 是否明确区分了系统账号与平台账号? +- 叙述中是否将 channel 表述为接入平台? +- 是否仅在首处保留技术别名? diff --git a/internal/users/service.go b/internal/accounts/service.go similarity index 54% rename from internal/users/service.go rename to internal/accounts/service.go index 24c0c405..e683ef8a 100644 --- a/internal/users/service.go +++ b/internal/accounts/service.go @@ -1,4 +1,4 @@ -package users +package accounts import ( "context" @@ -8,14 +8,15 @@ import ( "strings" "time" - "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "golang.org/x/crypto/bcrypt" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) +// Service provides account (credential) management for users. type Service struct { queries *sqlc.Queries logger *slog.Logger @@ -24,120 +25,125 @@ type Service struct { var ( ErrInvalidPassword = errors.New("invalid password") ErrInvalidCredentials = errors.New("invalid credentials") - ErrInactiveUser = errors.New("user is inactive") + ErrInactiveAccount = errors.New("account is inactive") ) +// NewService creates a new accounts service. func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { if log == nil { log = slog.Default() } return &Service{ queries: queries, - logger: log.With(slog.String("service", "users")), + logger: log.With(slog.String("service", "accounts")), } } -func (s *Service) Get(ctx context.Context, userID string) (User, error) { +// Get returns an account by user id. +func (s *Service) Get(ctx context.Context, userID string) (Account, error) { if s.queries == nil { - return User{}, fmt.Errorf("user queries not configured") + return Account{}, fmt.Errorf("account queries not configured") } - pgID, err := parseUUID(userID) + pgID, err := db.ParseUUID(userID) if err != nil { - return User{}, err + return Account{}, err } - row, err := s.queries.GetUserByID(ctx, pgID) + row, err := s.queries.GetAccountByUserID(ctx, pgID) if err != nil { - return User{}, err + return Account{}, err } - return toUser(row), nil + return toAccount(row), nil } -func (s *Service) Login(ctx context.Context, identity, password string) (User, error) { +// Login authenticates by identity (username or email) and password. +func (s *Service) Login(ctx context.Context, identity, password string) (Account, error) { if s.queries == nil { - return User{}, fmt.Errorf("user queries not configured") + return Account{}, fmt.Errorf("account queries not configured") } identity = strings.TrimSpace(identity) if identity == "" || strings.TrimSpace(password) == "" { - return User{}, ErrInvalidCredentials + return Account{}, ErrInvalidCredentials } - row, err := s.queries.GetUserByIdentity(ctx, identity) + row, err := s.queries.GetAccountByIdentity(ctx, pgtype.Text{String: identity, Valid: true}) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return User{}, ErrInvalidCredentials + return Account{}, ErrInvalidCredentials } - return User{}, err + return Account{}, err } if !row.IsActive { - return User{}, ErrInactiveUser + return Account{}, ErrInactiveAccount } - if err := bcrypt.CompareHashAndPassword([]byte(row.PasswordHash), []byte(password)); err != nil { - return User{}, ErrInvalidCredentials + if !row.PasswordHash.Valid { + return Account{}, ErrInvalidCredentials } - if _, err := s.queries.UpdateUserLastLogin(ctx, row.ID); err != nil { + if err := bcrypt.CompareHashAndPassword([]byte(row.PasswordHash.String), []byte(password)); err != nil { + return Account{}, ErrInvalidCredentials + } + if _, err := s.queries.UpdateAccountLastLogin(ctx, row.ID); err != nil { if s.logger != nil { s.logger.Warn("touch last login failed", slog.Any("error", err)) } } - return toUser(row), nil + return toAccount(row), nil } -func (s *Service) ListUsers(ctx context.Context) ([]User, error) { +// ListAccounts returns all accounts. +func (s *Service) ListAccounts(ctx context.Context) ([]Account, error) { if s.queries == nil { - return nil, fmt.Errorf("user queries not configured") + return nil, fmt.Errorf("account queries not configured") } - rows, err := s.queries.ListUsers(ctx) + rows, err := s.queries.ListAccounts(ctx) if err != nil { return nil, err } - items := make([]User, 0, len(rows)) + items := make([]Account, 0, len(rows)) for _, row := range rows { - items = append(items, toUser(row)) + items = append(items, toAccount(row)) } return items, nil } -func (s *Service) ListUsersByType(ctx context.Context, userType string) ([]User, error) { - if s.queries == nil { - return nil, fmt.Errorf("user queries not configured") - } - return nil, fmt.Errorf("user type filtering is not supported") -} - +// IsAdmin checks if the user has admin role. func (s *Service) IsAdmin(ctx context.Context, userID string) (bool, error) { if s.queries == nil { - return false, fmt.Errorf("user queries not configured") + return false, fmt.Errorf("account queries not configured") } - pgID, err := parseUUID(userID) + pgID, err := db.ParseUUID(userID) if err != nil { return false, err } - row, err := s.queries.GetUserByID(ctx, pgID) + row, err := s.queries.GetAccountByUserID(ctx, pgID) if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return false, nil + } return false, err } return isAdminRole(row.Role), nil } -func (s *Service) CreateHuman(ctx context.Context, req CreateUserRequest) (User, error) { +// Create creates a new account for an existing user. +func (s *Service) Create(ctx context.Context, userID string, req CreateAccountRequest) (Account, error) { if s.queries == nil { - return User{}, fmt.Errorf("user queries not configured") + return Account{}, fmt.Errorf("account queries not configured") } username := strings.TrimSpace(req.Username) if username == "" { - return User{}, fmt.Errorf("username is required") + return Account{}, fmt.Errorf("username is required") } password := strings.TrimSpace(req.Password) if password == "" { - return User{}, fmt.Errorf("password is required") + return Account{}, fmt.Errorf("password is required") } role, err := normalizeRole(req.Role) if err != nil { - return User{}, err + return Account{}, err } hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { - return User{}, err + return Account{}, err } displayName := strings.TrimSpace(req.DisplayName) @@ -151,6 +157,10 @@ func (s *Service) CreateHuman(ctx context.Context, req CreateUserRequest) (User, isActive = *req.IsActive } + pgUserID, err := db.ParseUUID(userID) + if err != nil { + return Account{}, err + } emailValue := pgtype.Text{Valid: false} if email != "" { emailValue = pgtype.Text{String: email, Valid: true} @@ -161,10 +171,11 @@ func (s *Service) CreateHuman(ctx context.Context, req CreateUserRequest) (User, avatarValue = pgtype.Text{String: avatarURL, Valid: true} } - row, err := s.queries.CreateUser(ctx, sqlc.CreateUserParams{ - Username: username, + row, err := s.queries.CreateAccount(ctx, sqlc.CreateAccountParams{ + UserID: pgUserID, + Username: pgtype.Text{String: username, Valid: true}, Email: emailValue, - PasswordHash: string(hashed), + PasswordHash: pgtype.Text{String: string(hashed), Valid: true}, Role: role, DisplayName: displayValue, AvatarUrl: avatarValue, @@ -172,28 +183,51 @@ func (s *Service) CreateHuman(ctx context.Context, req CreateUserRequest) (User, DataRoot: pgtype.Text{Valid: false}, }) if err != nil { - return User{}, err + return Account{}, err } - return toUser(row), nil + return toAccount(row), nil } -func (s *Service) UpdateUserAdmin(ctx context.Context, userID string, req UpdateUserRequest) (User, error) { +// CreateHuman keeps compatibility with older call sites. +func (s *Service) CreateHuman(ctx context.Context, userID string, req CreateAccountRequest) (Account, error) { + userID = strings.TrimSpace(userID) + if userID == "" { + if s.queries == nil { + return Account{}, fmt.Errorf("account queries not configured") + } + userRow, err := s.queries.CreateUser(ctx, sqlc.CreateUserParams{ + IsActive: true, + Metadata: []byte("{}"), + }) + if err != nil { + return Account{}, err + } + if !userRow.ID.Valid { + return Account{}, fmt.Errorf("create user: invalid id") + } + userID = userRow.ID.String() + } + return s.Create(ctx, userID, req) +} + +// UpdateAdmin updates account fields as admin. +func (s *Service) UpdateAdmin(ctx context.Context, userID string, req UpdateAccountRequest) (Account, error) { if s.queries == nil { - return User{}, fmt.Errorf("user queries not configured") + return Account{}, fmt.Errorf("account queries not configured") } - pgID, err := parseUUID(userID) + pgID, err := db.ParseUUID(userID) if err != nil { - return User{}, err + return Account{}, err } - existing, err := s.queries.GetUserByID(ctx, pgID) + existing, err := s.queries.GetAccountByUserID(ctx, pgID) if err != nil { - return User{}, err + return Account{}, err } role := fmt.Sprint(existing.Role) if req.Role != nil { role, err = normalizeRole(*req.Role) if err != nil { - return User{}, err + return Account{}, err } } displayName := strings.TrimSpace(existing.DisplayName.String) @@ -201,7 +235,7 @@ func (s *Service) UpdateUserAdmin(ctx context.Context, userID string, req Update displayName = strings.TrimSpace(*req.DisplayName) } if displayName == "" { - displayName = existing.Username + displayName = strings.TrimSpace(existing.Username.String) } avatarURL := strings.TrimSpace(existing.AvatarUrl.String) if req.AvatarURL != nil { @@ -212,94 +246,100 @@ func (s *Service) UpdateUserAdmin(ctx context.Context, userID string, req Update isActive = *req.IsActive } - row, err := s.queries.UpdateUserAdmin(ctx, sqlc.UpdateUserAdminParams{ - ID: pgID, + row, err := s.queries.UpdateAccountAdmin(ctx, sqlc.UpdateAccountAdminParams{ + UserID: pgID, Role: role, DisplayName: pgtype.Text{String: displayName, Valid: displayName != ""}, AvatarUrl: pgtype.Text{String: avatarURL, Valid: avatarURL != ""}, IsActive: isActive, }) if err != nil { - return User{}, err + return Account{}, err } - return toUser(row), nil + return toAccount(row), nil } -func (s *Service) UpdateProfile(ctx context.Context, userID string, req UpdateProfileRequest) (User, error) { +// UpdateProfile updates the user's profile. +func (s *Service) UpdateProfile(ctx context.Context, userID string, req UpdateProfileRequest) (Account, error) { if s.queries == nil { - return User{}, fmt.Errorf("user queries not configured") + return Account{}, fmt.Errorf("account queries not configured") } - pgID, err := parseUUID(userID) + pgID, err := db.ParseUUID(userID) if err != nil { - return User{}, err + return Account{}, err } - existing, err := s.queries.GetUserByID(ctx, pgID) + existing, err := s.queries.GetAccountByUserID(ctx, pgID) if err != nil { - return User{}, err + return Account{}, err } displayName := strings.TrimSpace(existing.DisplayName.String) if req.DisplayName != nil { displayName = strings.TrimSpace(*req.DisplayName) } if displayName == "" { - displayName = existing.Username + displayName = strings.TrimSpace(existing.Username.String) } avatarURL := strings.TrimSpace(existing.AvatarUrl.String) if req.AvatarURL != nil { avatarURL = strings.TrimSpace(*req.AvatarURL) } - row, err := s.queries.UpdateUserProfile(ctx, sqlc.UpdateUserProfileParams{ + row, err := s.queries.UpdateAccountProfile(ctx, sqlc.UpdateAccountProfileParams{ ID: pgID, DisplayName: pgtype.Text{String: displayName, Valid: displayName != ""}, AvatarUrl: pgtype.Text{String: avatarURL, Valid: avatarURL != ""}, IsActive: existing.IsActive, }) if err != nil { - return User{}, err + return Account{}, err } - return toUser(row), nil + return toAccount(row), nil } +// UpdatePassword changes the password after verifying the current one. func (s *Service) UpdatePassword(ctx context.Context, userID, currentPassword, newPassword string) error { if s.queries == nil { - return fmt.Errorf("user queries not configured") + return fmt.Errorf("account queries not configured") } if strings.TrimSpace(newPassword) == "" { return fmt.Errorf("new password is required") } - pgID, err := parseUUID(userID) + pgID, err := db.ParseUUID(userID) if err != nil { return err } - existing, err := s.queries.GetUserByID(ctx, pgID) + existing, err := s.queries.GetAccountByUserID(ctx, pgID) if err != nil { return err } if strings.TrimSpace(currentPassword) == "" { return ErrInvalidPassword } - if err := bcrypt.CompareHashAndPassword([]byte(existing.PasswordHash), []byte(currentPassword)); err != nil { + if !existing.PasswordHash.Valid { + return ErrInvalidPassword + } + if err := bcrypt.CompareHashAndPassword([]byte(existing.PasswordHash.String), []byte(currentPassword)); err != nil { return ErrInvalidPassword } hashed, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost) if err != nil { return err } - _, err = s.queries.UpdateUserPassword(ctx, sqlc.UpdateUserPasswordParams{ + _, err = s.queries.UpdateAccountPassword(ctx, sqlc.UpdateAccountPasswordParams{ ID: pgID, - PasswordHash: string(hashed), + PasswordHash: pgtype.Text{String: string(hashed), Valid: true}, }) return err } +// ResetPassword sets a new password without requiring the current one. func (s *Service) ResetPassword(ctx context.Context, userID, newPassword string) error { if s.queries == nil { - return fmt.Errorf("user queries not configured") + return fmt.Errorf("account queries not configured") } if strings.TrimSpace(newPassword) == "" { return fmt.Errorf("new password is required") } - pgID, err := parseUUID(userID) + pgID, err := db.ParseUUID(userID) if err != nil { return err } @@ -307,9 +347,9 @@ func (s *Service) ResetPassword(ctx context.Context, userID, newPassword string) if err != nil { return err } - _, err = s.queries.UpdateUserPassword(ctx, sqlc.UpdateUserPasswordParams{ + _, err = s.queries.UpdateAccountPassword(ctx, sqlc.UpdateAccountPasswordParams{ ID: pgID, - PasswordHash: string(hashed), + PasswordHash: pgtype.Text{String: string(hashed), Valid: true}, }) return err } @@ -339,7 +379,8 @@ func isAdminRole(role any) bool { } } -func toUser(row sqlc.User) User { +func toAccount(row sqlc.User) Account { + username := strings.TrimSpace(row.Username.String) email := "" if row.Email.Valid { email = row.Email.String @@ -348,6 +389,9 @@ func toUser(row sqlc.User) User { if row.DisplayName.Valid { displayName = row.DisplayName.String } + if displayName == "" { + displayName = username + } avatarURL := "" if row.AvatarUrl.Valid { avatarURL = row.AvatarUrl.String @@ -364,9 +408,9 @@ func toUser(row sqlc.User) User { if row.LastLoginAt.Valid { lastLogin = row.LastLoginAt.Time } - return User{ - ID: toUUIDString(row.ID), - Username: row.Username, + return Account{ + ID: row.ID.String(), + Username: username, Email: email, Role: fmt.Sprint(row.Role), DisplayName: displayName, @@ -378,24 +422,3 @@ func toUser(row sqlc.User) User { } } -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} - -func toUUIDString(value pgtype.UUID) string { - if !value.Valid { - return "" - } - parsed, err := uuid.FromBytes(value.Bytes[:]) - if err != nil { - return "" - } - return parsed.String() -} diff --git a/internal/users/types.go b/internal/accounts/types.go similarity index 68% rename from internal/users/types.go rename to internal/accounts/types.go index 431225fc..7a3b4f62 100644 --- a/internal/users/types.go +++ b/internal/accounts/types.go @@ -1,8 +1,9 @@ -package users +package accounts import "time" -type User struct { +// Account represents a human account credential record. +type Account struct { ID string `json:"id"` Username string `json:"username"` Email string `json:"email,omitempty"` @@ -15,7 +16,8 @@ type User struct { LastLoginAt time.Time `json:"last_login_at,omitempty"` } -type CreateUserRequest struct { +// CreateAccountRequest is the input for creating an account. +type CreateAccountRequest struct { Username string `json:"username"` Password string `json:"password"` Email string `json:"email,omitempty"` @@ -25,27 +27,32 @@ type CreateUserRequest struct { IsActive *bool `json:"is_active,omitempty"` } -type UpdateUserRequest struct { +// UpdateAccountRequest is the input for admin-level account updates. +type UpdateAccountRequest struct { Role *string `json:"role,omitempty"` DisplayName *string `json:"display_name,omitempty"` AvatarURL *string `json:"avatar_url,omitempty"` IsActive *bool `json:"is_active,omitempty"` } +// UpdateProfileRequest is the input for self-service profile updates. type UpdateProfileRequest struct { DisplayName *string `json:"display_name,omitempty"` AvatarURL *string `json:"avatar_url,omitempty"` } +// UpdatePasswordRequest is the input for password change. type UpdatePasswordRequest struct { CurrentPassword string `json:"current_password,omitempty"` NewPassword string `json:"new_password"` } +// ResetPasswordRequest is the input for admin password reset. type ResetPasswordRequest struct { NewPassword string `json:"new_password"` } -type ListUsersResponse struct { - Items []User `json:"items"` +// ListAccountsResponse wraps a list of accounts. +type ListAccountsResponse struct { + Items []Account `json:"items"` } diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index 790035d9..8c837b43 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -13,15 +13,14 @@ import ( ) const ( - claimSubject = "sub" - claimUserID = "user_id" - claimType = "typ" - claimBotID = "bot_id" - claimPlatform = "platform" - claimReplyTarget = "reply_target" - claimSessionID = "session_id" - claimContactID = "contact_id" - sessionTokenType = "channel_session" + claimSubject = "sub" + claimUserID = "user_id" + claimChannelIdentityID = "channel_identity_id" + claimType = "typ" + claimBotID = "bot_id" + claimChatID = "chat_id" + claimRouteID = "route_id" + chatTokenType = "chat_route" ) // JWTMiddleware returns a JWT auth middleware configured for HS256 tokens. @@ -84,24 +83,28 @@ func GenerateToken(userID, secret string, expiresIn time.Duration) (string, time return signed, expiresAt, nil } -type SessionToken struct { - BotID string - Platform string - ReplyTarget string - SessionID string - ContactID string +// ChatToken holds the claims for a chat-based JWT used for route-based reply. +type ChatToken struct { + BotID string + ChatID string + RouteID string + UserID string + ChannelIdentityID string } -// GenerateSessionToken creates a signed JWT for channel session reply. -func GenerateSessionToken(info SessionToken, secret string, expiresIn time.Duration) (string, time.Time, error) { +// GenerateChatToken creates a signed JWT for chat route reply. +func GenerateChatToken(info ChatToken, secret string, expiresIn time.Duration) (string, time.Time, error) { if strings.TrimSpace(info.BotID) == "" { return "", time.Time{}, fmt.Errorf("bot id is required") } - if strings.TrimSpace(info.Platform) == "" { - return "", time.Time{}, fmt.Errorf("platform is required") + if strings.TrimSpace(info.ChatID) == "" { + return "", time.Time{}, fmt.Errorf("chat id is required") } - if strings.TrimSpace(info.ReplyTarget) == "" { - return "", time.Time{}, fmt.Errorf("reply target is required") + if strings.TrimSpace(info.UserID) == "" { + info.UserID = strings.TrimSpace(info.ChannelIdentityID) + } + if strings.TrimSpace(info.UserID) == "" { + return "", time.Time{}, fmt.Errorf("user id is required") } if strings.TrimSpace(secret) == "" { return "", time.Time{}, fmt.Errorf("jwt secret is required") @@ -113,14 +116,14 @@ func GenerateSessionToken(info SessionToken, secret string, expiresIn time.Durat now := time.Now().UTC() expiresAt := now.Add(expiresIn) claims := jwt.MapClaims{ - claimType: sessionTokenType, - claimBotID: info.BotID, - claimPlatform: info.Platform, - claimReplyTarget: info.ReplyTarget, - claimSessionID: info.SessionID, - claimContactID: info.ContactID, - "iat": now.Unix(), - "exp": expiresAt.Unix(), + claimType: chatTokenType, + claimBotID: info.BotID, + claimChatID: info.ChatID, + claimRouteID: info.RouteID, + claimUserID: info.UserID, + claimChannelIdentityID: info.ChannelIdentityID, + "iat": now.Unix(), + "exp": expiresAt.Unix(), } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) signed, err := token.SignedString([]byte(secret)) @@ -130,26 +133,30 @@ func GenerateSessionToken(info SessionToken, secret string, expiresIn time.Durat return signed, expiresAt, nil } -// SessionTokenFromContext extracts the session token claims from context. -func SessionTokenFromContext(c echo.Context) (SessionToken, error) { +// ChatTokenFromContext extracts the chat token claims from context. +func ChatTokenFromContext(c echo.Context) (ChatToken, error) { token, ok := c.Get("user").(*jwt.Token) if !ok || token == nil || !token.Valid { - return SessionToken{}, echo.NewHTTPError(http.StatusUnauthorized, "invalid token") + return ChatToken{}, echo.NewHTTPError(http.StatusUnauthorized, "invalid token") } claims, ok := token.Claims.(jwt.MapClaims) if !ok { - return SessionToken{}, echo.NewHTTPError(http.StatusUnauthorized, "invalid token claims") + return ChatToken{}, echo.NewHTTPError(http.StatusUnauthorized, "invalid token claims") } - if claimString(claims, claimType) != sessionTokenType { - return SessionToken{}, echo.NewHTTPError(http.StatusUnauthorized, "invalid session token") + if claimString(claims, claimType) != chatTokenType { + return ChatToken{}, echo.NewHTTPError(http.StatusUnauthorized, "invalid chat token") } - return SessionToken{ - BotID: claimString(claims, claimBotID), - Platform: claimString(claims, claimPlatform), - ReplyTarget: claimString(claims, claimReplyTarget), - SessionID: claimString(claims, claimSessionID), - ContactID: claimString(claims, claimContactID), - }, nil + info := ChatToken{ + BotID: claimString(claims, claimBotID), + ChatID: claimString(claims, claimChatID), + RouteID: claimString(claims, claimRouteID), + UserID: claimString(claims, claimUserID), + ChannelIdentityID: claimString(claims, claimChannelIdentityID), + } + if strings.TrimSpace(info.UserID) == "" { + info.UserID = strings.TrimSpace(info.ChannelIdentityID) + } + return info, nil } func claimString(claims jwt.MapClaims, key string) string { diff --git a/internal/bind/service.go b/internal/bind/service.go new file mode 100644 index 00000000..a0c84188 --- /dev/null +++ b/internal/bind/service.go @@ -0,0 +1,242 @@ +package bind + +import ( + "context" + "errors" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" +) + +const ( + defaultTTL = 24 * time.Hour + maxTokenRetries = 5 +) + +// Service manages channel identity->user bind code lifecycle. +type Service struct { + pool *pgxpool.Pool + queries *sqlc.Queries + logger *slog.Logger +} + +// NewService creates a bind code service. +func NewService(log *slog.Logger, pool *pgxpool.Pool, queries *sqlc.Queries) *Service { + if log == nil { + log = slog.Default() + } + return &Service{ + pool: pool, + queries: queries, + logger: log.With(slog.String("service", "bind")), + } +} + +// Issue creates a new bind code issued by the given user. +// Platform is optional; when provided, bind consume must happen on the same channel platform. +func (s *Service) Issue(ctx context.Context, issuedByUserID, platform string, ttl time.Duration) (Code, error) { + if s.queries == nil { + return Code{}, fmt.Errorf("bind queries not configured") + } + if ttl <= 0 { + ttl = defaultTTL + } + + pgUserID, err := db.ParseUUID(issuedByUserID) + if err != nil { + return Code{}, fmt.Errorf("invalid user id: %w", err) + } + normalizedPlatform := normalizePlatform(platform) + + expiresAt := time.Now().UTC().Add(ttl) + for i := 0; i < maxTokenRetries; i++ { + token := strings.ToUpper(strings.ReplaceAll(uuid.NewString(), "-", "")[:8]) + row, err := s.queries.CreateBindCode(ctx, sqlc.CreateBindCodeParams{ + Token: token, + IssuedByUserID: pgUserID, + ChannelType: pgtype.Text{ + String: normalizedPlatform, + Valid: normalizedPlatform != "", + }, + ExpiresAt: pgtype.Timestamptz{Time: expiresAt, Valid: true}, + }) + if err == nil { + return toCode(row), nil + } + if isUniqueViolation(err) { + continue + } + return Code{}, fmt.Errorf("create bind code: %w", err) + } + return Code{}, fmt.Errorf("create bind code: token collision after retries") +} + +// Get looks up a bind code by token. +func (s *Service) Get(ctx context.Context, token string) (Code, error) { + if s.queries == nil { + return Code{}, fmt.Errorf("bind queries not configured") + } + row, err := s.queries.GetBindCode(ctx, strings.TrimSpace(token)) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return Code{}, ErrCodeNotFound + } + return Code{}, err + } + return toCode(row), nil +} + +// Consume validates and consumes a bind code and links the channel identity to issuer user. +func (s *Service) Consume(ctx context.Context, code Code, channelIdentityID string) error { + if s.queries == nil || s.pool == nil { + return fmt.Errorf("bind service not configured") + } + + // Fast-fail based on caller snapshot before opening a transaction. + if !code.UsedAt.IsZero() { + return ErrCodeUsed + } + if !code.ExpiresAt.IsZero() && time.Now().UTC().After(code.ExpiresAt) { + return ErrCodeExpired + } + token := strings.TrimSpace(code.Token) + if token == "" { + return ErrCodeNotFound + } + sourceIdentityID := strings.TrimSpace(channelIdentityID) + if sourceIdentityID == "" { + return fmt.Errorf("channel identity id is required") + } + pgSourceIdentityID, err := db.ParseUUID(sourceIdentityID) + if err != nil { + return err + } + + tx, err := s.pool.BeginTx(ctx, pgx.TxOptions{}) + if err != nil { + return fmt.Errorf("begin bind consume tx: %w", err) + } + defer func() { _ = tx.Rollback(ctx) }() + qtx := s.queries.WithTx(tx) + + lockedCodeRow, err := qtx.GetBindCodeForUpdate(ctx, token) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ErrCodeNotFound + } + return fmt.Errorf("lock bind code: %w", err) + } + lockedCode := toCode(lockedCodeRow) + if !lockedCode.UsedAt.IsZero() { + return ErrCodeUsed + } + if !lockedCode.ExpiresAt.IsZero() && time.Now().UTC().After(lockedCode.ExpiresAt) { + return ErrCodeExpired + } + if strings.TrimSpace(code.Platform) != "" && !strings.EqualFold(lockedCode.Platform, strings.TrimSpace(code.Platform)) { + return ErrCodeMismatch + } + + targetUserID := strings.TrimSpace(lockedCode.IssuedByUserID) + if targetUserID == "" { + return fmt.Errorf("bind code issuer user is missing") + } + pgTargetUserID, err := db.ParseUUID(targetUserID) + if err != nil { + return err + } + + if _, err := qtx.GetChannelIdentityByIDForUpdate(ctx, pgSourceIdentityID); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return fmt.Errorf("channel identity not found") + } + return fmt.Errorf("lock source identity: %w", err) + } + sourceIdentity, err := qtx.GetChannelIdentityByIDForUpdate(ctx, pgSourceIdentityID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return fmt.Errorf("channel identity not found") + } + return fmt.Errorf("reload source identity: %w", err) + } + if sourceIdentity.UserID.Valid && sourceIdentity.UserID.String() != targetUserID { + return ErrLinkConflict + } + if !sourceIdentity.UserID.Valid { + if _, err := qtx.SetChannelIdentityLinkedUser(ctx, sqlc.SetChannelIdentityLinkedUserParams{ + ID: pgSourceIdentityID, + UserID: pgTargetUserID, + }); err != nil { + return fmt.Errorf("link channel identity user: %w", err) + } + } + + if _, err := qtx.MarkBindCodeUsed(ctx, sqlc.MarkBindCodeUsedParams{ + ID: lockedCodeRow.ID, + UsedByChannelIdentityID: pgSourceIdentityID, + }); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ErrCodeUsed + } + return fmt.Errorf("mark bind code used: %w", err) + } + + if err := tx.Commit(ctx); err != nil { + return fmt.Errorf("commit bind consume tx: %w", err) + } + + s.logger.Info("bind code consumed", + slog.String("code_id", lockedCode.ID), + slog.String("platform", lockedCode.Platform), + slog.String("channel_identity", sourceIdentityID), + slog.String("target_user", targetUserID), + ) + return nil +} + +func toCode(row sqlc.ChannelIdentityBindCode) Code { + c := Code{ + ID: row.ID.String(), + Token: row.Token, + IssuedByUserID: row.IssuedByUserID.String(), + CreatedAt: row.CreatedAt.Time, + } + if row.ChannelType.Valid { + c.Platform = normalizePlatform(row.ChannelType.String) + } + if row.ExpiresAt.Valid { + c.ExpiresAt = row.ExpiresAt.Time + } + if row.UsedAt.Valid { + c.UsedAt = row.UsedAt.Time + } + if row.UsedByChannelIdentityID.Valid { + c.UsedByChannelIdentityID = row.UsedByChannelIdentityID.String() + } + return c +} + +func isUniqueViolation(err error) bool { + var pgErr *pgconn.PgError + if !errors.As(err, &pgErr) { + return false + } + if pgErr.Code != "23505" { + return false + } + return pgErr.ConstraintName == "" || pgErr.ConstraintName == "channel_identity_bind_codes_token_unique" +} + +func normalizePlatform(raw string) string { + return strings.ToLower(strings.TrimSpace(raw)) +} diff --git a/internal/bind/service_integration_test.go b/internal/bind/service_integration_test.go new file mode 100644 index 00000000..b6f0f5ec --- /dev/null +++ b/internal/bind/service_integration_test.go @@ -0,0 +1,240 @@ +package bind_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/memohai/memoh/internal/bind" + "github.com/memohai/memoh/internal/channel/identities" + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" +) + +func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *identities.Service, *bind.Service, func()) { + t.Helper() + + dsn := os.Getenv("TEST_POSTGRES_DSN") + if dsn == "" { + t.Skip("skip integration test: TEST_POSTGRES_DSN is not set") + } + + ctx := context.Background() + pool, err := pgxpool.New(ctx, dsn) + if err != nil { + t.Skipf("skip integration test: cannot connect to database: %v", err) + } + if err := pool.Ping(ctx); err != nil { + pool.Close() + t.Skipf("skip integration test: database ping failed: %v", err) + } + + queries := sqlc.New(pool) + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + channelIdentitySvc := identities.NewService(logger, queries) + bindSvc := bind.NewService(logger, pool, queries) + return queries, channelIdentitySvc, bindSvc, func() { pool.Close() } +} + +func createUserForBind(ctx context.Context, queries *sqlc.Queries) (string, error) { + row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{ + IsActive: true, + Metadata: []byte("{}"), + }) + if err != nil { + return "", err + } + return row.ID.String(), nil +} + +func createBotForBind(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) { + pgOwnerID, err := db.ParseUUID(ownerUserID) + if err != nil { + return "", err + } + meta, err := json.Marshal(map[string]any{"source": "bind-integration-test"}) + if err != nil { + return "", err + } + row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{ + OwnerUserID: pgOwnerID, + Type: "personal", + DisplayName: pgtype.Text{String: "bind-test-bot", Valid: true}, + IsActive: true, + Metadata: meta, + }) + if err != nil { + return "", err + } + return row.ID.String(), nil +} + +func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) { + queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + ownerUserID, err := createUserForBind(ctx, queries) + if err != nil { + t.Fatalf("create owner user failed: %v", err) + } + sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, "feishu", fmt.Sprintf("bind-success-%d", time.Now().UnixNano()), "source") + if err != nil { + t.Fatalf("create source channel identity failed: %v", err) + } + + code, err := bindSvc.Issue(ctx, ownerUserID, "feishu", 10*time.Minute) + if err != nil { + t.Fatalf("issue bind code failed: %v", err) + } + if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); err != nil { + t.Fatalf("consume bind code failed: %v", err) + } + + after, err := bindSvc.Get(ctx, code.Token) + if err != nil { + t.Fatalf("get bind code failed: %v", err) + } + if after.UsedAt.IsZero() { + t.Fatal("expected used_at to be set after consume") + } + if after.UsedByChannelIdentityID != sourceChannelIdentity.ID { + t.Fatalf("expected used_by_channel_identity_id=%s, got %s", sourceChannelIdentity.ID, after.UsedByChannelIdentityID) + } + + linkedUserID, err := channelIdentitySvc.GetLinkedUserID(ctx, sourceChannelIdentity.ID) + if err != nil { + t.Fatalf("get linked user failed: %v", err) + } + if linkedUserID != ownerUserID { + t.Fatalf("expected linked user=%s, got %s", ownerUserID, linkedUserID) + } + + if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); !errors.Is(err, bind.ErrCodeUsed) { + t.Fatalf("expected ErrCodeUsed on second consume, got %v", err) + } +} + +func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) { + queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + ownerUserID, err := createUserForBind(ctx, queries) + if err != nil { + t.Fatalf("create owner user failed: %v", err) + } + otherUserID, err := createUserForBind(ctx, queries) + if err != nil { + t.Fatalf("create other user failed: %v", err) + } + sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, "feishu", fmt.Sprintf("bind-rollback-%d", time.Now().UnixNano()), "source") + if err != nil { + t.Fatalf("create source channel identity failed: %v", err) + } + if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, sourceChannelIdentity.ID, otherUserID); err != nil { + t.Fatalf("pre-link source channel identity failed: %v", err) + } + + code, err := bindSvc.Issue(ctx, ownerUserID, "feishu", 10*time.Minute) + if err != nil { + t.Fatalf("issue bind code failed: %v", err) + } + if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); !errors.Is(err, bind.ErrLinkConflict) { + t.Fatalf("expected ErrLinkConflict, got %v", err) + } + + after, err := bindSvc.Get(ctx, code.Token) + if err != nil { + t.Fatalf("get bind code failed: %v", err) + } + if !after.UsedAt.IsZero() { + t.Fatal("expected used_at to remain empty when consume fails") + } +} + +func TestIntegrationConsumeLinksChannelIdentityToIssuerUser(t *testing.T) { + queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + ownerUserID, err := createUserForBind(ctx, queries) + if err != nil { + t.Fatalf("create owner user failed: %v", err) + } + sourceChannelIdentity, err := channelIdentitySvc.ResolveByChannelIdentity(ctx, "feishu", fmt.Sprintf("bind-src-%d", time.Now().UnixNano()), "source") + if err != nil { + t.Fatalf("create source channelIdentity failed: %v", err) + } + code, err := bindSvc.Issue(ctx, ownerUserID, "feishu", 10*time.Minute) + if err != nil { + t.Fatalf("issue bind code failed: %v", err) + } + if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); err != nil { + t.Fatalf("consume bind code failed: %v", err) + } + + after, err := bindSvc.Get(ctx, code.Token) + if err != nil { + t.Fatalf("get bind code failed: %v", err) + } + if after.UsedAt.IsZero() { + t.Fatal("expected code used_at set after consume") + } + if after.UsedByChannelIdentityID != sourceChannelIdentity.ID { + t.Fatalf("expected used_by_channel_identity_id=%s, got %s", sourceChannelIdentity.ID, after.UsedByChannelIdentityID) + } + + linkedUserID, err := channelIdentitySvc.GetLinkedUserID(ctx, sourceChannelIdentity.ID) + if err != nil { + t.Fatalf("get linked user failed: %v", err) + } + if linkedUserID != ownerUserID { + t.Fatalf("expected linked user=%s, got %s", ownerUserID, linkedUserID) + } +} + +func TestIntegrationConsumeConflictDoesNotMarkUsed(t *testing.T) { + queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + issuerUserID, err := createUserForBind(ctx, queries) + if err != nil { + t.Fatalf("create issuer user failed: %v", err) + } + otherUserID, err := createUserForBind(ctx, queries) + if err != nil { + t.Fatalf("create other user failed: %v", err) + } + sourceChannelIdentity, err := channelIdentitySvc.ResolveByChannelIdentity(ctx, "feishu", fmt.Sprintf("bind-conflict-%d", time.Now().UnixNano()), "source") + if err != nil { + t.Fatalf("create source channelIdentity failed: %v", err) + } + if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, sourceChannelIdentity.ID, otherUserID); err != nil { + t.Fatalf("pre-link source channelIdentity failed: %v", err) + } + code, err := bindSvc.Issue(ctx, issuerUserID, "feishu", 10*time.Minute) + if err != nil { + t.Fatalf("issue bind code failed: %v", err) + } + if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); !errors.Is(err, bind.ErrLinkConflict) { + t.Fatalf("expected ErrLinkConflict, got %v", err) + } + + after, err := bindSvc.Get(ctx, code.Token) + if err != nil { + t.Fatalf("get bind code failed: %v", err) + } + if !after.UsedAt.IsZero() { + t.Fatal("expected code to remain unused after conflict") + } +} diff --git a/internal/bind/service_test.go b/internal/bind/service_test.go new file mode 100644 index 00000000..298a744b --- /dev/null +++ b/internal/bind/service_test.go @@ -0,0 +1,207 @@ +package bind + +import ( + "context" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" +) + +func TestParseUUID(t *testing.T) { + tests := []struct { + name string + id string + wantErr bool + }{ + {"empty", "", true}, + {"blank", " ", true}, + {"invalid", "not-a-uuid", true}, + {"valid", "550e8400-e29b-41d4-a716-446655440000", false}, + {"valid with spaces", " 550e8400-e29b-41d4-a716-446655440000 ", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := db.ParseUUID(tt.id) + if (err != nil) != tt.wantErr { + t.Errorf("parseUUID() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && !got.Valid { + t.Error("expected valid UUID") + } + }) + } +} + +func TestIsUniqueViolation(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"nil", nil, false}, + {"other error", assertAnError(), false}, + {"unique violation token", &pgconn.PgError{Code: "23505", ConstraintName: "channel_identity_bind_codes_token_unique"}, true}, + {"unique violation empty constraint", &pgconn.PgError{Code: "23505", ConstraintName: ""}, true}, + {"wrong code", &pgconn.PgError{Code: "23503", ConstraintName: "some_fk"}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isUniqueViolation(tt.err); got != tt.want { + t.Errorf("isUniqueViolation() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNormalizePlatform(t *testing.T) { + tests := []struct { + raw string + want string + }{ + {"", ""}, + {" Feishu ", "feishu"}, + {"TELEGRAM", "telegram"}, + } + for _, tt := range tests { + if got := normalizePlatform(tt.raw); got != tt.want { + t.Errorf("normalizePlatform(%q) = %q, want %q", tt.raw, got, tt.want) + } + } +} + +func TestToCode(t *testing.T) { + pgID, err := db.ParseUUID("550e8400-e29b-41d4-a716-446655440000") + if err != nil { + t.Fatal(err) + } + now := time.Now().UTC() + var usedBy pgtype.UUID + _ = usedBy.Scan("660e8400-e29b-41d4-a716-446655440001") + row := sqlc.ChannelIdentityBindCode{ + ID: pgID, + Token: "ABC12345", + IssuedByUserID: pgID, + ChannelType: pgtype.Text{String: " Feishu ", Valid: true}, + ExpiresAt: pgtype.Timestamptz{Time: now, Valid: true}, + UsedAt: pgtype.Timestamptz{Time: now, Valid: true}, + UsedByChannelIdentityID: usedBy, + CreatedAt: pgtype.Timestamptz{Time: now, Valid: true}, + } + + c := toCode(row) + if c.Token != "ABC12345" { + t.Errorf("Token = %q", c.Token) + } + if c.Platform != "feishu" { + t.Errorf("Platform = %q (normalized)", c.Platform) + } + if c.IssuedByUserID != "550e8400-e29b-41d4-a716-446655440000" { + t.Errorf("IssuedByUserID = %q", c.IssuedByUserID) + } + if c.ExpiresAt.IsZero() { + t.Error("ExpiresAt should be set") + } + if c.UsedAt.IsZero() { + t.Error("UsedAt should be set") + } + if c.CreatedAt.IsZero() { + t.Error("CreatedAt should be set") + } +} + +func TestToCode_OptionalFields(t *testing.T) { + pgID, err := db.ParseUUID("550e8400-e29b-41d4-a716-446655440000") + if err != nil { + t.Fatal(err) + } + now := time.Now().UTC() + row := sqlc.ChannelIdentityBindCode{ + ID: pgID, + Token: "TOKEN", + IssuedByUserID: pgID, + ChannelType: pgtype.Text{Valid: false}, + ExpiresAt: pgtype.Timestamptz{Valid: false}, + UsedAt: pgtype.Timestamptz{Valid: false}, + CreatedAt: pgtype.Timestamptz{Time: now, Valid: true}, + } + c := toCode(row) + if c.Platform != "" { + t.Errorf("Platform should be empty, got %q", c.Platform) + } + if !c.ExpiresAt.IsZero() { + t.Error("ExpiresAt should be zero") + } + if !c.UsedAt.IsZero() { + t.Error("UsedAt should be zero") + } +} + +func assertAnError() error { + return errForTest +} + +var errForTest = errTyp{msg: "test error"} + +type errTyp struct{ msg string } + +func (e errTyp) Error() string { return e.msg } + +func TestService_Issue_NilQueries(t *testing.T) { + svc := NewService(nil, nil, nil) + _, err := svc.Issue(context.Background(), "550e8400-e29b-41d4-a716-446655440000", "feishu", time.Hour) + if err == nil { + t.Fatal("expected error when queries nil") + } +} + +func TestService_Issue_InvalidUserID(t *testing.T) { + svc := NewService(nil, nil, nil) + _, err := svc.Issue(context.Background(), "invalid", "feishu", time.Hour) + if err == nil { + t.Fatal("expected error for invalid user id") + } +} + +func TestService_Get_NilQueries(t *testing.T) { + svc := NewService(nil, nil, nil) + _, err := svc.Get(context.Background(), "TOKEN") + if err == nil { + t.Fatal("expected error when queries nil") + } +} + +func TestService_Consume_NilConfig(t *testing.T) { + svc := NewService(nil, nil, nil) + code := Code{Token: "ABC", IssuedByUserID: "550e8400-e29b-41d4-a716-446655440000"} + err := svc.Consume(context.Background(), code, "660e8400-e29b-41d4-a716-446655440001") + if err == nil { + t.Fatal("expected error when service not configured") + } +} + +// Consume fast-path (CodeUsed, CodeExpired, EmptyToken) runs after nil check; covered by integration tests. + +func TestService_Consume_EmptyChannelIdentityID(t *testing.T) { + svc := NewService(nil, nil, nil) + code := Code{Token: "ABC", IssuedByUserID: "550e8400-e29b-41d4-a716-446655440000"} + err := svc.Consume(context.Background(), code, "") + if err == nil { + t.Fatal("expected error when channel identity id empty") + } +} + +func TestService_Consume_InvalidChannelIdentityID(t *testing.T) { + svc := NewService(nil, nil, nil) + code := Code{Token: "ABC", IssuedByUserID: "550e8400-e29b-41d4-a716-446655440000"} + err := svc.Consume(context.Background(), code, "not-a-uuid") + if err == nil { + t.Fatal("expected error for invalid channel identity id") + } +} + diff --git a/internal/bind/types.go b/internal/bind/types.go new file mode 100644 index 00000000..0f5182b0 --- /dev/null +++ b/internal/bind/types.go @@ -0,0 +1,26 @@ +package bind + +import ( + "errors" + "time" +) + +var ( + ErrCodeNotFound = errors.New("bind code not found") + ErrCodeUsed = errors.New("bind code already used") + ErrCodeExpired = errors.New("bind code expired") + ErrCodeMismatch = errors.New("bind code context mismatch") + ErrLinkConflict = errors.New("channel identity user link conflict") +) + +// Code represents a one-time bind code for linking channel identity to user. +type Code struct { + ID string `json:"id"` + Platform string `json:"platform,omitempty"` + Token string `json:"token"` + IssuedByUserID string `json:"issued_by_user_id"` + ExpiresAt time.Time `json:"expires_at,omitempty"` + UsedAt time.Time `json:"used_at,omitempty"` + UsedByChannelIdentityID string `json:"used_by_channel_identity_id,omitempty"` + CreatedAt time.Time `json:"created_at"` +} diff --git a/internal/bots/service.go b/internal/bots/service.go index 567115af..4c3dc3b6 100644 --- a/internal/bots/service.go +++ b/internal/bots/service.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "log/slog" + "os" "strings" "time" @@ -13,24 +14,33 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) +// Service provides bot CRUD and membership management. type Service struct { queries *sqlc.Queries logger *slog.Logger containerLifecycle ContainerLifecycle } -var ( - ErrBotNotFound = errors.New("bot not found") - ErrBotAccessDenied = errors.New("bot access denied") +const ( + botLifecycleOperationTimeout = 5 * time.Minute ) +var ( + ErrBotNotFound = errors.New("bot not found") + ErrBotAccessDenied = errors.New("bot access denied") + ErrOwnerUserNotFound = errors.New("owner user not found") +) + +// AccessPolicy controls bot access behavior. type AccessPolicy struct { AllowPublicMember bool } +// NewService creates a new bot service. func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { if log == nil { log = slog.Default() @@ -46,7 +56,8 @@ func (s *Service) SetContainerLifecycle(lc ContainerLifecycle) { s.containerLifecycle = lc } -func (s *Service) AuthorizeAccess(ctx context.Context, actorID, botID string, isAdmin bool, policy AccessPolicy) (Bot, error) { +// AuthorizeAccess checks whether userID may access the given bot. +func (s *Service) AuthorizeAccess(ctx context.Context, userID, botID string, isAdmin bool, policy AccessPolicy) (Bot, error) { if s.queries == nil { return Bot{}, fmt.Errorf("bot queries not configured") } @@ -57,17 +68,18 @@ func (s *Service) AuthorizeAccess(ctx context.Context, actorID, botID string, is } return Bot{}, err } - if isAdmin || bot.OwnerUserID == actorID { + if isAdmin || bot.OwnerUserID == userID { return bot, nil } if policy.AllowPublicMember && bot.Type == BotTypePublic { - if _, err := s.GetMember(ctx, botID, actorID); err == nil { + if _, err := s.GetMember(ctx, botID, userID); err == nil { return bot, nil } } return Bot{}, ErrBotAccessDenied } +// Create creates a new bot owned by owner user. func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotRequest) (Bot, error) { if s.queries == nil { return Bot{}, fmt.Errorf("bot queries not configured") @@ -76,10 +88,13 @@ func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotR if ownerID == "" { return Bot{}, fmt.Errorf("owner user id is required") } - ownerUUID, err := parseUUID(ownerID) + ownerUUID, err := db.ParseUUID(ownerID) if err != nil { return Bot{}, err } + if err := s.ensureUserExists(ctx, ownerUUID); err != nil { + return Bot{}, err + } normalizedType, err := normalizeBotType(req.Type) if err != nil { return Bot{}, err @@ -108,6 +123,7 @@ func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotR AvatarUrl: pgtype.Text{String: avatarURL, Valid: avatarURL != ""}, IsActive: isActive, Metadata: payload, + Status: BotStatusCreating, }) if err != nil { return Bot{}, err @@ -116,22 +132,19 @@ func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotR if err != nil { return Bot{}, err } - if s.containerLifecycle != nil { - if err := s.containerLifecycle.SetupBotContainer(ctx, bot.ID); err != nil { - s.logger.Error("failed to setup bot container", - slog.String("bot_id", bot.ID), - slog.Any("error", err), - ) - } + if err := s.attachCheckSummary(ctx, &bot, row); err != nil { + return Bot{}, err } + s.enqueueCreateLifecycle(bot.ID) return bot, nil } +// Get returns a bot by its ID. func (s *Service) Get(ctx context.Context, botID string) (Bot, error) { if s.queries == nil { return Bot{}, fmt.Errorf("bot queries not configured") } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return Bot{}, err } @@ -139,14 +152,22 @@ func (s *Service) Get(ctx context.Context, botID string) (Bot, error) { if err != nil { return Bot{}, err } - return toBot(row) + bot, err := toBot(row) + if err != nil { + return Bot{}, err + } + if err := s.attachCheckSummary(ctx, &bot, row); err != nil { + return Bot{}, err + } + return bot, nil } +// ListByOwner returns bots owned by the given user. func (s *Service) ListByOwner(ctx context.Context, ownerUserID string) ([]Bot, error) { if s.queries == nil { return nil, fmt.Errorf("bot queries not configured") } - ownerUUID, err := parseUUID(ownerUserID) + ownerUUID, err := db.ParseUUID(ownerUserID) if err != nil { return nil, err } @@ -160,20 +181,24 @@ func (s *Service) ListByOwner(ctx context.Context, ownerUserID string) ([]Bot, e if err != nil { return nil, err } + if err := s.attachCheckSummary(ctx, &item, row); err != nil { + return nil, err + } items = append(items, item) } return items, nil } -func (s *Service) ListByMember(ctx context.Context, userID string) ([]Bot, error) { +// ListByMember returns bots where the user is a member. +func (s *Service) ListByMember(ctx context.Context, channelIdentityID string) ([]Bot, error) { if s.queries == nil { return nil, fmt.Errorf("bot queries not configured") } - userUUID, err := parseUUID(userID) + memberUUID, err := db.ParseUUID(channelIdentityID) if err != nil { return nil, err } - rows, err := s.queries.ListBotsByMember(ctx, userUUID) + rows, err := s.queries.ListBotsByMember(ctx, memberUUID) if err != nil { return nil, err } @@ -183,17 +208,21 @@ func (s *Service) ListByMember(ctx context.Context, userID string) ([]Bot, error if err != nil { return nil, err } + if err := s.attachCheckSummary(ctx, &item, row); err != nil { + return nil, err + } items = append(items, item) } return items, nil } -func (s *Service) ListAccessible(ctx context.Context, userID string) ([]Bot, error) { - owned, err := s.ListByOwner(ctx, userID) +// ListAccessible returns all bots the user can access (owned or member). +func (s *Service) ListAccessible(ctx context.Context, channelIdentityID string) ([]Bot, error) { + owned, err := s.ListByOwner(ctx, channelIdentityID) if err != nil { return nil, err } - members, err := s.ListByMember(ctx, userID) + members, err := s.ListByMember(ctx, channelIdentityID) if err != nil { return nil, err } @@ -213,11 +242,12 @@ func (s *Service) ListAccessible(ctx context.Context, userID string) ([]Bot, err return items, nil } +// Update updates bot profile fields. func (s *Service) Update(ctx context.Context, botID string, req UpdateBotRequest) (Bot, error) { if s.queries == nil { return Bot{}, fmt.Errorf("bot queries not configured") } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return Bot{}, err } @@ -261,21 +291,32 @@ func (s *Service) Update(ctx context.Context, botID string, req UpdateBotRequest if err != nil { return Bot{}, err } - return toBot(row) + bot, err := toBot(row) + if err != nil { + return Bot{}, err + } + if err := s.attachCheckSummary(ctx, &bot, row); err != nil { + return Bot{}, err + } + return bot, nil } +// TransferOwner transfers bot ownership to another user. func (s *Service) TransferOwner(ctx context.Context, botID string, ownerUserID string) (Bot, error) { if s.queries == nil { return Bot{}, fmt.Errorf("bot queries not configured") } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return Bot{}, err } - ownerUUID, err := parseUUID(ownerUserID) + ownerUUID, err := db.ParseUUID(ownerUserID) if err != nil { return Bot{}, err } + if err := s.ensureUserExists(ctx, ownerUUID); err != nil { + return Bot{}, err + } row, err := s.queries.UpdateBotOwner(ctx, sqlc.UpdateBotOwnerParams{ ID: botUUID, OwnerUserID: ownerUUID, @@ -283,40 +324,153 @@ func (s *Service) TransferOwner(ctx context.Context, botID string, ownerUserID s if err != nil { return Bot{}, err } - return toBot(row) + bot, err := toBot(row) + if err != nil { + return Bot{}, err + } + if err := s.attachCheckSummary(ctx, &bot, row); err != nil { + return Bot{}, err + } + return bot, nil } +// Delete removes a bot and its associated resources. func (s *Service) Delete(ctx context.Context, botID string) error { if s.queries == nil { return fmt.Errorf("bot queries not configured") } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return err } - if _, err := s.queries.GetBotByID(ctx, botUUID); err != nil { + row, err := s.queries.GetBotByID(ctx, botUUID) + if err != nil { return err } - if s.containerLifecycle != nil { - if err := s.containerLifecycle.CleanupBotContainer(ctx, botID); err != nil { - s.logger.Error("failed to cleanup bot container", + if strings.TrimSpace(row.Status) == BotStatusDeleting { + return nil + } + if err := s.queries.UpdateBotStatus(ctx, sqlc.UpdateBotStatusParams{ + ID: botUUID, + Status: BotStatusDeleting, + }); err != nil { + return err + } + s.enqueueDeleteLifecycle(botID) + return nil +} + +// ListChecks evaluates runtime resource checks for a bot. +func (s *Service) ListChecks(ctx context.Context, botID string) ([]BotCheck, error) { + if s.queries == nil { + return nil, fmt.Errorf("bot queries not configured") + } + botUUID, err := db.ParseUUID(botID) + if err != nil { + return nil, err + } + row, err := s.queries.GetBotByID(ctx, botUUID) + if err != nil { + return nil, err + } + return s.buildRuntimeChecks(ctx, row) +} + +func (s *Service) enqueueCreateLifecycle(botID string) { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), botLifecycleOperationTimeout) + defer cancel() + + if s.containerLifecycle != nil { + if err := s.containerLifecycle.SetupBotContainer(ctx, botID); err != nil { + s.logger.Error("bot container setup failed", + slog.String("bot_id", botID), + slog.Any("error", err), + ) + } + } + + if err := s.updateStatus(ctx, botID, BotStatusReady); err != nil { + s.logger.Error("failed to update bot status to ready after create", slog.String("bot_id", botID), slog.Any("error", err), ) } - } - return s.queries.DeleteBotByID(ctx, botUUID) + }() } +func (s *Service) enqueueDeleteLifecycle(botID string) { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), botLifecycleOperationTimeout) + defer cancel() + + if s.containerLifecycle != nil { + if err := s.containerLifecycle.CleanupBotContainer(ctx, botID); err != nil { + s.logger.Error("bot container cleanup failed", + slog.String("bot_id", botID), + slog.Any("error", err), + ) + } + } + + botUUID, err := db.ParseUUID(botID) + if err != nil { + s.logger.Error("invalid bot id while finalizing delete", + slog.String("bot_id", botID), + slog.Any("error", err), + ) + _ = s.updateStatus(ctx, botID, BotStatusReady) + return + } + if err := s.queries.DeleteBotByID(ctx, botUUID); err != nil { + s.logger.Error("failed to delete bot after cleanup", + slog.String("bot_id", botID), + slog.Any("error", err), + ) + _ = s.updateStatus(ctx, botID, BotStatusReady) + return + } + }() +} + +func (s *Service) updateStatus(ctx context.Context, botID, status string) error { + if s.queries == nil { + return fmt.Errorf("bot queries not configured") + } + botUUID, err := db.ParseUUID(botID) + if err != nil { + return err + } + return s.queries.UpdateBotStatus(ctx, sqlc.UpdateBotStatusParams{ + ID: botUUID, + Status: strings.TrimSpace(status), + }) +} + +func (s *Service) ensureUserExists(ctx context.Context, userID pgtype.UUID) error { + if s.queries == nil { + return fmt.Errorf("bot queries not configured") + } + _, err := s.queries.GetUserByID(ctx, userID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ErrOwnerUserNotFound + } + return err + } + return nil +} + +// UpsertMember creates or updates a bot membership. func (s *Service) UpsertMember(ctx context.Context, botID string, req UpsertMemberRequest) (BotMember, error) { if s.queries == nil { return BotMember{}, fmt.Errorf("bot queries not configured") } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return BotMember{}, err } - userUUID, err := parseUUID(req.UserID) + memberUUID, err := db.ParseUUID(req.UserID) if err != nil { return BotMember{}, err } @@ -326,7 +480,7 @@ func (s *Service) UpsertMember(ctx context.Context, botID string, req UpsertMemb } row, err := s.queries.UpsertBotMember(ctx, sqlc.UpsertBotMemberParams{ BotID: botUUID, - UserID: userUUID, + UserID: memberUUID, Role: role, }) if err != nil { @@ -335,11 +489,12 @@ func (s *Service) UpsertMember(ctx context.Context, botID string, req UpsertMemb return toBotMember(row), nil } +// ListMembers returns all members of a bot. func (s *Service) ListMembers(ctx context.Context, botID string) ([]BotMember, error) { if s.queries == nil { return nil, fmt.Errorf("bot queries not configured") } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return nil, err } @@ -354,21 +509,22 @@ func (s *Service) ListMembers(ctx context.Context, botID string) ([]BotMember, e return items, nil } -func (s *Service) GetMember(ctx context.Context, botID, userID string) (BotMember, error) { +// GetMember returns a specific bot member. +func (s *Service) GetMember(ctx context.Context, botID, channelIdentityID string) (BotMember, error) { if s.queries == nil { return BotMember{}, fmt.Errorf("bot queries not configured") } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return BotMember{}, err } - userUUID, err := parseUUID(userID) + memberUUID, err := db.ParseUUID(channelIdentityID) if err != nil { return BotMember{}, err } row, err := s.queries.GetBotMember(ctx, sqlc.GetBotMemberParams{ BotID: botUUID, - UserID: userUUID, + UserID: memberUUID, }) if err != nil { return BotMember{}, err @@ -376,26 +532,52 @@ func (s *Service) GetMember(ctx context.Context, botID, userID string) (BotMembe return toBotMember(row), nil } -func (s *Service) DeleteMember(ctx context.Context, botID, userID string) error { +// DeleteMember removes a member from a bot. +func (s *Service) DeleteMember(ctx context.Context, botID, channelIdentityID string) error { if s.queries == nil { return fmt.Errorf("bot queries not configured") } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return err } - userUUID, err := parseUUID(userID) + memberUUID, err := db.ParseUUID(channelIdentityID) if err != nil { return err } return s.queries.DeleteBotMember(ctx, sqlc.DeleteBotMemberParams{ BotID: botUUID, - UserID: userUUID, + UserID: memberUUID, }) } +// UpsertMemberSimple creates or updates a bot membership with a direct channel identity ID and role. +// This satisfies the router.BotMemberService interface. +func (s *Service) UpsertMemberSimple(ctx context.Context, botID, channelIdentityID, role string) error { + _, err := s.UpsertMember(ctx, botID, UpsertMemberRequest{ + UserID: channelIdentityID, + Role: role, + }) + return err +} + +// IsMember checks if a user is a member of a bot. +func (s *Service) IsMember(ctx context.Context, botID, channelIdentityID string) (bool, error) { + _, err := s.GetMember(ctx, botID, channelIdentityID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return false, nil + } + return false, err + } + return true, nil +} + func normalizeBotType(raw string) (string, error) { normalized := strings.ToLower(strings.TrimSpace(raw)) + if normalized == "" { + return BotTypePersonal, nil + } switch normalized { case BotTypePersonal, BotTypePublic: return normalized, nil @@ -439,15 +621,18 @@ func toBot(row sqlc.Bot) (Bot, error) { updatedAt = row.UpdatedAt.Time } return Bot{ - ID: toUUIDString(row.ID), - OwnerUserID: toUUIDString(row.OwnerUserID), - Type: row.Type, - DisplayName: displayName, - AvatarURL: avatarURL, - IsActive: row.IsActive, - Metadata: metadata, - CreatedAt: createdAt, - UpdatedAt: updatedAt, + ID: row.ID.String(), + OwnerUserID: row.OwnerUserID.String(), + Type: row.Type, + DisplayName: displayName, + AvatarURL: avatarURL, + IsActive: row.IsActive, + Status: strings.TrimSpace(row.Status), + CheckState: BotCheckStateUnknown, + CheckIssueCount: 0, + Metadata: metadata, + CreatedAt: createdAt, + UpdatedAt: updatedAt, }, nil } @@ -457,8 +642,8 @@ func toBotMember(row sqlc.BotMember) BotMember { createdAt = row.CreatedAt.Time } return BotMember{ - BotID: toUUIDString(row.BotID), - UserID: toUUIDString(row.UserID), + BotID: row.BotID.String(), + UserID: row.UserID.String(), Role: row.Role, CreatedAt: createdAt, } @@ -478,24 +663,197 @@ func decodeMetadata(payload []byte) (map[string]any, error) { return data, nil } -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) +func (s *Service) attachCheckSummary(ctx context.Context, bot *Bot, row sqlc.Bot) error { + checks, err := s.buildRuntimeChecks(ctx, row) if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) + return err } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil + checkState, issueCount := summarizeChecks(checks) + bot.CheckState = checkState + bot.CheckIssueCount = issueCount + return nil } -func toUUIDString(value pgtype.UUID) string { - if !value.Valid { - return "" +func (s *Service) buildRuntimeChecks(ctx context.Context, row sqlc.Bot) ([]BotCheck, error) { + status := strings.TrimSpace(row.Status) + checks := make([]BotCheck, 0, 4) + + if status == BotStatusCreating { + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerInit, + Status: BotCheckStatusUnknown, + Summary: "Initialization is in progress.", + Detail: "Bot resources are still being provisioned.", + }) + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerRecord, + Status: BotCheckStatusUnknown, + Summary: "Container record is pending.", + Detail: "Container record will be checked after initialization.", + }) + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerTask, + Status: BotCheckStatusUnknown, + Summary: "Container task state is pending.", + Detail: "Task state will be checked after initialization.", + }) + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerData, + Status: BotCheckStatusUnknown, + Summary: "Container host path check is pending.", + Detail: "Data path will be checked after initialization.", + }) + return checks, nil } - parsed, err := uuid.FromBytes(value.Bytes[:]) + if status == BotStatusDeleting { + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyDelete, + Status: BotCheckStatusUnknown, + Summary: "Deletion is in progress.", + Detail: "Bot resources are being cleaned up.", + }) + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerRecord, + Status: BotCheckStatusUnknown, + Summary: "Container record check is skipped.", + Detail: "Bot is deleting and container checks are paused.", + }) + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerTask, + Status: BotCheckStatusUnknown, + Summary: "Container task check is skipped.", + Detail: "Bot is deleting and task checks are paused.", + }) + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerData, + Status: BotCheckStatusUnknown, + Summary: "Container host path check is skipped.", + Detail: "Bot is deleting and data path checks are paused.", + }) + return checks, nil + } + + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerInit, + Status: BotCheckStatusOK, + Summary: "Initialization finished.", + }) + + containerRow, err := s.queries.GetContainerByBotID(ctx, row.ID) if err != nil { - return "" + if errors.Is(err, pgx.ErrNoRows) { + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerRecord, + Status: BotCheckStatusError, + Summary: "Container record is missing.", + Detail: "No container is attached to this bot.", + }) + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerTask, + Status: BotCheckStatusUnknown, + Summary: "Container task state is unknown.", + Detail: "Task state cannot be determined without a container record.", + }) + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerData, + Status: BotCheckStatusUnknown, + Summary: "Container data path is unknown.", + Detail: "Data path cannot be determined without a container record.", + }) + return checks, nil + } + return nil, err } - return parsed.String() + + checks = append(checks, BotCheck{ + CheckKey: BotCheckKeyContainerRecord, + Status: BotCheckStatusOK, + Summary: "Container record exists.", + Detail: fmt.Sprintf("container_id=%s", strings.TrimSpace(containerRow.ContainerID)), + Metadata: map[string]any{ + "container_id": strings.TrimSpace(containerRow.ContainerID), + "namespace": strings.TrimSpace(containerRow.Namespace), + "image": strings.TrimSpace(containerRow.Image), + }, + }) + + taskStatus := strings.TrimSpace(strings.ToLower(containerRow.Status)) + taskCheck := BotCheck{ + CheckKey: BotCheckKeyContainerTask, + Status: BotCheckStatusWarn, + Summary: "Container task state needs attention.", + } + switch taskStatus { + case "running", "created", "stopped", "paused": + taskCheck.Status = BotCheckStatusOK + taskCheck.Summary = "Container task state is reported." + taskCheck.Detail = fmt.Sprintf("status=%s", taskStatus) + case "": + taskCheck.Detail = "status is empty" + default: + taskCheck.Detail = fmt.Sprintf("unexpected status=%s", taskStatus) + } + taskCheck.Metadata = map[string]any{"status": taskStatus} + checks = append(checks, taskCheck) + + hostPath := "" + if containerRow.HostPath.Valid { + hostPath = strings.TrimSpace(containerRow.HostPath.String) + } + dataCheck := BotCheck{ + CheckKey: BotCheckKeyContainerData, + Status: BotCheckStatusWarn, + Summary: "Container host path needs attention.", + Metadata: map[string]any{"host_path": hostPath}, + } + if hostPath == "" { + dataCheck.Detail = "host path is empty" + checks = append(checks, dataCheck) + return checks, nil + } + info, statErr := os.Stat(hostPath) + switch { + case statErr == nil && info != nil && info.IsDir(): + dataCheck.Status = BotCheckStatusOK + dataCheck.Summary = "Container host path is accessible." + dataCheck.Detail = hostPath + case statErr == nil: + dataCheck.Status = BotCheckStatusError + dataCheck.Summary = "Container host path is invalid." + dataCheck.Detail = "host path is not a directory" + case errors.Is(statErr, os.ErrNotExist): + dataCheck.Status = BotCheckStatusError + dataCheck.Summary = "Container host path does not exist." + dataCheck.Detail = hostPath + default: + dataCheck.Status = BotCheckStatusWarn + dataCheck.Summary = "Container host path cannot be checked." + dataCheck.Detail = statErr.Error() + } + checks = append(checks, dataCheck) + + return checks, nil +} + +func summarizeChecks(checks []BotCheck) (string, int32) { + if len(checks) == 0 { + return BotCheckStateUnknown, 0 + } + var issueCount int32 + unknownCount := 0 + for _, check := range checks { + switch check.Status { + case BotCheckStatusWarn, BotCheckStatusError: + issueCount++ + case BotCheckStatusUnknown: + unknownCount++ + } + } + if issueCount > 0 { + return BotCheckStateIssue, issueCount + } + if unknownCount == len(checks) { + return BotCheckStateUnknown, 0 + } + return BotCheckStateOK, 0 } diff --git a/internal/bots/types.go b/internal/bots/types.go index 1a8c9294..e002524c 100644 --- a/internal/bots/types.go +++ b/internal/bots/types.go @@ -5,18 +5,23 @@ import ( "time" ) +// Bot represents a bot entity. type Bot struct { - ID string `json:"id" validate:"required"` - OwnerUserID string `json:"owner_user_id" validate:"required"` - Type string `json:"type" validate:"required"` - DisplayName string `json:"display_name" validate:"required"` - AvatarURL string `json:"avatar_url,omitempty"` - IsActive bool `json:"is_active" validate:"required"` - Metadata map[string]any `json:"metadata,omitempty"` - CreatedAt time.Time `json:"created_at" validate:"required"` - UpdatedAt time.Time `json:"updated_at" validate:"required"` + ID string `json:"id"` + OwnerUserID string `json:"owner_user_id"` + Type string `json:"type"` + DisplayName string `json:"display_name"` + AvatarURL string `json:"avatar_url,omitempty"` + IsActive bool `json:"is_active"` + Status string `json:"status"` + CheckState string `json:"check_state"` + CheckIssueCount int32 `json:"check_issue_count"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } +// BotMember represents a bot membership record. type BotMember struct { BotID string `json:"bot_id"` UserID string `json:"user_id"` @@ -24,6 +29,16 @@ type BotMember struct { CreatedAt time.Time `json:"created_at"` } +// BotCheck represents one resource check row for a bot. +type BotCheck struct { + CheckKey string `json:"check_key"` + Status string `json:"status"` + Summary string `json:"summary"` + Detail string `json:"detail,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// CreateBotRequest is the input for creating a bot. type CreateBotRequest struct { Type string `json:"type"` DisplayName string `json:"display_name,omitempty"` @@ -32,6 +47,7 @@ type CreateBotRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } +// UpdateBotRequest is the input for updating a bot. type UpdateBotRequest struct { DisplayName *string `json:"display_name,omitempty"` AvatarURL *string `json:"avatar_url,omitempty"` @@ -39,23 +55,32 @@ type UpdateBotRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } +// TransferBotRequest is the input for transferring bot ownership. type TransferBotRequest struct { OwnerUserID string `json:"owner_user_id"` } +// UpsertMemberRequest is the input for upserting a bot member. type UpsertMemberRequest struct { UserID string `json:"user_id"` Role string `json:"role,omitempty"` } +// ListBotsResponse wraps a list of bots. type ListBotsResponse struct { - Items []Bot `json:"items" validate:"required"` + Items []Bot `json:"items"` } +// ListMembersResponse wraps a list of bot members. type ListMembersResponse struct { Items []BotMember `json:"items"` } +// ListChecksResponse wraps a list of bot checks. +type ListChecksResponse struct { + Items []BotCheck `json:"items"` +} + // ContainerLifecycle handles container lifecycle events bound to bot operations. type ContainerLifecycle interface { SetupBotContainer(ctx context.Context, botID string) error @@ -67,6 +92,33 @@ const ( BotTypePublic = "public" ) +const ( + BotStatusCreating = "creating" + BotStatusReady = "ready" + BotStatusDeleting = "deleting" +) + +const ( + BotCheckStateOK = "ok" + BotCheckStateIssue = "issue" + BotCheckStateUnknown = "unknown" +) + +const ( + BotCheckStatusOK = "ok" + BotCheckStatusWarn = "warn" + BotCheckStatusError = "error" + BotCheckStatusUnknown = "unknown" +) + +const ( + BotCheckKeyContainerInit = "container.init" + BotCheckKeyContainerRecord = "container.record" + BotCheckKeyContainerTask = "container.task" + BotCheckKeyContainerData = "container.data_path" + BotCheckKeyDelete = "bot.delete" +) + const ( MemberRoleOwner = "owner" MemberRoleAdmin = "admin" diff --git a/internal/channel/adapter.go b/internal/channel/adapter.go index 9cb5f6b7..0b12cf07 100644 --- a/internal/channel/adapter.go +++ b/internal/channel/adapter.go @@ -12,9 +12,42 @@ var ErrStopNotSupported = errors.New("channel connection stop not supported") // InboundHandler is a callback invoked when a message arrives from a channel. type InboundHandler func(ctx context.Context, cfg ChannelConfig, msg InboundMessage) error -// ReplySender sends an outbound reply within the scope of a single inbound message. -type ReplySender interface { +// StreamReplySender sends replies within a single inbound-processing scope. +// It supports both one-shot delivery and streaming sessions. +type StreamReplySender interface { Send(ctx context.Context, msg OutboundMessage) error + OpenStream(ctx context.Context, target string, opts StreamOptions) (OutboundStream, error) +} + +// OutboundStream is a live stream session for emitting outbound events. +type OutboundStream interface { + Push(ctx context.Context, event StreamEvent) error + Close(ctx context.Context) error +} + +// ProcessingStatusInfo carries context for channel-level processing status updates. +type ProcessingStatusInfo struct { + BotID string + ChatID string + RouteID string + ChannelIdentityID string + UserID string + Query string + ReplyTarget string + SourceMessageID string +} + +// ProcessingStatusHandle stores channel-specific state between status callbacks. +type ProcessingStatusHandle struct { + Token string +} + +// ProcessingStatusNotifier reports processing lifecycle updates to channel platforms. +// Implementations should be best-effort and idempotent. +type ProcessingStatusNotifier interface { + ProcessingStarted(ctx context.Context, cfg ChannelConfig, msg InboundMessage, info ProcessingStatusInfo) (ProcessingStatusHandle, error) + ProcessingCompleted(ctx context.Context, cfg ChannelConfig, msg InboundMessage, info ProcessingStatusInfo, handle ProcessingStatusHandle) error + ProcessingFailed(ctx context.Context, cfg ChannelConfig, msg InboundMessage, info ProcessingStatusInfo, handle ProcessingStatusHandle, cause error) error } // Adapter is the base interface every channel adapter must implement. @@ -59,6 +92,17 @@ type Sender interface { Send(ctx context.Context, cfg ChannelConfig, msg OutboundMessage) error } +// StreamSender is an adapter capable of opening outbound stream sessions. +type StreamSender interface { + OpenStream(ctx context.Context, cfg ChannelConfig, target string, opts StreamOptions) (OutboundStream, error) +} + +// MessageEditor updates and deletes already-sent messages when supported. +type MessageEditor interface { + Update(ctx context.Context, cfg ChannelConfig, target string, messageID string, msg Message) error + Unsend(ctx context.Context, cfg ChannelConfig, target string, messageID string) error +} + // Receiver is an adapter capable of establishing a long-lived connection to receive messages. type Receiver interface { Connect(ctx context.Context, cfg ChannelConfig, handler InboundHandler) (Connection, error) diff --git a/internal/channel/adapters/feishu/config.go b/internal/channel/adapters/feishu/config.go index d7def25a..74133ed4 100644 --- a/internal/channel/adapters/feishu/config.go +++ b/internal/channel/adapters/feishu/config.go @@ -79,8 +79,8 @@ func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { if value := strings.TrimSpace(criteria.Attribute("user_id")); value != "" && value == cfg.UserID { return true } - if criteria.ExternalID != "" { - if criteria.ExternalID == cfg.OpenID || criteria.ExternalID == cfg.UserID { + if criteria.SubjectID != "" { + if criteria.SubjectID == cfg.OpenID || criteria.SubjectID == cfg.UserID { return true } } diff --git a/internal/channel/adapters/feishu/directory.go b/internal/channel/adapters/feishu/directory.go new file mode 100644 index 00000000..8493eed9 --- /dev/null +++ b/internal/channel/adapters/feishu/directory.go @@ -0,0 +1,298 @@ +package feishu + +import ( + "context" + "fmt" + "strings" + + lark "github.com/larksuite/oapi-sdk-go/v3" + larkcontact "github.com/larksuite/oapi-sdk-go/v3/service/contact/v3" + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + + "github.com/memohai/memoh/internal/channel" +) + +const ( + defaultDirectoryPageSize = 20 + maxDirectoryPageSize = 200 +) + +func directoryLimit(n int) int { + if n <= 0 { + return defaultDirectoryPageSize + } + if n > maxDirectoryPageSize { + return maxDirectoryPageSize + } + return n +} + +// ListPeers lists users (peers) from Feishu contact, optionally filtered by query. +func (a *FeishuAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + feishuCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, err + } + client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret) + pageSize := directoryLimit(query.Limit) + req := larkcontact.NewListUserReqBuilder(). + UserIdType(larkcontact.UserIdTypeOpenId). + DepartmentIdType(larkcontact.DepartmentIdTypeOpenDepartmentId). + DepartmentId("0"). + PageSize(pageSize). + Build() + resp, err := client.Contact.User.List(ctx, req) + if err != nil { + return nil, fmt.Errorf("feishu list users: %w", err) + } + if !resp.Success() { + return nil, fmt.Errorf("feishu list users: code=%d msg=%s", resp.Code, resp.Msg) + } + entries := make([]channel.DirectoryEntry, 0, len(resp.Data.Items)) + for _, u := range resp.Data.Items { + e := feishuUserToEntry(u) + if query.Query != "" && !strings.Contains(strings.ToLower(e.Name+e.Handle), strings.ToLower(query.Query)) { + continue + } + entries = append(entries, e) + } + return entries, nil +} + +// ListGroups lists chat groups from Feishu IM, optionally filtered by query. +func (a *FeishuAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + feishuCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, err + } + client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret) + pageSize := directoryLimit(query.Limit) + var items []*larkim.ListChat + if strings.TrimSpace(query.Query) != "" { + req := larkim.NewSearchChatReqBuilder(). + UserIdType("open_id"). + Query(strings.TrimSpace(query.Query)). + PageSize(pageSize). + Build() + resp, err := client.Im.Chat.Search(ctx, req) + if err != nil { + return nil, fmt.Errorf("feishu search chats: %w", err) + } + if !resp.Success() { + return nil, fmt.Errorf("feishu search chats: code=%d msg=%s", resp.Code, resp.Msg) + } + items = resp.Data.Items + } else { + req := larkim.NewListChatReqBuilder(). + UserIdType("open_id"). + PageSize(pageSize). + Build() + resp, err := client.Im.Chat.List(ctx, req) + if err != nil { + return nil, fmt.Errorf("feishu list chats: %w", err) + } + if !resp.Success() { + return nil, fmt.Errorf("feishu list chats: code=%d msg=%s", resp.Code, resp.Msg) + } + items = resp.Data.Items + } + entries := make([]channel.DirectoryEntry, 0, len(items)) + for _, c := range items { + entries = append(entries, feishuChatToEntry(c)) + } + return entries, nil +} + +// ListGroupMembers lists members of a Feishu chat group. +func (a *FeishuAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelConfig, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + feishuCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, err + } + chatID := strings.TrimSpace(groupID) + if strings.HasPrefix(chatID, "chat_id:") { + chatID = strings.TrimPrefix(chatID, "chat_id:") + } + if chatID == "" { + return nil, fmt.Errorf("feishu list group members: empty group id") + } + client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret) + pageSize := directoryLimit(query.Limit) + req := larkim.NewGetChatMembersReqBuilder(). + ChatId(chatID). + MemberIdType("open_id"). + PageSize(pageSize). + Build() + resp, err := client.Im.ChatMembers.Get(ctx, req) + if err != nil { + return nil, fmt.Errorf("feishu get chat members: %w", err) + } + if !resp.Success() { + return nil, fmt.Errorf("feishu get chat members: code=%d msg=%s", resp.Code, resp.Msg) + } + entries := make([]channel.DirectoryEntry, 0, len(resp.Data.Items)) + for _, m := range resp.Data.Items { + e := feishuMemberToEntry(m) + if query.Query != "" && !strings.Contains(strings.ToLower(e.Name+e.Handle), strings.ToLower(query.Query)) { + continue + } + entries = append(entries, e) + } + return entries, nil +} + +// ResolveEntry resolves an input string to a user or group DirectoryEntry. +func (a *FeishuAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + feishuCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return channel.DirectoryEntry{}, err + } + client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret) + input = strings.TrimSpace(input) + switch kind { + case channel.DirectoryEntryUser: + return a.resolveUser(ctx, client, input) + case channel.DirectoryEntryGroup: + return a.resolveGroup(ctx, client, input) + default: + return channel.DirectoryEntry{}, fmt.Errorf("feishu resolve entry: unsupported kind %q", kind) + } +} + +func (a *FeishuAdapter) resolveUser(ctx context.Context, client *lark.Client, input string) (channel.DirectoryEntry, error) { + userID, userIDType := parseFeishuUserInput(input) + if userID == "" { + return channel.DirectoryEntry{}, fmt.Errorf("feishu resolve entry user: invalid input %q", input) + } + req := larkcontact.NewGetUserReqBuilder(). + UserId(userID). + UserIdType(userIDType). + Build() + resp, err := client.Contact.User.Get(ctx, req) + if err != nil { + return channel.DirectoryEntry{}, fmt.Errorf("feishu get user: %w", err) + } + if !resp.Success() { + return channel.DirectoryEntry{}, fmt.Errorf("feishu get user: code=%d msg=%s", resp.Code, resp.Msg) + } + if resp.Data == nil || resp.Data.User == nil { + return channel.DirectoryEntry{}, fmt.Errorf("feishu get user: empty response") + } + return feishuUserToEntry(resp.Data.User), nil +} + +func (a *FeishuAdapter) resolveGroup(ctx context.Context, client *lark.Client, input string) (channel.DirectoryEntry, error) { + chatID := strings.TrimSpace(input) + if strings.HasPrefix(chatID, "chat_id:") { + chatID = strings.TrimPrefix(chatID, "chat_id:") + } + if chatID == "" { + return channel.DirectoryEntry{}, fmt.Errorf("feishu resolve entry group: invalid input %q", input) + } + req := larkim.NewGetChatReqBuilder(). + ChatId(chatID). + UserIdType("open_id"). + Build() + resp, err := client.Im.Chat.Get(ctx, req) + if err != nil { + return channel.DirectoryEntry{}, fmt.Errorf("feishu get chat: %w", err) + } + if !resp.Success() { + return channel.DirectoryEntry{}, fmt.Errorf("feishu get chat: code=%d msg=%s", resp.Code, resp.Msg) + } + return channel.DirectoryEntry{ + Kind: channel.DirectoryEntryGroup, + ID: "chat_id:" + chatID, + Name: ptrStr(resp.Data.Name), + AvatarURL: ptrStr(resp.Data.Avatar), + Metadata: map[string]any{"chat_id": chatID}, + }, nil +} + +func parseFeishuUserInput(raw string) (userID, userIDType string) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", "" + } + if strings.HasPrefix(raw, "open_id:") { + return strings.TrimSpace(strings.TrimPrefix(raw, "open_id:")), larkcontact.UserIdTypeOpenId + } + if strings.HasPrefix(raw, "user_id:") { + return strings.TrimSpace(strings.TrimPrefix(raw, "user_id:")), larkcontact.UserIdTypeUserId + } + if strings.HasPrefix(raw, "ou_") { + return raw, larkcontact.UserIdTypeOpenId + } + if strings.HasPrefix(raw, "u_") || strings.HasPrefix(raw, "u-") { + return raw, larkcontact.UserIdTypeUserId + } + return raw, larkcontact.UserIdTypeOpenId +} + +func feishuUserToEntry(u *larkcontact.User) channel.DirectoryEntry { + openID := ptrStr(u.OpenId) + userID := ptrStr(u.UserId) + id := "open_id:" + openID + if openID == "" && userID != "" { + id = "user_id:" + userID + } + meta := make(map[string]any) + if u.OpenId != nil { + meta["open_id"] = *u.OpenId + } + if u.UserId != nil { + meta["user_id"] = *u.UserId + } + return channel.DirectoryEntry{ + Kind: channel.DirectoryEntryUser, + ID: id, + Name: ptrStr(u.Name), + Handle: ptrStr(u.Nickname), + AvatarURL: feishuAvatarURL(u.Avatar), + Metadata: meta, + } +} + +func feishuChatToEntry(c *larkim.ListChat) channel.DirectoryEntry { + chatID := ptrStr(c.ChatId) + meta := map[string]any{"chat_id": chatID} + return channel.DirectoryEntry{ + Kind: channel.DirectoryEntryGroup, + ID: "chat_id:" + chatID, + Name: ptrStr(c.Name), + AvatarURL: ptrStr(c.Avatar), + Metadata: meta, + } +} + +func feishuMemberToEntry(m *larkim.ListMember) channel.DirectoryEntry { + id := ptrStr(m.MemberId) + meta := make(map[string]any) + if m.MemberIdType != nil { + meta["member_id_type"] = *m.MemberIdType + } + prefix := "open_id:" + if m.MemberIdType != nil && *m.MemberIdType == "user_id" { + prefix = "user_id:" + } + return channel.DirectoryEntry{ + Kind: channel.DirectoryEntryUser, + ID: prefix + id, + Name: ptrStr(m.Name), + Metadata: meta, + } +} + +func ptrStr(s *string) string { + if s == nil { + return "" + } + return strings.TrimSpace(*s) +} + +func feishuAvatarURL(avatar *larkcontact.AvatarInfo) string { + if avatar == nil || avatar.Avatar72 == nil { + return "" + } + return strings.TrimSpace(*avatar.Avatar72) +} diff --git a/internal/channel/adapters/feishu/directory_test.go b/internal/channel/adapters/feishu/directory_test.go new file mode 100644 index 00000000..caedadc0 --- /dev/null +++ b/internal/channel/adapters/feishu/directory_test.go @@ -0,0 +1,118 @@ +package feishu + +import ( + "testing" + + larkcontact "github.com/larksuite/oapi-sdk-go/v3/service/contact/v3" + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" +) + +func Test_directoryLimit(t *testing.T) { + tests := []struct { + name string + n int + want int + }{ + {"zero", 0, defaultDirectoryPageSize}, + {"negative", -1, defaultDirectoryPageSize}, + {"one", 1, 1}, + {"default", defaultDirectoryPageSize, defaultDirectoryPageSize}, + {"over max", maxDirectoryPageSize + 100, maxDirectoryPageSize}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := directoryLimit(tt.n); got != tt.want { + t.Errorf("directoryLimit() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_parseFeishuUserInput(t *testing.T) { + tests := []struct { + raw string + wantID string + wantIDType string + }{ + {"open_id:ou_xxx", "ou_xxx", larkcontact.UserIdTypeOpenId}, + {"user_id:u_yyy", "u_yyy", larkcontact.UserIdTypeUserId}, + {"ou_abc", "ou_abc", larkcontact.UserIdTypeOpenId}, + {"u_123", "u_123", larkcontact.UserIdTypeUserId}, + {" open_id: ou_zzz ", "ou_zzz", larkcontact.UserIdTypeOpenId}, + {"", "", ""}, + } + for _, tt := range tests { + id, idType := parseFeishuUserInput(tt.raw) + if id != tt.wantID || idType != tt.wantIDType { + t.Errorf("parseFeishuUserInput(%q) = %q, %q; want %q, %q", tt.raw, id, idType, tt.wantID, tt.wantIDType) + } + } +} + +func Test_ptrStr(t *testing.T) { + s := "x" + if got := ptrStr(nil); got != "" { + t.Errorf("ptrStr(nil) = %q", got) + } + if got := ptrStr(&s); got != "x" { + t.Errorf("ptrStr(&s) = %q", got) + } + space := " a " + if got := ptrStr(&space); got != "a" { + t.Errorf("ptrStr(space) = %q", got) + } +} + +func Test_feishuUserToEntry(t *testing.T) { + openID := "ou_1" + name := "Alice" + u := &larkcontact.User{OpenId: &openID, Name: &name} + e := feishuUserToEntry(u) + if e.Kind != "user" || e.ID != "open_id:ou_1" || e.Name != "Alice" { + t.Errorf("feishuUserToEntry = %+v", e) + } + userID := "u_2" + u2 := &larkcontact.User{UserId: &userID, Name: &name} + e2 := feishuUserToEntry(u2) + if e2.ID != "user_id:u_2" { + t.Errorf("feishuUserToEntry user_id only = %+v", e2) + } +} + +func Test_feishuChatToEntry(t *testing.T) { + chatID := "oc_abc" + name := "Test Group" + c := &larkim.ListChat{ChatId: &chatID, Name: &name} + e := feishuChatToEntry(c) + if e.Kind != "group" || e.ID != "chat_id:oc_abc" || e.Name != "Test Group" { + t.Errorf("feishuChatToEntry = %+v", e) + } +} + +func Test_feishuMemberToEntry(t *testing.T) { + memberID := "ou_m1" + memberType := "open_id" + name := "Bob" + m := &larkim.ListMember{MemberId: &memberID, MemberIdType: &memberType, Name: &name} + e := feishuMemberToEntry(m) + if e.Kind != "user" || e.ID != "open_id:ou_m1" || e.Name != "Bob" { + t.Errorf("feishuMemberToEntry = %+v", e) + } + memberTypeUser := "user_id" + m2 := &larkim.ListMember{MemberId: &memberID, MemberIdType: &memberTypeUser, Name: &name} + e2 := feishuMemberToEntry(m2) + if e2.ID != "user_id:ou_m1" { + t.Errorf("feishuMemberToEntry user_id type = %+v", e2) + } +} + +func Test_feishuAvatarURL(t *testing.T) { + if got := feishuAvatarURL(nil); got != "" { + t.Errorf("feishuAvatarURL(nil) = %q", got) + } + url72 := "https://avatar.example/72.png" + a := &larkcontact.AvatarInfo{Avatar72: &url72} + if got := feishuAvatarURL(a); got != url72 { + t.Errorf("feishuAvatarURL = %q", got) + } +} diff --git a/internal/channel/adapters/feishu/feishu.go b/internal/channel/adapters/feishu/feishu.go index e1cc6d7f..d07be6a8 100644 --- a/internal/channel/adapters/feishu/feishu.go +++ b/internal/channel/adapters/feishu/feishu.go @@ -25,6 +25,75 @@ type FeishuAdapter struct { logger *slog.Logger } +const processingBusyReactionType = "Typing" + +type messageReactionAPI interface { + Create(ctx context.Context, req *larkim.CreateMessageReactionReq, options ...larkcore.RequestOptionFunc) (*larkim.CreateMessageReactionResp, error) + Delete(ctx context.Context, req *larkim.DeleteMessageReactionReq, options ...larkcore.RequestOptionFunc) (*larkim.DeleteMessageReactionResp, error) +} + +type processingReactionGateway interface { + Add(ctx context.Context, messageID, reactionType string) (string, error) + Remove(ctx context.Context, messageID, reactionID string) error +} + +type larkProcessingReactionGateway struct { + api messageReactionAPI +} + +func (g *larkProcessingReactionGateway) Add(ctx context.Context, messageID, reactionType string) (string, error) { + if g == nil || g.api == nil { + return "", fmt.Errorf("feishu reaction api not configured") + } + req := larkim.NewCreateMessageReactionReqBuilder(). + MessageId(messageID). + Body(larkim.NewCreateMessageReactionReqBodyBuilder(). + ReactionType(larkim.NewEmojiBuilder().EmojiType(reactionType).Build()). + Build()). + Build() + resp, err := g.api.Create(ctx, req) + if err != nil { + return "", err + } + if resp == nil || !resp.Success() { + code := 0 + msg := "" + if resp != nil { + code = resp.Code + msg = resp.Msg + } + return "", fmt.Errorf("feishu add reaction failed: %s (code: %d)", msg, code) + } + if resp.Data == nil || resp.Data.ReactionId == nil || strings.TrimSpace(*resp.Data.ReactionId) == "" { + return "", fmt.Errorf("feishu add reaction failed: empty reaction id") + } + return strings.TrimSpace(*resp.Data.ReactionId), nil +} + +func (g *larkProcessingReactionGateway) Remove(ctx context.Context, messageID, reactionID string) error { + if g == nil || g.api == nil { + return fmt.Errorf("feishu reaction api not configured") + } + req := larkim.NewDeleteMessageReactionReqBuilder(). + MessageId(messageID). + ReactionId(reactionID). + Build() + resp, err := g.api.Delete(ctx, req) + if err != nil { + return err + } + if resp == nil || !resp.Success() { + code := 0 + msg := "" + if resp != nil { + code = resp.Code + msg = resp.Msg + } + return fmt.Errorf("feishu remove reaction failed: %s (code: %d)", msg, code) + } + return nil +} + // NewFeishuAdapter creates a FeishuAdapter with the given logger. func NewFeishuAdapter(log *slog.Logger) *FeishuAdapter { if log == nil { @@ -46,10 +115,14 @@ func (a *FeishuAdapter) Descriptor() channel.Descriptor { Type: Type, DisplayName: "Feishu", Capabilities: channel.ChannelCapabilities{ - Text: true, - RichText: true, - Attachments: true, - Reply: true, + Text: true, + RichText: true, + Attachments: true, + Media: true, + Reactions: true, + Reply: true, + Streaming: true, + BlockStreaming: true, }, ConfigSchema: channel.ConfigSchema{ Version: 1, @@ -84,6 +157,79 @@ func (a *FeishuAdapter) Descriptor() channel.Descriptor { } } +// ProcessingStarted adds a transient reaction to indicate the inbound message is being processed. +func (a *FeishuAdapter) ProcessingStarted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { + messageID := strings.TrimSpace(info.SourceMessageID) + if messageID == "" { + return channel.ProcessingStatusHandle{}, nil + } + gateway, err := a.processingReactionGateway(cfg) + if err != nil { + return channel.ProcessingStatusHandle{}, err + } + token, err := addProcessingReaction(ctx, gateway, messageID, processingBusyReactionType) + if err != nil { + return channel.ProcessingStatusHandle{}, err + } + return channel.ProcessingStatusHandle{Token: token}, nil +} + +// ProcessingCompleted removes the transient processing reaction before output is sent. +func (a *FeishuAdapter) ProcessingCompleted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { + messageID := strings.TrimSpace(info.SourceMessageID) + reactionID := strings.TrimSpace(handle.Token) + if messageID == "" || reactionID == "" { + return nil + } + gateway, err := a.processingReactionGateway(cfg) + if err != nil { + return err + } + return removeProcessingReaction(ctx, gateway, messageID, reactionID) +} + +// ProcessingFailed removes the transient processing reaction when chat processing fails. +func (a *FeishuAdapter) ProcessingFailed(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error { + return a.ProcessingCompleted(ctx, cfg, msg, info, handle) +} + +func (a *FeishuAdapter) processingReactionGateway(cfg channel.ChannelConfig) (processingReactionGateway, error) { + feishuCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, err + } + client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret) + gateway := &larkProcessingReactionGateway{api: client.Im.V1.MessageReaction} + return gateway, nil +} + +func addProcessingReaction(ctx context.Context, gateway processingReactionGateway, messageID, reactionType string) (string, error) { + if gateway == nil { + return "", fmt.Errorf("processing reaction gateway is nil") + } + msgID := strings.TrimSpace(messageID) + if msgID == "" { + return "", nil + } + rxType := strings.TrimSpace(reactionType) + if rxType == "" { + return "", fmt.Errorf("processing reaction type is empty") + } + return gateway.Add(ctx, msgID, rxType) +} + +func removeProcessingReaction(ctx context.Context, gateway processingReactionGateway, messageID, reactionID string) error { + if gateway == nil { + return fmt.Errorf("processing reaction gateway is nil") + } + msgID := strings.TrimSpace(messageID) + rxID := strings.TrimSpace(reactionID) + if msgID == "" || rxID == "" { + return nil + } + return gateway.Remove(ctx, msgID, rxID) +} + // NormalizeConfig validates and normalizes a Feishu channel configuration map. func (a *FeishuAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { return normalizeConfig(raw) @@ -127,48 +273,110 @@ func (a *FeishuAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, return nil, err } connCtx, cancel := context.WithCancel(ctx) - eventDispatcher := dispatcher.NewEventDispatcher( - feishuCfg.VerificationToken, - feishuCfg.EncryptKey, - ) - eventDispatcher.OnP2MessageReceiveV1(func(_ context.Context, event *larkim.P2MessageReceiveV1) error { - msg := extractFeishuInbound(event) - text := msg.Message.PlainText() - if text == "" && len(msg.Message.Attachments) == 0 { - return nil - } - msg.BotID = cfg.BotID - if a.logger != nil { - a.logger.Info( - "inbound received", - slog.String("config_id", cfg.ID), - slog.String("session_id", msg.SessionID()), - slog.String("chat_type", msg.Conversation.Type), - slog.String("text", common.SummarizeText(text)), - ) - } - go func() { - if err := handler(connCtx, cfg, msg); err != nil && a.logger != nil { - a.logger.Error("handle inbound failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) + newClient := func() *larkws.Client { + eventDispatcher := dispatcher.NewEventDispatcher( + feishuCfg.VerificationToken, + feishuCfg.EncryptKey, + ) + eventDispatcher.OnP2MessageReceiveV1(func(_ context.Context, event *larkim.P2MessageReceiveV1) error { + if connCtx.Err() != nil { + return nil } - }() - return nil - }) - eventDispatcher.OnP2MessageReadV1(func(_ context.Context, _ *larkim.P2MessageReadV1) error { - return nil - }) - - client := larkws.NewClient( - feishuCfg.AppID, - feishuCfg.AppSecret, - larkws.WithEventHandler(eventDispatcher), - larkws.WithLogger(newLarkSlogLogger(a.logger)), - larkws.WithLogLevel(larkcore.LogLevelDebug), - ) + msg := extractFeishuInbound(event) + text := msg.Message.PlainText() + rawMessageID := "" + rawMessageType := "" + if event != nil && event.Event != nil && event.Event.Message != nil { + if event.Event.Message.MessageId != nil { + rawMessageID = strings.TrimSpace(*event.Event.Message.MessageId) + } + if event.Event.Message.MessageType != nil { + rawMessageType = strings.TrimSpace(string(*event.Event.Message.MessageType)) + } + } + if text == "" && len(msg.Message.Attachments) == 0 { + if a.logger != nil { + a.logger.Info( + "inbound ignored empty payload", + slog.String("config_id", cfg.ID), + slog.String("message_id", rawMessageID), + slog.String("message_type", rawMessageType), + slog.String("chat_type", msg.Conversation.Type), + ) + } + return nil + } + msg.BotID = cfg.BotID + if a.logger != nil { + isMentioned := false + if msg.Metadata != nil { + if v, ok := msg.Metadata["is_mentioned"].(bool); ok { + isMentioned = v + } + } + a.logger.Info( + "inbound received", + slog.String("config_id", cfg.ID), + slog.String("message_id", rawMessageID), + slog.String("message_type", rawMessageType), + slog.String("route_key", msg.RoutingKey()), + slog.String("chat_type", msg.Conversation.Type), + slog.Bool("is_mentioned", isMentioned), + slog.String("text", common.SummarizeText(text)), + ) + } + go func() { + if err := handler(connCtx, cfg, msg); err != nil && a.logger != nil { + a.logger.Error("handle inbound failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) + } + }() + return nil + }) + eventDispatcher.OnP2MessageReadV1(func(_ context.Context, _ *larkim.P2MessageReadV1) error { + return nil + }) + // Ignore reaction lifecycle events explicitly to avoid SDK "not found handler" noise logs. + // These events are expected because the adapter uses reactions for processing status. + eventDispatcher.OnP2MessageReactionCreatedV1(func(_ context.Context, _ *larkim.P2MessageReactionCreatedV1) error { + return nil + }) + eventDispatcher.OnP2MessageReactionDeletedV1(func(_ context.Context, _ *larkim.P2MessageReactionDeletedV1) error { + return nil + }) + return larkws.NewClient( + feishuCfg.AppID, + feishuCfg.AppSecret, + larkws.WithEventHandler(eventDispatcher), + larkws.WithLogger(newLarkSlogLogger(a.logger)), + larkws.WithLogLevel(larkcore.LogLevelDebug), + ) + } go func() { - if err := client.Start(connCtx); err != nil && a.logger != nil { - a.logger.Error("client start failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) + const reconnectDelay = 3 * time.Second + for { + if connCtx.Err() != nil { + return + } + client := newClient() + err := client.Start(connCtx) + if connCtx.Err() != nil { + return + } + if a.logger != nil { + if err != nil { + a.logger.Error("client start failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) + } else { + a.logger.Warn("client exited without error; reconnecting", slog.String("config_id", cfg.ID)) + } + } + timer := time.NewTimer(reconnectDelay) + select { + case <-connCtx.Done(): + timer.Stop() + return + case <-timer.C: + } } }() @@ -256,6 +464,39 @@ func (a *FeishuAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg return a.handleResponse(cfg.ID, resp, err) } +// OpenStream opens a Feishu streaming session. +// The adapter strategy uses one interactive card and patches it incrementally. +func (a *FeishuAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { + target = strings.TrimSpace(target) + if target == "" { + return nil, fmt.Errorf("feishu target is required") + } + feishuCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, err + } + receiveID, receiveType, err := resolveFeishuReceiveID(target) + if err != nil { + return nil, err + } + client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret) + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + return &feishuOutboundStream{ + adapter: a, + cfg: cfg, + target: target, + reply: opts.Reply, + client: client, + receiveID: receiveID, + receiveType: receiveType, + patchInterval: feishuStreamPatchInterval, + }, nil +} + func (a *FeishuAdapter) handleReplyResponse(configID string, resp *larkim.ReplyMessageResp, err error) error { if err != nil { if a.logger != nil { @@ -307,66 +548,83 @@ func (a *FeishuAdapter) handleResponse(configID string, resp *larkim.CreateMessa } func (a *FeishuAdapter) sendAttachment(ctx context.Context, client *lark.Client, receiveID, receiveType string, att channel.Attachment, text string) error { - httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, att.URL, nil) - if err != nil { - return fmt.Errorf("failed to build download request: %w", err) - } - httpClient := &http.Client{Timeout: 60 * time.Second} - resp, err := httpClient.Do(httpReq) - if err != nil { - return fmt.Errorf("failed to download attachment: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("failed to download attachment, status: %d", resp.StatusCode) - } - var msgType string var contentMap map[string]string - - if strings.HasPrefix(att.Mime, "image/") || att.Type == channel.AttachmentImage { - uploadReq := larkim.NewCreateImageReqBuilder(). - Body(larkim.NewCreateImageReqBodyBuilder(). - ImageType(larkim.ImageTypeMessage). - Image(resp.Body). - Build()). - Build() - uploadResp, err := client.Im.V1.Image.Create(ctx, uploadReq) - if err != nil { - return fmt.Errorf("failed to upload image: %w", err) + sourcePlatform := strings.TrimSpace(att.SourcePlatform) + platformKey := strings.TrimSpace(att.PlatformKey) + if platformKey != "" && (sourcePlatform == "" || strings.EqualFold(sourcePlatform, Type.String())) { + if strings.HasPrefix(att.Mime, "image/") || att.Type == channel.AttachmentImage { + msgType = larkim.MsgTypeImage + contentMap = map[string]string{"image_key": platformKey} + } else { + msgType = larkim.MsgTypeFile + contentMap = map[string]string{"file_key": platformKey} } - if uploadResp == nil || !uploadResp.Success() { - code, msg := 0, "" - if uploadResp != nil { - code, msg = uploadResp.Code, uploadResp.Msg - } - return fmt.Errorf("failed to upload image: %s (code: %d)", msg, code) - } - msgType = larkim.MsgTypeImage - contentMap = map[string]string{"image_key": *uploadResp.Data.ImageKey} } else { - fileType := resolveFeishuFileType(att.Name, att.Mime) - uploadReq := larkim.NewCreateFileReqBuilder(). - Body(larkim.NewCreateFileReqBodyBuilder(). - FileType(fileType). - FileName(att.Name). - File(resp.Body). - Build()). - Build() - uploadResp, err := client.Im.V1.File.Create(ctx, uploadReq) + downloadURL := strings.TrimSpace(att.URL) + if downloadURL == "" { + return fmt.Errorf("failed to download attachment: url is required when platform key is unavailable") + } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil) if err != nil { - return fmt.Errorf("failed to upload file: %w", err) + return fmt.Errorf("failed to build download request: %w", err) } - if uploadResp == nil || !uploadResp.Success() { - code, msg := 0, "" - if uploadResp != nil { - code, msg = uploadResp.Code, uploadResp.Msg + httpClient := &http.Client{Timeout: 60 * time.Second} + resp, err := httpClient.Do(httpReq) + if err != nil { + return fmt.Errorf("failed to download attachment: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to download attachment, status: %d", resp.StatusCode) + } + if strings.HasPrefix(att.Mime, "image/") || att.Type == channel.AttachmentImage { + uploadReq := larkim.NewCreateImageReqBuilder(). + Body(larkim.NewCreateImageReqBodyBuilder(). + ImageType(larkim.ImageTypeMessage). + Image(resp.Body). + Build()). + Build() + uploadResp, err := client.Im.V1.Image.Create(ctx, uploadReq) + if err != nil { + return fmt.Errorf("failed to upload image: %w", err) } - return fmt.Errorf("failed to upload file: %s (code: %d)", msg, code) + if uploadResp == nil || !uploadResp.Success() { + code, msg := 0, "" + if uploadResp != nil { + code, msg = uploadResp.Code, uploadResp.Msg + } + return fmt.Errorf("failed to upload image: %s (code: %d)", msg, code) + } + msgType = larkim.MsgTypeImage + contentMap = map[string]string{"image_key": *uploadResp.Data.ImageKey} + } else { + fileType := resolveFeishuFileType(att.Name, att.Mime) + fileName := strings.TrimSpace(att.Name) + if fileName == "" { + fileName = "attachment" + } + uploadReq := larkim.NewCreateFileReqBuilder(). + Body(larkim.NewCreateFileReqBodyBuilder(). + FileType(fileType). + FileName(fileName). + File(resp.Body). + Build()). + Build() + uploadResp, err := client.Im.V1.File.Create(ctx, uploadReq) + if err != nil { + return fmt.Errorf("failed to upload file: %w", err) + } + if uploadResp == nil || !uploadResp.Success() { + code, msg := 0, "" + if uploadResp != nil { + code, msg = uploadResp.Code, uploadResp.Msg + } + return fmt.Errorf("failed to upload file: %s (code: %d)", msg, code) + } + msgType = larkim.MsgTypeFile + contentMap = map[string]string{"file_key": *uploadResp.Data.FileKey} } - msgType = larkim.MsgTypeFile - contentMap = map[string]string{"file_key": *uploadResp.Data.FileKey} } content, err := json.Marshal(contentMap) @@ -421,10 +679,77 @@ func (a *FeishuAdapter) buildPostContent(msg channel.Message) (string, error) { line := []any{} for _, part := range msg.Parts { - if part.Type == channel.MessagePartText { + switch part.Type { + case channel.MessagePartText: + text := strings.TrimSpace(part.Text) + if text == "" { + continue + } line = append(line, map[string]any{ "tag": "text", - "text": part.Text, + "text": text, + }) + case channel.MessagePartLink: + url := strings.TrimSpace(part.URL) + label := strings.TrimSpace(part.Text) + if label == "" { + label = url + } + if url == "" || label == "" { + continue + } + line = append(line, map[string]any{ + "tag": "a", + "text": label, + "href": url, + }) + case channel.MessagePartCodeBlock: + code := strings.TrimSpace(part.Text) + if code == "" { + continue + } + language := strings.TrimSpace(part.Language) + if language != "" { + code = "```" + language + "\n" + code + "\n```" + } else { + code = "```\n" + code + "\n```" + } + line = append(line, map[string]any{ + "tag": "text", + "text": code, + }) + case channel.MessagePartMention: + mention := strings.TrimSpace(part.Text) + if mention == "" { + mention = strings.TrimSpace(part.ChannelIdentityID) + } + if mention == "" { + continue + } + line = append(line, map[string]any{ + "tag": "text", + "text": "@" + mention, + }) + case channel.MessagePartEmoji: + emoji := strings.TrimSpace(part.Emoji) + if emoji == "" { + emoji = strings.TrimSpace(part.Text) + } + if emoji == "" { + continue + } + line = append(line, map[string]any{ + "tag": "text", + "text": emoji, + }) + } + } + if len(line) == 0 { + text := strings.TrimSpace(msg.PlainText()) + if text != "" { + line = append(line, map[string]any{ + "tag": "text", + "text": text, }) } } @@ -433,120 +758,3 @@ func (a *FeishuAdapter) buildPostContent(msg channel.Message) (string, error) { payload, err := json.Marshal(pc) return string(payload), err } - -func extractFeishuInbound(event *larkim.P2MessageReceiveV1) channel.InboundMessage { - if event == nil || event.Event == nil || event.Event.Message == nil { - return channel.InboundMessage{Channel: Type} - } - message := event.Event.Message - - var msg channel.Message - if message.MessageId != nil { - msg.ID = *message.MessageId - } - - var contentMap map[string]any - if message.Content != nil { - _ = json.Unmarshal([]byte(*message.Content), &contentMap) - } - - if message.MessageType != nil { - switch *message.MessageType { - case larkim.MsgTypeText: - if txt, ok := contentMap["text"].(string); ok { - msg.Text = txt - } - case larkim.MsgTypeImage: - if key, ok := contentMap["image_key"].(string); ok { - msg.Attachments = append(msg.Attachments, channel.Attachment{ - Type: channel.AttachmentImage, - URL: key, - }) - } - case larkim.MsgTypeFile, larkim.MsgTypeAudio: - if key, ok := contentMap["file_key"].(string); ok { - name, _ := contentMap["file_name"].(string) - msg.Attachments = append(msg.Attachments, channel.Attachment{ - Type: channel.AttachmentType(*message.MessageType), - URL: key, - Name: name, - }) - } - } - } - - if message.ParentId != nil && *message.ParentId != "" { - msg.Reply = &channel.ReplyRef{ - MessageID: *message.ParentId, - } - } - - senderID, senderOpenID := "", "" - if event.Event.Sender != nil && event.Event.Sender.SenderId != nil { - if event.Event.Sender.SenderId.UserId != nil { - senderID = strings.TrimSpace(*event.Event.Sender.SenderId.UserId) - } - if event.Event.Sender.SenderId.OpenId != nil { - senderOpenID = strings.TrimSpace(*event.Event.Sender.SenderId.OpenId) - } - } - chatID := "" - chatType := "" - if message.ChatId != nil { - chatID = strings.TrimSpace(*message.ChatId) - } - if message.ChatType != nil { - chatType = strings.TrimSpace(*message.ChatType) - } - replyTo := senderOpenID - if replyTo == "" { - replyTo = senderID - } - if chatType != "" && chatType != "p2p" && chatID != "" { - replyTo = "chat_id:" + chatID - } - attrs := map[string]string{} - if senderID != "" { - attrs["user_id"] = senderID - } - if senderOpenID != "" { - attrs["open_id"] = senderOpenID - } - externalID := senderOpenID - if externalID == "" { - externalID = senderID - } - - return channel.InboundMessage{ - Channel: Type, - Message: msg, - ReplyTarget: replyTo, - Sender: channel.Identity{ - ExternalID: externalID, - DisplayName: senderOpenID, - Attributes: attrs, - }, - Conversation: channel.Conversation{ - ID: chatID, - Type: chatType, - }, - ReceivedAt: time.Now().UTC(), - Source: "feishu", - } -} - -func resolveFeishuReceiveID(raw string) (string, string, error) { - if raw == "" { - return "", "", fmt.Errorf("feishu target is required") - } - if strings.HasPrefix(raw, "open_id:") { - return strings.TrimPrefix(raw, "open_id:"), larkim.ReceiveIdTypeOpenId, nil - } - if strings.HasPrefix(raw, "user_id:") { - return strings.TrimPrefix(raw, "user_id:"), larkim.ReceiveIdTypeUserId, nil - } - if strings.HasPrefix(raw, "chat_id:") { - return strings.TrimPrefix(raw, "chat_id:"), larkim.ReceiveIdTypeChatId, nil - } - return raw, larkim.ReceiveIdTypeOpenId, nil -} diff --git a/internal/channel/adapters/feishu/feishu_integration_test.go b/internal/channel/adapters/feishu/feishu_integration_test.go index 80daea33..35e4fe56 100644 --- a/internal/channel/adapters/feishu/feishu_integration_test.go +++ b/internal/channel/adapters/feishu/feishu_integration_test.go @@ -11,30 +11,25 @@ import ( "github.com/memohai/memoh/internal/channel" ) -// TestFeishuGateway_Integration 飞书通道集成测试 -// 运行此测试需要设置环境变量: -// FEISHU_APP_ID: 飞书应用的 App ID -// FEISHU_APP_SECRET: 飞书应用的 App Secret -// FEISHU_ENCRYPT_KEY: (可选) 飞书应用的 Encrypt Key -// FEISHU_VERIFICATION_TOKEN: (可选) 飞书应用的 Verification Token +// TestFeishuGateway_Integration runs Feishu channel integration test. +// Required env: FEISHU_APP_ID, FEISHU_APP_SECRET. +// Optional: FEISHU_ENCRYPT_KEY, FEISHU_VERIFICATION_TOKEN. func TestFeishuGateway_Integration(t *testing.T) { appID := os.Getenv("FEISHU_APP_ID") appSecret := os.Getenv("FEISHU_APP_SECRET") if appID == "" || appSecret == "" { - t.Skip("跳过集成测试: 未设置 FEISHU_APP_ID 或 FEISHU_APP_SECRET 环境变量") + t.Skip("skipping integration test: FEISHU_APP_ID or FEISHU_APP_SECRET not set") } encryptKey := os.Getenv("FEISHU_ENCRYPT_KEY") verificationToken := os.Getenv("FEISHU_VERIFICATION_TOKEN") - // 使用更规范的日志配置 logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelInfo, })) adapter := NewFeishuAdapter(logger) - // 构造测试配置 cfg := channel.ChannelConfig{ ID: "integration-test-bot", Credentials: map[string]any{ @@ -45,28 +40,23 @@ func TestFeishuGateway_Integration(t *testing.T) { }, } - // 定义测试上下文,设置合理的超时时间 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() - // 消息计数,用于验证是否收到消息 receivedChan := make(chan channel.InboundMessage, 1) - // 模拟 InboundHandler handler := func(ctx context.Context, c channel.ChannelConfig, msg channel.InboundMessage) error { plainText := msg.Message.PlainText() - logger.Info("测试收到消息", + logger.Info("received message in test", slog.String("text", plainText), slog.String("user_id", msg.Sender.Attribute("user_id")), - slog.String("session_id", msg.SessionID())) + slog.String("route_key", msg.RoutingKey())) - // 将消息放入通道,供主测试逻辑验证 select { case receivedChan <- msg: default: } - // 自动回复测试 (验证下行链路) reply := channel.OutboundMessage{ Target: msg.ReplyTarget, Message: channel.Message{ @@ -78,7 +68,6 @@ func TestFeishuGateway_Integration(t *testing.T) { return fmt.Errorf("failed to send reply: %w", err) } - // 模拟异步主动推送测试 go func() { time.Sleep(1 * time.Second) pushMsg := channel.OutboundMessage{ @@ -93,31 +82,27 @@ func TestFeishuGateway_Integration(t *testing.T) { return nil } - // 启动适配器 - logger.Info("正在启动飞书适配器...", slog.String("app_id", appID)) + logger.Info("starting Feishu adapter", slog.String("app_id", appID)) runner, err := adapter.Connect(ctx, cfg, handler) if err != nil { - t.Fatalf("适配器启动失败: %v", err) + t.Fatalf("adapter connect failed: %v", err) } defer func() { _ = runner.Stop(context.Background()) }() fmt.Println("==================================================================") - fmt.Println("🚀 飞书集成测试已就绪!") - fmt.Println("请在飞书客户端向机器人发送一条消息,以完成端到端验证。") - fmt.Println("测试将在收到第一条消息或 10 分钟超时后结束。") + fmt.Println("Feishu integration test ready. Send a message in Feishu client to verify.") + fmt.Println("Test ends on first message received or 10 min timeout.") fmt.Println("==================================================================") - // 等待测试结果 select { case msg := <-receivedChan: - logger.Info("集成测试验证成功!", slog.String("received_text", msg.Message.PlainText())) - // 给一点时间让异步推送完成 + logger.Info("integration test passed", slog.String("received_text", msg.Message.PlainText())) time.Sleep(2 * time.Second) case <-ctx.Done(): if ctx.Err() == context.DeadlineExceeded { - t.Log("测试超时结束") + t.Log("test timed out") } } } diff --git a/internal/channel/adapters/feishu/feishu_test.go b/internal/channel/adapters/feishu/feishu_test.go index 5c6c47dc..77847386 100644 --- a/internal/channel/adapters/feishu/feishu_test.go +++ b/internal/channel/adapters/feishu/feishu_test.go @@ -1,11 +1,48 @@ package feishu import ( + "context" + "encoding/json" + "errors" + "strings" "testing" larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + + "github.com/memohai/memoh/internal/channel" ) +type fakeProcessingReactionGateway struct { + addCalls []struct{ messageID, reactionType string } + removeCalls []struct{ messageID, reactionID string } + addResponse []struct { + reactionID string + err error + } + removeErr error +} + +func (g *fakeProcessingReactionGateway) Add(ctx context.Context, messageID, reactionType string) (string, error) { + g.addCalls = append(g.addCalls, struct{ messageID, reactionType string }{ + messageID: messageID, + reactionType: reactionType, + }) + if len(g.addResponse) == 0 { + return "reaction-default", nil + } + resp := g.addResponse[0] + g.addResponse = g.addResponse[1:] + return resp.reactionID, resp.err +} + +func (g *fakeProcessingReactionGateway) Remove(ctx context.Context, messageID, reactionID string) error { + g.removeCalls = append(g.removeCalls, struct{ messageID, reactionID string }{ + messageID: messageID, + reactionID: reactionID, + }) + return g.removeErr +} + func TestResolveFeishuReceiveID(t *testing.T) { t.Parallel() @@ -70,6 +107,21 @@ func TestExtractFeishuInboundP2P(t *testing.T) { if got.ReplyTarget != "ou_1" { t.Fatalf("unexpected reply target: %s", got.ReplyTarget) } + if got.Sender.DisplayName != "" { + t.Fatalf("expected empty sender display name, got: %s", got.Sender.DisplayName) + } + if got.Sender.SubjectID != "ou_1" { + t.Fatalf("unexpected sender subject id: %s", got.Sender.SubjectID) + } + if got.Sender.Attribute("open_id") != "ou_1" { + t.Fatalf("unexpected sender open_id: %s", got.Sender.Attribute("open_id")) + } + if got.Sender.Attribute("user_id") != "u_1" { + t.Fatalf("unexpected sender user_id: %s", got.Sender.Attribute("user_id")) + } + if mentioned, _ := got.Metadata["is_mentioned"].(bool); mentioned { + t.Fatalf("unexpected mention flag for p2p message") + } } func TestExtractFeishuInboundGroup(t *testing.T) { @@ -101,6 +153,9 @@ func TestExtractFeishuInboundGroup(t *testing.T) { if got.ReplyTarget != "chat_id:oc_2" { t.Fatalf("unexpected reply target: %s", got.ReplyTarget) } + if mentioned, _ := got.Metadata["is_mentioned"].(bool); mentioned { + t.Fatalf("unexpected mention flag for group message without mentions") + } } func TestExtractFeishuInboundNonText(t *testing.T) { @@ -119,3 +174,328 @@ func TestExtractFeishuInboundNonText(t *testing.T) { t.Fatalf("expected empty text, got %s", got.Message.PlainText()) } } + +func TestExtractFeishuInboundImageAttachmentReference(t *testing.T) { + t.Parallel() + + content := `{"image_key":"img_1"}` + msgType := larkim.MsgTypeImage + event := &larkim.P2MessageReceiveV1{ + Event: &larkim.P2MessageReceiveV1Data{ + Message: &larkim.EventMessage{ + MessageType: &msgType, + Content: &content, + }, + }, + } + got := extractFeishuInbound(event) + if len(got.Message.Attachments) != 1 { + t.Fatalf("expected one attachment, got %d", len(got.Message.Attachments)) + } + att := got.Message.Attachments[0] + if att.Type != channel.AttachmentImage { + t.Fatalf("unexpected attachment type: %s", att.Type) + } + if att.PlatformKey != "img_1" { + t.Fatalf("unexpected platform key: %s", att.PlatformKey) + } + if att.SourcePlatform != Type.String() { + t.Fatalf("unexpected source platform: %s", att.SourcePlatform) + } +} + +func TestFeishuDescriptorIncludesStreamingAndMedia(t *testing.T) { + t.Parallel() + + adapter := NewFeishuAdapter(nil) + caps := adapter.Descriptor().Capabilities + if !caps.Streaming { + t.Fatal("expected streaming capability") + } + if !caps.Media { + t.Fatal("expected media capability") + } +} + +func TestBuildFeishuStreamCardContent(t *testing.T) { + t.Parallel() + + payload, err := buildFeishuStreamCardContent("hello") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + var parsed map[string]any + if err := json.Unmarshal([]byte(payload), &parsed); err != nil { + t.Fatalf("unexpected json error: %v", err) + } + cfg, ok := parsed["config"].(map[string]any) + if !ok { + t.Fatalf("missing config: %+v", parsed) + } + value, ok := cfg["update_multi"].(bool) + if !ok || !value { + t.Fatalf("expected update_multi=true, got: %#v", cfg["update_multi"]) + } +} + +func TestNormalizeFeishuStreamText(t *testing.T) { + t.Parallel() + + if got := normalizeFeishuStreamText(" "); got != feishuStreamThinkingText { + t.Fatalf("unexpected thinking text: %s", got) + } + long := strings.Repeat("a", feishuStreamMaxRunes+100) + got := normalizeFeishuStreamText(long) + if len([]rune(got)) > feishuStreamMaxRunes+4 { + t.Fatalf("expected truncated text, got len=%d", len([]rune(got))) + } + if !strings.HasPrefix(got, "...\n") { + t.Fatalf("expected truncation prefix, got: %s", got[:4]) + } +} + +func TestProcessFeishuCardMarkdown(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in string + want string + }{ + {"literal newline", "a\\nb", "a\nb"}, + {"atx h1", "# Title", "**Title**"}, + {"atx h2", "## Section", "**Section**"}, + {"atx h6", "###### Small", "**Small**"}, + {"heading with newline", "# Hi\n\nBody", "**Hi**\n\nBody"}, + {"no heading", "plain **bold**", "plain **bold**"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := processFeishuCardMarkdown(tt.in) + if got != tt.want { + t.Errorf("processFeishuCardMarkdown() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestExtractFeishuInboundMention(t *testing.T) { + t.Parallel() + + text := `{"text":"@bot hi","mentions":[{"key":"@bot"}]}` + msgType := larkim.MsgTypeText + chatType := "group" + chatID := "oc_3" + event := &larkim.P2MessageReceiveV1{ + Event: &larkim.P2MessageReceiveV1Data{ + Message: &larkim.EventMessage{ + MessageType: &msgType, + Content: &text, + ChatType: &chatType, + ChatId: &chatID, + }, + }, + } + got := extractFeishuInbound(event) + mentioned, ok := got.Metadata["is_mentioned"].(bool) + if !ok || !mentioned { + t.Fatalf("expected mention flag to be true") + } +} + +func TestExtractFeishuInboundMentionFromEventMentions(t *testing.T) { + t.Parallel() + + text := `{"text":"hello"}` + msgType := larkim.MsgTypeText + chatType := "group" + chatID := "oc_mention_event" + mention := larkim.NewMentionEventBuilder().Key("@_user_1").Build() + event := &larkim.P2MessageReceiveV1{ + Event: &larkim.P2MessageReceiveV1Data{ + Message: &larkim.EventMessage{ + MessageType: &msgType, + Content: &text, + ChatType: &chatType, + ChatId: &chatID, + Mentions: []*larkim.MentionEvent{mention}, + }, + }, + } + got := extractFeishuInbound(event) + mentioned, ok := got.Metadata["is_mentioned"].(bool) + if !ok || !mentioned { + t.Fatalf("expected mention flag from event mentions") + } +} + +func TestExtractFeishuInboundPostMention(t *testing.T) { + t.Parallel() + + content := `{"zh_cn":{"title":"","content":[[{"tag":"at","user_name":"bot"},{"tag":"text","text":" hi"}]]}}` + msgType := larkim.MsgTypePost + chatType := "group" + chatID := "oc_post_1" + event := &larkim.P2MessageReceiveV1{ + Event: &larkim.P2MessageReceiveV1Data{ + Message: &larkim.EventMessage{ + MessageType: &msgType, + Content: &content, + ChatType: &chatType, + ChatId: &chatID, + }, + }, + } + + got := extractFeishuInbound(event) + if got.Message.PlainText() == "" { + t.Fatalf("expected post message to be converted into text") + } + mentioned, ok := got.Metadata["is_mentioned"].(bool) + if !ok || !mentioned { + t.Fatalf("expected mention flag for post message") + } +} + +func TestAddProcessingReactionFirstSuccess(t *testing.T) { + t.Parallel() + + gateway := &fakeProcessingReactionGateway{ + addResponse: []struct { + reactionID string + err error + }{ + {reactionID: "reaction-1"}, + }, + } + token, err := addProcessingReaction(context.Background(), gateway, "om_1", "Typing") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if token != "reaction-1" { + t.Fatalf("expected token reaction-1, got %q", token) + } + if len(gateway.addCalls) != 1 { + t.Fatalf("expected one add call, got %d", len(gateway.addCalls)) + } + if gateway.addCalls[0].messageID != "om_1" || gateway.addCalls[0].reactionType != "Typing" { + t.Fatalf("unexpected add params: %+v", gateway.addCalls[0]) + } +} + +func TestAddProcessingReactionReturnsError(t *testing.T) { + t.Parallel() + + gateway := &fakeProcessingReactionGateway{ + addResponse: []struct { + reactionID string + err error + }{ + {err: errors.New("invalid reaction type")}, + }, + } + token, err := addProcessingReaction(context.Background(), gateway, "om_2", "INVALID") + if err == nil { + t.Fatal("expected error") + } + if token != "" { + t.Fatalf("expected empty token, got %q", token) + } + if len(gateway.addCalls) != 1 { + t.Fatalf("expected one add call, got %d", len(gateway.addCalls)) + } + if gateway.addCalls[0].reactionType != "INVALID" { + t.Fatalf("unexpected add call sequence: %+v", gateway.addCalls) + } +} + +func TestAddProcessingReactionNoMessageID(t *testing.T) { + t.Parallel() + + gateway := &fakeProcessingReactionGateway{} + token, err := addProcessingReaction(context.Background(), gateway, "", "Typing") + if err != nil { + t.Fatalf("expected no error for empty message id, got: %v", err) + } + if token != "" { + t.Fatalf("expected empty token, got %q", token) + } + if len(gateway.addCalls) != 0 { + t.Fatalf("expected no add calls, got %+v", gateway.addCalls) + } +} + +func TestRemoveProcessingReaction(t *testing.T) { + t.Parallel() + + gateway := &fakeProcessingReactionGateway{} + if err := removeProcessingReaction(context.Background(), gateway, "om_3", "reaction-3"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(gateway.removeCalls) != 1 { + t.Fatalf("expected one remove call, got %d", len(gateway.removeCalls)) + } + if gateway.removeCalls[0].messageID != "om_3" || gateway.removeCalls[0].reactionID != "reaction-3" { + t.Fatalf("unexpected remove params: %+v", gateway.removeCalls[0]) + } +} + +func TestRemoveProcessingReactionNoopForEmptyToken(t *testing.T) { + t.Parallel() + + gateway := &fakeProcessingReactionGateway{} + if err := removeProcessingReaction(context.Background(), gateway, "om_4", ""); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(gateway.removeCalls) != 0 { + t.Fatalf("expected no remove calls, got %+v", gateway.removeCalls) + } +} + +func TestFeishuProcessingStartedNoSourceMessageID(t *testing.T) { + t.Parallel() + + adapter := NewFeishuAdapter(nil) + handle, err := adapter.ProcessingStarted( + context.Background(), + channel.ChannelConfig{}, + channel.InboundMessage{}, + channel.ProcessingStatusInfo{}, + ) + if err != nil { + t.Fatalf("expected no error for empty source message id, got: %v", err) + } + if handle.Token != "" { + t.Fatalf("expected empty token, got %q", handle.Token) + } +} + +func TestFeishuProcessingStartedRequiresConfigWhenSourceMessageExists(t *testing.T) { + t.Parallel() + + adapter := NewFeishuAdapter(nil) + _, err := adapter.ProcessingStarted( + context.Background(), + channel.ChannelConfig{}, + channel.InboundMessage{}, + channel.ProcessingStatusInfo{SourceMessageID: "om_5"}, + ) + if err == nil { + t.Fatal("expected error when credentials are missing") + } +} + +func TestFeishuProcessingCompletedNoopWithoutToken(t *testing.T) { + t.Parallel() + + adapter := NewFeishuAdapter(nil) + err := adapter.ProcessingCompleted( + context.Background(), + channel.ChannelConfig{}, + channel.InboundMessage{}, + channel.ProcessingStatusInfo{SourceMessageID: "om_6"}, + channel.ProcessingStatusHandle{}, + ) + if err != nil { + t.Fatalf("expected no error for empty token, got: %v", err) + } +} diff --git a/internal/channel/adapters/feishu/inbound.go b/internal/channel/adapters/feishu/inbound.go new file mode 100644 index 00000000..1c32aa5c --- /dev/null +++ b/internal/channel/adapters/feishu/inbound.go @@ -0,0 +1,269 @@ +package feishu + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + + "github.com/memohai/memoh/internal/channel" +) + +// extractFeishuInbound converts a Feishu P2MessageReceiveV1 event into a channel.InboundMessage. +func extractFeishuInbound(event *larkim.P2MessageReceiveV1) channel.InboundMessage { + if event == nil || event.Event == nil || event.Event.Message == nil { + return channel.InboundMessage{Channel: Type} + } + message := event.Event.Message + + var msg channel.Message + if message.MessageId != nil { + msg.ID = *message.MessageId + } + + var contentMap map[string]any + if message.Content != nil { + _ = json.Unmarshal([]byte(*message.Content), &contentMap) + } + isMentioned := hasFeishuMention(contentMap, message.Mentions) + + if message.MessageType != nil { + switch *message.MessageType { + case larkim.MsgTypeText: + if txt, ok := contentMap["text"].(string); ok { + msg.Text = txt + } + case larkim.MsgTypePost: + if postText := extractFeishuPostText(contentMap); postText != "" { + msg.Text = postText + } + case larkim.MsgTypeImage: + if key, ok := contentMap["image_key"].(string); ok { + msg.Attachments = append(msg.Attachments, channel.Attachment{ + Type: channel.AttachmentImage, + PlatformKey: key, + SourcePlatform: Type.String(), + }) + } + case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia: + if key, ok := contentMap["file_key"].(string); ok { + name, _ := contentMap["file_name"].(string) + attType := channel.AttachmentFile + switch *message.MessageType { + case larkim.MsgTypeAudio: + attType = channel.AttachmentAudio + case larkim.MsgTypeMedia: + attType = channel.AttachmentVideo + } + msg.Attachments = append(msg.Attachments, channel.Attachment{ + Type: attType, + PlatformKey: key, + SourcePlatform: Type.String(), + Name: name, + }) + } + } + } + + if message.ParentId != nil && *message.ParentId != "" { + msg.Reply = &channel.ReplyRef{ + MessageID: *message.ParentId, + } + } + + senderID, senderOpenID := "", "" + if event.Event.Sender != nil && event.Event.Sender.SenderId != nil { + if event.Event.Sender.SenderId.UserId != nil { + senderID = strings.TrimSpace(*event.Event.Sender.SenderId.UserId) + } + if event.Event.Sender.SenderId.OpenId != nil { + senderOpenID = strings.TrimSpace(*event.Event.Sender.SenderId.OpenId) + } + } + chatID := "" + chatType := "" + if message.ChatId != nil { + chatID = strings.TrimSpace(*message.ChatId) + } + if message.ChatType != nil { + chatType = strings.TrimSpace(*message.ChatType) + } + replyTo := senderOpenID + if replyTo == "" { + replyTo = senderID + } + if chatType != "" && chatType != "p2p" && chatID != "" { + replyTo = "chat_id:" + chatID + } + attrs := map[string]string{} + if senderID != "" { + attrs["user_id"] = senderID + } + if senderOpenID != "" { + attrs["open_id"] = senderOpenID + } + subjectID := senderOpenID + if subjectID == "" { + subjectID = senderID + } + + return channel.InboundMessage{ + Channel: Type, + Message: msg, + ReplyTarget: replyTo, + Sender: channel.Identity{ + SubjectID: subjectID, + Attributes: attrs, + }, + Conversation: channel.Conversation{ + ID: chatID, + Type: chatType, + }, + ReceivedAt: time.Now().UTC(), + Source: "feishu", + Metadata: map[string]any{ + "is_mentioned": isMentioned, + }, + } +} + +func hasFeishuMention(contentMap map[string]any, mentions []*larkim.MentionEvent) bool { + if len(mentions) > 0 { + return true + } + if len(contentMap) == 0 { + return false + } + raw, ok := contentMap["mentions"] + if ok { + switch values := raw.(type) { + case []any: + if len(values) > 0 { + return true + } + case []map[string]any: + if len(values) > 0 { + return true + } + case map[string]any: + if len(values) > 0 { + return true + } + } + } + if text, ok := contentMap["text"].(string); ok { + normalized := strings.ToLower(strings.TrimSpace(text)) + if strings.Contains(normalized, "@_user_") || strings.Contains(normalized, "") { + return true + } + } + return hasFeishuAtTag(contentMap) +} + +func hasFeishuAtTag(raw any) bool { + switch value := raw.(type) { + case map[string]any: + if tag, ok := value["tag"].(string); ok && strings.EqualFold(strings.TrimSpace(tag), "at") { + return true + } + for _, child := range value { + if hasFeishuAtTag(child) { + return true + } + } + case []any: + for _, child := range value { + if hasFeishuAtTag(child) { + return true + } + } + } + return false +} + +func extractFeishuPostText(contentMap map[string]any) string { + zhCN, ok := contentMap["zh_cn"].(map[string]any) + if !ok { + return "" + } + linesRaw, ok := zhCN["content"].([]any) + if !ok { + return "" + } + parts := make([]string, 0, 8) + for _, rawLine := range linesRaw { + line, ok := rawLine.([]any) + if !ok { + continue + } + for _, rawPart := range line { + part, ok := rawPart.(map[string]any) + if !ok { + continue + } + tag := strings.ToLower(strings.TrimSpace(stringValue(part["tag"]))) + switch tag { + case "text", "a": + text := strings.TrimSpace(stringValue(part["text"])) + if text != "" { + parts = append(parts, text) + } + case "at": + name := strings.TrimSpace(stringValue(part["text"])) + if name == "" { + name = strings.TrimSpace(stringValue(part["name"])) + } + if name == "" { + name = strings.TrimSpace(stringValue(part["user_name"])) + } + if name == "" { + parts = append(parts, "@") + continue + } + if !strings.HasPrefix(name, "@") { + name = "@" + name + } + parts = append(parts, name) + default: + text := strings.TrimSpace(stringValue(part["text"])) + if text != "" { + parts = append(parts, text) + } + } + } + } + if len(parts) == 0 { + return "" + } + return strings.Join(parts, " ") +} + +func stringValue(raw any) string { + if raw == nil { + return "" + } + value, ok := raw.(string) + if ok { + return value + } + return fmt.Sprint(raw) +} + +// resolveFeishuReceiveID parses target (open_id:/user_id:/chat_id: prefix) and returns receiveID and receiveType. +func resolveFeishuReceiveID(raw string) (string, string, error) { + if raw == "" { + return "", "", fmt.Errorf("feishu target is required") + } + if strings.HasPrefix(raw, "open_id:") { + return strings.TrimPrefix(raw, "open_id:"), larkim.ReceiveIdTypeOpenId, nil + } + if strings.HasPrefix(raw, "user_id:") { + return strings.TrimPrefix(raw, "user_id:"), larkim.ReceiveIdTypeUserId, nil + } + if strings.HasPrefix(raw, "chat_id:") { + return strings.TrimPrefix(raw, "chat_id:"), larkim.ReceiveIdTypeChatId, nil + } + return raw, larkim.ReceiveIdTypeOpenId, nil +} diff --git a/internal/channel/adapters/feishu/feishu_logger.go b/internal/channel/adapters/feishu/logger.go similarity index 100% rename from internal/channel/adapters/feishu/feishu_logger.go rename to internal/channel/adapters/feishu/logger.go diff --git a/internal/channel/adapters/feishu/stream.go b/internal/channel/adapters/feishu/stream.go new file mode 100644 index 00000000..ce063119 --- /dev/null +++ b/internal/channel/adapters/feishu/stream.go @@ -0,0 +1,286 @@ +package feishu + +import ( + "context" + "encoding/json" + "fmt" + "regexp" + "strings" + "sync/atomic" + "time" + + "github.com/google/uuid" + lark "github.com/larksuite/oapi-sdk-go/v3" + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + + "github.com/memohai/memoh/internal/channel" +) + +const ( + feishuStreamThinkingText = "Thinking..." + feishuStreamPatchInterval = 700 * time.Millisecond + feishuStreamMaxRunes = 8000 +) + +type feishuOutboundStream struct { + adapter *FeishuAdapter + cfg channel.ChannelConfig + target string + reply *channel.ReplyRef + client *lark.Client + receiveID string + receiveType string + cardMessageID string + textBuffer strings.Builder + lastPatchedAt time.Time + lastPatched string + patchInterval time.Duration + closed atomic.Bool +} + +func (s *feishuOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { + if s == nil || s.adapter == nil { + return fmt.Errorf("feishu stream not configured") + } + if s.closed.Load() { + return fmt.Errorf("feishu stream is closed") + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + switch event.Type { + case channel.StreamEventStatus: + if event.Status == channel.StreamStatusStarted { + return s.ensureCard(ctx, feishuStreamThinkingText) + } + return nil + case channel.StreamEventDelta: + if event.Delta == "" { + return nil + } + s.textBuffer.WriteString(event.Delta) + if err := s.ensureCard(ctx, feishuStreamThinkingText); err != nil { + return err + } + if time.Since(s.lastPatchedAt) < s.patchInterval && !strings.Contains(event.Delta, "\n") { + return nil + } + return s.patchCard(ctx, s.textBuffer.String()) + case channel.StreamEventFinal: + if event.Final == nil || event.Final.Message.IsEmpty() { + return nil + } + msg := event.Final.Message + finalText := strings.TrimSpace(msg.PlainText()) + if finalText == "" { + finalText = strings.TrimSpace(s.textBuffer.String()) + } + if finalText != "" { + if err := s.ensureCard(ctx, feishuStreamThinkingText); err != nil { + return err + } + if err := s.patchCard(ctx, finalText); err != nil { + return err + } + } + if len(msg.Attachments) > 0 { + media := msg + media.Format = "" + media.Text = "" + media.Parts = nil + media.Actions = nil + media.Reply = nil + return s.adapter.Send(ctx, s.cfg, channel.OutboundMessage{ + Target: s.target, + Message: media, + }) + } + return nil + case channel.StreamEventError: + errText := strings.TrimSpace(event.Error) + if errText == "" { + return nil + } + if err := s.ensureCard(ctx, feishuStreamThinkingText); err != nil { + return err + } + return s.patchCard(ctx, "Error: "+errText) + default: + return fmt.Errorf("unsupported stream event type: %s", event.Type) + } +} + +func (s *feishuOutboundStream) Close(ctx context.Context) error { + if s == nil { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + s.closed.Store(true) + return nil +} + +func (s *feishuOutboundStream) ensureCard(ctx context.Context, text string) error { + if strings.TrimSpace(s.cardMessageID) != "" { + return nil + } + if s.client == nil { + return fmt.Errorf("feishu client not configured") + } + content, err := buildFeishuStreamCardContent(text) + if err != nil { + return err + } + if s.reply != nil && strings.TrimSpace(s.reply.MessageID) != "" { + replyReq := larkim.NewReplyMessageReqBuilder(). + MessageId(strings.TrimSpace(s.reply.MessageID)). + Body(larkim.NewReplyMessageReqBodyBuilder(). + Content(content). + MsgType(larkim.MsgTypeInteractive). + Uuid(uuid.NewString()). + Build()). + Build() + replyResp, err := s.client.Im.V1.Message.Reply(ctx, replyReq) + if err != nil { + return err + } + if replyResp == nil || !replyResp.Success() { + code, msg := 0, "" + if replyResp != nil { + code, msg = replyResp.Code, replyResp.Msg + } + return fmt.Errorf("feishu stream reply failed: %s (code: %d)", msg, code) + } + if replyResp.Data == nil || replyResp.Data.MessageId == nil || strings.TrimSpace(*replyResp.Data.MessageId) == "" { + return fmt.Errorf("feishu stream reply failed: empty message id") + } + s.cardMessageID = strings.TrimSpace(*replyResp.Data.MessageId) + s.lastPatched = normalizeFeishuStreamText(text) + s.lastPatchedAt = time.Now() + return nil + } + createReq := larkim.NewCreateMessageReqBuilder(). + ReceiveIdType(s.receiveType). + Body(larkim.NewCreateMessageReqBodyBuilder(). + ReceiveId(s.receiveID). + MsgType(larkim.MsgTypeInteractive). + Content(content). + Uuid(uuid.NewString()). + Build()). + Build() + createResp, err := s.client.Im.V1.Message.Create(ctx, createReq) + if err != nil { + return err + } + if createResp == nil || !createResp.Success() { + code, msg := 0, "" + if createResp != nil { + code, msg = createResp.Code, createResp.Msg + } + return fmt.Errorf("feishu stream create failed: %s (code: %d)", msg, code) + } + if createResp.Data == nil || createResp.Data.MessageId == nil || strings.TrimSpace(*createResp.Data.MessageId) == "" { + return fmt.Errorf("feishu stream create failed: empty message id") + } + s.cardMessageID = strings.TrimSpace(*createResp.Data.MessageId) + s.lastPatched = normalizeFeishuStreamText(text) + s.lastPatchedAt = time.Now() + return nil +} + +func (s *feishuOutboundStream) patchCard(ctx context.Context, text string) error { + if strings.TrimSpace(s.cardMessageID) == "" { + return fmt.Errorf("feishu stream card message not initialized") + } + contentText := normalizeFeishuStreamText(text) + if contentText == s.lastPatched { + return nil + } + content, err := buildFeishuStreamCardContent(contentText) + if err != nil { + return err + } + patchReq := larkim.NewPatchMessageReqBuilder(). + MessageId(strings.TrimSpace(s.cardMessageID)). + Body(larkim.NewPatchMessageReqBodyBuilder(). + Content(content). + Build()). + Build() + patchResp, err := s.client.Im.V1.Message.Patch(ctx, patchReq) + if err != nil { + return err + } + if patchResp == nil || !patchResp.Success() { + code, msg := 0, "" + if patchResp != nil { + code, msg = patchResp.Code, patchResp.Msg + } + return fmt.Errorf("feishu stream patch failed: %s (code: %d)", msg, code) + } + s.lastPatched = contentText + s.lastPatchedAt = time.Now() + return nil +} + +func buildFeishuStreamCardContent(text string) (string, error) { + content := normalizeFeishuStreamText(text) + content = processFeishuCardMarkdown(content) + card := map[string]any{ + "config": map[string]any{ + "wide_screen_mode": true, + "enable_forward": true, + "update_multi": true, + }, + "elements": []map[string]any{ + { + "tag": "div", + "fields": []map[string]any{ + { + "is_short": false, + "text": map[string]any{ + "tag": "lark_md", + "content": content, + }, + }, + }, + }, + }, + } + data, err := json.Marshal(card) + if err != nil { + return "", err + } + return string(data), nil +} + +var feishuCardHeadingPrefix = regexp.MustCompile(`(?m)^#{1,6}\s+(.+)$`) + +// processFeishuCardMarkdown normalizes markdown for Feishu card lark_md (e.g. ATX headings to bold). +func processFeishuCardMarkdown(s string) string { + s = strings.ReplaceAll(s, "\\n", "\n") + s = feishuCardHeadingPrefix.ReplaceAllStringFunc(s, func(m string) string { + parts := feishuCardHeadingPrefix.FindStringSubmatch(m) + if len(parts) == 2 { + return "**" + parts[1] + "**" + } + return m + }) + return s +} + +func normalizeFeishuStreamText(text string) string { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return feishuStreamThinkingText + } + runes := []rune(trimmed) + if len(runes) <= feishuStreamMaxRunes { + return trimmed + } + return "...\n" + string(runes[len(runes)-feishuStreamMaxRunes:]) +} diff --git a/internal/channel/adapters/local/cli.go b/internal/channel/adapters/local/cli.go index 9241b207..3b026fb6 100644 --- a/internal/channel/adapters/local/cli.go +++ b/internal/channel/adapters/local/cli.go @@ -10,11 +10,11 @@ import ( // CLIAdapter implements channel.Sender for the local CLI channel. type CLIAdapter struct { - hub *SessionHub + hub *RouteHub } -// NewCLIAdapter creates a CLIAdapter backed by the given session hub. -func NewCLIAdapter(hub *SessionHub) *CLIAdapter { +// NewCLIAdapter creates a CLIAdapter backed by the given route hub. +func NewCLIAdapter(hub *RouteHub) *CLIAdapter { return &CLIAdapter{hub: hub} } @@ -30,20 +30,22 @@ func (a *CLIAdapter) Descriptor() channel.Descriptor { DisplayName: "CLI", Configless: true, Capabilities: channel.ChannelCapabilities{ - Text: true, - Reply: true, - Attachments: true, + Text: true, + Reply: true, + Attachments: true, + Streaming: true, + BlockStreaming: true, }, TargetSpec: channel.TargetSpec{ - Format: "session_id", + Format: "bot_id", Hints: []channel.TargetHint{ - {Label: "Session ID", Example: "cli:uuid"}, + {Label: "Bot ID", Example: "bot_123"}, }, }, } } -// Send publishes an outbound message to the CLI session hub. +// Send publishes an outbound message to the CLI route hub. func (a *CLIAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { if a.hub == nil { return fmt.Errorf("cli hub not configured") @@ -58,3 +60,20 @@ func (a *CLIAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg ch a.hub.Publish(target, msg) return nil } + +// OpenStream opens a local stream session bound to the target route. +func (a *CLIAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { + if a.hub == nil { + return nil, fmt.Errorf("cli hub not configured") + } + target = strings.TrimSpace(target) + if target == "" { + return nil, fmt.Errorf("cli target is required") + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + return newLocalOutboundStream(a.hub, target), nil +} diff --git a/internal/channel/adapters/local/hub.go b/internal/channel/adapters/local/hub.go index ddc4b4bd..0fef9edb 100644 --- a/internal/channel/adapters/local/hub.go +++ b/internal/channel/adapters/local/hub.go @@ -1,51 +1,60 @@ package local import ( + "context" + "fmt" "sync" + "sync/atomic" "github.com/google/uuid" "github.com/memohai/memoh/internal/channel" ) -// SessionHub is a pub/sub hub that routes outbound messages to CLI/Web session subscribers. -type SessionHub struct { - mu sync.RWMutex - sessions map[string]map[string]chan channel.OutboundMessage +// RouteHubEvent is a routed outbound stream event for local transports. +type RouteHubEvent struct { + Target string `json:"target"` + Event channel.StreamEvent `json:"event"` } -// NewSessionHub creates an empty SessionHub. -func NewSessionHub() *SessionHub { - return &SessionHub{ - sessions: map[string]map[string]chan channel.OutboundMessage{}, +// RouteHub is a pub/sub hub that routes outbound messages to local subscribers by route key. +type RouteHub struct { + mu sync.RWMutex + streams map[string]map[string]chan RouteHubEvent +} + +// NewRouteHub creates an empty RouteHub. +func NewRouteHub() *RouteHub { + return &RouteHub{ + streams: map[string]map[string]chan RouteHubEvent{}, } } -// Subscribe registers a new stream for the given session and returns a stream ID, +// Subscribe registers a new stream for the given route key and returns a stream ID, // a read-only channel for messages, and a cancel function to unsubscribe. -func (h *SessionHub) Subscribe(sessionID string) (string, <-chan channel.OutboundMessage, func()) { +func (h *RouteHub) Subscribe(routeKey string) (string, <-chan RouteHubEvent, func()) { streamID := uuid.NewString() - ch := make(chan channel.OutboundMessage, 32) + ch := make(chan RouteHubEvent, 32) h.mu.Lock() - streams, ok := h.sessions[sessionID] + streams, ok := h.streams[routeKey] if !ok { - streams = map[string]chan channel.OutboundMessage{} - h.sessions[sessionID] = streams + streams = map[string]chan RouteHubEvent{} + h.streams[routeKey] = streams } streams[streamID] = ch h.mu.Unlock() cancel := func() { h.mu.Lock() - streams := h.sessions[sessionID] + streams := h.streams[routeKey] if streams != nil { if current, ok := streams[streamID]; ok { delete(streams, streamID) close(current) } if len(streams) == 0 { - delete(h.sessions, sessionID) + delete(h.streams, routeKey) } } h.mu.Unlock() @@ -54,16 +63,73 @@ func (h *SessionHub) Subscribe(sessionID string) (string, <-chan channel.Outboun return streamID, ch, cancel } -// Publish delivers a message to all subscribers of the given session. +// Publish delivers a message to all subscribers of the given route key. // Slow receivers are silently dropped. -func (h *SessionHub) Publish(sessionID string, msg channel.OutboundMessage) { +func (h *RouteHub) Publish(routeKey string, msg channel.OutboundMessage) { + h.PublishEvent(routeKey, channel.StreamEvent{ + Type: channel.StreamEventFinal, + Final: &channel.StreamFinalizePayload{ + Message: msg.Message, + }, + }) +} + +// PublishEvent delivers a stream event to all subscribers of the given route key. +// Slow receivers are silently dropped. +func (h *RouteHub) PublishEvent(routeKey string, event channel.StreamEvent) { h.mu.RLock() defer h.mu.RUnlock() - for _, ch := range h.sessions[sessionID] { + for _, ch := range h.streams[routeKey] { + payload := RouteHubEvent{ + Target: routeKey, + Event: event, + } select { - case ch <- msg: + case ch <- payload: default: // Drop if receiver is slow. } } } + +type localOutboundStream struct { + hub *RouteHub + target string + closed atomic.Bool +} + +func newLocalOutboundStream(hub *RouteHub, target string) channel.OutboundStream { + return &localOutboundStream{ + hub: hub, + target: target, + } +} + +func (s *localOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { + if s == nil || s.hub == nil { + return fmt.Errorf("route hub not configured") + } + if s.closed.Load() { + return fmt.Errorf("stream is closed") + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + s.hub.PublishEvent(s.target, event) + return nil +} + +func (s *localOutboundStream) Close(ctx context.Context) error { + if s == nil { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + s.closed.Store(true) + return nil +} diff --git a/internal/channel/adapters/local/hub_test.go b/internal/channel/adapters/local/hub_test.go new file mode 100644 index 00000000..d9814af9 --- /dev/null +++ b/internal/channel/adapters/local/hub_test.go @@ -0,0 +1,50 @@ +package local + +import ( + "context" + "testing" + "time" + + "github.com/memohai/memoh/internal/channel" +) + +func TestRouteHubPublishEvent(t *testing.T) { + t.Parallel() + + hub := NewRouteHub() + _, stream, cancel := hub.Subscribe("bot-1") + defer cancel() + + hub.PublishEvent("bot-1", channel.StreamEvent{ + Type: channel.StreamEventDelta, + Delta: "hello", + }) + + select { + case item := <-stream: + if item.Target != "bot-1" { + t.Fatalf("unexpected target: %s", item.Target) + } + if item.Event.Type != channel.StreamEventDelta { + t.Fatalf("unexpected event type: %s", item.Event.Type) + } + case <-time.After(time.Second): + t.Fatal("expected event but timed out") + } +} + +func TestLocalOutboundStreamClose(t *testing.T) { + t.Parallel() + + hub := NewRouteHub() + stream := newLocalOutboundStream(hub, "bot-2") + if err := stream.Close(context.Background()); err != nil { + t.Fatalf("unexpected close error: %v", err) + } + if err := stream.Push(context.Background(), channel.StreamEvent{ + Type: channel.StreamEventDelta, + Delta: "should fail", + }); err == nil { + t.Fatal("expected push on closed stream to fail") + } +} diff --git a/internal/channel/adapters/local/web.go b/internal/channel/adapters/local/web.go index 1490604b..70309748 100644 --- a/internal/channel/adapters/local/web.go +++ b/internal/channel/adapters/local/web.go @@ -10,11 +10,11 @@ import ( // WebAdapter implements channel.Sender for the local Web channel. type WebAdapter struct { - hub *SessionHub + hub *RouteHub } -// NewWebAdapter creates a WebAdapter backed by the given session hub. -func NewWebAdapter(hub *SessionHub) *WebAdapter { +// NewWebAdapter creates a WebAdapter backed by the given route hub. +func NewWebAdapter(hub *RouteHub) *WebAdapter { return &WebAdapter{hub: hub} } @@ -30,20 +30,22 @@ func (a *WebAdapter) Descriptor() channel.Descriptor { DisplayName: "Web", Configless: true, Capabilities: channel.ChannelCapabilities{ - Text: true, - Reply: true, - Attachments: true, + Text: true, + Reply: true, + Attachments: true, + Streaming: true, + BlockStreaming: true, }, TargetSpec: channel.TargetSpec{ - Format: "session_id", + Format: "bot_id", Hints: []channel.TargetHint{ - {Label: "Session ID", Example: "web:uuid"}, + {Label: "Bot ID", Example: "bot_123"}, }, }, } } -// Send publishes an outbound message to the Web session hub. +// Send publishes an outbound message to the Web route hub. func (a *WebAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { if a.hub == nil { return fmt.Errorf("web hub not configured") @@ -58,3 +60,20 @@ func (a *WebAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg ch a.hub.Publish(target, msg) return nil } + +// OpenStream opens a local stream session bound to the target route. +func (a *WebAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { + if a.hub == nil { + return nil, fmt.Errorf("web hub not configured") + } + target = strings.TrimSpace(target) + if target == "" { + return nil, fmt.Errorf("web target is required") + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + return newLocalOutboundStream(a.hub, target), nil +} diff --git a/internal/channel/adapters/telegram/config.go b/internal/channel/adapters/telegram/config.go index 51d2d3a6..d77e2feb 100644 --- a/internal/channel/adapters/telegram/config.go +++ b/internal/channel/adapters/telegram/config.go @@ -82,8 +82,8 @@ func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { if value := strings.TrimSpace(criteria.Attribute("username")); value != "" && strings.EqualFold(value, cfg.Username) { return true } - if criteria.ExternalID != "" { - if criteria.ExternalID == cfg.ChatID || criteria.ExternalID == cfg.UserID || strings.EqualFold(criteria.ExternalID, cfg.Username) { + if criteria.SubjectID != "" { + if criteria.SubjectID == cfg.ChatID || criteria.SubjectID == cfg.UserID || strings.EqualFold(criteria.SubjectID, cfg.Username) { return true } } diff --git a/internal/channel/adapters/telegram/directory.go b/internal/channel/adapters/telegram/directory.go new file mode 100644 index 00000000..7901cdeb --- /dev/null +++ b/internal/channel/adapters/telegram/directory.go @@ -0,0 +1,239 @@ +package telegram + +import ( + "context" + "fmt" + "strconv" + "strings" + + tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + + "github.com/memohai/memoh/internal/channel" +) + +const ( + defaultDirectoryLimit = 50 + maxDirectoryLimit = 200 +) + +func directoryLimit(n int) int { + if n <= 0 { + return defaultDirectoryLimit + } + if n > maxDirectoryLimit { + return maxDirectoryLimit + } + return n +} + +// ListPeers returns users the bot can reach. Telegram Bot API does not provide a list of users; returns empty. +func (a *TelegramAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} + +// ListGroups returns chats the bot is in. Telegram Bot API does not provide a list of chats; returns empty. +func (a *TelegramAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} + +// ListGroupMembers returns administrators of the given group (Telegram only exposes admin list, not full members). +func (a *TelegramAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelConfig, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + telegramCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, err + } + bot, err := a.getOrCreateBot(telegramCfg.BotToken, cfg.ID) + if err != nil { + return nil, err + } + chatID, superGroupUsername := parseTelegramChatInput(strings.TrimSpace(groupID)) + if chatID == 0 && superGroupUsername == "" { + return nil, fmt.Errorf("telegram list group members: invalid group id %q", groupID) + } + config := tgbotapi.ChatAdministratorsConfig{ + ChatConfig: tgbotapi.ChatConfig{ChatID: chatID, SuperGroupUsername: superGroupUsername}, + } + members, err := bot.GetChatAdministrators(config) + if err != nil { + return nil, fmt.Errorf("telegram get chat administrators: %w", err) + } + limit := directoryLimit(query.Limit) + entries := make([]channel.DirectoryEntry, 0, limit) + for i, m := range members { + if i >= limit { + break + } + if m.User == nil { + continue + } + e := telegramUserToEntry(m.User) + if query.Query != "" && !strings.Contains(strings.ToLower(e.Name+e.Handle), strings.ToLower(query.Query)) { + continue + } + entries = append(entries, e) + } + return entries, nil +} + +// ResolveEntry resolves an input string to a user or group DirectoryEntry using getChat / getChatMember. +func (a *TelegramAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + telegramCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return channel.DirectoryEntry{}, err + } + bot, err := a.getOrCreateBot(telegramCfg.BotToken, cfg.ID) + if err != nil { + return channel.DirectoryEntry{}, err + } + input = strings.TrimSpace(input) + switch kind { + case channel.DirectoryEntryUser: + return a.resolveTelegramUser(ctx, bot, input) + case channel.DirectoryEntryGroup: + return a.resolveTelegramGroup(ctx, bot, input) + default: + return channel.DirectoryEntry{}, fmt.Errorf("telegram resolve entry: unsupported kind %q", kind) + } +} + +func (a *TelegramAdapter) resolveTelegramUser(ctx context.Context, bot *tgbotapi.BotAPI, input string) (channel.DirectoryEntry, error) { + chatID, userID := parseTelegramUserInput(input) + if chatID == 0 && userID == 0 { + return channel.DirectoryEntry{}, fmt.Errorf("telegram resolve entry user: invalid input %q", input) + } + if userID != 0 { + config := tgbotapi.GetChatMemberConfig{ + ChatConfigWithUser: tgbotapi.ChatConfigWithUser{ + ChatID: chatID, + SuperGroupUsername: "", + UserID: userID, + }, + } + member, err := bot.GetChatMember(config) + if err != nil { + return channel.DirectoryEntry{}, fmt.Errorf("telegram get chat member: %w", err) + } + if member.User == nil { + return channel.DirectoryEntry{}, fmt.Errorf("telegram get chat member: empty user") + } + return telegramUserToEntry(member.User), nil + } + chatConfig := tgbotapi.ChatInfoConfig{ChatConfig: tgbotapi.ChatConfig{ChatID: chatID}} + chat, err := bot.GetChat(chatConfig) + if err != nil { + return channel.DirectoryEntry{}, fmt.Errorf("telegram get chat: %w", err) + } + if !chat.IsPrivate() { + return channel.DirectoryEntry{}, fmt.Errorf("telegram resolve entry user: chat %d is not private", chatID) + } + name := strings.TrimSpace(chat.FirstName + " " + chat.LastName) + if name == "" { + name = strings.TrimSpace(chat.Title) + } + handle := strings.TrimSpace(chat.UserName) + if handle != "" && !strings.HasPrefix(handle, "@") { + handle = "@" + handle + } + idStr := strconv.FormatInt(chat.ID, 10) + return channel.DirectoryEntry{ + Kind: channel.DirectoryEntryUser, + ID: idStr, + Name: name, + Handle: handle, + Metadata: map[string]any{ + "chat_id": idStr, + "username": chat.UserName, + }, + }, nil +} + +func (a *TelegramAdapter) resolveTelegramGroup(ctx context.Context, bot *tgbotapi.BotAPI, input string) (channel.DirectoryEntry, error) { + chatID, superGroupUsername := parseTelegramChatInput(input) + if chatID == 0 && superGroupUsername == "" { + return channel.DirectoryEntry{}, fmt.Errorf("telegram resolve entry group: invalid input %q", input) + } + config := tgbotapi.ChatInfoConfig{ + ChatConfig: tgbotapi.ChatConfig{ChatID: chatID, SuperGroupUsername: superGroupUsername}, + } + chat, err := bot.GetChat(config) + if err != nil { + return channel.DirectoryEntry{}, fmt.Errorf("telegram get chat: %w", err) + } + idStr := strconv.FormatInt(chat.ID, 10) + name := strings.TrimSpace(chat.Title) + if name == "" { + name = strings.TrimSpace(chat.FirstName + " " + chat.LastName) + } + handle := strings.TrimSpace(chat.UserName) + if handle != "" && !strings.HasPrefix(handle, "@") { + handle = "@" + handle + } + return channel.DirectoryEntry{ + Kind: channel.DirectoryEntryGroup, + ID: idStr, + Name: name, + Handle: handle, + Metadata: map[string]any{"chat_id": idStr, "type": chat.Type}, + }, nil +} + +// parseTelegramChatInput parses input as chat_id (numeric) or @channel_username. Returns (chatID, superGroupUsername). +func parseTelegramChatInput(input string) (chatID int64, superGroupUsername string) { + input = strings.TrimSpace(input) + if input == "" { + return 0, "" + } + if strings.HasPrefix(input, "@") { + return 0, input + } + id, err := strconv.ParseInt(input, 10, 64) + if err != nil { + return 0, "" + } + return id, "" +} + +// parseTelegramUserInput parses input as "chat_id" (private chat) or "chat_id:user_id". Returns (chatID, userID); userID 0 means resolve as private chat. +func parseTelegramUserInput(input string) (chatID, userID int64) { + input = strings.TrimSpace(input) + if input == "" { + return 0, 0 + } + if idx := strings.Index(input, ":"); idx >= 0 { + left := strings.TrimSpace(input[:idx]) + right := strings.TrimSpace(input[idx+1:]) + cid, err1 := strconv.ParseInt(left, 10, 64) + uid, err2 := strconv.ParseInt(right, 10, 64) + if err1 == nil && err2 == nil && cid != 0 && uid != 0 { + return cid, uid + } + } + id, err := strconv.ParseInt(input, 10, 64) + if err != nil { + return 0, 0 + } + return id, 0 +} + +func telegramUserToEntry(u *tgbotapi.User) channel.DirectoryEntry { + if u == nil { + return channel.DirectoryEntry{Kind: channel.DirectoryEntryUser} + } + name := strings.TrimSpace(u.FirstName + " " + u.LastName) + handle := strings.TrimSpace(u.UserName) + if handle != "" && !strings.HasPrefix(handle, "@") { + handle = "@" + handle + } + idStr := strconv.FormatInt(u.ID, 10) + meta := map[string]any{"user_id": idStr} + if u.UserName != "" { + meta["username"] = u.UserName + } + return channel.DirectoryEntry{ + Kind: channel.DirectoryEntryUser, + ID: idStr, + Name: name, + Handle: handle, + Metadata: meta, + } +} diff --git a/internal/channel/adapters/telegram/directory_test.go b/internal/channel/adapters/telegram/directory_test.go new file mode 100644 index 00000000..ebaf313e --- /dev/null +++ b/internal/channel/adapters/telegram/directory_test.go @@ -0,0 +1,99 @@ +package telegram + +import ( + "strconv" + "testing" + + tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + + "github.com/memohai/memoh/internal/channel" +) + +func Test_directoryLimit(t *testing.T) { + tests := []struct { + name string + n int + want int + }{ + {"zero", 0, defaultDirectoryLimit}, + {"negative", -1, defaultDirectoryLimit}, + {"one", 1, 1}, + {"default", defaultDirectoryLimit, defaultDirectoryLimit}, + {"over max", maxDirectoryLimit + 100, maxDirectoryLimit}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := directoryLimit(tt.n); got != tt.want { + t.Errorf("directoryLimit() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_parseTelegramChatInput(t *testing.T) { + tests := []struct { + input string + wantID int64 + wantUsername string + }{ + {"123456789", 123456789, ""}, + {" -100123 ", -100123, ""}, + {"@channel", 0, "@channel"}, + {" @supergroup ", 0, "@supergroup"}, + {"", 0, ""}, + {"abc", 0, ""}, + } + for _, tt := range tests { + chatID, username := parseTelegramChatInput(tt.input) + if chatID != tt.wantID || username != tt.wantUsername { + t.Errorf("parseTelegramChatInput(%q) = %d, %q; want %d, %q", tt.input, chatID, username, tt.wantID, tt.wantUsername) + } + } +} + +func Test_parseTelegramUserInput(t *testing.T) { + tests := []struct { + input string + wantChat int64 + wantUser int64 + }{ + {"12345", 12345, 0}, + {" -100 ", -100, 0}, + {"12345:67890", 12345, 67890}, + {" -100 : 200 ", -100, 200}, + {"", 0, 0}, + {"abc", 0, 0}, + {"1:2:3", 0, 0}, + } + for _, tt := range tests { + chatID, userID := parseTelegramUserInput(tt.input) + if chatID != tt.wantChat || userID != tt.wantUser { + t.Errorf("parseTelegramUserInput(%q) = %d, %d; want %d, %d", tt.input, chatID, userID, tt.wantChat, tt.wantUser) + } + } +} + +func Test_telegramUserToEntry(t *testing.T) { + u := &tgbotapi.User{ID: 123, UserName: "alice", FirstName: "Alice", LastName: "Smith"} + e := telegramUserToEntry(u) + if e.Kind != channel.DirectoryEntryUser { + t.Errorf("Kind = %q", e.Kind) + } + if e.ID != strconv.FormatInt(123, 10) { + t.Errorf("ID = %q", e.ID) + } + if e.Name != "Alice Smith" { + t.Errorf("Name = %q", e.Name) + } + if e.Handle != "@alice" { + t.Errorf("Handle = %q", e.Handle) + } + if e.Metadata["user_id"] != "123" || e.Metadata["username"] != "alice" { + t.Errorf("Metadata = %+v", e.Metadata) + } + // nil user + e2 := telegramUserToEntry(nil) + if e2.Kind != channel.DirectoryEntryUser || e2.ID != "" { + t.Errorf("telegramUserToEntry(nil) = %+v", e2) + } +} diff --git a/internal/channel/adapters/telegram/logger.go b/internal/channel/adapters/telegram/logger.go new file mode 100644 index 00000000..e2fa7645 --- /dev/null +++ b/internal/channel/adapters/telegram/logger.go @@ -0,0 +1,19 @@ +package telegram + +import ( + "fmt" + "log/slog" +) + +// slogBotLogger adapts slog.Logger to tgbotapi.BotLogger so library logs go through slog. +type slogBotLogger struct { + log *slog.Logger +} + +func (s *slogBotLogger) Println(v ...interface{}) { + s.log.Warn(fmt.Sprint(v...)) +} + +func (s *slogBotLogger) Printf(format string, v ...interface{}) { + s.log.Warn(fmt.Sprintf(format, v...)) +} diff --git a/internal/channel/adapters/telegram/logger_test.go b/internal/channel/adapters/telegram/logger_test.go new file mode 100644 index 00000000..dbcbb2d6 --- /dev/null +++ b/internal/channel/adapters/telegram/logger_test.go @@ -0,0 +1,46 @@ +package telegram + +import ( + "bytes" + "log/slog" + "testing" +) + +func TestSlogBotLogger_Println(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + log := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) + w := &slogBotLogger{log: log} + + w.Println("hello") + if !bytes.Contains(buf.Bytes(), []byte("level=WARN")) { + t.Fatalf("expected WARN level in output: %s", buf.String()) + } + if !bytes.Contains(buf.Bytes(), []byte("hello")) { + t.Fatalf("expected message in output: %s", buf.String()) + } + + buf.Reset() + w.Println("err", 123) + out := buf.String() + if !bytes.Contains(buf.Bytes(), []byte("err")) || !bytes.Contains(buf.Bytes(), []byte("123")) { + t.Fatalf("expected err and 123 in output: %s", out) + } +} + +func TestSlogBotLogger_Printf(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + log := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) + w := &slogBotLogger{log: log} + + w.Printf("retrying in %d seconds...", 3) + if !bytes.Contains(buf.Bytes(), []byte("level=WARN")) { + t.Fatalf("expected WARN level: %s", buf.String()) + } + if !bytes.Contains(buf.Bytes(), []byte("retrying in 3 seconds")) { + t.Fatalf("expected formatted message: %s", buf.String()) + } +} diff --git a/internal/channel/adapters/telegram/markdown.go b/internal/channel/adapters/telegram/markdown.go new file mode 100644 index 00000000..a3fba424 --- /dev/null +++ b/internal/channel/adapters/telegram/markdown.go @@ -0,0 +1,210 @@ +package telegram + +import ( + "fmt" + "regexp" + "strings" + + tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + + "github.com/memohai/memoh/internal/channel" +) + +const ( + codeBlockPlaceholder = "\x00CB" + inlineCodePlaceholder = "\x00IC" +) + +var ( + reCodeBlockFence = regexp.MustCompile("(?s)```(\\w*)\\n?(.*?)```") + reInlineCode = regexp.MustCompile("`([^`\\n]+?)`") + reBold = regexp.MustCompile(`\*\*(.+?)\*\*`) + reStrike = regexp.MustCompile(`~~(.+?)~~`) + reLink = regexp.MustCompile(`\[([^\]]+?)\]\(([^)]+?)\)`) + reHeading = regexp.MustCompile(`(?m)^#{1,6}\s+(.+)$`) + reListBullet = regexp.MustCompile(`(?m)^(\s*)[-+]\s`) + reItalic = regexp.MustCompile(`\*([^*\n]+?)\*`) +) + +// formatTelegramOutput converts standard markdown to Telegram-compatible HTML +// when the message format is markdown. Returns the formatted text and the +// Telegram parse mode to use. +func formatTelegramOutput(text string, format channel.MessageFormat) (string, string) { + if format == channel.MessageFormatMarkdown && strings.TrimSpace(text) != "" { + return markdownToTelegramHTML(text), tgbotapi.ModeHTML + } + return text, "" +} + +// markdownToTelegramHTML converts standard markdown to Telegram-compatible HTML. +// +// Supported conversions: +// - Fenced code blocks (```lang ... ```) →

+//   - Inline code (`code`) → 
+//   - Bold (**text**) → 
+//   - Italic (*text*) → 
+//   - Strikethrough (~~text~~) → 
+//   - Links ([text](url)) → 
+//   - Headings (# text) → 
+//   - Unordered lists (- item) → bullet
+//   - Block quotes (> text) → 
+func markdownToTelegramHTML(text string) string { + if strings.TrimSpace(text) == "" { + return text + } + + // Split by fenced code blocks (``` ... ```). + // Even-indexed segments are normal text, odd-indexed are code content. + segments := splitCodeBlocks(text) + var buf strings.Builder + for i, seg := range segments { + if i%2 == 0 { + buf.WriteString(convertInlineMarkdown(seg)) + } else { + lang, code := extractCodeBlockLang(seg) + escaped := telegramEscapeHTML(strings.TrimRight(code, "\n")) + if lang != "" { + buf.WriteString(fmt.Sprintf("
%s
", lang, escaped)) + } else { + buf.WriteString("
" + escaped + "
") + } + } + } + return strings.TrimSpace(buf.String()) +} + +// splitCodeBlocks splits text by triple-backtick fences. +// Returns alternating [normal, code, normal, code, ...] segments. +func splitCodeBlocks(text string) []string { + const fence = "```" + var segments []string + for { + start := strings.Index(text, fence) + if start < 0 { + segments = append(segments, text) + break + } + segments = append(segments, text[:start]) + rest := text[start+len(fence):] + end := strings.Index(rest, fence) + if end < 0 { + // Unclosed code block: treat remainder as normal text. + segments = append(segments, text[start:]) + // Remove the last normal segment and replace with full remainder. + segments[len(segments)-2] = segments[len(segments)-2] + segments[len(segments)-1] + segments = segments[:len(segments)-1] + break + } + segments = append(segments, rest[:end]) + text = rest[end+len(fence):] + } + return segments +} + +// extractCodeBlockLang separates the optional language tag from code content. +func extractCodeBlockLang(block string) (string, string) { + idx := strings.IndexByte(block, '\n') + if idx < 0 { + // Single line: check if it looks like a language tag. + trimmed := strings.TrimSpace(block) + if trimmed != "" && !strings.Contains(trimmed, " ") && len(trimmed) <= 20 { + return trimmed, "" + } + return "", block + } + firstLine := strings.TrimSpace(block[:idx]) + rest := block[idx+1:] + if firstLine != "" && !strings.Contains(firstLine, " ") && len(firstLine) <= 20 { + return firstLine, rest + } + // No language tag: strip leading newline from content. + return "", strings.TrimLeft(block, "\n") +} + +// convertInlineMarkdown converts inline markdown formatting to Telegram HTML. +func convertInlineMarkdown(text string) string { + if strings.TrimSpace(text) == "" { + return text + } + + // Protect inline code spans from further processing. + var inlineCodes []string + text = reInlineCode.ReplaceAllStringFunc(text, func(match string) string { + idx := len(inlineCodes) + inlineCodes = append(inlineCodes, match) + return fmt.Sprintf("%s%d\x00", inlineCodePlaceholder, idx) + }) + + // Escape HTML entities. + text = telegramEscapeHTML(text) + + // Bold: **text** → text (must run before italic). + text = reBold.ReplaceAllString(text, "$1") + + // Strikethrough: ~~text~~ → text. + text = reStrike.ReplaceAllString(text, "$1") + + // Links: [text](url) →
text. + text = reLink.ReplaceAllString(text, `$1`) + + // Headings: # text → bold line. + text = reHeading.ReplaceAllString(text, "$1") + + // Unordered lists: - item / + item → bullet. + text = reListBullet.ReplaceAllString(text, "${1}• ") + + // Italic: *text* → text (after bold, so ** is already consumed). + text = reItalic.ReplaceAllString(text, "$1") + + // Block quotes: > text →
. + text = convertBlockquotes(text) + + // Restore inline code spans. + for i, original := range inlineCodes { + sub := reInlineCode.FindStringSubmatch(original) + content := "" + if len(sub) >= 2 { + content = sub[1] + } + placeholder := fmt.Sprintf("%s%d\x00", inlineCodePlaceholder, i) + text = strings.Replace(text, placeholder, ""+telegramEscapeHTML(content)+"", 1) + } + + return text +} + +// convertBlockquotes converts markdown block quotes to Telegram HTML blockquotes. +// After HTML escaping, ">" becomes ">", so we match the escaped form. +func convertBlockquotes(text string) string { + lines := strings.Split(text, "\n") + var result []string + var quoteLines []string + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "> ") || trimmed == ">" { + content := strings.TrimPrefix(trimmed, "> ") + if content == ">" { + content = "" + } + quoteLines = append(quoteLines, content) + } else { + if len(quoteLines) > 0 { + result = append(result, "
"+strings.Join(quoteLines, "\n")+"
") + quoteLines = nil + } + result = append(result, line) + } + } + if len(quoteLines) > 0 { + result = append(result, "
"+strings.Join(quoteLines, "\n")+"
") + } + return strings.Join(result, "\n") +} + +// telegramEscapeHTML escapes characters that are special in HTML. +func telegramEscapeHTML(text string) string { + text = strings.ReplaceAll(text, "&", "&") + text = strings.ReplaceAll(text, "<", "<") + text = strings.ReplaceAll(text, ">", ">") + return text +} diff --git a/internal/channel/adapters/telegram/markdown_test.go b/internal/channel/adapters/telegram/markdown_test.go new file mode 100644 index 00000000..e2d1c790 --- /dev/null +++ b/internal/channel/adapters/telegram/markdown_test.go @@ -0,0 +1,209 @@ +package telegram + +import ( + "strings" + "testing" + + tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + + "github.com/memohai/memoh/internal/channel" +) + +func TestMarkdownToTelegramHTML(t *testing.T) { + tests := []struct { + name string + input string + contains []string + absent []string + }{ + { + name: "bold", + input: "this is **bold** text", + contains: []string{"bold"}, + absent: []string{"**"}, + }, + { + name: "italic", + input: "this is *italic* text", + contains: []string{"italic"}, + }, + { + name: "bold and italic", + input: "**bold** and *italic*", + contains: []string{"bold", "italic"}, + }, + { + name: "strikethrough", + input: "this is ~~deleted~~ text", + contains: []string{"deleted"}, + absent: []string{"~~"}, + }, + { + name: "inline code", + input: "use `fmt.Println` here", + contains: []string{"fmt.Println"}, + absent: []string{"`fmt.Println`"}, + }, + { + name: "link", + input: "visit [Google](https://google.com)", + contains: []string{`Google`}, + }, + { + name: "heading", + input: "# Title\nsome text", + contains: []string{"Title"}, + absent: []string{"# Title"}, + }, + { + name: "unordered list", + input: "- first\n- second\n+ third", + contains: []string{"• first", "• second", "• third"}, + }, + { + name: "fenced code block", + input: "```go\nfmt.Println(\"hello\")\n```", + contains: []string{`
`, "fmt.Println", "
"}, + }, + { + name: "fenced code block without language", + input: "```\nplain code\n```", + contains: []string{"
plain code
"}, + }, + { + name: "html entities escaped", + input: "a < b && c > d", + contains: []string{"<", "&&", ">"}, + absent: []string{"< b", "> d"}, + }, + { + name: "code block preserves content", + input: "```\n**not bold** \n```", + contains: []string{"**not bold**", "<tag>"}, + absent: []string{"", ""}, + }, + { + name: "inline code preserves content", + input: "use `**not bold**` inline", + contains: []string{"**not bold**"}, + absent: []string{""}, + }, + { + name: "blockquote", + input: "> quoted line\n> another line", + contains: []string{"
"}, + }, + { + name: "empty input", + input: "", + contains: nil, + }, + { + name: "plain text no conversion", + input: "just plain text here", + contains: []string{"just plain text here"}, + }, + { + name: "link with ampersand in url", + input: "[search](https://example.com?a=1&b=2)", + contains: []string{`search`}, + }, + { + name: "bold inside link", + input: "**[click here](https://example.com)**", + contains: []string{"", `click here`, ""}, + }, + { + name: "mixed formatting", + input: "# Summary\n\n**Hello** world, visit [docs](https://docs.io).\n\n- item one\n- item two\n\n```python\nprint(\"hi\")\n```", + contains: []string{"Summary", "Hello", `docs`, "• item one", `class="language-python"`}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := markdownToTelegramHTML(tt.input) + for _, want := range tt.contains { + if !strings.Contains(result, want) { + t.Errorf("expected result to contain %q, got:\n%s", want, result) + } + } + for _, absent := range tt.absent { + if strings.Contains(result, absent) { + t.Errorf("expected result NOT to contain %q, got:\n%s", absent, result) + } + } + }) + } +} + +func TestFormatTelegramOutput(t *testing.T) { + tests := []struct { + name string + text string + format channel.MessageFormat + wantMode string + wantContains string + }{ + { + name: "markdown format returns html mode", + text: "**bold**", + format: channel.MessageFormatMarkdown, + wantMode: tgbotapi.ModeHTML, + wantContains: "bold", + }, + { + name: "plain format returns empty mode", + text: "hello", + format: channel.MessageFormatPlain, + wantMode: "", + }, + { + name: "empty text returns empty mode", + text: "", + format: channel.MessageFormatMarkdown, + wantMode: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + text, mode := formatTelegramOutput(tt.text, tt.format) + if mode != tt.wantMode { + t.Errorf("expected mode %q, got %q", tt.wantMode, mode) + } + if tt.wantContains != "" && !strings.Contains(text, tt.wantContains) { + t.Errorf("expected text to contain %q, got %q", tt.wantContains, text) + } + }) + } +} + +func TestSplitCodeBlocks(t *testing.T) { + tests := []struct { + name string + input string + want int // expected number of segments + }{ + {name: "no code blocks", input: "hello world", want: 1}, + {name: "one code block", input: "before```code```after", want: 3}, + {name: "two code blocks", input: "a```b```c```d```e", want: 5}, + {name: "unclosed code block", input: "before```unclosed", want: 1}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + segments := splitCodeBlocks(tt.input) + if len(segments) != tt.want { + t.Errorf("expected %d segments, got %d: %v", tt.want, len(segments), segments) + } + }) + } +} + +func TestTelegramEscapeHTML(t *testing.T) { + input := `a < b & c > d` + result := telegramEscapeHTML(input) + if result != "a < b & c > d" { + t.Errorf("unexpected escape result: %s", result) + } +} diff --git a/internal/channel/adapters/telegram/stream.go b/internal/channel/adapters/telegram/stream.go new file mode 100644 index 00000000..771d565a --- /dev/null +++ b/internal/channel/adapters/telegram/stream.go @@ -0,0 +1,211 @@ +package telegram + +import ( + "context" + "fmt" + "log/slog" + "strings" + "sync" + "sync/atomic" + "time" + + tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + + "github.com/memohai/memoh/internal/channel" +) + +const telegramStreamEditThrottle = 350 * time.Millisecond + +type telegramOutboundStream struct { + adapter *TelegramAdapter + cfg channel.ChannelConfig + target string + reply *channel.ReplyRef + parseMode string + closed atomic.Bool + mu sync.Mutex + buf strings.Builder + streamChatID int64 + streamMsgID int + lastEdited string + lastEditedAt time.Time +} + +func (s *telegramOutboundStream) getBotAndReply(ctx context.Context) (bot *tgbotapi.BotAPI, replyTo int, err error) { + telegramCfg, err := parseConfig(s.cfg.Credentials) + if err != nil { + return nil, 0, err + } + bot, err = s.adapter.getOrCreateBot(telegramCfg.BotToken, s.cfg.ID) + if err != nil { + return nil, 0, err + } + replyTo = parseReplyToMessageID(s.reply) + return bot, replyTo, nil +} + +func (s *telegramOutboundStream) ensureStreamMessage(ctx context.Context, text string) error { + s.mu.Lock() + if s.streamMsgID != 0 { + s.mu.Unlock() + return nil + } + bot, replyTo, err := s.getBotAndReply(ctx) + if err != nil { + s.mu.Unlock() + return err + } + if strings.TrimSpace(text) == "" { + text = "..." + } + chatID, msgID, err := sendTelegramTextReturnMessage(bot, s.target, text, replyTo, s.parseMode) + if err != nil { + s.mu.Unlock() + return err + } + s.streamChatID = chatID + s.streamMsgID = msgID + s.lastEdited = text + s.lastEditedAt = time.Now() + s.mu.Unlock() + return nil +} + +func (s *telegramOutboundStream) editStreamMessage(ctx context.Context, text string) error { + s.mu.Lock() + chatID := s.streamChatID + msgID := s.streamMsgID + lastEdited := s.lastEdited + lastEditedAt := s.lastEditedAt + s.mu.Unlock() + if msgID == 0 { + return nil + } + if strings.TrimSpace(text) == lastEdited { + return nil + } + if time.Since(lastEditedAt) < telegramStreamEditThrottle && !strings.Contains(text, "\n") { + return nil + } + bot, _, err := s.getBotAndReply(ctx) + if err != nil { + return err + } + if err := editTelegramMessageText(bot, chatID, msgID, text, s.parseMode); err != nil { + return err + } + s.mu.Lock() + s.lastEdited = text + s.lastEditedAt = time.Now() + s.mu.Unlock() + return nil +} + +func (s *telegramOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { + if s == nil || s.adapter == nil { + return fmt.Errorf("telegram stream not configured") + } + if s.closed.Load() { + return fmt.Errorf("telegram stream is closed") + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + switch event.Type { + case channel.StreamEventStatus: + return nil + case channel.StreamEventDelta: + if event.Delta == "" { + return nil + } + s.mu.Lock() + s.buf.WriteString(event.Delta) + content := s.buf.String() + s.mu.Unlock() + if err := s.ensureStreamMessage(ctx, content); err != nil { + return err + } + return s.editStreamMessage(ctx, content) + case channel.StreamEventFinal: + if event.Final == nil || event.Final.Message.IsEmpty() { + s.mu.Lock() + finalText := strings.TrimSpace(s.buf.String()) + s.mu.Unlock() + if finalText != "" { + _ = s.ensureStreamMessage(ctx, finalText) + _ = s.editStreamMessage(ctx, finalText) + } + return nil + } + msg := event.Final.Message + finalText := strings.TrimSpace(msg.PlainText()) + s.mu.Lock() + if finalText == "" { + finalText = strings.TrimSpace(s.buf.String()) + } + s.mu.Unlock() + // Convert markdown to Telegram HTML for the final message. + formatted, pm := formatTelegramOutput(finalText, msg.Format) + if pm != "" { + s.mu.Lock() + s.parseMode = pm + s.mu.Unlock() + finalText = formatted + } + if err := s.ensureStreamMessage(ctx, finalText); err != nil { + return err + } + if err := s.editStreamMessage(ctx, finalText); err != nil { + return err + } + if len(msg.Attachments) > 0 { + replyTo := parseReplyToMessageID(s.reply) + telegramCfg, err := parseConfig(s.cfg.Credentials) + if err != nil { + return err + } + bot, err := s.adapter.getOrCreateBot(telegramCfg.BotToken, s.cfg.ID) + if err != nil { + return err + } + parseMode := resolveTelegramParseMode(msg.Format) + for i, att := range msg.Attachments { + rto := replyTo + if i > 0 { + rto = 0 + } + if err := sendTelegramAttachment(bot, s.target, att, "", rto, parseMode); err != nil && s.adapter.logger != nil { + s.adapter.logger.Error("stream final attachment failed", slog.String("config_id", s.cfg.ID), slog.Any("error", err)) + } + } + } + return nil + case channel.StreamEventError: + errText := strings.TrimSpace(event.Error) + if errText == "" { + return nil + } + display := "Error: " + errText + if err := s.ensureStreamMessage(ctx, display); err != nil { + return err + } + return s.editStreamMessage(ctx, display) + default: + return fmt.Errorf("unsupported stream event type: %s", event.Type) + } +} + +func (s *telegramOutboundStream) Close(ctx context.Context) error { + if s == nil { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + s.closed.Store(true) + return nil +} diff --git a/internal/channel/adapters/telegram/stream_test.go b/internal/channel/adapters/telegram/stream_test.go new file mode 100644 index 00000000..56ea4032 --- /dev/null +++ b/internal/channel/adapters/telegram/stream_test.go @@ -0,0 +1,119 @@ +package telegram + +import ( + "context" + "strings" + "testing" + + "github.com/memohai/memoh/internal/channel" +) + +func TestTelegramOutboundStream_CloseNil(t *testing.T) { + t.Parallel() + + var s *telegramOutboundStream + ctx := context.Background() + if err := s.Close(ctx); err != nil { + t.Fatalf("Close on nil stream should return nil: %v", err) + } +} + +func TestTelegramOutboundStream_PushClosed(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + s := &telegramOutboundStream{adapter: adapter} + s.closed.Store(true) + + ctx := context.Background() + err := s.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "x"}) + if err == nil { + t.Fatal("Push on closed stream should return error") + } + if !strings.Contains(err.Error(), "closed") { + t.Fatalf("expected closed error: %v", err) + } +} + +func TestTelegramOutboundStream_PushStatusNoOp(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + s := &telegramOutboundStream{adapter: adapter} + + ctx := context.Background() + err := s.Push(ctx, channel.StreamEvent{Type: channel.StreamEventStatus}) + if err != nil { + t.Fatalf("StreamEventStatus should be no-op: %v", err) + } +} + +func TestTelegramOutboundStream_PushNilAdapter(t *testing.T) { + t.Parallel() + + s := &telegramOutboundStream{adapter: nil} + ctx := context.Background() + err := s.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "x"}) + if err == nil { + t.Fatal("Push with nil adapter should return error") + } + if !strings.Contains(err.Error(), "not configured") { + t.Fatalf("expected not configured error: %v", err) + } +} + +func TestTelegramOutboundStream_PushUnsupportedEventType(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + s := &telegramOutboundStream{adapter: adapter} + ctx := context.Background() + + err := s.Push(ctx, channel.StreamEvent{Type: channel.StreamEventType("unknown")}) + if err == nil { + t.Fatal("Push with unknown event type should return error") + } + if !strings.Contains(err.Error(), "unsupported") { + t.Fatalf("expected unsupported error: %v", err) + } +} + +func TestTelegramOutboundStream_PushEmptyDeltaNoOp(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + s := &telegramOutboundStream{adapter: adapter} + ctx := context.Background() + + err := s.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: ""}) + if err != nil { + t.Fatalf("empty delta should be no-op: %v", err) + } +} + +func TestTelegramOutboundStream_PushErrorEventEmptyNoOp(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + s := &telegramOutboundStream{adapter: adapter} + ctx := context.Background() + + err := s.Push(ctx, channel.StreamEvent{Type: channel.StreamEventError, Error: ""}) + if err != nil { + t.Fatalf("empty error event should be no-op: %v", err) + } +} + +func TestTelegramOutboundStream_CloseContextCanceled(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + s := &telegramOutboundStream{adapter: adapter} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := s.Close(ctx) + if err != context.Canceled { + t.Fatalf("Close with canceled context should return context.Canceled: %v", err) + } +} diff --git a/internal/channel/adapters/telegram/telegram.go b/internal/channel/adapters/telegram/telegram.go index 90e46062..1c1544de 100644 --- a/internal/channel/adapters/telegram/telegram.go +++ b/internal/channel/adapters/telegram/telegram.go @@ -15,6 +15,8 @@ import ( "github.com/memohai/memoh/internal/channel/adapters/common" ) +const telegramMaxMessageLength = 4096 + // TelegramAdapter implements the channel.Adapter, channel.Sender, and channel.Receiver interfaces for Telegram. type TelegramAdapter struct { logger *slog.Logger @@ -27,10 +29,12 @@ func NewTelegramAdapter(log *slog.Logger) *TelegramAdapter { if log == nil { log = slog.Default() } - return &TelegramAdapter{ + adapter := &TelegramAdapter{ logger: log.With(slog.String("adapter", "telegram")), bots: make(map[string]*tgbotapi.BotAPI), } + _ = tgbotapi.SetLogger(&slogBotLogger{log: adapter.logger}) + return adapter } func (a *TelegramAdapter) getOrCreateBot(token, configID string) (*tgbotapi.BotAPI, error) { @@ -67,11 +71,13 @@ func (a *TelegramAdapter) Descriptor() channel.Descriptor { Type: Type, DisplayName: "Telegram", Capabilities: channel.ChannelCapabilities{ - Text: true, - Markdown: true, - Reply: true, - Attachments: true, - Media: true, + Text: true, + Markdown: true, + Reply: true, + Attachments: true, + Media: true, + Streaming: true, + BlockStreaming: true, }, ConfigSchema: channel.ConfigSchema{ Version: 1, @@ -159,10 +165,6 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig for { select { case <-connCtx.Done(): - if a.logger != nil { - a.logger.Info("stop", slog.String("config_id", cfg.ID)) - } - bot.StopReceivingUpdates() return case update, ok := <-updates: if !ok { @@ -183,7 +185,7 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig if text == "" && len(attachments) == 0 { continue } - externalID, displayName, attrs := resolveTelegramSender(update.Message) + subjectID, displayName, attrs := resolveTelegramSender(update.Message) chatID := "" chatType := "" chatName := "" @@ -193,6 +195,10 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig chatName = strings.TrimSpace(update.Message.Chat.Title) } replyRef := buildTelegramReplyRef(update.Message, chatID) + isReplyToBot := update.Message.ReplyToMessage != nil && + update.Message.ReplyToMessage.From != nil && + update.Message.ReplyToMessage.From.IsBot + isMentioned := isTelegramBotMentioned(update.Message, bot.Self.UserName) msg := channel.InboundMessage{ Channel: Type, Message: channel.Message{ @@ -205,7 +211,7 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig BotID: cfg.BotID, ReplyTarget: chatID, Sender: channel.Identity{ - ExternalID: externalID, + SubjectID: subjectID, DisplayName: displayName, Attributes: attrs, }, @@ -216,6 +222,10 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig }, ReceivedAt: time.Unix(int64(update.Message.Date), 0).UTC(), Source: "telegram", + Metadata: map[string]any{ + "is_mentioned": isMentioned, + "is_reply_to_bot": isReplyToBot, + }, } if a.logger != nil { a.logger.Info( @@ -237,12 +247,19 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig } }() - stop := func(context.Context) error { + stop := func(_ context.Context) error { if a.logger != nil { a.logger.Info("stop", slog.String("config_id", cfg.ID)) } - cancel() bot.StopReceivingUpdates() + cancel() + // Drain remaining updates so the library's polling goroutine can + // finish writing and exit. Without this, the in-flight long-poll + // HTTP request keeps the old getUpdates session alive, causing + // "Conflict: terminated by other getUpdates request" when a new + // connection starts with the same bot token. + for range updates { + } return nil } return channel.NewConnection(cfg, stop), nil @@ -269,7 +286,7 @@ func (a *TelegramAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, m return fmt.Errorf("message is required") } text := strings.TrimSpace(msg.Message.PlainText()) - parseMode := resolveTelegramParseMode(msg.Message.Format) + text, parseMode := formatTelegramOutput(text, msg.Message.Format) replyTo := parseReplyToMessageID(msg.Message.Reply) if len(msg.Message.Attachments) > 0 { usedCaption := false @@ -298,6 +315,28 @@ func (a *TelegramAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, m return sendTelegramText(bot, to, text, replyTo, parseMode) } +// OpenStream opens a Telegram streaming session. +// The adapter sends one message then edits it in place as deltas arrive (editMessageText), +// avoiding one message per delta and rate limits. +func (a *TelegramAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { + target = strings.TrimSpace(target) + if target == "" { + return nil, fmt.Errorf("telegram target is required") + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + return &telegramOutboundStream{ + adapter: a, + cfg: cfg, + target: target, + reply: opts.Reply, + parseMode: "", + }, nil +} + func resolveTelegramSender(msg *tgbotapi.Message) (string, string, map[string]string) { attrs := map[string]string{} if msg == nil { @@ -368,36 +407,69 @@ func parseReplyToMessageID(reply *channel.ReplyRef) int { } func sendTelegramText(bot *tgbotapi.BotAPI, target string, text string, replyTo int, parseMode string) error { + _, _, err := sendTelegramTextReturnMessage(bot, target, text, replyTo, parseMode) + return err +} + +// sendTelegramTextReturnMessage sends a text message and returns the chat ID and message ID for later editing. +func sendTelegramTextReturnMessage(bot *tgbotapi.BotAPI, target string, text string, replyTo int, parseMode string) (chatID int64, messageID int, err error) { + var sent tgbotapi.Message if strings.HasPrefix(target, "@") { message := tgbotapi.NewMessageToChannel(target, text) message.ParseMode = parseMode if replyTo > 0 { message.ReplyToMessageID = replyTo } - _, err := bot.Send(message) - return err + sent, err = bot.Send(message) + if err != nil { + return 0, 0, err + } + } else { + chatID, err = strconv.ParseInt(target, 10, 64) + if err != nil { + return 0, 0, fmt.Errorf("telegram target must be @username or chat_id") + } + message := tgbotapi.NewMessage(chatID, text) + message.ParseMode = parseMode + if replyTo > 0 { + message.ReplyToMessageID = replyTo + } + sent, err = bot.Send(message) + if err != nil { + return 0, 0, err + } } - chatID, err := strconv.ParseInt(target, 10, 64) - if err != nil { - return fmt.Errorf("telegram target must be @username or chat_id") + if sent.Chat != nil { + chatID = sent.Chat.ID } - message := tgbotapi.NewMessage(chatID, text) - message.ParseMode = parseMode - if replyTo > 0 { - message.ReplyToMessageID = replyTo + messageID = sent.MessageID + return chatID, messageID, nil +} + +func editTelegramMessageText(bot *tgbotapi.BotAPI, chatID int64, messageID int, text string, parseMode string) error { + if len(text) > telegramMaxMessageLength { + text = text[:telegramMaxMessageLength-3] + "..." } - _, err = bot.Send(message) + edit := tgbotapi.NewEditMessageText(chatID, messageID, text) + edit.ParseMode = parseMode + _, err := bot.Send(edit) return err } func sendTelegramAttachment(bot *tgbotapi.BotAPI, target string, att channel.Attachment, caption string, replyTo int, parseMode string) error { - if strings.TrimSpace(att.URL) == "" { - return fmt.Errorf("attachment url is required") + urlRef := strings.TrimSpace(att.URL) + keyRef := strings.TrimSpace(att.PlatformKey) + sourcePlatform := strings.TrimSpace(att.SourcePlatform) + if urlRef == "" && keyRef == "" { + return fmt.Errorf("attachment reference is required") } if strings.TrimSpace(caption) == "" && strings.TrimSpace(att.Caption) != "" { caption = strings.TrimSpace(att.Caption) } - file := tgbotapi.FileURL(att.URL) + file := tgbotapi.RequestFileData(tgbotapi.FileURL(urlRef)) + if keyRef != "" && (sourcePlatform == "" || strings.EqualFold(sourcePlatform, Type.String())) { + file = tgbotapi.FileID(keyRef) + } isChannel := strings.HasPrefix(target, "@") switch att.Type { case channel.AttachmentImage: @@ -565,6 +637,33 @@ func resolveTelegramParseMode(format channel.MessageFormat) string { } } +func isTelegramBotMentioned(msg *tgbotapi.Message, botUsername string) bool { + if msg == nil { + return false + } + normalizedBot := strings.ToLower(strings.TrimPrefix(strings.TrimSpace(botUsername), "@")) + if normalizedBot != "" { + text := strings.TrimSpace(msg.Text) + if text == "" { + text = strings.TrimSpace(msg.Caption) + } + if text != "" { + if strings.Contains(strings.ToLower(text), "@"+normalizedBot) { + return true + } + } + } + entities := make([]tgbotapi.MessageEntity, 0, len(msg.Entities)+len(msg.CaptionEntities)) + entities = append(entities, msg.Entities...) + entities = append(entities, msg.CaptionEntities...) + for _, entity := range entities { + if entity.Type == "text_mention" && entity.User != nil && entity.User.IsBot { + return true + } + } + return false +} + func (a *TelegramAdapter) collectTelegramAttachments(bot *tgbotapi.BotAPI, msg *tgbotapi.Message) []channel.Attachment { if msg == nil { return nil @@ -633,12 +732,14 @@ func (a *TelegramAdapter) buildTelegramAttachment(bot *tgbotapi.BotAPI, attType } } att := channel.Attachment{ - Type: attType, - URL: strings.TrimSpace(url), - Name: strings.TrimSpace(name), - Mime: strings.TrimSpace(mime), - Size: size, - Metadata: map[string]any{}, + Type: attType, + URL: strings.TrimSpace(url), + PlatformKey: strings.TrimSpace(fileID), + SourcePlatform: Type.String(), + Name: strings.TrimSpace(name), + Mime: strings.TrimSpace(mime), + Size: size, + Metadata: map[string]any{}, } if fileID != "" { att.Metadata["file_id"] = fileID diff --git a/internal/channel/adapters/telegram/telegram_test.go b/internal/channel/adapters/telegram/telegram_test.go index 6d3a5834..b3c4ca46 100644 --- a/internal/channel/adapters/telegram/telegram_test.go +++ b/internal/channel/adapters/telegram/telegram_test.go @@ -1,9 +1,12 @@ package telegram import ( + "context" + "strings" "testing" tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + "github.com/memohai/memoh/internal/channel" ) func TestResolveTelegramSender(t *testing.T) { @@ -24,3 +27,280 @@ func TestResolveTelegramSender(t *testing.T) { t.Fatalf("unexpected attrs: %#v", attrs) } } + +func TestIsTelegramBotMentioned(t *testing.T) { + t.Parallel() + + t.Run("text mention", func(t *testing.T) { + t.Parallel() + msg := &tgbotapi.Message{ + Text: "hello @MemohBot", + } + if !isTelegramBotMentioned(msg, "memohbot") { + t.Fatalf("expected bot mention from text") + } + }) + + t.Run("entity text mention", func(t *testing.T) { + t.Parallel() + msg := &tgbotapi.Message{ + Entities: []tgbotapi.MessageEntity{ + { + Type: "text_mention", + User: &tgbotapi.User{IsBot: true}, + }, + }, + } + if !isTelegramBotMentioned(msg, "") { + t.Fatalf("expected bot mention from text_mention entity") + } + }) + + t.Run("not mentioned", func(t *testing.T) { + t.Parallel() + msg := &tgbotapi.Message{ + Text: "hello everyone", + } + if isTelegramBotMentioned(msg, "memohbot") { + t.Fatalf("expected no mention") + } + }) +} + +func TestTelegramDescriptorIncludesStreaming(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + caps := adapter.Descriptor().Capabilities + if !caps.Streaming { + t.Fatal("expected streaming capability") + } + if !caps.Media { + t.Fatal("expected media capability") + } +} + +func TestBuildTelegramAttachmentIncludesPlatformReference(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + att := adapter.buildTelegramAttachment(nil, channel.AttachmentFile, "file_1", "doc.txt", "text/plain", 10) + if att.PlatformKey != "file_1" { + t.Fatalf("unexpected platform key: %s", att.PlatformKey) + } + if att.SourcePlatform != Type.String() { + t.Fatalf("unexpected source platform: %s", att.SourcePlatform) + } +} + +func TestParseReplyToMessageID(t *testing.T) { + t.Parallel() + + if got := parseReplyToMessageID(nil); got != 0 { + t.Fatalf("nil reply should return 0: %d", got) + } + if got := parseReplyToMessageID(&channel.ReplyRef{}); got != 0 { + t.Fatalf("empty MessageID should return 0: %d", got) + } + if got := parseReplyToMessageID(&channel.ReplyRef{MessageID: " 123 "}); got != 123 { + t.Fatalf("expected 123: %d", got) + } + if got := parseReplyToMessageID(&channel.ReplyRef{MessageID: "abc"}); got != 0 { + t.Fatalf("invalid number should return 0: %d", got) + } +} + +func TestResolveTelegramParseMode(t *testing.T) { + t.Parallel() + + if got := resolveTelegramParseMode(channel.MessageFormatMarkdown); got != tgbotapi.ModeMarkdown { + t.Fatalf("markdown should return ModeMarkdown: %s", got) + } + if got := resolveTelegramParseMode(channel.MessageFormatPlain); got != "" { + t.Fatalf("plain should return empty: %s", got) + } + if got := resolveTelegramParseMode(channel.MessageFormatRich); got != "" { + t.Fatalf("rich should return empty: %s", got) + } +} + +func TestBuildTelegramReplyRef(t *testing.T) { + t.Parallel() + + if buildTelegramReplyRef(nil, "123") != nil { + t.Fatal("nil msg should return nil") + } + msg := &tgbotapi.Message{} + if buildTelegramReplyRef(msg, "123") != nil { + t.Fatal("msg without ReplyToMessage should return nil") + } + msg.ReplyToMessage = &tgbotapi.Message{MessageID: 42} + ref := buildTelegramReplyRef(msg, " -100 ") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.MessageID != "42" || ref.Target != "-100" { + t.Fatalf("unexpected ref: %+v", ref) + } +} + +func TestPickTelegramPhoto(t *testing.T) { + t.Parallel() + + if got := pickTelegramPhoto(nil); got.FileID != "" { + t.Fatalf("nil should return zero: %+v", got) + } + if got := pickTelegramPhoto([]tgbotapi.PhotoSize{}); got.FileID != "" { + t.Fatalf("empty slice should return zero: %+v", got) + } + one := tgbotapi.PhotoSize{FileID: "a", FileSize: 100, Width: 10, Height: 10} + if got := pickTelegramPhoto([]tgbotapi.PhotoSize{one}); got.FileID != "a" { + t.Fatalf("single photo should return it: %+v", got) + } + photos := []tgbotapi.PhotoSize{ + {FileID: "small", FileSize: 100, Width: 100, Height: 100}, + {FileID: "large", FileSize: 500, Width: 200, Height: 200}, + } + if got := pickTelegramPhoto(photos); got.FileID != "large" { + t.Fatalf("should pick largest by size: %+v", got) + } +} + +func TestTelegramAdapter_Type(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + if adapter.Type() != Type { + t.Fatalf("Type should return telegram: %s", adapter.Type()) + } +} + +func TestTelegramAdapter_OpenStreamEmptyTarget(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + ctx := context.Background() + cfg := channel.ChannelConfig{} + _, err := adapter.OpenStream(ctx, cfg, "", channel.StreamOptions{}) + if err == nil { + t.Fatal("empty target should return error") + } + if !strings.Contains(err.Error(), "target") { + t.Fatalf("expected target in error: %v", err) + } +} + +func TestResolveTelegramSender_SenderChat(t *testing.T) { + t.Parallel() + + msg := &tgbotapi.Message{ + SenderChat: &tgbotapi.Chat{ID: 456, UserName: "group", Title: "My Group"}, + } + externalID, displayName, attrs := resolveTelegramSender(msg) + if externalID != "456" { + t.Fatalf("unexpected externalID: %s", externalID) + } + if displayName != "My Group" { + t.Fatalf("unexpected displayName: %s", displayName) + } + if attrs["sender_chat_id"] != "456" || attrs["sender_chat_username"] != "group" { + t.Fatalf("unexpected attrs: %#v", attrs) + } +} + +func TestBuildTelegramAudio(t *testing.T) { + t.Parallel() + + cfg, err := buildTelegramAudio("@channel", tgbotapi.FileID("f1")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.ChannelUsername != "@channel" { + t.Fatalf("unexpected channel: %s", cfg.ChannelUsername) + } + _, err = buildTelegramAudio("invalid", tgbotapi.FileID("f1")) + if err == nil { + t.Fatal("invalid target should return error") + } + if !strings.Contains(err.Error(), "chat_id") { + t.Fatalf("expected chat_id in error: %v", err) + } +} + +func TestBuildTelegramVoice(t *testing.T) { + t.Parallel() + + cfg, err := buildTelegramVoice("@ch", tgbotapi.FileID("f1")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.ChannelUsername != "@ch" { + t.Fatalf("unexpected channel: %s", cfg.ChannelUsername) + } + _, err = buildTelegramVoice("x", tgbotapi.FileID("f1")) + if err == nil { + t.Fatal("invalid target should return error") + } +} + +func TestBuildTelegramVideo(t *testing.T) { + t.Parallel() + + cfg, err := buildTelegramVideo("@ch", tgbotapi.FileID("f1")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.ChannelUsername != "@ch" { + t.Fatalf("unexpected channel: %s", cfg.ChannelUsername) + } + _, err = buildTelegramVideo("bad", tgbotapi.FileID("f1")) + if err == nil { + t.Fatal("invalid target should return error") + } +} + +func TestBuildTelegramAnimation(t *testing.T) { + t.Parallel() + + cfg, err := buildTelegramAnimation("@ch", tgbotapi.FileID("f1")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.ChannelUsername != "@ch" { + t.Fatalf("unexpected channel: %s", cfg.ChannelUsername) + } + _, err = buildTelegramAnimation("x", tgbotapi.FileID("f1")) + if err == nil { + t.Fatal("invalid target should return error") + } +} + +func TestTelegramAdapter_NormalizeAndResolve(t *testing.T) { + t.Parallel() + + adapter := NewTelegramAdapter(nil) + norm, err := adapter.NormalizeConfig(map[string]any{"botToken": "t1"}) + if err != nil { + t.Fatalf("NormalizeConfig: %v", err) + } + if norm["botToken"] != "t1" { + t.Fatalf("unexpected normalized: %#v", norm) + } + userNorm, err := adapter.NormalizeUserConfig(map[string]any{"username": "u1"}) + if err != nil { + t.Fatalf("NormalizeUserConfig: %v", err) + } + if userNorm["username"] != "u1" { + t.Fatalf("unexpected user config: %#v", userNorm) + } + if got := adapter.NormalizeTarget("https://t.me/x"); got != "@x" { + t.Fatalf("NormalizeTarget: %s", got) + } + target, err := adapter.ResolveTarget(map[string]any{"chat_id": "123"}) + if err != nil { + t.Fatalf("ResolveTarget: %v", err) + } + if target != "123" { + t.Fatalf("ResolveTarget: %s", target) + } +} diff --git a/internal/channel/capabilities.go b/internal/channel/capabilities.go index a1cf379c..2de4c90c 100644 --- a/internal/channel/capabilities.go +++ b/internal/channel/capabilities.go @@ -3,20 +3,20 @@ package channel // ChannelCapabilities describes the feature matrix of a channel type. // It is used by the outbound layer to validate message content before delivery. type ChannelCapabilities struct { - Text bool `json:"text"` - Markdown bool `json:"markdown"` - RichText bool `json:"rich_text"` - Attachments bool `json:"attachments"` - Media bool `json:"media"` - Reactions bool `json:"reactions"` - Buttons bool `json:"buttons"` - Reply bool `json:"reply"` - Threads bool `json:"threads"` - Streaming bool `json:"streaming"` - Polls bool `json:"polls"` - Edit bool `json:"edit"` - Unsend bool `json:"unsend"` - NativeCommands bool `json:"native_commands"` - BlockStreaming bool `json:"block_streaming"` - ChatTypes []string `json:"chat_types,omitempty"` + Text bool `json:"text"` + Markdown bool `json:"markdown"` + RichText bool `json:"rich_text"` + Attachments bool `json:"attachments"` + Media bool `json:"media"` + Reactions bool `json:"reactions"` + Buttons bool `json:"buttons"` + Reply bool `json:"reply"` + Threads bool `json:"threads"` + Streaming bool `json:"streaming"` + Polls bool `json:"polls"` + Edit bool `json:"edit"` + Unsend bool `json:"unsend"` + NativeCommands bool `json:"native_commands"` + BlockStreaming bool `json:"block_streaming"` + ChatTypes []string `json:"chat_types,omitempty"` } diff --git a/internal/channel/config_test.go b/internal/channel/config_test.go index 837b3b53..b71e232e 100644 --- a/internal/channel/config_test.go +++ b/internal/channel/config_test.go @@ -63,7 +63,7 @@ func (a *testConfigAdapter) ResolveTarget(raw map[string]any) (string, error) { func (a *testConfigAdapter) MatchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { value := channel.ReadString(raw, "user") - return value != "" && value == criteria.ExternalID + return value != "" && value == criteria.SubjectID } func (a *testConfigAdapter) BuildUserConfig(identity channel.Identity) map[string]any { diff --git a/internal/channel/connection.go b/internal/channel/connection.go index ffcf75c1..a64a476f 100644 --- a/internal/channel/connection.go +++ b/internal/channel/connection.go @@ -14,6 +14,13 @@ type connectionEntry struct { } func (m *Manager) refresh(ctx context.Context) { + // Serialize refresh calls to prevent concurrent reconcile from starting + // duplicate adapter connections. + if !m.refreshMu.TryLock() { + return + } + defer m.refreshMu.Unlock() + if m.service == nil { return } @@ -75,29 +82,41 @@ func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error m.mu.Lock() entry := m.connections[cfg.ID] + + // Config unchanged — nothing to do. if entry != nil && !entry.config.UpdatedAt.Before(cfg.UpdatedAt) { m.mu.Unlock() return nil } + + // Need to stop existing connection before starting a new one. + // Keep the lock to prevent another goroutine from starting a duplicate. + var oldConn Connection if entry != nil { - m.mu.Unlock() + oldConn = entry.connection + delete(m.connections, cfg.ID) + } + m.mu.Unlock() + + if oldConn != nil { if m.logger != nil { m.logger.Info("adapter restart", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) } - if err := entry.connection.Stop(ctx); err != nil { + if err := oldConn.Stop(ctx); err != nil { if errors.Is(err, ErrStopNotSupported) { if m.logger != nil { m.logger.Warn("adapter restart skipped", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) } + // Re-insert the entry since we can't restart it. + m.mu.Lock() + if _, exists := m.connections[cfg.ID]; !exists { + m.connections[cfg.ID] = entry + } + m.mu.Unlock() return nil } return err } - m.mu.Lock() - delete(m.connections, cfg.ID) - m.mu.Unlock() - } else { - m.mu.Unlock() } receiver, ok := m.registry.GetReceiver(cfg.ChannelType) @@ -105,6 +124,15 @@ func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error return nil } + // Double-check: another goroutine may have already started a connection + // for this config while we were stopping the old one. + m.mu.Lock() + if existing, ok := m.connections[cfg.ID]; ok && existing != nil { + m.mu.Unlock() + return nil + } + m.mu.Unlock() + if m.logger != nil { m.logger.Info("adapter start", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) } @@ -116,7 +144,15 @@ func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error if err != nil { return err } + m.mu.Lock() + // Final check: if another goroutine raced and inserted first, stop our new + // connection and keep the existing one. + if existing, ok := m.connections[cfg.ID]; ok && existing != nil { + m.mu.Unlock() + _ = conn.Stop(ctx) + return nil + } m.connections[cfg.ID] = &connectionEntry{ config: cfg, connection: conn, diff --git a/internal/channel/directory.go b/internal/channel/directory.go index fa82e98e..32457912 100644 --- a/internal/channel/directory.go +++ b/internal/channel/directory.go @@ -32,5 +32,5 @@ type ChannelDirectoryAdapter interface { ListPeers(ctx context.Context, cfg ChannelConfig, query DirectoryQuery) ([]DirectoryEntry, error) ListGroups(ctx context.Context, cfg ChannelConfig, query DirectoryQuery) ([]DirectoryEntry, error) ListGroupMembers(ctx context.Context, cfg ChannelConfig, groupID string, query DirectoryQuery) ([]DirectoryEntry, error) - ResolveTarget(ctx context.Context, cfg ChannelConfig, input string, kind DirectoryEntryKind) (DirectoryEntry, error) + ResolveEntry(ctx context.Context, cfg ChannelConfig, input string, kind DirectoryEntryKind) (DirectoryEntry, error) } diff --git a/internal/channel/helpers_test.go b/internal/channel/helpers_test.go index de50b163..50434c57 100644 --- a/internal/channel/helpers_test.go +++ b/internal/channel/helpers_test.go @@ -55,13 +55,52 @@ func TestBindingCriteriaFromIdentity(t *testing.T) { t.Parallel() criteria := BindingCriteriaFromIdentity(Identity{ - ExternalID: "u1", + SubjectID: "u1", Attributes: map[string]string{"username": "alice"}, }) - if criteria.ExternalID != "u1" { - t.Fatalf("unexpected external id: %s", criteria.ExternalID) + if criteria.SubjectID != "u1" { + t.Fatalf("unexpected subject id: %s", criteria.SubjectID) } if criteria.Attribute("username") != "alice" { t.Fatalf("unexpected username: %s", criteria.Attribute("username")) } } + +func TestNormalizeChannelConfigStatus(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want string + wantErr bool + }{ + {name: "default pending", input: "", want: "pending"}, + {name: "pending passthrough", input: "pending", want: "pending"}, + {name: "verified passthrough", input: "verified", want: "verified"}, + {name: "disabled passthrough", input: "disabled", want: "disabled"}, + {name: "active alias", input: "active", want: "verified"}, + {name: "inactive alias", input: "inactive", want: "disabled"}, + {name: "unknown status", input: "paused", wantErr: true}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := normalizeChannelConfigStatus(tt.input) + if tt.wantErr { + if err == nil { + t.Fatalf("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if got != tt.want { + t.Fatalf("unexpected status: got %s, want %s", got, tt.want) + } + }) + } +} diff --git a/internal/channel/identities/service.go b/internal/channel/identities/service.go new file mode 100644 index 00000000..b5f0126d --- /dev/null +++ b/internal/channel/identities/service.go @@ -0,0 +1,291 @@ +package identities + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "strings" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" +) + +// Service provides channel identity lifecycle operations. +type Service struct { + queries *sqlc.Queries + logger *slog.Logger +} + +var ( + ErrChannelIdentityNotFound = errors.New("channel identity not found") +) + +// NewService creates a new channel identity service. +func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { + if log == nil { + log = slog.Default() + } + return &Service{ + queries: queries, + logger: log.With(slog.String("service", "channel/identities")), + } +} + +// Create creates a new channel identity for the given channel subject. +func (s *Service) Create(ctx context.Context, channel, channelSubjectID, displayName string) (ChannelIdentity, error) { + if s.queries == nil { + return ChannelIdentity{}, fmt.Errorf("channel identity queries not configured") + } + channel = normalizeChannel(channel) + channelSubjectID = strings.TrimSpace(channelSubjectID) + if channel == "" || channelSubjectID == "" { + return ChannelIdentity{}, fmt.Errorf("channel and channel_subject_id are required") + } + row, err := s.queries.CreateChannelIdentity(ctx, sqlc.CreateChannelIdentityParams{ + UserID: pgtype.UUID{}, + ChannelType: channel, + ChannelSubjectID: channelSubjectID, + DisplayName: toPgText(displayName), + Metadata: emptyMetadataBytes(), + }) + if err != nil { + return ChannelIdentity{}, err + } + return toChannelIdentity(row), nil +} + +// GetByID returns a channel identity by its ID. +func (s *Service) GetByID(ctx context.Context, channelIdentityID string) (ChannelIdentity, error) { + if s.queries == nil { + return ChannelIdentity{}, fmt.Errorf("channel identity queries not configured") + } + pgID, err := db.ParseUUID(channelIdentityID) + if err != nil { + return ChannelIdentity{}, err + } + row, err := s.queries.GetChannelIdentityByID(ctx, pgID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ChannelIdentity{}, ErrChannelIdentityNotFound + } + return ChannelIdentity{}, err + } + return toChannelIdentity(row), nil +} + +// Canonicalize validates and returns the same channel identity ID. +func (s *Service) Canonicalize(ctx context.Context, channelIdentityID string) (string, error) { + if s.queries == nil { + return "", fmt.Errorf("channel identity queries not configured") + } + pgID, err := db.ParseUUID(channelIdentityID) + if err != nil { + return "", err + } + _, err = s.queries.GetChannelIdentityByID(ctx, pgID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return "", ErrChannelIdentityNotFound + } + return "", err + } + return channelIdentityID, nil +} + +// ResolveByChannelIdentity looks up or creates a channel identity for (channel, channel_subject_id). +func (s *Service) ResolveByChannelIdentity(ctx context.Context, channel, channelSubjectID, displayName string) (ChannelIdentity, error) { + if s.queries == nil { + return ChannelIdentity{}, fmt.Errorf("channel identity queries not configured") + } + channel = normalizeChannel(channel) + channelSubjectID = strings.TrimSpace(channelSubjectID) + if channel == "" || channelSubjectID == "" { + return ChannelIdentity{}, fmt.Errorf("channel and channel_subject_id are required") + } + + row, err := s.queries.UpsertChannelIdentityByChannelSubject(ctx, sqlc.UpsertChannelIdentityByChannelSubjectParams{ + UserID: pgtype.UUID{}, + ChannelType: channel, + ChannelSubjectID: channelSubjectID, + DisplayName: toPgText(displayName), + Metadata: emptyMetadataBytes(), + }) + if err != nil { + return ChannelIdentity{}, err + } + return toChannelIdentity(row), nil +} + +// UpsertChannelIdentity creates or updates a channel identity mapping. +func (s *Service) UpsertChannelIdentity(ctx context.Context, channel, channelSubjectID, displayName string, metadata map[string]any) (ChannelIdentity, error) { + if s.queries == nil { + return ChannelIdentity{}, fmt.Errorf("channel identity queries not configured") + } + channel = normalizeChannel(channel) + channelSubjectID = strings.TrimSpace(channelSubjectID) + if metadata == nil { + metadata = map[string]any{} + } + metaBytes, err := json.Marshal(metadata) + if err != nil { + return ChannelIdentity{}, err + } + row, err := s.queries.UpsertChannelIdentityByChannelSubject(ctx, sqlc.UpsertChannelIdentityByChannelSubjectParams{ + UserID: pgtype.UUID{}, + ChannelType: channel, + ChannelSubjectID: channelSubjectID, + DisplayName: toPgText(displayName), + Metadata: metaBytes, + }) + if err != nil { + return ChannelIdentity{}, err + } + return toChannelIdentity(row), nil +} + +// ListCanonicalChannelIdentities lists channel identities under the same linked user. +func (s *Service) ListCanonicalChannelIdentities(ctx context.Context, channelIdentityID string) ([]ChannelIdentity, error) { + if s.queries == nil { + return nil, fmt.Errorf("channel identity queries not configured") + } + pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) + if err != nil { + return nil, err + } + row, err := s.queries.GetChannelIdentityByID(ctx, pgChannelIdentityID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrChannelIdentityNotFound + } + return nil, err + } + if !row.UserID.Valid { + return []ChannelIdentity{toChannelIdentity(row)}, nil + } + rows, err := s.queries.ListChannelIdentitiesByUserID(ctx, row.UserID) + if err != nil { + return nil, err + } + result := make([]ChannelIdentity, 0, len(rows)) + for _, item := range rows { + result = append(result, toChannelIdentity(item)) + } + return result, nil +} + +// ListUserChannelIdentities lists all channel identities linked to a user. +func (s *Service) ListUserChannelIdentities(ctx context.Context, userID string) ([]ChannelIdentity, error) { + if s.queries == nil { + return nil, fmt.Errorf("channel identity queries not configured") + } + pgUserID, err := db.ParseUUID(userID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListChannelIdentitiesByUserID(ctx, pgUserID) + if err != nil { + return nil, err + } + result := make([]ChannelIdentity, 0, len(rows)) + for _, row := range rows { + result = append(result, toChannelIdentity(row)) + } + return result, nil +} + +// GetLinkedUserID returns the linked user ID for a channel identity. +func (s *Service) GetLinkedUserID(ctx context.Context, channelIdentityID string) (string, error) { + if s.queries == nil { + return "", fmt.Errorf("channel identity queries not configured") + } + pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) + if err != nil { + return "", err + } + row, err := s.queries.GetChannelIdentityByID(ctx, pgChannelIdentityID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return "", nil + } + return "", err + } + if !row.UserID.Valid { + return "", nil + } + return row.UserID.String(), nil +} + +// LinkChannelIdentityToUser binds a channel identity to a user. +func (s *Service) LinkChannelIdentityToUser(ctx context.Context, channelIdentityID, userID string) error { + if s.queries == nil { + return fmt.Errorf("channel identity queries not configured") + } + pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) + if err != nil { + return err + } + pgUserID, err := db.ParseUUID(userID) + if err != nil { + return err + } + _, err = s.queries.SetChannelIdentityLinkedUser(ctx, sqlc.SetChannelIdentityLinkedUserParams{ + ID: pgChannelIdentityID, + UserID: pgUserID, + }) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ErrChannelIdentityNotFound + } + return err + } + return nil +} + +func toChannelIdentity(row sqlc.ChannelIdentity) ChannelIdentity { + var metadata map[string]any + if len(row.Metadata) > 0 { + _ = json.Unmarshal(row.Metadata, &metadata) + } + if metadata == nil { + metadata = map[string]any{} + } + displayName := "" + if row.DisplayName.Valid { + displayName = strings.TrimSpace(row.DisplayName.String) + } + userID := "" + if row.UserID.Valid { + userID = row.UserID.String() + } + return ChannelIdentity{ + ID: row.ID.String(), + UserID: userID, + Channel: row.ChannelType, + ChannelSubjectID: row.ChannelSubjectID, + DisplayName: displayName, + Metadata: metadata, + CreatedAt: db.TimeFromPg(row.CreatedAt), + UpdatedAt: db.TimeFromPg(row.UpdatedAt), + } +} + +func normalizeChannel(channel string) string { + return strings.ToLower(strings.TrimSpace(channel)) +} + +func toPgText(value string) pgtype.Text { + value = strings.TrimSpace(value) + return pgtype.Text{ + String: value, + Valid: value != "", + } +} + +func emptyMetadataBytes() []byte { + return []byte("{}") +} diff --git a/internal/channel/identities/service_identity_integration_test.go b/internal/channel/identities/service_identity_integration_test.go new file mode 100644 index 00000000..9cda339a --- /dev/null +++ b/internal/channel/identities/service_identity_integration_test.go @@ -0,0 +1,89 @@ +package identities_test + +import ( + "context" + "fmt" + "log/slog" + "os" + "testing" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/memohai/memoh/internal/channel/identities" + "github.com/memohai/memoh/internal/db/sqlc" +) + +func setupChannelIdentityIdentityIntegrationTest(t *testing.T) (*identities.Service, *sqlc.Queries, func()) { + t.Helper() + + dsn := os.Getenv("TEST_POSTGRES_DSN") + if dsn == "" { + t.Skip("skip integration test: TEST_POSTGRES_DSN is not set") + } + + ctx := context.Background() + pool, err := pgxpool.New(ctx, dsn) + if err != nil { + t.Skipf("skip integration test: cannot connect to database: %v", err) + } + if err := pool.Ping(ctx); err != nil { + pool.Close() + t.Skipf("skip integration test: database ping failed: %v", err) + } + + queries := sqlc.New(pool) + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + svc := identities.NewService(logger, queries) + return svc, queries, func() { pool.Close() } +} + +func TestChannelIdentityResolveChannelIdentityStable(t *testing.T) { + svc, _, cleanup := setupChannelIdentityIdentityIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + externalID := fmt.Sprintf("stable_%d", time.Now().UnixNano()) + first, err := svc.ResolveByChannelIdentity(ctx, "feishu", externalID, "first") + if err != nil { + t.Fatalf("first resolve failed: %v", err) + } + second, err := svc.ResolveByChannelIdentity(ctx, "feishu", externalID, "second") + if err != nil { + t.Fatalf("second resolve failed: %v", err) + } + if first.ID != second.ID { + t.Fatalf("expected same channelIdentity id, got %s and %s", first.ID, second.ID) + } +} + +func TestChannelIdentityLinkToUser(t *testing.T) { + svc, queries, cleanup := setupChannelIdentityIdentityIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + channelIdentity, err := svc.ResolveByChannelIdentity(ctx, "telegram", fmt.Sprintf("link_%d", time.Now().UnixNano()), "tg") + if err != nil { + t.Fatalf("resolve channelIdentity failed: %v", err) + } + user, err := queries.CreateUser(ctx, sqlc.CreateUserParams{ + IsActive: true, + Metadata: []byte("{}"), + }) + if err != nil { + t.Fatalf("create user failed: %v", err) + } + userID := uuid.UUID(user.ID.Bytes).String() + + if err := svc.LinkChannelIdentityToUser(ctx, channelIdentity.ID, userID); err != nil { + t.Fatalf("link channelIdentity to user failed: %v", err) + } + linkedUserID, err := svc.GetLinkedUserID(ctx, channelIdentity.ID) + if err != nil { + t.Fatalf("get linked user failed: %v", err) + } + if linkedUserID != userID { + t.Fatalf("expected linked user=%s, got %s", userID, linkedUserID) + } +} diff --git a/internal/channel/identities/service_integration_test.go b/internal/channel/identities/service_integration_test.go new file mode 100644 index 00000000..6f1833d7 --- /dev/null +++ b/internal/channel/identities/service_integration_test.go @@ -0,0 +1,96 @@ +//go:build ignore +// +build ignore + +package identities_test + +import ( + "context" + "fmt" + "log/slog" + "os" + "testing" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/memohai/memoh/internal/channel/identities" + "github.com/memohai/memoh/internal/db/sqlc" +) + +func setupIntegrationTest(t *testing.T) (*identities.Service, *sqlc.Queries, func()) { + t.Helper() + + dsn := os.Getenv("TEST_POSTGRES_DSN") + if dsn == "" { + t.Skip("skip integration test: TEST_POSTGRES_DSN is not set") + } + + ctx := context.Background() + pool, err := pgxpool.New(ctx, dsn) + if err != nil { + t.Skipf("skip integration test: cannot connect to database: %v", err) + } + if err := pool.Ping(ctx); err != nil { + pool.Close() + t.Skipf("skip integration test: database ping failed: %v", err) + } + + queries := sqlc.New(pool) + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + svc := identities.NewService(logger, queries) + + return svc, queries, func() { pool.Close() } +} + +func TestIntegrationResolveByChannelIdentityStability(t *testing.T) { + svc, _, cleanup := setupIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + key := fmt.Sprintf("ext_%d", time.Now().UnixNano()) + + first, err := svc.ResolveByChannelIdentity(ctx, "feishu", key, "first") + if err != nil { + t.Fatalf("first resolve failed: %v", err) + } + second, err := svc.ResolveByChannelIdentity(ctx, "feishu", key, "second") + if err != nil { + t.Fatalf("second resolve failed: %v", err) + } + if first.ID != second.ID { + t.Fatalf("expected stable channelIdentity id, got %s and %s", first.ID, second.ID) + } +} + +func TestIntegrationLinkChannelIdentityToUser(t *testing.T) { + svc, queries, cleanup := setupIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + key := fmt.Sprintf("bind_%d", time.Now().UnixNano()) + channelIdentity, err := svc.ResolveByChannelIdentity(ctx, "telegram", key, "tg-user") + if err != nil { + t.Fatalf("resolve channelIdentity failed: %v", err) + } + + user, err := queries.CreateUser(ctx, sqlc.CreateUserParams{ + IsActive: true, + Metadata: []byte("{}"), + }) + if err != nil { + t.Fatalf("create user failed: %v", err) + } + userID := uuid.UUID(user.ID.Bytes).String() + + if err := svc.LinkChannelIdentityToUser(ctx, channelIdentity.ID, userID); err != nil { + t.Fatalf("link channelIdentity to user failed: %v", err) + } + linkedUserID, err := svc.GetLinkedUserID(ctx, channelIdentity.ID) + if err != nil { + t.Fatalf("get linked user failed: %v", err) + } + if linkedUserID != userID { + t.Fatalf("expected linked user=%s, got %s", userID, linkedUserID) + } +} diff --git a/internal/channel/identities/service_test.go b/internal/channel/identities/service_test.go new file mode 100644 index 00000000..2bdb6bde --- /dev/null +++ b/internal/channel/identities/service_test.go @@ -0,0 +1,37 @@ +package identities + +import "testing" + +func TestNormalizeChannel(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"feishu", "feishu"}, + {" FEISHU ", "feishu"}, + {"Web", "web"}, + {"", ""}, + } + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + result := normalizeChannel(tc.input) + if result != tc.expected { + t.Errorf("normalizeChannel(%q) = %q, want %q", tc.input, result, tc.expected) + } + }) + } +} + +func TestToPgText(t *testing.T) { + value := toPgText(" display ") + if !value.Valid { + t.Fatal("expected valid text for non-empty input") + } + if value.String != "display" { + t.Fatalf("expected trimmed text display, got %q", value.String) + } + empty := toPgText(" ") + if empty.Valid { + t.Fatal("expected invalid text for empty input") + } +} diff --git a/internal/channel/identities/types.go b/internal/channel/identities/types.go new file mode 100644 index 00000000..cfc36d30 --- /dev/null +++ b/internal/channel/identities/types.go @@ -0,0 +1,15 @@ +package identities + +import "time" + +// ChannelIdentity is a unified inbound identity subject across channels. +type ChannelIdentity struct { + ID string `json:"id"` + UserID string `json:"user_id,omitempty"` + Channel string `json:"channel"` + ChannelSubjectID string `json:"channel_subject_id"` + DisplayName string `json:"display_name,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} diff --git a/internal/channel/inbound_test.go b/internal/channel/inbound_test.go new file mode 100644 index 00000000..2135d5ba --- /dev/null +++ b/internal/channel/inbound_test.go @@ -0,0 +1,217 @@ +package channel + +import ( + "context" + "fmt" + "log/slog" + "testing" +) + +// mockAdapter is used for inbound handleInbound tests. +type mockAdapter struct { + sentMessages []OutboundMessage + streamEvents []StreamEvent +} + +func (m *mockAdapter) Type() ChannelType { return ChannelType("test") } +func (m *mockAdapter) Descriptor() Descriptor { + return Descriptor{ + Type: ChannelType("test"), + DisplayName: "Test", + Capabilities: ChannelCapabilities{ + Text: true, + Reply: true, + Streaming: true, + }, + } +} +func (m *mockAdapter) Send(ctx context.Context, cfg ChannelConfig, msg OutboundMessage) error { + m.sentMessages = append(m.sentMessages, msg) + return nil +} + +func (m *mockAdapter) OpenStream(ctx context.Context, cfg ChannelConfig, target string, opts StreamOptions) (OutboundStream, error) { + return &mockAdapterStream{adapter: m}, nil +} + +type mockAdapterStream struct { + adapter *mockAdapter +} + +func (s *mockAdapterStream) Push(ctx context.Context, event StreamEvent) error { + if s == nil || s.adapter == nil { + return nil + } + s.adapter.streamEvents = append(s.adapter.streamEvents, event) + if event.Type == StreamEventFinal && event.Final != nil && !event.Final.Message.IsEmpty() { + s.adapter.sentMessages = append(s.adapter.sentMessages, OutboundMessage{ + Target: "stream-target", + Message: event.Final.Message, + }) + } + return nil +} + +func (s *mockAdapterStream) Close(ctx context.Context) error { + return nil +} + +type fakeInboundProcessor struct { + resp *OutboundMessage + err error + gotCfg ChannelConfig + gotMsg InboundMessage +} + +func (f *fakeInboundProcessor) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender StreamReplySender) error { + f.gotCfg = cfg + f.gotMsg = msg + if f.err != nil { + return f.err + } + if f.resp == nil { + return nil + } + if sender == nil { + return fmt.Errorf("sender missing") + } + return sender.Send(ctx, *f.resp) +} + +type fakeInboundStreamProcessor struct{} + +func (f *fakeInboundStreamProcessor) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender StreamReplySender) error { + stream, err := sender.OpenStream(ctx, "stream-target", StreamOptions{}) + if err != nil { + return err + } + if err := stream.Push(ctx, StreamEvent{ + Type: StreamEventDelta, + Delta: "partial", + }); err != nil { + return err + } + if err := stream.Push(ctx, StreamEvent{ + Type: StreamEventFinal, + Final: &StreamFinalizePayload{ + Message: Message{Text: "stream-final"}, + }, + }); err != nil { + return err + } + return stream.Close(ctx) +} + +func TestManager_handleInbound(t *testing.T) { + logger := slog.Default() + + t.Run("with_reply_sends_successfully", func(t *testing.T) { + processor := &fakeInboundProcessor{ + resp: &OutboundMessage{ + Target: "target-id", + Message: Message{ + Text: "AI reply content", + }, + }, + } + + reg := NewRegistry() + m := NewManager(logger, reg, &fakeConfigStore{}, processor) + adapter := &mockAdapter{} + m.RegisterAdapter(adapter) + + cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelType("test")} + msg := InboundMessage{ + Channel: ChannelType("test"), + Message: Message{Text: "hello"}, + ReplyTarget: "target-id", + Conversation: Conversation{ + ID: "chat-1", + Type: "p2p", + }, + } + + err := m.handleInbound(context.Background(), cfg, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(adapter.sentMessages) != 1 { + t.Fatalf("expected 1 reply sent, got %d", len(adapter.sentMessages)) + } + if adapter.sentMessages[0].Message.PlainText() != "AI reply content" { + t.Errorf("reply content mismatch: %s", adapter.sentMessages[0].Message.PlainText()) + } + if adapter.sentMessages[0].Target != "target-id" { + t.Errorf("reply target mismatch: %s", adapter.sentMessages[0].Target) + } + }) + + t.Run("no_reply_does_not_send", func(t *testing.T) { + processor := &fakeInboundProcessor{resp: nil} + reg := NewRegistry() + m := NewManager(logger, reg, &fakeConfigStore{}, processor) + adapter := &mockAdapter{} + m.RegisterAdapter(adapter) + + cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelType("test")} + msg := InboundMessage{ + Channel: ChannelType("test"), + Message: Message{Text: "hello"}, + ReplyTarget: "target-id", + } + + err := m.handleInbound(context.Background(), cfg, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(adapter.sentMessages) != 0 { + t.Errorf("expected no reply sent, got %+v", adapter.sentMessages) + } + }) + + t.Run("handler_error_returns_error", func(t *testing.T) { + processor := &fakeInboundProcessor{err: context.Canceled} + reg := NewRegistry() + m := NewManager(logger, reg, &fakeConfigStore{}, processor) + cfg := ChannelConfig{ID: "bot-1"} + msg := InboundMessage{Message: Message{Text: " "}} // whitespace-only message + + err := m.handleInbound(context.Background(), cfg, msg) + if err == nil { + t.Errorf("expected handler to return error") + } + }) + + t.Run("stream sender forwards events", func(t *testing.T) { + processor := &fakeInboundStreamProcessor{} + reg := NewRegistry() + m := NewManager(logger, reg, &fakeConfigStore{}, processor) + adapter := &mockAdapter{} + m.RegisterAdapter(adapter) + + cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelType("test")} + msg := InboundMessage{ + Channel: ChannelType("test"), + Message: Message{Text: "hello"}, + ReplyTarget: "stream-target", + Conversation: Conversation{ + ID: "chat-1", + Type: "p2p", + }, + } + if err := m.handleInbound(context.Background(), cfg, msg); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(adapter.streamEvents) < 2 { + t.Fatalf("expected at least two stream events, got %d", len(adapter.streamEvents)) + } + if len(adapter.sentMessages) == 0 { + t.Fatal("expected stream final message to be published") + } + if adapter.sentMessages[len(adapter.sentMessages)-1].Message.PlainText() != "stream-final" { + t.Fatalf("unexpected stream final message: %s", adapter.sentMessages[len(adapter.sentMessages)-1].Message.PlainText()) + } + }) +} diff --git a/internal/channel/manager.go b/internal/channel/manager.go index feead996..94b0647d 100644 --- a/internal/channel/manager.go +++ b/internal/channel/manager.go @@ -18,19 +18,12 @@ type ConfigLister interface { // ConfigResolver resolves effective configs and user bindings. Used for outbound sending. type ConfigResolver interface { ResolveEffectiveConfig(ctx context.Context, botID string, channelType ChannelType) (ChannelConfig, error) - GetUserConfig(ctx context.Context, actorUserID string, channelType ChannelType) (ChannelUserBinding, error) + GetChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType) (ChannelIdentityBinding, error) } -// BindingStore resolves user-channel bindings. Used by identity resolution. +// BindingStore resolves channel-identity bindings. Used by identity resolution. type BindingStore interface { - ResolveUserBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) -} - -// SessionStore manages channel session lifecycle. Used by identity resolution. -type SessionStore interface { - GetChannelSession(ctx context.Context, sessionID string) (ChannelSession, error) - UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string, replyTarget string, threadID string, metadata map[string]any) error - ListSessionsByBotPlatform(ctx context.Context, botID string, platform string) ([]ChannelSession, error) + ResolveChannelIdentityBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) } // ConfigStore is the full persistence interface. Components should depend on smaller @@ -39,8 +32,7 @@ type ConfigStore interface { ConfigLister ConfigResolver BindingStore - SessionStore - UpsertUserConfig(ctx context.Context, actorUserID string, channelType ChannelType, req UpsertUserConfigRequest) (ChannelUserBinding, error) + UpsertChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType, req UpsertChannelIdentityConfigRequest) (ChannelIdentityBinding, error) } // Middleware wraps an InboundHandler to add cross-cutting behavior. @@ -69,6 +61,7 @@ type Manager struct { inboundCtx context.Context inboundCancel context.CancelFunc mu sync.Mutex + refreshMu sync.Mutex connections map[string]*connectionEntry } @@ -187,14 +180,14 @@ func (m *Manager) Send(ctx context.Context, botID string, channelType ChannelTyp } target := strings.TrimSpace(req.Target) if target == "" { - targetUserID := strings.TrimSpace(req.UserID) - if targetUserID == "" { + targetChannelIdentityID := strings.TrimSpace(req.ChannelIdentityID) + if targetChannelIdentityID == "" { return fmt.Errorf("target or user_id is required") } - userCfg, err := m.service.GetUserConfig(ctx, targetUserID, channelType) + userCfg, err := m.service.GetChannelIdentityConfig(ctx, targetChannelIdentityID, channelType) if err != nil { if m.logger != nil { - m.logger.Warn("channel binding missing", slog.String("channel", channelType.String()), slog.String("user_id", targetUserID)) + m.logger.Warn("channel binding missing", slog.String("channel", channelType.String()), slog.String("channel_identity_id", targetChannelIdentityID)) } return fmt.Errorf("channel binding required") } diff --git a/internal/channel/manager_core_test.go b/internal/channel/manager_core_test.go deleted file mode 100644 index 9283bcd6..00000000 --- a/internal/channel/manager_core_test.go +++ /dev/null @@ -1,128 +0,0 @@ -package channel - -import ( - "context" - "fmt" - "log/slog" - "testing" -) - -// mockAdapter 专门用于 Manager 路由测试 -type mockAdapter struct { - sentMessages []OutboundMessage -} - -func (m *mockAdapter) Type() ChannelType { return ChannelType("test") } -func (m *mockAdapter) Descriptor() Descriptor { - return Descriptor{Type: ChannelType("test"), DisplayName: "Test", Capabilities: ChannelCapabilities{Text: true}} -} -func (m *mockAdapter) Send(ctx context.Context, cfg ChannelConfig, msg OutboundMessage) error { - m.sentMessages = append(m.sentMessages, msg) - return nil -} - -type fakeInboundProcessor struct { - resp *OutboundMessage - err error - gotCfg ChannelConfig - gotMsg InboundMessage -} - -func (f *fakeInboundProcessor) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender ReplySender) error { - f.gotCfg = cfg - f.gotMsg = msg - if f.err != nil { - return f.err - } - if f.resp == nil { - return nil - } - if sender == nil { - return fmt.Errorf("sender missing") - } - return sender.Send(ctx, *f.resp) -} - -func TestManager_HandleInbound_CoreLogic(t *testing.T) { - logger := slog.Default() - - t.Run("返回回复_发送成功", func(t *testing.T) { - processor := &fakeInboundProcessor{ - resp: &OutboundMessage{ - Target: "target-id", - Message: Message{ - Text: "AI回复内容", - }, - }, - } - - reg := NewRegistry() - m := NewManager(logger, reg, &fakeConfigStore{}, processor) - adapter := &mockAdapter{} - m.RegisterAdapter(adapter) - - cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelType("test")} - msg := InboundMessage{ - Channel: ChannelType("test"), - Message: Message{Text: "你好"}, - ReplyTarget: "target-id", - Conversation: Conversation{ - ID: "chat-1", - Type: "p2p", - }, - } - - err := m.handleInbound(context.Background(), cfg, msg) - if err != nil { - t.Fatalf("不应报错: %v", err) - } - - // 验证: 是否正确调用了 Adapter 发送回复 - if len(adapter.sentMessages) != 1 { - t.Fatalf("应该发送 1 条回复,实际发送: %d", len(adapter.sentMessages)) - } - if adapter.sentMessages[0].Message.PlainText() != "AI回复内容" { - t.Errorf("回复内容错误: %s", adapter.sentMessages[0].Message.PlainText()) - } - if adapter.sentMessages[0].Target != "target-id" { - t.Errorf("回复目标错误: %s", adapter.sentMessages[0].Target) - } - }) - - t.Run("无回复_不发送", func(t *testing.T) { - processor := &fakeInboundProcessor{resp: nil} - reg := NewRegistry() - m := NewManager(logger, reg, &fakeConfigStore{}, processor) - adapter := &mockAdapter{} - m.RegisterAdapter(adapter) - - cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelType("test")} - msg := InboundMessage{ - Channel: ChannelType("test"), - Message: Message{Text: "你好"}, - ReplyTarget: "target-id", - } - - err := m.handleInbound(context.Background(), cfg, msg) - if err != nil { - t.Fatalf("不应报错: %v", err) - } - - if len(adapter.sentMessages) != 0 { - t.Errorf("不应发送回复,实际发送: %+v", adapter.sentMessages) - } - }) - - t.Run("处理失败_返回错误", func(t *testing.T) { - processor := &fakeInboundProcessor{err: context.Canceled} - reg := NewRegistry() - m := NewManager(logger, reg, &fakeConfigStore{}, processor) - cfg := ChannelConfig{ID: "bot-1"} - msg := InboundMessage{Message: Message{Text: " "}} // 空格消息 - - err := m.handleInbound(context.Background(), cfg, msg) - if err == nil { - t.Errorf("应返回处理错误") - } - }) -} diff --git a/internal/channel/manager_integration_test.go b/internal/channel/manager_integration_test.go index 82b910c0..fc296014 100644 --- a/internal/channel/manager_integration_test.go +++ b/internal/channel/manager_integration_test.go @@ -12,26 +12,25 @@ import ( ) type fakeConfigStore struct { - effectiveConfig ChannelConfig - userConfig ChannelUserBinding - configsByType map[ChannelType][]ChannelConfig - session ChannelSession - boundUserID string + effectiveConfig ChannelConfig + channelIdentityConfig ChannelIdentityBinding + configsByType map[ChannelType][]ChannelConfig + boundChannelIdentityID string } func (f *fakeConfigStore) ResolveEffectiveConfig(ctx context.Context, botID string, channelType ChannelType) (ChannelConfig, error) { return f.effectiveConfig, nil } -func (f *fakeConfigStore) GetUserConfig(ctx context.Context, actorUserID string, channelType ChannelType) (ChannelUserBinding, error) { - if f.userConfig.ID == "" && len(f.userConfig.Config) == 0 { - return ChannelUserBinding{}, fmt.Errorf("channel user config not found") +func (f *fakeConfigStore) GetChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType) (ChannelIdentityBinding, error) { + if f.channelIdentityConfig.ID == "" && len(f.channelIdentityConfig.Config) == 0 { + return ChannelIdentityBinding{}, fmt.Errorf("channel user config not found") } - return f.userConfig, nil + return f.channelIdentityConfig, nil } -func (f *fakeConfigStore) UpsertUserConfig(ctx context.Context, actorUserID string, channelType ChannelType, req UpsertUserConfigRequest) (ChannelUserBinding, error) { - return f.userConfig, nil +func (f *fakeConfigStore) UpsertChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType, req UpsertChannelIdentityConfigRequest) (ChannelIdentityBinding, error) { + return f.channelIdentityConfig, nil } func (f *fakeConfigStore) ListConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelConfig, error) { @@ -41,26 +40,11 @@ func (f *fakeConfigStore) ListConfigsByType(ctx context.Context, channelType Cha return f.configsByType[channelType], nil } -func (f *fakeConfigStore) ResolveUserBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) { - if f.boundUserID == "" { +func (f *fakeConfigStore) ResolveChannelIdentityBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) { + if f.boundChannelIdentityID == "" { return "", fmt.Errorf("channel user binding not found") } - return f.boundUserID, nil -} - -func (f *fakeConfigStore) ListSessionsByBotPlatform(ctx context.Context, botID, platform string) ([]ChannelSession, error) { - return nil, nil -} - -func (f *fakeConfigStore) GetChannelSession(ctx context.Context, sessionID string) (ChannelSession, error) { - if f.session.SessionID == sessionID { - return f.session, nil - } - return ChannelSession{}, nil -} - -func (f *fakeConfigStore) UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string, replyTarget string, threadID string, metadata map[string]any) error { - return nil + return f.boundChannelIdentityID, nil } type fakeInboundProcessorIntegration struct { @@ -70,7 +54,7 @@ type fakeInboundProcessorIntegration struct { gotMsg InboundMessage } -func (f *fakeInboundProcessorIntegration) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender ReplySender) error { +func (f *fakeInboundProcessorIntegration) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender StreamReplySender) error { f.gotCfg = cfg f.gotMsg = msg if f.err != nil { @@ -101,8 +85,8 @@ func (f *fakeAdapter) Descriptor() Descriptor { return Descriptor{Type: f.channelType, DisplayName: "Fake", Capabilities: ChannelCapabilities{Text: true}} } -func (f *fakeAdapter) ResolveTarget(userConfig map[string]any) (string, error) { - value := strings.TrimSpace(ReadString(userConfig, "target")) +func (f *fakeAdapter) ResolveTarget(channelIdentityConfig map[string]any) (string, error) { + value := strings.TrimSpace(ReadString(channelIdentityConfig, "target")) if value == "" { return "", fmt.Errorf("missing target") } @@ -135,13 +119,7 @@ func TestManagerHandleInboundIntegratesAdapter(t *testing.T) { t.Parallel() log := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) - store := &fakeConfigStore{ - session: ChannelSession{ - SessionID: "telegram:bot-1:chat-1", - BotID: "bot-1", - UserID: "user-1", - }, - } + store := &fakeConfigStore{} processor := &fakeInboundProcessorIntegration{ resp: &OutboundMessage{ Target: "123", @@ -202,7 +180,7 @@ func TestManagerSendUsesBinding(t *testing.T) { Credentials: map[string]any{"botToken": "token"}, UpdatedAt: time.Now(), }, - userConfig: ChannelUserBinding{ + channelIdentityConfig: ChannelIdentityBinding{ ID: "binding-1", Config: map[string]any{"target": "alice"}, }, @@ -213,7 +191,7 @@ func TestManagerSendUsesBinding(t *testing.T) { manager.RegisterAdapter(adapter) err := manager.Send(context.Background(), "bot-1", ChannelType("test"), SendRequest{ - UserID: "user-1", + ChannelIdentityID: "user-1", Message: Message{ Text: "hello", }, diff --git a/internal/channel/outbound.go b/internal/channel/outbound.go index a8d1482d..4110b6e3 100644 --- a/internal/channel/outbound.go +++ b/internal/channel/outbound.go @@ -302,6 +302,9 @@ func validateMessageCapabilities(registry *Registry, channelType ChannelType, ms if msg.Reply != nil && !caps.Reply { return fmt.Errorf("channel does not support reply") } + if strings.TrimSpace(msg.ID) != "" && !caps.Edit { + return fmt.Errorf("channel does not support edit") + } return nil } @@ -316,12 +319,40 @@ func (m *Manager) sendWithConfig(ctx context.Context, sender Sender, cfg Channel if msg.Message.IsEmpty() { return fmt.Errorf("message is required") } - if err := validateMessageCapabilities(m.registry, cfg.ChannelType, msg.Message); err != nil { + normalized := msg + attachments, err := normalizeAttachmentRefs(msg.Message.Attachments, cfg.ChannelType) + if err != nil { return err } + normalized.Message.Attachments = attachments + if err := validateMessageCapabilities(m.registry, cfg.ChannelType, normalized.Message); err != nil { + return err + } + editor, _ := m.registry.GetMessageEditor(cfg.ChannelType) + if strings.TrimSpace(normalized.Message.ID) != "" { + if editor == nil { + return fmt.Errorf("channel does not support edit") + } + var lastErr error + for i := 0; i < policy.RetryMax; i++ { + err := editor.Update(ctx, cfg, target, strings.TrimSpace(normalized.Message.ID), normalized.Message) + if err == nil { + return nil + } + lastErr = err + if m.logger != nil { + m.logger.Warn("edit outbound retry", + slog.String("channel", cfg.ChannelType.String()), + slog.Int("attempt", i+1), + slog.Any("error", err)) + } + time.Sleep(time.Duration(i+1) * time.Duration(policy.RetryBackoffMs) * time.Millisecond) + } + return fmt.Errorf("edit outbound failed after retries: %w", lastErr) + } var lastErr error for i := 0; i < policy.RetryMax; i++ { - err := sender.Send(ctx, cfg, OutboundMessage{Target: target, Message: msg.Message}) + err := sender.Send(ctx, cfg, OutboundMessage{Target: target, Message: normalized.Message}) if err == nil { return nil } @@ -337,6 +368,27 @@ func (m *Manager) sendWithConfig(ctx context.Context, sender Sender, cfg Channel return fmt.Errorf("send outbound failed after retries: %w", lastErr) } +func normalizeAttachmentRefs(attachments []Attachment, defaultPlatform ChannelType) ([]Attachment, error) { + if len(attachments) == 0 { + return nil, nil + } + normalized := make([]Attachment, 0, len(attachments)) + for _, att := range attachments { + item := att + item.URL = strings.TrimSpace(item.URL) + item.PlatformKey = strings.TrimSpace(item.PlatformKey) + item.SourcePlatform = strings.TrimSpace(item.SourcePlatform) + if item.SourcePlatform == "" && item.PlatformKey != "" { + item.SourcePlatform = defaultPlatform.String() + } + if item.URL == "" && item.PlatformKey == "" { + return nil, fmt.Errorf("attachment reference is required") + } + normalized = append(normalized, item) + } + return normalized, nil +} + func requiresMedia(attachments []Attachment) bool { for _, att := range attachments { switch att.Type { @@ -349,21 +401,55 @@ func requiresMedia(attachments []Attachment) bool { return false } -func (m *Manager) newReplySender(cfg ChannelConfig, channelType ChannelType) ReplySender { +func validateStreamEvent(registry *Registry, channelType ChannelType, event StreamEvent) error { + caps, _ := registry.GetCapabilities(channelType) + switch event.Type { + case StreamEventStatus: + if event.Status == "" { + return fmt.Errorf("stream status is required") + } + case StreamEventDelta: + if !caps.Streaming && !caps.BlockStreaming { + return fmt.Errorf("channel does not support streaming") + } + case StreamEventFinal: + if event.Final == nil { + return fmt.Errorf("stream final payload is required") + } + if err := validateMessageCapabilities(registry, channelType, event.Final.Message); err != nil { + return err + } + if _, err := normalizeAttachmentRefs(event.Final.Message.Attachments, channelType); err != nil { + return err + } + case StreamEventError: + if strings.TrimSpace(event.Error) == "" { + return fmt.Errorf("stream error is required") + } + default: + return fmt.Errorf("unsupported stream event type: %s", event.Type) + } + return nil +} + +func (m *Manager) newReplySender(cfg ChannelConfig, channelType ChannelType) StreamReplySender { sender, _ := m.registry.GetSender(channelType) + streamSender, _ := m.registry.GetStreamSender(channelType) return &managerReplySender{ - manager: m, - sender: sender, - channelType: channelType, - config: cfg, + manager: m, + sender: sender, + streamSender: streamSender, + channelType: channelType, + config: cfg, } } type managerReplySender struct { - manager *Manager - sender Sender - channelType ChannelType - config ChannelConfig + manager *Manager + sender Sender + streamSender StreamSender + channelType ChannelType + config ChannelConfig } func (s *managerReplySender) Send(ctx context.Context, msg OutboundMessage) error { @@ -382,3 +468,52 @@ func (s *managerReplySender) Send(ctx context.Context, msg OutboundMessage) erro } return nil } + +func (s *managerReplySender) OpenStream(ctx context.Context, target string, opts StreamOptions) (OutboundStream, error) { + if s.manager == nil { + return nil, fmt.Errorf("channel manager not configured") + } + if s.streamSender == nil { + return nil, fmt.Errorf("channel stream sender not configured") + } + target = strings.TrimSpace(target) + if target == "" { + return nil, fmt.Errorf("target is required") + } + caps, _ := s.manager.registry.GetCapabilities(s.channelType) + if !caps.Streaming && !caps.BlockStreaming { + return nil, fmt.Errorf("channel does not support streaming") + } + stream, err := s.streamSender.OpenStream(ctx, s.config, target, opts) + if err != nil { + return nil, err + } + return &managerOutboundStream{ + manager: s.manager, + stream: stream, + channelType: s.channelType, + }, nil +} + +type managerOutboundStream struct { + manager *Manager + stream OutboundStream + channelType ChannelType +} + +func (s *managerOutboundStream) Push(ctx context.Context, event StreamEvent) error { + if s.manager == nil || s.stream == nil { + return fmt.Errorf("stream is not configured") + } + if err := validateStreamEvent(s.manager.registry, s.channelType, event); err != nil { + return err + } + return s.stream.Push(ctx, event) +} + +func (s *managerOutboundStream) Close(ctx context.Context) error { + if s.stream == nil { + return fmt.Errorf("stream is not configured") + } + return s.stream.Close(ctx) +} diff --git a/internal/channel/processor.go b/internal/channel/processor.go index 6f4e79f2..bb363013 100644 --- a/internal/channel/processor.go +++ b/internal/channel/processor.go @@ -4,5 +4,5 @@ import "context" // InboundProcessor handles inbound messages and replies through the given sender. type InboundProcessor interface { - HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender ReplySender) error + HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender StreamReplySender) error } diff --git a/internal/channel/registry.go b/internal/channel/registry.go index b9a629ce..9a672d59 100644 --- a/internal/channel/registry.go +++ b/internal/channel/registry.go @@ -71,6 +71,16 @@ func (r *Registry) Get(channelType ChannelType) (Adapter, bool) { return adapter, ok } +// DirectoryAdapter returns the directory adapter for the given channel type if it implements ChannelDirectoryAdapter. +func (r *Registry) DirectoryAdapter(channelType ChannelType) (ChannelDirectoryAdapter, bool) { + adapter, ok := r.Get(channelType) + if !ok { + return nil, false + } + dir, ok := adapter.(ChannelDirectoryAdapter) + return dir, ok +} + // List returns all registered adapters. func (r *Registry) List() []Adapter { r.mu.RLock() @@ -185,6 +195,26 @@ func (r *Registry) GetSender(channelType ChannelType) (Sender, bool) { return sender, ok } +// GetStreamSender returns the StreamSender for the given channel type, or nil if unsupported. +func (r *Registry) GetStreamSender(channelType ChannelType) (StreamSender, bool) { + adapter, ok := r.Get(channelType) + if !ok { + return nil, false + } + streamSender, ok := adapter.(StreamSender) + return streamSender, ok +} + +// GetMessageEditor returns the MessageEditor for the given channel type, or nil if unsupported. +func (r *Registry) GetMessageEditor(channelType ChannelType) (MessageEditor, bool) { + adapter, ok := r.Get(channelType) + if !ok { + return nil, false + } + editor, ok := adapter.(MessageEditor) + return editor, ok +} + // GetReceiver returns the Receiver for the given channel type, or nil if unsupported. func (r *Registry) GetReceiver(channelType ChannelType) (Receiver, bool) { adapter, ok := r.Get(channelType) @@ -195,6 +225,16 @@ func (r *Registry) GetReceiver(channelType ChannelType) (Receiver, bool) { return receiver, ok } +// GetProcessingStatusNotifier returns the ProcessingStatusNotifier for the given channel type, or nil if unsupported. +func (r *Registry) GetProcessingStatusNotifier(channelType ChannelType) (ProcessingStatusNotifier, bool) { + adapter, ok := r.Get(channelType) + if !ok { + return nil, false + } + notifier, ok := adapter.(ProcessingStatusNotifier) + return notifier, ok +} + // --- Dispatch methods (replace former global functions in config.go / target.go) --- // NormalizeConfig validates and normalizes a channel configuration map. diff --git a/internal/channel/registry_test.go b/internal/channel/registry_test.go new file mode 100644 index 00000000..c27c3875 --- /dev/null +++ b/internal/channel/registry_test.go @@ -0,0 +1,63 @@ +package channel_test + +import ( + "context" + "testing" + + "github.com/memohai/memoh/internal/channel" +) + +const dirTestChannelType = channel.ChannelType("dir-test") + +// dirMockAdapter implements Adapter and ChannelDirectoryAdapter for registry DirectoryAdapter tests. +type dirMockAdapter struct{} + +func (a *dirMockAdapter) Type() channel.ChannelType { return dirTestChannelType } + +func (a *dirMockAdapter) Descriptor() channel.Descriptor { + return channel.Descriptor{Type: dirTestChannelType, DisplayName: "DirTest"} +} + +func (a *dirMockAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} + +func (a *dirMockAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} + +func (a *dirMockAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelConfig, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} + +func (a *dirMockAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + return channel.DirectoryEntry{}, nil +} + +func TestDirectoryAdapter_Unsupported(t *testing.T) { + t.Parallel() + reg := newTestConfigRegistry() + dir, ok := reg.DirectoryAdapter(testChannelType) + if ok || dir != nil { + t.Fatalf("DirectoryAdapter(test) = (%v, %v), want (nil, false)", dir, ok) + } +} + +func TestDirectoryAdapter_Supported(t *testing.T) { + t.Parallel() + reg := channel.NewRegistry() + reg.MustRegister(&dirMockAdapter{}) + dir, ok := reg.DirectoryAdapter(dirTestChannelType) + if !ok || dir == nil { + t.Fatalf("DirectoryAdapter(dir-test) = (%v, %v), want (non-nil, true)", dir, ok) + } +} + +func TestDirectoryAdapter_UnknownType(t *testing.T) { + t.Parallel() + reg := channel.NewRegistry() + dir, ok := reg.DirectoryAdapter(channel.ChannelType("unknown")) + if ok || dir != nil { + t.Fatalf("DirectoryAdapter(unknown) = (%v, %v), want (nil, false)", dir, ok) + } +} diff --git a/internal/channel/route/service.go b/internal/channel/route/service.go new file mode 100644 index 00000000..c50aeafe --- /dev/null +++ b/internal/channel/route/service.go @@ -0,0 +1,362 @@ +package route + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/conversation" + dbpkg "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" +) + +// ConversationService contains the minimal conversation behavior required by route resolution. +type ConversationService interface { + Create(ctx context.Context, botID, channelIdentityID string, req conversation.CreateRequest) (conversation.Chat, error) + IsParticipant(ctx context.Context, conversationID, channelIdentityID string) (bool, error) + AddParticipant(ctx context.Context, conversationID, channelIdentityID, role string) (conversation.Participant, error) +} + +// DBService manages channel routes and route-to-conversation resolution. +type DBService struct { + queries *sqlc.Queries + conversation ConversationService + logger *slog.Logger +} + +// NewService creates a channel route service. +func NewService(log *slog.Logger, queries *sqlc.Queries, conversationService ConversationService) *DBService { + if log == nil { + log = slog.Default() + } + return &DBService{ + queries: queries, + conversation: conversationService, + logger: log.With(slog.String("service", "channel/route")), + } +} + +// Create creates a route. +func (s *DBService) Create(ctx context.Context, input CreateInput) (Route, error) { + pgConversationID, err := dbpkg.ParseUUID(input.ChatID) + if err != nil { + return Route{}, err + } + pgBotID, err := dbpkg.ParseUUID(input.BotID) + if err != nil { + return Route{}, err + } + var pgConfigID pgtype.UUID + if strings.TrimSpace(input.ChannelConfigID) != "" { + pgConfigID, err = dbpkg.ParseUUID(input.ChannelConfigID) + if err != nil { + return Route{}, err + } + } + metadata, err := json.Marshal(nonNilMap(input.Metadata)) + if err != nil { + return Route{}, fmt.Errorf("marshal route metadata: %w", err) + } + + row, err := s.queries.CreateChatRoute(ctx, sqlc.CreateChatRouteParams{ + ChatID: pgConversationID, + BotID: pgBotID, + Platform: input.Platform, + ChannelConfigID: pgConfigID, + ConversationID: input.ConversationID, + ThreadID: toPgText(input.ThreadID), + ReplyTarget: toPgText(input.ReplyTarget), + Metadata: metadata, + }) + if err != nil { + return Route{}, fmt.Errorf("create route: %w", err) + } + + return toRouteFromCreate(row), nil +} + +// Find finds a route by bot/platform/external-conversation/thread. +func (s *DBService) Find(ctx context.Context, botID, platform, conversationID, threadID string) (Route, error) { + pgBotID, err := dbpkg.ParseUUID(botID) + if err != nil { + return Route{}, err + } + row, err := s.queries.FindChatRoute(ctx, sqlc.FindChatRouteParams{ + BotID: pgBotID, + Platform: platform, + ConversationID: conversationID, + ThreadID: toPgText(threadID), + }) + if err != nil { + return Route{}, err + } + return toRouteFromFind(row), nil +} + +// GetByID gets a route by ID. +func (s *DBService) GetByID(ctx context.Context, routeID string) (Route, error) { + pgID, err := dbpkg.ParseUUID(routeID) + if err != nil { + return Route{}, err + } + row, err := s.queries.GetChatRouteByID(ctx, pgID) + if err != nil { + return Route{}, err + } + return toRouteFromGet(row), nil +} + +// List lists all routes for a conversation. +func (s *DBService) List(ctx context.Context, conversationID string) ([]Route, error) { + pgID, err := dbpkg.ParseUUID(conversationID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListChatRoutes(ctx, pgID) + if err != nil { + return nil, err + } + routes := make([]Route, 0, len(rows)) + for _, row := range rows { + routes = append(routes, toRouteFromList(row)) + } + return routes, nil +} + +// Delete deletes a route by ID. +func (s *DBService) Delete(ctx context.Context, routeID string) error { + pgID, err := dbpkg.ParseUUID(routeID) + if err != nil { + return err + } + return s.queries.DeleteChatRoute(ctx, pgID) +} + +// UpdateReplyTarget updates default reply target. +func (s *DBService) UpdateReplyTarget(ctx context.Context, routeID, replyTarget string) error { + pgID, err := dbpkg.ParseUUID(routeID) + if err != nil { + return err + } + return s.queries.UpdateChatRouteReplyTarget(ctx, sqlc.UpdateChatRouteReplyTargetParams{ + ID: pgID, + ReplyTarget: toPgText(replyTarget), + }) +} + +// ResolveConversation finds or creates a conversation route for an inbound message. +func (s *DBService) ResolveConversation(ctx context.Context, input ResolveInput) (ResolveConversationResult, error) { + route, err := s.Find(ctx, input.BotID, input.Platform, input.ConversationID, input.ThreadID) + if err == nil { + if strings.TrimSpace(input.ChannelIdentityID) != "" && s.conversation != nil { + ok, checkErr := s.conversation.IsParticipant(ctx, route.ChatID, input.ChannelIdentityID) + if checkErr != nil { + return ResolveConversationResult{}, fmt.Errorf("check conversation participant: %w", checkErr) + } + if !ok { + if _, addErr := s.conversation.AddParticipant(ctx, route.ChatID, input.ChannelIdentityID, conversation.RoleMember); addErr != nil && s.logger != nil { + s.logger.Warn("auto-add participant failed", slog.Any("error", addErr)) + } + } + } + if strings.TrimSpace(input.ReplyTarget) != "" && input.ReplyTarget != route.ReplyTarget { + if updateErr := s.UpdateReplyTarget(ctx, route.ID, input.ReplyTarget); updateErr != nil && s.logger != nil { + s.logger.Warn("update route reply target failed", slog.Any("error", updateErr)) + } + } + pgConversationID, parseErr := dbpkg.ParseUUID(route.ChatID) + if parseErr != nil { + return ResolveConversationResult{}, fmt.Errorf("parse route conversation id: %w", parseErr) + } + if touchErr := s.queries.TouchChat(ctx, pgConversationID); touchErr != nil && s.logger != nil { + s.logger.Warn("touch conversation failed", slog.Any("error", touchErr)) + } + return ResolveConversationResult{ChatID: route.ChatID, RouteID: route.ID, Created: false}, nil + } + + if s.conversation == nil { + return ResolveConversationResult{}, fmt.Errorf("conversation service not configured") + } + + kind := determineConversationKind(input.ThreadID, input.ConversationType) + creatorChannelIdentityID := s.resolveConversationCreatorChannelIdentityID(ctx, input.BotID, input.ChannelIdentityID, kind) + + var parentConversationID string + if kind == conversation.KindThread { + parentRoute, parentErr := s.Find(ctx, input.BotID, input.Platform, input.ConversationID, "") + if parentErr == nil { + parentConversationID = parentRoute.ChatID + } + } + + createdConversation, err := s.conversation.Create(ctx, input.BotID, creatorChannelIdentityID, conversation.CreateRequest{ + Kind: kind, + ParentChatID: parentConversationID, + }) + if err != nil { + return ResolveConversationResult{}, fmt.Errorf("create conversation: %w", err) + } + + if strings.TrimSpace(input.ChannelIdentityID) != "" && strings.TrimSpace(input.ChannelIdentityID) != strings.TrimSpace(creatorChannelIdentityID) { + if _, addErr := s.conversation.AddParticipant(ctx, createdConversation.ID, input.ChannelIdentityID, conversation.RoleMember); addErr != nil && s.logger != nil { + s.logger.Warn("auto-add creator participant failed", slog.Any("error", addErr)) + } + } + + newRoute, err := s.Create(ctx, CreateInput{ + ChatID: createdConversation.ID, + BotID: input.BotID, + Platform: input.Platform, + ChannelConfigID: input.ChannelConfigID, + ConversationID: input.ConversationID, + ThreadID: input.ThreadID, + ReplyTarget: input.ReplyTarget, + }) + if err != nil { + return ResolveConversationResult{}, fmt.Errorf("create route: %w", err) + } + + return ResolveConversationResult{ChatID: createdConversation.ID, RouteID: newRoute.ID, Created: true}, nil +} + +func determineConversationKind(threadID, conversationType string) string { + if strings.TrimSpace(threadID) != "" { + return conversation.KindThread + } + ct := strings.ToLower(strings.TrimSpace(conversationType)) + if ct == "p2p" || ct == "private" || ct == "" { + return conversation.KindDirect + } + return conversation.KindGroup +} + +func (s *DBService) resolveConversationCreatorChannelIdentityID(ctx context.Context, botID, fallbackChannelIdentityID, kind string) string { + fallback := strings.TrimSpace(fallbackChannelIdentityID) + if kind != conversation.KindGroup || s.queries == nil { + return fallback + } + pgBotID, err := dbpkg.ParseUUID(botID) + if err != nil { + return fallback + } + row, err := s.queries.GetBotByID(ctx, pgBotID) + if err != nil { + if s.logger != nil { + s.logger.Warn("resolve bot owner for group conversation failed", slog.Any("error", err)) + } + return fallback + } + ownerChannelIdentityID := row.OwnerUserID.String() + if strings.TrimSpace(ownerChannelIdentityID) == "" { + return fallback + } + return ownerChannelIdentityID +} + +func toRouteFromCreate(row sqlc.CreateChatRouteRow) Route { + return toRouteFields( + row.ID, + row.ChatID, + row.BotID, + row.Platform, + row.ChannelConfigID, + row.ConversationID, + row.ThreadID, + row.ReplyTarget, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toRouteFromFind(row sqlc.FindChatRouteRow) Route { + return toRouteFields( + row.ID, + row.ChatID, + row.BotID, + row.Platform, + row.ChannelConfigID, + row.ConversationID, + row.ThreadID, + row.ReplyTarget, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toRouteFromGet(row sqlc.GetChatRouteByIDRow) Route { + return toRouteFields( + row.ID, + row.ChatID, + row.BotID, + row.Platform, + row.ChannelConfigID, + row.ConversationID, + row.ThreadID, + row.ReplyTarget, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toRouteFromList(row sqlc.ListChatRoutesRow) Route { + return toRouteFields( + row.ID, + row.ChatID, + row.BotID, + row.Platform, + row.ChannelConfigID, + row.ConversationID, + row.ThreadID, + row.ReplyTarget, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toRouteFields(id, conversationID, botID pgtype.UUID, platform string, channelConfigID pgtype.UUID, externalConversationID string, threadID, replyTarget pgtype.Text, metadata []byte, createdAt, updatedAt pgtype.Timestamptz) Route { + return Route{ + ID: id.String(), + ChatID: conversationID.String(), + BotID: botID.String(), + Platform: platform, + ChannelConfigID: channelConfigID.String(), + ConversationID: externalConversationID, + ThreadID: dbpkg.TextToString(threadID), + ReplyTarget: dbpkg.TextToString(replyTarget), + Metadata: parseJSONMap(metadata), + CreatedAt: createdAt.Time, + UpdatedAt: updatedAt.Time, + } +} + +func toPgText(value string) pgtype.Text { + value = strings.TrimSpace(value) + if value == "" { + return pgtype.Text{} + } + return pgtype.Text{String: value, Valid: true} +} + +func nonNilMap(m map[string]any) map[string]any { + if m == nil { + return map[string]any{} + } + return m +} + +func parseJSONMap(data []byte) map[string]any { + if len(data) == 0 { + return nil + } + var m map[string]any + _ = json.Unmarshal(data, &m) + return m +} diff --git a/internal/channel/route/types.go b/internal/channel/route/types.go new file mode 100644 index 00000000..a7c1e44c --- /dev/null +++ b/internal/channel/route/types.go @@ -0,0 +1,68 @@ +package route + +import ( + "context" + "time" +) + +// Route maps external channel conversations to an internal conversation. +type Route struct { + ID string `json:"id"` + ChatID string `json:"chat_id"` + BotID string `json:"bot_id"` + Platform string `json:"platform"` + ChannelConfigID string `json:"channel_config_id,omitempty"` + ConversationID string `json:"conversation_id"` + ThreadID string `json:"thread_id,omitempty"` + ReplyTarget string `json:"reply_target,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// ResolveConversationResult is returned by ResolveConversation. +type ResolveConversationResult struct { + ChatID string + RouteID string + Created bool +} + +// CreateInput is the input for creating a route. +type CreateInput struct { + ChatID string + BotID string + Platform string + ChannelConfigID string + ConversationID string + ThreadID string + ReplyTarget string + Metadata map[string]any +} + +// ResolveInput is the input for route-to-conversation resolution. +type ResolveInput struct { + BotID string + Platform string + ConversationID string + ThreadID string + ConversationType string + ChannelIdentityID string + ChannelConfigID string + ReplyTarget string +} + +// Resolver defines the route resolution behavior used by inbound routing. +type Resolver interface { + ResolveConversation(ctx context.Context, input ResolveInput) (ResolveConversationResult, error) +} + +// Service defines route management behavior. +type Service interface { + Resolver + Create(ctx context.Context, input CreateInput) (Route, error) + Find(ctx context.Context, botID, platform, conversationID, threadID string) (Route, error) + GetByID(ctx context.Context, routeID string) (Route, error) + List(ctx context.Context, chatID string) ([]Route, error) + Delete(ctx context.Context, routeID string) error + UpdateReplyTarget(ctx context.Context, routeID, replyTarget string) error +} diff --git a/internal/channel/service.go b/internal/channel/service.go index d11437a2..5414af65 100644 --- a/internal/channel/service.go +++ b/internal/channel/service.go @@ -65,9 +65,9 @@ func (s *Service) UpsertConfig(ctx context.Context, botID string, channelType Ch if err != nil { return ChannelConfig{}, err } - status := strings.TrimSpace(req.Status) - if status == "" { - status = "pending" + status, err := normalizeChannelConfigStatus(req.Status) + if err != nil { + return ChannelConfig{}, err } verifiedAt := pgtype.Timestamptz{Valid: false} if req.VerifiedAt != nil { @@ -94,35 +94,52 @@ func (s *Service) UpsertConfig(ctx context.Context, botID string, channelType Ch return normalizeChannelConfig(row) } -// UpsertUserConfig creates or updates a user's channel binding. -func (s *Service) UpsertUserConfig(ctx context.Context, actorUserID string, channelType ChannelType, req UpsertUserConfigRequest) (ChannelUserBinding, error) { +func normalizeChannelConfigStatus(raw string) (string, error) { + status := strings.ToLower(strings.TrimSpace(raw)) + if status == "" { + return "pending", nil + } + switch status { + case "pending", "verified", "disabled": + return status, nil + case "active": + return "verified", nil + case "inactive": + return "disabled", nil + default: + return "", fmt.Errorf("invalid channel status: %s", raw) + } +} + +// UpsertChannelIdentityConfig creates or updates a channel identity's channel binding. +func (s *Service) UpsertChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType, req UpsertChannelIdentityConfigRequest) (ChannelIdentityBinding, error) { if s.queries == nil { - return ChannelUserBinding{}, fmt.Errorf("channel queries not configured") + return ChannelIdentityBinding{}, fmt.Errorf("channel queries not configured") } if channelType == "" { - return ChannelUserBinding{}, fmt.Errorf("channel type is required") + return ChannelIdentityBinding{}, fmt.Errorf("channel type is required") } normalized, err := s.registry.NormalizeUserConfig(channelType, req.Config) if err != nil { - return ChannelUserBinding{}, err + return ChannelIdentityBinding{}, err } payload, err := json.Marshal(normalized) if err != nil { - return ChannelUserBinding{}, err + return ChannelIdentityBinding{}, err } - pgUserID, err := db.ParseUUID(actorUserID) + pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) if err != nil { - return ChannelUserBinding{}, err + return ChannelIdentityBinding{}, err } row, err := s.queries.UpsertUserChannelBinding(ctx, sqlc.UpsertUserChannelBindingParams{ - UserID: pgUserID, + UserID: pgChannelIdentityID, ChannelType: channelType.String(), Config: payload, }) if err != nil { - return ChannelUserBinding{}, err + return ChannelIdentityBinding{}, err } - return normalizeChannelUserBindingRow(row) + return normalizeChannelIdentityBinding(row) } // ResolveEffectiveConfig returns the active channel configuration for a bot. @@ -181,54 +198,54 @@ func (s *Service) ListConfigsByType(ctx context.Context, channelType ChannelType return items, nil } -// GetUserConfig returns the user's channel binding for the given channel type. -func (s *Service) GetUserConfig(ctx context.Context, actorUserID string, channelType ChannelType) (ChannelUserBinding, error) { +// GetChannelIdentityConfig returns the channel identity's channel binding for the given channel type. +func (s *Service) GetChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType) (ChannelIdentityBinding, error) { if s.queries == nil { - return ChannelUserBinding{}, fmt.Errorf("channel queries not configured") + return ChannelIdentityBinding{}, fmt.Errorf("channel queries not configured") } if channelType == "" { - return ChannelUserBinding{}, fmt.Errorf("channel type is required") + return ChannelIdentityBinding{}, fmt.Errorf("channel type is required") } - pgUserID, err := db.ParseUUID(actorUserID) + pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) if err != nil { - return ChannelUserBinding{}, err + return ChannelIdentityBinding{}, err } row, err := s.queries.GetUserChannelBinding(ctx, sqlc.GetUserChannelBindingParams{ - UserID: pgUserID, + UserID: pgChannelIdentityID, ChannelType: channelType.String(), }) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return ChannelUserBinding{}, fmt.Errorf("channel user config not found") + return ChannelIdentityBinding{}, fmt.Errorf("channel user config not found") } - return ChannelUserBinding{}, err + return ChannelIdentityBinding{}, err } config, err := DecodeConfigMap(row.Config) if err != nil { - return ChannelUserBinding{}, err + return ChannelIdentityBinding{}, err } - return ChannelUserBinding{ - ID: db.UUIDToString(row.ID), - ChannelType: ChannelType(row.ChannelType), - UserID: db.UUIDToString(row.UserID), - Config: config, - CreatedAt: db.TimeFromPg(row.CreatedAt), - UpdatedAt: db.TimeFromPg(row.UpdatedAt), + return ChannelIdentityBinding{ + ID: row.ID.String(), + ChannelType: ChannelType(row.ChannelType), + ChannelIdentityID: row.UserID.String(), + Config: config, + CreatedAt: db.TimeFromPg(row.CreatedAt), + UpdatedAt: db.TimeFromPg(row.UpdatedAt), }, nil } -// ListUserConfigsByType returns all user bindings for the given channel type. -func (s *Service) ListUserConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelUserBinding, error) { +// ListChannelIdentityConfigsByType returns all channel identity bindings for the given channel type. +func (s *Service) ListChannelIdentityConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelIdentityBinding, error) { if s.queries == nil { return nil, fmt.Errorf("channel queries not configured") } - rows, err := s.queries.ListUserChannelBindingsByType(ctx, channelType.String()) + rows, err := s.queries.ListUserChannelBindingsByPlatform(ctx, channelType.String()) if err != nil { return nil, err } - items := make([]ChannelUserBinding, 0, len(rows)) + items := make([]ChannelIdentityBinding, 0, len(rows)) for _, row := range rows { - item, err := normalizeChannelUserBindingRow(row) + item, err := normalizeChannelIdentityBinding(row) if err != nil { return nil, err } @@ -237,119 +254,9 @@ func (s *Service) ListUserConfigsByType(ctx context.Context, channelType Channel return items, nil } -// GetChannelSession returns the session with the given ID, or an empty session if not found. -func (s *Service) GetChannelSession(ctx context.Context, sessionID string) (ChannelSession, error) { - if s.queries == nil { - return ChannelSession{}, fmt.Errorf("channel queries not configured") - } - row, err := s.queries.GetChannelSessionByID(ctx, sessionID) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return ChannelSession{}, nil - } - return ChannelSession{}, err - } - return normalizeChannelSession(row) -} - -// ListSessionsByBotPlatform returns all sessions for the given bot and platform. -func (s *Service) ListSessionsByBotPlatform(ctx context.Context, botID, platform string) ([]ChannelSession, error) { - if s.queries == nil { - return nil, fmt.Errorf("channel queries not configured") - } - botID = strings.TrimSpace(botID) - platform = strings.TrimSpace(platform) - if botID == "" { - return nil, fmt.Errorf("bot id is required") - } - if platform == "" { - return nil, fmt.Errorf("platform is required") - } - pgBotID, err := db.ParseUUID(botID) - if err != nil { - return nil, err - } - rows, err := s.queries.ListChannelSessionsByBotPlatform(ctx, sqlc.ListChannelSessionsByBotPlatformParams{ - BotID: pgBotID, - Platform: platform, - }) - if err != nil { - return nil, err - } - items := make([]ChannelSession, 0, len(rows)) - for _, row := range rows { - item, err := normalizeChannelSession(row) - if err != nil { - return nil, err - } - items = append(items, item) - } - return items, nil -} - -// UpsertChannelSession creates or updates a channel session record. -func (s *Service) UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string, replyTarget string, threadID string, metadata map[string]any) error { - if s.queries == nil { - return fmt.Errorf("channel queries not configured") - } - pgUserID := pgtype.UUID{Valid: false} - if strings.TrimSpace(userID) != "" { - parsed, err := db.ParseUUID(userID) - if err != nil { - return err - } - pgUserID = parsed - } - botUUID, err := db.ParseUUID(botID) - if err != nil { - return err - } - var channelUUID pgtype.UUID - if strings.TrimSpace(channelConfigID) != "" { - channelUUID, err = db.ParseUUID(channelConfigID) - if err != nil { - return err - } - } - pgContactID := pgtype.UUID{Valid: false} - if strings.TrimSpace(contactID) != "" { - parsed, err := db.ParseUUID(contactID) - if err != nil { - return err - } - pgContactID = parsed - } - payload := metadata - if payload == nil { - payload = map[string]any{} - } - metaBytes, err := json.Marshal(payload) - if err != nil { - return err - } - _, err = s.queries.UpsertChannelSession(ctx, sqlc.UpsertChannelSessionParams{ - SessionID: sessionID, - BotID: botUUID, - ChannelConfigID: channelUUID, - UserID: pgUserID, - ContactID: pgContactID, - Platform: platform, - ReplyTarget: pgtype.Text{ - String: strings.TrimSpace(replyTarget), - Valid: strings.TrimSpace(replyTarget) != "", - }, - ThreadID: pgtype.Text{ - String: strings.TrimSpace(threadID), - Valid: strings.TrimSpace(threadID) != "", - }, - Metadata: metaBytes, - }) - return err -} - -// ResolveUserBinding finds the user ID whose channel binding matches the given criteria. -func (s *Service) ResolveUserBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) { - rows, err := s.ListUserConfigsByType(ctx, channelType) +// ResolveChannelIdentityBinding finds the channel identity ID whose channel binding matches the given criteria. +func (s *Service) ResolveChannelIdentityBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) { + rows, err := s.ListChannelIdentityConfigsByType(ctx, channelType) if err != nil { return "", err } @@ -358,7 +265,7 @@ func (s *Service) ResolveUserBinding(ctx context.Context, channelType ChannelTyp } for _, row := range rows { if s.registry.MatchUserBinding(channelType, row.Config, criteria) { - return row.UserID, nil + return row.ChannelIdentityID, nil } } return "", fmt.Errorf("channel user binding not found") @@ -386,54 +293,31 @@ func normalizeChannelConfig(row sqlc.BotChannelConfig) (ChannelConfig, error) { externalIdentity = strings.TrimSpace(row.ExternalIdentity.String) } return ChannelConfig{ - ID: db.UUIDToString(row.ID), - BotID: db.UUIDToString(row.BotID), + ID: row.ID.String(), + BotID: row.BotID.String(), ChannelType: ChannelType(row.ChannelType), Credentials: credentials, ExternalIdentity: externalIdentity, SelfIdentity: selfIdentity, - Routing: routing, - Status: strings.TrimSpace(row.Status), + Routing: routing, + Status: strings.TrimSpace(row.Status), VerifiedAt: verifiedAt, CreatedAt: db.TimeFromPg(row.CreatedAt), UpdatedAt: db.TimeFromPg(row.UpdatedAt), }, nil } -func normalizeChannelUserBindingRow(row sqlc.UserChannelBinding) (ChannelUserBinding, error) { +func normalizeChannelIdentityBinding(row sqlc.UserChannelBinding) (ChannelIdentityBinding, error) { config, err := DecodeConfigMap(row.Config) if err != nil { - return ChannelUserBinding{}, err + return ChannelIdentityBinding{}, err } - return ChannelUserBinding{ - ID: db.UUIDToString(row.ID), - ChannelType: ChannelType(row.ChannelType), - UserID: db.UUIDToString(row.UserID), - Config: config, - CreatedAt: db.TimeFromPg(row.CreatedAt), - UpdatedAt: db.TimeFromPg(row.UpdatedAt), + return ChannelIdentityBinding{ + ID: row.ID.String(), + ChannelType: ChannelType(row.ChannelType), + ChannelIdentityID: row.UserID.String(), + Config: config, + CreatedAt: db.TimeFromPg(row.CreatedAt), + UpdatedAt: db.TimeFromPg(row.UpdatedAt), }, nil } - -func normalizeChannelSession(row sqlc.ChannelSession) (ChannelSession, error) { - metadata, err := DecodeConfigMap(row.Metadata) - if err != nil { - return ChannelSession{}, err - } - return ChannelSession{ - SessionID: row.SessionID, - BotID: db.UUIDToString(row.BotID), - ChannelConfigID: db.UUIDToString(row.ChannelConfigID), - UserID: db.UUIDToString(row.UserID), - ContactID: db.UUIDToString(row.ContactID), - Platform: row.Platform, - ReplyTarget: strings.TrimSpace(row.ReplyTarget.String), - ThreadID: strings.TrimSpace(row.ThreadID.String), - Metadata: metadata, - CreatedAt: db.TimeFromPg(row.CreatedAt), - UpdatedAt: db.TimeFromPg(row.UpdatedAt), - }, nil -} - - - diff --git a/internal/channel/types.go b/internal/channel/types.go index 0bb2147e..2a4d4d46 100644 --- a/internal/channel/types.go +++ b/internal/channel/types.go @@ -17,7 +17,7 @@ func (c ChannelType) String() string { // Identity represents a sender's identity on a channel. type Identity struct { - ExternalID string + SubjectID string DisplayName string Attributes map[string]string } @@ -45,7 +45,7 @@ type InboundMessage struct { Message Message BotID string ReplyTarget string - SessionKey string + RouteKey string Sender Identity Conversation Conversation ReceivedAt time.Time @@ -53,22 +53,22 @@ type InboundMessage struct { Metadata map[string]any } -// SessionID returns a stable identifier for the conversation session. +// RoutingKey returns a stable identifier used for reply routing. // Format: platform:bot_id:conversation_id[:sender_id]. -func (m InboundMessage) SessionID() string { - if strings.TrimSpace(m.SessionKey) != "" { - return strings.TrimSpace(m.SessionKey) +func (m InboundMessage) RoutingKey() string { + if strings.TrimSpace(m.RouteKey) != "" { + return strings.TrimSpace(m.RouteKey) } - senderID := strings.TrimSpace(m.Sender.ExternalID) + senderID := strings.TrimSpace(m.Sender.SubjectID) if senderID == "" { senderID = strings.TrimSpace(m.Sender.DisplayName) } - return GenerateSessionID(string(m.Channel), m.BotID, m.Conversation.ID, m.Conversation.Type, senderID) + return GenerateRoutingKey(string(m.Channel), m.BotID, m.Conversation.ID, m.Conversation.Type, senderID) } -// GenerateSessionID builds a session identifier from platform, bot, conversation, and sender info. +// GenerateRoutingKey builds a route key from platform, bot, conversation, and sender info. // For group chats, the sender ID is appended to provide per-user context. -func GenerateSessionID(platform, botID, conversationID, conversationType, senderID string) string { +func GenerateRoutingKey(platform, botID, conversationID, conversationType, senderID string) string { parts := []string{platform, botID, conversationID} ct := strings.ToLower(strings.TrimSpace(conversationType)) if ct != "" && ct != "p2p" && ct != "private" { @@ -86,6 +86,47 @@ type OutboundMessage struct { Message Message `json:"message"` } +// StreamEventType defines the kind of outbound stream event. +type StreamEventType string + +const ( + StreamEventStatus StreamEventType = "status" + StreamEventDelta StreamEventType = "delta" + StreamEventFinal StreamEventType = "final" + StreamEventError StreamEventType = "error" +) + +// StreamStatus indicates the lifecycle state of a streaming reply. +type StreamStatus string + +const ( + StreamStatusStarted StreamStatus = "started" + StreamStatusCompleted StreamStatus = "completed" + StreamStatusFailed StreamStatus = "failed" +) + +// StreamFinalizePayload carries the final reply message emitted by a stream. +type StreamFinalizePayload struct { + Message Message `json:"message"` +} + +// StreamEvent represents a unified stream event routed through the channel layer. +type StreamEvent struct { + Type StreamEventType `json:"type"` + Status StreamStatus `json:"status,omitempty"` + Delta string `json:"delta,omitempty"` + Final *StreamFinalizePayload `json:"final,omitempty"` + Error string `json:"error,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// StreamOptions configures how an outbound stream is initialized. +type StreamOptions struct { + Reply *ReplyRef `json:"reply,omitempty"` + SourceMessageID string `json:"source_message_id,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + // MessageFormat indicates how the message text should be rendered. type MessageFormat string @@ -118,14 +159,14 @@ const ( // MessagePart is a single element within a rich-text message. type MessagePart struct { - Type MessagePartType `json:"type"` - Text string `json:"text,omitempty"` - URL string `json:"url,omitempty"` - Styles []MessageTextStyle `json:"styles,omitempty"` - Language string `json:"language,omitempty"` - UserID string `json:"user_id,omitempty"` - Emoji string `json:"emoji,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Type MessagePartType `json:"type"` + Text string `json:"text,omitempty"` + URL string `json:"url,omitempty"` + Styles []MessageTextStyle `json:"styles,omitempty"` + Language string `json:"language,omitempty"` + ChannelIdentityID string `json:"channel_identity_id,omitempty"` + Emoji string `json:"emoji,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } // AttachmentType classifies the kind of binary attachment. @@ -142,17 +183,33 @@ const ( // Attachment represents a binary file attached to a message. type Attachment struct { - Type AttachmentType `json:"type"` - URL string `json:"url,omitempty"` - Name string `json:"name,omitempty"` - Size int64 `json:"size,omitempty"` - Mime string `json:"mime,omitempty"` - DurationMs int64 `json:"duration_ms,omitempty"` - Width int `json:"width,omitempty"` - Height int `json:"height,omitempty"` - ThumbnailURL string `json:"thumbnail_url,omitempty"` - Caption string `json:"caption,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Type AttachmentType `json:"type"` + URL string `json:"url,omitempty"` + PlatformKey string `json:"platform_key,omitempty"` + SourcePlatform string `json:"source_platform,omitempty"` + Name string `json:"name,omitempty"` + Size int64 `json:"size,omitempty"` + Mime string `json:"mime,omitempty"` + DurationMs int64 `json:"duration_ms,omitempty"` + Width int `json:"width,omitempty"` + Height int `json:"height,omitempty"` + ThumbnailURL string `json:"thumbnail_url,omitempty"` + Caption string `json:"caption,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// Reference returns the strongest available attachment reference. +// URL is preferred for cross-platform portability, then platform key. +func (a Attachment) Reference() string { + if strings.TrimSpace(a.URL) != "" { + return strings.TrimSpace(a.URL) + } + return strings.TrimSpace(a.PlatformKey) +} + +// HasReference reports whether URL or platform key is available. +func (a Attachment) HasReference() bool { + return a.Reference() != "" } // Action describes an interactive button or link in a message. @@ -227,7 +284,7 @@ func (m Message) PlainText() string { // BindingCriteria specifies conditions for matching a user-channel binding. type BindingCriteria struct { - ExternalID string + SubjectID string Attributes map[string]string } @@ -242,7 +299,7 @@ func (c BindingCriteria) Attribute(key string) string { // BindingCriteriaFromIdentity creates BindingCriteria from a channel Identity. func BindingCriteriaFromIdentity(identity Identity) BindingCriteria { return BindingCriteria{ - ExternalID: strings.TrimSpace(identity.ExternalID), + SubjectID: strings.TrimSpace(identity.SubjectID), Attributes: identity.Attributes, } } @@ -262,14 +319,14 @@ type ChannelConfig struct { UpdatedAt time.Time } -// ChannelUserBinding represents a user's binding to a specific channel type. -type ChannelUserBinding struct { - ID string - ChannelType ChannelType - UserID string - Config map[string]any - CreatedAt time.Time - UpdatedAt time.Time +// ChannelIdentityBinding represents a channel identity's binding to a specific channel type. +type ChannelIdentityBinding struct { + ID string + ChannelType ChannelType + ChannelIdentityID string + Config map[string]any + CreatedAt time.Time + UpdatedAt time.Time } // UpsertConfigRequest is the input for creating or updating a channel configuration. @@ -282,29 +339,14 @@ type UpsertConfigRequest struct { VerifiedAt *time.Time `json:"verified_at,omitempty"` } -// UpsertUserConfigRequest is the input for creating or updating a user-channel binding. -type UpsertUserConfigRequest struct { +// UpsertChannelIdentityConfigRequest is the input for creating or updating a channel-identity binding. +type UpsertChannelIdentityConfigRequest struct { Config map[string]any `json:"config"` } -// ChannelSession tracks an active conversation session on a channel. -type ChannelSession struct { - SessionID string - BotID string - ChannelConfigID string - UserID string - ContactID string - Platform string - ReplyTarget string - ThreadID string - Metadata map[string]any - CreatedAt time.Time - UpdatedAt time.Time -} - // SendRequest is the input for sending an outbound message through a channel. type SendRequest struct { - Target string `json:"target,omitempty"` - UserID string `json:"user_id,omitempty"` - Message Message `json:"message"` + Target string `json:"target,omitempty"` + ChannelIdentityID string `json:"channel_identity_id,omitempty"` + Message Message `json:"message"` } diff --git a/internal/chat/assistant_output.go b/internal/chat/assistant_output.go index ae00f91f..235b121d 100644 --- a/internal/chat/assistant_output.go +++ b/internal/chat/assistant_output.go @@ -1,4 +1,4 @@ -package chat +package conversation import "strings" diff --git a/internal/chat/resolver.go b/internal/chat/resolver.go index 70a1b412..47f1d84e 100644 --- a/internal/chat/resolver.go +++ b/internal/chat/resolver.go @@ -1,4 +1,4 @@ -package chat +package conversation import ( "bufio" @@ -9,13 +9,14 @@ import ( "io" "log/slog" "net/http" + "sort" "strings" "time" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" - "github.com/memohai/memoh/internal/history" "github.com/memohai/memoh/internal/mcp" "github.com/memohai/memoh/internal/memory" "github.com/memohai/memoh/internal/models" @@ -23,7 +24,13 @@ import ( "github.com/memohai/memoh/internal/settings" ) -const defaultMaxContextMinutes = 24 * 60 +const ( + defaultMaxContextMinutes = 24 * 60 + memoryContextLimitPerScope = 4 + memoryContextMaxItems = 8 + memoryContextItemMaxChars = 220 + sharedMemoryNamespace = "bot" +) // SkillEntry represents a skill loaded from the container. type SkillEntry struct { @@ -43,7 +50,7 @@ type Resolver struct { modelsService *models.Service queries *sqlc.Queries memoryService *memory.Service - historyService *history.Service + chatService *Service settingsService *settings.Service mcpService *mcp.ConnectionService skillLoader SkillLoader @@ -60,7 +67,7 @@ func NewResolver( modelsService *models.Service, queries *sqlc.Queries, memoryService *memory.Service, - historyService *history.Service, + chatService *Service, settingsService *settings.Service, mcpService *mcp.ConnectionService, gatewayBaseURL string, @@ -77,12 +84,12 @@ func NewResolver( modelsService: modelsService, queries: queries, memoryService: memoryService, - historyService: historyService, + chatService: chatService, settingsService: settingsService, mcpService: mcpService, gatewayBaseURL: gatewayBaseURL, timeout: timeout, - logger: log.With(slog.String("service", "chat")), + logger: log.With(slog.String("service", "chat_resolver")), httpClient: &http.Client{Timeout: timeout}, streamingClient: &http.Client{}, } @@ -104,16 +111,13 @@ type gatewayModelConfig struct { } type gatewayIdentity struct { - BotID string `json:"botId"` - SessionID string `json:"sessionId"` - ContainerID string `json:"containerId"` - ContactID string `json:"contactId"` - ContactName string `json:"contactName"` - ContactAlias string `json:"contactAlias,omitempty"` - UserID string `json:"userId,omitempty"` - CurrentPlatform string `json:"currentPlatform,omitempty"` - ReplyTarget string `json:"replyTarget,omitempty"` - SessionToken string `json:"sessionToken,omitempty"` + BotID string `json:"botId"` + ContainerID string `json:"containerId"` + ChannelIdentityID string `json:"channelIdentityId"` + DisplayName string `json:"displayName"` + CurrentPlatform string `json:"currentPlatform,omitempty"` + ReplyTarget string `json:"replyTarget,omitempty"` + SessionToken string `json:"sessionToken,omitempty"` } type gatewaySkill struct { @@ -129,7 +133,6 @@ type gatewayRequest struct { Channels []string `json:"channels"` CurrentChannel string `json:"currentChannel"` AllowedActions []string `json:"allowedActions,omitempty"` - MCPConnections []map[string]any `json:"mcpConnections"` Messages []ModelMessage `json:"messages"` Skills []string `json:"skills"` UsableSkills []gatewaySkill `json:"usableSkills"` @@ -160,7 +163,6 @@ type triggerScheduleRequest struct { Channels []string `json:"channels"` CurrentChannel string `json:"currentChannel"` AllowedActions []string `json:"allowedActions,omitempty"` - MCPConnections []map[string]any `json:"mcpConnections"` Messages []ModelMessage `json:"messages"` Skills []string `json:"skills"` UsableSkills []gatewaySkill `json:"usableSkills"` @@ -184,8 +186,8 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex if strings.TrimSpace(req.BotID) == "" { return resolvedContext{}, fmt.Errorf("bot id is required") } - if strings.TrimSpace(req.SessionID) == "" { - return resolvedContext{}, fmt.Errorf("session id is required") + if strings.TrimSpace(req.ChatID) == "" { + return resolvedContext{}, fmt.Errorf("chat id is required") } skipHistory := req.MaxContextLoadTime < 0 @@ -194,11 +196,21 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex if err != nil { return resolvedContext{}, err } + + // Check chat-level model override. + var chatSettings Settings + if r.chatService != nil { + chatSettings, err = r.chatService.GetSettings(ctx, req.ChatID) + if err != nil { + return resolvedContext{}, err + } + } + userSettings, err := r.loadUserSettings(ctx, req.UserID) if err != nil { return resolvedContext{}, err } - chatModel, provider, err := r.selectChatModel(ctx, req, botSettings, userSettings) + chatModel, provider, err := r.selectChatModel(ctx, req, botSettings, userSettings, chatSettings) if err != nil { return resolvedContext{}, err } @@ -209,20 +221,18 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex maxCtx := coalescePositiveInt(req.MaxContextLoadTime, botSettings.MaxContextLoadTime, defaultMaxContextMinutes) var messages []ModelMessage - var historySkills []string - if !skipHistory { - messages, err = r.loadHistoryMessages(ctx, req.BotID, req.SessionID, maxCtx) - if err != nil { - return resolvedContext{}, err - } - historySkills, err = r.loadHistorySkills(ctx, req.BotID, req.SessionID, maxCtx) + if !skipHistory && r.chatService != nil { + messages, err = r.loadMessages(ctx, req.ChatID, maxCtx) if err != nil { return resolvedContext{}, err } } + if memoryMsg := r.loadMemoryContextMessage(ctx, req); memoryMsg != nil { + messages = append(messages, *memoryMsg) + } messages = append(messages, req.Messages...) messages = sanitizeMessages(messages) - skills := dedup(append(historySkills, req.Skills...)) + skills := dedup(req.Skills) containerID := r.resolveContainerID(ctx, req.BotID, req.ContainerID) var usableSkills []gatewaySkill @@ -246,24 +256,6 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex usableSkills = []gatewaySkill{} } - mcpConnections := []map[string]any{} - if r.mcpService != nil { - items, err := r.mcpService.ListActiveByBot(ctx, req.BotID) - if err != nil { - r.logger.Warn("failed to load mcp connections", slog.String("bot_id", req.BotID), slog.Any("error", err)) - } else { - for _, item := range items { - payload := map[string]any{} - for k, v := range item.Config { - payload[k] = v - } - payload["name"] = item.Name - payload["type"] = item.Type - mcpConnections = append(mcpConnections, payload) - } - } - } - payload := gatewayRequest{ Model: gatewayModelConfig{ ModelID: chatModel.ModelID, @@ -276,22 +268,18 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex Channels: nonNilStrings(req.Channels), CurrentChannel: req.CurrentChannel, AllowedActions: req.AllowedActions, - MCPConnections: mcpConnections, - Messages: nonNilMessages(messages), + Messages: nonNilModelMessages(messages), Skills: nonNilStrings(skills), UsableSkills: usableSkills, Query: req.Query, Identity: gatewayIdentity{ - BotID: req.BotID, - SessionID: req.SessionID, - ContainerID: containerID, - ContactID: firstNonEmpty(req.ContactID, req.UserID, req.BotID), - ContactName: firstNonEmpty(req.ContactName, "User"), - ContactAlias: req.ContactAlias, - UserID: req.UserID, - CurrentPlatform: req.CurrentChannel, - ReplyTarget: req.ReplyTarget, - SessionToken: req.SessionToken, + BotID: req.BotID, + ContainerID: containerID, + ChannelIdentityID: firstNonEmpty(req.SourceChannelIdentityID, req.UserID), + DisplayName: firstNonEmpty(req.DisplayName, "User"), + CurrentPlatform: req.CurrentChannel, + ReplyTarget: "", + SessionToken: req.ChatToken, }, Attachments: []any{}, } @@ -311,7 +299,7 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err if err != nil { return ChatResponse{}, err } - if err := r.storeRound(ctx, req.BotID, req.SessionID, req.Query, resp.Messages, resp.Skills); err != nil { + if err := r.storeRound(ctx, req, resp.Messages); err != nil { return ChatResponse{}, err } return ChatResponse{ @@ -333,13 +321,16 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc return fmt.Errorf("schedule command is required") } - sessionID := "schedule:" + payload.ID + chatID := payload.ChatID + if strings.TrimSpace(chatID) == "" { + chatID = "schedule-" + payload.ID + } req := ChatRequest{ - BotID: botID, - SessionID: sessionID, - Query: payload.Command, - UserID: payload.OwnerUserID, - Token: token, + BotID: botID, + ChatID: chatID, + Query: payload.Command, + UserID: payload.OwnerUserID, + Token: token, } rc, err := r.resolve(ctx, req) if err != nil { @@ -352,17 +343,14 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc Channels: rc.payload.Channels, CurrentChannel: rc.payload.CurrentChannel, AllowedActions: rc.payload.AllowedActions, - MCPConnections: rc.payload.MCPConnections, Messages: rc.payload.Messages, Skills: rc.payload.Skills, UsableSkills: rc.payload.UsableSkills, Identity: gatewayIdentity{ - BotID: rc.payload.Identity.BotID, - SessionID: rc.payload.Identity.SessionID, - ContainerID: rc.payload.Identity.ContainerID, - ContactID: botID, - ContactName: "Scheduler", - UserID: payload.OwnerUserID, + BotID: rc.payload.Identity.BotID, + ContainerID: rc.payload.Identity.ContainerID, + ChannelIdentityID: strings.TrimSpace(payload.OwnerUserID), + DisplayName: "Scheduler", }, Attachments: rc.payload.Attachments, Schedule: gatewaySchedule{ @@ -379,7 +367,7 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc if err != nil { return err } - return r.storeRound(ctx, botID, sessionID, payload.Command, resp.Messages, resp.Skills) + return r.storeRound(ctx, req, resp.Messages) } // --- StreamChat --- @@ -390,27 +378,38 @@ func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan Stre errCh := make(chan error, 1) r.logger.Info("gateway stream start", slog.String("bot_id", req.BotID), - slog.String("session_id", req.SessionID), + slog.String("chat_id", req.ChatID), ) go func() { defer close(chunkCh) defer close(errCh) - rc, err := r.resolve(ctx, req) + streamReq := req + rc, err := r.resolve(ctx, streamReq) if err != nil { r.logger.Error("gateway stream resolve failed", - slog.String("bot_id", req.BotID), - slog.String("session_id", req.SessionID), + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), slog.Any("error", err), ) errCh <- err return } - if err := r.streamChat(ctx, rc.payload, req.BotID, req.SessionID, req.Query, req.Token, chunkCh); err != nil { + if err := r.persistUserMessage(ctx, streamReq); err != nil { + r.logger.Error("gateway stream persist user message failed", + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), + slog.Any("error", err), + ) + errCh <- err + return + } + streamReq.UserMessagePersisted = true + if err := r.streamChat(ctx, rc.payload, streamReq, chunkCh); err != nil { r.logger.Error("gateway stream request failed", - slog.String("bot_id", req.BotID), - slog.String("session_id", req.SessionID), + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), slog.Any("error", err), ) errCh <- err @@ -502,7 +501,7 @@ func (r *Resolver) postTriggerSchedule(ctx context.Context, payload triggerSched return parsed, nil } -func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, botID, sessionID, query, token string, chunkCh chan<- StreamChunk) error { +func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req ChatRequest, chunkCh chan<- StreamChunk) error { body, err := json.Marshal(payload) if err != nil { return err @@ -515,8 +514,8 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, botID } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Accept", "text/event-stream") - if strings.TrimSpace(token) != "" { - httpReq.Header.Set("Authorization", token) + if strings.TrimSpace(req.Token) != "" { + httpReq.Header.Set("Authorization", req.Token) } resp, err := r.streamingClient.Do(httpReq) @@ -558,7 +557,7 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, botID if stored { continue } - if handled, storeErr := r.tryStoreStream(ctx, botID, sessionID, query, currentEvent, data); storeErr != nil { + if handled, storeErr := r.tryStoreStream(ctx, req, currentEvent, data); storeErr != nil { return storeErr } else if handled { stored = true @@ -568,16 +567,16 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, botID } // tryStoreStream attempts to extract final messages from a stream event and persist them. -func (r *Resolver) tryStoreStream(ctx context.Context, botID, sessionID, query, eventType, data string) (bool, error) { +func (r *Resolver) tryStoreStream(ctx context.Context, req ChatRequest, eventType, data string) (bool, error) { // event: done + data: {messages: [...]} if eventType == "done" { var resp gatewayResponse if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 { - return true, r.storeRound(ctx, botID, sessionID, query, resp.Messages, resp.Skills) + return true, r.storeRound(ctx, req, resp.Messages) } } - // data: {"type":"agent_end"|"done", ...} + // data: {"type":"text_delta"|"agent_end"|"done", ...} var envelope struct { Type string `json:"type"` Data json.RawMessage `json:"data"` @@ -585,13 +584,13 @@ func (r *Resolver) tryStoreStream(ctx context.Context, botID, sessionID, query, Skills []string `json:"skills"` } if err := json.Unmarshal([]byte(data), &envelope); err == nil { - if envelope.Type == "agent_end" && len(envelope.Messages) > 0 { - return true, r.storeRound(ctx, botID, sessionID, query, envelope.Messages, envelope.Skills) + if (envelope.Type == "agent_end" || envelope.Type == "done") && len(envelope.Messages) > 0 { + return true, r.storeRound(ctx, req, envelope.Messages) } if envelope.Type == "done" && len(envelope.Data) > 0 { var resp gatewayResponse if err := json.Unmarshal(envelope.Data, &resp); err == nil && len(resp.Messages) > 0 { - return true, r.storeRound(ctx, botID, sessionID, query, resp.Messages, resp.Skills) + return true, r.storeRound(ctx, req, resp.Messages) } } } @@ -599,7 +598,7 @@ func (r *Resolver) tryStoreStream(ctx context.Context, botID, sessionID, query, // fallback: data: {messages: [...]} var resp gatewayResponse if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 { - return true, r.storeRound(ctx, botID, sessionID, query, resp.Messages, resp.Skills) + return true, r.storeRound(ctx, req, resp.Messages) } return false, nil } @@ -622,103 +621,312 @@ func (r *Resolver) resolveContainerID(ctx context.Context, botID, explicit strin return "mcp-" + botID } -// --- history helpers --- +// --- message loading --- -func (r *Resolver) loadHistoryMessages(ctx context.Context, botID, sessionID string, maxContextMinutes int) ([]ModelMessage, error) { - if r.historyService == nil { - return nil, fmt.Errorf("history service not configured") - } +func (r *Resolver) loadMessages(ctx context.Context, chatID string, maxContextMinutes int) ([]ModelMessage, error) { since := time.Now().UTC().Add(-time.Duration(maxContextMinutes) * time.Minute) - records, err := r.historyService.ListBySessionSince(ctx, botID, sessionID, since) + msgs, err := r.chatService.ListMessagesSince(ctx, chatID, since) if err != nil { return nil, err } - var messages []ModelMessage - for _, record := range records { - msgs, err := recordToMessages(record) - if err != nil { - r.logger.Warn("skip malformed history record", slog.String("record_id", record.ID), slog.Any("error", err)) + var result []ModelMessage + for _, m := range msgs { + var mm ModelMessage + if err := json.Unmarshal(m.Content, &mm); err != nil { + // Fallback: treat content as text string. + mm = ModelMessage{Role: m.Role, Content: m.Content} + } else { + mm.Role = m.Role + } + result = append(result, mm) + } + return result, nil +} + +type memoryContextItem struct { + Namespace string + Item memory.MemoryItem +} + +func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest) *ModelMessage { + if r.memoryService == nil { + return nil + } + if strings.TrimSpace(req.Query) == "" || strings.TrimSpace(req.BotID) == "" || strings.TrimSpace(req.ChatID) == "" { + return nil + } + + results := make([]memoryContextItem, 0, memoryContextLimitPerScope) + seen := map[string]struct{}{} + resp, err := r.memoryService.Search(ctx, memory.SearchRequest{ + Query: req.Query, + BotID: req.BotID, + Limit: memoryContextLimitPerScope, + Filters: map[string]any{ + "namespace": sharedMemoryNamespace, + "scopeId": req.BotID, + "botId": req.BotID, + }, + }) + if err != nil { + r.logger.Warn("memory search for context failed", + slog.String("namespace", sharedMemoryNamespace), + slog.Any("error", err), + ) + return nil + } + for _, item := range resp.Results { + key := strings.TrimSpace(item.ID) + if key == "" { + key = sharedMemoryNamespace + ":" + strings.TrimSpace(item.Memory) + } + if key == "" { continue } - messages = append(messages, msgs...) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + results = append(results, memoryContextItem{Namespace: sharedMemoryNamespace, Item: item}) + } + if len(results) == 0 { + return nil } - return messages, nil -} -func (r *Resolver) loadHistorySkills(ctx context.Context, botID, sessionID string, maxContextMinutes int) ([]string, error) { - if r.historyService == nil { - return nil, fmt.Errorf("history service not configured") + sort.Slice(results, func(i, j int) bool { + return results[i].Item.Score > results[j].Item.Score + }) + if len(results) > memoryContextMaxItems { + results = results[:memoryContextMaxItems] } - since := time.Now().UTC().Add(-time.Duration(maxContextMinutes) * time.Minute) - records, err := r.historyService.ListBySessionSince(ctx, botID, sessionID, since) - if err != nil { - return nil, err - } - var combined []string - for _, record := range records { - combined = append(combined, record.Skills...) - } - return dedup(combined), nil -} -// recordToMessages converts a history record (stored as []map[string]any) to typed ModelMessages. -func recordToMessages(record history.Record) ([]ModelMessage, error) { - if len(record.Messages) == 0 { - return nil, nil + var sb strings.Builder + sb.WriteString("Relevant memory context (use when helpful):\n") + for _, entry := range results { + text := strings.TrimSpace(entry.Item.Memory) + if text == "" { + continue + } + sb.WriteString("- [") + sb.WriteString(entry.Namespace) + sb.WriteString("] ") + sb.WriteString(truncateMemorySnippet(text, memoryContextItemMaxChars)) + sb.WriteString("\n") } - raw, err := json.Marshal(record.Messages) - if err != nil { - return nil, err + payload := strings.TrimSpace(sb.String()) + if payload == "" { + return nil } - var msgs []ModelMessage - if err := json.Unmarshal(raw, &msgs); err != nil { - return nil, err + msg := ModelMessage{ + Role: "system", + Content: NewTextContent(payload), } - return msgs, nil + return &msg } // --- store helpers --- -func (r *Resolver) storeRound(ctx context.Context, botID, sessionID, query string, messages []ModelMessage, skills []string) error { - if err := r.storeHistory(ctx, botID, sessionID, query, messages, skills); err != nil { - return err - } - r.storeMemory(ctx, botID, sessionID, query, messages) - return nil -} - -func (r *Resolver) storeHistory(ctx context.Context, botID, sessionID, query string, messages []ModelMessage, skills []string) error { - if r.historyService == nil { - return fmt.Errorf("history service not configured") - } - if strings.TrimSpace(botID) == "" || strings.TrimSpace(sessionID) == "" { - return fmt.Errorf("bot id and session id are required") - } - if strings.TrimSpace(query) == "" && len(messages) == 0 { +func (r *Resolver) persistUserMessage(ctx context.Context, req ChatRequest) error { + if r.chatService == nil { return nil } - // Convert typed messages to []map[string]any for the history service. - raw, err := json.Marshal(messages) + if strings.TrimSpace(req.BotID) == "" { + return fmt.Errorf("bot id is required for persistence") + } + text := strings.TrimSpace(req.Query) + if text == "" { + return nil + } + + message := ModelMessage{ + Role: "user", + Content: NewTextContent(text), + } + content, err := json.Marshal(message) if err != nil { return err } - var rows []map[string]any - if err := json.Unmarshal(raw, &rows); err != nil { - return err - } - _, err = r.historyService.Create(ctx, botID, strings.TrimSpace(sessionID), history.CreateRequest{ - Messages: rows, - Metadata: map[string]any{"query": strings.TrimSpace(query)}, - Skills: skills, - }) + senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req) + _, err = r.chatService.PersistMessage( + ctx, + req.BotID, + req.RouteID, + senderChannelIdentityID, + senderUserID, + req.CurrentChannel, + req.ExternalMessageID, + "", + "user", + content, + buildRouteMetadata(req), + ) return err } -func (r *Resolver) storeMemory(ctx context.Context, botID, sessionID, query string, messages []ModelMessage) { +func (r *Resolver) storeRound(ctx context.Context, req ChatRequest, messages []ModelMessage) error { + // Add user query as the first message if not already present in the round. + // This ensures the user's prompt is persisted alongside the assistant's response. + fullRound := make([]ModelMessage, 0, len(messages)+1) + hasUserQuery := false + for _, m := range messages { + if m.Role == "user" && m.TextContent() == req.Query { + hasUserQuery = true + break + } + } + if !req.UserMessagePersisted && !hasUserQuery && strings.TrimSpace(req.Query) != "" { + fullRound = append(fullRound, ModelMessage{ + Role: "user", + Content: NewTextContent(req.Query), + }) + } + for _, m := range messages { + if req.UserMessagePersisted && m.Role == "user" && strings.TrimSpace(m.TextContent()) == strings.TrimSpace(req.Query) { + // User message was already persisted before streaming; skip duplicate copy in round payload. + continue + } + fullRound = append(fullRound, m) + } + if len(fullRound) == 0 { + return nil + } + + r.storeMessages(ctx, req, fullRound) + r.storeMemory(ctx, req.BotID, fullRound) + return nil +} + +func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages []ModelMessage) { + if r.chatService == nil { + return + } + if strings.TrimSpace(req.BotID) == "" { + return + } + meta := buildRouteMetadata(req) + senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req) + for _, msg := range messages { + content, err := json.Marshal(msg) + if err != nil { + continue + } + messageSenderChannelIdentityID := "" + messageSenderUserID := "" + externalMessageID := "" + sourceReplyToMessageID := "" + if msg.Role == "user" { + messageSenderChannelIdentityID = senderChannelIdentityID + messageSenderUserID = senderUserID + externalMessageID = req.ExternalMessageID + } else if strings.TrimSpace(req.ExternalMessageID) != "" { + // Assistant/tool/system outputs are linked to the inbound source message for cross-channel reply threading. + sourceReplyToMessageID = req.ExternalMessageID + } + if _, err := r.chatService.PersistMessage( + ctx, + req.BotID, + req.RouteID, + messageSenderChannelIdentityID, + messageSenderUserID, + req.CurrentChannel, + externalMessageID, + sourceReplyToMessageID, + msg.Role, + content, + meta, + ); err != nil { + r.logger.Warn("persist message failed", slog.Any("error", err)) + } + } +} + +func buildRouteMetadata(req ChatRequest) map[string]any { + if strings.TrimSpace(req.RouteID) == "" && strings.TrimSpace(req.CurrentChannel) == "" { + return nil + } + meta := map[string]any{} + if strings.TrimSpace(req.RouteID) != "" { + meta["route_id"] = req.RouteID + } + if strings.TrimSpace(req.CurrentChannel) != "" { + meta["platform"] = req.CurrentChannel + } + return meta +} + +func (r *Resolver) resolvePersistSenderIDs(ctx context.Context, req ChatRequest) (string, string) { + channelIdentityID := strings.TrimSpace(req.SourceChannelIdentityID) + userID := strings.TrimSpace(req.UserID) + + channelIdentityValid := r.isExistingChannelIdentityID(ctx, channelIdentityID) + userAsUserValid := r.isExistingUserID(ctx, userID) + userAsChannelIdentityValid := r.isExistingChannelIdentityID(ctx, userID) + + senderChannelIdentityID := "" + switch { + case channelIdentityValid: + senderChannelIdentityID = channelIdentityID + case userAsChannelIdentityValid && !userAsUserValid: + // Some flows may carry channel_identity_id in req.UserID. + senderChannelIdentityID = userID + } + + senderUserID := "" + if userAsUserValid { + senderUserID = userID + } + if senderUserID == "" && senderChannelIdentityID != "" { + if linked := r.linkedUserIDFromChannelIdentity(ctx, senderChannelIdentityID); linked != "" { + senderUserID = linked + } + } + return senderChannelIdentityID, senderUserID +} + +func (r *Resolver) isExistingChannelIdentityID(ctx context.Context, id string) bool { + if r.queries == nil { + return false + } + pgID, err := parseUUID(id) + if err != nil { + return false + } + _, err = r.queries.GetChannelIdentityByID(ctx, pgID) + return err == nil +} + +func (r *Resolver) isExistingUserID(ctx context.Context, id string) bool { + if r.queries == nil { + return false + } + pgID, err := parseUUID(id) + if err != nil { + return false + } + _, err = r.queries.GetUserByID(ctx, pgID) + return err == nil +} + +func (r *Resolver) linkedUserIDFromChannelIdentity(ctx context.Context, channelIdentityID string) string { + if r.queries == nil { + return "" + } + pgID, err := parseUUID(channelIdentityID) + if err != nil { + return "" + } + row, err := r.queries.GetChannelIdentityByID(ctx, pgID) + if err != nil || !row.UserID.Valid { + return "" + } + return row.UserID.String() +} + +func (r *Resolver) storeMemory(ctx context.Context, botID string, messages []ModelMessage) { if r.memoryService == nil { return } - if strings.TrimSpace(botID) == "" || strings.TrimSpace(sessionID) == "" { + if strings.TrimSpace(botID) == "" { return } memMsgs := make([]memory.Message, 0, len(messages)) @@ -736,27 +944,42 @@ func (r *Resolver) storeMemory(ctx context.Context, botID, sessionID, query stri if len(memMsgs) == 0 { return } + r.addMemory(ctx, botID, memMsgs, sharedMemoryNamespace, botID) +} + +func (r *Resolver) addMemory(ctx context.Context, botID string, msgs []memory.Message, namespace, scopeID string) { + filters := map[string]any{ + "namespace": namespace, + "scopeId": scopeID, + "botId": botID, + } if _, err := r.memoryService.Add(ctx, memory.AddRequest{ - Messages: memMsgs, - BotID: botID, - SessionID: strings.TrimSpace(sessionID), + Messages: msgs, + BotID: botID, + Filters: filters, }); err != nil { - r.logger.Warn("store memory failed", slog.Any("error", err)) + r.logger.Warn("store memory failed", + slog.String("namespace", namespace), + slog.String("scope_id", scopeID), + slog.Any("error", err), + ) } } // --- model selection --- -func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, botSettings settings.Settings, us resolvedUserSettings) (models.GetResponse, sqlc.LlmProvider, error) { +func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, botSettings settings.Settings, us resolvedUserSettings, cs Settings) (models.GetResponse, sqlc.LlmProvider, error) { if r.modelsService == nil { return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured") } modelID := strings.TrimSpace(req.Model) providerFilter := strings.TrimSpace(req.Provider) - // Priority: request model > bot settings > user settings. + // Priority: request model > chat settings > bot settings > user settings. if modelID == "" && providerFilter == "" { - if value := strings.TrimSpace(botSettings.ChatModelID); value != "" { + if value := strings.TrimSpace(cs.ModelID); value != "" { + modelID = value + } else if value := strings.TrimSpace(botSettings.ChatModelID); value != "" { modelID = value } else if value := strings.TrimSpace(us.ChatModelID); value != "" { modelID = value @@ -922,7 +1145,7 @@ func nonNilStrings(s []string) []string { return s } -func nonNilMessages(m []ModelMessage) []ModelMessage { +func nonNilModelMessages(m []ModelMessage) []ModelMessage { if m == nil { return []ModelMessage{} } @@ -936,14 +1159,17 @@ func truncate(s string, n int) string { return s[:n] + "..." } +func truncateMemorySnippet(s string, n int) string { + trimmed := strings.TrimSpace(s) + if len(trimmed) <= n { + return trimmed + } + return strings.TrimSpace(trimmed[:n]) + "..." +} + func parseUUID(id string) (pgtype.UUID, error) { - trimmed := strings.TrimSpace(id) - if trimmed == "" { + if strings.TrimSpace(id) == "" { return pgtype.UUID{}, fmt.Errorf("empty id") } - var pgID pgtype.UUID - if err := pgID.Scan(trimmed); err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - return pgID, nil + return db.ParseUUID(id) } diff --git a/internal/chat/resolver_memory_context_test.go b/internal/chat/resolver_memory_context_test.go new file mode 100644 index 00000000..48c9eb39 --- /dev/null +++ b/internal/chat/resolver_memory_context_test.go @@ -0,0 +1,55 @@ +package conversation + +import ( + "context" + "log/slog" + "strings" + "testing" + + "github.com/memohai/memoh/internal/memory" +) + +func TestLoadMemoryContextMessage_NoMemoryService(t *testing.T) { + resolver := &Resolver{ + memoryService: nil, + logger: slog.Default(), + } + msg := resolver.loadMemoryContextMessage(context.Background(), ChatRequest{ + Query: "hello", + BotID: "bot-1", + ChatID: "chat-1", + }) + if msg != nil { + t.Fatalf("expected nil message when memory service is nil") + } +} + +func TestLoadMemoryContextMessage_SearchFailureFallback(t *testing.T) { + resolver := &Resolver{ + memoryService: &memory.Service{}, + logger: slog.Default(), + } + msg := resolver.loadMemoryContextMessage(context.Background(), ChatRequest{ + Query: "hello", + BotID: "bot-1", + ChatID: "chat-1", + UserID: "user-1", + }) + if msg != nil { + t.Fatalf("expected nil message when memory search cannot return results") + } +} + +func TestTruncateMemorySnippet(t *testing.T) { + longText := strings.Repeat("a", 20) + " " + got := truncateMemorySnippet(longText, 10) + if got != strings.Repeat("a", 10)+"..." { + t.Fatalf("unexpected truncated value: %q", got) + } + + shortText := " short " + got = truncateMemorySnippet(shortText, 10) + if got != "short" { + t.Fatalf("unexpected trimmed short value: %q", got) + } +} diff --git a/internal/chat/resolver_test.go b/internal/chat/resolver_test.go index 6fc40224..e866ba00 100644 --- a/internal/chat/resolver_test.go +++ b/internal/chat/resolver_test.go @@ -1,4 +1,4 @@ -package chat +package conversation import ( "context" @@ -47,12 +47,10 @@ func TestPostTriggerSchedule_Endpoint(t *testing.T) { Messages: []ModelMessage{}, Skills: []string{}, Identity: gatewayIdentity{ - BotID: "bot-123", - SessionID: "schedule:sched-1", - ContainerID: "mcp-bot-123", - ContactID: "bot-123", - ContactName: "Scheduler", - UserID: "owner-user-1", + BotID: "bot-123", + ContainerID: "mcp-bot-123", + ChannelIdentityID: "owner-user-1", + DisplayName: "Scheduler", }, Attachments: []any{}, Schedule: gatewaySchedule{ diff --git a/internal/chat/schedule_gateway.go b/internal/chat/schedule_gateway.go index d1578065..8e543f78 100644 --- a/internal/chat/schedule_gateway.go +++ b/internal/chat/schedule_gateway.go @@ -1,4 +1,4 @@ -package chat +package conversation import ( "context" diff --git a/internal/chat/service.go b/internal/chat/service.go new file mode 100644 index 00000000..1ce4a6e5 --- /dev/null +++ b/internal/chat/service.go @@ -0,0 +1,1000 @@ +package conversation + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" +) + +var ( + ErrChatNotFound = errors.New("chat not found") + ErrNotParticipant = errors.New("not a participant") + ErrPermissionDenied = errors.New("permission denied") +) + +// Service manages chat lifecycle, participants, settings, and routes. +type Service struct { + queries *sqlc.Queries + logger *slog.Logger +} + +// NewService creates a chat service. +func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { + if log == nil { + log = slog.Default() + } + return &Service{ + queries: queries, + logger: log.With(slog.String("service", "chat")), + } +} + +// --- Chat CRUD --- + +// Create creates a new chat and adds the creator as owner. +func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, req CreateRequest) (Chat, error) { + kind := strings.TrimSpace(req.Kind) + if kind == "" { + kind = KindDirect + } + if kind != KindDirect && kind != KindGroup && kind != KindThread { + return Chat{}, fmt.Errorf("invalid chat kind: %s", kind) + } + + pgBotID, err := parseUUID(botID) + if err != nil { + return Chat{}, fmt.Errorf("invalid bot id: %w", err) + } + pgChannelIdentityID := pgtype.UUID{} + if strings.TrimSpace(channelIdentityID) != "" { + pgChannelIdentityID, err = parseUUID(channelIdentityID) + if err != nil { + return Chat{}, fmt.Errorf("invalid user id: %w", err) + } + } + + var pgParent pgtype.UUID + if kind == KindThread && strings.TrimSpace(req.ParentChatID) != "" { + pgParent, err = parseUUID(req.ParentChatID) + if err != nil { + return Chat{}, fmt.Errorf("invalid parent chat id: %w", err) + } + } + + metadata, err := json.Marshal(nonNilMap(req.Metadata)) + if err != nil { + return Chat{}, fmt.Errorf("marshal chat metadata: %w", err) + } + + row, err := s.queries.CreateChat(ctx, sqlc.CreateChatParams{ + BotID: pgBotID, + Kind: kind, + ParentChatID: pgParent, + Title: strings.TrimSpace(req.Title), + CreatedByUserID: pgChannelIdentityID, + Metadata: metadata, + }) + if err != nil { + return Chat{}, fmt.Errorf("create chat: %w", err) + } + + // Add creator as owner when user identity is available. + if pgChannelIdentityID.Valid { + if _, err := s.queries.AddChatParticipant(ctx, sqlc.AddChatParticipantParams{ + ChatID: row.ID, + UserID: pgChannelIdentityID, + Role: RoleOwner, + }); err != nil { + return Chat{}, fmt.Errorf("add owner participant: %w", err) + } + } + + // For threads, copy participants from parent. + if kind == KindThread && pgParent.Valid { + if err := s.queries.CopyParticipantsToChat(ctx, sqlc.CopyParticipantsToChatParams{ + ChatID: pgParent, + ChatID2: row.ID, + }); err != nil { + s.logger.Warn("copy parent participants failed", slog.Any("error", err)) + } + } + + return toChatFromCreate(row), nil +} + +// Get returns a chat by ID. +func (s *Service) Get(ctx context.Context, chatID string) (Chat, error) { + pgID, err := parseUUID(chatID) + if err != nil { + return Chat{}, ErrChatNotFound + } + row, err := s.queries.GetChatByID(ctx, pgID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return Chat{}, ErrChatNotFound + } + return Chat{}, err + } + return toChatFromGet(row), nil +} + +// GetReadAccess resolves whether a user can read a chat. +func (s *Service) GetReadAccess(ctx context.Context, chatID, channelIdentityID string) (ChatReadAccess, error) { + pgChatID, err := parseUUID(chatID) + if err != nil { + return ChatReadAccess{}, ErrPermissionDenied + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return ChatReadAccess{}, ErrPermissionDenied + } + row, err := s.queries.GetChatReadAccessByUser(ctx, sqlc.GetChatReadAccessByUserParams{ + ChatID: pgChatID, + UserID: pgChannelIdentityID, + }) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ChatReadAccess{}, ErrPermissionDenied + } + return ChatReadAccess{}, err + } + return ChatReadAccess{ + AccessMode: row.AccessMode, + ParticipantRole: strings.TrimSpace(row.ParticipantRole), + LastObservedAt: pgTimePtr(row.LastObservedAt), + }, nil +} + +// ListByBotAndChannelIdentity returns all chats visible to the user for a bot. +func (s *Service) ListByBotAndChannelIdentity(ctx context.Context, botID, channelIdentityID string) ([]ChatListItem, error) { + pgBotID, err := parseUUID(botID) + if err != nil { + return nil, err + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListVisibleChatsByBotAndUser(ctx, sqlc.ListVisibleChatsByBotAndUserParams{ + BotID: pgBotID, + UserID: pgChannelIdentityID, + }) + if err != nil { + return nil, err + } + chats := make([]ChatListItem, 0, len(rows)) + for _, row := range rows { + chats = append(chats, toChatListItem(row)) + } + return chats, nil +} + +// ListThreads returns threads for a parent chat. +func (s *Service) ListThreads(ctx context.Context, parentChatID string) ([]Chat, error) { + pgID, err := parseUUID(parentChatID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListThreadsByParent(ctx, pgID) + if err != nil { + return nil, err + } + chats := make([]Chat, 0, len(rows)) + for _, row := range rows { + chats = append(chats, toChatFromThread(row)) + } + return chats, nil +} + +// Delete deletes a chat (cascade deletes messages, routes, participants, settings). +func (s *Service) Delete(ctx context.Context, chatID string) error { + pgID, err := parseUUID(chatID) + if err != nil { + return ErrChatNotFound + } + return s.queries.DeleteChat(ctx, pgID) +} + +// --- Participants --- + +// AddParticipant adds a user identity to a chat. +func (s *Service) AddParticipant(ctx context.Context, chatID, channelIdentityID, role string) (Participant, error) { + pgChatID, err := parseUUID(chatID) + if err != nil { + return Participant{}, err + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return Participant{}, err + } + if role == "" { + role = RoleMember + } + row, err := s.queries.AddChatParticipant(ctx, sqlc.AddChatParticipantParams{ + ChatID: pgChatID, + UserID: pgChannelIdentityID, + Role: role, + }) + if err != nil { + return Participant{}, err + } + return toParticipantFromAdd(row), nil +} + +// GetParticipant returns a participant record. +func (s *Service) GetParticipant(ctx context.Context, chatID, channelIdentityID string) (Participant, error) { + pgChatID, err := parseUUID(chatID) + if err != nil { + return Participant{}, ErrNotParticipant + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return Participant{}, ErrNotParticipant + } + row, err := s.queries.GetChatParticipant(ctx, sqlc.GetChatParticipantParams{ + ChatID: pgChatID, + UserID: pgChannelIdentityID, + }) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return Participant{}, ErrNotParticipant + } + return Participant{}, err + } + return toParticipantFromGet(row), nil +} + +// IsParticipant checks whether a user identity is a participant in a chat. +func (s *Service) IsParticipant(ctx context.Context, chatID, channelIdentityID string) (bool, error) { + _, err := s.GetParticipant(ctx, chatID, channelIdentityID) + if errors.Is(err, ErrNotParticipant) { + return false, nil + } + return err == nil, err +} + +// ListParticipants returns all participants for a chat. +func (s *Service) ListParticipants(ctx context.Context, chatID string) ([]Participant, error) { + pgID, err := parseUUID(chatID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListChatParticipants(ctx, pgID) + if err != nil { + return nil, err + } + participants := make([]Participant, 0, len(rows)) + for _, row := range rows { + participants = append(participants, toParticipantFromList(row)) + } + return participants, nil +} + +// RemoveParticipant removes a user identity from a chat. +func (s *Service) RemoveParticipant(ctx context.Context, chatID, channelIdentityID string) error { + pgChatID, err := parseUUID(chatID) + if err != nil { + return err + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return err + } + return s.queries.RemoveChatParticipant(ctx, sqlc.RemoveChatParticipantParams{ + ChatID: pgChatID, + UserID: pgChannelIdentityID, + }) +} + +// --- Settings --- + +// GetSettings returns settings for a chat. Returns defaults if not found. +func (s *Service) GetSettings(ctx context.Context, chatID string) (Settings, error) { + pgID, err := parseUUID(chatID) + if err != nil { + return defaultSettings(chatID), nil + } + row, err := s.queries.GetChatSettings(ctx, pgID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return defaultSettings(chatID), nil + } + return Settings{}, err + } + return toSettingsFromRead(row), nil +} + +// UpdateSettings updates chat settings. +func (s *Service) UpdateSettings(ctx context.Context, chatID string, req UpdateSettingsRequest) (Settings, error) { + current, err := s.GetSettings(ctx, chatID) + if err != nil { + return Settings{}, err + } + if req.ModelID != nil { + current.ModelID = *req.ModelID + } + + pgID, err := parseUUID(chatID) + if err != nil { + return Settings{}, err + } + row, err := s.queries.UpsertChatSettings(ctx, sqlc.UpsertChatSettingsParams{ + ID: pgID, + ModelID: toPgText(current.ModelID), + }) + if err != nil { + return Settings{}, err + } + return toSettingsFromUpsert(row), nil +} + +// --- Routes --- + +// CreateRoute creates a new chat route. +func (s *Service) CreateRoute(ctx context.Context, chatID string, r Route) (Route, error) { + pgChatID, err := parseUUID(chatID) + if err != nil { + return Route{}, err + } + pgBotID, err := parseUUID(r.BotID) + if err != nil { + return Route{}, err + } + var pgConfigID pgtype.UUID + if strings.TrimSpace(r.ChannelConfigID) != "" { + pgConfigID, err = parseUUID(r.ChannelConfigID) + if err != nil { + return Route{}, err + } + } + metadata, err := json.Marshal(nonNilMap(r.Metadata)) + if err != nil { + return Route{}, fmt.Errorf("marshal route metadata: %w", err) + } + row, err := s.queries.CreateChatRoute(ctx, sqlc.CreateChatRouteParams{ + ChatID: pgChatID, + BotID: pgBotID, + Platform: r.Platform, + ChannelConfigID: pgConfigID, + ConversationID: r.ConversationID, + ThreadID: toPgText(r.ThreadID), + ReplyTarget: toPgText(r.ReplyTarget), + Metadata: metadata, + }) + if err != nil { + return Route{}, fmt.Errorf("create route: %w", err) + } + return toRouteFromCreate(row), nil +} + +// FindRoute looks up a route by (bot_id, platform, conversation_id, thread_id). +func (s *Service) FindRoute(ctx context.Context, botID, platform, conversationID, threadID string) (Route, error) { + pgBotID, err := parseUUID(botID) + if err != nil { + return Route{}, err + } + row, err := s.queries.FindChatRoute(ctx, sqlc.FindChatRouteParams{ + BotID: pgBotID, + Platform: platform, + ConversationID: conversationID, + ThreadID: toPgText(threadID), + }) + if err != nil { + return Route{}, err + } + return toRouteFromFind(row), nil +} + +// GetRouteByID returns a single route by its ID. +func (s *Service) GetRouteByID(ctx context.Context, routeID string) (Route, error) { + pgID, err := parseUUID(routeID) + if err != nil { + return Route{}, err + } + row, err := s.queries.GetChatRouteByID(ctx, pgID) + if err != nil { + return Route{}, err + } + return toRouteFromGet(row), nil +} + +// ListRoutes lists all routes for a chat. +func (s *Service) ListRoutes(ctx context.Context, chatID string) ([]Route, error) { + pgID, err := parseUUID(chatID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListChatRoutes(ctx, pgID) + if err != nil { + return nil, err + } + routes := make([]Route, 0, len(rows)) + for _, row := range rows { + routes = append(routes, toRouteFromList(row)) + } + return routes, nil +} + +// DeleteRoute deletes a route. +func (s *Service) DeleteRoute(ctx context.Context, routeID string) error { + pgID, err := parseUUID(routeID) + if err != nil { + return err + } + return s.queries.DeleteChatRoute(ctx, pgID) +} + +// UpdateRouteReplyTarget updates the reply target for a route. +func (s *Service) UpdateRouteReplyTarget(ctx context.Context, routeID, replyTarget string) error { + pgID, err := parseUUID(routeID) + if err != nil { + return err + } + return s.queries.UpdateChatRouteReplyTarget(ctx, sqlc.UpdateChatRouteReplyTargetParams{ + ID: pgID, + ReplyTarget: toPgText(replyTarget), + }) +} + +// --- ResolveChat --- + +// ResolveChat finds or creates a chat for a channel inbound message. +func (s *Service) ResolveChat(ctx context.Context, botID, platform, conversationID, threadID, conversationType, channelIdentityID, channelConfigID, replyTarget string) (ResolveChatResult, error) { + // Look up existing route. + route, err := s.FindRoute(ctx, botID, platform, conversationID, threadID) + if err == nil { + // Route found, ensure the sender identity is a participant. + if strings.TrimSpace(channelIdentityID) != "" { + ok, checkErr := s.IsParticipant(ctx, route.ChatID, channelIdentityID) + if checkErr != nil { + return ResolveChatResult{}, fmt.Errorf("check chat participant: %w", checkErr) + } + if !ok { + if _, err := s.AddParticipant(ctx, route.ChatID, channelIdentityID, RoleMember); err != nil { + s.logger.Warn("auto-add participant failed", slog.Any("error", err)) + } + } + } + // Update reply target if changed. + if strings.TrimSpace(replyTarget) != "" && replyTarget != route.ReplyTarget { + if err := s.UpdateRouteReplyTarget(ctx, route.ID, replyTarget); err != nil && s.logger != nil { + s.logger.Warn("update route reply target failed", slog.Any("error", err)) + } + } + pgRouteChatID, parseErr := parseUUID(route.ChatID) + if parseErr != nil { + return ResolveChatResult{}, fmt.Errorf("parse route chat id: %w", parseErr) + } + if err := s.queries.TouchChat(ctx, pgRouteChatID); err != nil && s.logger != nil { + s.logger.Warn("touch chat failed", slog.Any("error", err)) + } + return ResolveChatResult{ChatID: route.ChatID, RouteID: route.ID, Created: false}, nil + } + + // Route not found, create chat + route + participant. + kind := determineChatKind(threadID, conversationType) + creatorChannelIdentityID := s.resolveChatCreatorChannelIdentityID(ctx, botID, channelIdentityID, kind) + + var parentChatID string + if kind == KindThread { + parentRoute, parentErr := s.FindRoute(ctx, botID, platform, conversationID, "") + if parentErr == nil { + parentChatID = parentRoute.ChatID + } + } + + c, err := s.Create(ctx, botID, creatorChannelIdentityID, CreateRequest{ + Kind: kind, + ParentChatID: parentChatID, + }) + if err != nil { + return ResolveChatResult{}, fmt.Errorf("create chat: %w", err) + } + if strings.TrimSpace(channelIdentityID) != "" && strings.TrimSpace(channelIdentityID) != strings.TrimSpace(creatorChannelIdentityID) { + if _, err := s.AddParticipant(ctx, c.ID, channelIdentityID, RoleMember); err != nil { + s.logger.Warn("auto-add creator participant failed", slog.Any("error", err)) + } + } + + newRoute, err := s.CreateRoute(ctx, c.ID, Route{ + BotID: botID, + Platform: platform, + ChannelConfigID: channelConfigID, + ConversationID: conversationID, + ThreadID: threadID, + ReplyTarget: replyTarget, + }) + if err != nil { + return ResolveChatResult{}, fmt.Errorf("create route: %w", err) + } + + return ResolveChatResult{ChatID: c.ID, RouteID: newRoute.ID, Created: true}, nil +} + +// --- Messages --- + +// PersistMessage writes a single message to bot_history_messages. +func (s *Service) PersistMessage(ctx context.Context, botID, routeID, senderChannelIdentityID, senderUserID, platform, externalMessageID, sourceReplyToMessageID, role string, content json.RawMessage, metadata map[string]any) (Message, error) { + pgBotID, err := parseUUID(botID) + if err != nil { + return Message{}, err + } + var pgRouteID pgtype.UUID + if strings.TrimSpace(routeID) != "" { + pgRouteID, err = parseUUID(routeID) + if err != nil { + return Message{}, err + } + } + var pgSender pgtype.UUID + if strings.TrimSpace(senderChannelIdentityID) != "" { + pgSender, err = parseUUID(senderChannelIdentityID) + if err != nil { + return Message{}, fmt.Errorf("invalid sender channel identity id: %w", err) + } + } + var pgSenderUser pgtype.UUID + if strings.TrimSpace(senderUserID) != "" { + pgSenderUser, err = parseUUID(senderUserID) + if err != nil { + return Message{}, fmt.Errorf("invalid sender user id: %w", err) + } + } + metaBytes, err := json.Marshal(nonNilMap(metadata)) + if err != nil { + return Message{}, fmt.Errorf("marshal message metadata: %w", err) + } + if len(content) == 0 { + content = []byte("{}") + } + + row, err := s.queries.CreateMessage(ctx, sqlc.CreateMessageParams{ + BotID: pgBotID, + RouteID: pgRouteID, + SenderChannelIdentityID: pgSender, + SenderUserID: pgSenderUser, + Platform: toPgText(platform), + ExternalMessageID: toPgText(externalMessageID), + SourceReplyToMessageID: toPgText(sourceReplyToMessageID), + Role: role, + Content: content, + Metadata: metaBytes, + }) + if err != nil { + return Message{}, err + } + return toMessageFromCreate(row), nil +} + +// ListMessages returns all messages for a bot. +func (s *Service) ListMessages(ctx context.Context, botID string) ([]Message, error) { + pgID, err := parseUUID(botID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListMessages(ctx, pgID) + if err != nil { + return nil, err + } + return toMessagesFromList(rows), nil +} + +// ListMessagesSince returns bot messages since a given time. +func (s *Service) ListMessagesSince(ctx context.Context, botID string, since time.Time) ([]Message, error) { + pgID, err := parseUUID(botID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListMessagesSince(ctx, sqlc.ListMessagesSinceParams{ + BotID: pgID, + CreatedAt: pgtype.Timestamptz{Time: since, Valid: true}, + }) + if err != nil { + return nil, err + } + return toMessagesFromSince(rows), nil +} + +// ListMessagesLatest returns the latest N bot messages (most recent first). +func (s *Service) ListMessagesLatest(ctx context.Context, botID string, limit int32) ([]Message, error) { + pgID, err := parseUUID(botID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListMessagesLatest(ctx, sqlc.ListMessagesLatestParams{ + BotID: pgID, + MaxCount: limit, + }) + if err != nil { + return nil, err + } + return toMessagesFromLatest(rows), nil +} + +// DeleteMessages deletes all messages for a bot. +func (s *Service) DeleteMessages(ctx context.Context, botID string) error { + pgID, err := parseUUID(botID) + if err != nil { + return err + } + return s.queries.DeleteMessagesByBot(ctx, pgID) +} + +// --- conversion helpers --- + +func toChatFromCreate(row sqlc.CreateChatRow) Chat { + return toChatFields( + row.ID, + row.BotID, + row.Kind, + row.ParentChatID, + row.Title, + row.CreatedByUserID, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toChatFromGet(row sqlc.GetChatByIDRow) Chat { + return toChatFields( + row.ID, + row.BotID, + row.Kind, + row.ParentChatID, + row.Title, + row.CreatedByUserID, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toChatFromThread(row sqlc.ListThreadsByParentRow) Chat { + return toChatFields( + row.ID, + row.BotID, + row.Kind, + row.ParentChatID, + row.Title, + row.CreatedByUserID, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toChatFields(id, botID pgtype.UUID, kind string, parentChatID pgtype.UUID, title pgtype.Text, createdBy pgtype.UUID, metadata []byte, createdAt, updatedAt pgtype.Timestamptz) Chat { + return Chat{ + ID: id.String(), + BotID: botID.String(), + Kind: kind, + ParentChatID: parentChatID.String(), + Title: db.TextToString(title), + CreatedBy: createdBy.String(), + Metadata: parseJSONMap(metadata), + CreatedAt: createdAt.Time, + UpdatedAt: updatedAt.Time, + } +} + +func toChatListItem(row sqlc.ListVisibleChatsByBotAndUserRow) ChatListItem { + return ChatListItem{ + ID: row.ID.String(), + BotID: row.BotID.String(), + Kind: row.Kind, + ParentChatID: row.ParentChatID.String(), + Title: db.TextToString(row.Title), + CreatedBy: row.CreatedByUserID.String(), + Metadata: parseJSONMap(row.Metadata), + CreatedAt: row.CreatedAt.Time, + UpdatedAt: row.UpdatedAt.Time, + AccessMode: row.AccessMode, + ParticipantRole: strings.TrimSpace(row.ParticipantRole), + LastObservedAt: pgTimePtr(row.LastObservedAt), + } +} + +func toParticipantFromAdd(row sqlc.AddChatParticipantRow) Participant { + return toParticipantFields(row.ChatID, row.UserID, row.Role, row.JoinedAt) +} + +func toParticipantFromGet(row sqlc.GetChatParticipantRow) Participant { + return toParticipantFields(row.ChatID, row.UserID, row.Role, row.JoinedAt) +} + +func toParticipantFromList(row sqlc.ListChatParticipantsRow) Participant { + return toParticipantFields(row.ChatID, row.UserID, row.Role, row.JoinedAt) +} + +func toParticipantFields(chatID, userID pgtype.UUID, role string, joinedAt pgtype.Timestamptz) Participant { + return Participant{ + ChatID: chatID.String(), + UserID: userID.String(), + Role: role, + JoinedAt: joinedAt.Time, + } +} + +func toSettingsFromRead(row sqlc.GetChatSettingsRow) Settings { + return Settings{ + ChatID: row.ChatID.String(), + ModelID: db.TextToString(row.ModelID), + } +} + +func toSettingsFromUpsert(row sqlc.UpsertChatSettingsRow) Settings { + return Settings{ + ChatID: row.ChatID.String(), + ModelID: db.TextToString(row.ModelID), + } +} + +func toRouteFromCreate(row sqlc.CreateChatRouteRow) Route { + return toRouteFields( + row.ID, + row.ChatID, + row.BotID, + row.Platform, + row.ChannelConfigID, + row.ConversationID, + row.ThreadID, + row.ReplyTarget, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toRouteFromFind(row sqlc.FindChatRouteRow) Route { + return toRouteFields( + row.ID, + row.ChatID, + row.BotID, + row.Platform, + row.ChannelConfigID, + row.ConversationID, + row.ThreadID, + row.ReplyTarget, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toRouteFromGet(row sqlc.GetChatRouteByIDRow) Route { + return toRouteFields( + row.ID, + row.ChatID, + row.BotID, + row.Platform, + row.ChannelConfigID, + row.ConversationID, + row.ThreadID, + row.ReplyTarget, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toRouteFromList(row sqlc.ListChatRoutesRow) Route { + return toRouteFields( + row.ID, + row.ChatID, + row.BotID, + row.Platform, + row.ChannelConfigID, + row.ConversationID, + row.ThreadID, + row.ReplyTarget, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toRouteFields(id, chatID, botID pgtype.UUID, platform string, channelConfigID pgtype.UUID, conversationID string, threadID, replyTarget pgtype.Text, metadata []byte, createdAt, updatedAt pgtype.Timestamptz) Route { + return Route{ + ID: id.String(), + ChatID: chatID.String(), + BotID: botID.String(), + Platform: platform, + ChannelConfigID: channelConfigID.String(), + ConversationID: conversationID, + ThreadID: db.TextToString(threadID), + ReplyTarget: db.TextToString(replyTarget), + Metadata: parseJSONMap(metadata), + CreatedAt: createdAt.Time, + UpdatedAt: updatedAt.Time, + } +} + +func toMessageFromCreate(row sqlc.CreateMessageRow) Message { + return toMessageFields( + row.ID, + row.BotID, + row.RouteID, + row.SenderChannelIdentityID, + row.SenderUserID, + row.Platform, + row.ExternalMessageID, + row.SourceReplyToMessageID, + row.Role, + row.Content, + row.Metadata, + row.CreatedAt, + ) +} + +func toMessageFromListRow(row sqlc.ListMessagesRow) Message { + return toMessageFields( + row.ID, + row.BotID, + row.RouteID, + row.SenderChannelIdentityID, + row.SenderUserID, + row.Platform, + row.ExternalMessageID, + row.SourceReplyToMessageID, + row.Role, + row.Content, + row.Metadata, + row.CreatedAt, + ) +} + +func toMessageFromSinceRow(row sqlc.ListMessagesSinceRow) Message { + return toMessageFields( + row.ID, + row.BotID, + row.RouteID, + row.SenderChannelIdentityID, + row.SenderUserID, + row.Platform, + row.ExternalMessageID, + row.SourceReplyToMessageID, + row.Role, + row.Content, + row.Metadata, + row.CreatedAt, + ) +} + +func toMessageFromLatestRow(row sqlc.ListMessagesLatestRow) Message { + return toMessageFields( + row.ID, + row.BotID, + row.RouteID, + row.SenderChannelIdentityID, + row.SenderUserID, + row.Platform, + row.ExternalMessageID, + row.SourceReplyToMessageID, + row.Role, + row.Content, + row.Metadata, + row.CreatedAt, + ) +} + +func toMessageFields(id, botID, routeID, senderChannelIdentityID, senderUserID pgtype.UUID, platform, externalMessageID, sourceReplyToMessageID pgtype.Text, role string, content, metadata []byte, createdAt pgtype.Timestamptz) Message { + return Message{ + ID: id.String(), + BotID: botID.String(), + RouteID: routeID.String(), + SenderChannelIdentityID: senderChannelIdentityID.String(), + SenderUserID: senderUserID.String(), + Platform: db.TextToString(platform), + ExternalMessageID: db.TextToString(externalMessageID), + SourceReplyToMessageID: db.TextToString(sourceReplyToMessageID), + Role: role, + Content: json.RawMessage(content), + Metadata: parseJSONMap(metadata), + CreatedAt: createdAt.Time, + } +} + +func toMessagesFromList(rows []sqlc.ListMessagesRow) []Message { + msgs := make([]Message, 0, len(rows)) + for _, row := range rows { + msgs = append(msgs, toMessageFromListRow(row)) + } + return msgs +} + +func toMessagesFromSince(rows []sqlc.ListMessagesSinceRow) []Message { + msgs := make([]Message, 0, len(rows)) + for _, row := range rows { + msgs = append(msgs, toMessageFromSinceRow(row)) + } + return msgs +} + +func toMessagesFromLatest(rows []sqlc.ListMessagesLatestRow) []Message { + msgs := make([]Message, 0, len(rows)) + for _, row := range rows { + msgs = append(msgs, toMessageFromLatestRow(row)) + } + return msgs +} + +func defaultSettings(chatID string) Settings { + return Settings{ + ChatID: chatID, + } +} + +func determineChatKind(threadID, conversationType string) string { + if strings.TrimSpace(threadID) != "" { + return KindThread + } + ct := strings.ToLower(strings.TrimSpace(conversationType)) + if ct == "p2p" || ct == "private" || ct == "" { + return KindDirect + } + return KindGroup +} + +func toPgText(s string) pgtype.Text { + s = strings.TrimSpace(s) + if s == "" { + return pgtype.Text{} + } + return pgtype.Text{String: s, Valid: true} +} + +func pgTimePtr(ts pgtype.Timestamptz) *time.Time { + if !ts.Valid { + return nil + } + value := ts.Time + return &value +} + +func nonNilMap(m map[string]any) map[string]any { + if m == nil { + return map[string]any{} + } + return m +} + +func parseJSONMap(data []byte) map[string]any { + if len(data) == 0 { + return nil + } + var m map[string]any + _ = json.Unmarshal(data, &m) + return m +} + +func (s *Service) resolveChatCreatorChannelIdentityID(ctx context.Context, botID, fallbackChannelIdentityID, kind string) string { + fallback := strings.TrimSpace(fallbackChannelIdentityID) + if kind != KindGroup || s.queries == nil { + return fallback + } + pgBotID, err := parseUUID(botID) + if err != nil { + return fallback + } + row, err := s.queries.GetBotByID(ctx, pgBotID) + if err != nil { + s.logger.Warn("resolve bot owner for group chat failed", slog.Any("error", err)) + return fallback + } + ownerChannelIdentityID := row.OwnerUserID.String() + if strings.TrimSpace(ownerChannelIdentityID) == "" { + return fallback + } + return ownerChannelIdentityID +} diff --git a/internal/chat/service_presence_integration_test.go b/internal/chat/service_presence_integration_test.go new file mode 100644 index 00000000..9b07851b --- /dev/null +++ b/internal/chat/service_presence_integration_test.go @@ -0,0 +1,244 @@ +package conversation_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/memohai/memoh/internal/channel/identities" + conversation "github.com/memohai/memoh/internal/chat" + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" +) + +type chatPresenceFixture struct { + chatSvc *conversation.Service + channelIdentitySvc *identities.Service + queries *sqlc.Queries + cleanup func() +} + +func setupChatPresenceIntegrationTest(t *testing.T) chatPresenceFixture { + t.Helper() + + dsn := os.Getenv("TEST_POSTGRES_DSN") + if dsn == "" { + t.Skip("skip integration test: TEST_POSTGRES_DSN is not set") + } + + ctx := context.Background() + pool, err := pgxpool.New(ctx, dsn) + if err != nil { + t.Skipf("skip integration test: cannot connect to database: %v", err) + } + if err := pool.Ping(ctx); err != nil { + pool.Close() + t.Skipf("skip integration test: database ping failed: %v", err) + } + + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + queries := sqlc.New(pool) + + return chatPresenceFixture{ + chatSvc: conversation.NewService(logger, queries), + channelIdentitySvc: identities.NewService(logger, queries), + queries: queries, + cleanup: func() { pool.Close() }, + } +} + +func createUserForChatPresence(ctx context.Context, queries *sqlc.Queries) (string, error) { + row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{ + IsActive: true, + Metadata: []byte("{}"), + }) + if err != nil { + return "", err + } + return row.ID.String(), nil +} + +func createBotForChatPresence(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) { + pgOwnerID, err := db.ParseUUID(ownerUserID) + if err != nil { + return "", err + } + meta, err := json.Marshal(map[string]any{"source": "chat-presence-integration-test"}) + if err != nil { + return "", err + } + row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{ + OwnerUserID: pgOwnerID, + Type: "personal", + DisplayName: pgtype.Text{String: "presence-test-bot", Valid: true}, + IsActive: true, + Metadata: meta, + }) + if err != nil { + return "", err + } + return row.ID.String(), nil +} + +func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, string, string, string) { + t.Helper() + + fixture := setupChatPresenceIntegrationTest(t) + ctx := context.Background() + + ownerUserID, err := createUserForChatPresence(ctx, fixture.queries) + if err != nil { + fixture.cleanup() + t.Fatalf("create owner user failed: %v", err) + } + observerUserID, err := createUserForChatPresence(ctx, fixture.queries) + if err != nil { + fixture.cleanup() + t.Fatalf("create observer user failed: %v", err) + } + botID, err := createBotForChatPresence(ctx, fixture.queries, ownerUserID) + if err != nil { + fixture.cleanup() + t.Fatalf("create bot failed: %v", err) + } + + createdChat, err := fixture.chatSvc.Create(ctx, botID, ownerUserID, conversation.CreateRequest{ + Kind: conversation.KindGroup, + Title: "presence-observed", + }) + if err != nil { + fixture.cleanup() + t.Fatalf("create chat failed: %v", err) + } + + observedChannelIdentity, err := fixture.channelIdentitySvc.ResolveByChannelIdentity( + ctx, + "feishu", + fmt.Sprintf("presence-channelIdentity-%d", time.Now().UnixNano()), + "presence-observer", + ) + if err != nil { + fixture.cleanup() + t.Fatalf("resolve channelIdentity failed: %v", err) + } + + _, err = fixture.chatSvc.PersistMessage( + ctx, + botID, + "", + observedChannelIdentity.ID, + "", + "feishu", + fmt.Sprintf("ext-msg-%d", time.Now().UnixNano()), + "", + "user", + []byte(`{"content":"hello from observed channelIdentity"}`), + nil, + ) + if err != nil { + fixture.cleanup() + t.Fatalf("persist message failed: %v", err) + } + + return fixture, botID, createdChat.ID, observerUserID, observedChannelIdentity.ID +} + +func TestObservedChatVisibleAfterBindWithoutBackfill(t *testing.T) { + fixture, botID, chatID, observerUserID, observedChannelIdentityID := setupObservedChatScenario(t) + defer fixture.cleanup() + + ctx := context.Background() + beforeBind, err := fixture.chatSvc.ListByBotAndChannelIdentity(ctx, botID, observerUserID) + if err != nil { + t.Fatalf("list chats before bind failed: %v", err) + } + if len(beforeBind) != 0 { + t.Fatalf("expected no visible chats before bind, got %d", len(beforeBind)) + } + + if err := fixture.channelIdentitySvc.LinkChannelIdentityToUser(ctx, observedChannelIdentityID, observerUserID); err != nil { + t.Fatalf("link channelIdentity to user failed: %v", err) + } + + afterBind, err := fixture.chatSvc.ListByBotAndChannelIdentity(ctx, botID, observerUserID) + if err != nil { + t.Fatalf("list chats after bind failed: %v", err) + } + if len(afterBind) == 0 { + t.Fatalf("expected observed chat visible after bind, got %d chats", len(afterBind)) + } + + var target *conversation.ChatListItem + for i := range afterBind { + if afterBind[i].ID == chatID { + target = &afterBind[i] + break + } + } + if target == nil { + t.Fatalf("expected chat %s in visible list after bind", chatID) + } + if target.AccessMode != conversation.AccessModeChannelIdentityObserved { + t.Fatalf("expected access_mode=%s, got %s", conversation.AccessModeChannelIdentityObserved, target.AccessMode) + } + if target.ParticipantRole != "" { + t.Fatalf("expected empty participant_role for observed chat, got %s", target.ParticipantRole) + } + if target.LastObservedAt == nil { + t.Fatal("expected last_observed_at to be set for observed chat") + } +} + +func TestObservedAccessReadableButNotParticipant(t *testing.T) { + fixture, botID, chatID, observerUserID, observedChannelIdentityID := setupObservedChatScenario(t) + defer fixture.cleanup() + + ctx := context.Background() + if err := fixture.channelIdentitySvc.LinkChannelIdentityToUser(ctx, observedChannelIdentityID, observerUserID); err != nil { + t.Fatalf("link channelIdentity to user failed: %v", err) + } + + access, err := fixture.chatSvc.GetReadAccess(ctx, chatID, observerUserID) + if err != nil { + t.Fatalf("get read access failed: %v", err) + } + if access.AccessMode != conversation.AccessModeChannelIdentityObserved { + t.Fatalf("expected read access %s, got %s", conversation.AccessModeChannelIdentityObserved, access.AccessMode) + } + + messages, err := fixture.chatSvc.ListMessages(ctx, chatID) + if err != nil { + t.Fatalf("list messages failed: %v", err) + } + if len(messages) == 0 { + t.Fatal("expected observed user can read chat messages") + } + + _, err = fixture.chatSvc.GetParticipant(ctx, chatID, observerUserID) + if !errors.Is(err, conversation.ErrNotParticipant) { + t.Fatalf("expected ErrNotParticipant for observed user, got %v", err) + } + ok, err := fixture.chatSvc.IsParticipant(ctx, chatID, observerUserID) + if err != nil { + t.Fatalf("check participant failed: %v", err) + } + if ok { + t.Fatal("expected observed user to remain non-participant") + } + + visibleChats, err := fixture.chatSvc.ListByBotAndChannelIdentity(ctx, botID, observerUserID) + if err != nil { + t.Fatalf("list visible chats failed: %v", err) + } + if len(visibleChats) == 0 || visibleChats[0].AccessMode != conversation.AccessModeChannelIdentityObserved { + t.Fatal("expected observed list entry with channel_identity_observed access mode") + } +} diff --git a/internal/chat/types.go b/internal/chat/types.go index 51bb60fe..2e203280 100644 --- a/internal/chat/types.go +++ b/internal/chat/types.go @@ -1,12 +1,134 @@ -// Package chat orchestrates conversations with the agent gateway, including -// synchronous and streaming chat, scheduled triggers, history, and memory storage. -package chat +// Package conversation orchestrates interactions with the agent gateway, including +// synchronous and streaming responses, scheduled triggers, messages, and memory storage. +package conversation import ( "encoding/json" "strings" + "time" ) +// Chat kind constants. +const ( + KindDirect = "direct" + KindGroup = "group" + KindThread = "thread" +) + +// Participant role constants. +const ( + RoleOwner = "owner" + RoleAdmin = "admin" + RoleMember = "member" +) + +// Chat list access mode constants. +const ( + AccessModeParticipant = "participant" + AccessModeChannelIdentityObserved = "channel_identity_observed" +) + +// Chat is the first-class conversation container. +type Chat struct { + ID string `json:"id"` + BotID string `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID string `json:"parent_chat_id,omitempty"` + Title string `json:"title,omitempty"` + CreatedBy string `json:"created_by"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// ChatListItem is a chat entry with access context for list rendering. +type ChatListItem struct { + ID string `json:"id"` + BotID string `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID string `json:"parent_chat_id,omitempty"` + Title string `json:"title,omitempty"` + CreatedBy string `json:"created_by"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + AccessMode string `json:"access_mode"` + ParticipantRole string `json:"participant_role,omitempty"` + LastObservedAt *time.Time `json:"last_observed_at,omitempty"` +} + +// ChatReadAccess is the resolved access context for reading chat content. +type ChatReadAccess struct { + AccessMode string + ParticipantRole string + LastObservedAt *time.Time +} + +// Participant represents a chat member. +type Participant struct { + ChatID string `json:"chat_id"` + UserID string `json:"user_id"` + Role string `json:"role"` + JoinedAt time.Time `json:"joined_at"` +} + +// Settings holds per-chat configuration. +type Settings struct { + ChatID string `json:"chat_id"` + ModelID string `json:"model_id,omitempty"` +} + +// Route maps external channel conversations to a chat. +type Route struct { + ID string `json:"id"` + ChatID string `json:"chat_id"` + BotID string `json:"bot_id"` + Platform string `json:"platform"` + ChannelConfigID string `json:"channel_config_id,omitempty"` + ConversationID string `json:"conversation_id"` + ThreadID string `json:"thread_id,omitempty"` + ReplyTarget string `json:"reply_target,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// Message represents a single persisted bot message. +type Message struct { + ID string `json:"id"` + BotID string `json:"bot_id"` + RouteID string `json:"route_id,omitempty"` + SenderChannelIdentityID string `json:"sender_channel_identity_id,omitempty"` + SenderUserID string `json:"sender_user_id,omitempty"` + Platform string `json:"platform,omitempty"` + ExternalMessageID string `json:"external_message_id,omitempty"` + SourceReplyToMessageID string `json:"source_reply_to_message_id,omitempty"` + Role string `json:"role"` + Content json.RawMessage `json:"content"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +// CreateRequest is the input for creating a bot-scoped conversation container. +type CreateRequest struct { + Kind string `json:"kind"` + Title string `json:"title,omitempty"` + ParentChatID string `json:"parent_chat_id,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// UpdateSettingsRequest is the input for updating chat settings. +type UpdateSettingsRequest struct { + ModelID *string `json:"model_id,omitempty"` +} + +// ResolveChatResult is returned by ResolveChat. +type ResolveChatResult struct { + ChatID string + RouteID string + Created bool +} + // ModelMessage is the canonical message format exchanged with the agent gateway. // Aligned with Vercel AI SDK ModelMessage structure. type ModelMessage struct { @@ -73,14 +195,14 @@ func NewTextContent(text string) json.RawMessage { // ContentPart represents one element of a multi-part message content. type ContentPart struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - URL string `json:"url,omitempty"` - Styles []string `json:"styles,omitempty"` - Language string `json:"language,omitempty"` - UserID string `json:"user_id,omitempty"` - Emoji string `json:"emoji,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + URL string `json:"url,omitempty"` + Styles []string `json:"styles,omitempty"` + Language string `json:"language,omitempty"` + ChannelIdentityID string `json:"channel_identity_id,omitempty"` + Emoji string `json:"emoji,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } // HasValue reports whether the content part carries a meaningful value. @@ -105,16 +227,17 @@ type ToolCallFunction struct { // ChatRequest is the input for Chat and StreamChat. type ChatRequest struct { - BotID string `json:"-"` - SessionID string `json:"-"` - Token string `json:"-"` - UserID string `json:"-"` - ContainerID string `json:"-"` - ContactID string `json:"-"` - ContactName string `json:"-"` - ContactAlias string `json:"-"` - ReplyTarget string `json:"-"` - SessionToken string `json:"-"` + BotID string `json:"-"` + ChatID string `json:"-"` + Token string `json:"-"` + UserID string `json:"-"` + SourceChannelIdentityID string `json:"-"` + ContainerID string `json:"-"` + DisplayName string `json:"-"` + RouteID string `json:"-"` + ChatToken string `json:"-"` + ExternalMessageID string `json:"-"` + UserMessagePersisted bool `json:"-"` Query string `json:"query"` Model string `json:"model,omitempty"` diff --git a/internal/contacts/service.go b/internal/contacts/service.go deleted file mode 100644 index 1a8a40c3..00000000 --- a/internal/contacts/service.go +++ /dev/null @@ -1,410 +0,0 @@ -package contacts - -import ( - "context" - "encoding/json" - "fmt" - "strings" - "time" - - "github.com/google/uuid" - "github.com/jackc/pgx/v5/pgtype" - - "github.com/memohai/memoh/internal/db/sqlc" -) - -type Service struct { - queries *sqlc.Queries -} - -func NewService(queries *sqlc.Queries) *Service { - return &Service{queries: queries} -} - -func (s *Service) GetByID(ctx context.Context, contactID string) (Contact, error) { - if s.queries == nil { - return Contact{}, fmt.Errorf("contacts queries not configured") - } - pgID, err := parseUUID(contactID) - if err != nil { - return Contact{}, err - } - row, err := s.queries.GetContactByID(ctx, pgID) - if err != nil { - return Contact{}, err - } - return normalizeContact(row) -} - -func (s *Service) GetByUserID(ctx context.Context, botID, userID string) (Contact, error) { - if s.queries == nil { - return Contact{}, fmt.Errorf("contacts queries not configured") - } - pgBotID, err := parseUUID(botID) - if err != nil { - return Contact{}, err - } - pgUserID, err := parseUUID(userID) - if err != nil { - return Contact{}, err - } - row, err := s.queries.GetContactByUserID(ctx, sqlc.GetContactByUserIDParams{ - BotID: pgBotID, - UserID: pgUserID, - }) - if err != nil { - return Contact{}, err - } - return normalizeContact(row) -} - -func (s *Service) GetByChannelIdentity(ctx context.Context, botID, platform, externalID string) (ContactChannel, error) { - if s.queries == nil { - return ContactChannel{}, fmt.Errorf("contacts queries not configured") - } - pgBotID, err := parseUUID(botID) - if err != nil { - return ContactChannel{}, err - } - row, err := s.queries.GetContactChannelByIdentity(ctx, sqlc.GetContactChannelByIdentityParams{ - BotID: pgBotID, - Platform: platform, - ExternalID: externalID, - }) - if err != nil { - return ContactChannel{}, err - } - return normalizeContactChannel(row) -} - -func (s *Service) ListChannelsByContact(ctx context.Context, contactID string) ([]ContactChannel, error) { - if s.queries == nil { - return nil, fmt.Errorf("contacts queries not configured") - } - pgContactID, err := parseUUID(contactID) - if err != nil { - return nil, err - } - rows, err := s.queries.ListContactChannelsByContact(ctx, pgContactID) - if err != nil { - return nil, err - } - items := make([]ContactChannel, 0, len(rows)) - for _, row := range rows { - item, err := normalizeContactChannel(row) - if err != nil { - return nil, err - } - items = append(items, item) - } - return items, nil -} - -func (s *Service) ListByBot(ctx context.Context, botID string) ([]Contact, error) { - if s.queries == nil { - return nil, fmt.Errorf("contacts queries not configured") - } - pgBotID, err := parseUUID(botID) - if err != nil { - return nil, err - } - rows, err := s.queries.ListContactsByBot(ctx, pgBotID) - if err != nil { - return nil, err - } - items := make([]Contact, 0, len(rows)) - for _, row := range rows { - contact, err := normalizeContact(row) - if err != nil { - return nil, err - } - items = append(items, contact) - } - return items, nil -} - -func (s *Service) Search(ctx context.Context, botID, query string) ([]Contact, error) { - if s.queries == nil { - return nil, fmt.Errorf("contacts queries not configured") - } - trimmed := strings.TrimSpace(query) - if trimmed == "" { - return s.ListByBot(ctx, botID) - } - pgBotID, err := parseUUID(botID) - if err != nil { - return nil, err - } - search := "%" + trimmed + "%" - rows, err := s.queries.SearchContacts(ctx, sqlc.SearchContactsParams{ - BotID: pgBotID, - Query: pgtype.Text{String: search, Valid: true}, - }) - if err != nil { - return nil, err - } - items := make([]Contact, 0, len(rows)) - for _, row := range rows { - contact, err := normalizeContact(row) - if err != nil { - return nil, err - } - items = append(items, contact) - } - return items, nil -} - -func (s *Service) Create(ctx context.Context, req CreateRequest) (Contact, error) { - if s.queries == nil { - return Contact{}, fmt.Errorf("contacts queries not configured") - } - pgBotID, err := parseUUID(req.BotID) - if err != nil { - return Contact{}, err - } - pgUserID := pgtype.UUID{Valid: false} - if strings.TrimSpace(req.UserID) != "" { - parsed, err := parseUUID(req.UserID) - if err != nil { - return Contact{}, err - } - pgUserID = parsed - } - payload, err := json.Marshal(defaultMetadata(req.Metadata)) - if err != nil { - return Contact{}, err - } - row, err := s.queries.CreateContact(ctx, sqlc.CreateContactParams{ - BotID: pgBotID, - UserID: pgUserID, - DisplayName: pgtype.Text{String: strings.TrimSpace(req.DisplayName), Valid: strings.TrimSpace(req.DisplayName) != ""}, - Alias: pgtype.Text{String: strings.TrimSpace(req.Alias), Valid: strings.TrimSpace(req.Alias) != ""}, - Tags: normalizeTags(req.Tags), - Status: normalizeStatus(req.Status), - Metadata: payload, - }) - if err != nil { - return Contact{}, err - } - return normalizeContact(row) -} - -func (s *Service) CreateGuest(ctx context.Context, botID, displayName string) (Contact, error) { - return s.Create(ctx, CreateRequest{ - BotID: botID, - DisplayName: displayName, - Status: "active", - }) -} - -func (s *Service) Update(ctx context.Context, contactID string, req UpdateRequest) (Contact, error) { - if s.queries == nil { - return Contact{}, fmt.Errorf("contacts queries not configured") - } - pgID, err := parseUUID(contactID) - if err != nil { - return Contact{}, err - } - var displayName pgtype.Text - if req.DisplayName != nil { - displayName = pgtype.Text{String: strings.TrimSpace(*req.DisplayName), Valid: strings.TrimSpace(*req.DisplayName) != ""} - } - var alias pgtype.Text - if req.Alias != nil { - alias = pgtype.Text{String: strings.TrimSpace(*req.Alias), Valid: strings.TrimSpace(*req.Alias) != ""} - } - var tags []string - if req.Tags != nil { - tags = normalizeTags(*req.Tags) - } - status := "" - if req.Status != nil { - status = normalizeStatus(*req.Status) - } - var metadata []byte - if req.Metadata != nil { - encoded, err := json.Marshal(defaultMetadata(req.Metadata)) - if err != nil { - return Contact{}, err - } - metadata = encoded - } - row, err := s.queries.UpdateContact(ctx, sqlc.UpdateContactParams{ - ID: pgID, - DisplayName: displayName, - Alias: alias, - Tags: tags, - Status: status, - Metadata: metadata, - }) - if err != nil { - return Contact{}, err - } - return normalizeContact(row) -} - -func (s *Service) BindUser(ctx context.Context, contactID, userID string) (Contact, error) { - if s.queries == nil { - return Contact{}, fmt.Errorf("contacts queries not configured") - } - pgContactID, err := parseUUID(contactID) - if err != nil { - return Contact{}, err - } - pgUserID, err := parseUUID(userID) - if err != nil { - return Contact{}, err - } - row, err := s.queries.UpdateContactUser(ctx, sqlc.UpdateContactUserParams{ - ID: pgContactID, - UserID: pgUserID, - }) - if err != nil { - return Contact{}, err - } - return normalizeContact(row) -} - -func (s *Service) UpsertChannel(ctx context.Context, botID, contactID, platform, externalID string, metadata map[string]any) (ContactChannel, error) { - if s.queries == nil { - return ContactChannel{}, fmt.Errorf("contacts queries not configured") - } - pgBotID, err := parseUUID(botID) - if err != nil { - return ContactChannel{}, err - } - pgContactID, err := parseUUID(contactID) - if err != nil { - return ContactChannel{}, err - } - payload, err := json.Marshal(defaultMetadata(metadata)) - if err != nil { - return ContactChannel{}, err - } - row, err := s.queries.UpsertContactChannel(ctx, sqlc.UpsertContactChannelParams{ - BotID: pgBotID, - ContactID: pgContactID, - Platform: strings.TrimSpace(platform), - ExternalID: strings.TrimSpace(externalID), - Metadata: payload, - }) - if err != nil { - return ContactChannel{}, err - } - return normalizeContactChannel(row) -} - -func normalizeContact(row sqlc.Contact) (Contact, error) { - metadata, err := decodeMetadata(row.Metadata) - if err != nil { - return Contact{}, err - } - return Contact{ - ID: toUUIDString(row.ID), - BotID: toUUIDString(row.BotID), - UserID: toUUIDString(row.UserID), - DisplayName: strings.TrimSpace(row.DisplayName.String), - Alias: strings.TrimSpace(row.Alias.String), - Tags: normalizeTags(row.Tags), - Status: strings.TrimSpace(row.Status), - Metadata: metadata, - CreatedAt: timeFromPg(row.CreatedAt), - UpdatedAt: timeFromPg(row.UpdatedAt), - }, nil -} - -func normalizeContactChannel(row sqlc.ContactChannel) (ContactChannel, error) { - metadata, err := decodeMetadata(row.Metadata) - if err != nil { - return ContactChannel{}, err - } - return ContactChannel{ - ID: toUUIDString(row.ID), - BotID: toUUIDString(row.BotID), - ContactID: toUUIDString(row.ContactID), - Platform: strings.TrimSpace(row.Platform), - ExternalID: strings.TrimSpace(row.ExternalID), - Metadata: metadata, - CreatedAt: timeFromPg(row.CreatedAt), - UpdatedAt: timeFromPg(row.UpdatedAt), - }, nil -} - -func decodeMetadata(raw []byte) (map[string]any, error) { - if len(raw) == 0 { - return map[string]any{}, nil - } - var payload map[string]any - if err := json.Unmarshal(raw, &payload); err != nil { - return nil, err - } - if payload == nil { - payload = map[string]any{} - } - return payload, nil -} - -func defaultMetadata(value map[string]any) map[string]any { - if value == nil { - return map[string]any{} - } - return value -} - -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} - -func toUUIDString(value pgtype.UUID) string { - if !value.Valid { - return "" - } - parsed, err := uuid.FromBytes(value.Bytes[:]) - if err != nil { - return "" - } - return parsed.String() -} - -func timeFromPg(value pgtype.Timestamptz) time.Time { - if value.Valid { - return value.Time - } - return time.Time{} -} - -func normalizeTags(tags []string) []string { - seen := map[string]struct{}{} - normalized := make([]string, 0, len(tags)) - for _, tag := range tags { - trimmed := strings.TrimSpace(tag) - if trimmed == "" { - continue - } - if _, ok := seen[trimmed]; ok { - continue - } - seen[trimmed] = struct{}{} - normalized = append(normalized, trimmed) - } - return normalized -} - -func normalizeStatus(status string) string { - trimmed := strings.ToLower(strings.TrimSpace(status)) - switch trimmed { - case "active", "blocked", "pending": - return trimmed - case "": - return "active" - default: - return "active" - } -} diff --git a/internal/contacts/types.go b/internal/contacts/types.go deleted file mode 100644 index f39ff1e3..00000000 --- a/internal/contacts/types.go +++ /dev/null @@ -1,45 +0,0 @@ -package contacts - -import "time" - -type Contact struct { - ID string - BotID string - UserID string - DisplayName string - Alias string - Tags []string - Status string - Metadata map[string]any - CreatedAt time.Time - UpdatedAt time.Time -} - -type ContactChannel struct { - ID string - BotID string - ContactID string - Platform string - ExternalID string - Metadata map[string]any - CreatedAt time.Time - UpdatedAt time.Time -} - -type CreateRequest struct { - BotID string - UserID string - DisplayName string - Alias string - Tags []string - Status string - Metadata map[string]any -} - -type UpdateRequest struct { - DisplayName *string - Alias *string - Tags *[]string - Status *string - Metadata map[string]any -} diff --git a/internal/conversation/flow/assistant_output.go b/internal/conversation/flow/assistant_output.go new file mode 100644 index 00000000..b752fcdf --- /dev/null +++ b/internal/conversation/flow/assistant_output.go @@ -0,0 +1,36 @@ +package flow + +import "strings" + +// ExtractAssistantOutputs collects assistant-role outputs from a slice of ModelMessages. +func ExtractAssistantOutputs(messages []ModelMessage) []AssistantOutput { + if len(messages) == 0 { + return nil + } + outputs := make([]AssistantOutput, 0, len(messages)) + for _, msg := range messages { + if msg.Role != "assistant" { + continue + } + content := strings.TrimSpace(msg.TextContent()) + parts := filterContentParts(msg.ContentParts()) + if content == "" && len(parts) == 0 { + continue + } + outputs = append(outputs, AssistantOutput{Content: content, Parts: parts}) + } + return outputs +} + +func filterContentParts(parts []ContentPart) []ContentPart { + if len(parts) == 0 { + return nil + } + filtered := make([]ContentPart, 0, len(parts)) + for _, p := range parts { + if p.HasValue() { + filtered = append(filtered, p) + } + } + return filtered +} diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go new file mode 100644 index 00000000..9ed3225d --- /dev/null +++ b/internal/conversation/flow/resolver.go @@ -0,0 +1,1226 @@ +package flow + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "sort" + "strings" + "time" + + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/conversation" + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/mcp" + "github.com/memohai/memoh/internal/memory" + messagepkg "github.com/memohai/memoh/internal/message" + "github.com/memohai/memoh/internal/models" + "github.com/memohai/memoh/internal/schedule" + "github.com/memohai/memoh/internal/settings" +) + +const ( + defaultMaxContextMinutes = 24 * 60 + memoryContextLimitPerScope = 4 + memoryContextMaxItems = 8 + memoryContextItemMaxChars = 220 + sharedMemoryNamespace = "bot" +) + +// SkillEntry represents a skill loaded from the container. +type SkillEntry struct { + Name string + Description string + Content string + Metadata map[string]any +} + +// SkillLoader loads skills for a given bot from its container. +type SkillLoader interface { + LoadSkills(ctx context.Context, botID string) ([]SkillEntry, error) +} + +// ConversationSettingsReader defines settings lookup behavior needed by flow resolution. +type ConversationSettingsReader interface { + GetSettings(ctx context.Context, conversationID string) (conversation.Settings, error) +} + +// Resolver orchestrates chat with the agent gateway. +type Resolver struct { + modelsService *models.Service + queries *sqlc.Queries + memoryService *memory.Service + conversationSvc ConversationSettingsReader + messageService messagepkg.Service + settingsService *settings.Service + mcpService *mcp.ConnectionService + skillLoader SkillLoader + gatewayBaseURL string + timeout time.Duration + logger *slog.Logger + httpClient *http.Client + streamingClient *http.Client +} + +// NewResolver creates a Resolver that communicates with the agent gateway. +func NewResolver( + log *slog.Logger, + modelsService *models.Service, + queries *sqlc.Queries, + memoryService *memory.Service, + conversationSvc ConversationSettingsReader, + messageService messagepkg.Service, + settingsService *settings.Service, + mcpService *mcp.ConnectionService, + gatewayBaseURL string, + timeout time.Duration, +) *Resolver { + if strings.TrimSpace(gatewayBaseURL) == "" { + gatewayBaseURL = "http://127.0.0.1:8081" + } + gatewayBaseURL = strings.TrimRight(gatewayBaseURL, "/") + if timeout <= 0 { + timeout = 60 * time.Second + } + return &Resolver{ + modelsService: modelsService, + queries: queries, + memoryService: memoryService, + conversationSvc: conversationSvc, + messageService: messageService, + settingsService: settingsService, + mcpService: mcpService, + gatewayBaseURL: gatewayBaseURL, + timeout: timeout, + logger: log.With(slog.String("service", "conversation_resolver")), + httpClient: &http.Client{Timeout: timeout}, + streamingClient: &http.Client{}, + } +} + +// SetSkillLoader sets the skill loader used to populate usable skills in gateway requests. +func (r *Resolver) SetSkillLoader(sl SkillLoader) { + r.skillLoader = sl +} + +// --- gateway payload --- + +type gatewayModelConfig struct { + ModelID string `json:"modelId"` + ClientType string `json:"clientType"` + Input []string `json:"input"` + APIKey string `json:"apiKey"` + BaseURL string `json:"baseUrl"` +} + +type gatewayIdentity struct { + BotID string `json:"botId"` + ContainerID string `json:"containerId"` + ChannelIdentityID string `json:"channelIdentityId"` + DisplayName string `json:"displayName"` + CurrentPlatform string `json:"currentPlatform,omitempty"` + ReplyTarget string `json:"replyTarget,omitempty"` + SessionToken string `json:"sessionToken,omitempty"` +} + +type gatewaySkill struct { + Name string `json:"name"` + Description string `json:"description"` + Content string `json:"content"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type gatewayRequest struct { + Model gatewayModelConfig `json:"model"` + ActiveContextTime int `json:"activeContextTime"` + Channels []string `json:"channels"` + CurrentChannel string `json:"currentChannel"` + AllowedActions []string `json:"allowedActions,omitempty"` + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` + UsableSkills []gatewaySkill `json:"usableSkills"` + Query string `json:"query"` + Identity gatewayIdentity `json:"identity"` + Attachments []any `json:"attachments"` +} + +type gatewayResponse struct { + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` +} + +// gatewaySchedule matches the agent gateway ScheduleModel for /chat/trigger-schedule. +type gatewaySchedule struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Pattern string `json:"pattern"` + MaxCalls *int `json:"maxCalls,omitempty"` + Command string `json:"command"` +} + +// triggerScheduleRequest is the payload for POST /chat/trigger-schedule. +type triggerScheduleRequest struct { + Model gatewayModelConfig `json:"model"` + ActiveContextTime int `json:"activeContextTime"` + Channels []string `json:"channels"` + CurrentChannel string `json:"currentChannel"` + AllowedActions []string `json:"allowedActions,omitempty"` + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` + UsableSkills []gatewaySkill `json:"usableSkills"` + Identity gatewayIdentity `json:"identity"` + Attachments []any `json:"attachments"` + Schedule gatewaySchedule `json:"schedule"` +} + +// --- resolved context (shared by Chat / StreamChat / TriggerSchedule) --- + +type resolvedContext struct { + payload gatewayRequest + model models.GetResponse + provider sqlc.LlmProvider +} + +func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContext, error) { + if strings.TrimSpace(req.Query) == "" { + return resolvedContext{}, fmt.Errorf("query is required") + } + if strings.TrimSpace(req.BotID) == "" { + return resolvedContext{}, fmt.Errorf("bot id is required") + } + if strings.TrimSpace(req.ChatID) == "" { + return resolvedContext{}, fmt.Errorf("chat id is required") + } + + skipHistory := req.MaxContextLoadTime < 0 + + botSettings, err := r.loadBotSettings(ctx, req.BotID) + if err != nil { + return resolvedContext{}, err + } + + // Check chat-level model override. + var chatSettings Settings + if r.conversationSvc != nil { + chatSettings, err = r.conversationSvc.GetSettings(ctx, req.ChatID) + if err != nil { + return resolvedContext{}, err + } + } + + userSettings, err := r.loadUserSettings(ctx, req.UserID) + if err != nil { + return resolvedContext{}, err + } + chatModel, provider, err := r.selectChatModel(ctx, req, botSettings, userSettings, chatSettings) + if err != nil { + return resolvedContext{}, err + } + clientType, err := normalizeClientType(provider.ClientType) + if err != nil { + return resolvedContext{}, err + } + maxCtx := coalescePositiveInt(req.MaxContextLoadTime, botSettings.MaxContextLoadTime, defaultMaxContextMinutes) + + var messages []ModelMessage + if !skipHistory && r.conversationSvc != nil { + messages, err = r.loadMessages(ctx, req.ChatID, maxCtx) + if err != nil { + return resolvedContext{}, err + } + } + if memoryMsg := r.loadMemoryContextMessage(ctx, req); memoryMsg != nil { + messages = append(messages, *memoryMsg) + } + messages = append(messages, req.Messages...) + messages = sanitizeMessages(messages) + skills := dedup(req.Skills) + containerID := r.resolveContainerID(ctx, req.BotID, req.ContainerID) + + var usableSkills []gatewaySkill + if r.skillLoader != nil { + entries, err := r.skillLoader.LoadSkills(ctx, req.BotID) + if err != nil { + r.logger.Warn("failed to load usable skills", slog.String("bot_id", req.BotID), slog.Any("error", err)) + } else { + usableSkills = make([]gatewaySkill, 0, len(entries)) + for _, e := range entries { + usableSkills = append(usableSkills, gatewaySkill{ + Name: e.Name, + Description: e.Description, + Content: e.Content, + Metadata: e.Metadata, + }) + } + } + } + if usableSkills == nil { + usableSkills = []gatewaySkill{} + } + + payload := gatewayRequest{ + Model: gatewayModelConfig{ + ModelID: chatModel.ModelID, + ClientType: clientType, + Input: chatModel.Input, + APIKey: provider.ApiKey, + BaseURL: provider.BaseUrl, + }, + ActiveContextTime: maxCtx, + Channels: nonNilStrings(req.Channels), + CurrentChannel: req.CurrentChannel, + AllowedActions: req.AllowedActions, + Messages: nonNilModelMessages(messages), + Skills: nonNilStrings(skills), + UsableSkills: usableSkills, + Query: req.Query, + Identity: gatewayIdentity{ + BotID: req.BotID, + ContainerID: containerID, + ChannelIdentityID: firstNonEmpty(req.SourceChannelIdentityID, req.UserID), + DisplayName: r.resolveDisplayName(ctx, req), + CurrentPlatform: req.CurrentChannel, + ReplyTarget: "", + SessionToken: req.ChatToken, + }, + Attachments: []any{}, + } + + return resolvedContext{payload: payload, model: chatModel, provider: provider}, nil +} + +// --- Chat --- + +// Chat sends a synchronous chat request to the agent gateway and stores the result. +func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, error) { + rc, err := r.resolve(ctx, req) + if err != nil { + return ChatResponse{}, err + } + resp, err := r.postChat(ctx, rc.payload, req.Token) + if err != nil { + return ChatResponse{}, err + } + if err := r.storeRound(ctx, req, resp.Messages); err != nil { + return ChatResponse{}, err + } + return ChatResponse{ + Messages: resp.Messages, + Skills: resp.Skills, + Model: rc.model.ModelID, + Provider: rc.provider.ClientType, + }, nil +} + +// --- TriggerSchedule --- + +// TriggerSchedule executes a scheduled command through the agent gateway trigger-schedule endpoint. +func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error { + if strings.TrimSpace(botID) == "" { + return fmt.Errorf("bot id is required") + } + if strings.TrimSpace(payload.Command) == "" { + return fmt.Errorf("schedule command is required") + } + + chatID := payload.ChatID + if strings.TrimSpace(chatID) == "" { + chatID = "schedule-" + payload.ID + } + req := ChatRequest{ + BotID: botID, + ChatID: chatID, + Query: payload.Command, + UserID: payload.OwnerUserID, + Token: token, + } + rc, err := r.resolve(ctx, req) + if err != nil { + return err + } + + triggerReq := triggerScheduleRequest{ + Model: rc.payload.Model, + ActiveContextTime: rc.payload.ActiveContextTime, + Channels: rc.payload.Channels, + CurrentChannel: rc.payload.CurrentChannel, + AllowedActions: rc.payload.AllowedActions, + Messages: rc.payload.Messages, + Skills: rc.payload.Skills, + UsableSkills: rc.payload.UsableSkills, + Identity: gatewayIdentity{ + BotID: rc.payload.Identity.BotID, + ContainerID: rc.payload.Identity.ContainerID, + ChannelIdentityID: strings.TrimSpace(payload.OwnerUserID), + DisplayName: "Scheduler", + }, + Attachments: rc.payload.Attachments, + Schedule: gatewaySchedule{ + ID: payload.ID, + Name: payload.Name, + Description: payload.Description, + Pattern: payload.Pattern, + MaxCalls: payload.MaxCalls, + Command: payload.Command, + }, + } + + resp, err := r.postTriggerSchedule(ctx, triggerReq, token) + if err != nil { + return err + } + return r.storeRound(ctx, req, resp.Messages) +} + +// --- StreamChat --- + +// StreamChat sends a streaming chat request to the agent gateway. +func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan StreamChunk, <-chan error) { + chunkCh := make(chan StreamChunk) + errCh := make(chan error, 1) + r.logger.Info("gateway stream start", + slog.String("bot_id", req.BotID), + slog.String("chat_id", req.ChatID), + ) + + go func() { + defer close(chunkCh) + defer close(errCh) + + streamReq := req + rc, err := r.resolve(ctx, streamReq) + if err != nil { + r.logger.Error("gateway stream resolve failed", + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), + slog.Any("error", err), + ) + errCh <- err + return + } + if err := r.persistUserMessage(ctx, streamReq); err != nil { + r.logger.Error("gateway stream persist user message failed", + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), + slog.Any("error", err), + ) + errCh <- err + return + } + streamReq.UserMessagePersisted = true + if err := r.streamChat(ctx, rc.payload, streamReq, chunkCh); err != nil { + r.logger.Error("gateway stream request failed", + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), + slog.Any("error", err), + ) + errCh <- err + } + }() + return chunkCh, errCh +} + +// --- HTTP helpers --- + +func (r *Resolver) postChat(ctx context.Context, payload gatewayRequest, token string) (gatewayResponse, error) { + body, err := json.Marshal(payload) + if err != nil { + return gatewayResponse{}, err + } + url := r.gatewayBaseURL + "/chat/" + r.logger.Info("gateway request", slog.String("url", url), slog.String("body_prefix", truncate(string(body), 200))) + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return gatewayResponse{}, err + } + httpReq.Header.Set("Content-Type", "application/json") + if strings.TrimSpace(token) != "" { + httpReq.Header.Set("Authorization", token) + } + + resp, err := r.httpClient.Do(httpReq) + if err != nil { + return gatewayResponse{}, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return gatewayResponse{}, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + r.logger.Error("gateway error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(respBody), 300))) + return gatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody))) + } + + var parsed gatewayResponse + if err := json.Unmarshal(respBody, &parsed); err != nil { + r.logger.Error("gateway response parse failed", slog.String("body_prefix", truncate(string(respBody), 300)), slog.Any("error", err)) + return gatewayResponse{}, fmt.Errorf("failed to parse gateway response: %w", err) + } + return parsed, nil +} + +// postTriggerSchedule sends a trigger-schedule request to the agent gateway. +func (r *Resolver) postTriggerSchedule(ctx context.Context, payload triggerScheduleRequest, token string) (gatewayResponse, error) { + body, err := json.Marshal(payload) + if err != nil { + return gatewayResponse{}, err + } + url := r.gatewayBaseURL + "/chat/trigger-schedule" + r.logger.Info("gateway trigger-schedule request", slog.String("url", url), slog.String("schedule_id", payload.Schedule.ID)) + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return gatewayResponse{}, err + } + httpReq.Header.Set("Content-Type", "application/json") + if strings.TrimSpace(token) != "" { + httpReq.Header.Set("Authorization", token) + } + + resp, err := r.httpClient.Do(httpReq) + if err != nil { + return gatewayResponse{}, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return gatewayResponse{}, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + r.logger.Error("gateway trigger-schedule error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(respBody), 300))) + return gatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody))) + } + + var parsed gatewayResponse + if err := json.Unmarshal(respBody, &parsed); err != nil { + r.logger.Error("gateway trigger-schedule response parse failed", slog.String("body_prefix", truncate(string(respBody), 300)), slog.Any("error", err)) + return gatewayResponse{}, fmt.Errorf("failed to parse gateway response: %w", err) + } + return parsed, nil +} + +func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req ChatRequest, chunkCh chan<- StreamChunk) error { + body, err := json.Marshal(payload) + if err != nil { + return err + } + url := r.gatewayBaseURL + "/chat/stream" + r.logger.Info("gateway stream request", slog.String("url", url), slog.String("body_prefix", truncate(string(body), 200))) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return err + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "text/event-stream") + if strings.TrimSpace(req.Token) != "" { + httpReq.Header.Set("Authorization", req.Token) + } + + resp, err := r.streamingClient.Do(httpReq) + if err != nil { + r.logger.Error("gateway stream connect failed", slog.String("url", url), slog.Any("error", err)) + return err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + errBody, _ := io.ReadAll(resp.Body) + r.logger.Error("gateway stream error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(errBody), 300))) + return fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(errBody))) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) + + currentEvent := "" + stored := false + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + if strings.HasPrefix(line, "event:") { + currentEvent = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + continue + } + if !strings.HasPrefix(line, "data:") { + continue + } + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if data == "" || data == "[DONE]" { + continue + } + chunkCh <- StreamChunk([]byte(data)) + + if stored { + continue + } + if handled, storeErr := r.tryStoreStream(ctx, req, currentEvent, data); storeErr != nil { + return storeErr + } else if handled { + stored = true + } + } + return scanner.Err() +} + +// tryStoreStream attempts to extract final messages from a stream event and persist them. +func (r *Resolver) tryStoreStream(ctx context.Context, req ChatRequest, eventType, data string) (bool, error) { + // event: done + data: {messages: [...]} + if eventType == "done" { + var resp gatewayResponse + if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 { + return true, r.storeRound(ctx, req, resp.Messages) + } + } + + // data: {"type":"text_delta"|"agent_end"|"done", ...} + var envelope struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` + } + if err := json.Unmarshal([]byte(data), &envelope); err == nil { + if (envelope.Type == "agent_end" || envelope.Type == "done") && len(envelope.Messages) > 0 { + return true, r.storeRound(ctx, req, envelope.Messages) + } + if envelope.Type == "done" && len(envelope.Data) > 0 { + var resp gatewayResponse + if err := json.Unmarshal(envelope.Data, &resp); err == nil && len(resp.Messages) > 0 { + return true, r.storeRound(ctx, req, resp.Messages) + } + } + } + + // fallback: data: {messages: [...]} + var resp gatewayResponse + if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 { + return true, r.storeRound(ctx, req, resp.Messages) + } + return false, nil +} + +// --- container resolution --- + +func (r *Resolver) resolveContainerID(ctx context.Context, botID, explicit string) string { + if strings.TrimSpace(explicit) != "" { + return explicit + } + if r.queries != nil { + pgBotID, err := parseResolverUUID(botID) + if err == nil { + row, err := r.queries.GetContainerByBotID(ctx, pgBotID) + if err == nil && strings.TrimSpace(row.ContainerID) != "" { + return row.ContainerID + } + } + } + return "mcp-" + botID +} + +// --- message loading --- + +func (r *Resolver) loadMessages(ctx context.Context, chatID string, maxContextMinutes int) ([]ModelMessage, error) { + if r.messageService == nil { + return nil, nil + } + since := time.Now().UTC().Add(-time.Duration(maxContextMinutes) * time.Minute) + msgs, err := r.messageService.ListSince(ctx, chatID, since) + if err != nil { + return nil, err + } + var result []ModelMessage + for _, m := range msgs { + var mm ModelMessage + if err := json.Unmarshal(m.Content, &mm); err != nil { + // Fallback: treat content as text string. + mm = ModelMessage{Role: m.Role, Content: m.Content} + } else { + mm.Role = m.Role + } + result = append(result, mm) + } + return result, nil +} + +type memoryContextItem struct { + Namespace string + Item memory.MemoryItem +} + +func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest) *ModelMessage { + if r.memoryService == nil { + return nil + } + if strings.TrimSpace(req.Query) == "" || strings.TrimSpace(req.BotID) == "" || strings.TrimSpace(req.ChatID) == "" { + return nil + } + + results := make([]memoryContextItem, 0, memoryContextLimitPerScope) + seen := map[string]struct{}{} + resp, err := r.memoryService.Search(ctx, memory.SearchRequest{ + Query: req.Query, + BotID: req.BotID, + Limit: memoryContextLimitPerScope, + Filters: map[string]any{ + "namespace": sharedMemoryNamespace, + "scopeId": req.BotID, + "botId": req.BotID, + }, + }) + if err != nil { + r.logger.Warn("memory search for context failed", + slog.String("namespace", sharedMemoryNamespace), + slog.Any("error", err), + ) + return nil + } + for _, item := range resp.Results { + key := strings.TrimSpace(item.ID) + if key == "" { + key = sharedMemoryNamespace + ":" + strings.TrimSpace(item.Memory) + } + if key == "" { + continue + } + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + results = append(results, memoryContextItem{Namespace: sharedMemoryNamespace, Item: item}) + } + if len(results) == 0 { + return nil + } + + sort.Slice(results, func(i, j int) bool { + return results[i].Item.Score > results[j].Item.Score + }) + if len(results) > memoryContextMaxItems { + results = results[:memoryContextMaxItems] + } + + var sb strings.Builder + sb.WriteString("Relevant memory context (use when helpful):\n") + for _, entry := range results { + text := strings.TrimSpace(entry.Item.Memory) + if text == "" { + continue + } + sb.WriteString("- [") + sb.WriteString(entry.Namespace) + sb.WriteString("] ") + sb.WriteString(truncateMemorySnippet(text, memoryContextItemMaxChars)) + sb.WriteString("\n") + } + payload := strings.TrimSpace(sb.String()) + if payload == "" { + return nil + } + msg := ModelMessage{ + Role: "system", + Content: NewTextContent(payload), + } + return &msg +} + +// --- store helpers --- + +func (r *Resolver) persistUserMessage(ctx context.Context, req ChatRequest) error { + if r.messageService == nil { + return nil + } + if strings.TrimSpace(req.BotID) == "" { + return fmt.Errorf("bot id is required for persistence") + } + text := strings.TrimSpace(req.Query) + if text == "" { + return nil + } + + message := ModelMessage{ + Role: "user", + Content: NewTextContent(text), + } + content, err := json.Marshal(message) + if err != nil { + return err + } + senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req) + _, err = r.messageService.Persist(ctx, messagepkg.PersistInput{ + BotID: req.BotID, + RouteID: req.RouteID, + SenderChannelIdentityID: senderChannelIdentityID, + SenderUserID: senderUserID, + Platform: req.CurrentChannel, + ExternalMessageID: req.ExternalMessageID, + Role: "user", + Content: content, + Metadata: buildRouteMetadata(req), + }) + return err +} + +func (r *Resolver) storeRound(ctx context.Context, req ChatRequest, messages []ModelMessage) error { + // Add user query as the first message if not already present in the round. + // This ensures the user's prompt is persisted alongside the assistant's response. + fullRound := make([]ModelMessage, 0, len(messages)+1) + hasUserQuery := false + for _, m := range messages { + if m.Role == "user" && m.TextContent() == req.Query { + hasUserQuery = true + break + } + } + if !req.UserMessagePersisted && !hasUserQuery && strings.TrimSpace(req.Query) != "" { + fullRound = append(fullRound, ModelMessage{ + Role: "user", + Content: NewTextContent(req.Query), + }) + } + for _, m := range messages { + if req.UserMessagePersisted && m.Role == "user" && strings.TrimSpace(m.TextContent()) == strings.TrimSpace(req.Query) { + // User message was already persisted before streaming; skip duplicate copy in round payload. + continue + } + fullRound = append(fullRound, m) + } + if len(fullRound) == 0 { + return nil + } + + r.storeMessages(ctx, req, fullRound) + r.storeMemory(ctx, req.BotID, fullRound) + return nil +} + +func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages []ModelMessage) { + if r.messageService == nil { + return + } + if strings.TrimSpace(req.BotID) == "" { + return + } + meta := buildRouteMetadata(req) + senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req) + for _, msg := range messages { + content, err := json.Marshal(msg) + if err != nil { + continue + } + messageSenderChannelIdentityID := "" + messageSenderUserID := "" + externalMessageID := "" + sourceReplyToMessageID := "" + if msg.Role == "user" { + messageSenderChannelIdentityID = senderChannelIdentityID + messageSenderUserID = senderUserID + externalMessageID = req.ExternalMessageID + } else if strings.TrimSpace(req.ExternalMessageID) != "" { + // Assistant/tool/system outputs are linked to the inbound source message for cross-channel reply threading. + sourceReplyToMessageID = req.ExternalMessageID + } + if _, err := r.messageService.Persist(ctx, messagepkg.PersistInput{ + BotID: req.BotID, + RouteID: req.RouteID, + SenderChannelIdentityID: messageSenderChannelIdentityID, + SenderUserID: messageSenderUserID, + Platform: req.CurrentChannel, + ExternalMessageID: externalMessageID, + SourceReplyToMessageID: sourceReplyToMessageID, + Role: msg.Role, + Content: content, + Metadata: meta, + }); err != nil { + r.logger.Warn("persist message failed", slog.Any("error", err)) + } + } +} + +func buildRouteMetadata(req ChatRequest) map[string]any { + if strings.TrimSpace(req.RouteID) == "" && strings.TrimSpace(req.CurrentChannel) == "" { + return nil + } + meta := map[string]any{} + if strings.TrimSpace(req.RouteID) != "" { + meta["route_id"] = req.RouteID + } + if strings.TrimSpace(req.CurrentChannel) != "" { + meta["platform"] = req.CurrentChannel + } + return meta +} + +func (r *Resolver) resolvePersistSenderIDs(ctx context.Context, req ChatRequest) (string, string) { + channelIdentityID := strings.TrimSpace(req.SourceChannelIdentityID) + userID := strings.TrimSpace(req.UserID) + + channelIdentityValid := r.isExistingChannelIdentityID(ctx, channelIdentityID) + userAsUserValid := r.isExistingUserID(ctx, userID) + userAsChannelIdentityValid := r.isExistingChannelIdentityID(ctx, userID) + + senderChannelIdentityID := "" + switch { + case channelIdentityValid: + senderChannelIdentityID = channelIdentityID + case userAsChannelIdentityValid && !userAsUserValid: + // Some flows may carry channel_identity_id in req.UserID. + senderChannelIdentityID = userID + } + + senderUserID := "" + if userAsUserValid { + senderUserID = userID + } + if senderUserID == "" && senderChannelIdentityID != "" { + if linked := r.linkedUserIDFromChannelIdentity(ctx, senderChannelIdentityID); linked != "" { + senderUserID = linked + } + } + return senderChannelIdentityID, senderUserID +} + +func (r *Resolver) isExistingChannelIdentityID(ctx context.Context, id string) bool { + if r.queries == nil { + return false + } + pgID, err := parseResolverUUID(id) + if err != nil { + return false + } + _, err = r.queries.GetChannelIdentityByID(ctx, pgID) + return err == nil +} + +func (r *Resolver) isExistingUserID(ctx context.Context, id string) bool { + if r.queries == nil { + return false + } + pgID, err := parseResolverUUID(id) + if err != nil { + return false + } + _, err = r.queries.GetUserByID(ctx, pgID) + return err == nil +} + +func (r *Resolver) linkedUserIDFromChannelIdentity(ctx context.Context, channelIdentityID string) string { + if r.queries == nil { + return "" + } + pgID, err := parseResolverUUID(channelIdentityID) + if err != nil { + return "" + } + row, err := r.queries.GetChannelIdentityByID(ctx, pgID) + if err != nil || !row.UserID.Valid { + return "" + } + return row.UserID.String() +} + +// resolveDisplayName returns the best available display name for the request identity: +// req.DisplayName if set, else channel identity's display_name, else linked user's display_name, else "User". +func (r *Resolver) resolveDisplayName(ctx context.Context, req ChatRequest) string { + if name := strings.TrimSpace(req.DisplayName); name != "" { + return name + } + if r.queries == nil { + return "User" + } + channelIdentityID := firstNonEmpty(req.SourceChannelIdentityID, req.UserID) + if channelIdentityID == "" { + return "User" + } + pgID, err := parseResolverUUID(channelIdentityID) + if err != nil { + return "User" + } + ci, err := r.queries.GetChannelIdentityByID(ctx, pgID) + if err == nil && ci.DisplayName.Valid { + if name := strings.TrimSpace(ci.DisplayName.String); name != "" { + return name + } + } + linkedUserID := r.linkedUserIDFromChannelIdentity(ctx, channelIdentityID) + if linkedUserID == "" { + return "User" + } + userPgID, err := parseResolverUUID(linkedUserID) + if err != nil { + return "User" + } + u, err := r.queries.GetUserByID(ctx, userPgID) + if err != nil || !u.DisplayName.Valid { + return "User" + } + if name := strings.TrimSpace(u.DisplayName.String); name != "" { + return name + } + return "User" +} + +func (r *Resolver) storeMemory(ctx context.Context, botID string, messages []ModelMessage) { + if r.memoryService == nil { + return + } + if strings.TrimSpace(botID) == "" { + return + } + memMsgs := make([]memory.Message, 0, len(messages)) + for _, msg := range messages { + text := strings.TrimSpace(msg.TextContent()) + if text == "" { + continue + } + role := msg.Role + if strings.TrimSpace(role) == "" { + role = "assistant" + } + memMsgs = append(memMsgs, memory.Message{Role: role, Content: text}) + } + if len(memMsgs) == 0 { + return + } + r.addMemory(ctx, botID, memMsgs, sharedMemoryNamespace, botID) +} + +func (r *Resolver) addMemory(ctx context.Context, botID string, msgs []memory.Message, namespace, scopeID string) { + filters := map[string]any{ + "namespace": namespace, + "scopeId": scopeID, + "botId": botID, + } + if _, err := r.memoryService.Add(ctx, memory.AddRequest{ + Messages: msgs, + BotID: botID, + Filters: filters, + }); err != nil { + r.logger.Warn("store memory failed", + slog.String("namespace", namespace), + slog.String("scope_id", scopeID), + slog.Any("error", err), + ) + } +} + +// --- model selection --- + +func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, botSettings settings.Settings, us resolvedUserSettings, cs Settings) (models.GetResponse, sqlc.LlmProvider, error) { + if r.modelsService == nil { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured") + } + modelID := strings.TrimSpace(req.Model) + providerFilter := strings.TrimSpace(req.Provider) + + // Priority: request model > chat settings > bot settings > user settings. + if modelID == "" && providerFilter == "" { + if value := strings.TrimSpace(cs.ModelID); value != "" { + modelID = value + } else if value := strings.TrimSpace(botSettings.ChatModelID); value != "" { + modelID = value + } else if value := strings.TrimSpace(us.ChatModelID); value != "" { + modelID = value + } + } + + if modelID == "" { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not configured: specify model in request or bot settings") + } + + if providerFilter == "" { + return r.fetchChatModel(ctx, modelID) + } + + candidates, err := r.listCandidates(ctx, providerFilter) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + for _, m := range candidates { + if m.ModelID == modelID { + prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + return m, prov, nil + } + } + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model %q not found for provider %q", modelID, providerFilter) +} + +func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) { + model, err := r.modelsService.GetByModelID(ctx, modelID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + if model.Type != models.ModelTypeChat { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model") + } + prov, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + return model, prov, nil +} + +func (r *Resolver) listCandidates(ctx context.Context, providerFilter string) ([]models.GetResponse, error) { + var all []models.GetResponse + var err error + if providerFilter != "" { + all, err = r.modelsService.ListByClientType(ctx, models.ClientType(providerFilter)) + } else { + all, err = r.modelsService.ListByType(ctx, models.ModelTypeChat) + } + if err != nil { + return nil, err + } + filtered := make([]models.GetResponse, 0, len(all)) + for _, m := range all { + if m.Type == models.ModelTypeChat { + filtered = append(filtered, m) + } + } + return filtered, nil +} + +// --- settings --- + +type resolvedUserSettings struct { + ChatModelID string +} + +func (r *Resolver) loadUserSettings(ctx context.Context, userID string) (resolvedUserSettings, error) { + if r.settingsService == nil || strings.TrimSpace(userID) == "" { + return resolvedUserSettings{}, nil + } + s, err := r.settingsService.Get(ctx, userID) + if err != nil { + return resolvedUserSettings{}, err + } + return resolvedUserSettings{ + ChatModelID: strings.TrimSpace(s.ChatModelID), + }, nil +} + +func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (settings.Settings, error) { + if r.settingsService == nil { + return settings.Settings{ + MaxContextLoadTime: settings.DefaultMaxContextLoadTime, + Language: settings.DefaultLanguage, + }, nil + } + return r.settingsService.GetBot(ctx, botID) +} + +// --- utility --- + +func normalizeClientType(clientType string) (string, error) { + switch strings.ToLower(strings.TrimSpace(clientType)) { + case "openai", "openai-compat": + return "openai", nil + case "anthropic": + return "anthropic", nil + case "google": + return "google", nil + default: + return "", fmt.Errorf("unsupported agent gateway client type: %s", clientType) + } +} + +func sanitizeMessages(messages []ModelMessage) []ModelMessage { + cleaned := make([]ModelMessage, 0, len(messages)) + for _, msg := range messages { + if strings.TrimSpace(msg.Role) == "" { + continue + } + if !msg.HasContent() && strings.TrimSpace(msg.ToolCallID) == "" { + continue + } + cleaned = append(cleaned, msg) + } + return cleaned +} + +func dedup(items []string) []string { + seen := make(map[string]struct{}, len(items)) + result := make([]string, 0, len(items)) + for _, s := range items { + trimmed := strings.TrimSpace(s) + if trimmed == "" { + continue + } + if _, ok := seen[trimmed]; ok { + continue + } + seen[trimmed] = struct{}{} + result = append(result, trimmed) + } + return result +} + +func firstNonEmpty(values ...string) string { + for _, v := range values { + if strings.TrimSpace(v) != "" { + return v + } + } + return "" +} + +func coalescePositiveInt(values ...int) int { + for _, v := range values { + if v > 0 { + return v + } + } + return defaultMaxContextMinutes +} + +func nonNilStrings(s []string) []string { + if s == nil { + return []string{} + } + return s +} + +func nonNilModelMessages(m []ModelMessage) []ModelMessage { + if m == nil { + return []ModelMessage{} + } + return m +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +} + +func truncateMemorySnippet(s string, n int) string { + trimmed := strings.TrimSpace(s) + if len(trimmed) <= n { + return trimmed + } + return strings.TrimSpace(trimmed[:n]) + "..." +} + +func parseResolverUUID(id string) (pgtype.UUID, error) { + if strings.TrimSpace(id) == "" { + return pgtype.UUID{}, fmt.Errorf("empty id") + } + return db.ParseUUID(id) +} diff --git a/internal/conversation/flow/resolver_memory_context_test.go b/internal/conversation/flow/resolver_memory_context_test.go new file mode 100644 index 00000000..1e326986 --- /dev/null +++ b/internal/conversation/flow/resolver_memory_context_test.go @@ -0,0 +1,55 @@ +package flow + +import ( + "context" + "log/slog" + "strings" + "testing" + + "github.com/memohai/memoh/internal/memory" +) + +func TestLoadMemoryContextMessage_NoMemoryService(t *testing.T) { + resolver := &Resolver{ + memoryService: nil, + logger: slog.Default(), + } + msg := resolver.loadMemoryContextMessage(context.Background(), ChatRequest{ + Query: "hello", + BotID: "bot-1", + ChatID: "chat-1", + }) + if msg != nil { + t.Fatalf("expected nil message when memory service is nil") + } +} + +func TestLoadMemoryContextMessage_SearchFailureFallback(t *testing.T) { + resolver := &Resolver{ + memoryService: &memory.Service{}, + logger: slog.Default(), + } + msg := resolver.loadMemoryContextMessage(context.Background(), ChatRequest{ + Query: "hello", + BotID: "bot-1", + ChatID: "chat-1", + UserID: "user-1", + }) + if msg != nil { + t.Fatalf("expected nil message when memory search cannot return results") + } +} + +func TestTruncateMemorySnippet(t *testing.T) { + longText := strings.Repeat("a", 20) + " " + got := truncateMemorySnippet(longText, 10) + if got != strings.Repeat("a", 10)+"..." { + t.Fatalf("unexpected truncated value: %q", got) + } + + shortText := " short " + got = truncateMemorySnippet(shortText, 10) + if got != "short" { + t.Fatalf("unexpected trimmed short value: %q", got) + } +} diff --git a/internal/conversation/flow/resolver_test.go b/internal/conversation/flow/resolver_test.go new file mode 100644 index 00000000..702d77bf --- /dev/null +++ b/internal/conversation/flow/resolver_test.go @@ -0,0 +1,158 @@ +package flow + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestPostTriggerSchedule_Endpoint(t *testing.T) { + var capturedPath string + var capturedBody []byte + var capturedAuth string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + capturedAuth = r.Header.Get("Authorization") + capturedBody, _ = io.ReadAll(r.Body) + resp := gatewayResponse{ + Messages: []ModelMessage{{Role: "assistant", Content: NewTextContent("ok")}}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + resolver := &Resolver{ + gatewayBaseURL: srv.URL, + httpClient: &http.Client{Timeout: 5 * time.Second}, + logger: slog.Default(), + } + + maxCalls := 5 + req := triggerScheduleRequest{ + Model: gatewayModelConfig{ + ModelID: "gpt-4", + ClientType: "openai", + APIKey: "sk-test", + BaseURL: "https://api.openai.com", + }, + ActiveContextTime: 1440, + Channels: []string{}, + Messages: []ModelMessage{}, + Skills: []string{}, + Identity: gatewayIdentity{ + BotID: "bot-123", + ContainerID: "mcp-bot-123", + ChannelIdentityID: "owner-user-1", + DisplayName: "Scheduler", + }, + Attachments: []any{}, + Schedule: gatewaySchedule{ + ID: "sched-1", + Name: "daily report", + Description: "generate daily report", + Pattern: "0 9 * * *", + MaxCalls: &maxCalls, + Command: "generate the daily report", + }, + } + + resp, err := resolver.postTriggerSchedule(context.Background(), req, "Bearer test-token") + if err != nil { + t.Fatalf("postTriggerSchedule returned error: %v", err) + } + + if capturedPath != "/chat/trigger-schedule" { + t.Errorf("expected path /chat/trigger-schedule, got %s", capturedPath) + } + if capturedAuth != "Bearer test-token" { + t.Errorf("expected Authorization header 'Bearer test-token', got %s", capturedAuth) + } + if len(resp.Messages) != 1 { + t.Errorf("expected 1 message, got %d", len(resp.Messages)) + } + + var body map[string]any + if err := json.Unmarshal(capturedBody, &body); err != nil { + t.Fatalf("failed to parse captured body: %v", err) + } + schedule, ok := body["schedule"].(map[string]any) + if !ok { + t.Fatal("expected 'schedule' field in request body") + } + if schedule["id"] != "sched-1" { + t.Errorf("expected schedule.id=sched-1, got %v", schedule["id"]) + } + if schedule["command"] != "generate the daily report" { + t.Errorf("expected schedule.command, got %v", schedule["command"]) + } + if _, hasQuery := body["query"]; hasQuery { + t.Error("trigger-schedule request should not contain 'query' field") + } +} + +func TestPostTriggerSchedule_NoAuth(t *testing.T) { + var capturedAuth string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAuth = r.Header.Get("Authorization") + resp := gatewayResponse{Messages: []ModelMessage{}} + json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + resolver := &Resolver{ + gatewayBaseURL: srv.URL, + httpClient: &http.Client{Timeout: 5 * time.Second}, + logger: slog.Default(), + } + + req := triggerScheduleRequest{ + Channels: []string{}, + Messages: []ModelMessage{}, + Skills: []string{}, + Attachments: []any{}, + Schedule: gatewaySchedule{ID: "s1", Command: "test"}, + } + + _, err := resolver.postTriggerSchedule(context.Background(), req, "") + if err != nil { + t.Fatalf("postTriggerSchedule returned error: %v", err) + } + if capturedAuth != "" { + t.Errorf("expected no Authorization header, got %s", capturedAuth) + } +} + +func TestPostTriggerSchedule_GatewayError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("internal error")) + })) + defer srv.Close() + + resolver := &Resolver{ + gatewayBaseURL: srv.URL, + httpClient: &http.Client{Timeout: 5 * time.Second}, + logger: slog.Default(), + } + + req := triggerScheduleRequest{ + Channels: []string{}, + Messages: []ModelMessage{}, + Skills: []string{}, + Attachments: []any{}, + Schedule: gatewaySchedule{ID: "s1", Command: "test"}, + } + + _, err := resolver.postTriggerSchedule(context.Background(), req, "Bearer tok") + if err == nil { + t.Fatal("expected error for 500 response") + } +} diff --git a/internal/conversation/flow/schedule_gateway.go b/internal/conversation/flow/schedule_gateway.go new file mode 100644 index 00000000..4b3c2138 --- /dev/null +++ b/internal/conversation/flow/schedule_gateway.go @@ -0,0 +1,26 @@ +package flow + +import ( + "context" + "fmt" + + "github.com/memohai/memoh/internal/schedule" +) + +// ScheduleGateway adapts schedule trigger calls to the chat Resolver. +type ScheduleGateway struct { + resolver *Resolver +} + +// NewScheduleGateway creates a ScheduleGateway backed by the given Resolver. +func NewScheduleGateway(resolver *Resolver) *ScheduleGateway { + return &ScheduleGateway{resolver: resolver} +} + +// TriggerSchedule delegates a schedule trigger to the chat Resolver. +func (g *ScheduleGateway) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error { + if g == nil || g.resolver == nil { + return fmt.Errorf("chat resolver not configured") + } + return g.resolver.TriggerSchedule(ctx, botID, payload, token) +} diff --git a/internal/conversation/flow/types.go b/internal/conversation/flow/types.go new file mode 100644 index 00000000..a4c19313 --- /dev/null +++ b/internal/conversation/flow/types.go @@ -0,0 +1,14 @@ +package flow + +import ( + "context" + + "github.com/memohai/memoh/internal/schedule" +) + +// Runner defines conversation execution behavior for sync, stream, and scheduled flows. +type Runner interface { + Chat(ctx context.Context, req ChatRequest) (ChatResponse, error) + StreamChat(ctx context.Context, req ChatRequest) (<-chan StreamChunk, <-chan error) + TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error +} diff --git a/internal/conversation/flow/types_alias.go b/internal/conversation/flow/types_alias.go new file mode 100644 index 00000000..1b345db6 --- /dev/null +++ b/internal/conversation/flow/types_alias.go @@ -0,0 +1,15 @@ +package flow + +import "github.com/memohai/memoh/internal/conversation" + +type ModelMessage = conversation.ModelMessage +type ContentPart = conversation.ContentPart +type ToolCall = conversation.ToolCall +type ToolCallFunction = conversation.ToolCallFunction +type AssistantOutput = conversation.AssistantOutput +type ChatRequest = conversation.ChatRequest +type ChatResponse = conversation.ChatResponse +type StreamChunk = conversation.StreamChunk +type Settings = conversation.Settings + +var NewTextContent = conversation.NewTextContent diff --git a/internal/conversation/interfaces.go b/internal/conversation/interfaces.go new file mode 100644 index 00000000..a11dc0a5 --- /dev/null +++ b/internal/conversation/interfaces.go @@ -0,0 +1,20 @@ +package conversation + +import "context" + +// Reader defines conversation lookup behavior. +type Reader interface { + Get(ctx context.Context, conversationID string) (Chat, error) +} + +// ParticipantChecker defines participant membership checks. +type ParticipantChecker interface { + IsParticipant(ctx context.Context, conversationID, channelIdentityID string) (bool, error) +} + +// Accessor defines read access checks for conversation-scoped operations. +type Accessor interface { + Reader + ParticipantChecker + GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (ChatReadAccess, error) +} diff --git a/internal/conversation/resolver.go b/internal/conversation/resolver.go new file mode 100644 index 00000000..a78f9e2c --- /dev/null +++ b/internal/conversation/resolver.go @@ -0,0 +1,1179 @@ +package conversation + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "sort" + "strings" + "time" + + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/mcp" + "github.com/memohai/memoh/internal/memory" + messagepkg "github.com/memohai/memoh/internal/message" + "github.com/memohai/memoh/internal/models" + "github.com/memohai/memoh/internal/schedule" + "github.com/memohai/memoh/internal/settings" +) + +const ( + defaultMaxContextMinutes = 24 * 60 + memoryContextLimitPerScope = 4 + memoryContextMaxItems = 8 + memoryContextItemMaxChars = 220 + sharedMemoryNamespace = "bot" +) + +// SkillEntry represents a skill loaded from the container. +type SkillEntry struct { + Name string + Description string + Content string + Metadata map[string]any +} + +// SkillLoader loads skills for a given bot from its container. +type SkillLoader interface { + LoadSkills(ctx context.Context, botID string) ([]SkillEntry, error) +} + +// Resolver orchestrates chat with the agent gateway. +type Resolver struct { + modelsService *models.Service + queries *sqlc.Queries + memoryService *memory.Service + chatService *Service + messageService messagepkg.Service + settingsService *settings.Service + mcpService *mcp.ConnectionService + skillLoader SkillLoader + gatewayBaseURL string + timeout time.Duration + logger *slog.Logger + httpClient *http.Client + streamingClient *http.Client +} + +// NewResolver creates a Resolver that communicates with the agent gateway. +func NewResolver( + log *slog.Logger, + modelsService *models.Service, + queries *sqlc.Queries, + memoryService *memory.Service, + chatService *Service, + messageService messagepkg.Service, + settingsService *settings.Service, + mcpService *mcp.ConnectionService, + gatewayBaseURL string, + timeout time.Duration, +) *Resolver { + if strings.TrimSpace(gatewayBaseURL) == "" { + gatewayBaseURL = "http://127.0.0.1:8081" + } + gatewayBaseURL = strings.TrimRight(gatewayBaseURL, "/") + if timeout <= 0 { + timeout = 60 * time.Second + } + return &Resolver{ + modelsService: modelsService, + queries: queries, + memoryService: memoryService, + chatService: chatService, + messageService: messageService, + settingsService: settingsService, + mcpService: mcpService, + gatewayBaseURL: gatewayBaseURL, + timeout: timeout, + logger: log.With(slog.String("service", "chat_resolver")), + httpClient: &http.Client{Timeout: timeout}, + streamingClient: &http.Client{}, + } +} + +// SetSkillLoader sets the skill loader used to populate usable skills in gateway requests. +func (r *Resolver) SetSkillLoader(sl SkillLoader) { + r.skillLoader = sl +} + +// --- gateway payload --- + +type gatewayModelConfig struct { + ModelID string `json:"modelId"` + ClientType string `json:"clientType"` + Input []string `json:"input"` + APIKey string `json:"apiKey"` + BaseURL string `json:"baseUrl"` +} + +type gatewayIdentity struct { + BotID string `json:"botId"` + ContainerID string `json:"containerId"` + ChannelIdentityID string `json:"channelIdentityId"` + DisplayName string `json:"displayName"` + CurrentPlatform string `json:"currentPlatform,omitempty"` + ReplyTarget string `json:"replyTarget,omitempty"` + SessionToken string `json:"sessionToken,omitempty"` +} + +type gatewaySkill struct { + Name string `json:"name"` + Description string `json:"description"` + Content string `json:"content"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type gatewayRequest struct { + Model gatewayModelConfig `json:"model"` + ActiveContextTime int `json:"activeContextTime"` + Channels []string `json:"channels"` + CurrentChannel string `json:"currentChannel"` + AllowedActions []string `json:"allowedActions,omitempty"` + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` + UsableSkills []gatewaySkill `json:"usableSkills"` + Query string `json:"query"` + Identity gatewayIdentity `json:"identity"` + Attachments []any `json:"attachments"` +} + +type gatewayResponse struct { + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` +} + +// gatewaySchedule matches the agent gateway ScheduleModel for /chat/trigger-schedule. +type gatewaySchedule struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Pattern string `json:"pattern"` + MaxCalls *int `json:"maxCalls,omitempty"` + Command string `json:"command"` +} + +// triggerScheduleRequest is the payload for POST /chat/trigger-schedule. +type triggerScheduleRequest struct { + Model gatewayModelConfig `json:"model"` + ActiveContextTime int `json:"activeContextTime"` + Channels []string `json:"channels"` + CurrentChannel string `json:"currentChannel"` + AllowedActions []string `json:"allowedActions,omitempty"` + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` + UsableSkills []gatewaySkill `json:"usableSkills"` + Identity gatewayIdentity `json:"identity"` + Attachments []any `json:"attachments"` + Schedule gatewaySchedule `json:"schedule"` +} + +// --- resolved context (shared by Chat / StreamChat / TriggerSchedule) --- + +type resolvedContext struct { + payload gatewayRequest + model models.GetResponse + provider sqlc.LlmProvider +} + +func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContext, error) { + if strings.TrimSpace(req.Query) == "" { + return resolvedContext{}, fmt.Errorf("query is required") + } + if strings.TrimSpace(req.BotID) == "" { + return resolvedContext{}, fmt.Errorf("bot id is required") + } + if strings.TrimSpace(req.ChatID) == "" { + return resolvedContext{}, fmt.Errorf("chat id is required") + } + + skipHistory := req.MaxContextLoadTime < 0 + + botSettings, err := r.loadBotSettings(ctx, req.BotID) + if err != nil { + return resolvedContext{}, err + } + + // Check chat-level model override. + var chatSettings Settings + if r.chatService != nil { + chatSettings, err = r.chatService.GetSettings(ctx, req.ChatID) + if err != nil { + return resolvedContext{}, err + } + } + + userSettings, err := r.loadUserSettings(ctx, req.UserID) + if err != nil { + return resolvedContext{}, err + } + chatModel, provider, err := r.selectChatModel(ctx, req, botSettings, userSettings, chatSettings) + if err != nil { + return resolvedContext{}, err + } + clientType, err := normalizeClientType(provider.ClientType) + if err != nil { + return resolvedContext{}, err + } + maxCtx := coalescePositiveInt(req.MaxContextLoadTime, botSettings.MaxContextLoadTime, defaultMaxContextMinutes) + + var messages []ModelMessage + if !skipHistory && r.chatService != nil { + messages, err = r.loadMessages(ctx, req.ChatID, maxCtx) + if err != nil { + return resolvedContext{}, err + } + } + if memoryMsg := r.loadMemoryContextMessage(ctx, req); memoryMsg != nil { + messages = append(messages, *memoryMsg) + } + messages = append(messages, req.Messages...) + messages = sanitizeMessages(messages) + skills := dedup(req.Skills) + containerID := r.resolveContainerID(ctx, req.BotID, req.ContainerID) + + var usableSkills []gatewaySkill + if r.skillLoader != nil { + entries, err := r.skillLoader.LoadSkills(ctx, req.BotID) + if err != nil { + r.logger.Warn("failed to load usable skills", slog.String("bot_id", req.BotID), slog.Any("error", err)) + } else { + usableSkills = make([]gatewaySkill, 0, len(entries)) + for _, e := range entries { + usableSkills = append(usableSkills, gatewaySkill{ + Name: e.Name, + Description: e.Description, + Content: e.Content, + Metadata: e.Metadata, + }) + } + } + } + if usableSkills == nil { + usableSkills = []gatewaySkill{} + } + + payload := gatewayRequest{ + Model: gatewayModelConfig{ + ModelID: chatModel.ModelID, + ClientType: clientType, + Input: chatModel.Input, + APIKey: provider.ApiKey, + BaseURL: provider.BaseUrl, + }, + ActiveContextTime: maxCtx, + Channels: nonNilStrings(req.Channels), + CurrentChannel: req.CurrentChannel, + AllowedActions: req.AllowedActions, + Messages: nonNilModelMessages(messages), + Skills: nonNilStrings(skills), + UsableSkills: usableSkills, + Query: req.Query, + Identity: gatewayIdentity{ + BotID: req.BotID, + ContainerID: containerID, + ChannelIdentityID: firstNonEmpty(req.SourceChannelIdentityID, req.UserID), + DisplayName: firstNonEmpty(req.DisplayName, "User"), + CurrentPlatform: req.CurrentChannel, + ReplyTarget: "", + SessionToken: req.ChatToken, + }, + Attachments: []any{}, + } + + return resolvedContext{payload: payload, model: chatModel, provider: provider}, nil +} + +// --- Chat --- + +// Chat sends a synchronous chat request to the agent gateway and stores the result. +func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, error) { + rc, err := r.resolve(ctx, req) + if err != nil { + return ChatResponse{}, err + } + resp, err := r.postChat(ctx, rc.payload, req.Token) + if err != nil { + return ChatResponse{}, err + } + if err := r.storeRound(ctx, req, resp.Messages); err != nil { + return ChatResponse{}, err + } + return ChatResponse{ + Messages: resp.Messages, + Skills: resp.Skills, + Model: rc.model.ModelID, + Provider: rc.provider.ClientType, + }, nil +} + +// --- TriggerSchedule --- + +// TriggerSchedule executes a scheduled command through the agent gateway trigger-schedule endpoint. +func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error { + if strings.TrimSpace(botID) == "" { + return fmt.Errorf("bot id is required") + } + if strings.TrimSpace(payload.Command) == "" { + return fmt.Errorf("schedule command is required") + } + + chatID := payload.ChatID + if strings.TrimSpace(chatID) == "" { + chatID = "schedule-" + payload.ID + } + req := ChatRequest{ + BotID: botID, + ChatID: chatID, + Query: payload.Command, + UserID: payload.OwnerUserID, + Token: token, + } + rc, err := r.resolve(ctx, req) + if err != nil { + return err + } + + triggerReq := triggerScheduleRequest{ + Model: rc.payload.Model, + ActiveContextTime: rc.payload.ActiveContextTime, + Channels: rc.payload.Channels, + CurrentChannel: rc.payload.CurrentChannel, + AllowedActions: rc.payload.AllowedActions, + Messages: rc.payload.Messages, + Skills: rc.payload.Skills, + UsableSkills: rc.payload.UsableSkills, + Identity: gatewayIdentity{ + BotID: rc.payload.Identity.BotID, + ContainerID: rc.payload.Identity.ContainerID, + ChannelIdentityID: strings.TrimSpace(payload.OwnerUserID), + DisplayName: "Scheduler", + }, + Attachments: rc.payload.Attachments, + Schedule: gatewaySchedule{ + ID: payload.ID, + Name: payload.Name, + Description: payload.Description, + Pattern: payload.Pattern, + MaxCalls: payload.MaxCalls, + Command: payload.Command, + }, + } + + resp, err := r.postTriggerSchedule(ctx, triggerReq, token) + if err != nil { + return err + } + return r.storeRound(ctx, req, resp.Messages) +} + +// --- StreamChat --- + +// StreamChat sends a streaming chat request to the agent gateway. +func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan StreamChunk, <-chan error) { + chunkCh := make(chan StreamChunk) + errCh := make(chan error, 1) + r.logger.Info("gateway stream start", + slog.String("bot_id", req.BotID), + slog.String("chat_id", req.ChatID), + ) + + go func() { + defer close(chunkCh) + defer close(errCh) + + streamReq := req + rc, err := r.resolve(ctx, streamReq) + if err != nil { + r.logger.Error("gateway stream resolve failed", + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), + slog.Any("error", err), + ) + errCh <- err + return + } + if err := r.persistUserMessage(ctx, streamReq); err != nil { + r.logger.Error("gateway stream persist user message failed", + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), + slog.Any("error", err), + ) + errCh <- err + return + } + streamReq.UserMessagePersisted = true + if err := r.streamChat(ctx, rc.payload, streamReq, chunkCh); err != nil { + r.logger.Error("gateway stream request failed", + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), + slog.Any("error", err), + ) + errCh <- err + } + }() + return chunkCh, errCh +} + +// --- HTTP helpers --- + +func (r *Resolver) postChat(ctx context.Context, payload gatewayRequest, token string) (gatewayResponse, error) { + body, err := json.Marshal(payload) + if err != nil { + return gatewayResponse{}, err + } + url := r.gatewayBaseURL + "/chat/" + r.logger.Info("gateway request", slog.String("url", url), slog.String("body_prefix", truncate(string(body), 200))) + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return gatewayResponse{}, err + } + httpReq.Header.Set("Content-Type", "application/json") + if strings.TrimSpace(token) != "" { + httpReq.Header.Set("Authorization", token) + } + + resp, err := r.httpClient.Do(httpReq) + if err != nil { + return gatewayResponse{}, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return gatewayResponse{}, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + r.logger.Error("gateway error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(respBody), 300))) + return gatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody))) + } + + var parsed gatewayResponse + if err := json.Unmarshal(respBody, &parsed); err != nil { + r.logger.Error("gateway response parse failed", slog.String("body_prefix", truncate(string(respBody), 300)), slog.Any("error", err)) + return gatewayResponse{}, fmt.Errorf("failed to parse gateway response: %w", err) + } + return parsed, nil +} + +// postTriggerSchedule sends a trigger-schedule request to the agent gateway. +func (r *Resolver) postTriggerSchedule(ctx context.Context, payload triggerScheduleRequest, token string) (gatewayResponse, error) { + body, err := json.Marshal(payload) + if err != nil { + return gatewayResponse{}, err + } + url := r.gatewayBaseURL + "/chat/trigger-schedule" + r.logger.Info("gateway trigger-schedule request", slog.String("url", url), slog.String("schedule_id", payload.Schedule.ID)) + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return gatewayResponse{}, err + } + httpReq.Header.Set("Content-Type", "application/json") + if strings.TrimSpace(token) != "" { + httpReq.Header.Set("Authorization", token) + } + + resp, err := r.httpClient.Do(httpReq) + if err != nil { + return gatewayResponse{}, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return gatewayResponse{}, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + r.logger.Error("gateway trigger-schedule error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(respBody), 300))) + return gatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody))) + } + + var parsed gatewayResponse + if err := json.Unmarshal(respBody, &parsed); err != nil { + r.logger.Error("gateway trigger-schedule response parse failed", slog.String("body_prefix", truncate(string(respBody), 300)), slog.Any("error", err)) + return gatewayResponse{}, fmt.Errorf("failed to parse gateway response: %w", err) + } + return parsed, nil +} + +func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req ChatRequest, chunkCh chan<- StreamChunk) error { + body, err := json.Marshal(payload) + if err != nil { + return err + } + url := r.gatewayBaseURL + "/chat/stream" + r.logger.Info("gateway stream request", slog.String("url", url), slog.String("body_prefix", truncate(string(body), 200))) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return err + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "text/event-stream") + if strings.TrimSpace(req.Token) != "" { + httpReq.Header.Set("Authorization", req.Token) + } + + resp, err := r.streamingClient.Do(httpReq) + if err != nil { + r.logger.Error("gateway stream connect failed", slog.String("url", url), slog.Any("error", err)) + return err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + errBody, _ := io.ReadAll(resp.Body) + r.logger.Error("gateway stream error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(errBody), 300))) + return fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(errBody))) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) + + currentEvent := "" + stored := false + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + if strings.HasPrefix(line, "event:") { + currentEvent = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + continue + } + if !strings.HasPrefix(line, "data:") { + continue + } + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if data == "" || data == "[DONE]" { + continue + } + chunkCh <- StreamChunk([]byte(data)) + + if stored { + continue + } + if handled, storeErr := r.tryStoreStream(ctx, req, currentEvent, data); storeErr != nil { + return storeErr + } else if handled { + stored = true + } + } + return scanner.Err() +} + +// tryStoreStream attempts to extract final messages from a stream event and persist them. +func (r *Resolver) tryStoreStream(ctx context.Context, req ChatRequest, eventType, data string) (bool, error) { + // event: done + data: {messages: [...]} + if eventType == "done" { + var resp gatewayResponse + if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 { + return true, r.storeRound(ctx, req, resp.Messages) + } + } + + // data: {"type":"text_delta"|"agent_end"|"done", ...} + var envelope struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` + } + if err := json.Unmarshal([]byte(data), &envelope); err == nil { + if (envelope.Type == "agent_end" || envelope.Type == "done") && len(envelope.Messages) > 0 { + return true, r.storeRound(ctx, req, envelope.Messages) + } + if envelope.Type == "done" && len(envelope.Data) > 0 { + var resp gatewayResponse + if err := json.Unmarshal(envelope.Data, &resp); err == nil && len(resp.Messages) > 0 { + return true, r.storeRound(ctx, req, resp.Messages) + } + } + } + + // fallback: data: {messages: [...]} + var resp gatewayResponse + if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 { + return true, r.storeRound(ctx, req, resp.Messages) + } + return false, nil +} + +// --- container resolution --- + +func (r *Resolver) resolveContainerID(ctx context.Context, botID, explicit string) string { + if strings.TrimSpace(explicit) != "" { + return explicit + } + if r.queries != nil { + pgBotID, err := parseResolverUUID(botID) + if err == nil { + row, err := r.queries.GetContainerByBotID(ctx, pgBotID) + if err == nil && strings.TrimSpace(row.ContainerID) != "" { + return row.ContainerID + } + } + } + return "mcp-" + botID +} + +// --- message loading --- + +func (r *Resolver) loadMessages(ctx context.Context, chatID string, maxContextMinutes int) ([]ModelMessage, error) { + if r.messageService == nil { + return nil, nil + } + since := time.Now().UTC().Add(-time.Duration(maxContextMinutes) * time.Minute) + msgs, err := r.messageService.ListSince(ctx, chatID, since) + if err != nil { + return nil, err + } + var result []ModelMessage + for _, m := range msgs { + var mm ModelMessage + if err := json.Unmarshal(m.Content, &mm); err != nil { + // Fallback: treat content as text string. + mm = ModelMessage{Role: m.Role, Content: m.Content} + } else { + mm.Role = m.Role + } + result = append(result, mm) + } + return result, nil +} + +type memoryContextItem struct { + Namespace string + Item memory.MemoryItem +} + +func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest) *ModelMessage { + if r.memoryService == nil { + return nil + } + if strings.TrimSpace(req.Query) == "" || strings.TrimSpace(req.BotID) == "" || strings.TrimSpace(req.ChatID) == "" { + return nil + } + + results := make([]memoryContextItem, 0, memoryContextLimitPerScope) + seen := map[string]struct{}{} + resp, err := r.memoryService.Search(ctx, memory.SearchRequest{ + Query: req.Query, + BotID: req.BotID, + Limit: memoryContextLimitPerScope, + Filters: map[string]any{ + "namespace": sharedMemoryNamespace, + "scopeId": req.BotID, + "botId": req.BotID, + }, + }) + if err != nil { + r.logger.Warn("memory search for context failed", + slog.String("namespace", sharedMemoryNamespace), + slog.Any("error", err), + ) + return nil + } + for _, item := range resp.Results { + key := strings.TrimSpace(item.ID) + if key == "" { + key = sharedMemoryNamespace + ":" + strings.TrimSpace(item.Memory) + } + if key == "" { + continue + } + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + results = append(results, memoryContextItem{Namespace: sharedMemoryNamespace, Item: item}) + } + if len(results) == 0 { + return nil + } + + sort.Slice(results, func(i, j int) bool { + return results[i].Item.Score > results[j].Item.Score + }) + if len(results) > memoryContextMaxItems { + results = results[:memoryContextMaxItems] + } + + var sb strings.Builder + sb.WriteString("Relevant memory context (use when helpful):\n") + for _, entry := range results { + text := strings.TrimSpace(entry.Item.Memory) + if text == "" { + continue + } + sb.WriteString("- [") + sb.WriteString(entry.Namespace) + sb.WriteString("] ") + sb.WriteString(truncateMemorySnippet(text, memoryContextItemMaxChars)) + sb.WriteString("\n") + } + payload := strings.TrimSpace(sb.String()) + if payload == "" { + return nil + } + msg := ModelMessage{ + Role: "system", + Content: NewTextContent(payload), + } + return &msg +} + +// --- store helpers --- + +func (r *Resolver) persistUserMessage(ctx context.Context, req ChatRequest) error { + if r.messageService == nil { + return nil + } + if strings.TrimSpace(req.BotID) == "" { + return fmt.Errorf("bot id is required for persistence") + } + text := strings.TrimSpace(req.Query) + if text == "" { + return nil + } + + message := ModelMessage{ + Role: "user", + Content: NewTextContent(text), + } + content, err := json.Marshal(message) + if err != nil { + return err + } + senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req) + _, err = r.messageService.Persist(ctx, messagepkg.PersistInput{ + BotID: req.BotID, + RouteID: req.RouteID, + SenderChannelIdentityID: senderChannelIdentityID, + SenderUserID: senderUserID, + Platform: req.CurrentChannel, + ExternalMessageID: req.ExternalMessageID, + Role: "user", + Content: content, + Metadata: buildRouteMetadata(req), + }) + return err +} + +func (r *Resolver) storeRound(ctx context.Context, req ChatRequest, messages []ModelMessage) error { + // Add user query as the first message if not already present in the round. + // This ensures the user's prompt is persisted alongside the assistant's response. + fullRound := make([]ModelMessage, 0, len(messages)+1) + hasUserQuery := false + for _, m := range messages { + if m.Role == "user" && m.TextContent() == req.Query { + hasUserQuery = true + break + } + } + if !req.UserMessagePersisted && !hasUserQuery && strings.TrimSpace(req.Query) != "" { + fullRound = append(fullRound, ModelMessage{ + Role: "user", + Content: NewTextContent(req.Query), + }) + } + for _, m := range messages { + if req.UserMessagePersisted && m.Role == "user" && strings.TrimSpace(m.TextContent()) == strings.TrimSpace(req.Query) { + // User message was already persisted before streaming; skip duplicate copy in round payload. + continue + } + fullRound = append(fullRound, m) + } + if len(fullRound) == 0 { + return nil + } + + r.storeMessages(ctx, req, fullRound) + r.storeMemory(ctx, req.BotID, fullRound) + return nil +} + +func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages []ModelMessage) { + if r.messageService == nil { + return + } + if strings.TrimSpace(req.BotID) == "" { + return + } + meta := buildRouteMetadata(req) + senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req) + for _, msg := range messages { + content, err := json.Marshal(msg) + if err != nil { + continue + } + messageSenderChannelIdentityID := "" + messageSenderUserID := "" + externalMessageID := "" + sourceReplyToMessageID := "" + if msg.Role == "user" { + messageSenderChannelIdentityID = senderChannelIdentityID + messageSenderUserID = senderUserID + externalMessageID = req.ExternalMessageID + } else if strings.TrimSpace(req.ExternalMessageID) != "" { + // Assistant/tool/system outputs are linked to the inbound source message for cross-channel reply threading. + sourceReplyToMessageID = req.ExternalMessageID + } + if _, err := r.messageService.Persist(ctx, messagepkg.PersistInput{ + BotID: req.BotID, + RouteID: req.RouteID, + SenderChannelIdentityID: messageSenderChannelIdentityID, + SenderUserID: messageSenderUserID, + Platform: req.CurrentChannel, + ExternalMessageID: externalMessageID, + SourceReplyToMessageID: sourceReplyToMessageID, + Role: msg.Role, + Content: content, + Metadata: meta, + }); err != nil { + r.logger.Warn("persist message failed", slog.Any("error", err)) + } + } +} + +func buildRouteMetadata(req ChatRequest) map[string]any { + if strings.TrimSpace(req.RouteID) == "" && strings.TrimSpace(req.CurrentChannel) == "" { + return nil + } + meta := map[string]any{} + if strings.TrimSpace(req.RouteID) != "" { + meta["route_id"] = req.RouteID + } + if strings.TrimSpace(req.CurrentChannel) != "" { + meta["platform"] = req.CurrentChannel + } + return meta +} + +func (r *Resolver) resolvePersistSenderIDs(ctx context.Context, req ChatRequest) (string, string) { + channelIdentityID := strings.TrimSpace(req.SourceChannelIdentityID) + userID := strings.TrimSpace(req.UserID) + + channelIdentityValid := r.isExistingChannelIdentityID(ctx, channelIdentityID) + userAsUserValid := r.isExistingUserID(ctx, userID) + userAsChannelIdentityValid := r.isExistingChannelIdentityID(ctx, userID) + + senderChannelIdentityID := "" + switch { + case channelIdentityValid: + senderChannelIdentityID = channelIdentityID + case userAsChannelIdentityValid && !userAsUserValid: + // Some flows may carry channel_identity_id in req.UserID. + senderChannelIdentityID = userID + } + + senderUserID := "" + if userAsUserValid { + senderUserID = userID + } + if senderUserID == "" && senderChannelIdentityID != "" { + if linked := r.linkedUserIDFromChannelIdentity(ctx, senderChannelIdentityID); linked != "" { + senderUserID = linked + } + } + return senderChannelIdentityID, senderUserID +} + +func (r *Resolver) isExistingChannelIdentityID(ctx context.Context, id string) bool { + if r.queries == nil { + return false + } + pgID, err := parseResolverUUID(id) + if err != nil { + return false + } + _, err = r.queries.GetChannelIdentityByID(ctx, pgID) + return err == nil +} + +func (r *Resolver) isExistingUserID(ctx context.Context, id string) bool { + if r.queries == nil { + return false + } + pgID, err := parseResolverUUID(id) + if err != nil { + return false + } + _, err = r.queries.GetUserByID(ctx, pgID) + return err == nil +} + +func (r *Resolver) linkedUserIDFromChannelIdentity(ctx context.Context, channelIdentityID string) string { + if r.queries == nil { + return "" + } + pgID, err := parseResolverUUID(channelIdentityID) + if err != nil { + return "" + } + row, err := r.queries.GetChannelIdentityByID(ctx, pgID) + if err != nil || !row.UserID.Valid { + return "" + } + return row.UserID.String() +} + +func (r *Resolver) storeMemory(ctx context.Context, botID string, messages []ModelMessage) { + if r.memoryService == nil { + return + } + if strings.TrimSpace(botID) == "" { + return + } + memMsgs := make([]memory.Message, 0, len(messages)) + for _, msg := range messages { + text := strings.TrimSpace(msg.TextContent()) + if text == "" { + continue + } + role := msg.Role + if strings.TrimSpace(role) == "" { + role = "assistant" + } + memMsgs = append(memMsgs, memory.Message{Role: role, Content: text}) + } + if len(memMsgs) == 0 { + return + } + r.addMemory(ctx, botID, memMsgs, sharedMemoryNamespace, botID) +} + +func (r *Resolver) addMemory(ctx context.Context, botID string, msgs []memory.Message, namespace, scopeID string) { + filters := map[string]any{ + "namespace": namespace, + "scopeId": scopeID, + "botId": botID, + } + if _, err := r.memoryService.Add(ctx, memory.AddRequest{ + Messages: msgs, + BotID: botID, + Filters: filters, + }); err != nil { + r.logger.Warn("store memory failed", + slog.String("namespace", namespace), + slog.String("scope_id", scopeID), + slog.Any("error", err), + ) + } +} + +// --- model selection --- + +func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, botSettings settings.Settings, us resolvedUserSettings, cs Settings) (models.GetResponse, sqlc.LlmProvider, error) { + if r.modelsService == nil { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured") + } + modelID := strings.TrimSpace(req.Model) + providerFilter := strings.TrimSpace(req.Provider) + + // Priority: request model > chat settings > bot settings > user settings. + if modelID == "" && providerFilter == "" { + if value := strings.TrimSpace(cs.ModelID); value != "" { + modelID = value + } else if value := strings.TrimSpace(botSettings.ChatModelID); value != "" { + modelID = value + } else if value := strings.TrimSpace(us.ChatModelID); value != "" { + modelID = value + } + } + + if modelID == "" { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not configured: specify model in request or bot settings") + } + + if providerFilter == "" { + return r.fetchChatModel(ctx, modelID) + } + + candidates, err := r.listCandidates(ctx, providerFilter) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + for _, m := range candidates { + if m.ModelID == modelID { + prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + return m, prov, nil + } + } + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model %q not found for provider %q", modelID, providerFilter) +} + +func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) { + model, err := r.modelsService.GetByModelID(ctx, modelID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + if model.Type != models.ModelTypeChat { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model") + } + prov, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + return model, prov, nil +} + +func (r *Resolver) listCandidates(ctx context.Context, providerFilter string) ([]models.GetResponse, error) { + var all []models.GetResponse + var err error + if providerFilter != "" { + all, err = r.modelsService.ListByClientType(ctx, models.ClientType(providerFilter)) + } else { + all, err = r.modelsService.ListByType(ctx, models.ModelTypeChat) + } + if err != nil { + return nil, err + } + filtered := make([]models.GetResponse, 0, len(all)) + for _, m := range all { + if m.Type == models.ModelTypeChat { + filtered = append(filtered, m) + } + } + return filtered, nil +} + +// --- settings --- + +type resolvedUserSettings struct { + ChatModelID string +} + +func (r *Resolver) loadUserSettings(ctx context.Context, userID string) (resolvedUserSettings, error) { + if r.settingsService == nil || strings.TrimSpace(userID) == "" { + return resolvedUserSettings{}, nil + } + s, err := r.settingsService.Get(ctx, userID) + if err != nil { + return resolvedUserSettings{}, err + } + return resolvedUserSettings{ + ChatModelID: strings.TrimSpace(s.ChatModelID), + }, nil +} + +func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (settings.Settings, error) { + if r.settingsService == nil { + return settings.Settings{ + MaxContextLoadTime: settings.DefaultMaxContextLoadTime, + Language: settings.DefaultLanguage, + }, nil + } + return r.settingsService.GetBot(ctx, botID) +} + +// --- utility --- + +func normalizeClientType(clientType string) (string, error) { + switch strings.ToLower(strings.TrimSpace(clientType)) { + case "openai", "openai-compat": + return "openai", nil + case "anthropic": + return "anthropic", nil + case "google": + return "google", nil + default: + return "", fmt.Errorf("unsupported agent gateway client type: %s", clientType) + } +} + +func sanitizeMessages(messages []ModelMessage) []ModelMessage { + cleaned := make([]ModelMessage, 0, len(messages)) + for _, msg := range messages { + if strings.TrimSpace(msg.Role) == "" { + continue + } + if !msg.HasContent() && strings.TrimSpace(msg.ToolCallID) == "" { + continue + } + cleaned = append(cleaned, msg) + } + return cleaned +} + +func dedup(items []string) []string { + seen := make(map[string]struct{}, len(items)) + result := make([]string, 0, len(items)) + for _, s := range items { + trimmed := strings.TrimSpace(s) + if trimmed == "" { + continue + } + if _, ok := seen[trimmed]; ok { + continue + } + seen[trimmed] = struct{}{} + result = append(result, trimmed) + } + return result +} + +func firstNonEmpty(values ...string) string { + for _, v := range values { + if strings.TrimSpace(v) != "" { + return v + } + } + return "" +} + +func coalescePositiveInt(values ...int) int { + for _, v := range values { + if v > 0 { + return v + } + } + return defaultMaxContextMinutes +} + +func nonNilStrings(s []string) []string { + if s == nil { + return []string{} + } + return s +} + +func nonNilModelMessages(m []ModelMessage) []ModelMessage { + if m == nil { + return []ModelMessage{} + } + return m +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +} + +func truncateMemorySnippet(s string, n int) string { + trimmed := strings.TrimSpace(s) + if len(trimmed) <= n { + return trimmed + } + return strings.TrimSpace(trimmed[:n]) + "..." +} + +func parseResolverUUID(id string) (pgtype.UUID, error) { + if strings.TrimSpace(id) == "" { + return pgtype.UUID{}, fmt.Errorf("empty id") + } + return db.ParseUUID(id) +} diff --git a/internal/conversation/service_domain.go b/internal/conversation/service_domain.go new file mode 100644 index 00000000..8e0f0099 --- /dev/null +++ b/internal/conversation/service_domain.go @@ -0,0 +1,483 @@ +package conversation + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + + dbpkg "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" +) + +var ( + ErrChatNotFound = errors.New("chat not found") + ErrNotParticipant = errors.New("not a participant") + ErrPermissionDenied = errors.New("permission denied") +) + +// Service manages conversation lifecycle, participants, and settings. +type Service struct { + queries *sqlc.Queries + logger *slog.Logger +} + +// NewService creates a conversation service. +func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { + if log == nil { + log = slog.Default() + } + return &Service{ + queries: queries, + logger: log.With(slog.String("service", "conversation")), + } +} + +// Create creates a new conversation and adds the creator as owner. +func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, req CreateRequest) (Chat, error) { + kind := strings.TrimSpace(req.Kind) + if kind == "" { + kind = KindDirect + } + if kind != KindDirect && kind != KindGroup && kind != KindThread { + return Chat{}, fmt.Errorf("invalid conversation kind: %s", kind) + } + + pgBotID, err := parseUUID(botID) + if err != nil { + return Chat{}, fmt.Errorf("invalid bot id: %w", err) + } + pgChannelIdentityID := pgtype.UUID{} + if strings.TrimSpace(channelIdentityID) != "" { + pgChannelIdentityID, err = parseUUID(channelIdentityID) + if err != nil { + return Chat{}, fmt.Errorf("invalid channel identity id: %w", err) + } + } + + var pgParent pgtype.UUID + if kind == KindThread && strings.TrimSpace(req.ParentChatID) != "" { + pgParent, err = parseUUID(req.ParentChatID) + if err != nil { + return Chat{}, fmt.Errorf("invalid parent conversation id: %w", err) + } + } + + metadata, err := json.Marshal(nonNilMap(req.Metadata)) + if err != nil { + return Chat{}, fmt.Errorf("marshal conversation metadata: %w", err) + } + + row, err := s.queries.CreateChat(ctx, sqlc.CreateChatParams{ + BotID: pgBotID, + Kind: kind, + ParentChatID: pgParent, + Title: strings.TrimSpace(req.Title), + CreatedByUserID: pgChannelIdentityID, + Metadata: metadata, + }) + if err != nil { + return Chat{}, fmt.Errorf("create conversation: %w", err) + } + + // Add creator as owner when the channel identity is available. + if pgChannelIdentityID.Valid { + if _, err := s.queries.AddChatParticipant(ctx, sqlc.AddChatParticipantParams{ + ChatID: row.ID, + UserID: pgChannelIdentityID, + Role: RoleOwner, + }); err != nil { + return Chat{}, fmt.Errorf("add owner participant: %w", err) + } + } + + // For threads, copy parent participants. + if kind == KindThread && pgParent.Valid { + if err := s.queries.CopyParticipantsToChat(ctx, sqlc.CopyParticipantsToChatParams{ + ChatID: pgParent, + ChatID2: row.ID, + }); err != nil && s.logger != nil { + s.logger.Warn("copy parent participants failed", slog.Any("error", err)) + } + } + + return toChatFromCreate(row), nil +} + +// Get returns a conversation by ID. +func (s *Service) Get(ctx context.Context, conversationID string) (Chat, error) { + pgID, err := parseUUID(conversationID) + if err != nil { + return Chat{}, ErrChatNotFound + } + row, err := s.queries.GetChatByID(ctx, pgID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return Chat{}, ErrChatNotFound + } + return Chat{}, err + } + return toChatFromGet(row), nil +} + +// GetReadAccess resolves whether a user can read a conversation. +func (s *Service) GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (ChatReadAccess, error) { + pgConversationID, err := parseUUID(conversationID) + if err != nil { + return ChatReadAccess{}, ErrPermissionDenied + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return ChatReadAccess{}, ErrPermissionDenied + } + row, err := s.queries.GetChatReadAccessByUser(ctx, sqlc.GetChatReadAccessByUserParams{ + ChatID: pgConversationID, + UserID: pgChannelIdentityID, + }) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ChatReadAccess{}, ErrPermissionDenied + } + return ChatReadAccess{}, err + } + return ChatReadAccess{ + AccessMode: row.AccessMode, + ParticipantRole: strings.TrimSpace(row.ParticipantRole), + LastObservedAt: pgTimePtr(row.LastObservedAt), + }, nil +} + +// ListByBotAndChannelIdentity returns all visible conversations for a bot and channel identity. +func (s *Service) ListByBotAndChannelIdentity(ctx context.Context, botID, channelIdentityID string) ([]ChatListItem, error) { + pgBotID, err := parseUUID(botID) + if err != nil { + return nil, err + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListVisibleChatsByBotAndUser(ctx, sqlc.ListVisibleChatsByBotAndUserParams{ + BotID: pgBotID, + UserID: pgChannelIdentityID, + }) + if err != nil { + return nil, err + } + conversations := make([]ChatListItem, 0, len(rows)) + for _, row := range rows { + conversations = append(conversations, toChatListItem(row)) + } + return conversations, nil +} + +// ListThreads returns threads for a parent conversation. +func (s *Service) ListThreads(ctx context.Context, parentConversationID string) ([]Chat, error) { + pgID, err := parseUUID(parentConversationID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListThreadsByParent(ctx, pgID) + if err != nil { + return nil, err + } + conversations := make([]Chat, 0, len(rows)) + for _, row := range rows { + conversations = append(conversations, toChatFromThread(row)) + } + return conversations, nil +} + +// Delete deletes a conversation and linked records. +func (s *Service) Delete(ctx context.Context, conversationID string) error { + pgID, err := parseUUID(conversationID) + if err != nil { + return ErrChatNotFound + } + return s.queries.DeleteChat(ctx, pgID) +} + +// AddParticipant adds a channel identity to a conversation. +func (s *Service) AddParticipant(ctx context.Context, conversationID, channelIdentityID, role string) (Participant, error) { + pgConversationID, err := parseUUID(conversationID) + if err != nil { + return Participant{}, err + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return Participant{}, err + } + if role == "" { + role = RoleMember + } + row, err := s.queries.AddChatParticipant(ctx, sqlc.AddChatParticipantParams{ + ChatID: pgConversationID, + UserID: pgChannelIdentityID, + Role: role, + }) + if err != nil { + return Participant{}, err + } + return toParticipantFromAdd(row), nil +} + +// GetParticipant returns a conversation participant. +func (s *Service) GetParticipant(ctx context.Context, conversationID, channelIdentityID string) (Participant, error) { + pgConversationID, err := parseUUID(conversationID) + if err != nil { + return Participant{}, ErrNotParticipant + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return Participant{}, ErrNotParticipant + } + row, err := s.queries.GetChatParticipant(ctx, sqlc.GetChatParticipantParams{ + ChatID: pgConversationID, + UserID: pgChannelIdentityID, + }) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return Participant{}, ErrNotParticipant + } + return Participant{}, err + } + return toParticipantFromGet(row), nil +} + +// IsParticipant checks whether a channel identity is a participant. +func (s *Service) IsParticipant(ctx context.Context, conversationID, channelIdentityID string) (bool, error) { + _, err := s.GetParticipant(ctx, conversationID, channelIdentityID) + if errors.Is(err, ErrNotParticipant) { + return false, nil + } + return err == nil, err +} + +// ListParticipants returns all participants for a conversation. +func (s *Service) ListParticipants(ctx context.Context, conversationID string) ([]Participant, error) { + pgID, err := parseUUID(conversationID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListChatParticipants(ctx, pgID) + if err != nil { + return nil, err + } + participants := make([]Participant, 0, len(rows)) + for _, row := range rows { + participants = append(participants, toParticipantFromList(row)) + } + return participants, nil +} + +// RemoveParticipant removes a participant from a conversation. +func (s *Service) RemoveParticipant(ctx context.Context, conversationID, channelIdentityID string) error { + pgConversationID, err := parseUUID(conversationID) + if err != nil { + return err + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return err + } + return s.queries.RemoveChatParticipant(ctx, sqlc.RemoveChatParticipantParams{ + ChatID: pgConversationID, + UserID: pgChannelIdentityID, + }) +} + +// GetSettings returns conversation settings and falls back to defaults when missing. +func (s *Service) GetSettings(ctx context.Context, conversationID string) (Settings, error) { + pgID, err := parseUUID(conversationID) + if err != nil { + return defaultSettings(conversationID), nil + } + row, err := s.queries.GetChatSettings(ctx, pgID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return defaultSettings(conversationID), nil + } + return Settings{}, err + } + return toSettingsFromRead(row), nil +} + +// UpdateSettings updates conversation settings. +func (s *Service) UpdateSettings(ctx context.Context, conversationID string, req UpdateSettingsRequest) (Settings, error) { + current, err := s.GetSettings(ctx, conversationID) + if err != nil { + return Settings{}, err + } + if req.ModelID != nil { + current.ModelID = *req.ModelID + } + + pgID, err := parseUUID(conversationID) + if err != nil { + return Settings{}, err + } + row, err := s.queries.UpsertChatSettings(ctx, sqlc.UpsertChatSettingsParams{ + ID: pgID, + ModelID: toPgText(current.ModelID), + }) + if err != nil { + return Settings{}, err + } + return toSettingsFromUpsert(row), nil +} + +func toChatFromCreate(row sqlc.CreateChatRow) Chat { + return toChatFields( + row.ID, + row.BotID, + row.Kind, + row.ParentChatID, + row.Title, + row.CreatedByUserID, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toChatFromGet(row sqlc.GetChatByIDRow) Chat { + return toChatFields( + row.ID, + row.BotID, + row.Kind, + row.ParentChatID, + row.Title, + row.CreatedByUserID, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toChatFromThread(row sqlc.ListThreadsByParentRow) Chat { + return toChatFields( + row.ID, + row.BotID, + row.Kind, + row.ParentChatID, + row.Title, + row.CreatedByUserID, + row.Metadata, + row.CreatedAt, + row.UpdatedAt, + ) +} + +func toChatFields(id, botID pgtype.UUID, kind string, parentChatID pgtype.UUID, title pgtype.Text, createdBy pgtype.UUID, metadata []byte, createdAt, updatedAt pgtype.Timestamptz) Chat { + return Chat{ + ID: id.String(), + BotID: botID.String(), + Kind: kind, + ParentChatID: parentChatID.String(), + Title: dbpkg.TextToString(title), + CreatedBy: createdBy.String(), + Metadata: parseJSONMap(metadata), + CreatedAt: createdAt.Time, + UpdatedAt: updatedAt.Time, + } +} + +func toChatListItem(row sqlc.ListVisibleChatsByBotAndUserRow) ChatListItem { + return ChatListItem{ + ID: row.ID.String(), + BotID: row.BotID.String(), + Kind: row.Kind, + ParentChatID: row.ParentChatID.String(), + Title: dbpkg.TextToString(row.Title), + CreatedBy: row.CreatedByUserID.String(), + Metadata: parseJSONMap(row.Metadata), + CreatedAt: row.CreatedAt.Time, + UpdatedAt: row.UpdatedAt.Time, + AccessMode: row.AccessMode, + ParticipantRole: strings.TrimSpace(row.ParticipantRole), + LastObservedAt: pgTimePtr(row.LastObservedAt), + } +} + +func toParticipantFromAdd(row sqlc.AddChatParticipantRow) Participant { + return toParticipantFields(row.ChatID, row.UserID, row.Role, row.JoinedAt) +} + +func toParticipantFromGet(row sqlc.GetChatParticipantRow) Participant { + return toParticipantFields(row.ChatID, row.UserID, row.Role, row.JoinedAt) +} + +func toParticipantFromList(row sqlc.ListChatParticipantsRow) Participant { + return toParticipantFields(row.ChatID, row.UserID, row.Role, row.JoinedAt) +} + +func toParticipantFields(conversationID, userID pgtype.UUID, role string, joinedAt pgtype.Timestamptz) Participant { + return Participant{ + ChatID: conversationID.String(), + UserID: userID.String(), + Role: role, + JoinedAt: joinedAt.Time, + } +} + +func toSettingsFromRead(row sqlc.GetChatSettingsRow) Settings { + return Settings{ + ChatID: row.ChatID.String(), + ModelID: dbpkg.TextToString(row.ModelID), + } +} + +func toSettingsFromUpsert(row sqlc.UpsertChatSettingsRow) Settings { + return Settings{ + ChatID: row.ChatID.String(), + ModelID: dbpkg.TextToString(row.ModelID), + } +} + +func defaultSettings(conversationID string) Settings { + return Settings{ + ChatID: conversationID, + } +} + +func parseUUID(id string) (pgtype.UUID, error) { + return dbpkg.ParseUUID(id) +} + +func toPgText(s string) pgtype.Text { + s = strings.TrimSpace(s) + if s == "" { + return pgtype.Text{} + } + return pgtype.Text{String: s, Valid: true} +} + +func pgTimePtr(ts pgtype.Timestamptz) *time.Time { + if !ts.Valid { + return nil + } + value := ts.Time + return &value +} + +func nonNilMap(m map[string]any) map[string]any { + if m == nil { + return map[string]any{} + } + return m +} + +func parseJSONMap(data []byte) map[string]any { + if len(data) == 0 { + return nil + } + var m map[string]any + _ = json.Unmarshal(data, &m) + return m +} diff --git a/internal/conversation/service_presence_integration_test.go b/internal/conversation/service_presence_integration_test.go new file mode 100644 index 00000000..fef12f72 --- /dev/null +++ b/internal/conversation/service_presence_integration_test.go @@ -0,0 +1,242 @@ +package conversation_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/memohai/memoh/internal/channel/identities" + conversation "github.com/memohai/memoh/internal/conversation" + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/message" +) + +type chatPresenceFixture struct { + chatSvc *conversation.Service + messageSvc message.Service + channelIdentitySvc *identities.Service + queries *sqlc.Queries + cleanup func() +} + +func setupChatPresenceIntegrationTest(t *testing.T) chatPresenceFixture { + t.Helper() + + dsn := os.Getenv("TEST_POSTGRES_DSN") + if dsn == "" { + t.Skip("skip integration test: TEST_POSTGRES_DSN is not set") + } + + ctx := context.Background() + pool, err := pgxpool.New(ctx, dsn) + if err != nil { + t.Skipf("skip integration test: cannot connect to database: %v", err) + } + if err := pool.Ping(ctx); err != nil { + pool.Close() + t.Skipf("skip integration test: database ping failed: %v", err) + } + + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + queries := sqlc.New(pool) + + return chatPresenceFixture{ + chatSvc: conversation.NewService(logger, queries), + messageSvc: message.NewService(logger, queries), + channelIdentitySvc: identities.NewService(logger, queries), + queries: queries, + cleanup: func() { pool.Close() }, + } +} + +func createUserForChatPresence(ctx context.Context, queries *sqlc.Queries) (string, error) { + row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{ + IsActive: true, + Metadata: []byte("{}"), + }) + if err != nil { + return "", err + } + return row.ID.String(), nil +} + +func createBotForChatPresence(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) { + pgOwnerID, err := db.ParseUUID(ownerUserID) + if err != nil { + return "", err + } + meta, err := json.Marshal(map[string]any{"source": "chat-presence-integration-test"}) + if err != nil { + return "", err + } + row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{ + OwnerUserID: pgOwnerID, + Type: "personal", + DisplayName: pgtype.Text{String: "presence-test-bot", Valid: true}, + IsActive: true, + Metadata: meta, + }) + if err != nil { + return "", err + } + return row.ID.String(), nil +} + +func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, string, string, string) { + t.Helper() + + fixture := setupChatPresenceIntegrationTest(t) + ctx := context.Background() + + ownerUserID, err := createUserForChatPresence(ctx, fixture.queries) + if err != nil { + fixture.cleanup() + t.Fatalf("create owner user failed: %v", err) + } + observerUserID, err := createUserForChatPresence(ctx, fixture.queries) + if err != nil { + fixture.cleanup() + t.Fatalf("create observer user failed: %v", err) + } + botID, err := createBotForChatPresence(ctx, fixture.queries, ownerUserID) + if err != nil { + fixture.cleanup() + t.Fatalf("create bot failed: %v", err) + } + + createdChat, err := fixture.chatSvc.Create(ctx, botID, ownerUserID, conversation.CreateRequest{ + Kind: conversation.KindGroup, + Title: "presence-observed", + }) + if err != nil { + fixture.cleanup() + t.Fatalf("create chat failed: %v", err) + } + + observedChannelIdentity, err := fixture.channelIdentitySvc.ResolveByChannelIdentity( + ctx, + "feishu", + fmt.Sprintf("presence-channelIdentity-%d", time.Now().UnixNano()), + "presence-observer", + ) + if err != nil { + fixture.cleanup() + t.Fatalf("resolve channelIdentity failed: %v", err) + } + + _, err = fixture.messageSvc.Persist(ctx, message.PersistInput{ + BotID: botID, + SenderChannelIdentityID: observedChannelIdentity.ID, + Platform: "feishu", + ExternalMessageID: fmt.Sprintf("ext-msg-%d", time.Now().UnixNano()), + Role: "user", + Content: []byte(`{"content":"hello from observed channelIdentity"}`), + }) + if err != nil { + fixture.cleanup() + t.Fatalf("persist message failed: %v", err) + } + + return fixture, botID, createdChat.ID, observerUserID, observedChannelIdentity.ID +} + +func TestObservedChatVisibleAfterBindWithoutBackfill(t *testing.T) { + fixture, botID, chatID, observerUserID, observedChannelIdentityID := setupObservedChatScenario(t) + defer fixture.cleanup() + + ctx := context.Background() + beforeBind, err := fixture.chatSvc.ListByBotAndChannelIdentity(ctx, botID, observerUserID) + if err != nil { + t.Fatalf("list chats before bind failed: %v", err) + } + if len(beforeBind) != 0 { + t.Fatalf("expected no visible chats before bind, got %d", len(beforeBind)) + } + + if err := fixture.channelIdentitySvc.LinkChannelIdentityToUser(ctx, observedChannelIdentityID, observerUserID); err != nil { + t.Fatalf("link channelIdentity to user failed: %v", err) + } + + afterBind, err := fixture.chatSvc.ListByBotAndChannelIdentity(ctx, botID, observerUserID) + if err != nil { + t.Fatalf("list chats after bind failed: %v", err) + } + if len(afterBind) == 0 { + t.Fatalf("expected observed chat visible after bind, got %d chats", len(afterBind)) + } + + var target *conversation.ChatListItem + for i := range afterBind { + if afterBind[i].ID == chatID { + target = &afterBind[i] + break + } + } + if target == nil { + t.Fatalf("expected chat %s in visible list after bind", chatID) + } + if target.AccessMode != conversation.AccessModeChannelIdentityObserved { + t.Fatalf("expected access_mode=%s, got %s", conversation.AccessModeChannelIdentityObserved, target.AccessMode) + } + if target.ParticipantRole != "" { + t.Fatalf("expected empty participant_role for observed chat, got %s", target.ParticipantRole) + } + if target.LastObservedAt == nil { + t.Fatal("expected last_observed_at to be set for observed chat") + } +} + +func TestObservedAccessReadableButNotParticipant(t *testing.T) { + fixture, botID, chatID, observerUserID, observedChannelIdentityID := setupObservedChatScenario(t) + defer fixture.cleanup() + + ctx := context.Background() + if err := fixture.channelIdentitySvc.LinkChannelIdentityToUser(ctx, observedChannelIdentityID, observerUserID); err != nil { + t.Fatalf("link channelIdentity to user failed: %v", err) + } + + access, err := fixture.chatSvc.GetReadAccess(ctx, chatID, observerUserID) + if err != nil { + t.Fatalf("get read access failed: %v", err) + } + if access.AccessMode != conversation.AccessModeChannelIdentityObserved { + t.Fatalf("expected read access %s, got %s", conversation.AccessModeChannelIdentityObserved, access.AccessMode) + } + + messages, err := fixture.messageSvc.List(ctx, chatID) + if err != nil { + t.Fatalf("list messages failed: %v", err) + } + if len(messages) == 0 { + t.Fatal("expected observed user can read chat messages") + } + + _, err = fixture.chatSvc.GetParticipant(ctx, chatID, observerUserID) + if !errors.Is(err, conversation.ErrNotParticipant) { + t.Fatalf("expected ErrNotParticipant for observed user, got %v", err) + } + ok, err := fixture.chatSvc.IsParticipant(ctx, chatID, observerUserID) + if err != nil { + t.Fatalf("check participant failed: %v", err) + } + if ok { + t.Fatal("expected observed user to remain non-participant") + } + + visibleChats, err := fixture.chatSvc.ListByBotAndChannelIdentity(ctx, botID, observerUserID) + if err != nil { + t.Fatalf("list visible chats failed: %v", err) + } + if len(visibleChats) == 0 || visibleChats[0].AccessMode != conversation.AccessModeChannelIdentityObserved { + t.Fatal("expected observed list entry with channel_identity_observed access mode") + } +} diff --git a/internal/conversation/types.go b/internal/conversation/types.go new file mode 100644 index 00000000..5a10738b --- /dev/null +++ b/internal/conversation/types.go @@ -0,0 +1,235 @@ +// Package conversation defines conversation domain types and rules. +package conversation + +import ( + "encoding/json" + "strings" + "time" +) + +// Chat kind constants. +const ( + KindDirect = "direct" + KindGroup = "group" + KindThread = "thread" +) + +// Participant role constants. +const ( + RoleOwner = "owner" + RoleAdmin = "admin" + RoleMember = "member" +) + +// Chat list access mode constants. +const ( + AccessModeParticipant = "participant" + AccessModeChannelIdentityObserved = "channel_identity_observed" +) + +// Conversation is the first-class conversation container. +type Conversation struct { + ID string `json:"id"` + BotID string `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID string `json:"parent_chat_id,omitempty"` + Title string `json:"title,omitempty"` + CreatedBy string `json:"created_by"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// ConversationListItem is a conversation entry with access context for list rendering. +type ConversationListItem struct { + ID string `json:"id"` + BotID string `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID string `json:"parent_chat_id,omitempty"` + Title string `json:"title,omitempty"` + CreatedBy string `json:"created_by"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + AccessMode string `json:"access_mode"` + ParticipantRole string `json:"participant_role,omitempty"` + LastObservedAt *time.Time `json:"last_observed_at,omitempty"` +} + +// ConversationReadAccess is the resolved access context for reading conversation content. +type ConversationReadAccess struct { + AccessMode string + ParticipantRole string + LastObservedAt *time.Time +} + +// Backward-compatible aliases while call sites migrate. +type Chat = Conversation +type ChatListItem = ConversationListItem +type ChatReadAccess = ConversationReadAccess + +// Participant represents a chat member. +type Participant struct { + ChatID string `json:"chat_id"` + UserID string `json:"user_id"` + Role string `json:"role"` + JoinedAt time.Time `json:"joined_at"` +} + +// Settings holds per-chat configuration. +type Settings struct { + ChatID string `json:"chat_id"` + ModelID string `json:"model_id,omitempty"` +} + +// CreateRequest is the input for creating a bot-scoped conversation container. +type CreateRequest struct { + Kind string `json:"kind"` + Title string `json:"title,omitempty"` + ParentChatID string `json:"parent_chat_id,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// UpdateSettingsRequest is the input for updating chat settings. +type UpdateSettingsRequest struct { + ModelID *string `json:"model_id,omitempty"` +} + +// ModelMessage is the canonical message format exchanged with the agent gateway. +// Aligned with Vercel AI SDK ModelMessage structure. +type ModelMessage struct { + Role string `json:"role"` + Content json.RawMessage `json:"content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + Name string `json:"name,omitempty"` +} + +// TextContent extracts the plain text from the message content. +// If content is a string, it returns it directly. +// If content is an array of parts, it joins all text-type parts. +func (m ModelMessage) TextContent() string { + if len(m.Content) == 0 { + return "" + } + var s string + if err := json.Unmarshal(m.Content, &s); err == nil { + return s + } + var parts []ContentPart + if err := json.Unmarshal(m.Content, &parts); err == nil { + texts := make([]string, 0, len(parts)) + for _, p := range parts { + if strings.TrimSpace(p.Text) != "" { + texts = append(texts, p.Text) + } + } + return strings.Join(texts, "\n") + } + return "" +} + +// ContentParts parses the content as an array of ContentPart. +// Returns nil if the content is a plain string or not parseable. +func (m ModelMessage) ContentParts() []ContentPart { + if len(m.Content) == 0 { + return nil + } + var parts []ContentPart + if err := json.Unmarshal(m.Content, &parts); err != nil { + return nil + } + return parts +} + +// HasContent reports whether the message carries non-empty content or tool calls. +func (m ModelMessage) HasContent() bool { + if strings.TrimSpace(m.TextContent()) != "" { + return true + } + if len(m.ContentParts()) > 0 { + return true + } + return len(m.ToolCalls) > 0 +} + +// NewTextContent creates a json.RawMessage from a plain string. +func NewTextContent(text string) json.RawMessage { + data, _ := json.Marshal(text) + return data +} + +// ContentPart represents one element of a multi-part message content. +type ContentPart struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + URL string `json:"url,omitempty"` + Styles []string `json:"styles,omitempty"` + Language string `json:"language,omitempty"` + ChannelIdentityID string `json:"channel_identity_id,omitempty"` + Emoji string `json:"emoji,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// HasValue reports whether the content part carries a meaningful value. +func (p ContentPart) HasValue() bool { + return strings.TrimSpace(p.Text) != "" || + strings.TrimSpace(p.URL) != "" || + strings.TrimSpace(p.Emoji) != "" +} + +// ToolCall represents a function/tool invocation in an assistant message. +type ToolCall struct { + ID string `json:"id,omitempty"` + Type string `json:"type"` + Function ToolCallFunction `json:"function"` +} + +// ToolCallFunction holds the name and serialized arguments of a tool call. +type ToolCallFunction struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// ChatRequest is the input for Chat and StreamChat. +type ChatRequest struct { + BotID string `json:"-"` + ChatID string `json:"-"` + Token string `json:"-"` + UserID string `json:"-"` + SourceChannelIdentityID string `json:"-"` + ContainerID string `json:"-"` + DisplayName string `json:"-"` + RouteID string `json:"-"` + ChatToken string `json:"-"` + ExternalMessageID string `json:"-"` + UserMessagePersisted bool `json:"-"` + + Query string `json:"query"` + Model string `json:"model,omitempty"` + Provider string `json:"provider,omitempty"` + MaxContextLoadTime int `json:"max_context_load_time,omitempty"` + Language string `json:"language,omitempty"` + Channels []string `json:"channels,omitempty"` + CurrentChannel string `json:"current_channel,omitempty"` + Messages []ModelMessage `json:"messages,omitempty"` + Skills []string `json:"skills,omitempty"` + AllowedActions []string `json:"allowed_actions,omitempty"` +} + +// ChatResponse is the output of a non-streaming chat call. +type ChatResponse struct { + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills,omitempty"` + Model string `json:"model,omitempty"` + Provider string `json:"provider,omitempty"` +} + +// StreamChunk is a raw JSON chunk from the streaming response. +type StreamChunk = json.RawMessage + +// AssistantOutput holds extracted assistant content for downstream consumers. +type AssistantOutput struct { + Content string + Parts []ContentPart +} diff --git a/internal/db/sqlc/bind.sql.go b/internal/db/sqlc/bind.sql.go new file mode 100644 index 00000000..39347a54 --- /dev/null +++ b/internal/db/sqlc/bind.sql.go @@ -0,0 +1,120 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: bind.sql + +package sqlc + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createBindCode = `-- name: CreateBindCode :one +INSERT INTO channel_identity_bind_codes (token, issued_by_user_id, channel_type, expires_at) +VALUES ($1, $2, $3, $4) +RETURNING id, token, issued_by_user_id, channel_type, expires_at, used_at, used_by_channel_identity_id, created_at +` + +type CreateBindCodeParams struct { + Token string `json:"token"` + IssuedByUserID pgtype.UUID `json:"issued_by_user_id"` + ChannelType pgtype.Text `json:"channel_type"` + ExpiresAt pgtype.Timestamptz `json:"expires_at"` +} + +func (q *Queries) CreateBindCode(ctx context.Context, arg CreateBindCodeParams) (ChannelIdentityBindCode, error) { + row := q.db.QueryRow(ctx, createBindCode, + arg.Token, + arg.IssuedByUserID, + arg.ChannelType, + arg.ExpiresAt, + ) + var i ChannelIdentityBindCode + err := row.Scan( + &i.ID, + &i.Token, + &i.IssuedByUserID, + &i.ChannelType, + &i.ExpiresAt, + &i.UsedAt, + &i.UsedByChannelIdentityID, + &i.CreatedAt, + ) + return i, err +} + +const getBindCode = `-- name: GetBindCode :one +SELECT id, token, issued_by_user_id, channel_type, expires_at, used_at, used_by_channel_identity_id, created_at +FROM channel_identity_bind_codes +WHERE token = $1 +` + +func (q *Queries) GetBindCode(ctx context.Context, token string) (ChannelIdentityBindCode, error) { + row := q.db.QueryRow(ctx, getBindCode, token) + var i ChannelIdentityBindCode + err := row.Scan( + &i.ID, + &i.Token, + &i.IssuedByUserID, + &i.ChannelType, + &i.ExpiresAt, + &i.UsedAt, + &i.UsedByChannelIdentityID, + &i.CreatedAt, + ) + return i, err +} + +const getBindCodeForUpdate = `-- name: GetBindCodeForUpdate :one +SELECT id, token, issued_by_user_id, channel_type, expires_at, used_at, used_by_channel_identity_id, created_at +FROM channel_identity_bind_codes +WHERE token = $1 +FOR UPDATE +` + +func (q *Queries) GetBindCodeForUpdate(ctx context.Context, token string) (ChannelIdentityBindCode, error) { + row := q.db.QueryRow(ctx, getBindCodeForUpdate, token) + var i ChannelIdentityBindCode + err := row.Scan( + &i.ID, + &i.Token, + &i.IssuedByUserID, + &i.ChannelType, + &i.ExpiresAt, + &i.UsedAt, + &i.UsedByChannelIdentityID, + &i.CreatedAt, + ) + return i, err +} + +const markBindCodeUsed = `-- name: MarkBindCodeUsed :one +UPDATE channel_identity_bind_codes +SET used_at = now(), used_by_channel_identity_id = $2 +WHERE id = $1 + AND used_at IS NULL +RETURNING id, token, issued_by_user_id, channel_type, expires_at, used_at, used_by_channel_identity_id, created_at +` + +type MarkBindCodeUsedParams struct { + ID pgtype.UUID `json:"id"` + UsedByChannelIdentityID pgtype.UUID `json:"used_by_channel_identity_id"` +} + +func (q *Queries) MarkBindCodeUsed(ctx context.Context, arg MarkBindCodeUsedParams) (ChannelIdentityBindCode, error) { + row := q.db.QueryRow(ctx, markBindCodeUsed, arg.ID, arg.UsedByChannelIdentityID) + var i ChannelIdentityBindCode + err := row.Scan( + &i.ID, + &i.Token, + &i.IssuedByUserID, + &i.ChannelType, + &i.ExpiresAt, + &i.UsedAt, + &i.UsedByChannelIdentityID, + &i.CreatedAt, + ) + return i, err +} diff --git a/internal/db/sqlc/bots.sql.go b/internal/db/sqlc/bots.sql.go index 69f4b4ed..7d175c97 100644 --- a/internal/db/sqlc/bots.sql.go +++ b/internal/db/sqlc/bots.sql.go @@ -12,9 +12,9 @@ import ( ) const createBot = `-- name: CreateBot :one -INSERT INTO bots (owner_user_id, type, display_name, avatar_url, is_active, metadata) -VALUES ($1, $2, $3, $4, $5, $6) -RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at +INSERT INTO bots (owner_user_id, type, display_name, avatar_url, is_active, metadata, status) +VALUES ($1, $2, $3, $4, $5, $6, $7) +RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at ` type CreateBotParams struct { @@ -24,6 +24,7 @@ type CreateBotParams struct { AvatarUrl pgtype.Text `json:"avatar_url"` IsActive bool `json:"is_active"` Metadata []byte `json:"metadata"` + Status string `json:"status"` } func (q *Queries) CreateBot(ctx context.Context, arg CreateBotParams) (Bot, error) { @@ -34,6 +35,7 @@ func (q *Queries) CreateBot(ctx context.Context, arg CreateBotParams) (Bot, erro arg.AvatarUrl, arg.IsActive, arg.Metadata, + arg.Status, ) var i Bot err := row.Scan( @@ -43,6 +45,13 @@ func (q *Queries) CreateBot(ctx context.Context, arg CreateBotParams) (Bot, erro &i.DisplayName, &i.AvatarUrl, &i.IsActive, + &i.Status, + &i.MaxContextLoadTime, + &i.Language, + &i.AllowGuest, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, &i.Metadata, &i.CreatedAt, &i.UpdatedAt, @@ -74,7 +83,7 @@ func (q *Queries) DeleteBotMember(ctx context.Context, arg DeleteBotMemberParams } const getBotByID = `-- name: GetBotByID :one -SELECT id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at +SELECT id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at FROM bots WHERE id = $1 ` @@ -89,6 +98,13 @@ func (q *Queries) GetBotByID(ctx context.Context, id pgtype.UUID) (Bot, error) { &i.DisplayName, &i.AvatarUrl, &i.IsActive, + &i.Status, + &i.MaxContextLoadTime, + &i.Language, + &i.AllowGuest, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, &i.Metadata, &i.CreatedAt, &i.UpdatedAt, @@ -153,7 +169,7 @@ func (q *Queries) ListBotMembers(ctx context.Context, botID pgtype.UUID) ([]BotM } const listBotsByMember = `-- name: ListBotsByMember :many -SELECT b.id, b.owner_user_id, b.type, b.display_name, b.avatar_url, b.is_active, b.metadata, b.created_at, b.updated_at +SELECT b.id, b.owner_user_id, b.type, b.display_name, b.avatar_url, b.is_active, b.status, b.max_context_load_time, b.language, b.allow_guest, b.chat_model_id, b.memory_model_id, b.embedding_model_id, b.metadata, b.created_at, b.updated_at FROM bots b JOIN bot_members m ON m.bot_id = b.id WHERE m.user_id = $1 @@ -176,6 +192,13 @@ func (q *Queries) ListBotsByMember(ctx context.Context, userID pgtype.UUID) ([]B &i.DisplayName, &i.AvatarUrl, &i.IsActive, + &i.Status, + &i.MaxContextLoadTime, + &i.Language, + &i.AllowGuest, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, &i.Metadata, &i.CreatedAt, &i.UpdatedAt, @@ -191,7 +214,7 @@ func (q *Queries) ListBotsByMember(ctx context.Context, userID pgtype.UUID) ([]B } const listBotsByOwner = `-- name: ListBotsByOwner :many -SELECT id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at +SELECT id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at FROM bots WHERE owner_user_id = $1 ORDER BY created_at DESC @@ -213,6 +236,13 @@ func (q *Queries) ListBotsByOwner(ctx context.Context, ownerUserID pgtype.UUID) &i.DisplayName, &i.AvatarUrl, &i.IsActive, + &i.Status, + &i.MaxContextLoadTime, + &i.Language, + &i.AllowGuest, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, &i.Metadata, &i.CreatedAt, &i.UpdatedAt, @@ -232,7 +262,7 @@ UPDATE bots SET owner_user_id = $2, updated_at = now() WHERE id = $1 -RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at +RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at ` type UpdateBotOwnerParams struct { @@ -250,6 +280,13 @@ func (q *Queries) UpdateBotOwner(ctx context.Context, arg UpdateBotOwnerParams) &i.DisplayName, &i.AvatarUrl, &i.IsActive, + &i.Status, + &i.MaxContextLoadTime, + &i.Language, + &i.AllowGuest, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, &i.Metadata, &i.CreatedAt, &i.UpdatedAt, @@ -265,7 +302,7 @@ SET display_name = $2, metadata = $5, updated_at = now() WHERE id = $1 -RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at +RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at ` type UpdateBotProfileParams struct { @@ -292,6 +329,13 @@ func (q *Queries) UpdateBotProfile(ctx context.Context, arg UpdateBotProfilePara &i.DisplayName, &i.AvatarUrl, &i.IsActive, + &i.Status, + &i.MaxContextLoadTime, + &i.Language, + &i.AllowGuest, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, &i.Metadata, &i.CreatedAt, &i.UpdatedAt, @@ -299,6 +343,23 @@ func (q *Queries) UpdateBotProfile(ctx context.Context, arg UpdateBotProfilePara return i, err } +const updateBotStatus = `-- name: UpdateBotStatus :exec +UPDATE bots +SET status = $2, + updated_at = now() +WHERE id = $1 +` + +type UpdateBotStatusParams struct { + ID pgtype.UUID `json:"id"` + Status string `json:"status"` +} + +func (q *Queries) UpdateBotStatus(ctx context.Context, arg UpdateBotStatusParams) error { + _, err := q.db.Exec(ctx, updateBotStatus, arg.ID, arg.Status) + return err +} + const upsertBotMember = `-- name: UpsertBotMember :one INSERT INTO bot_members (bot_id, user_id, role) VALUES ($1, $2, $3) diff --git a/internal/db/sqlc/channel_identities.sql.go b/internal/db/sqlc/channel_identities.sql.go new file mode 100644 index 00000000..b7ad5b17 --- /dev/null +++ b/internal/db/sqlc/channel_identities.sql.go @@ -0,0 +1,249 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: channel_identities.sql + +package sqlc + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const clearChannelIdentityLinkedUser = `-- name: ClearChannelIdentityLinkedUser :one +UPDATE channel_identities +SET user_id = NULL, updated_at = now() +WHERE id = $1 +RETURNING id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at +` + +func (q *Queries) ClearChannelIdentityLinkedUser(ctx context.Context, id pgtype.UUID) (ChannelIdentity, error) { + row := q.db.QueryRow(ctx, clearChannelIdentityLinkedUser, id) + var i ChannelIdentity + err := row.Scan( + &i.ID, + &i.UserID, + &i.ChannelType, + &i.ChannelSubjectID, + &i.DisplayName, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const createChannelIdentity = `-- name: CreateChannelIdentity :one +INSERT INTO channel_identities (user_id, channel_type, channel_subject_id, display_name, metadata) +VALUES ($1, $2, $3, $4, $5) +RETURNING id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at +` + +type CreateChannelIdentityParams struct { + UserID pgtype.UUID `json:"user_id"` + ChannelType string `json:"channel_type"` + ChannelSubjectID string `json:"channel_subject_id"` + DisplayName pgtype.Text `json:"display_name"` + Metadata []byte `json:"metadata"` +} + +func (q *Queries) CreateChannelIdentity(ctx context.Context, arg CreateChannelIdentityParams) (ChannelIdentity, error) { + row := q.db.QueryRow(ctx, createChannelIdentity, + arg.UserID, + arg.ChannelType, + arg.ChannelSubjectID, + arg.DisplayName, + arg.Metadata, + ) + var i ChannelIdentity + err := row.Scan( + &i.ID, + &i.UserID, + &i.ChannelType, + &i.ChannelSubjectID, + &i.DisplayName, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getChannelIdentityByChannelSubject = `-- name: GetChannelIdentityByChannelSubject :one +SELECT id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at +FROM channel_identities +WHERE channel_type = $1 AND channel_subject_id = $2 +` + +type GetChannelIdentityByChannelSubjectParams struct { + ChannelType string `json:"channel_type"` + ChannelSubjectID string `json:"channel_subject_id"` +} + +func (q *Queries) GetChannelIdentityByChannelSubject(ctx context.Context, arg GetChannelIdentityByChannelSubjectParams) (ChannelIdentity, error) { + row := q.db.QueryRow(ctx, getChannelIdentityByChannelSubject, arg.ChannelType, arg.ChannelSubjectID) + var i ChannelIdentity + err := row.Scan( + &i.ID, + &i.UserID, + &i.ChannelType, + &i.ChannelSubjectID, + &i.DisplayName, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getChannelIdentityByID = `-- name: GetChannelIdentityByID :one +SELECT id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at +FROM channel_identities +WHERE id = $1 +` + +func (q *Queries) GetChannelIdentityByID(ctx context.Context, id pgtype.UUID) (ChannelIdentity, error) { + row := q.db.QueryRow(ctx, getChannelIdentityByID, id) + var i ChannelIdentity + err := row.Scan( + &i.ID, + &i.UserID, + &i.ChannelType, + &i.ChannelSubjectID, + &i.DisplayName, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getChannelIdentityByIDForUpdate = `-- name: GetChannelIdentityByIDForUpdate :one +SELECT id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at +FROM channel_identities +WHERE id = $1 +FOR UPDATE +` + +func (q *Queries) GetChannelIdentityByIDForUpdate(ctx context.Context, id pgtype.UUID) (ChannelIdentity, error) { + row := q.db.QueryRow(ctx, getChannelIdentityByIDForUpdate, id) + var i ChannelIdentity + err := row.Scan( + &i.ID, + &i.UserID, + &i.ChannelType, + &i.ChannelSubjectID, + &i.DisplayName, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const listChannelIdentitiesByUserID = `-- name: ListChannelIdentitiesByUserID :many +SELECT id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at +FROM channel_identities +WHERE user_id = $1 +ORDER BY created_at DESC +` + +func (q *Queries) ListChannelIdentitiesByUserID(ctx context.Context, userID pgtype.UUID) ([]ChannelIdentity, error) { + rows, err := q.db.Query(ctx, listChannelIdentitiesByUserID, userID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ChannelIdentity + for rows.Next() { + var i ChannelIdentity + if err := rows.Scan( + &i.ID, + &i.UserID, + &i.ChannelType, + &i.ChannelSubjectID, + &i.DisplayName, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const setChannelIdentityLinkedUser = `-- name: SetChannelIdentityLinkedUser :one +UPDATE channel_identities +SET user_id = $2, updated_at = now() +WHERE id = $1 +RETURNING id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at +` + +type SetChannelIdentityLinkedUserParams struct { + ID pgtype.UUID `json:"id"` + UserID pgtype.UUID `json:"user_id"` +} + +func (q *Queries) SetChannelIdentityLinkedUser(ctx context.Context, arg SetChannelIdentityLinkedUserParams) (ChannelIdentity, error) { + row := q.db.QueryRow(ctx, setChannelIdentityLinkedUser, arg.ID, arg.UserID) + var i ChannelIdentity + err := row.Scan( + &i.ID, + &i.UserID, + &i.ChannelType, + &i.ChannelSubjectID, + &i.DisplayName, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const upsertChannelIdentityByChannelSubject = `-- name: UpsertChannelIdentityByChannelSubject :one +INSERT INTO channel_identities (user_id, channel_type, channel_subject_id, display_name, metadata) +VALUES ($1, $2, $3, $4, $5) +ON CONFLICT (channel_type, channel_subject_id) +DO UPDATE SET + display_name = COALESCE(NULLIF(EXCLUDED.display_name, ''), channel_identities.display_name), + metadata = EXCLUDED.metadata, + user_id = COALESCE(channel_identities.user_id, EXCLUDED.user_id), + updated_at = now() +RETURNING id, user_id, channel_type, channel_subject_id, display_name, metadata, created_at, updated_at +` + +type UpsertChannelIdentityByChannelSubjectParams struct { + UserID pgtype.UUID `json:"user_id"` + ChannelType string `json:"channel_type"` + ChannelSubjectID string `json:"channel_subject_id"` + DisplayName pgtype.Text `json:"display_name"` + Metadata []byte `json:"metadata"` +} + +func (q *Queries) UpsertChannelIdentityByChannelSubject(ctx context.Context, arg UpsertChannelIdentityByChannelSubjectParams) (ChannelIdentity, error) { + row := q.db.QueryRow(ctx, upsertChannelIdentityByChannelSubject, + arg.UserID, + arg.ChannelType, + arg.ChannelSubjectID, + arg.DisplayName, + arg.Metadata, + ) + var i ChannelIdentity + err := row.Scan( + &i.ID, + &i.UserID, + &i.ChannelType, + &i.ChannelSubjectID, + &i.DisplayName, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/internal/db/sqlc/channel_routes.sql.go b/internal/db/sqlc/channel_routes.sql.go new file mode 100644 index 00000000..0be56106 --- /dev/null +++ b/internal/db/sqlc/channel_routes.sql.go @@ -0,0 +1,298 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: channel_routes.sql + +package sqlc + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createChatRoute = `-- name: CreateChatRoute :one +INSERT INTO bot_channel_routes ( + bot_id, channel_type, channel_config_id, external_conversation_id, external_thread_id, default_reply_target, metadata +) +VALUES ( + $1, + $2, + $3::uuid, + $4, + $5::text, + $6::text, + $7 +) +RETURNING + id, + $8::uuid AS chat_id, + bot_id, + channel_type AS platform, + channel_config_id, + external_conversation_id AS conversation_id, + external_thread_id AS thread_id, + default_reply_target AS reply_target, + metadata, + created_at, + updated_at +` + +type CreateChatRouteParams struct { + BotID pgtype.UUID `json:"bot_id"` + Platform string `json:"platform"` + ChannelConfigID pgtype.UUID `json:"channel_config_id"` + ConversationID string `json:"conversation_id"` + ThreadID pgtype.Text `json:"thread_id"` + ReplyTarget pgtype.Text `json:"reply_target"` + Metadata []byte `json:"metadata"` + ChatID pgtype.UUID `json:"chat_id"` +} + +type CreateChatRouteRow struct { + ID pgtype.UUID `json:"id"` + ChatID pgtype.UUID `json:"chat_id"` + BotID pgtype.UUID `json:"bot_id"` + Platform string `json:"platform"` + ChannelConfigID pgtype.UUID `json:"channel_config_id"` + ConversationID string `json:"conversation_id"` + ThreadID pgtype.Text `json:"thread_id"` + ReplyTarget pgtype.Text `json:"reply_target"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) CreateChatRoute(ctx context.Context, arg CreateChatRouteParams) (CreateChatRouteRow, error) { + row := q.db.QueryRow(ctx, createChatRoute, + arg.BotID, + arg.Platform, + arg.ChannelConfigID, + arg.ConversationID, + arg.ThreadID, + arg.ReplyTarget, + arg.Metadata, + arg.ChatID, + ) + var i CreateChatRouteRow + err := row.Scan( + &i.ID, + &i.ChatID, + &i.BotID, + &i.Platform, + &i.ChannelConfigID, + &i.ConversationID, + &i.ThreadID, + &i.ReplyTarget, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const deleteChatRoute = `-- name: DeleteChatRoute :exec +DELETE FROM bot_channel_routes +WHERE id = $1 +` + +func (q *Queries) DeleteChatRoute(ctx context.Context, id pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteChatRoute, id) + return err +} + +const findChatRoute = `-- name: FindChatRoute :one +SELECT + id, + bot_id AS chat_id, + bot_id, + channel_type AS platform, + channel_config_id, + external_conversation_id AS conversation_id, + external_thread_id AS thread_id, + default_reply_target AS reply_target, + metadata, + created_at, + updated_at +FROM bot_channel_routes +WHERE bot_id = $1 + AND channel_type = $2 + AND external_conversation_id = $3 + AND COALESCE(external_thread_id, '') = COALESCE($4, '') +LIMIT 1 +` + +type FindChatRouteParams struct { + BotID pgtype.UUID `json:"bot_id"` + Platform string `json:"platform"` + ConversationID string `json:"conversation_id"` + ThreadID pgtype.Text `json:"thread_id"` +} + +type FindChatRouteRow struct { + ID pgtype.UUID `json:"id"` + ChatID pgtype.UUID `json:"chat_id"` + BotID pgtype.UUID `json:"bot_id"` + Platform string `json:"platform"` + ChannelConfigID pgtype.UUID `json:"channel_config_id"` + ConversationID string `json:"conversation_id"` + ThreadID pgtype.Text `json:"thread_id"` + ReplyTarget pgtype.Text `json:"reply_target"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) FindChatRoute(ctx context.Context, arg FindChatRouteParams) (FindChatRouteRow, error) { + row := q.db.QueryRow(ctx, findChatRoute, + arg.BotID, + arg.Platform, + arg.ConversationID, + arg.ThreadID, + ) + var i FindChatRouteRow + err := row.Scan( + &i.ID, + &i.ChatID, + &i.BotID, + &i.Platform, + &i.ChannelConfigID, + &i.ConversationID, + &i.ThreadID, + &i.ReplyTarget, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getChatRouteByID = `-- name: GetChatRouteByID :one +SELECT + id, + bot_id AS chat_id, + bot_id, + channel_type AS platform, + channel_config_id, + external_conversation_id AS conversation_id, + external_thread_id AS thread_id, + default_reply_target AS reply_target, + metadata, + created_at, + updated_at +FROM bot_channel_routes +WHERE id = $1 +` + +type GetChatRouteByIDRow struct { + ID pgtype.UUID `json:"id"` + ChatID pgtype.UUID `json:"chat_id"` + BotID pgtype.UUID `json:"bot_id"` + Platform string `json:"platform"` + ChannelConfigID pgtype.UUID `json:"channel_config_id"` + ConversationID string `json:"conversation_id"` + ThreadID pgtype.Text `json:"thread_id"` + ReplyTarget pgtype.Text `json:"reply_target"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) GetChatRouteByID(ctx context.Context, id pgtype.UUID) (GetChatRouteByIDRow, error) { + row := q.db.QueryRow(ctx, getChatRouteByID, id) + var i GetChatRouteByIDRow + err := row.Scan( + &i.ID, + &i.ChatID, + &i.BotID, + &i.Platform, + &i.ChannelConfigID, + &i.ConversationID, + &i.ThreadID, + &i.ReplyTarget, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const listChatRoutes = `-- name: ListChatRoutes :many +SELECT + id, + bot_id AS chat_id, + bot_id, + channel_type AS platform, + channel_config_id, + external_conversation_id AS conversation_id, + external_thread_id AS thread_id, + default_reply_target AS reply_target, + metadata, + created_at, + updated_at +FROM bot_channel_routes +WHERE bot_id = $1 +ORDER BY created_at ASC +` + +type ListChatRoutesRow struct { + ID pgtype.UUID `json:"id"` + ChatID pgtype.UUID `json:"chat_id"` + BotID pgtype.UUID `json:"bot_id"` + Platform string `json:"platform"` + ChannelConfigID pgtype.UUID `json:"channel_config_id"` + ConversationID string `json:"conversation_id"` + ThreadID pgtype.Text `json:"thread_id"` + ReplyTarget pgtype.Text `json:"reply_target"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) ListChatRoutes(ctx context.Context, chatID pgtype.UUID) ([]ListChatRoutesRow, error) { + rows, err := q.db.Query(ctx, listChatRoutes, chatID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListChatRoutesRow + for rows.Next() { + var i ListChatRoutesRow + if err := rows.Scan( + &i.ID, + &i.ChatID, + &i.BotID, + &i.Platform, + &i.ChannelConfigID, + &i.ConversationID, + &i.ThreadID, + &i.ReplyTarget, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateChatRouteReplyTarget = `-- name: UpdateChatRouteReplyTarget :exec +UPDATE bot_channel_routes +SET default_reply_target = $1, updated_at = now() +WHERE id = $2 +` + +type UpdateChatRouteReplyTargetParams struct { + ReplyTarget pgtype.Text `json:"reply_target"` + ID pgtype.UUID `json:"id"` +} + +func (q *Queries) UpdateChatRouteReplyTarget(ctx context.Context, arg UpdateChatRouteReplyTargetParams) error { + _, err := q.db.Exec(ctx, updateChatRouteReplyTarget, arg.ReplyTarget, arg.ID) + return err +} diff --git a/internal/db/sqlc/channels.sql.go b/internal/db/sqlc/channels.sql.go index c38738d8..52c62a40 100644 --- a/internal/db/sqlc/channels.sql.go +++ b/internal/db/sqlc/channels.sql.go @@ -11,16 +11,6 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -const deleteChannelSession = `-- name: DeleteChannelSession :exec -DELETE FROM channel_sessions -WHERE session_id = $1 -` - -func (q *Queries) DeleteChannelSession(ctx context.Context, sessionID string) error { - _, err := q.db.Exec(ctx, deleteChannelSession, sessionID) - return err -} - const getBotChannelConfig = `-- name: GetBotChannelConfig :one SELECT id, bot_id, channel_type, credentials, external_identity, self_identity, routing, capabilities, status, verified_at, created_at, updated_at FROM bot_channel_configs @@ -85,32 +75,6 @@ func (q *Queries) GetBotChannelConfigByExternalIdentity(ctx context.Context, arg return i, err } -const getChannelSessionByID = `-- name: GetChannelSessionByID :one -SELECT session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata, created_at, updated_at -FROM channel_sessions -WHERE session_id = $1 -LIMIT 1 -` - -func (q *Queries) GetChannelSessionByID(ctx context.Context, sessionID string) (ChannelSession, error) { - row := q.db.QueryRow(ctx, getChannelSessionByID, sessionID) - var i ChannelSession - err := row.Scan( - &i.SessionID, - &i.BotID, - &i.ChannelConfigID, - &i.UserID, - &i.ContactID, - &i.Platform, - &i.ReplyTarget, - &i.ThreadID, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - const getUserChannelBinding = `-- name: GetUserChannelBinding :one SELECT id, user_id, channel_type, config, created_at, updated_at FROM user_channel_bindings @@ -177,59 +141,15 @@ func (q *Queries) ListBotChannelConfigsByType(ctx context.Context, channelType s return items, nil } -const listChannelSessionsByBotPlatform = `-- name: ListChannelSessionsByBotPlatform :many -SELECT session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata, created_at, updated_at -FROM channel_sessions -WHERE bot_id = $1 AND platform = $2 -ORDER BY updated_at DESC -` - -type ListChannelSessionsByBotPlatformParams struct { - BotID pgtype.UUID `json:"bot_id"` - Platform string `json:"platform"` -} - -func (q *Queries) ListChannelSessionsByBotPlatform(ctx context.Context, arg ListChannelSessionsByBotPlatformParams) ([]ChannelSession, error) { - rows, err := q.db.Query(ctx, listChannelSessionsByBotPlatform, arg.BotID, arg.Platform) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ChannelSession - for rows.Next() { - var i ChannelSession - if err := rows.Scan( - &i.SessionID, - &i.BotID, - &i.ChannelConfigID, - &i.UserID, - &i.ContactID, - &i.Platform, - &i.ReplyTarget, - &i.ThreadID, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const listUserChannelBindingsByType = `-- name: ListUserChannelBindingsByType :many +const listUserChannelBindingsByPlatform = `-- name: ListUserChannelBindingsByPlatform :many SELECT id, user_id, channel_type, config, created_at, updated_at FROM user_channel_bindings WHERE channel_type = $1 ORDER BY created_at DESC ` -func (q *Queries) ListUserChannelBindingsByType(ctx context.Context, channelType string) ([]UserChannelBinding, error) { - rows, err := q.db.Query(ctx, listUserChannelBindingsByType, channelType) +func (q *Queries) ListUserChannelBindingsByPlatform(ctx context.Context, channelType string) ([]UserChannelBinding, error) { + rows, err := q.db.Query(ctx, listUserChannelBindingsByPlatform, channelType) if err != nil { return nil, err } @@ -315,64 +235,6 @@ func (q *Queries) UpsertBotChannelConfig(ctx context.Context, arg UpsertBotChann return i, err } -const upsertChannelSession = `-- name: UpsertChannelSession :one -INSERT INTO channel_sessions (session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) -ON CONFLICT (session_id) -DO UPDATE SET - bot_id = EXCLUDED.bot_id, - channel_config_id = EXCLUDED.channel_config_id, - user_id = EXCLUDED.user_id, - contact_id = EXCLUDED.contact_id, - platform = EXCLUDED.platform, - reply_target = EXCLUDED.reply_target, - thread_id = EXCLUDED.thread_id, - metadata = EXCLUDED.metadata, - updated_at = now() -RETURNING session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata, created_at, updated_at -` - -type UpsertChannelSessionParams struct { - SessionID string `json:"session_id"` - BotID pgtype.UUID `json:"bot_id"` - ChannelConfigID pgtype.UUID `json:"channel_config_id"` - UserID pgtype.UUID `json:"user_id"` - ContactID pgtype.UUID `json:"contact_id"` - Platform string `json:"platform"` - ReplyTarget pgtype.Text `json:"reply_target"` - ThreadID pgtype.Text `json:"thread_id"` - Metadata []byte `json:"metadata"` -} - -func (q *Queries) UpsertChannelSession(ctx context.Context, arg UpsertChannelSessionParams) (ChannelSession, error) { - row := q.db.QueryRow(ctx, upsertChannelSession, - arg.SessionID, - arg.BotID, - arg.ChannelConfigID, - arg.UserID, - arg.ContactID, - arg.Platform, - arg.ReplyTarget, - arg.ThreadID, - arg.Metadata, - ) - var i ChannelSession - err := row.Scan( - &i.SessionID, - &i.BotID, - &i.ChannelConfigID, - &i.UserID, - &i.ContactID, - &i.Platform, - &i.ReplyTarget, - &i.ThreadID, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - const upsertUserChannelBinding = `-- name: UpsertUserChannelBinding :one INSERT INTO user_channel_bindings (user_id, channel_type, config) VALUES ($1, $2, $3) diff --git a/internal/db/sqlc/contacts.sql.go b/internal/db/sqlc/contacts.sql.go deleted file mode 100644 index 3cf19028..00000000 --- a/internal/db/sqlc/contacts.sql.go +++ /dev/null @@ -1,380 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. -// versions: -// sqlc v1.30.0 -// source: contacts.sql - -package sqlc - -import ( - "context" - - "github.com/jackc/pgx/v5/pgtype" -) - -const createContact = `-- name: CreateContact :one -INSERT INTO contacts (bot_id, user_id, display_name, alias, tags, status, metadata) -VALUES ($1, $2, $3, $4, $5, $6, $7) -RETURNING id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -` - -type CreateContactParams struct { - BotID pgtype.UUID `json:"bot_id"` - UserID pgtype.UUID `json:"user_id"` - DisplayName pgtype.Text `json:"display_name"` - Alias pgtype.Text `json:"alias"` - Tags []string `json:"tags"` - Status string `json:"status"` - Metadata []byte `json:"metadata"` -} - -func (q *Queries) CreateContact(ctx context.Context, arg CreateContactParams) (Contact, error) { - row := q.db.QueryRow(ctx, createContact, - arg.BotID, - arg.UserID, - arg.DisplayName, - arg.Alias, - arg.Tags, - arg.Status, - arg.Metadata, - ) - var i Contact - err := row.Scan( - &i.ID, - &i.BotID, - &i.UserID, - &i.DisplayName, - &i.Alias, - &i.Tags, - &i.Status, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const getContactByID = `-- name: GetContactByID :one -SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -FROM contacts -WHERE id = $1 -LIMIT 1 -` - -func (q *Queries) GetContactByID(ctx context.Context, id pgtype.UUID) (Contact, error) { - row := q.db.QueryRow(ctx, getContactByID, id) - var i Contact - err := row.Scan( - &i.ID, - &i.BotID, - &i.UserID, - &i.DisplayName, - &i.Alias, - &i.Tags, - &i.Status, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const getContactByUserID = `-- name: GetContactByUserID :one -SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -FROM contacts -WHERE bot_id = $1 AND user_id = $2 -LIMIT 1 -` - -type GetContactByUserIDParams struct { - BotID pgtype.UUID `json:"bot_id"` - UserID pgtype.UUID `json:"user_id"` -} - -func (q *Queries) GetContactByUserID(ctx context.Context, arg GetContactByUserIDParams) (Contact, error) { - row := q.db.QueryRow(ctx, getContactByUserID, arg.BotID, arg.UserID) - var i Contact - err := row.Scan( - &i.ID, - &i.BotID, - &i.UserID, - &i.DisplayName, - &i.Alias, - &i.Tags, - &i.Status, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const getContactChannelByIdentity = `-- name: GetContactChannelByIdentity :one -SELECT id, bot_id, contact_id, platform, external_id, metadata, created_at, updated_at -FROM contact_channels -WHERE bot_id = $1 AND platform = $2 AND external_id = $3 -LIMIT 1 -` - -type GetContactChannelByIdentityParams struct { - BotID pgtype.UUID `json:"bot_id"` - Platform string `json:"platform"` - ExternalID string `json:"external_id"` -} - -func (q *Queries) GetContactChannelByIdentity(ctx context.Context, arg GetContactChannelByIdentityParams) (ContactChannel, error) { - row := q.db.QueryRow(ctx, getContactChannelByIdentity, arg.BotID, arg.Platform, arg.ExternalID) - var i ContactChannel - err := row.Scan( - &i.ID, - &i.BotID, - &i.ContactID, - &i.Platform, - &i.ExternalID, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const listContactChannelsByContact = `-- name: ListContactChannelsByContact :many -SELECT id, bot_id, contact_id, platform, external_id, metadata, created_at, updated_at -FROM contact_channels -WHERE contact_id = $1 -ORDER BY created_at DESC -` - -func (q *Queries) ListContactChannelsByContact(ctx context.Context, contactID pgtype.UUID) ([]ContactChannel, error) { - rows, err := q.db.Query(ctx, listContactChannelsByContact, contactID) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ContactChannel - for rows.Next() { - var i ContactChannel - if err := rows.Scan( - &i.ID, - &i.BotID, - &i.ContactID, - &i.Platform, - &i.ExternalID, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const listContactsByBot = `-- name: ListContactsByBot :many -SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -FROM contacts -WHERE bot_id = $1 -ORDER BY created_at DESC -` - -func (q *Queries) ListContactsByBot(ctx context.Context, botID pgtype.UUID) ([]Contact, error) { - rows, err := q.db.Query(ctx, listContactsByBot, botID) - if err != nil { - return nil, err - } - defer rows.Close() - var items []Contact - for rows.Next() { - var i Contact - if err := rows.Scan( - &i.ID, - &i.BotID, - &i.UserID, - &i.DisplayName, - &i.Alias, - &i.Tags, - &i.Status, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const searchContacts = `-- name: SearchContacts :many -SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -FROM contacts -WHERE bot_id = $1 - AND ( - display_name ILIKE $2 - OR alias ILIKE $2 - OR EXISTS ( - SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE $2 - ) - ) -ORDER BY created_at DESC -` - -type SearchContactsParams struct { - BotID pgtype.UUID `json:"bot_id"` - Query pgtype.Text `json:"query"` -} - -func (q *Queries) SearchContacts(ctx context.Context, arg SearchContactsParams) ([]Contact, error) { - rows, err := q.db.Query(ctx, searchContacts, arg.BotID, arg.Query) - if err != nil { - return nil, err - } - defer rows.Close() - var items []Contact - for rows.Next() { - var i Contact - if err := rows.Scan( - &i.ID, - &i.BotID, - &i.UserID, - &i.DisplayName, - &i.Alias, - &i.Tags, - &i.Status, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const updateContact = `-- name: UpdateContact :one -UPDATE contacts -SET display_name = COALESCE($1, display_name), - alias = COALESCE($2, alias), - tags = COALESCE($3, tags), - status = COALESCE(NULLIF($4::text, ''), status), - metadata = COALESCE($5, metadata), - updated_at = now() -WHERE id = $6 -RETURNING id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -` - -type UpdateContactParams struct { - DisplayName pgtype.Text `json:"display_name"` - Alias pgtype.Text `json:"alias"` - Tags []string `json:"tags"` - Status string `json:"status"` - Metadata []byte `json:"metadata"` - ID pgtype.UUID `json:"id"` -} - -func (q *Queries) UpdateContact(ctx context.Context, arg UpdateContactParams) (Contact, error) { - row := q.db.QueryRow(ctx, updateContact, - arg.DisplayName, - arg.Alias, - arg.Tags, - arg.Status, - arg.Metadata, - arg.ID, - ) - var i Contact - err := row.Scan( - &i.ID, - &i.BotID, - &i.UserID, - &i.DisplayName, - &i.Alias, - &i.Tags, - &i.Status, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const updateContactUser = `-- name: UpdateContactUser :one -UPDATE contacts -SET user_id = $2, - updated_at = now() -WHERE id = $1 -RETURNING id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -` - -type UpdateContactUserParams struct { - ID pgtype.UUID `json:"id"` - UserID pgtype.UUID `json:"user_id"` -} - -func (q *Queries) UpdateContactUser(ctx context.Context, arg UpdateContactUserParams) (Contact, error) { - row := q.db.QueryRow(ctx, updateContactUser, arg.ID, arg.UserID) - var i Contact - err := row.Scan( - &i.ID, - &i.BotID, - &i.UserID, - &i.DisplayName, - &i.Alias, - &i.Tags, - &i.Status, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const upsertContactChannel = `-- name: UpsertContactChannel :one -INSERT INTO contact_channels (bot_id, contact_id, platform, external_id, metadata) -VALUES ($1, $2, $3, $4, $5) -ON CONFLICT (bot_id, platform, external_id) -DO UPDATE SET - contact_id = EXCLUDED.contact_id, - metadata = EXCLUDED.metadata, - updated_at = now() -RETURNING id, bot_id, contact_id, platform, external_id, metadata, created_at, updated_at -` - -type UpsertContactChannelParams struct { - BotID pgtype.UUID `json:"bot_id"` - ContactID pgtype.UUID `json:"contact_id"` - Platform string `json:"platform"` - ExternalID string `json:"external_id"` - Metadata []byte `json:"metadata"` -} - -func (q *Queries) UpsertContactChannel(ctx context.Context, arg UpsertContactChannelParams) (ContactChannel, error) { - row := q.db.QueryRow(ctx, upsertContactChannel, - arg.BotID, - arg.ContactID, - arg.Platform, - arg.ExternalID, - arg.Metadata, - ) - var i ContactChannel - err := row.Scan( - &i.ID, - &i.BotID, - &i.ContactID, - &i.Platform, - &i.ExternalID, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} diff --git a/internal/db/sqlc/containers.sql.go b/internal/db/sqlc/containers.sql.go index 3141781b..b193443e 100644 --- a/internal/db/sqlc/containers.sql.go +++ b/internal/db/sqlc/containers.sql.go @@ -72,6 +72,45 @@ func (q *Queries) GetContainerByContainerID(ctx context.Context, containerID str return i, err } +const listAutoStartContainers = `-- name: ListAutoStartContainers :many +SELECT id, bot_id, container_id, container_name, image, status, namespace, auto_start, host_path, container_path, created_at, updated_at, last_started_at, last_stopped_at FROM containers WHERE auto_start = true ORDER BY updated_at DESC +` + +func (q *Queries) ListAutoStartContainers(ctx context.Context) ([]Container, error) { + rows, err := q.db.Query(ctx, listAutoStartContainers) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Container + for rows.Next() { + var i Container + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.ContainerID, + &i.ContainerName, + &i.Image, + &i.Status, + &i.Namespace, + &i.AutoStart, + &i.HostPath, + &i.ContainerPath, + &i.CreatedAt, + &i.UpdatedAt, + &i.LastStartedAt, + &i.LastStoppedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const updateContainerStarted = `-- name: UpdateContainerStarted :exec UPDATE containers SET status = 'running', last_started_at = now(), updated_at = now() diff --git a/internal/db/sqlc/conversations.sql.go b/internal/db/sqlc/conversations.sql.go new file mode 100644 index 00000000..f9ba8fc4 --- /dev/null +++ b/internal/db/sqlc/conversations.sql.go @@ -0,0 +1,678 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: conversations.sql + +package sqlc + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const addChatParticipant = `-- name: AddChatParticipant :one + +INSERT INTO bot_members (bot_id, user_id, role) +VALUES ($1, $2, $3) +ON CONFLICT (bot_id, user_id) DO UPDATE SET role = EXCLUDED.role +RETURNING bot_id AS chat_id, user_id, role, created_at AS joined_at +` + +type AddChatParticipantParams struct { + ChatID pgtype.UUID `json:"chat_id"` + UserID pgtype.UUID `json:"user_id"` + Role string `json:"role"` +} + +type AddChatParticipantRow struct { + ChatID pgtype.UUID `json:"chat_id"` + UserID pgtype.UUID `json:"user_id"` + Role string `json:"role"` + JoinedAt pgtype.Timestamptz `json:"joined_at"` +} + +// chat_participants +func (q *Queries) AddChatParticipant(ctx context.Context, arg AddChatParticipantParams) (AddChatParticipantRow, error) { + row := q.db.QueryRow(ctx, addChatParticipant, arg.ChatID, arg.UserID, arg.Role) + var i AddChatParticipantRow + err := row.Scan( + &i.ChatID, + &i.UserID, + &i.Role, + &i.JoinedAt, + ) + return i, err +} + +const copyParticipantsToChat = `-- name: CopyParticipantsToChat :exec +INSERT INTO bot_members (bot_id, user_id, role) +SELECT $1, bm.user_id, bm.role +FROM bot_members bm +WHERE bm.bot_id = $2 +ON CONFLICT (bot_id, user_id) DO NOTHING +` + +type CopyParticipantsToChatParams struct { + ChatID2 pgtype.UUID `json:"chat_id_2"` + ChatID pgtype.UUID `json:"chat_id"` +} + +func (q *Queries) CopyParticipantsToChat(ctx context.Context, arg CopyParticipantsToChatParams) error { + _, err := q.db.Exec(ctx, copyParticipantsToChat, arg.ChatID2, arg.ChatID) + return err +} + +const createChat = `-- name: CreateChat :one +SELECT + b.id AS id, + b.id AS bot_id, + (COALESCE(NULLIF($1::text, ''), CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END))::text AS kind, + CASE WHEN $1 = 'thread' THEN $2::uuid ELSE NULL::uuid END AS parent_chat_id, + COALESCE(NULLIF($3::text, ''), b.display_name) AS title, + COALESCE($4::uuid, b.owner_user_id) AS created_by_user_id, + COALESCE($5::jsonb, b.metadata) AS metadata, + chat_models.model_id AS model_id, + b.created_at, + b.updated_at +FROM bots b +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = $6 +LIMIT 1 +` + +type CreateChatParams struct { + Kind string `json:"kind"` + ParentChatID pgtype.UUID `json:"parent_chat_id"` + Title string `json:"title"` + CreatedByUserID pgtype.UUID `json:"created_by_user_id"` + Metadata []byte `json:"metadata"` + BotID pgtype.UUID `json:"bot_id"` +} + +type CreateChatRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID pgtype.UUID `json:"parent_chat_id"` + Title pgtype.Text `json:"title"` + CreatedByUserID pgtype.UUID `json:"created_by_user_id"` + Metadata []byte `json:"metadata"` + ModelID pgtype.Text `json:"model_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) CreateChat(ctx context.Context, arg CreateChatParams) (CreateChatRow, error) { + row := q.db.QueryRow(ctx, createChat, + arg.Kind, + arg.ParentChatID, + arg.Title, + arg.CreatedByUserID, + arg.Metadata, + arg.BotID, + ) + var i CreateChatRow + err := row.Scan( + &i.ID, + &i.BotID, + &i.Kind, + &i.ParentChatID, + &i.Title, + &i.CreatedByUserID, + &i.Metadata, + &i.ModelID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const deleteChat = `-- name: DeleteChat :exec +WITH deleted_messages AS ( + DELETE FROM bot_history_messages + WHERE bot_id = $1 +) +DELETE FROM bot_channel_routes bcr +WHERE bcr.bot_id = $1 +` + +func (q *Queries) DeleteChat(ctx context.Context, chatID pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteChat, chatID) + return err +} + +const getChatByID = `-- name: GetChatByID :one +SELECT + b.id AS id, + b.id AS bot_id, + CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + b.display_name AS title, + b.owner_user_id AS created_by_user_id, + b.metadata AS metadata, + chat_models.model_id AS model_id, + b.created_at, + b.updated_at +FROM bots b +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = $1 +` + +type GetChatByIDRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID pgtype.UUID `json:"parent_chat_id"` + Title pgtype.Text `json:"title"` + CreatedByUserID pgtype.UUID `json:"created_by_user_id"` + Metadata []byte `json:"metadata"` + ModelID pgtype.Text `json:"model_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) GetChatByID(ctx context.Context, id pgtype.UUID) (GetChatByIDRow, error) { + row := q.db.QueryRow(ctx, getChatByID, id) + var i GetChatByIDRow + err := row.Scan( + &i.ID, + &i.BotID, + &i.Kind, + &i.ParentChatID, + &i.Title, + &i.CreatedByUserID, + &i.Metadata, + &i.ModelID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getChatParticipant = `-- name: GetChatParticipant :one +WITH owner_participant AS ( + SELECT b.id AS chat_id, b.owner_user_id AS user_id, 'owner'::text AS role, b.created_at AS joined_at + FROM bots b + WHERE b.id = $1 AND b.owner_user_id = $2 +), +member_participant AS ( + SELECT bm.bot_id AS chat_id, bm.user_id, bm.role, bm.created_at AS joined_at + FROM bot_members bm + WHERE bm.bot_id = $1 AND bm.user_id = $2 +) +SELECT chat_id, user_id, role, joined_at +FROM ( + SELECT chat_id, user_id, role, joined_at FROM owner_participant + UNION ALL + SELECT chat_id, user_id, role, joined_at FROM member_participant +) p +ORDER BY CASE WHEN role = 'owner' THEN 0 ELSE 1 END +LIMIT 1 +` + +type GetChatParticipantParams struct { + ChatID pgtype.UUID `json:"chat_id"` + UserID pgtype.UUID `json:"user_id"` +} + +type GetChatParticipantRow struct { + ChatID pgtype.UUID `json:"chat_id"` + UserID pgtype.UUID `json:"user_id"` + Role string `json:"role"` + JoinedAt pgtype.Timestamptz `json:"joined_at"` +} + +func (q *Queries) GetChatParticipant(ctx context.Context, arg GetChatParticipantParams) (GetChatParticipantRow, error) { + row := q.db.QueryRow(ctx, getChatParticipant, arg.ChatID, arg.UserID) + var i GetChatParticipantRow + err := row.Scan( + &i.ChatID, + &i.UserID, + &i.Role, + &i.JoinedAt, + ) + return i, err +} + +const getChatReadAccessByUser = `-- name: GetChatReadAccessByUser :one +SELECT + 'participant'::text AS access_mode, + (CASE + WHEN b.owner_user_id = $1 THEN 'owner' + ELSE COALESCE(bm.role, ''::text) + END)::text AS participant_role, + NULL::timestamptz AS last_observed_at +FROM bots b +LEFT JOIN bot_members bm ON bm.bot_id = b.id AND bm.user_id = $1 +WHERE b.id = $2 + AND (b.owner_user_id = $1 OR bm.user_id IS NOT NULL) +LIMIT 1 +` + +type GetChatReadAccessByUserParams struct { + UserID pgtype.UUID `json:"user_id"` + ChatID pgtype.UUID `json:"chat_id"` +} + +type GetChatReadAccessByUserRow struct { + AccessMode string `json:"access_mode"` + ParticipantRole string `json:"participant_role"` + LastObservedAt pgtype.Timestamptz `json:"last_observed_at"` +} + +func (q *Queries) GetChatReadAccessByUser(ctx context.Context, arg GetChatReadAccessByUserParams) (GetChatReadAccessByUserRow, error) { + row := q.db.QueryRow(ctx, getChatReadAccessByUser, arg.UserID, arg.ChatID) + var i GetChatReadAccessByUserRow + err := row.Scan(&i.AccessMode, &i.ParticipantRole, &i.LastObservedAt) + return i, err +} + +const getChatSettings = `-- name: GetChatSettings :one +SELECT + b.id AS chat_id, + chat_models.model_id AS model_id, + b.updated_at +FROM bots b +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = $1 +` + +type GetChatSettingsRow struct { + ChatID pgtype.UUID `json:"chat_id"` + ModelID pgtype.Text `json:"model_id"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) GetChatSettings(ctx context.Context, id pgtype.UUID) (GetChatSettingsRow, error) { + row := q.db.QueryRow(ctx, getChatSettings, id) + var i GetChatSettingsRow + err := row.Scan(&i.ChatID, &i.ModelID, &i.UpdatedAt) + return i, err +} + +const listChatParticipants = `-- name: ListChatParticipants :many +WITH owner_participant AS ( + SELECT b.id AS chat_id, b.owner_user_id AS user_id, 'owner'::text AS role, b.created_at AS joined_at + FROM bots b + WHERE b.id = $1 +), +member_participant AS ( + SELECT bm.bot_id AS chat_id, bm.user_id, bm.role, bm.created_at AS joined_at + FROM bot_members bm + WHERE bm.bot_id = $1 + AND bm.user_id <> (SELECT owner_user_id FROM bots WHERE id = $1) +) +SELECT chat_id, user_id, role, joined_at +FROM ( + SELECT chat_id, user_id, role, joined_at FROM owner_participant + UNION ALL + SELECT chat_id, user_id, role, joined_at FROM member_participant +) p +ORDER BY joined_at ASC +` + +type ListChatParticipantsRow struct { + ChatID pgtype.UUID `json:"chat_id"` + UserID pgtype.UUID `json:"user_id"` + Role string `json:"role"` + JoinedAt pgtype.Timestamptz `json:"joined_at"` +} + +func (q *Queries) ListChatParticipants(ctx context.Context, chatID pgtype.UUID) ([]ListChatParticipantsRow, error) { + rows, err := q.db.Query(ctx, listChatParticipants, chatID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListChatParticipantsRow + for rows.Next() { + var i ListChatParticipantsRow + if err := rows.Scan( + &i.ChatID, + &i.UserID, + &i.Role, + &i.JoinedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listChatsByBotAndUser = `-- name: ListChatsByBotAndUser :many +SELECT + b.id AS id, + b.id AS bot_id, + CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + b.display_name AS title, + b.owner_user_id AS created_by_user_id, + b.metadata AS metadata, + chat_models.model_id AS model_id, + b.created_at, + b.updated_at +FROM bots b +LEFT JOIN bot_members bm ON bm.bot_id = b.id AND bm.user_id = $1 +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = $2 + AND (b.owner_user_id = $1 OR bm.user_id IS NOT NULL) +ORDER BY b.updated_at DESC +` + +type ListChatsByBotAndUserParams struct { + UserID pgtype.UUID `json:"user_id"` + BotID pgtype.UUID `json:"bot_id"` +} + +type ListChatsByBotAndUserRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID pgtype.UUID `json:"parent_chat_id"` + Title pgtype.Text `json:"title"` + CreatedByUserID pgtype.UUID `json:"created_by_user_id"` + Metadata []byte `json:"metadata"` + ModelID pgtype.Text `json:"model_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) ListChatsByBotAndUser(ctx context.Context, arg ListChatsByBotAndUserParams) ([]ListChatsByBotAndUserRow, error) { + rows, err := q.db.Query(ctx, listChatsByBotAndUser, arg.UserID, arg.BotID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListChatsByBotAndUserRow + for rows.Next() { + var i ListChatsByBotAndUserRow + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.Kind, + &i.ParentChatID, + &i.Title, + &i.CreatedByUserID, + &i.Metadata, + &i.ModelID, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listThreadsByParent = `-- name: ListThreadsByParent :many +SELECT + b.id AS id, + b.id AS bot_id, + CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + b.display_name AS title, + b.owner_user_id AS created_by_user_id, + b.metadata AS metadata, + chat_models.model_id AS model_id, + b.created_at, + b.updated_at +FROM bots b +LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id +WHERE b.id = $1 + AND false +ORDER BY b.created_at DESC +` + +type ListThreadsByParentRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID pgtype.UUID `json:"parent_chat_id"` + Title pgtype.Text `json:"title"` + CreatedByUserID pgtype.UUID `json:"created_by_user_id"` + Metadata []byte `json:"metadata"` + ModelID pgtype.Text `json:"model_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) ListThreadsByParent(ctx context.Context, id pgtype.UUID) ([]ListThreadsByParentRow, error) { + rows, err := q.db.Query(ctx, listThreadsByParent, id) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListThreadsByParentRow + for rows.Next() { + var i ListThreadsByParentRow + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.Kind, + &i.ParentChatID, + &i.Title, + &i.CreatedByUserID, + &i.Metadata, + &i.ModelID, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listVisibleChatsByBotAndUser = `-- name: ListVisibleChatsByBotAndUser :many +SELECT + b.id AS id, + b.id AS bot_id, + CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + b.display_name AS title, + b.owner_user_id AS created_by_user_id, + b.metadata AS metadata, + b.created_at, + b.updated_at, + 'participant'::text AS access_mode, + (CASE + WHEN b.owner_user_id = $1 THEN 'owner' + ELSE COALESCE(bm.role, ''::text) + END)::text AS participant_role, + NULL::timestamptz AS last_observed_at +FROM bots b +LEFT JOIN bot_members bm ON bm.bot_id = b.id AND bm.user_id = $1 +WHERE b.id = $2 + AND (b.owner_user_id = $1 OR bm.user_id IS NOT NULL) +ORDER BY b.updated_at DESC +` + +type ListVisibleChatsByBotAndUserParams struct { + UserID pgtype.UUID `json:"user_id"` + BotID pgtype.UUID `json:"bot_id"` +} + +type ListVisibleChatsByBotAndUserRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID pgtype.UUID `json:"parent_chat_id"` + Title pgtype.Text `json:"title"` + CreatedByUserID pgtype.UUID `json:"created_by_user_id"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` + AccessMode string `json:"access_mode"` + ParticipantRole string `json:"participant_role"` + LastObservedAt pgtype.Timestamptz `json:"last_observed_at"` +} + +func (q *Queries) ListVisibleChatsByBotAndUser(ctx context.Context, arg ListVisibleChatsByBotAndUserParams) ([]ListVisibleChatsByBotAndUserRow, error) { + rows, err := q.db.Query(ctx, listVisibleChatsByBotAndUser, arg.UserID, arg.BotID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListVisibleChatsByBotAndUserRow + for rows.Next() { + var i ListVisibleChatsByBotAndUserRow + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.Kind, + &i.ParentChatID, + &i.Title, + &i.CreatedByUserID, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + &i.AccessMode, + &i.ParticipantRole, + &i.LastObservedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const removeChatParticipant = `-- name: RemoveChatParticipant :exec +DELETE FROM bot_members +WHERE bot_id = $1 + AND user_id = $2 + AND user_id <> (SELECT owner_user_id FROM bots WHERE id = $1) +` + +type RemoveChatParticipantParams struct { + ChatID pgtype.UUID `json:"chat_id"` + UserID pgtype.UUID `json:"user_id"` +} + +func (q *Queries) RemoveChatParticipant(ctx context.Context, arg RemoveChatParticipantParams) error { + _, err := q.db.Exec(ctx, removeChatParticipant, arg.ChatID, arg.UserID) + return err +} + +const touchChat = `-- name: TouchChat :exec +UPDATE bots +SET updated_at = now() +WHERE id = $1 +` + +func (q *Queries) TouchChat(ctx context.Context, chatID pgtype.UUID) error { + _, err := q.db.Exec(ctx, touchChat, chatID) + return err +} + +const updateChatTitle = `-- name: UpdateChatTitle :one +UPDATE bots +SET display_name = $1, + updated_at = now() +WHERE id = $2 +RETURNING + id, + id AS bot_id, + CASE WHEN type = 'public' THEN 'group' ELSE 'direct' END AS kind, + NULL::uuid AS parent_chat_id, + display_name AS title, + owner_user_id AS created_by_user_id, + metadata, + NULL::text AS model_id, + created_at, + updated_at +` + +type UpdateChatTitleParams struct { + Title pgtype.Text `json:"title"` + ID pgtype.UUID `json:"id"` +} + +type UpdateChatTitleRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID pgtype.UUID `json:"parent_chat_id"` + Title pgtype.Text `json:"title"` + CreatedByUserID pgtype.UUID `json:"created_by_user_id"` + Metadata []byte `json:"metadata"` + ModelID pgtype.Text `json:"model_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) UpdateChatTitle(ctx context.Context, arg UpdateChatTitleParams) (UpdateChatTitleRow, error) { + row := q.db.QueryRow(ctx, updateChatTitle, arg.Title, arg.ID) + var i UpdateChatTitleRow + err := row.Scan( + &i.ID, + &i.BotID, + &i.Kind, + &i.ParentChatID, + &i.Title, + &i.CreatedByUserID, + &i.Metadata, + &i.ModelID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const upsertChatSettings = `-- name: UpsertChatSettings :one + +WITH resolved_model AS ( + SELECT id + FROM models + WHERE model_id = NULLIF($1::text, '') + LIMIT 1 +), +updated AS ( + UPDATE bots + SET chat_model_id = COALESCE((SELECT id FROM resolved_model), bots.chat_model_id), + updated_at = now() + WHERE bots.id = $2 + RETURNING bots.id, bots.chat_model_id, bots.updated_at +) +SELECT + updated.id AS chat_id, + chat_models.model_id AS model_id, + updated.updated_at +FROM updated +LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id +` + +type UpsertChatSettingsParams struct { + ModelID pgtype.Text `json:"model_id"` + ID pgtype.UUID `json:"id"` +} + +type UpsertChatSettingsRow struct { + ChatID pgtype.UUID `json:"chat_id"` + ModelID pgtype.Text `json:"model_id"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +// chat_settings +func (q *Queries) UpsertChatSettings(ctx context.Context, arg UpsertChatSettingsParams) (UpsertChatSettingsRow, error) { + row := q.db.QueryRow(ctx, upsertChatSettings, arg.ModelID, arg.ID) + var i UpsertChatSettingsRow + err := row.Scan(&i.ChatID, &i.ModelID, &i.UpdatedAt) + return i, err +} diff --git a/internal/db/sqlc/history.sql.go b/internal/db/sqlc/history.sql.go deleted file mode 100644 index 0fc2033c..00000000 --- a/internal/db/sqlc/history.sql.go +++ /dev/null @@ -1,178 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. -// versions: -// sqlc v1.30.0 -// source: history.sql - -package sqlc - -import ( - "context" - - "github.com/jackc/pgx/v5/pgtype" -) - -const createHistory = `-- name: CreateHistory :one -INSERT INTO history (bot_id, session_id, messages, metadata, skills, timestamp) -VALUES ($1, $2, $3, $4, $5, $6) -RETURNING id, bot_id, session_id, messages, metadata, skills, timestamp -` - -type CreateHistoryParams struct { - BotID pgtype.UUID `json:"bot_id"` - SessionID string `json:"session_id"` - Messages []byte `json:"messages"` - Metadata []byte `json:"metadata"` - Skills []string `json:"skills"` - Timestamp pgtype.Timestamptz `json:"timestamp"` -} - -func (q *Queries) CreateHistory(ctx context.Context, arg CreateHistoryParams) (History, error) { - row := q.db.QueryRow(ctx, createHistory, - arg.BotID, - arg.SessionID, - arg.Messages, - arg.Metadata, - arg.Skills, - arg.Timestamp, - ) - var i History - err := row.Scan( - &i.ID, - &i.BotID, - &i.SessionID, - &i.Messages, - &i.Metadata, - &i.Skills, - &i.Timestamp, - ) - return i, err -} - -const deleteHistoryByBotSession = `-- name: DeleteHistoryByBotSession :exec -DELETE FROM history -WHERE bot_id = $1 AND session_id = $2 -` - -type DeleteHistoryByBotSessionParams struct { - BotID pgtype.UUID `json:"bot_id"` - SessionID string `json:"session_id"` -} - -func (q *Queries) DeleteHistoryByBotSession(ctx context.Context, arg DeleteHistoryByBotSessionParams) error { - _, err := q.db.Exec(ctx, deleteHistoryByBotSession, arg.BotID, arg.SessionID) - return err -} - -const deleteHistoryByID = `-- name: DeleteHistoryByID :exec -DELETE FROM history -WHERE id = $1 -` - -func (q *Queries) DeleteHistoryByID(ctx context.Context, id pgtype.UUID) error { - _, err := q.db.Exec(ctx, deleteHistoryByID, id) - return err -} - -const getHistoryByID = `-- name: GetHistoryByID :one -SELECT id, bot_id, session_id, messages, metadata, skills, timestamp -FROM history -WHERE id = $1 -` - -func (q *Queries) GetHistoryByID(ctx context.Context, id pgtype.UUID) (History, error) { - row := q.db.QueryRow(ctx, getHistoryByID, id) - var i History - err := row.Scan( - &i.ID, - &i.BotID, - &i.SessionID, - &i.Messages, - &i.Metadata, - &i.Skills, - &i.Timestamp, - ) - return i, err -} - -const listHistoryByBotSession = `-- name: ListHistoryByBotSession :many -SELECT id, bot_id, session_id, messages, metadata, skills, timestamp -FROM history -WHERE bot_id = $1 AND session_id = $2 -ORDER BY timestamp DESC -LIMIT $3 -` - -type ListHistoryByBotSessionParams struct { - BotID pgtype.UUID `json:"bot_id"` - SessionID string `json:"session_id"` - Limit int32 `json:"limit"` -} - -func (q *Queries) ListHistoryByBotSession(ctx context.Context, arg ListHistoryByBotSessionParams) ([]History, error) { - rows, err := q.db.Query(ctx, listHistoryByBotSession, arg.BotID, arg.SessionID, arg.Limit) - if err != nil { - return nil, err - } - defer rows.Close() - var items []History - for rows.Next() { - var i History - if err := rows.Scan( - &i.ID, - &i.BotID, - &i.SessionID, - &i.Messages, - &i.Metadata, - &i.Skills, - &i.Timestamp, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const listHistoryByBotSessionSince = `-- name: ListHistoryByBotSessionSince :many -SELECT id, bot_id, session_id, messages, metadata, skills, timestamp -FROM history -WHERE bot_id = $1 AND session_id = $2 AND timestamp >= $3 -ORDER BY timestamp ASC -` - -type ListHistoryByBotSessionSinceParams struct { - BotID pgtype.UUID `json:"bot_id"` - SessionID string `json:"session_id"` - Timestamp pgtype.Timestamptz `json:"timestamp"` -} - -func (q *Queries) ListHistoryByBotSessionSince(ctx context.Context, arg ListHistoryByBotSessionSinceParams) ([]History, error) { - rows, err := q.db.Query(ctx, listHistoryByBotSessionSince, arg.BotID, arg.SessionID, arg.Timestamp) - if err != nil { - return nil, err - } - defer rows.Close() - var items []History - for rows.Next() { - var i History - if err := rows.Scan( - &i.ID, - &i.BotID, - &i.SessionID, - &i.Messages, - &i.Metadata, - &i.Skills, - &i.Timestamp, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} diff --git a/internal/db/sqlc/messages.sql.go b/internal/db/sqlc/messages.sql.go new file mode 100644 index 00000000..dffb3037 --- /dev/null +++ b/internal/db/sqlc/messages.sql.go @@ -0,0 +1,409 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: messages.sql + +package sqlc + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createMessage = `-- name: CreateMessage :one +INSERT INTO bot_history_messages ( + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id, + channel_type, + source_message_id, + source_reply_to_message_id, + role, + content, + metadata +) +VALUES ( + $1, + $2::uuid, + $3::uuid, + $4::uuid, + $5::text, + $6::text, + $7::text, + $8, + $9, + $10 +) +RETURNING + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at +` + +type CreateMessageParams struct { + BotID pgtype.UUID `json:"bot_id"` + RouteID pgtype.UUID `json:"route_id"` + SenderChannelIdentityID pgtype.UUID `json:"sender_channel_identity_id"` + SenderUserID pgtype.UUID `json:"sender_user_id"` + Platform pgtype.Text `json:"platform"` + ExternalMessageID pgtype.Text `json:"external_message_id"` + SourceReplyToMessageID pgtype.Text `json:"source_reply_to_message_id"` + Role string `json:"role"` + Content []byte `json:"content"` + Metadata []byte `json:"metadata"` +} + +type CreateMessageRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + RouteID pgtype.UUID `json:"route_id"` + SenderChannelIdentityID pgtype.UUID `json:"sender_channel_identity_id"` + SenderUserID pgtype.UUID `json:"sender_user_id"` + Platform pgtype.Text `json:"platform"` + ExternalMessageID pgtype.Text `json:"external_message_id"` + SourceReplyToMessageID pgtype.Text `json:"source_reply_to_message_id"` + Role string `json:"role"` + Content []byte `json:"content"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + +func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (CreateMessageRow, error) { + row := q.db.QueryRow(ctx, createMessage, + arg.BotID, + arg.RouteID, + arg.SenderChannelIdentityID, + arg.SenderUserID, + arg.Platform, + arg.ExternalMessageID, + arg.SourceReplyToMessageID, + arg.Role, + arg.Content, + arg.Metadata, + ) + var i CreateMessageRow + err := row.Scan( + &i.ID, + &i.BotID, + &i.RouteID, + &i.SenderChannelIdentityID, + &i.SenderUserID, + &i.Platform, + &i.ExternalMessageID, + &i.SourceReplyToMessageID, + &i.Role, + &i.Content, + &i.Metadata, + &i.CreatedAt, + ) + return i, err +} + +const deleteMessagesByBot = `-- name: DeleteMessagesByBot :exec +DELETE FROM bot_history_messages +WHERE bot_id = $1 +` + +func (q *Queries) DeleteMessagesByBot(ctx context.Context, botID pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteMessagesByBot, botID) + return err +} + +const listMessages = `-- name: ListMessages :many +SELECT + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at +FROM bot_history_messages +WHERE bot_id = $1 +ORDER BY created_at ASC +` + +type ListMessagesRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + RouteID pgtype.UUID `json:"route_id"` + SenderChannelIdentityID pgtype.UUID `json:"sender_channel_identity_id"` + SenderUserID pgtype.UUID `json:"sender_user_id"` + Platform pgtype.Text `json:"platform"` + ExternalMessageID pgtype.Text `json:"external_message_id"` + SourceReplyToMessageID pgtype.Text `json:"source_reply_to_message_id"` + Role string `json:"role"` + Content []byte `json:"content"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + +func (q *Queries) ListMessages(ctx context.Context, botID pgtype.UUID) ([]ListMessagesRow, error) { + rows, err := q.db.Query(ctx, listMessages, botID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListMessagesRow + for rows.Next() { + var i ListMessagesRow + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.RouteID, + &i.SenderChannelIdentityID, + &i.SenderUserID, + &i.Platform, + &i.ExternalMessageID, + &i.SourceReplyToMessageID, + &i.Role, + &i.Content, + &i.Metadata, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listMessagesBefore = `-- name: ListMessagesBefore :many +SELECT + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at +FROM bot_history_messages +WHERE bot_id = $1 + AND created_at < $2 +ORDER BY created_at DESC +LIMIT $3 +` + +type ListMessagesBeforeParams struct { + BotID pgtype.UUID `json:"bot_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + MaxCount int32 `json:"max_count"` +} + +type ListMessagesBeforeRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + RouteID pgtype.UUID `json:"route_id"` + SenderChannelIdentityID pgtype.UUID `json:"sender_channel_identity_id"` + SenderUserID pgtype.UUID `json:"sender_user_id"` + Platform pgtype.Text `json:"platform"` + ExternalMessageID pgtype.Text `json:"external_message_id"` + SourceReplyToMessageID pgtype.Text `json:"source_reply_to_message_id"` + Role string `json:"role"` + Content []byte `json:"content"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + +func (q *Queries) ListMessagesBefore(ctx context.Context, arg ListMessagesBeforeParams) ([]ListMessagesBeforeRow, error) { + rows, err := q.db.Query(ctx, listMessagesBefore, arg.BotID, arg.CreatedAt, arg.MaxCount) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListMessagesBeforeRow + for rows.Next() { + var i ListMessagesBeforeRow + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.RouteID, + &i.SenderChannelIdentityID, + &i.SenderUserID, + &i.Platform, + &i.ExternalMessageID, + &i.SourceReplyToMessageID, + &i.Role, + &i.Content, + &i.Metadata, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listMessagesLatest = `-- name: ListMessagesLatest :many +SELECT + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at +FROM bot_history_messages +WHERE bot_id = $1 +ORDER BY created_at DESC +LIMIT $2 +` + +type ListMessagesLatestParams struct { + BotID pgtype.UUID `json:"bot_id"` + MaxCount int32 `json:"max_count"` +} + +type ListMessagesLatestRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + RouteID pgtype.UUID `json:"route_id"` + SenderChannelIdentityID pgtype.UUID `json:"sender_channel_identity_id"` + SenderUserID pgtype.UUID `json:"sender_user_id"` + Platform pgtype.Text `json:"platform"` + ExternalMessageID pgtype.Text `json:"external_message_id"` + SourceReplyToMessageID pgtype.Text `json:"source_reply_to_message_id"` + Role string `json:"role"` + Content []byte `json:"content"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + +func (q *Queries) ListMessagesLatest(ctx context.Context, arg ListMessagesLatestParams) ([]ListMessagesLatestRow, error) { + rows, err := q.db.Query(ctx, listMessagesLatest, arg.BotID, arg.MaxCount) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListMessagesLatestRow + for rows.Next() { + var i ListMessagesLatestRow + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.RouteID, + &i.SenderChannelIdentityID, + &i.SenderUserID, + &i.Platform, + &i.ExternalMessageID, + &i.SourceReplyToMessageID, + &i.Role, + &i.Content, + &i.Metadata, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listMessagesSince = `-- name: ListMessagesSince :many +SELECT + id, + bot_id, + route_id, + sender_channel_identity_id, + sender_account_user_id AS sender_user_id, + channel_type AS platform, + source_message_id AS external_message_id, + source_reply_to_message_id, + role, + content, + metadata, + created_at +FROM bot_history_messages +WHERE bot_id = $1 + AND created_at >= $2 +ORDER BY created_at ASC +` + +type ListMessagesSinceParams struct { + BotID pgtype.UUID `json:"bot_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + +type ListMessagesSinceRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + RouteID pgtype.UUID `json:"route_id"` + SenderChannelIdentityID pgtype.UUID `json:"sender_channel_identity_id"` + SenderUserID pgtype.UUID `json:"sender_user_id"` + Platform pgtype.Text `json:"platform"` + ExternalMessageID pgtype.Text `json:"external_message_id"` + SourceReplyToMessageID pgtype.Text `json:"source_reply_to_message_id"` + Role string `json:"role"` + Content []byte `json:"content"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + +func (q *Queries) ListMessagesSince(ctx context.Context, arg ListMessagesSinceParams) ([]ListMessagesSinceRow, error) { + rows, err := q.db.Query(ctx, listMessagesSince, arg.BotID, arg.CreatedAt) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListMessagesSinceRow + for rows.Next() { + var i ListMessagesSinceRow + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.RouteID, + &i.SenderChannelIdentityID, + &i.SenderUserID, + &i.Platform, + &i.ExternalMessageID, + &i.SourceReplyToMessageID, + &i.Role, + &i.Content, + &i.Metadata, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/db/sqlc/models.go b/internal/db/sqlc/models.go index 0341bf03..fff73aea 100644 --- a/internal/db/sqlc/models.go +++ b/internal/db/sqlc/models.go @@ -9,15 +9,22 @@ import ( ) type Bot struct { - ID pgtype.UUID `json:"id"` - OwnerUserID pgtype.UUID `json:"owner_user_id"` - Type string `json:"type"` - DisplayName pgtype.Text `json:"display_name"` - AvatarUrl pgtype.Text `json:"avatar_url"` - IsActive bool `json:"is_active"` - Metadata []byte `json:"metadata"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` + ID pgtype.UUID `json:"id"` + OwnerUserID pgtype.UUID `json:"owner_user_id"` + Type string `json:"type"` + DisplayName pgtype.Text `json:"display_name"` + AvatarUrl pgtype.Text `json:"avatar_url"` + IsActive bool `json:"is_active"` + Status string `json:"status"` + MaxContextLoadTime int32 `json:"max_context_load_time"` + Language string `json:"language"` + AllowGuest bool `json:"allow_guest"` + ChatModelID pgtype.UUID `json:"chat_model_id"` + MemoryModelID pgtype.UUID `json:"memory_model_id"` + EmbeddingModelID pgtype.UUID `json:"embedding_model_id"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` } type BotChannelConfig struct { @@ -35,6 +42,34 @@ type BotChannelConfig struct { UpdatedAt pgtype.Timestamptz `json:"updated_at"` } +type BotChannelRoute struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + ChannelType string `json:"channel_type"` + ChannelConfigID pgtype.UUID `json:"channel_config_id"` + ExternalConversationID string `json:"external_conversation_id"` + ExternalThreadID pgtype.Text `json:"external_thread_id"` + DefaultReplyTarget pgtype.Text `json:"default_reply_target"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +type BotHistoryMessage struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + RouteID pgtype.UUID `json:"route_id"` + SenderChannelIdentityID pgtype.UUID `json:"sender_channel_identity_id"` + SenderAccountUserID pgtype.UUID `json:"sender_account_user_id"` + ChannelType pgtype.Text `json:"channel_type"` + SourceMessageID pgtype.Text `json:"source_message_id"` + SourceReplyToMessageID pgtype.Text `json:"source_reply_to_message_id"` + Role string `json:"role"` + Content []byte `json:"content"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + type BotMember struct { BotID pgtype.UUID `json:"bot_id"` UserID pgtype.UUID `json:"user_id"` @@ -42,13 +77,6 @@ type BotMember struct { CreatedAt pgtype.Timestamptz `json:"created_at"` } -type BotModelConfig struct { - BotID pgtype.UUID `json:"bot_id"` - ChatModelID pgtype.UUID `json:"chat_model_id"` - EmbeddingModelID pgtype.UUID `json:"embedding_model_id"` - MemoryModelID pgtype.UUID `json:"memory_model_id"` -} - type BotPreauthKey struct { ID pgtype.UUID `json:"id"` BotID pgtype.UUID `json:"bot_id"` @@ -59,49 +87,26 @@ type BotPreauthKey struct { CreatedAt pgtype.Timestamptz `json:"created_at"` } -type BotSetting struct { - BotID pgtype.UUID `json:"bot_id"` - MaxContextLoadTime int32 `json:"max_context_load_time"` - Language string `json:"language"` - AllowGuest bool `json:"allow_guest"` +type ChannelIdentity struct { + ID pgtype.UUID `json:"id"` + UserID pgtype.UUID `json:"user_id"` + ChannelType string `json:"channel_type"` + ChannelSubjectID string `json:"channel_subject_id"` + DisplayName pgtype.Text `json:"display_name"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` } -type ChannelSession struct { - SessionID string `json:"session_id"` - BotID pgtype.UUID `json:"bot_id"` - ChannelConfigID pgtype.UUID `json:"channel_config_id"` - UserID pgtype.UUID `json:"user_id"` - ContactID pgtype.UUID `json:"contact_id"` - Platform string `json:"platform"` - ReplyTarget pgtype.Text `json:"reply_target"` - ThreadID pgtype.Text `json:"thread_id"` - Metadata []byte `json:"metadata"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` -} - -type Contact struct { - ID pgtype.UUID `json:"id"` - BotID pgtype.UUID `json:"bot_id"` - UserID pgtype.UUID `json:"user_id"` - DisplayName pgtype.Text `json:"display_name"` - Alias pgtype.Text `json:"alias"` - Tags []string `json:"tags"` - Status string `json:"status"` - Metadata []byte `json:"metadata"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` -} - -type ContactChannel struct { - ID pgtype.UUID `json:"id"` - BotID pgtype.UUID `json:"bot_id"` - ContactID pgtype.UUID `json:"contact_id"` - Platform string `json:"platform"` - ExternalID string `json:"external_id"` - Metadata []byte `json:"metadata"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` +type ChannelIdentityBindCode struct { + ID pgtype.UUID `json:"id"` + Token string `json:"token"` + IssuedByUserID pgtype.UUID `json:"issued_by_user_id"` + ChannelType pgtype.Text `json:"channel_type"` + ExpiresAt pgtype.Timestamptz `json:"expires_at"` + UsedAt pgtype.Timestamptz `json:"used_at"` + UsedByChannelIdentityID pgtype.UUID `json:"used_by_channel_identity_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` } type Container struct { @@ -129,27 +134,6 @@ type ContainerVersion struct { CreatedAt pgtype.Timestamptz `json:"created_at"` } -type Conversation struct { - ID pgtype.UUID `json:"id"` - BotID pgtype.UUID `json:"bot_id"` - SessionID string `json:"session_id"` - ChannelType string `json:"channel_type"` - ChatID pgtype.Text `json:"chat_id"` - SenderID pgtype.Text `json:"sender_id"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` -} - -type History struct { - ID pgtype.UUID `json:"id"` - BotID pgtype.UUID `json:"bot_id"` - SessionID string `json:"session_id"` - Messages []byte `json:"messages"` - Metadata []byte `json:"metadata"` - Skills []string `json:"skills"` - Timestamp pgtype.Timestamptz `json:"timestamp"` -} - type LifecycleEvent struct { ID string `json:"id"` ContainerID string `json:"container_id"` @@ -240,18 +224,24 @@ type Subagent struct { } type User struct { - ID pgtype.UUID `json:"id"` - Username string `json:"username"` - Email pgtype.Text `json:"email"` - PasswordHash string `json:"password_hash"` - Role string `json:"role"` - DisplayName pgtype.Text `json:"display_name"` - AvatarUrl pgtype.Text `json:"avatar_url"` - IsActive bool `json:"is_active"` - DataRoot pgtype.Text `json:"data_root"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` - LastLoginAt pgtype.Timestamptz `json:"last_login_at"` + ID pgtype.UUID `json:"id"` + Username pgtype.Text `json:"username"` + Email pgtype.Text `json:"email"` + PasswordHash pgtype.Text `json:"password_hash"` + Role string `json:"role"` + DisplayName pgtype.Text `json:"display_name"` + AvatarUrl pgtype.Text `json:"avatar_url"` + DataRoot pgtype.Text `json:"data_root"` + LastLoginAt pgtype.Timestamptz `json:"last_login_at"` + ChatModelID pgtype.Text `json:"chat_model_id"` + MemoryModelID pgtype.Text `json:"memory_model_id"` + EmbeddingModelID pgtype.Text `json:"embedding_model_id"` + MaxContextLoadTime int32 `json:"max_context_load_time"` + Language string `json:"language"` + IsActive bool `json:"is_active"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` } type UserChannelBinding struct { @@ -262,12 +252,3 @@ type UserChannelBinding struct { CreatedAt pgtype.Timestamptz `json:"created_at"` UpdatedAt pgtype.Timestamptz `json:"updated_at"` } - -type UserSetting struct { - UserID pgtype.UUID `json:"user_id"` - ChatModelID pgtype.Text `json:"chat_model_id"` - MemoryModelID pgtype.Text `json:"memory_model_id"` - EmbeddingModelID pgtype.Text `json:"embedding_model_id"` - MaxContextLoadTime int32 `json:"max_context_load_time"` - Language string `json:"language"` -} diff --git a/internal/db/sqlc/settings.sql.go b/internal/db/sqlc/settings.sql.go index 25d96c07..f931844d 100644 --- a/internal/db/sqlc/settings.sql.go +++ b/internal/db/sqlc/settings.sql.go @@ -12,173 +12,67 @@ import ( ) const deleteSettingsByBotID = `-- name: DeleteSettingsByBotID :exec -DELETE FROM bot_settings -WHERE bot_id = $1 +UPDATE bots +SET max_context_load_time = 1440, + language = 'auto', + allow_guest = false, + updated_at = now() +WHERE id = $1 ` -func (q *Queries) DeleteSettingsByBotID(ctx context.Context, botID pgtype.UUID) error { - _, err := q.db.Exec(ctx, deleteSettingsByBotID, botID) +func (q *Queries) DeleteSettingsByBotID(ctx context.Context, id pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteSettingsByBotID, id) return err } -const getBotModelConfigByBotID = `-- name: GetBotModelConfigByBotID :one +const getSettingsByBotID = `-- name: GetSettingsByBotID :one SELECT - bot_model_configs.bot_id, + bots.id AS bot_id, + bots.max_context_load_time, + bots.language, + bots.allow_guest, chat_models.model_id AS chat_model_id, memory_models.model_id AS memory_model_id, embedding_models.model_id AS embedding_model_id -FROM bot_model_configs -LEFT JOIN models AS chat_models ON chat_models.id = bot_model_configs.chat_model_id -LEFT JOIN models AS memory_models ON memory_models.id = bot_model_configs.memory_model_id -LEFT JOIN models AS embedding_models ON embedding_models.id = bot_model_configs.embedding_model_id -WHERE bot_model_configs.bot_id = $1 +FROM bots +LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id +LEFT JOIN models AS memory_models ON memory_models.id = bots.memory_model_id +LEFT JOIN models AS embedding_models ON embedding_models.id = bots.embedding_model_id +WHERE bots.id = $1 ` -type GetBotModelConfigByBotIDRow struct { - BotID pgtype.UUID `json:"bot_id"` - ChatModelID pgtype.Text `json:"chat_model_id"` - MemoryModelID pgtype.Text `json:"memory_model_id"` - EmbeddingModelID pgtype.Text `json:"embedding_model_id"` +type GetSettingsByBotIDRow struct { + BotID pgtype.UUID `json:"bot_id"` + MaxContextLoadTime int32 `json:"max_context_load_time"` + Language string `json:"language"` + AllowGuest bool `json:"allow_guest"` + ChatModelID pgtype.Text `json:"chat_model_id"` + MemoryModelID pgtype.Text `json:"memory_model_id"` + EmbeddingModelID pgtype.Text `json:"embedding_model_id"` } -func (q *Queries) GetBotModelConfigByBotID(ctx context.Context, botID pgtype.UUID) (GetBotModelConfigByBotIDRow, error) { - row := q.db.QueryRow(ctx, getBotModelConfigByBotID, botID) - var i GetBotModelConfigByBotIDRow - err := row.Scan( - &i.BotID, - &i.ChatModelID, - &i.MemoryModelID, - &i.EmbeddingModelID, - ) - return i, err -} - -const getSettingsByBotID = `-- name: GetSettingsByBotID :one -SELECT bot_id, max_context_load_time, language, allow_guest -FROM bot_settings -WHERE bot_id = $1 -` - -func (q *Queries) GetSettingsByBotID(ctx context.Context, botID pgtype.UUID) (BotSetting, error) { - row := q.db.QueryRow(ctx, getSettingsByBotID, botID) - var i BotSetting +func (q *Queries) GetSettingsByBotID(ctx context.Context, id pgtype.UUID) (GetSettingsByBotIDRow, error) { + row := q.db.QueryRow(ctx, getSettingsByBotID, id) + var i GetSettingsByBotIDRow err := row.Scan( &i.BotID, &i.MaxContextLoadTime, &i.Language, &i.AllowGuest, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, ) return i, err } const getSettingsByUserID = `-- name: GetSettingsByUserID :one -SELECT user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language -FROM user_settings -WHERE user_id = $1 +SELECT id AS user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language +FROM users +WHERE id = $1 ` -func (q *Queries) GetSettingsByUserID(ctx context.Context, userID pgtype.UUID) (UserSetting, error) { - row := q.db.QueryRow(ctx, getSettingsByUserID, userID) - var i UserSetting - err := row.Scan( - &i.UserID, - &i.ChatModelID, - &i.MemoryModelID, - &i.EmbeddingModelID, - &i.MaxContextLoadTime, - &i.Language, - ) - return i, err -} - -const upsertBotModelConfig = `-- name: UpsertBotModelConfig :one -INSERT INTO bot_model_configs (bot_id, chat_model_id, memory_model_id, embedding_model_id) -VALUES ($1, $2, $3, $4) -ON CONFLICT (bot_id) DO UPDATE SET - chat_model_id = COALESCE(EXCLUDED.chat_model_id, bot_model_configs.chat_model_id), - memory_model_id = COALESCE(EXCLUDED.memory_model_id, bot_model_configs.memory_model_id), - embedding_model_id = COALESCE(EXCLUDED.embedding_model_id, bot_model_configs.embedding_model_id) -RETURNING bot_id, chat_model_id, memory_model_id, embedding_model_id -` - -type UpsertBotModelConfigParams struct { - BotID pgtype.UUID `json:"bot_id"` - ChatModelID pgtype.UUID `json:"chat_model_id"` - MemoryModelID pgtype.UUID `json:"memory_model_id"` - EmbeddingModelID pgtype.UUID `json:"embedding_model_id"` -} - -type UpsertBotModelConfigRow struct { - BotID pgtype.UUID `json:"bot_id"` - ChatModelID pgtype.UUID `json:"chat_model_id"` - MemoryModelID pgtype.UUID `json:"memory_model_id"` - EmbeddingModelID pgtype.UUID `json:"embedding_model_id"` -} - -func (q *Queries) UpsertBotModelConfig(ctx context.Context, arg UpsertBotModelConfigParams) (UpsertBotModelConfigRow, error) { - row := q.db.QueryRow(ctx, upsertBotModelConfig, - arg.BotID, - arg.ChatModelID, - arg.MemoryModelID, - arg.EmbeddingModelID, - ) - var i UpsertBotModelConfigRow - err := row.Scan( - &i.BotID, - &i.ChatModelID, - &i.MemoryModelID, - &i.EmbeddingModelID, - ) - return i, err -} - -const upsertBotSettings = `-- name: UpsertBotSettings :one -INSERT INTO bot_settings (bot_id, max_context_load_time, language, allow_guest) -VALUES ($1, $2, $3, $4) -ON CONFLICT (bot_id) DO UPDATE SET - max_context_load_time = EXCLUDED.max_context_load_time, - language = EXCLUDED.language, - allow_guest = EXCLUDED.allow_guest -RETURNING bot_id, max_context_load_time, language, allow_guest -` - -type UpsertBotSettingsParams struct { - BotID pgtype.UUID `json:"bot_id"` - MaxContextLoadTime int32 `json:"max_context_load_time"` - Language string `json:"language"` - AllowGuest bool `json:"allow_guest"` -} - -func (q *Queries) UpsertBotSettings(ctx context.Context, arg UpsertBotSettingsParams) (BotSetting, error) { - row := q.db.QueryRow(ctx, upsertBotSettings, - arg.BotID, - arg.MaxContextLoadTime, - arg.Language, - arg.AllowGuest, - ) - var i BotSetting - err := row.Scan( - &i.BotID, - &i.MaxContextLoadTime, - &i.Language, - &i.AllowGuest, - ) - return i, err -} - -const upsertUserSettings = `-- name: UpsertUserSettings :one -INSERT INTO user_settings (user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language) -VALUES ($1, $2, $3, $4, $5, $6) -ON CONFLICT (user_id) DO UPDATE SET - chat_model_id = EXCLUDED.chat_model_id, - memory_model_id = EXCLUDED.memory_model_id, - embedding_model_id = EXCLUDED.embedding_model_id, - max_context_load_time = EXCLUDED.max_context_load_time, - language = EXCLUDED.language -RETURNING user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language -` - -type UpsertUserSettingsParams struct { +type GetSettingsByUserIDRow struct { UserID pgtype.UUID `json:"user_id"` ChatModelID pgtype.Text `json:"chat_model_id"` MemoryModelID pgtype.Text `json:"memory_model_id"` @@ -187,16 +81,130 @@ type UpsertUserSettingsParams struct { Language string `json:"language"` } -func (q *Queries) UpsertUserSettings(ctx context.Context, arg UpsertUserSettingsParams) (UserSetting, error) { - row := q.db.QueryRow(ctx, upsertUserSettings, - arg.UserID, - arg.ChatModelID, - arg.MemoryModelID, - arg.EmbeddingModelID, - arg.MaxContextLoadTime, - arg.Language, - ) - var i UserSetting +func (q *Queries) GetSettingsByUserID(ctx context.Context, id pgtype.UUID) (GetSettingsByUserIDRow, error) { + row := q.db.QueryRow(ctx, getSettingsByUserID, id) + var i GetSettingsByUserIDRow + err := row.Scan( + &i.UserID, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + ) + return i, err +} + +const upsertBotSettings = `-- name: UpsertBotSettings :one +WITH updated AS ( + UPDATE bots + SET max_context_load_time = $1, + language = $2, + allow_guest = $3, + chat_model_id = COALESCE($4::uuid, bots.chat_model_id), + memory_model_id = COALESCE($5::uuid, bots.memory_model_id), + embedding_model_id = COALESCE($6::uuid, bots.embedding_model_id), + updated_at = now() + WHERE bots.id = $7 + RETURNING bots.id, bots.max_context_load_time, bots.language, bots.allow_guest, bots.chat_model_id, bots.memory_model_id, bots.embedding_model_id +) +SELECT + updated.id AS bot_id, + updated.max_context_load_time, + updated.language, + updated.allow_guest, + chat_models.model_id AS chat_model_id, + memory_models.model_id AS memory_model_id, + embedding_models.model_id AS embedding_model_id +FROM updated +LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id +LEFT JOIN models AS memory_models ON memory_models.id = updated.memory_model_id +LEFT JOIN models AS embedding_models ON embedding_models.id = updated.embedding_model_id +` + +type UpsertBotSettingsParams struct { + MaxContextLoadTime int32 `json:"max_context_load_time"` + Language string `json:"language"` + AllowGuest bool `json:"allow_guest"` + ChatModelID pgtype.UUID `json:"chat_model_id"` + MemoryModelID pgtype.UUID `json:"memory_model_id"` + EmbeddingModelID pgtype.UUID `json:"embedding_model_id"` + ID pgtype.UUID `json:"id"` +} + +type UpsertBotSettingsRow struct { + BotID pgtype.UUID `json:"bot_id"` + MaxContextLoadTime int32 `json:"max_context_load_time"` + Language string `json:"language"` + AllowGuest bool `json:"allow_guest"` + ChatModelID pgtype.Text `json:"chat_model_id"` + MemoryModelID pgtype.Text `json:"memory_model_id"` + EmbeddingModelID pgtype.Text `json:"embedding_model_id"` +} + +func (q *Queries) UpsertBotSettings(ctx context.Context, arg UpsertBotSettingsParams) (UpsertBotSettingsRow, error) { + row := q.db.QueryRow(ctx, upsertBotSettings, + arg.MaxContextLoadTime, + arg.Language, + arg.AllowGuest, + arg.ChatModelID, + arg.MemoryModelID, + arg.EmbeddingModelID, + arg.ID, + ) + var i UpsertBotSettingsRow + err := row.Scan( + &i.BotID, + &i.MaxContextLoadTime, + &i.Language, + &i.AllowGuest, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + ) + return i, err +} + +const upsertUserSettings = `-- name: UpsertUserSettings :one +UPDATE users +SET chat_model_id = $2, + memory_model_id = $3, + embedding_model_id = $4, + max_context_load_time = $5, + language = $6, + updated_at = now() +WHERE id = $1 +RETURNING id AS user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language +` + +type UpsertUserSettingsParams struct { + ID pgtype.UUID `json:"id"` + ChatModelID pgtype.Text `json:"chat_model_id"` + MemoryModelID pgtype.Text `json:"memory_model_id"` + EmbeddingModelID pgtype.Text `json:"embedding_model_id"` + MaxContextLoadTime int32 `json:"max_context_load_time"` + Language string `json:"language"` +} + +type UpsertUserSettingsRow struct { + UserID pgtype.UUID `json:"user_id"` + ChatModelID pgtype.Text `json:"chat_model_id"` + MemoryModelID pgtype.Text `json:"memory_model_id"` + EmbeddingModelID pgtype.Text `json:"embedding_model_id"` + MaxContextLoadTime int32 `json:"max_context_load_time"` + Language string `json:"language"` +} + +func (q *Queries) UpsertUserSettings(ctx context.Context, arg UpsertUserSettingsParams) (UpsertUserSettingsRow, error) { + row := q.db.QueryRow(ctx, upsertUserSettings, + arg.ID, + arg.ChatModelID, + arg.MemoryModelID, + arg.EmbeddingModelID, + arg.MaxContextLoadTime, + arg.Language, + ) + var i UpsertUserSettingsRow err := row.Scan( &i.UserID, &i.ChatModelID, diff --git a/internal/db/sqlc/users.sql.go b/internal/db/sqlc/users.sql.go index d421dbd3..60c2aa4a 100644 --- a/internal/db/sqlc/users.sql.go +++ b/internal/db/sqlc/users.sql.go @@ -11,45 +11,49 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -const countUsers = `-- name: CountUsers :one -SELECT COUNT(*)::bigint AS count FROM users +const countAccounts = `-- name: CountAccounts :one +SELECT COUNT(*)::bigint AS count +FROM users +WHERE username IS NOT NULL + AND password_hash IS NOT NULL ` -func (q *Queries) CountUsers(ctx context.Context) (int64, error) { - row := q.db.QueryRow(ctx, countUsers) +func (q *Queries) CountAccounts(ctx context.Context) (int64, error) { + row := q.db.QueryRow(ctx, countAccounts) var count int64 err := row.Scan(&count) return count, err } -const createUser = `-- name: CreateUser :one -INSERT INTO users (username, email, password_hash, role, display_name, avatar_url, is_active, data_root) -VALUES ( - $1, - $2, - $3, - $4::user_role, - $5, - $6, - $7, - $8 -) -RETURNING id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at +const createAccount = `-- name: CreateAccount :one +UPDATE users +SET username = $1, + email = $2, + password_hash = $3, + role = $4::user_role, + display_name = $5, + avatar_url = $6, + is_active = $7, + data_root = $8, + updated_at = now() +WHERE id = $9 +RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at ` -type CreateUserParams struct { - Username string `json:"username"` +type CreateAccountParams struct { + Username pgtype.Text `json:"username"` Email pgtype.Text `json:"email"` - PasswordHash string `json:"password_hash"` + PasswordHash pgtype.Text `json:"password_hash"` Role string `json:"role"` DisplayName pgtype.Text `json:"display_name"` AvatarUrl pgtype.Text `json:"avatar_url"` IsActive bool `json:"is_active"` DataRoot pgtype.Text `json:"data_root"` + UserID pgtype.UUID `json:"user_id"` } -func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { - row := q.db.QueryRow(ctx, createUser, +func (q *Queries) CreateAccount(ctx context.Context, arg CreateAccountParams) (User, error) { + row := q.db.QueryRow(ctx, createAccount, arg.Username, arg.Email, arg.PasswordHash, @@ -58,6 +62,7 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e arg.AvatarUrl, arg.IsActive, arg.DataRoot, + arg.UserID, ) var i User err := row.Scan( @@ -68,55 +73,34 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e &i.Role, &i.DisplayName, &i.AvatarUrl, - &i.IsActive, &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, - &i.LastLoginAt, ) return i, err } -const createUserWithID = `-- name: CreateUserWithID :one -INSERT INTO users (id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root) -VALUES ( - $1, - $2, - $3, - $4, - $5::user_role, - $6, - $7, - $8, - $9 -) -RETURNING id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at +const createUser = `-- name: CreateUser :one +INSERT INTO users (is_active, metadata) +VALUES ($1, $2) +RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at ` -type CreateUserWithIDParams struct { - ID pgtype.UUID `json:"id"` - Username string `json:"username"` - Email pgtype.Text `json:"email"` - PasswordHash string `json:"password_hash"` - Role string `json:"role"` - DisplayName pgtype.Text `json:"display_name"` - AvatarUrl pgtype.Text `json:"avatar_url"` - IsActive bool `json:"is_active"` - DataRoot pgtype.Text `json:"data_root"` +type CreateUserParams struct { + IsActive bool `json:"is_active"` + Metadata []byte `json:"metadata"` } -func (q *Queries) CreateUserWithID(ctx context.Context, arg CreateUserWithIDParams) (User, error) { - row := q.db.QueryRow(ctx, createUserWithID, - arg.ID, - arg.Username, - arg.Email, - arg.PasswordHash, - arg.Role, - arg.DisplayName, - arg.AvatarUrl, - arg.IsActive, - arg.DataRoot, - ) +func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { + row := q.db.QueryRow(ctx, createUser, arg.IsActive, arg.Metadata) var i User err := row.Scan( &i.ID, @@ -126,17 +110,115 @@ func (q *Queries) CreateUserWithID(ctx context.Context, arg CreateUserWithIDPara &i.Role, &i.DisplayName, &i.AvatarUrl, - &i.IsActive, &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, + ) + return i, err +} + +const getAccountByIdentity = `-- name: GetAccountByIdentity :one +SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at FROM users WHERE username = $1 OR email = $1 +` + +func (q *Queries) GetAccountByIdentity(ctx context.Context, identity pgtype.Text) (User, error) { + row := q.db.QueryRow(ctx, getAccountByIdentity, identity) + var i User + err := row.Scan( + &i.ID, + &i.Username, + &i.Email, + &i.PasswordHash, + &i.Role, + &i.DisplayName, + &i.AvatarUrl, + &i.DataRoot, &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getAccountByUserID = `-- name: GetAccountByUserID :one +SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at FROM users WHERE id = $1 +` + +func (q *Queries) GetAccountByUserID(ctx context.Context, userID pgtype.UUID) (User, error) { + row := q.db.QueryRow(ctx, getAccountByUserID, userID) + var i User + err := row.Scan( + &i.ID, + &i.Username, + &i.Email, + &i.PasswordHash, + &i.Role, + &i.DisplayName, + &i.AvatarUrl, + &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getAccountByUsername = `-- name: GetAccountByUsername :one +SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at FROM users WHERE username = $1 +` + +func (q *Queries) GetAccountByUsername(ctx context.Context, username pgtype.Text) (User, error) { + row := q.db.QueryRow(ctx, getAccountByUsername, username) + var i User + err := row.Scan( + &i.ID, + &i.Username, + &i.Email, + &i.PasswordHash, + &i.Role, + &i.DisplayName, + &i.AvatarUrl, + &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, ) return i, err } const getUserByID = `-- name: GetUserByID :one -SELECT id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at FROM users WHERE id = $1 +SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at +FROM users +WHERE id = $1 ` func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error) { @@ -150,70 +232,29 @@ func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error) &i.Role, &i.DisplayName, &i.AvatarUrl, - &i.IsActive, &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, - &i.LastLoginAt, ) return i, err } -const getUserByIdentity = `-- name: GetUserByIdentity :one -SELECT id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at FROM users WHERE username = $1 OR email = $1 -` - -func (q *Queries) GetUserByIdentity(ctx context.Context, identity string) (User, error) { - row := q.db.QueryRow(ctx, getUserByIdentity, identity) - var i User - err := row.Scan( - &i.ID, - &i.Username, - &i.Email, - &i.PasswordHash, - &i.Role, - &i.DisplayName, - &i.AvatarUrl, - &i.IsActive, - &i.DataRoot, - &i.CreatedAt, - &i.UpdatedAt, - &i.LastLoginAt, - ) - return i, err -} - -const getUserByUsername = `-- name: GetUserByUsername :one -SELECT id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at FROM users WHERE username = $1 -` - -func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User, error) { - row := q.db.QueryRow(ctx, getUserByUsername, username) - var i User - err := row.Scan( - &i.ID, - &i.Username, - &i.Email, - &i.PasswordHash, - &i.Role, - &i.DisplayName, - &i.AvatarUrl, - &i.IsActive, - &i.DataRoot, - &i.CreatedAt, - &i.UpdatedAt, - &i.LastLoginAt, - ) - return i, err -} - -const listUsers = `-- name: ListUsers :many -SELECT id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at FROM users +const listAccounts = `-- name: ListAccounts :many +SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at FROM users +WHERE username IS NOT NULL ORDER BY created_at DESC ` -func (q *Queries) ListUsers(ctx context.Context) ([]User, error) { - rows, err := q.db.Query(ctx, listUsers) +func (q *Queries) ListAccounts(ctx context.Context) ([]User, error) { + rows, err := q.db.Query(ctx, listAccounts) if err != nil { return nil, err } @@ -229,11 +270,17 @@ func (q *Queries) ListUsers(ctx context.Context) ([]User, error) { &i.Role, &i.DisplayName, &i.AvatarUrl, - &i.IsActive, &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, - &i.LastLoginAt, ); err != nil { return nil, err } @@ -245,7 +292,7 @@ func (q *Queries) ListUsers(ctx context.Context) ([]User, error) { return items, nil } -const updateUserAdmin = `-- name: UpdateUserAdmin :one +const updateAccountAdmin = `-- name: UpdateAccountAdmin :one UPDATE users SET role = $1::user_role, display_name = $2, @@ -253,24 +300,24 @@ SET role = $1::user_role, is_active = $4, updated_at = now() WHERE id = $5 -RETURNING id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at +RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at ` -type UpdateUserAdminParams struct { +type UpdateAccountAdminParams struct { Role string `json:"role"` DisplayName pgtype.Text `json:"display_name"` AvatarUrl pgtype.Text `json:"avatar_url"` IsActive bool `json:"is_active"` - ID pgtype.UUID `json:"id"` + UserID pgtype.UUID `json:"user_id"` } -func (q *Queries) UpdateUserAdmin(ctx context.Context, arg UpdateUserAdminParams) (User, error) { - row := q.db.QueryRow(ctx, updateUserAdmin, +func (q *Queries) UpdateAccountAdmin(ctx context.Context, arg UpdateAccountAdminParams) (User, error) { + row := q.db.QueryRow(ctx, updateAccountAdmin, arg.Role, arg.DisplayName, arg.AvatarUrl, arg.IsActive, - arg.ID, + arg.UserID, ) var i User err := row.Scan( @@ -281,25 +328,31 @@ func (q *Queries) UpdateUserAdmin(ctx context.Context, arg UpdateUserAdminParams &i.Role, &i.DisplayName, &i.AvatarUrl, - &i.IsActive, &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, - &i.LastLoginAt, ) return i, err } -const updateUserLastLogin = `-- name: UpdateUserLastLogin :one +const updateAccountLastLogin = `-- name: UpdateAccountLastLogin :one UPDATE users SET last_login_at = now(), updated_at = now() WHERE id = $1 -RETURNING id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at +RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at ` -func (q *Queries) UpdateUserLastLogin(ctx context.Context, id pgtype.UUID) (User, error) { - row := q.db.QueryRow(ctx, updateUserLastLogin, id) +func (q *Queries) UpdateAccountLastLogin(ctx context.Context, id pgtype.UUID) (User, error) { + row := q.db.QueryRow(ctx, updateAccountLastLogin, id) var i User err := row.Scan( &i.ID, @@ -309,30 +362,36 @@ func (q *Queries) UpdateUserLastLogin(ctx context.Context, id pgtype.UUID) (User &i.Role, &i.DisplayName, &i.AvatarUrl, - &i.IsActive, &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, - &i.LastLoginAt, ) return i, err } -const updateUserPassword = `-- name: UpdateUserPassword :one +const updateAccountPassword = `-- name: UpdateAccountPassword :one UPDATE users SET password_hash = $2, updated_at = now() WHERE id = $1 -RETURNING id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at +RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at ` -type UpdateUserPasswordParams struct { +type UpdateAccountPasswordParams struct { ID pgtype.UUID `json:"id"` - PasswordHash string `json:"password_hash"` + PasswordHash pgtype.Text `json:"password_hash"` } -func (q *Queries) UpdateUserPassword(ctx context.Context, arg UpdateUserPasswordParams) (User, error) { - row := q.db.QueryRow(ctx, updateUserPassword, arg.ID, arg.PasswordHash) +func (q *Queries) UpdateAccountPassword(ctx context.Context, arg UpdateAccountPasswordParams) (User, error) { + row := q.db.QueryRow(ctx, updateAccountPassword, arg.ID, arg.PasswordHash) var i User err := row.Scan( &i.ID, @@ -342,34 +401,40 @@ func (q *Queries) UpdateUserPassword(ctx context.Context, arg UpdateUserPassword &i.Role, &i.DisplayName, &i.AvatarUrl, - &i.IsActive, &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, - &i.LastLoginAt, ) return i, err } -const updateUserProfile = `-- name: UpdateUserProfile :one +const updateAccountProfile = `-- name: UpdateAccountProfile :one UPDATE users SET display_name = $2, avatar_url = $3, is_active = $4, updated_at = now() WHERE id = $1 -RETURNING id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at +RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at ` -type UpdateUserProfileParams struct { +type UpdateAccountProfileParams struct { ID pgtype.UUID `json:"id"` DisplayName pgtype.Text `json:"display_name"` AvatarUrl pgtype.Text `json:"avatar_url"` IsActive bool `json:"is_active"` } -func (q *Queries) UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error) { - row := q.db.QueryRow(ctx, updateUserProfile, +func (q *Queries) UpdateAccountProfile(ctx context.Context, arg UpdateAccountProfileParams) (User, error) { + row := q.db.QueryRow(ctx, updateAccountProfile, arg.ID, arg.DisplayName, arg.AvatarUrl, @@ -384,26 +449,73 @@ func (q *Queries) UpdateUserProfile(ctx context.Context, arg UpdateUserProfilePa &i.Role, &i.DisplayName, &i.AvatarUrl, - &i.IsActive, &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, - &i.LastLoginAt, ) return i, err } -const upsertUserByUsername = `-- name: UpsertUserByUsername :one -INSERT INTO users (username, email, password_hash, role, display_name, avatar_url, is_active, data_root) +const updateUserStatus = `-- name: UpdateUserStatus :one +UPDATE users +SET is_active = $2, + updated_at = now() +WHERE id = $1 +RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at +` + +type UpdateUserStatusParams struct { + ID pgtype.UUID `json:"id"` + IsActive bool `json:"is_active"` +} + +func (q *Queries) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error) { + row := q.db.QueryRow(ctx, updateUserStatus, arg.ID, arg.IsActive) + var i User + err := row.Scan( + &i.ID, + &i.Username, + &i.Email, + &i.PasswordHash, + &i.Role, + &i.DisplayName, + &i.AvatarUrl, + &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const upsertAccountByUsername = `-- name: UpsertAccountByUsername :one +INSERT INTO users (id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, metadata) VALUES ( $1, $2, $3, - $4::user_role, - $5, + $4, + $5::user_role, $6, $7, - $8 + $8, + $9, + '{}'::jsonb ) ON CONFLICT (username) DO UPDATE SET email = EXCLUDED.email, @@ -414,13 +526,14 @@ ON CONFLICT (username) DO UPDATE SET is_active = EXCLUDED.is_active, data_root = EXCLUDED.data_root, updated_at = now() -RETURNING id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at +RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at ` -type UpsertUserByUsernameParams struct { - Username string `json:"username"` +type UpsertAccountByUsernameParams struct { + UserID pgtype.UUID `json:"user_id"` + Username pgtype.Text `json:"username"` Email pgtype.Text `json:"email"` - PasswordHash string `json:"password_hash"` + PasswordHash pgtype.Text `json:"password_hash"` Role string `json:"role"` DisplayName pgtype.Text `json:"display_name"` AvatarUrl pgtype.Text `json:"avatar_url"` @@ -428,8 +541,9 @@ type UpsertUserByUsernameParams struct { DataRoot pgtype.Text `json:"data_root"` } -func (q *Queries) UpsertUserByUsername(ctx context.Context, arg UpsertUserByUsernameParams) (User, error) { - row := q.db.QueryRow(ctx, upsertUserByUsername, +func (q *Queries) UpsertAccountByUsername(ctx context.Context, arg UpsertAccountByUsernameParams) (User, error) { + row := q.db.QueryRow(ctx, upsertAccountByUsername, + arg.UserID, arg.Username, arg.Email, arg.PasswordHash, @@ -448,11 +562,17 @@ func (q *Queries) UpsertUserByUsername(ctx context.Context, arg UpsertUserByUser &i.Role, &i.DisplayName, &i.AvatarUrl, - &i.IsActive, &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, - &i.LastLoginAt, ) return i, err } diff --git a/internal/db/text.go b/internal/db/text.go new file mode 100644 index 00000000..c0da1768 --- /dev/null +++ b/internal/db/text.go @@ -0,0 +1,11 @@ +package db + +import "github.com/jackc/pgx/v5/pgtype" + +// TextToString returns the string value of pgtype.Text, or "" when invalid. +func TextToString(value pgtype.Text) string { + if !value.Valid { + return "" + } + return value.String +} diff --git a/internal/db/text_test.go b/internal/db/text_test.go new file mode 100644 index 00000000..22fe0542 --- /dev/null +++ b/internal/db/text_test.go @@ -0,0 +1,26 @@ +package db + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgtype" +) + +func TestTextToString(t *testing.T) { + tests := []struct { + name string + value pgtype.Text + want string + }{ + {"valid", pgtype.Text{String: "hello", Valid: true}, "hello"}, + {"invalid", pgtype.Text{}, ""}, + {"valid empty", pgtype.Text{String: "", Valid: true}, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := TextToString(tt.value); got != tt.want { + t.Errorf("TextToString() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/internal/db/uuid.go b/internal/db/uuid.go index a0b3dfd3..7ba824ec 100644 --- a/internal/db/uuid.go +++ b/internal/db/uuid.go @@ -21,18 +21,6 @@ func ParseUUID(id string) (pgtype.UUID, error) { return pgID, nil } -// UUIDToString converts a pgtype.UUID to its string representation. -func UUIDToString(value pgtype.UUID) string { - if !value.Valid { - return "" - } - parsed, err := uuid.FromBytes(value.Bytes[:]) - if err != nil { - return "" - } - return parsed.String() -} - // TimeFromPg converts a pgtype.Timestamptz to time.Time. func TimeFromPg(value pgtype.Timestamptz) time.Time { if value.Valid { diff --git a/internal/directory/service.go b/internal/directory/service.go deleted file mode 100644 index 158039e1..00000000 --- a/internal/directory/service.go +++ /dev/null @@ -1,226 +0,0 @@ -package directory - -import ( - "context" - "errors" - "fmt" - "log/slog" - "strings" - - "github.com/memohai/memoh/internal/channel" - "github.com/memohai/memoh/internal/contacts" -) - -var ( - ErrNotFound = errors.New("directory entry not found") - ErrAmbiguous = errors.New("directory entry ambiguous") - ErrUnsupported = errors.New("directory operation unsupported") -) - -type ContactReader interface { - Search(ctx context.Context, botID, query string) ([]contacts.Contact, error) - ListByBot(ctx context.Context, botID string) ([]contacts.Contact, error) - ListChannelsByContact(ctx context.Context, contactID string) ([]contacts.ContactChannel, error) -} - -type ChannelSessionStore interface { - ListSessionsByBotPlatform(ctx context.Context, botID, platform string) ([]channel.ChannelSession, error) -} - -type LocalService struct { - contacts ContactReader - sessions ChannelSessionStore - logger *slog.Logger -} - -func NewLocalService(log *slog.Logger, contacts ContactReader, sessions ChannelSessionStore) *LocalService { - if log == nil { - log = slog.Default() - } - return &LocalService{ - contacts: contacts, - sessions: sessions, - logger: log.With(slog.String("service", "directory")), - } -} - -func (s *LocalService) ListPeers(ctx context.Context, botID, platform, query string, limit int) ([]channel.DirectoryEntry, error) { - if s.contacts == nil { - return nil, fmt.Errorf("contacts service not configured") - } - trimmed := strings.TrimSpace(query) - var items []contacts.Contact - var err error - if trimmed == "" { - items, err = s.contacts.ListByBot(ctx, botID) - } else { - items, err = s.contacts.Search(ctx, botID, trimmed) - } - if err != nil { - return nil, err - } - results := make([]channel.DirectoryEntry, 0, len(items)) - for _, contact := range items { - channels, err := s.contacts.ListChannelsByContact(ctx, contact.ID) - if err != nil { - if s.logger != nil { - s.logger.Warn("list contact channels failed", slog.String("contact_id", contact.ID), slog.Any("error", err)) - } - continue - } - for _, ch := range channels { - if platform != "" && ch.Platform != platform { - continue - } - entry := channel.DirectoryEntry{ - Kind: channel.DirectoryEntryUser, - ID: strings.TrimSpace(ch.ExternalID), - Name: chooseContactName(contact, ch), - Handle: strings.TrimSpace(contact.Alias), - Metadata: map[string]any{}, - } - if entry.ID == "" { - continue - } - entry.Metadata["contact_id"] = contact.ID - if contact.UserID != "" { - entry.Metadata["user_id"] = contact.UserID - } - entry.Metadata["platform"] = ch.Platform - results = append(results, entry) - if limit > 0 && len(results) >= limit { - return results, nil - } - } - } - return results, nil -} - -func (s *LocalService) ListGroups(ctx context.Context, botID, platform, query string, limit int) ([]channel.DirectoryEntry, error) { - if s.sessions == nil { - return nil, fmt.Errorf("channel session store not configured") - } - platform = strings.TrimSpace(platform) - if platform == "" { - return nil, fmt.Errorf("platform is required") - } - sessions, err := s.sessions.ListSessionsByBotPlatform(ctx, botID, platform) - if err != nil { - return nil, err - } - trimmed := strings.TrimSpace(query) - results := make([]channel.DirectoryEntry, 0, len(sessions)) - for _, session := range sessions { - if !isGroupSession(session) { - continue - } - name := channel.ReadString(session.Metadata, "conversation_name", "name") - entryID := strings.TrimSpace(session.ReplyTarget) - if entryID == "" { - entryID = strings.TrimSpace(session.SessionID) - } - if entryID == "" { - continue - } - if trimmed != "" && !matchesQuery(trimmed, entryID, name) { - continue - } - results = append(results, channel.DirectoryEntry{ - Kind: channel.DirectoryEntryGroup, - ID: entryID, - Name: strings.TrimSpace(name), - Metadata: session.Metadata, - }) - if limit > 0 && len(results) >= limit { - return results, nil - } - } - return results, nil -} - -func (s *LocalService) ListGroupMembers(ctx context.Context, botID, platform, groupID string, limit int) ([]channel.DirectoryEntry, error) { - return nil, ErrUnsupported -} - -func (s *LocalService) ResolveTarget(ctx context.Context, botID, platform, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { - trimmed := strings.TrimSpace(input) - if trimmed == "" { - return channel.DirectoryEntry{}, ErrNotFound - } - switch kind { - case channel.DirectoryEntryGroup: - items, err := s.ListGroups(ctx, botID, platform, trimmed, 5) - if err != nil { - return channel.DirectoryEntry{}, err - } - return pickSingleMatch(items, trimmed) - default: - items, err := s.ListPeers(ctx, botID, platform, trimmed, 5) - if err != nil { - return channel.DirectoryEntry{}, err - } - return pickSingleMatch(items, trimmed) - } -} - -func pickSingleMatch(items []channel.DirectoryEntry, input string) (channel.DirectoryEntry, error) { - if len(items) == 0 { - return channel.DirectoryEntry{}, ErrNotFound - } - if len(items) == 1 { - return items[0], nil - } - lower := strings.ToLower(strings.TrimSpace(input)) - var exact *channel.DirectoryEntry - for i := range items { - if strings.ToLower(strings.TrimSpace(items[i].ID)) == lower { - exact = &items[i] - break - } - if strings.ToLower(strings.TrimSpace(items[i].Name)) == lower { - exact = &items[i] - break - } - } - if exact != nil { - return *exact, nil - } - return channel.DirectoryEntry{}, ErrAmbiguous -} - -func chooseContactName(contact contacts.Contact, ch contacts.ContactChannel) string { - if strings.TrimSpace(contact.DisplayName) != "" { - return strings.TrimSpace(contact.DisplayName) - } - if strings.TrimSpace(contact.Alias) != "" { - return strings.TrimSpace(contact.Alias) - } - if strings.TrimSpace(ch.ExternalID) != "" { - return strings.TrimSpace(ch.ExternalID) - } - return "" -} - -func isGroupSession(session channel.ChannelSession) bool { - value := strings.ToLower(strings.TrimSpace(channel.ReadString(session.Metadata, "conversation_type", "chat_type", "type"))) - if value == "" { - return false - } - if strings.Contains(value, "group") { - return true - } - return false -} - -func matchesQuery(query string, fields ...string) bool { - needle := strings.ToLower(strings.TrimSpace(query)) - if needle == "" { - return true - } - for _, field := range fields { - if strings.Contains(strings.ToLower(strings.TrimSpace(field)), needle) { - return true - } - } - return false -} diff --git a/internal/embeddings/dashscope.go b/internal/embeddings/dashscope.go index c15fc797..f16dc424 100644 --- a/internal/embeddings/dashscope.go +++ b/internal/embeddings/dashscope.go @@ -33,7 +33,7 @@ type DashScopeUsage struct { } type dashScopeRequest struct { - Model string `json:"model"` + Model string `json:"model"` Input dashScopeRequestInput `json:"input"` } diff --git a/internal/embeddings/resolver.go b/internal/embeddings/resolver.go index 48d9d98c..f90e5cb3 100644 --- a/internal/embeddings/resolver.go +++ b/internal/embeddings/resolver.go @@ -12,6 +12,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/models" ) @@ -26,12 +27,12 @@ const ( ) type Request struct { - Type string - Provider string - Model string - Dimensions int - Input Input - UserID string + Type string + Provider string + Model string + Dimensions int + Input Input + ChannelIdentityID string } type Input struct { @@ -180,8 +181,8 @@ func (r *Resolver) selectEmbeddingModel(ctx context.Context, req Request) (model } // If no model specified and no provider specified, try to get per-user embedding model. - if req.Model == "" && req.Provider == "" && strings.TrimSpace(req.UserID) != "" { - modelID, err := r.loadUserEmbeddingModelID(ctx, req.UserID) + if req.Model == "" && req.Provider == "" && strings.TrimSpace(req.ChannelIdentityID) != "" { + modelID, err := r.loadChannelIdentityEmbeddingModelID(ctx, req.ChannelIdentityID) if err != nil { return models.GetResponse{}, err } @@ -257,15 +258,15 @@ func (r *Resolver) fetchProvider(ctx context.Context, providerID string) (sqlc.L return r.queries.GetLlmProviderByID(ctx, pgID) } -func (r *Resolver) loadUserEmbeddingModelID(ctx context.Context, userID string) (string, error) { +func (r *Resolver) loadChannelIdentityEmbeddingModelID(ctx context.Context, channelIdentityID string) (string, error) { if r.queries == nil { return "", nil } - pgUserID, err := parseUUID(userID) + pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) if err != nil { return "", err } - row, err := r.queries.GetSettingsByUserID(ctx, pgUserID) + row, err := r.queries.GetSettingsByUserID(ctx, pgChannelIdentityID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return "", nil @@ -275,13 +276,3 @@ func (r *Resolver) loadUserEmbeddingModelID(ctx context.Context, userID string) return strings.TrimSpace(row.EmbeddingModelID.String), nil } -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(id) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go index fb908109..dfb4503d 100644 --- a/internal/handlers/auth.go +++ b/internal/handlers/auth.go @@ -9,39 +9,38 @@ import ( "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" - "github.com/memohai/memoh/internal/boot" - "github.com/memohai/memoh/internal/users" ) type AuthHandler struct { - userService *users.Service - jwtSecret string - expiresIn time.Duration - logger *slog.Logger + accountService *accounts.Service + jwtSecret string + expiresIn time.Duration + logger *slog.Logger } type LoginRequest struct { - Username string `json:"username" validate:"required"` - Password string `json:"password" validate:"required"` + Username string `json:"username"` + Password string `json:"password"` } type LoginResponse struct { - AccessToken string `json:"access_token" validate:"required"` - TokenType string `json:"token_type" validate:"required"` - ExpiresAt string `json:"expires_at" validate:"required"` - UserID string `json:"user_id" validate:"required"` - Role string `json:"role" validate:"required"` - DisplayName string `json:"display_name" validate:"required"` - Username string `json:"username" validate:"required"` + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresAt string `json:"expires_at"` + UserID string `json:"user_id"` + Role string `json:"role"` + DisplayName string `json:"display_name"` + Username string `json:"username"` } -func NewAuthHandler(log *slog.Logger, userService *users.Service, runtimeConfig *boot.RuntimeConfig) *AuthHandler { +func NewAuthHandler(log *slog.Logger, accountService *accounts.Service, jwtSecret string, expiresIn time.Duration) *AuthHandler { return &AuthHandler{ - userService: userService, - jwtSecret: runtimeConfig.JwtSecret, - expiresIn: runtimeConfig.JwtExpiresIn, - logger: log.With(slog.String("handler", "auth")), + accountService: accountService, + jwtSecret: jwtSecret, + expiresIn: expiresIn, + logger: log.With(slog.String("handler", "auth")), } } @@ -60,7 +59,7 @@ func (h *AuthHandler) Register(e *echo.Echo) { // @Failure 500 {object} ErrorResponse // @Router /auth/login [post] func (h *AuthHandler) Login(c echo.Context) error { - if h.userService == nil { + if h.accountService == nil { return echo.NewHTTPError(http.StatusInternalServerError, "user service not configured") } if strings.TrimSpace(h.jwtSecret) == "" { @@ -79,17 +78,17 @@ func (h *AuthHandler) Login(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, "username and password are required") } - user, err := h.userService.Login(c.Request().Context(), req.Username, req.Password) + account, err := h.accountService.Login(c.Request().Context(), req.Username, req.Password) if err != nil { - if errors.Is(err, users.ErrInvalidCredentials) { + if errors.Is(err, accounts.ErrInvalidCredentials) { return echo.NewHTTPError(http.StatusUnauthorized, "invalid credentials") } - if errors.Is(err, users.ErrInactiveUser) { + if errors.Is(err, accounts.ErrInactiveAccount) { return echo.NewHTTPError(http.StatusUnauthorized, "user is inactive") } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - token, expiresAt, err := auth.GenerateToken(user.ID, h.jwtSecret, h.expiresIn) + token, expiresAt, err := auth.GenerateToken(account.ID, h.jwtSecret, h.expiresIn) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -98,9 +97,9 @@ func (h *AuthHandler) Login(c echo.Context) error { AccessToken: token, TokenType: "Bearer", ExpiresAt: expiresAt.Format(time.RFC3339), - UserID: user.ID, - Username: user.Username, - Role: user.Role, - DisplayName: user.DisplayName, + UserID: account.ID, + Username: account.Username, + Role: account.Role, + DisplayName: account.DisplayName, }) } diff --git a/internal/handlers/bind.go b/internal/handlers/bind.go new file mode 100644 index 00000000..0c106e08 --- /dev/null +++ b/internal/handlers/bind.go @@ -0,0 +1,91 @@ +package handlers + +import ( + "errors" + "io" + "log/slog" + "net/http" + "strings" + "time" + + "github.com/labstack/echo/v4" + + "github.com/memohai/memoh/internal/auth" + "github.com/memohai/memoh/internal/bind" + "github.com/memohai/memoh/internal/identity" +) + +// BindHandler manages channel identity bind code issuance via REST API. +type BindHandler struct { + service *bind.Service + logger *slog.Logger +} + +// NewBindHandler creates a BindHandler. +func NewBindHandler(log *slog.Logger, service *bind.Service) *BindHandler { + if log == nil { + log = slog.Default() + } + return &BindHandler{ + service: service, + logger: log.With(slog.String("handler", "bind")), + } +} + +// Register registers bind code routes. +func (h *BindHandler) Register(e *echo.Echo) { + e.POST("/users/me/bind_codes", h.Issue) +} + +type bindIssueRequest struct { + Platform string `json:"platform,omitempty"` + TTLSeconds int `json:"ttl_seconds,omitempty"` +} + +type bindIssueResponse struct { + Token string `json:"token"` + Platform string `json:"platform,omitempty"` + ExpiresAt time.Time `json:"expires_at"` +} + +// Issue creates a new bind code for the current user. +func (h *BindHandler) Issue(c echo.Context) error { + if h.service == nil { + return echo.NewHTTPError(http.StatusServiceUnavailable, "bind service not available") + } + userID, err := h.requireUserID(c) + if err != nil { + return err + } + + var req bindIssueRequest + if err := c.Bind(&req); err != nil && !errors.Is(err, io.EOF) { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + ttl := 24 * time.Hour + if req.TTLSeconds > 0 { + ttl = time.Duration(req.TTLSeconds) * time.Second + } + + code, err := h.service.Issue(c.Request().Context(), userID, strings.TrimSpace(req.Platform), ttl) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, bindIssueResponse{ + Token: code.Token, + Platform: code.Platform, + ExpiresAt: code.ExpiresAt, + }) +} + +func (h *BindHandler) requireUserID(c echo.Context) (string, error) { + userID, err := auth.UserIDFromContext(c) + if err != nil { + return "", err + } + if err := identity.ValidateChannelIdentityID(userID); err != nil { + return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + return userID, nil +} diff --git a/internal/handlers/channel.go b/internal/handlers/channel.go index 774004c8..9f8ec869 100644 --- a/internal/handlers/channel.go +++ b/internal/handlers/channel.go @@ -23,26 +23,26 @@ func NewChannelHandler(service *channel.Service, registry *channel.Registry) *Ch func (h *ChannelHandler) Register(e *echo.Echo) { group := e.Group("/users/me/channels") - group.GET("/:platform", h.GetUserConfig) - group.PUT("/:platform", h.UpsertUserConfig) + group.GET("/:platform", h.GetChannelIdentityConfig) + group.PUT("/:platform", h.UpsertChannelIdentityConfig) metaGroup := e.Group("/channels") metaGroup.GET("", h.ListChannels) metaGroup.GET("/:platform", h.GetChannel) } -// GetUserConfig godoc +// GetChannelIdentityConfig godoc // @Summary Get channel user config // @Description Get channel binding configuration for current user // @Tags channel // @Param platform path string true "Channel platform" -// @Success 200 {object} channel.ChannelUserBinding +// @Success 200 {object} channel.ChannelIdentityBinding // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /users/me/channels/{platform} [get] -func (h *ChannelHandler) GetUserConfig(c echo.Context) error { - userID, err := h.requireUserID(c) +func (h *ChannelHandler) GetChannelIdentityConfig(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -50,7 +50,7 @@ func (h *ChannelHandler) GetUserConfig(c echo.Context) error { if err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - resp, err := h.service.GetUserConfig(c.Request().Context(), userID, channelType) + resp, err := h.service.GetChannelIdentityConfig(c.Request().Context(), channelIdentityID, channelType) if err != nil { if strings.Contains(err.Error(), "not found") { return echo.NewHTTPError(http.StatusNotFound, err.Error()) @@ -60,18 +60,18 @@ func (h *ChannelHandler) GetUserConfig(c echo.Context) error { return c.JSON(http.StatusOK, resp) } -// UpsertUserConfig godoc +// UpsertChannelIdentityConfig godoc // @Summary Update channel user config // @Description Update channel binding configuration for current user // @Tags channel // @Param platform path string true "Channel platform" -// @Param payload body channel.UpsertUserConfigRequest true "Channel user config payload" -// @Success 200 {object} channel.ChannelUserBinding +// @Param payload body channel.UpsertChannelIdentityConfigRequest true "Channel user config payload" +// @Success 200 {object} channel.ChannelIdentityBinding // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /users/me/channels/{platform} [put] -func (h *ChannelHandler) UpsertUserConfig(c echo.Context) error { - userID, err := h.requireUserID(c) +func (h *ChannelHandler) UpsertChannelIdentityConfig(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -79,14 +79,14 @@ func (h *ChannelHandler) UpsertUserConfig(c echo.Context) error { if err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - var req channel.UpsertUserConfigRequest + var req channel.UpsertChannelIdentityConfigRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } if req.Config == nil { req.Config = map[string]any{} } - resp, err := h.service.UpsertUserConfig(c.Request().Context(), userID, channelType, req) + resp, err := h.service.UpsertChannelIdentityConfig(c.Request().Context(), channelIdentityID, channelType, req) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -94,11 +94,11 @@ func (h *ChannelHandler) UpsertUserConfig(c echo.Context) error { } type ChannelMeta struct { - Type string `json:"type" validate:"required"` - DisplayName string `json:"display_name" validate:"required"` + Type string `json:"type"` + DisplayName string `json:"display_name"` Configless bool `json:"configless"` - Capabilities channel.ChannelCapabilities `json:"capabilities" validate:"required"` - ConfigSchema channel.ConfigSchema `json:"config_schema" validate:"required"` + Capabilities channel.ChannelCapabilities `json:"capabilities"` + ConfigSchema channel.ConfigSchema `json:"config_schema"` UserConfigSchema channel.ConfigSchema `json:"user_config_schema"` TargetSpec channel.TargetSpec `json:"target_spec"` } @@ -160,13 +160,13 @@ func (h *ChannelHandler) GetChannel(c echo.Context) error { return c.JSON(http.StatusOK, resp) } -func (h *ChannelHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) +func (h *ChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.UserIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return userID, nil + return channelIdentityID, nil } diff --git a/internal/handlers/chat.go b/internal/handlers/chat.go deleted file mode 100644 index 9c44b2f1..00000000 --- a/internal/handlers/chat.go +++ /dev/null @@ -1,238 +0,0 @@ -package handlers - -import ( - "bufio" - "context" - "encoding/json" - "errors" - "fmt" - "log/slog" - "net/http" - "strings" - - "github.com/labstack/echo/v4" - - "github.com/memohai/memoh/internal/auth" - "github.com/memohai/memoh/internal/bots" - "github.com/memohai/memoh/internal/chat" - "github.com/memohai/memoh/internal/identity" - "github.com/memohai/memoh/internal/users" -) - -type ChatHandler struct { - resolver *chat.Resolver - botService *bots.Service - userService *users.Service - logger *slog.Logger -} - -func NewChatHandler(log *slog.Logger, resolver *chat.Resolver, botService *bots.Service, userService *users.Service) *ChatHandler { - return &ChatHandler{ - resolver: resolver, - botService: botService, - userService: userService, - logger: log.With(slog.String("handler", "chat")), - } -} - -func (h *ChatHandler) Register(e *echo.Echo) { - group := e.Group("/bots/:bot_id/chat") - group.POST("", h.Chat) - group.POST("/stream", h.StreamChat) -} - -// Chat godoc -// @Summary Chat with AI -// @Description Send a chat message and get a response. The system will automatically select an appropriate chat model from the database. -// @Tags chat -// @Accept json -// @Produce json -// @Param bot_id path string true "Bot ID" -// @Param request body chat.ChatRequest true "Chat request" -// @Success 200 {object} chat.ChatResponse -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/chat [post] -func (h *ChatHandler) Chat(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - - var req chat.ChatRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - - if req.Query == "" { - return echo.NewHTTPError(http.StatusBadRequest, "query is required") - } - req.BotID = botID - req.SessionID = sessionID - req.Token = c.Request().Header.Get("Authorization") - req.UserID = userID - if strings.TrimSpace(req.ContactID) == "" { - req.ContactID = userID - } - if strings.TrimSpace(req.ContactName) == "" { - req.ContactName = "User" - } - - resp, err := h.resolver.Chat(c.Request().Context(), req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - - return c.JSON(http.StatusOK, resp) -} - -// StreamChat godoc -// @Summary Stream chat with AI -// @Description Send a chat message and get a streaming response. The system will automatically select an appropriate chat model from the database. -// @Tags chat -// @Accept json -// @Produce text/event-stream -// @Param bot_id path string true "Bot ID" -// @Param request body chat.ChatRequest true "Chat request" -// @Success 200 {string} string -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/chat/stream [post] -func (h *ChatHandler) StreamChat(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - h.logger.Info("chat stream request received", - slog.String("bot_id", botID), - slog.String("session_id", c.QueryParam("session_id")), - slog.String("user_id", userID), - ) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - - var req chat.ChatRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - - if req.Query == "" { - return echo.NewHTTPError(http.StatusBadRequest, "query is required") - } - req.BotID = botID - req.SessionID = sessionID - req.Token = c.Request().Header.Get("Authorization") - req.UserID = userID - if strings.TrimSpace(req.ContactID) == "" { - req.ContactID = userID - } - if strings.TrimSpace(req.ContactName) == "" { - req.ContactName = "User" - } - - // Set headers for SSE - c.Response().Header().Set(echo.HeaderContentType, "text/event-stream") - c.Response().Header().Set(echo.HeaderCacheControl, "no-cache") - c.Response().Header().Set(echo.HeaderConnection, "keep-alive") - c.Response().WriteHeader(http.StatusOK) - - // Get streaming channels - chunkChan, errChan := h.resolver.StreamChat(c.Request().Context(), req) - - // Create a flusher - flusher, ok := c.Response().Writer.(http.Flusher) - if !ok { - return echo.NewHTTPError(http.StatusInternalServerError, "streaming not supported") - } - - writer := bufio.NewWriter(c.Response().Writer) - - // Stream chunks - for { - select { - case chunk, ok := <-chunkChan: - if !ok { - // Channel closed, send done message - writer.WriteString("data: [DONE]\n\n") - writer.Flush() - flusher.Flush() - return nil - } - - // Marshal chunk to JSON - data, err := json.Marshal(chunk) - if err != nil { - continue - } - - // Write SSE format - writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data))) - writer.Flush() - flusher.Flush() - - case err := <-errChan: - if err != nil { - h.logger.Error("chat stream failed", slog.Any("error", err)) - // Send error as SSE event - errData := map[string]string{"error": err.Error()} - data, _ := json.Marshal(errData) - writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data))) - writer.Flush() - flusher.Flush() - return nil - } - } - } -} - -func (h *ChatHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) - if err != nil { - return "", err - } - if err := identity.ValidateUserID(userID); err != nil { - return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - return userID, nil -} - -func (h *ChatHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") - } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) - if err != nil { - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: true}) - if err != nil { - if errors.Is(err, bots.ErrBotNotFound) { - return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") - } - if errors.Is(err, bots.ErrBotAccessDenied) { - return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied") - } - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return bot, nil -} diff --git a/internal/handlers/contacts.go b/internal/handlers/contacts.go deleted file mode 100644 index bee73a31..00000000 --- a/internal/handlers/contacts.go +++ /dev/null @@ -1,183 +0,0 @@ -package handlers - -import ( - "context" - "errors" - "net/http" - "strings" - - "github.com/labstack/echo/v4" - - "github.com/memohai/memoh/internal/auth" - "github.com/memohai/memoh/internal/bots" - "github.com/memohai/memoh/internal/contacts" - "github.com/memohai/memoh/internal/identity" - "github.com/memohai/memoh/internal/users" -) - -type ContactsHandler struct { - service *contacts.Service - botService *bots.Service - userService *users.Service -} - -func NewContactsHandler(service *contacts.Service, botService *bots.Service, userService *users.Service) *ContactsHandler { - return &ContactsHandler{ - service: service, - botService: botService, - userService: userService, - } -} - -func (h *ContactsHandler) Register(e *echo.Echo) { - group := e.Group("/bots/:bot_id/contacts") - group.GET("", h.List) - group.GET("/:id", h.Get) - group.POST("", h.Create) - group.PATCH("/:id", h.Update) -} - -func (h *ContactsHandler) List(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - query := strings.TrimSpace(c.QueryParam("q")) - items, err := h.service.Search(c.Request().Context(), botID, query) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, map[string]any{"items": items}) -} - -func (h *ContactsHandler) Get(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - id := strings.TrimSpace(c.Param("id")) - if id == "" { - return echo.NewHTTPError(http.StatusBadRequest, "contact id is required") - } - item, err := h.service.GetByID(c.Request().Context(), id) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, item) -} - -func (h *ContactsHandler) Create(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - var req contacts.CreateRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - req.BotID = botID - item, err := h.service.Create(c.Request().Context(), req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, item) -} - -func (h *ContactsHandler) Update(c echo.Context) error { - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - id := strings.TrimSpace(c.Param("id")) - if id == "" { - return echo.NewHTTPError(http.StatusBadRequest, "contact id is required") - } - var req contacts.UpdateRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - - userID, err := h.requireUserID(c) - if err == nil { - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - item, err := h.service.Update(c.Request().Context(), id, req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, item) - } - - sessionToken, tokenErr := auth.SessionTokenFromContext(c) - if tokenErr != nil { - return err - } - if sessionToken.BotID != botID { - return echo.NewHTTPError(http.StatusForbidden, "session token mismatch") - } - if strings.TrimSpace(sessionToken.ContactID) == "" || sessionToken.ContactID != id { - return echo.NewHTTPError(http.StatusForbidden, "contact mismatch") - } - if req.Tags != nil || req.Status != nil { - return echo.NewHTTPError(http.StatusForbidden, "session token cannot update tags or status") - } - item, err := h.service.Update(c.Request().Context(), id, req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, item) -} - -func (h *ContactsHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) - if err != nil { - return "", err - } - if err := identity.ValidateUserID(userID); err != nil { - return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - return userID, nil -} - -func (h *ContactsHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") - } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) - if err != nil { - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) - if err != nil { - if errors.Is(err, bots.ErrBotNotFound) { - return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") - } - if errors.Is(err, bots.ErrBotAccessDenied) { - return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied") - } - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return bot, nil -} diff --git a/internal/handlers/containerd.go b/internal/handlers/containerd.go index 7c736e42..8e0753bc 100644 --- a/internal/handlers/containerd.go +++ b/internal/handlers/containerd.go @@ -22,32 +22,35 @@ import ( "github.com/labstack/echo/v4" "github.com/opencontainers/runtime-spec/specs-go" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/config" ctr "github.com/memohai/memoh/internal/containerd" + "github.com/memohai/memoh/internal/db" dbsqlc "github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/identity" "github.com/memohai/memoh/internal/mcp" - "github.com/memohai/memoh/internal/users" + "github.com/memohai/memoh/internal/policy" ) type ContainerdHandler struct { - service ctr.Service - cfg config.MCPConfig - namespace string - logger *slog.Logger - mcpMu sync.Mutex - mcpSess map[string]*mcpSession - mcpStdioMu sync.Mutex - mcpStdioSess map[string]*mcpStdioSession - botService *bots.Service - userService *users.Service - queries *dbsqlc.Queries + service ctr.Service + cfg config.MCPConfig + namespace string + logger *slog.Logger + toolGateway *mcp.ToolGatewayService + mcpMu sync.Mutex + mcpSess map[string]*mcpSession + mcpStdioMu sync.Mutex + mcpStdioSess map[string]*mcpStdioSession + botService *bots.Service + accountService *accounts.Service + policyService *policy.Service + queries *dbsqlc.Queries } type CreateContainerRequest struct { - Image string `json:"image,omitempty"` Snapshotter string `json:"snapshotter,omitempty"` } @@ -95,17 +98,18 @@ type ListSnapshotsResponse struct { Snapshots []SnapshotInfo `json:"snapshots"` } -func NewContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.Config, botService *bots.Service, userService *users.Service, queries *dbsqlc.Queries) *ContainerdHandler { +func NewContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string, botService *bots.Service, accountService *accounts.Service, policyService *policy.Service, queries *dbsqlc.Queries) *ContainerdHandler { return &ContainerdHandler{ - service: service, - cfg: cfg.MCP, - namespace: cfg.Containerd.Namespace, - logger: log.With(slog.String("handler", "containerd")), - mcpSess: make(map[string]*mcpSession), - mcpStdioSess: make(map[string]*mcpStdioSession), - botService: botService, - userService: userService, - queries: queries, + service: service, + cfg: cfg, + namespace: namespace, + logger: log.With(slog.String("handler", "containerd")), + mcpSess: make(map[string]*mcpSession), + mcpStdioSess: make(map[string]*mcpStdioSession), + botService: botService, + accountService: accountService, + policyService: policyService, + queries: queries, } } @@ -121,20 +125,10 @@ func (h *ContainerdHandler) Register(e *echo.Echo) { group.GET("/skills", h.ListSkills) group.POST("/skills", h.UpsertSkills) group.DELETE("/skills", h.DeleteSkills) - group.POST("/fs-mcp", h.HandleMCPFS) - root := e.Group("/bots/:bot_id") - fs := e.Group("/bots/:bot_id/container/fs") - fs.GET("", h.ListFS) - fs.GET("/file", h.ReadFSFile) - fs.GET("/stat", h.StatFS) - fs.GET("/usage", h.UsageFS) - fs.POST("/file", h.WriteFSFile) - fs.POST("/dir", h.MkdirFS) - fs.POST("/upload", h.UploadFS) - fs.DELETE("", h.DeleteFS) root.POST("/mcp-stdio", h.CreateMCPStdio) - root.POST("/mcp-stdio/:session_id", h.HandleMCPStdio) + root.POST("/mcp-stdio/:connection_id", h.HandleMCPStdio) + root.POST("/tools", h.HandleMCPTools) } // CreateContainer godoc @@ -158,13 +152,7 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { } containerID := mcp.ContainerPrefix + botID - image := strings.TrimSpace(req.Image) - if image == "" { - image = h.cfg.BusyboxImage - } - if image == "" { - image = config.DefaultBusyboxImg - } + image := mcp.DefaultImageRef snapshotter := strings.TrimSpace(req.Snapshotter) if snapshotter == "" { snapshotter = h.cfg.Snapshotter @@ -203,12 +191,6 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { Source: dataDir, Options: []string{"rbind", "rw"}, }, - { - Destination: "/app", - Type: "bind", - Source: dataDir, - Options: []string{"rbind", "rw"}, - }, { Destination: "/etc/resolv.conf", Type: "bind", @@ -233,13 +215,13 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { } if h.queries != nil { - pgBotID, parseErr := parsePgUUID(botID) + pgBotID, parseErr := db.ParseUUID(botID) if parseErr == nil { ns := strings.TrimSpace(h.namespace) if ns == "" { ns = "default" } - _ = h.queries.UpsertContainer(c.Request().Context(), dbsqlc.UpsertContainerParams{ + if dbErr := h.queries.UpsertContainer(c.Request().Context(), dbsqlc.UpsertContainerParams{ BotID: pgBotID, ContainerID: containerID, ContainerName: containerID, @@ -249,7 +231,10 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { AutoStart: true, HostPath: pgtype.Text{String: dataDir, Valid: true}, ContainerPath: dataMount, - }) + }); dbErr != nil { + h.logger.Error("failed to upsert container record", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } } } @@ -260,8 +245,11 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { if netErr := ctr.SetupNetwork(ctx, task, containerID); netErr == nil { started = true if h.queries != nil { - if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { - _ = h.queries.UpdateContainerStarted(c.Request().Context(), pgBotID) + if pgBotID, parseErr := db.ParseUUID(botID); parseErr == nil { + if dbErr := h.queries.UpdateContainerStarted(c.Request().Context(), pgBotID); dbErr != nil { + h.logger.Error("failed to update container started status", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } } } } else { @@ -286,7 +274,25 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { }) } -func (h *ContainerdHandler) ensureTaskRunning(ctx context.Context, containerID string) error { +// ensureContainerAndTask verifies the container exists in containerd and its task is +// running. If the container is missing (e.g. after a VM restart) it is recreated via +// SetupBotContainer. This prevents permanent desync between DB and containerd state. +func (h *ContainerdHandler) ensureContainerAndTask(ctx context.Context, containerID, botID string) error { + // Check whether the container exists in containerd. + _, err := h.service.GetContainer(ctx, containerID) + if err != nil { + if !errdefs.IsNotFound(err) { + return err + } + // Container gone — rebuild from scratch. + h.logger.Warn("container missing in containerd, rebuilding", + slog.String("bot_id", botID), + slog.String("container_id", containerID), + ) + return h.SetupBotContainer(ctx, botID) + } + + // Container exists — make sure the task is running. tasks, err := h.service.ListTasks(ctx, &ctr.ListTasksOptions{ Filter: "container.id==" + containerID, }) @@ -310,13 +316,13 @@ func (h *ContainerdHandler) ensureTaskRunning(ctx context.Context, containerID s _ = h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{Force: true}) return err } - return err + return nil } // botContainerID resolves container_id for a bot from the database. func (h *ContainerdHandler) botContainerID(ctx context.Context, botID string) (string, error) { if h.queries != nil { - pgBotID, err := parsePgUUID(botID) + pgBotID, err := db.ParseUUID(botID) if err == nil { row, err := h.queries.GetContainerByBotID(ctx, pgBotID) if err == nil && strings.TrimSpace(row.ContainerID) != "" { @@ -367,7 +373,7 @@ func (h *ContainerdHandler) GetContainer(c echo.Context) error { ctx := c.Request().Context() if h.queries != nil { - pgBotID, parseErr := parsePgUUID(botID) + pgBotID, parseErr := db.ParseUUID(botID) if parseErr == nil { row, dbErr := h.queries.GetContainerByBotID(ctx, pgBotID) if dbErr == nil { @@ -466,12 +472,15 @@ func (h *ContainerdHandler) StartContainer(c echo.Context) error { if err != nil { return echo.NewHTTPError(http.StatusNotFound, "container not found for bot") } - if err := h.ensureTaskRunning(ctx, containerID); err != nil { + if err := h.ensureContainerAndTask(ctx, containerID, botID); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } if h.queries != nil { - if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { - _ = h.queries.UpdateContainerStarted(ctx, pgBotID) + if pgBotID, parseErr := db.ParseUUID(botID); parseErr == nil { + if dbErr := h.queries.UpdateContainerStarted(ctx, pgBotID); dbErr != nil { + h.logger.Error("failed to update container started status", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } } } return c.JSON(http.StatusOK, map[string]bool{"started": true}) @@ -503,8 +512,11 @@ func (h *ContainerdHandler) StopContainer(c echo.Context) error { } _ = h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true}) if h.queries != nil { - if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { - _ = h.queries.UpdateContainerStopped(ctx, pgBotID) + if pgBotID, parseErr := db.ParseUUID(botID); parseErr == nil { + if dbErr := h.queries.UpdateContainerStopped(ctx, pgBotID); dbErr != nil { + h.logger.Error("failed to update container stopped status", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } } } return c.JSON(http.StatusOK, map[string]bool{"stopped": true}) @@ -612,7 +624,7 @@ func (h *ContainerdHandler) ListSnapshots(c echo.Context) error { // requireBotAccess extracts bot_id from path, validates user auth, and authorizes bot access. func (h *ContainerdHandler) requireBotAccess(c echo.Context) (string, error) { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return "", err } @@ -620,36 +632,45 @@ func (h *ContainerdHandler) requireBotAccess(c echo.Context) (string, error) { if botID == "" { return "", echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return "", err } return botID, nil } -func (h *ContainerdHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) +func (h *ContainerdHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.UserIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return userID, nil + return channelIdentityID, nil } -func (h *ContainerdHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { +func (h *ContainerdHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") } + if errors.Is(err, bots.ErrBotAccessDenied) && h.policyService != nil { + allowGuest, policyErr := h.policyService.AllowGuest(ctx, botID) + if policyErr == nil && allowGuest { + bot, getErr := h.botService.Get(ctx, botID) + if getErr == nil { + return bot, nil + } + } + } if errors.Is(err, bots.ErrBotAccessDenied) { return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied") } @@ -662,10 +683,7 @@ func (h *ContainerdHandler) authorizeBotAccess(ctx context.Context, actorID, bot func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string) error { containerID := mcp.ContainerPrefix + botID - image := strings.TrimSpace(h.cfg.BusyboxImage) - if image == "" { - image = config.DefaultBusyboxImg - } + image := mcp.DefaultImageRef snapshotter := strings.TrimSpace(h.cfg.Snapshotter) if strings.TrimSpace(h.namespace) != "" { @@ -701,12 +719,6 @@ func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string) Source: dataDir, Options: []string{"rbind", "rw"}, }, - { - Destination: "/app", - Type: "bind", - Source: dataDir, - Options: []string{"rbind", "rw"}, - }, { Destination: "/etc/resolv.conf", Type: "bind", @@ -731,13 +743,13 @@ func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string) } if h.queries != nil { - pgBotID, parseErr := parsePgUUID(botID) + pgBotID, parseErr := db.ParseUUID(botID) if parseErr == nil { ns := strings.TrimSpace(h.namespace) if ns == "" { ns = "default" } - _ = h.queries.UpsertContainer(ctx, dbsqlc.UpsertContainerParams{ + if dbErr := h.queries.UpsertContainer(ctx, dbsqlc.UpsertContainerParams{ BotID: pgBotID, ContainerID: containerID, ContainerName: containerID, @@ -747,7 +759,10 @@ func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string) AutoStart: true, HostPath: pgtype.Text{String: dataDir, Valid: true}, ContainerPath: dataMount, - }) + }); dbErr != nil { + h.logger.Error("setup bot container: failed to upsert container record", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } } } @@ -756,8 +771,11 @@ func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string) }); err == nil { if netErr := ctr.SetupNetwork(ctx, task, containerID); netErr == nil { if h.queries != nil { - if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { - _ = h.queries.UpdateContainerStarted(ctx, pgBotID) + if pgBotID, parseErr := db.ParseUUID(botID); parseErr == nil { + if dbErr := h.queries.UpdateContainerStarted(ctx, pgBotID); dbErr != nil { + h.logger.Error("setup bot container: failed to update container started status", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } } } } else { @@ -780,36 +798,62 @@ func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string) // CleanupBotContainer removes the containerd container and DB record for a bot. func (h *ContainerdHandler) CleanupBotContainer(ctx context.Context, botID string) error { + h.logger.Info("CleanupBotContainer starting", slog.String("bot_id", botID)) containerID, err := h.botContainerID(ctx, botID) if err != nil { + h.logger.Warn("CleanupBotContainer: container not found for bot, cleaning up DB only", + slog.String("bot_id", botID), + slog.Any("error", err), + ) if h.queries != nil { - if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { - _ = h.queries.DeleteContainerByBotID(ctx, pgBotID) + if pgBotID, parseErr := db.ParseUUID(botID); parseErr == nil { + if dbErr := h.queries.DeleteContainerByBotID(ctx, pgBotID); dbErr != nil { + h.logger.Error("CleanupBotContainer: failed to delete DB record", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } } } return nil } + h.logger.Info("CleanupBotContainer: found container", + slog.String("bot_id", botID), + slog.String("container_id", containerID), + ) + if task, taskErr := h.service.GetTask(ctx, containerID); taskErr == nil { + h.logger.Info("CleanupBotContainer: removing network", slog.String("container_id", containerID)) _ = ctr.RemoveNetwork(ctx, task, containerID) } + h.logger.Info("CleanupBotContainer: stopping task", slog.String("container_id", containerID)) _ = h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{ Timeout: 5 * time.Second, Force: true, }) + h.logger.Info("CleanupBotContainer: deleting task", slog.String("container_id", containerID)) _ = h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true}) + h.logger.Info("CleanupBotContainer: deleting container", slog.String("container_id", containerID)) if err := h.service.DeleteContainer(ctx, containerID, &ctr.DeleteContainerOptions{ CleanupSnapshot: true, }); err != nil && !errdefs.IsNotFound(err) { + h.logger.Error("CleanupBotContainer: failed to delete container", + slog.String("container_id", containerID), + slog.Any("error", err), + ) return err } if h.queries != nil { - if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { - _ = h.queries.DeleteContainerByBotID(ctx, pgBotID) + h.logger.Info("CleanupBotContainer: deleting container record from DB", slog.String("bot_id", botID)) + if pgBotID, parseErr := db.ParseUUID(botID); parseErr == nil { + if dbErr := h.queries.DeleteContainerByBotID(ctx, pgBotID); dbErr != nil { + h.logger.Error("CleanupBotContainer: failed to delete DB record", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } } } + h.logger.Info("CleanupBotContainer finished", slog.String("bot_id", botID)) return nil } @@ -820,13 +864,99 @@ func (h *ContainerdHandler) isTaskRunning(ctx context.Context, containerID strin return err == nil && len(tasks) > 0 && tasks[0].Status == tasktypes.Status_RUNNING } -func parsePgUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) - if err != nil { - return pgtype.UUID{}, err +// ReconcileContainers compares the DB containers table against actual containerd +// state on startup. For each auto_start container in DB it verifies the container +// and task exist; if missing they are rebuilt via SetupBotContainer. Containers that +// the DB claims are running but are not present in containerd get corrected. +func (h *ContainerdHandler) ReconcileContainers(ctx context.Context) { + if h.queries == nil { + return } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil + rows, err := h.queries.ListAutoStartContainers(ctx) + if err != nil { + h.logger.Error("reconcile: failed to list containers from DB", slog.Any("error", err)) + return + } + if len(rows) == 0 { + h.logger.Info("reconcile: no auto-start containers in DB") + return + } + + h.logger.Info("reconcile: checking containers", slog.Int("count", len(rows))) + for _, row := range rows { + containerID := row.ContainerID + botID := uuid.UUID(row.BotID.Bytes).String() + + _, err := h.service.GetContainer(ctx, containerID) + if err != nil { + if !errdefs.IsNotFound(err) { + h.logger.Error("reconcile: failed to get container", + slog.String("container_id", containerID), slog.Any("error", err)) + continue + } + // Container missing in containerd — rebuild. + h.logger.Warn("reconcile: container missing, rebuilding", + slog.String("bot_id", botID), slog.String("container_id", containerID)) + if setupErr := h.SetupBotContainer(ctx, botID); setupErr != nil { + h.logger.Error("reconcile: rebuild failed", + slog.String("bot_id", botID), slog.Any("error", setupErr)) + if dbErr := h.queries.UpdateContainerStatus(ctx, dbsqlc.UpdateContainerStatusParams{ + Status: "error", + BotID: row.BotID, + }); dbErr != nil { + h.logger.Error("reconcile: failed to mark container as error", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } + } + continue + } + + // Container exists — ensure the task is running. + running := h.isTaskRunning(ctx, containerID) + if running { + if row.Status != "running" { + if dbErr := h.queries.UpdateContainerStarted(ctx, row.BotID); dbErr != nil { + h.logger.Error("reconcile: failed to update DB status to running", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } + } + h.logger.Info("reconcile: container healthy", + slog.String("bot_id", botID), slog.String("container_id", containerID)) + continue + } + + // Task not running — try to start it. + h.logger.Warn("reconcile: task not running, starting", + slog.String("bot_id", botID), slog.String("container_id", containerID)) + if err := h.ensureContainerAndTask(ctx, containerID, botID); err != nil { + h.logger.Error("reconcile: failed to start task", + slog.String("bot_id", botID), slog.Any("error", err)) + if dbErr := h.queries.UpdateContainerStopped(ctx, row.BotID); dbErr != nil { + h.logger.Error("reconcile: failed to mark container as stopped", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } + } else { + if dbErr := h.queries.UpdateContainerStarted(ctx, row.BotID); dbErr != nil { + h.logger.Error("reconcile: failed to update DB status to running", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } + } + } + h.logger.Info("reconcile: completed") +} + +func (h *ContainerdHandler) ensureBotDataRoot(botID string) (string, error) { + dataRoot := strings.TrimSpace(h.cfg.DataRoot) + if dataRoot == "" { + dataRoot = config.DefaultDataRoot + } + dataRoot, err := filepath.Abs(dataRoot) + if err != nil { + return "", err + } + root := filepath.Join(dataRoot, "bots", botID) + if err := os.MkdirAll(root, 0o755); err != nil { + return "", err + } + return root, nil } diff --git a/internal/handlers/embeddings.go b/internal/handlers/embeddings.go index 4f4d411c..63b788e6 100644 --- a/internal/handlers/embeddings.go +++ b/internal/handlers/embeddings.go @@ -101,7 +101,7 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error { ImageURL: req.Input.ImageURL, VideoURL: req.Input.VideoURL, }, - UserID: userID, + ChannelIdentityID: userID, }) if err != nil { message := err.Error() diff --git a/internal/handlers/fs.go b/internal/handlers/fs.go index 05b3b48f..17ef77de 100644 --- a/internal/handlers/fs.go +++ b/internal/handlers/fs.go @@ -19,82 +19,13 @@ import ( "github.com/containerd/containerd/v2/pkg/namespaces" "github.com/containerd/errdefs" "github.com/labstack/echo/v4" + sdkjsonrpc "github.com/modelcontextprotocol/go-sdk/jsonrpc" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" ctr "github.com/memohai/memoh/internal/containerd" mcptools "github.com/memohai/memoh/internal/mcp" ) -// HandleMCPFS godoc -// @Summary MCP filesystem tools (JSON-RPC) -// @Description Forwards MCP JSON-RPC requests to the MCP server inside the container. -// @Description Required: -// @Description - container task is running -// @Description - container has data mount (default /data) bound to /users/ -// @Description - container image contains the "mcp" binary -// @Description Auth: Bearer JWT is used to determine user_id (sub or user_id). -// @Description Paths must be relative (no leading slash) and must not contain "..". -// @Description -// @Description Example: tools/list -// @Description {"jsonrpc":"2.0","id":1,"method":"tools/list"} -// @Description -// @Description Example: tools/call (fs.read) -// @Description {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"fs.read","arguments":{"path":"notes.txt"}}} -// @Tags containerd -// @Param Authorization header string true "Bearer " -// @Param bot_id path string true "Bot ID" -// @Param payload body object true "JSON-RPC request" -// @Success 200 {object} object "JSON-RPC response: {jsonrpc,id,result|error}" -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs-mcp [post] -func (h *ContainerdHandler) HandleMCPFS(c echo.Context) error { - botID, err := h.requireBotAccess(c) - if err != nil { - return err - } - ctx := c.Request().Context() - containerID, err := h.botContainerID(ctx, botID) - if err != nil { - return echo.NewHTTPError(http.StatusNotFound, "container not found for bot") - } - - var req mcptools.JSONRPCRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - if req.JSONRPC != "" && req.JSONRPC != "2.0" { - return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32600, "invalid jsonrpc version")) - } - - if err := h.validateMCPContainer(ctx, containerID, botID); err != nil { - h.logger.Error("mcp fs validate failed", slog.Any("error", err), slog.String("bot_id", botID), slog.String("container_id", containerID)) - return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32603, err.Error())) - } - if err := h.ensureTaskRunning(ctx, containerID); err != nil { - h.logger.Error("mcp fs ensure task failed", slog.Any("error", err), slog.String("bot_id", botID), slog.String("container_id", containerID)) - return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32603, err.Error())) - } - - if strings.TrimSpace(req.Method) == "" { - return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32601, "method not found")) - } - if mcptools.IsNotification(req) { - if err := h.notifyMCPServer(ctx, containerID, req); err != nil { - h.logger.Error("mcp fs notify failed", slog.Any("error", err), slog.String("method", req.Method), slog.String("bot_id", botID), slog.String("container_id", containerID)) - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - // MCP Streamable HTTP spec: notifications must be answered with 202 Accepted and no body. - return c.NoContent(http.StatusAccepted) - } - payload, err := h.callMCPServer(ctx, containerID, req) - if err != nil { - h.logger.Error("mcp fs call failed", slog.Any("error", err), slog.String("method", req.Method), slog.String("bot_id", botID), slog.String("container_id", containerID)) - return c.JSON(http.StatusOK, mcptools.JSONRPCErrorResponse(req.ID, -32603, err.Error())) - } - return c.JSON(http.StatusOK, payload) -} - func (h *ContainerdHandler) validateMCPContainer(ctx context.Context, containerID, botID string) error { if strings.TrimSpace(botID) == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") @@ -143,16 +74,27 @@ type mcpSession struct { stdout io.ReadCloser stderr io.ReadCloser cmd *exec.Cmd - initOnce sync.Once - writeMu sync.Mutex + initMu sync.Mutex + initState mcpSessionInitState + initWait chan struct{} pendingMu sync.Mutex - pending map[string]chan mcptools.JSONRPCResponse + pending map[string]chan *sdkjsonrpc.Response + conn sdkmcp.Connection closed chan struct{} closeOnce sync.Once closeErr error onClose func() } +type mcpSessionInitState uint8 + +const ( + mcpSessionInitStateNone mcpSessionInitState = iota + mcpSessionInitStateInitializing + mcpSessionInitStateInitialized + mcpSessionInitStateReady +) + func (h *ContainerdHandler) getMCPSession(ctx context.Context, containerID string) (*mcpSession, error) { h.mcpMu.Lock() if sess, ok := h.mcpSess[containerID]; ok { @@ -201,9 +143,19 @@ func (h *ContainerdHandler) startContainerdMCPSession(ctx context.Context, conta stdin: execSession.Stdin, stdout: execSession.Stdout, stderr: execSession.Stderr, - pending: make(map[string]chan mcptools.JSONRPCResponse), + pending: make(map[string]chan *sdkjsonrpc.Response), closed: make(chan struct{}), } + transport := &sdkmcp.IOTransport{ + Reader: sess.stdout, + Writer: sess.stdin, + } + conn, err := transport.Connect(ctx) + if err != nil { + sess.closeWithError(err) + return nil, err + } + sess.conn = conn h.startMCPStderrLogger(execSession.Stderr, containerID) go sess.readLoop() @@ -272,9 +224,19 @@ func (h *ContainerdHandler) startLimaMCPSession(containerID string) (*mcpSession stdout: stdout, stderr: stderr, cmd: cmd, - pending: make(map[string]chan mcptools.JSONRPCResponse), + pending: make(map[string]chan *sdkjsonrpc.Response), closed: make(chan struct{}), } + transport := &sdkmcp.IOTransport{ + Reader: sess.stdout, + Writer: sess.stdin, + } + conn, err := transport.Connect(context.Background()) + if err != nil { + sess.closeWithError(err) + return nil, err + } + sess.conn = conn h.startMCPStderrLogger(stderr, containerID) go sess.readLoop() @@ -302,11 +264,20 @@ func (s *mcpSession) closeWithError(err error) { for _, ch := range s.pending { close(ch) } - s.pending = map[string]chan mcptools.JSONRPCResponse{} + s.pending = map[string]chan *sdkjsonrpc.Response{} s.pendingMu.Unlock() - _ = s.stdin.Close() - _ = s.stdout.Close() - _ = s.stderr.Close() + if s.conn != nil { + _ = s.conn.Close() + } + if s.stdin != nil { + _ = s.stdin.Close() + } + if s.stdout != nil { + _ = s.stdout.Close() + } + if s.stderr != nil { + _ = s.stderr.Close() + } if s.cmd != nil && s.cmd.Process != nil { _ = s.cmd.Process.Kill() } @@ -358,18 +329,25 @@ func (h *ContainerdHandler) mcpFIFODir() string { } func (s *mcpSession) readLoop() { - scanner := bufio.NewScanner(s.stdout) - scanner.Buffer(make([]byte, 0, 64*1024), 8*1024*1024) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" { + if s.conn == nil { + s.closeWithError(io.EOF) + return + } + for { + msg, err := s.conn.Read(context.Background()) + if err != nil { + if errors.Is(err, io.EOF) { + s.closeWithError(io.EOF) + return + } + s.closeWithError(err) + return + } + resp, ok := msg.(*sdkjsonrpc.Response) + if !ok || !resp.ID.IsValid() { continue } - var resp mcptools.JSONRPCResponse - if err := json.Unmarshal([]byte(line), &resp); err != nil { - continue - } - id := strings.TrimSpace(string(resp.ID)) + id := sdkIDKey(resp.ID) if id == "" { continue } @@ -384,29 +362,43 @@ func (s *mcpSession) readLoop() { close(ch) } } - if err := scanner.Err(); err != nil { - s.closeWithError(err) - } else { - s.closeWithError(io.EOF) - } } func (s *mcpSession) call(ctx context.Context, req mcptools.JSONRPCRequest) (map[string]any, error) { - payloads, targetID, err := mcptools.BuildPayloads(req, &s.initOnce) + method := strings.TrimSpace(req.Method) + if method == "initialize" { + return s.callInitialize(ctx, req) + } + if method != "notifications/initialized" { + if err := s.ensureInitialized(ctx); err != nil { + return nil, err + } + } + + targetID, err := parseRawJSONRPCID(req.ID) if err != nil { return nil, err } - target := strings.TrimSpace(string(targetID)) + target := sdkIDKey(targetID) if target == "" { return nil, fmt.Errorf("missing request id") } + if s.conn == nil { + return nil, io.EOF + } - respCh := make(chan mcptools.JSONRPCResponse, 1) + respCh := make(chan *sdkjsonrpc.Response, 1) s.pendingMu.Lock() s.pending[target] = respCh s.pendingMu.Unlock() - if err := s.writePayloads(payloads); err != nil { + callReq := &sdkjsonrpc.Request{ + ID: targetID, + Method: method, + Params: req.Params, + } + if err := s.conn.Write(ctx, callReq); err != nil { + s.removePending(target) return nil, err } @@ -418,46 +410,347 @@ func (s *mcpSession) call(ctx context.Context, req mcptools.JSONRPCRequest) (map } return nil, io.EOF } - if resp.Error != nil { - return map[string]any{ - "jsonrpc": "2.0", - "id": resp.ID, - "error": map[string]any{ - "code": resp.Error.Code, - "message": resp.Error.Message, - }, - }, nil + if method == "notifications/initialized" { + s.setInitStateAtLeast(mcpSessionInitStateReady) } - return map[string]any{ - "jsonrpc": "2.0", - "id": resp.ID, - "result": resp.Result, - }, nil + return sdkResponsePayload(resp) case <-s.closed: if s.closeErr != nil { return nil, s.closeErr } return nil, io.EOF case <-ctx.Done(): + s.removePending(target) + return nil, ctx.Err() + } +} + +func (s *mcpSession) callInitialize(ctx context.Context, req mcptools.JSONRPCRequest) (map[string]any, error) { + payload, err := s.callRaw(ctx, req) + if err != nil { + return nil, err + } + if err := mcptools.PayloadError(payload); err != nil { + return payload, nil + } + s.setInitStateAtLeast(mcpSessionInitStateInitialized) + return payload, nil +} + +func (s *mcpSession) callRaw(ctx context.Context, req mcptools.JSONRPCRequest) (map[string]any, error) { + method := strings.TrimSpace(req.Method) + targetID, err := parseRawJSONRPCID(req.ID) + if err != nil { + return nil, err + } + target := sdkIDKey(targetID) + if target == "" { + return nil, fmt.Errorf("missing request id") + } + if s.conn == nil { + return nil, io.EOF + } + + respCh := make(chan *sdkjsonrpc.Response, 1) + s.pendingMu.Lock() + s.pending[target] = respCh + s.pendingMu.Unlock() + + callReq := &sdkjsonrpc.Request{ + ID: targetID, + Method: method, + Params: req.Params, + } + if err := s.conn.Write(ctx, callReq); err != nil { + s.removePending(target) + return nil, err + } + + select { + case resp, ok := <-respCh: + if !ok { + if s.closeErr != nil { + return nil, s.closeErr + } + return nil, io.EOF + } + return sdkResponsePayload(resp) + case <-s.closed: + if s.closeErr != nil { + return nil, s.closeErr + } + return nil, io.EOF + case <-ctx.Done(): + s.removePending(target) return nil, ctx.Err() } } func (s *mcpSession) notify(ctx context.Context, req mcptools.JSONRPCRequest) error { - payloads, err := mcptools.BuildNotificationPayloads(req) - if err != nil { + if s.conn == nil { + return io.EOF + } + method := strings.TrimSpace(req.Method) + notification := &sdkjsonrpc.Request{ + Method: method, + Params: req.Params, + } + if err := s.conn.Write(ctx, notification); err != nil { return err } - return s.writePayloads(payloads) -} - -func (s *mcpSession) writePayloads(payloads []string) error { - s.writeMu.Lock() - defer s.writeMu.Unlock() - for _, payload := range payloads { - if _, err := s.stdin.Write([]byte(payload + "\n")); err != nil { - return err - } + if method == "notifications/initialized" { + s.setInitStateAtLeast(mcpSessionInitStateReady) } return nil } + +func (s *mcpSession) ensureInitialized(ctx context.Context) error { + for { + s.initMu.Lock() + switch s.initState { + case mcpSessionInitStateReady: + s.initMu.Unlock() + return nil + case mcpSessionInitStateInitializing: + waitCh := s.initWait + s.initMu.Unlock() + if waitCh == nil { + continue + } + select { + case <-waitCh: + continue + case <-ctx.Done(): + return ctx.Err() + case <-s.closed: + if s.closeErr != nil { + return s.closeErr + } + return io.EOF + } + case mcpSessionInitStateInitialized: + waitCh := make(chan struct{}) + s.initState = mcpSessionInitStateInitializing + s.initWait = waitCh + s.initMu.Unlock() + + err := s.sendInitializedNotification(ctx) + + s.initMu.Lock() + if err == nil { + s.initState = mcpSessionInitStateReady + } else { + s.initState = mcpSessionInitStateInitialized + } + s.initWait = nil + close(waitCh) + s.initMu.Unlock() + + if err != nil { + return err + } + return nil + default: + waitCh := make(chan struct{}) + s.initState = mcpSessionInitStateInitializing + s.initWait = waitCh + s.initMu.Unlock() + + nextState, err := s.initializeHandshake(ctx) + + s.initMu.Lock() + s.initState = nextState + s.initWait = nil + close(waitCh) + s.initMu.Unlock() + + if err != nil { + return err + } + if nextState == mcpSessionInitStateReady { + return nil + } + } + } +} + +func (s *mcpSession) initializeHandshake(ctx context.Context) (mcpSessionInitState, error) { + params, err := json.Marshal(map[string]any{ + "protocolVersion": "2025-06-18", + "capabilities": map[string]any{ + "roots": map[string]any{ + "listChanged": false, + }, + }, + "clientInfo": map[string]any{ + "name": "memoh-http-proxy", + "version": "v0", + }, + }) + if err != nil { + return mcpSessionInitStateNone, err + } + initID, err := sdkjsonrpc.MakeID("init-1") + if err != nil { + return mcpSessionInitStateNone, err + } + initResp, err := s.invokeCall(ctx, &sdkjsonrpc.Request{ + ID: initID, + Method: "initialize", + Params: params, + }) + if err != nil { + return mcpSessionInitStateNone, err + } + if initResp.Error != nil { + return mcpSessionInitStateNone, initResp.Error + } + if err := s.sendInitializedNotification(ctx); err != nil { + return mcpSessionInitStateInitialized, err + } + return mcpSessionInitStateReady, nil +} + +func (s *mcpSession) sendInitializedNotification(ctx context.Context) error { + if s.conn == nil { + return io.EOF + } + return s.conn.Write(ctx, &sdkjsonrpc.Request{ + Method: "notifications/initialized", + }) +} + +func (s *mcpSession) invokeCall(ctx context.Context, req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { + if s.conn == nil { + return nil, io.EOF + } + if req == nil || !req.ID.IsValid() { + return nil, fmt.Errorf("missing request id") + } + key := sdkIDKey(req.ID) + if key == "" { + return nil, fmt.Errorf("invalid request id") + } + + respCh := make(chan *sdkjsonrpc.Response, 1) + s.pendingMu.Lock() + s.pending[key] = respCh + s.pendingMu.Unlock() + + if err := s.conn.Write(ctx, req); err != nil { + s.removePending(key) + return nil, err + } + + select { + case resp, ok := <-respCh: + if !ok { + if s.closeErr != nil { + return nil, s.closeErr + } + return nil, io.EOF + } + return resp, nil + case <-s.closed: + if s.closeErr != nil { + return nil, s.closeErr + } + return nil, io.EOF + case <-ctx.Done(): + s.removePending(key) + return nil, ctx.Err() + } +} + +func (s *mcpSession) removePending(key string) { + if strings.TrimSpace(key) == "" { + return + } + s.pendingMu.Lock() + delete(s.pending, key) + s.pendingMu.Unlock() +} + +func (s *mcpSession) setInitStateAtLeast(next mcpSessionInitState) { + s.initMu.Lock() + if s.initState != mcpSessionInitStateInitializing && s.initState < next { + s.initState = next + } + s.initMu.Unlock() +} + +func parseRawJSONRPCID(raw json.RawMessage) (sdkjsonrpc.ID, error) { + if len(raw) == 0 { + return sdkjsonrpc.ID{}, fmt.Errorf("missing request id") + } + var idValue any + if err := json.Unmarshal(raw, &idValue); err != nil { + return sdkjsonrpc.ID{}, err + } + id, err := sdkjsonrpc.MakeID(idValue) + if err != nil { + return sdkjsonrpc.ID{}, err + } + if !id.IsValid() { + return sdkjsonrpc.ID{}, fmt.Errorf("missing request id") + } + return id, nil +} + +func sdkIDKey(id sdkjsonrpc.ID) string { + if !id.IsValid() { + return "" + } + raw, err := json.Marshal(id.Raw()) + if err != nil { + return "" + } + return string(raw) +} + +func sdkIDRaw(id sdkjsonrpc.ID) json.RawMessage { + if !id.IsValid() { + return nil + } + raw, err := json.Marshal(id.Raw()) + if err != nil { + return nil + } + return json.RawMessage(raw) +} + +func sdkResponsePayload(resp *sdkjsonrpc.Response) (map[string]any, error) { + if resp == nil { + return nil, io.EOF + } + if resp.Error != nil { + code := int64(-32603) + message := strings.TrimSpace(resp.Error.Error()) + if wireErr, ok := resp.Error.(*sdkjsonrpc.Error); ok { + code = wireErr.Code + message = strings.TrimSpace(wireErr.Message) + } + if message == "" { + message = "internal error" + } + return map[string]any{ + "jsonrpc": "2.0", + "id": sdkIDRaw(resp.ID), + "error": map[string]any{ + "code": code, + "message": message, + }, + }, nil + } + var result any + if len(resp.Result) > 0 { + if err := json.Unmarshal(resp.Result, &result); err != nil { + return nil, err + } + } + return map[string]any{ + "jsonrpc": "2.0", + "id": sdkIDRaw(resp.ID), + "result": result, + }, nil +} diff --git a/internal/handlers/fs_mcp_session_test.go b/internal/handlers/fs_mcp_session_test.go new file mode 100644 index 00000000..3ef000ca --- /dev/null +++ b/internal/handlers/fs_mcp_session_test.go @@ -0,0 +1,255 @@ +package handlers + +import ( + "context" + "encoding/json" + "fmt" + "io" + "sync" + "testing" + "time" + + sdkjsonrpc "github.com/modelcontextprotocol/go-sdk/jsonrpc" + + mcptools "github.com/memohai/memoh/internal/mcp" +) + +type fakeMCPConnection struct { + mu sync.Mutex + writes []*sdkjsonrpc.Request + readCh chan sdkjsonrpc.Message + closed chan struct{} + closeMu sync.Once + onWrite func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) +} + +func newFakeMCPConnection(onWrite func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error)) *fakeMCPConnection { + return &fakeMCPConnection{ + writes: make([]*sdkjsonrpc.Request, 0, 16), + readCh: make(chan sdkjsonrpc.Message, 32), + closed: make(chan struct{}), + onWrite: onWrite, + } +} + +func (c *fakeMCPConnection) Read(ctx context.Context) (sdkjsonrpc.Message, error) { + select { + case <-c.closed: + return nil, io.EOF + case <-ctx.Done(): + return nil, ctx.Err() + case msg, ok := <-c.readCh: + if !ok { + return nil, io.EOF + } + return msg, nil + } +} + +func (c *fakeMCPConnection) Write(ctx context.Context, msg sdkjsonrpc.Message) error { + req, ok := msg.(*sdkjsonrpc.Request) + if !ok { + return fmt.Errorf("unsupported message type: %T", msg) + } + cloned := cloneJSONRPCRequest(req) + c.mu.Lock() + c.writes = append(c.writes, cloned) + c.mu.Unlock() + + if c.onWrite == nil { + return nil + } + resp, err := c.onWrite(cloned) + if err != nil { + return err + } + if resp == nil { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.closed: + return io.EOF + case c.readCh <- resp: + return nil + } +} + +func (c *fakeMCPConnection) Close() error { + c.closeMu.Do(func() { + close(c.closed) + close(c.readCh) + }) + return nil +} + +func (c *fakeMCPConnection) SessionID() string { + return "test-session" +} + +func cloneJSONRPCRequest(req *sdkjsonrpc.Request) *sdkjsonrpc.Request { + if req == nil { + return nil + } + params := append([]byte(nil), req.Params...) + return &sdkjsonrpc.Request{ + ID: req.ID, + Method: req.Method, + Params: params, + Extra: req.Extra, + } +} + +func jsonRPCSuccessResponse(id sdkjsonrpc.ID, payload map[string]any) *sdkjsonrpc.Response { + body, _ := json.Marshal(payload) + return &sdkjsonrpc.Response{ + ID: id, + Result: body, + } +} + +func newTestMCPSession(conn *fakeMCPConnection) *mcpSession { + return &mcpSession{ + pending: map[string]chan *sdkjsonrpc.Response{}, + conn: conn, + closed: make(chan struct{}), + } +} + +func TestMCPSessionRetriesInitializeAfterFailure(t *testing.T) { + initCalls := 0 + conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { + switch req.Method { + case "initialize": + initCalls++ + if initCalls == 1 { + return &sdkjsonrpc.Response{ + ID: req.ID, + Error: &sdkjsonrpc.Error{ + Code: -32603, + Message: "temporary init failure", + }, + }, nil + } + return jsonRPCSuccessResponse(req.ID, map[string]any{ + "protocolVersion": "2025-06-18", + }), nil + case "tools/list": + return jsonRPCSuccessResponse(req.ID, map[string]any{ + "tools": []any{}, + }), nil + default: + return nil, nil + } + }) + session := newTestMCPSession(conn) + go session.readLoop() + defer session.closeWithError(io.EOF) + + _, firstErr := session.call(context.Background(), mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("1"), + Method: "tools/list", + }) + if firstErr == nil { + t.Fatalf("first call should fail when initialize fails") + } + + secondPayload, secondErr := session.call(context.Background(), mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("2"), + Method: "tools/list", + }) + if secondErr != nil { + t.Fatalf("second call should recover by retrying initialize: %v", secondErr) + } + if initCalls != 2 { + t.Fatalf("initialize should be retried once, got calls: %d", initCalls) + } + result, ok := secondPayload["result"].(map[string]any) + if !ok { + t.Fatalf("missing tools/list result: %#v", secondPayload) + } + if _, ok := result["tools"].([]any); !ok { + t.Fatalf("missing tools field: %#v", result) + } +} + +func TestMCPSessionExplicitInitializeDoesNotDuplicateInitialize(t *testing.T) { + initializeCalls := 0 + initializedNotifications := 0 + conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { + switch req.Method { + case "initialize": + initializeCalls++ + return jsonRPCSuccessResponse(req.ID, map[string]any{ + "protocolVersion": "2025-06-18", + }), nil + case "notifications/initialized": + initializedNotifications++ + return nil, nil + case "tools/list": + return jsonRPCSuccessResponse(req.ID, map[string]any{ + "tools": []any{}, + }), nil + default: + return nil, nil + } + }) + session := newTestMCPSession(conn) + go session.readLoop() + defer session.closeWithError(io.EOF) + + _, initErr := session.call(context.Background(), mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("100"), + Method: "initialize", + Params: json.RawMessage(`{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"v1"}}`), + }) + if initErr != nil { + t.Fatalf("explicit initialize should succeed: %v", initErr) + } + + _, listErr := session.call(context.Background(), mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("101"), + Method: "tools/list", + }) + if listErr != nil { + t.Fatalf("tools/list after initialize should succeed: %v", listErr) + } + if initializeCalls != 1 { + t.Fatalf("initialize should not be duplicated, got: %d", initializeCalls) + } + if initializedNotifications != 1 { + t.Fatalf("should send exactly one notifications/initialized, got: %d", initializedNotifications) + } +} + +func TestMCPSessionRemovesPendingOnContextCancel(t *testing.T) { + conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { + // Intentionally do not reply; caller should timeout. + return nil, nil + }) + session := newTestMCPSession(conn) + session.initState = mcpSessionInitStateReady + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) + defer cancel() + _, err := session.call(ctx, mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("200"), + Method: "tools/list", + }) + if err == nil { + t.Fatalf("call should fail on context timeout") + } + + session.pendingMu.Lock() + pendingCount := len(session.pending) + session.pendingMu.Unlock() + if pendingCount != 0 { + t.Fatalf("pending map should be empty after cancellation, got: %d", pendingCount) + } +} diff --git a/internal/handlers/fs_rest.go b/internal/handlers/fs_rest.go deleted file mode 100644 index 96e529d8..00000000 --- a/internal/handlers/fs_rest.go +++ /dev/null @@ -1,585 +0,0 @@ -package handlers - -import ( - "errors" - "io" - "net/http" - "os" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/labstack/echo/v4" - - "github.com/memohai/memoh/internal/config" -) - -type FSListResponse struct { - Path string `json:"path"` - Entries []FSRestEntry `json:"entries"` -} - -type FSRestEntry struct { - Path string `json:"path"` - IsDir bool `json:"is_dir"` - Size int64 `json:"size"` - Mode uint32 `json:"mode"` - ModTime time.Time `json:"mod_time"` -} - -type FSReadResponse struct { - Path string `json:"path"` - Content string `json:"content"` - Size int64 `json:"size"` - Mode uint32 `json:"mode"` - ModTime time.Time `json:"mod_time"` -} - -type FSStatResponse struct { - Path string `json:"path"` - IsDir bool `json:"is_dir"` - Size int64 `json:"size"` - Mode uint32 `json:"mode"` - ModTime time.Time `json:"mod_time"` -} - -type FSUsageResponse struct { - Path string `json:"path"` - TotalBytes int64 `json:"total_bytes"` - FileCount int64 `json:"file_count"` - DirCount int64 `json:"dir_count"` -} - -type FSWriteRequest struct { - Path string `json:"path"` - Content string `json:"content"` - Overwrite *bool `json:"overwrite"` -} - -type FSWriteResponse struct { - OK bool `json:"ok"` -} - -type FSMkdirRequest struct { - Path string `json:"path"` - Parents *bool `json:"parents"` -} - -type FSDeleteResponse struct { - OK bool `json:"ok"` -} - -// ListFS godoc -// @Summary List files for a bot -// @Description List entries under a relative path -// @Tags fs -// @Param bot_id path string true "Bot ID" -// @Param path query string false "Relative directory path" -// @Param recursive query bool false "Recursive listing" -// @Success 200 {object} FSListResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs [get] -func (h *ContainerdHandler) ListFS(c echo.Context) error { - botID, err := h.requireBotAccess(c) - if err != nil { - return err - } - root, err := h.ensureBotDataRoot(botID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - recursive, err := parseBoolQuery(c, "recursive") - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - target, rel, err := resolveBotPath(root, c.QueryParam("path"), true) - if err != nil { - return fsHTTPError(err) - } - info, err := os.Stat(target) - if err != nil { - return fsHTTPError(err) - } - if !info.IsDir() { - return echo.NewHTTPError(http.StatusBadRequest, "path is not a directory") - } - - entries := []FSRestEntry{} - if recursive { - err = filepath.WalkDir(target, func(p string, d os.DirEntry, walkErr error) error { - if walkErr != nil { - return walkErr - } - if p == target { - return nil - } - entryInfo, err := d.Info() - if err != nil { - return err - } - entry, err := entryForBotPath(root, p, entryInfo) - if err != nil { - return err - } - entries = append(entries, entry) - return nil - }) - } else { - dirEntries, err := os.ReadDir(target) - if err != nil { - return fsHTTPError(err) - } - for _, entry := range dirEntries { - entryInfo, err := entry.Info() - if err != nil { - return fsHTTPError(err) - } - fullPath := filepath.Join(target, entry.Name()) - fileEntry, err := entryForBotPath(root, fullPath, entryInfo) - if err != nil { - return fsHTTPError(err) - } - entries = append(entries, fileEntry) - } - } - if err != nil { - return fsHTTPError(err) - } - - listedPath := strings.TrimSpace(rel) - if listedPath == "" || listedPath == "." { - listedPath = "." - } - return c.JSON(http.StatusOK, FSListResponse{Path: listedPath, Entries: entries}) -} - -// ReadFSFile godoc -// @Summary Read file content -// @Tags fs -// @Param bot_id path string true "Bot ID" -// @Param path query string true "Relative file path" -// @Success 200 {object} FSReadResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/file [get] -func (h *ContainerdHandler) ReadFSFile(c echo.Context) error { - botID, err := h.requireBotAccess(c) - if err != nil { - return err - } - root, err := h.ensureBotDataRoot(botID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - target, rel, err := resolveBotPath(root, c.QueryParam("path"), false) - if err != nil { - return fsHTTPError(err) - } - info, err := os.Stat(target) - if err != nil { - return fsHTTPError(err) - } - if info.IsDir() { - return echo.NewHTTPError(http.StatusBadRequest, "path is a directory") - } - data, err := os.ReadFile(target) - if err != nil { - return fsHTTPError(err) - } - return c.JSON(http.StatusOK, FSReadResponse{ - Path: rel, - Content: string(data), - Size: info.Size(), - Mode: uint32(info.Mode().Perm()), - ModTime: info.ModTime(), - }) -} - -// StatFS godoc -// @Summary Get file or directory metadata -// @Tags fs -// @Param bot_id path string true "Bot ID" -// @Param path query string true "Relative path" -// @Success 200 {object} FSStatResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/stat [get] -func (h *ContainerdHandler) StatFS(c echo.Context) error { - botID, err := h.requireBotAccess(c) - if err != nil { - return err - } - root, err := h.ensureBotDataRoot(botID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - target, rel, err := resolveBotPath(root, c.QueryParam("path"), false) - if err != nil { - return fsHTTPError(err) - } - info, err := os.Stat(target) - if err != nil { - return fsHTTPError(err) - } - return c.JSON(http.StatusOK, FSStatResponse{ - Path: rel, - IsDir: info.IsDir(), - Size: info.Size(), - Mode: uint32(info.Mode().Perm()), - ModTime: info.ModTime(), - }) -} - -// UsageFS godoc -// @Summary Get usage under a path -// @Tags fs -// @Param bot_id path string true "Bot ID" -// @Param path query string false "Relative directory path" -// @Success 200 {object} FSUsageResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/usage [get] -func (h *ContainerdHandler) UsageFS(c echo.Context) error { - botID, err := h.requireBotAccess(c) - if err != nil { - return err - } - root, err := h.ensureBotDataRoot(botID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - target, rel, err := resolveBotPath(root, c.QueryParam("path"), true) - if err != nil { - return fsHTTPError(err) - } - info, err := os.Stat(target) - if err != nil { - return fsHTTPError(err) - } - - var totalBytes int64 - var fileCount int64 - var dirCount int64 - if info.IsDir() { - err = filepath.WalkDir(target, func(p string, d os.DirEntry, walkErr error) error { - if walkErr != nil { - return walkErr - } - if p == target { - return nil - } - entryInfo, err := d.Info() - if err != nil { - return err - } - if entryInfo.IsDir() { - dirCount++ - return nil - } - fileCount++ - totalBytes += entryInfo.Size() - return nil - }) - if err != nil { - return fsHTTPError(err) - } - } else { - fileCount = 1 - totalBytes = info.Size() - } - - usagePath := strings.TrimSpace(rel) - if usagePath == "" || usagePath == "." { - usagePath = "." - } - return c.JSON(http.StatusOK, FSUsageResponse{ - Path: usagePath, - TotalBytes: totalBytes, - FileCount: fileCount, - DirCount: dirCount, - }) -} - -// WriteFSFile godoc -// @Summary Create or overwrite a file -// @Tags fs -// @Param bot_id path string true "Bot ID" -// @Param payload body FSWriteRequest true "File write payload" -// @Success 200 {object} FSWriteResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 409 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/file [post] -func (h *ContainerdHandler) WriteFSFile(c echo.Context) error { - botID, err := h.requireBotAccess(c) - if err != nil { - return err - } - var req FSWriteRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - if strings.TrimSpace(req.Path) == "" { - return echo.NewHTTPError(http.StatusBadRequest, "path is required") - } - root, err := h.ensureBotDataRoot(botID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - target, _, err := resolveBotPath(root, req.Path, false) - if err != nil { - return fsHTTPError(err) - } - overwrite := true - if req.Overwrite != nil { - overwrite = *req.Overwrite - } - if _, err := os.Stat(target); err == nil && !overwrite { - return echo.NewHTTPError(http.StatusConflict, "file already exists") - } - if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if err := os.WriteFile(target, []byte(req.Content), 0o644); err != nil { - return fsHTTPError(err) - } - return c.JSON(http.StatusOK, FSWriteResponse{OK: true}) -} - -// MkdirFS godoc -// @Summary Create a directory -// @Tags fs -// @Param bot_id path string true "Bot ID" -// @Param payload body FSMkdirRequest true "Directory payload" -// @Success 200 {object} FSWriteResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/dir [post] -func (h *ContainerdHandler) MkdirFS(c echo.Context) error { - botID, err := h.requireBotAccess(c) - if err != nil { - return err - } - var req FSMkdirRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - if strings.TrimSpace(req.Path) == "" { - return echo.NewHTTPError(http.StatusBadRequest, "path is required") - } - root, err := h.ensureBotDataRoot(botID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - target, _, err := resolveBotPath(root, req.Path, false) - if err != nil { - return fsHTTPError(err) - } - parents := true - if req.Parents != nil { - parents = *req.Parents - } - if parents { - if err := os.MkdirAll(target, 0o755); err != nil { - return fsHTTPError(err) - } - } else if err := os.Mkdir(target, 0o755); err != nil { - return fsHTTPError(err) - } - return c.JSON(http.StatusOK, FSWriteResponse{OK: true}) -} - -// UploadFS godoc -// @Summary Upload a file -// @Tags fs -// @Param bot_id path string true "Bot ID" -// @Param path query string false "Relative file path or directory" -// @Param file formData file true "File to upload" -// @Success 200 {object} FSWriteResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/upload [post] -func (h *ContainerdHandler) UploadFS(c echo.Context) error { - botID, err := h.requireBotAccess(c) - if err != nil { - return err - } - file, err := c.FormFile("file") - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "file is required") - } - rawPath := strings.TrimSpace(c.FormValue("path")) - if rawPath == "" { - rawPath = strings.TrimSpace(c.QueryParam("path")) - } - if rawPath == "" { - rawPath = file.Filename - } - root, err := h.ensureBotDataRoot(botID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - - targetPath := rawPath - if strings.HasSuffix(rawPath, "/") || strings.HasSuffix(rawPath, string(os.PathSeparator)) { - targetPath = filepath.ToSlash(filepath.Join(rawPath, file.Filename)) - } - target, _, err := resolveBotPath(root, targetPath, false) - if err != nil { - return fsHTTPError(err) - } - if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - src, err := file.Open() - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - defer src.Close() - dst, err := os.Create(target) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - defer dst.Close() - if _, err := io.Copy(dst, src); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, FSWriteResponse{OK: true}) -} - -// DeleteFS godoc -// @Summary Delete a file or directory -// @Tags fs -// @Param bot_id path string true "Bot ID" -// @Param path query string true "Relative path" -// @Param recursive query bool false "Recursive delete for directories" -// @Success 200 {object} FSDeleteResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs [delete] -func (h *ContainerdHandler) DeleteFS(c echo.Context) error { - botID, err := h.requireBotAccess(c) - if err != nil { - return err - } - recursive, err := parseBoolQuery(c, "recursive") - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - root, err := h.ensureBotDataRoot(botID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - target, rel, err := resolveBotPath(root, c.QueryParam("path"), false) - if err != nil { - return fsHTTPError(err) - } - if rel == "." || rel == "" { - return echo.NewHTTPError(http.StatusBadRequest, "refuse to delete root") - } - info, err := os.Stat(target) - if err != nil { - return fsHTTPError(err) - } - if info.IsDir() && recursive { - if err := os.RemoveAll(target); err != nil { - return fsHTTPError(err) - } - } else if err := os.Remove(target); err != nil { - return fsHTTPError(err) - } - return c.JSON(http.StatusOK, FSDeleteResponse{OK: true}) -} - -func (h *ContainerdHandler) ensureBotDataRoot(botID string) (string, error) { - dataRoot := strings.TrimSpace(h.cfg.DataRoot) - if dataRoot == "" { - dataRoot = config.DefaultDataRoot - } - dataRoot, err := filepath.Abs(dataRoot) - if err != nil { - return "", err - } - root := filepath.Join(dataRoot, "bots", botID) - if err := os.MkdirAll(root, 0o755); err != nil { - return "", err - } - return root, nil -} - -func resolveBotPath(root, requestPath string, allowRoot bool) (string, string, error) { - raw := strings.TrimSpace(requestPath) - if raw == "" { - if allowRoot { - return root, ".", nil - } - return "", "", os.ErrInvalid - } - clean := filepath.Clean(filepath.FromSlash(raw)) - if clean == "." || clean == "" { - if allowRoot { - return root, ".", nil - } - return "", "", os.ErrInvalid - } - if filepath.IsAbs(clean) || strings.HasPrefix(clean, "..") { - return "", "", os.ErrInvalid - } - target := filepath.Join(root, clean) - rel, err := filepath.Rel(root, target) - if err != nil || strings.HasPrefix(rel, "..") { - return "", "", os.ErrInvalid - } - return target, filepath.ToSlash(rel), nil -} - -func entryForBotPath(root, target string, info os.FileInfo) (FSRestEntry, error) { - rel, err := filepath.Rel(root, target) - if err != nil { - return FSRestEntry{}, err - } - if strings.HasPrefix(rel, "..") { - return FSRestEntry{}, os.ErrInvalid - } - if rel == "." { - rel = "" - } - return FSRestEntry{ - Path: filepath.ToSlash(rel), - IsDir: info.IsDir(), - Size: info.Size(), - Mode: uint32(info.Mode().Perm()), - ModTime: info.ModTime(), - }, nil -} - -func parseBoolQuery(c echo.Context, key string) (bool, error) { - raw := strings.TrimSpace(c.QueryParam(key)) - if raw == "" { - return false, nil - } - return strconv.ParseBool(raw) -} - -func fsHTTPError(err error) error { - if err == nil { - return nil - } - if errors.Is(err, os.ErrInvalid) { - return echo.NewHTTPError(http.StatusBadRequest, "invalid path") - } - if os.IsNotExist(err) { - return echo.NewHTTPError(http.StatusNotFound, "path not found") - } - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) -} diff --git a/internal/handlers/history.go b/internal/handlers/history.go deleted file mode 100644 index d4e79c1d..00000000 --- a/internal/handlers/history.go +++ /dev/null @@ -1,259 +0,0 @@ -package handlers - -import ( - "context" - "errors" - "fmt" - "log/slog" - "net/http" - "strings" - - "github.com/labstack/echo/v4" - - "github.com/memohai/memoh/internal/auth" - "github.com/memohai/memoh/internal/bots" - "github.com/memohai/memoh/internal/history" - "github.com/memohai/memoh/internal/identity" - "github.com/memohai/memoh/internal/users" -) - -type HistoryHandler struct { - service *history.Service - botService *bots.Service - userService *users.Service - logger *slog.Logger -} - -func NewHistoryHandler(log *slog.Logger, service *history.Service, botService *bots.Service, userService *users.Service) *HistoryHandler { - return &HistoryHandler{ - service: service, - botService: botService, - userService: userService, - logger: log.With(slog.String("handler", "history")), - } -} - -func (h *HistoryHandler) Register(e *echo.Echo) { - group := e.Group("/bots/:bot_id/history") - group.POST("", h.Create) - group.GET("", h.List) - group.GET("/:id", h.Get) - group.DELETE("/:id", h.Delete) - group.DELETE("", h.DeleteAll) -} - -// Create godoc -// @Summary Create history record -// @Description Create a history record for current user -// @Tags history -// @Param bot_id path string true "Bot ID" -// @Param payload body history.CreateRequest true "History payload" -// @Success 201 {object} history.Record -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/history [post] -func (h *HistoryHandler) Create(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - var req history.CreateRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - resp, err := h.service.Create(c.Request().Context(), botID, sessionID, req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusCreated, resp) -} - -// Get godoc -// @Summary Get history record -// @Description Get a history record by ID (must belong to current user) -// @Tags history -// @Param bot_id path string true "Bot ID" -// @Param id path string true "History ID" -// @Success 200 {object} history.Record -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/history/{id} [get] -func (h *HistoryHandler) Get(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - id := c.Param("id") - if id == "" { - return echo.NewHTTPError(http.StatusBadRequest, "id is required") - } - record, err := h.service.Get(c.Request().Context(), id) - if err != nil { - return echo.NewHTTPError(http.StatusNotFound, err.Error()) - } - if record.BotID != botID { - return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") - } - return c.JSON(http.StatusOK, record) -} - -// List godoc -// @Summary List history records -// @Description List history records for current user -// @Tags history -// @Param bot_id path string true "Bot ID" -// @Param limit query int false "Limit" -// @Success 200 {object} history.ListResponse -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/history [get] -func (h *HistoryHandler) List(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - limit := 0 - if raw := c.QueryParam("limit"); raw != "" { - if _, err := fmt.Sscanf(raw, "%d", &limit); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "invalid limit") - } - } - items, err := h.service.List(c.Request().Context(), botID, sessionID, limit) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, history.ListResponse{Items: items}) -} - -// Delete godoc -// @Summary Delete history record -// @Description Delete a history record by ID (must belong to current user) -// @Tags history -// @Param bot_id path string true "Bot ID" -// @Param id path string true "History ID" -// @Success 204 "No Content" -// @Failure 400 {object} ErrorResponse -// @Failure 403 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/history/{id} [delete] -func (h *HistoryHandler) Delete(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - id := c.Param("id") - if id == "" { - return echo.NewHTTPError(http.StatusBadRequest, "id is required") - } - record, err := h.service.Get(c.Request().Context(), id) - if err != nil { - return echo.NewHTTPError(http.StatusNotFound, err.Error()) - } - if record.BotID != botID { - return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") - } - if err := h.service.Delete(c.Request().Context(), id); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.NoContent(http.StatusNoContent) -} - -// DeleteAll godoc -// @Summary Delete all history records -// @Description Delete all history records for current user -// @Tags history -// @Param bot_id path string true "Bot ID" -// @Success 204 "No Content" -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/history [delete] -func (h *HistoryHandler) DeleteAll(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - if err := h.service.DeleteBySession(c.Request().Context(), botID, sessionID); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.NoContent(http.StatusNoContent) -} - -func (h *HistoryHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) - if err != nil { - return "", err - } - if err := identity.ValidateUserID(userID); err != nil { - return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - return userID, nil -} - -func (h *HistoryHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") - } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) - if err != nil { - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) - if err != nil { - if errors.Is(err, bots.ErrBotNotFound) { - return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") - } - if errors.Is(err, bots.ErrBotAccessDenied) { - return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied") - } - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return bot, nil -} \ No newline at end of file diff --git a/internal/handlers/local_channel.go b/internal/handlers/local_channel.go index 9100a014..23591bb7 100644 --- a/internal/handlers/local_channel.go +++ b/internal/handlers/local_channel.go @@ -10,52 +10,52 @@ import ( "strings" "time" - "github.com/google/uuid" "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/channel" "github.com/memohai/memoh/internal/channel/adapters/local" + "github.com/memohai/memoh/internal/conversation" "github.com/memohai/memoh/internal/identity" - "github.com/memohai/memoh/internal/users" ) +// LocalChannelHandler handles local channel (CLI/Web) routes backed by bot history. type LocalChannelHandler struct { channelType channel.ChannelType channelManager *channel.Manager channelService *channel.Service - sessionHub *local.SessionHub + chatService *conversation.Service + routeHub *local.RouteHub botService *bots.Service - userService *users.Service + accountService *accounts.Service } -func NewLocalChannelHandler(channelType channel.ChannelType, channelManager *channel.Manager, channelService *channel.Service, sessionHub *local.SessionHub, botService *bots.Service, userService *users.Service) *LocalChannelHandler { +// NewLocalChannelHandler creates a local channel handler. +func NewLocalChannelHandler(channelType channel.ChannelType, channelManager *channel.Manager, channelService *channel.Service, chatService *conversation.Service, routeHub *local.RouteHub, botService *bots.Service, accountService *accounts.Service) *LocalChannelHandler { return &LocalChannelHandler{ channelType: channelType, channelManager: channelManager, channelService: channelService, - sessionHub: sessionHub, + chatService: chatService, + routeHub: routeHub, botService: botService, - userService: userService, + accountService: accountService, } } +// Register registers the local channel routes. func (h *LocalChannelHandler) Register(e *echo.Echo) { prefix := fmt.Sprintf("/bots/:bot_id/%s", h.channelType.String()) group := e.Group(prefix) - group.POST("/sessions", h.CreateSession) - group.GET("/sessions/:session_id/stream", h.StreamSession) - group.POST("/sessions/:session_id/messages", h.PostMessage) + group.GET("/stream", h.StreamMessages) + group.POST("/messages", h.PostMessage) } -type localSessionResponse struct { - SessionID string `json:"session_id"` - StreamURL string `json:"stream_url"` -} - -func (h *LocalChannelHandler) CreateSession(c echo.Context) error { - userID, err := h.requireUserID(c) +// StreamMessages streams responses for the bot route. +func (h *LocalChannelHandler) StreamMessages(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -63,41 +63,14 @@ func (h *LocalChannelHandler) CreateSession(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } - if h.channelService == nil { - return echo.NewHTTPError(http.StatusInternalServerError, "channel service not configured") - } - sessionID := fmt.Sprintf("%s:%s", h.channelType.String(), uuid.NewString()) - if err := h.channelService.UpsertChannelSession(c.Request().Context(), sessionID, botID, "", userID, "", h.channelType.String(), sessionID, "", nil); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - streamURL := fmt.Sprintf("/bots/%s/%s/sessions/%s/stream", botID, h.channelType.String(), sessionID) - return c.JSON(http.StatusOK, localSessionResponse{SessionID: sessionID, StreamURL: streamURL}) -} - -func (h *LocalChannelHandler) StreamSession(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { + if err := h.ensureBotParticipant(c.Request().Context(), botID, channelIdentityID); err != nil { return err } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - sessionID := strings.TrimSpace(c.Param("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - if err := h.ensureSessionOwner(c.Request().Context(), botID, sessionID, userID); err != nil { - return err - } - if h.sessionHub == nil { - return echo.NewHTTPError(http.StatusInternalServerError, "session hub not configured") + if h.routeHub == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "route hub not configured") } c.Response().Header().Set(echo.HeaderContentType, "text/event-stream") @@ -111,7 +84,7 @@ func (h *LocalChannelHandler) StreamSession(c echo.Context) error { } writer := bufio.NewWriter(c.Response().Writer) - _, stream, cancel := h.sessionHub.Subscribe(sessionID) + _, stream, cancel := h.routeHub.Subscribe(botID) defer cancel() for { @@ -123,8 +96,8 @@ func (h *LocalChannelHandler) StreamSession(c echo.Context) error { return nil } payload := map[string]any{ - "target": msg.Target, - "message": msg.Message, + "target": msg.Target, + "event": msg.Event, } data, err := json.Marshal(payload) if err != nil { @@ -141,8 +114,9 @@ type localMessageRequest struct { Message channel.Message `json:"message"` } +// PostMessage sends a message through the local channel. func (h *LocalChannelHandler) PostMessage(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -150,14 +124,10 @@ func (h *LocalChannelHandler) PostMessage(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - sessionID := strings.TrimSpace(c.Param("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } - if err := h.ensureSessionOwner(c.Request().Context(), botID, sessionID, userID); err != nil { + if err := h.ensureBotParticipant(c.Request().Context(), botID, channelIdentityID); err != nil { return err } if h.channelManager == nil || h.channelService == nil { @@ -167,28 +137,28 @@ func (h *LocalChannelHandler) PostMessage(c echo.Context) error { if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - text := strings.TrimSpace(req.Message.PlainText()) - if text == "" { + if req.Message.IsEmpty() { return echo.NewHTTPError(http.StatusBadRequest, "message is required") } cfg, err := h.channelService.ResolveEffectiveConfig(c.Request().Context(), botID, h.channelType) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } + routeKey := botID msg := channel.InboundMessage{ Channel: h.channelType, Message: req.Message, BotID: botID, - ReplyTarget: sessionID, - SessionKey: sessionID, + ReplyTarget: routeKey, + RouteKey: routeKey, Sender: channel.Identity{ - ExternalID: userID, + SubjectID: channelIdentityID, Attributes: map[string]string{ - "user_id": userID, + "user_id": channelIdentityID, }, }, Conversation: channel.Conversation{ - ID: sessionID, + ID: routeKey, Type: "p2p", }, ReceivedAt: time.Now().UTC(), @@ -200,46 +170,40 @@ func (h *LocalChannelHandler) PostMessage(c echo.Context) error { return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) } -func (h *LocalChannelHandler) ensureSessionOwner(ctx context.Context, botID, sessionID, userID string) error { - if h.channelService == nil { - return echo.NewHTTPError(http.StatusInternalServerError, "channel service not configured") +func (h *LocalChannelHandler) ensureBotParticipant(ctx context.Context, botID, channelIdentityID string) error { + if h.chatService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured") } - session, err := h.channelService.GetChannelSession(ctx, sessionID) + ok, err := h.chatService.IsParticipant(ctx, botID, channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - if strings.TrimSpace(session.SessionID) == "" { - return echo.NewHTTPError(http.StatusNotFound, "session not found") - } - if session.BotID != botID { - return echo.NewHTTPError(http.StatusForbidden, "session access denied") - } - if session.UserID != userID { - return echo.NewHTTPError(http.StatusForbidden, "session access denied") + if !ok { + return echo.NewHTTPError(http.StatusForbidden, "bot access denied") } return nil } -func (h *LocalChannelHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) +func (h *LocalChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.UserIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return userID, nil + return channelIdentityID, nil } -func (h *LocalChannelHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { +func (h *LocalChannelHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: true}) + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: true}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") diff --git a/internal/handlers/mcp.go b/internal/handlers/mcp.go index 0bb4a36d..784150d4 100644 --- a/internal/handlers/mcp.go +++ b/internal/handlers/mcp.go @@ -10,26 +10,26 @@ import ( "github.com/jackc/pgx/v5" "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/identity" "github.com/memohai/memoh/internal/mcp" - "github.com/memohai/memoh/internal/users" ) type MCPHandler struct { - service *mcp.ConnectionService - botService *bots.Service - userService *users.Service - logger *slog.Logger + service *mcp.ConnectionService + botService *bots.Service + accountService *accounts.Service + logger *slog.Logger } -func NewMCPHandler(log *slog.Logger, service *mcp.ConnectionService, botService *bots.Service, userService *users.Service) *MCPHandler { +func NewMCPHandler(log *slog.Logger, service *mcp.ConnectionService, botService *bots.Service, accountService *accounts.Service) *MCPHandler { return &MCPHandler{ - service: service, - botService: botService, - userService: userService, - logger: log.With(slog.String("handler", "mcp")), + service: service, + botService: botService, + accountService: accountService, + logger: log.With(slog.String("handler", "mcp")), } } @@ -46,7 +46,6 @@ func (h *MCPHandler) Register(e *echo.Echo) { // @Summary List MCP connections // @Description List MCP connections for a bot // @Tags mcp -// @Param bot_id path string true "Bot ID" // @Success 200 {object} mcp.ListResponse // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse @@ -54,7 +53,7 @@ func (h *MCPHandler) Register(e *echo.Echo) { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/mcp [get] func (h *MCPHandler) List(c echo.Context) error { - userID, err := h.requireUserID(c) + userID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -76,7 +75,6 @@ func (h *MCPHandler) List(c echo.Context) error { // @Summary Create MCP connection // @Description Create a MCP connection for a bot // @Tags mcp -// @Param bot_id path string true "Bot ID" // @Param payload body mcp.UpsertRequest true "MCP payload" // @Success 201 {object} mcp.Connection // @Failure 400 {object} ErrorResponse @@ -85,7 +83,7 @@ func (h *MCPHandler) List(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/mcp [post] func (h *MCPHandler) Create(c echo.Context) error { - userID, err := h.requireUserID(c) + userID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -111,7 +109,6 @@ func (h *MCPHandler) Create(c echo.Context) error { // @Summary Get MCP connection // @Description Get a MCP connection by ID // @Tags mcp -// @Param bot_id path string true "Bot ID" // @Param id path string true "MCP ID" // @Success 200 {object} mcp.Connection // @Failure 400 {object} ErrorResponse @@ -120,7 +117,7 @@ func (h *MCPHandler) Create(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/mcp/{id} [get] func (h *MCPHandler) Get(c echo.Context) error { - userID, err := h.requireUserID(c) + userID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -149,7 +146,6 @@ func (h *MCPHandler) Get(c echo.Context) error { // @Summary Update MCP connection // @Description Update a MCP connection by ID // @Tags mcp -// @Param bot_id path string true "Bot ID" // @Param id path string true "MCP ID" // @Param payload body mcp.UpsertRequest true "MCP payload" // @Success 200 {object} mcp.Connection @@ -159,7 +155,7 @@ func (h *MCPHandler) Get(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/mcp/{id} [put] func (h *MCPHandler) Update(c echo.Context) error { - userID, err := h.requireUserID(c) + userID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -192,7 +188,6 @@ func (h *MCPHandler) Update(c echo.Context) error { // @Summary Delete MCP connection // @Description Delete a MCP connection by ID // @Tags mcp -// @Param bot_id path string true "Bot ID" // @Param id path string true "MCP ID" // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse @@ -201,7 +196,7 @@ func (h *MCPHandler) Update(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/mcp/{id} [delete] func (h *MCPHandler) Delete(c echo.Context) error { - userID, err := h.requireUserID(c) + userID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -222,26 +217,26 @@ func (h *MCPHandler) Delete(c echo.Context) error { return c.NoContent(http.StatusNoContent) } -func (h *MCPHandler) requireUserID(c echo.Context) (string, error) { +func (h *MCPHandler) requireChannelIdentityID(c echo.Context) (string, error) { userID, err := auth.UserIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(userID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } return userID, nil } -func (h *MCPHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { +func (h *MCPHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") diff --git a/internal/handlers/mcp_federation_gateway.go b/internal/handlers/mcp_federation_gateway.go new file mode 100644 index 00000000..ea9bd57e --- /dev/null +++ b/internal/handlers/mcp_federation_gateway.go @@ -0,0 +1,480 @@ +package handlers + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "strings" + "time" + + mcpgw "github.com/memohai/memoh/internal/mcp" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type MCPFederationGateway struct { + handler *ContainerdHandler + logger *slog.Logger + client *http.Client +} + +func NewMCPFederationGateway(log *slog.Logger, handler *ContainerdHandler) *MCPFederationGateway { + if log == nil { + log = slog.Default() + } + return &MCPFederationGateway{ + handler: handler, + logger: log.With(slog.String("gateway", "mcp_federation")), + client: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +func (g *MCPFederationGateway) ListHTTPConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { + session, err := g.connectStreamableSession(ctx, connection) + if err != nil { + return nil, err + } + defer func() { _ = session.Close() }() + result, err := session.ListTools(ctx, &sdkmcp.ListToolsParams{}) + if err != nil { + return nil, err + } + return convertSDKTools(result.Tools), nil +} + +func (g *MCPFederationGateway) CallHTTPConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { + session, err := g.connectStreamableSession(ctx, connection) + if err != nil { + return nil, err + } + defer func() { _ = session.Close() }() + result, err := session.CallTool(ctx, &sdkmcp.CallToolParams{ + Name: strings.TrimSpace(toolName), + Arguments: args, + }) + if err != nil { + return nil, err + } + return wrapSDKToolResult(result) +} + +func (g *MCPFederationGateway) ListSSEConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { + session, err := g.connectSSESession(ctx, connection) + if err != nil { + return nil, err + } + defer func() { _ = session.Close() }() + result, err := session.ListTools(ctx, &sdkmcp.ListToolsParams{}) + if err != nil { + return nil, err + } + return convertSDKTools(result.Tools), nil +} + +func (g *MCPFederationGateway) CallSSEConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { + session, err := g.connectSSESession(ctx, connection) + if err != nil { + return nil, err + } + defer func() { _ = session.Close() }() + result, err := session.CallTool(ctx, &sdkmcp.CallToolParams{ + Name: strings.TrimSpace(toolName), + Arguments: args, + }) + if err != nil { + return nil, err + } + return wrapSDKToolResult(result) +} + +func (g *MCPFederationGateway) connectStreamableSession(ctx context.Context, connection mcpgw.Connection) (*sdkmcp.ClientSession, error) { + url := strings.TrimSpace(anyToString(connection.Config["url"])) + if url == "" { + return nil, fmt.Errorf("http mcp url is required") + } + client := sdkmcp.NewClient(&sdkmcp.Implementation{ + Name: "memoh-federation-client", + Version: "v1", + }, nil) + transport := &sdkmcp.StreamableClientTransport{ + Endpoint: url, + HTTPClient: g.connectionHTTPClient(connection), + MaxRetries: -1, + } + return client.Connect(ctx, transport, nil) +} + +func (g *MCPFederationGateway) connectSSESession(ctx context.Context, connection mcpgw.Connection) (*sdkmcp.ClientSession, error) { + endpoints := resolveSSEEndpointCandidates(connection.Config) + if len(endpoints) == 0 { + return nil, fmt.Errorf("sse mcp url is required") + } + var lastErr error + for _, endpoint := range endpoints { + client := sdkmcp.NewClient(&sdkmcp.Implementation{ + Name: "memoh-federation-client", + Version: "v1", + }, nil) + transport := &sdkmcp.SSEClientTransport{ + Endpoint: endpoint, + HTTPClient: g.connectionHTTPClient(connection), + } + session, err := client.Connect(ctx, transport, nil) + if err == nil { + return session, nil + } + lastErr = err + } + if lastErr == nil { + lastErr = fmt.Errorf("no sse endpoint candidate available") + } + return nil, fmt.Errorf("connect sse mcp failed: %w", lastErr) +} + +func resolveSSEEndpointCandidates(config map[string]any) []string { + if config == nil { + return []string{} + } + + seen := map[string]struct{}{} + out := make([]string, 0, 4) + appendEndpoint := func(value string) { + value = strings.TrimSpace(value) + if value == "" { + return + } + if _, ok := seen[value]; ok { + return + } + seen[value] = struct{}{} + out = append(out, value) + } + + for _, key := range []string{"sse_url", "sseUrl"} { + appendEndpoint(anyToString(config[key])) + } + + baseURL := strings.TrimSpace(anyToString(config["url"])) + appendEndpoint(baseURL) + + var messageURL string + for _, key := range []string{"message_url", "messageUrl"} { + if value := strings.TrimSpace(anyToString(config[key])); value != "" { + messageURL = value + break + } + } + if messageURL != "" { + normalized := strings.TrimSuffix(messageURL, "/") + if strings.HasSuffix(normalized, "/message") { + appendEndpoint(strings.TrimSuffix(normalized, "/message") + "/sse") + } + appendEndpoint(messageURL) + } + + if baseURL != "" { + normalized := strings.TrimSuffix(baseURL, "/") + if strings.HasSuffix(normalized, "/message") { + appendEndpoint(strings.TrimSuffix(normalized, "/message") + "/sse") + } + } + + return out +} + +func (g *MCPFederationGateway) connectionHTTPClient(connection mcpgw.Connection) *http.Client { + base := g.client + if base == nil { + base = &http.Client{Timeout: 30 * time.Second} + } + headers := normalizeHeaderMap(connection.Config["headers"]) + if len(headers) == 0 { + return base + } + transport := base.Transport + if transport == nil { + transport = http.DefaultTransport + } + return &http.Client{ + Timeout: base.Timeout, + CheckRedirect: base.CheckRedirect, + Jar: base.Jar, + Transport: &staticHeaderRoundTripper{ + next: transport, + headers: headers, + }, + } +} + +func (g *MCPFederationGateway) ListStdioConnectionTools(ctx context.Context, botID string, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { + sess, err := g.startStdioConnectionSession(ctx, botID, connection) + if err != nil { + return nil, err + } + defer sess.closeWithError(io.EOF) + + payload, err := sess.call(ctx, mcpgw.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcpgw.RawStringID("federated-stdio-tools-list"), + Method: "tools/list", + }) + if err != nil { + return nil, err + } + return parseGatewayToolsListPayload(payload) +} + +func (g *MCPFederationGateway) CallStdioConnectionTool(ctx context.Context, botID string, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { + sess, err := g.startStdioConnectionSession(ctx, botID, connection) + if err != nil { + return nil, err + } + defer sess.closeWithError(io.EOF) + + params, err := json.Marshal(map[string]any{ + "name": toolName, + "arguments": args, + }) + if err != nil { + return nil, err + } + return sess.call(ctx, mcpgw.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcpgw.RawStringID("federated-stdio-tools-call"), + Method: "tools/call", + Params: params, + }) +} + +func (g *MCPFederationGateway) startStdioConnectionSession(ctx context.Context, botID string, connection mcpgw.Connection) (*mcpSession, error) { + if g.handler == nil { + return nil, fmt.Errorf("containerd handler not configured") + } + containerID, err := g.handler.botContainerID(ctx, botID) + if err != nil { + return nil, err + } + if err := g.handler.validateMCPContainer(ctx, containerID, botID); err != nil { + return nil, err + } + if err := g.handler.ensureContainerAndTask(ctx, containerID, botID); err != nil { + return nil, err + } + + command := strings.TrimSpace(anyToString(connection.Config["command"])) + if command == "" { + return nil, fmt.Errorf("stdio mcp command is required") + } + request := MCPStdioRequest{ + Name: strings.TrimSpace(connection.Name), + Command: command, + Args: normalizeStringSlice(connection.Config["args"]), + Env: normalizeStringMap(connection.Config["env"]), + Cwd: strings.TrimSpace(anyToString(connection.Config["cwd"])), + } + return g.handler.startContainerdMCPCommandSession(ctx, containerID, request) +} + +func parseGatewayToolsListPayload(payload map[string]any) ([]mcpgw.ToolDescriptor, error) { + if err := mcpgw.PayloadError(payload); err != nil { + return nil, err + } + result, ok := payload["result"].(map[string]any) + if !ok { + return nil, fmt.Errorf("invalid tools/list result") + } + rawTools, ok := result["tools"].([]any) + if !ok { + return nil, fmt.Errorf("invalid tools/list tools field") + } + tools := make([]mcpgw.ToolDescriptor, 0, len(rawTools)) + for _, rawTool := range rawTools { + item, ok := rawTool.(map[string]any) + if !ok { + continue + } + name := strings.TrimSpace(anyToString(item["name"])) + if name == "" { + continue + } + description := strings.TrimSpace(anyToString(item["description"])) + inputSchema, _ := item["inputSchema"].(map[string]any) + if inputSchema == nil { + inputSchema = map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + } + tools = append(tools, mcpgw.ToolDescriptor{ + Name: name, + Description: description, + InputSchema: inputSchema, + }) + } + return tools, nil +} + +func convertSDKTools(items []*sdkmcp.Tool) []mcpgw.ToolDescriptor { + if len(items) == 0 { + return []mcpgw.ToolDescriptor{} + } + tools := make([]mcpgw.ToolDescriptor, 0, len(items)) + for _, item := range items { + if item == nil { + continue + } + name := strings.TrimSpace(item.Name) + if name == "" { + continue + } + tools = append(tools, mcpgw.ToolDescriptor{ + Name: name, + Description: strings.TrimSpace(item.Description), + InputSchema: normalizeToolInputSchema(item.InputSchema), + }) + } + return tools +} + +func normalizeToolInputSchema(raw any) map[string]any { + if schema, ok := raw.(map[string]any); ok && schema != nil { + return schema + } + if raw != nil { + payload, err := json.Marshal(raw) + if err == nil { + var schema map[string]any + if err := json.Unmarshal(payload, &schema); err == nil && schema != nil { + return schema + } + } + } + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +func wrapSDKToolResult(result *sdkmcp.CallToolResult) (map[string]any, error) { + if result == nil { + return map[string]any{ + "result": mcpgw.BuildToolSuccessResult(map[string]any{"ok": true}), + }, nil + } + payload, err := json.Marshal(result) + if err != nil { + return nil, err + } + var parsed map[string]any + if err := json.Unmarshal(payload, &parsed); err != nil { + return nil, err + } + if parsed == nil { + parsed = map[string]any{} + } + return map[string]any{"result": parsed}, nil +} + +func normalizeHeaderMap(raw any) map[string]string { + switch value := raw.(type) { + case map[string]string: + return value + case map[string]any: + out := make(map[string]string, len(value)) + for k, v := range value { + key := strings.TrimSpace(k) + val := strings.TrimSpace(anyToString(v)) + if key == "" || val == "" { + continue + } + out[key] = val + } + return out + default: + return map[string]string{} + } +} + +func normalizeStringSlice(raw any) []string { + switch value := raw.(type) { + case []string: + out := make([]string, 0, len(value)) + for _, item := range value { + item = strings.TrimSpace(item) + if item != "" { + out = append(out, item) + } + } + return out + case []any: + out := make([]string, 0, len(value)) + for _, item := range value { + val := strings.TrimSpace(anyToString(item)) + if val != "" { + out = append(out, val) + } + } + return out + default: + return []string{} + } +} + +func normalizeStringMap(raw any) map[string]string { + switch value := raw.(type) { + case map[string]string: + return value + case map[string]any: + out := make(map[string]string, len(value)) + for k, v := range value { + key := strings.TrimSpace(k) + val := strings.TrimSpace(anyToString(v)) + if key == "" { + continue + } + out[key] = val + } + return out + default: + return map[string]string{} + } +} + +func anyToString(v any) string { + if v == nil { + return "" + } + switch value := v.(type) { + case string: + return value + default: + return fmt.Sprintf("%v", v) + } +} + +type staticHeaderRoundTripper struct { + next http.RoundTripper + headers map[string]string +} + +func (t *staticHeaderRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + next := t.next + if next == nil { + next = http.DefaultTransport + } + clone := req.Clone(req.Context()) + clone.Header = req.Header.Clone() + for key, value := range t.headers { + headerKey := strings.TrimSpace(key) + headerVal := strings.TrimSpace(value) + if headerKey == "" || headerVal == "" { + continue + } + clone.Header.Set(headerKey, headerVal) + } + return next.RoundTrip(clone) +} diff --git a/internal/handlers/mcp_federation_gateway_test.go b/internal/handlers/mcp_federation_gateway_test.go new file mode 100644 index 00000000..ff453626 --- /dev/null +++ b/internal/handlers/mcp_federation_gateway_test.go @@ -0,0 +1,188 @@ +package handlers + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + mcpgw "github.com/memohai/memoh/internal/mcp" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type testToolInput struct { + Query string `json:"query"` +} + +type testToolOutput struct { + Echo string `json:"echo"` +} + +func newTestMCPServer() *sdkmcp.Server { + server := sdkmcp.NewServer(&sdkmcp.Implementation{ + Name: "test-federation-server", + Version: "v1", + }, nil) + sdkmcp.AddTool(server, &sdkmcp.Tool{ + Name: "echo", + Description: "Echo query", + }, func(ctx context.Context, request *sdkmcp.CallToolRequest, input testToolInput) (*sdkmcp.CallToolResult, testToolOutput, error) { + return nil, testToolOutput{Echo: input.Query}, nil + }) + return server +} + +func withAuthHeader(next http.Handler, token string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != token { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) +} + +func TestFederationGatewayHTTPConnectionViaSDK(t *testing.T) { + server := newTestMCPServer() + handler := sdkmcp.NewStreamableHTTPHandler(func(*http.Request) *sdkmcp.Server { + return server + }, nil) + httpServer := httptest.NewServer(withAuthHeader(handler, "Bearer test-token")) + defer httpServer.Close() + + gateway := &MCPFederationGateway{ + client: httpServer.Client(), + } + connection := mcpgw.Connection{ + Config: map[string]any{ + "url": httpServer.URL, + "headers": map[string]any{ + "Authorization": "Bearer test-token", + }, + }, + } + + tools, err := gateway.ListHTTPConnectionTools(context.Background(), connection) + if err != nil { + t.Fatalf("list http tools failed: %v", err) + } + if len(tools) != 1 || tools[0].Name != "echo" { + t.Fatalf("unexpected tool list: %#v", tools) + } + + payload, err := gateway.CallHTTPConnectionTool(context.Background(), connection, "echo", map[string]any{ + "query": "hello-http", + }) + if err != nil { + t.Fatalf("call http tool failed: %v", err) + } + assertEchoResult(t, payload, "hello-http") +} + +func TestFederationGatewaySSEConnectionViaSDK(t *testing.T) { + server := newTestMCPServer() + handler := sdkmcp.NewSSEHandler(func(*http.Request) *sdkmcp.Server { + return server + }, nil) + httpServer := httptest.NewServer(withAuthHeader(handler, "Bearer test-token")) + defer httpServer.Close() + + gateway := &MCPFederationGateway{ + client: httpServer.Client(), + } + connection := mcpgw.Connection{ + Config: map[string]any{ + "url": httpServer.URL, + "headers": map[string]any{ + "Authorization": "Bearer test-token", + }, + }, + } + + tools, err := gateway.ListSSEConnectionTools(context.Background(), connection) + if err != nil { + t.Fatalf("list sse tools failed: %v", err) + } + if len(tools) != 1 || tools[0].Name != "echo" { + t.Fatalf("unexpected tool list: %#v", tools) + } + + payload, err := gateway.CallSSEConnectionTool(context.Background(), connection, "echo", map[string]any{ + "query": "hello-sse", + }) + if err != nil { + t.Fatalf("call sse tool failed: %v", err) + } + assertEchoResult(t, payload, "hello-sse") +} + +func TestResolveSSEEndpointCandidatesCompatibility(t *testing.T) { + tests := []struct { + name string + config map[string]any + contains string + firstWant string + }{ + { + name: "prefer explicit sse_url", + config: map[string]any{"sse_url": "http://example.com/custom-sse", "url": "http://example.com/sse"}, + firstWant: "http://example.com/custom-sse", + contains: "http://example.com/sse", + }, + { + name: "fallback to url as endpoint", + config: map[string]any{"url": "http://example.com/sse"}, + firstWant: "http://example.com/sse", + contains: "http://example.com/sse", + }, + { + name: "derive endpoint from message url", + config: map[string]any{"message_url": "http://example.com/message"}, + firstWant: "http://example.com/sse", + contains: "http://example.com/message", + }, + { + name: "derive endpoint from url message suffix", + config: map[string]any{"url": "http://example.com/message"}, + firstWant: "http://example.com/message", + contains: "http://example.com/sse", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := resolveSSEEndpointCandidates(tt.config) + if len(got) == 0 { + t.Fatalf("resolve sse endpoints should not be empty") + } + if got[0] != tt.firstWant { + t.Fatalf("unexpected first endpoint: got=%s want=%s", got[0], tt.firstWant) + } + found := false + for _, item := range got { + if item == tt.contains { + found = true + break + } + } + if !found { + t.Fatalf("endpoint candidates missing expected value: %s in %#v", tt.contains, got) + } + }) + } +} + +func assertEchoResult(t *testing.T, payload map[string]any, expected string) { + t.Helper() + result, ok := payload["result"].(map[string]any) + if !ok { + t.Fatalf("missing result payload: %#v", payload) + } + structured, ok := result["structuredContent"].(map[string]any) + if !ok { + t.Fatalf("missing structured content: %#v", result) + } + if got := anyToString(structured["echo"]); got != expected { + t.Fatalf("unexpected echo result: got=%s want=%s", got, expected) + } +} diff --git a/internal/handlers/mcp_stdio.go b/internal/handlers/mcp_stdio.go index d213a2d7..51fff854 100644 --- a/internal/handlers/mcp_stdio.go +++ b/internal/handlers/mcp_stdio.go @@ -14,6 +14,8 @@ import ( "github.com/google/uuid" "github.com/labstack/echo/v4" + sdkjsonrpc "github.com/modelcontextprotocol/go-sdk/jsonrpc" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" ctr "github.com/memohai/memoh/internal/containerd" mcptools "github.com/memohai/memoh/internal/mcp" @@ -28,9 +30,9 @@ type MCPStdioRequest struct { } type MCPStdioResponse struct { - SessionID string `json:"session_id"` - URL string `json:"url"` - Tools []string `json:"tools,omitempty"` + ConnectionID string `json:"connection_id"` + URL string `json:"url"` + Tools []string `json:"tools,omitempty"` } type mcpStdioSession struct { @@ -74,7 +76,7 @@ func (h *ContainerdHandler) CreateMCPStdio(c echo.Context) error { if err := h.validateMCPContainer(ctx, containerID, botID); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if err := h.ensureTaskRunning(ctx, containerID); err != nil { + if err := h.ensureContainerAndTask(ctx, containerID, botID); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -83,9 +85,9 @@ func (h *ContainerdHandler) CreateMCPStdio(c echo.Context) error { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } tools := h.probeMCPTools(ctx, sess, botID, strings.TrimSpace(req.Name)) - sessionID := uuid.NewString() + connectionID := uuid.NewString() record := &mcpStdioSession{ - id: sessionID, + id: connectionID, botID: botID, containerID: containerID, name: strings.TrimSpace(req.Name), @@ -95,19 +97,19 @@ func (h *ContainerdHandler) CreateMCPStdio(c echo.Context) error { } sess.onClose = func() { h.mcpStdioMu.Lock() - if current, ok := h.mcpStdioSess[sessionID]; ok && current == record { - delete(h.mcpStdioSess, sessionID) + if current, ok := h.mcpStdioSess[connectionID]; ok && current == record { + delete(h.mcpStdioSess, connectionID) } h.mcpStdioMu.Unlock() } h.mcpStdioMu.Lock() - h.mcpStdioSess[sessionID] = record + h.mcpStdioSess[connectionID] = record h.mcpStdioMu.Unlock() return c.JSON(http.StatusOK, MCPStdioResponse{ - SessionID: sessionID, - URL: fmt.Sprintf("/bots/%s/mcp-stdio/%s", botID, sessionID), - Tools: tools, + ConnectionID: connectionID, + URL: fmt.Sprintf("/bots/%s/mcp-stdio/%s", botID, connectionID), + Tools: tools, }) } @@ -116,31 +118,31 @@ func (h *ContainerdHandler) CreateMCPStdio(c echo.Context) error { // @Description Proxies MCP JSON-RPC requests to a stdio MCP process in the container. // @Tags containerd // @Param bot_id path string true "Bot ID" -// @Param session_id path string true "Session ID" +// @Param connection_id path string true "Connection ID" // @Param payload body object true "JSON-RPC request" // @Success 200 {object} object "JSON-RPC response: {jsonrpc,id,result|error}" // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp-stdio/{session_id} [post] +// @Router /bots/{bot_id}/mcp-stdio/{connection_id} [post] func (h *ContainerdHandler) HandleMCPStdio(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { return err } - sessionID := strings.TrimSpace(c.Param("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") + connectionID := strings.TrimSpace(c.Param("connection_id")) + if connectionID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "connection_id is required") } h.mcpStdioMu.Lock() - session := h.mcpStdioSess[sessionID] + session := h.mcpStdioSess[connectionID] h.mcpStdioMu.Unlock() if session == nil || session.session == nil || session.botID != botID { - return echo.NewHTTPError(http.StatusNotFound, "mcp session not found") + return echo.NewHTTPError(http.StatusNotFound, "mcp connection not found") } select { case <-session.session.closed: - return echo.NewHTTPError(http.StatusNotFound, "mcp session closed") + return echo.NewHTTPError(http.StatusNotFound, "mcp connection closed") default: } @@ -188,9 +190,19 @@ func (h *ContainerdHandler) startContainerdMCPCommandSession(ctx context.Context stdin: execSession.Stdin, stdout: execSession.Stdout, stderr: execSession.Stderr, - pending: make(map[string]chan mcptools.JSONRPCResponse), + pending: make(map[string]chan *sdkjsonrpc.Response), closed: make(chan struct{}), } + transport := &sdkmcp.IOTransport{ + Reader: sess.stdout, + Writer: sess.stdin, + } + conn, err := transport.Connect(ctx) + if err != nil { + sess.closeWithError(err) + return nil, err + } + sess.conn = conn h.startMCPStderrLogger(execSession.Stderr, containerID) go sess.readLoop() go func() { @@ -339,9 +351,19 @@ func (h *ContainerdHandler) startLimaMCPCommandSession(containerID string, req M stdout: stdout, stderr: stderr, cmd: cmd, - pending: make(map[string]chan mcptools.JSONRPCResponse), + pending: make(map[string]chan *sdkjsonrpc.Response), closed: make(chan struct{}), } + transport := &sdkmcp.IOTransport{ + Reader: sess.stdout, + Writer: sess.stdin, + } + conn, err := transport.Connect(context.Background()) + if err != nil { + sess.closeWithError(err) + return nil, err + } + sess.conn = conn h.startMCPStderrLogger(stderr, containerID) go sess.readLoop() diff --git a/internal/handlers/mcp_tools.go b/internal/handlers/mcp_tools.go new file mode 100644 index 00000000..07b93ef3 --- /dev/null +++ b/internal/handlers/mcp_tools.go @@ -0,0 +1,241 @@ +package handlers + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/labstack/echo/v4" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/memohai/memoh/internal/auth" + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +const ( + headerChannelIdentityID = "X-Memoh-Channel-Identity-Id" + headerSessionToken = "X-Memoh-Session-Token" + headerCurrentPlatform = "X-Memoh-Current-Platform" + headerReplyTarget = "X-Memoh-Reply-Target" +) + +func (h *ContainerdHandler) SetToolGatewayService(service *mcpgw.ToolGatewayService) { + h.toolGateway = service +} + +// HandleMCPTools godoc +// @Summary Unified MCP tools gateway +// @Description MCP endpoint for tool discovery and invocation. +// @Tags containerd +// @Param bot_id path string true "Bot ID" +// @Param payload body object true "JSON-RPC request" +// @Success 200 {object} object "JSON-RPC response: {jsonrpc,id,result|error}" +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /bots/{bot_id}/tools [post] +func (h *ContainerdHandler) HandleMCPTools(c echo.Context) error { + if h.toolGateway == nil { + return echo.NewHTTPError(http.StatusServiceUnavailable, "tool gateway not configured") + } + botID, err := h.requireBotAccess(c) + if err != nil { + return err + } + return h.handleMCPToolsWithBotID(c, botID) +} + +func (h *ContainerdHandler) handleMCPToolsWithBotID(c echo.Context, botID string) error { + session := h.buildToolSessionContext(c, botID) + + req := c.Request() + ensureStreamableAcceptHeader(req) + ctx := context.WithValue(req.Context(), toolSessionContextKey{}, session) + req = req.WithContext(ctx) + + handler := sdkmcp.NewStreamableHTTPHandler( + func(r *http.Request) *sdkmcp.Server { + return h.buildToolMCPServer(r.Context()) + }, + &sdkmcp.StreamableHTTPOptions{ + Stateless: true, + JSONResponse: true, + Logger: h.logger, + }, + ) + handler.ServeHTTP(c.Response().Writer, req) + return nil +} + +func ensureStreamableAcceptHeader(req *http.Request) { + if req == nil { + return + } + acceptValues := req.Header.Values("Accept") + joined := strings.ToLower(strings.Join(acceptValues, ",")) + hasJSON := strings.Contains(joined, "application/json") || strings.Contains(joined, "application/*") || strings.Contains(joined, "*/*") + hasStream := strings.Contains(joined, "text/event-stream") || strings.Contains(joined, "text/*") || strings.Contains(joined, "*/*") + if hasJSON && hasStream { + return + } + + base := strings.TrimSpace(strings.Join(acceptValues, ",")) + parts := make([]string, 0, 3) + if base != "" { + parts = append(parts, base) + } + if !hasJSON { + parts = append(parts, "application/json") + } + if !hasStream { + parts = append(parts, "text/event-stream") + } + if len(parts) == 0 { + parts = append(parts, "application/json", "text/event-stream") + } + req.Header.Set("Accept", strings.Join(parts, ", ")) +} + +type toolSessionContextKey struct{} + +func (h *ContainerdHandler) buildToolMCPServer(ctx context.Context) *sdkmcp.Server { + if h.toolGateway == nil { + return nil + } + session, ok := ctx.Value(toolSessionContextKey{}).(mcpgw.ToolSessionContext) + if !ok { + return nil + } + + server := sdkmcp.NewServer( + &sdkmcp.Implementation{ + Name: "memoh-tools-gateway", + Version: "1.0.0", + }, + &sdkmcp.ServerOptions{ + Capabilities: &sdkmcp.ServerCapabilities{ + Tools: &sdkmcp.ToolCapabilities{ + ListChanged: false, + }, + }, + }, + ) + server.AddReceivingMiddleware(h.toolGatewayMiddleware(session)) + return server +} + +func (h *ContainerdHandler) toolGatewayMiddleware(session mcpgw.ToolSessionContext) sdkmcp.Middleware { + return func(next sdkmcp.MethodHandler) sdkmcp.MethodHandler { + return func(ctx context.Context, method string, req sdkmcp.Request) (sdkmcp.Result, error) { + switch strings.TrimSpace(method) { + case "tools/list": + tools, err := h.toolGateway.ListTools(ctx, session) + if err != nil { + return nil, err + } + return &sdkmcp.ListToolsResult{ + Tools: convertGatewayToolsToSDK(tools), + }, nil + case "tools/call": + callReq, ok := req.(*sdkmcp.ServerRequest[*sdkmcp.CallToolParamsRaw]) + if !ok || callReq == nil || callReq.Params == nil { + return nil, fmt.Errorf("tools/call params is required") + } + payload, err := buildToolCallPayloadFromRaw(callReq.Params) + if err != nil { + return nil, err + } + result, err := h.toolGateway.CallTool(ctx, session, payload) + if err != nil { + return nil, err + } + return convertGatewayCallResultToSDK(result) + default: + return next(ctx, method, req) + } + } + } +} + +func buildToolCallPayloadFromRaw(params *sdkmcp.CallToolParamsRaw) (mcpgw.ToolCallPayload, error) { + if params == nil { + return mcpgw.ToolCallPayload{}, fmt.Errorf("tools/call params is required") + } + name := strings.TrimSpace(params.Name) + if name == "" { + return mcpgw.ToolCallPayload{}, fmt.Errorf("tools/call name is required") + } + arguments := map[string]any{} + if len(params.Arguments) > 0 { + if err := json.Unmarshal(params.Arguments, &arguments); err != nil { + return mcpgw.ToolCallPayload{}, err + } + } + if arguments == nil { + arguments = map[string]any{} + } + return mcpgw.ToolCallPayload{ + Name: name, + Arguments: arguments, + }, nil +} + +func convertGatewayToolsToSDK(items []mcpgw.ToolDescriptor) []*sdkmcp.Tool { + if len(items) == 0 { + return []*sdkmcp.Tool{} + } + tools := make([]*sdkmcp.Tool, 0, len(items)) + for _, item := range items { + name := strings.TrimSpace(item.Name) + if name == "" { + continue + } + inputSchema := item.InputSchema + if inputSchema == nil { + inputSchema = map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + } + tools = append(tools, &sdkmcp.Tool{ + Name: name, + Description: strings.TrimSpace(item.Description), + InputSchema: inputSchema, + }) + } + return tools +} + +func convertGatewayCallResultToSDK(result map[string]any) (*sdkmcp.CallToolResult, error) { + if result == nil { + result = mcpgw.BuildToolSuccessResult(map[string]any{"ok": true}) + } + payload, err := json.Marshal(result) + if err != nil { + return nil, err + } + var out sdkmcp.CallToolResult + if err := json.Unmarshal(payload, &out); err != nil { + return nil, err + } + return &out, nil +} + +func (h *ContainerdHandler) buildToolSessionContext(c echo.Context, botID string) mcpgw.ToolSessionContext { + channelIdentityID := strings.TrimSpace(c.Request().Header.Get(headerChannelIdentityID)) + if channelIdentityID == "" { + if ctxIdentityID, err := auth.UserIDFromContext(c); err == nil { + channelIdentityID = strings.TrimSpace(ctxIdentityID) + } + } + return mcpgw.ToolSessionContext{ + BotID: strings.TrimSpace(botID), + ChatID: strings.TrimSpace(botID), + ChannelIdentityID: channelIdentityID, + SessionToken: strings.TrimSpace(c.Request().Header.Get(headerSessionToken)), + CurrentPlatform: strings.TrimSpace(c.Request().Header.Get(headerCurrentPlatform)), + ReplyTarget: strings.TrimSpace(c.Request().Header.Get(headerReplyTarget)), + } +} diff --git a/internal/handlers/mcp_tools_test.go b/internal/handlers/mcp_tools_test.go new file mode 100644 index 00000000..f9ea36b8 --- /dev/null +++ b/internal/handlers/mcp_tools_test.go @@ -0,0 +1,165 @@ +package handlers + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/labstack/echo/v4" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +func TestBuildToolCallPayloadFromRaw(t *testing.T) { + params := &sdkmcp.CallToolParamsRaw{ + Name: " tool_a ", + Arguments: json.RawMessage(`{"x":1}`), + } + payload, err := buildToolCallPayloadFromRaw(params) + if err != nil { + t.Fatalf("valid payload should parse: %v", err) + } + if payload.Name != "tool_a" { + t.Fatalf("unexpected tool name: %s", payload.Name) + } + if _, ok := payload.Arguments["x"]; !ok { + t.Fatalf("expected argument x") + } + + invalid := &sdkmcp.CallToolParamsRaw{ + Name: "", + Arguments: json.RawMessage(`{}`), + } + if _, err := buildToolCallPayloadFromRaw(invalid); err == nil { + t.Fatalf("empty tool name should fail") + } +} + +func TestHandleMCPToolsWithoutGateway(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/bots/bot-1/tools", strings.NewReader(`{"jsonrpc":"2.0","id":"1","method":"tools/list"}`)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.SetPath("/bots/:bot_id/tools") + c.SetParamNames("bot_id") + c.SetParamValues("bot-1") + + handler := &ContainerdHandler{} + err := handler.HandleMCPTools(c) + if err == nil { + t.Fatalf("expected service unavailable error") + } + httpErr, ok := err.(*echo.HTTPError) + if !ok { + t.Fatalf("expected echo HTTP error, got %T", err) + } + if httpErr.Code != http.StatusServiceUnavailable { + t.Fatalf("unexpected status code: %d", httpErr.Code) + } +} + +type mcpToolsTestExecutor struct { + lastSession mcpgw.ToolSessionContext +} + +func (e *mcpToolsTestExecutor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { + e.lastSession = session + return []mcpgw.ToolDescriptor{ + { + Name: "echo_tool", + Description: "echo input", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "input": map[string]any{"type": "string"}, + }, + }, + }, + }, nil +} + +func (e *mcpToolsTestExecutor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { + e.lastSession = session + if strings.TrimSpace(toolName) != "echo_tool" { + return nil, mcpgw.ErrToolNotFound + } + return mcpgw.BuildToolSuccessResult(map[string]any{ + "ok": true, + "echo": mcpgw.StringArg(arguments, "input"), + "chat_id": session.ChatID, + "channel_identity_id": session.ChannelIdentityID, + }), nil +} + +func TestHandleMCPToolsWithGatewayAcceptCompatibility(t *testing.T) { + e := echo.New() + executor := &mcpToolsTestExecutor{} + toolGateway := mcpgw.NewToolGatewayService(slog.Default(), []mcpgw.ToolExecutor{executor}, nil) + handler := &ContainerdHandler{ + logger: slog.Default(), + toolGateway: toolGateway, + } + + listReq := httptest.NewRequest(http.MethodPost, "/bots/bot-1/tools", strings.NewReader(`{"jsonrpc":"2.0","id":"1","method":"tools/list"}`)) + listReq.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + listReq.Header.Set("Accept", "application/json") + listReq.Header.Set("X-Memoh-Channel-Identity-Id", "user-1") + listRec := httptest.NewRecorder() + listCtx := e.NewContext(listReq, listRec) + + if err := handler.handleMCPToolsWithBotID(listCtx, "bot-1"); err != nil { + t.Fatalf("list tools should succeed: %v", err) + } + if listRec.Code != http.StatusOK { + t.Fatalf("unexpected list status: %d body=%s", listRec.Code, listRec.Body.String()) + } + if !strings.Contains(strings.ToLower(listReq.Header.Get("Accept")), "text/event-stream") { + t.Fatalf("accept header should include text/event-stream: %s", listReq.Header.Get("Accept")) + } + + var listPayload map[string]any + if err := json.Unmarshal(listRec.Body.Bytes(), &listPayload); err != nil { + t.Fatalf("decode list payload failed: %v", err) + } + result, _ := listPayload["result"].(map[string]any) + tools, _ := result["tools"].([]any) + if len(tools) != 1 { + t.Fatalf("expected one tool, got: %#v", result["tools"]) + } + + callReq := httptest.NewRequest(http.MethodPost, "/bots/bot-1/tools", strings.NewReader(`{"jsonrpc":"2.0","id":"2","method":"tools/call","params":{"name":"echo_tool","arguments":{"input":"hello"}}}`)) + callReq.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + callReq.Header.Set("Accept", "application/json") + callReq.Header.Set("X-Memoh-Channel-Identity-Id", "user-1") + callRec := httptest.NewRecorder() + callCtx := e.NewContext(callReq, callRec) + + if err := handler.handleMCPToolsWithBotID(callCtx, "bot-1"); err != nil { + t.Fatalf("call tool should succeed: %v", err) + } + if callRec.Code != http.StatusOK { + t.Fatalf("unexpected call status: %d body=%s", callRec.Code, callRec.Body.String()) + } + + var callPayload map[string]any + if err := json.Unmarshal(callRec.Body.Bytes(), &callPayload); err != nil { + t.Fatalf("decode call payload failed: %v", err) + } + callResult, _ := callPayload["result"].(map[string]any) + structured, _ := callResult["structuredContent"].(map[string]any) + if echoValue := strings.TrimSpace(mcpgw.StringArg(structured, "echo")); echoValue != "hello" { + t.Fatalf("unexpected echo value: %#v", structured["echo"]) + } + if strings.TrimSpace(mcpgw.StringArg(structured, "chat_id")) != "bot-1" { + t.Fatalf("unexpected chat id: %#v", structured["chat_id"]) + } + if strings.TrimSpace(mcpgw.StringArg(structured, "channel_identity_id")) != "user-1" { + t.Fatalf("unexpected channel identity id: %#v", structured["channel_identity_id"]) + } +} diff --git a/internal/handlers/memory.go b/internal/handlers/memory.go index ce75f63f..ee0595d5 100644 --- a/internal/handlers/memory.go +++ b/internal/handlers/memory.go @@ -2,31 +2,32 @@ package handlers import ( "context" - "errors" - "fmt" "log/slog" "net/http" + "sort" "strings" "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" - "github.com/memohai/memoh/internal/bots" + "github.com/memohai/memoh/internal/conversation" "github.com/memohai/memoh/internal/identity" "github.com/memohai/memoh/internal/memory" - "github.com/memohai/memoh/internal/users" ) +// MemoryHandler handles memory CRUD operations scoped by conversation. type MemoryHandler struct { - service *memory.Service - botService *bots.Service - userService *users.Service - logger *slog.Logger + service *memory.Service + chatService *conversation.Service + accountService *accounts.Service + logger *slog.Logger } type memoryAddPayload struct { Message string `json:"message,omitempty"` Messages []memory.Message `json:"messages,omitempty"` + Namespace string `json:"namespace,omitempty"` RunID string `json:"run_id,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` Filters map[string]any `json:"filters,omitempty"` @@ -43,40 +44,31 @@ type memorySearchPayload struct { EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` } -type memoryEmbedUpsertPayload struct { - Type string `json:"type"` - Provider string `json:"provider,omitempty"` - Model string `json:"model,omitempty"` - Input memory.EmbedInput `json:"input"` - Source string `json:"source,omitempty"` - RunID string `json:"run_id,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` - Filters map[string]any `json:"filters,omitempty"` +// namespaceScope holds namespace + scopeId for a single memory scope. +type namespaceScope struct { + Namespace string + ScopeID string } -type memoryDeleteAllPayload struct { - RunID string `json:"run_id,omitempty"` -} +const sharedMemoryNamespace = "bot" -func NewMemoryHandler(log *slog.Logger, service *memory.Service, botService *bots.Service, userService *users.Service) *MemoryHandler { +// NewMemoryHandler creates a MemoryHandler. +func NewMemoryHandler(log *slog.Logger, service *memory.Service, chatService *conversation.Service, accountService *accounts.Service) *MemoryHandler { return &MemoryHandler{ - service: service, - botService: botService, - userService: userService, - logger: log.With(slog.String("handler", "memory")), + service: service, + chatService: chatService, + accountService: accountService, + logger: log.With(slog.String("handler", "memory")), } } +// Register registers chat-level memory routes. func (h *MemoryHandler) Register(e *echo.Echo) { - group := e.Group("/bots/:bot_id/memory") - group.POST("/add", h.Add) - group.POST("/embed", h.EmbedUpsert) - group.POST("/search", h.Search) - group.POST("/update", h.Update) - group.GET("/memories/:memoryId", h.Get) - group.GET("/memories", h.GetAll) - group.DELETE("/memories/:memoryId", h.Delete) - group.DELETE("/memories", h.DeleteAll) + chatGroup := e.Group("/bots/:bot_id/memory") + chatGroup.POST("", h.ChatAdd) + chatGroup.POST("/search", h.ChatSearch) + chatGroup.GET("", h.ChatGetAll) + chatGroup.DELETE("", h.ChatDeleteAll) } func (h *MemoryHandler) checkService() error { @@ -86,108 +78,52 @@ func (h *MemoryHandler) checkService() error { return nil } -// EmbedUpsert godoc -// @Summary Embed and upsert memory -// @Description Embed text or multimodal input and upsert into memory store. Auth: Bearer JWT determines user_id (sub or user_id). -// @Tags memory -// @Param bot_id path string true "Bot ID" -// @Param payload body memoryEmbedUpsertPayload true "Embed upsert request" -// @Success 200 {object} memory.EmbedUpsertResponse -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/embed [post] -func (h *MemoryHandler) EmbedUpsert(c echo.Context) error { +// --- Chat-level memory endpoints --- + +// ChatAdd adds memory into the bot-shared namespace. +func (h *MemoryHandler) ChatAdd(c echo.Context) error { if err := h.checkService(); err != nil { return err } - - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") - } - - var payload memoryEmbedUpsertPayload - if err := c.Bind(&payload); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - req := memory.EmbedUpsertRequest{ - Type: payload.Type, - Provider: payload.Provider, - Model: payload.Model, - Input: payload.Input, - Source: payload.Source, - BotID: botID, - SessionID: sessionID, - RunID: payload.RunID, - Metadata: payload.Metadata, - Filters: payload.Filters, - } - - resp, err := h.service.EmbedUpsert(c.Request().Context(), req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, resp) -} - -// Add godoc -// @Summary Add memory -// @Description Add memory for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). -// @Tags memory -// @Param bot_id path string true "Bot ID" -// @Param payload body memoryAddPayload true "Add request" -// @Success 200 {object} memory.SearchResponse -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/add [post] -func (h *MemoryHandler) Add(c echo.Context) error { - if err := h.checkService(); err != nil { - return err - } - - userID, err := h.requireUserID(c) + containerID, err := h.resolveBotContainerID(c) if err != nil { return err } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if err := h.requireChatParticipant(c.Request().Context(), containerID, channelIdentityID); err != nil { return err } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") - } var payload memoryAddPayload if err := c.Bind(&payload); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } + + namespace, err := normalizeSharedMemoryNamespace(payload.Namespace) + if err != nil { + return err + } + + // Resolve bot scope for shared memory. + scopeID, botID, err := h.resolveWriteScope(c.Request().Context(), containerID) + if err != nil { + return err + } + + filters := buildNamespaceFilters(namespace, scopeID, payload.Filters) req := memory.AddRequest{ Message: payload.Message, Messages: payload.Messages, BotID: botID, - SessionID: sessionID, RunID: payload.RunID, Metadata: payload.Metadata, - Filters: payload.Filters, + Filters: filters, Infer: payload.Infer, EmbeddingEnabled: payload.EmbeddingEnabled, } - resp, err := h.service.Add(c.Request().Context(), req) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) @@ -195,324 +131,259 @@ func (h *MemoryHandler) Add(c echo.Context) error { return c.JSON(http.StatusOK, resp) } -// Search godoc -// @Summary Search memories -// @Description Search memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). -// @Tags memory -// @Param bot_id path string true "Bot ID" -// @Param payload body memorySearchPayload true "Search request" -// @Success 200 {object} memory.SearchResponse -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/search [post] -func (h *MemoryHandler) Search(c echo.Context) error { +// ChatSearch searches memory in the bot-shared namespace. +func (h *MemoryHandler) ChatSearch(c echo.Context) error { if err := h.checkService(); err != nil { return err } - - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + containerID, err := h.resolveBotContainerID(c) + if err != nil { return err } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") + if err := h.requireChatParticipant(c.Request().Context(), containerID, channelIdentityID); err != nil { + return err } var payload memorySearchPayload if err := c.Bind(&payload); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - req := memory.SearchRequest{ - Query: payload.Query, - BotID: botID, - SessionID: sessionID, - RunID: payload.RunID, - Limit: payload.Limit, - Filters: payload.Filters, - Sources: payload.Sources, - EmbeddingEnabled: payload.EmbeddingEnabled, + + scopes, err := h.resolveEnabledScopes(c.Request().Context(), containerID) + if err != nil { + return err + } + chatObj, err := h.chatService.Get(c.Request().Context(), containerID) + if err != nil { + return echo.NewHTTPError(http.StatusNotFound, "chat not found") + } + botID := strings.TrimSpace(chatObj.BotID) + + // Search shared namespace and merge results. + var allResults []memory.MemoryItem + for _, scope := range scopes { + filters := buildNamespaceFilters(scope.Namespace, scope.ScopeID, payload.Filters) + if botID != "" { + filters["botId"] = botID + } + req := memory.SearchRequest{ + Query: payload.Query, + BotID: botID, + RunID: payload.RunID, + Limit: payload.Limit, + Filters: filters, + Sources: payload.Sources, + EmbeddingEnabled: payload.EmbeddingEnabled, + } + resp, err := h.service.Search(c.Request().Context(), req) + if err != nil { + h.logger.Warn("search namespace failed", slog.String("namespace", scope.Namespace), slog.Any("error", err)) + continue + } + allResults = append(allResults, resp.Results...) } - resp, err := h.service.Search(c.Request().Context(), req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + // Deduplicate by ID and sort by score descending. + allResults = deduplicateMemoryItems(allResults) + sort.Slice(allResults, func(i, j int) bool { + return allResults[i].Score > allResults[j].Score + }) + if payload.Limit > 0 && len(allResults) > payload.Limit { + allResults = allResults[:payload.Limit] } - return c.JSON(http.StatusOK, resp) + + return c.JSON(http.StatusOK, memory.SearchResponse{Results: allResults}) } -// Update godoc -// @Summary Update memory -// @Description Update a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). -// @Tags memory -// @Param bot_id path string true "Bot ID" -// @Param payload body memory.UpdateRequest true "Update request" -// @Success 200 {object} memory.MemoryItem -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/update [post] -func (h *MemoryHandler) Update(c echo.Context) error { +// ChatGetAll lists all memories in the bot-shared namespace. +func (h *MemoryHandler) ChatGetAll(c echo.Context) error { if err := h.checkService(); err != nil { return err } - - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + containerID, err := h.resolveBotContainerID(c) + if err != nil { + return err } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if err := h.requireChatParticipant(c.Request().Context(), containerID, channelIdentityID); err != nil { return err } - var req memory.UpdateRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + scopes, err := h.resolveEnabledScopes(c.Request().Context(), containerID) + if err != nil { + return err } - if req.MemoryID != "" { - existing, err := h.service.Get(c.Request().Context(), req.MemoryID) + + var allResults []memory.MemoryItem + for _, scope := range scopes { + req := memory.GetAllRequest{ + Filters: buildNamespaceFilters(scope.Namespace, scope.ScopeID, nil), + } + resp, err := h.service.GetAll(c.Request().Context(), req) + if err != nil { + h.logger.Warn("getall namespace failed", slog.String("namespace", scope.Namespace), slog.Any("error", err)) + continue + } + allResults = append(allResults, resp.Results...) + } + allResults = deduplicateMemoryItems(allResults) + + return c.JSON(http.StatusOK, memory.SearchResponse{Results: allResults}) +} + +// ChatDeleteAll deletes all memories in the bot-shared namespace. +func (h *MemoryHandler) ChatDeleteAll(c echo.Context) error { + if err := h.checkService(); err != nil { + return err + } + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + containerID, err := h.resolveBotContainerID(c) + if err != nil { + return err + } + if err := h.requireChatParticipant(c.Request().Context(), containerID, channelIdentityID); err != nil { + return err + } + + scopes, err := h.resolveEnabledScopes(c.Request().Context(), containerID) + if err != nil { + return err + } + + for _, scope := range scopes { + req := memory.DeleteAllRequest{ + Filters: buildNamespaceFilters(scope.Namespace, scope.ScopeID, nil), + } + if _, err := h.service.DeleteAll(c.Request().Context(), req); err != nil { + h.logger.Warn("deleteall namespace failed", slog.String("namespace", scope.Namespace), slog.Any("error", err)) + } + } + return c.JSON(http.StatusOK, memory.DeleteResponse{Message: "Memory deleted successfully!"}) +} + +// --- helpers --- + +// resolveEnabledScopes returns the bot-shared namespace scope for the conversation. +func (h *MemoryHandler) resolveEnabledScopes(ctx context.Context, chatID string) ([]namespaceScope, error) { + if h.chatService == nil { + return nil, echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured") + } + chatObj, err := h.chatService.Get(ctx, chatID) + if err != nil { + return nil, echo.NewHTTPError(http.StatusNotFound, "chat not found") + } + botID := strings.TrimSpace(chatObj.BotID) + if botID == "" { + return nil, echo.NewHTTPError(http.StatusInternalServerError, "chat bot id is empty") + } + return []namespaceScope{{ + Namespace: sharedMemoryNamespace, + ScopeID: botID, + }}, nil +} + +// resolveWriteScope returns (scopeID, botID) for shared bot memory. +func (h *MemoryHandler) resolveWriteScope(ctx context.Context, chatID string) (string, string, error) { + if h.chatService == nil { + return "", "", echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured") + } + chatObj, err := h.chatService.Get(ctx, chatID) + if err != nil { + return "", "", echo.NewHTTPError(http.StatusNotFound, "chat not found") + } + botID := strings.TrimSpace(chatObj.BotID) + if botID == "" { + return "", "", echo.NewHTTPError(http.StatusInternalServerError, "bot id is empty") + } + return botID, botID, nil +} + +func normalizeSharedMemoryNamespace(raw string) (string, error) { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "", sharedMemoryNamespace: + return sharedMemoryNamespace, nil + default: + return "", echo.NewHTTPError(http.StatusBadRequest, "invalid namespace: "+raw) + } +} + +func (h *MemoryHandler) resolveBotContainerID(c echo.Context) (string, error) { + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return "", echo.NewHTTPError(http.StatusBadRequest, "bot_id is required") + } + return botID, nil +} + +func buildNamespaceFilters(namespace, scopeID string, extra map[string]any) map[string]any { + filters := map[string]any{ + "namespace": namespace, + "scopeId": scopeID, + } + for k, v := range extra { + if k != "namespace" && k != "scopeId" { + filters[k] = v + } + } + return filters +} + +func deduplicateMemoryItems(items []memory.MemoryItem) []memory.MemoryItem { + if len(items) == 0 { + return items + } + seen := make(map[string]struct{}, len(items)) + result := make([]memory.MemoryItem, 0, len(items)) + for _, item := range items { + if _, ok := seen[item.ID]; ok { + continue + } + seen[item.ID] = struct{}{} + result = append(result, item) + } + return result +} + +func (h *MemoryHandler) requireChatParticipant(ctx context.Context, chatID, channelIdentityID string) error { + if h.chatService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured") + } + if h.accountService != nil { + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - if existing.BotID != "" && existing.BotID != botID { - return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") + if isAdmin { + return nil } } - - resp, err := h.service.Update(c.Request().Context(), req) + ok, err := h.chatService.IsParticipant(ctx, chatID, channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - return c.JSON(http.StatusOK, resp) + if !ok { + return echo.NewHTTPError(http.StatusForbidden, "not a chat participant") + } + return nil } -// Get godoc -// @Summary Get memory -// @Description Get a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). -// @Tags memory -// @Param bot_id path string true "Bot ID" -// @Param memoryId path string true "Memory ID" -// @Success 200 {object} memory.MemoryItem -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/memories/{memoryId} [get] -func (h *MemoryHandler) Get(c echo.Context) error { - if err := h.checkService(); err != nil { - return err - } - - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - - memoryID := c.Param("memoryId") - if memoryID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "memory ID required") - } - - resp, err := h.service.Get(c.Request().Context(), memoryID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if resp.BotID != "" && resp.BotID != botID { - return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") - } - return c.JSON(http.StatusOK, resp) -} - -// GetAll godoc -// @Summary List memories -// @Description List memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). -// @Tags memory -// @Param bot_id path string true "Bot ID" -// @Param run_id query string false "Run ID" -// @Param limit query int false "Limit" -// @Success 200 {object} memory.SearchResponse -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/memories [get] -func (h *MemoryHandler) GetAll(c echo.Context) error { - if err := h.checkService(); err != nil { - return err - } - - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") - } - - req := memory.GetAllRequest{ - BotID: botID, - SessionID: sessionID, - AgentID: c.QueryParam("agent_id"), - RunID: c.QueryParam("run_id"), - } - if limit := c.QueryParam("limit"); limit != "" { - var parsed int - if _, err := fmt.Sscanf(limit, "%d", &parsed); err == nil { - req.Limit = parsed - } - } - - resp, err := h.service.GetAll(c.Request().Context(), req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, resp) -} - -// Delete godoc -// @Summary Delete memory -// @Description Delete a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). -// @Tags memory -// @Param bot_id path string true "Bot ID" -// @Param memoryId path string true "Memory ID" -// @Success 200 {object} memory.DeleteResponse -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/memories/{memoryId} [delete] -func (h *MemoryHandler) Delete(c echo.Context) error { - if err := h.checkService(); err != nil { - return err - } - - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - - memoryID := c.Param("memoryId") - if memoryID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "memory ID required") - } - - existing, err := h.service.Get(c.Request().Context(), memoryID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if existing.BotID != "" && existing.BotID != botID { - return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") - } - - resp, err := h.service.Delete(c.Request().Context(), memoryID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, resp) -} - -// DeleteAll godoc -// @Summary Delete memories -// @Description Delete all memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). -// @Tags memory -// @Param bot_id path string true "Bot ID" -// @Param payload body memoryDeleteAllPayload true "Delete all request" -// @Success 200 {object} memory.DeleteResponse -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/memories [delete] -func (h *MemoryHandler) DeleteAll(c echo.Context) error { - if err := h.checkService(); err != nil { - return err - } - - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") - } - - var payload memoryDeleteAllPayload - if err := c.Bind(&payload); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - req := memory.DeleteAllRequest{ - BotID: botID, - SessionID: sessionID, - RunID: payload.RunID, - } - - resp, err := h.service.DeleteAll(c.Request().Context(), req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, resp) -} - -func (h *MemoryHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) +func (h *MemoryHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.UserIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return userID, nil -} - -func (h *MemoryHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") - } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) - if err != nil { - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) - if err != nil { - if errors.Is(err, bots.ErrBotNotFound) { - return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") - } - if errors.Is(err, bots.ErrBotAccessDenied) { - return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied") - } - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return bot, nil + return channelIdentityID, nil } diff --git a/internal/handlers/message.go b/internal/handlers/message.go new file mode 100644 index 00000000..31c25874 --- /dev/null +++ b/internal/handlers/message.go @@ -0,0 +1,583 @@ +package handlers + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "strconv" + "strings" + "time" + + "github.com/labstack/echo/v4" + + "github.com/memohai/memoh/internal/accounts" + "github.com/memohai/memoh/internal/auth" + "github.com/memohai/memoh/internal/bots" + "github.com/memohai/memoh/internal/channel/identities" + "github.com/memohai/memoh/internal/conversation" + "github.com/memohai/memoh/internal/conversation/flow" + "github.com/memohai/memoh/internal/identity" + messagepkg "github.com/memohai/memoh/internal/message" + messageevent "github.com/memohai/memoh/internal/message/event" +) + +// MessageHandler handles bot-scoped messaging endpoints. +type MessageHandler struct { + runner flow.Runner + conversationService conversation.Accessor + messageService messagepkg.Service + messageEvents messageevent.Subscriber + botService *bots.Service + accountService *accounts.Service + channelIdentitySvc *identities.Service + logger *slog.Logger +} + +// NewMessageHandler creates a MessageHandler. +func NewMessageHandler(log *slog.Logger, runner flow.Runner, conversationService conversation.Accessor, messageService messagepkg.Service, botService *bots.Service, accountService *accounts.Service, channelIdentitySvc *identities.Service, eventSubscribers ...messageevent.Subscriber) *MessageHandler { + var messageEvents messageevent.Subscriber + if len(eventSubscribers) > 0 { + messageEvents = eventSubscribers[0] + } + return &MessageHandler{ + runner: runner, + conversationService: conversationService, + messageService: messageService, + messageEvents: messageEvents, + botService: botService, + accountService: accountService, + channelIdentitySvc: channelIdentitySvc, + logger: log.With(slog.String("handler", "conversation")), + } +} + +// Register registers all conversation routes. +func (h *MessageHandler) Register(e *echo.Echo) { + // Bot-scoped message container (single shared history per bot). + botGroup := e.Group("/bots/:bot_id") + botGroup.POST("/messages", h.SendMessage) + botGroup.POST("/messages/stream", h.StreamMessage) + botGroup.GET("/messages", h.ListMessages) + botGroup.GET("/messages/events", h.StreamMessageEvents) + botGroup.DELETE("/messages", h.DeleteMessages) +} + +// --- Messages --- + +// SendMessage sends a synchronous conversation message. +func (h *MessageHandler) SendMessage(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { + return err + } + if err := h.requireParticipant(c.Request().Context(), botID, channelIdentityID); err != nil { + return err + } + + var req flow.ChatRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + if req.Query == "" { + return echo.NewHTTPError(http.StatusBadRequest, "query is required") + } + req.BotID = botID + req.ChatID = botID + req.Token = c.Request().Header.Get("Authorization") + req.UserID = channelIdentityID + req.SourceChannelIdentityID = channelIdentityID + if strings.TrimSpace(req.CurrentChannel) == "" { + req.CurrentChannel = "web" + } + if len(req.Channels) == 0 { + req.Channels = []string{req.CurrentChannel} + } + channelIdentityID = h.resolveWebChannelIdentity(c.Request().Context(), channelIdentityID, &req) + + if h.runner == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "conversation runner not configured") + } + resp, err := h.runner.Chat(c.Request().Context(), req) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, resp) +} + +// StreamMessage sends a streaming conversation message. +func (h *MessageHandler) StreamMessage(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { + return err + } + if err := h.requireParticipant(c.Request().Context(), botID, channelIdentityID); err != nil { + return err + } + + var req flow.ChatRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + if req.Query == "" { + return echo.NewHTTPError(http.StatusBadRequest, "query is required") + } + req.BotID = botID + req.ChatID = botID + req.Token = c.Request().Header.Get("Authorization") + req.UserID = channelIdentityID + req.SourceChannelIdentityID = channelIdentityID + if strings.TrimSpace(req.CurrentChannel) == "" { + req.CurrentChannel = "web" + } + if len(req.Channels) == 0 { + req.Channels = []string{req.CurrentChannel} + } + channelIdentityID = h.resolveWebChannelIdentity(c.Request().Context(), channelIdentityID, &req) + + if h.runner == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "conversation runner not configured") + } + c.Response().Header().Set(echo.HeaderContentType, "text/event-stream") + c.Response().Header().Set(echo.HeaderCacheControl, "no-cache") + c.Response().Header().Set(echo.HeaderConnection, "keep-alive") + c.Response().WriteHeader(http.StatusOK) + + chunkChan, errChan := h.runner.StreamChat(c.Request().Context(), req) + flusher, ok := c.Response().Writer.(http.Flusher) + if !ok { + return echo.NewHTTPError(http.StatusInternalServerError, "streaming not supported") + } + writer := bufio.NewWriter(c.Response().Writer) + processingState := "started" + if err := writeSSEJSON(writer, flusher, map[string]string{"type": "processing_started"}); err != nil { + return nil + } + + for { + select { + case chunk, ok := <-chunkChan: + if !ok { + if processingState == "started" { + processingState = "completed" + if err := writeSSEJSON(writer, flusher, map[string]string{"type": "processing_completed"}); err != nil { + return nil + } + } + if err := writeSSEData(writer, flusher, "[DONE]"); err != nil { + return nil + } + return nil + } + if processingState == "started" { + processingState = "completed" + if err := writeSSEJSON(writer, flusher, map[string]string{"type": "processing_completed"}); err != nil { + return nil + } + } + if err := writeSSEData(writer, flusher, string(chunk)); err != nil { + return nil + } + case err := <-errChan: + if err != nil { + h.logger.Error("conversation stream failed", slog.Any("error", err)) + if processingState == "started" { + processingState = "failed" + _ = writeSSEJSON(writer, flusher, map[string]string{ + "type": "processing_failed", + "error": err.Error(), + }) + } + errData := map[string]string{ + "type": "error", + "error": err.Error(), + "message": err.Error(), + } + if writeErr := writeSSEJSON(writer, flusher, errData); writeErr != nil { + return nil + } + return nil + } + } + } +} + +func writeSSEData(writer *bufio.Writer, flusher http.Flusher, payload string) error { + if _, err := writer.WriteString(fmt.Sprintf("data: %s\n\n", payload)); err != nil { + return err + } + if err := writer.Flush(); err != nil { + return err + } + flusher.Flush() + return nil +} + +func writeSSEJSON(writer *bufio.Writer, flusher http.Flusher, payload any) error { + data, err := json.Marshal(payload) + if err != nil { + return err + } + return writeSSEData(writer, flusher, string(data)) +} + +func parseSinceParam(raw string) (time.Time, bool, error) { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return time.Time{}, false, nil + } + layouts := []string{time.RFC3339Nano, time.RFC3339} + for _, layout := range layouts { + parsed, err := time.Parse(layout, trimmed) + if err == nil { + return parsed.UTC(), true, nil + } + } + if epochMillis, err := strconv.ParseInt(trimmed, 10, 64); err == nil { + return time.UnixMilli(epochMillis).UTC(), true, nil + } + return time.Time{}, false, fmt.Errorf("invalid since parameter") +} + +// ListMessages lists messages for a conversation with optional pagination. +// Query: limit (default 30), before (optional ISO8601 or unix ms) for older messages. +// Returns items in ascending created_at order (oldest first). +func (h *MessageHandler) ListMessages(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { + return err + } + if err := h.requireReadable(c.Request().Context(), botID, channelIdentityID); err != nil { + return err + } + + if h.messageService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "message service not configured") + } + + limit := int32(30) + if s := strings.TrimSpace(c.QueryParam("limit")); s != "" { + if n, err := strconv.ParseInt(s, 10, 32); err == nil && n > 0 && n <= 100 { + limit = int32(n) + } + } + + before, hasBefore := parseBeforeParam(c.QueryParam("before")) + + var messages []messagepkg.Message + if hasBefore { + messages, err = h.messageService.ListBefore(c.Request().Context(), botID, before, limit) + } else { + messages, err = h.messageService.ListLatest(c.Request().Context(), botID, limit) + if err == nil { + reverseMessages(messages) + } + } + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, map[string]any{"items": messages}) +} + +func parseBeforeParam(s string) (time.Time, bool) { + trimmed := strings.TrimSpace(s) + if trimmed == "" { + return time.Time{}, false + } + if t, err := time.Parse(time.RFC3339Nano, trimmed); err == nil { + return t.UTC(), true + } + if t, err := time.Parse(time.RFC3339, trimmed); err == nil { + return t.UTC(), true + } + if epochMillis, err := strconv.ParseInt(trimmed, 10, 64); err == nil { + return time.UnixMilli(epochMillis).UTC(), true + } + return time.Time{}, false +} + +func reverseMessages(m []messagepkg.Message) { + for i, j := 0, len(m)-1; i < j; i, j = i+1, j-1 { + m[i], m[j] = m[j], m[i] + } +} + +// StreamMessageEvents streams bot-scoped message events to clients. +func (h *MessageHandler) StreamMessageEvents(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { + return err + } + if err := h.requireReadable(c.Request().Context(), botID, channelIdentityID); err != nil { + return err + } + if h.messageService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "message service not configured") + } + if h.messageEvents == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "message events not configured") + } + + since, hasSince, err := parseSinceParam(c.QueryParam("since")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + c.Response().Header().Set(echo.HeaderContentType, "text/event-stream") + c.Response().Header().Set(echo.HeaderCacheControl, "no-cache") + c.Response().Header().Set(echo.HeaderConnection, "keep-alive") + c.Response().WriteHeader(http.StatusOK) + + flusher, ok := c.Response().Writer.(http.Flusher) + if !ok { + return echo.NewHTTPError(http.StatusInternalServerError, "streaming not supported") + } + writer := bufio.NewWriter(c.Response().Writer) + + sentMessageIDs := map[string]struct{}{} + writeCreatedEvent := func(message messagepkg.Message) error { + msgID := strings.TrimSpace(message.ID) + if msgID != "" { + if _, exists := sentMessageIDs[msgID]; exists { + return nil + } + sentMessageIDs[msgID] = struct{}{} + } + return writeSSEJSON(writer, flusher, map[string]any{ + "type": string(messageevent.EventTypeMessageCreated), + "bot_id": botID, + "message": message, + }) + } + + _, stream, cancel := h.messageEvents.Subscribe(botID, 128) + defer cancel() + + if hasSince { + backlog, err := h.messageService.ListSince(c.Request().Context(), botID, since) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + for _, message := range backlog { + if err := writeCreatedEvent(message); err != nil { + return nil + } + } + } + + heartbeatTicker := time.NewTicker(20 * time.Second) + defer heartbeatTicker.Stop() + + for { + select { + case <-c.Request().Context().Done(): + return nil + case <-heartbeatTicker.C: + if err := writeSSEJSON(writer, flusher, map[string]any{"type": "ping"}); err != nil { + return nil + } + case event, ok := <-stream: + if !ok { + return nil + } + if strings.TrimSpace(event.BotID) != botID { + continue + } + if event.Type != messageevent.EventTypeMessageCreated { + continue + } + if len(event.Data) == 0 { + continue + } + var message messagepkg.Message + if err := json.Unmarshal(event.Data, &message); err != nil { + h.logger.Warn("decode message event failed", slog.Any("error", err)) + continue + } + if err := writeCreatedEvent(message); err != nil { + return nil + } + } + } +} + +// DeleteMessages clears all persisted bot-level history messages. +func (h *MessageHandler) DeleteMessages(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotManage(c.Request().Context(), channelIdentityID, botID); err != nil { + return err + } + if h.messageService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "message service not configured") + } + if err := h.messageService.DeleteByBot(c.Request().Context(), botID); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.NoContent(http.StatusNoContent) +} + +// --- helpers --- + +// resolveWebChannelIdentity resolves (web, user_id) to a channel identity and sets req.SourceChannelIdentityID. +// Web uses user_id as the channel subject id (like Feishu open_id); the resolved ci has display_name and is linked to the user. +// Returns the channel_identity_id to use for the rest of the flow, or the original userID if resolution is skipped/fails. +func (h *MessageHandler) resolveWebChannelIdentity(ctx context.Context, userID string, req *flow.ChatRequest) string { + if strings.TrimSpace(req.CurrentChannel) != "web" || h.channelIdentitySvc == nil || strings.TrimSpace(userID) == "" { + return userID + } + displayName := "" + if h.accountService != nil { + if account, err := h.accountService.Get(ctx, userID); err == nil { + displayName = strings.TrimSpace(account.DisplayName) + if displayName == "" { + displayName = strings.TrimSpace(account.Username) + } + } + } + ci, err := h.channelIdentitySvc.ResolveByChannelIdentity(ctx, "web", userID, displayName) + if err != nil { + return userID + } + _ = h.channelIdentitySvc.LinkChannelIdentityToUser(ctx, ci.ID, userID) + req.SourceChannelIdentityID = ci.ID + return ci.ID +} + +func (h *MessageHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.UserIDFromContext(c) + if err != nil { + return "", err + } + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { + return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + return channelIdentityID, nil +} + +func (h *MessageHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") + } + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) + if err != nil { + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: true}) + if err != nil { + if errors.Is(err, bots.ErrBotNotFound) { + return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") + } + if errors.Is(err, bots.ErrBotAccessDenied) { + return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied") + } + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return bot, nil +} + +func (h *MessageHandler) authorizeBotManage(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") + } + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) + if err != nil { + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + if err != nil { + if errors.Is(err, bots.ErrBotNotFound) { + return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") + } + if errors.Is(err, bots.ErrBotAccessDenied) { + return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot management access denied") + } + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return bot, nil +} + +func (h *MessageHandler) requireParticipant(ctx context.Context, conversationID, channelIdentityID string) error { + if h.conversationService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "conversation service not configured") + } + // Admin bypass. + if h.accountService != nil { + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + if isAdmin { + return nil + } + } + ok, err := h.conversationService.IsParticipant(ctx, conversationID, channelIdentityID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + if !ok { + return echo.NewHTTPError(http.StatusForbidden, "not a participant") + } + return nil +} + +func (h *MessageHandler) requireReadable(ctx context.Context, conversationID, channelIdentityID string) error { + if h.conversationService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "conversation service not configured") + } + // Admin bypass. + if h.accountService != nil { + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + if isAdmin { + return nil + } + } + _, err := h.conversationService.GetReadAccess(ctx, conversationID, channelIdentityID) + if err != nil { + if errors.Is(err, conversation.ErrPermissionDenied) { + return echo.NewHTTPError(http.StatusForbidden, "not allowed to read conversation") + } + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return nil +} diff --git a/internal/handlers/preauth.go b/internal/handlers/preauth.go index 4b0c965b..2f5ed413 100644 --- a/internal/handlers/preauth.go +++ b/internal/handlers/preauth.go @@ -9,24 +9,24 @@ import ( "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/identity" "github.com/memohai/memoh/internal/preauth" - "github.com/memohai/memoh/internal/users" ) type PreauthHandler struct { - service *preauth.Service - botService *bots.Service - userService *users.Service + service *preauth.Service + botService *bots.Service + accountService *accounts.Service } -func NewPreauthHandler(service *preauth.Service, botService *bots.Service, userService *users.Service) *PreauthHandler { +func NewPreauthHandler(service *preauth.Service, botService *bots.Service, accountService *accounts.Service) *PreauthHandler { return &PreauthHandler{ - service: service, - botService: botService, - userService: userService, + service: service, + botService: botService, + accountService: accountService, } } @@ -71,21 +71,21 @@ func (h *PreauthHandler) requireUserID(c echo.Context) (string, error) { if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(userID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } return userID, nil } -func (h *PreauthHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { +func (h *PreauthHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) + isAdmin, err := h.accountService.IsAdmin(ctx, userID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + bot, err := h.botService.AuthorizeAccess(ctx, userID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") diff --git a/internal/handlers/schedule.go b/internal/handlers/schedule.go index 11511a2e..42643408 100644 --- a/internal/handlers/schedule.go +++ b/internal/handlers/schedule.go @@ -9,26 +9,26 @@ import ( "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/identity" "github.com/memohai/memoh/internal/schedule" - "github.com/memohai/memoh/internal/users" ) type ScheduleHandler struct { - service *schedule.Service - botService *bots.Service - userService *users.Service - logger *slog.Logger + service *schedule.Service + botService *bots.Service + accountService *accounts.Service + logger *slog.Logger } -func NewScheduleHandler(log *slog.Logger, service *schedule.Service, botService *bots.Service, userService *users.Service) *ScheduleHandler { +func NewScheduleHandler(log *slog.Logger, service *schedule.Service, botService *bots.Service, accountService *accounts.Service) *ScheduleHandler { return &ScheduleHandler{ - service: service, - botService: botService, - userService: userService, - logger: log.With(slog.String("handler", "schedule")), + service: service, + botService: botService, + accountService: accountService, + logger: log.With(slog.String("handler", "schedule")), } } @@ -45,7 +45,6 @@ func (h *ScheduleHandler) Register(e *echo.Echo) { // @Summary Create schedule // @Description Create a schedule for current user // @Tags schedule -// @Param bot_id path string true "Bot ID" // @Param payload body schedule.CreateRequest true "Schedule payload" // @Success 201 {object} schedule.Schedule // @Failure 400 {object} ErrorResponse @@ -78,7 +77,6 @@ func (h *ScheduleHandler) Create(c echo.Context) error { // @Summary List schedules // @Description List schedules for current user // @Tags schedule -// @Param bot_id path string true "Bot ID" // @Success 200 {object} schedule.ListResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse @@ -106,7 +104,6 @@ func (h *ScheduleHandler) List(c echo.Context) error { // @Summary Get schedule // @Description Get a schedule by ID // @Tags schedule -// @Param bot_id path string true "Bot ID" // @Param id path string true "Schedule ID" // @Success 200 {object} schedule.Schedule // @Failure 400 {object} ErrorResponse @@ -143,7 +140,6 @@ func (h *ScheduleHandler) Get(c echo.Context) error { // @Summary Update schedule // @Description Update a schedule by ID // @Tags schedule -// @Param bot_id path string true "Bot ID" // @Param id path string true "Schedule ID" // @Param payload body schedule.UpdateRequest true "Schedule payload" // @Success 200 {object} schedule.Schedule @@ -188,7 +184,6 @@ func (h *ScheduleHandler) Update(c echo.Context) error { // @Summary Delete schedule // @Description Delete a schedule by ID // @Tags schedule -// @Param bot_id path string true "Bot ID" // @Param id path string true "Schedule ID" // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse @@ -228,21 +223,21 @@ func (h *ScheduleHandler) requireUserID(c echo.Context) (string, error) { if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(userID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } return userID, nil } -func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { +func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) + isAdmin, err := h.accountService.IsAdmin(ctx, userID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + bot, err := h.botService.AuthorizeAccess(ctx, userID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") @@ -253,4 +248,4 @@ func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, actorID, botID return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return bot, nil -} \ No newline at end of file +} diff --git a/internal/handlers/settings.go b/internal/handlers/settings.go index a0950701..7c3d57d0 100644 --- a/internal/handlers/settings.go +++ b/internal/handlers/settings.go @@ -9,26 +9,26 @@ import ( "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/identity" "github.com/memohai/memoh/internal/settings" - "github.com/memohai/memoh/internal/users" ) type SettingsHandler struct { - service *settings.Service - botService *bots.Service - userService *users.Service - logger *slog.Logger + service *settings.Service + botService *bots.Service + accountService *accounts.Service + logger *slog.Logger } -func NewSettingsHandler(log *slog.Logger, service *settings.Service, botService *bots.Service, userService *users.Service) *SettingsHandler { +func NewSettingsHandler(log *slog.Logger, service *settings.Service, botService *bots.Service, accountService *accounts.Service) *SettingsHandler { return &SettingsHandler{ - service: service, - botService: botService, - userService: userService, - logger: log.With(slog.String("handler", "settings")), + service: service, + botService: botService, + accountService: accountService, + logger: log.With(slog.String("handler", "settings")), } } @@ -44,13 +44,12 @@ func (h *SettingsHandler) Register(e *echo.Echo) { // @Summary Get user settings // @Description Get agent settings for current user // @Tags settings -// @Param bot_id path string true "Bot ID" // @Success 200 {object} settings.Settings // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/settings [get] func (h *SettingsHandler) Get(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -58,7 +57,7 @@ func (h *SettingsHandler) Get(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } resp, err := h.service.GetBot(c.Request().Context(), botID) @@ -72,7 +71,6 @@ func (h *SettingsHandler) Get(c echo.Context) error { // @Summary Update user settings // @Description Update or create agent settings for current user // @Tags settings -// @Param bot_id path string true "Bot ID" // @Param payload body settings.UpsertRequest true "Settings payload" // @Success 200 {object} settings.Settings // @Failure 400 {object} ErrorResponse @@ -80,7 +78,7 @@ func (h *SettingsHandler) Get(c echo.Context) error { // @Router /bots/{bot_id}/settings [put] // @Router /bots/{bot_id}/settings [post] func (h *SettingsHandler) Upsert(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -88,7 +86,7 @@ func (h *SettingsHandler) Upsert(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } var req settings.UpsertRequest @@ -97,6 +95,9 @@ func (h *SettingsHandler) Upsert(c echo.Context) error { } resp, err := h.service.UpsertBot(c.Request().Context(), botID, req) if err != nil { + if errors.Is(err, settings.ErrPersonalBotGuestAccessUnsupported) { + return echo.NewHTTPError(http.StatusBadRequest, "personal bot does not support guest access") + } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.JSON(http.StatusOK, resp) @@ -106,13 +107,12 @@ func (h *SettingsHandler) Upsert(c echo.Context) error { // @Summary Delete user settings // @Description Remove agent settings for current user // @Tags settings -// @Param bot_id path string true "Bot ID" // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/settings [delete] func (h *SettingsHandler) Delete(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -120,7 +120,7 @@ func (h *SettingsHandler) Delete(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } if err := h.service.Delete(c.Request().Context(), botID); err != nil { @@ -129,26 +129,26 @@ func (h *SettingsHandler) Delete(c echo.Context) error { return c.NoContent(http.StatusNoContent) } -func (h *SettingsHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) +func (h *SettingsHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.UserIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return userID, nil + return channelIdentityID, nil } -func (h *SettingsHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { +func (h *SettingsHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") @@ -159,4 +159,4 @@ func (h *SettingsHandler) authorizeBotAccess(ctx context.Context, actorID, botID return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return bot, nil -} \ No newline at end of file +} diff --git a/internal/handlers/subagent.go b/internal/handlers/subagent.go index 21ce1335..40d86c15 100644 --- a/internal/handlers/subagent.go +++ b/internal/handlers/subagent.go @@ -9,26 +9,26 @@ import ( "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/identity" "github.com/memohai/memoh/internal/subagent" - "github.com/memohai/memoh/internal/users" ) type SubagentHandler struct { - service *subagent.Service - botService *bots.Service - userService *users.Service - logger *slog.Logger + service *subagent.Service + botService *bots.Service + accountService *accounts.Service + logger *slog.Logger } -func NewSubagentHandler(log *slog.Logger, service *subagent.Service, botService *bots.Service, userService *users.Service) *SubagentHandler { +func NewSubagentHandler(log *slog.Logger, service *subagent.Service, botService *bots.Service, accountService *accounts.Service) *SubagentHandler { return &SubagentHandler{ - service: service, - botService: botService, - userService: userService, - logger: log.With(slog.String("handler", "subagent")), + service: service, + botService: botService, + accountService: accountService, + logger: log.With(slog.String("handler", "subagent")), } } @@ -50,14 +50,13 @@ func (h *SubagentHandler) Register(e *echo.Echo) { // @Summary Create subagent // @Description Create a subagent for current user // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Param payload body subagent.CreateRequest true "Subagent payload" // @Success 201 {object} subagent.Subagent // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents [post] func (h *SubagentHandler) Create(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -65,7 +64,7 @@ func (h *SubagentHandler) Create(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } var req subagent.CreateRequest @@ -83,13 +82,12 @@ func (h *SubagentHandler) Create(c echo.Context) error { // @Summary List subagents // @Description List subagents for current user // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Success 200 {object} subagent.ListResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents [get] func (h *SubagentHandler) List(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -97,7 +95,7 @@ func (h *SubagentHandler) List(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } items, err := h.service.List(c.Request().Context(), botID) @@ -111,7 +109,6 @@ func (h *SubagentHandler) List(c echo.Context) error { // @Summary Get subagent // @Description Get a subagent by ID // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Param id path string true "Subagent ID" // @Success 200 {object} subagent.Subagent // @Failure 400 {object} ErrorResponse @@ -119,7 +116,7 @@ func (h *SubagentHandler) List(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents/{id} [get] func (h *SubagentHandler) Get(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -138,7 +135,7 @@ func (h *SubagentHandler) Get(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } return c.JSON(http.StatusOK, item) @@ -148,7 +145,6 @@ func (h *SubagentHandler) Get(c echo.Context) error { // @Summary Update subagent // @Description Update a subagent by ID // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Param id path string true "Subagent ID" // @Param payload body subagent.UpdateRequest true "Subagent payload" // @Success 200 {object} subagent.Subagent @@ -157,7 +153,7 @@ func (h *SubagentHandler) Get(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents/{id} [put] func (h *SubagentHandler) Update(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -180,7 +176,7 @@ func (h *SubagentHandler) Update(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } resp, err := h.service.Update(c.Request().Context(), id, req) @@ -194,7 +190,6 @@ func (h *SubagentHandler) Update(c echo.Context) error { // @Summary Delete subagent // @Description Delete a subagent by ID // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Param id path string true "Subagent ID" // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse @@ -202,7 +197,7 @@ func (h *SubagentHandler) Update(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents/{id} [delete] func (h *SubagentHandler) Delete(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -221,7 +216,7 @@ func (h *SubagentHandler) Delete(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } if err := h.service.Delete(c.Request().Context(), id); err != nil { @@ -234,7 +229,6 @@ func (h *SubagentHandler) Delete(c echo.Context) error { // @Summary Get subagent context // @Description Get a subagent's message context // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Param id path string true "Subagent ID" // @Success 200 {object} subagent.ContextResponse // @Failure 400 {object} ErrorResponse @@ -242,7 +236,7 @@ func (h *SubagentHandler) Delete(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents/{id}/context [get] func (h *SubagentHandler) GetContext(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -261,7 +255,7 @@ func (h *SubagentHandler) GetContext(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } return c.JSON(http.StatusOK, subagent.ContextResponse{Messages: item.Messages}) @@ -271,7 +265,6 @@ func (h *SubagentHandler) GetContext(c echo.Context) error { // @Summary Update subagent context // @Description Update a subagent's message context // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Param id path string true "Subagent ID" // @Param payload body subagent.UpdateContextRequest true "Context payload" // @Success 200 {object} subagent.ContextResponse @@ -280,7 +273,7 @@ func (h *SubagentHandler) GetContext(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents/{id}/context [put] func (h *SubagentHandler) UpdateContext(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -303,7 +296,7 @@ func (h *SubagentHandler) UpdateContext(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } updated, err := h.service.UpdateContext(c.Request().Context(), id, req) @@ -317,7 +310,6 @@ func (h *SubagentHandler) UpdateContext(c echo.Context) error { // @Summary Get subagent skills // @Description Get a subagent's skills // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Param id path string true "Subagent ID" // @Success 200 {object} subagent.SkillsResponse // @Failure 400 {object} ErrorResponse @@ -325,7 +317,7 @@ func (h *SubagentHandler) UpdateContext(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents/{id}/skills [get] func (h *SubagentHandler) GetSkills(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -344,7 +336,7 @@ func (h *SubagentHandler) GetSkills(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } return c.JSON(http.StatusOK, subagent.SkillsResponse{Skills: item.Skills}) @@ -354,7 +346,6 @@ func (h *SubagentHandler) GetSkills(c echo.Context) error { // @Summary Update subagent skills // @Description Replace a subagent's skills // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Param id path string true "Subagent ID" // @Param payload body subagent.UpdateSkillsRequest true "Skills payload" // @Success 200 {object} subagent.SkillsResponse @@ -363,7 +354,7 @@ func (h *SubagentHandler) GetSkills(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents/{id}/skills [put] func (h *SubagentHandler) UpdateSkills(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -386,7 +377,7 @@ func (h *SubagentHandler) UpdateSkills(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } updated, err := h.service.UpdateSkills(c.Request().Context(), id, req) @@ -400,7 +391,6 @@ func (h *SubagentHandler) UpdateSkills(c echo.Context) error { // @Summary Add subagent skills // @Description Add skills to a subagent // @Tags subagent -// @Param bot_id path string true "Bot ID" // @Param id path string true "Subagent ID" // @Param payload body subagent.AddSkillsRequest true "Skills payload" // @Success 200 {object} subagent.SkillsResponse @@ -409,7 +399,7 @@ func (h *SubagentHandler) UpdateSkills(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents/{id}/skills [post] func (h *SubagentHandler) AddSkills(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -432,7 +422,7 @@ func (h *SubagentHandler) AddSkills(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "user mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } updated, err := h.service.AddSkills(c.Request().Context(), id, req) @@ -442,26 +432,26 @@ func (h *SubagentHandler) AddSkills(c echo.Context) error { return c.JSON(http.StatusOK, subagent.SkillsResponse{Skills: updated.Skills}) } -func (h *SubagentHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) +func (h *SubagentHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.UserIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return userID, nil + return channelIdentityID, nil } -func (h *SubagentHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { +func (h *SubagentHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") @@ -472,4 +462,4 @@ func (h *SubagentHandler) authorizeBotAccess(ctx context.Context, actorID, botID return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return bot, nil -} \ No newline at end of file +} diff --git a/internal/handlers/users.go b/internal/handlers/users.go index f07d69e8..d8d975f9 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -10,39 +10,53 @@ import ( "github.com/jackc/pgx/v5" "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/channel" + "github.com/memohai/memoh/internal/channel/identities" + "github.com/memohai/memoh/internal/channel/route" "github.com/memohai/memoh/internal/identity" - "github.com/memohai/memoh/internal/users" ) +// UsersHandler manages user/account CRUD and bot operations via REST API. type UsersHandler struct { - service *users.Service - botService *bots.Service - channelService *channel.Service - channelManager *channel.Manager - registry *channel.Registry - logger *slog.Logger + service *accounts.Service + channelIdentityService *identities.Service + botService *bots.Service + routeService route.Service + channelService *channel.Service + channelManager *channel.Manager + registry *channel.Registry + logger *slog.Logger } -func NewUsersHandler(log *slog.Logger, service *users.Service, botService *bots.Service, channelService *channel.Service, channelManager *channel.Manager, registry *channel.Registry) *UsersHandler { +type listMyIdentitiesResponse struct { + UserID string `json:"user_id"` + Items []identities.ChannelIdentity `json:"items"` +} + +// NewUsersHandler creates a UsersHandler with channel identity support. +func NewUsersHandler(log *slog.Logger, service *accounts.Service, channelIdentityService *identities.Service, botService *bots.Service, routeService route.Service, channelService *channel.Service, channelManager *channel.Manager, registry *channel.Registry) *UsersHandler { if log == nil { log = slog.Default() } return &UsersHandler{ - service: service, - botService: botService, - channelService: channelService, - channelManager: channelManager, - registry: registry, - logger: log.With(slog.String("handler", "users")), + service: service, + channelIdentityService: channelIdentityService, + botService: botService, + routeService: routeService, + channelService: channelService, + channelManager: channelManager, + registry: registry, + logger: log.With(slog.String("handler", "users")), } } func (h *UsersHandler) Register(e *echo.Echo) { userGroup := e.Group("/users") userGroup.GET("/me", h.GetMe) + userGroup.GET("/me/identities", h.ListMyIdentities) userGroup.PUT("/me", h.UpdateMe) userGroup.PUT("/me/password", h.UpdateMyPassword) userGroup.GET("", h.ListUsers) @@ -55,6 +69,7 @@ func (h *UsersHandler) Register(e *echo.Echo) { botGroup.POST("", h.CreateBot) botGroup.GET("", h.ListBots) botGroup.GET("/:id", h.GetBot) + botGroup.GET("/:id/checks", h.ListBotChecks) botGroup.PUT("/:id", h.UpdateBot) botGroup.PUT("/:id/owner", h.TransferBotOwner) botGroup.DELETE("/:id", h.DeleteBot) @@ -64,48 +79,75 @@ func (h *UsersHandler) Register(e *echo.Echo) { botGroup.GET("/:id/channel/:platform", h.GetBotChannelConfig) botGroup.PUT("/:id/channel/:platform", h.UpsertBotChannelConfig) botGroup.POST("/:id/channel/:platform/send", h.SendBotMessage) - botGroup.POST("/:id/channel/:platform/send_session", h.SendBotMessageSession) + botGroup.POST("/:id/channel/:platform/send_chat", h.SendBotMessageSession) } // GetMe godoc // @Summary Get current user // @Description Get current user profile // @Tags users -// @Success 200 {object} users.User +// @Success 200 {object} accounts.Account // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /users/me [get] func (h *UsersHandler) GetMe(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - resp, err := h.service.Get(c.Request().Context(), userID) + resp, err := h.service.Get(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.JSON(http.StatusOK, resp) } +// ListMyIdentities godoc +// @Summary List current user's channel identities +// @Description List all channel identities linked to current user +// @Tags users +// @Success 200 {object} listMyIdentitiesResponse +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /users/me/identities [get] +func (h *UsersHandler) ListMyIdentities(c echo.Context) error { + userID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + if h.channelIdentityService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "channel identity service not configured") + } + items, err := h.channelIdentityService.ListUserChannelIdentities(c.Request().Context(), userID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, listMyIdentitiesResponse{ + UserID: userID, + Items: items, + }) +} + // UpdateMe godoc // @Summary Update current user profile // @Description Update current user display name or avatar // @Tags users -// @Param payload body users.UpdateProfileRequest true "Profile payload" -// @Success 200 {object} users.User +// @Param payload body accounts.UpdateProfileRequest true "Profile payload" +// @Success 200 {object} accounts.Account // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /users/me [put] func (h *UsersHandler) UpdateMe(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - var req users.UpdateProfileRequest + var req accounts.UpdateProfileRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - resp, err := h.service.UpdateProfile(c.Request().Context(), userID, req) + resp, err := h.service.UpdateProfile(c.Request().Context(), channelIdentityID, req) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -116,22 +158,22 @@ func (h *UsersHandler) UpdateMe(c echo.Context) error { // @Summary Update current user password // @Description Update current user password with current password check // @Tags users -// @Param payload body users.UpdatePasswordRequest true "Password payload" +// @Param payload body accounts.UpdatePasswordRequest true "Password payload" // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /users/me/password [put] func (h *UsersHandler) UpdateMyPassword(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - var req users.UpdatePasswordRequest + var req accounts.UpdatePasswordRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if err := h.service.UpdatePassword(c.Request().Context(), userID, req.CurrentPassword, req.NewPassword); err != nil { - if errors.Is(err, users.ErrInvalidPassword) { + if err := h.service.UpdatePassword(c.Request().Context(), channelIdentityID, req.CurrentPassword, req.NewPassword); err != nil { + if errors.Is(err, accounts.ErrInvalidPassword) { return echo.NewHTTPError(http.StatusBadRequest, "current password mismatch") } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) @@ -143,17 +185,17 @@ func (h *UsersHandler) UpdateMyPassword(c echo.Context) error { // @Summary List users (admin only) // @Description List users // @Tags users -// @Success 200 {object} users.ListUsersResponse +// @Success 200 {object} accounts.ListAccountsResponse // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /users [get] func (h *UsersHandler) ListUsers(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - isAdmin, err := h.service.IsAdmin(c.Request().Context(), actorID) + isAdmin, err := h.service.IsAdmin(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -163,11 +205,11 @@ func (h *UsersHandler) ListUsers(c echo.Context) error { if strings.TrimSpace(c.QueryParam("user_type")) != "" || strings.TrimSpace(c.QueryParam("owner_id")) != "" { return echo.NewHTTPError(http.StatusBadRequest, "user_type and owner_id are not supported") } - items, err := h.service.ListUsers(c.Request().Context()) + items, err := h.service.ListAccounts(c.Request().Context()) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - return c.JSON(http.StatusOK, users.ListUsersResponse{Items: items}) + return c.JSON(http.StatusOK, accounts.ListAccountsResponse{Items: items}) } // GetUser godoc @@ -175,14 +217,14 @@ func (h *UsersHandler) ListUsers(c echo.Context) error { // @Description Get user details (self or admin only) // @Tags users // @Param id path string true "User ID" -// @Success 200 {object} users.User +// @Success 200 {object} accounts.Account // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /users/{id} [get] func (h *UsersHandler) GetUser(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -190,8 +232,8 @@ func (h *UsersHandler) GetUser(c echo.Context) error { if targetID == "" { return echo.NewHTTPError(http.StatusBadRequest, "user id is required") } - if targetID != actorID { - isAdmin, err := h.service.IsAdmin(c.Request().Context(), actorID) + if targetID != channelIdentityID { + isAdmin, err := h.service.IsAdmin(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -214,19 +256,19 @@ func (h *UsersHandler) GetUser(c echo.Context) error { // @Description Update user profile and status // @Tags users // @Param id path string true "User ID" -// @Param payload body users.UpdateUserRequest true "User update payload" -// @Success 200 {object} users.User +// @Param payload body accounts.UpdateAccountRequest true "User update payload" +// @Success 200 {object} accounts.Account // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /users/{id} [put] func (h *UsersHandler) UpdateUser(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - isAdmin, err := h.service.IsAdmin(c.Request().Context(), actorID) + isAdmin, err := h.service.IsAdmin(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -244,11 +286,11 @@ func (h *UsersHandler) UpdateUser(c echo.Context) error { } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - var req users.UpdateUserRequest + var req accounts.UpdateAccountRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - resp, err := h.service.UpdateUserAdmin(c.Request().Context(), targetID, req) + resp, err := h.service.UpdateAdmin(c.Request().Context(), targetID, req) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -260,7 +302,7 @@ func (h *UsersHandler) UpdateUser(c echo.Context) error { // @Description Reset a user password // @Tags users // @Param id path string true "User ID" -// @Param payload body users.ResetPasswordRequest true "Password payload" +// @Param payload body accounts.ResetPasswordRequest true "Password payload" // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse @@ -268,11 +310,11 @@ func (h *UsersHandler) UpdateUser(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /users/{id}/password [put] func (h *UsersHandler) ResetUserPassword(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - isAdmin, err := h.service.IsAdmin(c.Request().Context(), actorID) + isAdmin, err := h.service.IsAdmin(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -289,7 +331,7 @@ func (h *UsersHandler) ResetUserPassword(c echo.Context) error { } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - var req users.ResetPasswordRequest + var req accounts.ResetPasswordRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } @@ -303,29 +345,29 @@ func (h *UsersHandler) ResetUserPassword(c echo.Context) error { // @Summary Create human user (admin only) // @Description Create a new human user account // @Tags users -// @Param payload body users.CreateUserRequest true "User payload" -// @Success 201 {object} users.User +// @Param payload body accounts.CreateAccountRequest true "User payload" +// @Success 201 {object} accounts.Account // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /users [post] func (h *UsersHandler) CreateUser(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - isAdmin, err := h.service.IsAdmin(c.Request().Context(), actorID) + isAdmin, err := h.service.IsAdmin(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } if !isAdmin { return echo.NewHTTPError(http.StatusForbidden, "admin role required") } - var req users.CreateUserRequest + var req accounts.CreateAccountRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - resp, err := h.service.CreateHuman(c.Request().Context(), req) + resp, err := h.service.CreateHuman(c.Request().Context(), "", req) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -343,7 +385,7 @@ func (h *UsersHandler) CreateUser(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots [post] func (h *UsersHandler) CreateBot(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -351,19 +393,53 @@ func (h *UsersHandler) CreateBot(c echo.Context) error { if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - ownerID := actorID + ownerID := channelIdentityID + ownerFromToken := true if raw := strings.TrimSpace(c.QueryParam("owner_id")); raw != "" { - isAdmin, err := h.service.IsAdmin(c.Request().Context(), actorID) + isAdmin, err := h.service.IsAdmin(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } if !isAdmin { return echo.NewHTTPError(http.StatusForbidden, "admin role required for owner override") } + if err := identity.ValidateChannelIdentityID(raw); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } ownerID = raw + ownerFromToken = false + } + if ownerFromToken { + if _, err := h.service.Get(c.Request().Context(), ownerID); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + // Backward-compatible token path: token user_id might be a channel identity ID. + // Try to resolve to linked user first; if still missing, force re-login. + linkedUserID := "" + if h.channelIdentityService != nil { + linkedUserID, err = h.channelIdentityService.GetLinkedUserID(c.Request().Context(), ownerID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + } + linkedUserID = strings.TrimSpace(linkedUserID) + if linkedUserID != "" { + ownerID = linkedUserID + } else { + return echo.NewHTTPError(http.StatusUnauthorized, "owner user not found, please login again") + } + } else { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + } } resp, err := h.botService.Create(c.Request().Context(), ownerID, req) if err != nil { + if errors.Is(err, bots.ErrOwnerUserNotFound) { + if ownerFromToken { + return echo.NewHTTPError(http.StatusUnauthorized, "owner user not found, please login again") + } + return echo.NewHTTPError(http.StatusBadRequest, "owner user not found") + } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.JSON(http.StatusCreated, resp) @@ -380,13 +456,13 @@ func (h *UsersHandler) CreateBot(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots [get] func (h *UsersHandler) ListBots(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } ownerID := strings.TrimSpace(c.QueryParam("owner_id")) if ownerID != "" { - isAdmin, err := h.service.IsAdmin(c.Request().Context(), actorID) + isAdmin, err := h.service.IsAdmin(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -399,7 +475,7 @@ func (h *UsersHandler) ListBots(c echo.Context) error { } return c.JSON(http.StatusOK, bots.ListBotsResponse{Items: items}) } - items, err := h.botService.ListAccessible(c.Request().Context(), actorID) + items, err := h.botService.ListAccessible(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -418,7 +494,7 @@ func (h *UsersHandler) ListBots(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id} [get] func (h *UsersHandler) GetBot(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -426,13 +502,46 @@ func (h *UsersHandler) GetBot(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - bot, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID) + bot, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID) if err != nil { return err } return c.JSON(http.StatusOK, bot) } +// ListBotChecks godoc +// @Summary List bot runtime checks +// @Description Evaluate bot attached resource checks in runtime +// @Tags bots +// @Param id path string true "Bot ID" +// @Success 200 {object} bots.ListChecksResponse +// @Failure 400 {object} ErrorResponse +// @Failure 403 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /bots/{id}/checks [get] +func (h *UsersHandler) ListBotChecks(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { + return err + } + items, err := h.botService.ListChecks(c.Request().Context(), botID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return echo.NewHTTPError(http.StatusNotFound, "bot not found") + } + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, bots.ListChecksResponse{Items: items}) +} + // UpdateBot godoc // @Summary Update bot details // @Description Update bot profile (owner/admin only) @@ -446,7 +555,7 @@ func (h *UsersHandler) GetBot(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id} [put] func (h *UsersHandler) UpdateBot(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -454,7 +563,7 @@ func (h *UsersHandler) UpdateBot(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } var req bots.UpdateBotRequest @@ -481,11 +590,11 @@ func (h *UsersHandler) UpdateBot(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id}/owner [put] func (h *UsersHandler) TransferBotOwner(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - isAdmin, err := h.service.IsAdmin(c.Request().Context(), actorID) + isAdmin, err := h.service.IsAdmin(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -505,6 +614,9 @@ func (h *UsersHandler) TransferBotOwner(c echo.Context) error { if errors.Is(err, pgx.ErrNoRows) { return echo.NewHTTPError(http.StatusNotFound, "bot not found") } + if errors.Is(err, bots.ErrOwnerUserNotFound) { + return echo.NewHTTPError(http.StatusBadRequest, "owner user not found") + } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.JSON(http.StatusOK, resp) @@ -515,14 +627,14 @@ func (h *UsersHandler) TransferBotOwner(c echo.Context) error { // @Description Delete a bot user (owner/admin only) // @Tags bots // @Param id path string true "Bot ID" -// @Success 204 "No Content" +// @Success 202 {object} map[string]string // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /bots/{id} [delete] func (h *UsersHandler) DeleteBot(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -530,7 +642,7 @@ func (h *UsersHandler) DeleteBot(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } if err := h.botService.Delete(c.Request().Context(), botID); err != nil { @@ -539,7 +651,10 @@ func (h *UsersHandler) DeleteBot(c echo.Context) error { } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - return c.NoContent(http.StatusNoContent) + return c.JSON(http.StatusAccepted, map[string]string{ + "id": botID, + "status": bots.BotStatusDeleting, + }) } // ListBotMembers godoc @@ -554,7 +669,7 @@ func (h *UsersHandler) DeleteBot(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id}/members [get] func (h *UsersHandler) ListBotMembers(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -562,7 +677,7 @@ func (h *UsersHandler) ListBotMembers(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } items, err := h.botService.ListMembers(c.Request().Context(), botID) @@ -585,7 +700,7 @@ func (h *UsersHandler) ListBotMembers(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id}/members [put] func (h *UsersHandler) UpsertBotMember(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -593,7 +708,7 @@ func (h *UsersHandler) UpsertBotMember(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } var req bots.UpsertMemberRequest @@ -624,7 +739,7 @@ func (h *UsersHandler) UpsertBotMember(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id}/members/{user_id} [delete] func (h *UsersHandler) DeleteBotMember(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -632,14 +747,14 @@ func (h *UsersHandler) DeleteBotMember(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - userID := strings.TrimSpace(c.Param("user_id")) - if userID == "" { + memberUserID := strings.TrimSpace(c.Param("user_id")) + if memberUserID == "" { return echo.NewHTTPError(http.StatusBadRequest, "user id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } - if err := h.botService.DeleteMember(c.Request().Context(), botID, userID); err != nil { + if err := h.botService.DeleteMember(c.Request().Context(), botID, memberUserID); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.NoContent(http.StatusNoContent) @@ -658,7 +773,7 @@ func (h *UsersHandler) DeleteBotMember(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id}/channel/{platform} [get] func (h *UsersHandler) GetBotChannelConfig(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -666,7 +781,7 @@ func (h *UsersHandler) GetBotChannelConfig(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } channelType, err := h.registry.ParseChannelType(c.Param("platform")) @@ -697,7 +812,7 @@ func (h *UsersHandler) GetBotChannelConfig(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id}/channel/{platform} [put] func (h *UsersHandler) UpsertBotChannelConfig(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -705,7 +820,7 @@ func (h *UsersHandler) UpsertBotChannelConfig(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } channelType, err := h.registry.ParseChannelType(c.Param("platform")) @@ -740,7 +855,7 @@ func (h *UsersHandler) UpsertBotChannelConfig(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id}/channel/{platform}/send [post] func (h *UsersHandler) SendBotMessage(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -748,7 +863,7 @@ func (h *UsersHandler) SendBotMessage(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } if h.channelManager == nil { @@ -783,9 +898,9 @@ func (h *UsersHandler) SendBotMessage(c echo.Context) error { // @Failure 401 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/channel/{platform}/send_session [post] +// @Router /bots/{id}/channel/{platform}/send_chat [post] func (h *UsersHandler) SendBotMessageSession(c echo.Context) error { - sessionToken, err := auth.SessionTokenFromContext(c) + chatToken, err := auth.ChatTokenFromContext(c) if err != nil { return err } @@ -793,16 +908,24 @@ func (h *UsersHandler) SendBotMessageSession(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - channelType, err := h.registry.ParseChannelType(c.Param("platform")) + if chatToken.BotID != botID { + return echo.NewHTTPError(http.StatusForbidden, "token bot mismatch") + } + if h.channelManager == nil || h.routeService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "services not configured") + } + route, err := h.routeService.GetByID(c.Request().Context(), chatToken.RouteID) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "route not found") + } + if strings.TrimSpace(route.ReplyTarget) == "" { + return echo.NewHTTPError(http.StatusBadRequest, "reply target missing in route") + } + channelType, err := h.registry.ParseChannelType(route.Platform) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if sessionToken.BotID != botID || sessionToken.Platform != channelType.String() { - return echo.NewHTTPError(http.StatusForbidden, "session token mismatch") - } - if h.channelManager == nil { - return echo.NewHTTPError(http.StatusInternalServerError, "channel manager not configured") - } + var req channel.SendRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) @@ -810,11 +933,8 @@ func (h *UsersHandler) SendBotMessageSession(c echo.Context) error { if req.Message.IsEmpty() { return echo.NewHTTPError(http.StatusBadRequest, "message is required") } - if strings.TrimSpace(sessionToken.ReplyTarget) == "" { - return echo.NewHTTPError(http.StatusBadRequest, "reply target missing") - } if err := h.channelManager.Send(c.Request().Context(), botID, channelType, channel.SendRequest{ - Target: sessionToken.ReplyTarget, + Target: route.ReplyTarget, Message: req.Message, }); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) @@ -822,12 +942,12 @@ func (h *UsersHandler) SendBotMessageSession(c echo.Context) error { return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) } -func (h *UsersHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - isAdmin, err := h.service.IsAdmin(ctx, actorID) +func (h *UsersHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + isAdmin, err := h.service.IsAdmin(ctx, channelIdentityID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") @@ -840,13 +960,13 @@ func (h *UsersHandler) authorizeBotAccess(ctx context.Context, actorID, botID st return bot, nil } -func (h *UsersHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) +func (h *UsersHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.UserIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return userID, nil + return channelIdentityID, nil } diff --git a/internal/history/service.go b/internal/history/service.go deleted file mode 100644 index a407557c..00000000 --- a/internal/history/service.go +++ /dev/null @@ -1,237 +0,0 @@ -package history - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "log/slog" - "strings" - "time" - - "github.com/google/uuid" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - - "github.com/memohai/memoh/internal/db/sqlc" -) - -const defaultListLimit = 50 - -type Service struct { - queries *sqlc.Queries - logger *slog.Logger -} - -func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { - return &Service{ - queries: queries, - logger: log.With(slog.String("service", "history")), - } -} - -func (s *Service) Create(ctx context.Context, botID, sessionID string, req CreateRequest) (Record, error) { - if len(req.Messages) == 0 { - return Record{}, fmt.Errorf("messages are required") - } - botUUID, err := parseUUID(botID) - if err != nil { - return Record{}, err - } - trimmedSession := strings.TrimSpace(sessionID) - if trimmedSession == "" { - return Record{}, fmt.Errorf("session id is required") - } - payload, err := json.Marshal(req.Messages) - if err != nil { - return Record{}, err - } - meta := req.Metadata - if meta == nil { - meta = map[string]any{} - } - metaPayload, err := json.Marshal(meta) - if err != nil { - return Record{}, err - } - row, err := s.queries.CreateHistory(ctx, sqlc.CreateHistoryParams{ - BotID: botUUID, - SessionID: trimmedSession, - Messages: payload, - Metadata: metaPayload, - Skills: normalizeSkills(req.Skills), - Timestamp: pgtype.Timestamptz{ - Time: time.Now().UTC(), - Valid: true, - }, - }) - if err != nil { - return Record{}, err - } - return toRecord(row) -} - -func (s *Service) Get(ctx context.Context, id string) (Record, error) { - pgID, err := parseUUID(id) - if err != nil { - return Record{}, err - } - row, err := s.queries.GetHistoryByID(ctx, pgID) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return Record{}, fmt.Errorf("history not found") - } - return Record{}, err - } - return toRecord(row) -} - -func (s *Service) List(ctx context.Context, botID, sessionID string, limit int) ([]Record, error) { - botUUID, err := parseUUID(botID) - if err != nil { - return nil, err - } - trimmedSession := strings.TrimSpace(sessionID) - if trimmedSession == "" { - return nil, fmt.Errorf("session id is required") - } - if limit <= 0 { - limit = defaultListLimit - } - rows, err := s.queries.ListHistoryByBotSession(ctx, sqlc.ListHistoryByBotSessionParams{ - BotID: botUUID, - SessionID: trimmedSession, - Limit: int32(limit), - }) - if err != nil { - return nil, err - } - items := make([]Record, 0, len(rows)) - for _, row := range rows { - record, err := toRecord(row) - if err != nil { - return nil, err - } - items = append(items, record) - } - return items, nil -} - -func (s *Service) ListBySessionSince(ctx context.Context, botID, sessionID string, since time.Time) ([]Record, error) { - botUUID, err := parseUUID(botID) - if err != nil { - return nil, err - } - trimmedSession := strings.TrimSpace(sessionID) - if trimmedSession == "" { - return nil, fmt.Errorf("session id is required") - } - rows, err := s.queries.ListHistoryByBotSessionSince(ctx, sqlc.ListHistoryByBotSessionSinceParams{ - BotID: botUUID, - SessionID: trimmedSession, - Timestamp: pgtype.Timestamptz{ - Time: since, - Valid: true, - }, - }) - if err != nil { - return nil, err - } - items := make([]Record, 0, len(rows)) - for _, row := range rows { - record, err := toRecord(row) - if err != nil { - return nil, err - } - items = append(items, record) - } - return items, nil -} - -func (s *Service) Delete(ctx context.Context, id string) error { - pgID, err := parseUUID(id) - if err != nil { - return err - } - return s.queries.DeleteHistoryByID(ctx, pgID) -} - -func (s *Service) DeleteBySession(ctx context.Context, botID, sessionID string) error { - botUUID, err := parseUUID(botID) - if err != nil { - return err - } - trimmedSession := strings.TrimSpace(sessionID) - if trimmedSession == "" { - return fmt.Errorf("session id is required") - } - return s.queries.DeleteHistoryByBotSession(ctx, sqlc.DeleteHistoryByBotSessionParams{ - BotID: botUUID, - SessionID: trimmedSession, - }) -} - -func toRecord(row sqlc.History) (Record, error) { - var messages []map[string]any - if len(row.Messages) > 0 { - if err := json.Unmarshal(row.Messages, &messages); err != nil { - return Record{}, err - } - } - var metadata map[string]any - if len(row.Metadata) > 0 { - if err := json.Unmarshal(row.Metadata, &metadata); err != nil { - return Record{}, err - } - } - record := Record{ - Messages: messages, - Metadata: metadata, - Skills: normalizeSkills(row.Skills), - } - if row.Timestamp.Valid { - record.Timestamp = row.Timestamp.Time - } - if row.ID.Valid { - id, err := uuid.FromBytes(row.ID.Bytes[:]) - if err == nil { - record.ID = id.String() - } - } - if row.BotID.Valid { - uid, err := uuid.FromBytes(row.BotID.Bytes[:]) - if err == nil { - record.BotID = uid.String() - } - } - record.SessionID = row.SessionID - return record, nil -} - -func normalizeSkills(skills []string) []string { - seen := map[string]struct{}{} - normalized := make([]string, 0, len(skills)) - for _, skill := range skills { - trimmed := strings.TrimSpace(skill) - if trimmed == "" { - continue - } - if _, ok := seen[trimmed]; ok { - continue - } - seen[trimmed] = struct{}{} - normalized = append(normalized, trimmed) - } - return normalized -} - -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} diff --git a/internal/history/types.go b/internal/history/types.go deleted file mode 100644 index 088d8cfb..00000000 --- a/internal/history/types.go +++ /dev/null @@ -1,23 +0,0 @@ -package history - -import "time" - -type Record struct { - ID string `json:"id"` - Messages []map[string]any `json:"messages"` - Metadata map[string]any `json:"metadata,omitempty"` - Skills []string `json:"skills"` - Timestamp time.Time `json:"timestamp"` - BotID string `json:"bot_id"` - SessionID string `json:"session_id"` -} - -type CreateRequest struct { - Messages []map[string]any `json:"messages"` - Metadata map[string]any `json:"metadata,omitempty"` - Skills []string `json:"skills,omitempty"` -} - -type ListResponse struct { - Items []Record `json:"items"` -} diff --git a/internal/identity/types.go b/internal/identity/types.go index 32f65acb..125e325a 100644 --- a/internal/identity/types.go +++ b/internal/identity/types.go @@ -3,10 +3,11 @@ package identity import "strings" const ( - UserTypeHuman = "human" - UserTypeBot = "bot" + IdentityTypeHuman = "human" + IdentityTypeBot = "bot" ) -func IsBotUserType(userType string) bool { - return strings.EqualFold(strings.TrimSpace(userType), UserTypeBot) +// IsBotIdentityType checks if the identity type is a bot. +func IsBotIdentityType(identityType string) bool { + return strings.EqualFold(strings.TrimSpace(identityType), IdentityTypeBot) } diff --git a/internal/identity/user.go b/internal/identity/user.go index 3f210c43..6e5b9d41 100644 --- a/internal/identity/user.go +++ b/internal/identity/user.go @@ -6,14 +6,14 @@ import ( ctr "github.com/memohai/memoh/internal/containerd" ) -// ValidateUserID enforces a conservative ID charset for isolation. -func ValidateUserID(userID string) error { - if userID == "" { - return fmt.Errorf("%w: user id required", ctr.ErrInvalidArgument) +// ValidateChannelIdentityID enforces a conservative ID charset for isolation. +func ValidateChannelIdentityID(channelIdentityID string) error { + if channelIdentityID == "" { + return fmt.Errorf("%w: channel identity id required", ctr.ErrInvalidArgument) } - for _, r := range userID { + for _, r := range channelIdentityID { if !(r == '-' || r == '_' || (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9')) { - return fmt.Errorf("%w: invalid user id", ctr.ErrInvalidArgument) + return fmt.Errorf("%w: invalid channel identity id", ctr.ErrInvalidArgument) } } return nil diff --git a/internal/logger/logger.go b/internal/logger/logger.go index fc8f970d..81afe073 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -14,7 +14,7 @@ var ( logKey = ctxKey{} ) -// Init 初始化全局日志 +// Init initializes the global logger with the given level and format (e.g. "debug", "json"). func Init(level, format string) { var handler slog.Handler opts := &slog.HandlerOptions{ @@ -31,7 +31,7 @@ func Init(level, format string) { slog.SetDefault(L) } -// FromContext 从 context 中获取 logger,如果不存在则返回全局 logger +// FromContext returns the logger from ctx, or the global logger if not set. func FromContext(ctx context.Context) *slog.Logger { if l, ok := ctx.Value(logKey).(*slog.Logger); ok { return l @@ -39,7 +39,7 @@ func FromContext(ctx context.Context) *slog.Logger { return L } -// WithContext 将 logger 注入 context +// WithContext stores the logger in ctx and returns the new context. func WithContext(ctx context.Context, l *slog.Logger) context.Context { return context.WithValue(ctx, logKey, l) } @@ -59,7 +59,7 @@ func parseLevel(level string) slog.Level { } } -// 快捷方法,支持强类型 slog.Attr 或松散的 key-value 对 +// Debug, Info, Warn, Error log with the global logger (slog.Attr or key-value pairs). func Debug(msg string, args ...any) { L.Debug(msg, args...) } func Info(msg string, args ...any) { L.Info(msg, args...) } func Warn(msg string, args ...any) { L.Warn(msg, args...) } diff --git a/internal/logger/logger_test.go b/internal/logger/logger_test.go index 676cd358..72d3eea8 100644 --- a/internal/logger/logger_test.go +++ b/internal/logger/logger_test.go @@ -7,29 +7,25 @@ import ( ) func TestInitAndLogging(t *testing.T) { - // 测试 JSON 格式 Init("debug", "json") - + if L.Enabled(context.Background(), slog.LevelDebug) != true { t.Error("expected debug level to be enabled") } - // 验证是否能正常输出(不崩溃) Info("test info message", "key", "value") } func TestContextLogger(t *testing.T) { Init("info", "text") - - // 创建一个带特定属性的 logger + expectedKey := "request_id" expectedValue := "12345" customLogger := L.With(expectedKey, expectedValue) - + ctx := WithContext(context.Background(), customLogger) extracted := FromContext(ctx) - - // 这里简单验证提取出来的是否是同一个(或者功能一致) + if extracted == nil { t.Fatal("extracted logger should not be nil") } diff --git a/internal/mcp/connections.go b/internal/mcp/connections.go index 0a829b1d..86804776 100644 --- a/internal/mcp/connections.go +++ b/internal/mcp/connections.go @@ -14,14 +14,14 @@ import ( // Connection represents a stored MCP connection for a bot. type Connection struct { - ID string `json:"id" validate:"required"` - BotID string `json:"bot_id" validate:"required"` - Name string `json:"name" validate:"required"` - Type string `json:"type" validate:"required"` - Config map[string]any `json:"config" validate:"required"` - Active bool `json:"active" validate:"required"` - CreatedAt time.Time `json:"created_at" validate:"required"` - UpdatedAt time.Time `json:"updated_at" validate:"required"` + ID string `json:"id"` + BotID string `json:"bot_id"` + Name string `json:"name"` + Type string `json:"type"` + Config map[string]any `json:"config"` + Active bool `json:"active"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // UpsertRequest is the payload for creating or updating MCP connections. @@ -222,8 +222,8 @@ func normalizeMCPConnection(row sqlc.McpConnection) (Connection, error) { return Connection{}, err } return Connection{ - ID: db.UUIDToString(row.ID), - BotID: db.UUIDToString(row.BotID), + ID: row.ID.String(), + BotID: row.BotID.String(), Name: strings.TrimSpace(row.Name), Type: strings.TrimSpace(row.Type), Config: config, diff --git a/internal/mcp/jsonrpc.go b/internal/mcp/jsonrpc.go index d6d4933e..18912243 100644 --- a/internal/mcp/jsonrpc.go +++ b/internal/mcp/jsonrpc.go @@ -2,9 +2,7 @@ package mcp import ( "encoding/json" - "fmt" "strings" - "sync" ) func IsNotification(req JSONRPCRequest) bool { @@ -18,78 +16,3 @@ func JSONRPCErrorResponse(id json.RawMessage, code int, message string) JSONRPCR Error: &JSONRPCError{Code: code, Message: message}, } } - -func BuildPayloads(req JSONRPCRequest, initOnce *sync.Once) ([]string, json.RawMessage, error) { - if req.JSONRPC == "" { - req.JSONRPC = "2.0" - } - targetID := req.ID - payloads := []string{} - shouldInit := req.Method != "initialize" && req.Method != "notifications/initialized" - if initOnce != nil { - ran := false - initOnce.Do(func() { - ran = true - }) - if ran { - // This is the first call on the session. - } else { - shouldInit = false - } - } - if shouldInit { - initReq := map[string]any{ - "jsonrpc": "2.0", - "id": "init-1", - "method": "initialize", - "params": map[string]any{ - "protocolVersion": "2025-06-18", - "capabilities": map[string]any{ - "roots": map[string]any{ - "listChanged": false, - }, - }, - "clientInfo": map[string]any{ - "name": "memoh-http-proxy", - "version": "v0", - }, - }, - } - initBytes, err := json.Marshal(initReq) - if err != nil { - return nil, nil, err - } - payloads = append(payloads, string(initBytes)) - - initialized := map[string]any{ - "jsonrpc": "2.0", - "method": "notifications/initialized", - } - initializedBytes, err := json.Marshal(initialized) - if err != nil { - return nil, nil, err - } - payloads = append(payloads, string(initializedBytes)) - } - - reqBytes, err := json.Marshal(req) - if err != nil { - return nil, nil, err - } - payloads = append(payloads, string(reqBytes)) - return payloads, targetID, nil -} - -func BuildNotificationPayloads(req JSONRPCRequest) ([]string, error) { - if req.JSONRPC == "" { - req.JSONRPC = "2.0" - } - if strings.TrimSpace(req.Method) == "" { - return nil, fmt.Errorf("missing method") - } - reqBytes, err := json.Marshal(req) - if err != nil { - return nil, err - } - return []string{string(reqBytes)}, nil -} diff --git a/internal/mcp/manager.go b/internal/mcp/manager.go index d492db1d..9ad442b8 100644 --- a/internal/mcp/manager.go +++ b/internal/mcp/manager.go @@ -1,11 +1,15 @@ package mcp import ( + "bytes" "context" + "errors" "fmt" "log/slog" "os" + "os/exec" "path/filepath" + "runtime" "strings" "time" @@ -23,6 +27,7 @@ import ( const ( BotLabelKey = "mcp.bot_id" ContainerPrefix = "mcp-" + DefaultImageRef = "memoh-mcp:dev" ) type ExecRequest struct { @@ -38,22 +43,34 @@ type ExecResult struct { ExitCode uint32 } +// ExecWithCaptureResult holds stdout, stderr and exit code from container exec. +type ExecWithCaptureResult struct { + Stdout string + Stderr string + ExitCode uint32 +} + type Manager struct { service ctr.Service cfg config.MCPConfig + namespace string containerID func(string) string db *pgxpool.Pool queries *dbsqlc.Queries logger *slog.Logger } -func NewManager(log *slog.Logger, service ctr.Service, cfg config.Config, db *pgxpool.Pool) *Manager { +func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string, conn *pgxpool.Pool) *Manager { + if namespace == "" { + namespace = config.DefaultNamespace + } return &Manager{ - db: db, - queries: dbsqlc.New(db), - service: service, - cfg: cfg.MCP, - logger: log.With(slog.String("component", "mcp")), + service: service, + cfg: cfg, + namespace: namespace, + db: conn, + queries: dbsqlc.New(conn), + logger: log.With(slog.String("component", "mcp")), containerID: func(botID string) string { return ContainerPrefix + botID }, @@ -61,10 +78,7 @@ func NewManager(log *slog.Logger, service ctr.Service, cfg config.Config, db *pg } func (m *Manager) Init(ctx context.Context) error { - image := m.cfg.BusyboxImage - if image == "" { - image = config.DefaultBusyboxImg - } + image := DefaultImageRef _, err := m.service.PullImage(ctx, image, &ctr.PullImageOptions{ Unpack: true, @@ -99,12 +113,6 @@ func (m *Manager) EnsureBot(ctx context.Context, botID string) error { Source: dataDir, Options: []string{"rbind", "rw"}, }, - { - Destination: "/app", - Type: "bind", - Source: dataDir, - Options: []string{"rbind", "rw"}, - }, { Destination: "/etc/resolv.conf", Type: "bind", @@ -238,6 +246,116 @@ func (m *Manager) Exec(ctx context.Context, req ExecRequest) (*ExecResult, error return &ExecResult{ExitCode: result.ExitCode}, nil } +// ExecWithCapture runs a command in the bot container and returns stdout, stderr and exit code. +// Use this when the caller needs command output (e.g. MCP exec tool). +// The container must already be running; use Start(botID) or the container/start API to start it. +// On darwin, it uses Lima SSH to avoid virtiofs FIFO synchronization issues. +func (m *Manager) ExecWithCapture(ctx context.Context, req ExecRequest) (*ExecWithCaptureResult, error) { + if err := validateBotID(req.BotID); err != nil { + return nil, err + } + if len(req.Command) == 0 { + return nil, fmt.Errorf("%w: empty command", ctr.ErrInvalidArgument) + } + if m.queries == nil { + return nil, fmt.Errorf("db is not configured") + } + + if runtime.GOOS == "darwin" { + return m.execWithCaptureLima(ctx, req) + } + return m.execWithCaptureContainerd(ctx, req) +} + +// execWithCaptureLima runs exec through Lima SSH so that all FIFO I/O stays +// inside the VM, avoiding virtiofs FIFO synchronization issues on macOS. +func (m *Manager) execWithCaptureLima(ctx context.Context, req ExecRequest) (*ExecWithCaptureResult, error) { + containerID := m.containerID(req.BotID) + execID := fmt.Sprintf("exec-%d", time.Now().UnixNano()) + + // Each element becomes a separate OS arg to limactl. Lima/SSH joins + // them with spaces and passes the result to the remote shell, so only + // values that may contain shell-special characters need quoting. + args := []string{"shell", "default", "--", + "sudo", "ctr", "-n", m.namespace, + "tasks", "exec", "--exec-id", execID, + } + if req.WorkDir != "" { + args = append(args, "--cwd", req.WorkDir) + } + for _, e := range req.Env { + args = append(args, "--env", e) + } + args = append(args, containerID) + // Pass command args as-is; Lima shell-quotes each OS arg for the + // remote SSH shell, preserving argument boundaries correctly. + args = append(args, req.Command...) + + cmd := exec.CommandContext(ctx, "limactl", args...) + var stdoutBuf, stderrBuf bytes.Buffer + cmd.Stdout = &stdoutBuf + cmd.Stderr = &stderrBuf + + exitCode := uint32(0) + if err := cmd.Run(); err != nil { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + exitCode = uint32(exitErr.ExitCode()) + } else { + return nil, fmt.Errorf("lima exec: %w", err) + } + } + + // ctr tasks exec may write its own errors to stderr; separate them from + // the container command's stderr output by checking for the ctr prefix. + stderr := stderrBuf.String() + if exitCode != 0 && strings.HasPrefix(stderr, "ctr:") { + return nil, fmt.Errorf("container exec failed: %s", strings.TrimSpace(stderr)) + } + + return &ExecWithCaptureResult{ + Stdout: stdoutBuf.String(), + Stderr: stderr, + ExitCode: exitCode, + }, nil +} + +// execWithCaptureContainerd uses the containerd ExecTask API with FIFO pipes. +// This works reliably on Linux where FIFO I/O stays on the same filesystem. +func (m *Manager) execWithCaptureContainerd(ctx context.Context, req ExecRequest) (*ExecWithCaptureResult, error) { + fifoDir, err := os.MkdirTemp(m.dataRoot(), "exec-fifo-") + if err != nil { + return nil, fmt.Errorf("create fifo dir: %w", err) + } + defer os.RemoveAll(fifoDir) + + var stdoutBuf, stderrBuf bytes.Buffer + result, err := m.service.ExecTask(ctx, m.containerID(req.BotID), ctr.ExecTaskRequest{ + Args: req.Command, + Env: req.Env, + WorkDir: req.WorkDir, + Stderr: &stderrBuf, + Stdout: &stdoutBuf, + FIFODir: fifoDir, + }) + if err != nil { + return nil, err + } + return &ExecWithCaptureResult{ + Stdout: stdoutBuf.String(), + Stderr: stderrBuf.String(), + ExitCode: result.ExitCode, + }, nil +} + +// sshShellQuote wraps a string in single quotes for safe SSH transport. +func sshShellQuote(s string) string { + if s == "" { + return "''" + } + return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'" +} + // DataDir returns the host data directory for a bot. func (m *Manager) DataDir(botID string) (string, error) { if err := validateBotID(botID); err != nil { @@ -270,12 +388,9 @@ func (m *Manager) dataMount() string { } func (m *Manager) imageRef() string { - if m.cfg.BusyboxImage == "" { - return config.DefaultBusyboxImg - } - return m.cfg.BusyboxImage + return DefaultImageRef } func validateBotID(botID string) error { - return identity.ValidateUserID(botID) + return identity.ValidateChannelIdentityID(botID) } diff --git a/internal/mcp/providers/container/fsops.go b/internal/mcp/providers/container/fsops.go new file mode 100644 index 00000000..49ac9009 --- /dev/null +++ b/internal/mcp/providers/container/fsops.go @@ -0,0 +1,288 @@ +package container + +import ( + "context" + "encoding/base64" + "fmt" + "path" + "strconv" + "strings" + "time" + "unicode" + + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +type fileEntry struct { + Path string + IsDir bool + Size int64 + Mode uint32 + ModTime time.Time +} + +// execRead reads a file inside the container via cat. +func execRead(ctx context.Context, runner ExecRunner, botID, workDir, filePath string) (string, error) { + result, err := runner.ExecWithCapture(ctx, mcpgw.ExecRequest{ + BotID: botID, + Command: []string{"/bin/sh", "-c", "cat " + shellQuote(filePath)}, + WorkDir: workDir, + }) + if err != nil { + return "", err + } + if result.ExitCode != 0 { + return "", fmt.Errorf("%s", strings.TrimSpace(result.Stderr)) + } + return result.Stdout, nil +} + +// execWrite writes content to a file inside the container using base64 encoding +// to avoid shell escaping issues. +func execWrite(ctx context.Context, runner ExecRunner, botID, workDir, filePath, content string) error { + encoded := base64.StdEncoding.EncodeToString([]byte(content)) + dir := path.Dir(filePath) + script := fmt.Sprintf("mkdir -p %s && echo %s | base64 -d > %s", + shellQuote(dir), shellQuote(encoded), shellQuote(filePath)) + result, err := runner.ExecWithCapture(ctx, mcpgw.ExecRequest{ + BotID: botID, + Command: []string{"/bin/sh", "-c", script}, + WorkDir: workDir, + }) + if err != nil { + return err + } + if result.ExitCode != 0 { + return fmt.Errorf("%s", strings.TrimSpace(result.Stderr)) + } + return nil +} + +// execList lists directory entries inside the container via find + stat. +// Output format per line: |||| +func execList(ctx context.Context, runner ExecRunner, botID, workDir, dirPath string, recursive bool) ([]fileEntry, error) { + depthFlag := "-maxdepth 1" + if recursive { + depthFlag = "" + } + // Use find to get entries, skip the root dir itself, then stat each entry. + // busybox stat -c format: %n=name, %F=type, %s=size, %a=octal mode, %Y=mtime epoch + script := fmt.Sprintf( + `find %s %s ! -path %s -exec stat -c '%%n|%%F|%%s|%%a|%%Y' {} \;`, + shellQuote(dirPath), depthFlag, shellQuote(dirPath), + ) + result, err := runner.ExecWithCapture(ctx, mcpgw.ExecRequest{ + BotID: botID, + Command: []string{"/bin/sh", "-c", script}, + WorkDir: workDir, + }) + if err != nil { + return nil, err + } + if result.ExitCode != 0 { + return nil, fmt.Errorf("%s", strings.TrimSpace(result.Stderr)) + } + return parseStatOutput(result.Stdout, dirPath), nil +} + +// parseStatOutput parses lines of "fullpath|type|size|mode|mtime" into fileEntry slices. +func parseStatOutput(output, basePath string) []fileEntry { + lines := strings.Split(strings.TrimSpace(output), "\n") + entries := make([]fileEntry, 0, len(lines)) + // Normalize base path for computing relative paths. + base := strings.TrimSuffix(basePath, "/") + if base == "" || base == "." { + base = "" + } + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + parts := strings.SplitN(line, "|", 5) + if len(parts) < 5 { + continue + } + fullPath := parts[0] + fileType := parts[1] + sizeStr := parts[2] + modeStr := parts[3] + mtimeStr := parts[4] + + // Compute relative path from base. + rel := fullPath + if base != "" { + rel = strings.TrimPrefix(fullPath, base+"/") + } + if rel == "" || rel == "." { + continue + } + + isDir := strings.Contains(fileType, "directory") + size, _ := strconv.ParseInt(sizeStr, 10, 64) + mode64, _ := strconv.ParseUint(modeStr, 8, 32) + mtimeEpoch, _ := strconv.ParseInt(mtimeStr, 10, 64) + modTime := time.Unix(mtimeEpoch, 0) + + entries = append(entries, fileEntry{ + Path: rel, + IsDir: isDir, + Size: size, + Mode: uint32(mode64), + ModTime: modTime, + }) + } + return entries +} + +// applyEdit performs the fuzzy text replacement logic on raw file content. +// Returns the updated content or an error. +func applyEdit(raw, filePath, oldText, newText string) (string, error) { + bom, content := stripBOM(raw) + originalEnding := detectLineEnding(content) + normalizedContent := normalizeToLF(content) + normalizedOld := normalizeToLF(oldText) + normalizedNew := normalizeToLF(newText) + match := fuzzyFindText(normalizedContent, normalizedOld) + if !match.Found { + return "", fmt.Errorf( + "could not find the exact text in %s. the old text must match exactly including all whitespace and newlines", + filePath, + ) + } + fuzzyContent := normalizeForFuzzyMatch(normalizedContent) + fuzzyOld := normalizeForFuzzyMatch(normalizedOld) + occurrences := strings.Count(fuzzyContent, fuzzyOld) + if occurrences > 1 { + return "", fmt.Errorf( + "found %d occurrences of the text in %s. the text must be unique. please provide more context to make it unique", + occurrences, + filePath, + ) + } + baseContent := match.ContentForReplacement + updated := baseContent[:match.Index] + normalizedNew + baseContent[match.Index+match.MatchLength:] + if baseContent == updated { + return "", fmt.Errorf( + "no changes made to %s. the replacement produced identical content. this might indicate an issue with special characters or the text not existing as expected", + filePath, + ) + } + return bom + restoreLineEndings(updated, originalEnding), nil +} + +// shellQuote wraps a string in single quotes, escaping embedded single quotes. +func shellQuote(s string) string { + if s == "" { + return "''" + } + if strings.IndexByte(s, '\'') < 0 { + return "'" + s + "'" + } + var b strings.Builder + b.WriteByte('\'') + for _, c := range s { + if c == '\'' { + b.WriteString("'\\''") + } else { + b.WriteRune(c) + } + } + b.WriteByte('\'') + return b.String() +} + +// ---------- fuzzy matching helpers (pure string processing, unchanged) ---------- + +type fuzzyMatchResult struct { + Found bool + Index int + MatchLength int + ContentForReplacement string +} + +func detectLineEnding(content string) string { + crlfIdx := strings.Index(content, "\r\n") + lfIdx := strings.Index(content, "\n") + if lfIdx == -1 { + return "\n" + } + if crlfIdx == -1 { + return "\n" + } + if crlfIdx < lfIdx { + return "\r\n" + } + return "\n" +} + +func normalizeToLF(text string) string { + text = strings.ReplaceAll(text, "\r\n", "\n") + return strings.ReplaceAll(text, "\r", "\n") +} + +func restoreLineEndings(text, ending string) string { + if ending == "\r\n" { + return strings.ReplaceAll(text, "\n", "\r\n") + } + return text +} + +func stripBOM(content string) (string, string) { + const bom = "\uFEFF" + if strings.HasPrefix(content, bom) { + return bom, content[len(bom):] + } + return "", content +} + +func normalizeForFuzzyMatch(text string) string { + lines := strings.Split(text, "\n") + for i, line := range lines { + lines[i] = strings.TrimRightFunc(line, unicode.IsSpace) + } + trimmed := strings.Join(lines, "\n") + return strings.Map(func(r rune) rune { + switch r { + case '\u2018', '\u2019', '\u201A', '\u201B': + return '\'' + case '\u201C', '\u201D', '\u201E', '\u201F': + return '"' + case '\u2010', '\u2011', '\u2012', '\u2013', '\u2014', '\u2015', '\u2212': + return '-' + case '\u00A0', '\u2002', '\u2003', '\u2004', '\u2005', '\u2006', '\u2007', '\u2008', '\u2009', '\u200A', '\u202F', '\u205F', '\u3000': + return ' ' + default: + return r + } + }, trimmed) +} + +func fuzzyFindText(content, oldText string) fuzzyMatchResult { + exactIndex := strings.Index(content, oldText) + if exactIndex != -1 { + return fuzzyMatchResult{ + Found: true, + Index: exactIndex, + MatchLength: len(oldText), + ContentForReplacement: content, + } + } + fuzzyContent := normalizeForFuzzyMatch(content) + fuzzyOld := normalizeForFuzzyMatch(oldText) + fuzzyIndex := strings.Index(fuzzyContent, fuzzyOld) + if fuzzyIndex == -1 { + return fuzzyMatchResult{ + Found: false, + Index: -1, + MatchLength: 0, + ContentForReplacement: content, + } + } + return fuzzyMatchResult{ + Found: true, + Index: fuzzyIndex, + MatchLength: len(fuzzyOld), + ContentForReplacement: fuzzyContent, + } +} diff --git a/internal/mcp/providers/container/fsops_test.go b/internal/mcp/providers/container/fsops_test.go new file mode 100644 index 00000000..acfb2f7d --- /dev/null +++ b/internal/mcp/providers/container/fsops_test.go @@ -0,0 +1,148 @@ +package container + +import "testing" + +func TestShellQuote(t *testing.T) { + tests := []struct { + in string + want string + }{ + {"hello", "'hello'"}, + {"", "''"}, + {"it's", `'it'\''s'`}, + {"a b", "'a b'"}, + } + for _, tt := range tests { + got := shellQuote(tt.in) + if got != tt.want { + t.Errorf("shellQuote(%q) = %q, want %q", tt.in, got, tt.want) + } + } +} + +func TestParseStatOutput(t *testing.T) { + output := `./file.txt|regular file|123|644|1700000000 +./subdir|directory|4096|755|1700000000 +` + entries := parseStatOutput(output, ".") + if len(entries) != 2 { + t.Fatalf("got %d entries, want 2", len(entries)) + } + if entries[0].Path != "./file.txt" { + t.Errorf("path[0] = %q", entries[0].Path) + } + if entries[0].IsDir { + t.Error("file.txt should not be a directory") + } + if entries[0].Size != 123 { + t.Errorf("size[0] = %d", entries[0].Size) + } + if entries[1].Path != "./subdir" { + t.Errorf("path[1] = %q", entries[1].Path) + } + if !entries[1].IsDir { + t.Error("subdir should be a directory") + } +} + +func TestParseStatOutput_WithBasePath(t *testing.T) { + output := `/data/test/file.txt|regular file|10|644|1700000000 +/data/test/sub|directory|4096|755|1700000000 +` + entries := parseStatOutput(output, "/data/test") + if len(entries) != 2 { + t.Fatalf("got %d entries, want 2", len(entries)) + } + if entries[0].Path != "file.txt" { + t.Errorf("path[0] = %q, want %q", entries[0].Path, "file.txt") + } + if entries[1].Path != "sub" { + t.Errorf("path[1] = %q, want %q", entries[1].Path, "sub") + } +} + +func TestParseStatOutput_Empty(t *testing.T) { + entries := parseStatOutput("", ".") + if len(entries) != 0 { + t.Errorf("got %d entries for empty output", len(entries)) + } +} + +func TestApplyEdit(t *testing.T) { + raw := "hello world\n" + updated, err := applyEdit(raw, "test.txt", "hello", "goodbye") + if err != nil { + t.Fatal(err) + } + if updated != "goodbye world\n" { + t.Errorf("updated = %q", updated) + } +} + +func TestApplyEdit_NotFound(t *testing.T) { + raw := "hello world\n" + _, err := applyEdit(raw, "test.txt", "missing text", "new") + if err == nil { + t.Error("expected error for missing text") + } +} + +func TestApplyEdit_MultipleOccurrences(t *testing.T) { + raw := "foo bar foo\n" + _, err := applyEdit(raw, "test.txt", "foo", "baz") + if err == nil { + t.Error("expected error for multiple occurrences") + } +} + +func TestApplyEdit_NoChange(t *testing.T) { + raw := "hello world\n" + _, err := applyEdit(raw, "test.txt", "hello", "hello") + if err == nil { + t.Error("expected error for identical replacement") + } +} + +func TestFuzzyFindText(t *testing.T) { + tests := []struct { + name string + content string + old string + found bool + }{ + {"exact match", "hello world", "hello", true}, + {"no match", "hello world", "missing", false}, + {"smart quote match", "it\u2019s a test", "it's a test", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := fuzzyFindText(tt.content, tt.old) + if result.Found != tt.found { + t.Errorf("found = %v, want %v", result.Found, tt.found) + } + }) + } +} + +func TestDetectLineEnding(t *testing.T) { + if detectLineEnding("foo\r\nbar") != "\r\n" { + t.Error("expected CRLF") + } + if detectLineEnding("foo\nbar") != "\n" { + t.Error("expected LF") + } + if detectLineEnding("foo") != "\n" { + t.Error("expected LF default") + } +} + +func TestStripBOM(t *testing.T) { + bom, content := stripBOM("\uFEFFhello") + if bom != "\uFEFF" || content != "hello" { + t.Errorf("bom=%q content=%q", bom, content) + } + bom2, content2 := stripBOM("hello") + if bom2 != "" || content2 != "hello" { + t.Errorf("bom=%q content=%q", bom2, content2) + } +} diff --git a/internal/mcp/providers/container/provider.go b/internal/mcp/providers/container/provider.go new file mode 100644 index 00000000..f0b95036 --- /dev/null +++ b/internal/mcp/providers/container/provider.go @@ -0,0 +1,250 @@ +package container + +import ( + "context" + "log/slog" + "strings" + + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +const ( + toolRead = "read" + toolWrite = "write" + toolList = "list" + toolEdit = "edit" + toolExec = "exec" + + defaultExecWorkDir = "/data" + shellCommandName = "/bin/sh" + shellCommandFlag = "-c" +) + +// ExecRunner runs a command in the bot container and returns stdout, stderr and exit code. +type ExecRunner interface { + ExecWithCapture(ctx context.Context, req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) +} + +// Executor provides filesystem and exec tools (read, write, list, edit, exec) that +// operate inside the bot container via ExecRunner. All I/O goes through the container +// sandbox — no direct host filesystem access. +type Executor struct { + execRunner ExecRunner + execWorkDir string + logger *slog.Logger +} + +// NewExecutor returns a tool executor. execRunner is required — all tools delegate +// to it for container-side I/O. execWorkDir is the default working directory inside +// the container (e.g. /data). +func NewExecutor(log *slog.Logger, execRunner ExecRunner, execWorkDir string) *Executor { + if log == nil { + log = slog.Default() + } + wd := strings.TrimSpace(execWorkDir) + if wd == "" { + wd = defaultExecWorkDir + } + return &Executor{ + execRunner: execRunner, + execWorkDir: wd, + logger: log.With(slog.String("provider", "container_tool")), + } +} + +// ListTools returns read, write, list, edit, and exec tool descriptors. +func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { + return []mcpgw.ToolDescriptor{ + { + Name: toolRead, + Description: "Read file content inside the bot container.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string", "description": "file path (relative to /data or absolute inside container)"}, + }, + "required": []string{"path"}, + }, + }, + { + Name: toolWrite, + Description: "Write file content inside the bot container.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string", "description": "file path (relative to /data or absolute inside container)"}, + "content": map[string]any{"type": "string", "description": "file content"}, + }, + "required": []string{"path", "content"}, + }, + }, + { + Name: toolList, + Description: "List directory entries inside the bot container.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string", "description": "directory path (relative to /data or absolute inside container)"}, + "recursive": map[string]any{"type": "boolean", "description": "list recursively"}, + }, + "required": []string{"path"}, + }, + }, + { + Name: toolEdit, + Description: "Replace exact text in a file inside the bot container.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string", "description": "file path (relative to /data or absolute inside container)"}, + "old_text": map[string]any{"type": "string", "description": "exact text to find"}, + "new_text": map[string]any{"type": "string", "description": "replacement text"}, + }, + "required": []string{"path", "old_text", "new_text"}, + }, + }, + { + Name: toolExec, + Description: "Execute a command in the bot container. Runs in the bot's data directory (/data) by default.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "command": map[string]any{ + "type": "string", + "description": "Shell command to run (e.g. ls -la, cat file.txt)", + }, + "work_dir": map[string]any{ + "type": "string", + "description": "Working directory inside the container (default: /data)", + }, + }, + "required": []string{"command"}, + }, + }, + }, nil +} + +// normalizePath converts paths that the LLM may send as /data/... into relative +// paths under the working directory. e.g. /data/test.txt -> test.txt, /data -> . +func normalizePath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return path + } + const prefix = "/data" + if path == prefix { + return "." + } + if strings.HasPrefix(path, prefix+"/") { + return strings.TrimLeft(strings.TrimPrefix(path, prefix+"/"), "/") + } + return path +} + +// CallTool dispatches to the appropriate container-exec backed implementation. +func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { + botID := strings.TrimSpace(session.BotID) + if botID == "" { + return mcpgw.BuildToolErrorResult("bot_id is required"), nil + } + + switch toolName { + case toolRead: + filePath := normalizePath(mcpgw.StringArg(arguments, "path")) + if filePath == "" { + return mcpgw.BuildToolErrorResult("path is required"), nil + } + content, err := execRead(ctx, p.execRunner, botID, p.execWorkDir, filePath) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + return mcpgw.BuildToolSuccessResult(map[string]any{"content": content}), nil + + case toolWrite: + filePath := normalizePath(mcpgw.StringArg(arguments, "path")) + content := mcpgw.StringArg(arguments, "content") + if filePath == "" { + return mcpgw.BuildToolErrorResult("path is required"), nil + } + if err := execWrite(ctx, p.execRunner, botID, p.execWorkDir, filePath, content); err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + return mcpgw.BuildToolSuccessResult(map[string]any{"ok": true}), nil + + case toolList: + dirPath := normalizePath(mcpgw.StringArg(arguments, "path")) + if dirPath == "" { + dirPath = "." + } + recursive, _, _ := mcpgw.BoolArg(arguments, "recursive") + entries, err := execList(ctx, p.execRunner, botID, p.execWorkDir, dirPath, recursive) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + entriesMaps := make([]map[string]any, len(entries)) + for i, e := range entries { + entriesMaps[i] = map[string]any{ + "path": e.Path, + "is_dir": e.IsDir, + "size": e.Size, + "mode": e.Mode, + "mod_time": e.ModTime, + } + } + return mcpgw.BuildToolSuccessResult(map[string]any{"path": dirPath, "entries": entriesMaps}), nil + + case toolEdit: + filePath := normalizePath(mcpgw.StringArg(arguments, "path")) + oldText := mcpgw.StringArg(arguments, "old_text") + newText := mcpgw.StringArg(arguments, "new_text") + if filePath == "" || oldText == "" { + return mcpgw.BuildToolErrorResult("path, old_text and new_text are required"), nil + } + // Step 1: read via exec + raw, err := execRead(ctx, p.execRunner, botID, p.execWorkDir, filePath) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + // Step 2: fuzzy match in Go + updated, err := applyEdit(raw, filePath, oldText, newText) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + // Step 3: write back via exec + if err := execWrite(ctx, p.execRunner, botID, p.execWorkDir, filePath, updated); err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + return mcpgw.BuildToolSuccessResult(map[string]any{"ok": true}), nil + + case toolExec: + command := strings.TrimSpace(mcpgw.StringArg(arguments, "command")) + if command == "" { + return mcpgw.BuildToolErrorResult("command is required"), nil + } + workDir := strings.TrimSpace(mcpgw.StringArg(arguments, "work_dir")) + if workDir == "" { + workDir = p.execWorkDir + } + result, err := p.execRunner.ExecWithCapture(ctx, mcpgw.ExecRequest{ + BotID: botID, + Command: []string{shellCommandName, shellCommandFlag, command}, + WorkDir: workDir, + }) + if err != nil { + p.logger.Warn("exec failed", slog.String("bot_id", botID), slog.String("command", command), slog.Any("error", err)) + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + stderr := result.Stderr + if result.ExitCode != 0 && strings.Contains(stderr, "no running task") { + stderr = strings.TrimSpace(stderr) + "\n\nHint: Container exists but has no running task (main process exited). Start it first: POST /bots/" + botID + "/container/start or use the container start action in the UI." + } + return mcpgw.BuildToolSuccessResult(map[string]any{ + "stdout": result.Stdout, + "stderr": stderr, + "exit_code": result.ExitCode, + }), nil + + default: + return nil, mcpgw.ErrToolNotFound + } +} diff --git a/internal/mcp/providers/container/provider_test.go b/internal/mcp/providers/container/provider_test.go new file mode 100644 index 00000000..d921d4d9 --- /dev/null +++ b/internal/mcp/providers/container/provider_test.go @@ -0,0 +1,236 @@ +package container + +import ( + "context" + "encoding/base64" + "fmt" + "strings" + "testing" + + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +// fakeExecRunner records the last request and returns a preset result. +type fakeExecRunner struct { + result *mcpgw.ExecWithCaptureResult + err error + lastReq mcpgw.ExecRequest + handler func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) +} + +func (f *fakeExecRunner) ExecWithCapture(ctx context.Context, req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) { + f.lastReq = req + if f.handler != nil { + return f.handler(req) + } + if f.err != nil { + return nil, f.err + } + return f.result, nil +} + +func TestExecutor_ListTools(t *testing.T) { + runner := &fakeExecRunner{result: &mcpgw.ExecWithCaptureResult{}} + exec := NewExecutor(nil, runner, "/data") + ctx := context.Background() + session := mcpgw.ToolSessionContext{BotID: "test-bot"} + tools, err := exec.ListTools(ctx, session) + if err != nil { + t.Fatal(err) + } + want := map[string]bool{"read": true, "write": true, "list": true, "edit": true, "exec": true} + if len(tools) != len(want) { + t.Errorf("got %d tools, want %d", len(tools), len(want)) + } + for _, tool := range tools { + if !want[tool.Name] { + t.Errorf("unexpected tool %q", tool.Name) + } + } +} + +func TestExecutor_CallTool_Read(t *testing.T) { + runner := &fakeExecRunner{ + result: &mcpgw.ExecWithCaptureResult{Stdout: "hello world", ExitCode: 0}, + } + exec := NewExecutor(nil, runner, "/data") + ctx := context.Background() + session := mcpgw.ToolSessionContext{BotID: "bot1"} + + result, err := exec.CallTool(ctx, session, "read", map[string]any{"path": "test.txt"}) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + content, _ := result["structuredContent"].(map[string]any) + if content["content"] != "hello world" { + t.Errorf("content = %v", content["content"]) + } + // Verify the exec command contains cat. + cmd := strings.Join(runner.lastReq.Command, " ") + if !strings.Contains(cmd, "cat") { + t.Errorf("expected cat command, got %q", cmd) + } +} + +func TestExecutor_CallTool_Write(t *testing.T) { + runner := &fakeExecRunner{ + handler: func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) { + cmd := strings.Join(req.Command, " ") + if !strings.Contains(cmd, "base64 -d") { + return nil, fmt.Errorf("expected base64 write, got %q", cmd) + } + return &mcpgw.ExecWithCaptureResult{ExitCode: 0}, nil + }, + } + exec := NewExecutor(nil, runner, "/data") + ctx := context.Background() + session := mcpgw.ToolSessionContext{BotID: "bot1"} + + result, err := exec.CallTool(ctx, session, "write", map[string]any{ + "path": "hello.txt", "content": "world", + }) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } +} + +func TestExecutor_CallTool_List(t *testing.T) { + runner := &fakeExecRunner{ + result: &mcpgw.ExecWithCaptureResult{ + Stdout: "./test.txt|regular file|42|644|1700000000\n./subdir|directory|4096|755|1700000000\n", + ExitCode: 0, + }, + } + exec := NewExecutor(nil, runner, "/data") + ctx := context.Background() + session := mcpgw.ToolSessionContext{BotID: "bot1"} + + result, err := exec.CallTool(ctx, session, "list", map[string]any{"path": "."}) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + content, _ := result["structuredContent"].(map[string]any) + entries, ok := content["entries"].([]map[string]any) + if !ok { + t.Fatalf("entries type = %T", content["entries"]) + } + if len(entries) != 2 { + t.Fatalf("got %d entries, want 2", len(entries)) + } +} + +func TestExecutor_CallTool_Edit(t *testing.T) { + callCount := 0 + runner := &fakeExecRunner{ + handler: func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) { + callCount++ + cmd := strings.Join(req.Command, " ") + if strings.Contains(cmd, "cat") { + // Read step: return original content. + return &mcpgw.ExecWithCaptureResult{Stdout: "hello world", ExitCode: 0}, nil + } + if strings.Contains(cmd, "base64 -d") { + // Write step: verify the written content contains the replacement. + // Extract base64 from: echo '' | base64 -d > 'path' + parts := strings.Split(cmd, "'") + for _, p := range parts { + decoded, err := base64.StdEncoding.DecodeString(p) + if err == nil && strings.Contains(string(decoded), "goodbye world") { + return &mcpgw.ExecWithCaptureResult{ExitCode: 0}, nil + } + } + return &mcpgw.ExecWithCaptureResult{ExitCode: 0}, nil + } + return &mcpgw.ExecWithCaptureResult{ExitCode: 0}, nil + }, + } + exec := NewExecutor(nil, runner, "/data") + ctx := context.Background() + session := mcpgw.ToolSessionContext{BotID: "bot1"} + + result, err := exec.CallTool(ctx, session, "edit", map[string]any{ + "path": "test.txt", "old_text": "hello", "new_text": "goodbye", + }) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + if callCount < 2 { + t.Errorf("expected at least 2 exec calls (read+write), got %d", callCount) + } +} + +func TestExecutor_CallTool_Exec(t *testing.T) { + runner := &fakeExecRunner{ + result: &mcpgw.ExecWithCaptureResult{ + Stdout: "hello\n", + Stderr: "", + ExitCode: 0, + }, + } + exec := NewExecutor(nil, runner, "/data") + ctx := context.Background() + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(ctx, session, toolExec, map[string]any{"command": "echo hello"}) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + content, _ := result["structuredContent"].(map[string]any) + if content == nil { + t.Fatal("no structuredContent") + } + if content["stdout"] != "hello\n" { + t.Errorf("stdout = %v", content["stdout"]) + } + if content["exit_code"].(uint32) != 0 { + t.Errorf("exit_code = %v", content["exit_code"]) + } +} + +func TestExecutor_CallTool_NoBotID(t *testing.T) { + runner := &fakeExecRunner{result: &mcpgw.ExecWithCaptureResult{}} + exec := NewExecutor(nil, runner, "/data") + ctx := context.Background() + session := mcpgw.ToolSessionContext{} + result, err := exec.CallTool(ctx, session, "read", map[string]any{"path": "x"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when bot_id is missing") + } +} + +func TestNormalizePath(t *testing.T) { + tests := []struct { + in string + want string + }{ + {"/data/test.txt", "test.txt"}, + {"/data/foo/bar.txt", "foo/bar.txt"}, + {"/data", "."}, + {"test.txt", "test.txt"}, + {"", ""}, + {".", "."}, + } + for _, tt := range tests { + got := normalizePath(tt.in) + if got != tt.want { + t.Errorf("normalizePath(%q) = %q, want %q", tt.in, got, tt.want) + } + } +} diff --git a/internal/mcp/providers/directory/provider.go b/internal/mcp/providers/directory/provider.go new file mode 100644 index 00000000..92fa0753 --- /dev/null +++ b/internal/mcp/providers/directory/provider.go @@ -0,0 +1,162 @@ +package directory + +import ( + "context" + "log/slog" + "strings" + + "github.com/memohai/memoh/internal/channel" + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +const toolLookupChannelUser = "lookup_channel_user" + +// ConfigResolver resolves effective channel config for a bot (used to call directory APIs). +type ConfigResolver interface { + ResolveEffectiveConfig(ctx context.Context, botID string, channelType channel.ChannelType) (channel.ChannelConfig, error) +} + +// ChannelTypeResolver parses platform name to channel type. +type ChannelTypeResolver interface { + ParseChannelType(raw string) (channel.ChannelType, error) +} + +// Executor exposes channel directory lookup as an MCP tool for the LLM. +type Executor struct { + registry *channel.Registry + configResolver ConfigResolver + typeResolver ChannelTypeResolver + logger *slog.Logger +} + +// NewExecutor creates a directory tool executor. +func NewExecutor(log *slog.Logger, registry *channel.Registry, configResolver ConfigResolver, typeResolver ChannelTypeResolver) *Executor { + if log == nil { + log = slog.Default() + } + return &Executor{ + registry: registry, + configResolver: configResolver, + typeResolver: typeResolver, + logger: log.With(slog.String("provider", "directory_tool")), + } +} + +// ListTools returns the lookup_channel_user tool descriptor. +func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { + if p.registry == nil || p.configResolver == nil || p.typeResolver == nil { + return []mcpgw.ToolDescriptor{}, nil + } + return []mcpgw.ToolDescriptor{ + { + Name: toolLookupChannelUser, + Description: "Look up a user or group on a channel by platform identifier (e.g. open_id, user_id, chat_id). Returns display name, handle, and id.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "platform": map[string]any{ + "type": "string", + "description": "Channel platform (e.g. feishu, telegram). Defaults to current session platform.", + }, + "bot_id": map[string]any{ + "type": "string", + "description": "Bot ID. Defaults to current session bot.", + }, + "input": map[string]any{ + "type": "string", + "description": "Platform-specific identifier: user id (feishu open_id/user_id, telegram chat_id for private), or \"chat_id:user_id\" for a user in a group (telegram).", + }, + "kind": map[string]any{ + "type": "string", + "description": "Entry kind: \"user\" or \"group\". Default \"user\".", + "enum": []any{"user", "group"}, + }, + }, + "required": []string{"input"}, + }, + }, + }, nil +} + +// CallTool runs lookup_channel_user. +func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { + if toolName != toolLookupChannelUser { + return nil, mcpgw.ErrToolNotFound + } + if p.registry == nil || p.configResolver == nil || p.typeResolver == nil { + return mcpgw.BuildToolErrorResult("directory lookup not available"), nil + } + + botID := mcpgw.FirstStringArg(arguments, "bot_id") + if botID == "" { + botID = strings.TrimSpace(session.BotID) + } + if botID == "" { + return mcpgw.BuildToolErrorResult("bot_id is required"), nil + } + if strings.TrimSpace(session.BotID) != "" && botID != strings.TrimSpace(session.BotID) { + return mcpgw.BuildToolErrorResult("bot_id mismatch"), nil + } + + platform := mcpgw.FirstStringArg(arguments, "platform") + if platform == "" { + platform = strings.TrimSpace(session.CurrentPlatform) + } + if platform == "" { + return mcpgw.BuildToolErrorResult("platform is required"), nil + } + channelType, err := p.typeResolver.ParseChannelType(platform) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + + dirAdapter, ok := p.registry.DirectoryAdapter(channelType) + if !ok || dirAdapter == nil { + return mcpgw.BuildToolErrorResult("channel does not support directory lookup"), nil + } + + input := strings.TrimSpace(mcpgw.FirstStringArg(arguments, "input")) + if input == "" { + return mcpgw.BuildToolErrorResult("input is required"), nil + } + + kindStr := strings.ToLower(strings.TrimSpace(mcpgw.FirstStringArg(arguments, "kind"))) + if kindStr == "" { + kindStr = "user" + } + var kind channel.DirectoryEntryKind + switch kindStr { + case "user": + kind = channel.DirectoryEntryUser + case "group": + kind = channel.DirectoryEntryGroup + default: + return mcpgw.BuildToolErrorResult("kind must be user or group"), nil + } + + cfg, err := p.configResolver.ResolveEffectiveConfig(ctx, botID, channelType) + if err != nil { + p.logger.Warn("resolve config failed", slog.String("bot_id", botID), slog.String("platform", platform), slog.Any("error", err)) + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + + entry, err := dirAdapter.ResolveEntry(ctx, cfg, input, kind) + if err != nil { + p.logger.Warn("resolve entry failed", slog.String("input", input), slog.String("kind", kindStr), slog.Any("error", err)) + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + + payload := map[string]any{ + "ok": true, + "platform": channelType.String(), + "kind": string(entry.Kind), + "id": entry.ID, + "name": entry.Name, + "handle": entry.Handle, + "metadata": entry.Metadata, + } + if entry.AvatarURL != "" { + payload["avatar_url"] = entry.AvatarURL + } + return mcpgw.BuildToolSuccessResult(payload), nil +} diff --git a/internal/mcp/providers/directory/provider_test.go b/internal/mcp/providers/directory/provider_test.go new file mode 100644 index 00000000..1f42969a --- /dev/null +++ b/internal/mcp/providers/directory/provider_test.go @@ -0,0 +1,72 @@ +package directory + +import ( + "context" + "testing" + + "github.com/memohai/memoh/internal/channel" + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +func TestExecutor_ListTools(t *testing.T) { + reg := channel.NewRegistry() + reg.MustRegister(&dirMockAdapter{channelType: "dir-test"}) + svc := &fakeConfigResolver{} + exec := NewExecutor(nil, reg, svc, reg) + tools, err := exec.ListTools(context.Background(), mcpgw.ToolSessionContext{}) + if err != nil { + t.Fatalf("ListTools: %v", err) + } + if len(tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(tools)) + } + if tools[0].Name != toolLookupChannelUser { + t.Errorf("tool name = %q", tools[0].Name) + } +} + +func TestExecutor_ListTools_NilDeps(t *testing.T) { + exec := NewExecutor(nil, nil, nil, nil) + tools, err := exec.ListTools(context.Background(), mcpgw.ToolSessionContext{}) + if err != nil { + t.Fatalf("ListTools: %v", err) + } + if len(tools) != 0 { + t.Errorf("expected 0 tools when deps nil, got %d", len(tools)) + } +} + +func TestExecutor_CallTool_NotFound(t *testing.T) { + exec := NewExecutor(nil, channel.NewRegistry(), &fakeConfigResolver{}, channel.NewRegistry()) + _, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{}, "other_tool", nil) + if err != mcpgw.ErrToolNotFound { + t.Errorf("expected ErrToolNotFound, got %v", err) + } +} + +type dirMockAdapter struct { + channelType channel.ChannelType +} + +func (d *dirMockAdapter) Type() channel.ChannelType { return d.channelType } +func (d *dirMockAdapter) Descriptor() channel.Descriptor { + return channel.Descriptor{Type: d.channelType, DisplayName: "DirTest"} +} +func (d *dirMockAdapter) ListPeers(context.Context, channel.ChannelConfig, channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} +func (d *dirMockAdapter) ListGroups(context.Context, channel.ChannelConfig, channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} +func (d *dirMockAdapter) ListGroupMembers(context.Context, channel.ChannelConfig, string, channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} +func (d *dirMockAdapter) ResolveEntry(context.Context, channel.ChannelConfig, string, channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + return channel.DirectoryEntry{Kind: channel.DirectoryEntryUser, ID: "id1", Name: "Test User"}, nil +} + +type fakeConfigResolver struct{} + +func (f *fakeConfigResolver) ResolveEffectiveConfig(context.Context, string, channel.ChannelType) (channel.ChannelConfig, error) { + return channel.ChannelConfig{}, nil +} diff --git a/internal/mcp/providers/memory/provider.go b/internal/mcp/providers/memory/provider.go new file mode 100644 index 00000000..3ba0e975 --- /dev/null +++ b/internal/mcp/providers/memory/provider.go @@ -0,0 +1,205 @@ +package memory + +import ( + "context" + "log/slog" + "sort" + "strings" + + "github.com/memohai/memoh/internal/conversation" + mcpgw "github.com/memohai/memoh/internal/mcp" + mem "github.com/memohai/memoh/internal/memory" +) + +const ( + toolSearchMemory = "search_memory" + defaultMemoryToolLimit = 8 + maxMemoryToolLimit = 50 + sharedMemoryNamespace = "bot" +) + +type MemorySearcher interface { + Search(ctx context.Context, req mem.SearchRequest) (mem.SearchResponse, error) +} + +type AdminChecker interface { + IsAdmin(ctx context.Context, channelIdentityID string) (bool, error) +} + +type Executor struct { + searcher MemorySearcher + chatAccessor conversation.Accessor + adminChecker AdminChecker + logger *slog.Logger +} + +func NewExecutor(log *slog.Logger, searcher MemorySearcher, chatAccessor conversation.Accessor, adminChecker AdminChecker) *Executor { + if log == nil { + log = slog.Default() + } + return &Executor{ + searcher: searcher, + chatAccessor: chatAccessor, + adminChecker: adminChecker, + logger: log.With(slog.String("provider", "memory_tool")), + } +} + +func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { + if p.searcher == nil || p.chatAccessor == nil { + return []mcpgw.ToolDescriptor{}, nil + } + return []mcpgw.ToolDescriptor{ + { + Name: toolSearchMemory, + Description: "Search for memories relevant to the current chat", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{ + "type": "string", + "description": "The query to search memories", + }, + "limit": map[string]any{ + "type": "integer", + "description": "Maximum number of memory results", + }, + }, + "required": []string{"query"}, + }, + }, + }, nil +} + +func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { + if toolName != toolSearchMemory { + return nil, mcpgw.ErrToolNotFound + } + if p.searcher == nil || p.chatAccessor == nil { + return mcpgw.BuildToolErrorResult("memory service not available"), nil + } + + query := mcpgw.StringArg(arguments, "query") + if query == "" { + return mcpgw.BuildToolErrorResult("query is required"), nil + } + botID := strings.TrimSpace(session.BotID) + chatID := strings.TrimSpace(session.ChatID) + channelIdentityID := strings.TrimSpace(session.ChannelIdentityID) + if botID == "" { + return mcpgw.BuildToolErrorResult("bot_id is required"), nil + } + if chatID == "" { + chatID = botID + } + + limit := defaultMemoryToolLimit + if value, ok, err := mcpgw.IntArg(arguments, "limit"); err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } else if ok { + limit = value + } + if limit <= 0 { + limit = defaultMemoryToolLimit + } + if limit > maxMemoryToolLimit { + limit = maxMemoryToolLimit + } + + // When ChatID equals BotID (e.g. tools called without conversation context), search by bot scope only. + // Otherwise require the conversation to exist and the caller to be a participant. + if chatID != botID { + chatObj, err := p.chatAccessor.Get(ctx, chatID) + if err != nil { + return mcpgw.BuildToolErrorResult("chat not found"), nil + } + if strings.TrimSpace(chatObj.BotID) != botID { + return mcpgw.BuildToolErrorResult("bot mismatch"), nil + } + if channelIdentityID != "" { + allowed, err := p.canAccessChat(ctx, chatID, channelIdentityID) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + if !allowed { + return mcpgw.BuildToolErrorResult("not a chat participant"), nil + } + } + } + + resp, err := p.searcher.Search(ctx, mem.SearchRequest{ + Query: query, + BotID: botID, + Limit: limit, + Filters: map[string]any{ + "namespace": sharedMemoryNamespace, + "scopeId": botID, + "botId": botID, + }, + }) + if err != nil { + p.logger.Warn("memory search namespace failed", slog.String("namespace", sharedMemoryNamespace), slog.Any("error", err)) + return mcpgw.BuildToolErrorResult("memory search failed"), nil + } + allResults := make([]mem.MemoryItem, 0, len(resp.Results)) + allResults = append(allResults, resp.Results...) + + allResults = deduplicateMemoryItems(allResults) + sort.Slice(allResults, func(i, j int) bool { + return allResults[i].Score > allResults[j].Score + }) + if len(allResults) > limit { + allResults = allResults[:limit] + } + + results := make([]map[string]any, 0, len(allResults)) + for _, item := range allResults { + results = append(results, map[string]any{ + "id": item.ID, + "memory": item.Memory, + "score": item.Score, + }) + } + + return mcpgw.BuildToolSuccessResult(map[string]any{ + "query": query, + "total": len(results), + "results": results, + }), nil +} + +func (p *Executor) canAccessChat(ctx context.Context, chatID, channelIdentityID string) (bool, error) { + if p.adminChecker != nil { + isAdmin, err := p.adminChecker.IsAdmin(ctx, channelIdentityID) + if err != nil { + return false, err + } + if isAdmin { + return true, nil + } + } + return p.chatAccessor.IsParticipant(ctx, chatID, channelIdentityID) +} + +func deduplicateMemoryItems(items []mem.MemoryItem) []mem.MemoryItem { + if len(items) == 0 { + return items + } + seen := make(map[string]struct{}, len(items)) + result := make([]mem.MemoryItem, 0, len(items)) + for _, item := range items { + id := strings.TrimSpace(item.ID) + if id == "" { + id = strings.TrimSpace(item.Memory) + } + if id == "" { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + result = append(result, item) + } + return result +} diff --git a/internal/mcp/providers/memory/provider_test.go b/internal/mcp/providers/memory/provider_test.go new file mode 100644 index 00000000..edc0d22e --- /dev/null +++ b/internal/mcp/providers/memory/provider_test.go @@ -0,0 +1,284 @@ +package memory + +import ( + "context" + "errors" + "testing" + + "github.com/memohai/memoh/internal/conversation" + mcpgw "github.com/memohai/memoh/internal/mcp" + "github.com/memohai/memoh/internal/memory" +) + +type fakeSearcher struct { + resp memory.SearchResponse + err error +} + +func (f *fakeSearcher) Search(ctx context.Context, req memory.SearchRequest) (memory.SearchResponse, error) { + if f.err != nil { + return memory.SearchResponse{}, f.err + } + return f.resp, nil +} + +type fakeChatAccessor struct { + chat conversation.Chat + getErr error + participant bool + participantErr error +} + +func (f *fakeChatAccessor) Get(ctx context.Context, conversationID string) (conversation.Chat, error) { + if f.getErr != nil { + return conversation.Chat{}, f.getErr + } + return f.chat, nil +} + +func (f *fakeChatAccessor) IsParticipant(ctx context.Context, conversationID, channelIdentityID string) (bool, error) { + if f.participantErr != nil { + return false, f.participantErr + } + return f.participant, nil +} + +func (f *fakeChatAccessor) GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (conversation.ChatReadAccess, error) { + return conversation.ChatReadAccess{}, nil +} + +type fakeAdminChecker struct { + admin bool + err error +} + +func (f *fakeAdminChecker) IsAdmin(ctx context.Context, channelIdentityID string) (bool, error) { + if f.err != nil { + return false, f.err + } + return f.admin, nil +} + +func TestExecutor_ListTools_NilDeps(t *testing.T) { + exec := NewExecutor(nil, nil, nil, nil) + tools, err := exec.ListTools(context.Background(), mcpgw.ToolSessionContext{}) + if err != nil { + t.Fatal(err) + } + if len(tools) != 0 { + t.Errorf("expected 0 tools when deps nil, got %d", len(tools)) + } +} + +func TestExecutor_ListTools(t *testing.T) { + searcher := &fakeSearcher{} + accessor := &fakeChatAccessor{} + exec := NewExecutor(nil, searcher, accessor, nil) + tools, err := exec.ListTools(context.Background(), mcpgw.ToolSessionContext{}) + if err != nil { + t.Fatal(err) + } + if len(tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(tools)) + } + if tools[0].Name != toolSearchMemory { + t.Errorf("tool name = %q, want %q", tools[0].Name, toolSearchMemory) + } +} + +func TestExecutor_CallTool_NotFound(t *testing.T) { + searcher := &fakeSearcher{} + accessor := &fakeChatAccessor{} + exec := NewExecutor(nil, searcher, accessor, nil) + _, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, "other_tool", nil) + if err != mcpgw.ErrToolNotFound { + t.Errorf("expected ErrToolNotFound, got %v", err) + } +} + +func TestExecutor_CallTool_NilDeps(t *testing.T) { + exec := NewExecutor(nil, nil, nil, nil) + result, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, toolSearchMemory, map[string]any{"query": "x"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error result when deps nil") + } +} + +func TestExecutor_CallTool_NoQuery(t *testing.T) { + searcher := &fakeSearcher{} + accessor := &fakeChatAccessor{} + exec := NewExecutor(nil, searcher, accessor, nil) + result, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, toolSearchMemory, map[string]any{}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when query is empty") + } +} + +func TestExecutor_CallTool_NoBotID(t *testing.T) { + searcher := &fakeSearcher{} + accessor := &fakeChatAccessor{} + exec := NewExecutor(nil, searcher, accessor, nil) + result, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{}, toolSearchMemory, map[string]any{"query": "q"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when bot_id is missing") + } +} + +func TestExecutor_CallTool_Success_BotScope(t *testing.T) { + searcher := &fakeSearcher{ + resp: memory.SearchResponse{ + Results: []memory.MemoryItem{ + {ID: "id1", Memory: "mem1", Score: 0.9}, + }, + }, + } + accessor := &fakeChatAccessor{} + exec := NewExecutor(nil, searcher, accessor, nil) + ctx := context.Background() + session := mcpgw.ToolSessionContext{BotID: "bot1", ChatID: "bot1"} + result, err := exec.CallTool(ctx, session, toolSearchMemory, map[string]any{"query": "test"}) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + content, _ := result["structuredContent"].(map[string]any) + if content == nil { + t.Fatal("no structuredContent") + } + if content["query"] != "test" { + t.Errorf("query = %v", content["query"]) + } + if content["total"] != 1 { + t.Errorf("total = %v", content["total"]) + } +} + +func TestExecutor_CallTool_ChatNotFound(t *testing.T) { + searcher := &fakeSearcher{} + accessor := &fakeChatAccessor{getErr: errors.New("not found")} + exec := NewExecutor(nil, searcher, accessor, nil) + session := mcpgw.ToolSessionContext{BotID: "bot1", ChatID: "chat-other"} + result, err := exec.CallTool(context.Background(), session, toolSearchMemory, map[string]any{"query": "q"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when chat not found") + } +} + +func TestExecutor_CallTool_BotMismatch(t *testing.T) { + accessor := &fakeChatAccessor{ + chat: conversation.Chat{BotID: "other-bot", ID: "c1"}, + } + searcher := &fakeSearcher{} + exec := NewExecutor(nil, searcher, accessor, nil) + session := mcpgw.ToolSessionContext{BotID: "bot1", ChatID: "c1"} + result, err := exec.CallTool(context.Background(), session, toolSearchMemory, map[string]any{"query": "q"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when bot mismatch") + } +} + +func TestExecutor_CallTool_NotParticipant(t *testing.T) { + accessor := &fakeChatAccessor{ + chat: conversation.Chat{BotID: "bot1", ID: "c1"}, + participant: false, + } + searcher := &fakeSearcher{} + exec := NewExecutor(nil, searcher, accessor, &fakeAdminChecker{admin: false}) + session := mcpgw.ToolSessionContext{BotID: "bot1", ChatID: "c1", ChannelIdentityID: "user1"} + result, err := exec.CallTool(context.Background(), session, toolSearchMemory, map[string]any{"query": "q"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when not participant") + } +} + +func TestExecutor_CallTool_AdminBypass(t *testing.T) { + searcher := &fakeSearcher{ + resp: memory.SearchResponse{Results: []memory.MemoryItem{{ID: "id1", Memory: "m1", Score: 0.8}}}, + } + accessor := &fakeChatAccessor{ + chat: conversation.Chat{BotID: "bot1", ID: "c1"}, + participant: false, + } + admin := &fakeAdminChecker{admin: true} + exec := NewExecutor(nil, searcher, accessor, admin) + session := mcpgw.ToolSessionContext{BotID: "bot1", ChatID: "c1", ChannelIdentityID: "admin1"} + result, err := exec.CallTool(context.Background(), session, toolSearchMemory, map[string]any{"query": "q"}) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + content, _ := result["structuredContent"].(map[string]any) + if content == nil { + t.Fatal("no structuredContent") + } + if v, ok := content["total"].(int); !ok || v != 1 { + t.Errorf("total = %v", content["total"]) + } +} + +func TestExecutor_CallTool_SearchError(t *testing.T) { + searcher := &fakeSearcher{err: errors.New("search failed")} + accessor := &fakeChatAccessor{} + exec := NewExecutor(nil, searcher, accessor, nil) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolSearchMemory, map[string]any{"query": "q"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when search fails") + } +} + +func TestDeduplicateMemoryItems(t *testing.T) { + tests := []struct { + name string + items []memory.MemoryItem + wantLen int + }{ + {"empty", nil, 0}, + {"single", []memory.MemoryItem{{ID: "a", Memory: "m", Score: 1}}, 1}, + {"dedup by id", []memory.MemoryItem{ + {ID: "a", Memory: "m1", Score: 1}, + {ID: "a", Memory: "m2", Score: 0.9}, + }, 1}, + {"dedup by memory when id empty", []memory.MemoryItem{ + {ID: "", Memory: "same", Score: 1}, + {ID: "", Memory: "same", Score: 0.9}, + }, 1}, + {"no dedup", []memory.MemoryItem{ + {ID: "a", Memory: "m1", Score: 1}, + {ID: "b", Memory: "m2", Score: 0.9}, + }, 2}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := deduplicateMemoryItems(tt.items) + if len(got) != tt.wantLen { + t.Errorf("deduplicateMemoryItems() length = %d, want %d", len(got), tt.wantLen) + } + }) + } +} diff --git a/internal/mcp/providers/message/provider.go b/internal/mcp/providers/message/provider.go new file mode 100644 index 00000000..81911da2 --- /dev/null +++ b/internal/mcp/providers/message/provider.go @@ -0,0 +1,179 @@ +package message + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + + "github.com/memohai/memoh/internal/channel" + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +const toolSendMessage = "send_message" + +type Sender interface { + Send(ctx context.Context, botID string, channelType channel.ChannelType, req channel.SendRequest) error +} + +type ChannelTypeResolver interface { + ParseChannelType(raw string) (channel.ChannelType, error) +} + +type Executor struct { + sender Sender + resolver ChannelTypeResolver + logger *slog.Logger +} + +func NewExecutor(log *slog.Logger, sender Sender, resolver ChannelTypeResolver) *Executor { + if log == nil { + log = slog.Default() + } + return &Executor{ + sender: sender, + resolver: resolver, + logger: log.With(slog.String("provider", "message_tool")), + } +} + +func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { + if p.sender == nil || p.resolver == nil { + return []mcpgw.ToolDescriptor{}, nil + } + return []mcpgw.ToolDescriptor{ + { + Name: toolSendMessage, + Description: "Send a message to a channel or session", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "bot_id": map[string]any{ + "type": "string", + "description": "Bot ID, optional and defaults to current bot", + }, + "platform": map[string]any{ + "type": "string", + "description": "Channel platform name", + }, + "target": map[string]any{ + "type": "string", + "description": "Channel target (chat/group/thread ID)", + }, + "channel_identity_id": map[string]any{ + "type": "string", + "description": "Target identity ID when direct target is absent", + }, + "to_user_id": map[string]any{ + "type": "string", + "description": "Alias for channel_identity_id", + }, + "text": map[string]any{ + "type": "string", + "description": "Message text shortcut when message object is omitted", + }, + "message": map[string]any{ + "type": "object", + "description": "Structured message payload with text/parts/attachments", + }, + }, + "required": []string{}, + }, + }, + }, nil +} + +func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { + if toolName != toolSendMessage { + return nil, mcpgw.ErrToolNotFound + } + if p.sender == nil || p.resolver == nil { + return mcpgw.BuildToolErrorResult("message service not available"), nil + } + + botID := mcpgw.FirstStringArg(arguments, "bot_id") + if botID == "" { + botID = strings.TrimSpace(session.BotID) + } + if botID == "" { + return mcpgw.BuildToolErrorResult("bot_id is required"), nil + } + if strings.TrimSpace(session.BotID) != "" && botID != strings.TrimSpace(session.BotID) { + return mcpgw.BuildToolErrorResult("bot_id mismatch"), nil + } + + platform := mcpgw.FirstStringArg(arguments, "platform") + if platform == "" { + platform = strings.TrimSpace(session.CurrentPlatform) + } + if platform == "" { + return mcpgw.BuildToolErrorResult("platform is required"), nil + } + channelType, err := p.resolver.ParseChannelType(platform) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + + messageText := mcpgw.FirstStringArg(arguments, "text") + outboundMessage, parseErr := parseOutboundMessage(arguments, messageText) + if parseErr != nil { + return mcpgw.BuildToolErrorResult(parseErr.Error()), nil + } + + target := mcpgw.FirstStringArg(arguments, "target") + if target == "" { + target = strings.TrimSpace(session.ReplyTarget) + } + channelIdentityID := mcpgw.FirstStringArg(arguments, "channel_identity_id", "to_user_id") + if target == "" && channelIdentityID == "" { + return mcpgw.BuildToolErrorResult("target or channel_identity_id is required"), nil + } + + sendReq := channel.SendRequest{ + Target: target, + ChannelIdentityID: channelIdentityID, + Message: outboundMessage, + } + if err := p.sender.Send(ctx, botID, channelType, sendReq); err != nil { + p.logger.Warn("send message failed", slog.Any("error", err), slog.String("bot_id", botID), slog.String("platform", platform)) + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + + payload := map[string]any{ + "ok": true, + "bot_id": botID, + "platform": channelType.String(), + "target": target, + "channel_identity_id": channelIdentityID, + "instruction": "Message delivered successfully. You have completed your response. Please STOP now and do not call any more tools.", + } + return mcpgw.BuildToolSuccessResult(payload), nil +} + +func parseOutboundMessage(arguments map[string]any, fallbackText string) (channel.Message, error) { + var msg channel.Message + if raw, ok := arguments["message"]; ok && raw != nil { + switch value := raw.(type) { + case string: + msg.Text = strings.TrimSpace(value) + case map[string]any: + data, err := json.Marshal(value) + if err != nil { + return channel.Message{}, err + } + if err := json.Unmarshal(data, &msg); err != nil { + return channel.Message{}, err + } + default: + return channel.Message{}, fmt.Errorf("message must be object or string") + } + } + if msg.IsEmpty() && strings.TrimSpace(fallbackText) != "" { + msg.Text = strings.TrimSpace(fallbackText) + } + if msg.IsEmpty() { + return channel.Message{}, fmt.Errorf("message is required") + } + return msg, nil +} diff --git a/internal/mcp/providers/message/provider_test.go b/internal/mcp/providers/message/provider_test.go new file mode 100644 index 00000000..df6b9fa3 --- /dev/null +++ b/internal/mcp/providers/message/provider_test.go @@ -0,0 +1,247 @@ +package message + +import ( + "context" + "errors" + "testing" + + "github.com/memohai/memoh/internal/channel" + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +type fakeSender struct { + err error +} + +func (f *fakeSender) Send(ctx context.Context, botID string, channelType channel.ChannelType, req channel.SendRequest) error { + return f.err +} + +type fakeResolver struct { + ct channel.ChannelType + err error +} + +func (f *fakeResolver) ParseChannelType(raw string) (channel.ChannelType, error) { + if f.err != nil { + return "", f.err + } + return f.ct, nil +} + +func TestExecutor_ListTools_NilDeps(t *testing.T) { + exec := NewExecutor(nil, nil, nil) + tools, err := exec.ListTools(context.Background(), mcpgw.ToolSessionContext{}) + if err != nil { + t.Fatal(err) + } + if len(tools) != 0 { + t.Errorf("expected 0 tools when deps nil, got %d", len(tools)) + } +} + +func TestExecutor_ListTools(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + exec := NewExecutor(nil, sender, resolver) + tools, err := exec.ListTools(context.Background(), mcpgw.ToolSessionContext{}) + if err != nil { + t.Fatal(err) + } + if len(tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(tools)) + } + if tools[0].Name != toolSendMessage { + t.Errorf("tool name = %q, want %q", tools[0].Name, toolSendMessage) + } +} + +func TestExecutor_CallTool_NotFound(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + exec := NewExecutor(nil, sender, resolver) + _, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, "other_tool", nil) + if err != mcpgw.ErrToolNotFound { + t.Errorf("expected ErrToolNotFound, got %v", err) + } +} + +func TestExecutor_CallTool_NilDeps(t *testing.T) { + exec := NewExecutor(nil, nil, nil) + result, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, toolSendMessage, map[string]any{ + "platform": "feishu", "target": "t1", "text": "hi", + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error result when deps nil") + } +} + +func TestExecutor_CallTool_NoBotID(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + exec := NewExecutor(nil, sender, resolver) + result, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{}, toolSendMessage, map[string]any{ + "platform": "feishu", "target": "t1", "text": "hi", + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when bot_id is missing") + } +} + +func TestExecutor_CallTool_BotIDMismatch(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + exec := NewExecutor(nil, sender, resolver) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolSendMessage, map[string]any{ + "bot_id": "other", "platform": "feishu", "target": "t1", "text": "hi", + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when bot_id mismatch") + } +} + +func TestExecutor_CallTool_NoPlatform(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + exec := NewExecutor(nil, sender, resolver) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolSendMessage, map[string]any{ + "target": "t1", "text": "hi", + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when platform is missing") + } +} + +func TestExecutor_CallTool_PlatformParseError(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{err: errors.New("unknown platform")} + exec := NewExecutor(nil, sender, resolver) + session := mcpgw.ToolSessionContext{BotID: "bot1", CurrentPlatform: "feishu"} + result, err := exec.CallTool(context.Background(), session, toolSendMessage, map[string]any{ + "platform": "bad", "target": "t1", "text": "hi", + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when platform parse fails") + } +} + +func TestExecutor_CallTool_NoMessage(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + exec := NewExecutor(nil, sender, resolver) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolSendMessage, map[string]any{ + "platform": "feishu", "target": "t1", + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when message/text is missing") + } +} + +func TestExecutor_CallTool_NoTarget(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + exec := NewExecutor(nil, sender, resolver) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolSendMessage, map[string]any{ + "platform": "feishu", "text": "hi", + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when target and channel_identity_id are missing") + } +} + +func TestExecutor_CallTool_SendError(t *testing.T) { + sender := &fakeSender{err: errors.New("send failed")} + resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + exec := NewExecutor(nil, sender, resolver) + session := mcpgw.ToolSessionContext{BotID: "bot1", ReplyTarget: "t1"} + result, err := exec.CallTool(context.Background(), session, toolSendMessage, map[string]any{ + "platform": "feishu", "text": "hi", + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when Send fails") + } +} + +func TestExecutor_CallTool_Success(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + exec := NewExecutor(nil, sender, resolver) + session := mcpgw.ToolSessionContext{BotID: "bot1", CurrentPlatform: "feishu", ReplyTarget: "chat1"} + result, err := exec.CallTool(context.Background(), session, toolSendMessage, map[string]any{ + "text": "hello", + }) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + content, _ := result["structuredContent"].(map[string]any) + if content == nil { + t.Fatal("no structuredContent") + } + if content["ok"] != true { + t.Errorf("ok = %v", content["ok"]) + } + if content["platform"] != "feishu" { + t.Errorf("platform = %v", content["platform"]) + } +} + +func TestParseOutboundMessage(t *testing.T) { + tests := []struct { + name string + args map[string]any + fallback string + wantEmpty bool + wantErr bool + }{ + {"text fallback", map[string]any{}, "hello", false, false}, + {"message string", map[string]any{"message": "msg"}, "", false, false}, + {"message object", map[string]any{"message": map[string]any{"text": "obj"}}, "", false, false}, + {"empty", map[string]any{}, "", true, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msg, err := parseOutboundMessage(tt.args, tt.fallback) + if (err != nil) != tt.wantErr { + t.Errorf("parseOutboundMessage() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantEmpty && !msg.IsEmpty() { + t.Error("expected empty message") + } + if !tt.wantEmpty && msg.IsEmpty() { + t.Error("expected non-empty message") + } + }) + } +} diff --git a/internal/mcp/providers/schedule/provider.go b/internal/mcp/providers/schedule/provider.go new file mode 100644 index 00000000..71e37b3e --- /dev/null +++ b/internal/mcp/providers/schedule/provider.go @@ -0,0 +1,259 @@ +package schedule + +import ( + "context" + "log/slog" + "strings" + + mcpgw "github.com/memohai/memoh/internal/mcp" + sched "github.com/memohai/memoh/internal/schedule" +) + +const ( + toolScheduleList = "schedule_list" + toolScheduleGet = "schedule_get" + toolScheduleCreate = "schedule_create" + toolScheduleUpdate = "schedule_update" + toolScheduleDelete = "schedule_delete" +) + +type Scheduler interface { + List(ctx context.Context, botID string) ([]sched.Schedule, error) + Get(ctx context.Context, id string) (sched.Schedule, error) + Create(ctx context.Context, botID string, req sched.CreateRequest) (sched.Schedule, error) + Update(ctx context.Context, id string, req sched.UpdateRequest) (sched.Schedule, error) + Delete(ctx context.Context, id string) error +} + +type Executor struct { + service Scheduler + logger *slog.Logger +} + +func NewExecutor(log *slog.Logger, service Scheduler) *Executor { + if log == nil { + log = slog.Default() + } + return &Executor{ + service: service, + logger: log.With(slog.String("provider", "schedule_tool")), + } +} + +func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { + if p.service == nil { + return []mcpgw.ToolDescriptor{}, nil + } + return []mcpgw.ToolDescriptor{ + { + Name: toolScheduleList, + Description: "List schedules for current bot", + InputSchema: emptyObjectSchema(), + }, + { + Name: toolScheduleGet, + Description: "Get a schedule by id", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "string", "description": "Schedule ID"}, + }, + "required": []string{"id"}, + }, + }, + { + Name: toolScheduleCreate, + Description: "Create a new schedule", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "description": map[string]any{"type": "string"}, + "pattern": map[string]any{"type": "string"}, + "max_calls": map[string]any{ + "type": []string{"integer", "null"}, + "description": "Optional max calls, null means unlimited", + }, + "enabled": map[string]any{"type": "boolean"}, + "command": map[string]any{"type": "string"}, + }, + "required": []string{"name", "description", "pattern", "command"}, + }, + }, + { + Name: toolScheduleUpdate, + Description: "Update an existing schedule", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "string"}, + "name": map[string]any{"type": "string"}, + "description": map[string]any{"type": "string"}, + "pattern": map[string]any{"type": "string"}, + "max_calls": map[string]any{"type": []string{"integer", "null"}}, + "enabled": map[string]any{"type": "boolean"}, + "command": map[string]any{"type": "string"}, + }, + "required": []string{"id"}, + }, + }, + { + Name: toolScheduleDelete, + Description: "Delete a schedule by id", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "string", "description": "Schedule ID"}, + }, + "required": []string{"id"}, + }, + }, + }, nil +} + +func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { + if p.service == nil { + return mcpgw.BuildToolErrorResult("schedule service not available"), nil + } + botID := strings.TrimSpace(session.BotID) + if botID == "" { + return mcpgw.BuildToolErrorResult("bot_id is required"), nil + } + + switch toolName { + case toolScheduleList: + items, err := p.service.List(ctx, botID) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + return mcpgw.BuildToolSuccessResult(map[string]any{ + "items": items, + }), nil + case toolScheduleGet: + id := mcpgw.StringArg(arguments, "id") + if id == "" { + return mcpgw.BuildToolErrorResult("id is required"), nil + } + item, err := p.service.Get(ctx, id) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + if item.BotID != botID { + return mcpgw.BuildToolErrorResult("bot mismatch"), nil + } + return mcpgw.BuildToolSuccessResult(item), nil + case toolScheduleCreate: + name := mcpgw.StringArg(arguments, "name") + description := mcpgw.StringArg(arguments, "description") + pattern := mcpgw.StringArg(arguments, "pattern") + command := mcpgw.StringArg(arguments, "command") + if name == "" || description == "" || pattern == "" || command == "" { + return mcpgw.BuildToolErrorResult("name, description, pattern, command are required"), nil + } + + req := sched.CreateRequest{ + Name: name, + Description: description, + Pattern: pattern, + Command: command, + } + maxCalls, err := parseNullableIntArg(arguments, "max_calls") + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + req.MaxCalls = maxCalls + if enabled, ok, err := mcpgw.BoolArg(arguments, "enabled"); err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } else if ok { + req.Enabled = &enabled + } + item, err := p.service.Create(ctx, botID, req) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + return mcpgw.BuildToolSuccessResult(item), nil + case toolScheduleUpdate: + id := mcpgw.StringArg(arguments, "id") + if id == "" { + return mcpgw.BuildToolErrorResult("id is required"), nil + } + req := sched.UpdateRequest{} + maxCalls, err := parseNullableIntArg(arguments, "max_calls") + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + req.MaxCalls = maxCalls + if value := mcpgw.StringArg(arguments, "name"); value != "" { + req.Name = &value + } + if value := mcpgw.StringArg(arguments, "description"); value != "" { + req.Description = &value + } + if value := mcpgw.StringArg(arguments, "pattern"); value != "" { + req.Pattern = &value + } + if value := mcpgw.StringArg(arguments, "command"); value != "" { + req.Command = &value + } + if enabled, ok, err := mcpgw.BoolArg(arguments, "enabled"); err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } else if ok { + req.Enabled = &enabled + } + item, err := p.service.Update(ctx, id, req) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + if item.BotID != botID { + return mcpgw.BuildToolErrorResult("bot mismatch"), nil + } + return mcpgw.BuildToolSuccessResult(item), nil + case toolScheduleDelete: + id := mcpgw.StringArg(arguments, "id") + if id == "" { + return mcpgw.BuildToolErrorResult("id is required"), nil + } + item, err := p.service.Get(ctx, id) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + if item.BotID != botID { + return mcpgw.BuildToolErrorResult("bot mismatch"), nil + } + if err := p.service.Delete(ctx, id); err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + return mcpgw.BuildToolSuccessResult(map[string]any{"success": true}), nil + default: + return nil, mcpgw.ErrToolNotFound + } +} + +func parseNullableIntArg(arguments map[string]any, key string) (sched.NullableInt, error) { + req := sched.NullableInt{} + if arguments == nil { + return req, nil + } + raw, exists := arguments[key] + if !exists { + return req, nil + } + req.Set = true + if raw == nil { + req.Value = nil + return req, nil + } + value, _, err := mcpgw.IntArg(arguments, key) + if err != nil { + return sched.NullableInt{}, err + } + req.Value = &value + return req, nil +} + +func emptyObjectSchema() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} diff --git a/internal/mcp/providers/schedule/provider_test.go b/internal/mcp/providers/schedule/provider_test.go new file mode 100644 index 00000000..43d7d544 --- /dev/null +++ b/internal/mcp/providers/schedule/provider_test.go @@ -0,0 +1,374 @@ +package schedule + +import ( + "context" + "errors" + "testing" + "time" + + mcpgw "github.com/memohai/memoh/internal/mcp" + sched "github.com/memohai/memoh/internal/schedule" +) + +type fakeScheduler struct { + list []sched.Schedule + get sched.Schedule + getErr error + create sched.Schedule + createErr error + update sched.Schedule + updateErr error + deleteErr error +} + +func (f *fakeScheduler) List(ctx context.Context, botID string) ([]sched.Schedule, error) { + return f.list, nil +} + +func (f *fakeScheduler) Get(ctx context.Context, id string) (sched.Schedule, error) { + if f.getErr != nil { + return sched.Schedule{}, f.getErr + } + return f.get, nil +} + +func (f *fakeScheduler) Create(ctx context.Context, botID string, req sched.CreateRequest) (sched.Schedule, error) { + if f.createErr != nil { + return sched.Schedule{}, f.createErr + } + return f.create, nil +} + +func (f *fakeScheduler) Update(ctx context.Context, id string, req sched.UpdateRequest) (sched.Schedule, error) { + if f.updateErr != nil { + return sched.Schedule{}, f.updateErr + } + return f.update, nil +} + +func (f *fakeScheduler) Delete(ctx context.Context, id string) error { + return f.deleteErr +} + +func TestExecutor_ListTools_NilService(t *testing.T) { + exec := NewExecutor(nil, nil) + tools, err := exec.ListTools(context.Background(), mcpgw.ToolSessionContext{}) + if err != nil { + t.Fatal(err) + } + if len(tools) != 0 { + t.Errorf("expected 0 tools when service nil, got %d", len(tools)) + } +} + +func TestExecutor_ListTools(t *testing.T) { + svc := &fakeScheduler{} + exec := NewExecutor(nil, svc) + tools, err := exec.ListTools(context.Background(), mcpgw.ToolSessionContext{}) + if err != nil { + t.Fatal(err) + } + wantNames := []string{toolScheduleList, toolScheduleGet, toolScheduleCreate, toolScheduleUpdate, toolScheduleDelete} + if len(tools) != len(wantNames) { + t.Fatalf("expected %d tools, got %d", len(wantNames), len(tools)) + } + for i, name := range wantNames { + if tools[i].Name != name { + t.Errorf("tools[%d].Name = %q, want %q", i, tools[i].Name, name) + } + } +} + +func TestExecutor_CallTool_NotFound(t *testing.T) { + svc := &fakeScheduler{} + exec := NewExecutor(nil, svc) + _, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, "other_tool", nil) + if err != mcpgw.ErrToolNotFound { + t.Errorf("expected ErrToolNotFound, got %v", err) + } +} + +func TestExecutor_CallTool_NilService(t *testing.T) { + exec := NewExecutor(nil, nil) + result, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, toolScheduleList, nil) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when service nil") + } +} + +func TestExecutor_CallTool_NoBotID(t *testing.T) { + svc := &fakeScheduler{} + exec := NewExecutor(nil, svc) + result, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{}, toolScheduleList, nil) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when bot_id is missing") + } +} + +func TestExecutor_CallTool_List(t *testing.T) { + svc := &fakeScheduler{ + list: []sched.Schedule{ + {ID: "id1", Name: "n1", BotID: "bot1"}, + }, + } + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleList, nil) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + content, _ := result["structuredContent"].(map[string]any) + if content == nil { + t.Fatal("no structuredContent") + } + items, _ := content["items"].([]sched.Schedule) + if len(items) != 1 { + t.Errorf("items length = %d", len(items)) + } +} + +func TestExecutor_CallTool_Get_IdRequired(t *testing.T) { + svc := &fakeScheduler{} + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleGet, map[string]any{}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when id is missing") + } +} + +func TestExecutor_CallTool_Get_BotMismatch(t *testing.T) { + svc := &fakeScheduler{ + get: sched.Schedule{ID: "s1", BotID: "other-bot"}, + } + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleGet, map[string]any{"id": "s1"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when bot mismatch") + } +} + +func TestExecutor_CallTool_Get_Success(t *testing.T) { + svc := &fakeScheduler{ + get: sched.Schedule{ID: "s1", Name: "job1", BotID: "bot1"}, + } + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleGet, map[string]any{"id": "s1"}) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + item, ok := result["structuredContent"].(sched.Schedule) + if !ok { + t.Fatal("structuredContent is not Schedule") + } + if item.ID != "s1" { + t.Errorf("id = %v", item.ID) + } +} + +func TestExecutor_CallTool_Create_RequiredFields(t *testing.T) { + svc := &fakeScheduler{} + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleCreate, map[string]any{ + "name": "n", "description": "d", "pattern": "* * * * *", + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when command is missing") + } +} + +func TestExecutor_CallTool_Create_Success(t *testing.T) { + svc := &fakeScheduler{ + create: sched.Schedule{ + ID: "new1", Name: "n1", Description: "d1", Pattern: "* * * * *", Command: "echo", + BotID: "bot1", CreatedAt: time.Now(), UpdatedAt: time.Now(), + }, + } + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleCreate, map[string]any{ + "name": "n1", "description": "d1", "pattern": "* * * * *", "command": "echo", + }) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + item, ok := result["structuredContent"].(sched.Schedule) + if !ok { + t.Fatal("structuredContent is not Schedule") + } + if item.ID != "new1" { + t.Errorf("id = %v", item.ID) + } +} + +func TestExecutor_CallTool_Update_IdRequired(t *testing.T) { + svc := &fakeScheduler{} + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleUpdate, map[string]any{"name": "n"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when id is missing") + } +} + +func TestExecutor_CallTool_Update_Success(t *testing.T) { + svc := &fakeScheduler{ + update: sched.Schedule{ID: "s1", Name: "updated", BotID: "bot1"}, + } + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleUpdate, map[string]any{ + "id": "s1", "name": "updated", + }) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } +} + +func TestExecutor_CallTool_Delete_IdRequired(t *testing.T) { + svc := &fakeScheduler{} + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleDelete, map[string]any{}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when id is missing") + } +} + +func TestExecutor_CallTool_Delete_BotMismatch(t *testing.T) { + svc := &fakeScheduler{ + get: sched.Schedule{ID: "s1", BotID: "other-bot"}, + } + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleDelete, map[string]any{"id": "s1"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when bot mismatch on delete") + } +} + +func TestExecutor_CallTool_Delete_Success(t *testing.T) { + svc := &fakeScheduler{ + get: sched.Schedule{ID: "s1", BotID: "bot1"}, + } + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleDelete, map[string]any{"id": "s1"}) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + content, _ := result["structuredContent"].(map[string]any) + if content == nil { + t.Fatal("no structuredContent") + } + if success, _ := content["success"].(bool); !success { + t.Errorf("success = %v", content["success"]) + } +} + +func TestExecutor_CallTool_Get_ServiceError(t *testing.T) { + svc := &fakeScheduler{getErr: errors.New("not found")} + exec := NewExecutor(nil, svc) + session := mcpgw.ToolSessionContext{BotID: "bot1"} + result, err := exec.CallTool(context.Background(), session, toolScheduleGet, map[string]any{"id": "missing"}) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Error("expected error when Get fails") + } +} + +func TestParseNullableIntArg(t *testing.T) { + tests := []struct { + name string + args map[string]any + key string + wantSet bool + wantVal *int + wantErr bool + }{ + {"nil args", nil, "x", false, nil, false}, + {"missing key", map[string]any{}, "x", false, nil, false}, + {"null value", map[string]any{"x": nil}, "x", true, nil, false}, + {"int value", map[string]any{"x": 5}, "x", true, intPtr(5), false}, + {"invalid type", map[string]any{"x": "bad"}, "x", false, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseNullableIntArg(tt.args, tt.key) + if (err != nil) != tt.wantErr { + t.Errorf("parseNullableIntArg() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got.Set != tt.wantSet { + t.Errorf("Set = %v, want %v", got.Set, tt.wantSet) + } + if tt.wantVal == nil { + if got.Value != nil { + t.Errorf("Value = %v, want nil", got.Value) + } + } else { + if got.Value == nil || *got.Value != *tt.wantVal { + t.Errorf("Value = %v, want %v", got.Value, tt.wantVal) + } + } + }) + } +} + +func TestEmptyObjectSchema(t *testing.T) { + m := emptyObjectSchema() + if m["type"] != "object" { + t.Errorf("type = %v", m["type"]) + } + if m["properties"] == nil { + t.Error("properties should be non-nil") + } +} + +func intPtr(n int) *int { + return &n +} diff --git a/internal/mcp/sources/federation/source.go b/internal/mcp/sources/federation/source.go new file mode 100644 index 00000000..e7e34059 --- /dev/null +++ b/internal/mcp/sources/federation/source.go @@ -0,0 +1,276 @@ +package federation + +import ( + "context" + "fmt" + "log/slog" + "sort" + "strconv" + "strings" + "sync" + "time" + + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +const cacheTTL = 5 * time.Second + +type ConnectionLister interface { + ListActiveByBot(ctx context.Context, botID string) ([]mcpgw.Connection, error) +} + +type Gateway interface { + ListHTTPConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) + CallHTTPConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) + + ListSSEConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) + CallSSEConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) + + ListStdioConnectionTools(ctx context.Context, botID string, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) + CallStdioConnectionTool(ctx context.Context, botID string, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) +} + +type toolRoute struct { + sourceType string + originalName string + connection mcpgw.Connection +} + +type cacheEntry struct { + expiresAt time.Time + routes map[string]toolRoute + tools []mcpgw.ToolDescriptor +} + +type Source struct { + logger *slog.Logger + gateway Gateway + connections ConnectionLister + + mu sync.Mutex + cache map[string]cacheEntry +} + +func NewSource(log *slog.Logger, gateway Gateway, connections ConnectionLister) *Source { + if log == nil { + log = slog.Default() + } + return &Source{ + logger: log.With(slog.String("source", "federated_mcp_tool")), + gateway: gateway, + connections: connections, + cache: map[string]cacheEntry{}, + } +} + +func (s *Source) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { + botID := strings.TrimSpace(session.BotID) + if botID == "" || s.gateway == nil { + return []mcpgw.ToolDescriptor{}, nil + } + if cached, ok := s.getCache(botID); ok { + return cloneTools(cached.tools), nil + } + tools, routes := s.buildToolsAndRoutes(ctx, botID) + s.setCache(botID, cacheEntry{ + expiresAt: time.Now().Add(cacheTTL), + routes: routes, + tools: tools, + }) + return cloneTools(tools), nil +} + +func (s *Source) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { + if s.gateway == nil { + return mcpgw.BuildToolErrorResult("federation gateway not available"), nil + } + botID := strings.TrimSpace(session.BotID) + if botID == "" { + return mcpgw.BuildToolErrorResult("bot_id is required"), nil + } + + route, ok := s.getRoute(botID, toolName) + if !ok { + _, _ = s.ListTools(ctx, session) + route, ok = s.getRoute(botID, toolName) + if !ok { + return nil, mcpgw.ErrToolNotFound + } + } + if arguments == nil { + arguments = map[string]any{} + } + + var ( + payload map[string]any + err error + ) + switch route.sourceType { + case "http": + payload, err = s.gateway.CallHTTPConnectionTool(ctx, route.connection, route.originalName, arguments) + case "sse": + payload, err = s.gateway.CallSSEConnectionTool(ctx, route.connection, route.originalName, arguments) + case "stdio": + payload, err = s.gateway.CallStdioConnectionTool(ctx, botID, route.connection, route.originalName, arguments) + default: + return mcpgw.BuildToolErrorResult("unsupported federated source"), nil + } + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + if err := mcpgw.PayloadError(payload); err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + if result, ok := payload["result"].(map[string]any); ok { + return result, nil + } + return mcpgw.BuildToolSuccessResult(payload), nil +} + +func (s *Source) buildToolsAndRoutes(ctx context.Context, botID string) ([]mcpgw.ToolDescriptor, map[string]toolRoute) { + routes := map[string]toolRoute{} + tools := make([]mcpgw.ToolDescriptor, 0, 16) + + addTool := func(descriptor mcpgw.ToolDescriptor, route toolRoute) { + name := strings.TrimSpace(descriptor.Name) + if name == "" { + return + } + finalName := name + if _, exists := routes[finalName]; exists { + seed := strings.ReplaceAll(finalName, ".", "_") + if seed == "" { + seed = "tool" + } + for i := 2; ; i++ { + candidate := seed + "_" + strconv.Itoa(i) + if _, ok := routes[candidate]; ok { + continue + } + finalName = candidate + break + } + } + descriptor.Name = finalName + routes[finalName] = route + tools = append(tools, descriptor) + } + + if s.connections != nil { + items, err := s.connections.ListActiveByBot(ctx, botID) + if err != nil { + s.logger.Warn("list mcp connections failed", slog.String("bot_id", botID), slog.Any("error", err)) + } else { + sort.Slice(items, func(i, j int) bool { + if items[i].Name == items[j].Name { + return items[i].ID < items[j].ID + } + return items[i].Name < items[j].Name + }) + for _, connection := range items { + var connTools []mcpgw.ToolDescriptor + switch strings.ToLower(strings.TrimSpace(connection.Type)) { + case "http": + connTools, err = s.gateway.ListHTTPConnectionTools(ctx, connection) + case "sse": + connTools, err = s.gateway.ListSSEConnectionTools(ctx, connection) + case "stdio": + connTools, err = s.gateway.ListStdioConnectionTools(ctx, botID, connection) + default: + s.logger.Warn("unsupported mcp connection type", slog.String("connection_id", connection.ID), slog.String("type", connection.Type)) + continue + } + if err != nil { + s.logger.Warn("list tools from connection failed", slog.String("connection_id", connection.ID), slog.String("name", connection.Name), slog.Any("error", err)) + continue + } + prefix := sanitizePrefix(connection.Name) + for _, tool := range connTools { + origin := strings.TrimSpace(tool.Name) + alias := origin + if prefix != "" { + alias = prefix + "." + origin + } + tool.Name = alias + if strings.TrimSpace(tool.Description) != "" { + tool.Description = "[" + strings.TrimSpace(connection.Name) + "] " + tool.Description + } else { + tool.Description = "[" + strings.TrimSpace(connection.Name) + "] " + origin + } + addTool(tool, toolRoute{ + sourceType: strings.ToLower(strings.TrimSpace(connection.Type)), + originalName: origin, + connection: connection, + }) + } + } + } + } + return tools, routes +} + +func sanitizePrefix(raw string) string { + raw = strings.TrimSpace(strings.ToLower(raw)) + if raw == "" { + return "mcp" + } + builder := strings.Builder{} + for _, ch := range raw { + if (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' { + builder.WriteRune(ch) + continue + } + builder.WriteRune('_') + } + normalized := strings.Trim(builder.String(), "._-") + if normalized == "" { + return "mcp" + } + return normalized +} + +func cloneTools(items []mcpgw.ToolDescriptor) []mcpgw.ToolDescriptor { + if len(items) == 0 { + return []mcpgw.ToolDescriptor{} + } + out := make([]mcpgw.ToolDescriptor, 0, len(items)) + for _, item := range items { + out = append(out, mcpgw.ToolDescriptor{ + Name: item.Name, + Description: item.Description, + InputSchema: item.InputSchema, + }) + } + return out +} + +func (s *Source) getCache(botID string) (cacheEntry, bool) { + s.mu.Lock() + defer s.mu.Unlock() + cached, ok := s.cache[botID] + if !ok || time.Now().After(cached.expiresAt) { + return cacheEntry{}, false + } + return cached, true +} + +func (s *Source) setCache(botID string, entry cacheEntry) { + s.mu.Lock() + s.cache[botID] = entry + s.mu.Unlock() +} + +func (s *Source) getRoute(botID, toolName string) (toolRoute, bool) { + s.mu.Lock() + defer s.mu.Unlock() + cached, ok := s.cache[botID] + if !ok || time.Now().After(cached.expiresAt) { + return toolRoute{}, false + } + route, exists := cached.routes[strings.TrimSpace(toolName)] + return route, exists +} + +func (s *Source) String() string { + return fmt.Sprintf("FederationSource(%p)", s) +} diff --git a/internal/mcp/sources/federation/source_test.go b/internal/mcp/sources/federation/source_test.go new file mode 100644 index 00000000..b591ef44 --- /dev/null +++ b/internal/mcp/sources/federation/source_test.go @@ -0,0 +1,126 @@ +package federation + +import ( + "context" + "log/slog" + "testing" + + mcpgw "github.com/memohai/memoh/internal/mcp" +) + +type testConnectionLister struct { + items []mcpgw.Connection + err error +} + +func (l *testConnectionLister) ListActiveByBot(ctx context.Context, botID string) ([]mcpgw.Connection, error) { + if l.err != nil { + return nil, l.err + } + return l.items, nil +} + +type testGateway struct { + listHTTP []mcpgw.ToolDescriptor + listSSE []mcpgw.ToolDescriptor + listStdio []mcpgw.ToolDescriptor + + lastCallType string +} + +func (g *testGateway) ListHTTPConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { + return g.listHTTP, nil +} + +func (g *testGateway) CallHTTPConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { + g.lastCallType = "http" + return map[string]any{"result": map[string]any{"ok": true, "route": "http"}}, nil +} + +func (g *testGateway) ListSSEConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { + return g.listSSE, nil +} + +func (g *testGateway) CallSSEConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { + g.lastCallType = "sse" + return map[string]any{"result": map[string]any{"ok": true, "route": "sse"}}, nil +} + +func (g *testGateway) ListStdioConnectionTools(ctx context.Context, botID string, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { + return g.listStdio, nil +} + +func (g *testGateway) CallStdioConnectionTool(ctx context.Context, botID string, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { + g.lastCallType = "stdio" + return map[string]any{"result": map[string]any{"ok": true, "route": "stdio"}}, nil +} + +func TestSourceListToolsIncludesSSETools(t *testing.T) { + gateway := &testGateway{ + listSSE: []mcpgw.ToolDescriptor{ + { + Name: "search", + Description: "search remote data", + InputSchema: map[string]any{"type": "object"}, + }, + }, + } + lister := &testConnectionLister{ + items: []mcpgw.Connection{ + { + ID: "conn-1", + Name: "Remote SSE", + Type: "sse", + Active: true, + Config: map[string]any{"url": "http://example.com/sse"}, + }, + }, + } + + source := NewSource(slog.Default(), gateway, lister) + tools, err := source.ListTools(context.Background(), mcpgw.ToolSessionContext{BotID: "bot-1"}) + if err != nil { + t.Fatalf("list tools failed: %v", err) + } + if len(tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(tools)) + } + if tools[0].Name != "remote_sse.search" { + t.Fatalf("unexpected tool alias: %s", tools[0].Name) + } +} + +func TestSourceCallToolRoutesToSSEConnection(t *testing.T) { + gateway := &testGateway{ + listSSE: []mcpgw.ToolDescriptor{ + { + Name: "search", + Description: "search remote data", + InputSchema: map[string]any{"type": "object"}, + }, + }, + } + lister := &testConnectionLister{ + items: []mcpgw.Connection{ + { + ID: "conn-1", + Name: "Remote SSE", + Type: "sse", + Active: true, + Config: map[string]any{"url": "http://example.com/sse"}, + }, + }, + } + source := NewSource(slog.Default(), gateway, lister) + + result, err := source.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot-1"}, "remote_sse.search", map[string]any{"query": "hello"}) + if err != nil { + t.Fatalf("call tool failed: %v", err) + } + if gateway.lastCallType != "sse" { + t.Fatalf("expected sse route, got: %s", gateway.lastCallType) + } + if ok, _ := result["ok"].(bool); !ok { + t.Fatalf("expected ok=true in result") + } +} diff --git a/internal/mcp/tool_gateway_service.go b/internal/mcp/tool_gateway_service.go new file mode 100644 index 00000000..f7a921d5 --- /dev/null +++ b/internal/mcp/tool_gateway_service.go @@ -0,0 +1,168 @@ +package mcp + +import ( + "context" + "fmt" + "log/slog" + "strings" + "sync" + "time" +) + +const ( + defaultToolRegistryCacheTTL = 5 * time.Second +) + +type cachedToolRegistry struct { + expiresAt time.Time + registry *ToolRegistry +} + +// ToolGatewayService federates tools from executors and sources. +type ToolGatewayService struct { + logger *slog.Logger + executors []ToolExecutor + sources []ToolSource + cacheTTL time.Duration + + mu sync.Mutex + cache map[string]cachedToolRegistry +} + +func NewToolGatewayService(log *slog.Logger, executors []ToolExecutor, sources []ToolSource) *ToolGatewayService { + if log == nil { + log = slog.Default() + } + filteredExecutors := make([]ToolExecutor, 0, len(executors)) + for _, executor := range executors { + if executor != nil { + filteredExecutors = append(filteredExecutors, executor) + } + } + filteredSources := make([]ToolSource, 0, len(sources)) + for _, source := range sources { + if source != nil { + filteredSources = append(filteredSources, source) + } + } + return &ToolGatewayService{ + logger: log.With(slog.String("service", "tool_gateway")), + executors: filteredExecutors, + sources: filteredSources, + cacheTTL: defaultToolRegistryCacheTTL, + cache: map[string]cachedToolRegistry{}, + } +} + +func (s *ToolGatewayService) InitializeResult() map[string]any { + return map[string]any{ + "protocolVersion": "2025-06-18", + "capabilities": map[string]any{ + "tools": map[string]any{ + "listChanged": false, + }, + }, + "serverInfo": map[string]any{ + "name": "memoh-tools-gateway", + "version": "1.0.0", + }, + } +} + +func (s *ToolGatewayService) ListTools(ctx context.Context, session ToolSessionContext) ([]ToolDescriptor, error) { + registry, err := s.getRegistry(ctx, session, false) + if err != nil { + return nil, err + } + return registry.List(), nil +} + +func (s *ToolGatewayService) CallTool(ctx context.Context, session ToolSessionContext, payload ToolCallPayload) (map[string]any, error) { + toolName := strings.TrimSpace(payload.Name) + if toolName == "" { + return nil, fmt.Errorf("tool name is required") + } + + registry, err := s.getRegistry(ctx, session, false) + if err != nil { + return nil, err + } + executor, _, ok := registry.Lookup(toolName) + if !ok { + // Refresh once for dynamic executors/sources. + registry, err = s.getRegistry(ctx, session, true) + if err != nil { + return nil, err + } + executor, _, ok = registry.Lookup(toolName) + if !ok { + return BuildToolErrorResult("tool not found: " + toolName), nil + } + } + + arguments := payload.Arguments + if arguments == nil { + arguments = map[string]any{} + } + result, err := executor.CallTool(ctx, session, toolName, arguments) + if err != nil { + if err == ErrToolNotFound { + return BuildToolErrorResult("tool not found: " + toolName), nil + } + return BuildToolErrorResult(err.Error()), nil + } + if result == nil { + return BuildToolSuccessResult(map[string]any{"ok": true}), nil + } + return result, nil +} + +func (s *ToolGatewayService) getRegistry(ctx context.Context, session ToolSessionContext, force bool) (*ToolRegistry, error) { + botID := strings.TrimSpace(session.BotID) + if botID == "" { + return nil, fmt.Errorf("bot id is required") + } + if !force { + s.mu.Lock() + cached, ok := s.cache[botID] + if ok && time.Now().Before(cached.expiresAt) && cached.registry != nil { + s.mu.Unlock() + return cached.registry, nil + } + s.mu.Unlock() + } + + registry := NewToolRegistry() + for _, executor := range s.executors { + tools, err := executor.ListTools(ctx, session) + if err != nil { + s.logger.Warn("list tools from executor failed", slog.Any("error", err)) + continue + } + for _, tool := range tools { + if err := registry.Register(executor, tool); err != nil { + s.logger.Warn("skip duplicated/invalid tool", slog.String("tool", tool.Name), slog.Any("error", err)) + } + } + } + for _, source := range s.sources { + tools, err := source.ListTools(ctx, session) + if err != nil { + s.logger.Warn("list tools from source failed", slog.Any("error", err)) + continue + } + for _, tool := range tools { + if err := registry.Register(source, tool); err != nil { + s.logger.Warn("skip duplicated/invalid tool", slog.String("tool", tool.Name), slog.Any("error", err)) + } + } + } + + s.mu.Lock() + s.cache[botID] = cachedToolRegistry{ + expiresAt: time.Now().Add(s.cacheTTL), + registry: registry, + } + s.mu.Unlock() + return registry, nil +} diff --git a/internal/mcp/tool_gateway_service_test.go b/internal/mcp/tool_gateway_service_test.go new file mode 100644 index 00000000..3509f7ef --- /dev/null +++ b/internal/mcp/tool_gateway_service_test.go @@ -0,0 +1,126 @@ +package mcp + +import ( + "context" + "errors" + "log/slog" + "testing" +) + +type gatewayTestProvider struct { + tools []ToolDescriptor + callResult map[string]map[string]any + callErr map[string]error +} + +func (p *gatewayTestProvider) ListTools(ctx context.Context, session ToolSessionContext) ([]ToolDescriptor, error) { + return p.tools, nil +} + +func (p *gatewayTestProvider) CallTool(ctx context.Context, session ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { + if err, ok := p.callErr[toolName]; ok { + return nil, err + } + if result, ok := p.callResult[toolName]; ok { + return result, nil + } + return nil, ErrToolNotFound +} + +func TestToolGatewayServiceListTools(t *testing.T) { + providerA := &gatewayTestProvider{ + tools: []ToolDescriptor{ + {Name: "tool_a", InputSchema: map[string]any{"type": "object"}}, + {Name: "dup_tool", InputSchema: map[string]any{"type": "object"}}, + }, + } + providerB := &gatewayTestProvider{ + tools: []ToolDescriptor{ + {Name: "tool_b", InputSchema: map[string]any{"type": "object"}}, + {Name: "dup_tool", InputSchema: map[string]any{"type": "object"}}, + }, + } + service := NewToolGatewayService(slog.Default(), []ToolExecutor{providerA, providerB}, nil) + + tools, err := service.ListTools(context.Background(), ToolSessionContext{BotID: "bot-1"}) + if err != nil { + t.Fatalf("list tools failed: %v", err) + } + if len(tools) != 3 { + t.Fatalf("expected 3 tools after dedupe, got %d", len(tools)) + } +} + +func TestToolGatewayServiceCallToolSuccess(t *testing.T) { + provider := &gatewayTestProvider{ + tools: []ToolDescriptor{ + {Name: "echo_tool", InputSchema: map[string]any{"type": "object"}}, + }, + callResult: map[string]map[string]any{ + "echo_tool": { + "content": []map[string]any{ + {"type": "text", "text": "ok"}, + }, + }, + }, + callErr: map[string]error{}, + } + service := NewToolGatewayService(slog.Default(), []ToolExecutor{provider}, nil) + + result, err := service.CallTool(context.Background(), ToolSessionContext{BotID: "bot-1"}, ToolCallPayload{ + Name: "echo_tool", + Arguments: map[string]any{"value": "hello"}, + }) + if err != nil { + t.Fatalf("call tool should not fail: %v", err) + } + if _, ok := result["content"]; !ok { + t.Fatalf("expected content in tool result") + } +} + +func TestToolGatewayServiceCallToolNotFound(t *testing.T) { + provider := &gatewayTestProvider{ + tools: []ToolDescriptor{}, + callResult: map[string]map[string]any{}, + callErr: map[string]error{}, + } + service := NewToolGatewayService(slog.Default(), []ToolExecutor{provider}, nil) + + result, err := service.CallTool(context.Background(), ToolSessionContext{BotID: "bot-1"}, ToolCallPayload{ + Name: "missing_tool", + Arguments: map[string]any{}, + }) + if err != nil { + t.Fatalf("call should return mcp error result instead of failing: %v", err) + } + isErr, _ := result["isError"].(bool) + if !isErr { + t.Fatalf("expected isError=true for missing tool") + } +} + +func TestToolGatewayServiceCallToolProviderError(t *testing.T) { + provider := &gatewayTestProvider{ + tools: []ToolDescriptor{ + {Name: "broken_tool", InputSchema: map[string]any{"type": "object"}}, + }, + callResult: map[string]map[string]any{}, + callErr: map[string]error{ + "broken_tool": errors.New("boom"), + }, + } + service := NewToolGatewayService(slog.Default(), []ToolExecutor{provider}, nil) + + result, err := service.CallTool(context.Background(), ToolSessionContext{BotID: "bot-1"}, ToolCallPayload{ + Name: "broken_tool", + Arguments: map[string]any{}, + }) + if err != nil { + t.Fatalf("call should not return hard error: %v", err) + } + isErr, _ := result["isError"].(bool) + if !isErr { + t.Fatalf("expected isError=true for provider failure") + } +} diff --git a/internal/mcp/tool_registry.go b/internal/mcp/tool_registry.go new file mode 100644 index 00000000..edd7552c --- /dev/null +++ b/internal/mcp/tool_registry.go @@ -0,0 +1,72 @@ +package mcp + +import ( + "fmt" + "sort" + "strings" +) + +type registryItem struct { + executor ToolExecutor + tool ToolDescriptor +} + +// ToolRegistry stores provider ownership and descriptor metadata. +type ToolRegistry struct { + items map[string]registryItem +} + +func NewToolRegistry() *ToolRegistry { + return &ToolRegistry{ + items: map[string]registryItem{}, + } +} + +func (r *ToolRegistry) Register(executor ToolExecutor, tool ToolDescriptor) error { + if executor == nil { + return fmt.Errorf("tool executor is required") + } + name := strings.TrimSpace(tool.Name) + if name == "" { + return fmt.Errorf("tool name is required") + } + if tool.InputSchema == nil { + tool.InputSchema = map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + } + if _, exists := r.items[name]; exists { + return fmt.Errorf("tool already registered: %s", name) + } + tool.Name = name + r.items[name] = registryItem{ + executor: executor, + tool: tool, + } + return nil +} + +func (r *ToolRegistry) Lookup(name string) (ToolExecutor, ToolDescriptor, bool) { + item, ok := r.items[strings.TrimSpace(name)] + if !ok { + return nil, ToolDescriptor{}, false + } + return item.executor, item.tool, true +} + +func (r *ToolRegistry) List() []ToolDescriptor { + if len(r.items) == 0 { + return []ToolDescriptor{} + } + names := make([]string, 0, len(r.items)) + for name := range r.items { + names = append(names, name) + } + sort.Strings(names) + tools := make([]ToolDescriptor, 0, len(names)) + for _, name := range names { + tools = append(tools, r.items[name].tool) + } + return tools +} diff --git a/internal/mcp/tool_registry_test.go b/internal/mcp/tool_registry_test.go new file mode 100644 index 00000000..f5001d9d --- /dev/null +++ b/internal/mcp/tool_registry_test.go @@ -0,0 +1,83 @@ +package mcp + +import ( + "context" + "testing" +) + +type registryTestProvider struct{} + +func (p *registryTestProvider) ListTools(ctx context.Context, session ToolSessionContext) ([]ToolDescriptor, error) { + return nil, nil +} + +func (p *registryTestProvider) CallTool(ctx context.Context, session ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { + return nil, nil +} + +func TestToolRegistryRegisterAndLookup(t *testing.T) { + registry := NewToolRegistry() + provider := ®istryTestProvider{} + if err := registry.Register(provider, ToolDescriptor{ + Name: "tool_a", + Description: "test", + InputSchema: map[string]any{"type": "object"}, + }); err != nil { + t.Fatalf("register should succeed: %v", err) + } + + gotProvider, descriptor, ok := registry.Lookup("tool_a") + if !ok { + t.Fatalf("lookup should find registered tool") + } + if gotProvider != provider { + t.Fatalf("lookup provider mismatch") + } + if descriptor.Name != "tool_a" { + t.Fatalf("lookup descriptor mismatch, got: %s", descriptor.Name) + } +} + +func TestToolRegistryRegisterDuplicate(t *testing.T) { + registry := NewToolRegistry() + provider := ®istryTestProvider{} + first := ToolDescriptor{ + Name: "dup_tool", + Description: "first", + InputSchema: map[string]any{"type": "object"}, + } + second := ToolDescriptor{ + Name: "dup_tool", + Description: "second", + InputSchema: map[string]any{"type": "object"}, + } + if err := registry.Register(provider, first); err != nil { + t.Fatalf("first register should succeed: %v", err) + } + if err := registry.Register(provider, second); err == nil { + t.Fatalf("duplicate register should fail") + } +} + +func TestToolRegistryListStableOrder(t *testing.T) { + registry := NewToolRegistry() + provider := ®istryTestProvider{} + tools := []ToolDescriptor{ + {Name: "b_tool", InputSchema: map[string]any{"type": "object"}}, + {Name: "a_tool", InputSchema: map[string]any{"type": "object"}}, + {Name: "c_tool", InputSchema: map[string]any{"type": "object"}}, + } + for _, tool := range tools { + if err := registry.Register(provider, tool); err != nil { + t.Fatalf("register %s failed: %v", tool.Name, err) + } + } + + list := registry.List() + if len(list) != 3 { + t.Fatalf("expected 3 tools, got %d", len(list)) + } + if list[0].Name != "a_tool" || list[1].Name != "b_tool" || list[2].Name != "c_tool" { + t.Fatalf("unexpected order: %#v", list) + } +} diff --git a/internal/mcp/tool_types.go b/internal/mcp/tool_types.go new file mode 100644 index 00000000..9a556ec5 --- /dev/null +++ b/internal/mcp/tool_types.go @@ -0,0 +1,197 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "math" + "strings" +) + +// ToolSessionContext carries request-scoped identity for tool execution. +type ToolSessionContext struct { + BotID string + ChatID string + ChannelIdentityID string + SessionToken string + CurrentPlatform string + ReplyTarget string +} + +// ToolDescriptor is the MCP tools/list item shape used by the gateway. +type ToolDescriptor struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema map[string]any `json:"inputSchema"` +} + +// ToolExecutor represents business-facing tools (message/schedule/memory). +type ToolExecutor interface { + ListTools(ctx context.Context, session ToolSessionContext) ([]ToolDescriptor, error) + CallTool(ctx context.Context, session ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) +} + +// ToolSource represents infrastructure-level tool sources (federation/connectors). +// A source is not a business tool itself; it supplies and routes downstream tools. +type ToolSource interface { + ListTools(ctx context.Context, session ToolSessionContext) ([]ToolDescriptor, error) + CallTool(ctx context.Context, session ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) +} + +// ToolCallPayload is the MCP tools/call params payload. +type ToolCallPayload struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` +} + +// ErrToolNotFound indicates the provider does not own the requested tool. +var ErrToolNotFound = fmt.Errorf("tool not found") + +// BuildToolSuccessResult builds a standard MCP tool success result object. +func BuildToolSuccessResult(structured any) map[string]any { + result := map[string]any{} + if structured != nil { + result["structuredContent"] = structured + if text := stringifyStructuredContent(structured); text != "" { + result["content"] = []map[string]any{ + { + "type": "text", + "text": text, + }, + } + } + } + if len(result) == 0 { + result["content"] = []map[string]any{ + { + "type": "text", + "text": "ok", + }, + } + } + return result +} + +// BuildToolErrorResult builds a standard MCP tool error result object. +func BuildToolErrorResult(message string) map[string]any { + msg := strings.TrimSpace(message) + if msg == "" { + msg = "tool execution failed" + } + return map[string]any{ + "isError": true, + "content": []map[string]any{ + { + "type": "text", + "text": msg, + }, + }, + } +} + +func stringifyStructuredContent(v any) string { + if v == nil { + return "" + } + switch value := v.(type) { + case string: + return strings.TrimSpace(value) + default: + payload, err := json.Marshal(value) + if err != nil { + return "" + } + return string(payload) + } +} + +func StringArg(arguments map[string]any, key string) string { + if arguments == nil { + return "" + } + raw, ok := arguments[key] + if !ok { + return "" + } + switch value := raw.(type) { + case string: + return strings.TrimSpace(value) + default: + return strings.TrimSpace(fmt.Sprintf("%v", raw)) + } +} + +func FirstStringArg(arguments map[string]any, keys ...string) string { + for _, key := range keys { + if value := StringArg(arguments, key); value != "" { + return value + } + } + return "" +} + +func IntArg(arguments map[string]any, key string) (int, bool, error) { + if arguments == nil { + return 0, false, nil + } + raw, ok := arguments[key] + if !ok || raw == nil { + return 0, false, nil + } + switch value := raw.(type) { + case int: + return value, true, nil + case int8: + return int(value), true, nil + case int16: + return int(value), true, nil + case int32: + return int(value), true, nil + case int64: + return int(value), true, nil + case uint: + return int(value), true, nil + case uint8: + return int(value), true, nil + case uint16: + return int(value), true, nil + case uint32: + return int(value), true, nil + case uint64: + return int(value), true, nil + case float32: + f := float64(value) + if math.IsNaN(f) || math.IsInf(f, 0) { + return 0, true, fmt.Errorf("%s must be a valid number", key) + } + return int(f), true, nil + case float64: + if math.IsNaN(value) || math.IsInf(value, 0) { + return 0, true, fmt.Errorf("%s must be a valid number", key) + } + return int(value), true, nil + case json.Number: + i, err := value.Int64() + if err != nil { + return 0, true, fmt.Errorf("%s must be an integer", key) + } + return int(i), true, nil + default: + return 0, true, fmt.Errorf("%s must be a number", key) + } +} + +func BoolArg(arguments map[string]any, key string) (bool, bool, error) { + if arguments == nil { + return false, false, nil + } + raw, ok := arguments[key] + if !ok || raw == nil { + return false, false, nil + } + value, ok := raw.(bool) + if !ok { + return false, true, fmt.Errorf("%s must be a boolean", key) + } + return value, true, nil +} diff --git a/internal/mcp/tools.go b/internal/mcp/tools.go deleted file mode 100644 index 7c440871..00000000 --- a/internal/mcp/tools.go +++ /dev/null @@ -1,421 +0,0 @@ -package mcp - -import ( - "bytes" - "context" - "fmt" - "io/fs" - "os" - "os/exec" - "path/filepath" - "strings" - "time" - "unicode" - - sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" -) - -type FSReadInput struct { - Path string `json:"path" jsonschema:"relative file path"` -} - -type FSReadOutput struct { - Content string `json:"content" jsonschema:"file content"` -} - -type FSWriteInput struct { - Path string `json:"path" jsonschema:"relative file path"` - Content string `json:"content" jsonschema:"file content"` -} - -type FSWriteOutput struct { - OK bool `json:"ok" jsonschema:"write result"` -} - -type FSListInput struct { - Path string `json:"path" jsonschema:"relative directory path"` - Recursive bool `json:"recursive" jsonschema:"recursive listing"` -} - -type FSFileEntry struct { - Path string `json:"path" jsonschema:"relative entry path"` - IsDir bool `json:"is_dir" jsonschema:"is directory"` - Size int64 `json:"size" jsonschema:"entry size"` - Mode uint32 `json:"mode" jsonschema:"file mode"` - ModTime time.Time `json:"mod_time" jsonschema:"modification time"` -} - -type FSListOutput struct { - Path string `json:"path" jsonschema:"listed path"` - Entries []FSFileEntry `json:"entries" jsonschema:"entries"` -} - -type FSEditInput struct { - Path string `json:"path" jsonschema:"relative file path"` - OldText string `json:"old_text" jsonschema:"exact text to find"` - NewText string `json:"new_text" jsonschema:"replacement text"` -} - -type FSEditOutput struct { - OK bool `json:"ok" jsonschema:"apply result"` -} - -type ExecInput struct { - Command string `json:"command" jsonschema:"command to run"` - Args []string `json:"args" jsonschema:"command arguments"` -} - -type ExecOutput struct { - OK bool `json:"ok" jsonschema:"execution success"` - ExitCode int `json:"exit_code" jsonschema:"process exit code"` - Stdout string `json:"stdout" jsonschema:"standard output"` - Stderr string `json:"stderr" jsonschema:"standard error"` -} - -func RegisterTools(server *sdkmcp.Server) { - sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "read", Description: "read file content"}, fsReadTool) - sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "write", Description: "write file content"}, fsWriteTool) - sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "list", Description: "list directory entries"}, fsListTool) - sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "edit", Description: "replace exact text in a file"}, fsEditTool) - sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "exec", Description: "execute command"}, execTool) -} - -func fsReadTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSReadInput) ( - *sdkmcp.CallToolResult, - FSReadOutput, - error, -) { - root := dataRoot() - target, err := resolvePath(root, input.Path) - if err != nil { - return nil, FSReadOutput{}, err - } - data, err := os.ReadFile(target) - if err != nil { - return nil, FSReadOutput{}, err - } - return nil, FSReadOutput{Content: string(data)}, nil -} - -func fsWriteTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSWriteInput) ( - *sdkmcp.CallToolResult, - FSWriteOutput, - error, -) { - root := dataRoot() - target, err := resolvePath(root, input.Path) - if err != nil { - return nil, FSWriteOutput{}, err - } - if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { - return nil, FSWriteOutput{}, err - } - if err := os.WriteFile(target, []byte(input.Content), 0o644); err != nil { - return nil, FSWriteOutput{}, err - } - return nil, FSWriteOutput{OK: true}, nil -} - -func fsListTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSListInput) ( - *sdkmcp.CallToolResult, - FSListOutput, - error, -) { - root := dataRoot() - target, err := resolvePathAllowRoot(root, input.Path) - if err != nil { - return nil, FSListOutput{}, err - } - info, err := os.Stat(target) - if err != nil { - return nil, FSListOutput{}, err - } - if !info.IsDir() { - return nil, FSListOutput{}, fmt.Errorf("path is not a directory") - } - - entries := []FSFileEntry{} - if input.Recursive { - err = filepath.WalkDir(target, func(p string, d fs.DirEntry, walkErr error) error { - if walkErr != nil { - return walkErr - } - if p == target { - return nil - } - entryInfo, err := d.Info() - if err != nil { - return err - } - entry, err := entryForPath(root, p, entryInfo) - if err != nil { - return err - } - entries = append(entries, entry) - return nil - }) - } else { - dirEntries, err := os.ReadDir(target) - if err != nil { - return nil, FSListOutput{}, err - } - for _, entry := range dirEntries { - entryInfo, err := entry.Info() - if err != nil { - return nil, FSListOutput{}, err - } - fullPath := filepath.Join(target, entry.Name()) - fileEntry, err := entryForPath(root, fullPath, entryInfo) - if err != nil { - return nil, FSListOutput{}, err - } - entries = append(entries, fileEntry) - } - } - if err != nil { - return nil, FSListOutput{}, err - } - - listedPath := strings.TrimSpace(input.Path) - if listedPath == "" { - listedPath = "." - } - return nil, FSListOutput{Path: listedPath, Entries: entries}, nil -} - -func fsEditTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSEditInput) ( - *sdkmcp.CallToolResult, - FSEditOutput, - error, -) { - root := dataRoot() - target, err := resolvePath(root, input.Path) - if err != nil { - return nil, FSEditOutput{}, err - } - orig, err := os.ReadFile(target) - if err != nil { - return nil, FSEditOutput{}, err - } - raw := string(orig) - bom, content := stripBOM(raw) - originalEnding := detectLineEnding(content) - normalizedContent := normalizeToLF(content) - normalizedOld := normalizeToLF(input.OldText) - normalizedNew := normalizeToLF(input.NewText) - - match := fuzzyFindText(normalizedContent, normalizedOld) - if !match.Found { - return nil, FSEditOutput{}, fmt.Errorf( - "could not find the exact text in %s. the old text must match exactly including all whitespace and newlines", - input.Path, - ) - } - - fuzzyContent := normalizeForFuzzyMatch(normalizedContent) - fuzzyOld := normalizeForFuzzyMatch(normalizedOld) - occurrences := strings.Count(fuzzyContent, fuzzyOld) - if occurrences > 1 { - return nil, FSEditOutput{}, fmt.Errorf( - "found %d occurrences of the text in %s. the text must be unique. please provide more context to make it unique", - occurrences, - input.Path, - ) - } - - baseContent := match.ContentForReplacement - updated := baseContent[:match.Index] + normalizedNew + baseContent[match.Index+match.MatchLength:] - if baseContent == updated { - return nil, FSEditOutput{}, fmt.Errorf( - "no changes made to %s. the replacement produced identical content. this might indicate an issue with special characters or the text not existing as expected", - input.Path, - ) - } - - finalContent := bom + restoreLineEndings(updated, originalEnding) - info, err := os.Stat(target) - if err != nil { - return nil, FSEditOutput{}, err - } - if err := os.WriteFile(target, []byte(finalContent), info.Mode().Perm()); err != nil { - return nil, FSEditOutput{}, err - } - return nil, FSEditOutput{OK: true}, nil -} - -func execTool(ctx context.Context, req *sdkmcp.CallToolRequest, input ExecInput) ( - *sdkmcp.CallToolResult, - ExecOutput, - error, -) { - if strings.TrimSpace(input.Command) == "" { - return nil, ExecOutput{}, fmt.Errorf("command is required") - } - cmd := exec.CommandContext(ctx, input.Command, input.Args...) - var stdout bytes.Buffer - var stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - if err := cmd.Run(); err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - return nil, ExecOutput{ - OK: false, - ExitCode: exitErr.ExitCode(), - Stdout: stdout.String(), - Stderr: stderr.String(), - }, nil - } - return nil, ExecOutput{}, err - } - - return nil, ExecOutput{ - OK: true, - ExitCode: 0, - Stdout: stdout.String(), - Stderr: stderr.String(), - }, nil -} - -func dataRoot() string { - root := strings.TrimSpace(os.Getenv("MCP_DATA_DIR")) - if root == "" { - root = "/data" - } - return root -} - -func resolvePathAllowRoot(root, requestPath string) (string, error) { - if strings.TrimSpace(requestPath) == "" { - return root, nil - } - return resolvePath(root, requestPath) -} - -func resolvePath(root, requestPath string) (string, error) { - clean := filepath.Clean(requestPath) - if clean == "." || clean == "" { - return "", os.ErrInvalid - } - if filepath.IsAbs(clean) || strings.HasPrefix(clean, "..") { - return "", os.ErrInvalid - } - return filepath.Join(root, clean), nil -} - -func entryForPath(root, target string, info os.FileInfo) (FSFileEntry, error) { - rel, err := filepath.Rel(root, target) - if err != nil { - return FSFileEntry{}, err - } - if strings.HasPrefix(rel, "..") { - return FSFileEntry{}, os.ErrInvalid - } - if rel == "." { - rel = "" - } - return FSFileEntry{ - Path: filepath.ToSlash(rel), - IsDir: info.IsDir(), - Size: info.Size(), - Mode: uint32(info.Mode().Perm()), - ModTime: info.ModTime(), - }, nil -} - -type FuzzyMatchResult struct { - Found bool - Index int - MatchLength int - UsedFuzzyMatch bool - ContentForReplacement string -} - -func detectLineEnding(content string) string { - crlfIdx := strings.Index(content, "\r\n") - lfIdx := strings.Index(content, "\n") - if lfIdx == -1 { - return "\n" - } - if crlfIdx == -1 { - return "\n" - } - if crlfIdx < lfIdx { - return "\r\n" - } - return "\n" -} - -func normalizeToLF(text string) string { - text = strings.ReplaceAll(text, "\r\n", "\n") - return strings.ReplaceAll(text, "\r", "\n") -} - -func restoreLineEndings(text, ending string) string { - if ending == "\r\n" { - return strings.ReplaceAll(text, "\n", "\r\n") - } - return text -} - -func stripBOM(content string) (string, string) { - if strings.HasPrefix(content, "\uFEFF") { - return "\uFEFF", content[1:] - } - return "", content -} - -func normalizeForFuzzyMatch(text string) string { - lines := strings.Split(text, "\n") - for i, line := range lines { - lines[i] = strings.TrimRightFunc(line, unicode.IsSpace) - } - trimmed := strings.Join(lines, "\n") - return strings.Map(func(r rune) rune { - switch r { - case '\u2018', '\u2019', '\u201A', '\u201B': - return '\'' - case '\u201C', '\u201D', '\u201E', '\u201F': - return '"' - case '\u2010', '\u2011', '\u2012', '\u2013', '\u2014', '\u2015', '\u2212': - return '-' - case '\u00A0', '\u2002', '\u2003', '\u2004', '\u2005', '\u2006', '\u2007', '\u2008', '\u2009', '\u200A', '\u202F', '\u205F', '\u3000': - return ' ' - default: - return r - } - }, trimmed) -} - -func fuzzyFindText(content, oldText string) FuzzyMatchResult { - exactIndex := strings.Index(content, oldText) - if exactIndex != -1 { - return FuzzyMatchResult{ - Found: true, - Index: exactIndex, - MatchLength: len(oldText), - UsedFuzzyMatch: false, - ContentForReplacement: content, - } - } - - fuzzyContent := normalizeForFuzzyMatch(content) - fuzzyOld := normalizeForFuzzyMatch(oldText) - fuzzyIndex := strings.Index(fuzzyContent, fuzzyOld) - if fuzzyIndex == -1 { - return FuzzyMatchResult{ - Found: false, - Index: -1, - MatchLength: 0, - UsedFuzzyMatch: false, - ContentForReplacement: content, - } - } - return FuzzyMatchResult{ - Found: true, - Index: fuzzyIndex, - MatchLength: len(fuzzyOld), - UsedFuzzyMatch: true, - ContentForReplacement: fuzzyContent, - } -} diff --git a/internal/mcp/versioning.go b/internal/mcp/versioning.go index 0ceb603d..f9de5043 100644 --- a/internal/mcp/versioning.go +++ b/internal/mcp/versioning.go @@ -4,18 +4,17 @@ import ( "context" "encoding/json" "fmt" - "strings" "time" "github.com/containerd/containerd/v2/pkg/oci" "github.com/containerd/errdefs" - "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" "github.com/opencontainers/runtime-spec/specs-go" "github.com/memohai/memoh/internal/config" ctr "github.com/memohai/memoh/internal/containerd" + "github.com/memohai/memoh/internal/db" dbsqlc "github.com/memohai/memoh/internal/db/sqlc" ) @@ -88,12 +87,6 @@ func (m *Manager) CreateVersion(ctx context.Context, userID string) (*VersionInf Source: dataDir, Options: []string{"rbind", "rw"}, }, - { - Destination: "/app", - Type: "bind", - Source: dataDir, - Options: []string{"rbind", "rw"}, - }, { Destination: "/etc/resolv.conf", Type: "bind", @@ -224,12 +217,6 @@ func (m *Manager) RollbackVersion(ctx context.Context, userID string, version in Source: dataDir, Options: []string{"rbind", "rw"}, }, - { - Destination: "/app", - Type: "bind", - Source: dataDir, - Options: []string{"rbind", "rw"}, - }, { Destination: "/etc/resolv.conf", Type: "bind", @@ -291,7 +278,7 @@ func (m *Manager) ensureDBRecords(ctx context.Context, botID, containerID, runti if err != nil { return pgtype.UUID{}, err } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return pgtype.UUID{}, err } @@ -383,13 +370,3 @@ func (m *Manager) insertEvent(ctx context.Context, containerID, eventType string }) } -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} diff --git a/internal/memory/indexer_test.go b/internal/memory/indexer_test.go index 2fc6a19c..e6fdfa09 100644 --- a/internal/memory/indexer_test.go +++ b/internal/memory/indexer_test.go @@ -62,18 +62,16 @@ func TestBM25Indexer_TermFrequencies(t *testing.T) { func TestBM25Indexer_BM25Logic(t *testing.T) { indexer := NewBM25Indexer(nil) - // 1. 添加一个包含 "golang" 的文档 lang := "en" tf1 := map[string]int{"golang": 1, "programming": 1} len1 := 2 indices1, values1 := indexer.AddDocument(lang, tf1, len1) - // 2. 添加另一个包含 "golang" 但更长的文档 tf2 := map[string]int{"golang": 1, "tutorial": 1, "advanced": 1, "topics": 1} len2 := 4 indices2, values2 := indexer.AddDocument(lang, tf2, len2) - // 验证:在 BM25 中,相同词项在短文档中的权重应该比在长文档中高(惩罚长文档) + // In BM25, same term in a shorter doc should have higher weight than in a longer doc. var weight1, weight2 float32 for i, idx := range indices1 { if idx == termHash("golang") { @@ -90,11 +88,10 @@ func TestBM25Indexer_BM25Logic(t *testing.T) { t.Errorf("Expected weight in shorter doc (%f) to be higher than in longer doc (%f)", weight1, weight2) } - // 3. 添加一个不包含 "golang" 的文档,增加文档总数,验证 IDF 变化 - // IDF 应该随着包含该词的文档比例减少而增加 + // Add a doc without "golang" to increase doc count; IDF should increase. oldWeight1 := weight1 indexer.AddDocument(lang, map[string]int{"rust": 1}, 1) - indices3, values3 := indexer.AddDocument(lang, tf1, len1) // 再次生成相同文档的向量 + indices3, values3 := indexer.AddDocument(lang, tf1, len1) for i, idx := range indices3 { if idx == termHash("golang") { @@ -112,7 +109,6 @@ func TestBM25Indexer_RemoveDocument(t *testing.T) { lang := "en" term := "test" - // 添加文档 tf, docLen, _ := indexer.TermFrequencies(lang, term) indexer.AddDocument(lang, tf, docLen) @@ -123,7 +119,6 @@ func TestBM25Indexer_RemoveDocument(t *testing.T) { } indexer.mu.RUnlock() - // 删除文档 indexer.RemoveDocument(lang, tf, docLen) indexer.mu.RLock() @@ -134,7 +129,7 @@ func TestBM25Indexer_RemoveDocument(t *testing.T) { } func TestTermHash_CollisionResistance(t *testing.T) { - // 验证不同词项生成的哈希索引在 20bit 空间内是否分布合理(简单检查不冲突) + // Check that different terms get distinct hashes in 20-bit space (no collision in small sample). h1 := termHash("apple") h2 := termHash("orange") h3 := termHash("banana") @@ -143,7 +138,6 @@ func TestTermHash_CollisionResistance(t *testing.T) { t.Errorf("Detected unexpected hash collision in small sample: %d, %d, %d", h1, h2, h3) } - // 验证掩码是否生效 if h1 > sparseDimMask { t.Errorf("Hash %d exceeds mask %d", h1, sparseDimMask) } diff --git a/internal/memory/qdrant_store.go b/internal/memory/qdrant_store.go index ce5a826e..2aada92c 100644 --- a/internal/memory/qdrant_store.go +++ b/internal/memory/qdrant_store.go @@ -506,7 +506,7 @@ func (s *QdrantStore) ensurePayloadIndexes(ctx context.Context) error { if s.client == nil { return nil } - fields := []string{"botId", "sessionId", "runId"} + fields := []string{"botId", "runId"} wait := true for _, field := range fields { _, err := s.client.CreateFieldIndex(ctx, &qdrant.CreateFieldIndexCollection{ diff --git a/internal/memory/service.go b/internal/memory/service.go index 343ee1b3..3ff076e2 100644 --- a/internal/memory/service.go +++ b/internal/memory/service.go @@ -409,12 +409,12 @@ func (s *Service) Get(ctx context.Context, memoryID string) (MemoryItem, error) func (s *Service) GetAll(ctx context.Context, req GetAllRequest) (SearchResponse, error) { filters := map[string]any{} + for k, v := range req.Filters { + filters[k] = v + } if req.BotID != "" { filters["botId"] = req.BotID } - if req.SessionID != "" { - filters["sessionId"] = req.SessionID - } if req.RunID != "" { filters["runId"] = req.RunID } @@ -445,12 +445,12 @@ func (s *Service) Delete(ctx context.Context, memoryID string) (DeleteResponse, func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteResponse, error) { filters := map[string]any{} + for k, v := range req.Filters { + filters[k] = v + } if req.BotID != "" { filters["botId"] = req.BotID } - if req.SessionID != "" { - filters["sessionId"] = req.SessionID - } if req.RunID != "" { filters["runId"] = req.RunID } @@ -756,9 +756,6 @@ func buildFilters(req AddRequest) map[string]any { if req.BotID != "" { filters["botId"] = req.BotID } - if req.SessionID != "" { - filters["sessionId"] = req.SessionID - } if req.RunID != "" { filters["runId"] = req.RunID } @@ -773,9 +770,6 @@ func buildSearchFilters(req SearchRequest) map[string]any { if req.BotID != "" { filters["botId"] = req.BotID } - if req.SessionID != "" { - filters["sessionId"] = req.SessionID - } if req.RunID != "" { filters["runId"] = req.RunID } @@ -790,9 +784,6 @@ func buildEmbedFilters(req EmbedUpsertRequest) map[string]any { if req.BotID != "" { filters["botId"] = req.BotID } - if req.SessionID != "" { - filters["sessionId"] = req.SessionID - } if req.RunID != "" { filters["runId"] = req.RunID } @@ -883,9 +874,6 @@ func payloadToMemoryItem(id string, payload map[string]any) MemoryItem { if v, ok := payload["botId"].(string); ok { item.BotID = v } - if v, ok := payload["sessionId"].(string); ok { - item.SessionID = v - } if v, ok := payload["runId"].(string); ok { item.RunID = v } diff --git a/internal/memory/service_test.go b/internal/memory/service_test.go index 92db7d63..bec57bfa 100644 --- a/internal/memory/service_test.go +++ b/internal/memory/service_test.go @@ -7,7 +7,7 @@ import ( "testing" ) -// MockLLM 模拟 LLM 行为 +// MockLLM mocks LLM for tests. type MockLLM struct { ExtractFunc func(ctx context.Context, req ExtractRequest) (ExtractResponse, error) DecideFunc func(ctx context.Context, req DecideRequest) (DecideResponse, error) @@ -25,11 +25,9 @@ func (m *MockLLM) DetectLanguage(ctx context.Context, text string) (string, erro } func TestService_Add_FullFlow(t *testing.T) { - // 这是一个高质量的集成逻辑测试,验证 Service.Add 的完整决策流 ctx := context.Background() logger := slog.Default() - // 1. Mock LLM: 模拟从对话中提取事实,并决定添加新记忆 mockLLM := &MockLLM{ ExtractFunc: func(ctx context.Context, req ExtractRequest) (ExtractResponse, error) { return ExtractResponse{Facts: []string{"User likes Go"}}, nil @@ -46,20 +44,7 @@ func TestService_Add_FullFlow(t *testing.T) { }, } - // 2. 初始化依赖 - // 注意:由于 QdrantStore 涉及网络,我们这里仅测试逻辑流。 - // 如果要跑通,需要一个 MockStore,但为了保持示例简洁且高质量, - // 我们重点展示如何组织 Service 的测试架构。 - - // 假设我们有一个内存版的 Store 或者 MockStore (此处略,实际项目中建议实现 MockStore) - // 这里演示逻辑链路的正确性 - t.Run("Decision Flow - ADD", func(t *testing.T) { - // 验证 Service 是否正确调用了 LLM 的 Extract 和 Decide - // 并且根据 Decide 的结果执行了相应的 Action - - // 提示:在实际代码中,Service.Add 会依次调用 Extract -> collectCandidates -> Decide -> applyAdd - // 我们可以通过在 Mock 中增加计数器来验证调用链路。 extractCalled := false decideCalled := false @@ -75,24 +60,17 @@ func TestService_Add_FullFlow(t *testing.T) { return DecideResponse{Actions: []DecisionAction{{Event: "ADD", Text: "Fact 1"}}}, nil } - // 由于 Service 结构体字段是私有的且依赖较多, - // 高质量的测试通常会配合接口或构造函数注入。 - // 这里我们验证核心逻辑:Decide 的 Action 映射 - s := &Service{ llm: mockLLM, logger: logger, bm25: NewBM25Indexer(nil), - // store: mockStore, // 实际测试中需要注入 MockStore } - // 模拟一个 Add 请求 req := AddRequest{ Message: "I love coding in Go", BotID: "bot-123", } - // 由于没有注入真实的 Store,这里会报错,但我们可以验证到报错前的逻辑 _, err := s.Add(ctx, req) if !extractCalled { @@ -102,26 +80,21 @@ func TestService_Add_FullFlow(t *testing.T) { t.Error("Expected LLM.Decide to be called") } - // 如果 err 是因为 store 为 nil 导致的,说明前面的 LLM 链路已经跑通 if err == nil || !reflectContains(err.Error(), "qdrant store") { - // 如果没报错或者报了别的错,说明逻辑有误 + // Expected either nil (if mock store added) or qdrant store error. } }) } func reflectContains(s, substr string) bool { - return fmt.Sprintf("%s", s) != "" // 简化逻辑 + return fmt.Sprintf("%s", s) != "" } func TestRankFusion_Logic(t *testing.T) { - // 测试 RRF (Reciprocal Rank Fusion) 逻辑 - // 验证不同来源的结果是否能被正确合并和排序 - p1 := qdrantPoint{ID: "1", Payload: map[string]any{"data": "result 1"}} p2 := qdrantPoint{ID: "2", Payload: map[string]any{"data": "result 2"}} - // 来源 A: 1 号排第一,2 号排第二 - // 来源 B: 2 号排第一,1 号排第二 + // Source A: 1 first, 2 second; Source B: 2 first, 1 second. pointsBySource := map[string][]qdrantPoint{ "source_a": {p1, p2}, "source_b": {p2, p1}, @@ -137,8 +110,7 @@ func TestRankFusion_Logic(t *testing.T) { t.Fatalf("Expected 2 results, got %d", len(results)) } - // 在这个对称的情况下,两者的 RRF 分数应该相同 if results[0].Score != results[1].Score { - // 理论上 1/(60+1) + 1/(60+2) + // Symmetric case: both get same RRF score (e.g. 1/(k+1)+1/(k+2) for k=60). } } diff --git a/internal/memory/types.go b/internal/memory/types.go index 22299457..606f102a 100644 --- a/internal/memory/types.go +++ b/internal/memory/types.go @@ -15,28 +15,26 @@ type Message struct { } type AddRequest struct { - Message string `json:"message,omitempty"` - Messages []Message `json:"messages,omitempty"` - BotID string `json:"bot_id,omitempty"` - SessionID string `json:"session_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - RunID string `json:"run_id,omitempty"` + Message string `json:"message,omitempty"` + Messages []Message `json:"messages,omitempty"` + BotID string `json:"bot_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + RunID string `json:"run_id,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` Filters map[string]any `json:"filters,omitempty"` - Infer *bool `json:"infer,omitempty"` - EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` + Infer *bool `json:"infer,omitempty"` + EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` } type SearchRequest struct { - Query string `json:"query"` - BotID string `json:"bot_id,omitempty"` - SessionID string `json:"session_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - RunID string `json:"run_id,omitempty"` - Limit int `json:"limit,omitempty"` + Query string `json:"query"` + BotID string `json:"bot_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + RunID string `json:"run_id,omitempty"` + Limit int `json:"limit,omitempty"` Filters map[string]any `json:"filters,omitempty"` - Sources []string `json:"sources,omitempty"` - EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` + Sources []string `json:"sources,omitempty"` + EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` } type UpdateRequest struct { @@ -46,18 +44,18 @@ type UpdateRequest struct { } type GetAllRequest struct { - BotID string `json:"bot_id,omitempty"` - SessionID string `json:"session_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - RunID string `json:"run_id,omitempty"` - Limit int `json:"limit,omitempty"` + BotID string `json:"bot_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + RunID string `json:"run_id,omitempty"` + Limit int `json:"limit,omitempty"` + Filters map[string]any `json:"filters,omitempty"` } type DeleteAllRequest struct { - BotID string `json:"bot_id,omitempty"` - SessionID string `json:"session_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - RunID string `json:"run_id,omitempty"` + BotID string `json:"bot_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + RunID string `json:"run_id,omitempty"` + Filters map[string]any `json:"filters,omitempty"` } type EmbedInput struct { @@ -67,17 +65,16 @@ type EmbedInput struct { } type EmbedUpsertRequest struct { - Type string `json:"type"` - Provider string `json:"provider,omitempty"` - Model string `json:"model,omitempty"` - Input EmbedInput `json:"input"` - Source string `json:"source,omitempty"` - BotID string `json:"bot_id,omitempty"` - SessionID string `json:"session_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - RunID string `json:"run_id,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` - Filters map[string]any `json:"filters,omitempty"` + Type string `json:"type"` + Provider string `json:"provider,omitempty"` + Model string `json:"model,omitempty"` + Input EmbedInput `json:"input"` + Source string `json:"source,omitempty"` + BotID string `json:"bot_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + RunID string `json:"run_id,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + Filters map[string]any `json:"filters,omitempty"` } type EmbedUpsertResponse struct { @@ -88,17 +85,16 @@ type EmbedUpsertResponse struct { } type MemoryItem struct { - ID string `json:"id"` - Memory string `json:"memory"` - Hash string `json:"hash,omitempty"` - CreatedAt string `json:"createdAt,omitempty"` - UpdatedAt string `json:"updatedAt,omitempty"` - Score float64 `json:"score,omitempty"` + ID string `json:"id"` + Memory string `json:"memory"` + Hash string `json:"hash,omitempty"` + CreatedAt string `json:"createdAt,omitempty"` + UpdatedAt string `json:"updatedAt,omitempty"` + Score float64 `json:"score,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` - BotID string `json:"botId,omitempty"` - SessionID string `json:"sessionId,omitempty"` - AgentID string `json:"agentId,omitempty"` - RunID string `json:"runId,omitempty"` + BotID string `json:"botId,omitempty"` + AgentID string `json:"agentId,omitempty"` + RunID string `json:"runId,omitempty"` } type SearchResponse struct { @@ -111,7 +107,7 @@ type DeleteResponse struct { } type ExtractRequest struct { - Messages []Message `json:"messages"` + Messages []Message `json:"messages"` Filters map[string]any `json:"filters,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } @@ -121,16 +117,16 @@ type ExtractResponse struct { } type CandidateMemory struct { - ID string `json:"id"` - Memory string `json:"memory"` + ID string `json:"id"` + Memory string `json:"memory"` Metadata map[string]any `json:"metadata,omitempty"` } type DecideRequest struct { - Facts []string `json:"facts"` - Candidates []CandidateMemory `json:"candidates"` - Filters map[string]any `json:"filters,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Facts []string `json:"facts"` + Candidates []CandidateMemory `json:"candidates"` + Filters map[string]any `json:"filters,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } type DecisionAction struct { diff --git a/internal/message/event/hub.go b/internal/message/event/hub.go new file mode 100644 index 00000000..eec37cb0 --- /dev/null +++ b/internal/message/event/hub.go @@ -0,0 +1,124 @@ +package event + +import ( + "encoding/json" + "strings" + "sync" + + "github.com/google/uuid" +) + +const ( + // DefaultBufferSize is the default per-subscriber channel buffer. + DefaultBufferSize = 64 +) + +// EventType identifies the event category published by the message event hub. +type EventType string + +const ( + // EventTypeMessageCreated is emitted after a message is persisted successfully. + EventTypeMessageCreated EventType = "message_created" +) + +// Event is the normalized payload emitted by the in-process message event hub. +type Event struct { + Type EventType `json:"type"` + BotID string `json:"bot_id"` + Data json.RawMessage `json:"data,omitempty"` +} + +// Publisher publishes events to subscribers. +type Publisher interface { + Publish(event Event) +} + +// Subscriber subscribes to bot-scoped events. +type Subscriber interface { + Subscribe(botID string, buffer int) (string, <-chan Event, func()) +} + +// Hub is an in-process pub/sub dispatcher for bot-scoped message events. +type Hub struct { + mu sync.RWMutex + streams map[string]map[string]chan Event +} + +// NewHub creates an empty message event hub. +func NewHub() *Hub { + return &Hub{ + streams: map[string]map[string]chan Event{}, + } +} + +// Publish broadcasts one event to all subscribers under the same bot ID. +// Slow subscribers are dropped in a non-blocking way. +func (h *Hub) Publish(event Event) { + if h == nil { + return + } + botID := strings.TrimSpace(event.BotID) + if botID == "" { + return + } + h.mu.RLock() + defer h.mu.RUnlock() + for _, ch := range h.streams[botID] { + select { + case ch <- event: + default: + // Drop if receiver is slow to avoid blocking persistence path. + } + } +} + +// Subscribe registers one subscriber under a bot ID. +// It returns a stream ID, read-only event channel, and a cancel function. +func (h *Hub) Subscribe(botID string, buffer int) (string, <-chan Event, func()) { + if h == nil { + ch := make(chan Event) + close(ch) + return "", ch, func() {} + } + botID = strings.TrimSpace(botID) + if botID == "" { + ch := make(chan Event) + close(ch) + return "", ch, func() {} + } + if buffer <= 0 { + buffer = DefaultBufferSize + } + + streamID := uuid.NewString() + ch := make(chan Event, buffer) + + h.mu.Lock() + streams, ok := h.streams[botID] + if !ok { + streams = map[string]chan Event{} + h.streams[botID] = streams + } + streams[streamID] = ch + h.mu.Unlock() + + var once sync.Once + cancel := func() { + once.Do(func() { + h.mu.Lock() + streams := h.streams[botID] + if streams != nil { + if current, ok := streams[streamID]; ok { + delete(streams, streamID) + close(current) + } + if len(streams) == 0 { + delete(h.streams, botID) + } + } + h.mu.Unlock() + }) + } + + return streamID, ch, cancel +} diff --git a/internal/message/event/hub_test.go b/internal/message/event/hub_test.go new file mode 100644 index 00000000..987c3861 --- /dev/null +++ b/internal/message/event/hub_test.go @@ -0,0 +1,59 @@ +package event + +import ( + "testing" + "time" +) + +func TestHubPublishScopedByBotID(t *testing.T) { + hub := NewHub() + _, botAStream, cancelA := hub.Subscribe("bot-a", 8) + defer cancelA() + _, botBStream, cancelB := hub.Subscribe("bot-b", 8) + defer cancelB() + + hub.Publish(Event{Type: EventTypeMessageCreated, BotID: "bot-a"}) + + select { + case <-botAStream: + case <-time.After(200 * time.Millisecond): + t.Fatalf("expected event for bot-a subscriber") + } + + select { + case <-botBStream: + t.Fatalf("did not expect bot-b subscriber to receive bot-a event") + case <-time.After(120 * time.Millisecond): + } +} + +func TestHubCancelUnsubscribe(t *testing.T) { + hub := NewHub() + _, stream, cancel := hub.Subscribe("bot-a", 8) + cancel() + + select { + case _, ok := <-stream: + if ok { + t.Fatalf("expected stream to be closed after cancel") + } + case <-time.After(200 * time.Millisecond): + t.Fatalf("timed out waiting for stream close") + } +} + +func TestHubSlowSubscriberDoesNotBlockPublish(t *testing.T) { + hub := NewHub() + _, stream, cancel := hub.Subscribe("bot-a", 1) + defer cancel() + + hub.Publish(Event{Type: EventTypeMessageCreated, BotID: "bot-a"}) + hub.Publish(Event{Type: EventTypeMessageCreated, BotID: "bot-a"}) + hub.Publish(Event{Type: EventTypeMessageCreated, BotID: "bot-a"}) + + select { + case <-stream: + case <-time.After(200 * time.Millisecond): + t.Fatalf("expected at least one event in buffer") + } +} diff --git a/internal/message/service.go b/internal/message/service.go new file mode 100644 index 00000000..60051230 --- /dev/null +++ b/internal/message/service.go @@ -0,0 +1,358 @@ +package message + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/jackc/pgx/v5/pgtype" + + dbpkg "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/message/event" +) + +// DBService persists and reads bot history messages. +type DBService struct { + queries *sqlc.Queries + logger *slog.Logger + publisher event.Publisher +} + +// NewService creates a message service. +func NewService(log *slog.Logger, queries *sqlc.Queries, publishers ...event.Publisher) *DBService { + if log == nil { + log = slog.Default() + } + var publisher event.Publisher + if len(publishers) > 0 { + publisher = publishers[0] + } + return &DBService{ + queries: queries, + logger: log.With(slog.String("service", "message")), + publisher: publisher, + } +} + +// Persist writes a single message to bot_history_messages. +func (s *DBService) Persist(ctx context.Context, input PersistInput) (Message, error) { + pgBotID, err := dbpkg.ParseUUID(input.BotID) + if err != nil { + return Message{}, fmt.Errorf("invalid bot id: %w", err) + } + + pgRouteID, err := parseOptionalUUID(input.RouteID) + if err != nil { + return Message{}, fmt.Errorf("invalid route id: %w", err) + } + pgSenderChannelIdentityID, err := parseOptionalUUID(input.SenderChannelIdentityID) + if err != nil { + return Message{}, fmt.Errorf("invalid sender channel identity id: %w", err) + } + pgSenderUserID, err := parseOptionalUUID(input.SenderUserID) + if err != nil { + return Message{}, fmt.Errorf("invalid sender user id: %w", err) + } + + metaBytes, err := json.Marshal(nonNilMap(input.Metadata)) + if err != nil { + return Message{}, fmt.Errorf("marshal message metadata: %w", err) + } + + content := input.Content + if len(content) == 0 { + content = []byte("{}") + } + + row, err := s.queries.CreateMessage(ctx, sqlc.CreateMessageParams{ + BotID: pgBotID, + RouteID: pgRouteID, + SenderChannelIdentityID: pgSenderChannelIdentityID, + SenderUserID: pgSenderUserID, + Platform: toPgText(input.Platform), + ExternalMessageID: toPgText(input.ExternalMessageID), + SourceReplyToMessageID: toPgText(input.SourceReplyToMessageID), + Role: input.Role, + Content: content, + Metadata: metaBytes, + }) + if err != nil { + return Message{}, err + } + + result := toMessageFromCreate(row) + s.publishMessageCreated(result) + return result, nil +} + +// List returns all messages for a bot. +func (s *DBService) List(ctx context.Context, botID string) ([]Message, error) { + pgBotID, err := dbpkg.ParseUUID(botID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListMessages(ctx, pgBotID) + if err != nil { + return nil, err + } + return toMessagesFromList(rows), nil +} + +// ListSince returns bot messages since a given time. +func (s *DBService) ListSince(ctx context.Context, botID string, since time.Time) ([]Message, error) { + pgBotID, err := dbpkg.ParseUUID(botID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListMessagesSince(ctx, sqlc.ListMessagesSinceParams{ + BotID: pgBotID, + CreatedAt: pgtype.Timestamptz{Time: since, Valid: true}, + }) + if err != nil { + return nil, err + } + return toMessagesFromSince(rows), nil +} + +// ListLatest returns the latest N bot messages (newest first in DB; caller may reverse for ASC). +func (s *DBService) ListLatest(ctx context.Context, botID string, limit int32) ([]Message, error) { + pgBotID, err := dbpkg.ParseUUID(botID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListMessagesLatest(ctx, sqlc.ListMessagesLatestParams{ + BotID: pgBotID, + MaxCount: limit, + }) + if err != nil { + return nil, err + } + return toMessagesFromLatest(rows), nil +} + +// ListBefore returns up to limit messages older than before (created_at < before), ordered oldest-first. +func (s *DBService) ListBefore(ctx context.Context, botID string, before time.Time, limit int32) ([]Message, error) { + pgBotID, err := dbpkg.ParseUUID(botID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListMessagesBefore(ctx, sqlc.ListMessagesBeforeParams{ + BotID: pgBotID, + CreatedAt: pgtype.Timestamptz{Time: before, Valid: true}, + MaxCount: limit, + }) + if err != nil { + return nil, err + } + return toMessagesFromBefore(rows), nil +} + +// DeleteByBot deletes all messages for a bot. +func (s *DBService) DeleteByBot(ctx context.Context, botID string) error { + pgBotID, err := dbpkg.ParseUUID(botID) + if err != nil { + return err + } + return s.queries.DeleteMessagesByBot(ctx, pgBotID) +} + +func toMessageFromCreate(row sqlc.CreateMessageRow) Message { + return toMessageFields( + row.ID, + row.BotID, + row.RouteID, + row.SenderChannelIdentityID, + row.SenderUserID, + row.Platform, + row.ExternalMessageID, + row.SourceReplyToMessageID, + row.Role, + row.Content, + row.Metadata, + row.CreatedAt, + ) +} + +func toMessageFromListRow(row sqlc.ListMessagesRow) Message { + return toMessageFields( + row.ID, + row.BotID, + row.RouteID, + row.SenderChannelIdentityID, + row.SenderUserID, + row.Platform, + row.ExternalMessageID, + row.SourceReplyToMessageID, + row.Role, + row.Content, + row.Metadata, + row.CreatedAt, + ) +} + +func toMessageFromSinceRow(row sqlc.ListMessagesSinceRow) Message { + return toMessageFields( + row.ID, + row.BotID, + row.RouteID, + row.SenderChannelIdentityID, + row.SenderUserID, + row.Platform, + row.ExternalMessageID, + row.SourceReplyToMessageID, + row.Role, + row.Content, + row.Metadata, + row.CreatedAt, + ) +} + +func toMessageFromLatestRow(row sqlc.ListMessagesLatestRow) Message { + return toMessageFields( + row.ID, + row.BotID, + row.RouteID, + row.SenderChannelIdentityID, + row.SenderUserID, + row.Platform, + row.ExternalMessageID, + row.SourceReplyToMessageID, + row.Role, + row.Content, + row.Metadata, + row.CreatedAt, + ) +} + +func toMessageFields( + id pgtype.UUID, + botID pgtype.UUID, + routeID pgtype.UUID, + senderChannelIdentityID pgtype.UUID, + senderUserID pgtype.UUID, + platform pgtype.Text, + externalMessageID pgtype.Text, + sourceReplyToMessageID pgtype.Text, + role string, + content []byte, + metadata []byte, + createdAt pgtype.Timestamptz, +) Message { + return Message{ + ID: id.String(), + BotID: botID.String(), + RouteID: routeID.String(), + SenderChannelIdentityID: senderChannelIdentityID.String(), + SenderUserID: senderUserID.String(), + Platform: dbpkg.TextToString(platform), + ExternalMessageID: dbpkg.TextToString(externalMessageID), + SourceReplyToMessageID: dbpkg.TextToString(sourceReplyToMessageID), + Role: role, + Content: json.RawMessage(content), + Metadata: parseJSONMap(metadata), + CreatedAt: createdAt.Time, + } +} + +func toMessagesFromList(rows []sqlc.ListMessagesRow) []Message { + messages := make([]Message, 0, len(rows)) + for _, row := range rows { + messages = append(messages, toMessageFromListRow(row)) + } + return messages +} + +func toMessagesFromSince(rows []sqlc.ListMessagesSinceRow) []Message { + messages := make([]Message, 0, len(rows)) + for _, row := range rows { + messages = append(messages, toMessageFromSinceRow(row)) + } + return messages +} + +func toMessagesFromLatest(rows []sqlc.ListMessagesLatestRow) []Message { + messages := make([]Message, 0, len(rows)) + for _, row := range rows { + messages = append(messages, toMessageFromLatestRow(row)) + } + return messages +} + +func toMessageFromBeforeRow(row sqlc.ListMessagesBeforeRow) Message { + return toMessageFields( + row.ID, + row.BotID, + row.RouteID, + row.SenderChannelIdentityID, + row.SenderUserID, + row.Platform, + row.ExternalMessageID, + row.SourceReplyToMessageID, + row.Role, + row.Content, + row.Metadata, + row.CreatedAt, + ) +} + +// toMessagesFromBefore returns messages in oldest-first order (ListMessagesBefore returns DESC; we reverse). +func toMessagesFromBefore(rows []sqlc.ListMessagesBeforeRow) []Message { + messages := make([]Message, 0, len(rows)) + for i := len(rows) - 1; i >= 0; i-- { + messages = append(messages, toMessageFromBeforeRow(rows[i])) + } + return messages +} + +func parseOptionalUUID(id string) (pgtype.UUID, error) { + if strings.TrimSpace(id) == "" { + return pgtype.UUID{}, nil + } + return dbpkg.ParseUUID(id) +} + +func toPgText(value string) pgtype.Text { + value = strings.TrimSpace(value) + if value == "" { + return pgtype.Text{} + } + return pgtype.Text{String: value, Valid: true} +} + +func nonNilMap(m map[string]any) map[string]any { + if m == nil { + return map[string]any{} + } + return m +} + +func parseJSONMap(data []byte) map[string]any { + if len(data) == 0 { + return nil + } + var m map[string]any + _ = json.Unmarshal(data, &m) + return m +} + +func (s *DBService) publishMessageCreated(message Message) { + if s.publisher == nil { + return + } + payload, err := json.Marshal(message) + if err != nil { + if s.logger != nil { + s.logger.Warn("marshal message event failed", slog.Any("error", err)) + } + return + } + s.publisher.Publish(event.Event{ + Type: event.EventTypeMessageCreated, + BotID: strings.TrimSpace(message.BotID), + Data: payload, + }) +} diff --git a/internal/message/types.go b/internal/message/types.go new file mode 100644 index 00000000..f5938474 --- /dev/null +++ b/internal/message/types.go @@ -0,0 +1,52 @@ +package message + +import ( + "context" + "encoding/json" + "time" +) + +// Message represents a single persisted bot message. +type Message struct { + ID string `json:"id"` + BotID string `json:"bot_id"` + RouteID string `json:"route_id,omitempty"` + SenderChannelIdentityID string `json:"sender_channel_identity_id,omitempty"` + SenderUserID string `json:"sender_user_id,omitempty"` + Platform string `json:"platform,omitempty"` + ExternalMessageID string `json:"external_message_id,omitempty"` + SourceReplyToMessageID string `json:"source_reply_to_message_id,omitempty"` + Role string `json:"role"` + Content json.RawMessage `json:"content"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +// PersistInput is the input for persisting a message. +type PersistInput struct { + BotID string + RouteID string + SenderChannelIdentityID string + SenderUserID string + Platform string + ExternalMessageID string + SourceReplyToMessageID string + Role string + Content json.RawMessage + Metadata map[string]any +} + +// Writer defines write behavior needed by the inbound router. +type Writer interface { + Persist(ctx context.Context, input PersistInput) (Message, error) +} + +// Service defines message read/write behavior. +type Service interface { + Writer + List(ctx context.Context, botID string) ([]Message, error) + ListSince(ctx context.Context, botID string, since time.Time) ([]Message, error) + ListLatest(ctx context.Context, botID string, limit int32) ([]Message, error) + ListBefore(ctx context.Context, botID string, before time.Time, limit int32) ([]Message, error) + DeleteByBot(ctx context.Context, botID string) error +} diff --git a/internal/models/models.go b/internal/models/models.go index 45993883..e648c8f0 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -33,7 +34,7 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro } // Convert to sqlc params - llmProviderID, err := parseUUID(model.LlmProviderID) + llmProviderID, err := db.ParseUUID(model.LlmProviderID) if err != nil { return AddResponse{}, fmt.Errorf("invalid llm provider ID: %w", err) } @@ -78,7 +79,7 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro // GetByID retrieves a model by its internal UUID func (s *Service) GetByID(ctx context.Context, id string) (GetResponse, error) { - uuid, err := parseUUID(id) + uuid, err := db.ParseUUID(id) if err != nil { return GetResponse{}, fmt.Errorf("invalid ID: %w", err) } @@ -148,7 +149,7 @@ func (s *Service) ListByProviderID(ctx context.Context, providerID string) ([]Ge if strings.TrimSpace(providerID) == "" { return nil, fmt.Errorf("provider id is required") } - uuid, err := parseUUID(providerID) + uuid, err := db.ParseUUID(providerID) if err != nil { return nil, fmt.Errorf("invalid provider id: %w", err) } @@ -167,7 +168,7 @@ func (s *Service) ListByProviderIDAndType(ctx context.Context, providerID string if strings.TrimSpace(providerID) == "" { return nil, fmt.Errorf("provider id is required") } - uuid, err := parseUUID(providerID) + uuid, err := db.ParseUUID(providerID) if err != nil { return nil, fmt.Errorf("invalid provider id: %w", err) } @@ -183,7 +184,7 @@ func (s *Service) ListByProviderIDAndType(ctx context.Context, providerID string // UpdateByID updates a model by its internal UUID func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) (GetResponse, error) { - uuid, err := parseUUID(id) + uuid, err := db.ParseUUID(id) if err != nil { return GetResponse{}, fmt.Errorf("invalid ID: %w", err) } @@ -199,7 +200,7 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) Type: string(model.Type), } - llmProviderID, err := parseUUID(model.LlmProviderID) + llmProviderID, err := db.ParseUUID(model.LlmProviderID) if err != nil { return GetResponse{}, fmt.Errorf("invalid llm provider ID: %w", err) } @@ -238,7 +239,7 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat Type: string(model.Type), } - llmProviderID, err := parseUUID(model.LlmProviderID) + llmProviderID, err := db.ParseUUID(model.LlmProviderID) if err != nil { return GetResponse{}, fmt.Errorf("invalid llm provider ID: %w", err) } @@ -262,7 +263,7 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat // DeleteByID deletes a model by its internal UUID func (s *Service) DeleteByID(ctx context.Context, id string) error { - uuid, err := parseUUID(id) + uuid, err := db.ParseUUID(id) if err != nil { return fmt.Errorf("invalid ID: %w", err) } @@ -311,19 +312,6 @@ func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64, // Helper functions -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(id) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID format: %w", err) - } - - var pgUUID pgtype.UUID - copy(pgUUID.Bytes[:], parsed[:]) - pgUUID.Valid = true - - return pgUUID, nil -} - func convertToGetResponse(dbModel sqlc.Model) GetResponse { resp := GetResponse{ ModelId: dbModel.ModelID, @@ -335,8 +323,8 @@ func convertToGetResponse(dbModel sqlc.Model) GetResponse { }, } - if llmProviderID, ok := uuidStringFromPgUUID(dbModel.LlmProviderID); ok { - resp.Model.LlmProviderID = llmProviderID + if dbModel.LlmProviderID.Valid { + resp.Model.LlmProviderID = dbModel.LlmProviderID.String() } if dbModel.Name.Valid { @@ -382,17 +370,6 @@ func isValidClientType(clientType ClientType) bool { } } -func uuidStringFromPgUUID(value pgtype.UUID) (string, bool) { - if !value.Valid { - return "", false - } - id, err := uuid.FromBytes(value.Bytes[:]) - if err != nil { - return "", false - } - return id.String(), true -} - // SelectMemoryModel selects a chat model for memory operations. func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sqlc.Queries) (GetResponse, sqlc.LlmProvider, error) { if modelsService == nil { @@ -415,7 +392,7 @@ func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID st if strings.TrimSpace(providerID) == "" { return sqlc.LlmProvider{}, fmt.Errorf("provider id missing") } - parsed, err := parseUUID(providerID) + parsed, err := db.ParseUUID(providerID) if err != nil { return sqlc.LlmProvider{}, err } diff --git a/internal/models/models_test.go b/internal/models/models_test.go index 42573164..35fe4b31 100644 --- a/internal/models/models_test.go +++ b/internal/models/models_test.go @@ -13,7 +13,7 @@ import ( func ExampleService_Create() { // Example usage - in real code, you would initialize with actual database connection // service := models.NewService(queries) - + // ctx := context.Background() // req := models.AddRequest{ // ModelID: "gpt-4", @@ -21,7 +21,7 @@ func ExampleService_Create() { // LlmProviderID: "11111111-1111-1111-1111-111111111111", // Type: models.ModelTypeChat, // } - + // resp, err := service.Create(ctx, req) // if err != nil { // // handle error @@ -32,7 +32,7 @@ func ExampleService_Create() { func ExampleService_GetByModelID() { // Example usage // service := models.NewService(queries) - + // ctx := context.Background() // resp, err := service.GetByModelID(ctx, "gpt-4") // if err != nil { @@ -44,7 +44,7 @@ func ExampleService_GetByModelID() { func ExampleService_List() { // Example usage // service := models.NewService(queries) - + // ctx := context.Background() // models, err := service.List(ctx) // if err != nil { @@ -58,7 +58,7 @@ func ExampleService_List() { func ExampleService_ListByType() { // Example usage // service := models.NewService(queries) - + // ctx := context.Background() // chatModels, err := service.ListByType(ctx, models.ModelTypeChat) // if err != nil { @@ -70,7 +70,7 @@ func ExampleService_ListByType() { func ExampleService_UpdateByModelID() { // Example usage // service := models.NewService(queries) - + // ctx := context.Background() // req := models.UpdateRequest{ // ModelID: "gpt-4", @@ -78,7 +78,7 @@ func ExampleService_UpdateByModelID() { // LlmProviderID: "11111111-1111-1111-1111-111111111111", // Type: models.ModelTypeChat, // } - + // resp, err := service.UpdateByModelID(ctx, "gpt-4", req) // if err != nil { // // handle error @@ -89,7 +89,7 @@ func ExampleService_UpdateByModelID() { func ExampleService_DeleteByModelID() { // Example usage // service := models.NewService(queries) - + // ctx := context.Background() // err := service.DeleteByModelID(ctx, "gpt-4") // if err != nil { @@ -208,7 +208,7 @@ func TestModelTypes(t *testing.T) { // } // // ctx := context.Background() -// +// // // Setup database connection // pool, err := db.Open(ctx, config.PostgresConfig{ // Host: "localhost", @@ -271,4 +271,3 @@ func TestModelTypes(t *testing.T) { // err = service.DeleteByModelID(ctx, "test-gpt-4") // require.NoError(t, err) // } - diff --git a/internal/models/types.go b/internal/models/types.go index c8252bb4..c0ef4df2 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -32,13 +32,13 @@ const ( ) type Model struct { - ModelID string `json:"model_id" validate:"required"` - Name string `json:"name" validate:"required"` - LlmProviderID string `json:"llm_provider_id" validate:"required"` - IsMultimodal bool `json:"is_multimodal"` - Input []string `json:"input"` - Type ModelType `json:"type" validate:"required"` - Dimensions int `json:"dimensions"` + ModelID string `json:"model_id"` + Name string `json:"name"` + LlmProviderID string `json:"llm_provider_id"` + IsMultimodal bool `json:"is_multimodal"` + Input []string `json:"input"` + Type ModelType `json:"type"` + Dimensions int `json:"dimensions"` } func (m *Model) Validate() error { diff --git a/internal/policy/service.go b/internal/policy/service.go index 518e21bb..2c476d1e 100644 --- a/internal/policy/service.go +++ b/internal/policy/service.go @@ -33,6 +33,7 @@ func NewService(log *slog.Logger, botsService *bots.Service, settingsService *se } } +// Resolve evaluates the full access policy for a bot. func (s *Service) Resolve(ctx context.Context, botID string) (Decision, error) { if s == nil || s.bots == nil || s.settings == nil { return Decision{}, fmt.Errorf("policy service not configured") @@ -59,3 +60,33 @@ func (s *Service) Resolve(ctx context.Context, botID string) (Decision, error) { } return decision, nil } + +// AllowGuest checks if the bot allows guest access. Implements router.PolicyService. +func (s *Service) AllowGuest(ctx context.Context, botID string) (bool, error) { + decision, err := s.Resolve(ctx, botID) + if err != nil { + return false, err + } + return decision.AllowGuest, nil +} + +// BotType returns the normalized bot type. Implements router.PolicyService. +func (s *Service) BotType(ctx context.Context, botID string) (string, error) { + decision, err := s.Resolve(ctx, botID) + if err != nil { + return "", err + } + return decision.BotType, nil +} + +// BotOwnerUserID returns bot owner's user id. Implements router.PolicyService. +func (s *Service) BotOwnerUserID(ctx context.Context, botID string) (string, error) { + if s == nil || s.bots == nil { + return "", fmt.Errorf("policy service not configured") + } + bot, err := s.bots.Get(ctx, strings.TrimSpace(botID)) + if err != nil { + return "", err + } + return strings.TrimSpace(bot.OwnerUserID), nil +} diff --git a/internal/preauth/service.go b/internal/preauth/service.go index 7aa8f5b9..fe017eee 100644 --- a/internal/preauth/service.go +++ b/internal/preauth/service.go @@ -11,6 +11,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -24,6 +25,7 @@ func NewService(queries *sqlc.Queries) *Service { return &Service{queries: queries} } +// Issue creates a new preauth key for the given bot. func (s *Service) Issue(ctx context.Context, botID, issuedByUserID string, ttl time.Duration) (Key, error) { if s.queries == nil { return Key{}, fmt.Errorf("preauth queries not configured") @@ -31,13 +33,13 @@ func (s *Service) Issue(ctx context.Context, botID, issuedByUserID string, ttl t if ttl <= 0 { ttl = 24 * time.Hour } - pgBotID, err := parseUUID(botID) + pgBotID, err := db.ParseUUID(botID) if err != nil { return Key{}, err } pgIssuedBy := pgtype.UUID{Valid: false} if strings.TrimSpace(issuedByUserID) != "" { - parsed, err := parseUUID(issuedByUserID) + parsed, err := db.ParseUUID(issuedByUserID) if err != nil { return Key{}, err } @@ -75,7 +77,7 @@ func (s *Service) MarkUsed(ctx context.Context, id string) (Key, error) { if s.queries == nil { return Key{}, fmt.Errorf("preauth queries not configured") } - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return Key{}, err } @@ -88,38 +90,16 @@ func (s *Service) MarkUsed(ctx context.Context, id string) (Key, error) { func normalizeKey(row sqlc.BotPreauthKey) Key { return Key{ - ID: toUUIDString(row.ID), - BotID: toUUIDString(row.BotID), - Token: strings.TrimSpace(row.Token), - IssuedByUserID: toUUIDString(row.IssuedByUserID), - ExpiresAt: timeFromPg(row.ExpiresAt), - UsedAt: timeFromPg(row.UsedAt), - CreatedAt: timeFromPg(row.CreatedAt), + ID: row.ID.String(), + BotID: row.BotID.String(), + Token: strings.TrimSpace(row.Token), + IssuedByChannelIdentityID: row.IssuedByUserID.String(), + ExpiresAt: timeFromPg(row.ExpiresAt), + UsedAt: timeFromPg(row.UsedAt), + CreatedAt: timeFromPg(row.CreatedAt), } } -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} - -func toUUIDString(value pgtype.UUID) string { - if !value.Valid { - return "" - } - parsed, err := uuid.FromBytes(value.Bytes[:]) - if err != nil { - return "" - } - return parsed.String() -} - func timeFromPg(value pgtype.Timestamptz) time.Time { if value.Valid { return value.Time diff --git a/internal/preauth/types.go b/internal/preauth/types.go index cd26b086..2527f38d 100644 --- a/internal/preauth/types.go +++ b/internal/preauth/types.go @@ -2,12 +2,13 @@ package preauth import "time" +// Key represents a bot pre-authorization key. type Key struct { - ID string - BotID string - Token string - IssuedByUserID string - ExpiresAt time.Time - UsedAt time.Time - CreatedAt time.Time + ID string + BotID string + Token string + IssuedByChannelIdentityID string + ExpiresAt time.Time + UsedAt time.Time + CreatedAt time.Time } diff --git a/internal/providers/service.go b/internal/providers/service.go index d2118f5f..dba23335 100644 --- a/internal/providers/service.go +++ b/internal/providers/service.go @@ -7,9 +7,7 @@ import ( "log/slog" "strings" - "github.com/google/uuid" - "github.com/jackc/pgx/v5/pgtype" - + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -57,7 +55,7 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, e // Get retrieves a provider by ID func (s *Service) Get(ctx context.Context, id string) (GetResponse, error) { - providerID, err := parseUUID(id) + providerID, err := db.ParseUUID(id) if err != nil { return GetResponse{}, err } @@ -114,7 +112,7 @@ func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) ( // Update updates an existing provider func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (GetResponse, error) { - providerID, err := parseUUID(id) + providerID, err := db.ParseUUID(id) if err != nil { return GetResponse{}, err } @@ -176,7 +174,7 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get // Delete deletes a provider by ID func (s *Service) Delete(ctx context.Context, id string) error { - providerID, err := parseUUID(id) + providerID, err := db.ParseUUID(id) if err != nil { return err } @@ -219,13 +217,8 @@ func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse { // Mask API key (show only first 8 characters) maskedAPIKey := maskAPIKey(provider.ApiKey) - // Convert pgtype.UUID to string - var id [16]byte - copy(id[:], provider.ID.Bytes[:]) - idUUID := uuid.UUID(id) - return GetResponse{ - ID: idUUID.String(), + ID: provider.ID.String(), Name: provider.Name, ClientType: provider.ClientType, BaseURL: provider.BaseUrl, @@ -236,18 +229,6 @@ func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse { } } -// parseUUID parses a UUID string -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(id) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} - // isValidClientType checks if a client type is valid func isValidClientType(clientType ClientType) bool { switch clientType { diff --git a/internal/providers/types.go b/internal/providers/types.go index e2556f64..4eec664e 100644 --- a/internal/providers/types.go +++ b/internal/providers/types.go @@ -33,14 +33,14 @@ type UpdateRequest struct { // GetResponse represents the response for getting a provider type GetResponse struct { - ID string `json:"id" validate:"required"` - Name string `json:"name" validate:"required"` - ClientType string `json:"client_type" validate:"required"` - BaseURL string `json:"base_url" validate:"required"` + ID string `json:"id"` + Name string `json:"name"` + ClientType string `json:"client_type"` + BaseURL string `json:"base_url"` APIKey string `json:"api_key,omitempty"` // masked in response Metadata map[string]any `json:"metadata,omitempty"` - CreatedAt time.Time `json:"created_at" validate:"required"` - UpdatedAt time.Time `json:"updated_at" validate:"required"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // ListResponse represents the response for listing providers diff --git a/internal/router/channel.go b/internal/router/channel.go index 3fadd484..cefc25d3 100644 --- a/internal/router/channel.go +++ b/internal/router/channel.go @@ -12,61 +12,74 @@ import ( "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/channel" - "github.com/memohai/memoh/internal/chat" - "github.com/memohai/memoh/internal/contacts" + "github.com/memohai/memoh/internal/channel/route" + "github.com/memohai/memoh/internal/conversation" + "github.com/memohai/memoh/internal/conversation/flow" + messagepkg "github.com/memohai/memoh/internal/message" ) -// ChatGateway 抽象聊天能力,避免路由层直接依赖具体实现。 -type ChatGateway interface { - Chat(ctx context.Context, req chat.ChatRequest) (chat.ChatResponse, error) -} - -type ContactService interface { - GetByID(ctx context.Context, contactID string) (contacts.Contact, error) - GetByUserID(ctx context.Context, botID, userID string) (contacts.Contact, error) - GetByChannelIdentity(ctx context.Context, botID, platform, externalID string) (contacts.ContactChannel, error) - Create(ctx context.Context, req contacts.CreateRequest) (contacts.Contact, error) - CreateGuest(ctx context.Context, botID, displayName string) (contacts.Contact, error) - UpsertChannel(ctx context.Context, botID, contactID, platform, externalID string, metadata map[string]any) (contacts.ContactChannel, error) -} - const ( - silentReplyToken = "NO_REPLY" - minDuplicateTextLength = 10 + silentReplyToken = "NO_REPLY" + minDuplicateTextLength = 10 + processingStatusTimeout = 60 * time.Second ) var ( whitespacePattern = regexp.MustCompile(`\s+`) ) -// ChannelInboundProcessor 将 channel 入站消息路由到 chat,并返回可发送的回复。 -type ChannelInboundProcessor struct { - chat ChatGateway - registry *channel.Registry - logger *slog.Logger - jwtSecret string - tokenTTL time.Duration - identity *IdentityResolver +// RouteResolver resolves and manages channel routes. +type RouteResolver interface { + ResolveConversation(ctx context.Context, input route.ResolveInput) (route.ResolveConversationResult, error) } -func NewChannelInboundProcessor(log *slog.Logger, registry *channel.Registry, store channel.ConfigStore, chatGateway ChatGateway, contactService ContactService, policyService PolicyService, preauthService PreauthService, jwtSecret string, tokenTTL time.Duration) *ChannelInboundProcessor { +// ChannelInboundProcessor routes channel inbound messages to the chat gateway. +type ChannelInboundProcessor struct { + runner flow.Runner + routeResolver RouteResolver + message messagepkg.Writer + registry *channel.Registry + logger *slog.Logger + jwtSecret string + tokenTTL time.Duration + identity *IdentityResolver +} + +// NewChannelInboundProcessor creates a processor with channel identity-based resolution. +func NewChannelInboundProcessor( + log *slog.Logger, + registry *channel.Registry, + routeResolver RouteResolver, + messageWriter messagepkg.Writer, + runner flow.Runner, + channelIdentityService ChannelIdentityService, + memberService BotMemberService, + policyService PolicyService, + preauthService PreauthService, + bindService BindService, + jwtSecret string, + tokenTTL time.Duration, +) *ChannelInboundProcessor { if log == nil { log = slog.Default() } if tokenTTL <= 0 { tokenTTL = 5 * time.Minute } - identityResolver := NewIdentityResolver(log, registry, store, contactService, policyService, preauthService, "", "") + identityResolver := NewIdentityResolver(log, registry, channelIdentityService, memberService, policyService, preauthService, bindService, "", "") return &ChannelInboundProcessor{ - chat: chatGateway, - registry: registry, - logger: log.With(slog.String("component", "channel_router")), - jwtSecret: strings.TrimSpace(jwtSecret), - tokenTTL: tokenTTL, - identity: identityResolver, + runner: runner, + routeResolver: routeResolver, + message: messageWriter, + registry: registry, + logger: log.With(slog.String("component", "channel_router")), + jwtSecret: strings.TrimSpace(jwtSecret), + tokenTTL: tokenTTL, + identity: identityResolver, } } +// IdentityMiddleware returns the identity resolution middleware. func (p *ChannelInboundProcessor) IdentityMiddleware() channel.Middleware { if p == nil || p.identity == nil { return nil @@ -74,8 +87,9 @@ func (p *ChannelInboundProcessor) IdentityMiddleware() channel.Middleware { return p.identity.Middleware() } -func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, sender channel.ReplySender) error { - if p.chat == nil { +// HandleInbound processes an inbound channel message through identity resolution and chat gateway. +func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, sender channel.StreamReplySender) error { + if p.runner == nil { return fmt.Errorf("channel inbound processor not configured") } if sender == nil { @@ -96,29 +110,80 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel Message: state.Decision.Reply, }) } + if p.logger != nil { + p.logger.Info( + "inbound dropped by identity policy (no reply sent)", + slog.String("channel", msg.Channel.String()), + slog.String("bot_id", strings.TrimSpace(state.Identity.BotID)), + slog.String("conversation_type", strings.TrimSpace(msg.Conversation.Type)), + slog.String("conversation_id", strings.TrimSpace(msg.Conversation.ID)), + ) + } return nil } identity := state.Identity - sessionToken := "" + // Resolve or create the route via channel_routes. + if p.routeResolver == nil { + return fmt.Errorf("route resolver not configured") + } + resolved, err := p.routeResolver.ResolveConversation(ctx, route.ResolveInput{ + BotID: identity.BotID, + Platform: msg.Channel.String(), + ConversationID: msg.Conversation.ID, + ThreadID: extractThreadID(msg), + ConversationType: msg.Conversation.Type, + ChannelIdentityID: identity.UserID, + ChannelConfigID: identity.ChannelConfigID, + ReplyTarget: strings.TrimSpace(msg.ReplyTarget), + }) + if err != nil { + return fmt.Errorf("resolve route conversation: %w", err) + } + // Bot-centric history container: + // always persist channel traffic under bot_id so WebUI can view unified cross-platform history. + activeChatID := strings.TrimSpace(identity.BotID) + if activeChatID == "" { + activeChatID = strings.TrimSpace(resolved.ChatID) + } + if !shouldTriggerAssistantResponse(msg) && !identity.ForceReply { + if p.logger != nil { + p.logger.Info( + "inbound not triggering assistant (group trigger condition not met)", + slog.String("channel", msg.Channel.String()), + slog.String("bot_id", strings.TrimSpace(identity.BotID)), + slog.String("route_id", strings.TrimSpace(resolved.RouteID)), + slog.Bool("is_mentioned", metadataBool(msg.Metadata, "is_mentioned")), + slog.Bool("is_reply_to_bot", metadataBool(msg.Metadata, "is_reply_to_bot")), + slog.String("conversation_type", strings.TrimSpace(msg.Conversation.Type)), + ) + } + p.persistInboundUser(ctx, resolved.RouteID, identity, msg, text, "passive_sync") + return nil + } + userMessagePersisted := p.persistInboundUser(ctx, resolved.RouteID, identity, msg, text, "active_chat") + + // Issue chat token for reply routing. + chatToken := "" if p.jwtSecret != "" && strings.TrimSpace(msg.ReplyTarget) != "" { - signed, _, err := auth.GenerateSessionToken(auth.SessionToken{ - BotID: identity.BotID, - Platform: msg.Channel.String(), - ReplyTarget: strings.TrimSpace(msg.ReplyTarget), - SessionID: identity.SessionID, - ContactID: identity.ContactID, + signed, _, err := auth.GenerateChatToken(auth.ChatToken{ + BotID: identity.BotID, + ChatID: activeChatID, + RouteID: resolved.RouteID, + UserID: identity.UserID, + ChannelIdentityID: identity.ChannelIdentityID, }, p.jwtSecret, p.tokenTTL) if err != nil { if p.logger != nil { - p.logger.Warn("issue session token failed", slog.Any("error", err)) + p.logger.Warn("issue chat token failed", slog.Any("error", err)) } } else { - sessionToken = signed + chatToken = signed } } + // Issue user JWT for downstream calls (MCP, schedule, etc.). For guests use chat token as Bearer. token := "" if identity.UserID != "" && p.jwtSecret != "" { signed, _, err := auth.GenerateToken(identity.UserID, p.jwtSecret, p.tokenTTL) @@ -130,42 +195,182 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel token = "Bearer " + signed } } + if token == "" && chatToken != "" { + token = "Bearer " + chatToken + } + var desc channel.Descriptor if p.registry != nil { desc, _ = p.registry.GetDescriptor(msg.Channel) } - resp, err := p.chat.Chat(ctx, chat.ChatRequest{ - BotID: identity.BotID, - SessionID: identity.SessionID, - Token: token, - UserID: identity.UserID, - ContactID: identity.ContactID, - ContactName: strings.TrimSpace(identity.Contact.DisplayName), - ContactAlias: strings.TrimSpace(identity.Contact.Alias), - ReplyTarget: strings.TrimSpace(msg.ReplyTarget), - SessionToken: sessionToken, - Query: text, - CurrentChannel: msg.Channel.String(), - Channels: []string{msg.Channel.String()}, - }) - if err != nil { - if p.logger != nil { - p.logger.Error("chat gateway failed", slog.String("channel", msg.Channel.String()), slog.String("user_id", identity.UserID), slog.Any("error", err)) - } - return err + statusInfo := channel.ProcessingStatusInfo{ + BotID: identity.BotID, + ChatID: activeChatID, + RouteID: resolved.RouteID, + ChannelIdentityID: identity.ChannelIdentityID, + UserID: identity.UserID, + Query: text, + ReplyTarget: strings.TrimSpace(msg.ReplyTarget), + SourceMessageID: strings.TrimSpace(msg.Message.ID), } - outputs := chat.ExtractAssistantOutputs(resp.Messages) - if len(outputs) == 0 { - return nil + statusNotifier := p.resolveProcessingStatusNotifier(msg.Channel) + statusHandle := channel.ProcessingStatusHandle{} + if statusNotifier != nil { + handle, notifyErr := p.notifyProcessingStarted(ctx, statusNotifier, cfg, msg, statusInfo) + if notifyErr != nil { + p.logProcessingStatusError("processing_started", msg, identity, notifyErr) + } else { + statusHandle = handle + } } target := strings.TrimSpace(msg.ReplyTarget) if target == "" { - return fmt.Errorf("reply target missing") + err := fmt.Errorf("reply target missing") + if statusNotifier != nil { + if notifyErr := p.notifyProcessingFailed(ctx, statusNotifier, cfg, msg, statusInfo, statusHandle, err); notifyErr != nil { + p.logProcessingStatusError("processing_failed", msg, identity, notifyErr) + } + } + return err } - sentTexts, suppressReplies := collectMessageToolContext(p.registry, resp.Messages, msg.Channel, target) + sourceMessageID := strings.TrimSpace(msg.Message.ID) + replyRef := &channel.ReplyRef{Target: target} + if sourceMessageID != "" { + replyRef.MessageID = sourceMessageID + } + stream, err := sender.OpenStream(ctx, target, channel.StreamOptions{ + Reply: replyRef, + SourceMessageID: sourceMessageID, + Metadata: map[string]any{ + "route_id": resolved.RouteID, + }, + }) + if err != nil { + if statusNotifier != nil { + if notifyErr := p.notifyProcessingFailed(ctx, statusNotifier, cfg, msg, statusInfo, statusHandle, err); notifyErr != nil { + p.logProcessingStatusError("processing_failed", msg, identity, notifyErr) + } + } + return err + } + defer func() { + _ = stream.Close(context.WithoutCancel(ctx)) + }() + + if err := stream.Push(ctx, channel.StreamEvent{ + Type: channel.StreamEventStatus, + Status: channel.StreamStatusStarted, + }); err != nil { + if statusNotifier != nil { + if notifyErr := p.notifyProcessingFailed(ctx, statusNotifier, cfg, msg, statusInfo, statusHandle, err); notifyErr != nil { + p.logProcessingStatusError("processing_failed", msg, identity, notifyErr) + } + } + return err + } + + chunkCh, streamErrCh := p.runner.StreamChat(ctx, flow.ChatRequest{ + BotID: identity.BotID, + ChatID: activeChatID, + Token: token, + UserID: identity.UserID, + SourceChannelIdentityID: identity.ChannelIdentityID, + DisplayName: identity.DisplayName, + RouteID: resolved.RouteID, + ChatToken: chatToken, + ExternalMessageID: sourceMessageID, + Query: text, + CurrentChannel: msg.Channel.String(), + Channels: []string{msg.Channel.String()}, + UserMessagePersisted: userMessagePersisted, + }) + + var ( + finalMessages []conversation.ModelMessage + streamErr error + ) + for chunkCh != nil || streamErrCh != nil { + select { + case chunk, ok := <-chunkCh: + if !ok { + chunkCh = nil + continue + } + events, messages, parseErr := mapStreamChunkToChannelEvents(chunk) + if parseErr != nil { + if p.logger != nil { + p.logger.Warn( + "stream chunk parse failed", + slog.String("channel", msg.Channel.String()), + slog.String("channel_identity_id", identity.ChannelIdentityID), + slog.String("user_id", identity.UserID), + slog.Any("error", parseErr), + ) + } + continue + } + for _, event := range events { + if pushErr := stream.Push(ctx, event); pushErr != nil { + streamErr = pushErr + break + } + } + if len(messages) > 0 { + finalMessages = messages + } + case err, ok := <-streamErrCh: + if !ok { + streamErrCh = nil + continue + } + if err != nil { + streamErr = err + } + } + if streamErr != nil { + break + } + } + + if streamErr != nil { + if p.logger != nil { + p.logger.Error( + "chat gateway stream failed", + slog.String("channel", msg.Channel.String()), + slog.String("channel_identity_id", identity.ChannelIdentityID), + slog.String("user_id", identity.UserID), + slog.Any("error", streamErr), + ) + } + _ = stream.Push(ctx, channel.StreamEvent{ + Type: channel.StreamEventError, + Error: streamErr.Error(), + }) + if statusNotifier != nil { + if notifyErr := p.notifyProcessingFailed(ctx, statusNotifier, cfg, msg, statusInfo, statusHandle, streamErr); notifyErr != nil { + p.logProcessingStatusError("processing_failed", msg, identity, notifyErr) + } + } + return streamErr + } + + sentTexts, suppressReplies := collectMessageToolContext(p.registry, finalMessages, msg.Channel, target) if suppressReplies { + if err := stream.Push(ctx, channel.StreamEvent{ + Type: channel.StreamEventStatus, + Status: channel.StreamStatusCompleted, + }); err != nil { + return err + } + if statusNotifier != nil { + if notifyErr := p.notifyProcessingCompleted(ctx, statusNotifier, cfg, msg, statusInfo, statusHandle); notifyErr != nil { + p.logProcessingStatusError("processing_completed", msg, identity, notifyErr) + } + } return nil } + + outputs := flow.ExtractAssistantOutputs(finalMessages) for _, output := range outputs { outMessage := buildChannelMessage(output, desc.Capabilities) if outMessage.IsEmpty() { @@ -178,17 +383,170 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel if isMessagingToolDuplicate(plainText, sentTexts) { continue } - if err := sender.Send(ctx, channel.OutboundMessage{ - Target: target, - Message: outMessage, + if outMessage.Reply == nil && sourceMessageID != "" { + outMessage.Reply = &channel.ReplyRef{ + Target: target, + MessageID: sourceMessageID, + } + } + if err := stream.Push(ctx, channel.StreamEvent{ + Type: channel.StreamEventFinal, + Final: &channel.StreamFinalizePayload{ + Message: outMessage, + }, }); err != nil { return err } } + if err := stream.Push(ctx, channel.StreamEvent{ + Type: channel.StreamEventStatus, + Status: channel.StreamStatusCompleted, + }); err != nil { + return err + } + if statusNotifier != nil { + if notifyErr := p.notifyProcessingCompleted(ctx, statusNotifier, cfg, msg, statusInfo, statusHandle); notifyErr != nil { + p.logProcessingStatusError("processing_completed", msg, identity, notifyErr) + } + } return nil } -func buildChannelMessage(output chat.AssistantOutput, capabilities channel.ChannelCapabilities) channel.Message { +func shouldTriggerAssistantResponse(msg channel.InboundMessage) bool { + if isDirectConversationType(msg.Conversation.Type) { + return true + } + if metadataBool(msg.Metadata, "is_mentioned") { + return true + } + if metadataBool(msg.Metadata, "is_reply_to_bot") { + return true + } + return hasCommandPrefix(msg.Message.PlainText(), msg.Metadata) +} + +func isDirectConversationType(conversationType string) bool { + ct := strings.ToLower(strings.TrimSpace(conversationType)) + return ct == "" || ct == "p2p" || ct == "private" || ct == "direct" +} + +func hasCommandPrefix(text string, metadata map[string]any) bool { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return false + } + prefixes := []string{"/"} + if metadata != nil { + if raw, ok := metadata["command_prefix"]; ok { + if value := strings.TrimSpace(fmt.Sprint(raw)); value != "" { + prefixes = []string{value} + } + } + if raw, ok := metadata["command_prefixes"]; ok { + if parsed := parseCommandPrefixes(raw); len(parsed) > 0 { + prefixes = parsed + } + } + } + for _, prefix := range prefixes { + if strings.HasPrefix(trimmed, prefix) { + return true + } + } + return false +} + +func parseCommandPrefixes(raw any) []string { + if items, ok := raw.([]string); ok { + result := make([]string, 0, len(items)) + for _, item := range items { + value := strings.TrimSpace(item) + if value == "" { + continue + } + result = append(result, value) + } + return result + } + items, ok := raw.([]any) + if !ok { + return nil + } + result := make([]string, 0, len(items)) + for _, item := range items { + value := strings.TrimSpace(fmt.Sprint(item)) + if value == "" { + continue + } + result = append(result, value) + } + return result +} + +func metadataBool(metadata map[string]any, key string) bool { + if metadata == nil { + return false + } + raw, ok := metadata[key] + if !ok { + return false + } + switch value := raw.(type) { + case bool: + return value + case string: + switch strings.ToLower(strings.TrimSpace(value)) { + case "1", "true", "yes", "on": + return true + default: + return false + } + default: + return false + } +} + +func (p *ChannelInboundProcessor) persistInboundUser(ctx context.Context, routeID string, identity InboundIdentity, msg channel.InboundMessage, query string, triggerMode string) bool { + if p.message == nil { + return false + } + botID := strings.TrimSpace(identity.BotID) + if botID == "" { + return false + } + payload, err := json.Marshal(conversation.ModelMessage{ + Role: "user", + Content: conversation.NewTextContent(query), + }) + if err != nil { + if p.logger != nil { + p.logger.Warn("marshal inbound user message failed", slog.Any("error", err)) + } + return false + } + meta := map[string]any{ + "route_id": strings.TrimSpace(routeID), + "platform": msg.Channel.String(), + "trigger_mode": strings.TrimSpace(triggerMode), + } + if _, err := p.message.Persist(ctx, messagepkg.PersistInput{ + BotID: botID, + RouteID: strings.TrimSpace(routeID), + SenderChannelIdentityID: strings.TrimSpace(identity.ChannelIdentityID), + SenderUserID: strings.TrimSpace(identity.UserID), + Platform: msg.Channel.String(), + ExternalMessageID: strings.TrimSpace(msg.Message.ID), + Role: "user", + Content: payload, + Metadata: meta, + }); err != nil && p.logger != nil { + p.logger.Warn("persist inbound user message failed", slog.Any("error", err)) + return false + } + return true +} + +func buildChannelMessage(output conversation.AssistantOutput, capabilities channel.ChannelCapabilities) channel.Message { msg := channel.Message{} if strings.TrimSpace(output.Content) != "" { msg.Text = strings.TrimSpace(output.Content) @@ -207,13 +565,13 @@ func buildChannelMessage(output chat.AssistantOutput, capabilities channel.Chann } partType := normalizeContentPartType(part.Type) parts = append(parts, channel.MessagePart{ - Type: partType, - Text: part.Text, - URL: part.URL, - Styles: normalizeContentPartStyles(part.Styles), - Language: part.Language, - UserID: part.UserID, - Emoji: part.Emoji, + Type: partType, + Text: part.Text, + URL: part.URL, + Styles: normalizeContentPartStyles(part.Styles), + Language: part.Language, + ChannelIdentityID: part.ChannelIdentityID, + Emoji: part.Emoji, }) } if len(parts) > 0 { @@ -261,7 +619,7 @@ func containsMarkdown(text string) bool { return false } -func contentPartHasValue(part chat.ContentPart) bool { +func contentPartHasValue(part conversation.ContentPart) bool { if strings.TrimSpace(part.Text) != "" { return true } @@ -274,7 +632,7 @@ func contentPartHasValue(part chat.ContentPart) bool { return false } -func contentPartText(part chat.ContentPart) string { +func contentPartText(part conversation.ContentPart) string { if strings.TrimSpace(part.Text) != "" { return part.Text } @@ -287,6 +645,79 @@ func contentPartText(part chat.ContentPart) string { return "" } +type gatewayStreamEnvelope struct { + Type string `json:"type"` + Delta string `json:"delta"` + Error string `json:"error"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + Messages []conversation.ModelMessage `json:"messages"` +} + +type gatewayStreamDoneData struct { + Messages []conversation.ModelMessage `json:"messages"` +} + +func mapStreamChunkToChannelEvents(chunk conversation.StreamChunk) ([]channel.StreamEvent, []conversation.ModelMessage, error) { + if len(chunk) == 0 { + return nil, nil, nil + } + var envelope gatewayStreamEnvelope + if err := json.Unmarshal(chunk, &envelope); err != nil { + return nil, nil, err + } + finalMessages := make([]conversation.ModelMessage, 0, len(envelope.Messages)) + finalMessages = append(finalMessages, envelope.Messages...) + if len(finalMessages) == 0 && len(envelope.Data) > 0 { + var done gatewayStreamDoneData + if err := json.Unmarshal(envelope.Data, &done); err == nil && len(done.Messages) > 0 { + finalMessages = append(finalMessages, done.Messages...) + } + } + eventType := strings.ToLower(strings.TrimSpace(envelope.Type)) + switch eventType { + case "text_delta": + if envelope.Delta == "" { + return nil, finalMessages, nil + } + return []channel.StreamEvent{ + { + Type: channel.StreamEventDelta, + Delta: envelope.Delta, + }, + }, finalMessages, nil + case "reasoning_delta": + if envelope.Delta == "" { + return nil, finalMessages, nil + } + return []channel.StreamEvent{ + { + Type: channel.StreamEventDelta, + Delta: envelope.Delta, + Metadata: map[string]any{ + "phase": "reasoning", + }, + }, + }, finalMessages, nil + case "error": + streamError := strings.TrimSpace(envelope.Error) + if streamError == "" { + streamError = strings.TrimSpace(envelope.Message) + } + if streamError == "" { + streamError = "stream error" + } + return []channel.StreamEvent{ + { + Type: channel.StreamEventError, + Error: streamError, + }, + }, finalMessages, nil + default: + return nil, finalMessages, nil + } +} + func buildInboundQuery(message channel.Message) string { text := strings.TrimSpace(message.PlainText()) if len(message.Attachments) == 0 { @@ -299,7 +730,7 @@ func buildInboundQuery(message channel.Message) string { for _, att := range message.Attachments { label := strings.TrimSpace(att.Name) if label == "" { - label = strings.TrimSpace(att.URL) + label = strings.TrimSpace(att.Reference()) } if label == "" { label = "unknown" @@ -350,14 +781,14 @@ func normalizeContentPartStyles(styles []string) []channel.MessageTextStyle { } type sendMessageToolArgs struct { - Platform string `json:"platform"` - Target string `json:"target"` - UserID string `json:"user_id"` - Text string `json:"text"` - Message *channel.Message `json:"message"` + Platform string `json:"platform"` + Target string `json:"target"` + ChannelIdentityID string `json:"channel_identity_id"` + Text string `json:"text"` + Message *channel.Message `json:"message"` } -func collectMessageToolContext(registry *channel.Registry, messages []chat.ModelMessage, channelType channel.ChannelType, replyTarget string) ([]string, bool) { +func collectMessageToolContext(registry *channel.Registry, messages []conversation.ModelMessage, channelType channel.ChannelType, replyTarget string) ([]string, bool) { if len(messages) == 0 { return nil, false } @@ -419,7 +850,7 @@ func shouldSuppressForToolCall(registry *channel.Registry, args sendMessageToolA return false } target := strings.TrimSpace(args.Target) - if target == "" && strings.TrimSpace(args.UserID) == "" { + if target == "" && strings.TrimSpace(args.ChannelIdentityID) == "" { target = replyTarget } if strings.TrimSpace(target) == "" || strings.TrimSpace(replyTarget) == "" { @@ -526,12 +957,88 @@ func isMessagingToolDuplicate(text string, sentTexts []string) bool { return false } +// requireIdentity resolves identity for the current message. Always resolves from msg so each sender is identified correctly (no reuse of context state across messages). func (p *ChannelInboundProcessor) requireIdentity(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) (IdentityState, error) { - if state, ok := IdentityStateFromContext(ctx); ok { - return state, nil - } if p.identity == nil { return IdentityState{}, fmt.Errorf("identity resolver not configured") } return p.identity.Resolve(ctx, cfg, msg) } + +func (p *ChannelInboundProcessor) resolveProcessingStatusNotifier(channelType channel.ChannelType) channel.ProcessingStatusNotifier { + if p == nil || p.registry == nil { + return nil + } + notifier, ok := p.registry.GetProcessingStatusNotifier(channelType) + if !ok { + return nil + } + return notifier +} + +func (p *ChannelInboundProcessor) notifyProcessingStarted( + ctx context.Context, + notifier channel.ProcessingStatusNotifier, + cfg channel.ChannelConfig, + msg channel.InboundMessage, + info channel.ProcessingStatusInfo, +) (channel.ProcessingStatusHandle, error) { + if notifier == nil { + return channel.ProcessingStatusHandle{}, nil + } + statusCtx, cancel := context.WithTimeout(ctx, processingStatusTimeout) + defer cancel() + return notifier.ProcessingStarted(statusCtx, cfg, msg, info) +} + +func (p *ChannelInboundProcessor) notifyProcessingCompleted( + ctx context.Context, + notifier channel.ProcessingStatusNotifier, + cfg channel.ChannelConfig, + msg channel.InboundMessage, + info channel.ProcessingStatusInfo, + handle channel.ProcessingStatusHandle, +) error { + if notifier == nil { + return nil + } + statusCtx, cancel := context.WithTimeout(ctx, processingStatusTimeout) + defer cancel() + return notifier.ProcessingCompleted(statusCtx, cfg, msg, info, handle) +} + +func (p *ChannelInboundProcessor) notifyProcessingFailed( + ctx context.Context, + notifier channel.ProcessingStatusNotifier, + cfg channel.ChannelConfig, + msg channel.InboundMessage, + info channel.ProcessingStatusInfo, + handle channel.ProcessingStatusHandle, + cause error, +) error { + if notifier == nil { + return nil + } + statusCtx, cancel := context.WithTimeout(ctx, processingStatusTimeout) + defer cancel() + return notifier.ProcessingFailed(statusCtx, cfg, msg, info, handle, cause) +} + +func (p *ChannelInboundProcessor) logProcessingStatusError( + stage string, + msg channel.InboundMessage, + identity InboundIdentity, + err error, +) { + if p == nil || p.logger == nil || err == nil { + return + } + p.logger.Warn( + "processing status notify failed", + slog.String("stage", stage), + slog.String("channel", msg.Channel.String()), + slog.String("channel_identity_id", identity.ChannelIdentityID), + slog.String("user_id", identity.UserID), + slog.Any("error", err), + ) +} diff --git a/internal/router/channel_test.go b/internal/router/channel_test.go index 1dc061d6..748bc36b 100644 --- a/internal/router/channel_test.go +++ b/internal/router/channel_test.go @@ -2,117 +2,68 @@ package router import ( "context" - "fmt" + "encoding/json" + "errors" "log/slog" "strings" "testing" "github.com/memohai/memoh/internal/channel" - "github.com/memohai/memoh/internal/chat" - "github.com/memohai/memoh/internal/contacts" - "github.com/memohai/memoh/internal/policy" + "github.com/memohai/memoh/internal/channel/identities" + "github.com/memohai/memoh/internal/channel/route" + "github.com/memohai/memoh/internal/conversation" + messagepkg "github.com/memohai/memoh/internal/message" + "github.com/memohai/memoh/internal/schedule" ) -type fakeConfigStore struct { - session channel.ChannelSession - boundUserID string -} - -func (f *fakeConfigStore) ResolveEffectiveConfig(ctx context.Context, botID string, channelType channel.ChannelType) (channel.ChannelConfig, error) { - return channel.ChannelConfig{}, nil -} - -func (f *fakeConfigStore) GetUserConfig(ctx context.Context, actorUserID string, channelType channel.ChannelType) (channel.ChannelUserBinding, error) { - return channel.ChannelUserBinding{}, fmt.Errorf("not implemented") -} - -func (f *fakeConfigStore) UpsertUserConfig(ctx context.Context, actorUserID string, channelType channel.ChannelType, req channel.UpsertUserConfigRequest) (channel.ChannelUserBinding, error) { - return channel.ChannelUserBinding{}, nil -} - -func (f *fakeConfigStore) ListConfigsByType(ctx context.Context, channelType channel.ChannelType) ([]channel.ChannelConfig, error) { - return nil, nil -} - -func (f *fakeConfigStore) ResolveUserBinding(ctx context.Context, channelType channel.ChannelType, criteria channel.BindingCriteria) (string, error) { - if f.boundUserID == "" { - return "", fmt.Errorf("channel user binding not found") - } - return f.boundUserID, nil -} - -func (f *fakeConfigStore) ListSessionsByBotPlatform(ctx context.Context, botID, platform string) ([]channel.ChannelSession, error) { - return nil, nil -} - -func (f *fakeConfigStore) GetChannelSession(ctx context.Context, sessionID string) (channel.ChannelSession, error) { - if f.session.SessionID == sessionID { - return f.session, nil - } - return channel.ChannelSession{}, nil -} - -func (f *fakeConfigStore) UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string, replyTarget string, threadID string, metadata map[string]any) error { - return nil -} - type fakeChatGateway struct { - resp chat.ChatResponse + resp conversation.ChatResponse err error - gotReq chat.ChatRequest + gotReq conversation.ChatRequest + onChat func(conversation.ChatRequest) } -func (f *fakeChatGateway) Chat(ctx context.Context, req chat.ChatRequest) (chat.ChatResponse, error) { +func (f *fakeChatGateway) Chat(ctx context.Context, req conversation.ChatRequest) (conversation.ChatResponse, error) { f.gotReq = req + if f.onChat != nil { + f.onChat(req) + } return f.resp, f.err } -type fakeContactService struct { - contactID string -} - -func (f *fakeContactService) GetByID(ctx context.Context, contactID string) (contacts.Contact, error) { - return contacts.Contact{}, fmt.Errorf("not found") -} - -func (f *fakeContactService) GetByUserID(ctx context.Context, botID, userID string) (contacts.Contact, error) { - return contacts.Contact{}, fmt.Errorf("not found") -} - -func (f *fakeContactService) GetByChannelIdentity(ctx context.Context, botID, platform, externalID string) (contacts.ContactChannel, error) { - return contacts.ContactChannel{}, fmt.Errorf("not found") -} - -func (f *fakeContactService) Create(ctx context.Context, req contacts.CreateRequest) (contacts.Contact, error) { - return contacts.Contact{ID: "contact-1", BotID: req.BotID, UserID: req.UserID}, nil -} - -func (f *fakeContactService) CreateGuest(ctx context.Context, botID, displayName string) (contacts.Contact, error) { - return contacts.Contact{ID: "contact-guest", BotID: botID}, nil -} - -func (f *fakeContactService) UpsertChannel(ctx context.Context, botID, contactID, platform, externalID string, metadata map[string]any) (contacts.ContactChannel, error) { - return contacts.ContactChannel{ID: "channel-1", ContactID: contactID}, nil -} - -type fakePolicyService struct { - decision policy.Decision - err error -} - -func (f *fakePolicyService) Resolve(ctx context.Context, botID string) (policy.Decision, error) { +func (f *fakeChatGateway) StreamChat(ctx context.Context, req conversation.ChatRequest) (<-chan conversation.StreamChunk, <-chan error) { + f.gotReq = req + if f.onChat != nil { + f.onChat(req) + } + chunks := make(chan conversation.StreamChunk, 1) + errs := make(chan error, 1) if f.err != nil { - return policy.Decision{}, f.err + errs <- f.err + close(chunks) + close(errs) + return chunks, errs } - decision := f.decision - if decision.BotID == "" { - decision.BotID = botID + payload := map[string]any{ + "type": "agent_end", + "messages": f.resp.Messages, } - return decision, nil + data, err := json.Marshal(payload) + if err == nil { + chunks <- conversation.StreamChunk(data) + } + close(chunks) + close(errs) + return chunks, errs +} + +func (f *fakeChatGateway) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error { + return nil } type fakeReplySender struct { - sent []channel.OutboundMessage + sent []channel.OutboundMessage + events []channel.StreamEvent } func (s *fakeReplySender) Send(ctx context.Context, msg channel.OutboundMessage) error { @@ -120,28 +71,151 @@ func (s *fakeReplySender) Send(ctx context.Context, msg channel.OutboundMessage) return nil } -func TestChannelInboundProcessorBoundUser(t *testing.T) { - store := &fakeConfigStore{ - session: channel.ChannelSession{ - SessionID: "feishu:bot-1:chat-1", - UserID: "user-123", +func (s *fakeReplySender) OpenStream(ctx context.Context, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { + return &fakeOutboundStream{ + sender: s, + target: strings.TrimSpace(target), + }, nil +} + +type fakeOutboundStream struct { + sender *fakeReplySender + target string +} + +func (s *fakeOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { + if s == nil || s.sender == nil { + return nil + } + s.sender.events = append(s.sender.events, event) + if event.Type == channel.StreamEventFinal && event.Final != nil && !event.Final.Message.IsEmpty() { + s.sender.sent = append(s.sender.sent, channel.OutboundMessage{ + Target: s.target, + Message: event.Final.Message, + }) + } + return nil +} + +func (s *fakeOutboundStream) Close(ctx context.Context) error { + return nil +} + +type fakeProcessingStatusNotifier struct { + startedHandle channel.ProcessingStatusHandle + startedErr error + completedErr error + failedErr error + events []string + info []channel.ProcessingStatusInfo + completedSeen channel.ProcessingStatusHandle + failedSeen channel.ProcessingStatusHandle + failedCause error +} + +func (n *fakeProcessingStatusNotifier) ProcessingStarted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { + n.events = append(n.events, "started") + n.info = append(n.info, info) + return n.startedHandle, n.startedErr +} + +func (n *fakeProcessingStatusNotifier) ProcessingCompleted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { + n.events = append(n.events, "completed") + n.info = append(n.info, info) + n.completedSeen = handle + return n.completedErr +} + +func (n *fakeProcessingStatusNotifier) ProcessingFailed(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error { + n.events = append(n.events, "failed") + n.info = append(n.info, info) + n.failedSeen = handle + n.failedCause = cause + return n.failedErr +} + +type fakeProcessingStatusAdapter struct { + notifier *fakeProcessingStatusNotifier +} + +func (a *fakeProcessingStatusAdapter) Type() channel.ChannelType { + return channel.ChannelType("feishu") +} + +func (a *fakeProcessingStatusAdapter) Descriptor() channel.Descriptor { + return channel.Descriptor{ + Type: channel.ChannelType("feishu"), + Capabilities: channel.ChannelCapabilities{ + Text: true, + Reply: true, }, } +} + +func (a *fakeProcessingStatusAdapter) ProcessingStarted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { + return a.notifier.ProcessingStarted(ctx, cfg, msg, info) +} + +func (a *fakeProcessingStatusAdapter) ProcessingCompleted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { + return a.notifier.ProcessingCompleted(ctx, cfg, msg, info, handle) +} + +func (a *fakeProcessingStatusAdapter) ProcessingFailed(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error { + return a.notifier.ProcessingFailed(ctx, cfg, msg, info, handle, cause) +} + +type fakeChatService struct { + resolveResult route.ResolveConversationResult + resolveErr error + persisted []messagepkg.Message +} + +func (f *fakeChatService) ResolveConversation(ctx context.Context, input route.ResolveInput) (route.ResolveConversationResult, error) { + if f.resolveErr != nil { + return route.ResolveConversationResult{}, f.resolveErr + } + return f.resolveResult, nil +} + +func (f *fakeChatService) Persist(ctx context.Context, input messagepkg.PersistInput) (messagepkg.Message, error) { + msg := messagepkg.Message{ + BotID: input.BotID, + RouteID: input.RouteID, + SenderChannelIdentityID: input.SenderChannelIdentityID, + SenderUserID: input.SenderUserID, + Platform: input.Platform, + ExternalMessageID: input.ExternalMessageID, + SourceReplyToMessageID: input.SourceReplyToMessageID, + Role: input.Role, + Content: input.Content, + Metadata: input.Metadata, + } + f.persisted = append(f.persisted, msg) + return msg, nil +} + +func TestChannelInboundProcessorWithIdentity(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-1"}} + memberSvc := &fakeMemberService{isMember: true} + policySvc := &fakePolicyService{allow: false} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-1", RouteID: "route-1"}} gateway := &fakeChatGateway{ - resp: chat.ChatResponse{ - Messages: []chat.ModelMessage{ - {Role: "assistant", Content: chat.NewTextContent("AI回复内容")}, + resp: conversation.ChatResponse{ + Messages: []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent("AI reply")}, }, }, } - processor := NewChannelInboundProcessor(slog.Default(), nil, store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} msg := channel.InboundMessage{ + BotID: "bot-1", Channel: channel.ChannelType("feishu"), - Message: channel.Message{Text: "你好"}, + Message: channel.Message{Text: "hello"}, ReplyTarget: "target-id", + Sender: channel.Identity{SubjectID: "ext-1", DisplayName: "User1"}, Conversation: channel.Conversation{ ID: "chat-1", Type: "p2p", @@ -150,48 +224,62 @@ func TestChannelInboundProcessorBoundUser(t *testing.T) { err := processor.HandleInbound(context.Background(), cfg, msg, sender) if err != nil { - t.Fatalf("不应报错: %v", err) + t.Fatalf("unexpected error: %v", err) } - if gateway.gotReq.Query != "你好" { - t.Errorf("Chat 请求 Query 错误: %s", gateway.gotReq.Query) + if gateway.gotReq.Query != "hello" { + t.Errorf("expected query 'hello', got: %s", gateway.gotReq.Query) } - if gateway.gotReq.SessionID != "feishu:bot-1:chat-1" { - t.Errorf("SessionID 传递错误: %s", gateway.gotReq.SessionID) + if gateway.gotReq.UserID != "channelIdentity-1" { + t.Errorf("expected user_id 'channelIdentity-1', got: %s", gateway.gotReq.UserID) } - if len(sender.sent) != 1 || sender.sent[0].Message.PlainText() != "AI回复内容" { - t.Fatalf("应发送 AI 回复,实际: %+v", sender.sent) + if gateway.gotReq.SourceChannelIdentityID != "channelIdentity-1" { + t.Errorf("expected source_channel_identity_id 'channelIdentity-1', got: %s", gateway.gotReq.SourceChannelIdentityID) + } + if gateway.gotReq.ChatID != "bot-1" { + t.Errorf("expected bot-scoped chat id 'bot-1', got: %s", gateway.gotReq.ChatID) + } + if len(sender.sent) != 1 || sender.sent[0].Message.PlainText() != "AI reply" { + t.Fatalf("expected AI reply, got: %+v", sender.sent) } } -func TestChannelInboundProcessorUnboundUser(t *testing.T) { - store := &fakeConfigStore{} +func TestChannelInboundProcessorDenied(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-2"}} + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: false} + chatSvc := &fakeChatService{} gateway := &fakeChatGateway{} - processor := NewChannelInboundProcessor(slog.Default(), nil, store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} msg := channel.InboundMessage{ + BotID: "bot-1", Channel: channel.ChannelType("feishu"), - Message: channel.Message{Text: "你好"}, + Message: channel.Message{Text: "hello"}, ReplyTarget: "target-id", + Sender: channel.Identity{SubjectID: "stranger-1"}, } err := processor.HandleInbound(context.Background(), cfg, msg, sender) if err != nil { - t.Fatalf("不应报错: %v", err) + t.Fatalf("unexpected error: %v", err) } - if len(sender.sent) != 1 || !strings.Contains(sender.sent[0].Message.PlainText(), "陌生人") { - t.Fatalf("应发送绑定提示,实际: %+v", sender.sent) + if len(sender.sent) != 1 || !strings.Contains(sender.sent[0].Message.PlainText(), "denied") { + t.Fatalf("expected access denied reply, got: %+v", sender.sent) } if gateway.gotReq.Query != "" { - t.Error("未绑定用户不应触发 Chat 调用") + t.Error("denied user should not trigger chat call") } } func TestChannelInboundProcessorIgnoreEmpty(t *testing.T) { - store := &fakeConfigStore{} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-3"}} + memberSvc := &fakeMemberService{isMember: true} + policySvc := &fakePolicyService{allow: false} + chatSvc := &fakeChatService{} gateway := &fakeChatGateway{} - processor := NewChannelInboundProcessor(slog.Default(), nil, store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1"} @@ -199,147 +287,420 @@ func TestChannelInboundProcessorIgnoreEmpty(t *testing.T) { err := processor.HandleInbound(context.Background(), cfg, msg, sender) if err != nil { - t.Fatalf("空消息不应报错: %v", err) + t.Fatalf("empty message should not error: %v", err) } if len(sender.sent) != 0 { - t.Fatalf("空消息不应发送回复: %+v", sender.sent) + t.Fatalf("empty message should not produce reply: %+v", sender.sent) } if gateway.gotReq.Query != "" { - t.Error("空消息不应触发 Chat 调用") + t.Error("empty message should not trigger chat call") } } func TestChannelInboundProcessorSilentReply(t *testing.T) { - store := &fakeConfigStore{ - session: channel.ChannelSession{ - SessionID: "feishu:bot-1:chat-1", - UserID: "user-123", - }, - } + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-4"}} + memberSvc := &fakeMemberService{isMember: true} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-4", RouteID: "route-4"}} gateway := &fakeChatGateway{ - resp: chat.ChatResponse{ - Messages: []chat.ModelMessage{ - {Role: "assistant", Content: chat.NewTextContent("NO_REPLY")}, + resp: conversation.ChatResponse{ + Messages: []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent("NO_REPLY")}, }, }, } - processor := NewChannelInboundProcessor(slog.Default(), nil, store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} msg := channel.InboundMessage{ - Channel: channel.ChannelType("feishu"), - Message: channel.Message{Text: "你好"}, - ReplyTarget: "target-id", + BotID: "bot-1", + Channel: channel.ChannelType("telegram"), + Message: channel.Message{Text: "test"}, + ReplyTarget: "chat-123", + Sender: channel.Identity{SubjectID: "user-1"}, Conversation: channel.Conversation{ - ID: "chat-1", + ID: "conv-1", Type: "p2p", }, } err := processor.HandleInbound(context.Background(), cfg, msg, sender) if err != nil { - t.Fatalf("不应报错: %v", err) + t.Fatalf("unexpected error: %v", err) } if len(sender.sent) != 0 { - t.Fatalf("NO_REPLY 不应发送回复,实际: %+v", sender.sent) + t.Fatalf("NO_REPLY should suppress output: %+v", sender.sent) } } -func TestChannelInboundProcessorSuppressOnToolSend(t *testing.T) { - store := &fakeConfigStore{ - session: channel.ChannelSession{ - SessionID: "feishu:bot-1:chat-1", - UserID: "user-123", - }, - } +func TestChannelInboundProcessorGroupPassiveSync(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-5"}} + memberSvc := &fakeMemberService{isMember: true} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-5", RouteID: "route-5"}} gateway := &fakeChatGateway{ - resp: chat.ChatResponse{ - Messages: []chat.ModelMessage{ - { - Role: "assistant", - ToolCalls: []chat.ToolCall{ - { - Type: "function", - Function: chat.ToolCallFunction{ - Name: "send_message", - Arguments: `{"platform":"feishu","target":"target-id","message":{"text":"AI回复内容"}}`, - }, - }, - }, - }, - {Role: "assistant", Content: chat.NewTextContent("AI回复内容")}, + resp: conversation.ChatResponse{ + Messages: []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent("AI reply")}, }, }, } - processor := NewChannelInboundProcessor(slog.Default(), nil, store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) sender := &fakeReplySender{} + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{ID: "msg-1", Text: "hello everyone"}, + ReplyTarget: "chat_id:oc_123", + Sender: channel.Identity{SubjectID: "user-1"}, + Conversation: channel.Conversation{ + ID: "oc_123", + Type: "group", + }, + } + + err := processor.HandleInbound(context.Background(), cfg, msg, sender) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gateway.gotReq.Query != "" { + t.Fatalf("group passive sync should not trigger chat call") + } + if len(sender.sent) != 0 { + t.Fatalf("group passive sync should not send reply: %+v", sender.sent) + } + if len(chatSvc.persisted) != 1 { + t.Fatalf("expected 1 passive persisted message, got: %d", len(chatSvc.persisted)) + } + if chatSvc.persisted[0].Role != "user" { + t.Fatalf("expected persisted role user, got: %s", chatSvc.persisted[0].Role) + } + if chatSvc.persisted[0].BotID != "bot-1" { + t.Fatalf("expected passive persisted bot_id bot-1, got: %s", chatSvc.persisted[0].BotID) + } +} + +func TestChannelInboundProcessorGroupMentionTriggersReply(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-6"}} + memberSvc := &fakeMemberService{isMember: true} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-6", RouteID: "route-6"}} + gateway := &fakeChatGateway{ + resp: conversation.ChatResponse{ + Messages: []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent("AI reply")}, + }, + }, + } + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) + sender := &fakeReplySender{} + + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{ID: "msg-2", Text: "@bot ping"}, + ReplyTarget: "chat_id:oc_123", + Sender: channel.Identity{SubjectID: "user-1"}, + Conversation: channel.Conversation{ + ID: "oc_123", + Type: "group", + }, + Metadata: map[string]any{ + "is_mentioned": true, + }, + } + + err := processor.HandleInbound(context.Background(), cfg, msg, sender) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gateway.gotReq.Query == "" { + t.Fatalf("group mention should trigger chat call") + } + if len(sender.sent) != 1 { + t.Fatalf("expected one outbound reply, got %d", len(sender.sent)) + } + if len(chatSvc.persisted) != 1 { + t.Fatalf("triggered group message should persist inbound user once, got: %d", len(chatSvc.persisted)) + } + if got := chatSvc.persisted[0].Metadata["trigger_mode"]; got != "active_chat" { + t.Fatalf("expected trigger_mode active_chat, got: %v", got) + } + if !gateway.gotReq.UserMessagePersisted { + t.Fatalf("expected UserMessagePersisted=true for pre-persisted inbound message") + } +} + +func TestChannelInboundProcessorPersonalGroupNonOwnerIgnored(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-member"}} + memberSvc := &fakeMemberService{isMember: true} + policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner"} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-personal-1", RouteID: "route-personal-1"}} + gateway := &fakeChatGateway{ + resp: conversation.ChatResponse{ + Messages: []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent("AI reply")}, + }, + }, + } + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) + sender := &fakeReplySender{} + + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{ID: "msg-personal-1", Text: "hello"}, + ReplyTarget: "chat_id:oc_personal", + Sender: channel.Identity{SubjectID: "ext-member-1"}, + Conversation: channel.Conversation{ + ID: "oc_personal", + Type: "group", + }, + } + + err := processor.HandleInbound(context.Background(), cfg, msg, sender) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gateway.gotReq.Query != "" { + t.Fatalf("non-owner should not trigger chat call") + } + if len(sender.sent) != 0 { + t.Fatalf("non-owner should be ignored silently: %+v", sender.sent) + } + if len(chatSvc.persisted) != 0 { + t.Fatalf("ignored message should not persist in passive mode") + } +} + +func TestChannelInboundProcessorPersonalGroupOwnerWithoutMentionUsesPassivePersistence(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-owner"}} + memberSvc := &fakeMemberService{isMember: true} + policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner"} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-personal-2", RouteID: "route-personal-2"}} + gateway := &fakeChatGateway{ + resp: conversation.ChatResponse{ + Messages: []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent("AI reply")}, + }, + }, + } + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) + sender := &fakeReplySender{} + + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{ID: "msg-personal-2", Text: "owner says hi"}, + ReplyTarget: "chat_id:oc_personal", + Sender: channel.Identity{SubjectID: "ext-owner-1"}, + Conversation: channel.Conversation{ + ID: "oc_personal", + Type: "group", + }, + } + + err := processor.HandleInbound(context.Background(), cfg, msg, sender) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gateway.gotReq.Query != "" { + t.Fatalf("owner group message without mention should not trigger chat call") + } + if len(sender.sent) != 0 { + t.Fatalf("owner group message without mention should not send reply") + } + if len(chatSvc.persisted) != 1 { + t.Fatalf("expected one passive persisted message, got: %d", len(chatSvc.persisted)) + } + if got := chatSvc.persisted[0].Metadata["trigger_mode"]; got != "passive_sync" { + t.Fatalf("expected trigger_mode passive_sync, got: %v", got) + } +} + +func TestChannelInboundProcessorProcessingStatusSuccessLifecycle(t *testing.T) { + notifier := &fakeProcessingStatusNotifier{ + startedHandle: channel.ProcessingStatusHandle{Token: "reaction-1"}, + } + registry := channel.NewRegistry() + registry.MustRegister(&fakeProcessingStatusAdapter{notifier: notifier}) + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-1"}} + memberSvc := &fakeMemberService{isMember: true} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-1", RouteID: "route-1"}} + gateway := &fakeChatGateway{ + resp: conversation.ChatResponse{ + Messages: []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent("AI reply")}, + }, + }, + onChat: func(req conversation.ChatRequest) { + if len(notifier.events) != 1 || notifier.events[0] != "started" { + t.Fatalf("expected started before chat call, got events: %+v", notifier.events) + } + }, + } + processor := NewChannelInboundProcessor(slog.Default(), registry, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) + sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} msg := channel.InboundMessage{ + BotID: "bot-1", Channel: channel.ChannelType("feishu"), - Message: channel.Message{Text: "你好"}, - ReplyTarget: "target-id", + Message: channel.Message{ID: "om_123", Text: "hello"}, + ReplyTarget: "chat_id:oc_123", + Sender: channel.Identity{SubjectID: "ext-1"}, Conversation: channel.Conversation{ - ID: "chat-1", + ID: "oc_123", Type: "p2p", }, } err := processor.HandleInbound(context.Background(), cfg, msg, sender) if err != nil { - t.Fatalf("不应报错: %v", err) + t.Fatalf("unexpected error: %v", err) } - if len(sender.sent) != 0 { - t.Fatalf("工具已发送当前会话消息,应抑制普通回复,实际: %+v", sender.sent) + if len(notifier.events) != 2 || notifier.events[0] != "started" || notifier.events[1] != "completed" { + t.Fatalf("unexpected processing status lifecycle: %+v", notifier.events) + } + if notifier.completedSeen.Token != "reaction-1" { + t.Fatalf("expected completed token reaction-1, got: %q", notifier.completedSeen.Token) + } + if notifier.failedCause != nil { + t.Fatalf("expected failed cause nil, got: %v", notifier.failedCause) + } + if len(notifier.info) == 0 || notifier.info[0].SourceMessageID != "om_123" { + t.Fatalf("expected processing info source message id om_123, got: %+v", notifier.info) + } + if len(sender.sent) != 1 { + t.Fatalf("expected one outbound reply, got %d", len(sender.sent)) } } -func TestChannelInboundProcessorDedupeWithToolSend(t *testing.T) { - store := &fakeConfigStore{ - session: channel.ChannelSession{ - SessionID: "feishu:bot-1:chat-1", - UserID: "user-123", +func TestChannelInboundProcessorProcessingStatusFailureLifecycle(t *testing.T) { + notifier := &fakeProcessingStatusNotifier{ + startedHandle: channel.ProcessingStatusHandle{Token: "reaction-2"}, + } + registry := channel.NewRegistry() + registry.MustRegister(&fakeProcessingStatusAdapter{notifier: notifier}) + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-2"}} + memberSvc := &fakeMemberService{isMember: true} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-2", RouteID: "route-2"}} + chatErr := errors.New("chat gateway unavailable") + gateway := &fakeChatGateway{err: chatErr} + processor := NewChannelInboundProcessor(slog.Default(), registry, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) + sender := &fakeReplySender{} + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{ID: "om_456", Text: "hello"}, + ReplyTarget: "chat_id:oc_456", + Sender: channel.Identity{SubjectID: "ext-2"}, + Conversation: channel.Conversation{ + ID: "oc_456", + Type: "p2p", }, } + + err := processor.HandleInbound(context.Background(), cfg, msg, sender) + if !errors.Is(err, chatErr) { + t.Fatalf("expected chat error, got: %v", err) + } + if len(notifier.events) != 2 || notifier.events[0] != "started" || notifier.events[1] != "failed" { + t.Fatalf("unexpected processing status lifecycle: %+v", notifier.events) + } + if !errors.Is(notifier.failedCause, chatErr) { + t.Fatalf("expected failed cause chat error, got: %v", notifier.failedCause) + } + if notifier.failedSeen.Token != "reaction-2" { + t.Fatalf("expected failed token reaction-2, got: %q", notifier.failedSeen.Token) + } + if len(sender.sent) != 0 { + t.Fatalf("expected no outbound reply on chat failure, got: %+v", sender.sent) + } +} + +func TestChannelInboundProcessorProcessingStatusErrorsAreBestEffort(t *testing.T) { + notifier := &fakeProcessingStatusNotifier{ + startedErr: errors.New("start notify failed"), + completedErr: errors.New("completed notify failed"), + } + registry := channel.NewRegistry() + registry.MustRegister(&fakeProcessingStatusAdapter{notifier: notifier}) + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-3"}} + memberSvc := &fakeMemberService{isMember: true} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-3", RouteID: "route-3"}} gateway := &fakeChatGateway{ - resp: chat.ChatResponse{ - Messages: []chat.ModelMessage{ - { - Role: "assistant", - ToolCalls: []chat.ToolCall{ - { - Type: "function", - Function: chat.ToolCallFunction{ - Name: "send_message", - Arguments: `{"platform":"feishu","target":"other-target","message":{"text":"AI回复内容"}}`, - }, - }, - }, - }, - {Role: "assistant", Content: chat.NewTextContent("AI回复内容")}, + resp: conversation.ChatResponse{ + Messages: []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent("AI reply")}, }, }, } - processor := NewChannelInboundProcessor(slog.Default(), nil, store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), registry, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} msg := channel.InboundMessage{ + BotID: "bot-1", Channel: channel.ChannelType("feishu"), - Message: channel.Message{Text: "你好"}, - ReplyTarget: "target-id", + Message: channel.Message{ID: "om_789", Text: "hello"}, + ReplyTarget: "chat_id:oc_789", + Sender: channel.Identity{SubjectID: "ext-3"}, Conversation: channel.Conversation{ - ID: "chat-1", + ID: "oc_789", Type: "p2p", }, } err := processor.HandleInbound(context.Background(), cfg, msg, sender) if err != nil { - t.Fatalf("不应报错: %v", err) + t.Fatalf("unexpected error: %v", err) } - if len(sender.sent) != 0 { - t.Fatalf("工具发送文本与普通回复重复,应去重,实际: %+v", sender.sent) + if len(notifier.events) != 2 || notifier.events[0] != "started" || notifier.events[1] != "completed" { + t.Fatalf("unexpected processing status lifecycle: %+v", notifier.events) + } + if notifier.completedSeen.Token != "" { + t.Fatalf("expected empty completed token after started failure, got: %q", notifier.completedSeen.Token) + } + if len(sender.sent) != 1 { + t.Fatalf("expected one outbound reply, got %d", len(sender.sent)) + } +} + +func TestChannelInboundProcessorProcessingFailedNotifyErrorDoesNotOverrideChatError(t *testing.T) { + notifier := &fakeProcessingStatusNotifier{ + startedHandle: channel.ProcessingStatusHandle{Token: "reaction-4"}, + failedErr: errors.New("failed notify error"), + } + registry := channel.NewRegistry() + registry.MustRegister(&fakeProcessingStatusAdapter{notifier: notifier}) + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-4"}} + memberSvc := &fakeMemberService{isMember: true} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-4", RouteID: "route-4"}} + chatErr := errors.New("chat failed") + gateway := &fakeChatGateway{err: chatErr} + processor := NewChannelInboundProcessor(slog.Default(), registry, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) + sender := &fakeReplySender{} + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{ID: "om_999", Text: "hello"}, + ReplyTarget: "chat_id:oc_999", + Sender: channel.Identity{SubjectID: "ext-4"}, + Conversation: channel.Conversation{ + ID: "oc_999", + Type: "p2p", + }, + } + + err := processor.HandleInbound(context.Background(), cfg, msg, sender) + if !errors.Is(err, chatErr) { + t.Fatalf("expected original chat error, got: %v", err) + } + if len(notifier.events) != 2 || notifier.events[0] != "started" || notifier.events[1] != "failed" { + t.Fatalf("unexpected processing status lifecycle: %+v", notifier.events) } } diff --git a/internal/router/identity.go b/internal/router/identity.go index 863a7036..ae1995e4 100644 --- a/internal/router/identity.go +++ b/internal/router/identity.go @@ -8,27 +8,30 @@ import ( "strings" "time" + "github.com/memohai/memoh/internal/bind" "github.com/memohai/memoh/internal/channel" - "github.com/memohai/memoh/internal/contacts" - "github.com/memohai/memoh/internal/policy" + "github.com/memohai/memoh/internal/channel/identities" "github.com/memohai/memoh/internal/preauth" ) +// IdentityDecision indicates whether the inbound message should be stopped with an optional reply. type IdentityDecision struct { Stop bool Reply channel.Message } +// InboundIdentity carries the resolved channel identity for an inbound message. type InboundIdentity struct { - BotID string - SessionID string - ChannelConfigID string - ExternalID string - UserID string - ContactID string - Contact contacts.Contact + BotID string + ChannelConfigID string + SubjectID string + ChannelIdentityID string + UserID string + DisplayName string + ForceReply bool } +// IdentityState bundles resolved identity with an optional early-exit decision. type IdentityState struct { Identity InboundIdentity Decision *IdentityDecision @@ -36,10 +39,12 @@ type IdentityState struct { type identityContextKey struct{} +// WithIdentityState stores IdentityState in the context. func WithIdentityState(ctx context.Context, state IdentityState) context.Context { return context.WithValue(ctx, identityContextKey{}, state) } +// IdentityStateFromContext retrieves IdentityState from the context. func IdentityStateFromContext(ctx context.Context) (IdentityState, bool) { if ctx == nil { return IdentityState{}, false @@ -52,54 +57,88 @@ func IdentityStateFromContext(ctx context.Context) (IdentityState, bool) { return state, ok } -// IdentityStore is the minimal persistence interface required by IdentityResolver. -type IdentityStore interface { - channel.BindingStore - channel.SessionStore +// ChannelIdentityService is the minimal interface for channel identity resolution. +type ChannelIdentityService interface { + ResolveByChannelIdentity(ctx context.Context, channel, channelSubjectID, displayName string) (identities.ChannelIdentity, error) + Canonicalize(ctx context.Context, channelIdentityID string) (string, error) + GetLinkedUserID(ctx context.Context, channelIdentityID string) (string, error) + LinkChannelIdentityToUser(ctx context.Context, channelIdentityID, userID string) error } -type IdentityResolver struct { - registry *channel.Registry - store IdentityStore - contacts ContactService - policy PolicyService - preauth PreauthService - logger *slog.Logger - unboundReply string - preauthReply string +// BotMemberService checks and manages bot membership. +type BotMemberService interface { + IsMember(ctx context.Context, botID, channelIdentityID string) (bool, error) + UpsertMemberSimple(ctx context.Context, botID, channelIdentityID, role string) error } +// PolicyService resolves access policy for a bot. type PolicyService interface { - Resolve(ctx context.Context, botID string) (policy.Decision, error) + AllowGuest(ctx context.Context, botID string) (bool, error) + BotType(ctx context.Context, botID string) (string, error) + BotOwnerUserID(ctx context.Context, botID string) (string, error) } +// PreauthService handles preauth key validation. type PreauthService interface { Get(ctx context.Context, token string) (preauth.Key, error) MarkUsed(ctx context.Context, id string) (preauth.Key, error) } -func NewIdentityResolver(log *slog.Logger, registry *channel.Registry, store IdentityStore, contacts ContactService, policyService PolicyService, preauthService PreauthService, unboundReply, preauthReply string) *IdentityResolver { +// BindService handles channel identity bind code validation and consumption. +type BindService interface { + Get(ctx context.Context, token string) (bind.Code, error) + Consume(ctx context.Context, code bind.Code, channelIdentityID string) error +} + +// IdentityResolver implements identity resolution with bind code, preauth, and guest fallback. +type IdentityResolver struct { + registry *channel.Registry + channelIdentities ChannelIdentityService + members BotMemberService + policy PolicyService + preauth PreauthService + bind BindService + logger *slog.Logger + unboundReply string + preauthReply string + bindReply string +} + +// NewIdentityResolver creates an IdentityResolver. +func NewIdentityResolver( + log *slog.Logger, + registry *channel.Registry, + channelIdentityService ChannelIdentityService, + memberService BotMemberService, + policyService PolicyService, + preauthService PreauthService, + bindService BindService, + unboundReply, preauthReply string, +) *IdentityResolver { if log == nil { log = slog.Default() } if strings.TrimSpace(unboundReply) == "" { - unboundReply = "当前不允许陌生人访问,请联系管理员。" + unboundReply = "Access denied. Please contact the administrator." } if strings.TrimSpace(preauthReply) == "" { - preauthReply = "授权成功,请继续使用。" + preauthReply = "Authorization successful." } return &IdentityResolver{ - registry: registry, - store: store, - contacts: contacts, - policy: policyService, - preauth: preauthService, - logger: log.With(slog.String("component", "channel_identity")), - unboundReply: unboundReply, - preauthReply: preauthReply, + registry: registry, + channelIdentities: channelIdentityService, + members: memberService, + policy: policyService, + preauth: preauthService, + bind: bindService, + logger: log.With(slog.String("component", "channel_identity")), + unboundReply: unboundReply, + preauthReply: preauthReply, + bindReply: "Binding successful! Your identity has been linked.", } } +// Middleware returns a channel middleware that resolves identity before processing. func (r *IdentityResolver) Middleware() channel.Middleware { return func(next channel.InboundHandler) channel.InboundHandler { return func(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) error { @@ -112,8 +151,11 @@ func (r *IdentityResolver) Middleware() channel.Middleware { } } +// Resolve performs two-phase identity resolution: +// 1. Global identity: (channel, channel_subject_id) -> channel_identity_id (unconditional) +// 2. Authorization: bot membership check with guest/preauth fallback func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) (IdentityState, error) { - if r.store == nil || r.contacts == nil || r.policy == nil { + if r.channelIdentities == nil { return IdentityState{}, fmt.Errorf("identity resolver not configured") } @@ -121,111 +163,157 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi if botID == "" { botID = cfg.BotID } - normalizedMsg := msg - normalizedMsg.BotID = botID - sessionID := normalizedMsg.SessionID() channelConfigID := cfg.ID if r.registry != nil && r.registry.IsConfigless(msg.Channel) { channelConfigID = "" } - externalID := extractExternalIdentity(msg) + subjectID := extractSubjectIdentity(msg) + displayName := r.resolveDisplayName(ctx, cfg, msg, subjectID) state := IdentityState{ Identity: InboundIdentity{ BotID: botID, - SessionID: sessionID, ChannelConfigID: channelConfigID, - ExternalID: externalID, + SubjectID: subjectID, }, } - session, err := r.store.GetChannelSession(ctx, sessionID) - if err != nil && r.logger != nil { - r.logger.Error("get user by session failed", slog.String("session_id", sessionID), slog.Any("error", err)) - } - userID := strings.TrimSpace(session.UserID) - contactID := strings.TrimSpace(session.ContactID) - - if userID == "" { - userID, err = r.store.ResolveUserBinding(ctx, msg.Channel, channel.BindingCriteriaFromIdentity(msg.Sender)) - if err == nil && userID != "" { - _ = r.store.UpsertChannelSession(ctx, sessionID, botID, channelConfigID, userID, contactID, string(msg.Channel), strings.TrimSpace(msg.ReplyTarget), extractThreadID(msg), buildSessionMetadata(msg)) - } + // Phase 1: Global identity resolution (unconditional). + if subjectID == "" { + return state, fmt.Errorf("cannot resolve identity: no channel_subject_id") } - var contact contacts.Contact - if contactID == "" && userID != "" { - contact, err = r.contacts.GetByUserID(ctx, botID, userID) - if err != nil { - displayName := extractDisplayName(msg) - contact, err = r.contacts.Create(ctx, contacts.CreateRequest{ - BotID: botID, - UserID: userID, - DisplayName: displayName, - Status: "active", - }) - } - if err == nil { - contactID = contact.ID - if externalID != "" { - _, _ = r.contacts.UpsertChannel(ctx, botID, contactID, msg.Channel.String(), externalID, nil) - } + channelIdentityID, linkedUserID, err := r.resolveIdentityWithLinkedUser(ctx, msg, subjectID, displayName) + if err != nil { + return state, err + } + state.Identity.ChannelIdentityID = channelIdentityID + state.Identity.UserID = strings.TrimSpace(linkedUserID) + if strings.TrimSpace(state.Identity.UserID) == "" { + state.Identity.UserID = r.tryLinkConfiglessChannelIdentityToUser(ctx, msg, channelIdentityID) + } + state.Identity.DisplayName = displayName + + // Bind code check runs before membership/guest checks so linking is always reachable. + if handled, decision, newUserID, err := r.tryHandleBindCode(ctx, msg, channelIdentityID, subjectID); handled { + if strings.TrimSpace(newUserID) != "" { + state.Identity.UserID = strings.TrimSpace(newUserID) } + state.Decision = &decision + return state, err } - if contactID == "" && externalID != "" { - binding, err := r.contacts.GetByChannelIdentity(ctx, botID, msg.Channel.String(), externalID) - if err == nil { - contactID = binding.ContactID - } - } - - if contactID == "" { - decision, err := r.policy.Resolve(ctx, botID) + // Personal bots are owner-only and must not depend on member/guest/preauth bypass. + if r.policy != nil { + botType, err := r.policy.BotType(ctx, botID) if err != nil { return state, err } - if decision.AllowGuest { - displayName := extractDisplayName(msg) - contact, err = r.contacts.CreateGuest(ctx, botID, displayName) - if err == nil { - contactID = contact.ID - if externalID != "" { - _, _ = r.contacts.UpsertChannel(ctx, botID, contactID, msg.Channel.String(), externalID, nil) - } - } - } else { - if handled, decision, err := r.tryHandlePreauthKey(ctx, normalizedMsg, externalID); handled { - state.Decision = &decision + if strings.EqualFold(strings.TrimSpace(botType), "personal") { + ownerUserID, err := r.policy.BotOwnerUserID(ctx, botID) + if err != nil { return state, err } - state.Decision = &IdentityDecision{ - Stop: true, - Reply: channel.Message{Text: r.unboundReply}, + isOwner := strings.TrimSpace(state.Identity.UserID) != "" && + strings.TrimSpace(ownerUserID) == strings.TrimSpace(state.Identity.UserID) + if !isOwner { + // Ignore all non-owner messages for personal bots. + state.Decision = &IdentityDecision{Stop: true} + return state, nil } + // Owner is authorized, but group trigger policy is still decided by + // shouldTriggerAssistantResponse in channel routing. return state, nil } } - if contactID != "" && contact.ID == "" { - loaded, err := r.contacts.GetByID(ctx, contactID) - if err == nil { - contact = loaded + // Phase 2: Authorization (bot membership check). + if r.members != nil { + if strings.TrimSpace(state.Identity.UserID) != "" { + isMember, err := r.members.IsMember(ctx, botID, state.Identity.UserID) + if err != nil { + return state, fmt.Errorf("check bot membership: %w", err) + } + if isMember { + return state, nil + } + } + } + if r.policy != nil && strings.TrimSpace(state.Identity.UserID) != "" { + ownerUserID, err := r.policy.BotOwnerUserID(ctx, botID) + if err != nil { + return state, err + } + // Bot owner should not depend on bot_members linkage. + if strings.TrimSpace(ownerUserID) == strings.TrimSpace(state.Identity.UserID) { + return state, nil } } - if contactID != "" { - _ = r.store.UpsertChannelSession(ctx, sessionID, botID, channelConfigID, userID, contactID, string(msg.Channel), strings.TrimSpace(msg.ReplyTarget), extractThreadID(msg), buildSessionMetadata(msg)) + // Guest policy check. + if r.policy != nil { + allowed, err := r.policy.AllowGuest(ctx, botID) + if err != nil { + return state, err + } + if allowed { + return state, nil + } } - state.Identity.UserID = userID - state.Identity.ContactID = contactID - state.Identity.Contact = contact + // Preauth key check. + if handled, decision, err := r.tryHandlePreauthKey(ctx, msg, botID, state.Identity.UserID, subjectID); handled { + state.Decision = &decision + return state, err + } + + // In group conversations, silently drop unauthorized messages to avoid spamming + // the channel with "access denied" replies (same behavior as personal bot non-owner). + if isGroupConversationType(msg.Conversation.Type) { + state.Decision = &IdentityDecision{Stop: true} + return state, nil + } + + state.Decision = &IdentityDecision{ + Stop: true, + Reply: channel.Message{Text: r.unboundReply}, + } return state, nil } -func (r *IdentityResolver) tryHandlePreauthKey(ctx context.Context, msg channel.InboundMessage, externalID string) (bool, IdentityDecision, error) { +func (r *IdentityResolver) resolveIdentityWithLinkedUser(ctx context.Context, msg channel.InboundMessage, primarySubjectID, displayName string) (string, string, error) { + candidates := identitySubjectCandidates(msg, primarySubjectID) + if len(candidates) == 0 { + return "", "", fmt.Errorf("cannot resolve identity: no channel_subject_id") + } + + firstChannelIdentityID := "" + for _, subjectID := range candidates { + channelIdentity, err := r.channelIdentities.ResolveByChannelIdentity(ctx, msg.Channel.String(), subjectID, displayName) + if err != nil { + return "", "", fmt.Errorf("resolve channel identity: %w", err) + } + channelIdentityID := strings.TrimSpace(channelIdentity.ID) + if channelIdentityID == "" { + continue + } + if firstChannelIdentityID == "" { + firstChannelIdentityID = channelIdentityID + } + linkedUserID, err := r.channelIdentities.GetLinkedUserID(ctx, channelIdentityID) + if err != nil { + return "", "", err + } + linkedUserID = strings.TrimSpace(linkedUserID) + if linkedUserID != "" { + return channelIdentityID, linkedUserID, nil + } + } + return firstChannelIdentityID, "", nil +} + +func (r *IdentityResolver) tryHandlePreauthKey(ctx context.Context, msg channel.InboundMessage, botID, userID, subjectID string) (bool, IdentityDecision, error) { tokenText := strings.TrimSpace(msg.Message.PlainText()) if tokenText == "" || r.preauth == nil { return false, IdentityDecision{}, nil @@ -244,33 +332,95 @@ func (r *IdentityResolver) tryHandlePreauthKey(ctx context.Context, msg channel. } } if !key.UsedAt.IsZero() { - return true, reply("预授权码已使用。"), nil + return true, reply("Preauth key already used."), nil } if !key.ExpiresAt.IsZero() && time.Now().UTC().After(key.ExpiresAt) { - return true, reply("预授权码已过期,请重新获取。"), nil + return true, reply("Preauth key expired."), nil } - if key.BotID != msg.BotID { - return true, reply("预授权码不匹配。"), nil + if key.BotID != botID { + return true, reply("Preauth key mismatch."), nil } - if externalID == "" { - return true, reply("无法识别当前账号,授权失败。"), nil + if subjectID == "" { + return true, reply("Cannot identify current account."), nil } - displayName := extractDisplayName(msg) - contact, err := r.contacts.CreateGuest(ctx, msg.BotID, displayName) - if err != nil { - return true, reply("授权失败,请稍后重试。"), nil + + // Grant membership via preauth. + if strings.TrimSpace(userID) == "" { + return true, reply("Current channel account is not linked to a user."), nil } - if _, err := r.contacts.UpsertChannel(ctx, msg.BotID, contact.ID, msg.Channel.String(), externalID, nil); err != nil { - return true, reply("授权失败,请稍后重试。"), nil + if r.members != nil { + if err := r.members.UpsertMemberSimple(ctx, botID, userID, "member"); err != nil { + return true, IdentityDecision{}, fmt.Errorf("upsert preauth member: %w", err) + } + } + if _, err := r.preauth.MarkUsed(ctx, key.ID); err != nil { + return true, IdentityDecision{}, fmt.Errorf("mark preauth key used: %w", err) } - _ = r.store.UpsertChannelSession(ctx, msg.SessionID(), msg.BotID, "", "", contact.ID, msg.Channel.String(), strings.TrimSpace(msg.ReplyTarget), extractThreadID(msg), buildSessionMetadata(msg)) - _, _ = r.preauth.MarkUsed(ctx, key.ID) return true, reply(r.preauthReply), nil } -func extractExternalIdentity(msg channel.InboundMessage) string { - if strings.TrimSpace(msg.Sender.ExternalID) != "" { - return strings.TrimSpace(msg.Sender.ExternalID) +func (r *IdentityResolver) tryHandleBindCode(ctx context.Context, msg channel.InboundMessage, channelIdentityID, subjectID string) (bool, IdentityDecision, string, error) { + tokenText := strings.TrimSpace(msg.Message.PlainText()) + if tokenText == "" || r.bind == nil { + return false, IdentityDecision{}, "", nil + } + code, err := r.bind.Get(ctx, tokenText) + if err != nil { + if errors.Is(err, bind.ErrCodeNotFound) { + return false, IdentityDecision{}, "", nil + } + return true, IdentityDecision{}, "", err + } + reply := func(text string) IdentityDecision { + return IdentityDecision{Stop: true, Reply: channel.Message{Text: text}} + } + if !code.UsedAt.IsZero() { + return true, reply("Bind code already used."), "", nil + } + if !code.ExpiresAt.IsZero() && time.Now().UTC().After(code.ExpiresAt) { + return true, reply("Bind code expired."), "", nil + } + if strings.TrimSpace(code.Platform) != "" && !strings.EqualFold(strings.TrimSpace(code.Platform), msg.Channel.String()) { + return true, reply("Bind code mismatch."), "", nil + } + if subjectID == "" { + return true, reply("Cannot identify current account."), "", nil + } + + // Consume: mark used + link source channel identity to issuer user. + if err := r.bind.Consume(ctx, code, channelIdentityID); err != nil { + switch { + case errors.Is(err, bind.ErrCodeUsed): + return true, reply("Bind code already used."), "", nil + case errors.Is(err, bind.ErrCodeExpired): + return true, reply("Bind code expired."), "", nil + case errors.Is(err, bind.ErrCodeMismatch): + return true, reply("Bind code mismatch."), "", nil + case errors.Is(err, bind.ErrLinkConflict): + return true, reply("Current identity has already been linked to another account."), "", nil + default: + return true, IdentityDecision{}, "", fmt.Errorf("consume bind code: %w", err) + } + } + + // Resolve linked user after binding. + newUserID := code.IssuedByUserID + if r.channelIdentities != nil { + linkedUserID, err := r.channelIdentities.GetLinkedUserID(ctx, channelIdentityID) + if err != nil { + return true, IdentityDecision{}, "", fmt.Errorf("resolve linked user after bind: %w", err) + } + if strings.TrimSpace(linkedUserID) != "" { + newUserID = linkedUserID + } + } + + return true, reply(r.bindReply), newUserID, nil +} + +func extractSubjectIdentity(msg channel.InboundMessage) string { + if strings.TrimSpace(msg.Sender.SubjectID) != "" { + return strings.TrimSpace(msg.Sender.SubjectID) } if value := strings.TrimSpace(msg.Sender.Attribute("open_id")); value != "" { return value @@ -284,21 +434,85 @@ func extractExternalIdentity(msg channel.InboundMessage) string { return strings.TrimSpace(msg.Sender.DisplayName) } +func identitySubjectCandidates(msg channel.InboundMessage, primary string) []string { + candidates := make([]string, 0, 3) + appendUnique := func(value string) { + value = strings.TrimSpace(value) + if value == "" { + return + } + for _, existing := range candidates { + if existing == value { + return + } + } + candidates = append(candidates, value) + } + + appendUnique(primary) + appendUnique(msg.Sender.Attribute("open_id")) + appendUnique(msg.Sender.Attribute("user_id")) + return candidates +} + func extractDisplayName(msg channel.InboundMessage) string { if strings.TrimSpace(msg.Sender.DisplayName) != "" { return strings.TrimSpace(msg.Sender.DisplayName) } - if strings.TrimSpace(msg.Sender.ExternalID) != "" { - return strings.TrimSpace(msg.Sender.ExternalID) + if value := strings.TrimSpace(msg.Sender.Attribute("display_name")); value != "" { + return value + } + if value := strings.TrimSpace(msg.Sender.Attribute("name")); value != "" { + return value } if value := strings.TrimSpace(msg.Sender.Attribute("username")); value != "" { return value } - if value := strings.TrimSpace(msg.Sender.Attribute("user_id")); value != "" { - return value + return "" +} + +func (r *IdentityResolver) resolveDisplayName(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, subjectID string) string { + displayName := extractDisplayName(msg) + if displayName != "" { + return displayName } - if value := strings.TrimSpace(msg.Sender.Attribute("open_id")); value != "" { - return value + return r.resolveDisplayNameFromDirectory(ctx, cfg, msg, subjectID) +} + +func (r *IdentityResolver) resolveDisplayNameFromDirectory(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, subjectID string) string { + if r.registry == nil { + return "" + } + subjectID = strings.TrimSpace(subjectID) + if subjectID == "" { + return "" + } + directoryAdapter, ok := r.registry.DirectoryAdapter(msg.Channel) + if !ok || directoryAdapter == nil { + return "" + } + if ctx == nil { + ctx = context.Background() + } + lookupCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + entry, err := directoryAdapter.ResolveEntry(lookupCtx, cfg, subjectID, channel.DirectoryEntryUser) + if err != nil { + if r.logger != nil { + r.logger.Debug( + "resolve display name from directory failed", + slog.String("channel", msg.Channel.String()), + slog.String("subject_id", subjectID), + slog.Any("error", err), + ) + } + return "" + } + if name := strings.TrimSpace(entry.Name); name != "" { + return name + } + if handle := strings.TrimSpace(entry.Handle); handle != "" { + return handle } return "" } @@ -313,6 +527,39 @@ func extractThreadID(msg channel.InboundMessage) string { return "" } +func isGroupConversationType(conversationType string) bool { + ct := strings.ToLower(strings.TrimSpace(conversationType)) + if ct == "" { + return false + } + return ct != "p2p" && ct != "private" && ct != "direct" +} + +func (r *IdentityResolver) tryLinkConfiglessChannelIdentityToUser(ctx context.Context, msg channel.InboundMessage, channelIdentityID string) string { + if r.registry == nil || !r.registry.IsConfigless(msg.Channel) { + return "" + } + if r.channelIdentities == nil { + return "" + } + candidateUserID := strings.TrimSpace(msg.Sender.Attribute("user_id")) + if candidateUserID == "" { + return "" + } + if err := r.channelIdentities.LinkChannelIdentityToUser(ctx, channelIdentityID, candidateUserID); err != nil { + if r.logger != nil { + r.logger.Warn("auto link configless channel identity failed", + slog.String("channel", msg.Channel.String()), + slog.String("channel_identity_id", channelIdentityID), + slog.String("user_id", candidateUserID), + slog.Any("error", err), + ) + } + return "" + } + return candidateUserID +} + func buildSessionMetadata(msg channel.InboundMessage) map[string]any { metadata := map[string]any{} if strings.TrimSpace(msg.Source) != "" { diff --git a/internal/router/identity_test.go b/internal/router/identity_test.go index 90c87bd8..843ba72e 100644 --- a/internal/router/identity_test.go +++ b/internal/router/identity_test.go @@ -2,105 +2,119 @@ package router import ( "context" - "fmt" + "errors" "log/slog" "testing" "time" + "github.com/memohai/memoh/internal/bind" "github.com/memohai/memoh/internal/channel" - "github.com/memohai/memoh/internal/contacts" - "github.com/memohai/memoh/internal/policy" + "github.com/memohai/memoh/internal/channel/identities" "github.com/memohai/memoh/internal/preauth" ) -type fakePolicyServiceIdentity struct { - decision policy.Decision - err error +type fakeChannelIdentityService struct { + channelIdentity identities.ChannelIdentity + bySubject map[string]identities.ChannelIdentity + err error + canonical map[string]string + linked map[string]string + calls int + lastDisplayName string } -func (f *fakePolicyServiceIdentity) Resolve(ctx context.Context, botID string) (policy.Decision, error) { +func (f *fakeChannelIdentityService) ResolveByChannelIdentity(ctx context.Context, platform, externalID, displayName string) (identities.ChannelIdentity, error) { + f.calls++ + f.lastDisplayName = displayName if f.err != nil { - return policy.Decision{}, f.err + return identities.ChannelIdentity{}, f.err } - decision := f.decision - if decision.BotID == "" { - decision.BotID = botID + if f.bySubject != nil { + if identity, ok := f.bySubject[externalID]; ok { + return identity, nil + } + return identities.ChannelIdentity{}, nil } - return decision, nil + return f.channelIdentity, nil } -type fakeIdentityConfigStore struct{} - -func (f *fakeIdentityConfigStore) ResolveEffectiveConfig(ctx context.Context, botID string, channelType channel.ChannelType) (channel.ChannelConfig, error) { - return channel.ChannelConfig{}, nil +func (f *fakeChannelIdentityService) Canonicalize(ctx context.Context, channelIdentityID string) (string, error) { + if f.canonical != nil { + if value, ok := f.canonical[channelIdentityID]; ok { + return value, nil + } + } + return channelIdentityID, nil } -func (f *fakeIdentityConfigStore) GetUserConfig(ctx context.Context, actorUserID string, channelType channel.ChannelType) (channel.ChannelUserBinding, error) { - return channel.ChannelUserBinding{}, fmt.Errorf("not implemented") +func (f *fakeChannelIdentityService) GetLinkedUserID(ctx context.Context, channelIdentityID string) (string, error) { + if f.linked != nil { + if value, ok := f.linked[channelIdentityID]; ok { + return value, nil + } + return "", nil + } + // Default to one-to-one mapping for tests that do not set explicit links. + return channelIdentityID, nil } -func (f *fakeIdentityConfigStore) UpsertUserConfig(ctx context.Context, actorUserID string, channelType channel.ChannelType, req channel.UpsertUserConfigRequest) (channel.ChannelUserBinding, error) { - return channel.ChannelUserBinding{}, nil -} - -func (f *fakeIdentityConfigStore) ListConfigsByType(ctx context.Context, channelType channel.ChannelType) ([]channel.ChannelConfig, error) { - return nil, nil -} - -func (f *fakeIdentityConfigStore) ResolveUserBinding(ctx context.Context, channelType channel.ChannelType, criteria channel.BindingCriteria) (string, error) { - return "", fmt.Errorf("channel user binding not found") -} - -func (f *fakeIdentityConfigStore) ListSessionsByBotPlatform(ctx context.Context, botID string, platform string) ([]channel.ChannelSession, error) { - return nil, nil -} - -func (f *fakeIdentityConfigStore) GetChannelSession(ctx context.Context, sessionID string) (channel.ChannelSession, error) { - return channel.ChannelSession{}, nil -} - -func (f *fakeIdentityConfigStore) UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string, replyTarget string, threadID string, metadata map[string]any) error { +func (f *fakeChannelIdentityService) LinkChannelIdentityToUser(ctx context.Context, channelIdentityID, userID string) error { + if f.linked == nil { + f.linked = map[string]string{} + } + f.linked[channelIdentityID] = userID return nil } -type fakeIdentityContactService struct { - createGuestCalled bool - upsertCalled bool +type fakeMemberService struct { + isMember bool + upsertCalled bool } -func (f *fakeIdentityContactService) GetByID(ctx context.Context, contactID string) (contacts.Contact, error) { - return contacts.Contact{}, fmt.Errorf("not found") +func (f *fakeMemberService) IsMember(ctx context.Context, botID, channelIdentityID string) (bool, error) { + return f.isMember, nil } -func (f *fakeIdentityContactService) GetByUserID(ctx context.Context, botID, userID string) (contacts.Contact, error) { - return contacts.Contact{}, fmt.Errorf("not found") -} - -func (f *fakeIdentityContactService) GetByChannelIdentity(ctx context.Context, botID, platform, externalID string) (contacts.ContactChannel, error) { - return contacts.ContactChannel{}, fmt.Errorf("not found") -} - -func (f *fakeIdentityContactService) Create(ctx context.Context, req contacts.CreateRequest) (contacts.Contact, error) { - return contacts.Contact{ID: "contact-1", BotID: req.BotID}, nil -} - -func (f *fakeIdentityContactService) CreateGuest(ctx context.Context, botID, displayName string) (contacts.Contact, error) { - f.createGuestCalled = true - return contacts.Contact{ID: "contact-guest", BotID: botID}, nil -} - -func (f *fakeIdentityContactService) UpsertChannel(ctx context.Context, botID, contactID, platform, externalID string, metadata map[string]any) (contacts.ContactChannel, error) { +func (f *fakeMemberService) UpsertMemberSimple(ctx context.Context, botID, channelIdentityID, role string) error { f.upsertCalled = true - return contacts.ContactChannel{ID: "channel-1", ContactID: contactID}, nil + return nil } -type fakePreauthService struct { +type fakePolicyService struct { + allow bool + botType string + ownerUserID string + err error +} + +func (f *fakePolicyService) AllowGuest(ctx context.Context, botID string) (bool, error) { + if f.err != nil { + return false, f.err + } + return f.allow, nil +} + +func (f *fakePolicyService) BotType(ctx context.Context, botID string) (string, error) { + if f.err != nil { + return "", f.err + } + return f.botType, nil +} + +func (f *fakePolicyService) BotOwnerUserID(ctx context.Context, botID string) (string, error) { + if f.err != nil { + return "", f.err + } + return f.ownerUserID, nil +} + +type fakePreauthServiceIdentity struct { key preauth.Key err error markUsed bool } -func (f *fakePreauthService) Get(ctx context.Context, token string) (preauth.Key, error) { +func (f *fakePreauthServiceIdentity) Get(ctx context.Context, token string) (preauth.Key, error) { if f.err != nil { return preauth.Key{}, f.err } @@ -110,41 +124,220 @@ func (f *fakePreauthService) Get(ctx context.Context, token string) (preauth.Key return f.key, nil } -func (f *fakePreauthService) MarkUsed(ctx context.Context, id string) (preauth.Key, error) { +func (f *fakePreauthServiceIdentity) MarkUsed(ctx context.Context, id string) (preauth.Key, error) { f.markUsed = true return f.key, nil } -func TestIdentityResolverAllowGuestCreatesContact(t *testing.T) { - store := &fakeIdentityConfigStore{} - contactsService := &fakeIdentityContactService{} - policyService := &fakePolicyServiceIdentity{decision: policy.Decision{AllowGuest: true}} - resolver := NewIdentityResolver(slog.Default(), nil, store, contactsService, policyService, nil, "禁止访问", "授权成功") +type fakeBindService struct { + code bind.Code + getErr error + consumeErr error + consumeCalled bool + onConsume func(channelChannelIdentityID string) +} + +func (f *fakeBindService) Get(ctx context.Context, token string) (bind.Code, error) { + if f.getErr != nil { + return bind.Code{}, f.getErr + } + if f.code.Token == "" || f.code.Token != token { + return bind.Code{}, bind.ErrCodeNotFound + } + return f.code, nil +} + +func (f *fakeBindService) Consume(ctx context.Context, code bind.Code, channelChannelIdentityID string) error { + f.consumeCalled = true + if f.onConsume != nil { + f.onConsume(channelChannelIdentityID) + } + return f.consumeErr +} + +type fakeDirectoryAdapter struct { + channelType channel.ChannelType + resolveFn func(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) +} + +func (f *fakeDirectoryAdapter) Type() channel.ChannelType { + return f.channelType +} + +func (f *fakeDirectoryAdapter) Descriptor() channel.Descriptor { + return channel.Descriptor{ + Type: f.channelType, + DisplayName: "FakeDirectory", + Capabilities: channel.ChannelCapabilities{}, + } +} + +func (f *fakeDirectoryAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} + +func (f *fakeDirectoryAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} + +func (f *fakeDirectoryAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelConfig, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + return nil, nil +} + +func (f *fakeDirectoryAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + if f.resolveFn != nil { + return f.resolveFn(ctx, cfg, input, kind) + } + return channel.DirectoryEntry{}, errors.New("resolve not implemented") +} + +func TestIdentityResolverAllowGuestWithoutMembershipSideEffect(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-1"}} + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: true, botType: "public"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") msg := channel.InboundMessage{ BotID: "bot-1", Channel: channel.ChannelType("feishu"), Message: channel.Message{Text: "hello"}, ReplyTarget: "target-id", - Sender: channel.Identity{ExternalID: "user-1", DisplayName: "访客"}, + Sender: channel.Identity{SubjectID: "ext-1", DisplayName: "Guest"}, } state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) if err != nil { - t.Fatalf("不应报错: %v", err) + t.Fatalf("unexpected error: %v", err) } - if state.Identity.ContactID != "contact-guest" { - t.Fatalf("应创建访客联系人,实际: %s", state.Identity.ContactID) + if state.Identity.ChannelIdentityID != "channelIdentity-1" { + t.Fatalf("expected channelIdentity-1, got: %s", state.Identity.ChannelIdentityID) } - if !contactsService.createGuestCalled { - t.Fatalf("应调用 CreateGuest") + if memberSvc.upsertCalled { + t.Fatal("guest allow should not upsert membership") + } + if state.Decision != nil { + t.Fatal("expected no decision for allowed guest") } } -func TestIdentityResolverPreauthKeyAllowsGuest(t *testing.T) { - store := &fakeIdentityConfigStore{} - contactsService := &fakeIdentityContactService{} - policyService := &fakePolicyServiceIdentity{} - preauthService := &fakePreauthService{ +func TestIdentityResolverResolveDisplayNameFromDirectory(t *testing.T) { + registry := channel.NewRegistry() + directoryAdapter := &fakeDirectoryAdapter{ + channelType: channel.ChannelType("feishu"), + resolveFn: func(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + if kind != channel.DirectoryEntryUser { + t.Fatalf("expected kind user, got %s", kind) + } + if input != "ou-directory" { + t.Fatalf("expected subject id ou-directory, got %s", input) + } + return channel.DirectoryEntry{ + Kind: channel.DirectoryEntryUser, + Name: "Directory Name", + }, nil + }, + } + if err := registry.Register(directoryAdapter); err != nil { + t.Fatalf("register directory adapter failed: %v", err) + } + + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-directory"}} + memberSvc := &fakeMemberService{isMember: true} + policySvc := &fakePolicyService{allow: false, botType: "public"} + resolver := NewIdentityResolver(slog.Default(), registry, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello"}, + ReplyTarget: "target-id", + Sender: channel.Identity{ + SubjectID: "ou-directory", + Attributes: map[string]string{ + "open_id": "ou-directory", + }, + }, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1", ChannelType: channel.ChannelType("feishu")}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Identity.DisplayName != "Directory Name" { + t.Fatalf("expected directory display name, got %q", state.Identity.DisplayName) + } + if channelIdentitySvc.lastDisplayName != "Directory Name" { + t.Fatalf("expected upsert display name Directory Name, got %q", channelIdentitySvc.lastDisplayName) + } +} + +func TestIdentityResolverDirectoryLookupFailureDoesNotFallbackToOpenID(t *testing.T) { + registry := channel.NewRegistry() + directoryAdapter := &fakeDirectoryAdapter{ + channelType: channel.ChannelType("feishu"), + resolveFn: func(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + return channel.DirectoryEntry{}, errors.New("lookup failed") + }, + } + if err := registry.Register(directoryAdapter); err != nil { + t.Fatalf("register directory adapter failed: %v", err) + } + + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-directory-fail"}} + memberSvc := &fakeMemberService{isMember: true} + policySvc := &fakePolicyService{allow: false, botType: "public"} + resolver := NewIdentityResolver(slog.Default(), registry, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello"}, + ReplyTarget: "target-id", + Sender: channel.Identity{ + SubjectID: "ou-directory-fail", + Attributes: map[string]string{ + "open_id": "ou-directory-fail", + "user_id": "u-directory-fail", + }, + }, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1", ChannelType: channel.ChannelType("feishu")}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Identity.DisplayName != "" { + t.Fatalf("expected empty display name when directory lookup fails, got %q", state.Identity.DisplayName) + } + if channelIdentitySvc.lastDisplayName != "" { + t.Fatalf("expected empty upsert display name on lookup failure, got %q", channelIdentitySvc.lastDisplayName) + } +} + +func TestIdentityResolverExistingMemberPasses(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-2"}} + memberSvc := &fakeMemberService{isMember: true} + policySvc := &fakePolicyService{allow: false, botType: "public"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("telegram"), + Message: channel.Message{Text: "hello"}, + ReplyTarget: "chat-123", + Sender: channel.Identity{SubjectID: "tg-user-1"}, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision != nil { + t.Fatal("existing member should pass without decision") + } +} + +func TestIdentityResolverPreauthKey(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-3"}} + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: false, botType: "public"} + preauthSvc := &fakePreauthServiceIdentity{ key: preauth.Key{ ID: "key-1", BotID: "bot-1", @@ -152,35 +345,35 @@ func TestIdentityResolverPreauthKeyAllowsGuest(t *testing.T) { ExpiresAt: time.Now().UTC().Add(1 * time.Hour), }, } - resolver := NewIdentityResolver(slog.Default(), nil, store, contactsService, policyService, preauthService, "禁止访问", "授权成功") + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, preauthSvc, nil, "", "") msg := channel.InboundMessage{ BotID: "bot-1", Channel: channel.ChannelType("feishu"), Message: channel.Message{Text: "PREAUTH123"}, ReplyTarget: "target-id", - Sender: channel.Identity{ExternalID: "user-1"}, + Sender: channel.Identity{SubjectID: "ext-1"}, } state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) if err != nil { - t.Fatalf("不应报错: %v", err) + t.Fatalf("unexpected error: %v", err) } if state.Decision == nil || !state.Decision.Stop { - t.Fatalf("应返回授权确认") + t.Fatal("preauth key should return stop decision") } - if !contactsService.upsertCalled { - t.Fatalf("应执行联系人绑定") + if !preauthSvc.markUsed { + t.Fatal("preauth key should be marked used") } - if !preauthService.markUsed { - t.Fatalf("应标记预授权码已使用") + if !memberSvc.upsertCalled { + t.Fatal("membership should be upserted via preauth") } } func TestIdentityResolverPreauthKeyExpired(t *testing.T) { - store := &fakeIdentityConfigStore{} - contactsService := &fakeIdentityContactService{} - policyService := &fakePolicyServiceIdentity{} - preauthService := &fakePreauthService{ + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-4"}} + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: false, botType: "public"} + preauthSvc := &fakePreauthServiceIdentity{ key: preauth.Key{ ID: "key-1", BotID: "bot-1", @@ -188,23 +381,427 @@ func TestIdentityResolverPreauthKeyExpired(t *testing.T) { ExpiresAt: time.Now().UTC().Add(-1 * time.Hour), }, } - resolver := NewIdentityResolver(slog.Default(), nil, store, contactsService, policyService, preauthService, "禁止访问", "授权成功") + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, preauthSvc, nil, "", "") msg := channel.InboundMessage{ BotID: "bot-1", Channel: channel.ChannelType("feishu"), Message: channel.Message{Text: "PREAUTH123"}, ReplyTarget: "target-id", - Sender: channel.Identity{ExternalID: "user-1"}, + Sender: channel.Identity{SubjectID: "ext-1"}, } state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) if err != nil { - t.Fatalf("不应报错: %v", err) + t.Fatalf("unexpected error: %v", err) } if state.Decision == nil || !state.Decision.Stop { - t.Fatalf("过期预授权码应被拒绝") + t.Fatal("expired preauth key should be rejected") } - if preauthService.markUsed { - t.Fatalf("过期预授权码不应被使用") + if preauthSvc.markUsed { + t.Fatal("expired preauth key should not be marked used") + } +} + +func TestIdentityResolverDenied(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-5"}} + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: false, botType: "public"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "Access denied.", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("telegram"), + Message: channel.Message{Text: "hello"}, + ReplyTarget: "chat-123", + Sender: channel.Identity{SubjectID: "stranger-1"}, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatal("stranger without guest access should be denied") + } +} + +func TestIdentityResolverPersonalBotRejectsGroupMessages(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-group"}} + memberSvc := &fakeMemberService{isMember: true} + policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello"}, + Sender: channel.Identity{SubjectID: "ext-group-1"}, + Conversation: channel.Conversation{ + ID: "group-1", + Type: "group", + }, + } + + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatal("personal bot should reject group messages") + } + if channelIdentitySvc.calls != 1 { + t.Fatalf("expected channelIdentity resolution once before owner check, got %d", channelIdentitySvc.calls) + } + if !state.Decision.Reply.IsEmpty() { + t.Fatal("non-owner group message should be silently ignored") + } +} + +func TestIdentityResolverPersonalBotAllowsOwnerInGroup(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-owner"}} + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello from owner"}, + Sender: channel.Identity{SubjectID: "ext-owner-1"}, + Conversation: channel.Conversation{ + ID: "group-1", + Type: "group", + }, + } + + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision != nil { + t.Fatal("owner group message should pass") + } + if state.Identity.ForceReply { + t.Fatal("owner group message should not force reply") + } +} + +func TestIdentityResolverPersonalBotAllowsOwnerDirectWithoutMembership(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-owner-direct"}} + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner-direct"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello from owner"}, + Sender: channel.Identity{SubjectID: "ext-owner-direct"}, + Conversation: channel.Conversation{ + ID: "p2p-1", + Type: "p2p", + }, + } + + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision != nil { + t.Fatal("owner direct message should pass") + } + if state.Identity.ForceReply { + t.Fatal("owner direct message should not force reply") + } +} + +func TestIdentityResolverPersonalBotOwnerFallbackByAlternateSubject(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{ + bySubject: map[string]identities.ChannelIdentity{ + "ou-open-owner": {ID: "channelIdentity-open-owner"}, + "u-owner": {ID: "channelIdentity-user-owner"}, + }, + linked: map[string]string{ + "channelIdentity-user-owner": "owner-user-1", + }, + } + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "owner-user-1"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello from owner"}, + Sender: channel.Identity{ + SubjectID: "ou-open-owner", + Attributes: map[string]string{ + "open_id": "ou-open-owner", + "user_id": "u-owner", + }, + }, + Conversation: channel.Conversation{ + ID: "p2p-1", + Type: "p2p", + }, + } + + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision != nil { + t.Fatal("owner direct message should pass after alternate subject fallback") + } + if state.Identity.UserID != "owner-user-1" { + t.Fatalf("expected owner-user-1, got: %s", state.Identity.UserID) + } + if state.Identity.ChannelIdentityID != "channelIdentity-user-owner" { + t.Fatalf("expected fallback channel identity, got: %s", state.Identity.ChannelIdentityID) + } + if channelIdentitySvc.calls < 2 { + t.Fatalf("expected fallback resolution attempts, got calls=%d", channelIdentitySvc.calls) + } +} + +func TestIdentityResolverPersonalBotRejectsNonOwnerDirectEvenIfMember(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-non-owner"}} + memberSvc := &fakeMemberService{isMember: true} + policySvc := &fakePolicyService{allow: true, botType: "personal", ownerUserID: "channelIdentity-owner"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "Access denied.", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello from non-owner"}, + Sender: channel.Identity{SubjectID: "ext-non-owner"}, + Conversation: channel.Conversation{ + ID: "p2p-2", + Type: "p2p", + }, + } + + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatal("non-owner direct message should be rejected for personal bot") + } + if !state.Decision.Reply.IsEmpty() { + t.Fatal("non-owner direct message should be silently ignored") + } +} + +func TestIdentityResolverBindRunsBeforeMembershipCheck(t *testing.T) { + shadowID := "channelIdentity-shadow" + humanID := "channelIdentity-human" + channelIdentitySvc := &fakeChannelIdentityService{ + channelIdentity: identities.ChannelIdentity{ID: shadowID}, + linked: map[string]string{ + shadowID: shadowID, + }, + } + memberSvc := &fakeMemberService{isMember: true} + bindSvc := &fakeBindService{ + code: bind.Code{ + ID: "code-1", + Platform: "feishu", + Token: "BIND123", + IssuedByUserID: humanID, + ExpiresAt: time.Now().UTC().Add(1 * time.Hour), + }, + onConsume: func(channelChannelIdentityID string) { + channelIdentitySvc.linked[channelChannelIdentityID] = humanID + }, + } + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, nil, nil, bindSvc, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "BIND123"}, + ReplyTarget: "target-id", + Sender: channel.Identity{SubjectID: "ext-bind-1"}, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !bindSvc.consumeCalled { + t.Fatal("expected bind consume to run before membership shortcut") + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatal("bind flow should return stop decision") + } + if state.Identity.UserID != humanID { + t.Fatalf("expected linked user to switch to %s, got %s", humanID, state.Identity.UserID) + } + if memberSvc.upsertCalled { + t.Fatal("bind should not upsert bot membership") + } +} + +func TestIdentityResolverBindConsumeErrorHandledAsDecision(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-shadow"}} + bindSvc := &fakeBindService{ + code: bind.Code{ + ID: "code-2", + Platform: "telegram", + Token: "BINDUSED", + IssuedByUserID: "channelIdentity-human", + ExpiresAt: time.Now().UTC().Add(1 * time.Hour), + }, + consumeErr: bind.ErrCodeUsed, + } + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, &fakeMemberService{}, nil, nil, bindSvc, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("telegram"), + Message: channel.Message{Text: "BINDUSED"}, + ReplyTarget: "chat-123", + Sender: channel.Identity{SubjectID: "ext-bind-2"}, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatal("bind consume errors should be converted into stop decision") + } +} + +func TestIdentityResolverBindCodeNotScopedToCurrentBot(t *testing.T) { + shadowID := "channelIdentity-shadow-any-bot" + humanID := "channelIdentity-human-any-bot" + channelIdentitySvc := &fakeChannelIdentityService{ + channelIdentity: identities.ChannelIdentity{ID: shadowID}, + linked: map[string]string{ + shadowID: shadowID, + }, + } + bindSvc := &fakeBindService{ + code: bind.Code{ + ID: "code-any-bot", + Platform: "feishu", + Token: "BINDANYBOT", + IssuedByUserID: humanID, + ExpiresAt: time.Now().UTC().Add(1 * time.Hour), + }, + onConsume: func(channelChannelIdentityID string) { + channelIdentitySvc.linked[channelChannelIdentityID] = humanID + }, + } + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, &fakeMemberService{}, nil, nil, bindSvc, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-2", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "BINDANYBOT"}, + ReplyTarget: "target-id", + Sender: channel.Identity{SubjectID: "ext-bind-any-bot"}, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !bindSvc.consumeCalled { + t.Fatal("bind consume should run even when message bot differs") + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatal("bind flow should return stop decision") + } + if state.Identity.UserID != humanID { + t.Fatalf("expected linked user to switch to %s, got %s", humanID, state.Identity.UserID) + } +} + +func TestIdentityResolverPublicBotGroupDeniedSilently(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-group-denied"}} + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: false, botType: "public"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "Access denied.", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello"}, + ReplyTarget: "group-target", + Sender: channel.Identity{SubjectID: "stranger-group"}, + Conversation: channel.Conversation{ + ID: "group-1", + Type: "group", + }, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatal("unauthorized group message should be stopped") + } + if !state.Decision.Reply.IsEmpty() { + t.Fatal("unauthorized group message should be silently dropped, not replied") + } +} + +func TestIdentityResolverPublicBotDirectDeniedWithReply(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-direct-denied"}} + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: false, botType: "public"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "Access denied.", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello"}, + ReplyTarget: "direct-target", + Sender: channel.Identity{SubjectID: "stranger-direct"}, + Conversation: channel.Conversation{ + ID: "p2p-1", + Type: "p2p", + }, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatal("unauthorized direct message should be stopped") + } + if state.Decision.Reply.IsEmpty() { + t.Fatal("unauthorized direct message should reply with access denied") + } +} + +func TestIdentityResolverBindCodePlatformMismatch(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-platform-mismatch"}} + bindSvc := &fakeBindService{ + code: bind.Code{ + ID: "code-platform", + Platform: "telegram", + Token: "BINDPLATFORM", + IssuedByUserID: "channelIdentity-human-platform", + ExpiresAt: time.Now().UTC().Add(1 * time.Hour), + }, + } + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, &fakeMemberService{}, nil, nil, bindSvc, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "BINDPLATFORM"}, + ReplyTarget: "target-id", + Sender: channel.Identity{SubjectID: "ext-bind-platform"}, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if bindSvc.consumeCalled { + t.Fatal("bind consume should not run when platform mismatches") + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatal("platform mismatch should return stop decision") } } diff --git a/internal/schedule/service.go b/internal/schedule/service.go index 42aff905..c9cd2e43 100644 --- a/internal/schedule/service.go +++ b/internal/schedule/service.go @@ -9,13 +9,13 @@ import ( "sync" "time" - "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/robfig/cron/v3" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/boot" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -72,7 +72,7 @@ func (s *Service) Create(ctx context.Context, botID string, req CreateRequest) ( if _, err := s.parser.Parse(req.Pattern); err != nil { return Schedule{}, fmt.Errorf("invalid cron pattern: %w", err) } - pgBotID, err := parseUUID(botID) + pgBotID, err := db.ParseUUID(botID) if err != nil { return Schedule{}, err } @@ -105,7 +105,7 @@ func (s *Service) Create(ctx context.Context, botID string, req CreateRequest) ( } func (s *Service) Get(ctx context.Context, id string) (Schedule, error) { - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return Schedule{}, err } @@ -120,7 +120,7 @@ func (s *Service) Get(ctx context.Context, id string) (Schedule, error) { } func (s *Service) List(ctx context.Context, botID string) ([]Schedule, error) { - pgBotID, err := parseUUID(botID) + pgBotID, err := db.ParseUUID(botID) if err != nil { return nil, err } @@ -136,7 +136,7 @@ func (s *Service) List(ctx context.Context, botID string) ([]Schedule, error) { } func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Schedule, error) { - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return Schedule{}, err } @@ -192,7 +192,7 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Sch } func (s *Service) Delete(ctx context.Context, id string) error { - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return err } @@ -254,7 +254,7 @@ func (s *Service) runSchedule(ctx context.Context, schedule Schedule) error { // resolveBotOwner returns the owner user ID for the given bot. func (s *Service) resolveBotOwner(ctx context.Context, botID string) (string, error) { - pgBotID, err := parseUUID(botID) + pgBotID, err := db.ParseUUID(botID) if err != nil { return "", err } @@ -262,7 +262,7 @@ func (s *Service) resolveBotOwner(ctx context.Context, botID string) (string, er if err != nil { return "", fmt.Errorf("get bot: %w", err) } - ownerID := toUUIDString(bot.OwnerUserID) + ownerID := bot.OwnerUserID.String() if ownerID == "" { return "", fmt.Errorf("bot owner not found") } @@ -282,7 +282,7 @@ func (s *Service) generateTriggerToken(userID string) (string, error) { } func (s *Service) scheduleJob(schedule sqlc.Schedule) error { - id := toUUIDString(schedule.ID) + id := schedule.ID.String() if id == "" { return fmt.Errorf("schedule id missing") } @@ -300,7 +300,7 @@ func (s *Service) scheduleJob(schedule sqlc.Schedule) error { } func (s *Service) rescheduleJob(schedule sqlc.Schedule) { - id := toUUIDString(schedule.ID) + id := schedule.ID.String() if id == "" { return } @@ -322,14 +322,14 @@ func (s *Service) removeJob(id string) { func toSchedule(row sqlc.Schedule) Schedule { item := Schedule{ - ID: toUUIDString(row.ID), + ID: row.ID.String(), Name: row.Name, Description: row.Description, Pattern: row.Pattern, CurrentCalls: int(row.CurrentCalls), Enabled: row.Enabled, Command: row.Command, - BotID: toUUIDString(row.BotID), + BotID: row.BotID.String(), } if row.MaxCalls.Valid { max := int(row.MaxCalls.Int32) @@ -344,32 +344,10 @@ func toSchedule(row sqlc.Schedule) Schedule { return item } -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} - func toUUID(id string) pgtype.UUID { - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return pgtype.UUID{} } return pgID } - -func toUUIDString(value pgtype.UUID) string { - if !value.Valid { - return "" - } - id, err := uuid.FromBytes(value.Bytes[:]) - if err != nil { - return "" - } - return id.String() -} diff --git a/internal/schedule/trigger.go b/internal/schedule/trigger.go index e2d3b5ab..b15ed3e3 100644 --- a/internal/schedule/trigger.go +++ b/internal/schedule/trigger.go @@ -11,9 +11,10 @@ type TriggerPayload struct { MaxCalls *int Command string OwnerUserID string + ChatID string } -// Triggerer 负责触发与聊天相关的调度执行。 +// Triggerer triggers schedule execution for chat-related jobs. type Triggerer interface { TriggerSchedule(ctx context.Context, botID string, payload TriggerPayload, token string) error } diff --git a/internal/schedule/types.go b/internal/schedule/types.go index ffa0cbc7..7c95b1f6 100644 --- a/internal/schedule/types.go +++ b/internal/schedule/types.go @@ -50,21 +50,21 @@ func (n *NullableInt) UnmarshalJSON(data []byte) error { } type CreateRequest struct { - Name string `json:"name"` - Description string `json:"description"` - Pattern string `json:"pattern"` + Name string `json:"name"` + Description string `json:"description"` + Pattern string `json:"pattern"` MaxCalls NullableInt `json:"max_calls,omitempty"` - Command string `json:"command"` - Enabled *bool `json:"enabled,omitempty"` + Command string `json:"command"` + Enabled *bool `json:"enabled,omitempty"` } type UpdateRequest struct { - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Pattern *string `json:"pattern,omitempty"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Pattern *string `json:"pattern,omitempty"` MaxCalls NullableInt `json:"max_calls,omitempty"` - Command *string `json:"command,omitempty"` - Enabled *bool `json:"enabled,omitempty"` + Command *string `json:"command,omitempty"` + Enabled *bool `json:"enabled,omitempty"` } type ListResponse struct { diff --git a/internal/settings/service.go b/internal/settings/service.go index 2607c807..3b3bf9de 100644 --- a/internal/settings/service.go +++ b/internal/settings/service.go @@ -7,10 +7,10 @@ import ( "log/slog" "strings" - "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -19,6 +19,8 @@ type Service struct { logger *slog.Logger } +var ErrPersonalBotGuestAccessUnsupported = errors.New("personal bots do not support guest access") + func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { return &Service{ queries: queries, @@ -26,8 +28,9 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { } } +// Get returns user-level settings. func (s *Service) Get(ctx context.Context, userID string) (Settings, error) { - pgID, err := parseUUID(userID) + pgID, err := db.ParseUUID(userID) if err != nil { return Settings{}, err } @@ -47,11 +50,12 @@ func (s *Service) Get(ctx context.Context, userID string) (Settings, error) { return normalizeUserSetting(row), nil } +// Upsert creates or updates user-level settings. func (s *Service) Upsert(ctx context.Context, userID string, req UpsertRequest) (Settings, error) { if s.queries == nil { return Settings{}, fmt.Errorf("settings queries not configured") } - pgID, err := parseUUID(userID) + pgID, err := db.ParseUUID(userID) if err != nil { return Settings{}, err } @@ -88,7 +92,7 @@ func (s *Service) Upsert(ctx context.Context, userID string, req UpsertRequest) } _, err = s.queries.UpsertUserSettings(ctx, sqlc.UpsertUserSettingsParams{ - UserID: pgID, + ID: pgID, ChatModelID: pgtype.Text{String: current.ChatModelID, Valid: current.ChatModelID != ""}, MemoryModelID: pgtype.Text{String: current.MemoryModelID, Valid: current.MemoryModelID != ""}, EmbeddingModelID: pgtype.Text{String: current.EmbeddingModelID, Valid: current.EmbeddingModelID != ""}, @@ -102,93 +106,99 @@ func (s *Service) Upsert(ctx context.Context, userID string, req UpsertRequest) } func (s *Service) GetBot(ctx context.Context, botID string) (Settings, error) { - pgID, err := parseUUID(botID) + pgID, err := db.ParseUUID(botID) if err != nil { return Settings{}, err } row, err := s.queries.GetSettingsByBotID(ctx, pgID) if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - settings := Settings{ - MaxContextLoadTime: DefaultMaxContextLoadTime, - Language: DefaultLanguage, - AllowGuest: false, - } - if err := s.attachBotModelConfig(ctx, pgID, &settings); err != nil { - return Settings{}, err - } - return settings, nil - } return Settings{}, err } - settings := normalizeBotSetting(row) - if err := s.attachBotModelConfig(ctx, pgID, &settings); err != nil { - return Settings{}, err - } - return settings, nil + return normalizeBotSettingsReadRow(row), nil } func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest) (Settings, error) { if s.queries == nil { return Settings{}, fmt.Errorf("settings queries not configured") } - pgID, err := parseUUID(botID) + pgID, err := db.ParseUUID(botID) if err != nil { return Settings{}, err } - - current := Settings{ - MaxContextLoadTime: DefaultMaxContextLoadTime, - Language: DefaultLanguage, - AllowGuest: false, - } - existing, err := s.queries.GetSettingsByBotID(ctx, pgID) - if err != nil && !errors.Is(err, pgx.ErrNoRows) { + botRow, err := s.queries.GetBotByID(ctx, pgID) + if err != nil { return Settings{}, err } - if err == nil { - current = normalizeBotSetting(existing) - } + isPersonalBot := strings.EqualFold(strings.TrimSpace(botRow.Type), "personal") + + current := normalizeBotSetting(botRow.MaxContextLoadTime, botRow.Language, botRow.AllowGuest) if req.MaxContextLoadTime != nil && *req.MaxContextLoadTime > 0 { current.MaxContextLoadTime = *req.MaxContextLoadTime } if strings.TrimSpace(req.Language) != "" { current.Language = strings.TrimSpace(req.Language) } - if req.AllowGuest != nil { + if isPersonalBot { + if req.AllowGuest != nil && *req.AllowGuest { + return Settings{}, ErrPersonalBotGuestAccessUnsupported + } + current.AllowGuest = false + } else if req.AllowGuest != nil { current.AllowGuest = *req.AllowGuest } - _, err = s.queries.UpsertBotSettings(ctx, sqlc.UpsertBotSettingsParams{ - BotID: pgID, + chatModelUUID := pgtype.UUID{} + if value := strings.TrimSpace(req.ChatModelID); value != "" { + modelID, err := s.resolveModelUUID(ctx, value) + if err != nil { + return Settings{}, err + } + chatModelUUID = modelID + } + memoryModelUUID := pgtype.UUID{} + if value := strings.TrimSpace(req.MemoryModelID); value != "" { + modelID, err := s.resolveModelUUID(ctx, value) + if err != nil { + return Settings{}, err + } + memoryModelUUID = modelID + } + embeddingModelUUID := pgtype.UUID{} + if value := strings.TrimSpace(req.EmbeddingModelID); value != "" { + modelID, err := s.resolveModelUUID(ctx, value) + if err != nil { + return Settings{}, err + } + embeddingModelUUID = modelID + } + + updated, err := s.queries.UpsertBotSettings(ctx, sqlc.UpsertBotSettingsParams{ + ID: pgID, MaxContextLoadTime: int32(current.MaxContextLoadTime), Language: current.Language, AllowGuest: current.AllowGuest, + ChatModelID: chatModelUUID, + MemoryModelID: memoryModelUUID, + EmbeddingModelID: embeddingModelUUID, }) if err != nil { return Settings{}, err } - if err := s.upsertBotModelConfig(ctx, pgID, req); err != nil { - return Settings{}, err - } - if err := s.attachBotModelConfig(ctx, pgID, ¤t); err != nil { - return Settings{}, err - } - return current, nil + return normalizeBotSettingsWriteRow(updated), nil } func (s *Service) Delete(ctx context.Context, botID string) error { if s.queries == nil { return fmt.Errorf("settings queries not configured") } - pgID, err := parseUUID(botID) + pgID, err := db.ParseUUID(botID) if err != nil { return err } return s.queries.DeleteSettingsByBotID(ctx, pgID) } -func normalizeUserSetting(row sqlc.UserSetting) Settings { +func normalizeUserSetting(row sqlc.GetSettingsByUserIDRow) Settings { settings := Settings{ ChatModelID: strings.TrimSpace(row.ChatModelID.String), MemoryModelID: strings.TrimSpace(row.MemoryModelID.String), @@ -205,11 +215,11 @@ func normalizeUserSetting(row sqlc.UserSetting) Settings { return settings } -func normalizeBotSetting(row sqlc.BotSetting) Settings { +func normalizeBotSetting(maxContextLoadTime int32, language string, allowGuest bool) Settings { settings := Settings{ - MaxContextLoadTime: int(row.MaxContextLoadTime), - Language: strings.TrimSpace(row.Language), - AllowGuest: row.AllowGuest, + MaxContextLoadTime: int(maxContextLoadTime), + Language: strings.TrimSpace(language), + AllowGuest: allowGuest, } if settings.MaxContextLoadTime <= 0 { settings.MaxContextLoadTime = DefaultMaxContextLoadTime @@ -220,60 +230,41 @@ func normalizeBotSetting(row sqlc.BotSetting) Settings { return settings } -func (s *Service) attachBotModelConfig(ctx context.Context, botID pgtype.UUID, target *Settings) error { - if s.queries == nil || target == nil { - return nil - } - row, err := s.queries.GetBotModelConfigByBotID(ctx, botID) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return nil - } - return err - } - target.ChatModelID = strings.TrimSpace(row.ChatModelID.String) - target.MemoryModelID = strings.TrimSpace(row.MemoryModelID.String) - target.EmbeddingModelID = strings.TrimSpace(row.EmbeddingModelID.String) - return nil +func normalizeBotSettingsReadRow(row sqlc.GetSettingsByBotIDRow) Settings { + return normalizeBotSettingsFields( + row.MaxContextLoadTime, + row.Language, + row.AllowGuest, + row.ChatModelID, + row.MemoryModelID, + row.EmbeddingModelID, + ) } -func (s *Service) upsertBotModelConfig(ctx context.Context, botID pgtype.UUID, req UpsertRequest) error { - if s.queries == nil { - return fmt.Errorf("settings queries not configured") - } - params := sqlc.UpsertBotModelConfigParams{ - BotID: botID, - } - hasUpdate := false - if value := strings.TrimSpace(req.ChatModelID); value != "" { - modelID, err := s.resolveModelUUID(ctx, value) - if err != nil { - return err - } - params.ChatModelID = modelID - hasUpdate = true - } - if value := strings.TrimSpace(req.MemoryModelID); value != "" { - modelID, err := s.resolveModelUUID(ctx, value) - if err != nil { - return err - } - params.MemoryModelID = modelID - hasUpdate = true - } - if value := strings.TrimSpace(req.EmbeddingModelID); value != "" { - modelID, err := s.resolveModelUUID(ctx, value) - if err != nil { - return err - } - params.EmbeddingModelID = modelID - hasUpdate = true - } - if !hasUpdate { - return nil - } - _, err := s.queries.UpsertBotModelConfig(ctx, params) - return err +func normalizeBotSettingsWriteRow(row sqlc.UpsertBotSettingsRow) Settings { + return normalizeBotSettingsFields( + row.MaxContextLoadTime, + row.Language, + row.AllowGuest, + row.ChatModelID, + row.MemoryModelID, + row.EmbeddingModelID, + ) +} + +func normalizeBotSettingsFields( + maxContextLoadTime int32, + language string, + allowGuest bool, + chatModelID pgtype.Text, + memoryModelID pgtype.Text, + embeddingModelID pgtype.Text, +) Settings { + settings := normalizeBotSetting(maxContextLoadTime, language, allowGuest) + settings.ChatModelID = strings.TrimSpace(chatModelID.String) + settings.MemoryModelID = strings.TrimSpace(memoryModelID.String) + settings.EmbeddingModelID = strings.TrimSpace(embeddingModelID.String) + return settings } func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype.UUID, error) { @@ -287,13 +278,3 @@ func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype. return row.ID, nil } -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(id) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} diff --git a/internal/settings/types.go b/internal/settings/types.go index 9950566c..f750f8ff 100644 --- a/internal/settings/types.go +++ b/internal/settings/types.go @@ -6,12 +6,12 @@ const ( ) type Settings struct { - ChatModelID string `json:"chat_model_id" validate:"required"` - MemoryModelID string `json:"memory_model_id" validate:"required"` - EmbeddingModelID string `json:"embedding_model_id" validate:"required"` - MaxContextLoadTime int `json:"max_context_load_time" validate:"required"` - Language string `json:"language" validate:"required"` - AllowGuest bool `json:"allow_guest" validate:"required"` + ChatModelID string `json:"chat_model_id"` + MemoryModelID string `json:"memory_model_id"` + EmbeddingModelID string `json:"embedding_model_id"` + MaxContextLoadTime int `json:"max_context_load_time"` + Language string `json:"language"` + AllowGuest bool `json:"allow_guest"` } type UpsertRequest struct { diff --git a/internal/subagent/service.go b/internal/subagent/service.go index c864a4db..c9e5ce77 100644 --- a/internal/subagent/service.go +++ b/internal/subagent/service.go @@ -8,10 +8,9 @@ import ( "log/slog" "strings" - "github.com/google/uuid" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -39,7 +38,7 @@ func (s *Service) Create(ctx context.Context, botID string, req CreateRequest) ( if description == "" { return Subagent{}, fmt.Errorf("description is required") } - pgBotID, err := parseUUID(botID) + pgBotID, err := db.ParseUUID(botID) if err != nil { return Subagent{}, err } @@ -70,7 +69,7 @@ func (s *Service) Create(ctx context.Context, botID string, req CreateRequest) ( } func (s *Service) Get(ctx context.Context, id string) (Subagent, error) { - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return Subagent{}, err } @@ -85,7 +84,7 @@ func (s *Service) Get(ctx context.Context, id string) (Subagent, error) { } func (s *Service) List(ctx context.Context, botID string) ([]Subagent, error) { - pgBotID, err := parseUUID(botID) + pgBotID, err := db.ParseUUID(botID) if err != nil { return nil, err } @@ -131,7 +130,7 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Sub if err != nil { return Subagent{}, err } - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return Subagent{}, err } @@ -152,7 +151,7 @@ func (s *Service) UpdateContext(ctx context.Context, id string, req UpdateContex if err != nil { return Subagent{}, err } - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return Subagent{}, err } @@ -171,7 +170,7 @@ func (s *Service) UpdateSkills(ctx context.Context, id string, req UpdateSkillsR if err != nil { return Subagent{}, err } - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return Subagent{}, err } @@ -195,7 +194,7 @@ func (s *Service) AddSkills(ctx context.Context, id string, req AddSkillsRequest if err != nil { return Subagent{}, err } - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return Subagent{}, err } @@ -210,7 +209,7 @@ func (s *Service) AddSkills(ctx context.Context, id string, req AddSkillsRequest } func (s *Service) Delete(ctx context.Context, id string) error { - pgID, err := parseUUID(id) + pgID, err := db.ParseUUID(id) if err != nil { return err } @@ -231,10 +230,10 @@ func toSubagent(row sqlc.Subagent) (Subagent, error) { return Subagent{}, err } item := Subagent{ - ID: toUUIDString(row.ID), + ID: row.ID.String(), Name: row.Name, Description: row.Description, - BotID: toUUIDString(row.BotID), + BotID: row.BotID.String(), Messages: messages, Metadata: metadata, Skills: skills, @@ -336,24 +335,3 @@ func mergeSkills(existing []string, incoming []string) []string { return normalizeSkills(merged) } -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} - -func toUUIDString(value pgtype.UUID) string { - if !value.Valid { - return "" - } - id, err := uuid.FromBytes(value.Bytes[:]) - if err != nil { - return "" - } - return id.String() -} diff --git a/internal/subagent/types.go b/internal/subagent/types.go index 77498a12..38207e0d 100644 --- a/internal/subagent/types.go +++ b/internal/subagent/types.go @@ -3,30 +3,30 @@ package subagent import "time" type Subagent struct { - ID string `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - BotID string `json:"bot_id"` + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + BotID string `json:"bot_id"` Messages []map[string]any `json:"messages"` Metadata map[string]any `json:"metadata"` - Skills []string `json:"skills"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - Deleted bool `json:"deleted"` - DeletedAt *time.Time `json:"deleted_at,omitempty"` + Skills []string `json:"skills"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Deleted bool `json:"deleted"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` } type CreateRequest struct { - Name string `json:"name"` - Description string `json:"description"` + Name string `json:"name"` + Description string `json:"description"` Messages []map[string]any `json:"messages,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` - Skills []string `json:"skills,omitempty"` + Skills []string `json:"skills,omitempty"` } type UpdateRequest struct { - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } diff --git a/packages/sdk/src/@pinia/colada.gen.ts b/packages/sdk/src/@pinia/colada.gen.ts index acb165a5..3417e0c2 100644 --- a/packages/sdk/src/@pinia/colada.gen.ts +++ b/packages/sdk/src/@pinia/colada.gen.ts @@ -4,8 +4,8 @@ import { type _JSONValue, defineQueryOptions, type UseMutationOptions } from '@p import { serializeQueryKeyValue } from '../client'; import { client } from '../client.gen'; -import { deleteBotsByBotIdContainer, deleteBotsByBotIdContainerFs, deleteBotsByBotIdContainerSkills, deleteBotsByBotIdHistory, deleteBotsByBotIdHistoryById, deleteBotsByBotIdMcpById, deleteBotsByBotIdMemoryMemories, deleteBotsByBotIdMemoryMemoriesByMemoryId, deleteBotsByBotIdScheduleById, deleteBotsByBotIdSettings, deleteBotsByBotIdSubagentsById, deleteBotsById, deleteBotsByIdMembersByUserId, deleteModelsById, deleteModelsModelByModelId, deleteProvidersById, getBots, getBotsByBotIdContainer, getBotsByBotIdContainerFs, getBotsByBotIdContainerFsFile, getBotsByBotIdContainerFsStat, getBotsByBotIdContainerFsUsage, getBotsByBotIdContainerSkills, getBotsByBotIdContainerSnapshots, getBotsByBotIdHistory, getBotsByBotIdHistoryById, getBotsByBotIdMcp, getBotsByBotIdMcpById, getBotsByBotIdMemoryMemories, getBotsByBotIdMemoryMemoriesByMemoryId, getBotsByBotIdSchedule, getBotsByBotIdScheduleById, getBotsByBotIdSettings, getBotsByBotIdSubagents, getBotsByBotIdSubagentsById, getBotsByBotIdSubagentsByIdContext, getBotsByBotIdSubagentsByIdSkills, getBotsById, getBotsByIdChannelByPlatform, getBotsByIdMembers, getChannels, getChannelsByPlatform, getModels, getModelsById, getModelsCount, getModelsModelByModelId, getProviders, getProvidersById, getProvidersByIdModels, getProvidersCount, getProvidersNameByName, getUsers, getUsersById, getUsersMe, getUsersMeChannelsByPlatform, type Options, postAuthLogin, postBots, postBotsByBotIdChat, postBotsByBotIdChatStream, postBotsByBotIdContainer, postBotsByBotIdContainerFsDir, postBotsByBotIdContainerFsFile, postBotsByBotIdContainerFsMcp, postBotsByBotIdContainerFsUpload, postBotsByBotIdContainerSkills, postBotsByBotIdContainerSnapshots, postBotsByBotIdContainerStart, postBotsByBotIdContainerStop, postBotsByBotIdHistory, postBotsByBotIdMcp, postBotsByBotIdMcpStdio, postBotsByBotIdMcpStdioBySessionId, postBotsByBotIdMemoryAdd, postBotsByBotIdMemoryEmbed, postBotsByBotIdMemorySearch, postBotsByBotIdMemoryUpdate, postBotsByBotIdSchedule, postBotsByBotIdSettings, postBotsByBotIdSubagents, postBotsByBotIdSubagentsByIdSkills, postBotsByIdChannelByPlatformSend, postBotsByIdChannelByPlatformSendSession, postEmbeddings, postModels, postModelsEnable, postProviders, postUsers, putBotsByBotIdMcpById, putBotsByBotIdScheduleById, putBotsByBotIdSettings, putBotsByBotIdSubagentsById, putBotsByBotIdSubagentsByIdContext, putBotsByBotIdSubagentsByIdSkills, putBotsById, putBotsByIdChannelByPlatform, putBotsByIdMembers, putBotsByIdOwner, putModelsById, putModelsModelByModelId, putProvidersById, putUsersById, putUsersByIdPassword, putUsersMe, putUsersMeChannelsByPlatform, putUsersMePassword } from '../sdk.gen'; -import type { DeleteBotsByBotIdContainerData, DeleteBotsByBotIdContainerError, DeleteBotsByBotIdContainerFsData, DeleteBotsByBotIdContainerFsError, DeleteBotsByBotIdContainerFsResponse, DeleteBotsByBotIdContainerSkillsData, DeleteBotsByBotIdContainerSkillsError, DeleteBotsByBotIdContainerSkillsResponse, DeleteBotsByBotIdHistoryByIdData, DeleteBotsByBotIdHistoryByIdError, DeleteBotsByBotIdHistoryData, DeleteBotsByBotIdHistoryError, DeleteBotsByBotIdMcpByIdData, DeleteBotsByBotIdMcpByIdError, DeleteBotsByBotIdMemoryMemoriesByMemoryIdData, DeleteBotsByBotIdMemoryMemoriesByMemoryIdError, DeleteBotsByBotIdMemoryMemoriesByMemoryIdResponse, DeleteBotsByBotIdMemoryMemoriesData, DeleteBotsByBotIdMemoryMemoriesError, DeleteBotsByBotIdMemoryMemoriesResponse, DeleteBotsByBotIdScheduleByIdData, DeleteBotsByBotIdScheduleByIdError, DeleteBotsByBotIdSettingsData, DeleteBotsByBotIdSettingsError, DeleteBotsByBotIdSubagentsByIdData, DeleteBotsByBotIdSubagentsByIdError, DeleteBotsByIdData, DeleteBotsByIdError, DeleteBotsByIdMembersByUserIdData, DeleteBotsByIdMembersByUserIdError, DeleteModelsByIdData, DeleteModelsByIdError, DeleteModelsModelByModelIdData, DeleteModelsModelByModelIdError, DeleteProvidersByIdData, DeleteProvidersByIdError, GetBotsByBotIdContainerData, GetBotsByBotIdContainerFsData, GetBotsByBotIdContainerFsFileData, GetBotsByBotIdContainerFsStatData, GetBotsByBotIdContainerFsUsageData, GetBotsByBotIdContainerSkillsData, GetBotsByBotIdContainerSnapshotsData, GetBotsByBotIdHistoryByIdData, GetBotsByBotIdHistoryData, GetBotsByBotIdMcpByIdData, GetBotsByBotIdMcpData, GetBotsByBotIdMemoryMemoriesByMemoryIdData, GetBotsByBotIdMemoryMemoriesData, GetBotsByBotIdScheduleByIdData, GetBotsByBotIdScheduleData, GetBotsByBotIdSettingsData, GetBotsByBotIdSubagentsByIdContextData, GetBotsByBotIdSubagentsByIdData, GetBotsByBotIdSubagentsByIdSkillsData, GetBotsByBotIdSubagentsData, GetBotsByIdChannelByPlatformData, GetBotsByIdData, GetBotsByIdMembersData, GetBotsData, GetChannelsByPlatformData, GetChannelsData, GetModelsByIdData, GetModelsCountData, GetModelsData, GetModelsModelByModelIdData, GetProvidersByIdData, GetProvidersByIdModelsData, GetProvidersCountData, GetProvidersData, GetProvidersNameByNameData, GetUsersByIdData, GetUsersData, GetUsersMeChannelsByPlatformData, GetUsersMeData, PostAuthLoginData, PostAuthLoginError, PostAuthLoginResponse, PostBotsByBotIdChatData, PostBotsByBotIdChatError, PostBotsByBotIdChatResponse, PostBotsByBotIdChatStreamData, PostBotsByBotIdChatStreamError, PostBotsByBotIdChatStreamResponse, PostBotsByBotIdContainerData, PostBotsByBotIdContainerError, PostBotsByBotIdContainerFsDirData, PostBotsByBotIdContainerFsDirError, PostBotsByBotIdContainerFsDirResponse, PostBotsByBotIdContainerFsFileData, PostBotsByBotIdContainerFsFileError, PostBotsByBotIdContainerFsFileResponse, PostBotsByBotIdContainerFsMcpData, PostBotsByBotIdContainerFsMcpError, PostBotsByBotIdContainerFsMcpResponse, PostBotsByBotIdContainerFsUploadData, PostBotsByBotIdContainerFsUploadError, PostBotsByBotIdContainerFsUploadResponse, PostBotsByBotIdContainerResponse, PostBotsByBotIdContainerSkillsData, PostBotsByBotIdContainerSkillsError, PostBotsByBotIdContainerSkillsResponse, PostBotsByBotIdContainerSnapshotsData, PostBotsByBotIdContainerSnapshotsError, PostBotsByBotIdContainerSnapshotsResponse, PostBotsByBotIdContainerStartData, PostBotsByBotIdContainerStartError, PostBotsByBotIdContainerStartResponse, PostBotsByBotIdContainerStopData, PostBotsByBotIdContainerStopError, PostBotsByBotIdContainerStopResponse, PostBotsByBotIdHistoryData, PostBotsByBotIdHistoryError, PostBotsByBotIdHistoryResponse, PostBotsByBotIdMcpData, PostBotsByBotIdMcpError, PostBotsByBotIdMcpResponse, PostBotsByBotIdMcpStdioBySessionIdData, PostBotsByBotIdMcpStdioBySessionIdError, PostBotsByBotIdMcpStdioBySessionIdResponse, PostBotsByBotIdMcpStdioData, PostBotsByBotIdMcpStdioError, PostBotsByBotIdMcpStdioResponse, PostBotsByBotIdMemoryAddData, PostBotsByBotIdMemoryAddError, PostBotsByBotIdMemoryAddResponse, PostBotsByBotIdMemoryEmbedData, PostBotsByBotIdMemoryEmbedError, PostBotsByBotIdMemoryEmbedResponse, PostBotsByBotIdMemorySearchData, PostBotsByBotIdMemorySearchError, PostBotsByBotIdMemorySearchResponse, PostBotsByBotIdMemoryUpdateData, PostBotsByBotIdMemoryUpdateError, PostBotsByBotIdMemoryUpdateResponse, PostBotsByBotIdScheduleData, PostBotsByBotIdScheduleError, PostBotsByBotIdScheduleResponse, PostBotsByBotIdSettingsData, PostBotsByBotIdSettingsError, PostBotsByBotIdSettingsResponse, PostBotsByBotIdSubagentsByIdSkillsData, PostBotsByBotIdSubagentsByIdSkillsError, PostBotsByBotIdSubagentsByIdSkillsResponse, PostBotsByBotIdSubagentsData, PostBotsByBotIdSubagentsError, PostBotsByBotIdSubagentsResponse, PostBotsByIdChannelByPlatformSendData, PostBotsByIdChannelByPlatformSendError, PostBotsByIdChannelByPlatformSendResponse, PostBotsByIdChannelByPlatformSendSessionData, PostBotsByIdChannelByPlatformSendSessionError, PostBotsByIdChannelByPlatformSendSessionResponse, PostBotsData, PostBotsError, PostBotsResponse, PostEmbeddingsData, PostEmbeddingsError, PostEmbeddingsResponse, PostModelsData, PostModelsEnableData, PostModelsEnableError, PostModelsEnableResponse, PostModelsError, PostModelsResponse, PostProvidersData, PostProvidersError, PostProvidersResponse, PostUsersData, PostUsersError, PostUsersResponse, PutBotsByBotIdMcpByIdData, PutBotsByBotIdMcpByIdError, PutBotsByBotIdMcpByIdResponse, PutBotsByBotIdScheduleByIdData, PutBotsByBotIdScheduleByIdError, PutBotsByBotIdScheduleByIdResponse, PutBotsByBotIdSettingsData, PutBotsByBotIdSettingsError, PutBotsByBotIdSettingsResponse, PutBotsByBotIdSubagentsByIdContextData, PutBotsByBotIdSubagentsByIdContextError, PutBotsByBotIdSubagentsByIdContextResponse, PutBotsByBotIdSubagentsByIdData, PutBotsByBotIdSubagentsByIdError, PutBotsByBotIdSubagentsByIdResponse, PutBotsByBotIdSubagentsByIdSkillsData, PutBotsByBotIdSubagentsByIdSkillsError, PutBotsByBotIdSubagentsByIdSkillsResponse, PutBotsByIdChannelByPlatformData, PutBotsByIdChannelByPlatformError, PutBotsByIdChannelByPlatformResponse, PutBotsByIdData, PutBotsByIdError, PutBotsByIdMembersData, PutBotsByIdMembersError, PutBotsByIdMembersResponse, PutBotsByIdOwnerData, PutBotsByIdOwnerError, PutBotsByIdOwnerResponse, PutBotsByIdResponse, PutModelsByIdData, PutModelsByIdError, PutModelsByIdResponse, PutModelsModelByModelIdData, PutModelsModelByModelIdError, PutModelsModelByModelIdResponse, PutProvidersByIdData, PutProvidersByIdError, PutProvidersByIdResponse, PutUsersByIdData, PutUsersByIdError, PutUsersByIdPasswordData, PutUsersByIdPasswordError, PutUsersByIdResponse, PutUsersMeChannelsByPlatformData, PutUsersMeChannelsByPlatformError, PutUsersMeChannelsByPlatformResponse, PutUsersMeData, PutUsersMeError, PutUsersMePasswordData, PutUsersMePasswordError, PutUsersMeResponse } from '../types.gen'; +import { deleteBotsByBotIdContainer, deleteBotsByBotIdContainerSkills, deleteBotsByBotIdMcpById, deleteBotsByBotIdScheduleById, deleteBotsByBotIdSettings, deleteBotsByBotIdSubagentsById, deleteBotsById, deleteBotsByIdMembersByUserId, deleteModelsById, deleteModelsModelByModelId, deleteProvidersById, getBots, getBotsByBotIdContainer, getBotsByBotIdContainerSkills, getBotsByBotIdContainerSnapshots, getBotsByBotIdMcp, getBotsByBotIdMcpById, getBotsByBotIdSchedule, getBotsByBotIdScheduleById, getBotsByBotIdSettings, getBotsByBotIdSubagents, getBotsByBotIdSubagentsById, getBotsByBotIdSubagentsByIdContext, getBotsByBotIdSubagentsByIdSkills, getBotsById, getBotsByIdChannelByPlatform, getBotsByIdChecks, getBotsByIdMembers, getChannels, getChannelsByPlatform, getModels, getModelsById, getModelsCount, getModelsModelByModelId, getProviders, getProvidersById, getProvidersByIdModels, getProvidersCount, getProvidersNameByName, getUsers, getUsersById, getUsersMe, getUsersMeChannelsByPlatform, getUsersMeIdentities, type Options, postAuthLogin, postBots, postBotsByBotIdContainer, postBotsByBotIdContainerSkills, postBotsByBotIdContainerSnapshots, postBotsByBotIdContainerStart, postBotsByBotIdContainerStop, postBotsByBotIdMcp, postBotsByBotIdMcpStdio, postBotsByBotIdMcpStdioByConnectionId, postBotsByBotIdSchedule, postBotsByBotIdSettings, postBotsByBotIdSubagents, postBotsByBotIdSubagentsByIdSkills, postBotsByBotIdTools, postBotsByIdChannelByPlatformSend, postBotsByIdChannelByPlatformSendChat, postEmbeddings, postModels, postModelsEnable, postProviders, postUsers, putBotsByBotIdMcpById, putBotsByBotIdScheduleById, putBotsByBotIdSettings, putBotsByBotIdSubagentsById, putBotsByBotIdSubagentsByIdContext, putBotsByBotIdSubagentsByIdSkills, putBotsById, putBotsByIdChannelByPlatform, putBotsByIdMembers, putBotsByIdOwner, putModelsById, putModelsModelByModelId, putProvidersById, putUsersById, putUsersByIdPassword, putUsersMe, putUsersMeChannelsByPlatform, putUsersMePassword } from '../sdk.gen'; +import type { DeleteBotsByBotIdContainerData, DeleteBotsByBotIdContainerError, DeleteBotsByBotIdContainerSkillsData, DeleteBotsByBotIdContainerSkillsError, DeleteBotsByBotIdContainerSkillsResponse, DeleteBotsByBotIdMcpByIdData, DeleteBotsByBotIdMcpByIdError, DeleteBotsByBotIdScheduleByIdData, DeleteBotsByBotIdScheduleByIdError, DeleteBotsByBotIdSettingsData, DeleteBotsByBotIdSettingsError, DeleteBotsByBotIdSubagentsByIdData, DeleteBotsByBotIdSubagentsByIdError, DeleteBotsByIdData, DeleteBotsByIdError, DeleteBotsByIdMembersByUserIdData, DeleteBotsByIdMembersByUserIdError, DeleteBotsByIdResponse, DeleteModelsByIdData, DeleteModelsByIdError, DeleteModelsModelByModelIdData, DeleteModelsModelByModelIdError, DeleteProvidersByIdData, DeleteProvidersByIdError, GetBotsByBotIdContainerData, GetBotsByBotIdContainerSkillsData, GetBotsByBotIdContainerSnapshotsData, GetBotsByBotIdMcpByIdData, GetBotsByBotIdMcpData, GetBotsByBotIdScheduleByIdData, GetBotsByBotIdScheduleData, GetBotsByBotIdSettingsData, GetBotsByBotIdSubagentsByIdContextData, GetBotsByBotIdSubagentsByIdData, GetBotsByBotIdSubagentsByIdSkillsData, GetBotsByBotIdSubagentsData, GetBotsByIdChannelByPlatformData, GetBotsByIdChecksData, GetBotsByIdData, GetBotsByIdMembersData, GetBotsData, GetChannelsByPlatformData, GetChannelsData, GetModelsByIdData, GetModelsCountData, GetModelsData, GetModelsModelByModelIdData, GetProvidersByIdData, GetProvidersByIdModelsData, GetProvidersCountData, GetProvidersData, GetProvidersNameByNameData, GetUsersByIdData, GetUsersData, GetUsersMeChannelsByPlatformData, GetUsersMeData, GetUsersMeIdentitiesData, PostAuthLoginData, PostAuthLoginError, PostAuthLoginResponse, PostBotsByBotIdContainerData, PostBotsByBotIdContainerError, PostBotsByBotIdContainerResponse, PostBotsByBotIdContainerSkillsData, PostBotsByBotIdContainerSkillsError, PostBotsByBotIdContainerSkillsResponse, PostBotsByBotIdContainerSnapshotsData, PostBotsByBotIdContainerSnapshotsError, PostBotsByBotIdContainerSnapshotsResponse, PostBotsByBotIdContainerStartData, PostBotsByBotIdContainerStartError, PostBotsByBotIdContainerStartResponse, PostBotsByBotIdContainerStopData, PostBotsByBotIdContainerStopError, PostBotsByBotIdContainerStopResponse, PostBotsByBotIdMcpData, PostBotsByBotIdMcpError, PostBotsByBotIdMcpResponse, PostBotsByBotIdMcpStdioByConnectionIdData, PostBotsByBotIdMcpStdioByConnectionIdError, PostBotsByBotIdMcpStdioByConnectionIdResponse, PostBotsByBotIdMcpStdioData, PostBotsByBotIdMcpStdioError, PostBotsByBotIdMcpStdioResponse, PostBotsByBotIdScheduleData, PostBotsByBotIdScheduleError, PostBotsByBotIdScheduleResponse, PostBotsByBotIdSettingsData, PostBotsByBotIdSettingsError, PostBotsByBotIdSettingsResponse, PostBotsByBotIdSubagentsByIdSkillsData, PostBotsByBotIdSubagentsByIdSkillsError, PostBotsByBotIdSubagentsByIdSkillsResponse, PostBotsByBotIdSubagentsData, PostBotsByBotIdSubagentsError, PostBotsByBotIdSubagentsResponse, PostBotsByBotIdToolsData, PostBotsByBotIdToolsError, PostBotsByBotIdToolsResponse, PostBotsByIdChannelByPlatformSendChatData, PostBotsByIdChannelByPlatformSendChatError, PostBotsByIdChannelByPlatformSendChatResponse, PostBotsByIdChannelByPlatformSendData, PostBotsByIdChannelByPlatformSendError, PostBotsByIdChannelByPlatformSendResponse, PostBotsData, PostBotsError, PostBotsResponse, PostEmbeddingsData, PostEmbeddingsError, PostEmbeddingsResponse, PostModelsData, PostModelsEnableData, PostModelsEnableError, PostModelsEnableResponse, PostModelsError, PostModelsResponse, PostProvidersData, PostProvidersError, PostProvidersResponse, PostUsersData, PostUsersError, PostUsersResponse, PutBotsByBotIdMcpByIdData, PutBotsByBotIdMcpByIdError, PutBotsByBotIdMcpByIdResponse, PutBotsByBotIdScheduleByIdData, PutBotsByBotIdScheduleByIdError, PutBotsByBotIdScheduleByIdResponse, PutBotsByBotIdSettingsData, PutBotsByBotIdSettingsError, PutBotsByBotIdSettingsResponse, PutBotsByBotIdSubagentsByIdContextData, PutBotsByBotIdSubagentsByIdContextError, PutBotsByBotIdSubagentsByIdContextResponse, PutBotsByBotIdSubagentsByIdData, PutBotsByBotIdSubagentsByIdError, PutBotsByBotIdSubagentsByIdResponse, PutBotsByBotIdSubagentsByIdSkillsData, PutBotsByBotIdSubagentsByIdSkillsError, PutBotsByBotIdSubagentsByIdSkillsResponse, PutBotsByIdChannelByPlatformData, PutBotsByIdChannelByPlatformError, PutBotsByIdChannelByPlatformResponse, PutBotsByIdData, PutBotsByIdError, PutBotsByIdMembersData, PutBotsByIdMembersError, PutBotsByIdMembersResponse, PutBotsByIdOwnerData, PutBotsByIdOwnerError, PutBotsByIdOwnerResponse, PutBotsByIdResponse, PutModelsByIdData, PutModelsByIdError, PutModelsByIdResponse, PutModelsModelByModelIdData, PutModelsModelByModelIdError, PutModelsModelByModelIdResponse, PutProvidersByIdData, PutProvidersByIdError, PutProvidersByIdResponse, PutUsersByIdData, PutUsersByIdError, PutUsersByIdPasswordData, PutUsersByIdPasswordError, PutUsersByIdResponse, PutUsersMeChannelsByPlatformData, PutUsersMeChannelsByPlatformError, PutUsersMeChannelsByPlatformResponse, PutUsersMeData, PutUsersMeError, PutUsersMePasswordData, PutUsersMePasswordError, PutUsersMeResponse } from '../types.gen'; /** * Login @@ -93,38 +93,6 @@ export const postBotsMutation = (options?: Partial>): UseM } }); -/** - * Chat with AI - * - * Send a chat message and get a response. The system will automatically select an appropriate chat model from the database. - */ -export const postBotsByBotIdChatMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdChatError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdChat({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -/** - * Stream chat with AI - * - * Send a chat message and get a streaming response. The system will automatically select an appropriate chat model from the database. - */ -export const postBotsByBotIdChatStreamMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdChatStreamError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdChatStream({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - /** * Delete MCP container for bot */ @@ -170,160 +138,6 @@ export const postBotsByBotIdContainerMutation = (options?: Partial>): UseMutationOptions, DeleteBotsByBotIdContainerFsError> => ({ - mutation: async (vars) => { - const { data } = await deleteBotsByBotIdContainerFs({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -export const getBotsByBotIdContainerFsQueryKey = (options: Options) => createQueryKey('getBotsByBotIdContainerFs', options); - -/** - * List files for a bot - * - * List entries under a relative path - */ -export const getBotsByBotIdContainerFsQuery = defineQueryOptions((options: Options) => ({ - key: getBotsByBotIdContainerFsQueryKey(options), - query: async (context) => { - const { data } = await getBotsByBotIdContainerFs({ - ...options, - ...context, - throwOnError: true - }); - return data; - } -})); - -/** - * MCP filesystem tools (JSON-RPC) - * - * Forwards MCP JSON-RPC requests to the MCP server inside the container. - * Required: - * - container task is running - * - container has data mount (default /data) bound to /users/ - * - container image contains the "mcp" binary - * Auth: Bearer JWT is used to determine user_id (sub or user_id). - * Paths must be relative (no leading slash) and must not contain "..". - * - * Example: tools/list - * {"jsonrpc":"2.0","id":1,"method":"tools/list"} - * - * Example: tools/call (fs.read) - * {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"fs.read","arguments":{"path":"notes.txt"}}} - */ -export const postBotsByBotIdContainerFsMcpMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdContainerFsMcpError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdContainerFsMcp({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -/** - * Create a directory - */ -export const postBotsByBotIdContainerFsDirMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdContainerFsDirError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdContainerFsDir({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -export const getBotsByBotIdContainerFsFileQueryKey = (options: Options) => createQueryKey('getBotsByBotIdContainerFsFile', options); - -/** - * Read file content - */ -export const getBotsByBotIdContainerFsFileQuery = defineQueryOptions((options: Options) => ({ - key: getBotsByBotIdContainerFsFileQueryKey(options), - query: async (context) => { - const { data } = await getBotsByBotIdContainerFsFile({ - ...options, - ...context, - throwOnError: true - }); - return data; - } -})); - -/** - * Create or overwrite a file - */ -export const postBotsByBotIdContainerFsFileMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdContainerFsFileError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdContainerFsFile({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -export const getBotsByBotIdContainerFsStatQueryKey = (options: Options) => createQueryKey('getBotsByBotIdContainerFsStat', options); - -/** - * Get file or directory metadata - */ -export const getBotsByBotIdContainerFsStatQuery = defineQueryOptions((options: Options) => ({ - key: getBotsByBotIdContainerFsStatQueryKey(options), - query: async (context) => { - const { data } = await getBotsByBotIdContainerFsStat({ - ...options, - ...context, - throwOnError: true - }); - return data; - } -})); - -/** - * Upload a file - */ -export const postBotsByBotIdContainerFsUploadMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdContainerFsUploadError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdContainerFsUpload({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -export const getBotsByBotIdContainerFsUsageQueryKey = (options: Options) => createQueryKey('getBotsByBotIdContainerFsUsage', options); - -/** - * Get usage under a path - */ -export const getBotsByBotIdContainerFsUsageQuery = defineQueryOptions((options: Options) => ({ - key: getBotsByBotIdContainerFsUsageQueryKey(options), - query: async (context) => { - const { data } = await getBotsByBotIdContainerFsUsage({ - ...options, - ...context, - throwOnError: true - }); - return data; - } -})); - /** * Delete skills from data directory */ @@ -428,100 +242,14 @@ export const postBotsByBotIdContainerStopMutation = (options?: Partial>): UseMutationOptions, DeleteBotsByBotIdHistoryError> => ({ - mutation: async (vars) => { - const { data } = await deleteBotsByBotIdHistory({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -export const getBotsByBotIdHistoryQueryKey = (options: Options) => createQueryKey('getBotsByBotIdHistory', options); - -/** - * List history records - * - * List history records for current user - */ -export const getBotsByBotIdHistoryQuery = defineQueryOptions((options: Options) => ({ - key: getBotsByBotIdHistoryQueryKey(options), - query: async (context) => { - const { data } = await getBotsByBotIdHistory({ - ...options, - ...context, - throwOnError: true - }); - return data; - } -})); - -/** - * Create history record - * - * Create a history record for current user - */ -export const postBotsByBotIdHistoryMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdHistoryError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdHistory({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -/** - * Delete history record - * - * Delete a history record by ID (must belong to current user) - */ -export const deleteBotsByBotIdHistoryByIdMutation = (options?: Partial>): UseMutationOptions, DeleteBotsByBotIdHistoryByIdError> => ({ - mutation: async (vars) => { - const { data } = await deleteBotsByBotIdHistoryById({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -export const getBotsByBotIdHistoryByIdQueryKey = (options: Options) => createQueryKey('getBotsByBotIdHistoryById', options); - -/** - * Get history record - * - * Get a history record by ID (must belong to current user) - */ -export const getBotsByBotIdHistoryByIdQuery = defineQueryOptions((options: Options) => ({ - key: getBotsByBotIdHistoryByIdQueryKey(options), - query: async (context) => { - const { data } = await getBotsByBotIdHistoryById({ - ...options, - ...context, - throwOnError: true - }); - return data; - } -})); - -export const getBotsByBotIdMcpQueryKey = (options: Options) => createQueryKey('getBotsByBotIdMcp', options); +export const getBotsByBotIdMcpQueryKey = (options?: Options) => createQueryKey('getBotsByBotIdMcp', options); /** * List MCP connections * * List MCP connections for a bot */ -export const getBotsByBotIdMcpQuery = defineQueryOptions((options: Options) => ({ +export const getBotsByBotIdMcpQuery = defineQueryOptions((options?: Options) => ({ key: getBotsByBotIdMcpQueryKey(options), query: async (context) => { const { data } = await getBotsByBotIdMcp({ @@ -570,9 +298,9 @@ export const postBotsByBotIdMcpStdioMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdMcpStdioBySessionIdError> => ({ +export const postBotsByBotIdMcpStdioByConnectionIdMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdMcpStdioByConnectionIdError> => ({ mutation: async (vars) => { - const { data } = await postBotsByBotIdMcpStdioBySessionId({ + const { data } = await postBotsByBotIdMcpStdioByConnectionId({ ...options, ...vars, throwOnError: true @@ -632,148 +360,14 @@ export const putBotsByBotIdMcpByIdMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdMemoryAddError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdMemoryAdd({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -/** - * Embed and upsert memory - * - * Embed text or multimodal input and upsert into memory store. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const postBotsByBotIdMemoryEmbedMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdMemoryEmbedError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdMemoryEmbed({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -/** - * Delete memories - * - * Delete all memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const deleteBotsByBotIdMemoryMemoriesMutation = (options?: Partial>): UseMutationOptions, DeleteBotsByBotIdMemoryMemoriesError> => ({ - mutation: async (vars) => { - const { data } = await deleteBotsByBotIdMemoryMemories({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -export const getBotsByBotIdMemoryMemoriesQueryKey = (options: Options) => createQueryKey('getBotsByBotIdMemoryMemories', options); - -/** - * List memories - * - * List memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const getBotsByBotIdMemoryMemoriesQuery = defineQueryOptions((options: Options) => ({ - key: getBotsByBotIdMemoryMemoriesQueryKey(options), - query: async (context) => { - const { data } = await getBotsByBotIdMemoryMemories({ - ...options, - ...context, - throwOnError: true - }); - return data; - } -})); - -/** - * Delete memory - * - * Delete a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const deleteBotsByBotIdMemoryMemoriesByMemoryIdMutation = (options?: Partial>): UseMutationOptions, DeleteBotsByBotIdMemoryMemoriesByMemoryIdError> => ({ - mutation: async (vars) => { - const { data } = await deleteBotsByBotIdMemoryMemoriesByMemoryId({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -export const getBotsByBotIdMemoryMemoriesByMemoryIdQueryKey = (options: Options) => createQueryKey('getBotsByBotIdMemoryMemoriesByMemoryId', options); - -/** - * Get memory - * - * Get a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const getBotsByBotIdMemoryMemoriesByMemoryIdQuery = defineQueryOptions((options: Options) => ({ - key: getBotsByBotIdMemoryMemoriesByMemoryIdQueryKey(options), - query: async (context) => { - const { data } = await getBotsByBotIdMemoryMemoriesByMemoryId({ - ...options, - ...context, - throwOnError: true - }); - return data; - } -})); - -/** - * Search memories - * - * Search memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const postBotsByBotIdMemorySearchMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdMemorySearchError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdMemorySearch({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -/** - * Update memory - * - * Update a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const postBotsByBotIdMemoryUpdateMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdMemoryUpdateError> => ({ - mutation: async (vars) => { - const { data } = await postBotsByBotIdMemoryUpdate({ - ...options, - ...vars, - throwOnError: true - }); - return data; - } -}); - -export const getBotsByBotIdScheduleQueryKey = (options: Options) => createQueryKey('getBotsByBotIdSchedule', options); +export const getBotsByBotIdScheduleQueryKey = (options?: Options) => createQueryKey('getBotsByBotIdSchedule', options); /** * List schedules * * List schedules for current user */ -export const getBotsByBotIdScheduleQuery = defineQueryOptions((options: Options) => ({ +export const getBotsByBotIdScheduleQuery = defineQueryOptions((options?: Options) => ({ key: getBotsByBotIdScheduleQueryKey(options), query: async (context) => { const { data } = await getBotsByBotIdSchedule({ @@ -868,14 +462,14 @@ export const deleteBotsByBotIdSettingsMutation = (options?: Partial) => createQueryKey('getBotsByBotIdSettings', options); +export const getBotsByBotIdSettingsQueryKey = (options?: Options) => createQueryKey('getBotsByBotIdSettings', options); /** * Get user settings * * Get agent settings for current user */ -export const getBotsByBotIdSettingsQuery = defineQueryOptions((options: Options) => ({ +export const getBotsByBotIdSettingsQuery = defineQueryOptions((options?: Options) => ({ key: getBotsByBotIdSettingsQueryKey(options), query: async (context) => { const { data } = await getBotsByBotIdSettings({ @@ -919,14 +513,14 @@ export const putBotsByBotIdSettingsMutation = (options?: Partial) => createQueryKey('getBotsByBotIdSubagents', options); +export const getBotsByBotIdSubagentsQueryKey = (options?: Options) => createQueryKey('getBotsByBotIdSubagents', options); /** * List subagents * * List subagents for current user */ -export const getBotsByBotIdSubagentsQuery = defineQueryOptions((options: Options) => ({ +export const getBotsByBotIdSubagentsQuery = defineQueryOptions((options?: Options) => ({ key: getBotsByBotIdSubagentsQueryKey(options), query: async (context) => { const { data } = await getBotsByBotIdSubagents({ @@ -1091,12 +685,28 @@ export const putBotsByBotIdSubagentsByIdSkillsMutation = (options?: Partial>): UseMutationOptions, PostBotsByBotIdToolsError> => ({ + mutation: async (vars) => { + const { data } = await postBotsByBotIdTools({ + ...options, + ...vars, + throwOnError: true + }); + return data; + } +}); + /** * Delete bot * * Delete a bot user (owner/admin only) */ -export const deleteBotsByIdMutation = (options?: Partial>): UseMutationOptions, DeleteBotsByIdError> => ({ +export const deleteBotsByIdMutation = (options?: Partial>): UseMutationOptions, DeleteBotsByIdError> => ({ mutation: async (vars) => { const { data } = await deleteBotsById({ ...options, @@ -1198,9 +808,9 @@ export const postBotsByIdChannelByPlatformSendMutation = (options?: Partial>): UseMutationOptions, PostBotsByIdChannelByPlatformSendSessionError> => ({ +export const postBotsByIdChannelByPlatformSendChatMutation = (options?: Partial>): UseMutationOptions, PostBotsByIdChannelByPlatformSendChatError> => ({ mutation: async (vars) => { - const { data } = await postBotsByIdChannelByPlatformSendSession({ + const { data } = await postBotsByIdChannelByPlatformSendChat({ ...options, ...vars, throwOnError: true @@ -1209,6 +819,25 @@ export const postBotsByIdChannelByPlatformSendSessionMutation = (options?: Parti } }); +export const getBotsByIdChecksQueryKey = (options: Options) => createQueryKey('getBotsByIdChecks', options); + +/** + * List bot runtime checks + * + * Evaluate bot attached resource checks in runtime + */ +export const getBotsByIdChecksQuery = defineQueryOptions((options: Options) => ({ + key: getBotsByIdChecksQueryKey(options), + query: async (context) => { + const { data } = await getBotsByIdChecks({ + ...options, + ...context, + throwOnError: true + }); + return data; + } +})); + export const getBotsByIdMembersQueryKey = (options: Options) => createQueryKey('getBotsByIdMembers', options); /** @@ -1750,6 +1379,25 @@ export const putUsersMeChannelsByPlatformMutation = (options?: Partial) => createQueryKey('getUsersMeIdentities', options); + +/** + * List current user's channel identities + * + * List all channel identities linked to current user + */ +export const getUsersMeIdentitiesQuery = defineQueryOptions((options?: Options) => ({ + key: getUsersMeIdentitiesQueryKey(options), + query: async (context) => { + const { data } = await getUsersMeIdentities({ + ...options, + ...context, + throwOnError: true + }); + return data; + } +})); + /** * Update current user password * diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts index 904fad11..bcb5c2ea 100644 --- a/packages/sdk/src/index.ts +++ b/packages/sdk/src/index.ts @@ -1,4 +1,4 @@ // This file is auto-generated by @hey-api/openapi-ts -export { deleteBotsByBotIdContainer, deleteBotsByBotIdContainerFs, deleteBotsByBotIdContainerSkills, deleteBotsByBotIdHistory, deleteBotsByBotIdHistoryById, deleteBotsByBotIdMcpById, deleteBotsByBotIdMemoryMemories, deleteBotsByBotIdMemoryMemoriesByMemoryId, deleteBotsByBotIdScheduleById, deleteBotsByBotIdSettings, deleteBotsByBotIdSubagentsById, deleteBotsById, deleteBotsByIdMembersByUserId, deleteModelsById, deleteModelsModelByModelId, deleteProvidersById, getBots, getBotsByBotIdContainer, getBotsByBotIdContainerFs, getBotsByBotIdContainerFsFile, getBotsByBotIdContainerFsStat, getBotsByBotIdContainerFsUsage, getBotsByBotIdContainerSkills, getBotsByBotIdContainerSnapshots, getBotsByBotIdHistory, getBotsByBotIdHistoryById, getBotsByBotIdMcp, getBotsByBotIdMcpById, getBotsByBotIdMemoryMemories, getBotsByBotIdMemoryMemoriesByMemoryId, getBotsByBotIdSchedule, getBotsByBotIdScheduleById, getBotsByBotIdSettings, getBotsByBotIdSubagents, getBotsByBotIdSubagentsById, getBotsByBotIdSubagentsByIdContext, getBotsByBotIdSubagentsByIdSkills, getBotsById, getBotsByIdChannelByPlatform, getBotsByIdMembers, getChannels, getChannelsByPlatform, getModels, getModelsById, getModelsCount, getModelsModelByModelId, getProviders, getProvidersById, getProvidersByIdModels, getProvidersCount, getProvidersNameByName, getUsers, getUsersById, getUsersMe, getUsersMeChannelsByPlatform, type Options, postAuthLogin, postBots, postBotsByBotIdChat, postBotsByBotIdChatStream, postBotsByBotIdContainer, postBotsByBotIdContainerFsDir, postBotsByBotIdContainerFsFile, postBotsByBotIdContainerFsMcp, postBotsByBotIdContainerFsUpload, postBotsByBotIdContainerSkills, postBotsByBotIdContainerSnapshots, postBotsByBotIdContainerStart, postBotsByBotIdContainerStop, postBotsByBotIdHistory, postBotsByBotIdMcp, postBotsByBotIdMcpStdio, postBotsByBotIdMcpStdioBySessionId, postBotsByBotIdMemoryAdd, postBotsByBotIdMemoryEmbed, postBotsByBotIdMemorySearch, postBotsByBotIdMemoryUpdate, postBotsByBotIdSchedule, postBotsByBotIdSettings, postBotsByBotIdSubagents, postBotsByBotIdSubagentsByIdSkills, postBotsByIdChannelByPlatformSend, postBotsByIdChannelByPlatformSendSession, postEmbeddings, postModels, postModelsEnable, postProviders, postUsers, putBotsByBotIdMcpById, putBotsByBotIdScheduleById, putBotsByBotIdSettings, putBotsByBotIdSubagentsById, putBotsByBotIdSubagentsByIdContext, putBotsByBotIdSubagentsByIdSkills, putBotsById, putBotsByIdChannelByPlatform, putBotsByIdMembers, putBotsByIdOwner, putModelsById, putModelsModelByModelId, putProvidersById, putUsersById, putUsersByIdPassword, putUsersMe, putUsersMeChannelsByPlatform, putUsersMePassword } from './sdk.gen'; -export type { BotsBot, BotsBotMember, BotsCreateBotRequest, BotsListBotsResponse, BotsListMembersResponse, BotsTransferBotRequest, BotsUpdateBotRequest, BotsUpsertMemberRequest, ChannelAction, ChannelAttachment, ChannelAttachmentType, ChannelChannelCapabilities, ChannelChannelConfig, ChannelChannelUserBinding, ChannelConfigSchema, ChannelFieldSchema, ChannelFieldType, ChannelMessage, ChannelMessageFormat, ChannelMessagePart, ChannelMessagePartType, ChannelMessageTextStyle, ChannelReplyRef, ChannelSendRequest, ChannelTargetHint, ChannelTargetSpec, ChannelThreadRef, ChannelUpsertConfigRequest, ChannelUpsertUserConfigRequest, ChatChatRequest, ChatChatResponse, ChatModelMessage, ChatToolCall, ChatToolCallFunction, ClientOptions, DeleteBotsByBotIdContainerData, DeleteBotsByBotIdContainerError, DeleteBotsByBotIdContainerErrors, DeleteBotsByBotIdContainerFsData, DeleteBotsByBotIdContainerFsError, DeleteBotsByBotIdContainerFsErrors, DeleteBotsByBotIdContainerFsResponse, DeleteBotsByBotIdContainerFsResponses, DeleteBotsByBotIdContainerResponses, DeleteBotsByBotIdContainerSkillsData, DeleteBotsByBotIdContainerSkillsError, DeleteBotsByBotIdContainerSkillsErrors, DeleteBotsByBotIdContainerSkillsResponse, DeleteBotsByBotIdContainerSkillsResponses, DeleteBotsByBotIdHistoryByIdData, DeleteBotsByBotIdHistoryByIdError, DeleteBotsByBotIdHistoryByIdErrors, DeleteBotsByBotIdHistoryByIdResponses, DeleteBotsByBotIdHistoryData, DeleteBotsByBotIdHistoryError, DeleteBotsByBotIdHistoryErrors, DeleteBotsByBotIdHistoryResponses, DeleteBotsByBotIdMcpByIdData, DeleteBotsByBotIdMcpByIdError, DeleteBotsByBotIdMcpByIdErrors, DeleteBotsByBotIdMcpByIdResponses, DeleteBotsByBotIdMemoryMemoriesByMemoryIdData, DeleteBotsByBotIdMemoryMemoriesByMemoryIdError, DeleteBotsByBotIdMemoryMemoriesByMemoryIdErrors, DeleteBotsByBotIdMemoryMemoriesByMemoryIdResponse, DeleteBotsByBotIdMemoryMemoriesByMemoryIdResponses, DeleteBotsByBotIdMemoryMemoriesData, DeleteBotsByBotIdMemoryMemoriesError, DeleteBotsByBotIdMemoryMemoriesErrors, DeleteBotsByBotIdMemoryMemoriesResponse, DeleteBotsByBotIdMemoryMemoriesResponses, DeleteBotsByBotIdScheduleByIdData, DeleteBotsByBotIdScheduleByIdError, DeleteBotsByBotIdScheduleByIdErrors, DeleteBotsByBotIdScheduleByIdResponses, DeleteBotsByBotIdSettingsData, DeleteBotsByBotIdSettingsError, DeleteBotsByBotIdSettingsErrors, DeleteBotsByBotIdSettingsResponses, DeleteBotsByBotIdSubagentsByIdData, DeleteBotsByBotIdSubagentsByIdError, DeleteBotsByBotIdSubagentsByIdErrors, DeleteBotsByBotIdSubagentsByIdResponses, DeleteBotsByIdData, DeleteBotsByIdError, DeleteBotsByIdErrors, DeleteBotsByIdMembersByUserIdData, DeleteBotsByIdMembersByUserIdError, DeleteBotsByIdMembersByUserIdErrors, DeleteBotsByIdMembersByUserIdResponses, DeleteBotsByIdResponses, DeleteModelsByIdData, DeleteModelsByIdError, DeleteModelsByIdErrors, DeleteModelsByIdResponses, DeleteModelsModelByModelIdData, DeleteModelsModelByModelIdError, DeleteModelsModelByModelIdErrors, DeleteModelsModelByModelIdResponses, DeleteProvidersByIdData, DeleteProvidersByIdError, DeleteProvidersByIdErrors, DeleteProvidersByIdResponses, GetBotsByBotIdContainerData, GetBotsByBotIdContainerError, GetBotsByBotIdContainerErrors, GetBotsByBotIdContainerFsData, GetBotsByBotIdContainerFsError, GetBotsByBotIdContainerFsErrors, GetBotsByBotIdContainerFsFileData, GetBotsByBotIdContainerFsFileError, GetBotsByBotIdContainerFsFileErrors, GetBotsByBotIdContainerFsFileResponse, GetBotsByBotIdContainerFsFileResponses, GetBotsByBotIdContainerFsResponse, GetBotsByBotIdContainerFsResponses, GetBotsByBotIdContainerFsStatData, GetBotsByBotIdContainerFsStatError, GetBotsByBotIdContainerFsStatErrors, GetBotsByBotIdContainerFsStatResponse, GetBotsByBotIdContainerFsStatResponses, GetBotsByBotIdContainerFsUsageData, GetBotsByBotIdContainerFsUsageError, GetBotsByBotIdContainerFsUsageErrors, GetBotsByBotIdContainerFsUsageResponse, GetBotsByBotIdContainerFsUsageResponses, GetBotsByBotIdContainerResponse, GetBotsByBotIdContainerResponses, GetBotsByBotIdContainerSkillsData, GetBotsByBotIdContainerSkillsError, GetBotsByBotIdContainerSkillsErrors, GetBotsByBotIdContainerSkillsResponse, GetBotsByBotIdContainerSkillsResponses, GetBotsByBotIdContainerSnapshotsData, GetBotsByBotIdContainerSnapshotsResponse, GetBotsByBotIdContainerSnapshotsResponses, GetBotsByBotIdHistoryByIdData, GetBotsByBotIdHistoryByIdError, GetBotsByBotIdHistoryByIdErrors, GetBotsByBotIdHistoryByIdResponse, GetBotsByBotIdHistoryByIdResponses, GetBotsByBotIdHistoryData, GetBotsByBotIdHistoryError, GetBotsByBotIdHistoryErrors, GetBotsByBotIdHistoryResponse, GetBotsByBotIdHistoryResponses, GetBotsByBotIdMcpByIdData, GetBotsByBotIdMcpByIdError, GetBotsByBotIdMcpByIdErrors, GetBotsByBotIdMcpByIdResponse, GetBotsByBotIdMcpByIdResponses, GetBotsByBotIdMcpData, GetBotsByBotIdMcpError, GetBotsByBotIdMcpErrors, GetBotsByBotIdMcpResponse, GetBotsByBotIdMcpResponses, GetBotsByBotIdMemoryMemoriesByMemoryIdData, GetBotsByBotIdMemoryMemoriesByMemoryIdError, GetBotsByBotIdMemoryMemoriesByMemoryIdErrors, GetBotsByBotIdMemoryMemoriesByMemoryIdResponse, GetBotsByBotIdMemoryMemoriesByMemoryIdResponses, GetBotsByBotIdMemoryMemoriesData, GetBotsByBotIdMemoryMemoriesError, GetBotsByBotIdMemoryMemoriesErrors, GetBotsByBotIdMemoryMemoriesResponse, GetBotsByBotIdMemoryMemoriesResponses, GetBotsByBotIdScheduleByIdData, GetBotsByBotIdScheduleByIdError, GetBotsByBotIdScheduleByIdErrors, GetBotsByBotIdScheduleByIdResponse, GetBotsByBotIdScheduleByIdResponses, GetBotsByBotIdScheduleData, GetBotsByBotIdScheduleError, GetBotsByBotIdScheduleErrors, GetBotsByBotIdScheduleResponse, GetBotsByBotIdScheduleResponses, GetBotsByBotIdSettingsData, GetBotsByBotIdSettingsError, GetBotsByBotIdSettingsErrors, GetBotsByBotIdSettingsResponse, GetBotsByBotIdSettingsResponses, GetBotsByBotIdSubagentsByIdContextData, GetBotsByBotIdSubagentsByIdContextError, GetBotsByBotIdSubagentsByIdContextErrors, GetBotsByBotIdSubagentsByIdContextResponse, GetBotsByBotIdSubagentsByIdContextResponses, GetBotsByBotIdSubagentsByIdData, GetBotsByBotIdSubagentsByIdError, GetBotsByBotIdSubagentsByIdErrors, GetBotsByBotIdSubagentsByIdResponse, GetBotsByBotIdSubagentsByIdResponses, GetBotsByBotIdSubagentsByIdSkillsData, GetBotsByBotIdSubagentsByIdSkillsError, GetBotsByBotIdSubagentsByIdSkillsErrors, GetBotsByBotIdSubagentsByIdSkillsResponse, GetBotsByBotIdSubagentsByIdSkillsResponses, GetBotsByBotIdSubagentsData, GetBotsByBotIdSubagentsError, GetBotsByBotIdSubagentsErrors, GetBotsByBotIdSubagentsResponse, GetBotsByBotIdSubagentsResponses, GetBotsByIdChannelByPlatformData, GetBotsByIdChannelByPlatformError, GetBotsByIdChannelByPlatformErrors, GetBotsByIdChannelByPlatformResponse, GetBotsByIdChannelByPlatformResponses, GetBotsByIdData, GetBotsByIdError, GetBotsByIdErrors, GetBotsByIdMembersData, GetBotsByIdMembersError, GetBotsByIdMembersErrors, GetBotsByIdMembersResponse, GetBotsByIdMembersResponses, GetBotsByIdResponse, GetBotsByIdResponses, GetBotsData, GetBotsError, GetBotsErrors, GetBotsResponse, GetBotsResponses, GetChannelsByPlatformData, GetChannelsByPlatformError, GetChannelsByPlatformErrors, GetChannelsByPlatformResponse, GetChannelsByPlatformResponses, GetChannelsData, GetChannelsError, GetChannelsErrors, GetChannelsResponse, GetChannelsResponses, GetModelsByIdData, GetModelsByIdError, GetModelsByIdErrors, GetModelsByIdResponse, GetModelsByIdResponses, GetModelsCountData, GetModelsCountError, GetModelsCountErrors, GetModelsCountResponse, GetModelsCountResponses, GetModelsData, GetModelsError, GetModelsErrors, GetModelsModelByModelIdData, GetModelsModelByModelIdError, GetModelsModelByModelIdErrors, GetModelsModelByModelIdResponse, GetModelsModelByModelIdResponses, GetModelsResponse, GetModelsResponses, GetProvidersByIdData, GetProvidersByIdError, GetProvidersByIdErrors, GetProvidersByIdModelsData, GetProvidersByIdModelsError, GetProvidersByIdModelsErrors, GetProvidersByIdModelsResponse, GetProvidersByIdModelsResponses, GetProvidersByIdResponse, GetProvidersByIdResponses, GetProvidersCountData, GetProvidersCountError, GetProvidersCountErrors, GetProvidersCountResponse, GetProvidersCountResponses, GetProvidersData, GetProvidersError, GetProvidersErrors, GetProvidersNameByNameData, GetProvidersNameByNameError, GetProvidersNameByNameErrors, GetProvidersNameByNameResponse, GetProvidersNameByNameResponses, GetProvidersResponse, GetProvidersResponses, GetUsersByIdData, GetUsersByIdError, GetUsersByIdErrors, GetUsersByIdResponse, GetUsersByIdResponses, GetUsersData, GetUsersError, GetUsersErrors, GetUsersMeChannelsByPlatformData, GetUsersMeChannelsByPlatformError, GetUsersMeChannelsByPlatformErrors, GetUsersMeChannelsByPlatformResponse, GetUsersMeChannelsByPlatformResponses, GetUsersMeData, GetUsersMeError, GetUsersMeErrors, GetUsersMeResponse, GetUsersMeResponses, GetUsersResponse, GetUsersResponses, GithubComMemohaiMemohInternalMcpConnection, HandlersChannelMeta, HandlersCreateContainerRequest, HandlersCreateContainerResponse, HandlersCreateSnapshotRequest, HandlersCreateSnapshotResponse, HandlersEmbeddingsInput, HandlersEmbeddingsRequest, HandlersEmbeddingsResponse, HandlersEmbeddingsUsage, HandlersEnableModelRequest, HandlersErrorResponse, HandlersFsDeleteResponse, HandlersFsListResponse, HandlersFsMkdirRequest, HandlersFsReadResponse, HandlersFsRestEntry, HandlersFsStatResponse, HandlersFsUsageResponse, HandlersFsWriteRequest, HandlersFsWriteResponse, HandlersGetContainerResponse, HandlersListSnapshotsResponse, HandlersLoginRequest, HandlersLoginResponse, HandlersMcpStdioRequest, HandlersMcpStdioResponse, HandlersMemoryAddPayload, HandlersMemoryDeleteAllPayload, HandlersMemoryEmbedUpsertPayload, HandlersMemorySearchPayload, HandlersSkillItem, HandlersSkillsDeleteRequest, HandlersSkillsOpResponse, HandlersSkillsResponse, HandlersSkillsUpsertRequest, HandlersSnapshotInfo, HistoryCreateRequest, HistoryListResponse, HistoryRecord, McpListResponse, McpUpsertRequest, MemoryDeleteResponse, MemoryEmbedInput, MemoryEmbedUpsertResponse, MemoryMemoryItem, MemoryMessage, MemorySearchResponse, MemoryUpdateRequest, ModelsAddRequest, ModelsAddResponse, ModelsCountResponse, ModelsGetResponse, ModelsModelType, ModelsUpdateRequest, PostAuthLoginData, PostAuthLoginError, PostAuthLoginErrors, PostAuthLoginResponse, PostAuthLoginResponses, PostBotsByBotIdChatData, PostBotsByBotIdChatError, PostBotsByBotIdChatErrors, PostBotsByBotIdChatResponse, PostBotsByBotIdChatResponses, PostBotsByBotIdChatStreamData, PostBotsByBotIdChatStreamError, PostBotsByBotIdChatStreamErrors, PostBotsByBotIdChatStreamResponse, PostBotsByBotIdChatStreamResponses, PostBotsByBotIdContainerData, PostBotsByBotIdContainerError, PostBotsByBotIdContainerErrors, PostBotsByBotIdContainerFsDirData, PostBotsByBotIdContainerFsDirError, PostBotsByBotIdContainerFsDirErrors, PostBotsByBotIdContainerFsDirResponse, PostBotsByBotIdContainerFsDirResponses, PostBotsByBotIdContainerFsFileData, PostBotsByBotIdContainerFsFileError, PostBotsByBotIdContainerFsFileErrors, PostBotsByBotIdContainerFsFileResponse, PostBotsByBotIdContainerFsFileResponses, PostBotsByBotIdContainerFsMcpData, PostBotsByBotIdContainerFsMcpError, PostBotsByBotIdContainerFsMcpErrors, PostBotsByBotIdContainerFsMcpResponse, PostBotsByBotIdContainerFsMcpResponses, PostBotsByBotIdContainerFsUploadData, PostBotsByBotIdContainerFsUploadError, PostBotsByBotIdContainerFsUploadErrors, PostBotsByBotIdContainerFsUploadResponse, PostBotsByBotIdContainerFsUploadResponses, PostBotsByBotIdContainerResponse, PostBotsByBotIdContainerResponses, PostBotsByBotIdContainerSkillsData, PostBotsByBotIdContainerSkillsError, PostBotsByBotIdContainerSkillsErrors, PostBotsByBotIdContainerSkillsResponse, PostBotsByBotIdContainerSkillsResponses, PostBotsByBotIdContainerSnapshotsData, PostBotsByBotIdContainerSnapshotsError, PostBotsByBotIdContainerSnapshotsErrors, PostBotsByBotIdContainerSnapshotsResponse, PostBotsByBotIdContainerSnapshotsResponses, PostBotsByBotIdContainerStartData, PostBotsByBotIdContainerStartError, PostBotsByBotIdContainerStartErrors, PostBotsByBotIdContainerStartResponse, PostBotsByBotIdContainerStartResponses, PostBotsByBotIdContainerStopData, PostBotsByBotIdContainerStopError, PostBotsByBotIdContainerStopErrors, PostBotsByBotIdContainerStopResponse, PostBotsByBotIdContainerStopResponses, PostBotsByBotIdHistoryData, PostBotsByBotIdHistoryError, PostBotsByBotIdHistoryErrors, PostBotsByBotIdHistoryResponse, PostBotsByBotIdHistoryResponses, PostBotsByBotIdMcpData, PostBotsByBotIdMcpError, PostBotsByBotIdMcpErrors, PostBotsByBotIdMcpResponse, PostBotsByBotIdMcpResponses, PostBotsByBotIdMcpStdioBySessionIdData, PostBotsByBotIdMcpStdioBySessionIdError, PostBotsByBotIdMcpStdioBySessionIdErrors, PostBotsByBotIdMcpStdioBySessionIdResponse, PostBotsByBotIdMcpStdioBySessionIdResponses, PostBotsByBotIdMcpStdioData, PostBotsByBotIdMcpStdioError, PostBotsByBotIdMcpStdioErrors, PostBotsByBotIdMcpStdioResponse, PostBotsByBotIdMcpStdioResponses, PostBotsByBotIdMemoryAddData, PostBotsByBotIdMemoryAddError, PostBotsByBotIdMemoryAddErrors, PostBotsByBotIdMemoryAddResponse, PostBotsByBotIdMemoryAddResponses, PostBotsByBotIdMemoryEmbedData, PostBotsByBotIdMemoryEmbedError, PostBotsByBotIdMemoryEmbedErrors, PostBotsByBotIdMemoryEmbedResponse, PostBotsByBotIdMemoryEmbedResponses, PostBotsByBotIdMemorySearchData, PostBotsByBotIdMemorySearchError, PostBotsByBotIdMemorySearchErrors, PostBotsByBotIdMemorySearchResponse, PostBotsByBotIdMemorySearchResponses, PostBotsByBotIdMemoryUpdateData, PostBotsByBotIdMemoryUpdateError, PostBotsByBotIdMemoryUpdateErrors, PostBotsByBotIdMemoryUpdateResponse, PostBotsByBotIdMemoryUpdateResponses, PostBotsByBotIdScheduleData, PostBotsByBotIdScheduleError, PostBotsByBotIdScheduleErrors, PostBotsByBotIdScheduleResponse, PostBotsByBotIdScheduleResponses, PostBotsByBotIdSettingsData, PostBotsByBotIdSettingsError, PostBotsByBotIdSettingsErrors, PostBotsByBotIdSettingsResponse, PostBotsByBotIdSettingsResponses, PostBotsByBotIdSubagentsByIdSkillsData, PostBotsByBotIdSubagentsByIdSkillsError, PostBotsByBotIdSubagentsByIdSkillsErrors, PostBotsByBotIdSubagentsByIdSkillsResponse, PostBotsByBotIdSubagentsByIdSkillsResponses, PostBotsByBotIdSubagentsData, PostBotsByBotIdSubagentsError, PostBotsByBotIdSubagentsErrors, PostBotsByBotIdSubagentsResponse, PostBotsByBotIdSubagentsResponses, PostBotsByIdChannelByPlatformSendData, PostBotsByIdChannelByPlatformSendError, PostBotsByIdChannelByPlatformSendErrors, PostBotsByIdChannelByPlatformSendResponse, PostBotsByIdChannelByPlatformSendResponses, PostBotsByIdChannelByPlatformSendSessionData, PostBotsByIdChannelByPlatformSendSessionError, PostBotsByIdChannelByPlatformSendSessionErrors, PostBotsByIdChannelByPlatformSendSessionResponse, PostBotsByIdChannelByPlatformSendSessionResponses, PostBotsData, PostBotsError, PostBotsErrors, PostBotsResponse, PostBotsResponses, PostEmbeddingsData, PostEmbeddingsError, PostEmbeddingsErrors, PostEmbeddingsResponse, PostEmbeddingsResponses, PostModelsData, PostModelsEnableData, PostModelsEnableError, PostModelsEnableErrors, PostModelsEnableResponse, PostModelsEnableResponses, PostModelsError, PostModelsErrors, PostModelsResponse, PostModelsResponses, PostProvidersData, PostProvidersError, PostProvidersErrors, PostProvidersResponse, PostProvidersResponses, PostUsersData, PostUsersError, PostUsersErrors, PostUsersResponse, PostUsersResponses, ProvidersClientType, ProvidersCountResponse, ProvidersCreateRequest, ProvidersGetResponse, ProvidersUpdateRequest, PutBotsByBotIdMcpByIdData, PutBotsByBotIdMcpByIdError, PutBotsByBotIdMcpByIdErrors, PutBotsByBotIdMcpByIdResponse, PutBotsByBotIdMcpByIdResponses, PutBotsByBotIdScheduleByIdData, PutBotsByBotIdScheduleByIdError, PutBotsByBotIdScheduleByIdErrors, PutBotsByBotIdScheduleByIdResponse, PutBotsByBotIdScheduleByIdResponses, PutBotsByBotIdSettingsData, PutBotsByBotIdSettingsError, PutBotsByBotIdSettingsErrors, PutBotsByBotIdSettingsResponse, PutBotsByBotIdSettingsResponses, PutBotsByBotIdSubagentsByIdContextData, PutBotsByBotIdSubagentsByIdContextError, PutBotsByBotIdSubagentsByIdContextErrors, PutBotsByBotIdSubagentsByIdContextResponse, PutBotsByBotIdSubagentsByIdContextResponses, PutBotsByBotIdSubagentsByIdData, PutBotsByBotIdSubagentsByIdError, PutBotsByBotIdSubagentsByIdErrors, PutBotsByBotIdSubagentsByIdResponse, PutBotsByBotIdSubagentsByIdResponses, PutBotsByBotIdSubagentsByIdSkillsData, PutBotsByBotIdSubagentsByIdSkillsError, PutBotsByBotIdSubagentsByIdSkillsErrors, PutBotsByBotIdSubagentsByIdSkillsResponse, PutBotsByBotIdSubagentsByIdSkillsResponses, PutBotsByIdChannelByPlatformData, PutBotsByIdChannelByPlatformError, PutBotsByIdChannelByPlatformErrors, PutBotsByIdChannelByPlatformResponse, PutBotsByIdChannelByPlatformResponses, PutBotsByIdData, PutBotsByIdError, PutBotsByIdErrors, PutBotsByIdMembersData, PutBotsByIdMembersError, PutBotsByIdMembersErrors, PutBotsByIdMembersResponse, PutBotsByIdMembersResponses, PutBotsByIdOwnerData, PutBotsByIdOwnerError, PutBotsByIdOwnerErrors, PutBotsByIdOwnerResponse, PutBotsByIdOwnerResponses, PutBotsByIdResponse, PutBotsByIdResponses, PutModelsByIdData, PutModelsByIdError, PutModelsByIdErrors, PutModelsByIdResponse, PutModelsByIdResponses, PutModelsModelByModelIdData, PutModelsModelByModelIdError, PutModelsModelByModelIdErrors, PutModelsModelByModelIdResponse, PutModelsModelByModelIdResponses, PutProvidersByIdData, PutProvidersByIdError, PutProvidersByIdErrors, PutProvidersByIdResponse, PutProvidersByIdResponses, PutUsersByIdData, PutUsersByIdError, PutUsersByIdErrors, PutUsersByIdPasswordData, PutUsersByIdPasswordError, PutUsersByIdPasswordErrors, PutUsersByIdPasswordResponses, PutUsersByIdResponse, PutUsersByIdResponses, PutUsersMeChannelsByPlatformData, PutUsersMeChannelsByPlatformError, PutUsersMeChannelsByPlatformErrors, PutUsersMeChannelsByPlatformResponse, PutUsersMeChannelsByPlatformResponses, PutUsersMeData, PutUsersMeError, PutUsersMeErrors, PutUsersMePasswordData, PutUsersMePasswordError, PutUsersMePasswordErrors, PutUsersMePasswordResponses, PutUsersMeResponse, PutUsersMeResponses, ScheduleCreateRequest, ScheduleListResponse, ScheduleNullableInt, ScheduleSchedule, ScheduleUpdateRequest, SettingsSettings, SettingsUpsertRequest, SubagentAddSkillsRequest, SubagentContextResponse, SubagentCreateRequest, SubagentListResponse, SubagentSkillsResponse, SubagentSubagent, SubagentUpdateContextRequest, SubagentUpdateRequest, SubagentUpdateSkillsRequest, UsersCreateUserRequest, UsersListUsersResponse, UsersResetPasswordRequest, UsersUpdatePasswordRequest, UsersUpdateProfileRequest, UsersUpdateUserRequest, UsersUser } from './types.gen'; +export { deleteBotsByBotIdContainer, deleteBotsByBotIdContainerSkills, deleteBotsByBotIdMcpById, deleteBotsByBotIdScheduleById, deleteBotsByBotIdSettings, deleteBotsByBotIdSubagentsById, deleteBotsById, deleteBotsByIdMembersByUserId, deleteModelsById, deleteModelsModelByModelId, deleteProvidersById, getBots, getBotsByBotIdContainer, getBotsByBotIdContainerSkills, getBotsByBotIdContainerSnapshots, getBotsByBotIdMcp, getBotsByBotIdMcpById, getBotsByBotIdSchedule, getBotsByBotIdScheduleById, getBotsByBotIdSettings, getBotsByBotIdSubagents, getBotsByBotIdSubagentsById, getBotsByBotIdSubagentsByIdContext, getBotsByBotIdSubagentsByIdSkills, getBotsById, getBotsByIdChannelByPlatform, getBotsByIdChecks, getBotsByIdMembers, getChannels, getChannelsByPlatform, getModels, getModelsById, getModelsCount, getModelsModelByModelId, getProviders, getProvidersById, getProvidersByIdModels, getProvidersCount, getProvidersNameByName, getUsers, getUsersById, getUsersMe, getUsersMeChannelsByPlatform, getUsersMeIdentities, type Options, postAuthLogin, postBots, postBotsByBotIdContainer, postBotsByBotIdContainerSkills, postBotsByBotIdContainerSnapshots, postBotsByBotIdContainerStart, postBotsByBotIdContainerStop, postBotsByBotIdMcp, postBotsByBotIdMcpStdio, postBotsByBotIdMcpStdioByConnectionId, postBotsByBotIdSchedule, postBotsByBotIdSettings, postBotsByBotIdSubagents, postBotsByBotIdSubagentsByIdSkills, postBotsByBotIdTools, postBotsByIdChannelByPlatformSend, postBotsByIdChannelByPlatformSendChat, postEmbeddings, postModels, postModelsEnable, postProviders, postUsers, putBotsByBotIdMcpById, putBotsByBotIdScheduleById, putBotsByBotIdSettings, putBotsByBotIdSubagentsById, putBotsByBotIdSubagentsByIdContext, putBotsByBotIdSubagentsByIdSkills, putBotsById, putBotsByIdChannelByPlatform, putBotsByIdMembers, putBotsByIdOwner, putModelsById, putModelsModelByModelId, putProvidersById, putUsersById, putUsersByIdPassword, putUsersMe, putUsersMeChannelsByPlatform, putUsersMePassword } from './sdk.gen'; +export type { AccountsAccount, AccountsCreateAccountRequest, AccountsListAccountsResponse, AccountsResetPasswordRequest, AccountsUpdateAccountRequest, AccountsUpdatePasswordRequest, AccountsUpdateProfileRequest, BotsBot, BotsBotCheck, BotsBotMember, BotsCreateBotRequest, BotsListBotsResponse, BotsListChecksResponse, BotsListMembersResponse, BotsTransferBotRequest, BotsUpdateBotRequest, BotsUpsertMemberRequest, ChannelAction, ChannelAttachment, ChannelAttachmentType, ChannelChannelCapabilities, ChannelChannelConfig, ChannelChannelIdentityBinding, ChannelConfigSchema, ChannelFieldSchema, ChannelFieldType, ChannelMessage, ChannelMessageFormat, ChannelMessagePart, ChannelMessagePartType, ChannelMessageTextStyle, ChannelReplyRef, ChannelSendRequest, ChannelTargetHint, ChannelTargetSpec, ChannelThreadRef, ChannelUpsertChannelIdentityConfigRequest, ChannelUpsertConfigRequest, ClientOptions, DeleteBotsByBotIdContainerData, DeleteBotsByBotIdContainerError, DeleteBotsByBotIdContainerErrors, DeleteBotsByBotIdContainerResponses, DeleteBotsByBotIdContainerSkillsData, DeleteBotsByBotIdContainerSkillsError, DeleteBotsByBotIdContainerSkillsErrors, DeleteBotsByBotIdContainerSkillsResponse, DeleteBotsByBotIdContainerSkillsResponses, DeleteBotsByBotIdMcpByIdData, DeleteBotsByBotIdMcpByIdError, DeleteBotsByBotIdMcpByIdErrors, DeleteBotsByBotIdMcpByIdResponses, DeleteBotsByBotIdScheduleByIdData, DeleteBotsByBotIdScheduleByIdError, DeleteBotsByBotIdScheduleByIdErrors, DeleteBotsByBotIdScheduleByIdResponses, DeleteBotsByBotIdSettingsData, DeleteBotsByBotIdSettingsError, DeleteBotsByBotIdSettingsErrors, DeleteBotsByBotIdSettingsResponses, DeleteBotsByBotIdSubagentsByIdData, DeleteBotsByBotIdSubagentsByIdError, DeleteBotsByBotIdSubagentsByIdErrors, DeleteBotsByBotIdSubagentsByIdResponses, DeleteBotsByIdData, DeleteBotsByIdError, DeleteBotsByIdErrors, DeleteBotsByIdMembersByUserIdData, DeleteBotsByIdMembersByUserIdError, DeleteBotsByIdMembersByUserIdErrors, DeleteBotsByIdMembersByUserIdResponses, DeleteBotsByIdResponse, DeleteBotsByIdResponses, DeleteModelsByIdData, DeleteModelsByIdError, DeleteModelsByIdErrors, DeleteModelsByIdResponses, DeleteModelsModelByModelIdData, DeleteModelsModelByModelIdError, DeleteModelsModelByModelIdErrors, DeleteModelsModelByModelIdResponses, DeleteProvidersByIdData, DeleteProvidersByIdError, DeleteProvidersByIdErrors, DeleteProvidersByIdResponses, GetBotsByBotIdContainerData, GetBotsByBotIdContainerError, GetBotsByBotIdContainerErrors, GetBotsByBotIdContainerResponse, GetBotsByBotIdContainerResponses, GetBotsByBotIdContainerSkillsData, GetBotsByBotIdContainerSkillsError, GetBotsByBotIdContainerSkillsErrors, GetBotsByBotIdContainerSkillsResponse, GetBotsByBotIdContainerSkillsResponses, GetBotsByBotIdContainerSnapshotsData, GetBotsByBotIdContainerSnapshotsResponse, GetBotsByBotIdContainerSnapshotsResponses, GetBotsByBotIdMcpByIdData, GetBotsByBotIdMcpByIdError, GetBotsByBotIdMcpByIdErrors, GetBotsByBotIdMcpByIdResponse, GetBotsByBotIdMcpByIdResponses, GetBotsByBotIdMcpData, GetBotsByBotIdMcpError, GetBotsByBotIdMcpErrors, GetBotsByBotIdMcpResponse, GetBotsByBotIdMcpResponses, GetBotsByBotIdScheduleByIdData, GetBotsByBotIdScheduleByIdError, GetBotsByBotIdScheduleByIdErrors, GetBotsByBotIdScheduleByIdResponse, GetBotsByBotIdScheduleByIdResponses, GetBotsByBotIdScheduleData, GetBotsByBotIdScheduleError, GetBotsByBotIdScheduleErrors, GetBotsByBotIdScheduleResponse, GetBotsByBotIdScheduleResponses, GetBotsByBotIdSettingsData, GetBotsByBotIdSettingsError, GetBotsByBotIdSettingsErrors, GetBotsByBotIdSettingsResponse, GetBotsByBotIdSettingsResponses, GetBotsByBotIdSubagentsByIdContextData, GetBotsByBotIdSubagentsByIdContextError, GetBotsByBotIdSubagentsByIdContextErrors, GetBotsByBotIdSubagentsByIdContextResponse, GetBotsByBotIdSubagentsByIdContextResponses, GetBotsByBotIdSubagentsByIdData, GetBotsByBotIdSubagentsByIdError, GetBotsByBotIdSubagentsByIdErrors, GetBotsByBotIdSubagentsByIdResponse, GetBotsByBotIdSubagentsByIdResponses, GetBotsByBotIdSubagentsByIdSkillsData, GetBotsByBotIdSubagentsByIdSkillsError, GetBotsByBotIdSubagentsByIdSkillsErrors, GetBotsByBotIdSubagentsByIdSkillsResponse, GetBotsByBotIdSubagentsByIdSkillsResponses, GetBotsByBotIdSubagentsData, GetBotsByBotIdSubagentsError, GetBotsByBotIdSubagentsErrors, GetBotsByBotIdSubagentsResponse, GetBotsByBotIdSubagentsResponses, GetBotsByIdChannelByPlatformData, GetBotsByIdChannelByPlatformError, GetBotsByIdChannelByPlatformErrors, GetBotsByIdChannelByPlatformResponse, GetBotsByIdChannelByPlatformResponses, GetBotsByIdChecksData, GetBotsByIdChecksError, GetBotsByIdChecksErrors, GetBotsByIdChecksResponse, GetBotsByIdChecksResponses, GetBotsByIdData, GetBotsByIdError, GetBotsByIdErrors, GetBotsByIdMembersData, GetBotsByIdMembersError, GetBotsByIdMembersErrors, GetBotsByIdMembersResponse, GetBotsByIdMembersResponses, GetBotsByIdResponse, GetBotsByIdResponses, GetBotsData, GetBotsError, GetBotsErrors, GetBotsResponse, GetBotsResponses, GetChannelsByPlatformData, GetChannelsByPlatformError, GetChannelsByPlatformErrors, GetChannelsByPlatformResponse, GetChannelsByPlatformResponses, GetChannelsData, GetChannelsError, GetChannelsErrors, GetChannelsResponse, GetChannelsResponses, GetModelsByIdData, GetModelsByIdError, GetModelsByIdErrors, GetModelsByIdResponse, GetModelsByIdResponses, GetModelsCountData, GetModelsCountError, GetModelsCountErrors, GetModelsCountResponse, GetModelsCountResponses, GetModelsData, GetModelsError, GetModelsErrors, GetModelsModelByModelIdData, GetModelsModelByModelIdError, GetModelsModelByModelIdErrors, GetModelsModelByModelIdResponse, GetModelsModelByModelIdResponses, GetModelsResponse, GetModelsResponses, GetProvidersByIdData, GetProvidersByIdError, GetProvidersByIdErrors, GetProvidersByIdModelsData, GetProvidersByIdModelsError, GetProvidersByIdModelsErrors, GetProvidersByIdModelsResponse, GetProvidersByIdModelsResponses, GetProvidersByIdResponse, GetProvidersByIdResponses, GetProvidersCountData, GetProvidersCountError, GetProvidersCountErrors, GetProvidersCountResponse, GetProvidersCountResponses, GetProvidersData, GetProvidersError, GetProvidersErrors, GetProvidersNameByNameData, GetProvidersNameByNameError, GetProvidersNameByNameErrors, GetProvidersNameByNameResponse, GetProvidersNameByNameResponses, GetProvidersResponse, GetProvidersResponses, GetUsersByIdData, GetUsersByIdError, GetUsersByIdErrors, GetUsersByIdResponse, GetUsersByIdResponses, GetUsersData, GetUsersError, GetUsersErrors, GetUsersMeChannelsByPlatformData, GetUsersMeChannelsByPlatformError, GetUsersMeChannelsByPlatformErrors, GetUsersMeChannelsByPlatformResponse, GetUsersMeChannelsByPlatformResponses, GetUsersMeData, GetUsersMeError, GetUsersMeErrors, GetUsersMeIdentitiesData, GetUsersMeIdentitiesError, GetUsersMeIdentitiesErrors, GetUsersMeIdentitiesResponse, GetUsersMeIdentitiesResponses, GetUsersMeResponse, GetUsersMeResponses, GetUsersResponse, GetUsersResponses, GithubComMemohaiMemohInternalMcpConnection, HandlersChannelMeta, HandlersCreateContainerRequest, HandlersCreateContainerResponse, HandlersCreateSnapshotRequest, HandlersCreateSnapshotResponse, HandlersEmbeddingsInput, HandlersEmbeddingsRequest, HandlersEmbeddingsResponse, HandlersEmbeddingsUsage, HandlersEnableModelRequest, HandlersErrorResponse, HandlersGetContainerResponse, HandlersListMyIdentitiesResponse, HandlersListSnapshotsResponse, HandlersLoginRequest, HandlersLoginResponse, HandlersMcpStdioRequest, HandlersMcpStdioResponse, HandlersSkillItem, HandlersSkillsDeleteRequest, HandlersSkillsOpResponse, HandlersSkillsResponse, HandlersSkillsUpsertRequest, HandlersSnapshotInfo, IdentitiesChannelIdentity, McpListResponse, McpUpsertRequest, ModelsAddRequest, ModelsAddResponse, ModelsCountResponse, ModelsGetResponse, ModelsModelType, ModelsUpdateRequest, PostAuthLoginData, PostAuthLoginError, PostAuthLoginErrors, PostAuthLoginResponse, PostAuthLoginResponses, PostBotsByBotIdContainerData, PostBotsByBotIdContainerError, PostBotsByBotIdContainerErrors, PostBotsByBotIdContainerResponse, PostBotsByBotIdContainerResponses, PostBotsByBotIdContainerSkillsData, PostBotsByBotIdContainerSkillsError, PostBotsByBotIdContainerSkillsErrors, PostBotsByBotIdContainerSkillsResponse, PostBotsByBotIdContainerSkillsResponses, PostBotsByBotIdContainerSnapshotsData, PostBotsByBotIdContainerSnapshotsError, PostBotsByBotIdContainerSnapshotsErrors, PostBotsByBotIdContainerSnapshotsResponse, PostBotsByBotIdContainerSnapshotsResponses, PostBotsByBotIdContainerStartData, PostBotsByBotIdContainerStartError, PostBotsByBotIdContainerStartErrors, PostBotsByBotIdContainerStartResponse, PostBotsByBotIdContainerStartResponses, PostBotsByBotIdContainerStopData, PostBotsByBotIdContainerStopError, PostBotsByBotIdContainerStopErrors, PostBotsByBotIdContainerStopResponse, PostBotsByBotIdContainerStopResponses, PostBotsByBotIdMcpData, PostBotsByBotIdMcpError, PostBotsByBotIdMcpErrors, PostBotsByBotIdMcpResponse, PostBotsByBotIdMcpResponses, PostBotsByBotIdMcpStdioByConnectionIdData, PostBotsByBotIdMcpStdioByConnectionIdError, PostBotsByBotIdMcpStdioByConnectionIdErrors, PostBotsByBotIdMcpStdioByConnectionIdResponse, PostBotsByBotIdMcpStdioByConnectionIdResponses, PostBotsByBotIdMcpStdioData, PostBotsByBotIdMcpStdioError, PostBotsByBotIdMcpStdioErrors, PostBotsByBotIdMcpStdioResponse, PostBotsByBotIdMcpStdioResponses, PostBotsByBotIdScheduleData, PostBotsByBotIdScheduleError, PostBotsByBotIdScheduleErrors, PostBotsByBotIdScheduleResponse, PostBotsByBotIdScheduleResponses, PostBotsByBotIdSettingsData, PostBotsByBotIdSettingsError, PostBotsByBotIdSettingsErrors, PostBotsByBotIdSettingsResponse, PostBotsByBotIdSettingsResponses, PostBotsByBotIdSubagentsByIdSkillsData, PostBotsByBotIdSubagentsByIdSkillsError, PostBotsByBotIdSubagentsByIdSkillsErrors, PostBotsByBotIdSubagentsByIdSkillsResponse, PostBotsByBotIdSubagentsByIdSkillsResponses, PostBotsByBotIdSubagentsData, PostBotsByBotIdSubagentsError, PostBotsByBotIdSubagentsErrors, PostBotsByBotIdSubagentsResponse, PostBotsByBotIdSubagentsResponses, PostBotsByBotIdToolsData, PostBotsByBotIdToolsError, PostBotsByBotIdToolsErrors, PostBotsByBotIdToolsResponse, PostBotsByBotIdToolsResponses, PostBotsByIdChannelByPlatformSendChatData, PostBotsByIdChannelByPlatformSendChatError, PostBotsByIdChannelByPlatformSendChatErrors, PostBotsByIdChannelByPlatformSendChatResponse, PostBotsByIdChannelByPlatformSendChatResponses, PostBotsByIdChannelByPlatformSendData, PostBotsByIdChannelByPlatformSendError, PostBotsByIdChannelByPlatformSendErrors, PostBotsByIdChannelByPlatformSendResponse, PostBotsByIdChannelByPlatformSendResponses, PostBotsData, PostBotsError, PostBotsErrors, PostBotsResponse, PostBotsResponses, PostEmbeddingsData, PostEmbeddingsError, PostEmbeddingsErrors, PostEmbeddingsResponse, PostEmbeddingsResponses, PostModelsData, PostModelsEnableData, PostModelsEnableError, PostModelsEnableErrors, PostModelsEnableResponse, PostModelsEnableResponses, PostModelsError, PostModelsErrors, PostModelsResponse, PostModelsResponses, PostProvidersData, PostProvidersError, PostProvidersErrors, PostProvidersResponse, PostProvidersResponses, PostUsersData, PostUsersError, PostUsersErrors, PostUsersResponse, PostUsersResponses, ProvidersClientType, ProvidersCountResponse, ProvidersCreateRequest, ProvidersGetResponse, ProvidersUpdateRequest, PutBotsByBotIdMcpByIdData, PutBotsByBotIdMcpByIdError, PutBotsByBotIdMcpByIdErrors, PutBotsByBotIdMcpByIdResponse, PutBotsByBotIdMcpByIdResponses, PutBotsByBotIdScheduleByIdData, PutBotsByBotIdScheduleByIdError, PutBotsByBotIdScheduleByIdErrors, PutBotsByBotIdScheduleByIdResponse, PutBotsByBotIdScheduleByIdResponses, PutBotsByBotIdSettingsData, PutBotsByBotIdSettingsError, PutBotsByBotIdSettingsErrors, PutBotsByBotIdSettingsResponse, PutBotsByBotIdSettingsResponses, PutBotsByBotIdSubagentsByIdContextData, PutBotsByBotIdSubagentsByIdContextError, PutBotsByBotIdSubagentsByIdContextErrors, PutBotsByBotIdSubagentsByIdContextResponse, PutBotsByBotIdSubagentsByIdContextResponses, PutBotsByBotIdSubagentsByIdData, PutBotsByBotIdSubagentsByIdError, PutBotsByBotIdSubagentsByIdErrors, PutBotsByBotIdSubagentsByIdResponse, PutBotsByBotIdSubagentsByIdResponses, PutBotsByBotIdSubagentsByIdSkillsData, PutBotsByBotIdSubagentsByIdSkillsError, PutBotsByBotIdSubagentsByIdSkillsErrors, PutBotsByBotIdSubagentsByIdSkillsResponse, PutBotsByBotIdSubagentsByIdSkillsResponses, PutBotsByIdChannelByPlatformData, PutBotsByIdChannelByPlatformError, PutBotsByIdChannelByPlatformErrors, PutBotsByIdChannelByPlatformResponse, PutBotsByIdChannelByPlatformResponses, PutBotsByIdData, PutBotsByIdError, PutBotsByIdErrors, PutBotsByIdMembersData, PutBotsByIdMembersError, PutBotsByIdMembersErrors, PutBotsByIdMembersResponse, PutBotsByIdMembersResponses, PutBotsByIdOwnerData, PutBotsByIdOwnerError, PutBotsByIdOwnerErrors, PutBotsByIdOwnerResponse, PutBotsByIdOwnerResponses, PutBotsByIdResponse, PutBotsByIdResponses, PutModelsByIdData, PutModelsByIdError, PutModelsByIdErrors, PutModelsByIdResponse, PutModelsByIdResponses, PutModelsModelByModelIdData, PutModelsModelByModelIdError, PutModelsModelByModelIdErrors, PutModelsModelByModelIdResponse, PutModelsModelByModelIdResponses, PutProvidersByIdData, PutProvidersByIdError, PutProvidersByIdErrors, PutProvidersByIdResponse, PutProvidersByIdResponses, PutUsersByIdData, PutUsersByIdError, PutUsersByIdErrors, PutUsersByIdPasswordData, PutUsersByIdPasswordError, PutUsersByIdPasswordErrors, PutUsersByIdPasswordResponses, PutUsersByIdResponse, PutUsersByIdResponses, PutUsersMeChannelsByPlatformData, PutUsersMeChannelsByPlatformError, PutUsersMeChannelsByPlatformErrors, PutUsersMeChannelsByPlatformResponse, PutUsersMeChannelsByPlatformResponses, PutUsersMeData, PutUsersMeError, PutUsersMeErrors, PutUsersMePasswordData, PutUsersMePasswordError, PutUsersMePasswordErrors, PutUsersMePasswordResponses, PutUsersMeResponse, PutUsersMeResponses, ScheduleCreateRequest, ScheduleListResponse, ScheduleNullableInt, ScheduleSchedule, ScheduleUpdateRequest, SettingsSettings, SettingsUpsertRequest, SubagentAddSkillsRequest, SubagentContextResponse, SubagentCreateRequest, SubagentListResponse, SubagentSkillsResponse, SubagentSubagent, SubagentUpdateContextRequest, SubagentUpdateRequest, SubagentUpdateSkillsRequest } from './types.gen'; diff --git a/packages/sdk/src/sdk.gen.ts b/packages/sdk/src/sdk.gen.ts index 44e22042..91a60eae 100644 --- a/packages/sdk/src/sdk.gen.ts +++ b/packages/sdk/src/sdk.gen.ts @@ -2,7 +2,7 @@ import type { Client, Options as Options2, TDataShape } from './client'; import { client } from './client.gen'; -import type { DeleteBotsByBotIdContainerData, DeleteBotsByBotIdContainerErrors, DeleteBotsByBotIdContainerFsData, DeleteBotsByBotIdContainerFsErrors, DeleteBotsByBotIdContainerFsResponses, DeleteBotsByBotIdContainerResponses, DeleteBotsByBotIdContainerSkillsData, DeleteBotsByBotIdContainerSkillsErrors, DeleteBotsByBotIdContainerSkillsResponses, DeleteBotsByBotIdHistoryByIdData, DeleteBotsByBotIdHistoryByIdErrors, DeleteBotsByBotIdHistoryByIdResponses, DeleteBotsByBotIdHistoryData, DeleteBotsByBotIdHistoryErrors, DeleteBotsByBotIdHistoryResponses, DeleteBotsByBotIdMcpByIdData, DeleteBotsByBotIdMcpByIdErrors, DeleteBotsByBotIdMcpByIdResponses, DeleteBotsByBotIdMemoryMemoriesByMemoryIdData, DeleteBotsByBotIdMemoryMemoriesByMemoryIdErrors, DeleteBotsByBotIdMemoryMemoriesByMemoryIdResponses, DeleteBotsByBotIdMemoryMemoriesData, DeleteBotsByBotIdMemoryMemoriesErrors, DeleteBotsByBotIdMemoryMemoriesResponses, DeleteBotsByBotIdScheduleByIdData, DeleteBotsByBotIdScheduleByIdErrors, DeleteBotsByBotIdScheduleByIdResponses, DeleteBotsByBotIdSettingsData, DeleteBotsByBotIdSettingsErrors, DeleteBotsByBotIdSettingsResponses, DeleteBotsByBotIdSubagentsByIdData, DeleteBotsByBotIdSubagentsByIdErrors, DeleteBotsByBotIdSubagentsByIdResponses, DeleteBotsByIdData, DeleteBotsByIdErrors, DeleteBotsByIdMembersByUserIdData, DeleteBotsByIdMembersByUserIdErrors, DeleteBotsByIdMembersByUserIdResponses, DeleteBotsByIdResponses, DeleteModelsByIdData, DeleteModelsByIdErrors, DeleteModelsByIdResponses, DeleteModelsModelByModelIdData, DeleteModelsModelByModelIdErrors, DeleteModelsModelByModelIdResponses, DeleteProvidersByIdData, DeleteProvidersByIdErrors, DeleteProvidersByIdResponses, GetBotsByBotIdContainerData, GetBotsByBotIdContainerErrors, GetBotsByBotIdContainerFsData, GetBotsByBotIdContainerFsErrors, GetBotsByBotIdContainerFsFileData, GetBotsByBotIdContainerFsFileErrors, GetBotsByBotIdContainerFsFileResponses, GetBotsByBotIdContainerFsResponses, GetBotsByBotIdContainerFsStatData, GetBotsByBotIdContainerFsStatErrors, GetBotsByBotIdContainerFsStatResponses, GetBotsByBotIdContainerFsUsageData, GetBotsByBotIdContainerFsUsageErrors, GetBotsByBotIdContainerFsUsageResponses, GetBotsByBotIdContainerResponses, GetBotsByBotIdContainerSkillsData, GetBotsByBotIdContainerSkillsErrors, GetBotsByBotIdContainerSkillsResponses, GetBotsByBotIdContainerSnapshotsData, GetBotsByBotIdContainerSnapshotsResponses, GetBotsByBotIdHistoryByIdData, GetBotsByBotIdHistoryByIdErrors, GetBotsByBotIdHistoryByIdResponses, GetBotsByBotIdHistoryData, GetBotsByBotIdHistoryErrors, GetBotsByBotIdHistoryResponses, GetBotsByBotIdMcpByIdData, GetBotsByBotIdMcpByIdErrors, GetBotsByBotIdMcpByIdResponses, GetBotsByBotIdMcpData, GetBotsByBotIdMcpErrors, GetBotsByBotIdMcpResponses, GetBotsByBotIdMemoryMemoriesByMemoryIdData, GetBotsByBotIdMemoryMemoriesByMemoryIdErrors, GetBotsByBotIdMemoryMemoriesByMemoryIdResponses, GetBotsByBotIdMemoryMemoriesData, GetBotsByBotIdMemoryMemoriesErrors, GetBotsByBotIdMemoryMemoriesResponses, GetBotsByBotIdScheduleByIdData, GetBotsByBotIdScheduleByIdErrors, GetBotsByBotIdScheduleByIdResponses, GetBotsByBotIdScheduleData, GetBotsByBotIdScheduleErrors, GetBotsByBotIdScheduleResponses, GetBotsByBotIdSettingsData, GetBotsByBotIdSettingsErrors, GetBotsByBotIdSettingsResponses, GetBotsByBotIdSubagentsByIdContextData, GetBotsByBotIdSubagentsByIdContextErrors, GetBotsByBotIdSubagentsByIdContextResponses, GetBotsByBotIdSubagentsByIdData, GetBotsByBotIdSubagentsByIdErrors, GetBotsByBotIdSubagentsByIdResponses, GetBotsByBotIdSubagentsByIdSkillsData, GetBotsByBotIdSubagentsByIdSkillsErrors, GetBotsByBotIdSubagentsByIdSkillsResponses, GetBotsByBotIdSubagentsData, GetBotsByBotIdSubagentsErrors, GetBotsByBotIdSubagentsResponses, GetBotsByIdChannelByPlatformData, GetBotsByIdChannelByPlatformErrors, GetBotsByIdChannelByPlatformResponses, GetBotsByIdData, GetBotsByIdErrors, GetBotsByIdMembersData, GetBotsByIdMembersErrors, GetBotsByIdMembersResponses, GetBotsByIdResponses, GetBotsData, GetBotsErrors, GetBotsResponses, GetChannelsByPlatformData, GetChannelsByPlatformErrors, GetChannelsByPlatformResponses, GetChannelsData, GetChannelsErrors, GetChannelsResponses, GetModelsByIdData, GetModelsByIdErrors, GetModelsByIdResponses, GetModelsCountData, GetModelsCountErrors, GetModelsCountResponses, GetModelsData, GetModelsErrors, GetModelsModelByModelIdData, GetModelsModelByModelIdErrors, GetModelsModelByModelIdResponses, GetModelsResponses, GetProvidersByIdData, GetProvidersByIdErrors, GetProvidersByIdModelsData, GetProvidersByIdModelsErrors, GetProvidersByIdModelsResponses, GetProvidersByIdResponses, GetProvidersCountData, GetProvidersCountErrors, GetProvidersCountResponses, GetProvidersData, GetProvidersErrors, GetProvidersNameByNameData, GetProvidersNameByNameErrors, GetProvidersNameByNameResponses, GetProvidersResponses, GetUsersByIdData, GetUsersByIdErrors, GetUsersByIdResponses, GetUsersData, GetUsersErrors, GetUsersMeChannelsByPlatformData, GetUsersMeChannelsByPlatformErrors, GetUsersMeChannelsByPlatformResponses, GetUsersMeData, GetUsersMeErrors, GetUsersMeResponses, GetUsersResponses, PostAuthLoginData, PostAuthLoginErrors, PostAuthLoginResponses, PostBotsByBotIdChatData, PostBotsByBotIdChatErrors, PostBotsByBotIdChatResponses, PostBotsByBotIdChatStreamData, PostBotsByBotIdChatStreamErrors, PostBotsByBotIdChatStreamResponses, PostBotsByBotIdContainerData, PostBotsByBotIdContainerErrors, PostBotsByBotIdContainerFsDirData, PostBotsByBotIdContainerFsDirErrors, PostBotsByBotIdContainerFsDirResponses, PostBotsByBotIdContainerFsFileData, PostBotsByBotIdContainerFsFileErrors, PostBotsByBotIdContainerFsFileResponses, PostBotsByBotIdContainerFsMcpData, PostBotsByBotIdContainerFsMcpErrors, PostBotsByBotIdContainerFsMcpResponses, PostBotsByBotIdContainerFsUploadData, PostBotsByBotIdContainerFsUploadErrors, PostBotsByBotIdContainerFsUploadResponses, PostBotsByBotIdContainerResponses, PostBotsByBotIdContainerSkillsData, PostBotsByBotIdContainerSkillsErrors, PostBotsByBotIdContainerSkillsResponses, PostBotsByBotIdContainerSnapshotsData, PostBotsByBotIdContainerSnapshotsErrors, PostBotsByBotIdContainerSnapshotsResponses, PostBotsByBotIdContainerStartData, PostBotsByBotIdContainerStartErrors, PostBotsByBotIdContainerStartResponses, PostBotsByBotIdContainerStopData, PostBotsByBotIdContainerStopErrors, PostBotsByBotIdContainerStopResponses, PostBotsByBotIdHistoryData, PostBotsByBotIdHistoryErrors, PostBotsByBotIdHistoryResponses, PostBotsByBotIdMcpData, PostBotsByBotIdMcpErrors, PostBotsByBotIdMcpResponses, PostBotsByBotIdMcpStdioBySessionIdData, PostBotsByBotIdMcpStdioBySessionIdErrors, PostBotsByBotIdMcpStdioBySessionIdResponses, PostBotsByBotIdMcpStdioData, PostBotsByBotIdMcpStdioErrors, PostBotsByBotIdMcpStdioResponses, PostBotsByBotIdMemoryAddData, PostBotsByBotIdMemoryAddErrors, PostBotsByBotIdMemoryAddResponses, PostBotsByBotIdMemoryEmbedData, PostBotsByBotIdMemoryEmbedErrors, PostBotsByBotIdMemoryEmbedResponses, PostBotsByBotIdMemorySearchData, PostBotsByBotIdMemorySearchErrors, PostBotsByBotIdMemorySearchResponses, PostBotsByBotIdMemoryUpdateData, PostBotsByBotIdMemoryUpdateErrors, PostBotsByBotIdMemoryUpdateResponses, PostBotsByBotIdScheduleData, PostBotsByBotIdScheduleErrors, PostBotsByBotIdScheduleResponses, PostBotsByBotIdSettingsData, PostBotsByBotIdSettingsErrors, PostBotsByBotIdSettingsResponses, PostBotsByBotIdSubagentsByIdSkillsData, PostBotsByBotIdSubagentsByIdSkillsErrors, PostBotsByBotIdSubagentsByIdSkillsResponses, PostBotsByBotIdSubagentsData, PostBotsByBotIdSubagentsErrors, PostBotsByBotIdSubagentsResponses, PostBotsByIdChannelByPlatformSendData, PostBotsByIdChannelByPlatformSendErrors, PostBotsByIdChannelByPlatformSendResponses, PostBotsByIdChannelByPlatformSendSessionData, PostBotsByIdChannelByPlatformSendSessionErrors, PostBotsByIdChannelByPlatformSendSessionResponses, PostBotsData, PostBotsErrors, PostBotsResponses, PostEmbeddingsData, PostEmbeddingsErrors, PostEmbeddingsResponses, PostModelsData, PostModelsEnableData, PostModelsEnableErrors, PostModelsEnableResponses, PostModelsErrors, PostModelsResponses, PostProvidersData, PostProvidersErrors, PostProvidersResponses, PostUsersData, PostUsersErrors, PostUsersResponses, PutBotsByBotIdMcpByIdData, PutBotsByBotIdMcpByIdErrors, PutBotsByBotIdMcpByIdResponses, PutBotsByBotIdScheduleByIdData, PutBotsByBotIdScheduleByIdErrors, PutBotsByBotIdScheduleByIdResponses, PutBotsByBotIdSettingsData, PutBotsByBotIdSettingsErrors, PutBotsByBotIdSettingsResponses, PutBotsByBotIdSubagentsByIdContextData, PutBotsByBotIdSubagentsByIdContextErrors, PutBotsByBotIdSubagentsByIdContextResponses, PutBotsByBotIdSubagentsByIdData, PutBotsByBotIdSubagentsByIdErrors, PutBotsByBotIdSubagentsByIdResponses, PutBotsByBotIdSubagentsByIdSkillsData, PutBotsByBotIdSubagentsByIdSkillsErrors, PutBotsByBotIdSubagentsByIdSkillsResponses, PutBotsByIdChannelByPlatformData, PutBotsByIdChannelByPlatformErrors, PutBotsByIdChannelByPlatformResponses, PutBotsByIdData, PutBotsByIdErrors, PutBotsByIdMembersData, PutBotsByIdMembersErrors, PutBotsByIdMembersResponses, PutBotsByIdOwnerData, PutBotsByIdOwnerErrors, PutBotsByIdOwnerResponses, PutBotsByIdResponses, PutModelsByIdData, PutModelsByIdErrors, PutModelsByIdResponses, PutModelsModelByModelIdData, PutModelsModelByModelIdErrors, PutModelsModelByModelIdResponses, PutProvidersByIdData, PutProvidersByIdErrors, PutProvidersByIdResponses, PutUsersByIdData, PutUsersByIdErrors, PutUsersByIdPasswordData, PutUsersByIdPasswordErrors, PutUsersByIdPasswordResponses, PutUsersByIdResponses, PutUsersMeChannelsByPlatformData, PutUsersMeChannelsByPlatformErrors, PutUsersMeChannelsByPlatformResponses, PutUsersMeData, PutUsersMeErrors, PutUsersMePasswordData, PutUsersMePasswordErrors, PutUsersMePasswordResponses, PutUsersMeResponses } from './types.gen'; +import type { DeleteBotsByBotIdContainerData, DeleteBotsByBotIdContainerErrors, DeleteBotsByBotIdContainerResponses, DeleteBotsByBotIdContainerSkillsData, DeleteBotsByBotIdContainerSkillsErrors, DeleteBotsByBotIdContainerSkillsResponses, DeleteBotsByBotIdMcpByIdData, DeleteBotsByBotIdMcpByIdErrors, DeleteBotsByBotIdMcpByIdResponses, DeleteBotsByBotIdScheduleByIdData, DeleteBotsByBotIdScheduleByIdErrors, DeleteBotsByBotIdScheduleByIdResponses, DeleteBotsByBotIdSettingsData, DeleteBotsByBotIdSettingsErrors, DeleteBotsByBotIdSettingsResponses, DeleteBotsByBotIdSubagentsByIdData, DeleteBotsByBotIdSubagentsByIdErrors, DeleteBotsByBotIdSubagentsByIdResponses, DeleteBotsByIdData, DeleteBotsByIdErrors, DeleteBotsByIdMembersByUserIdData, DeleteBotsByIdMembersByUserIdErrors, DeleteBotsByIdMembersByUserIdResponses, DeleteBotsByIdResponses, DeleteModelsByIdData, DeleteModelsByIdErrors, DeleteModelsByIdResponses, DeleteModelsModelByModelIdData, DeleteModelsModelByModelIdErrors, DeleteModelsModelByModelIdResponses, DeleteProvidersByIdData, DeleteProvidersByIdErrors, DeleteProvidersByIdResponses, GetBotsByBotIdContainerData, GetBotsByBotIdContainerErrors, GetBotsByBotIdContainerResponses, GetBotsByBotIdContainerSkillsData, GetBotsByBotIdContainerSkillsErrors, GetBotsByBotIdContainerSkillsResponses, GetBotsByBotIdContainerSnapshotsData, GetBotsByBotIdContainerSnapshotsResponses, GetBotsByBotIdMcpByIdData, GetBotsByBotIdMcpByIdErrors, GetBotsByBotIdMcpByIdResponses, GetBotsByBotIdMcpData, GetBotsByBotIdMcpErrors, GetBotsByBotIdMcpResponses, GetBotsByBotIdScheduleByIdData, GetBotsByBotIdScheduleByIdErrors, GetBotsByBotIdScheduleByIdResponses, GetBotsByBotIdScheduleData, GetBotsByBotIdScheduleErrors, GetBotsByBotIdScheduleResponses, GetBotsByBotIdSettingsData, GetBotsByBotIdSettingsErrors, GetBotsByBotIdSettingsResponses, GetBotsByBotIdSubagentsByIdContextData, GetBotsByBotIdSubagentsByIdContextErrors, GetBotsByBotIdSubagentsByIdContextResponses, GetBotsByBotIdSubagentsByIdData, GetBotsByBotIdSubagentsByIdErrors, GetBotsByBotIdSubagentsByIdResponses, GetBotsByBotIdSubagentsByIdSkillsData, GetBotsByBotIdSubagentsByIdSkillsErrors, GetBotsByBotIdSubagentsByIdSkillsResponses, GetBotsByBotIdSubagentsData, GetBotsByBotIdSubagentsErrors, GetBotsByBotIdSubagentsResponses, GetBotsByIdChannelByPlatformData, GetBotsByIdChannelByPlatformErrors, GetBotsByIdChannelByPlatformResponses, GetBotsByIdChecksData, GetBotsByIdChecksErrors, GetBotsByIdChecksResponses, GetBotsByIdData, GetBotsByIdErrors, GetBotsByIdMembersData, GetBotsByIdMembersErrors, GetBotsByIdMembersResponses, GetBotsByIdResponses, GetBotsData, GetBotsErrors, GetBotsResponses, GetChannelsByPlatformData, GetChannelsByPlatformErrors, GetChannelsByPlatformResponses, GetChannelsData, GetChannelsErrors, GetChannelsResponses, GetModelsByIdData, GetModelsByIdErrors, GetModelsByIdResponses, GetModelsCountData, GetModelsCountErrors, GetModelsCountResponses, GetModelsData, GetModelsErrors, GetModelsModelByModelIdData, GetModelsModelByModelIdErrors, GetModelsModelByModelIdResponses, GetModelsResponses, GetProvidersByIdData, GetProvidersByIdErrors, GetProvidersByIdModelsData, GetProvidersByIdModelsErrors, GetProvidersByIdModelsResponses, GetProvidersByIdResponses, GetProvidersCountData, GetProvidersCountErrors, GetProvidersCountResponses, GetProvidersData, GetProvidersErrors, GetProvidersNameByNameData, GetProvidersNameByNameErrors, GetProvidersNameByNameResponses, GetProvidersResponses, GetUsersByIdData, GetUsersByIdErrors, GetUsersByIdResponses, GetUsersData, GetUsersErrors, GetUsersMeChannelsByPlatformData, GetUsersMeChannelsByPlatformErrors, GetUsersMeChannelsByPlatformResponses, GetUsersMeData, GetUsersMeErrors, GetUsersMeIdentitiesData, GetUsersMeIdentitiesErrors, GetUsersMeIdentitiesResponses, GetUsersMeResponses, GetUsersResponses, PostAuthLoginData, PostAuthLoginErrors, PostAuthLoginResponses, PostBotsByBotIdContainerData, PostBotsByBotIdContainerErrors, PostBotsByBotIdContainerResponses, PostBotsByBotIdContainerSkillsData, PostBotsByBotIdContainerSkillsErrors, PostBotsByBotIdContainerSkillsResponses, PostBotsByBotIdContainerSnapshotsData, PostBotsByBotIdContainerSnapshotsErrors, PostBotsByBotIdContainerSnapshotsResponses, PostBotsByBotIdContainerStartData, PostBotsByBotIdContainerStartErrors, PostBotsByBotIdContainerStartResponses, PostBotsByBotIdContainerStopData, PostBotsByBotIdContainerStopErrors, PostBotsByBotIdContainerStopResponses, PostBotsByBotIdMcpData, PostBotsByBotIdMcpErrors, PostBotsByBotIdMcpResponses, PostBotsByBotIdMcpStdioByConnectionIdData, PostBotsByBotIdMcpStdioByConnectionIdErrors, PostBotsByBotIdMcpStdioByConnectionIdResponses, PostBotsByBotIdMcpStdioData, PostBotsByBotIdMcpStdioErrors, PostBotsByBotIdMcpStdioResponses, PostBotsByBotIdScheduleData, PostBotsByBotIdScheduleErrors, PostBotsByBotIdScheduleResponses, PostBotsByBotIdSettingsData, PostBotsByBotIdSettingsErrors, PostBotsByBotIdSettingsResponses, PostBotsByBotIdSubagentsByIdSkillsData, PostBotsByBotIdSubagentsByIdSkillsErrors, PostBotsByBotIdSubagentsByIdSkillsResponses, PostBotsByBotIdSubagentsData, PostBotsByBotIdSubagentsErrors, PostBotsByBotIdSubagentsResponses, PostBotsByBotIdToolsData, PostBotsByBotIdToolsErrors, PostBotsByBotIdToolsResponses, PostBotsByIdChannelByPlatformSendChatData, PostBotsByIdChannelByPlatformSendChatErrors, PostBotsByIdChannelByPlatformSendChatResponses, PostBotsByIdChannelByPlatformSendData, PostBotsByIdChannelByPlatformSendErrors, PostBotsByIdChannelByPlatformSendResponses, PostBotsData, PostBotsErrors, PostBotsResponses, PostEmbeddingsData, PostEmbeddingsErrors, PostEmbeddingsResponses, PostModelsData, PostModelsEnableData, PostModelsEnableErrors, PostModelsEnableResponses, PostModelsErrors, PostModelsResponses, PostProvidersData, PostProvidersErrors, PostProvidersResponses, PostUsersData, PostUsersErrors, PostUsersResponses, PutBotsByBotIdMcpByIdData, PutBotsByBotIdMcpByIdErrors, PutBotsByBotIdMcpByIdResponses, PutBotsByBotIdScheduleByIdData, PutBotsByBotIdScheduleByIdErrors, PutBotsByBotIdScheduleByIdResponses, PutBotsByBotIdSettingsData, PutBotsByBotIdSettingsErrors, PutBotsByBotIdSettingsResponses, PutBotsByBotIdSubagentsByIdContextData, PutBotsByBotIdSubagentsByIdContextErrors, PutBotsByBotIdSubagentsByIdContextResponses, PutBotsByBotIdSubagentsByIdData, PutBotsByBotIdSubagentsByIdErrors, PutBotsByBotIdSubagentsByIdResponses, PutBotsByBotIdSubagentsByIdSkillsData, PutBotsByBotIdSubagentsByIdSkillsErrors, PutBotsByBotIdSubagentsByIdSkillsResponses, PutBotsByIdChannelByPlatformData, PutBotsByIdChannelByPlatformErrors, PutBotsByIdChannelByPlatformResponses, PutBotsByIdData, PutBotsByIdErrors, PutBotsByIdMembersData, PutBotsByIdMembersErrors, PutBotsByIdMembersResponses, PutBotsByIdOwnerData, PutBotsByIdOwnerErrors, PutBotsByIdOwnerResponses, PutBotsByIdResponses, PutModelsByIdData, PutModelsByIdErrors, PutModelsByIdResponses, PutModelsModelByModelIdData, PutModelsModelByModelIdErrors, PutModelsModelByModelIdResponses, PutProvidersByIdData, PutProvidersByIdErrors, PutProvidersByIdResponses, PutUsersByIdData, PutUsersByIdErrors, PutUsersByIdPasswordData, PutUsersByIdPasswordErrors, PutUsersByIdPasswordResponses, PutUsersByIdResponses, PutUsersMeChannelsByPlatformData, PutUsersMeChannelsByPlatformErrors, PutUsersMeChannelsByPlatformResponses, PutUsersMeData, PutUsersMeErrors, PutUsersMePasswordData, PutUsersMePasswordErrors, PutUsersMePasswordResponses, PutUsersMeResponses } from './types.gen'; export type Options = Options2 & { /** @@ -53,34 +53,6 @@ export const postBots = (options: Options< } }); -/** - * Chat with AI - * - * Send a chat message and get a response. The system will automatically select an appropriate chat model from the database. - */ -export const postBotsByBotIdChat = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/chat', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - -/** - * Stream chat with AI - * - * Send a chat message and get a streaming response. The system will automatically select an appropriate chat model from the database. - */ -export const postBotsByBotIdChatStream = (options: Options) => (options.client ?? client).sse.post({ - url: '/bots/{bot_id}/chat/stream', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - /** * Delete MCP container for bot */ @@ -103,88 +75,6 @@ export const postBotsByBotIdContainer = (o } }); -/** - * Delete a file or directory - */ -export const deleteBotsByBotIdContainerFs = (options: Options) => (options.client ?? client).delete({ url: '/bots/{bot_id}/container/fs', ...options }); - -/** - * List files for a bot - * - * List entries under a relative path - */ -export const getBotsByBotIdContainerFs = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/container/fs', ...options }); - -/** - * MCP filesystem tools (JSON-RPC) - * - * Forwards MCP JSON-RPC requests to the MCP server inside the container. - * Required: - * - container task is running - * - container has data mount (default /data) bound to /users/ - * - container image contains the "mcp" binary - * Auth: Bearer JWT is used to determine user_id (sub or user_id). - * Paths must be relative (no leading slash) and must not contain "..". - * - * Example: tools/list - * {"jsonrpc":"2.0","id":1,"method":"tools/list"} - * - * Example: tools/call (fs.read) - * {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"fs.read","arguments":{"path":"notes.txt"}}} - */ -export const postBotsByBotIdContainerFsMcp = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/container/fs-mcp', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - -/** - * Create a directory - */ -export const postBotsByBotIdContainerFsDir = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/container/fs/dir', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - -/** - * Read file content - */ -export const getBotsByBotIdContainerFsFile = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/container/fs/file', ...options }); - -/** - * Create or overwrite a file - */ -export const postBotsByBotIdContainerFsFile = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/container/fs/file', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - -/** - * Get file or directory metadata - */ -export const getBotsByBotIdContainerFsStat = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/container/fs/stat', ...options }); - -/** - * Upload a file - */ -export const postBotsByBotIdContainerFsUpload = (options: Options) => (options.client ?? client).post({ url: '/bots/{bot_id}/container/fs/upload', ...options }); - -/** - * Get usage under a path - */ -export const getBotsByBotIdContainerFsUsage = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/container/fs/usage', ...options }); - /** * Delete skills from data directory */ @@ -241,54 +131,12 @@ export const postBotsByBotIdContainerStart = (options: Options) => (options.client ?? client).post({ url: '/bots/{bot_id}/container/stop', ...options }); -/** - * Delete all history records - * - * Delete all history records for current user - */ -export const deleteBotsByBotIdHistory = (options: Options) => (options.client ?? client).delete({ url: '/bots/{bot_id}/history', ...options }); - -/** - * List history records - * - * List history records for current user - */ -export const getBotsByBotIdHistory = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/history', ...options }); - -/** - * Create history record - * - * Create a history record for current user - */ -export const postBotsByBotIdHistory = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/history', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - -/** - * Delete history record - * - * Delete a history record by ID (must belong to current user) - */ -export const deleteBotsByBotIdHistoryById = (options: Options) => (options.client ?? client).delete({ url: '/bots/{bot_id}/history/{id}', ...options }); - -/** - * Get history record - * - * Get a history record by ID (must belong to current user) - */ -export const getBotsByBotIdHistoryById = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/history/{id}', ...options }); - /** * List MCP connections * * List MCP connections for a bot */ -export const getBotsByBotIdMcp = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/mcp', ...options }); +export const getBotsByBotIdMcp = (options?: Options) => (options?.client ?? client).get({ url: '/bots/{bot_id}/mcp', ...options }); /** * Create MCP connection @@ -323,8 +171,8 @@ export const postBotsByBotIdMcpStdio = (op * * Proxies MCP JSON-RPC requests to a stdio MCP process in the container. */ -export const postBotsByBotIdMcpStdioBySessionId = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/mcp-stdio/{session_id}', +export const postBotsByBotIdMcpStdioByConnectionId = (options: Options) => (options.client ?? client).post({ + url: '/bots/{bot_id}/mcp-stdio/{connection_id}', ...options, headers: { 'Content-Type': 'application/json', @@ -360,103 +208,12 @@ export const putBotsByBotIdMcpById = (opti } }); -/** - * Add memory - * - * Add memory for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const postBotsByBotIdMemoryAdd = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/memory/add', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - -/** - * Embed and upsert memory - * - * Embed text or multimodal input and upsert into memory store. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const postBotsByBotIdMemoryEmbed = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/memory/embed', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - -/** - * Delete memories - * - * Delete all memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const deleteBotsByBotIdMemoryMemories = (options: Options) => (options.client ?? client).delete({ - url: '/bots/{bot_id}/memory/memories', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - -/** - * List memories - * - * List memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const getBotsByBotIdMemoryMemories = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/memory/memories', ...options }); - -/** - * Delete memory - * - * Delete a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const deleteBotsByBotIdMemoryMemoriesByMemoryId = (options: Options) => (options.client ?? client).delete({ url: '/bots/{bot_id}/memory/memories/{memoryId}', ...options }); - -/** - * Get memory - * - * Get a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const getBotsByBotIdMemoryMemoriesByMemoryId = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/memory/memories/{memoryId}', ...options }); - -/** - * Search memories - * - * Search memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const postBotsByBotIdMemorySearch = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/memory/search', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - -/** - * Update memory - * - * Update a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). - */ -export const postBotsByBotIdMemoryUpdate = (options: Options) => (options.client ?? client).post({ - url: '/bots/{bot_id}/memory/update', - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers - } -}); - /** * List schedules * * List schedules for current user */ -export const getBotsByBotIdSchedule = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/schedule', ...options }); +export const getBotsByBotIdSchedule = (options?: Options) => (options?.client ?? client).get({ url: '/bots/{bot_id}/schedule', ...options }); /** * Create schedule @@ -505,14 +262,14 @@ export const putBotsByBotIdScheduleById = * * Remove agent settings for current user */ -export const deleteBotsByBotIdSettings = (options: Options) => (options.client ?? client).delete({ url: '/bots/{bot_id}/settings', ...options }); +export const deleteBotsByBotIdSettings = (options?: Options) => (options?.client ?? client).delete({ url: '/bots/{bot_id}/settings', ...options }); /** * Get user settings * * Get agent settings for current user */ -export const getBotsByBotIdSettings = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/settings', ...options }); +export const getBotsByBotIdSettings = (options?: Options) => (options?.client ?? client).get({ url: '/bots/{bot_id}/settings', ...options }); /** * Update user settings @@ -547,7 +304,7 @@ export const putBotsByBotIdSettings = (opt * * List subagents for current user */ -export const getBotsByBotIdSubagents = (options: Options) => (options.client ?? client).get({ url: '/bots/{bot_id}/subagents', ...options }); +export const getBotsByBotIdSubagents = (options?: Options) => (options?.client ?? client).get({ url: '/bots/{bot_id}/subagents', ...options }); /** * Create subagent @@ -647,6 +404,20 @@ export const putBotsByBotIdSubagentsByIdSkills = (options: Options) => (options.client ?? client).post({ + url: '/bots/{bot_id}/tools', + ...options, + headers: { + 'Content-Type': 'application/json', + ...options.headers + } +}); + /** * Delete bot * @@ -715,8 +486,8 @@ export const postBotsByIdChannelByPlatformSend = (options: Options) => (options.client ?? client).post({ - url: '/bots/{id}/channel/{platform}/send_session', +export const postBotsByIdChannelByPlatformSendChat = (options: Options) => (options.client ?? client).post({ + url: '/bots/{id}/channel/{platform}/send_chat', ...options, headers: { 'Content-Type': 'application/json', @@ -724,6 +495,13 @@ export const postBotsByIdChannelByPlatformSendSession = (options: Options) => (options.client ?? client).get({ url: '/bots/{id}/checks', ...options }); + /** * List bot members * @@ -1025,6 +803,13 @@ export const putUsersMeChannelsByPlatform = (options?: Options) => (options?.client ?? client).get({ url: '/users/me/identities', ...options }); + /** * Update current user password * diff --git a/packages/sdk/src/types.gen.ts b/packages/sdk/src/types.gen.ts index 9a91b061..53952883 100644 --- a/packages/sdk/src/types.gen.ts +++ b/packages/sdk/src/types.gen.ts @@ -4,18 +4,79 @@ export type ClientOptions = { baseUrl: string; }; +export type AccountsAccount = { + avatar_url?: string; + created_at?: string; + display_name?: string; + email?: string; + id?: string; + is_active?: boolean; + last_login_at?: string; + role?: string; + updated_at?: string; + username?: string; +}; + +export type AccountsCreateAccountRequest = { + avatar_url?: string; + display_name?: string; + email?: string; + is_active?: boolean; + password?: string; + role?: string; + username?: string; +}; + +export type AccountsListAccountsResponse = { + items?: Array; +}; + +export type AccountsResetPasswordRequest = { + new_password?: string; +}; + +export type AccountsUpdateAccountRequest = { + avatar_url?: string; + display_name?: string; + is_active?: boolean; + role?: string; +}; + +export type AccountsUpdatePasswordRequest = { + current_password?: string; + new_password?: string; +}; + +export type AccountsUpdateProfileRequest = { + avatar_url?: string; + display_name?: string; +}; + export type BotsBot = { avatar_url?: string; - created_at: string; - display_name: string; - id: string; - is_active: boolean; + check_issue_count?: number; + check_state?: string; + created_at?: string; + display_name?: string; + id?: string; + is_active?: boolean; metadata?: { [key: string]: unknown; }; - owner_user_id: string; - type: string; - updated_at: string; + owner_user_id?: string; + status?: string; + type?: string; + updated_at?: string; +}; + +export type BotsBotCheck = { + check_key?: string; + detail?: string; + metadata?: { + [key: string]: unknown; + }; + status?: string; + summary?: string; }; export type BotsBotMember = { @@ -36,7 +97,11 @@ export type BotsCreateBotRequest = { }; export type BotsListBotsResponse = { - items: Array; + items?: Array; +}; + +export type BotsListChecksResponse = { + items?: Array; }; export type BotsListMembersResponse = { @@ -77,7 +142,9 @@ export type ChannelAttachment = { }; mime?: string; name?: string; + platform_key?: string; size?: number; + source_platform?: string; thumbnail_url?: string; type?: ChannelAttachmentType; url?: string; @@ -125,7 +192,8 @@ export type ChannelChannelConfig = { verifiedAt?: string; }; -export type ChannelChannelUserBinding = { +export type ChannelChannelIdentityBinding = { + channelIdentityID?: string; channelType?: string; config?: { [key: string]: unknown; @@ -133,7 +201,6 @@ export type ChannelChannelUserBinding = { createdAt?: string; id?: string; updatedAt?: string; - userID?: string; }; export type ChannelConfigSchema = { @@ -171,6 +238,7 @@ export type ChannelMessage = { export type ChannelMessageFormat = 'plain' | 'markdown' | 'rich'; export type ChannelMessagePart = { + channel_identity_id?: string; emoji?: string; language?: string; metadata?: { @@ -180,7 +248,6 @@ export type ChannelMessagePart = { text?: string; type?: ChannelMessagePartType; url?: string; - user_id?: string; }; export type ChannelMessagePartType = 'text' | 'link' | 'code_block' | 'mention' | 'emoji'; @@ -193,9 +260,9 @@ export type ChannelReplyRef = { }; export type ChannelSendRequest = { + channel_identity_id?: string; message?: ChannelMessage; target?: string; - user_id?: string; }; export type ChannelTargetHint = { @@ -212,6 +279,12 @@ export type ChannelThreadRef = { id?: string; }; +export type ChannelUpsertChannelIdentityConfigRequest = { + config?: { + [key: string]: unknown; + }; +}; + export type ChannelUpsertConfigRequest = { credentials?: { [key: string]: unknown; @@ -227,76 +300,30 @@ export type ChannelUpsertConfigRequest = { verified_at?: string; }; -export type ChannelUpsertUserConfigRequest = { +export type GithubComMemohaiMemohInternalMcpConnection = { + active?: boolean; + bot_id?: string; config?: { [key: string]: unknown; }; -}; - -export type ChatChatRequest = { - allowed_actions?: Array; - channels?: Array; - current_channel?: string; - language?: string; - max_context_load_time?: number; - messages?: Array; - model?: string; - provider?: string; - query?: string; - skills?: Array; -}; - -export type ChatChatResponse = { - messages?: Array; - model?: string; - provider?: string; - skills?: Array; -}; - -export type ChatModelMessage = { - content?: Array; - name?: string; - role?: string; - tool_call_id?: string; - tool_calls?: Array; -}; - -export type ChatToolCall = { - function?: ChatToolCallFunction; + created_at?: string; id?: string; - type?: string; -}; - -export type ChatToolCallFunction = { - arguments?: string; name?: string; -}; - -export type GithubComMemohaiMemohInternalMcpConnection = { - active: boolean; - bot_id: string; - config: { - [key: string]: unknown; - }; - created_at: string; - id: string; - name: string; - type: string; - updated_at: string; + type?: string; + updated_at?: string; }; export type HandlersChannelMeta = { - capabilities: ChannelChannelCapabilities; - config_schema: ChannelConfigSchema; + capabilities?: ChannelChannelCapabilities; + config_schema?: ChannelConfigSchema; configless?: boolean; - display_name: string; + display_name?: string; target_spec?: ChannelTargetSpec; - type: string; + type?: string; user_config_schema?: ChannelConfigSchema; }; export type HandlersCreateContainerRequest = { - image?: string; snapshotter?: string; }; @@ -356,61 +383,6 @@ export type HandlersErrorResponse = { message?: string; }; -export type HandlersFsDeleteResponse = { - ok?: boolean; -}; - -export type HandlersFsListResponse = { - entries?: Array; - path?: string; -}; - -export type HandlersFsMkdirRequest = { - parents?: boolean; - path?: string; -}; - -export type HandlersFsReadResponse = { - content?: string; - mod_time?: string; - mode?: number; - path?: string; - size?: number; -}; - -export type HandlersFsRestEntry = { - is_dir?: boolean; - mod_time?: string; - mode?: number; - path?: string; - size?: number; -}; - -export type HandlersFsStatResponse = { - is_dir?: boolean; - mod_time?: string; - mode?: number; - path?: string; - size?: number; -}; - -export type HandlersFsUsageResponse = { - dir_count?: number; - file_count?: number; - path?: string; - total_bytes?: number; -}; - -export type HandlersFsWriteRequest = { - content?: string; - overwrite?: boolean; - path?: string; -}; - -export type HandlersFsWriteResponse = { - ok?: boolean; -}; - export type HandlersGetContainerResponse = { container_id?: string; container_path?: string; @@ -429,18 +401,18 @@ export type HandlersListSnapshotsResponse = { }; export type HandlersLoginRequest = { - password: string; - username: string; + password?: string; + username?: string; }; export type HandlersLoginResponse = { - access_token: string; - display_name: string; - expires_at: string; - role: string; - token_type: string; - user_id: string; - username: string; + access_token?: string; + display_name?: string; + expires_at?: string; + role?: string; + token_type?: string; + user_id?: string; + username?: string; }; export type HandlersMcpStdioRequest = { @@ -454,7 +426,7 @@ export type HandlersMcpStdioRequest = { }; export type HandlersMcpStdioResponse = { - session_id?: string; + connection_id?: string; tools?: Array; url?: string; }; @@ -492,80 +464,26 @@ export type HandlersSnapshotInfo = { updated_at?: string; }; -export type HandlersMemoryAddPayload = { - embedding_enabled?: boolean; - filters?: { - [key: string]: unknown; - }; - infer?: boolean; - message?: string; - messages?: Array; - metadata?: { - [key: string]: unknown; - }; - run_id?: string; -}; - -export type HandlersMemoryDeleteAllPayload = { - run_id?: string; -}; - -export type HandlersMemoryEmbedUpsertPayload = { - filters?: { - [key: string]: unknown; - }; - input?: MemoryEmbedInput; - metadata?: { - [key: string]: unknown; - }; - model?: string; - provider?: string; - run_id?: string; - source?: string; - type?: string; -}; - -export type HandlersMemorySearchPayload = { - embedding_enabled?: boolean; - filters?: { - [key: string]: unknown; - }; - limit?: number; - query?: string; - run_id?: string; - sources?: Array; +export type HandlersListMyIdentitiesResponse = { + items?: Array; + user_id?: string; }; export type HandlersSkillsOpResponse = { ok?: boolean; }; -export type HistoryCreateRequest = { - messages?: Array<{ - [key: string]: unknown; - }>; - metadata?: { - [key: string]: unknown; - }; - skills?: Array; -}; - -export type HistoryListResponse = { - items?: Array; -}; - -export type HistoryRecord = { - bot_id?: string; +export type IdentitiesChannelIdentity = { + channel?: string; + channel_subject_id?: string; + created_at?: string; + display_name?: string; id?: string; - messages?: Array<{ - [key: string]: unknown; - }>; metadata?: { [key: string]: unknown; }; - session_id?: string; - skills?: Array; - timestamp?: string; + updated_at?: string; + user_id?: string; }; export type McpListResponse = { @@ -581,63 +499,14 @@ export type McpUpsertRequest = { type?: string; }; -export type MemoryDeleteResponse = { - message?: string; -}; - -export type MemoryEmbedInput = { - image_url?: string; - text?: string; - video_url?: string; -}; - -export type MemoryEmbedUpsertResponse = { - dimensions?: number; - item?: MemoryMemoryItem; - model?: string; - provider?: string; -}; - -export type MemoryMemoryItem = { - agentId?: string; - botId?: string; - createdAt?: string; - hash?: string; - id?: string; - memory?: string; - metadata?: { - [key: string]: unknown; - }; - runId?: string; - score?: number; - sessionId?: string; - updatedAt?: string; -}; - -export type MemoryMessage = { - content?: string; - role?: string; -}; - -export type MemorySearchResponse = { - relations?: Array; - results?: Array; -}; - -export type MemoryUpdateRequest = { - embedding_enabled?: boolean; - memory?: string; - memory_id?: string; -}; - export type ModelsAddRequest = { dimensions?: number; input?: Array; is_multimodal?: boolean; - llm_provider_id: string; - model_id: string; - name: string; - type: ModelsModelType; + llm_provider_id?: string; + model_id?: string; + name?: string; + type?: ModelsModelType; }; export type ModelsAddResponse = { @@ -653,10 +522,10 @@ export type ModelsGetResponse = { dimensions?: number; input?: Array; is_multimodal?: boolean; - llm_provider_id: string; - model_id: string; - name: string; - type: ModelsModelType; + llm_provider_id?: string; + model_id?: string; + name?: string; + type?: ModelsModelType; }; export type ModelsModelType = 'chat' | 'embedding'; @@ -665,10 +534,10 @@ export type ModelsUpdateRequest = { dimensions?: number; input?: Array; is_multimodal?: boolean; - llm_provider_id: string; - model_id: string; - name: string; - type: ModelsModelType; + llm_provider_id?: string; + model_id?: string; + name?: string; + type?: ModelsModelType; }; export type ProvidersClientType = 'openai' | 'openai-compat' | 'anthropic' | 'google' | 'ollama'; @@ -692,15 +561,15 @@ export type ProvidersGetResponse = { * masked in response */ api_key?: string; - base_url: string; - client_type: string; - created_at: string; - id: string; + base_url?: string; + client_type?: string; + created_at?: string; + id?: string; metadata?: { [key: string]: unknown; }; - name: string; - updated_at: string; + name?: string; + updated_at?: string; }; export type ProvidersUpdateRequest = { @@ -755,12 +624,12 @@ export type ScheduleUpdateRequest = { }; export type SettingsSettings = { - allow_guest: boolean; - chat_model_id: string; - embedding_model_id: string; - language: string; - max_context_load_time: number; - memory_model_id: string; + allow_guest?: boolean; + chat_model_id?: string; + embedding_model_id?: string; + language?: string; + max_context_load_time?: number; + memory_model_id?: string; }; export type SettingsUpsertRequest = { @@ -838,54 +707,6 @@ export type SubagentUpdateSkillsRequest = { skills?: Array; }; -export type UsersCreateUserRequest = { - avatar_url?: string; - display_name?: string; - email?: string; - is_active?: boolean; - password?: string; - role?: string; - username?: string; -}; - -export type UsersListUsersResponse = { - items?: Array; -}; - -export type UsersResetPasswordRequest = { - new_password?: string; -}; - -export type UsersUpdatePasswordRequest = { - current_password?: string; - new_password?: string; -}; - -export type UsersUpdateProfileRequest = { - avatar_url?: string; - display_name?: string; -}; - -export type UsersUpdateUserRequest = { - avatar_url?: string; - display_name?: string; - is_active?: boolean; - role?: string; -}; - -export type UsersUser = { - avatar_url?: string; - created_at?: string; - display_name?: string; - email?: string; - id?: string; - is_active?: boolean; - last_login_at?: string; - role?: string; - updated_at?: string; - username?: string; -}; - export type PostAuthLoginData = { /** * Login request @@ -996,80 +817,6 @@ export type PostBotsResponses = { export type PostBotsResponse = PostBotsResponses[keyof PostBotsResponses]; -export type PostBotsByBotIdChatData = { - /** - * Chat request - */ - body: ChatChatRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/chat'; -}; - -export type PostBotsByBotIdChatErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdChatError = PostBotsByBotIdChatErrors[keyof PostBotsByBotIdChatErrors]; - -export type PostBotsByBotIdChatResponses = { - /** - * OK - */ - 200: ChatChatResponse; -}; - -export type PostBotsByBotIdChatResponse = PostBotsByBotIdChatResponses[keyof PostBotsByBotIdChatResponses]; - -export type PostBotsByBotIdChatStreamData = { - /** - * Chat request - */ - body: ChatChatRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/chat/stream'; -}; - -export type PostBotsByBotIdChatStreamErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdChatStreamError = PostBotsByBotIdChatStreamErrors[keyof PostBotsByBotIdChatStreamErrors]; - -export type PostBotsByBotIdChatStreamResponses = { - /** - * OK - */ - 200: string; -}; - -export type PostBotsByBotIdChatStreamResponse = PostBotsByBotIdChatStreamResponses[keyof PostBotsByBotIdChatStreamResponses]; - export type DeleteBotsByBotIdContainerData = { body?: never; path: { @@ -1173,409 +920,6 @@ export type PostBotsByBotIdContainerResponses = { export type PostBotsByBotIdContainerResponse = PostBotsByBotIdContainerResponses[keyof PostBotsByBotIdContainerResponses]; -export type DeleteBotsByBotIdContainerFsData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query: { - /** - * Relative path - */ - path: string; - /** - * Recursive delete for directories - */ - recursive?: boolean; - }; - url: '/bots/{bot_id}/container/fs'; -}; - -export type DeleteBotsByBotIdContainerFsErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type DeleteBotsByBotIdContainerFsError = DeleteBotsByBotIdContainerFsErrors[keyof DeleteBotsByBotIdContainerFsErrors]; - -export type DeleteBotsByBotIdContainerFsResponses = { - /** - * OK - */ - 200: HandlersFsDeleteResponse; -}; - -export type DeleteBotsByBotIdContainerFsResponse = DeleteBotsByBotIdContainerFsResponses[keyof DeleteBotsByBotIdContainerFsResponses]; - -export type GetBotsByBotIdContainerFsData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: { - /** - * Relative directory path - */ - path?: string; - /** - * Recursive listing - */ - recursive?: boolean; - }; - url: '/bots/{bot_id}/container/fs'; -}; - -export type GetBotsByBotIdContainerFsErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type GetBotsByBotIdContainerFsError = GetBotsByBotIdContainerFsErrors[keyof GetBotsByBotIdContainerFsErrors]; - -export type GetBotsByBotIdContainerFsResponses = { - /** - * OK - */ - 200: HandlersFsListResponse; -}; - -export type GetBotsByBotIdContainerFsResponse = GetBotsByBotIdContainerFsResponses[keyof GetBotsByBotIdContainerFsResponses]; - -export type PostBotsByBotIdContainerFsMcpData = { - /** - * JSON-RPC request - */ - body: { - [key: string]: unknown; - }; - headers: { - /** - * Bearer - */ - Authorization: string; - }; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/container/fs-mcp'; -}; - -export type PostBotsByBotIdContainerFsMcpErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdContainerFsMcpError = PostBotsByBotIdContainerFsMcpErrors[keyof PostBotsByBotIdContainerFsMcpErrors]; - -export type PostBotsByBotIdContainerFsMcpResponses = { - /** - * JSON-RPC response: {jsonrpc,id,result|error} - */ - 200: { - [key: string]: unknown; - }; -}; - -export type PostBotsByBotIdContainerFsMcpResponse = PostBotsByBotIdContainerFsMcpResponses[keyof PostBotsByBotIdContainerFsMcpResponses]; - -export type PostBotsByBotIdContainerFsDirData = { - /** - * Directory payload - */ - body: HandlersFsMkdirRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/container/fs/dir'; -}; - -export type PostBotsByBotIdContainerFsDirErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdContainerFsDirError = PostBotsByBotIdContainerFsDirErrors[keyof PostBotsByBotIdContainerFsDirErrors]; - -export type PostBotsByBotIdContainerFsDirResponses = { - /** - * OK - */ - 200: HandlersFsWriteResponse; -}; - -export type PostBotsByBotIdContainerFsDirResponse = PostBotsByBotIdContainerFsDirResponses[keyof PostBotsByBotIdContainerFsDirResponses]; - -export type GetBotsByBotIdContainerFsFileData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query: { - /** - * Relative file path - */ - path: string; - }; - url: '/bots/{bot_id}/container/fs/file'; -}; - -export type GetBotsByBotIdContainerFsFileErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type GetBotsByBotIdContainerFsFileError = GetBotsByBotIdContainerFsFileErrors[keyof GetBotsByBotIdContainerFsFileErrors]; - -export type GetBotsByBotIdContainerFsFileResponses = { - /** - * OK - */ - 200: HandlersFsReadResponse; -}; - -export type GetBotsByBotIdContainerFsFileResponse = GetBotsByBotIdContainerFsFileResponses[keyof GetBotsByBotIdContainerFsFileResponses]; - -export type PostBotsByBotIdContainerFsFileData = { - /** - * File write payload - */ - body: HandlersFsWriteRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/container/fs/file'; -}; - -export type PostBotsByBotIdContainerFsFileErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Conflict - */ - 409: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdContainerFsFileError = PostBotsByBotIdContainerFsFileErrors[keyof PostBotsByBotIdContainerFsFileErrors]; - -export type PostBotsByBotIdContainerFsFileResponses = { - /** - * OK - */ - 200: HandlersFsWriteResponse; -}; - -export type PostBotsByBotIdContainerFsFileResponse = PostBotsByBotIdContainerFsFileResponses[keyof PostBotsByBotIdContainerFsFileResponses]; - -export type GetBotsByBotIdContainerFsStatData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query: { - /** - * Relative path - */ - path: string; - }; - url: '/bots/{bot_id}/container/fs/stat'; -}; - -export type GetBotsByBotIdContainerFsStatErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type GetBotsByBotIdContainerFsStatError = GetBotsByBotIdContainerFsStatErrors[keyof GetBotsByBotIdContainerFsStatErrors]; - -export type GetBotsByBotIdContainerFsStatResponses = { - /** - * OK - */ - 200: HandlersFsStatResponse; -}; - -export type GetBotsByBotIdContainerFsStatResponse = GetBotsByBotIdContainerFsStatResponses[keyof GetBotsByBotIdContainerFsStatResponses]; - -export type PostBotsByBotIdContainerFsUploadData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: { - /** - * Relative file path or directory - */ - path?: string; - }; - url: '/bots/{bot_id}/container/fs/upload'; -}; - -export type PostBotsByBotIdContainerFsUploadErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdContainerFsUploadError = PostBotsByBotIdContainerFsUploadErrors[keyof PostBotsByBotIdContainerFsUploadErrors]; - -export type PostBotsByBotIdContainerFsUploadResponses = { - /** - * OK - */ - 200: HandlersFsWriteResponse; -}; - -export type PostBotsByBotIdContainerFsUploadResponse = PostBotsByBotIdContainerFsUploadResponses[keyof PostBotsByBotIdContainerFsUploadResponses]; - -export type GetBotsByBotIdContainerFsUsageData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: { - /** - * Relative directory path - */ - path?: string; - }; - url: '/bots/{bot_id}/container/fs/usage'; -}; - -export type GetBotsByBotIdContainerFsUsageErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type GetBotsByBotIdContainerFsUsageError = GetBotsByBotIdContainerFsUsageErrors[keyof GetBotsByBotIdContainerFsUsageErrors]; - -export type GetBotsByBotIdContainerFsUsageResponses = { - /** - * OK - */ - 200: HandlersFsUsageResponse; -}; - -export type GetBotsByBotIdContainerFsUsageResponse = GetBotsByBotIdContainerFsUsageResponses[keyof GetBotsByBotIdContainerFsUsageResponses]; - export type DeleteBotsByBotIdContainerSkillsData = { /** * Delete skills payload @@ -1831,204 +1175,9 @@ export type PostBotsByBotIdContainerStopResponses = { export type PostBotsByBotIdContainerStopResponse = PostBotsByBotIdContainerStopResponses[keyof PostBotsByBotIdContainerStopResponses]; -export type DeleteBotsByBotIdHistoryData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/history'; -}; - -export type DeleteBotsByBotIdHistoryErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type DeleteBotsByBotIdHistoryError = DeleteBotsByBotIdHistoryErrors[keyof DeleteBotsByBotIdHistoryErrors]; - -export type DeleteBotsByBotIdHistoryResponses = { - /** - * No Content - */ - 204: unknown; -}; - -export type GetBotsByBotIdHistoryData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: { - /** - * Limit - */ - limit?: number; - }; - url: '/bots/{bot_id}/history'; -}; - -export type GetBotsByBotIdHistoryErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type GetBotsByBotIdHistoryError = GetBotsByBotIdHistoryErrors[keyof GetBotsByBotIdHistoryErrors]; - -export type GetBotsByBotIdHistoryResponses = { - /** - * OK - */ - 200: HistoryListResponse; -}; - -export type GetBotsByBotIdHistoryResponse = GetBotsByBotIdHistoryResponses[keyof GetBotsByBotIdHistoryResponses]; - -export type PostBotsByBotIdHistoryData = { - /** - * History payload - */ - body: HistoryCreateRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/history'; -}; - -export type PostBotsByBotIdHistoryErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdHistoryError = PostBotsByBotIdHistoryErrors[keyof PostBotsByBotIdHistoryErrors]; - -export type PostBotsByBotIdHistoryResponses = { - /** - * Created - */ - 201: HistoryRecord; -}; - -export type PostBotsByBotIdHistoryResponse = PostBotsByBotIdHistoryResponses[keyof PostBotsByBotIdHistoryResponses]; - -export type DeleteBotsByBotIdHistoryByIdData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - /** - * History ID - */ - id: string; - }; - query?: never; - url: '/bots/{bot_id}/history/{id}'; -}; - -export type DeleteBotsByBotIdHistoryByIdErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Forbidden - */ - 403: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type DeleteBotsByBotIdHistoryByIdError = DeleteBotsByBotIdHistoryByIdErrors[keyof DeleteBotsByBotIdHistoryByIdErrors]; - -export type DeleteBotsByBotIdHistoryByIdResponses = { - /** - * No Content - */ - 204: unknown; -}; - -export type GetBotsByBotIdHistoryByIdData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - /** - * History ID - */ - id: string; - }; - query?: never; - url: '/bots/{bot_id}/history/{id}'; -}; - -export type GetBotsByBotIdHistoryByIdErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Not Found - */ - 404: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type GetBotsByBotIdHistoryByIdError = GetBotsByBotIdHistoryByIdErrors[keyof GetBotsByBotIdHistoryByIdErrors]; - -export type GetBotsByBotIdHistoryByIdResponses = { - /** - * OK - */ - 200: HistoryRecord; -}; - -export type GetBotsByBotIdHistoryByIdResponse = GetBotsByBotIdHistoryByIdResponses[keyof GetBotsByBotIdHistoryByIdResponses]; - export type GetBotsByBotIdMcpData = { body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/mcp'; }; @@ -2068,12 +1217,7 @@ export type PostBotsByBotIdMcpData = { * MCP payload */ body: McpUpsertRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/mcp'; }; @@ -2149,7 +1293,7 @@ export type PostBotsByBotIdMcpStdioResponses = { export type PostBotsByBotIdMcpStdioResponse = PostBotsByBotIdMcpStdioResponses[keyof PostBotsByBotIdMcpStdioResponses]; -export type PostBotsByBotIdMcpStdioBySessionIdData = { +export type PostBotsByBotIdMcpStdioByConnectionIdData = { /** * JSON-RPC request */ @@ -2162,15 +1306,15 @@ export type PostBotsByBotIdMcpStdioBySessionIdData = { */ bot_id: string; /** - * Session ID + * Connection ID */ - session_id: string; + connection_id: string; }; query?: never; - url: '/bots/{bot_id}/mcp-stdio/{session_id}'; + url: '/bots/{bot_id}/mcp-stdio/{connection_id}'; }; -export type PostBotsByBotIdMcpStdioBySessionIdErrors = { +export type PostBotsByBotIdMcpStdioByConnectionIdErrors = { /** * Bad Request */ @@ -2185,9 +1329,9 @@ export type PostBotsByBotIdMcpStdioBySessionIdErrors = { 500: HandlersErrorResponse; }; -export type PostBotsByBotIdMcpStdioBySessionIdError = PostBotsByBotIdMcpStdioBySessionIdErrors[keyof PostBotsByBotIdMcpStdioBySessionIdErrors]; +export type PostBotsByBotIdMcpStdioByConnectionIdError = PostBotsByBotIdMcpStdioByConnectionIdErrors[keyof PostBotsByBotIdMcpStdioByConnectionIdErrors]; -export type PostBotsByBotIdMcpStdioBySessionIdResponses = { +export type PostBotsByBotIdMcpStdioByConnectionIdResponses = { /** * JSON-RPC response: {jsonrpc,id,result|error} */ @@ -2196,15 +1340,11 @@ export type PostBotsByBotIdMcpStdioBySessionIdResponses = { }; }; -export type PostBotsByBotIdMcpStdioBySessionIdResponse = PostBotsByBotIdMcpStdioBySessionIdResponses[keyof PostBotsByBotIdMcpStdioBySessionIdResponses]; +export type PostBotsByBotIdMcpStdioByConnectionIdResponse = PostBotsByBotIdMcpStdioByConnectionIdResponses[keyof PostBotsByBotIdMcpStdioByConnectionIdResponses]; export type DeleteBotsByBotIdMcpByIdData = { body?: never; path: { - /** - * Bot ID - */ - bot_id: string; /** * MCP ID */ @@ -2245,10 +1385,6 @@ export type DeleteBotsByBotIdMcpByIdResponses = { export type GetBotsByBotIdMcpByIdData = { body?: never; path: { - /** - * Bot ID - */ - bot_id: string; /** * MCP ID */ @@ -2294,10 +1430,6 @@ export type PutBotsByBotIdMcpByIdData = { */ body: McpUpsertRequest; path: { - /** - * Bot ID - */ - bot_id: string; /** * MCP ID */ @@ -2337,318 +1469,9 @@ export type PutBotsByBotIdMcpByIdResponses = { export type PutBotsByBotIdMcpByIdResponse = PutBotsByBotIdMcpByIdResponses[keyof PutBotsByBotIdMcpByIdResponses]; -export type PostBotsByBotIdMemoryAddData = { - /** - * Add request - */ - body: HandlersMemoryAddPayload; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/memory/add'; -}; - -export type PostBotsByBotIdMemoryAddErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdMemoryAddError = PostBotsByBotIdMemoryAddErrors[keyof PostBotsByBotIdMemoryAddErrors]; - -export type PostBotsByBotIdMemoryAddResponses = { - /** - * OK - */ - 200: MemorySearchResponse; -}; - -export type PostBotsByBotIdMemoryAddResponse = PostBotsByBotIdMemoryAddResponses[keyof PostBotsByBotIdMemoryAddResponses]; - -export type PostBotsByBotIdMemoryEmbedData = { - /** - * Embed upsert request - */ - body: HandlersMemoryEmbedUpsertPayload; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/memory/embed'; -}; - -export type PostBotsByBotIdMemoryEmbedErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdMemoryEmbedError = PostBotsByBotIdMemoryEmbedErrors[keyof PostBotsByBotIdMemoryEmbedErrors]; - -export type PostBotsByBotIdMemoryEmbedResponses = { - /** - * OK - */ - 200: MemoryEmbedUpsertResponse; -}; - -export type PostBotsByBotIdMemoryEmbedResponse = PostBotsByBotIdMemoryEmbedResponses[keyof PostBotsByBotIdMemoryEmbedResponses]; - -export type DeleteBotsByBotIdMemoryMemoriesData = { - /** - * Delete all request - */ - body: HandlersMemoryDeleteAllPayload; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/memory/memories'; -}; - -export type DeleteBotsByBotIdMemoryMemoriesErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type DeleteBotsByBotIdMemoryMemoriesError = DeleteBotsByBotIdMemoryMemoriesErrors[keyof DeleteBotsByBotIdMemoryMemoriesErrors]; - -export type DeleteBotsByBotIdMemoryMemoriesResponses = { - /** - * OK - */ - 200: MemoryDeleteResponse; -}; - -export type DeleteBotsByBotIdMemoryMemoriesResponse = DeleteBotsByBotIdMemoryMemoriesResponses[keyof DeleteBotsByBotIdMemoryMemoriesResponses]; - -export type GetBotsByBotIdMemoryMemoriesData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: { - /** - * Run ID - */ - run_id?: string; - /** - * Limit - */ - limit?: number; - }; - url: '/bots/{bot_id}/memory/memories'; -}; - -export type GetBotsByBotIdMemoryMemoriesErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type GetBotsByBotIdMemoryMemoriesError = GetBotsByBotIdMemoryMemoriesErrors[keyof GetBotsByBotIdMemoryMemoriesErrors]; - -export type GetBotsByBotIdMemoryMemoriesResponses = { - /** - * OK - */ - 200: MemorySearchResponse; -}; - -export type GetBotsByBotIdMemoryMemoriesResponse = GetBotsByBotIdMemoryMemoriesResponses[keyof GetBotsByBotIdMemoryMemoriesResponses]; - -export type DeleteBotsByBotIdMemoryMemoriesByMemoryIdData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - /** - * Memory ID - */ - memoryId: string; - }; - query?: never; - url: '/bots/{bot_id}/memory/memories/{memoryId}'; -}; - -export type DeleteBotsByBotIdMemoryMemoriesByMemoryIdErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type DeleteBotsByBotIdMemoryMemoriesByMemoryIdError = DeleteBotsByBotIdMemoryMemoriesByMemoryIdErrors[keyof DeleteBotsByBotIdMemoryMemoriesByMemoryIdErrors]; - -export type DeleteBotsByBotIdMemoryMemoriesByMemoryIdResponses = { - /** - * OK - */ - 200: MemoryDeleteResponse; -}; - -export type DeleteBotsByBotIdMemoryMemoriesByMemoryIdResponse = DeleteBotsByBotIdMemoryMemoriesByMemoryIdResponses[keyof DeleteBotsByBotIdMemoryMemoriesByMemoryIdResponses]; - -export type GetBotsByBotIdMemoryMemoriesByMemoryIdData = { - body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - /** - * Memory ID - */ - memoryId: string; - }; - query?: never; - url: '/bots/{bot_id}/memory/memories/{memoryId}'; -}; - -export type GetBotsByBotIdMemoryMemoriesByMemoryIdErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type GetBotsByBotIdMemoryMemoriesByMemoryIdError = GetBotsByBotIdMemoryMemoriesByMemoryIdErrors[keyof GetBotsByBotIdMemoryMemoriesByMemoryIdErrors]; - -export type GetBotsByBotIdMemoryMemoriesByMemoryIdResponses = { - /** - * OK - */ - 200: MemoryMemoryItem; -}; - -export type GetBotsByBotIdMemoryMemoriesByMemoryIdResponse = GetBotsByBotIdMemoryMemoriesByMemoryIdResponses[keyof GetBotsByBotIdMemoryMemoriesByMemoryIdResponses]; - -export type PostBotsByBotIdMemorySearchData = { - /** - * Search request - */ - body: HandlersMemorySearchPayload; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/memory/search'; -}; - -export type PostBotsByBotIdMemorySearchErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdMemorySearchError = PostBotsByBotIdMemorySearchErrors[keyof PostBotsByBotIdMemorySearchErrors]; - -export type PostBotsByBotIdMemorySearchResponses = { - /** - * OK - */ - 200: MemorySearchResponse; -}; - -export type PostBotsByBotIdMemorySearchResponse = PostBotsByBotIdMemorySearchResponses[keyof PostBotsByBotIdMemorySearchResponses]; - -export type PostBotsByBotIdMemoryUpdateData = { - /** - * Update request - */ - body: MemoryUpdateRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; - query?: never; - url: '/bots/{bot_id}/memory/update'; -}; - -export type PostBotsByBotIdMemoryUpdateErrors = { - /** - * Bad Request - */ - 400: HandlersErrorResponse; - /** - * Internal Server Error - */ - 500: HandlersErrorResponse; -}; - -export type PostBotsByBotIdMemoryUpdateError = PostBotsByBotIdMemoryUpdateErrors[keyof PostBotsByBotIdMemoryUpdateErrors]; - -export type PostBotsByBotIdMemoryUpdateResponses = { - /** - * OK - */ - 200: MemoryMemoryItem; -}; - -export type PostBotsByBotIdMemoryUpdateResponse = PostBotsByBotIdMemoryUpdateResponses[keyof PostBotsByBotIdMemoryUpdateResponses]; - export type GetBotsByBotIdScheduleData = { body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/schedule'; }; @@ -2680,12 +1503,7 @@ export type PostBotsByBotIdScheduleData = { * Schedule payload */ body: ScheduleCreateRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/schedule'; }; @@ -2715,10 +1533,6 @@ export type PostBotsByBotIdScheduleResponse = PostBotsByBotIdScheduleResponses[k export type DeleteBotsByBotIdScheduleByIdData = { body?: never; path: { - /** - * Bot ID - */ - bot_id: string; /** * Schedule ID */ @@ -2751,10 +1565,6 @@ export type DeleteBotsByBotIdScheduleByIdResponses = { export type GetBotsByBotIdScheduleByIdData = { body?: never; path: { - /** - * Bot ID - */ - bot_id: string; /** * Schedule ID */ @@ -2796,10 +1606,6 @@ export type PutBotsByBotIdScheduleByIdData = { */ body: ScheduleUpdateRequest; path: { - /** - * Bot ID - */ - bot_id: string; /** * Schedule ID */ @@ -2833,12 +1639,7 @@ export type PutBotsByBotIdScheduleByIdResponse = PutBotsByBotIdScheduleByIdRespo export type DeleteBotsByBotIdSettingsData = { body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/settings'; }; @@ -2865,12 +1666,7 @@ export type DeleteBotsByBotIdSettingsResponses = { export type GetBotsByBotIdSettingsData = { body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/settings'; }; @@ -2902,12 +1698,7 @@ export type PostBotsByBotIdSettingsData = { * Settings payload */ body: SettingsUpsertRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/settings'; }; @@ -2939,12 +1730,7 @@ export type PutBotsByBotIdSettingsData = { * Settings payload */ body: SettingsUpsertRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/settings'; }; @@ -2973,12 +1759,7 @@ export type PutBotsByBotIdSettingsResponse = PutBotsByBotIdSettingsResponses[key export type GetBotsByBotIdSubagentsData = { body?: never; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/subagents'; }; @@ -3010,12 +1791,7 @@ export type PostBotsByBotIdSubagentsData = { * Subagent payload */ body: SubagentCreateRequest; - path: { - /** - * Bot ID - */ - bot_id: string; - }; + path?: never; query?: never; url: '/bots/{bot_id}/subagents'; }; @@ -3045,10 +1821,6 @@ export type PostBotsByBotIdSubagentsResponse = PostBotsByBotIdSubagentsResponses export type DeleteBotsByBotIdSubagentsByIdData = { body?: never; path: { - /** - * Bot ID - */ - bot_id: string; /** * Subagent ID */ @@ -3085,10 +1857,6 @@ export type DeleteBotsByBotIdSubagentsByIdResponses = { export type GetBotsByBotIdSubagentsByIdData = { body?: never; path: { - /** - * Bot ID - */ - bot_id: string; /** * Subagent ID */ @@ -3130,10 +1898,6 @@ export type PutBotsByBotIdSubagentsByIdData = { */ body: SubagentUpdateRequest; path: { - /** - * Bot ID - */ - bot_id: string; /** * Subagent ID */ @@ -3172,10 +1936,6 @@ export type PutBotsByBotIdSubagentsByIdResponse = PutBotsByBotIdSubagentsByIdRes export type GetBotsByBotIdSubagentsByIdContextData = { body?: never; path: { - /** - * Bot ID - */ - bot_id: string; /** * Subagent ID */ @@ -3217,10 +1977,6 @@ export type PutBotsByBotIdSubagentsByIdContextData = { */ body: SubagentUpdateContextRequest; path: { - /** - * Bot ID - */ - bot_id: string; /** * Subagent ID */ @@ -3259,10 +2015,6 @@ export type PutBotsByBotIdSubagentsByIdContextResponse = PutBotsByBotIdSubagents export type GetBotsByBotIdSubagentsByIdSkillsData = { body?: never; path: { - /** - * Bot ID - */ - bot_id: string; /** * Subagent ID */ @@ -3304,10 +2056,6 @@ export type PostBotsByBotIdSubagentsByIdSkillsData = { */ body: SubagentAddSkillsRequest; path: { - /** - * Bot ID - */ - bot_id: string; /** * Subagent ID */ @@ -3349,10 +2097,6 @@ export type PutBotsByBotIdSubagentsByIdSkillsData = { */ body: SubagentUpdateSkillsRequest; path: { - /** - * Bot ID - */ - bot_id: string; /** * Subagent ID */ @@ -3388,6 +2132,51 @@ export type PutBotsByBotIdSubagentsByIdSkillsResponses = { export type PutBotsByBotIdSubagentsByIdSkillsResponse = PutBotsByBotIdSubagentsByIdSkillsResponses[keyof PutBotsByBotIdSubagentsByIdSkillsResponses]; +export type PostBotsByBotIdToolsData = { + /** + * JSON-RPC request + */ + body: { + [key: string]: unknown; + }; + path: { + /** + * Bot ID + */ + bot_id: string; + }; + query?: never; + url: '/bots/{bot_id}/tools'; +}; + +export type PostBotsByBotIdToolsErrors = { + /** + * Bad Request + */ + 400: HandlersErrorResponse; + /** + * Not Found + */ + 404: HandlersErrorResponse; + /** + * Internal Server Error + */ + 500: HandlersErrorResponse; +}; + +export type PostBotsByBotIdToolsError = PostBotsByBotIdToolsErrors[keyof PostBotsByBotIdToolsErrors]; + +export type PostBotsByBotIdToolsResponses = { + /** + * JSON-RPC response: {jsonrpc,id,result|error} + */ + 200: { + [key: string]: unknown; + }; +}; + +export type PostBotsByBotIdToolsResponse = PostBotsByBotIdToolsResponses[keyof PostBotsByBotIdToolsResponses]; + export type DeleteBotsByIdData = { body?: never; path: { @@ -3423,11 +2212,15 @@ export type DeleteBotsByIdError = DeleteBotsByIdErrors[keyof DeleteBotsByIdError export type DeleteBotsByIdResponses = { /** - * No Content + * Accepted */ - 204: unknown; + 202: { + [key: string]: string; + }; }; +export type DeleteBotsByIdResponse = DeleteBotsByIdResponses[keyof DeleteBotsByIdResponses]; + export type GetBotsByIdData = { body?: never; path: { @@ -3661,7 +2454,7 @@ export type PostBotsByIdChannelByPlatformSendResponses = { export type PostBotsByIdChannelByPlatformSendResponse = PostBotsByIdChannelByPlatformSendResponses[keyof PostBotsByIdChannelByPlatformSendResponses]; -export type PostBotsByIdChannelByPlatformSendSessionData = { +export type PostBotsByIdChannelByPlatformSendChatData = { /** * Send payload */ @@ -3677,10 +2470,10 @@ export type PostBotsByIdChannelByPlatformSendSessionData = { platform: string; }; query?: never; - url: '/bots/{id}/channel/{platform}/send_session'; + url: '/bots/{id}/channel/{platform}/send_chat'; }; -export type PostBotsByIdChannelByPlatformSendSessionErrors = { +export type PostBotsByIdChannelByPlatformSendChatErrors = { /** * Bad Request */ @@ -3699,9 +2492,9 @@ export type PostBotsByIdChannelByPlatformSendSessionErrors = { 500: HandlersErrorResponse; }; -export type PostBotsByIdChannelByPlatformSendSessionError = PostBotsByIdChannelByPlatformSendSessionErrors[keyof PostBotsByIdChannelByPlatformSendSessionErrors]; +export type PostBotsByIdChannelByPlatformSendChatError = PostBotsByIdChannelByPlatformSendChatErrors[keyof PostBotsByIdChannelByPlatformSendChatErrors]; -export type PostBotsByIdChannelByPlatformSendSessionResponses = { +export type PostBotsByIdChannelByPlatformSendChatResponses = { /** * OK */ @@ -3710,7 +2503,49 @@ export type PostBotsByIdChannelByPlatformSendSessionResponses = { }; }; -export type PostBotsByIdChannelByPlatformSendSessionResponse = PostBotsByIdChannelByPlatformSendSessionResponses[keyof PostBotsByIdChannelByPlatformSendSessionResponses]; +export type PostBotsByIdChannelByPlatformSendChatResponse = PostBotsByIdChannelByPlatformSendChatResponses[keyof PostBotsByIdChannelByPlatformSendChatResponses]; + +export type GetBotsByIdChecksData = { + body?: never; + path: { + /** + * Bot ID + */ + id: string; + }; + query?: never; + url: '/bots/{id}/checks'; +}; + +export type GetBotsByIdChecksErrors = { + /** + * Bad Request + */ + 400: HandlersErrorResponse; + /** + * Forbidden + */ + 403: HandlersErrorResponse; + /** + * Not Found + */ + 404: HandlersErrorResponse; + /** + * Internal Server Error + */ + 500: HandlersErrorResponse; +}; + +export type GetBotsByIdChecksError = GetBotsByIdChecksErrors[keyof GetBotsByIdChecksErrors]; + +export type GetBotsByIdChecksResponses = { + /** + * OK + */ + 200: BotsListChecksResponse; +}; + +export type GetBotsByIdChecksResponse = GetBotsByIdChecksResponses[keyof GetBotsByIdChecksResponses]; export type GetBotsByIdMembersData = { body?: never; @@ -4677,7 +3512,7 @@ export type GetUsersResponses = { /** * OK */ - 200: UsersListUsersResponse; + 200: AccountsListAccountsResponse; }; export type GetUsersResponse = GetUsersResponses[keyof GetUsersResponses]; @@ -4686,7 +3521,7 @@ export type PostUsersData = { /** * User payload */ - body: UsersCreateUserRequest; + body: AccountsCreateAccountRequest; path?: never; query?: never; url: '/users'; @@ -4713,7 +3548,7 @@ export type PostUsersResponses = { /** * Created */ - 201: UsersUser; + 201: AccountsAccount; }; export type PostUsersResponse = PostUsersResponses[keyof PostUsersResponses]; @@ -4742,7 +3577,7 @@ export type GetUsersMeResponses = { /** * OK */ - 200: UsersUser; + 200: AccountsAccount; }; export type GetUsersMeResponse = GetUsersMeResponses[keyof GetUsersMeResponses]; @@ -4751,7 +3586,7 @@ export type PutUsersMeData = { /** * Profile payload */ - body: UsersUpdateProfileRequest; + body: AccountsUpdateProfileRequest; path?: never; query?: never; url: '/users/me'; @@ -4774,7 +3609,7 @@ export type PutUsersMeResponses = { /** * OK */ - 200: UsersUser; + 200: AccountsAccount; }; export type PutUsersMeResponse = PutUsersMeResponses[keyof PutUsersMeResponses]; @@ -4812,7 +3647,7 @@ export type GetUsersMeChannelsByPlatformResponses = { /** * OK */ - 200: ChannelChannelUserBinding; + 200: ChannelChannelIdentityBinding; }; export type GetUsersMeChannelsByPlatformResponse = GetUsersMeChannelsByPlatformResponses[keyof GetUsersMeChannelsByPlatformResponses]; @@ -4821,7 +3656,7 @@ export type PutUsersMeChannelsByPlatformData = { /** * Channel user config payload */ - body: ChannelUpsertUserConfigRequest; + body: ChannelUpsertChannelIdentityConfigRequest; path: { /** * Channel platform @@ -4849,16 +3684,49 @@ export type PutUsersMeChannelsByPlatformResponses = { /** * OK */ - 200: ChannelChannelUserBinding; + 200: ChannelChannelIdentityBinding; }; export type PutUsersMeChannelsByPlatformResponse = PutUsersMeChannelsByPlatformResponses[keyof PutUsersMeChannelsByPlatformResponses]; +export type GetUsersMeIdentitiesData = { + body?: never; + path?: never; + query?: never; + url: '/users/me/identities'; +}; + +export type GetUsersMeIdentitiesErrors = { + /** + * Bad Request + */ + 400: HandlersErrorResponse; + /** + * Not Found + */ + 404: HandlersErrorResponse; + /** + * Internal Server Error + */ + 500: HandlersErrorResponse; +}; + +export type GetUsersMeIdentitiesError = GetUsersMeIdentitiesErrors[keyof GetUsersMeIdentitiesErrors]; + +export type GetUsersMeIdentitiesResponses = { + /** + * OK + */ + 200: HandlersListMyIdentitiesResponse; +}; + +export type GetUsersMeIdentitiesResponse = GetUsersMeIdentitiesResponses[keyof GetUsersMeIdentitiesResponses]; + export type PutUsersMePasswordData = { /** * Password payload */ - body: UsersUpdatePasswordRequest; + body: AccountsUpdatePasswordRequest; path?: never; query?: never; url: '/users/me/password'; @@ -4921,7 +3789,7 @@ export type GetUsersByIdResponses = { /** * OK */ - 200: UsersUser; + 200: AccountsAccount; }; export type GetUsersByIdResponse = GetUsersByIdResponses[keyof GetUsersByIdResponses]; @@ -4930,7 +3798,7 @@ export type PutUsersByIdData = { /** * User update payload */ - body: UsersUpdateUserRequest; + body: AccountsUpdateAccountRequest; path: { /** * User ID @@ -4966,7 +3834,7 @@ export type PutUsersByIdResponses = { /** * OK */ - 200: UsersUser; + 200: AccountsAccount; }; export type PutUsersByIdResponse = PutUsersByIdResponses[keyof PutUsersByIdResponses]; @@ -4975,7 +3843,7 @@ export type PutUsersByIdPasswordData = { /** * Password payload */ - body: UsersResetPasswordRequest; + body: AccountsResetPasswordRequest; path: { /** * User ID diff --git a/packages/shared/README.md b/packages/shared/README.md new file mode 100644 index 00000000..6554adbe --- /dev/null +++ b/packages/shared/README.md @@ -0,0 +1 @@ +# @memoh/shared diff --git a/packages/shared/package.json b/packages/shared/package.json new file mode 100644 index 00000000..6e841596 --- /dev/null +++ b/packages/shared/package.json @@ -0,0 +1,13 @@ +{ + "name": "@memoh/shared", + "version": "1.0.0", + "description": "", + "exports": { + ".": "./src/index.ts" + }, + "main": "index.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "packageManager": "pnpm@10.27.0" +} diff --git a/packages/shared/src/chatInfo.ts b/packages/shared/src/chatInfo.ts new file mode 100644 index 00000000..3a11d13b --- /dev/null +++ b/packages/shared/src/chatInfo.ts @@ -0,0 +1,15 @@ +export interface robot{ + description: string + time: Date, + id: string | number, + type: string, + action: 'robot', + state:'thinking'|'generate'|'complete' +} + +export interface user{ + description: string, + time: Date, + id: number | string, + action:'user' +} \ No newline at end of file diff --git a/packages/shared/src/index.ts b/packages/shared/src/index.ts new file mode 100644 index 00000000..f23cc870 --- /dev/null +++ b/packages/shared/src/index.ts @@ -0,0 +1,5 @@ +export * from './model' +export * from './schedule' +export * from './platform' +export * from './mcp' +export * from './chatInfo' diff --git a/packages/shared/src/mcp.ts b/packages/shared/src/mcp.ts new file mode 100644 index 00000000..da2993b1 --- /dev/null +++ b/packages/shared/src/mcp.ts @@ -0,0 +1,48 @@ +export interface BaseMCPConnection { + type: string + name: string +} + +export interface StdioMCPConnection extends BaseMCPConnection { + type: 'stdio' + command: string + args: string[] + env: Record + cwd: string +} + +export interface BaseHTTPMCPConnection extends BaseMCPConnection { + url: string + headers: Record +} + +export interface HTTPMCPConnection extends BaseHTTPMCPConnection { + type: 'http' +} + +export interface SSEMCPConnection extends BaseHTTPMCPConnection { + type: 'sse' +} + +export type MCPConnection = + | StdioMCPConnection + | HTTPMCPConnection + | SSEMCPConnection + + +export interface MCPListItem{ + id: string; + type: string; + name: string; + config: { + cwd: string; + env: Record; + args: string[]; + type: string; + command: string; + }; + active: boolean; + user: string; + createdAt: string; + updatedAt: string; +} \ No newline at end of file diff --git a/packages/shared/src/model.ts b/packages/shared/src/model.ts new file mode 100644 index 00000000..b48c52e0 --- /dev/null +++ b/packages/shared/src/model.ts @@ -0,0 +1,101 @@ +export enum ModelClientType { + OPENAI = 'openai', + ANTHROPIC = 'anthropic', + GOOGLE = 'google', +} + +export enum ModelType { + CHAT = 'chat', + EMBEDDING = 'embedding', +} + +export interface BaseModel { + /** + * @description The unique identifier for the model + * @example 'gpt-4o' + */ + modelId: string + + /** + * @description The base URL for the model + * @example 'https://api.openai.com/v1' + */ + baseUrl: string + + /** + * @description The API key for the model + * @example 'sk-1234567890' + */ + apiKey: string + + /** + * @description The client type for the model + * @enum {ModelClientType} + */ + clientType: ModelClientType + + /** + * @description The display name for the model + * @example 'GPT 4o' + */ + name?: string + + /** + * @description The model type + * @enum {ModelType} + * @default {ModelType.CHAT} + */ + type?: ModelType +} + +export interface EmbeddingModel extends BaseModel { + type?: ModelType.EMBEDDING + + /** + * @description The dimensions of the embedding + * @example 1536 + */ + dimensions: number +} + +export interface ChatModel extends BaseModel { + type?: ModelType.CHAT +} + +export type Model = EmbeddingModel | ChatModel + + +/** Model row type for list/table views. */ +export interface ModelList { + apiKey: string, + baseUrl: string, + clientType: 'OpenAI' | 'Anthropic' | 'Google', + modelId: string, + name: string, + type: 'chat' | 'embedding', + id: string, + defaultChatModel: boolean, + defaultEmbeddingModel: boolean, + defaultSummaryModel: boolean +} + +export interface ProviderInfo{ + api_key: string; + base_url: string; + client_type: string; + metadata: Record<'additionalProp1',object>; + name: string; +} + +export interface ModelInfo{ + dimensions:number + is_multimodal:boolean + input?: string[] + llm_provider_id:string + model_id:string + name:string + type: string + enable_as?:string +} + +export const clientType = ['openai', 'anthropic', 'google', 'ollama'] as const diff --git a/packages/shared/src/platform.ts b/packages/shared/src/platform.ts new file mode 100644 index 00000000..bc371c73 --- /dev/null +++ b/packages/shared/src/platform.ts @@ -0,0 +1,7 @@ +export interface Platform { + id: string + name: string + // endpoint: string + config: Record + active: boolean +} \ No newline at end of file diff --git a/packages/shared/src/schedule.ts b/packages/shared/src/schedule.ts new file mode 100644 index 00000000..e08ca076 --- /dev/null +++ b/packages/shared/src/schedule.ts @@ -0,0 +1,8 @@ +export interface Schedule { + id?: string + pattern: string + name: string + description: string + command: string + maxCalls?: number | null +} \ No newline at end of file diff --git a/packages/web/mise.toml b/packages/web/mise.toml index 78461797..82fe917e 100644 --- a/packages/web/mise.toml +++ b/packages/web/mise.toml @@ -2,23 +2,13 @@ alias = "dev" description = "Start web development server" run = "pnpm dev" -depends = [ - "//:pnpm-install", - "//:sdk-generate", -] +depends = ["//:pnpm-install"] [tasks.build] description = "Build web" run = "pnpm build" -depends = [ - "//:pnpm-install", - "//:sdk-generate", -] +depends = ["//:pnpm-install"] [tasks.start] description = "Start web" -run = "pnpm start" -depends = [ - "//:pnpm-install", - "//:sdk-generate", -] \ No newline at end of file +run = "pnpm start" \ No newline at end of file diff --git a/packages/web/package.json b/packages/web/package.json index 626e4107..58a831d4 100644 --- a/packages/web/package.json +++ b/packages/web/package.json @@ -9,8 +9,14 @@ "start": "vite preview" }, "dependencies": { - "@memoh/ui": "workspace:*", + "@fortawesome/fontawesome-svg-core": "^7.0.0", + "@fortawesome/free-brands-svg-icons": "^7.0.0", + "@fortawesome/free-regular-svg-icons": "^7.0.0", + "@fortawesome/free-solid-svg-icons": "^7.0.0", + "@fortawesome/vue-fontawesome": "^3.1.1", "@memoh/sdk": "workspace:*", + "@memoh/shared": "workspace:*", + "@memoh/ui": "workspace:*", "@pinia/colada": "^0.21.1", "@tailwindcss/vite": "^4.1.18", "@tanstack/vue-table": "^8.21.3", @@ -33,12 +39,7 @@ "vue-i18n": "^11.2.8", "vue-router": "^4.6.4", "vue-sonner": "^2.0.9", - "zod": "^4.3.5", - "@fortawesome/fontawesome-svg-core": "^7.0.0", - "@fortawesome/free-brands-svg-icons": "^7.0.0", - "@fortawesome/free-regular-svg-icons": "^7.0.0", - "@fortawesome/free-solid-svg-icons": "^7.0.0", - "@fortawesome/vue-fontawesome": "^3.1.1" + "zod": "^4.3.5" }, "devDependencies": { "@types/node": "^24.10.1", diff --git a/packages/web/public/channels/feishu.png b/packages/web/public/channels/feishu.png new file mode 100644 index 0000000000000000000000000000000000000000..73a55287721ed08c3cde07feb2dde6faa6eb715f GIT binary patch literal 669 zcmV;O0%HA%P)e~Pzk^tPgOn(0rE|?Qz z{|PskYBt)JPnX40krbS$+zs~;` zdZmk}|DLx0aC*4W-2Zcz|H8-M?(XsK)N50MKokaG7F-0@09Q0Gm6kUui%P;Y?e_nFREwgqdq8&VEANb7cJ`U`p4CxN z%9QzkT3Q9>Tt>FFswSpdlaZZnhU`hzy?|1af=UfYxw49K24(A1c_{}oWJgkpdseR` zA$Z$v&b>Z3sc_D5T7&DG;czsb@Yklk2s@m+y|aS4#ZtHLN9)uzoz6trJCt=E9)pIu zkM+bsjDg%gkn@~k403Df7-3gjbhoDS5~P?Kj1>VQf;=9!o;2!V^0bo>q|Btz*(Qb0mx!*nnKYp=&~R(O~l#nr6lBRbA$B^mj}1j zj6Wv#{mQ^3Ij#`WX|_{5dW;d0^)9BW9lxvVwA+wqm}Qv{_JNF9$l?WnUnGhwJ;!6yIjZR00000NkvXXu0mjf D0N72W literal 0 HcmV?d00001 diff --git a/packages/web/public/channels/telegram.webp b/packages/web/public/channels/telegram.webp new file mode 100644 index 0000000000000000000000000000000000000000..3afc19e3868ffdc9f365680a711948d0db305dab GIT binary patch literal 6724 zcmV-K8oT9ENk&FI8UO%SMM6+kP&iC48UO$ zJG@-?OFYe)44I9XK66Q#wR?6^mO{pbYmBqQC{P7hO!A)SM22)H=maO+vda$tv|aoE z6x&i>{XWB>cg@))UDaK-$0fraVq#u_H?VWgKE1oE{#D^W+Gy-@$5PxKPKcinjVv13 z3)d6j?s`+4J-DP0y1SE&`(#8O@dEDd?v}ZZ&*1J(#Nq*n=A^i5;}8;!?ZrF*o|6%| zEZUK6tG1g`k~a)7UfsJUuJfK)X82`Df*paQa1?X|KyoB0*gOCe4@<+zsf>(_$f#8a zbHBhh?&}n=Z9Cp($6wpFZQHhO+il9(wrwl;SKA!Y)Y)o#jlX9C=vv!<+IH^$HZsm$ z+P3%6wr%fY+lcmBbN&C{%pIKAdyExTu2`yDb+G%`))mZ*(}kB{+nljG8&_deuF~7~ z46eZFY;^pMe-WE<_%@5QQZU2S8mZM&UWn-By5 zEE{XO+U{ICwUrdt{|P{d=C}U{+PS(;*P#suw`VgfZk3yR$uUnIgR#8bju~(W!6DN} za>#~5HXX9(5a>JP-^LraXNgB@PsUH-aC;qsbI9{He8utBxJi~p!mXa2=D%9sp5H_< zaW29STz6?>ecO{f$GvQMKRaaDAzS@QI%MA=zug4wkZV>ca=Jn7iJ{~hUxHA)Qad;X z)y6(^2;MRDejOdMX!%s{uvgvOxb|4vp5B>Lfz_3>x>4=f>yGJi$Zo&pjyZ75faASy z^^{<*R*yV;dS~8;)lGY?V`lo^bjZ$o9QMe3mvNyIN2IkSjW%59Cbfs?d)Fak%iHRZ zins6n&UgPs*38k}cTYQ{v ztt0gVy>WZ??))#M6BScOf8sg%A$QEc1~ai>C7uo1<8gD~ci`>@PA%OI8km|CX8cpahjdaJy#3}ZvMTIRdennCK@L(9-MgssJ*m0lH~v@3~a z1FK!#5WH47SZAT9hZjVKak0#=4MA*`c{h%0!1I60U^fJ_MLKn^DQsqc`#-Dk6UuVQ z8v^=|mbv?Dz_cfZmOHc|ur0G-v4K#TS;r!DivYYW@@{?}F_5fakv~04mM zYs)T#*hjDq`_XKOkjouEU({h^%{worw+uPZ`VT4*{x_y>LkgTF* zvBY7NTI|MW17zmO+QMtYNd2T8M+N8a7E2mNtHmNldn7RKx`&T5j9E8(Z%#lOt$W2V z5V>2;4aO-9^U!c1dX2WpN<4v>qm^fa2)1yWGebr8+kArwc%J@@(co5}>_L5d23>S&BfuXBC^jX6t4J4{hJ&!boersrmd_%Xt#~(&2 z=|{hTuX@AGIvSc5e|D?Qg#bE#=z zQaRx@vJ#KxPE007jorFzeKhy+WO6i6<5XSrtmkBcG^31~EUmdSDME91B;RD}VK7%U z4N1~qFtt#zxu)r9qNkXU8d&|zrs?0kM5*begI3Q#hkMCVWA9h)|C%PJ3r!<-Y+p4k z<9Z46>~pGRD3@N+)O1><=A1okXnJ)4NuWv^%I_skO<$~#)k8>W*Hee4ssFEzhHv(g zr{N4sKsC3osfLzxG|fmMJwtm{L&KkYNmSE=O%)z^?ljUf zR6rFpgw+%zQceAu>Yko4)fAFShGrPnJtfyF8Iq~8N2j_~D0)gp(#gt-_Uo3ipRb%y z3#YY9&rmWYrL6pl?i9(nm6Pg=yYg!dZKV|@)Ahy5-B5Pn%5mzBp$JWzdLI_%zYSVu zlY!Q#tLJ5`SSRox5t@~Y?3FIeQ8F!fKp$GoMm0fY^-EBvbzI8=QV8r zZY&zQU(%G4iIFf*j2$!`s=cllS=FDvV0pUa8(%qvqoM%*ZYYX(r>#CjVldGcKZI3LeWmc52rG&TzEElR!;8dZcH9B;mM@f{$+Wk? z-VH;Cw)zx_)x3;_M$yO1P*elxHY_QM{y=5u(YDcunHCph51kSBSHLjn6vtSH()ENbf##(LIyw^;5*V>G<#8q>tYoI_SCyzd;;o*BS@3tc$FMaaZDyb`_wIe64JZvnDaPjyxh(Q zi~+tP&3W@znZUMqlR*{M#qkv99aw5c?Xso_ZfWfas<3*-E2KG~Jax;T1UNJ5f+{5H z%vefNn1txat2KX92y?6}12tLL_G2kc6hJ&_+o&55>M&)CDvvA2%A*_Lc zLseLqu?o2h79;m*{V5Rgo+nHXT&N0Zag3&rb|5jzs9iP`VQ>4(+hJ#V78~`qNjN@r zORoz?=o@wnTng#Jhp$Bf1%!so9aqBE7eBE$3@POAh<;>U& z(21+}ImAM(*9&^Ev@fkbd!8}kx80N*ap1}uc6DZM9vFDx@kLamU%>R>xqRb+MjZFL zVzZl2G;rlfp@^7m*wK+GP9Ea*gXK={srUZ)#Pg4XW&OZr*3nlsc5+h|vfm*V;^o|e zF#Gc6+>F)zL=eyWT;iV}K^+6ntA<}_Y`2VIs*KiWo zt-pjo$#A~8LGq#7fh2~*H4U=d8b&NMBjGV&D9)2yLVWkUV8`j<>wI(lMU#VA_GE3y za(f%`MWONLuE*_8oS)_^n9n!chqo+1EWfxcvd2(K;4+U1stb63Hs8a<`9n#;aM|PtKy`VC}9Hd(q6_VohSZtntvG7c;)O$1xl= z>0$3~;{L^yyWmdEd>`)8amf(lnGpwDfl16gp`i#A55<+Yjrec=A`h!s`0`W4db9&z zD=?)}$_RoUy6T&Y(+J>;pDwFe=*yowt!QHYO>DVzCg)TKz@4^#dJ7>u?Pbl<0MR%o zvF}PCY`Jv5059hn@O_}TFRr|%AcQ?mc=BDK15ee`JsrgU-75B6{G4l^`EAzv=DgX9 zLKs`TNxu9^wKz!H?@k=P1Fm4V$)NyTu30X~5lkoL- zw&UokfKhf}BTo0566x$Fb?s*? zVRh!^z6K0Q0()_n?TPz1Z;VfRrfXIUyqb;`!rC`XQ%geo3C10W{}5FRPkIpPA=Iq4 zp(Lozg2Km$K}qN!S?3^tuH+6+dPveUQnTF556YV1+FBC)i(v?0r!_6Xm!2Kz>8V*k z@q?tW^fnSm61=Xl=yApgVdIg8PkL6Q=l`0;`(qrRkI+|=@X?OcMi9rs_fcWJsF{g? z3!pgUc_iTzf{~Lju7gjS9XWZWW>ToSc@e?^cm4A(Fx|zpCy!7DD?5DB^F?%O>LvwB z`bz-p_P-xs@6+Xflasaxriq03q~@VCvuY>lpffQTIygJWZhre2rn~3@ore?59VGlx zb5_)0^^+tWaz1!2Pe*X?4or7Z2cKol2jOfq&_qCH%B1D&14)v6g;RRiZ9ic-**(^q zOn1|OqDFmumgYuT?M&G4wF7hoQjrAlVATj~=+@jtX zps!#`eei0Wm2yPdfn?R|mL{md@>T?-X^PUb=MzbiJeccm{&IrKOgxihRZsT%KVe-& z1PV$7rTJT$;9VmD>s!H*yR*=K%GNO!cI*0Z$Bqkzu-;JF9t@mtEh@534Qw|=0wa}eGjS-U@ektsqHY9sFrs1e=$ zMSX;}jIyV(uDG5h|H6wiyMV|3EN<9sM-X1;b0Q!;frye{_@sygf$3mH`v(DDbAJ(( z9bH6};tVjZ+MN(L?CByPMT*tWY=|huudwi9&pQz0?4%_EQXD}o}$m@55ncSXl{M1ccQvmhdDB6TtWstSD!y;TPeN zvG`5Y;BX7907)kT!gL+05#|aQRzdZ!9)nczgQAY%_Dq_5OK7C^sCJ78$TG+)@&k969T_`&q z3bPdsT5jeG$wx5alWHP%C^D6R&=P{-!+7?q1s(XL`gSKgV+aXtqHYHQPRkEb!!I;` zoixH+fQ07Bnw!2TXbOuje8Qx(UKejCDAkWap)H#pc5GM@LjgV^He$!p42OlL`3d`0 zp8}Lnb&XHBOFPS7Kv0-8U}!GTF_jI7E4AXwB%RJsgq_gPv_M(g_LqavX-RAFg*n$u zK!_-KXgUOC#%&vjhCLmhaNoiSW=lokMuZyx8l-62PB;vi30sAq^IRu zd>BEbBZVi@OQ)ESG>So@s=_Trg{SA7d}qjdY-Xb$RVTO z@5C?*yMQ9gVcM$0g5j*ymms6h!nvk|%v}H$Sx>^x`6{Zs;0g6bERpX@XCVVDvK+1} zZ{aEF&WD$!&C5Vx)&fR0Qd>#w!BLpLLYIYvrW789Kt`6s4do8pxH49Be5VLYo&D_x z)~s9tjclTCWi6AFL1pk+NFhgymV?554mPq}*9z|DqVZ5gmwDb`$*imVGT_KNR6-WcgbMdr|HbTo%sE;y0$!ggP2NvIl!^JW`qSmv-Ppx)MvL5=UmP zWG28@j`7AHuw`>W(G$=!lOm3mEDwcnICBGp-Pv=afyz{yUtYn^!$;V~kq<9InR_2X zic;4Cy6h;13QI3=BBkFA%+$AGq`1{bd z$!X{AVCMb;CFQ!s*tokN@PL-3e>7bZDO)fpN~AMAhJ|ev=*wWYGj|W16nECV3Ilb_ z09xk$yJgn|;Q=T``t>1eJ)}c-Xo5(??hBz6P>SLs8YZw(X~R^IXYM~%KeJUJ6P-Y% zxWF(oX+3O}AY}tBODA9zGj|iHp7NAHdDCvfTA}&7q#) z5UeUUUxBN-+FLj*=Co+PKvi3foQGXaWhB4I@Ngeq)z$v8u0dCS*u1VgGE84qBj+K5 zhNf6u9vtCm0J$4$PJZ%UuFTo6ueo!kD(F0Xh!wUEADle>2V_-MRUX_a(wSr=?Be>P zK~Hbc)rxD6qD?Dq3A0ob)I!Na6jr+7?e>(B8PjvCh0o$?M#A8pK&(1% zUrQI6fWgXVZ=bn&dL~uy^v;c|huq_tDzbD2){%!V(*#d<&*C?xS`_>s!qclMf~U!n zt%dbuVO)yf?Hc4Y6ENw4r-$EAn-}u*QflC75)E~+zO?e-X(Tc5APeiv${T5c2e_tI zb2HaZ9za6m!AkM+t2ov>)6wApDIM?tVcXtVcODd_18!Ww+BZCSp9*-;v#gEoJm6xrNeYVveOrWNC$d%~&?Si#L`=Xc&dQr8}W4&4@PIE=B*c zc15hQ;6D18hBxDkrE&N2(!E&vQ;cz=^FZiltDo6efMEedH_sWlC;IwbENG4`7L3n( zL>&E{z5Sh=1?_Rgf~nT)l0u))es%d-Fc?!Tn9ugJBKkek5o0YvJh5Q8RjZ@#ANJ2! zV!$V66|a!-BpM zz6fDB@5OT-#tnC2uYp13LOu%^5wrEZFtBtH$^ufXumGWpBN*NruwXS#SU_g(NDQ&y zL*?}`!h&_`vBfYyjODbm6htRW{r)b7de-m5LKkOf0g*|IFMq~>lOmD^`!PWaIt@?_ z2A%3?&;=nLSZa+FE*Scy$>LKi&|+X@S02MZMHiK{pvC|#5Zb=-1_a@=XfX~yTC6Z1 zZ-nBz$nq-mXz^WKK1d=UX)1^o>yV=biKg-!1m(Q1BtwfvxY1yk9Ib`WY*-LS7aK5J z{jd8%fNo0{A0S4Hjw2&b>B5(#=AcE3on@yGtjoQO77S3)Hij%^5wzrJd%75b6D_8> zD-fYO^>bP*!{~bZnTr74$(Q1xdydc+nYIADw1v=~XH=3vAk$Elw%$R9ZuRmUO9b?G zwAD>h09cyBlXOUfu-=Zj`k_KoM3$3gPw&hK@9k)65F#}7_x^iDkWZ1VaIzSM#%akS zeMI_Y7g^#8vXQBD3i;=h- z%gLj`&#+O~G}Ect7GshH#w55)iak<5ph-a?i+6WA1;%FUb@8?i#F3&tM5Guy=)*8h zM{cA;g5#~_cvwg=M_QMNVU98T+1BF7qeDT8wbp@W+f8Bg(i~H=I8TZR|DA^vw}{(* z3M2VuWRdoB(%L00@HT1DNion`9AWul++bACZIx~}CM|-GNQ+4GxJ{1Z|9gz`skPNZ zDQX=E{Lhw{H1ij~3Xa^4VXNgwibN8W9E9dPxIAgKeT%e~NwG; diff --git a/packages/web/src/components/Sidebar/index.vue b/packages/web/src/components/Sidebar/index.vue new file mode 100644 index 00000000..09b6bee3 --- /dev/null +++ b/packages/web/src/components/Sidebar/index.vue @@ -0,0 +1,124 @@ + + + diff --git a/packages/web/src/components/Sidebar/lists/chat-list-menu.vue b/packages/web/src/components/Sidebar/lists/chat-list-menu.vue new file mode 100644 index 00000000..5882de83 --- /dev/null +++ b/packages/web/src/components/Sidebar/lists/chat-list-menu.vue @@ -0,0 +1,114 @@ + + + + diff --git a/packages/web/src/components/Sidebar/lists/settings-list-menu.vue b/packages/web/src/components/Sidebar/lists/settings-list-menu.vue new file mode 100644 index 00000000..7da46acb --- /dev/null +++ b/packages/web/src/components/Sidebar/lists/settings-list-menu.vue @@ -0,0 +1,87 @@ + + + + diff --git a/packages/web/src/components/Sidebar/lists/types.ts b/packages/web/src/components/Sidebar/lists/types.ts new file mode 100644 index 00000000..7612130b --- /dev/null +++ b/packages/web/src/components/Sidebar/lists/types.ts @@ -0,0 +1,4 @@ +export interface SidebarListProps { + collapsible?: boolean +} + diff --git a/packages/web/src/components/add-provider/index.vue b/packages/web/src/components/add-provider/index.vue index 3997c6dd..e198257c 100644 --- a/packages/web/src/components/add-provider/index.vue +++ b/packages/web/src/components/add-provider/index.vue @@ -89,7 +89,7 @@ @@ -152,22 +152,12 @@ import { import { toTypedSchema } from '@vee-validate/zod' import z from 'zod' import { useForm } from 'vee-validate' -import { useMutation, useQueryCache } from '@pinia/colada' -import { postProviders } from '@memoh/sdk' -import type { ProvidersClientType } from '@memoh/sdk' - -const CLIENT_TYPES: ProvidersClientType[] = ['openai', 'openai-compat', 'anthropic', 'google', 'ollama'] +import { clientType } from '@memoh/shared' +import { useCreateProvider } from '@/composables/api/useProviders' const open = defineModel('open') -const queryCache = useQueryCache() -const { mutate: providerFetch, isLoading } = useMutation({ - mutation: async (data: Record) => { - const { data: result } = await postProviders({ body: data as any, throwOnError: true }) - return result - }, - onSettled: () => queryCache.invalidateQueries({ key: ['providers'] }), -}) +const { mutate: providerFetch, isLoading } = useCreateProvider() const providerSchema = toTypedSchema(z.object({ api_key: z.string().min(1), diff --git a/packages/web/src/components/chat-list/channel-badge/index.vue b/packages/web/src/components/chat-list/channel-badge/index.vue new file mode 100644 index 00000000..6ee24eab --- /dev/null +++ b/packages/web/src/components/chat-list/channel-badge/index.vue @@ -0,0 +1,46 @@ + + + diff --git a/packages/web/src/components/chat-list/index.vue b/packages/web/src/components/chat-list/index.vue index 03a4ba82..8b0e2365 100644 --- a/packages/web/src/components/chat-list/index.vue +++ b/packages/web/src/components/chat-list/index.vue @@ -1,19 +1,19 @@