From 06e8619a37d6918542a3e6b9daa7b61bf1dcc2bd Mon Sep 17 00:00:00 2001 From: BBQ Date: Wed, 11 Feb 2026 14:47:03 +0800 Subject: [PATCH] refactor(core): migrate channel identity and binding across app Align channel identity and bind flow across backend and app-facing layers, including generated swagger artifacts and package lock updates while excluding docs content changes. --- agent/src/agent.ts | 13 +- agent/src/config.ts | 2 +- agent/src/index.ts | 19 +- agent/src/models.ts | 6 +- agent/src/prompts/user.ts | 10 +- agent/src/tools/contact.ts | 129 +- agent/src/tools/index.ts | 9 +- agent/src/tools/mcp.ts | 29 +- agent/src/tools/memory.ts | 14 +- agent/src/tools/message.ts | 34 +- agent/src/tools/subagent.ts | 33 +- agent/src/types/agent.ts | 8 +- cmd/agent/main.go | 64 +- db/migrations/0001_init.down.sql | 24 +- db/migrations/0001_init.up.sql | 236 +-- db/queries/bind.sql | 22 + db/queries/bots.sql | 12 +- db/queries/channel_identities.sql | 49 + db/queries/channels.sql | 47 +- db/queries/chats.sql | 214 +++ db/queries/contacts.sql | 76 - db/queries/history.sql | 31 - db/queries/settings.sql | 93 +- db/queries/users.sql | 92 +- docs/docs.go | 1319 +++-------------- docs/swagger.json | 1319 +++-------------- docs/swagger.yaml | 894 ++--------- internal/{users => accounts}/service.go | 196 ++- internal/{users => accounts}/types.go | 19 +- internal/auth/jwt.go | 108 +- internal/bind/service.go | 262 ++++ internal/bind/service_integration_test.go | 562 +++++++ .../bind/service_link_integration_test.go | 168 +++ internal/bind/types.go | 26 + internal/bots/service.go | 103 +- internal/bots/types.go | 8 + internal/channel/adapters/feishu/config.go | 4 +- internal/channel/adapters/feishu/feishu.go | 30 +- .../channel/adapters/feishu/feishu_test.go | 30 + internal/channel/adapters/telegram/config.go | 4 +- .../channel/adapters/telegram/telegram.go | 39 +- .../adapters/telegram/telegram_test.go | 39 + internal/channel/config_test.go | 2 +- internal/channel/helpers_test.go | 45 +- internal/channel/manager.go | 24 +- internal/channel/manager_integration_test.go | 58 +- internal/channel/service.go | 258 +--- internal/channel/types.go | 65 +- internal/channelidentities/service.go | 291 ++++ .../service_identity_integration_test.go | 109 ++ .../service_integration_test.go | 99 ++ internal/channelidentities/service_test.go | 37 + internal/channelidentities/types.go | 15 + internal/chat/resolver.go | 459 ++++-- internal/chat/resolver_test.go | 11 +- internal/chat/service.go | 864 +++++++++++ .../chat/service_presence_integration_test.go | 269 ++++ internal/chat/types.go | 166 ++- internal/contacts/service.go | 410 ----- internal/contacts/types.go | 45 - internal/db/sqlc/bind.sql.go | 120 ++ internal/db/sqlc/bots.sql.go | 48 +- internal/db/sqlc/channel_identities.sql.go | 249 ++++ internal/db/sqlc/channels.sql.go | 178 +-- internal/db/sqlc/chats.sql.go | 988 ++++++++++++ internal/db/sqlc/contacts.sql.go | 380 ----- internal/db/sqlc/history.sql.go | 178 --- internal/db/sqlc/models.go | 211 +-- internal/db/sqlc/settings.sql.go | 312 ++-- internal/db/sqlc/users.sql.go | 442 ++++-- internal/directory/service.go | 226 --- internal/embeddings/dashscope.go | 2 +- internal/embeddings/resolver.go | 22 +- internal/handlers/auth.go | 38 +- internal/handlers/bind.go | 91 ++ internal/handlers/channel.go | 36 +- internal/handlers/chat.go | 592 ++++++-- internal/handlers/contacts.go | 183 --- internal/handlers/containerd.go | 84 +- internal/handlers/embeddings.go | 8 +- internal/handlers/history.go | 254 ---- internal/handlers/local_channel.go | 89 +- internal/handlers/mcp.go | 44 +- internal/handlers/memory.go | 615 ++++---- internal/handlers/models.go | 2 +- internal/handlers/preauth.go | 38 +- internal/handlers/schedule.go | 58 +- internal/handlers/settings.go | 53 +- internal/handlers/subagent.go | 78 +- internal/handlers/users.go | 248 ++-- internal/history/service.go | 237 --- internal/history/types.go | 23 - internal/identity/types.go | 9 +- internal/identity/user.go | 12 +- internal/logger/logger_test.go | 8 +- internal/mcp/manager.go | 2 +- internal/memory/service.go | 24 +- internal/memory/types.go | 102 +- internal/models/models_test.go | 19 +- internal/policy/service.go | 31 + internal/preauth/service.go | 15 +- internal/preauth/types.go | 15 +- internal/router/channel.go | 296 +++- internal/router/channel_test.go | 407 ++--- internal/router/identity.go | 384 +++-- internal/router/identity_test.go | 527 +++++-- internal/schedule/service_test.go | 4 +- internal/schedule/trigger.go | 1 + internal/schedule/types.go | 20 +- internal/server/server.go | 11 +- internal/settings/service.go | 181 ++- internal/subagent/types.go | 28 +- packages/web/src/components/Sidebar/index.vue | 169 ++- .../Sidebar/lists/chat-list-menu.vue | 215 +++ .../Sidebar/lists/settings-list-menu.vue | 94 ++ .../web/src/components/Sidebar/lists/types.ts | 4 + .../web/src/components/add-platform/index.vue | 172 --- .../components/chat-list/robot-chat/index.vue | 33 +- .../web/src/components/create-mcp/index.vue | 2 +- .../src/components/main-container/index.vue | 2 +- packages/web/src/composables/api/useAuth.ts | 3 + packages/web/src/composables/api/useBots.ts | 2 +- .../web/src/composables/api/useChannels.ts | 10 +- packages/web/src/composables/api/useChat.ts | 384 ++++- .../web/src/composables/api/usePlatform.ts | 38 - packages/web/src/composables/api/useUsers.ts | 80 + packages/web/src/i18n/locales/en.json | 100 +- packages/web/src/i18n/locales/zh.json | 100 +- packages/web/src/main.ts | 16 +- .../pages/bots/components/bot-channels.vue | 10 +- .../pages/bots/components/bot-settings.vue | 32 +- .../components/channel-settings-panel.vue | 17 +- .../src/pages/bots/components/create-bot.vue | 29 +- packages/web/src/pages/bots/detail.vue | 5 +- packages/web/src/pages/chat/index.vue | 154 +- packages/web/src/pages/login/index.vue | 7 +- packages/web/src/pages/mcp/index.vue | 22 +- packages/web/src/pages/models/index.vue | 30 - .../platform/components/platform-card.vue | 71 - packages/web/src/pages/platform/index.vue | 26 - packages/web/src/pages/settings/index.vue | 206 +-- packages/web/src/pages/settings/user.vue | 491 ++++++ packages/web/src/router.ts | 96 +- packages/web/src/store/User.ts | 12 + packages/web/src/store/chat-list.ts | 646 ++++---- packages/web/src/utils/request.ts | 1 + pnpm-lock.yaml | 40 +- 147 files changed, 11931 insertions(+), 9234 deletions(-) create mode 100644 db/queries/bind.sql create mode 100644 db/queries/channel_identities.sql create mode 100644 db/queries/chats.sql delete mode 100644 db/queries/contacts.sql delete mode 100644 db/queries/history.sql rename internal/{users => accounts}/service.go (58%) rename internal/{users => accounts}/types.go (68%) create mode 100644 internal/bind/service.go create mode 100644 internal/bind/service_integration_test.go create mode 100644 internal/bind/service_link_integration_test.go create mode 100644 internal/bind/types.go create mode 100644 internal/channelidentities/service.go create mode 100644 internal/channelidentities/service_identity_integration_test.go create mode 100644 internal/channelidentities/service_integration_test.go create mode 100644 internal/channelidentities/service_test.go create mode 100644 internal/channelidentities/types.go create mode 100644 internal/chat/service.go create mode 100644 internal/chat/service_presence_integration_test.go delete mode 100644 internal/contacts/service.go delete mode 100644 internal/contacts/types.go create mode 100644 internal/db/sqlc/bind.sql.go create mode 100644 internal/db/sqlc/channel_identities.sql.go create mode 100644 internal/db/sqlc/chats.sql.go delete mode 100644 internal/db/sqlc/contacts.sql.go delete mode 100644 internal/db/sqlc/history.sql.go delete mode 100644 internal/directory/service.go create mode 100644 internal/handlers/bind.go delete mode 100644 internal/handlers/contacts.go delete mode 100644 internal/handlers/history.go delete mode 100644 internal/history/service.go delete mode 100644 internal/history/types.go create mode 100644 packages/web/src/components/Sidebar/lists/chat-list-menu.vue create mode 100644 packages/web/src/components/Sidebar/lists/settings-list-menu.vue create mode 100644 packages/web/src/components/Sidebar/lists/types.ts delete mode 100644 packages/web/src/components/add-platform/index.vue delete mode 100644 packages/web/src/composables/api/usePlatform.ts create mode 100644 packages/web/src/composables/api/useUsers.ts delete mode 100644 packages/web/src/pages/platform/components/platform-card.vue delete mode 100644 packages/web/src/pages/platform/index.vue create mode 100644 packages/web/src/pages/settings/user.vue diff --git a/agent/src/agent.ts b/agent/src/agent.ts index 66ea9015..e66d1c13 100644 --- a/agent/src/agent.ts +++ b/agent/src/agent.ts @@ -28,8 +28,8 @@ export const createAgent = ({ botId: '', sessionId: '', containerId: '', - contactId: '', - contactName: '', + channelIdentityId: '', + displayName: '', }, auth, }: AgentParams, fetch: AuthFetcher) => { @@ -108,6 +108,7 @@ export const createAgent = ({ model: modelConfig, brave, identity, + auth, enableSkill, }) const defaultMCPConnections = getDefaultMCPConnections() @@ -130,8 +131,8 @@ export const createAgent = ({ const images = input.attachments.filter(attachment => attachment.type === 'image') const files = input.attachments.filter((a): a is ContainerFileAttachment => a.type === 'file') const text = user(input.query, { - contactId: identity.contactId, - contactName: identity.contactName, + channelIdentityId: identity.channelIdentityId || identity.contactId || '', + displayName: identity.displayName || identity.contactName || 'User', channel: currentChannel, date: new Date(), attachments: files, @@ -171,7 +172,7 @@ export const createAgent = ({ const { messages: strippedMessages, attachments: messageAttachments } = stripAttachmentsFromMessages(response.messages) const allAttachments = dedupeAttachments([...textAttachments, ...messageAttachments]) return { - messages: [userPrompt, ...strippedMessages], + messages: strippedMessages, reasoning: reasoning.map(part => part.text), usage, text: cleanedText, @@ -376,7 +377,7 @@ export const createAgent = ({ const { messages: strippedMessages } = stripAttachmentsFromMessages(result.messages) yield { type: 'agent_end', - messages: [userPrompt, ...strippedMessages], + messages: strippedMessages, reasoning: result.reasoning, usage: result.usage!, skills: getEnabledSkills(), diff --git a/agent/src/config.ts b/agent/src/config.ts index 9f3ca291..0cd0beda 100644 --- a/agent/src/config.ts +++ b/agent/src/config.ts @@ -9,7 +9,7 @@ type AgentGatewayConfig = { 'server': { addr?: string }, - 'brave': { + 'brave'?: { api_key?: string base_url?: string } diff --git a/agent/src/index.ts b/agent/src/index.ts index fcb0320a..d4ad2ac3 100644 --- a/agent/src/index.ts +++ b/agent/src/index.ts @@ -3,14 +3,18 @@ import { chatModule } from './modules/chat' import { corsMiddleware } from './middlewares/cors' import { errorMiddleware } from './middlewares/error' import { loadConfig } from './config' -import { join } from 'path' const config = loadConfig('../config.toml') export const getBraveConfig = () => { + const apiKey = config.brave?.api_key?.trim() ?? '' + if (!apiKey) { + return undefined + } + const baseUrl = config.brave?.base_url?.trim() || 'https://api.search.brave.com/res/v1/' return { - apiKey: config.brave.api_key ?? '', - baseUrl: config.brave.base_url ?? 'https://api.search.brave.com/res/v1/', + apiKey, + baseUrl, } } @@ -36,11 +40,16 @@ export const createAuthFetcher = (bearer: string | undefined): AuthFetcher => { return async (url: string, options?: RequestInit) => { const requestOptions = options ?? {} const headers = new Headers(requestOptions.headers || {}) - if (bearer) { + if (bearer && !headers.has('Authorization')) { headers.set('Authorization', `Bearer ${bearer}`) } - return await fetch(join(getBaseUrl(), url), { + const baseURL = getBaseUrl() + const requestURL = /^https?:\/\//i.test(url) + ? url + : new URL(url, `${baseURL.replace(/\/$/, '')}/`).toString() + + return await fetch(requestURL, { ...requestOptions, headers, }) diff --git a/agent/src/models.ts b/agent/src/models.ts index f4aca341..acce0fa8 100644 --- a/agent/src/models.ts +++ b/agent/src/models.ts @@ -24,8 +24,10 @@ export const IdentityContextModel = z.object({ botId: z.string().min(1, 'Bot ID is required'), sessionId: z.string().min(1, 'Session ID is required'), containerId: z.string().min(1, 'Container ID is required'), - contactId: z.string().min(1, 'Contact ID is required'), - contactName: z.string().min(1, 'Contact name is required'), + channelIdentityId: z.string().min(1, 'Channel identity ID is required'), + displayName: z.string().min(1, 'Display name is required'), + contactId: z.string().optional(), + contactName: z.string().optional(), contactAlias: z.string().optional(), userId: z.string().optional(), currentPlatform: z.string().optional(), diff --git a/agent/src/prompts/user.ts b/agent/src/prompts/user.ts index ac46c23b..742b8a0b 100644 --- a/agent/src/prompts/user.ts +++ b/agent/src/prompts/user.ts @@ -1,8 +1,8 @@ import { ContainerFileAttachment } from '../types' export interface UserParams { - contactId: string - contactName: string + channelIdentityId: string + displayName: string channel: string date: Date attachments: ContainerFileAttachment[] @@ -10,11 +10,11 @@ export interface UserParams { export const user = ( query: string, - { contactId, contactName, channel, date, attachments }: UserParams + { channelIdentityId, displayName, channel, date, attachments }: UserParams ) => { const headers = { - 'contact-id': contactId, - 'contact-name': contactName, + 'channel-identity-id': channelIdentityId, + 'display-name': displayName, 'channel': channel, 'time': date.toISOString(), 'attachments': attachments.map(attachment => attachment.path), diff --git a/agent/src/tools/contact.ts b/agent/src/tools/contact.ts index 205eb280..de5e1c6a 100644 --- a/agent/src/tools/contact.ts +++ b/agent/src/tools/contact.ts @@ -11,100 +11,57 @@ export type ContactToolParams = { export const getContactTools = ({ fetch, identity }: ContactToolParams) => { const botId = identity.botId.trim() + const listMyIdentities = async () => { + const response = await fetch('/users/me/identities') + return response.json() + } + const contactSearch = tool({ - description: 'Search contacts by name or alias', + description: 'Search identity cards by platform, external id, or display name', inputSchema: z.object({ - query: z.string().describe('The query to search for contacts'), + query: z.string().describe('The query to search identities').optional().default(''), }), execute: async ({ query }) => { - const url = `/bots/${botId}/contacts?q=${encodeURIComponent(query)}` - const response = await fetch(url) - return response.json() + const payload = await listMyIdentities() + const keyword = query.trim().toLowerCase() + const items = Array.isArray(payload?.items) ? payload.items : [] + const filtered = keyword + ? items.filter((item: { platform?: string; external_id?: string; display_name?: string }) => { + const platform = String(item?.platform ?? '').toLowerCase() + const externalID = String(item?.external_id ?? '').toLowerCase() + const displayName = String(item?.display_name ?? '').toLowerCase() + return platform.includes(keyword) || externalID.includes(keyword) || displayName.includes(keyword) + }) + : items + return { + canonical_channel_identity_id: payload?.canonical_channel_identity_id ?? '', + total: filtered.length, + items: filtered, + } }, }) - const contactCreate = tool({ - description: 'Create a contact', + const contactCardMe = tool({ + description: 'Get my canonical identity card and all linked channel identities', + inputSchema: z.object({}), + execute: async () => { + return listMyIdentities() + }, + }) + + const contactIssueBindCode = tool({ + description: 'Issue a bind code for linking current channel identity to this account', inputSchema: z.object({ - name: z.string().describe('The display name of the contact'), - alias: z.string().describe('The alias of the contact').optional(), - tags: z.array(z.string()).describe('The tags of the contact').optional(), + ttl_seconds: z.number().int().positive().optional().describe('Bind code ttl in seconds'), }), - execute: async ({ name, alias, tags }) => { - const response = await fetch(`/bots/${botId}/contacts`, { + execute: async ({ ttl_seconds }) => { + if (!botId) { + throw new Error('bot_id is required') + } + const response = await fetch(`/bots/${botId}/bind_codes`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - display_name: name, - alias: alias, - tags: tags ?? [], - }), - }) - return response.json() - }, - }) - - const contactUpdate = tool({ - description: 'Update a contact', - inputSchema: z.object({ - contact_id: z.string().describe('The ID of the contact to update'), - name: z.string().describe('The display name of the contact').optional(), - alias: z.string().describe('The alias of the contact').optional(), - tags: z.array(z.string()).describe('The tags of the contact').optional(), - }), - execute: async ({ contact_id, name, alias, tags }) => { - const response = await fetch(`/bots/${botId}/contacts/${contact_id}`, { - method: 'PATCH', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - display_name: name, - alias: alias, - tags: tags ?? [], - }), - }) - return response.json() - }, - }) - - // const contactBindToken = tool({ - // description: 'Issue a one-time bind token for a contact', - // inputSchema: z.object({ - // contact_id: ContactID, - // target_platform: z.string().describe('The platform to bind the contact to'), - // target_external_id: z.string().describe('The external ID of the contact'), - // ttl_seconds: z.number().describe('The number of seconds the bind token is valid').optional(), - // }), - // execute: async ({ bot_id, contact_id, target_platform, target_external_id, ttl_seconds }) => { - // const response = await fetch(`/bots/${botId}/contacts/${contact_id}/bind_token`, { - // method: 'POST', - // headers: { 'Content-Type': 'application/json' }, - // body: JSON.stringify({ - // target_platform: target_platform, - // target_external_id: target_external_id, - // ttl_seconds: ttl_seconds, - // }), - // }) - // return response.json() - // }, - // }) - - const contactBind = tool({ - description: 'Bind a contact to a platform identity using a bind token', - inputSchema: z.object({ - contact_id: z.string().describe('The ID of the contact to bind'), - platform: z.string().describe('The platform to bind the contact to'), - external_id: z.string().describe('The external ID of the contact'), - bind_token: z.string().describe('The bind token to use'), - }), - execute: async ({ contact_id, platform, external_id, bind_token }) => { - const response = await fetch(`/bots/${botId}/contacts/${contact_id}/bind`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - platform: platform, - external_id: external_id, - bind_token: bind_token, - }), + body: JSON.stringify({ ttl_seconds }), }) return response.json() }, @@ -112,9 +69,7 @@ export const getContactTools = ({ fetch, identity }: ContactToolParams) => { return { 'contact_search': contactSearch, - 'contact_create': contactCreate, - 'contact_update': contactUpdate, - // 'contact_bind_token': contactBindToken, - 'contact_bind': contactBind, + 'contact_card_me': contactCardMe, + 'contact_issue_bind_code': contactIssueBindCode, } } diff --git a/agent/src/tools/index.ts b/agent/src/tools/index.ts index 34422aed..80456ac3 100644 --- a/agent/src/tools/index.ts +++ b/agent/src/tools/index.ts @@ -1,5 +1,5 @@ import { AuthFetcher } from '..' -import { AgentAction, BraveConfig, IdentityContext, ModelConfig } from '../types' +import { AgentAction, AgentAuthContext, BraveConfig, IdentityContext, ModelConfig } from '../types' import { ToolSet } from 'ai' import { getWebTools } from './web' import { getScheduleTools } from './schedule' @@ -14,12 +14,13 @@ export interface ToolsParams { model: ModelConfig brave?: BraveConfig identity: IdentityContext + auth: AgentAuthContext enableSkill: (skill: string) => void } export const getTools = ( actions: AgentAction[], - { fetch, model, brave, identity, enableSkill }: ToolsParams + { fetch, model, brave, identity, auth, enableSkill }: ToolsParams ) => { const tools: ToolSet = {} if (actions.includes(AgentAction.Web) && brave) { @@ -31,11 +32,11 @@ export const getTools = ( Object.assign(tools, scheduleTools) } if (actions.includes(AgentAction.Memory)) { - const memoryTools = getMemoryTools({ fetch }) + const memoryTools = getMemoryTools({ fetch, identity }) Object.assign(tools, memoryTools) } if (actions.includes(AgentAction.Subagent)) { - const subagentTools = getSubagentTools({ fetch, model, brave, identity }) + const subagentTools = getSubagentTools({ fetch, model, brave, identity, auth }) Object.assign(tools, subagentTools) } if (actions.includes(AgentAction.Contact)) { diff --git a/agent/src/tools/mcp.ts b/agent/src/tools/mcp.ts index 1bb4466f..d10a5387 100644 --- a/agent/src/tools/mcp.ts +++ b/agent/src/tools/mcp.ts @@ -79,16 +79,25 @@ export const getMCPTools = async (connections: MCPConnection[], options: MCPTool } const toolSets = await Promise.all(connections.map(async (connection) => { - switch (connection.type) { - case 'http': - return getHTTPTools(connection) - case 'sse': - return getSSETools(connection) - case 'stdio': - return getStdioTools(connection) - default: - console.warn('unknown mcp connection type', connection) - return {} + try { + switch (connection.type) { + case 'http': + return await getHTTPTools(connection) + case 'sse': + return await getSSETools(connection) + case 'stdio': + return await getStdioTools(connection) + default: + console.warn('unknown mcp connection type', connection) + return {} + } + } catch (error) { + console.warn('skip mcp connection due to initialization error', { + name: connection.name, + type: connection.type, + error: error instanceof Error ? error.message : String(error), + }) + return {} } })) diff --git a/agent/src/tools/memory.ts b/agent/src/tools/memory.ts index 3936fae9..083b2d48 100644 --- a/agent/src/tools/memory.ts +++ b/agent/src/tools/memory.ts @@ -1,9 +1,11 @@ import { tool } from 'ai' import { AuthFetcher } from '..' +import type { IdentityContext } from '../types' import { z } from 'zod' export type MemoryToolParams = { fetch: AuthFetcher + identity: IdentityContext } type MemorySearchItem = { @@ -16,20 +18,26 @@ type MemorySearchItem = { } } -export const getMemoryTools = ({ fetch }: MemoryToolParams) => { +export const getMemoryTools = ({ fetch, identity }: MemoryToolParams) => { const searchMemory = tool({ description: 'Search for memories', inputSchema: z.object({ query: z.string().describe('The query to search for memories'), + limit: z.number().int().positive().max(50).optional(), }), - execute: async ({ query }) => { - const response = await fetch('/memory/search', { + execute: async ({ query, limit }) => { + const chatId = identity.sessionId.trim() + if (!chatId) { + throw new Error('sessionId is required to search memory') + } + const response = await fetch(`/chats/${chatId}/memory/search`, { method: 'POST', headers: { 'Content-Type': 'application/json', }, body: JSON.stringify({ query, + limit, }), }) const data = await response.json() diff --git a/agent/src/tools/message.ts b/agent/src/tools/message.ts index 77bbc906..c633c6a7 100644 --- a/agent/src/tools/message.ts +++ b/agent/src/tools/message.ts @@ -12,6 +12,7 @@ const SendMessageSchema = z.object({ bot_id: z.string().optional(), platform: z.string().optional(), target: z.string().optional(), + channel_identity_id: z.string().optional(), to_user_id: z.string().optional(), message: z.string(), }) @@ -25,38 +26,39 @@ export const getMessageTools = ({ fetch, identity }: MessageToolParams) => { const platform = (payload.platform ?? identity.currentPlatform ?? '').trim() const replyTarget = (identity.replyTarget ?? '').trim() const target = (payload.target ?? replyTarget).trim() - const toUserID = (payload.to_user_id ?? '').trim() + const channelIdentityID = (payload.channel_identity_id ?? payload.to_user_id ?? '').trim() if (!botId) { throw new Error('bot_id is required') } if (!platform) { throw new Error('platform is required') } - if (!target && !toUserID && !identity.sessionToken) { - throw new Error('target or to_user_id is required') + // Prefer chat token when there is no explicit target identity. + const useSessionToken = !!identity.sessionToken && !channelIdentityID + if (!target && !channelIdentityID && !useSessionToken) { + throw new Error('target or channel_identity_id is required') } - // Use session token if available and no explicit to_user_id specified - // This allows replying to current session without needing explicit auth - const useSessionToken = !!identity.sessionToken && !toUserID console.log('[Tool] send_message', { botId, platform, target: target || undefined, - toUserID: toUserID || undefined, + channelIdentityID: channelIdentityID || undefined, replyTarget, useSessionToken, }) - const body: Record = { message: payload.message } - if (!useSessionToken) { - if (target) { - body.to = target - } - if (toUserID) { - body.to_user_id = toUserID - } + const body: Record = { + message: { + text: payload.message, + }, + } + if (target) { + body.target = target + } + if (channelIdentityID) { + body.channel_identity_id = channelIdentityID } const url = useSessionToken - ? `/bots/${botId}/channel/${platform}/send_session` + ? `/bots/${botId}/channel/${platform}/send_chat` : `/bots/${botId}/channel/${platform}/send` const headers: Record = { 'Content-Type': 'application/json' } if (useSessionToken && identity.sessionToken) { diff --git a/agent/src/tools/subagent.ts b/agent/src/tools/subagent.ts index fe587e2a..65321660 100644 --- a/agent/src/tools/subagent.ts +++ b/agent/src/tools/subagent.ts @@ -1,7 +1,7 @@ import { tool } from 'ai' import { z } from 'zod' import { createAgent } from '../agent' -import { ModelConfig, BraveConfig } from '../types' +import { ModelConfig, BraveConfig, AgentAuthContext } from '../types' import { AuthFetcher } from '..' import { AgentAction, IdentityContext } from '../types/agent' @@ -10,14 +10,21 @@ export interface SubagentToolParams { model: ModelConfig brave?: BraveConfig identity: IdentityContext + auth: AgentAuthContext } -export const getSubagentTools = ({ fetch, model, brave, identity }: SubagentToolParams) => { +export const getSubagentTools = ({ fetch, model, brave, identity, auth }: SubagentToolParams) => { + const botId = identity.botId.trim() + const base = `/bots/${botId}/subagents` + const listSubagents = tool({ description: 'List subagents for current user', inputSchema: z.object({}), execute: async () => { - const response = await fetch('/subagents', { method: 'GET' }) + if (!botId) { + throw new Error('bot_id is required') + } + const response = await fetch(base, { method: 'GET' }) return response.json() }, }) @@ -29,7 +36,10 @@ export const getSubagentTools = ({ fetch, model, brave, identity }: SubagentTool description: z.string(), }), execute: async ({ name, description }) => { - const response = await fetch('/subagents', { + if (!botId) { + throw new Error('bot_id is required') + } + const response = await fetch(base, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ name, description }), @@ -44,7 +54,10 @@ export const getSubagentTools = ({ fetch, model, brave, identity }: SubagentTool id: z.string().describe('Subagent ID'), }), execute: async ({ id }) => { - const response = await fetch(`/subagents/${id}`, { method: 'DELETE' }) + if (!botId) { + throw new Error('bot_id is required') + } + const response = await fetch(`${base}/${id}`, { method: 'DELETE' }) return response.status === 204 ? { success: true } : response.json() }, }) @@ -56,14 +69,17 @@ export const getSubagentTools = ({ fetch, model, brave, identity }: SubagentTool query: z.string().describe('The prompt to ask the subagent to do.'), }), execute: async ({ name, query }) => { - const listResponse = await fetch('/subagents', { method: 'GET' }) + if (!botId) { + throw new Error('bot_id is required') + } + const listResponse = await fetch(base, { method: 'GET' }) const listPayload = await listResponse.json() const items = Array.isArray(listPayload?.items) ? listPayload.items : [] const target = items.find((item: { name?: string }) => item?.name === name) if (!target?.id) { throw new Error(`subagent not found: ${name}`) } - const contextResponse = await fetch(`/subagents/${target.id}/context`, { method: 'GET' }) + const contextResponse = await fetch(`${base}/${target.id}/context`, { method: 'GET' }) const contextPayload = await contextResponse.json() const contextMessages = Array.isArray(contextPayload?.messages) ? contextPayload.messages : [] const { askAsSubagent } = createAgent({ @@ -73,6 +89,7 @@ export const getSubagentTools = ({ fetch, model, brave, identity }: SubagentTool AgentAction.Web, ], identity, + auth, }, fetch) const result = await askAsSubagent({ messages: contextMessages, @@ -81,7 +98,7 @@ export const getSubagentTools = ({ fetch, model, brave, identity }: SubagentTool description: target.description, }) const updatedMessages = [...contextMessages, ...result.messages] - await fetch(`/subagents/${target.id}/context`, { + await fetch(`${base}/${target.id}/context`, { method: 'PUT', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ messages: updatedMessages }), diff --git a/agent/src/types/agent.ts b/agent/src/types/agent.ts index 430c3850..eac7b312 100644 --- a/agent/src/types/agent.ts +++ b/agent/src/types/agent.ts @@ -8,8 +8,12 @@ export interface IdentityContext { sessionId: string containerId: string - contactId: string - contactName: string + channelIdentityId: string + displayName: string + + // Deprecated compatibility fields kept optional for older callers. + contactId?: string + contactName?: string contactAlias?: string userId?: string diff --git a/cmd/agent/main.go b/cmd/agent/main.go index bee71fee..6b2c4050 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -8,20 +8,21 @@ import ( "strings" "time" + "github.com/memohai/memoh/internal/accounts" + "github.com/memohai/memoh/internal/bind" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/channel" "github.com/memohai/memoh/internal/channel/adapters/feishu" "github.com/memohai/memoh/internal/channel/adapters/local" "github.com/memohai/memoh/internal/channel/adapters/telegram" + "github.com/memohai/memoh/internal/channelidentities" "github.com/memohai/memoh/internal/chat" "github.com/memohai/memoh/internal/config" - "github.com/memohai/memoh/internal/contacts" ctr "github.com/memohai/memoh/internal/containerd" "github.com/memohai/memoh/internal/db" dbsqlc "github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/embeddings" "github.com/memohai/memoh/internal/handlers" - "github.com/memohai/memoh/internal/history" "github.com/memohai/memoh/internal/logger" "github.com/memohai/memoh/internal/mcp" "github.com/memohai/memoh/internal/memory" @@ -34,7 +35,6 @@ import ( "github.com/memohai/memoh/internal/server" "github.com/memohai/memoh/internal/settings" "github.com/memohai/memoh/internal/subagent" - "github.com/memohai/memoh/internal/users" "github.com/memohai/memoh/internal/version" "github.com/jackc/pgx/v5/pgtype" @@ -96,9 +96,9 @@ func main() { queries := dbsqlc.New(conn) modelsService := models.NewService(logger.L, queries) botService := bots.NewService(logger.L, queries) - usersService := users.NewService(logger.L, queries) + accountService := accounts.NewService(logger.L, queries) - containerdHandler := handlers.NewContainerdHandler(logger.L, service, cfg.MCP, cfg.Containerd.Namespace, botService, usersService, queries) + containerdHandler := handlers.NewContainerdHandler(logger.L, service, cfg.MCP, cfg.Containerd.Namespace, botService, accountService, queries) botService.SetContainerLifecycle(containerdHandler) if err := ensureAdminUser(ctx, logger.L, queries, cfg); err != nil { @@ -106,7 +106,7 @@ func main() { os.Exit(1) } - authHandler := handlers.NewAuthHandler(logger.L, usersService, cfg.Auth.JWTSecret, jwtExpiresIn) + authHandler := handlers.NewAuthHandler(logger.L, accountService, cfg.Auth.JWTSecret, jwtExpiresIn) // Initialize chat resolver after memory service is configured. var chatResolver *chat.Resolver @@ -134,7 +134,6 @@ func main() { bm25Indexer := memory.NewBM25Indexer(logger.L) memoryService := memory.NewService(logger.L, llmClient, textEmbedder, store, resolver, bm25Indexer, textModel.ModelID, multimodalModel.ModelID) - memoryHandler := handlers.NewMemoryHandler(logger.L, memoryService, botService, usersService) go func() { if err := memoryService.WarmupBM25(ctx, 200); err != nil { logger.Warn("bm25 warmup failed", slog.Any("error", err)) @@ -145,23 +144,23 @@ func main() { providersService := providers.NewService(logger.L, queries) providersHandler := handlers.NewProvidersHandler(logger.L, providersService, modelsService) settingsService := settings.NewService(logger.L, queries) - settingsHandler := handlers.NewSettingsHandler(logger.L, settingsService, botService, usersService) + settingsHandler := handlers.NewSettingsHandler(logger.L, settingsService, botService, accountService) modelsHandler := handlers.NewModelsHandler(logger.L, modelsService, settingsService) policyService := policy.NewService(logger.L, botService, settingsService) - historyService := history.NewService(logger.L, queries) - historyHandler := handlers.NewHistoryHandler(logger.L, historyService, botService, usersService) - contactsService := contacts.NewService(queries) - contactsHandler := handlers.NewContactsHandler(contactsService, botService, usersService) + chatService := chat.NewService(logger.L, queries) + memoryHandler := handlers.NewMemoryHandler(logger.L, memoryService, chatService, accountService) + actorService := channelidentities.NewService(logger.L, queries) preauthService := preauth.NewService(queries) - preauthHandler := handlers.NewPreauthHandler(preauthService, botService, usersService) + preauthHandler := handlers.NewPreauthHandler(preauthService, botService, accountService) + bindService := bind.NewService(logger.L, conn, queries) + bindHandler := handlers.NewBindHandler(logger.L, bindService) mcpConnectionsService := mcp.NewConnectionService(logger.L, queries) - mcpHandler := handlers.NewMCPHandler(logger.L, mcpConnectionsService, botService, usersService) - - chatResolver = chat.NewResolver(logger.L, modelsService, queries, memoryService, historyService, settingsService, mcpConnectionsService, cfg.AgentGateway.BaseURL(), 120*time.Second) + mcpHandler := handlers.NewMCPHandler(logger.L, mcpConnectionsService, botService, accountService) + chatResolver = chat.NewResolver(logger.L, modelsService, queries, memoryService, chatService, settingsService, mcpConnectionsService, cfg.AgentGateway.BaseURL(), 120*time.Second) chatResolver.SetSkillLoader(&skillLoaderAdapter{handler: containerdHandler}) embeddingsHandler := handlers.NewEmbeddingsHandler(logger.L, modelsService, queries) swaggerHandler := handlers.NewSwaggerHandler(logger.L) - chatHandler := handlers.NewChatHandler(logger.L, chatResolver, botService, usersService) + chatHandler := handlers.NewChatHandler(logger.L, chatResolver, chatService, botService, accountService) channelRegistry := channel.NewRegistry() sessionHub := local.NewSessionHub() channelRegistry.MustRegister(telegram.NewTelegramAdapter(logger.L)) @@ -169,26 +168,26 @@ func main() { channelRegistry.MustRegister(local.NewCLIAdapter(sessionHub)) channelRegistry.MustRegister(local.NewWebAdapter(sessionHub)) channelService := channel.NewService(queries, channelRegistry) - channelRouter := router.NewChannelInboundProcessor(logger.L, channelRegistry, channelService, chatResolver, contactsService, policyService, preauthService, cfg.Auth.JWTSecret, 5*time.Minute) + channelRouter := router.NewChannelInboundProcessor(logger.L, channelRegistry, chatService, chatResolver, actorService, botService, policyService, preauthService, bindService, cfg.Auth.JWTSecret, 5*time.Minute) channelManager := channel.NewManager(logger.L, channelRegistry, channelService, channelRouter) if mw := channelRouter.IdentityMiddleware(); mw != nil { channelManager.Use(mw) } channelManager.Start(ctx) channelHandler := handlers.NewChannelHandler(channelService, channelRegistry) - usersHandler := handlers.NewUsersHandler(logger.L, usersService, botService, channelService, channelManager, channelRegistry) - cliHandler := handlers.NewLocalChannelHandler(local.CLIType, channelManager, channelService, sessionHub, botService, usersService) - webHandler := handlers.NewLocalChannelHandler(local.WebType, channelManager, channelService, sessionHub, botService, usersService) + usersHandler := handlers.NewUsersHandler(logger.L, accountService, actorService, botService, chatService, channelService, channelManager, channelRegistry) + cliHandler := handlers.NewLocalChannelHandler(local.CLIType, channelManager, channelService, chatService, sessionHub, botService, accountService) + webHandler := handlers.NewLocalChannelHandler(local.WebType, channelManager, channelService, chatService, sessionHub, botService, accountService) scheduleGateway := chat.NewScheduleGateway(chatResolver) scheduleService := schedule.NewService(logger.L, queries, scheduleGateway, cfg.Auth.JWTSecret) if err := scheduleService.Bootstrap(ctx); err != nil { logger.Error("schedule bootstrap", slog.Any("error", err)) os.Exit(1) } - scheduleHandler := handlers.NewScheduleHandler(logger.L, scheduleService, botService, usersService) + scheduleHandler := handlers.NewScheduleHandler(logger.L, scheduleService, botService, accountService) subagentService := subagent.NewService(logger.L, queries) - subagentHandler := handlers.NewSubagentHandler(logger.L, subagentService, botService, usersService) - srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, historyHandler, contactsHandler, preauthHandler, scheduleHandler, subagentHandler, containerdHandler, channelHandler, usersHandler, mcpHandler, cliHandler, webHandler) + subagentHandler := handlers.NewSubagentHandler(logger.L, subagentService, botService, accountService) + srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, preauthHandler, bindHandler, scheduleHandler, subagentHandler, containerdHandler, channelHandler, usersHandler, mcpHandler, cliHandler, webHandler) if err := srv.Start(); err != nil { logger.Error("server failed", slog.Any("error", err)) @@ -249,7 +248,7 @@ func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Quer if queries == nil { return fmt.Errorf("db queries not configured") } - count, err := queries.CountUsers(ctx) + count, err := queries.CountAccounts(ctx) if err != nil { return err } @@ -272,6 +271,14 @@ func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Quer return err } + user, err := queries.CreateUser(ctx, dbsqlc.CreateUserParams{ + IsActive: true, + Metadata: []byte("{}"), + }) + if err != nil { + return fmt.Errorf("create admin user: %w", err) + } + emailValue := pgtype.Text{Valid: false} if email != "" { emailValue = pgtype.Text{String: email, Valid: true} @@ -279,10 +286,11 @@ func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Quer displayName := pgtype.Text{String: username, Valid: true} dataRoot := pgtype.Text{String: cfg.MCP.DataRoot, Valid: cfg.MCP.DataRoot != ""} - _, err = queries.CreateUser(ctx, dbsqlc.CreateUserParams{ - Username: username, + _, err = queries.CreateAccount(ctx, dbsqlc.CreateAccountParams{ + UserID: user.ID, + Username: pgtype.Text{String: username, Valid: true}, Email: emailValue, - PasswordHash: string(hashed), + PasswordHash: pgtype.Text{String: string(hashed), Valid: true}, Role: "admin", DisplayName: displayName, AvatarUrl: pgtype.Text{Valid: false}, diff --git a/db/migrations/0001_init.down.sql b/db/migrations/0001_init.down.sql index 28f5128d..014bed2e 100644 --- a/db/migrations/0001_init.down.sql +++ b/db/migrations/0001_init.down.sql @@ -1,26 +1,24 @@ -DROP TABLE IF EXISTS user_settings; DROP TABLE IF EXISTS subagents; DROP TABLE IF EXISTS schedule; DROP TABLE IF EXISTS lifecycle_events; DROP TABLE IF EXISTS container_versions; DROP TABLE IF EXISTS snapshots; DROP TABLE IF EXISTS containers; -DROP TABLE IF EXISTS channel_sessions; -DROP TABLE IF EXISTS contact_channels; +DROP TABLE IF EXISTS chat_routes; +DROP TABLE IF EXISTS chat_messages; +DROP TABLE IF EXISTS chat_channel_identity_presence; +DROP TABLE IF EXISTS chat_participants; +DROP TABLE IF EXISTS chats; +DROP TABLE IF EXISTS channel_identity_bind_codes; DROP TABLE IF EXISTS bot_preauth_keys; DROP TABLE IF EXISTS bot_channel_configs; -DROP TABLE IF EXISTS user_channel_bindings; -DROP TABLE IF EXISTS history; -DROP TABLE IF EXISTS conversations; DROP TABLE IF EXISTS mcp_connections; -DROP TABLE IF EXISTS bot_model_configs; -DROP TABLE IF EXISTS bot_settings; DROP TABLE IF EXISTS bot_members; -DROP TABLE IF EXISTS contact_bind_tokens; DROP TABLE IF EXISTS bots; +DROP TABLE IF EXISTS model_variants; +DROP TABLE IF EXISTS models; +DROP TABLE IF EXISTS llm_providers; +DROP TABLE IF EXISTS user_channel_bindings; +DROP TABLE IF EXISTS channel_identities; DROP TABLE IF EXISTS users; -DROP TABLE IF EXISTS contacts; --- DROP TABLE IF EXISTS model_variants; --- DROP TABLE IF EXISTS models; --- DROP TABLE IF EXISTS llm_providers; DROP TYPE IF EXISTS user_role; diff --git a/db/migrations/0001_init.up.sql b/db/migrations/0001_init.up.sql index b9facd8b..39528a14 100644 --- a/db/migrations/0001_init.up.sql +++ b/db/migrations/0001_init.up.sql @@ -8,23 +8,58 @@ BEGIN END $$; +-- users: Memoh user principal CREATE TABLE IF NOT EXISTS users ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - username TEXT NOT NULL, + username TEXT, email TEXT, - password_hash TEXT NOT NULL, + password_hash TEXT, role user_role NOT NULL DEFAULT 'member', display_name TEXT, avatar_url TEXT, - is_active BOOLEAN NOT NULL DEFAULT true, data_root TEXT, + last_login_at TIMESTAMPTZ, + chat_model_id TEXT, + memory_model_id TEXT, + embedding_model_id TEXT, + max_context_load_time INTEGER NOT NULL DEFAULT 1440, + language TEXT NOT NULL DEFAULT 'auto', + is_active BOOLEAN NOT NULL DEFAULT true, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), - last_login_at TIMESTAMPTZ, CONSTRAINT users_email_unique UNIQUE (email), CONSTRAINT users_username_unique UNIQUE (username) ); +-- channel_identities: unified inbound identity subject +CREATE TABLE IF NOT EXISTS channel_identities ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID REFERENCES users(id) ON DELETE SET NULL, + channel TEXT NOT NULL, + channel_subject_id TEXT NOT NULL, + display_name TEXT, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT channel_identities_channel_subject_unique UNIQUE (channel, channel_subject_id) +); + +CREATE INDEX IF NOT EXISTS idx_channel_identities_user_id ON channel_identities(user_id); + +-- user_channel_bindings: outbound delivery config +CREATE TABLE IF NOT EXISTS user_channel_bindings ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + platform TEXT NOT NULL, + config JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT user_channel_bindings_unique UNIQUE (user_id, platform) +); + +CREATE INDEX IF NOT EXISTS idx_user_channel_bindings_user_id ON user_channel_bindings(user_id); + CREATE TABLE IF NOT EXISTS llm_providers ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), name TEXT NOT NULL, @@ -73,6 +108,12 @@ CREATE TABLE IF NOT EXISTS bots ( display_name TEXT, avatar_url TEXT, is_active BOOLEAN NOT NULL DEFAULT true, + max_context_load_time INTEGER NOT NULL DEFAULT 1440, + language TEXT NOT NULL DEFAULT 'auto', + allow_guest BOOLEAN NOT NULL DEFAULT false, + chat_model_id UUID REFERENCES models(id) ON DELETE SET NULL, + memory_model_id UUID REFERENCES models(id) ON DELETE SET NULL, + embedding_model_id UUID REFERENCES models(id) ON DELETE SET NULL, metadata JSONB NOT NULL DEFAULT '{}'::jsonb, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), @@ -92,20 +133,6 @@ CREATE TABLE IF NOT EXISTS bot_members ( CREATE INDEX IF NOT EXISTS idx_bot_members_user_id ON bot_members(user_id); -CREATE TABLE IF NOT EXISTS bot_settings ( - bot_id UUID PRIMARY KEY REFERENCES bots(id) ON DELETE CASCADE, - max_context_load_time INTEGER NOT NULL DEFAULT 1440, - language TEXT NOT NULL DEFAULT 'auto', - allow_guest BOOLEAN NOT NULL DEFAULT false -); - -CREATE TABLE IF NOT EXISTS bot_model_configs ( - bot_id UUID PRIMARY KEY REFERENCES bots(id) ON DELETE CASCADE, - chat_model_id UUID REFERENCES models(id) ON DELETE SET NULL, - embedding_model_id UUID REFERENCES models(id) ON DELETE SET NULL, - memory_model_id UUID REFERENCES models(id) ON DELETE SET NULL -); - CREATE TABLE IF NOT EXISTS mcp_connections ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, @@ -121,45 +148,85 @@ CREATE TABLE IF NOT EXISTS mcp_connections ( CREATE INDEX IF NOT EXISTS idx_mcp_connections_bot_id ON mcp_connections(bot_id); -CREATE TABLE IF NOT EXISTS conversations ( +-- chats: first-class conversation container +CREATE TABLE IF NOT EXISTS chats ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, - session_id TEXT NOT NULL, - channel_type TEXT NOT NULL, - chat_id TEXT, - sender_id TEXT, - created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), - CONSTRAINT conversations_session_unique UNIQUE (bot_id, session_id) -); - -CREATE INDEX IF NOT EXISTS idx_conversations_bot_id ON conversations(bot_id); - -CREATE TABLE IF NOT EXISTS history ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, - session_id TEXT NOT NULL, - messages JSONB NOT NULL, + kind TEXT NOT NULL CHECK (kind IN ('direct', 'group', 'thread')), + parent_chat_id UUID REFERENCES chats(id) ON DELETE CASCADE, + title TEXT, + created_by_user_id UUID REFERENCES users(id), metadata JSONB NOT NULL DEFAULT '{}'::jsonb, - skills TEXT[] NOT NULL DEFAULT '{}'::text[], - timestamp TIMESTAMPTZ NOT NULL -); - -CREATE INDEX IF NOT EXISTS idx_history_bot ON history(bot_id); -CREATE INDEX IF NOT EXISTS idx_history_session ON history(session_id); -CREATE INDEX IF NOT EXISTS idx_history_timestamp ON history(timestamp); - -CREATE TABLE IF NOT EXISTS user_channel_bindings ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, - channel_type TEXT NOT NULL, - config JSONB NOT NULL DEFAULT '{}'::jsonb, + enable_chat_memory BOOLEAN NOT NULL DEFAULT true, + enable_private_memory BOOLEAN NOT NULL DEFAULT true, + enable_public_memory BOOLEAN NOT NULL DEFAULT false, + model_id TEXT, + settings_metadata JSONB NOT NULL DEFAULT '{}'::jsonb, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), - CONSTRAINT user_channel_bindings_unique UNIQUE (user_id, channel_type) + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() ); -CREATE INDEX IF NOT EXISTS idx_user_channel_bindings_user_id ON user_channel_bindings(user_id); +CREATE INDEX IF NOT EXISTS idx_chats_bot_id ON chats(bot_id); +CREATE INDEX IF NOT EXISTS idx_chats_parent ON chats(parent_chat_id); + +-- chat_participants: chat membership +CREATE TABLE IF NOT EXISTS chat_participants ( + chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE, + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + role TEXT NOT NULL DEFAULT 'member' CHECK (role IN ('owner', 'admin', 'member')), + joined_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (chat_id, user_id) +); + +CREATE INDEX IF NOT EXISTS idx_chat_participants_user ON chat_participants(user_id); + +-- chat_messages: per-message storage (replaces history) +CREATE TABLE IF NOT EXISTS chat_messages ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE, + bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, + route_id UUID, + sender_channel_identity_id UUID REFERENCES channel_identities(id), + sender_user_id UUID REFERENCES users(id), + platform TEXT, + external_message_id TEXT, + role TEXT NOT NULL CHECK (role IN ('user', 'assistant', 'system', 'tool')), + content JSONB NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +-- Backfill newly introduced columns for existing deployments where chat_messages +-- was created before route/platform traceability fields were added. +ALTER TABLE IF EXISTS chat_messages + ADD COLUMN IF NOT EXISTS route_id UUID; + +ALTER TABLE IF EXISTS chat_messages + ADD COLUMN IF NOT EXISTS platform TEXT; + +ALTER TABLE IF EXISTS chat_messages + ADD COLUMN IF NOT EXISTS external_message_id TEXT; + +CREATE INDEX IF NOT EXISTS idx_chat_messages_chat_created ON chat_messages(chat_id, created_at); +CREATE INDEX IF NOT EXISTS idx_chat_messages_bot ON chat_messages(bot_id); +CREATE INDEX IF NOT EXISTS idx_chat_messages_route ON chat_messages(route_id); +CREATE INDEX IF NOT EXISTS idx_chat_messages_external_lookup + ON chat_messages(platform, external_message_id); + +-- chat_channel_identity_presence: derived cache of channel identities observed in chats +CREATE TABLE IF NOT EXISTS chat_channel_identity_presence ( + chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE, + channel_identity_id UUID NOT NULL REFERENCES channel_identities(id) ON DELETE CASCADE, + first_seen_at TIMESTAMPTZ NOT NULL DEFAULT now(), + last_seen_at TIMESTAMPTZ NOT NULL DEFAULT now(), + message_count BIGINT NOT NULL DEFAULT 1, + PRIMARY KEY (chat_id, channel_identity_id) +); + +CREATE INDEX IF NOT EXISTS idx_chat_channel_identity_presence_identity_last_seen + ON chat_channel_identity_presence(channel_identity_id, last_seen_at DESC); +CREATE INDEX IF NOT EXISTS idx_chat_channel_identity_presence_chat_last_seen + ON chat_channel_identity_presence(chat_id, last_seen_at DESC); CREATE TABLE IF NOT EXISTS bot_channel_configs ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), @@ -183,26 +250,6 @@ CREATE UNIQUE INDEX IF NOT EXISTS idx_bot_channel_external_identity CREATE INDEX IF NOT EXISTS idx_bot_channel_bot_id ON bot_channel_configs(bot_id); -CREATE TABLE IF NOT EXISTS contacts ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, - user_id UUID REFERENCES users(id) ON DELETE SET NULL, - display_name TEXT, - alias TEXT, - tags TEXT[] NOT NULL DEFAULT '{}'::text[], - status TEXT NOT NULL DEFAULT 'active', - metadata JSONB NOT NULL DEFAULT '{}'::jsonb, - created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), - CONSTRAINT contacts_status_check CHECK (status IN ('active', 'blocked', 'pending')) -); - -CREATE UNIQUE INDEX IF NOT EXISTS idx_contacts_bot_user_unique - ON contacts(bot_id, user_id) - WHERE user_id IS NOT NULL; - -CREATE INDEX IF NOT EXISTS idx_contacts_bot_id ON contacts(bot_id); - CREATE TABLE IF NOT EXISTS bot_preauth_keys ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, @@ -217,37 +264,40 @@ CREATE TABLE IF NOT EXISTS bot_preauth_keys ( CREATE INDEX IF NOT EXISTS idx_bot_preauth_keys_bot_id ON bot_preauth_keys(bot_id); CREATE INDEX IF NOT EXISTS idx_bot_preauth_keys_expires ON bot_preauth_keys(expires_at); -CREATE TABLE IF NOT EXISTS contact_channels ( +-- channel_identity_bind_codes: one-time codes for channel identity->user linking +CREATE TABLE IF NOT EXISTS channel_identity_bind_codes ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, - contact_id UUID NOT NULL REFERENCES contacts(id) ON DELETE CASCADE, - platform TEXT NOT NULL, - external_id TEXT NOT NULL, - metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + token TEXT NOT NULL, + issued_by_user_id UUID NOT NULL REFERENCES users(id), + platform TEXT, + expires_at TIMESTAMPTZ, + used_at TIMESTAMPTZ, + used_by_channel_identity_id UUID REFERENCES channel_identities(id), created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), - CONSTRAINT contact_channels_unique UNIQUE (bot_id, platform, external_id) + CONSTRAINT channel_identity_bind_codes_token_unique UNIQUE (token) ); -CREATE INDEX IF NOT EXISTS idx_contact_channels_contact_id ON contact_channels(contact_id); -CREATE INDEX IF NOT EXISTS idx_contact_channels_platform_external ON contact_channels(platform, external_id); +CREATE INDEX IF NOT EXISTS idx_channel_identity_bind_codes_platform ON channel_identity_bind_codes(platform); -CREATE TABLE IF NOT EXISTS channel_sessions ( - session_id TEXT PRIMARY KEY, +-- chat_routes: routing mapping (replaces channel_sessions) +CREATE TABLE IF NOT EXISTS chat_routes ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE, bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, - channel_config_id UUID REFERENCES bot_channel_configs(id) ON DELETE SET NULL, - user_id UUID REFERENCES users(id) ON DELETE CASCADE, - contact_id UUID REFERENCES contacts(id) ON DELETE SET NULL, platform TEXT NOT NULL, - reply_target TEXT, + channel_config_id UUID REFERENCES bot_channel_configs(id) ON DELETE SET NULL, + conversation_id TEXT NOT NULL, thread_id TEXT, + reply_target TEXT, metadata JSONB NOT NULL DEFAULT '{}'::jsonb, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now() ); -CREATE INDEX IF NOT EXISTS idx_channel_sessions_bot_id ON channel_sessions(bot_id); -CREATE INDEX IF NOT EXISTS idx_channel_sessions_user_id ON channel_sessions(user_id); +CREATE UNIQUE INDEX IF NOT EXISTS idx_chat_routes_unique + ON chat_routes (bot_id, platform, conversation_id, COALESCE(thread_id, '')); +CREATE INDEX IF NOT EXISTS idx_chat_routes_chat ON chat_routes(chat_id); +CREATE INDEX IF NOT EXISTS idx_chat_routes_bot ON chat_routes(bot_id); CREATE TABLE IF NOT EXISTS containers ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), @@ -339,11 +389,3 @@ CREATE TABLE IF NOT EXISTS subagents ( CREATE INDEX IF NOT EXISTS idx_subagents_bot_id ON subagents(bot_id); CREATE INDEX IF NOT EXISTS idx_subagents_deleted ON subagents(deleted); -CREATE TABLE IF NOT EXISTS user_settings ( - user_id UUID PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE, - chat_model_id TEXT, - memory_model_id TEXT, - embedding_model_id TEXT, - max_context_load_time INTEGER NOT NULL DEFAULT 1440, - language TEXT NOT NULL DEFAULT 'auto' -); diff --git a/db/queries/bind.sql b/db/queries/bind.sql new file mode 100644 index 00000000..01233862 --- /dev/null +++ b/db/queries/bind.sql @@ -0,0 +1,22 @@ +-- name: CreateBindCode :one +INSERT INTO channel_identity_bind_codes (token, issued_by_user_id, platform, expires_at) +VALUES ($1, $2, $3, $4) +RETURNING id, token, issued_by_user_id, platform, expires_at, used_at, used_by_channel_identity_id, created_at; + +-- name: GetBindCode :one +SELECT id, token, issued_by_user_id, platform, expires_at, used_at, used_by_channel_identity_id, created_at +FROM channel_identity_bind_codes +WHERE token = $1; + +-- name: GetBindCodeForUpdate :one +SELECT id, token, issued_by_user_id, platform, expires_at, used_at, used_by_channel_identity_id, created_at +FROM channel_identity_bind_codes +WHERE token = $1 +FOR UPDATE; + +-- name: MarkBindCodeUsed :one +UPDATE channel_identity_bind_codes +SET used_at = now(), used_by_channel_identity_id = $2 +WHERE id = $1 + AND used_at IS NULL +RETURNING id, token, issued_by_user_id, platform, expires_at, used_at, used_by_channel_identity_id, created_at; diff --git a/db/queries/bots.sql b/db/queries/bots.sql index 5dd0e46a..2a16131d 100644 --- a/db/queries/bots.sql +++ b/db/queries/bots.sql @@ -1,21 +1,21 @@ -- name: CreateBot :one INSERT INTO bots (owner_user_id, type, display_name, avatar_url, is_active, metadata) VALUES ($1, $2, $3, $4, $5, $6) -RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at; +RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at; -- name: GetBotByID :one -SELECT id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at +SELECT id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at FROM bots WHERE id = $1; -- name: ListBotsByOwner :many -SELECT id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at +SELECT id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at FROM bots WHERE owner_user_id = $1 ORDER BY created_at DESC; -- name: ListBotsByMember :many -SELECT b.id, b.owner_user_id, b.type, b.display_name, b.avatar_url, b.is_active, b.metadata, b.created_at, b.updated_at +SELECT b.id, b.owner_user_id, b.type, b.display_name, b.avatar_url, b.is_active, b.max_context_load_time, b.language, b.allow_guest, b.chat_model_id, b.memory_model_id, b.embedding_model_id, b.metadata, b.created_at, b.updated_at FROM bots b JOIN bot_members m ON m.bot_id = b.id WHERE m.user_id = $1 @@ -29,14 +29,14 @@ SET display_name = $2, metadata = $5, updated_at = now() WHERE id = $1 -RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at; +RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at; -- name: UpdateBotOwner :one UPDATE bots SET owner_user_id = $2, updated_at = now() WHERE id = $1 -RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at; +RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at; -- name: DeleteBotByID :exec DELETE FROM bots WHERE id = $1; diff --git a/db/queries/channel_identities.sql b/db/queries/channel_identities.sql new file mode 100644 index 00000000..584e9116 --- /dev/null +++ b/db/queries/channel_identities.sql @@ -0,0 +1,49 @@ +-- name: CreateChannelIdentity :one +INSERT INTO channel_identities (user_id, channel, channel_subject_id, display_name, metadata) +VALUES ($1, $2, $3, $4, $5) +RETURNING id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at; + +-- name: GetChannelIdentityByID :one +SELECT id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +FROM channel_identities +WHERE id = $1; + +-- name: GetChannelIdentityByIDForUpdate :one +SELECT id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +FROM channel_identities +WHERE id = $1 +FOR UPDATE; + +-- name: GetChannelIdentityByChannelSubject :one +SELECT id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +FROM channel_identities +WHERE channel = $1 AND channel_subject_id = $2; + +-- name: UpsertChannelIdentityByChannelSubject :one +INSERT INTO channel_identities (user_id, channel, channel_subject_id, display_name, metadata) +VALUES ($1, $2, $3, $4, $5) +ON CONFLICT (channel, channel_subject_id) +DO UPDATE SET + display_name = EXCLUDED.display_name, + metadata = EXCLUDED.metadata, + user_id = COALESCE(channel_identities.user_id, EXCLUDED.user_id), + updated_at = now() +RETURNING id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at; + +-- name: ListChannelIdentitiesByUserID :many +SELECT id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +FROM channel_identities +WHERE user_id = $1 +ORDER BY created_at DESC; + +-- name: SetChannelIdentityLinkedUser :one +UPDATE channel_identities +SET user_id = $2, updated_at = now() +WHERE id = $1 +RETURNING id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at; + +-- name: ClearChannelIdentityLinkedUser :one +UPDATE channel_identities +SET user_id = NULL, updated_at = now() +WHERE id = $1 +RETURNING id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at; diff --git a/db/queries/channels.sql b/db/queries/channels.sql index 9323a26f..99b1eb3f 100644 --- a/db/queries/channels.sql +++ b/db/queries/channels.sql @@ -34,54 +34,23 @@ WHERE channel_type = $1 ORDER BY created_at DESC; -- name: GetUserChannelBinding :one -SELECT id, user_id, channel_type, config, created_at, updated_at +SELECT id, user_id, platform, config, created_at, updated_at FROM user_channel_bindings -WHERE user_id = $1 AND channel_type = $2 +WHERE user_id = $1 AND platform = $2 LIMIT 1; -- name: UpsertUserChannelBinding :one -INSERT INTO user_channel_bindings (user_id, channel_type, config) +INSERT INTO user_channel_bindings (user_id, platform, config) VALUES ($1, $2, $3) -ON CONFLICT (user_id, channel_type) +ON CONFLICT (user_id, platform) DO UPDATE SET config = EXCLUDED.config, updated_at = now() -RETURNING id, user_id, channel_type, config, created_at, updated_at; +RETURNING id, user_id, platform, config, created_at, updated_at; --- name: ListUserChannelBindingsByType :many -SELECT id, user_id, channel_type, config, created_at, updated_at +-- name: ListUserChannelBindingsByPlatform :many +SELECT id, user_id, platform, config, created_at, updated_at FROM user_channel_bindings -WHERE channel_type = $1 +WHERE platform = $1 ORDER BY created_at DESC; --- name: GetChannelSessionByID :one -SELECT session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata, created_at, updated_at -FROM channel_sessions -WHERE session_id = $1 -LIMIT 1; - --- name: ListChannelSessionsByBotPlatform :many -SELECT session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata, created_at, updated_at -FROM channel_sessions -WHERE bot_id = $1 AND platform = $2 -ORDER BY updated_at DESC; - --- name: UpsertChannelSession :one -INSERT INTO channel_sessions (session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) -ON CONFLICT (session_id) -DO UPDATE SET - bot_id = EXCLUDED.bot_id, - channel_config_id = EXCLUDED.channel_config_id, - user_id = EXCLUDED.user_id, - contact_id = EXCLUDED.contact_id, - platform = EXCLUDED.platform, - reply_target = EXCLUDED.reply_target, - thread_id = EXCLUDED.thread_id, - metadata = EXCLUDED.metadata, - updated_at = now() -RETURNING session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata, created_at, updated_at; - --- name: DeleteChannelSession :exec -DELETE FROM channel_sessions -WHERE session_id = $1; diff --git a/db/queries/chats.sql b/db/queries/chats.sql new file mode 100644 index 00000000..501249a2 --- /dev/null +++ b/db/queries/chats.sql @@ -0,0 +1,214 @@ +-- name: CreateChat :one +INSERT INTO chats (bot_id, kind, parent_chat_id, title, created_by_user_id, metadata) +VALUES ($1, $2, $3, $4, $5, $6) +RETURNING id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata, created_at, updated_at; + +-- name: GetChatByID :one +SELECT id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata, created_at, updated_at +FROM chats +WHERE id = $1; + +-- name: ListChatsByBotAndUser :many +SELECT c.id, c.bot_id, c.kind, c.parent_chat_id, c.title, c.created_by_user_id, c.metadata, c.enable_chat_memory, c.enable_private_memory, c.enable_public_memory, c.model_id, c.settings_metadata, c.created_at, c.updated_at +FROM chats c +JOIN chat_participants cp ON cp.chat_id = c.id +WHERE c.bot_id = $1 AND cp.user_id = $2 +ORDER BY c.updated_at DESC; + +-- name: ListVisibleChatsByBotAndUser :many +WITH participant_chats AS ( + SELECT c.id, c.bot_id, c.kind, c.parent_chat_id, c.title, c.created_by_user_id, c.metadata, c.created_at, c.updated_at, + 'participant'::text AS access_mode, + cp.role AS participant_role, + NULL::timestamptz AS last_observed_at + FROM chats c + JOIN chat_participants cp ON cp.chat_id = c.id + WHERE c.bot_id = $1 AND cp.user_id = $2 +), +observed_chats AS ( + SELECT c.id, c.bot_id, c.kind, c.parent_chat_id, c.title, c.created_by_user_id, c.metadata, c.created_at, c.updated_at, + 'channel_identity_observed'::text AS access_mode, + ''::text AS participant_role, + MAX(cap.last_seen_at) AS last_observed_at + FROM chats c + JOIN chat_channel_identity_presence cap ON cap.chat_id = c.id + JOIN channel_identities ci ON ci.id = cap.channel_identity_id + WHERE c.bot_id = $1 + AND ci.user_id = $2 + AND NOT EXISTS ( + SELECT 1 FROM chat_participants cp + WHERE cp.chat_id = c.id AND cp.user_id = $2 + ) + GROUP BY c.id, c.bot_id, c.kind, c.parent_chat_id, c.title, c.created_by_user_id, c.metadata, c.created_at, c.updated_at +) +SELECT id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, created_at, updated_at, + access_mode, participant_role, last_observed_at +FROM ( + SELECT * FROM participant_chats + UNION ALL + SELECT * FROM observed_chats +) v +ORDER BY v.updated_at DESC, v.last_observed_at DESC NULLS LAST; + +-- name: GetChatReadAccessByUser :one +WITH participant_access AS ( + SELECT 'participant'::text AS access_mode, + cp.role AS participant_role, + NULL::timestamptz AS last_observed_at + FROM chat_participants cp + WHERE cp.chat_id = $1 AND cp.user_id = $2 +), +observed_access AS ( + SELECT 'channel_identity_observed'::text AS access_mode, + ''::text AS participant_role, + MAX(cap.last_seen_at) AS last_observed_at + FROM chat_channel_identity_presence cap + JOIN channel_identities ci ON ci.id = cap.channel_identity_id + WHERE cap.chat_id = $1 AND ci.user_id = $2 + GROUP BY cap.chat_id +), +all_access AS ( + SELECT * FROM participant_access + UNION ALL + SELECT * FROM observed_access +) +SELECT access_mode, participant_role, last_observed_at +FROM all_access +ORDER BY CASE WHEN access_mode = 'participant' THEN 0 ELSE 1 END, last_observed_at DESC NULLS LAST +LIMIT 1; + +-- name: ListThreadsByParent :many +SELECT id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata, created_at, updated_at +FROM chats +WHERE parent_chat_id = $1 AND kind = 'thread' +ORDER BY created_at DESC; + +-- name: UpdateChatTitle :one +UPDATE chats SET title = $2, updated_at = now() +WHERE id = $1 +RETURNING id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata, created_at, updated_at; + +-- name: TouchChat :exec +UPDATE chats SET updated_at = now() WHERE id = $1; + +-- name: DeleteChat :exec +DELETE FROM chats WHERE id = $1; + +-- chat_participants + +-- name: AddChatParticipant :one +INSERT INTO chat_participants (chat_id, user_id, role) +VALUES ($1, $2, $3) +ON CONFLICT (chat_id, user_id) DO UPDATE SET role = EXCLUDED.role +RETURNING chat_id, user_id, role, joined_at; + +-- name: GetChatParticipant :one +SELECT chat_id, user_id, role, joined_at +FROM chat_participants +WHERE chat_id = $1 AND user_id = $2; + +-- name: ListChatParticipants :many +SELECT chat_id, user_id, role, joined_at +FROM chat_participants +WHERE chat_id = $1 +ORDER BY joined_at ASC; + +-- name: RemoveChatParticipant :exec +DELETE FROM chat_participants WHERE chat_id = $1 AND user_id = $2; + +-- name: CopyParticipantsToChat :exec +INSERT INTO chat_participants (chat_id, user_id, role) +SELECT $2, cp.user_id, cp.role FROM chat_participants cp WHERE cp.chat_id = $1 +ON CONFLICT (chat_id, user_id) DO NOTHING; + +-- chat_settings + +-- name: UpsertChatSettings :one +UPDATE chats +SET enable_chat_memory = $2, + enable_private_memory = $3, + enable_public_memory = $4, + model_id = $5, + settings_metadata = $6 +WHERE id = $1 +RETURNING id AS chat_id, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata AS metadata, updated_at; + +-- name: GetChatSettings :one +SELECT id AS chat_id, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata AS metadata, updated_at +FROM chats +WHERE id = $1; + +-- chat_routes + +-- name: CreateChatRoute :one +INSERT INTO chat_routes (chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8) +RETURNING id, chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata, created_at, updated_at; + +-- name: FindChatRoute :one +SELECT id, chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata, created_at, updated_at +FROM chat_routes +WHERE bot_id = $1 AND platform = $2 AND conversation_id = $3 + AND COALESCE(thread_id, '') = COALESCE(sqlc.narg('thread_id'), '') +LIMIT 1; + +-- name: GetChatRouteByID :one +SELECT id, chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata, created_at, updated_at +FROM chat_routes +WHERE id = $1; + +-- name: ListChatRoutes :many +SELECT id, chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata, created_at, updated_at +FROM chat_routes +WHERE chat_id = $1 +ORDER BY created_at ASC; + +-- name: UpdateChatRouteReplyTarget :exec +UPDATE chat_routes SET reply_target = $2, updated_at = now() WHERE id = $1; + +-- name: DeleteChatRoute :exec +DELETE FROM chat_routes WHERE id = $1; + +-- chat_messages + +-- name: CreateChatMessage :one +INSERT INTO chat_messages (chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) +RETURNING id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at; + +-- name: UpsertChatChannelIdentityPresence :exec +INSERT INTO chat_channel_identity_presence (chat_id, channel_identity_id, first_seen_at, last_seen_at, message_count) +VALUES ($1, $2, now(), now(), 1) +ON CONFLICT (chat_id, channel_identity_id) +DO UPDATE SET + last_seen_at = now(), + message_count = chat_channel_identity_presence.message_count + 1; + +-- name: ListChatMessages :many +SELECT id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at +FROM chat_messages +WHERE chat_id = $1 +ORDER BY created_at ASC; + +-- name: ListChatMessagesSince :many +SELECT id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at +FROM chat_messages +WHERE chat_id = $1 AND created_at >= $2 +ORDER BY created_at ASC; + +-- name: ListChatMessagesBefore :many +SELECT id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at +FROM chat_messages +WHERE chat_id = $1 AND created_at < $2 +ORDER BY created_at DESC +LIMIT $3; + +-- name: ListChatMessagesLatest :many +SELECT id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at +FROM chat_messages +WHERE chat_id = $1 +ORDER BY created_at DESC +LIMIT $2; + +-- name: DeleteChatMessagesByChat :exec +DELETE FROM chat_messages WHERE chat_id = $1; diff --git a/db/queries/contacts.sql b/db/queries/contacts.sql deleted file mode 100644 index 7f5d9fe8..00000000 --- a/db/queries/contacts.sql +++ /dev/null @@ -1,76 +0,0 @@ --- name: CreateContact :one -INSERT INTO contacts (bot_id, user_id, display_name, alias, tags, status, metadata) -VALUES ($1, $2, $3, $4, $5, $6, $7) -RETURNING id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at; - --- name: GetContactByID :one -SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -FROM contacts -WHERE id = $1 -LIMIT 1; - --- name: GetContactByUserID :one -SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -FROM contacts -WHERE bot_id = $1 AND user_id = $2 -LIMIT 1; - --- name: ListContactsByBot :many -SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -FROM contacts -WHERE bot_id = $1 -ORDER BY created_at DESC; - --- name: SearchContacts :many -SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -FROM contacts -WHERE bot_id = $1 - AND ( - display_name ILIKE sqlc.arg(query) - OR alias ILIKE sqlc.arg(query) - OR EXISTS ( - SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE sqlc.arg(query) - ) - ) -ORDER BY created_at DESC; - --- name: UpdateContact :one -UPDATE contacts -SET display_name = COALESCE(sqlc.narg(display_name), display_name), - alias = COALESCE(sqlc.narg(alias), alias), - tags = COALESCE(sqlc.narg(tags), tags), - status = COALESCE(NULLIF(sqlc.arg(status)::text, ''), status), - metadata = COALESCE(sqlc.narg(metadata), metadata), - updated_at = now() -WHERE id = sqlc.arg(id) -RETURNING id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at; - --- name: UpdateContactUser :one -UPDATE contacts -SET user_id = $2, - updated_at = now() -WHERE id = $1 -RETURNING id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at; - --- name: UpsertContactChannel :one -INSERT INTO contact_channels (bot_id, contact_id, platform, external_id, metadata) -VALUES ($1, $2, $3, $4, $5) -ON CONFLICT (bot_id, platform, external_id) -DO UPDATE SET - contact_id = EXCLUDED.contact_id, - metadata = EXCLUDED.metadata, - updated_at = now() -RETURNING id, bot_id, contact_id, platform, external_id, metadata, created_at, updated_at; - --- name: GetContactChannelByIdentity :one -SELECT id, bot_id, contact_id, platform, external_id, metadata, created_at, updated_at -FROM contact_channels -WHERE bot_id = $1 AND platform = $2 AND external_id = $3 -LIMIT 1; - --- name: ListContactChannelsByContact :many -SELECT id, bot_id, contact_id, platform, external_id, metadata, created_at, updated_at -FROM contact_channels -WHERE contact_id = $1 -ORDER BY created_at DESC; - diff --git a/db/queries/history.sql b/db/queries/history.sql deleted file mode 100644 index 7e95576c..00000000 --- a/db/queries/history.sql +++ /dev/null @@ -1,31 +0,0 @@ --- name: CreateHistory :one -INSERT INTO history (bot_id, session_id, messages, metadata, skills, timestamp) -VALUES ($1, $2, $3, $4, $5, $6) -RETURNING id, bot_id, session_id, messages, metadata, skills, timestamp; - --- name: ListHistoryByBotSessionSince :many -SELECT id, bot_id, session_id, messages, metadata, skills, timestamp -FROM history -WHERE bot_id = $1 AND session_id = $2 AND timestamp >= $3 -ORDER BY timestamp ASC; - --- name: GetHistoryByID :one -SELECT id, bot_id, session_id, messages, metadata, skills, timestamp -FROM history -WHERE id = $1; - --- name: ListHistoryByBotSession :many -SELECT id, bot_id, session_id, messages, metadata, skills, timestamp -FROM history -WHERE bot_id = $1 AND session_id = $2 -ORDER BY timestamp DESC -LIMIT $3; - --- name: DeleteHistoryByID :exec -DELETE FROM history -WHERE id = $1; - --- name: DeleteHistoryByBotSession :exec -DELETE FROM history -WHERE bot_id = $1 AND session_id = $2; - diff --git a/db/queries/settings.sql b/db/queries/settings.sql index 4f35be30..926fc0b4 100644 --- a/db/queries/settings.sql +++ b/db/queries/settings.sql @@ -1,55 +1,64 @@ -- name: GetSettingsByUserID :one -SELECT user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language -FROM user_settings -WHERE user_id = $1; +SELECT id AS user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language +FROM users +WHERE id = $1; -- name: UpsertUserSettings :one -INSERT INTO user_settings (user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language) -VALUES ($1, $2, $3, $4, $5, $6) -ON CONFLICT (user_id) DO UPDATE SET - chat_model_id = EXCLUDED.chat_model_id, - memory_model_id = EXCLUDED.memory_model_id, - embedding_model_id = EXCLUDED.embedding_model_id, - max_context_load_time = EXCLUDED.max_context_load_time, - language = EXCLUDED.language -RETURNING user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language; +UPDATE users +SET chat_model_id = $2, + memory_model_id = $3, + embedding_model_id = $4, + max_context_load_time = $5, + language = $6, + updated_at = now() +WHERE id = $1 +RETURNING id AS user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language; -- name: GetSettingsByBotID :one -SELECT bot_id, max_context_load_time, language, allow_guest -FROM bot_settings -WHERE bot_id = $1; - --- name: GetBotModelConfigByBotID :one SELECT - bot_model_configs.bot_id, + bots.id AS bot_id, + bots.max_context_load_time, + bots.language, + bots.allow_guest, chat_models.model_id AS chat_model_id, memory_models.model_id AS memory_model_id, embedding_models.model_id AS embedding_model_id -FROM bot_model_configs -LEFT JOIN models AS chat_models ON chat_models.id = bot_model_configs.chat_model_id -LEFT JOIN models AS memory_models ON memory_models.id = bot_model_configs.memory_model_id -LEFT JOIN models AS embedding_models ON embedding_models.id = bot_model_configs.embedding_model_id -WHERE bot_model_configs.bot_id = $1; +FROM bots +LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id +LEFT JOIN models AS memory_models ON memory_models.id = bots.memory_model_id +LEFT JOIN models AS embedding_models ON embedding_models.id = bots.embedding_model_id +WHERE bots.id = $1; -- name: UpsertBotSettings :one -INSERT INTO bot_settings (bot_id, max_context_load_time, language, allow_guest) -VALUES ($1, $2, $3, $4) -ON CONFLICT (bot_id) DO UPDATE SET - max_context_load_time = EXCLUDED.max_context_load_time, - language = EXCLUDED.language, - allow_guest = EXCLUDED.allow_guest -RETURNING bot_id, max_context_load_time, language, allow_guest; - --- name: UpsertBotModelConfig :one -INSERT INTO bot_model_configs (bot_id, chat_model_id, memory_model_id, embedding_model_id) -VALUES ($1, $2, $3, $4) -ON CONFLICT (bot_id) DO UPDATE SET - chat_model_id = COALESCE(EXCLUDED.chat_model_id, bot_model_configs.chat_model_id), - memory_model_id = COALESCE(EXCLUDED.memory_model_id, bot_model_configs.memory_model_id), - embedding_model_id = COALESCE(EXCLUDED.embedding_model_id, bot_model_configs.embedding_model_id) -RETURNING bot_id, chat_model_id, memory_model_id, embedding_model_id; +WITH updated AS ( + UPDATE bots + SET max_context_load_time = sqlc.arg(max_context_load_time), + language = sqlc.arg(language), + allow_guest = sqlc.arg(allow_guest), + chat_model_id = COALESCE(sqlc.narg(chat_model_id)::uuid, bots.chat_model_id), + memory_model_id = COALESCE(sqlc.narg(memory_model_id)::uuid, bots.memory_model_id), + embedding_model_id = COALESCE(sqlc.narg(embedding_model_id)::uuid, bots.embedding_model_id), + updated_at = now() + WHERE bots.id = sqlc.arg(id) + RETURNING bots.id, bots.max_context_load_time, bots.language, bots.allow_guest, bots.chat_model_id, bots.memory_model_id, bots.embedding_model_id +) +SELECT + updated.id AS bot_id, + updated.max_context_load_time, + updated.language, + updated.allow_guest, + chat_models.model_id AS chat_model_id, + memory_models.model_id AS memory_model_id, + embedding_models.model_id AS embedding_model_id +FROM updated +LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id +LEFT JOIN models AS memory_models ON memory_models.id = updated.memory_model_id +LEFT JOIN models AS embedding_models ON embedding_models.id = updated.embedding_model_id; -- name: DeleteSettingsByBotID :exec -DELETE FROM bot_settings -WHERE bot_id = $1; - +UPDATE bots +SET max_context_load_time = 1440, + language = 'auto', + allow_guest = false, + updated_at = now() +WHERE id = $1; diff --git a/db/queries/users.sql b/db/queries/users.sql index 6506d935..87dc6e50 100644 --- a/db/queries/users.sql +++ b/db/queries/users.sql @@ -1,20 +1,38 @@ -- name: CreateUser :one -INSERT INTO users (username, email, password_hash, role, display_name, avatar_url, is_active, data_root) -VALUES ( - sqlc.arg(username), - sqlc.arg(email), - sqlc.arg(password_hash), - sqlc.arg(role)::user_role, - sqlc.arg(display_name), - sqlc.arg(avatar_url), - sqlc.arg(is_active), - sqlc.arg(data_root) -) +INSERT INTO users (is_active, metadata) +VALUES ($1, $2) RETURNING *; --- name: UpsertUserByUsername :one -INSERT INTO users (username, email, password_hash, role, display_name, avatar_url, is_active, data_root) +-- name: GetUserByID :one +SELECT * +FROM users +WHERE id = $1; + +-- name: UpdateUserStatus :one +UPDATE users +SET is_active = $2, + updated_at = now() +WHERE id = $1 +RETURNING *; + +-- name: CreateAccount :one +UPDATE users +SET username = sqlc.arg(username), + email = sqlc.arg(email), + password_hash = sqlc.arg(password_hash), + role = sqlc.arg(role)::user_role, + display_name = sqlc.arg(display_name), + avatar_url = sqlc.arg(avatar_url), + is_active = sqlc.arg(is_active), + data_root = sqlc.arg(data_root), + updated_at = now() +WHERE id = sqlc.arg(user_id) +RETURNING *; + +-- name: UpsertAccountByUsername :one +INSERT INTO users (id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, metadata) VALUES ( + sqlc.arg(user_id), sqlc.arg(username), sqlc.arg(email), sqlc.arg(password_hash), @@ -22,7 +40,8 @@ VALUES ( sqlc.arg(display_name), sqlc.arg(avatar_url), sqlc.arg(is_active), - sqlc.arg(data_root) + sqlc.arg(data_root), + '{}'::jsonb ) ON CONFLICT (username) DO UPDATE SET email = EXCLUDED.email, @@ -35,39 +54,27 @@ ON CONFLICT (username) DO UPDATE SET updated_at = now() RETURNING *; --- name: GetUserByUsername :one +-- name: GetAccountByUsername :one SELECT * FROM users WHERE username = sqlc.arg(username); --- name: GetUserByIdentity :one +-- name: GetAccountByIdentity :one SELECT * FROM users WHERE username = sqlc.arg(identity) OR email = sqlc.arg(identity); --- name: GetUserByID :one -SELECT * FROM users WHERE id = sqlc.arg(id); +-- name: GetAccountByUserID :one +SELECT * FROM users WHERE id = sqlc.arg(user_id); --- name: CreateUserWithID :one -INSERT INTO users (id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root) -VALUES ( - sqlc.arg(id), - sqlc.arg(username), - sqlc.arg(email), - sqlc.arg(password_hash), - sqlc.arg(role)::user_role, - sqlc.arg(display_name), - sqlc.arg(avatar_url), - sqlc.arg(is_active), - sqlc.arg(data_root) -) -RETURNING *; +-- name: CountAccounts :one +SELECT COUNT(*)::bigint AS count +FROM users +WHERE username IS NOT NULL + AND password_hash IS NOT NULL; --- name: CountUsers :one -SELECT COUNT(*)::bigint AS count FROM users; - --- name: ListUsers :many +-- name: ListAccounts :many SELECT * FROM users +WHERE username IS NOT NULL ORDER BY created_at DESC; - --- name: UpdateUserProfile :one +-- name: UpdateAccountProfile :one UPDATE users SET display_name = $2, avatar_url = $3, @@ -76,27 +83,26 @@ SET display_name = $2, WHERE id = $1 RETURNING *; --- name: UpdateUserAdmin :one +-- name: UpdateAccountAdmin :one UPDATE users SET role = sqlc.arg(role)::user_role, display_name = sqlc.arg(display_name), avatar_url = sqlc.arg(avatar_url), is_active = sqlc.arg(is_active), updated_at = now() -WHERE id = sqlc.arg(id) +WHERE id = sqlc.arg(user_id) RETURNING *; --- name: UpdateUserPassword :one +-- name: UpdateAccountPassword :one UPDATE users SET password_hash = $2, updated_at = now() WHERE id = $1 RETURNING *; --- name: UpdateUserLastLogin :one +-- name: UpdateAccountLastLogin :one UPDATE users SET last_login_at = now(), updated_at = now() WHERE id = $1 RETURNING *; - diff --git a/docs/docs.go b/docs/docs.go index 9187bec8..b8bd4e5f 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -148,98 +148,6 @@ const docTemplate = `{ } } }, - "/bots/{bot_id}/chat": { - "post": { - "description": "Send a chat message and get a response. The system will automatically select an appropriate chat model from the database.", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "chat" - ], - "summary": "Chat with AI", - "parameters": [ - { - "description": "Chat request", - "name": "request", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/chat.ChatRequest" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/chat.ChatResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/bots/{bot_id}/chat/stream": { - "post": { - "description": "Send a chat message and get a streaming response. The system will automatically select an appropriate chat model from the database.", - "consumes": [ - "application/json" - ], - "produces": [ - "text/event-stream" - ], - "tags": [ - "chat" - ], - "summary": "Stream chat with AI", - "parameters": [ - { - "description": "Chat request", - "name": "request", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/chat.ChatRequest" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "type": "string" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, "/bots/{bot_id}/container": { "get": { "tags": [ @@ -1128,188 +1036,6 @@ const docTemplate = `{ } } }, - "/bots/{bot_id}/history": { - "get": { - "description": "List history records for current user", - "tags": [ - "history" - ], - "summary": "List history records", - "parameters": [ - { - "type": "integer", - "description": "Limit", - "name": "limit", - "in": "query" - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/history.ListResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - }, - "post": { - "description": "Create a history record for current user", - "tags": [ - "history" - ], - "summary": "Create history record", - "parameters": [ - { - "description": "History payload", - "name": "payload", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/history.CreateRequest" - } - } - ], - "responses": { - "201": { - "description": "Created", - "schema": { - "$ref": "#/definitions/history.Record" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - }, - "delete": { - "description": "Delete all history records for current user", - "tags": [ - "history" - ], - "summary": "Delete all history records", - "responses": { - "204": { - "description": "No Content" - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/bots/{bot_id}/history/{id}": { - "get": { - "description": "Get a history record by ID (must belong to current user)", - "tags": [ - "history" - ], - "summary": "Get history record", - "parameters": [ - { - "type": "string", - "description": "History ID", - "name": "id", - "in": "path", - "required": true - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/history.Record" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - }, - "delete": { - "description": "Delete a history record by ID (must belong to current user)", - "tags": [ - "history" - ], - "summary": "Delete history record", - "parameters": [ - { - "type": "string", - "description": "History ID", - "name": "id", - "in": "path", - "required": true - } - ], - "responses": { - "204": { - "description": "No Content" - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, "/bots/{bot_id}/mcp": { "get": { "description": "List MCP connections for a bot", @@ -1666,321 +1392,6 @@ const docTemplate = `{ } } }, - "/bots/{bot_id}/memory/add": { - "post": { - "description": "Add memory for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).", - "tags": [ - "memory" - ], - "summary": "Add memory", - "parameters": [ - { - "description": "Add request", - "name": "payload", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/handlers.memoryAddPayload" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/memory.SearchResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/bots/{bot_id}/memory/embed": { - "post": { - "description": "Embed text or multimodal input and upsert into memory store. Auth: Bearer JWT determines user_id (sub or user_id).", - "tags": [ - "memory" - ], - "summary": "Embed and upsert memory", - "parameters": [ - { - "description": "Embed upsert request", - "name": "payload", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/handlers.memoryEmbedUpsertPayload" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/memory.EmbedUpsertResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/bots/{bot_id}/memory/memories": { - "get": { - "description": "List memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).", - "tags": [ - "memory" - ], - "summary": "List memories", - "parameters": [ - { - "type": "string", - "description": "Run ID", - "name": "run_id", - "in": "query" - }, - { - "type": "integer", - "description": "Limit", - "name": "limit", - "in": "query" - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/memory.SearchResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - }, - "delete": { - "description": "Delete all memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).", - "tags": [ - "memory" - ], - "summary": "Delete memories", - "parameters": [ - { - "description": "Delete all request", - "name": "payload", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/handlers.memoryDeleteAllPayload" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/memory.DeleteResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/bots/{bot_id}/memory/memories/{memoryId}": { - "get": { - "description": "Get a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).", - "tags": [ - "memory" - ], - "summary": "Get memory", - "parameters": [ - { - "type": "string", - "description": "Memory ID", - "name": "memoryId", - "in": "path", - "required": true - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/memory.MemoryItem" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - }, - "delete": { - "description": "Delete a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).", - "tags": [ - "memory" - ], - "summary": "Delete memory", - "parameters": [ - { - "type": "string", - "description": "Memory ID", - "name": "memoryId", - "in": "path", - "required": true - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/memory.DeleteResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/bots/{bot_id}/memory/search": { - "post": { - "description": "Search memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).", - "tags": [ - "memory" - ], - "summary": "Search memories", - "parameters": [ - { - "description": "Search request", - "name": "payload", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/handlers.memorySearchPayload" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/memory.SearchResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/bots/{bot_id}/memory/update": { - "post": { - "description": "Update a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).", - "tags": [ - "memory" - ], - "summary": "Update memory", - "parameters": [ - { - "description": "Update request", - "name": "payload", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/memory.UpdateRequest" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/memory.MemoryItem" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, "/bots/{bot_id}/schedule": { "get": { "description": "List schedules for current user", @@ -3083,7 +2494,7 @@ const docTemplate = `{ } } }, - "/bots/{id}/channel/{platform}/send_session": { + "/bots/{id}/channel/{platform}/send_chat": { "post": { "description": "Send a message using a session-scoped token (reply only)", "tags": [ @@ -4316,7 +3727,7 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/users.ListUsersResponse" + "$ref": "#/definitions/accounts.ListAccountsResponse" } }, "400": { @@ -4352,7 +3763,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/users.CreateUserRequest" + "$ref": "#/definitions/accounts.CreateAccountRequest" } } ], @@ -4360,7 +3771,7 @@ const docTemplate = `{ "201": { "description": "Created", "schema": { - "$ref": "#/definitions/users.User" + "$ref": "#/definitions/accounts.Account" } }, "400": { @@ -4395,7 +3806,7 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/users.User" + "$ref": "#/definitions/accounts.Account" } }, "400": { @@ -4425,7 +3836,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/users.UpdateProfileRequest" + "$ref": "#/definitions/accounts.UpdateProfileRequest" } } ], @@ -4433,7 +3844,7 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/users.User" + "$ref": "#/definitions/accounts.Account" } }, "400": { @@ -4471,7 +3882,7 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/channel.ChannelUserBinding" + "$ref": "#/definitions/channel.ChannelIdentityBinding" } }, "400": { @@ -4514,7 +3925,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/channel.UpsertUserConfigRequest" + "$ref": "#/definitions/channel.UpsertChannelIdentityConfigRequest" } } ], @@ -4522,7 +3933,7 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/channel.ChannelUserBinding" + "$ref": "#/definitions/channel.ChannelIdentityBinding" } }, "400": { @@ -4540,6 +3951,41 @@ const docTemplate = `{ } } }, + "/users/me/identities": { + "get": { + "description": "List all channel identities linked to current user", + "tags": [ + "users" + ], + "summary": "List current user's channel identities", + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/handlers.listMyIdentitiesResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, "/users/me/password": { "put": { "description": "Update current user password with current password check", @@ -4554,7 +4000,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/users.UpdatePasswordRequest" + "$ref": "#/definitions/accounts.UpdatePasswordRequest" } } ], @@ -4597,7 +4043,7 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/users.User" + "$ref": "#/definitions/accounts.Account" } }, "400": { @@ -4646,7 +4092,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/users.UpdateUserRequest" + "$ref": "#/definitions/accounts.UpdateAccountRequest" } } ], @@ -4654,7 +4100,7 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/users.User" + "$ref": "#/definitions/accounts.Account" } }, "400": { @@ -4705,7 +4151,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/users.ResetPasswordRequest" + "$ref": "#/definitions/accounts.ResetPasswordRequest" } } ], @@ -4742,6 +4188,125 @@ const docTemplate = `{ } }, "definitions": { + "accounts.Account": { + "type": "object", + "properties": { + "avatar_url": { + "type": "string" + }, + "created_at": { + "type": "string" + }, + "display_name": { + "type": "string" + }, + "email": { + "type": "string" + }, + "id": { + "type": "string" + }, + "is_active": { + "type": "boolean" + }, + "last_login_at": { + "type": "string" + }, + "role": { + "type": "string" + }, + "updated_at": { + "type": "string" + }, + "username": { + "type": "string" + } + } + }, + "accounts.CreateAccountRequest": { + "type": "object", + "properties": { + "avatar_url": { + "type": "string" + }, + "display_name": { + "type": "string" + }, + "email": { + "type": "string" + }, + "is_active": { + "type": "boolean" + }, + "password": { + "type": "string" + }, + "role": { + "type": "string" + }, + "username": { + "type": "string" + } + } + }, + "accounts.ListAccountsResponse": { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "$ref": "#/definitions/accounts.Account" + } + } + } + }, + "accounts.ResetPasswordRequest": { + "type": "object", + "properties": { + "new_password": { + "type": "string" + } + } + }, + "accounts.UpdateAccountRequest": { + "type": "object", + "properties": { + "avatar_url": { + "type": "string" + }, + "display_name": { + "type": "string" + }, + "is_active": { + "type": "boolean" + }, + "role": { + "type": "string" + } + } + }, + "accounts.UpdatePasswordRequest": { + "type": "object", + "properties": { + "current_password": { + "type": "string" + }, + "new_password": { + "type": "string" + } + } + }, + "accounts.UpdateProfileRequest": { + "type": "object", + "properties": { + "avatar_url": { + "type": "string" + }, + "display_name": { + "type": "string" + } + } + }, "bots.Bot": { "type": "object", "properties": { @@ -5044,9 +4609,12 @@ const docTemplate = `{ } } }, - "channel.ChannelUserBinding": { + "channel.ChannelIdentityBinding": { "type": "object", "properties": { + "channelIdentityID": { + "type": "string" + }, "channelType": { "type": "string" }, @@ -5062,9 +4630,6 @@ const docTemplate = `{ }, "updatedAt": { "type": "string" - }, - "userID": { - "type": "string" } } }, @@ -5181,6 +4746,9 @@ const docTemplate = `{ "channel.MessagePart": { "type": "object", "properties": { + "channel_identity_id": { + "type": "string" + }, "emoji": { "type": "string" }, @@ -5205,9 +4773,6 @@ const docTemplate = `{ }, "url": { "type": "string" - }, - "user_id": { - "type": "string" } } }, @@ -5257,14 +4822,14 @@ const docTemplate = `{ "channel.SendRequest": { "type": "object", "properties": { + "channel_identity_id": { + "type": "string" + }, "message": { "$ref": "#/definitions/channel.Message" }, "target": { "type": "string" - }, - "user_id": { - "type": "string" } } }, @@ -5301,6 +4866,15 @@ const docTemplate = `{ } } }, + "channel.UpsertChannelIdentityConfigRequest": { + "type": "object", + "properties": { + "config": { + "type": "object", + "additionalProperties": {} + } + } + }, "channel.UpsertConfigRequest": { "type": "object", "properties": { @@ -5327,132 +4901,32 @@ const docTemplate = `{ } } }, - "channel.UpsertUserConfigRequest": { + "channelidentities.ChannelIdentity": { "type": "object", "properties": { - "config": { - "type": "object", - "additionalProperties": {} - } - } - }, - "chat.ChatRequest": { - "type": "object", - "properties": { - "allowed_actions": { - "type": "array", - "items": { - "type": "string" - } - }, - "channels": { - "type": "array", - "items": { - "type": "string" - } - }, - "current_channel": { + "channel": { "type": "string" }, - "language": { + "channel_subject_id": { "type": "string" }, - "max_context_load_time": { - "type": "integer" - }, - "messages": { - "type": "array", - "items": { - "$ref": "#/definitions/chat.ModelMessage" - } - }, - "model": { + "created_at": { "type": "string" }, - "provider": { + "display_name": { "type": "string" }, - "query": { - "type": "string" - }, - "skills": { - "type": "array", - "items": { - "type": "string" - } - } - } - }, - "chat.ChatResponse": { - "type": "object", - "properties": { - "messages": { - "type": "array", - "items": { - "$ref": "#/definitions/chat.ModelMessage" - } - }, - "model": { - "type": "string" - }, - "provider": { - "type": "string" - }, - "skills": { - "type": "array", - "items": { - "type": "string" - } - } - } - }, - "chat.ModelMessage": { - "type": "object", - "properties": { - "content": { - "type": "array", - "items": { - "type": "integer" - } - }, - "name": { - "type": "string" - }, - "role": { - "type": "string" - }, - "tool_call_id": { - "type": "string" - }, - "tool_calls": { - "type": "array", - "items": { - "$ref": "#/definitions/chat.ToolCall" - } - } - } - }, - "chat.ToolCall": { - "type": "object", - "properties": { - "function": { - "$ref": "#/definitions/chat.ToolCallFunction" - }, "id": { "type": "string" }, - "type": { - "type": "string" - } - } - }, - "chat.ToolCallFunction": { - "type": "object", - "properties": { - "arguments": { + "metadata": { + "type": "object", + "additionalProperties": {} + }, + "updated_at": { "type": "string" }, - "name": { + "user_id": { "type": "string" } } @@ -5997,103 +5471,20 @@ const docTemplate = `{ } } }, - "handlers.memoryAddPayload": { + "handlers.listMyIdentitiesResponse": { "type": "object", "properties": { - "embedding_enabled": { - "type": "boolean" - }, - "filters": { - "type": "object", - "additionalProperties": {} - }, - "infer": { - "type": "boolean" - }, - "message": { - "type": "string" - }, - "messages": { + "items": { "type": "array", "items": { - "$ref": "#/definitions/memory.Message" + "$ref": "#/definitions/channelidentities.ChannelIdentity" } }, - "metadata": { - "type": "object", - "additionalProperties": {} - }, - "run_id": { + "user_id": { "type": "string" } } }, - "handlers.memoryDeleteAllPayload": { - "type": "object", - "properties": { - "run_id": { - "type": "string" - } - } - }, - "handlers.memoryEmbedUpsertPayload": { - "type": "object", - "properties": { - "filters": { - "type": "object", - "additionalProperties": {} - }, - "input": { - "$ref": "#/definitions/memory.EmbedInput" - }, - "metadata": { - "type": "object", - "additionalProperties": {} - }, - "model": { - "type": "string" - }, - "provider": { - "type": "string" - }, - "run_id": { - "type": "string" - }, - "source": { - "type": "string" - }, - "type": { - "type": "string" - } - } - }, - "handlers.memorySearchPayload": { - "type": "object", - "properties": { - "embedding_enabled": { - "type": "boolean" - }, - "filters": { - "type": "object", - "additionalProperties": {} - }, - "limit": { - "type": "integer" - }, - "query": { - "type": "string" - }, - "run_id": { - "type": "string" - }, - "sources": { - "type": "array", - "items": { - "type": "string" - } - } - } - }, "handlers.skillsOpResponse": { "type": "object", "properties": { @@ -6102,73 +5493,6 @@ const docTemplate = `{ } } }, - "history.CreateRequest": { - "type": "object", - "properties": { - "messages": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": {} - } - }, - "metadata": { - "type": "object", - "additionalProperties": {} - }, - "skills": { - "type": "array", - "items": { - "type": "string" - } - } - } - }, - "history.ListResponse": { - "type": "object", - "properties": { - "items": { - "type": "array", - "items": { - "$ref": "#/definitions/history.Record" - } - } - } - }, - "history.Record": { - "type": "object", - "properties": { - "bot_id": { - "type": "string" - }, - "id": { - "type": "string" - }, - "messages": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": {} - } - }, - "metadata": { - "type": "object", - "additionalProperties": {} - }, - "session_id": { - "type": "string" - }, - "skills": { - "type": "array", - "items": { - "type": "string" - } - }, - "timestamp": { - "type": "string" - } - } - }, "mcp.ListResponse": { "type": "object", "properties": { @@ -6198,124 +5522,6 @@ const docTemplate = `{ } } }, - "memory.DeleteResponse": { - "type": "object", - "properties": { - "message": { - "type": "string" - } - } - }, - "memory.EmbedInput": { - "type": "object", - "properties": { - "image_url": { - "type": "string" - }, - "text": { - "type": "string" - }, - "video_url": { - "type": "string" - } - } - }, - "memory.EmbedUpsertResponse": { - "type": "object", - "properties": { - "dimensions": { - "type": "integer" - }, - "item": { - "$ref": "#/definitions/memory.MemoryItem" - }, - "model": { - "type": "string" - }, - "provider": { - "type": "string" - } - } - }, - "memory.MemoryItem": { - "type": "object", - "properties": { - "agentId": { - "type": "string" - }, - "botId": { - "type": "string" - }, - "createdAt": { - "type": "string" - }, - "hash": { - "type": "string" - }, - "id": { - "type": "string" - }, - "memory": { - "type": "string" - }, - "metadata": { - "type": "object", - "additionalProperties": {} - }, - "runId": { - "type": "string" - }, - "score": { - "type": "number" - }, - "sessionId": { - "type": "string" - }, - "updatedAt": { - "type": "string" - } - } - }, - "memory.Message": { - "type": "object", - "properties": { - "content": { - "type": "string" - }, - "role": { - "type": "string" - } - } - }, - "memory.SearchResponse": { - "type": "object", - "properties": { - "relations": { - "type": "array", - "items": {} - }, - "results": { - "type": "array", - "items": { - "$ref": "#/definitions/memory.MemoryItem" - } - } - } - }, - "memory.UpdateRequest": { - "type": "object", - "properties": { - "embedding_enabled": { - "type": "boolean" - }, - "memory": { - "type": "string" - }, - "memory_id": { - "type": "string" - } - } - }, "models.AddRequest": { "type": "object", "properties": { @@ -6844,125 +6050,6 @@ const docTemplate = `{ } } } - }, - "users.CreateUserRequest": { - "type": "object", - "properties": { - "avatar_url": { - "type": "string" - }, - "display_name": { - "type": "string" - }, - "email": { - "type": "string" - }, - "is_active": { - "type": "boolean" - }, - "password": { - "type": "string" - }, - "role": { - "type": "string" - }, - "username": { - "type": "string" - } - } - }, - "users.ListUsersResponse": { - "type": "object", - "properties": { - "items": { - "type": "array", - "items": { - "$ref": "#/definitions/users.User" - } - } - } - }, - "users.ResetPasswordRequest": { - "type": "object", - "properties": { - "new_password": { - "type": "string" - } - } - }, - "users.UpdatePasswordRequest": { - "type": "object", - "properties": { - "current_password": { - "type": "string" - }, - "new_password": { - "type": "string" - } - } - }, - "users.UpdateProfileRequest": { - "type": "object", - "properties": { - "avatar_url": { - "type": "string" - }, - "display_name": { - "type": "string" - } - } - }, - "users.UpdateUserRequest": { - "type": "object", - "properties": { - "avatar_url": { - "type": "string" - }, - "display_name": { - "type": "string" - }, - "is_active": { - "type": "boolean" - }, - "role": { - "type": "string" - } - } - }, - "users.User": { - "type": "object", - "properties": { - "avatar_url": { - "type": "string" - }, - "created_at": { - "type": "string" - }, - "display_name": { - "type": "string" - }, - "email": { - "type": "string" - }, - "id": { - "type": "string" - }, - "is_active": { - "type": "boolean" - }, - "last_login_at": { - "type": "string" - }, - "role": { - "type": "string" - }, - "updated_at": { - "type": "string" - }, - "username": { - "type": "string" - } - } } } }` diff --git a/docs/swagger.json b/docs/swagger.json index 57d5d4ba..465dd098 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -139,98 +139,6 @@ } } }, - "/bots/{bot_id}/chat": { - "post": { - "description": "Send a chat message and get a response. The system will automatically select an appropriate chat model from the database.", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "chat" - ], - "summary": "Chat with AI", - "parameters": [ - { - "description": "Chat request", - "name": "request", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/chat.ChatRequest" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/chat.ChatResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/bots/{bot_id}/chat/stream": { - "post": { - "description": "Send a chat message and get a streaming response. The system will automatically select an appropriate chat model from the database.", - "consumes": [ - "application/json" - ], - "produces": [ - "text/event-stream" - ], - "tags": [ - "chat" - ], - "summary": "Stream chat with AI", - "parameters": [ - { - "description": "Chat request", - "name": "request", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/chat.ChatRequest" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "type": "string" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, "/bots/{bot_id}/container": { "get": { "tags": [ @@ -1119,188 +1027,6 @@ } } }, - "/bots/{bot_id}/history": { - "get": { - "description": "List history records for current user", - "tags": [ - "history" - ], - "summary": "List history records", - "parameters": [ - { - "type": "integer", - "description": "Limit", - "name": "limit", - "in": "query" - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/history.ListResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - }, - "post": { - "description": "Create a history record for current user", - "tags": [ - "history" - ], - "summary": "Create history record", - "parameters": [ - { - "description": "History payload", - "name": "payload", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/history.CreateRequest" - } - } - ], - "responses": { - "201": { - "description": "Created", - "schema": { - "$ref": "#/definitions/history.Record" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - }, - "delete": { - "description": "Delete all history records for current user", - "tags": [ - "history" - ], - "summary": "Delete all history records", - "responses": { - "204": { - "description": "No Content" - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/bots/{bot_id}/history/{id}": { - "get": { - "description": "Get a history record by ID (must belong to current user)", - "tags": [ - "history" - ], - "summary": "Get history record", - "parameters": [ - { - "type": "string", - "description": "History ID", - "name": "id", - "in": "path", - "required": true - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/history.Record" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - }, - "delete": { - "description": "Delete a history record by ID (must belong to current user)", - "tags": [ - "history" - ], - "summary": "Delete history record", - "parameters": [ - { - "type": "string", - "description": "History ID", - "name": "id", - "in": "path", - "required": true - } - ], - "responses": { - "204": { - "description": "No Content" - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, "/bots/{bot_id}/mcp": { "get": { "description": "List MCP connections for a bot", @@ -1657,321 +1383,6 @@ } } }, - "/bots/{bot_id}/memory/add": { - "post": { - "description": "Add memory for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).", - "tags": [ - "memory" - ], - "summary": "Add memory", - "parameters": [ - { - "description": "Add request", - "name": "payload", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/handlers.memoryAddPayload" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/memory.SearchResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/bots/{bot_id}/memory/embed": { - "post": { - "description": "Embed text or multimodal input and upsert into memory store. Auth: Bearer JWT determines user_id (sub or user_id).", - "tags": [ - "memory" - ], - "summary": "Embed and upsert memory", - "parameters": [ - { - "description": "Embed upsert request", - "name": "payload", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/handlers.memoryEmbedUpsertPayload" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/memory.EmbedUpsertResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/bots/{bot_id}/memory/memories": { - "get": { - "description": "List memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).", - "tags": [ - "memory" - ], - "summary": "List memories", - "parameters": [ - { - "type": "string", - "description": "Run ID", - "name": "run_id", - "in": "query" - }, - { - "type": "integer", - "description": "Limit", - "name": "limit", - "in": "query" - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/memory.SearchResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - }, - "delete": { - "description": "Delete all memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).", - "tags": [ - "memory" - ], - "summary": "Delete memories", - "parameters": [ - { - "description": "Delete all request", - "name": "payload", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/handlers.memoryDeleteAllPayload" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/memory.DeleteResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/bots/{bot_id}/memory/memories/{memoryId}": { - "get": { - "description": "Get a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).", - "tags": [ - "memory" - ], - "summary": "Get memory", - "parameters": [ - { - "type": "string", - "description": "Memory ID", - "name": "memoryId", - "in": "path", - "required": true - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/memory.MemoryItem" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - }, - "delete": { - "description": "Delete a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).", - "tags": [ - "memory" - ], - "summary": "Delete memory", - "parameters": [ - { - "type": "string", - "description": "Memory ID", - "name": "memoryId", - "in": "path", - "required": true - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/memory.DeleteResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/bots/{bot_id}/memory/search": { - "post": { - "description": "Search memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).", - "tags": [ - "memory" - ], - "summary": "Search memories", - "parameters": [ - { - "description": "Search request", - "name": "payload", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/handlers.memorySearchPayload" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/memory.SearchResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/bots/{bot_id}/memory/update": { - "post": { - "description": "Update a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).", - "tags": [ - "memory" - ], - "summary": "Update memory", - "parameters": [ - { - "description": "Update request", - "name": "payload", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/memory.UpdateRequest" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/memory.MemoryItem" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, "/bots/{bot_id}/schedule": { "get": { "description": "List schedules for current user", @@ -3074,7 +2485,7 @@ } } }, - "/bots/{id}/channel/{platform}/send_session": { + "/bots/{id}/channel/{platform}/send_chat": { "post": { "description": "Send a message using a session-scoped token (reply only)", "tags": [ @@ -4307,7 +3718,7 @@ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/users.ListUsersResponse" + "$ref": "#/definitions/accounts.ListAccountsResponse" } }, "400": { @@ -4343,7 +3754,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/users.CreateUserRequest" + "$ref": "#/definitions/accounts.CreateAccountRequest" } } ], @@ -4351,7 +3762,7 @@ "201": { "description": "Created", "schema": { - "$ref": "#/definitions/users.User" + "$ref": "#/definitions/accounts.Account" } }, "400": { @@ -4386,7 +3797,7 @@ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/users.User" + "$ref": "#/definitions/accounts.Account" } }, "400": { @@ -4416,7 +3827,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/users.UpdateProfileRequest" + "$ref": "#/definitions/accounts.UpdateProfileRequest" } } ], @@ -4424,7 +3835,7 @@ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/users.User" + "$ref": "#/definitions/accounts.Account" } }, "400": { @@ -4462,7 +3873,7 @@ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/channel.ChannelUserBinding" + "$ref": "#/definitions/channel.ChannelIdentityBinding" } }, "400": { @@ -4505,7 +3916,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/channel.UpsertUserConfigRequest" + "$ref": "#/definitions/channel.UpsertChannelIdentityConfigRequest" } } ], @@ -4513,7 +3924,7 @@ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/channel.ChannelUserBinding" + "$ref": "#/definitions/channel.ChannelIdentityBinding" } }, "400": { @@ -4531,6 +3942,41 @@ } } }, + "/users/me/identities": { + "get": { + "description": "List all channel identities linked to current user", + "tags": [ + "users" + ], + "summary": "List current user's channel identities", + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/handlers.listMyIdentitiesResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, "/users/me/password": { "put": { "description": "Update current user password with current password check", @@ -4545,7 +3991,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/users.UpdatePasswordRequest" + "$ref": "#/definitions/accounts.UpdatePasswordRequest" } } ], @@ -4588,7 +4034,7 @@ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/users.User" + "$ref": "#/definitions/accounts.Account" } }, "400": { @@ -4637,7 +4083,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/users.UpdateUserRequest" + "$ref": "#/definitions/accounts.UpdateAccountRequest" } } ], @@ -4645,7 +4091,7 @@ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/users.User" + "$ref": "#/definitions/accounts.Account" } }, "400": { @@ -4696,7 +4142,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/users.ResetPasswordRequest" + "$ref": "#/definitions/accounts.ResetPasswordRequest" } } ], @@ -4733,6 +4179,125 @@ } }, "definitions": { + "accounts.Account": { + "type": "object", + "properties": { + "avatar_url": { + "type": "string" + }, + "created_at": { + "type": "string" + }, + "display_name": { + "type": "string" + }, + "email": { + "type": "string" + }, + "id": { + "type": "string" + }, + "is_active": { + "type": "boolean" + }, + "last_login_at": { + "type": "string" + }, + "role": { + "type": "string" + }, + "updated_at": { + "type": "string" + }, + "username": { + "type": "string" + } + } + }, + "accounts.CreateAccountRequest": { + "type": "object", + "properties": { + "avatar_url": { + "type": "string" + }, + "display_name": { + "type": "string" + }, + "email": { + "type": "string" + }, + "is_active": { + "type": "boolean" + }, + "password": { + "type": "string" + }, + "role": { + "type": "string" + }, + "username": { + "type": "string" + } + } + }, + "accounts.ListAccountsResponse": { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "$ref": "#/definitions/accounts.Account" + } + } + } + }, + "accounts.ResetPasswordRequest": { + "type": "object", + "properties": { + "new_password": { + "type": "string" + } + } + }, + "accounts.UpdateAccountRequest": { + "type": "object", + "properties": { + "avatar_url": { + "type": "string" + }, + "display_name": { + "type": "string" + }, + "is_active": { + "type": "boolean" + }, + "role": { + "type": "string" + } + } + }, + "accounts.UpdatePasswordRequest": { + "type": "object", + "properties": { + "current_password": { + "type": "string" + }, + "new_password": { + "type": "string" + } + } + }, + "accounts.UpdateProfileRequest": { + "type": "object", + "properties": { + "avatar_url": { + "type": "string" + }, + "display_name": { + "type": "string" + } + } + }, "bots.Bot": { "type": "object", "properties": { @@ -5035,9 +4600,12 @@ } } }, - "channel.ChannelUserBinding": { + "channel.ChannelIdentityBinding": { "type": "object", "properties": { + "channelIdentityID": { + "type": "string" + }, "channelType": { "type": "string" }, @@ -5053,9 +4621,6 @@ }, "updatedAt": { "type": "string" - }, - "userID": { - "type": "string" } } }, @@ -5172,6 +4737,9 @@ "channel.MessagePart": { "type": "object", "properties": { + "channel_identity_id": { + "type": "string" + }, "emoji": { "type": "string" }, @@ -5196,9 +4764,6 @@ }, "url": { "type": "string" - }, - "user_id": { - "type": "string" } } }, @@ -5248,14 +4813,14 @@ "channel.SendRequest": { "type": "object", "properties": { + "channel_identity_id": { + "type": "string" + }, "message": { "$ref": "#/definitions/channel.Message" }, "target": { "type": "string" - }, - "user_id": { - "type": "string" } } }, @@ -5292,6 +4857,15 @@ } } }, + "channel.UpsertChannelIdentityConfigRequest": { + "type": "object", + "properties": { + "config": { + "type": "object", + "additionalProperties": {} + } + } + }, "channel.UpsertConfigRequest": { "type": "object", "properties": { @@ -5318,132 +4892,32 @@ } } }, - "channel.UpsertUserConfigRequest": { + "channelidentities.ChannelIdentity": { "type": "object", "properties": { - "config": { - "type": "object", - "additionalProperties": {} - } - } - }, - "chat.ChatRequest": { - "type": "object", - "properties": { - "allowed_actions": { - "type": "array", - "items": { - "type": "string" - } - }, - "channels": { - "type": "array", - "items": { - "type": "string" - } - }, - "current_channel": { + "channel": { "type": "string" }, - "language": { + "channel_subject_id": { "type": "string" }, - "max_context_load_time": { - "type": "integer" - }, - "messages": { - "type": "array", - "items": { - "$ref": "#/definitions/chat.ModelMessage" - } - }, - "model": { + "created_at": { "type": "string" }, - "provider": { + "display_name": { "type": "string" }, - "query": { - "type": "string" - }, - "skills": { - "type": "array", - "items": { - "type": "string" - } - } - } - }, - "chat.ChatResponse": { - "type": "object", - "properties": { - "messages": { - "type": "array", - "items": { - "$ref": "#/definitions/chat.ModelMessage" - } - }, - "model": { - "type": "string" - }, - "provider": { - "type": "string" - }, - "skills": { - "type": "array", - "items": { - "type": "string" - } - } - } - }, - "chat.ModelMessage": { - "type": "object", - "properties": { - "content": { - "type": "array", - "items": { - "type": "integer" - } - }, - "name": { - "type": "string" - }, - "role": { - "type": "string" - }, - "tool_call_id": { - "type": "string" - }, - "tool_calls": { - "type": "array", - "items": { - "$ref": "#/definitions/chat.ToolCall" - } - } - } - }, - "chat.ToolCall": { - "type": "object", - "properties": { - "function": { - "$ref": "#/definitions/chat.ToolCallFunction" - }, "id": { "type": "string" }, - "type": { - "type": "string" - } - } - }, - "chat.ToolCallFunction": { - "type": "object", - "properties": { - "arguments": { + "metadata": { + "type": "object", + "additionalProperties": {} + }, + "updated_at": { "type": "string" }, - "name": { + "user_id": { "type": "string" } } @@ -5988,103 +5462,20 @@ } } }, - "handlers.memoryAddPayload": { + "handlers.listMyIdentitiesResponse": { "type": "object", "properties": { - "embedding_enabled": { - "type": "boolean" - }, - "filters": { - "type": "object", - "additionalProperties": {} - }, - "infer": { - "type": "boolean" - }, - "message": { - "type": "string" - }, - "messages": { + "items": { "type": "array", "items": { - "$ref": "#/definitions/memory.Message" + "$ref": "#/definitions/channelidentities.ChannelIdentity" } }, - "metadata": { - "type": "object", - "additionalProperties": {} - }, - "run_id": { + "user_id": { "type": "string" } } }, - "handlers.memoryDeleteAllPayload": { - "type": "object", - "properties": { - "run_id": { - "type": "string" - } - } - }, - "handlers.memoryEmbedUpsertPayload": { - "type": "object", - "properties": { - "filters": { - "type": "object", - "additionalProperties": {} - }, - "input": { - "$ref": "#/definitions/memory.EmbedInput" - }, - "metadata": { - "type": "object", - "additionalProperties": {} - }, - "model": { - "type": "string" - }, - "provider": { - "type": "string" - }, - "run_id": { - "type": "string" - }, - "source": { - "type": "string" - }, - "type": { - "type": "string" - } - } - }, - "handlers.memorySearchPayload": { - "type": "object", - "properties": { - "embedding_enabled": { - "type": "boolean" - }, - "filters": { - "type": "object", - "additionalProperties": {} - }, - "limit": { - "type": "integer" - }, - "query": { - "type": "string" - }, - "run_id": { - "type": "string" - }, - "sources": { - "type": "array", - "items": { - "type": "string" - } - } - } - }, "handlers.skillsOpResponse": { "type": "object", "properties": { @@ -6093,73 +5484,6 @@ } } }, - "history.CreateRequest": { - "type": "object", - "properties": { - "messages": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": {} - } - }, - "metadata": { - "type": "object", - "additionalProperties": {} - }, - "skills": { - "type": "array", - "items": { - "type": "string" - } - } - } - }, - "history.ListResponse": { - "type": "object", - "properties": { - "items": { - "type": "array", - "items": { - "$ref": "#/definitions/history.Record" - } - } - } - }, - "history.Record": { - "type": "object", - "properties": { - "bot_id": { - "type": "string" - }, - "id": { - "type": "string" - }, - "messages": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": {} - } - }, - "metadata": { - "type": "object", - "additionalProperties": {} - }, - "session_id": { - "type": "string" - }, - "skills": { - "type": "array", - "items": { - "type": "string" - } - }, - "timestamp": { - "type": "string" - } - } - }, "mcp.ListResponse": { "type": "object", "properties": { @@ -6189,124 +5513,6 @@ } } }, - "memory.DeleteResponse": { - "type": "object", - "properties": { - "message": { - "type": "string" - } - } - }, - "memory.EmbedInput": { - "type": "object", - "properties": { - "image_url": { - "type": "string" - }, - "text": { - "type": "string" - }, - "video_url": { - "type": "string" - } - } - }, - "memory.EmbedUpsertResponse": { - "type": "object", - "properties": { - "dimensions": { - "type": "integer" - }, - "item": { - "$ref": "#/definitions/memory.MemoryItem" - }, - "model": { - "type": "string" - }, - "provider": { - "type": "string" - } - } - }, - "memory.MemoryItem": { - "type": "object", - "properties": { - "agentId": { - "type": "string" - }, - "botId": { - "type": "string" - }, - "createdAt": { - "type": "string" - }, - "hash": { - "type": "string" - }, - "id": { - "type": "string" - }, - "memory": { - "type": "string" - }, - "metadata": { - "type": "object", - "additionalProperties": {} - }, - "runId": { - "type": "string" - }, - "score": { - "type": "number" - }, - "sessionId": { - "type": "string" - }, - "updatedAt": { - "type": "string" - } - } - }, - "memory.Message": { - "type": "object", - "properties": { - "content": { - "type": "string" - }, - "role": { - "type": "string" - } - } - }, - "memory.SearchResponse": { - "type": "object", - "properties": { - "relations": { - "type": "array", - "items": {} - }, - "results": { - "type": "array", - "items": { - "$ref": "#/definitions/memory.MemoryItem" - } - } - } - }, - "memory.UpdateRequest": { - "type": "object", - "properties": { - "embedding_enabled": { - "type": "boolean" - }, - "memory": { - "type": "string" - }, - "memory_id": { - "type": "string" - } - } - }, "models.AddRequest": { "type": "object", "properties": { @@ -6835,125 +6041,6 @@ } } } - }, - "users.CreateUserRequest": { - "type": "object", - "properties": { - "avatar_url": { - "type": "string" - }, - "display_name": { - "type": "string" - }, - "email": { - "type": "string" - }, - "is_active": { - "type": "boolean" - }, - "password": { - "type": "string" - }, - "role": { - "type": "string" - }, - "username": { - "type": "string" - } - } - }, - "users.ListUsersResponse": { - "type": "object", - "properties": { - "items": { - "type": "array", - "items": { - "$ref": "#/definitions/users.User" - } - } - } - }, - "users.ResetPasswordRequest": { - "type": "object", - "properties": { - "new_password": { - "type": "string" - } - } - }, - "users.UpdatePasswordRequest": { - "type": "object", - "properties": { - "current_password": { - "type": "string" - }, - "new_password": { - "type": "string" - } - } - }, - "users.UpdateProfileRequest": { - "type": "object", - "properties": { - "avatar_url": { - "type": "string" - }, - "display_name": { - "type": "string" - } - } - }, - "users.UpdateUserRequest": { - "type": "object", - "properties": { - "avatar_url": { - "type": "string" - }, - "display_name": { - "type": "string" - }, - "is_active": { - "type": "boolean" - }, - "role": { - "type": "string" - } - } - }, - "users.User": { - "type": "object", - "properties": { - "avatar_url": { - "type": "string" - }, - "created_at": { - "type": "string" - }, - "display_name": { - "type": "string" - }, - "email": { - "type": "string" - }, - "id": { - "type": "string" - }, - "is_active": { - "type": "boolean" - }, - "last_login_at": { - "type": "string" - }, - "role": { - "type": "string" - }, - "updated_at": { - "type": "string" - }, - "username": { - "type": "string" - } - } } } } \ No newline at end of file diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 6f690b22..1e52852c 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -1,4 +1,81 @@ definitions: + accounts.Account: + properties: + avatar_url: + type: string + created_at: + type: string + display_name: + type: string + email: + type: string + id: + type: string + is_active: + type: boolean + last_login_at: + type: string + role: + type: string + updated_at: + type: string + username: + type: string + type: object + accounts.CreateAccountRequest: + properties: + avatar_url: + type: string + display_name: + type: string + email: + type: string + is_active: + type: boolean + password: + type: string + role: + type: string + username: + type: string + type: object + accounts.ListAccountsResponse: + properties: + items: + items: + $ref: '#/definitions/accounts.Account' + type: array + type: object + accounts.ResetPasswordRequest: + properties: + new_password: + type: string + type: object + accounts.UpdateAccountRequest: + properties: + avatar_url: + type: string + display_name: + type: string + is_active: + type: boolean + role: + type: string + type: object + accounts.UpdatePasswordRequest: + properties: + current_password: + type: string + new_password: + type: string + type: object + accounts.UpdateProfileRequest: + properties: + avatar_url: + type: string + display_name: + type: string + type: object bots.Bot: properties: avatar_url: @@ -202,8 +279,10 @@ definitions: verifiedAt: type: string type: object - channel.ChannelUserBinding: + channel.ChannelIdentityBinding: properties: + channelIdentityID: + type: string channelType: type: string config: @@ -215,8 +294,6 @@ definitions: type: string updatedAt: type: string - userID: - type: string type: object channel.ConfigSchema: properties: @@ -297,6 +374,8 @@ definitions: - MessageFormatRich channel.MessagePart: properties: + channel_identity_id: + type: string emoji: type: string language: @@ -314,8 +393,6 @@ definitions: $ref: '#/definitions/channel.MessagePartType' url: type: string - user_id: - type: string type: object channel.MessagePartType: enum: @@ -352,12 +429,12 @@ definitions: type: object channel.SendRequest: properties: + channel_identity_id: + type: string message: $ref: '#/definitions/channel.Message' target: type: string - user_id: - type: string type: object channel.TargetHint: properties: @@ -380,6 +457,12 @@ definitions: id: type: string type: object + channel.UpsertChannelIdentityConfigRequest: + properties: + config: + additionalProperties: {} + type: object + type: object channel.UpsertConfigRequest: properties: credentials: @@ -398,89 +481,24 @@ definitions: verified_at: type: string type: object - channel.UpsertUserConfigRequest: + channelidentities.ChannelIdentity: properties: - config: - additionalProperties: {} - type: object - type: object - chat.ChatRequest: - properties: - allowed_actions: - items: - type: string - type: array - channels: - items: - type: string - type: array - current_channel: + channel: type: string - language: + channel_subject_id: type: string - max_context_load_time: - type: integer - messages: - items: - $ref: '#/definitions/chat.ModelMessage' - type: array - model: + created_at: type: string - provider: + display_name: type: string - query: - type: string - skills: - items: - type: string - type: array - type: object - chat.ChatResponse: - properties: - messages: - items: - $ref: '#/definitions/chat.ModelMessage' - type: array - model: - type: string - provider: - type: string - skills: - items: - type: string - type: array - type: object - chat.ModelMessage: - properties: - content: - items: - type: integer - type: array - name: - type: string - role: - type: string - tool_call_id: - type: string - tool_calls: - items: - $ref: '#/definitions/chat.ToolCall' - type: array - type: object - chat.ToolCall: - properties: - function: - $ref: '#/definitions/chat.ToolCallFunction' id: type: string - type: + metadata: + additionalProperties: {} + type: object + updated_at: type: string - type: object - chat.ToolCallFunction: - properties: - arguments: - type: string - name: + user_id: type: string type: object github_com_memohai_memoh_internal_mcp.Connection: @@ -833,121 +851,20 @@ definitions: updated_at: type: string type: object - handlers.memoryAddPayload: + handlers.listMyIdentitiesResponse: properties: - embedding_enabled: - type: boolean - filters: - additionalProperties: {} - type: object - infer: - type: boolean - message: - type: string - messages: + items: items: - $ref: '#/definitions/memory.Message' + $ref: '#/definitions/channelidentities.ChannelIdentity' type: array - metadata: - additionalProperties: {} - type: object - run_id: + user_id: type: string type: object - handlers.memoryDeleteAllPayload: - properties: - run_id: - type: string - type: object - handlers.memoryEmbedUpsertPayload: - properties: - filters: - additionalProperties: {} - type: object - input: - $ref: '#/definitions/memory.EmbedInput' - metadata: - additionalProperties: {} - type: object - model: - type: string - provider: - type: string - run_id: - type: string - source: - type: string - type: - type: string - type: object - handlers.memorySearchPayload: - properties: - embedding_enabled: - type: boolean - filters: - additionalProperties: {} - type: object - limit: - type: integer - query: - type: string - run_id: - type: string - sources: - items: - type: string - type: array - type: object handlers.skillsOpResponse: properties: ok: type: boolean type: object - history.CreateRequest: - properties: - messages: - items: - additionalProperties: {} - type: object - type: array - metadata: - additionalProperties: {} - type: object - skills: - items: - type: string - type: array - type: object - history.ListResponse: - properties: - items: - items: - $ref: '#/definitions/history.Record' - type: array - type: object - history.Record: - properties: - bot_id: - type: string - id: - type: string - messages: - items: - additionalProperties: {} - type: object - type: array - metadata: - additionalProperties: {} - type: object - session_id: - type: string - skills: - items: - type: string - type: array - timestamp: - type: string - type: object mcp.ListResponse: properties: items: @@ -967,83 +884,6 @@ definitions: type: type: string type: object - memory.DeleteResponse: - properties: - message: - type: string - type: object - memory.EmbedInput: - properties: - image_url: - type: string - text: - type: string - video_url: - type: string - type: object - memory.EmbedUpsertResponse: - properties: - dimensions: - type: integer - item: - $ref: '#/definitions/memory.MemoryItem' - model: - type: string - provider: - type: string - type: object - memory.MemoryItem: - properties: - agentId: - type: string - botId: - type: string - createdAt: - type: string - hash: - type: string - id: - type: string - memory: - type: string - metadata: - additionalProperties: {} - type: object - runId: - type: string - score: - type: number - sessionId: - type: string - updatedAt: - type: string - type: object - memory.Message: - properties: - content: - type: string - role: - type: string - type: object - memory.SearchResponse: - properties: - relations: - items: {} - type: array - results: - items: - $ref: '#/definitions/memory.MemoryItem' - type: array - type: object - memory.UpdateRequest: - properties: - embedding_enabled: - type: boolean - memory: - type: string - memory_id: - type: string - type: object models.AddRequest: properties: dimensions: @@ -1396,83 +1236,6 @@ definitions: type: string type: array type: object - users.CreateUserRequest: - properties: - avatar_url: - type: string - display_name: - type: string - email: - type: string - is_active: - type: boolean - password: - type: string - role: - type: string - username: - type: string - type: object - users.ListUsersResponse: - properties: - items: - items: - $ref: '#/definitions/users.User' - type: array - type: object - users.ResetPasswordRequest: - properties: - new_password: - type: string - type: object - users.UpdatePasswordRequest: - properties: - current_password: - type: string - new_password: - type: string - type: object - users.UpdateProfileRequest: - properties: - avatar_url: - type: string - display_name: - type: string - type: object - users.UpdateUserRequest: - properties: - avatar_url: - type: string - display_name: - type: string - is_active: - type: boolean - role: - type: string - type: object - users.User: - properties: - avatar_url: - type: string - created_at: - type: string - display_name: - type: string - email: - type: string - id: - type: string - is_active: - type: boolean - last_login_at: - type: string - role: - type: string - updated_at: - type: string - username: - type: string - type: object info: contact: {} title: Memoh API @@ -1565,68 +1328,6 @@ paths: summary: Create bot user tags: - bots - /bots/{bot_id}/chat: - post: - consumes: - - application/json - description: Send a chat message and get a response. The system will automatically - select an appropriate chat model from the database. - parameters: - - description: Chat request - in: body - name: request - required: true - schema: - $ref: '#/definitions/chat.ChatRequest' - produces: - - application/json - responses: - "200": - description: OK - schema: - $ref: '#/definitions/chat.ChatResponse' - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Chat with AI - tags: - - chat - /bots/{bot_id}/chat/stream: - post: - consumes: - - application/json - description: Send a chat message and get a streaming response. The system will - automatically select an appropriate chat model from the database. - parameters: - - description: Chat request - in: body - name: request - required: true - schema: - $ref: '#/definitions/chat.ChatRequest' - produces: - - text/event-stream - responses: - "200": - description: OK - schema: - type: string - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Stream chat with AI - tags: - - chat /bots/{bot_id}/container: delete: parameters: @@ -2224,126 +1925,6 @@ paths: summary: Stop container task for bot tags: - containerd - /bots/{bot_id}/history: - delete: - description: Delete all history records for current user - responses: - "204": - description: No Content - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Delete all history records - tags: - - history - get: - description: List history records for current user - parameters: - - description: Limit - in: query - name: limit - type: integer - responses: - "200": - description: OK - schema: - $ref: '#/definitions/history.ListResponse' - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: List history records - tags: - - history - post: - description: Create a history record for current user - parameters: - - description: History payload - in: body - name: payload - required: true - schema: - $ref: '#/definitions/history.CreateRequest' - responses: - "201": - description: Created - schema: - $ref: '#/definitions/history.Record' - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Create history record - tags: - - history - /bots/{bot_id}/history/{id}: - delete: - description: Delete a history record by ID (must belong to current user) - parameters: - - description: History ID - in: path - name: id - required: true - type: string - responses: - "204": - description: No Content - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "403": - description: Forbidden - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Delete history record - tags: - - history - get: - description: Get a history record by ID (must belong to current user) - parameters: - - description: History ID - in: path - name: id - required: true - type: string - responses: - "200": - description: OK - schema: - $ref: '#/definitions/history.Record' - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "404": - description: Not Found - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Get history record - tags: - - history /bots/{bot_id}/mcp: get: description: List MCP connections for a bot @@ -2581,220 +2162,6 @@ paths: summary: Update MCP connection tags: - mcp - /bots/{bot_id}/memory/add: - post: - description: 'Add memory for a user via memory. Auth: Bearer JWT determines - user_id (sub or user_id).' - parameters: - - description: Add request - in: body - name: payload - required: true - schema: - $ref: '#/definitions/handlers.memoryAddPayload' - responses: - "200": - description: OK - schema: - $ref: '#/definitions/memory.SearchResponse' - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Add memory - tags: - - memory - /bots/{bot_id}/memory/embed: - post: - description: 'Embed text or multimodal input and upsert into memory store. Auth: - Bearer JWT determines user_id (sub or user_id).' - parameters: - - description: Embed upsert request - in: body - name: payload - required: true - schema: - $ref: '#/definitions/handlers.memoryEmbedUpsertPayload' - responses: - "200": - description: OK - schema: - $ref: '#/definitions/memory.EmbedUpsertResponse' - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Embed and upsert memory - tags: - - memory - /bots/{bot_id}/memory/memories: - delete: - description: 'Delete all memories for a user via memory. Auth: Bearer JWT determines - user_id (sub or user_id).' - parameters: - - description: Delete all request - in: body - name: payload - required: true - schema: - $ref: '#/definitions/handlers.memoryDeleteAllPayload' - responses: - "200": - description: OK - schema: - $ref: '#/definitions/memory.DeleteResponse' - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Delete memories - tags: - - memory - get: - description: 'List memories for a user via memory. Auth: Bearer JWT determines - user_id (sub or user_id).' - parameters: - - description: Run ID - in: query - name: run_id - type: string - - description: Limit - in: query - name: limit - type: integer - responses: - "200": - description: OK - schema: - $ref: '#/definitions/memory.SearchResponse' - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: List memories - tags: - - memory - /bots/{bot_id}/memory/memories/{memoryId}: - delete: - description: 'Delete a memory by ID via memory. Auth: Bearer JWT determines - user_id (sub or user_id).' - parameters: - - description: Memory ID - in: path - name: memoryId - required: true - type: string - responses: - "200": - description: OK - schema: - $ref: '#/definitions/memory.DeleteResponse' - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Delete memory - tags: - - memory - get: - description: 'Get a memory by ID via memory. Auth: Bearer JWT determines user_id - (sub or user_id).' - parameters: - - description: Memory ID - in: path - name: memoryId - required: true - type: string - responses: - "200": - description: OK - schema: - $ref: '#/definitions/memory.MemoryItem' - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Get memory - tags: - - memory - /bots/{bot_id}/memory/search: - post: - description: 'Search memories for a user via memory. Auth: Bearer JWT determines - user_id (sub or user_id).' - parameters: - - description: Search request - in: body - name: payload - required: true - schema: - $ref: '#/definitions/handlers.memorySearchPayload' - responses: - "200": - description: OK - schema: - $ref: '#/definitions/memory.SearchResponse' - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Search memories - tags: - - memory - /bots/{bot_id}/memory/update: - post: - description: 'Update a memory by ID via memory. Auth: Bearer JWT determines - user_id (sub or user_id).' - parameters: - - description: Update request - in: body - name: payload - required: true - schema: - $ref: '#/definitions/memory.UpdateRequest' - responses: - "200": - description: OK - schema: - $ref: '#/definitions/memory.MemoryItem' - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Update memory - tags: - - memory /bots/{bot_id}/schedule: get: description: List schedules for current user @@ -3526,7 +2893,7 @@ paths: summary: Send message via bot channel tags: - bots - /bots/{id}/channel/{platform}/send_session: + /bots/{id}/channel/{platform}/send_chat: post: description: Send a message using a session-scoped token (reply only) parameters: @@ -4343,7 +3710,7 @@ paths: "200": description: OK schema: - $ref: '#/definitions/users.ListUsersResponse' + $ref: '#/definitions/accounts.ListAccountsResponse' "400": description: Bad Request schema: @@ -4367,12 +3734,12 @@ paths: name: payload required: true schema: - $ref: '#/definitions/users.CreateUserRequest' + $ref: '#/definitions/accounts.CreateAccountRequest' responses: "201": description: Created schema: - $ref: '#/definitions/users.User' + $ref: '#/definitions/accounts.Account' "400": description: Bad Request schema: @@ -4401,7 +3768,7 @@ paths: "200": description: OK schema: - $ref: '#/definitions/users.User' + $ref: '#/definitions/accounts.Account' "400": description: Bad Request schema: @@ -4434,12 +3801,12 @@ paths: name: payload required: true schema: - $ref: '#/definitions/users.UpdateUserRequest' + $ref: '#/definitions/accounts.UpdateAccountRequest' responses: "200": description: OK schema: - $ref: '#/definitions/users.User' + $ref: '#/definitions/accounts.Account' "400": description: Bad Request schema: @@ -4473,7 +3840,7 @@ paths: name: payload required: true schema: - $ref: '#/definitions/users.ResetPasswordRequest' + $ref: '#/definitions/accounts.ResetPasswordRequest' responses: "204": description: No Content @@ -4503,7 +3870,7 @@ paths: "200": description: OK schema: - $ref: '#/definitions/users.User' + $ref: '#/definitions/accounts.Account' "400": description: Bad Request schema: @@ -4523,12 +3890,12 @@ paths: name: payload required: true schema: - $ref: '#/definitions/users.UpdateProfileRequest' + $ref: '#/definitions/accounts.UpdateProfileRequest' responses: "200": description: OK schema: - $ref: '#/definitions/users.User' + $ref: '#/definitions/accounts.Account' "400": description: Bad Request schema: @@ -4553,7 +3920,7 @@ paths: "200": description: OK schema: - $ref: '#/definitions/channel.ChannelUserBinding' + $ref: '#/definitions/channel.ChannelIdentityBinding' "400": description: Bad Request schema: @@ -4582,12 +3949,12 @@ paths: name: payload required: true schema: - $ref: '#/definitions/channel.UpsertUserConfigRequest' + $ref: '#/definitions/channel.UpsertChannelIdentityConfigRequest' responses: "200": description: OK schema: - $ref: '#/definitions/channel.ChannelUserBinding' + $ref: '#/definitions/channel.ChannelIdentityBinding' "400": description: Bad Request schema: @@ -4599,6 +3966,29 @@ paths: summary: Update channel user config tags: - channel + /users/me/identities: + get: + description: List all channel identities linked to current user + responses: + "200": + description: OK + schema: + $ref: '#/definitions/handlers.listMyIdentitiesResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: List current user's channel identities + tags: + - users /users/me/password: put: description: Update current user password with current password check @@ -4608,7 +3998,7 @@ paths: name: payload required: true schema: - $ref: '#/definitions/users.UpdatePasswordRequest' + $ref: '#/definitions/accounts.UpdatePasswordRequest' responses: "204": description: No Content diff --git a/internal/users/service.go b/internal/accounts/service.go similarity index 58% rename from internal/users/service.go rename to internal/accounts/service.go index 24c0c405..787dc251 100644 --- a/internal/users/service.go +++ b/internal/accounts/service.go @@ -1,4 +1,4 @@ -package users +package accounts import ( "context" @@ -16,6 +16,7 @@ import ( "github.com/memohai/memoh/internal/db/sqlc" ) +// Service provides account (credential) management for users. type Service struct { queries *sqlc.Queries logger *slog.Logger @@ -24,120 +25,125 @@ type Service struct { var ( ErrInvalidPassword = errors.New("invalid password") ErrInvalidCredentials = errors.New("invalid credentials") - ErrInactiveUser = errors.New("user is inactive") + ErrInactiveAccount = errors.New("account is inactive") ) +// NewService creates a new accounts service. func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { if log == nil { log = slog.Default() } return &Service{ queries: queries, - logger: log.With(slog.String("service", "users")), + logger: log.With(slog.String("service", "accounts")), } } -func (s *Service) Get(ctx context.Context, userID string) (User, error) { +// Get returns an account by user id. +func (s *Service) Get(ctx context.Context, userID string) (Account, error) { if s.queries == nil { - return User{}, fmt.Errorf("user queries not configured") + return Account{}, fmt.Errorf("account queries not configured") } pgID, err := parseUUID(userID) if err != nil { - return User{}, err + return Account{}, err } - row, err := s.queries.GetUserByID(ctx, pgID) + row, err := s.queries.GetAccountByUserID(ctx, pgID) if err != nil { - return User{}, err + return Account{}, err } - return toUser(row), nil + return toAccount(row), nil } -func (s *Service) Login(ctx context.Context, identity, password string) (User, error) { +// Login authenticates by identity (username or email) and password. +func (s *Service) Login(ctx context.Context, identity, password string) (Account, error) { if s.queries == nil { - return User{}, fmt.Errorf("user queries not configured") + return Account{}, fmt.Errorf("account queries not configured") } identity = strings.TrimSpace(identity) if identity == "" || strings.TrimSpace(password) == "" { - return User{}, ErrInvalidCredentials + return Account{}, ErrInvalidCredentials } - row, err := s.queries.GetUserByIdentity(ctx, identity) + row, err := s.queries.GetAccountByIdentity(ctx, pgtype.Text{String: identity, Valid: true}) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return User{}, ErrInvalidCredentials + return Account{}, ErrInvalidCredentials } - return User{}, err + return Account{}, err } if !row.IsActive { - return User{}, ErrInactiveUser + return Account{}, ErrInactiveAccount } - if err := bcrypt.CompareHashAndPassword([]byte(row.PasswordHash), []byte(password)); err != nil { - return User{}, ErrInvalidCredentials + if !row.PasswordHash.Valid { + return Account{}, ErrInvalidCredentials } - if _, err := s.queries.UpdateUserLastLogin(ctx, row.ID); err != nil { + if err := bcrypt.CompareHashAndPassword([]byte(row.PasswordHash.String), []byte(password)); err != nil { + return Account{}, ErrInvalidCredentials + } + if _, err := s.queries.UpdateAccountLastLogin(ctx, row.ID); err != nil { if s.logger != nil { s.logger.Warn("touch last login failed", slog.Any("error", err)) } } - return toUser(row), nil + return toAccount(row), nil } -func (s *Service) ListUsers(ctx context.Context) ([]User, error) { +// ListAccounts returns all accounts. +func (s *Service) ListAccounts(ctx context.Context) ([]Account, error) { if s.queries == nil { - return nil, fmt.Errorf("user queries not configured") + return nil, fmt.Errorf("account queries not configured") } - rows, err := s.queries.ListUsers(ctx) + rows, err := s.queries.ListAccounts(ctx) if err != nil { return nil, err } - items := make([]User, 0, len(rows)) + items := make([]Account, 0, len(rows)) for _, row := range rows { - items = append(items, toUser(row)) + items = append(items, toAccount(row)) } return items, nil } -func (s *Service) ListUsersByType(ctx context.Context, userType string) ([]User, error) { - if s.queries == nil { - return nil, fmt.Errorf("user queries not configured") - } - return nil, fmt.Errorf("user type filtering is not supported") -} - +// IsAdmin checks if the user has admin role. func (s *Service) IsAdmin(ctx context.Context, userID string) (bool, error) { if s.queries == nil { - return false, fmt.Errorf("user queries not configured") + return false, fmt.Errorf("account queries not configured") } pgID, err := parseUUID(userID) if err != nil { return false, err } - row, err := s.queries.GetUserByID(ctx, pgID) + row, err := s.queries.GetAccountByUserID(ctx, pgID) if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return false, nil + } return false, err } return isAdminRole(row.Role), nil } -func (s *Service) CreateHuman(ctx context.Context, req CreateUserRequest) (User, error) { +// Create creates a new account for an existing user. +func (s *Service) Create(ctx context.Context, userID string, req CreateAccountRequest) (Account, error) { if s.queries == nil { - return User{}, fmt.Errorf("user queries not configured") + return Account{}, fmt.Errorf("account queries not configured") } username := strings.TrimSpace(req.Username) if username == "" { - return User{}, fmt.Errorf("username is required") + return Account{}, fmt.Errorf("username is required") } password := strings.TrimSpace(req.Password) if password == "" { - return User{}, fmt.Errorf("password is required") + return Account{}, fmt.Errorf("password is required") } role, err := normalizeRole(req.Role) if err != nil { - return User{}, err + return Account{}, err } hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { - return User{}, err + return Account{}, err } displayName := strings.TrimSpace(req.DisplayName) @@ -151,6 +157,10 @@ func (s *Service) CreateHuman(ctx context.Context, req CreateUserRequest) (User, isActive = *req.IsActive } + pgUserID, err := parseUUID(userID) + if err != nil { + return Account{}, err + } emailValue := pgtype.Text{Valid: false} if email != "" { emailValue = pgtype.Text{String: email, Valid: true} @@ -161,10 +171,11 @@ func (s *Service) CreateHuman(ctx context.Context, req CreateUserRequest) (User, avatarValue = pgtype.Text{String: avatarURL, Valid: true} } - row, err := s.queries.CreateUser(ctx, sqlc.CreateUserParams{ - Username: username, + row, err := s.queries.CreateAccount(ctx, sqlc.CreateAccountParams{ + UserID: pgUserID, + Username: pgtype.Text{String: username, Valid: true}, Email: emailValue, - PasswordHash: string(hashed), + PasswordHash: pgtype.Text{String: string(hashed), Valid: true}, Role: role, DisplayName: displayValue, AvatarUrl: avatarValue, @@ -172,28 +183,51 @@ func (s *Service) CreateHuman(ctx context.Context, req CreateUserRequest) (User, DataRoot: pgtype.Text{Valid: false}, }) if err != nil { - return User{}, err + return Account{}, err } - return toUser(row), nil + return toAccount(row), nil } -func (s *Service) UpdateUserAdmin(ctx context.Context, userID string, req UpdateUserRequest) (User, error) { +// CreateHuman keeps compatibility with older call sites. +func (s *Service) CreateHuman(ctx context.Context, userID string, req CreateAccountRequest) (Account, error) { + userID = strings.TrimSpace(userID) + if userID == "" { + if s.queries == nil { + return Account{}, fmt.Errorf("account queries not configured") + } + userRow, err := s.queries.CreateUser(ctx, sqlc.CreateUserParams{ + IsActive: true, + Metadata: []byte("{}"), + }) + if err != nil { + return Account{}, err + } + if !userRow.ID.Valid { + return Account{}, fmt.Errorf("create user: invalid id") + } + userID = uuid.UUID(userRow.ID.Bytes).String() + } + return s.Create(ctx, userID, req) +} + +// UpdateAdmin updates account fields as admin. +func (s *Service) UpdateAdmin(ctx context.Context, userID string, req UpdateAccountRequest) (Account, error) { if s.queries == nil { - return User{}, fmt.Errorf("user queries not configured") + return Account{}, fmt.Errorf("account queries not configured") } pgID, err := parseUUID(userID) if err != nil { - return User{}, err + return Account{}, err } - existing, err := s.queries.GetUserByID(ctx, pgID) + existing, err := s.queries.GetAccountByUserID(ctx, pgID) if err != nil { - return User{}, err + return Account{}, err } role := fmt.Sprint(existing.Role) if req.Role != nil { role, err = normalizeRole(*req.Role) if err != nil { - return User{}, err + return Account{}, err } } displayName := strings.TrimSpace(existing.DisplayName.String) @@ -201,7 +235,7 @@ func (s *Service) UpdateUserAdmin(ctx context.Context, userID string, req Update displayName = strings.TrimSpace(*req.DisplayName) } if displayName == "" { - displayName = existing.Username + displayName = strings.TrimSpace(existing.Username.String) } avatarURL := strings.TrimSpace(existing.AvatarUrl.String) if req.AvatarURL != nil { @@ -212,57 +246,59 @@ func (s *Service) UpdateUserAdmin(ctx context.Context, userID string, req Update isActive = *req.IsActive } - row, err := s.queries.UpdateUserAdmin(ctx, sqlc.UpdateUserAdminParams{ - ID: pgID, + row, err := s.queries.UpdateAccountAdmin(ctx, sqlc.UpdateAccountAdminParams{ + UserID: pgID, Role: role, DisplayName: pgtype.Text{String: displayName, Valid: displayName != ""}, AvatarUrl: pgtype.Text{String: avatarURL, Valid: avatarURL != ""}, IsActive: isActive, }) if err != nil { - return User{}, err + return Account{}, err } - return toUser(row), nil + return toAccount(row), nil } -func (s *Service) UpdateProfile(ctx context.Context, userID string, req UpdateProfileRequest) (User, error) { +// UpdateProfile updates the user's profile. +func (s *Service) UpdateProfile(ctx context.Context, userID string, req UpdateProfileRequest) (Account, error) { if s.queries == nil { - return User{}, fmt.Errorf("user queries not configured") + return Account{}, fmt.Errorf("account queries not configured") } pgID, err := parseUUID(userID) if err != nil { - return User{}, err + return Account{}, err } - existing, err := s.queries.GetUserByID(ctx, pgID) + existing, err := s.queries.GetAccountByUserID(ctx, pgID) if err != nil { - return User{}, err + return Account{}, err } displayName := strings.TrimSpace(existing.DisplayName.String) if req.DisplayName != nil { displayName = strings.TrimSpace(*req.DisplayName) } if displayName == "" { - displayName = existing.Username + displayName = strings.TrimSpace(existing.Username.String) } avatarURL := strings.TrimSpace(existing.AvatarUrl.String) if req.AvatarURL != nil { avatarURL = strings.TrimSpace(*req.AvatarURL) } - row, err := s.queries.UpdateUserProfile(ctx, sqlc.UpdateUserProfileParams{ + row, err := s.queries.UpdateAccountProfile(ctx, sqlc.UpdateAccountProfileParams{ ID: pgID, DisplayName: pgtype.Text{String: displayName, Valid: displayName != ""}, AvatarUrl: pgtype.Text{String: avatarURL, Valid: avatarURL != ""}, IsActive: existing.IsActive, }) if err != nil { - return User{}, err + return Account{}, err } - return toUser(row), nil + return toAccount(row), nil } +// UpdatePassword changes the password after verifying the current one. func (s *Service) UpdatePassword(ctx context.Context, userID, currentPassword, newPassword string) error { if s.queries == nil { - return fmt.Errorf("user queries not configured") + return fmt.Errorf("account queries not configured") } if strings.TrimSpace(newPassword) == "" { return fmt.Errorf("new password is required") @@ -271,30 +307,34 @@ func (s *Service) UpdatePassword(ctx context.Context, userID, currentPassword, n if err != nil { return err } - existing, err := s.queries.GetUserByID(ctx, pgID) + existing, err := s.queries.GetAccountByUserID(ctx, pgID) if err != nil { return err } if strings.TrimSpace(currentPassword) == "" { return ErrInvalidPassword } - if err := bcrypt.CompareHashAndPassword([]byte(existing.PasswordHash), []byte(currentPassword)); err != nil { + if !existing.PasswordHash.Valid { + return ErrInvalidPassword + } + if err := bcrypt.CompareHashAndPassword([]byte(existing.PasswordHash.String), []byte(currentPassword)); err != nil { return ErrInvalidPassword } hashed, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost) if err != nil { return err } - _, err = s.queries.UpdateUserPassword(ctx, sqlc.UpdateUserPasswordParams{ + _, err = s.queries.UpdateAccountPassword(ctx, sqlc.UpdateAccountPasswordParams{ ID: pgID, - PasswordHash: string(hashed), + PasswordHash: pgtype.Text{String: string(hashed), Valid: true}, }) return err } +// ResetPassword sets a new password without requiring the current one. func (s *Service) ResetPassword(ctx context.Context, userID, newPassword string) error { if s.queries == nil { - return fmt.Errorf("user queries not configured") + return fmt.Errorf("account queries not configured") } if strings.TrimSpace(newPassword) == "" { return fmt.Errorf("new password is required") @@ -307,9 +347,9 @@ func (s *Service) ResetPassword(ctx context.Context, userID, newPassword string) if err != nil { return err } - _, err = s.queries.UpdateUserPassword(ctx, sqlc.UpdateUserPasswordParams{ + _, err = s.queries.UpdateAccountPassword(ctx, sqlc.UpdateAccountPasswordParams{ ID: pgID, - PasswordHash: string(hashed), + PasswordHash: pgtype.Text{String: string(hashed), Valid: true}, }) return err } @@ -339,7 +379,8 @@ func isAdminRole(role any) bool { } } -func toUser(row sqlc.User) User { +func toAccount(row sqlc.User) Account { + username := strings.TrimSpace(row.Username.String) email := "" if row.Email.Valid { email = row.Email.String @@ -348,6 +389,9 @@ func toUser(row sqlc.User) User { if row.DisplayName.Valid { displayName = row.DisplayName.String } + if displayName == "" { + displayName = username + } avatarURL := "" if row.AvatarUrl.Valid { avatarURL = row.AvatarUrl.String @@ -364,9 +408,9 @@ func toUser(row sqlc.User) User { if row.LastLoginAt.Valid { lastLogin = row.LastLoginAt.Time } - return User{ + return Account{ ID: toUUIDString(row.ID), - Username: row.Username, + Username: username, Email: email, Role: fmt.Sprint(row.Role), DisplayName: displayName, diff --git a/internal/users/types.go b/internal/accounts/types.go similarity index 68% rename from internal/users/types.go rename to internal/accounts/types.go index 431225fc..7a3b4f62 100644 --- a/internal/users/types.go +++ b/internal/accounts/types.go @@ -1,8 +1,9 @@ -package users +package accounts import "time" -type User struct { +// Account represents a human account credential record. +type Account struct { ID string `json:"id"` Username string `json:"username"` Email string `json:"email,omitempty"` @@ -15,7 +16,8 @@ type User struct { LastLoginAt time.Time `json:"last_login_at,omitempty"` } -type CreateUserRequest struct { +// CreateAccountRequest is the input for creating an account. +type CreateAccountRequest struct { Username string `json:"username"` Password string `json:"password"` Email string `json:"email,omitempty"` @@ -25,27 +27,32 @@ type CreateUserRequest struct { IsActive *bool `json:"is_active,omitempty"` } -type UpdateUserRequest struct { +// UpdateAccountRequest is the input for admin-level account updates. +type UpdateAccountRequest struct { Role *string `json:"role,omitempty"` DisplayName *string `json:"display_name,omitempty"` AvatarURL *string `json:"avatar_url,omitempty"` IsActive *bool `json:"is_active,omitempty"` } +// UpdateProfileRequest is the input for self-service profile updates. type UpdateProfileRequest struct { DisplayName *string `json:"display_name,omitempty"` AvatarURL *string `json:"avatar_url,omitempty"` } +// UpdatePasswordRequest is the input for password change. type UpdatePasswordRequest struct { CurrentPassword string `json:"current_password,omitempty"` NewPassword string `json:"new_password"` } +// ResetPasswordRequest is the input for admin password reset. type ResetPasswordRequest struct { NewPassword string `json:"new_password"` } -type ListUsersResponse struct { - Items []User `json:"items"` +// ListAccountsResponse wraps a list of accounts. +type ListAccountsResponse struct { + Items []Account `json:"items"` } diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index 790035d9..2454aacc 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -13,15 +13,14 @@ import ( ) const ( - claimSubject = "sub" - claimUserID = "user_id" - claimType = "typ" - claimBotID = "bot_id" - claimPlatform = "platform" - claimReplyTarget = "reply_target" - claimSessionID = "session_id" - claimContactID = "contact_id" - sessionTokenType = "channel_session" + claimSubject = "sub" + claimUserID = "user_id" + claimChannelIdentityID = "channel_identity_id" + claimType = "typ" + claimBotID = "bot_id" + claimChatID = "chat_id" + claimRouteID = "route_id" + chatTokenType = "chat_route" ) // JWTMiddleware returns a JWT auth middleware configured for HS256 tokens. @@ -53,9 +52,17 @@ func UserIDFromContext(c echo.Context) (string, error) { if userID := claimString(claims, claimSubject); userID != "" { return userID, nil } + if legacyChannelIdentityID := claimString(claims, claimChannelIdentityID); legacyChannelIdentityID != "" { + return legacyChannelIdentityID, nil + } return "", echo.NewHTTPError(http.StatusUnauthorized, "user id missing") } +// ChannelIdentityIDFromContext is kept as compatibility alias and returns user id. +func ChannelIdentityIDFromContext(c echo.Context) (string, error) { + return UserIDFromContext(c) +} + // GenerateToken creates a signed JWT for the user. func GenerateToken(userID, secret string, expiresIn time.Duration) (string, time.Time, error) { if strings.TrimSpace(userID) == "" { @@ -71,10 +78,11 @@ func GenerateToken(userID, secret string, expiresIn time.Duration) (string, time now := time.Now().UTC() expiresAt := now.Add(expiresIn) claims := jwt.MapClaims{ - claimSubject: userID, - claimUserID: userID, - "iat": now.Unix(), - "exp": expiresAt.Unix(), + claimSubject: userID, + claimUserID: userID, + claimChannelIdentityID: userID, // legacy compatibility for handlers still reading channel_identity_id + "iat": now.Unix(), + "exp": expiresAt.Unix(), } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) signed, err := token.SignedString([]byte(secret)) @@ -84,24 +92,28 @@ func GenerateToken(userID, secret string, expiresIn time.Duration) (string, time return signed, expiresAt, nil } -type SessionToken struct { - BotID string - Platform string - ReplyTarget string - SessionID string - ContactID string +// ChatToken holds the claims for a chat-based JWT used for route-based reply. +type ChatToken struct { + BotID string + ChatID string + RouteID string + UserID string + ChannelIdentityID string } -// GenerateSessionToken creates a signed JWT for channel session reply. -func GenerateSessionToken(info SessionToken, secret string, expiresIn time.Duration) (string, time.Time, error) { +// GenerateChatToken creates a signed JWT for chat route reply. +func GenerateChatToken(info ChatToken, secret string, expiresIn time.Duration) (string, time.Time, error) { if strings.TrimSpace(info.BotID) == "" { return "", time.Time{}, fmt.Errorf("bot id is required") } - if strings.TrimSpace(info.Platform) == "" { - return "", time.Time{}, fmt.Errorf("platform is required") + if strings.TrimSpace(info.ChatID) == "" { + return "", time.Time{}, fmt.Errorf("chat id is required") } - if strings.TrimSpace(info.ReplyTarget) == "" { - return "", time.Time{}, fmt.Errorf("reply target is required") + if strings.TrimSpace(info.UserID) == "" { + info.UserID = strings.TrimSpace(info.ChannelIdentityID) + } + if strings.TrimSpace(info.UserID) == "" { + return "", time.Time{}, fmt.Errorf("user id is required") } if strings.TrimSpace(secret) == "" { return "", time.Time{}, fmt.Errorf("jwt secret is required") @@ -113,14 +125,14 @@ func GenerateSessionToken(info SessionToken, secret string, expiresIn time.Durat now := time.Now().UTC() expiresAt := now.Add(expiresIn) claims := jwt.MapClaims{ - claimType: sessionTokenType, - claimBotID: info.BotID, - claimPlatform: info.Platform, - claimReplyTarget: info.ReplyTarget, - claimSessionID: info.SessionID, - claimContactID: info.ContactID, - "iat": now.Unix(), - "exp": expiresAt.Unix(), + claimType: chatTokenType, + claimBotID: info.BotID, + claimChatID: info.ChatID, + claimRouteID: info.RouteID, + claimUserID: info.UserID, + claimChannelIdentityID: info.ChannelIdentityID, + "iat": now.Unix(), + "exp": expiresAt.Unix(), } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) signed, err := token.SignedString([]byte(secret)) @@ -130,26 +142,30 @@ func GenerateSessionToken(info SessionToken, secret string, expiresIn time.Durat return signed, expiresAt, nil } -// SessionTokenFromContext extracts the session token claims from context. -func SessionTokenFromContext(c echo.Context) (SessionToken, error) { +// ChatTokenFromContext extracts the chat token claims from context. +func ChatTokenFromContext(c echo.Context) (ChatToken, error) { token, ok := c.Get("user").(*jwt.Token) if !ok || token == nil || !token.Valid { - return SessionToken{}, echo.NewHTTPError(http.StatusUnauthorized, "invalid token") + return ChatToken{}, echo.NewHTTPError(http.StatusUnauthorized, "invalid token") } claims, ok := token.Claims.(jwt.MapClaims) if !ok { - return SessionToken{}, echo.NewHTTPError(http.StatusUnauthorized, "invalid token claims") + return ChatToken{}, echo.NewHTTPError(http.StatusUnauthorized, "invalid token claims") } - if claimString(claims, claimType) != sessionTokenType { - return SessionToken{}, echo.NewHTTPError(http.StatusUnauthorized, "invalid session token") + if claimString(claims, claimType) != chatTokenType { + return ChatToken{}, echo.NewHTTPError(http.StatusUnauthorized, "invalid chat token") } - return SessionToken{ - BotID: claimString(claims, claimBotID), - Platform: claimString(claims, claimPlatform), - ReplyTarget: claimString(claims, claimReplyTarget), - SessionID: claimString(claims, claimSessionID), - ContactID: claimString(claims, claimContactID), - }, nil + info := ChatToken{ + BotID: claimString(claims, claimBotID), + ChatID: claimString(claims, claimChatID), + RouteID: claimString(claims, claimRouteID), + UserID: claimString(claims, claimUserID), + ChannelIdentityID: claimString(claims, claimChannelIdentityID), + } + if strings.TrimSpace(info.UserID) == "" { + info.UserID = strings.TrimSpace(info.ChannelIdentityID) + } + return info, nil } func claimString(claims jwt.MapClaims, key string) string { diff --git a/internal/bind/service.go b/internal/bind/service.go new file mode 100644 index 00000000..c4a33dd3 --- /dev/null +++ b/internal/bind/service.go @@ -0,0 +1,262 @@ +package bind + +import ( + "context" + "errors" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/memohai/memoh/internal/db/sqlc" +) + +const ( + defaultTTL = 24 * time.Hour + maxTokenRetries = 5 +) + +// Service manages channel identity->user bind code lifecycle. +type Service struct { + pool *pgxpool.Pool + queries *sqlc.Queries + logger *slog.Logger +} + +// NewService creates a bind code service. +func NewService(log *slog.Logger, pool *pgxpool.Pool, queries *sqlc.Queries) *Service { + if log == nil { + log = slog.Default() + } + return &Service{ + pool: pool, + queries: queries, + logger: log.With(slog.String("service", "bind")), + } +} + +// Issue creates a new bind code issued by the given user. +// Platform is optional; when provided, bind consume must happen on the same channel platform. +func (s *Service) Issue(ctx context.Context, issuedByUserID, platform string, ttl time.Duration) (Code, error) { + if s.queries == nil { + return Code{}, fmt.Errorf("bind queries not configured") + } + if ttl <= 0 { + ttl = defaultTTL + } + + pgUserID, err := parseUUID(issuedByUserID) + if err != nil { + return Code{}, fmt.Errorf("invalid user id: %w", err) + } + normalizedPlatform := normalizePlatform(platform) + + expiresAt := time.Now().UTC().Add(ttl) + for i := 0; i < maxTokenRetries; i++ { + token := strings.ToUpper(strings.ReplaceAll(uuid.NewString(), "-", "")[:8]) + row, err := s.queries.CreateBindCode(ctx, sqlc.CreateBindCodeParams{ + Token: token, + IssuedByUserID: pgUserID, + Platform: pgtype.Text{ + String: normalizedPlatform, + Valid: normalizedPlatform != "", + }, + ExpiresAt: pgtype.Timestamptz{Time: expiresAt, Valid: true}, + }) + if err == nil { + return toCode(row), nil + } + if isUniqueViolation(err) { + continue + } + return Code{}, fmt.Errorf("create bind code: %w", err) + } + return Code{}, fmt.Errorf("create bind code: token collision after retries") +} + +// Get looks up a bind code by token. +func (s *Service) Get(ctx context.Context, token string) (Code, error) { + if s.queries == nil { + return Code{}, fmt.Errorf("bind queries not configured") + } + row, err := s.queries.GetBindCode(ctx, strings.TrimSpace(token)) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return Code{}, ErrCodeNotFound + } + return Code{}, err + } + return toCode(row), nil +} + +// Consume validates and consumes a bind code and links the channel identity to issuer user. +func (s *Service) Consume(ctx context.Context, code Code, channelIdentityID string) error { + if s.queries == nil || s.pool == nil { + return fmt.Errorf("bind service not configured") + } + + // Fast-fail based on caller snapshot before opening a transaction. + if !code.UsedAt.IsZero() { + return ErrCodeUsed + } + if !code.ExpiresAt.IsZero() && time.Now().UTC().After(code.ExpiresAt) { + return ErrCodeExpired + } + token := strings.TrimSpace(code.Token) + if token == "" { + return ErrCodeNotFound + } + sourceIdentityID := strings.TrimSpace(channelIdentityID) + if sourceIdentityID == "" { + return fmt.Errorf("channel identity id is required") + } + pgSourceIdentityID, err := parseUUID(sourceIdentityID) + if err != nil { + return err + } + + tx, err := s.pool.BeginTx(ctx, pgx.TxOptions{}) + if err != nil { + return fmt.Errorf("begin bind consume tx: %w", err) + } + defer func() { _ = tx.Rollback(ctx) }() + qtx := s.queries.WithTx(tx) + + lockedCodeRow, err := qtx.GetBindCodeForUpdate(ctx, token) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ErrCodeNotFound + } + return fmt.Errorf("lock bind code: %w", err) + } + lockedCode := toCode(lockedCodeRow) + if !lockedCode.UsedAt.IsZero() { + return ErrCodeUsed + } + if !lockedCode.ExpiresAt.IsZero() && time.Now().UTC().After(lockedCode.ExpiresAt) { + return ErrCodeExpired + } + if strings.TrimSpace(code.Platform) != "" && !strings.EqualFold(lockedCode.Platform, strings.TrimSpace(code.Platform)) { + return ErrCodeMismatch + } + + targetUserID := strings.TrimSpace(lockedCode.IssuedByUserID) + if targetUserID == "" { + return fmt.Errorf("bind code issuer user is missing") + } + pgTargetUserID, err := parseUUID(targetUserID) + if err != nil { + return err + } + + if _, err := qtx.GetChannelIdentityByIDForUpdate(ctx, pgSourceIdentityID); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return fmt.Errorf("channel identity not found") + } + return fmt.Errorf("lock source identity: %w", err) + } + sourceIdentity, err := qtx.GetChannelIdentityByIDForUpdate(ctx, pgSourceIdentityID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return fmt.Errorf("channel identity not found") + } + return fmt.Errorf("reload source identity: %w", err) + } + if sourceIdentity.UserID.Valid && uuidString(sourceIdentity.UserID) != targetUserID { + return ErrLinkConflict + } + if !sourceIdentity.UserID.Valid { + if _, err := qtx.SetChannelIdentityLinkedUser(ctx, sqlc.SetChannelIdentityLinkedUserParams{ + ID: pgSourceIdentityID, + UserID: pgTargetUserID, + }); err != nil { + return fmt.Errorf("link channel identity user: %w", err) + } + } + + if _, err := qtx.MarkBindCodeUsed(ctx, sqlc.MarkBindCodeUsedParams{ + ID: lockedCodeRow.ID, + UsedByChannelIdentityID: pgSourceIdentityID, + }); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ErrCodeUsed + } + return fmt.Errorf("mark bind code used: %w", err) + } + + if err := tx.Commit(ctx); err != nil { + return fmt.Errorf("commit bind consume tx: %w", err) + } + + s.logger.Info("bind code consumed", + slog.String("code_id", lockedCode.ID), + slog.String("platform", lockedCode.Platform), + slog.String("channel_identity", sourceIdentityID), + slog.String("target_user", targetUserID), + ) + return nil +} + +func toCode(row sqlc.ChannelIdentityBindCode) Code { + c := Code{ + ID: uuidString(row.ID), + Token: row.Token, + IssuedByUserID: uuidString(row.IssuedByUserID), + CreatedAt: row.CreatedAt.Time, + } + if row.Platform.Valid { + c.Platform = normalizePlatform(row.Platform.String) + } + if row.ExpiresAt.Valid { + c.ExpiresAt = row.ExpiresAt.Time + } + if row.UsedAt.Valid { + c.UsedAt = row.UsedAt.Time + } + if row.UsedByChannelIdentityID.Valid { + c.UsedByChannelIdentityID = uuidString(row.UsedByChannelIdentityID) + } + return c +} + +func parseUUID(id string) (pgtype.UUID, error) { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + return pgtype.UUID{}, fmt.Errorf("empty id") + } + var pgID pgtype.UUID + if err := pgID.Scan(trimmed); err != nil { + return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) + } + return pgID, nil +} + +func uuidString(id pgtype.UUID) string { + if !id.Valid { + return "" + } + b := id.Bytes + return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", + b[0:4], b[4:6], b[6:8], b[8:10], b[10:16]) +} + +func isUniqueViolation(err error) bool { + var pgErr *pgconn.PgError + if !errors.As(err, &pgErr) { + return false + } + if pgErr.Code != "23505" { + return false + } + return pgErr.ConstraintName == "" || pgErr.ConstraintName == "channel_identity_bind_codes_token_unique" +} + +func normalizePlatform(raw string) string { + return strings.ToLower(strings.TrimSpace(raw)) +} diff --git a/internal/bind/service_integration_test.go b/internal/bind/service_integration_test.go new file mode 100644 index 00000000..48a2c145 --- /dev/null +++ b/internal/bind/service_integration_test.go @@ -0,0 +1,562 @@ +//go:build ignore +// +build ignore + +package bind_test + +import ( + "context" + "encoding/json" + "errors" + "log/slog" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/memohai/memoh/internal/channelidentities" + "github.com/memohai/memoh/internal/bind" + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" +) + +func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.Service, *bind.Service, func()) { + t.Helper() + + dsn := os.Getenv("TEST_POSTGRES_DSN") + if dsn == "" { + t.Skip("skip integration test: TEST_POSTGRES_DSN is not set") + } + + ctx := context.Background() + pool, err := pgxpool.New(ctx, dsn) + if err != nil { + t.Skipf("skip integration test: cannot connect to database: %v", err) + } + if err := pool.Ping(ctx); err != nil { + pool.Close() + t.Skipf("skip integration test: database ping failed: %v", err) + } + + queries := sqlc.New(pool) + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + channelIdentitySvc := channelidentities.NewService(logger, queries) + bindSvc := bind.NewService(logger, pool, queries) + + return queries, channelIdentitySvc, bindSvc, func() { pool.Close() } +} + +func createUser(ctx context.Context, queries *sqlc.Queries) (string, error) { + row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{ + IsActive: true, + Metadata: []byte("{}"), + }) + if err != nil { + return "", err + } + return db.UUIDToString(row.ID), nil +} + +func createBot(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) { + pgOwnerID, err := db.ParseUUID(ownerUserID) + if err != nil { + return "", err + } + meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"}) + row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{ + OwnerUserID: pgOwnerID, + Type: "personal", + DisplayName: pgtype.Text{String: "bind-test-bot", Valid: true}, + IsActive: true, + Metadata: meta, + }) + if err != nil { + return "", err + } + return db.UUIDToString(row.ID), nil +} + +func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) { + queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + ownerUserID, err := createUser(ctx, queries) + if err != nil { + t.Fatalf("create owner user failed: %v", err) + } + sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel) + if err != nil { + t.Fatalf("create source channelIdentity failed: %v", err) + } + botID, err := createBot(ctx, queries, ownerUserID) + if err != nil { + t.Fatalf("create bot failed: %v", err) + } + + code, err := bindSvc.Issue(ctx, botID, ownerUserID, 10*time.Minute) + if err != nil { + t.Fatalf("issue bind code failed: %v", err) + } + if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); err != nil { + t.Fatalf("consume bind code failed: %v", err) + } + + after, err := bindSvc.Get(ctx, code.Token) + if err != nil { + t.Fatalf("get bind code failed: %v", err) + } + if after.UsedAt.IsZero() { + t.Fatal("expected used_at to be set after consume") + } + if after.UsedByChannelIdentityID != sourceChannelIdentity.ID { + t.Fatalf("expected used_by_channel_identity_id=%s, got %s", sourceChannelIdentity.ID, after.UsedByChannelIdentityID) + } + + linkedUserID, err := channelIdentitySvc.GetLinkedUserID(ctx, sourceChannelIdentity.ID) + if err != nil { + t.Fatalf("get linked user failed: %v", err) + } + if linkedUserID != ownerUserID { + t.Fatalf("expected linked user=%s, got %s", ownerUserID, linkedUserID) + } + + if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); !errors.Is(err, bind.ErrCodeUsed) { + t.Fatalf("expected ErrCodeUsed on second consume, got %v", err) + } +} + +func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) { + queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + ownerUserID, err := createUser(ctx, queries) + if err != nil { + t.Fatalf("create owner user failed: %v", err) + } + otherUserID, err := createUser(ctx, queries) + if err != nil { + t.Fatalf("create other user failed: %v", err) + } + sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel) + if err != nil { + t.Fatalf("create source channelIdentity failed: %v", err) + } + if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, sourceChannelIdentity.ID, otherUserID); err != nil { + t.Fatalf("pre-link source channelIdentity failed: %v", err) + } + botID, err := createBot(ctx, queries, ownerUserID) + if err != nil { + t.Fatalf("create bot failed: %v", err) + } + + code, err := bindSvc.Issue(ctx, botID, ownerUserID, 10*time.Minute) + if err != nil { + t.Fatalf("issue bind code failed: %v", err) + } + if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); !errors.Is(err, bind.ErrLinkConflict) { + t.Fatalf("expected ErrLinkConflict, got %v", err) + } + + after, err := bindSvc.Get(ctx, code.Token) + if err != nil { + t.Fatalf("get bind code failed: %v", err) + } + if !after.UsedAt.IsZero() { + t.Fatal("expected used_at to remain empty when consume fails") + } +} +package bind_test + +import ( + "context" + "encoding/json" + "errors" + "log/slog" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/memohai/memoh/internal/channelidentities" + "github.com/memohai/memoh/internal/bind" + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" +) + +func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.Service, *bind.Service, func()) { + t.Helper() + + dsn := os.Getenv("TEST_POSTGRES_DSN") + if dsn == "" { + t.Skip("skip integration test: TEST_POSTGRES_DSN is not set") + } + + ctx := context.Background() + pool, err := pgxpool.New(ctx, dsn) + if err != nil { + t.Skipf("skip integration test: cannot connect to database: %v", err) + } + if err := pool.Ping(ctx); err != nil { + pool.Close() + t.Skipf("skip integration test: database ping failed: %v", err) + } + + queries := sqlc.New(pool) + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + channelIdentitySvc := channelidentities.NewService(logger, queries) + bindSvc := bind.NewService(logger, pool, queries) + + cleanup := func() { + pool.Close() + } + return queries, channelIdentitySvc, bindSvc, cleanup +} + +func createUserForBindTest(ctx context.Context, queries *sqlc.Queries) (string, error) { + row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{ + IsActive: true, + Metadata: []byte("{}"), + }) + if err != nil { + return "", err + } + return db.UUIDToString(row.ID), nil +} + +func createBotForBindTest(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) { + pgOwnerID, err := db.ParseUUID(ownerUserID) + if err != nil { + return "", err + } + meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"}) + row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{ + OwnerUserID: pgOwnerID, + Type: "personal", + DisplayName: pgtype.Text{String: "bind-test-bot", Valid: true}, + AvatarUrl: pgtype.Text{}, + IsActive: true, + Metadata: meta, + }) + if err != nil { + return "", err + } + return db.UUIDToString(row.ID), nil +} + +func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) { + queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + ownerUserID, err := createUserForBindTest(ctx, queries) + if err != nil { + t.Fatalf("create owner user failed: %v", err) + } + sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel) + if err != nil { + t.Fatalf("create source channelIdentity failed: %v", err) + } + botID, err := createBotForBindTest(ctx, queries, ownerUserID) + if err != nil { + t.Fatalf("create bot failed: %v", err) + } + + code, err := bindSvc.Issue(ctx, botID, ownerUserID, 10*time.Minute) + if err != nil { + t.Fatalf("issue bind code failed: %v", err) + } + if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); err != nil { + t.Fatalf("consume bind code failed: %v", err) + } + + after, err := bindSvc.Get(ctx, code.Token) + if err != nil { + t.Fatalf("get bind code failed: %v", err) + } + if after.UsedAt.IsZero() { + t.Fatal("expected used_at to be set after successful consume") + } + if after.UsedByChannelIdentityID != sourceChannelIdentity.ID { + t.Fatalf("expected used_by_channel_identity_id=%s, got %s", sourceChannelIdentity.ID, after.UsedByChannelIdentityID) + } + + linkedUserID, err := channelIdentitySvc.GetLinkedUserID(ctx, sourceChannelIdentity.ID) + if err != nil { + t.Fatalf("get linked user failed: %v", err) + } + if linkedUserID != ownerUserID { + t.Fatalf("expected linked user=%s, got %s", ownerUserID, linkedUserID) + } + + if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); !errors.Is(err, bind.ErrCodeUsed) { + t.Fatalf("expected ErrCodeUsed on second consume, got %v", err) + } +} + +func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) { + queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + ownerUserID, err := createUserForBindTest(ctx, queries) + if err != nil { + t.Fatalf("create owner user failed: %v", err) + } + otherUserID, err := createUserForBindTest(ctx, queries) + if err != nil { + t.Fatalf("create other user failed: %v", err) + } + sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel) + if err != nil { + t.Fatalf("create source channelIdentity failed: %v", err) + } + if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, sourceChannelIdentity.ID, otherUserID); err != nil { + t.Fatalf("pre-link source channelIdentity failed: %v", err) + } + botID, err := createBotForBindTest(ctx, queries, ownerUserID) + if err != nil { + t.Fatalf("create bot failed: %v", err) + } + + code, err := bindSvc.Issue(ctx, botID, ownerUserID, 10*time.Minute) + if err != nil { + t.Fatalf("issue bind code failed: %v", err) + } + if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); !errors.Is(err, bind.ErrLinkConflict) { + t.Fatalf("expected ErrLinkConflict, got %v", err) + } + + after, err := bindSvc.Get(ctx, code.Token) + if err != nil { + t.Fatalf("get bind code failed: %v", err) + } + if !after.UsedAt.IsZero() { + t.Fatal("expected used_at to remain empty when consume fails") + } +} +package bind_test + +import ( + "context" + "encoding/json" + "errors" + "log/slog" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/memohai/memoh/internal/channelidentities" + "github.com/memohai/memoh/internal/bind" + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" +) + +func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.Service, *bind.Service, func()) { + t.Helper() + + dsn := os.Getenv("TEST_POSTGRES_DSN") + if dsn == "" { + t.Skip("skip integration test: TEST_POSTGRES_DSN is not set") + } + + ctx := context.Background() + pool, err := pgxpool.New(ctx, dsn) + if err != nil { + t.Skipf("skip integration test: cannot connect to database: %v", err) + } + if err := pool.Ping(ctx); err != nil { + pool.Close() + t.Skipf("skip integration test: database ping failed: %v", err) + } + + queries := sqlc.New(pool) + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + channelIdentitySvc := channelidentities.NewService(logger, queries) + bindSvc := bind.NewService(logger, pool, queries) + + cleanup := func() { + pool.Close() + } + return queries, channelIdentitySvc, bindSvc, cleanup +} + +func createBotForBindTest(ctx context.Context, queries *sqlc.Queries, ownerChannelIdentityID string) (string, error) { + pgOwnerID, err := db.ParseUUID(ownerChannelIdentityID) + if err != nil { + return "", err + } + meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"}) + row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{ + OwnerChannelIdentityID: pgOwnerID, + Type: "personal", + DisplayName: pgtype.Text{String: "bind-test-bot", Valid: true}, + AvatarUrl: pgtype.Text{}, + IsActive: true, + Metadata: meta, + }) + if err != nil { + return "", err + } + return db.UUIDToString(row.ID), nil +} + +func createChatForBindTest(ctx context.Context, queries *sqlc.Queries, botID, channelIdentityID string) (string, error) { + pgBotID, err := db.ParseUUID(botID) + if err != nil { + return "", err + } + pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) + if err != nil { + return "", err + } + row, err := queries.CreateChat(ctx, sqlc.CreateChatParams{ + BotID: pgBotID, + Kind: "direct", + ParentChatID: pgtype.UUID{}, + Title: pgtype.Text{}, + CreatedBy: pgChannelIdentityID, + Metadata: []byte("{}"), + }) + if err != nil { + return "", err + } + return db.UUIDToString(row.ID), nil +} + +func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) { + queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + human, err := channelIdentitySvc.Create(ctx, channelidentities.KindHuman) + if err != nil { + t.Fatalf("create human failed: %v", err) + } + shadow, err := channelIdentitySvc.Create(ctx, channelidentities.KindShadow) + if err != nil { + t.Fatalf("create shadow failed: %v", err) + } + botID, err := createBotForBindTest(ctx, queries, human.ID) + if err != nil { + t.Fatalf("create bot failed: %v", err) + } + chatID, err := createChatForBindTest(ctx, queries, botID, human.ID) + if err != nil { + t.Fatalf("create chat failed: %v", err) + } + pgChatID, err := db.ParseUUID(chatID) + if err != nil { + t.Fatalf("parse chat id failed: %v", err) + } + pgShadowID, err := db.ParseUUID(shadow.ID) + if err != nil { + t.Fatalf("parse shadow id failed: %v", err) + } + pgHumanID, err := db.ParseUUID(human.ID) + if err != nil { + t.Fatalf("parse human id failed: %v", err) + } + if _, err := queries.AddChatParticipant(ctx, sqlc.AddChatParticipantParams{ + ChatID: pgChatID, + ChannelIdentityID: pgShadowID, + Role: "member", + }); err != nil { + t.Fatalf("add shadow participant failed: %v", err) + } + + code, err := bindSvc.Issue(ctx, botID, human.ID, 10*time.Minute) + if err != nil { + t.Fatalf("issue bind code failed: %v", err) + } + if err := bindSvc.Consume(ctx, code, shadow.ID); err != nil { + t.Fatalf("consume bind code failed: %v", err) + } + + after, err := bindSvc.Get(ctx, code.Token) + if err != nil { + t.Fatalf("get bind code failed: %v", err) + } + if after.UsedAt.IsZero() { + t.Fatal("expected used_at to be set after successful consume") + } + if after.UsedByChannelIdentityID != shadow.ID { + t.Fatalf("expected used_by_channel_identity_id=%s, got %s", shadow.ID, after.UsedByChannelIdentityID) + } + + canonical, err := channelIdentitySvc.Canonicalize(ctx, shadow.ID) + if err != nil { + t.Fatalf("canonicalize failed: %v", err) + } + if canonical != human.ID { + t.Fatalf("expected canonical=%s, got %s", human.ID, canonical) + } + if _, err := queries.GetChatParticipant(ctx, sqlc.GetChatParticipantParams{ + ChatID: pgChatID, + ChannelIdentityID: pgHumanID, + }); err != nil { + t.Fatalf("expected human participant after bind, got error: %v", err) + } + if _, err := queries.GetChatParticipant(ctx, sqlc.GetChatParticipantParams{ + ChatID: pgChatID, + ChannelIdentityID: pgShadowID, + }); !errors.Is(err, pgx.ErrNoRows) { + t.Fatalf("expected shadow participant removed after bind, got %v", err) + } + + if err := bindSvc.Consume(ctx, code, shadow.ID); !errors.Is(err, bind.ErrCodeUsed) { + t.Fatalf("expected ErrCodeUsed on second consume, got %v", err) + } +} + +func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) { + queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + humanA, err := channelIdentitySvc.Create(ctx, channelidentities.KindHuman) + if err != nil { + t.Fatalf("create humanA failed: %v", err) + } + humanB, err := channelIdentitySvc.Create(ctx, channelidentities.KindHuman) + if err != nil { + t.Fatalf("create humanB failed: %v", err) + } + shadow, err := channelIdentitySvc.Create(ctx, channelidentities.KindShadow) + if err != nil { + t.Fatalf("create shadow failed: %v", err) + } + botID, err := createBotForBindTest(ctx, queries, humanA.ID) + if err != nil { + t.Fatalf("create bot failed: %v", err) + } + + // Pre-link shadow to another user so bind consume hits link conflict. + if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, shadow.ID, humanB.ID); err != nil { + t.Fatalf("pre link shadow->humanB failed: %v", err) + } + + code, err := bindSvc.Issue(ctx, botID, humanA.ID, 10*time.Minute) + if err != nil { + t.Fatalf("issue bind code failed: %v", err) + } + if err := bindSvc.Consume(ctx, code, shadow.ID); !errors.Is(err, bind.ErrLinkConflict) { + t.Fatalf("expected ErrLinkConflict, got %v", err) + } + + after, err := bindSvc.Get(ctx, code.Token) + if err != nil { + t.Fatalf("get bind code failed: %v", err) + } + if !after.UsedAt.IsZero() { + t.Fatal("expected used_at to remain empty when consume fails") + } +} diff --git a/internal/bind/service_link_integration_test.go b/internal/bind/service_link_integration_test.go new file mode 100644 index 00000000..19b8f78a --- /dev/null +++ b/internal/bind/service_link_integration_test.go @@ -0,0 +1,168 @@ +package bind_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "os" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/memohai/memoh/internal/bind" + "github.com/memohai/memoh/internal/channelidentities" + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" +) + +func setupBindLinkIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.Service, *bind.Service, func()) { + t.Helper() + + dsn := os.Getenv("TEST_POSTGRES_DSN") + if dsn == "" { + t.Skip("skip integration test: TEST_POSTGRES_DSN is not set") + } + + ctx := context.Background() + pool, err := pgxpool.New(ctx, dsn) + if err != nil { + t.Skipf("skip integration test: cannot connect to database: %v", err) + } + if err := pool.Ping(ctx); err != nil { + pool.Close() + t.Skipf("skip integration test: database ping failed: %v", err) + } + + queries := sqlc.New(pool) + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + channelIdentitySvc := channelidentities.NewService(logger, queries) + bindSvc := bind.NewService(logger, pool, queries) + return queries, channelIdentitySvc, bindSvc, func() { pool.Close() } +} + +func createUserForBind(ctx context.Context, queries *sqlc.Queries) (string, error) { + row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{ + IsActive: true, + Metadata: []byte("{}"), + }) + if err != nil { + return "", err + } + return db.UUIDToString(row.ID), nil +} + +func createBotForBind(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) { + pgOwnerID, err := db.ParseUUID(ownerUserID) + if err != nil { + return "", err + } + meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"}) + row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{ + OwnerUserID: pgOwnerID, + Type: "personal", + DisplayName: pgtype.Text{String: "bind-test-bot", Valid: true}, + IsActive: true, + Metadata: meta, + }) + if err != nil { + return "", err + } + return db.UUIDToString(row.ID), nil +} + +func isLegacyBindSchemaError(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "relation \"users\" does not exist") +} + +func TestBindConsumeLinksChannelIdentityToIssuerUser(t *testing.T) { + queries, channelIdentitySvc, bindSvc, cleanup := setupBindLinkIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + ownerUserID, err := createUserForBind(ctx, queries) + if err != nil { + if isLegacyBindSchemaError(err) { + t.Skipf("skip integration test on legacy schema: %v", err) + } + t.Fatalf("create owner user failed: %v", err) + } + sourceChannelIdentity, err := channelIdentitySvc.ResolveByChannelIdentity(ctx, "feishu", fmt.Sprintf("bind-src-%d", time.Now().UnixNano()), "source") + if err != nil { + t.Fatalf("create source channelIdentity failed: %v", err) + } + code, err := bindSvc.Issue(ctx, ownerUserID, "feishu", 10*time.Minute) + if err != nil { + t.Fatalf("issue bind code failed: %v", err) + } + if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); err != nil { + t.Fatalf("consume bind code failed: %v", err) + } + + after, err := bindSvc.Get(ctx, code.Token) + if err != nil { + t.Fatalf("get bind code failed: %v", err) + } + if after.UsedAt.IsZero() { + t.Fatal("expected code used_at set after consume") + } + if after.UsedByChannelIdentityID != sourceChannelIdentity.ID { + t.Fatalf("expected used_by_channel_identity_id=%s, got %s", sourceChannelIdentity.ID, after.UsedByChannelIdentityID) + } + + linkedUserID, err := channelIdentitySvc.GetLinkedUserID(ctx, sourceChannelIdentity.ID) + if err != nil { + t.Fatalf("get linked user failed: %v", err) + } + if linkedUserID != ownerUserID { + t.Fatalf("expected linked user=%s, got %s", ownerUserID, linkedUserID) + } +} + +func TestBindConsumeConflictDoesNotMarkUsed(t *testing.T) { + queries, channelIdentitySvc, bindSvc, cleanup := setupBindLinkIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + issuerUserID, err := createUserForBind(ctx, queries) + if err != nil { + if isLegacyBindSchemaError(err) { + t.Skipf("skip integration test on legacy schema: %v", err) + } + t.Fatalf("create issuer user failed: %v", err) + } + otherUserID, err := createUserForBind(ctx, queries) + if err != nil { + t.Fatalf("create other user failed: %v", err) + } + sourceChannelIdentity, err := channelIdentitySvc.ResolveByChannelIdentity(ctx, "feishu", fmt.Sprintf("bind-conflict-%d", time.Now().UnixNano()), "source") + if err != nil { + t.Fatalf("create source channelIdentity failed: %v", err) + } + if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, sourceChannelIdentity.ID, otherUserID); err != nil { + t.Fatalf("pre-link source channelIdentity failed: %v", err) + } + code, err := bindSvc.Issue(ctx, issuerUserID, "feishu", 10*time.Minute) + if err != nil { + t.Fatalf("issue bind code failed: %v", err) + } + if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); !errors.Is(err, bind.ErrLinkConflict) { + t.Fatalf("expected ErrLinkConflict, got %v", err) + } + + after, err := bindSvc.Get(ctx, code.Token) + if err != nil { + t.Fatalf("get bind code failed: %v", err) + } + if !after.UsedAt.IsZero() { + t.Fatal("expected code to remain unused after conflict") + } +} diff --git a/internal/bind/types.go b/internal/bind/types.go new file mode 100644 index 00000000..0f5182b0 --- /dev/null +++ b/internal/bind/types.go @@ -0,0 +1,26 @@ +package bind + +import ( + "errors" + "time" +) + +var ( + ErrCodeNotFound = errors.New("bind code not found") + ErrCodeUsed = errors.New("bind code already used") + ErrCodeExpired = errors.New("bind code expired") + ErrCodeMismatch = errors.New("bind code context mismatch") + ErrLinkConflict = errors.New("channel identity user link conflict") +) + +// Code represents a one-time bind code for linking channel identity to user. +type Code struct { + ID string `json:"id"` + Platform string `json:"platform,omitempty"` + Token string `json:"token"` + IssuedByUserID string `json:"issued_by_user_id"` + ExpiresAt time.Time `json:"expires_at,omitempty"` + UsedAt time.Time `json:"used_at,omitempty"` + UsedByChannelIdentityID string `json:"used_by_channel_identity_id,omitempty"` + CreatedAt time.Time `json:"created_at"` +} diff --git a/internal/bots/service.go b/internal/bots/service.go index 567115af..8c358fe9 100644 --- a/internal/bots/service.go +++ b/internal/bots/service.go @@ -16,6 +16,7 @@ import ( "github.com/memohai/memoh/internal/db/sqlc" ) +// Service provides bot CRUD and membership management. type Service struct { queries *sqlc.Queries logger *slog.Logger @@ -23,14 +24,17 @@ type Service struct { } var ( - ErrBotNotFound = errors.New("bot not found") - ErrBotAccessDenied = errors.New("bot access denied") + ErrBotNotFound = errors.New("bot not found") + ErrBotAccessDenied = errors.New("bot access denied") + ErrOwnerUserNotFound = errors.New("owner user not found") ) +// AccessPolicy controls bot access behavior. type AccessPolicy struct { AllowPublicMember bool } +// NewService creates a new bot service. func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { if log == nil { log = slog.Default() @@ -46,7 +50,8 @@ func (s *Service) SetContainerLifecycle(lc ContainerLifecycle) { s.containerLifecycle = lc } -func (s *Service) AuthorizeAccess(ctx context.Context, actorID, botID string, isAdmin bool, policy AccessPolicy) (Bot, error) { +// AuthorizeAccess checks whether userID may access the given bot. +func (s *Service) AuthorizeAccess(ctx context.Context, userID, botID string, isAdmin bool, policy AccessPolicy) (Bot, error) { if s.queries == nil { return Bot{}, fmt.Errorf("bot queries not configured") } @@ -57,17 +62,18 @@ func (s *Service) AuthorizeAccess(ctx context.Context, actorID, botID string, is } return Bot{}, err } - if isAdmin || bot.OwnerUserID == actorID { + if isAdmin || bot.OwnerUserID == userID { return bot, nil } if policy.AllowPublicMember && bot.Type == BotTypePublic { - if _, err := s.GetMember(ctx, botID, actorID); err == nil { + if _, err := s.GetMember(ctx, botID, userID); err == nil { return bot, nil } } return Bot{}, ErrBotAccessDenied } +// Create creates a new bot owned by owner user. func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotRequest) (Bot, error) { if s.queries == nil { return Bot{}, fmt.Errorf("bot queries not configured") @@ -80,6 +86,9 @@ func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotR if err != nil { return Bot{}, err } + if err := s.ensureUserExists(ctx, ownerUUID); err != nil { + return Bot{}, err + } normalizedType, err := normalizeBotType(req.Type) if err != nil { return Bot{}, err @@ -127,6 +136,7 @@ func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotR return bot, nil } +// Get returns a bot by its ID. func (s *Service) Get(ctx context.Context, botID string) (Bot, error) { if s.queries == nil { return Bot{}, fmt.Errorf("bot queries not configured") @@ -142,6 +152,7 @@ func (s *Service) Get(ctx context.Context, botID string) (Bot, error) { return toBot(row) } +// ListByOwner returns bots owned by the given user. func (s *Service) ListByOwner(ctx context.Context, ownerUserID string) ([]Bot, error) { if s.queries == nil { return nil, fmt.Errorf("bot queries not configured") @@ -165,15 +176,16 @@ func (s *Service) ListByOwner(ctx context.Context, ownerUserID string) ([]Bot, e return items, nil } -func (s *Service) ListByMember(ctx context.Context, userID string) ([]Bot, error) { +// ListByMember returns bots where the user is a member. +func (s *Service) ListByMember(ctx context.Context, channelIdentityID string) ([]Bot, error) { if s.queries == nil { return nil, fmt.Errorf("bot queries not configured") } - userUUID, err := parseUUID(userID) + memberUUID, err := parseUUID(channelIdentityID) if err != nil { return nil, err } - rows, err := s.queries.ListBotsByMember(ctx, userUUID) + rows, err := s.queries.ListBotsByMember(ctx, memberUUID) if err != nil { return nil, err } @@ -188,12 +200,13 @@ func (s *Service) ListByMember(ctx context.Context, userID string) ([]Bot, error return items, nil } -func (s *Service) ListAccessible(ctx context.Context, userID string) ([]Bot, error) { - owned, err := s.ListByOwner(ctx, userID) +// ListAccessible returns all bots the user can access (owned or member). +func (s *Service) ListAccessible(ctx context.Context, channelIdentityID string) ([]Bot, error) { + owned, err := s.ListByOwner(ctx, channelIdentityID) if err != nil { return nil, err } - members, err := s.ListByMember(ctx, userID) + members, err := s.ListByMember(ctx, channelIdentityID) if err != nil { return nil, err } @@ -213,6 +226,7 @@ func (s *Service) ListAccessible(ctx context.Context, userID string) ([]Bot, err return items, nil } +// Update updates bot profile fields. func (s *Service) Update(ctx context.Context, botID string, req UpdateBotRequest) (Bot, error) { if s.queries == nil { return Bot{}, fmt.Errorf("bot queries not configured") @@ -264,6 +278,7 @@ func (s *Service) Update(ctx context.Context, botID string, req UpdateBotRequest return toBot(row) } +// TransferOwner transfers bot ownership to another user. func (s *Service) TransferOwner(ctx context.Context, botID string, ownerUserID string) (Bot, error) { if s.queries == nil { return Bot{}, fmt.Errorf("bot queries not configured") @@ -276,6 +291,9 @@ func (s *Service) TransferOwner(ctx context.Context, botID string, ownerUserID s if err != nil { return Bot{}, err } + if err := s.ensureUserExists(ctx, ownerUUID); err != nil { + return Bot{}, err + } row, err := s.queries.UpdateBotOwner(ctx, sqlc.UpdateBotOwnerParams{ ID: botUUID, OwnerUserID: ownerUUID, @@ -286,6 +304,7 @@ func (s *Service) TransferOwner(ctx context.Context, botID string, ownerUserID s return toBot(row) } +// Delete removes a bot and its associated resources. func (s *Service) Delete(ctx context.Context, botID string) error { if s.queries == nil { return fmt.Errorf("bot queries not configured") @@ -298,16 +317,34 @@ func (s *Service) Delete(ctx context.Context, botID string) error { return err } if s.containerLifecycle != nil { + s.logger.Info("cleaning up bot container before deletion", slog.String("bot_id", botID)) if err := s.containerLifecycle.CleanupBotContainer(ctx, botID); err != nil { s.logger.Error("failed to cleanup bot container", slog.String("bot_id", botID), slog.Any("error", err), ) } + } else { + s.logger.Warn("container lifecycle not configured, skipping container cleanup", slog.String("bot_id", botID)) } return s.queries.DeleteBotByID(ctx, botUUID) } +func (s *Service) ensureUserExists(ctx context.Context, userID pgtype.UUID) error { + if s.queries == nil { + return fmt.Errorf("bot queries not configured") + } + _, err := s.queries.GetUserByID(ctx, userID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ErrOwnerUserNotFound + } + return err + } + return nil +} + +// UpsertMember creates or updates a bot membership. func (s *Service) UpsertMember(ctx context.Context, botID string, req UpsertMemberRequest) (BotMember, error) { if s.queries == nil { return BotMember{}, fmt.Errorf("bot queries not configured") @@ -316,7 +353,7 @@ func (s *Service) UpsertMember(ctx context.Context, botID string, req UpsertMemb if err != nil { return BotMember{}, err } - userUUID, err := parseUUID(req.UserID) + memberUUID, err := parseUUID(req.UserID) if err != nil { return BotMember{}, err } @@ -326,7 +363,7 @@ func (s *Service) UpsertMember(ctx context.Context, botID string, req UpsertMemb } row, err := s.queries.UpsertBotMember(ctx, sqlc.UpsertBotMemberParams{ BotID: botUUID, - UserID: userUUID, + UserID: memberUUID, Role: role, }) if err != nil { @@ -335,6 +372,7 @@ func (s *Service) UpsertMember(ctx context.Context, botID string, req UpsertMemb return toBotMember(row), nil } +// ListMembers returns all members of a bot. func (s *Service) ListMembers(ctx context.Context, botID string) ([]BotMember, error) { if s.queries == nil { return nil, fmt.Errorf("bot queries not configured") @@ -354,7 +392,8 @@ func (s *Service) ListMembers(ctx context.Context, botID string) ([]BotMember, e return items, nil } -func (s *Service) GetMember(ctx context.Context, botID, userID string) (BotMember, error) { +// GetMember returns a specific bot member. +func (s *Service) GetMember(ctx context.Context, botID, channelIdentityID string) (BotMember, error) { if s.queries == nil { return BotMember{}, fmt.Errorf("bot queries not configured") } @@ -362,13 +401,13 @@ func (s *Service) GetMember(ctx context.Context, botID, userID string) (BotMembe if err != nil { return BotMember{}, err } - userUUID, err := parseUUID(userID) + memberUUID, err := parseUUID(channelIdentityID) if err != nil { return BotMember{}, err } row, err := s.queries.GetBotMember(ctx, sqlc.GetBotMemberParams{ BotID: botUUID, - UserID: userUUID, + UserID: memberUUID, }) if err != nil { return BotMember{}, err @@ -376,7 +415,8 @@ func (s *Service) GetMember(ctx context.Context, botID, userID string) (BotMembe return toBotMember(row), nil } -func (s *Service) DeleteMember(ctx context.Context, botID, userID string) error { +// DeleteMember removes a member from a bot. +func (s *Service) DeleteMember(ctx context.Context, botID, channelIdentityID string) error { if s.queries == nil { return fmt.Errorf("bot queries not configured") } @@ -384,18 +424,43 @@ func (s *Service) DeleteMember(ctx context.Context, botID, userID string) error if err != nil { return err } - userUUID, err := parseUUID(userID) + memberUUID, err := parseUUID(channelIdentityID) if err != nil { return err } return s.queries.DeleteBotMember(ctx, sqlc.DeleteBotMemberParams{ BotID: botUUID, - UserID: userUUID, + UserID: memberUUID, }) } +// UpsertMemberSimple creates or updates a bot membership with a direct channel identity ID and role. +// This satisfies the router.BotMemberService interface. +func (s *Service) UpsertMemberSimple(ctx context.Context, botID, channelIdentityID, role string) error { + _, err := s.UpsertMember(ctx, botID, UpsertMemberRequest{ + UserID: channelIdentityID, + Role: role, + }) + return err +} + +// IsMember checks if a user is a member of a bot. +func (s *Service) IsMember(ctx context.Context, botID, channelIdentityID string) (bool, error) { + _, err := s.GetMember(ctx, botID, channelIdentityID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return false, nil + } + return false, err + } + return true, nil +} + func normalizeBotType(raw string) (string, error) { normalized := strings.ToLower(strings.TrimSpace(raw)) + if normalized == "" { + return BotTypePersonal, nil + } switch normalized { case BotTypePersonal, BotTypePublic: return normalized, nil diff --git a/internal/bots/types.go b/internal/bots/types.go index 2ecc9fc8..7a9babf5 100644 --- a/internal/bots/types.go +++ b/internal/bots/types.go @@ -5,6 +5,7 @@ import ( "time" ) +// Bot represents a bot entity. type Bot struct { ID string `json:"id"` OwnerUserID string `json:"owner_user_id"` @@ -17,6 +18,7 @@ type Bot struct { UpdatedAt time.Time `json:"updated_at"` } +// BotMember represents a bot membership record. type BotMember struct { BotID string `json:"bot_id"` UserID string `json:"user_id"` @@ -24,6 +26,7 @@ type BotMember struct { CreatedAt time.Time `json:"created_at"` } +// CreateBotRequest is the input for creating a bot. type CreateBotRequest struct { Type string `json:"type"` DisplayName string `json:"display_name,omitempty"` @@ -32,6 +35,7 @@ type CreateBotRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } +// UpdateBotRequest is the input for updating a bot. type UpdateBotRequest struct { DisplayName *string `json:"display_name,omitempty"` AvatarURL *string `json:"avatar_url,omitempty"` @@ -39,19 +43,23 @@ type UpdateBotRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } +// TransferBotRequest is the input for transferring bot ownership. type TransferBotRequest struct { OwnerUserID string `json:"owner_user_id"` } +// UpsertMemberRequest is the input for upserting a bot member. type UpsertMemberRequest struct { UserID string `json:"user_id"` Role string `json:"role,omitempty"` } +// ListBotsResponse wraps a list of bots. type ListBotsResponse struct { Items []Bot `json:"items"` } +// ListMembersResponse wraps a list of bot members. type ListMembersResponse struct { Items []BotMember `json:"items"` } diff --git a/internal/channel/adapters/feishu/config.go b/internal/channel/adapters/feishu/config.go index d7def25a..74133ed4 100644 --- a/internal/channel/adapters/feishu/config.go +++ b/internal/channel/adapters/feishu/config.go @@ -79,8 +79,8 @@ func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { if value := strings.TrimSpace(criteria.Attribute("user_id")); value != "" && value == cfg.UserID { return true } - if criteria.ExternalID != "" { - if criteria.ExternalID == cfg.OpenID || criteria.ExternalID == cfg.UserID { + if criteria.SubjectID != "" { + if criteria.SubjectID == cfg.OpenID || criteria.SubjectID == cfg.UserID { return true } } diff --git a/internal/channel/adapters/feishu/feishu.go b/internal/channel/adapters/feishu/feishu.go index e1cc6d7f..22f3bc66 100644 --- a/internal/channel/adapters/feishu/feishu.go +++ b/internal/channel/adapters/feishu/feishu.go @@ -449,6 +449,7 @@ func extractFeishuInbound(event *larkim.P2MessageReceiveV1) channel.InboundMessa if message.Content != nil { _ = json.Unmarshal([]byte(*message.Content), &contentMap) } + isMentioned := hasFeishuMention(contentMap) if message.MessageType != nil { switch *message.MessageType { @@ -512,9 +513,9 @@ func extractFeishuInbound(event *larkim.P2MessageReceiveV1) channel.InboundMessa if senderOpenID != "" { attrs["open_id"] = senderOpenID } - externalID := senderOpenID - if externalID == "" { - externalID = senderID + subjectID := senderOpenID + if subjectID == "" { + subjectID = senderID } return channel.InboundMessage{ @@ -522,7 +523,7 @@ func extractFeishuInbound(event *larkim.P2MessageReceiveV1) channel.InboundMessa Message: msg, ReplyTarget: replyTo, Sender: channel.Identity{ - ExternalID: externalID, + SubjectID: subjectID, DisplayName: senderOpenID, Attributes: attrs, }, @@ -532,6 +533,27 @@ func extractFeishuInbound(event *larkim.P2MessageReceiveV1) channel.InboundMessa }, ReceivedAt: time.Now().UTC(), Source: "feishu", + Metadata: map[string]any{ + "is_mentioned": isMentioned, + }, + } +} + +func hasFeishuMention(contentMap map[string]any) bool { + if len(contentMap) == 0 { + return false + } + raw, ok := contentMap["mentions"] + if !ok { + return false + } + switch mentions := raw.(type) { + case []any: + return len(mentions) > 0 + case []map[string]any: + return len(mentions) > 0 + default: + return false } } diff --git a/internal/channel/adapters/feishu/feishu_test.go b/internal/channel/adapters/feishu/feishu_test.go index 5c6c47dc..765f10a3 100644 --- a/internal/channel/adapters/feishu/feishu_test.go +++ b/internal/channel/adapters/feishu/feishu_test.go @@ -70,6 +70,9 @@ func TestExtractFeishuInboundP2P(t *testing.T) { if got.ReplyTarget != "ou_1" { t.Fatalf("unexpected reply target: %s", got.ReplyTarget) } + if mentioned, _ := got.Metadata["is_mentioned"].(bool); mentioned { + t.Fatalf("unexpected mention flag for p2p message") + } } func TestExtractFeishuInboundGroup(t *testing.T) { @@ -101,6 +104,9 @@ func TestExtractFeishuInboundGroup(t *testing.T) { if got.ReplyTarget != "chat_id:oc_2" { t.Fatalf("unexpected reply target: %s", got.ReplyTarget) } + if mentioned, _ := got.Metadata["is_mentioned"].(bool); mentioned { + t.Fatalf("unexpected mention flag for group message without mentions") + } } func TestExtractFeishuInboundNonText(t *testing.T) { @@ -119,3 +125,27 @@ func TestExtractFeishuInboundNonText(t *testing.T) { t.Fatalf("expected empty text, got %s", got.Message.PlainText()) } } + +func TestExtractFeishuInboundMention(t *testing.T) { + t.Parallel() + + text := `{"text":"@bot hi","mentions":[{"key":"@bot"}]}` + msgType := larkim.MsgTypeText + chatType := "group" + chatID := "oc_3" + event := &larkim.P2MessageReceiveV1{ + Event: &larkim.P2MessageReceiveV1Data{ + Message: &larkim.EventMessage{ + MessageType: &msgType, + Content: &text, + ChatType: &chatType, + ChatId: &chatID, + }, + }, + } + got := extractFeishuInbound(event) + mentioned, ok := got.Metadata["is_mentioned"].(bool) + if !ok || !mentioned { + t.Fatalf("expected mention flag to be true") + } +} diff --git a/internal/channel/adapters/telegram/config.go b/internal/channel/adapters/telegram/config.go index 51d2d3a6..d77e2feb 100644 --- a/internal/channel/adapters/telegram/config.go +++ b/internal/channel/adapters/telegram/config.go @@ -82,8 +82,8 @@ func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { if value := strings.TrimSpace(criteria.Attribute("username")); value != "" && strings.EqualFold(value, cfg.Username) { return true } - if criteria.ExternalID != "" { - if criteria.ExternalID == cfg.ChatID || criteria.ExternalID == cfg.UserID || strings.EqualFold(criteria.ExternalID, cfg.Username) { + if criteria.SubjectID != "" { + if criteria.SubjectID == cfg.ChatID || criteria.SubjectID == cfg.UserID || strings.EqualFold(criteria.SubjectID, cfg.Username) { return true } } diff --git a/internal/channel/adapters/telegram/telegram.go b/internal/channel/adapters/telegram/telegram.go index 90e46062..aa19ebe4 100644 --- a/internal/channel/adapters/telegram/telegram.go +++ b/internal/channel/adapters/telegram/telegram.go @@ -183,7 +183,7 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig if text == "" && len(attachments) == 0 { continue } - externalID, displayName, attrs := resolveTelegramSender(update.Message) + subjectID, displayName, attrs := resolveTelegramSender(update.Message) chatID := "" chatType := "" chatName := "" @@ -193,6 +193,10 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig chatName = strings.TrimSpace(update.Message.Chat.Title) } replyRef := buildTelegramReplyRef(update.Message, chatID) + isReplyToBot := update.Message.ReplyToMessage != nil && + update.Message.ReplyToMessage.From != nil && + update.Message.ReplyToMessage.From.IsBot + isMentioned := isTelegramBotMentioned(update.Message, bot.Self.UserName) msg := channel.InboundMessage{ Channel: Type, Message: channel.Message{ @@ -205,7 +209,7 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig BotID: cfg.BotID, ReplyTarget: chatID, Sender: channel.Identity{ - ExternalID: externalID, + SubjectID: subjectID, DisplayName: displayName, Attributes: attrs, }, @@ -216,6 +220,10 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig }, ReceivedAt: time.Unix(int64(update.Message.Date), 0).UTC(), Source: "telegram", + Metadata: map[string]any{ + "is_mentioned": isMentioned, + "is_reply_to_bot": isReplyToBot, + }, } if a.logger != nil { a.logger.Info( @@ -565,6 +573,33 @@ func resolveTelegramParseMode(format channel.MessageFormat) string { } } +func isTelegramBotMentioned(msg *tgbotapi.Message, botUsername string) bool { + if msg == nil { + return false + } + normalizedBot := strings.ToLower(strings.TrimPrefix(strings.TrimSpace(botUsername), "@")) + if normalizedBot != "" { + text := strings.TrimSpace(msg.Text) + if text == "" { + text = strings.TrimSpace(msg.Caption) + } + if text != "" { + if strings.Contains(strings.ToLower(text), "@"+normalizedBot) { + return true + } + } + } + entities := make([]tgbotapi.MessageEntity, 0, len(msg.Entities)+len(msg.CaptionEntities)) + entities = append(entities, msg.Entities...) + entities = append(entities, msg.CaptionEntities...) + for _, entity := range entities { + if entity.Type == "text_mention" && entity.User != nil && entity.User.IsBot { + return true + } + } + return false +} + func (a *TelegramAdapter) collectTelegramAttachments(bot *tgbotapi.BotAPI, msg *tgbotapi.Message) []channel.Attachment { if msg == nil { return nil diff --git a/internal/channel/adapters/telegram/telegram_test.go b/internal/channel/adapters/telegram/telegram_test.go index 6d3a5834..18c7492b 100644 --- a/internal/channel/adapters/telegram/telegram_test.go +++ b/internal/channel/adapters/telegram/telegram_test.go @@ -24,3 +24,42 @@ func TestResolveTelegramSender(t *testing.T) { t.Fatalf("unexpected attrs: %#v", attrs) } } + +func TestIsTelegramBotMentioned(t *testing.T) { + t.Parallel() + + t.Run("text mention", func(t *testing.T) { + t.Parallel() + msg := &tgbotapi.Message{ + Text: "hello @MemohBot", + } + if !isTelegramBotMentioned(msg, "memohbot") { + t.Fatalf("expected bot mention from text") + } + }) + + t.Run("entity text mention", func(t *testing.T) { + t.Parallel() + msg := &tgbotapi.Message{ + Entities: []tgbotapi.MessageEntity{ + { + Type: "text_mention", + User: &tgbotapi.User{IsBot: true}, + }, + }, + } + if !isTelegramBotMentioned(msg, "") { + t.Fatalf("expected bot mention from text_mention entity") + } + }) + + t.Run("not mentioned", func(t *testing.T) { + t.Parallel() + msg := &tgbotapi.Message{ + Text: "hello everyone", + } + if isTelegramBotMentioned(msg, "memohbot") { + t.Fatalf("expected no mention") + } + }) +} diff --git a/internal/channel/config_test.go b/internal/channel/config_test.go index 837b3b53..b71e232e 100644 --- a/internal/channel/config_test.go +++ b/internal/channel/config_test.go @@ -63,7 +63,7 @@ func (a *testConfigAdapter) ResolveTarget(raw map[string]any) (string, error) { func (a *testConfigAdapter) MatchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { value := channel.ReadString(raw, "user") - return value != "" && value == criteria.ExternalID + return value != "" && value == criteria.SubjectID } func (a *testConfigAdapter) BuildUserConfig(identity channel.Identity) map[string]any { diff --git a/internal/channel/helpers_test.go b/internal/channel/helpers_test.go index de50b163..50434c57 100644 --- a/internal/channel/helpers_test.go +++ b/internal/channel/helpers_test.go @@ -55,13 +55,52 @@ func TestBindingCriteriaFromIdentity(t *testing.T) { t.Parallel() criteria := BindingCriteriaFromIdentity(Identity{ - ExternalID: "u1", + SubjectID: "u1", Attributes: map[string]string{"username": "alice"}, }) - if criteria.ExternalID != "u1" { - t.Fatalf("unexpected external id: %s", criteria.ExternalID) + if criteria.SubjectID != "u1" { + t.Fatalf("unexpected subject id: %s", criteria.SubjectID) } if criteria.Attribute("username") != "alice" { t.Fatalf("unexpected username: %s", criteria.Attribute("username")) } } + +func TestNormalizeChannelConfigStatus(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want string + wantErr bool + }{ + {name: "default pending", input: "", want: "pending"}, + {name: "pending passthrough", input: "pending", want: "pending"}, + {name: "verified passthrough", input: "verified", want: "verified"}, + {name: "disabled passthrough", input: "disabled", want: "disabled"}, + {name: "active alias", input: "active", want: "verified"}, + {name: "inactive alias", input: "inactive", want: "disabled"}, + {name: "unknown status", input: "paused", wantErr: true}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := normalizeChannelConfigStatus(tt.input) + if tt.wantErr { + if err == nil { + t.Fatalf("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if got != tt.want { + t.Fatalf("unexpected status: got %s, want %s", got, tt.want) + } + }) + } +} diff --git a/internal/channel/manager.go b/internal/channel/manager.go index feead996..628e5340 100644 --- a/internal/channel/manager.go +++ b/internal/channel/manager.go @@ -18,19 +18,12 @@ type ConfigLister interface { // ConfigResolver resolves effective configs and user bindings. Used for outbound sending. type ConfigResolver interface { ResolveEffectiveConfig(ctx context.Context, botID string, channelType ChannelType) (ChannelConfig, error) - GetUserConfig(ctx context.Context, actorUserID string, channelType ChannelType) (ChannelUserBinding, error) + GetChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType) (ChannelIdentityBinding, error) } -// BindingStore resolves user-channel bindings. Used by identity resolution. +// BindingStore resolves channel-identity bindings. Used by identity resolution. type BindingStore interface { - ResolveUserBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) -} - -// SessionStore manages channel session lifecycle. Used by identity resolution. -type SessionStore interface { - GetChannelSession(ctx context.Context, sessionID string) (ChannelSession, error) - UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string, replyTarget string, threadID string, metadata map[string]any) error - ListSessionsByBotPlatform(ctx context.Context, botID string, platform string) ([]ChannelSession, error) + ResolveChannelIdentityBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) } // ConfigStore is the full persistence interface. Components should depend on smaller @@ -39,8 +32,7 @@ type ConfigStore interface { ConfigLister ConfigResolver BindingStore - SessionStore - UpsertUserConfig(ctx context.Context, actorUserID string, channelType ChannelType, req UpsertUserConfigRequest) (ChannelUserBinding, error) + UpsertChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType, req UpsertChannelIdentityConfigRequest) (ChannelIdentityBinding, error) } // Middleware wraps an InboundHandler to add cross-cutting behavior. @@ -187,14 +179,14 @@ func (m *Manager) Send(ctx context.Context, botID string, channelType ChannelTyp } target := strings.TrimSpace(req.Target) if target == "" { - targetUserID := strings.TrimSpace(req.UserID) - if targetUserID == "" { + targetChannelIdentityID := strings.TrimSpace(req.ChannelIdentityID) + if targetChannelIdentityID == "" { return fmt.Errorf("target or user_id is required") } - userCfg, err := m.service.GetUserConfig(ctx, targetUserID, channelType) + userCfg, err := m.service.GetChannelIdentityConfig(ctx, targetChannelIdentityID, channelType) if err != nil { if m.logger != nil { - m.logger.Warn("channel binding missing", slog.String("channel", channelType.String()), slog.String("user_id", targetUserID)) + m.logger.Warn("channel binding missing", slog.String("channel", channelType.String()), slog.String("channel_identity_id", targetChannelIdentityID)) } return fmt.Errorf("channel binding required") } diff --git a/internal/channel/manager_integration_test.go b/internal/channel/manager_integration_test.go index 82b910c0..fc5fe1bd 100644 --- a/internal/channel/manager_integration_test.go +++ b/internal/channel/manager_integration_test.go @@ -12,26 +12,25 @@ import ( ) type fakeConfigStore struct { - effectiveConfig ChannelConfig - userConfig ChannelUserBinding - configsByType map[ChannelType][]ChannelConfig - session ChannelSession - boundUserID string + effectiveConfig ChannelConfig + channelIdentityConfig ChannelIdentityBinding + configsByType map[ChannelType][]ChannelConfig + boundChannelIdentityID string } func (f *fakeConfigStore) ResolveEffectiveConfig(ctx context.Context, botID string, channelType ChannelType) (ChannelConfig, error) { return f.effectiveConfig, nil } -func (f *fakeConfigStore) GetUserConfig(ctx context.Context, actorUserID string, channelType ChannelType) (ChannelUserBinding, error) { - if f.userConfig.ID == "" && len(f.userConfig.Config) == 0 { - return ChannelUserBinding{}, fmt.Errorf("channel user config not found") +func (f *fakeConfigStore) GetChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType) (ChannelIdentityBinding, error) { + if f.channelIdentityConfig.ID == "" && len(f.channelIdentityConfig.Config) == 0 { + return ChannelIdentityBinding{}, fmt.Errorf("channel user config not found") } - return f.userConfig, nil + return f.channelIdentityConfig, nil } -func (f *fakeConfigStore) UpsertUserConfig(ctx context.Context, actorUserID string, channelType ChannelType, req UpsertUserConfigRequest) (ChannelUserBinding, error) { - return f.userConfig, nil +func (f *fakeConfigStore) UpsertChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType, req UpsertChannelIdentityConfigRequest) (ChannelIdentityBinding, error) { + return f.channelIdentityConfig, nil } func (f *fakeConfigStore) ListConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelConfig, error) { @@ -41,26 +40,11 @@ func (f *fakeConfigStore) ListConfigsByType(ctx context.Context, channelType Cha return f.configsByType[channelType], nil } -func (f *fakeConfigStore) ResolveUserBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) { - if f.boundUserID == "" { +func (f *fakeConfigStore) ResolveChannelIdentityBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) { + if f.boundChannelIdentityID == "" { return "", fmt.Errorf("channel user binding not found") } - return f.boundUserID, nil -} - -func (f *fakeConfigStore) ListSessionsByBotPlatform(ctx context.Context, botID, platform string) ([]ChannelSession, error) { - return nil, nil -} - -func (f *fakeConfigStore) GetChannelSession(ctx context.Context, sessionID string) (ChannelSession, error) { - if f.session.SessionID == sessionID { - return f.session, nil - } - return ChannelSession{}, nil -} - -func (f *fakeConfigStore) UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string, replyTarget string, threadID string, metadata map[string]any) error { - return nil + return f.boundChannelIdentityID, nil } type fakeInboundProcessorIntegration struct { @@ -101,8 +85,8 @@ func (f *fakeAdapter) Descriptor() Descriptor { return Descriptor{Type: f.channelType, DisplayName: "Fake", Capabilities: ChannelCapabilities{Text: true}} } -func (f *fakeAdapter) ResolveTarget(userConfig map[string]any) (string, error) { - value := strings.TrimSpace(ReadString(userConfig, "target")) +func (f *fakeAdapter) ResolveTarget(channelIdentityConfig map[string]any) (string, error) { + value := strings.TrimSpace(ReadString(channelIdentityConfig, "target")) if value == "" { return "", fmt.Errorf("missing target") } @@ -135,13 +119,7 @@ func TestManagerHandleInboundIntegratesAdapter(t *testing.T) { t.Parallel() log := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) - store := &fakeConfigStore{ - session: ChannelSession{ - SessionID: "telegram:bot-1:chat-1", - BotID: "bot-1", - UserID: "user-1", - }, - } + store := &fakeConfigStore{} processor := &fakeInboundProcessorIntegration{ resp: &OutboundMessage{ Target: "123", @@ -202,7 +180,7 @@ func TestManagerSendUsesBinding(t *testing.T) { Credentials: map[string]any{"botToken": "token"}, UpdatedAt: time.Now(), }, - userConfig: ChannelUserBinding{ + channelIdentityConfig: ChannelIdentityBinding{ ID: "binding-1", Config: map[string]any{"target": "alice"}, }, @@ -213,7 +191,7 @@ func TestManagerSendUsesBinding(t *testing.T) { manager.RegisterAdapter(adapter) err := manager.Send(context.Background(), "bot-1", ChannelType("test"), SendRequest{ - UserID: "user-1", + ChannelIdentityID: "user-1", Message: Message{ Text: "hello", }, diff --git a/internal/channel/service.go b/internal/channel/service.go index d11437a2..c5c4f796 100644 --- a/internal/channel/service.go +++ b/internal/channel/service.go @@ -65,9 +65,9 @@ func (s *Service) UpsertConfig(ctx context.Context, botID string, channelType Ch if err != nil { return ChannelConfig{}, err } - status := strings.TrimSpace(req.Status) - if status == "" { - status = "pending" + status, err := normalizeChannelConfigStatus(req.Status) + if err != nil { + return ChannelConfig{}, err } verifiedAt := pgtype.Timestamptz{Valid: false} if req.VerifiedAt != nil { @@ -94,35 +94,52 @@ func (s *Service) UpsertConfig(ctx context.Context, botID string, channelType Ch return normalizeChannelConfig(row) } -// UpsertUserConfig creates or updates a user's channel binding. -func (s *Service) UpsertUserConfig(ctx context.Context, actorUserID string, channelType ChannelType, req UpsertUserConfigRequest) (ChannelUserBinding, error) { +func normalizeChannelConfigStatus(raw string) (string, error) { + status := strings.ToLower(strings.TrimSpace(raw)) + if status == "" { + return "pending", nil + } + switch status { + case "pending", "verified", "disabled": + return status, nil + case "active": + return "verified", nil + case "inactive": + return "disabled", nil + default: + return "", fmt.Errorf("invalid channel status: %s", raw) + } +} + +// UpsertChannelIdentityConfig creates or updates a channel identity's channel binding. +func (s *Service) UpsertChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType, req UpsertChannelIdentityConfigRequest) (ChannelIdentityBinding, error) { if s.queries == nil { - return ChannelUserBinding{}, fmt.Errorf("channel queries not configured") + return ChannelIdentityBinding{}, fmt.Errorf("channel queries not configured") } if channelType == "" { - return ChannelUserBinding{}, fmt.Errorf("channel type is required") + return ChannelIdentityBinding{}, fmt.Errorf("channel type is required") } normalized, err := s.registry.NormalizeUserConfig(channelType, req.Config) if err != nil { - return ChannelUserBinding{}, err + return ChannelIdentityBinding{}, err } payload, err := json.Marshal(normalized) if err != nil { - return ChannelUserBinding{}, err + return ChannelIdentityBinding{}, err } - pgUserID, err := db.ParseUUID(actorUserID) + pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) if err != nil { - return ChannelUserBinding{}, err + return ChannelIdentityBinding{}, err } row, err := s.queries.UpsertUserChannelBinding(ctx, sqlc.UpsertUserChannelBindingParams{ - UserID: pgUserID, - ChannelType: channelType.String(), - Config: payload, + UserID: pgChannelIdentityID, + Platform: channelType.String(), + Config: payload, }) if err != nil { - return ChannelUserBinding{}, err + return ChannelIdentityBinding{}, err } - return normalizeChannelUserBindingRow(row) + return normalizeChannelIdentityBinding(row) } // ResolveEffectiveConfig returns the active channel configuration for a bot. @@ -181,54 +198,54 @@ func (s *Service) ListConfigsByType(ctx context.Context, channelType ChannelType return items, nil } -// GetUserConfig returns the user's channel binding for the given channel type. -func (s *Service) GetUserConfig(ctx context.Context, actorUserID string, channelType ChannelType) (ChannelUserBinding, error) { +// GetChannelIdentityConfig returns the channel identity's channel binding for the given channel type. +func (s *Service) GetChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType) (ChannelIdentityBinding, error) { if s.queries == nil { - return ChannelUserBinding{}, fmt.Errorf("channel queries not configured") + return ChannelIdentityBinding{}, fmt.Errorf("channel queries not configured") } if channelType == "" { - return ChannelUserBinding{}, fmt.Errorf("channel type is required") + return ChannelIdentityBinding{}, fmt.Errorf("channel type is required") } - pgUserID, err := db.ParseUUID(actorUserID) + pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) if err != nil { - return ChannelUserBinding{}, err + return ChannelIdentityBinding{}, err } row, err := s.queries.GetUserChannelBinding(ctx, sqlc.GetUserChannelBindingParams{ - UserID: pgUserID, - ChannelType: channelType.String(), + UserID: pgChannelIdentityID, + Platform: channelType.String(), }) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return ChannelUserBinding{}, fmt.Errorf("channel user config not found") + return ChannelIdentityBinding{}, fmt.Errorf("channel user config not found") } - return ChannelUserBinding{}, err + return ChannelIdentityBinding{}, err } config, err := DecodeConfigMap(row.Config) if err != nil { - return ChannelUserBinding{}, err + return ChannelIdentityBinding{}, err } - return ChannelUserBinding{ - ID: db.UUIDToString(row.ID), - ChannelType: ChannelType(row.ChannelType), - UserID: db.UUIDToString(row.UserID), - Config: config, - CreatedAt: db.TimeFromPg(row.CreatedAt), - UpdatedAt: db.TimeFromPg(row.UpdatedAt), + return ChannelIdentityBinding{ + ID: db.UUIDToString(row.ID), + ChannelType: ChannelType(row.Platform), + ChannelIdentityID: db.UUIDToString(row.UserID), + Config: config, + CreatedAt: db.TimeFromPg(row.CreatedAt), + UpdatedAt: db.TimeFromPg(row.UpdatedAt), }, nil } -// ListUserConfigsByType returns all user bindings for the given channel type. -func (s *Service) ListUserConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelUserBinding, error) { +// ListChannelIdentityConfigsByType returns all channel identity bindings for the given channel type. +func (s *Service) ListChannelIdentityConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelIdentityBinding, error) { if s.queries == nil { return nil, fmt.Errorf("channel queries not configured") } - rows, err := s.queries.ListUserChannelBindingsByType(ctx, channelType.String()) + rows, err := s.queries.ListUserChannelBindingsByPlatform(ctx, channelType.String()) if err != nil { return nil, err } - items := make([]ChannelUserBinding, 0, len(rows)) + items := make([]ChannelIdentityBinding, 0, len(rows)) for _, row := range rows { - item, err := normalizeChannelUserBindingRow(row) + item, err := normalizeChannelIdentityBinding(row) if err != nil { return nil, err } @@ -237,119 +254,9 @@ func (s *Service) ListUserConfigsByType(ctx context.Context, channelType Channel return items, nil } -// GetChannelSession returns the session with the given ID, or an empty session if not found. -func (s *Service) GetChannelSession(ctx context.Context, sessionID string) (ChannelSession, error) { - if s.queries == nil { - return ChannelSession{}, fmt.Errorf("channel queries not configured") - } - row, err := s.queries.GetChannelSessionByID(ctx, sessionID) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return ChannelSession{}, nil - } - return ChannelSession{}, err - } - return normalizeChannelSession(row) -} - -// ListSessionsByBotPlatform returns all sessions for the given bot and platform. -func (s *Service) ListSessionsByBotPlatform(ctx context.Context, botID, platform string) ([]ChannelSession, error) { - if s.queries == nil { - return nil, fmt.Errorf("channel queries not configured") - } - botID = strings.TrimSpace(botID) - platform = strings.TrimSpace(platform) - if botID == "" { - return nil, fmt.Errorf("bot id is required") - } - if platform == "" { - return nil, fmt.Errorf("platform is required") - } - pgBotID, err := db.ParseUUID(botID) - if err != nil { - return nil, err - } - rows, err := s.queries.ListChannelSessionsByBotPlatform(ctx, sqlc.ListChannelSessionsByBotPlatformParams{ - BotID: pgBotID, - Platform: platform, - }) - if err != nil { - return nil, err - } - items := make([]ChannelSession, 0, len(rows)) - for _, row := range rows { - item, err := normalizeChannelSession(row) - if err != nil { - return nil, err - } - items = append(items, item) - } - return items, nil -} - -// UpsertChannelSession creates or updates a channel session record. -func (s *Service) UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string, replyTarget string, threadID string, metadata map[string]any) error { - if s.queries == nil { - return fmt.Errorf("channel queries not configured") - } - pgUserID := pgtype.UUID{Valid: false} - if strings.TrimSpace(userID) != "" { - parsed, err := db.ParseUUID(userID) - if err != nil { - return err - } - pgUserID = parsed - } - botUUID, err := db.ParseUUID(botID) - if err != nil { - return err - } - var channelUUID pgtype.UUID - if strings.TrimSpace(channelConfigID) != "" { - channelUUID, err = db.ParseUUID(channelConfigID) - if err != nil { - return err - } - } - pgContactID := pgtype.UUID{Valid: false} - if strings.TrimSpace(contactID) != "" { - parsed, err := db.ParseUUID(contactID) - if err != nil { - return err - } - pgContactID = parsed - } - payload := metadata - if payload == nil { - payload = map[string]any{} - } - metaBytes, err := json.Marshal(payload) - if err != nil { - return err - } - _, err = s.queries.UpsertChannelSession(ctx, sqlc.UpsertChannelSessionParams{ - SessionID: sessionID, - BotID: botUUID, - ChannelConfigID: channelUUID, - UserID: pgUserID, - ContactID: pgContactID, - Platform: platform, - ReplyTarget: pgtype.Text{ - String: strings.TrimSpace(replyTarget), - Valid: strings.TrimSpace(replyTarget) != "", - }, - ThreadID: pgtype.Text{ - String: strings.TrimSpace(threadID), - Valid: strings.TrimSpace(threadID) != "", - }, - Metadata: metaBytes, - }) - return err -} - -// ResolveUserBinding finds the user ID whose channel binding matches the given criteria. -func (s *Service) ResolveUserBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) { - rows, err := s.ListUserConfigsByType(ctx, channelType) +// ResolveChannelIdentityBinding finds the channel identity ID whose channel binding matches the given criteria. +func (s *Service) ResolveChannelIdentityBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) { + rows, err := s.ListChannelIdentityConfigsByType(ctx, channelType) if err != nil { return "", err } @@ -358,7 +265,7 @@ func (s *Service) ResolveUserBinding(ctx context.Context, channelType ChannelTyp } for _, row := range rows { if s.registry.MatchUserBinding(channelType, row.Config, criteria) { - return row.UserID, nil + return row.ChannelIdentityID, nil } } return "", fmt.Errorf("channel user binding not found") @@ -392,48 +299,25 @@ func normalizeChannelConfig(row sqlc.BotChannelConfig) (ChannelConfig, error) { Credentials: credentials, ExternalIdentity: externalIdentity, SelfIdentity: selfIdentity, - Routing: routing, - Status: strings.TrimSpace(row.Status), + Routing: routing, + Status: strings.TrimSpace(row.Status), VerifiedAt: verifiedAt, CreatedAt: db.TimeFromPg(row.CreatedAt), UpdatedAt: db.TimeFromPg(row.UpdatedAt), }, nil } -func normalizeChannelUserBindingRow(row sqlc.UserChannelBinding) (ChannelUserBinding, error) { +func normalizeChannelIdentityBinding(row sqlc.UserChannelBinding) (ChannelIdentityBinding, error) { config, err := DecodeConfigMap(row.Config) if err != nil { - return ChannelUserBinding{}, err + return ChannelIdentityBinding{}, err } - return ChannelUserBinding{ - ID: db.UUIDToString(row.ID), - ChannelType: ChannelType(row.ChannelType), - UserID: db.UUIDToString(row.UserID), - Config: config, - CreatedAt: db.TimeFromPg(row.CreatedAt), - UpdatedAt: db.TimeFromPg(row.UpdatedAt), + return ChannelIdentityBinding{ + ID: db.UUIDToString(row.ID), + ChannelType: ChannelType(row.Platform), + ChannelIdentityID: db.UUIDToString(row.UserID), + Config: config, + CreatedAt: db.TimeFromPg(row.CreatedAt), + UpdatedAt: db.TimeFromPg(row.UpdatedAt), }, nil } - -func normalizeChannelSession(row sqlc.ChannelSession) (ChannelSession, error) { - metadata, err := DecodeConfigMap(row.Metadata) - if err != nil { - return ChannelSession{}, err - } - return ChannelSession{ - SessionID: row.SessionID, - BotID: db.UUIDToString(row.BotID), - ChannelConfigID: db.UUIDToString(row.ChannelConfigID), - UserID: db.UUIDToString(row.UserID), - ContactID: db.UUIDToString(row.ContactID), - Platform: row.Platform, - ReplyTarget: strings.TrimSpace(row.ReplyTarget.String), - ThreadID: strings.TrimSpace(row.ThreadID.String), - Metadata: metadata, - CreatedAt: db.TimeFromPg(row.CreatedAt), - UpdatedAt: db.TimeFromPg(row.UpdatedAt), - }, nil -} - - - diff --git a/internal/channel/types.go b/internal/channel/types.go index 0bb2147e..8c6ef1b0 100644 --- a/internal/channel/types.go +++ b/internal/channel/types.go @@ -17,7 +17,7 @@ func (c ChannelType) String() string { // Identity represents a sender's identity on a channel. type Identity struct { - ExternalID string + SubjectID string DisplayName string Attributes map[string]string } @@ -59,7 +59,7 @@ func (m InboundMessage) SessionID() string { if strings.TrimSpace(m.SessionKey) != "" { return strings.TrimSpace(m.SessionKey) } - senderID := strings.TrimSpace(m.Sender.ExternalID) + senderID := strings.TrimSpace(m.Sender.SubjectID) if senderID == "" { senderID = strings.TrimSpace(m.Sender.DisplayName) } @@ -118,14 +118,14 @@ const ( // MessagePart is a single element within a rich-text message. type MessagePart struct { - Type MessagePartType `json:"type"` - Text string `json:"text,omitempty"` - URL string `json:"url,omitempty"` - Styles []MessageTextStyle `json:"styles,omitempty"` - Language string `json:"language,omitempty"` - UserID string `json:"user_id,omitempty"` - Emoji string `json:"emoji,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Type MessagePartType `json:"type"` + Text string `json:"text,omitempty"` + URL string `json:"url,omitempty"` + Styles []MessageTextStyle `json:"styles,omitempty"` + Language string `json:"language,omitempty"` + ChannelIdentityID string `json:"channel_identity_id,omitempty"` + Emoji string `json:"emoji,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } // AttachmentType classifies the kind of binary attachment. @@ -227,7 +227,7 @@ func (m Message) PlainText() string { // BindingCriteria specifies conditions for matching a user-channel binding. type BindingCriteria struct { - ExternalID string + SubjectID string Attributes map[string]string } @@ -242,7 +242,7 @@ func (c BindingCriteria) Attribute(key string) string { // BindingCriteriaFromIdentity creates BindingCriteria from a channel Identity. func BindingCriteriaFromIdentity(identity Identity) BindingCriteria { return BindingCriteria{ - ExternalID: strings.TrimSpace(identity.ExternalID), + SubjectID: strings.TrimSpace(identity.SubjectID), Attributes: identity.Attributes, } } @@ -262,14 +262,14 @@ type ChannelConfig struct { UpdatedAt time.Time } -// ChannelUserBinding represents a user's binding to a specific channel type. -type ChannelUserBinding struct { - ID string - ChannelType ChannelType - UserID string - Config map[string]any - CreatedAt time.Time - UpdatedAt time.Time +// ChannelIdentityBinding represents a channel identity's binding to a specific channel type. +type ChannelIdentityBinding struct { + ID string + ChannelType ChannelType + ChannelIdentityID string + Config map[string]any + CreatedAt time.Time + UpdatedAt time.Time } // UpsertConfigRequest is the input for creating or updating a channel configuration. @@ -282,29 +282,14 @@ type UpsertConfigRequest struct { VerifiedAt *time.Time `json:"verified_at,omitempty"` } -// UpsertUserConfigRequest is the input for creating or updating a user-channel binding. -type UpsertUserConfigRequest struct { +// UpsertChannelIdentityConfigRequest is the input for creating or updating a channel-identity binding. +type UpsertChannelIdentityConfigRequest struct { Config map[string]any `json:"config"` } -// ChannelSession tracks an active conversation session on a channel. -type ChannelSession struct { - SessionID string - BotID string - ChannelConfigID string - UserID string - ContactID string - Platform string - ReplyTarget string - ThreadID string - Metadata map[string]any - CreatedAt time.Time - UpdatedAt time.Time -} - // SendRequest is the input for sending an outbound message through a channel. type SendRequest struct { - Target string `json:"target,omitempty"` - UserID string `json:"user_id,omitempty"` - Message Message `json:"message"` + Target string `json:"target,omitempty"` + ChannelIdentityID string `json:"channel_identity_id,omitempty"` + Message Message `json:"message"` } diff --git a/internal/channelidentities/service.go b/internal/channelidentities/service.go new file mode 100644 index 00000000..a43b4dc0 --- /dev/null +++ b/internal/channelidentities/service.go @@ -0,0 +1,291 @@ +package channelidentities + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "strings" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" +) + +// Service provides channel identity lifecycle operations. +type Service struct { + queries *sqlc.Queries + logger *slog.Logger +} + +var ( + ErrChannelIdentityNotFound = errors.New("channel identity not found") +) + +// NewService creates a new channel identity service. +func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { + if log == nil { + log = slog.Default() + } + return &Service{ + queries: queries, + logger: log.With(slog.String("service", "channelidentities")), + } +} + +// Create creates a new channel identity for the given channel subject. +func (s *Service) Create(ctx context.Context, channel, channelSubjectID, displayName string) (ChannelIdentity, error) { + if s.queries == nil { + return ChannelIdentity{}, fmt.Errorf("channel identity queries not configured") + } + channel = normalizeChannel(channel) + channelSubjectID = strings.TrimSpace(channelSubjectID) + if channel == "" || channelSubjectID == "" { + return ChannelIdentity{}, fmt.Errorf("channel and channel_subject_id are required") + } + row, err := s.queries.CreateChannelIdentity(ctx, sqlc.CreateChannelIdentityParams{ + UserID: pgtype.UUID{}, + Channel: channel, + ChannelSubjectID: channelSubjectID, + DisplayName: toPgText(displayName), + Metadata: emptyMetadataBytes(), + }) + if err != nil { + return ChannelIdentity{}, err + } + return toChannelIdentity(row), nil +} + +// GetByID returns a channel identity by its ID. +func (s *Service) GetByID(ctx context.Context, channelIdentityID string) (ChannelIdentity, error) { + if s.queries == nil { + return ChannelIdentity{}, fmt.Errorf("channel identity queries not configured") + } + pgID, err := db.ParseUUID(channelIdentityID) + if err != nil { + return ChannelIdentity{}, err + } + row, err := s.queries.GetChannelIdentityByID(ctx, pgID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ChannelIdentity{}, ErrChannelIdentityNotFound + } + return ChannelIdentity{}, err + } + return toChannelIdentity(row), nil +} + +// Canonicalize validates and returns the same channel identity ID. +func (s *Service) Canonicalize(ctx context.Context, channelIdentityID string) (string, error) { + if s.queries == nil { + return "", fmt.Errorf("channel identity queries not configured") + } + pgID, err := db.ParseUUID(channelIdentityID) + if err != nil { + return "", err + } + _, err = s.queries.GetChannelIdentityByID(ctx, pgID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return "", ErrChannelIdentityNotFound + } + return "", err + } + return channelIdentityID, nil +} + +// ResolveByChannelIdentity looks up or creates a channel identity for (channel, channel_subject_id). +func (s *Service) ResolveByChannelIdentity(ctx context.Context, channel, channelSubjectID, displayName string) (ChannelIdentity, error) { + if s.queries == nil { + return ChannelIdentity{}, fmt.Errorf("channel identity queries not configured") + } + channel = normalizeChannel(channel) + channelSubjectID = strings.TrimSpace(channelSubjectID) + if channel == "" || channelSubjectID == "" { + return ChannelIdentity{}, fmt.Errorf("channel and channel_subject_id are required") + } + + row, err := s.queries.UpsertChannelIdentityByChannelSubject(ctx, sqlc.UpsertChannelIdentityByChannelSubjectParams{ + UserID: pgtype.UUID{}, + Channel: channel, + ChannelSubjectID: channelSubjectID, + DisplayName: toPgText(displayName), + Metadata: emptyMetadataBytes(), + }) + if err != nil { + return ChannelIdentity{}, err + } + return toChannelIdentity(row), nil +} + +// UpsertChannelIdentity creates or updates a channel identity mapping. +func (s *Service) UpsertChannelIdentity(ctx context.Context, channel, channelSubjectID, displayName string, metadata map[string]any) (ChannelIdentity, error) { + if s.queries == nil { + return ChannelIdentity{}, fmt.Errorf("channel identity queries not configured") + } + channel = normalizeChannel(channel) + channelSubjectID = strings.TrimSpace(channelSubjectID) + if metadata == nil { + metadata = map[string]any{} + } + metaBytes, err := json.Marshal(metadata) + if err != nil { + return ChannelIdentity{}, err + } + row, err := s.queries.UpsertChannelIdentityByChannelSubject(ctx, sqlc.UpsertChannelIdentityByChannelSubjectParams{ + UserID: pgtype.UUID{}, + Channel: channel, + ChannelSubjectID: channelSubjectID, + DisplayName: toPgText(displayName), + Metadata: metaBytes, + }) + if err != nil { + return ChannelIdentity{}, err + } + return toChannelIdentity(row), nil +} + +// ListCanonicalChannelIdentities lists channel identities under the same linked user. +func (s *Service) ListCanonicalChannelIdentities(ctx context.Context, channelIdentityID string) ([]ChannelIdentity, error) { + if s.queries == nil { + return nil, fmt.Errorf("channel identity queries not configured") + } + pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) + if err != nil { + return nil, err + } + row, err := s.queries.GetChannelIdentityByID(ctx, pgChannelIdentityID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrChannelIdentityNotFound + } + return nil, err + } + if !row.UserID.Valid { + return []ChannelIdentity{toChannelIdentity(row)}, nil + } + rows, err := s.queries.ListChannelIdentitiesByUserID(ctx, row.UserID) + if err != nil { + return nil, err + } + result := make([]ChannelIdentity, 0, len(rows)) + for _, item := range rows { + result = append(result, toChannelIdentity(item)) + } + return result, nil +} + +// ListUserChannelIdentities lists all channel identities linked to a user. +func (s *Service) ListUserChannelIdentities(ctx context.Context, userID string) ([]ChannelIdentity, error) { + if s.queries == nil { + return nil, fmt.Errorf("channel identity queries not configured") + } + pgUserID, err := db.ParseUUID(userID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListChannelIdentitiesByUserID(ctx, pgUserID) + if err != nil { + return nil, err + } + result := make([]ChannelIdentity, 0, len(rows)) + for _, row := range rows { + result = append(result, toChannelIdentity(row)) + } + return result, nil +} + +// GetLinkedUserID returns the linked user ID for a channel identity. +func (s *Service) GetLinkedUserID(ctx context.Context, channelIdentityID string) (string, error) { + if s.queries == nil { + return "", fmt.Errorf("channel identity queries not configured") + } + pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) + if err != nil { + return "", err + } + row, err := s.queries.GetChannelIdentityByID(ctx, pgChannelIdentityID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return "", nil + } + return "", err + } + if !row.UserID.Valid { + return "", nil + } + return db.UUIDToString(row.UserID), nil +} + +// LinkChannelIdentityToUser binds a channel identity to a user. +func (s *Service) LinkChannelIdentityToUser(ctx context.Context, channelIdentityID, userID string) error { + if s.queries == nil { + return fmt.Errorf("channel identity queries not configured") + } + pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) + if err != nil { + return err + } + pgUserID, err := db.ParseUUID(userID) + if err != nil { + return err + } + _, err = s.queries.SetChannelIdentityLinkedUser(ctx, sqlc.SetChannelIdentityLinkedUserParams{ + ID: pgChannelIdentityID, + UserID: pgUserID, + }) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ErrChannelIdentityNotFound + } + return err + } + return nil +} + +func toChannelIdentity(row sqlc.ChannelIdentity) ChannelIdentity { + var metadata map[string]any + if len(row.Metadata) > 0 { + _ = json.Unmarshal(row.Metadata, &metadata) + } + if metadata == nil { + metadata = map[string]any{} + } + displayName := "" + if row.DisplayName.Valid { + displayName = strings.TrimSpace(row.DisplayName.String) + } + userID := "" + if row.UserID.Valid { + userID = db.UUIDToString(row.UserID) + } + return ChannelIdentity{ + ID: db.UUIDToString(row.ID), + UserID: userID, + Channel: row.Channel, + ChannelSubjectID: row.ChannelSubjectID, + DisplayName: displayName, + Metadata: metadata, + CreatedAt: db.TimeFromPg(row.CreatedAt), + UpdatedAt: db.TimeFromPg(row.UpdatedAt), + } +} + +func normalizeChannel(channel string) string { + return strings.ToLower(strings.TrimSpace(channel)) +} + +func toPgText(value string) pgtype.Text { + value = strings.TrimSpace(value) + return pgtype.Text{ + String: value, + Valid: value != "", + } +} + +func emptyMetadataBytes() []byte { + return []byte("{}") +} diff --git a/internal/channelidentities/service_identity_integration_test.go b/internal/channelidentities/service_identity_integration_test.go new file mode 100644 index 00000000..da24c05c --- /dev/null +++ b/internal/channelidentities/service_identity_integration_test.go @@ -0,0 +1,109 @@ +package channelidentities_test + +import ( + "context" + "fmt" + "log/slog" + "os" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/memohai/memoh/internal/channelidentities" + "github.com/memohai/memoh/internal/db/sqlc" +) + +func setupChannelIdentityIdentityIntegrationTest(t *testing.T) (*channelidentities.Service, *sqlc.Queries, func()) { + t.Helper() + + dsn := os.Getenv("TEST_POSTGRES_DSN") + if dsn == "" { + t.Skip("skip integration test: TEST_POSTGRES_DSN is not set") + } + + ctx := context.Background() + pool, err := pgxpool.New(ctx, dsn) + if err != nil { + t.Skipf("skip integration test: cannot connect to database: %v", err) + } + if err := pool.Ping(ctx); err != nil { + pool.Close() + t.Skipf("skip integration test: database ping failed: %v", err) + } + + queries := sqlc.New(pool) + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + svc := channelidentities.NewService(logger, queries) + return svc, queries, func() { pool.Close() } +} + +func formatUUID(bytes [16]byte) string { + return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16]) +} + +func isLegacyChannelIdentitySchemaError(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "channelidentities_kind_check") || + strings.Contains(msg, "column \"user_id\" of relation \"channelidentities\" does not exist") || + strings.Contains(msg, "column \"channel_subject_id\" of relation \"channelidentities\" does not exist") +} + +func TestChannelIdentityResolveChannelIdentityStable(t *testing.T) { + svc, _, cleanup := setupChannelIdentityIdentityIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + externalID := fmt.Sprintf("stable_%d", time.Now().UnixNano()) + first, err := svc.ResolveByChannelIdentity(ctx, "feishu", externalID, "first") + if err != nil { + if isLegacyChannelIdentitySchemaError(err) { + t.Skipf("skip integration test on legacy schema: %v", err) + } + t.Fatalf("first resolve failed: %v", err) + } + second, err := svc.ResolveByChannelIdentity(ctx, "feishu", externalID, "second") + if err != nil { + t.Fatalf("second resolve failed: %v", err) + } + if first.ID != second.ID { + t.Fatalf("expected same channelIdentity id, got %s and %s", first.ID, second.ID) + } +} + +func TestChannelIdentityLinkToUser(t *testing.T) { + svc, queries, cleanup := setupChannelIdentityIdentityIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + channelIdentity, err := svc.ResolveByChannelIdentity(ctx, "telegram", fmt.Sprintf("link_%d", time.Now().UnixNano()), "tg") + if err != nil { + if isLegacyChannelIdentitySchemaError(err) { + t.Skipf("skip integration test on legacy schema: %v", err) + } + t.Fatalf("resolve channelIdentity failed: %v", err) + } + user, err := queries.CreateUser(ctx, sqlc.CreateUserParams{ + IsActive: true, + Metadata: []byte("{}"), + }) + if err != nil { + t.Fatalf("create user failed: %v", err) + } + userID := formatUUID(user.ID.Bytes) + + if err := svc.LinkChannelIdentityToUser(ctx, channelIdentity.ID, userID); err != nil { + t.Fatalf("link channelIdentity to user failed: %v", err) + } + linkedUserID, err := svc.GetLinkedUserID(ctx, channelIdentity.ID) + if err != nil { + t.Fatalf("get linked user failed: %v", err) + } + if linkedUserID != userID { + t.Fatalf("expected linked user=%s, got %s", userID, linkedUserID) + } +} diff --git a/internal/channelidentities/service_integration_test.go b/internal/channelidentities/service_integration_test.go new file mode 100644 index 00000000..ddc08e85 --- /dev/null +++ b/internal/channelidentities/service_integration_test.go @@ -0,0 +1,99 @@ +//go:build ignore +// +build ignore + +package channelidentities_test + +import ( + "context" + "fmt" + "log/slog" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/memohai/memoh/internal/channelidentities" + "github.com/memohai/memoh/internal/db/sqlc" +) + +func setupIntegrationTest(t *testing.T) (*channelidentities.Service, *sqlc.Queries, func()) { + t.Helper() + + dsn := os.Getenv("TEST_POSTGRES_DSN") + if dsn == "" { + t.Skip("skip integration test: TEST_POSTGRES_DSN is not set") + } + + ctx := context.Background() + pool, err := pgxpool.New(ctx, dsn) + if err != nil { + t.Skipf("skip integration test: cannot connect to database: %v", err) + } + if err := pool.Ping(ctx); err != nil { + pool.Close() + t.Skipf("skip integration test: database ping failed: %v", err) + } + + queries := sqlc.New(pool) + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + svc := channelidentities.NewService(logger, queries) + + return svc, queries, func() { pool.Close() } +} + +func toUUIDString(v [16]byte) string { + return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", v[0:4], v[4:6], v[6:8], v[8:10], v[10:16]) +} + +func TestIntegrationResolveByChannelIdentityStability(t *testing.T) { + svc, _, cleanup := setupIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + key := fmt.Sprintf("ext_%d", time.Now().UnixNano()) + + first, err := svc.ResolveByChannelIdentity(ctx, "feishu", key, "first") + if err != nil { + t.Fatalf("first resolve failed: %v", err) + } + second, err := svc.ResolveByChannelIdentity(ctx, "feishu", key, "second") + if err != nil { + t.Fatalf("second resolve failed: %v", err) + } + if first.ID != second.ID { + t.Fatalf("expected stable channelIdentity id, got %s and %s", first.ID, second.ID) + } +} + +func TestIntegrationLinkChannelIdentityToUser(t *testing.T) { + svc, queries, cleanup := setupIntegrationTest(t) + defer cleanup() + + ctx := context.Background() + key := fmt.Sprintf("bind_%d", time.Now().UnixNano()) + channelIdentity, err := svc.ResolveByChannelIdentity(ctx, "telegram", key, "tg-user") + if err != nil { + t.Fatalf("resolve channelIdentity failed: %v", err) + } + + user, err := queries.CreateUser(ctx, sqlc.CreateUserParams{ + IsActive: true, + Metadata: []byte("{}"), + }) + if err != nil { + t.Fatalf("create user failed: %v", err) + } + userID := toUUIDString(user.ID.Bytes) + + if err := svc.LinkChannelIdentityToUser(ctx, channelIdentity.ID, userID); err != nil { + t.Fatalf("link channelIdentity to user failed: %v", err) + } + linkedUserID, err := svc.GetLinkedUserID(ctx, channelIdentity.ID) + if err != nil { + t.Fatalf("get linked user failed: %v", err) + } + if linkedUserID != userID { + t.Fatalf("expected linked user=%s, got %s", userID, linkedUserID) + } +} diff --git a/internal/channelidentities/service_test.go b/internal/channelidentities/service_test.go new file mode 100644 index 00000000..a5a73b9e --- /dev/null +++ b/internal/channelidentities/service_test.go @@ -0,0 +1,37 @@ +package channelidentities + +import "testing" + +func TestNormalizeChannel(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"feishu", "feishu"}, + {" FEISHU ", "feishu"}, + {"Web", "web"}, + {"", ""}, + } + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + result := normalizeChannel(tc.input) + if result != tc.expected { + t.Errorf("normalizeChannel(%q) = %q, want %q", tc.input, result, tc.expected) + } + }) + } +} + +func TestToPgText(t *testing.T) { + value := toPgText(" display ") + if !value.Valid { + t.Fatal("expected valid text for non-empty input") + } + if value.String != "display" { + t.Fatalf("expected trimmed text display, got %q", value.String) + } + empty := toPgText(" ") + if empty.Valid { + t.Fatal("expected invalid text for empty input") + } +} diff --git a/internal/channelidentities/types.go b/internal/channelidentities/types.go new file mode 100644 index 00000000..ecffe0e4 --- /dev/null +++ b/internal/channelidentities/types.go @@ -0,0 +1,15 @@ +package channelidentities + +import "time" + +// ChannelIdentity is a unified inbound identity subject across channels. +type ChannelIdentity struct { + ID string `json:"id"` + UserID string `json:"user_id,omitempty"` + Channel string `json:"channel"` + ChannelSubjectID string `json:"channel_subject_id"` + DisplayName string `json:"display_name,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} diff --git a/internal/chat/resolver.go b/internal/chat/resolver.go index 70a1b412..6e07202d 100644 --- a/internal/chat/resolver.go +++ b/internal/chat/resolver.go @@ -9,13 +9,13 @@ import ( "io" "log/slog" "net/http" + "sort" "strings" "time" "github.com/jackc/pgx/v5/pgtype" "github.com/memohai/memoh/internal/db/sqlc" - "github.com/memohai/memoh/internal/history" "github.com/memohai/memoh/internal/mcp" "github.com/memohai/memoh/internal/memory" "github.com/memohai/memoh/internal/models" @@ -23,7 +23,12 @@ import ( "github.com/memohai/memoh/internal/settings" ) -const defaultMaxContextMinutes = 24 * 60 +const ( + defaultMaxContextMinutes = 24 * 60 + memoryContextLimitPerScope = 4 + memoryContextMaxItems = 8 + memoryContextItemMaxChars = 220 +) // SkillEntry represents a skill loaded from the container. type SkillEntry struct { @@ -43,7 +48,7 @@ type Resolver struct { modelsService *models.Service queries *sqlc.Queries memoryService *memory.Service - historyService *history.Service + chatService *Service settingsService *settings.Service mcpService *mcp.ConnectionService skillLoader SkillLoader @@ -60,7 +65,7 @@ func NewResolver( modelsService *models.Service, queries *sqlc.Queries, memoryService *memory.Service, - historyService *history.Service, + chatService *Service, settingsService *settings.Service, mcpService *mcp.ConnectionService, gatewayBaseURL string, @@ -77,12 +82,12 @@ func NewResolver( modelsService: modelsService, queries: queries, memoryService: memoryService, - historyService: historyService, + chatService: chatService, settingsService: settingsService, mcpService: mcpService, gatewayBaseURL: gatewayBaseURL, timeout: timeout, - logger: log.With(slog.String("service", "chat")), + logger: log.With(slog.String("service", "chat_resolver")), httpClient: &http.Client{Timeout: timeout}, streamingClient: &http.Client{}, } @@ -104,16 +109,14 @@ type gatewayModelConfig struct { } type gatewayIdentity struct { - BotID string `json:"botId"` - SessionID string `json:"sessionId"` - ContainerID string `json:"containerId"` - ContactID string `json:"contactId"` - ContactName string `json:"contactName"` - ContactAlias string `json:"contactAlias,omitempty"` - UserID string `json:"userId,omitempty"` - CurrentPlatform string `json:"currentPlatform,omitempty"` - ReplyTarget string `json:"replyTarget,omitempty"` - SessionToken string `json:"sessionToken,omitempty"` + BotID string `json:"botId"` + SessionID string `json:"sessionId"` + ContainerID string `json:"containerId"` + ChannelIdentityID string `json:"channelIdentityId"` + DisplayName string `json:"displayName"` + CurrentPlatform string `json:"currentPlatform,omitempty"` + ReplyTarget string `json:"replyTarget,omitempty"` + SessionToken string `json:"sessionToken,omitempty"` } type gatewaySkill struct { @@ -184,8 +187,8 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex if strings.TrimSpace(req.BotID) == "" { return resolvedContext{}, fmt.Errorf("bot id is required") } - if strings.TrimSpace(req.SessionID) == "" { - return resolvedContext{}, fmt.Errorf("session id is required") + if strings.TrimSpace(req.ChatID) == "" { + return resolvedContext{}, fmt.Errorf("chat id is required") } skipHistory := req.MaxContextLoadTime < 0 @@ -194,11 +197,22 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex if err != nil { return resolvedContext{}, err } - userSettings, err := r.loadUserSettings(ctx, req.UserID) + + // Check chat-level model override. + var chatSettings Settings + if r.chatService != nil { + chatSettings, err = r.chatService.GetSettings(ctx, req.ChatID) + if err != nil { + return resolvedContext{}, err + } + r.enforceGroupMemoryPolicy(ctx, req.ChatID, &chatSettings) + } + + userSettings, err := r.loadUserSettings(ctx, req.ChannelIdentityID) if err != nil { return resolvedContext{}, err } - chatModel, provider, err := r.selectChatModel(ctx, req, botSettings, userSettings) + chatModel, provider, err := r.selectChatModel(ctx, req, botSettings, userSettings, chatSettings) if err != nil { return resolvedContext{}, err } @@ -209,20 +223,18 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex maxCtx := coalescePositiveInt(req.MaxContextLoadTime, botSettings.MaxContextLoadTime, defaultMaxContextMinutes) var messages []ModelMessage - var historySkills []string - if !skipHistory { - messages, err = r.loadHistoryMessages(ctx, req.BotID, req.SessionID, maxCtx) - if err != nil { - return resolvedContext{}, err - } - historySkills, err = r.loadHistorySkills(ctx, req.BotID, req.SessionID, maxCtx) + if !skipHistory && r.chatService != nil { + messages, err = r.loadChatMessages(ctx, req.ChatID, maxCtx) if err != nil { return resolvedContext{}, err } } + if memoryMsg := r.loadMemoryContextMessage(ctx, req, chatSettings); memoryMsg != nil { + messages = append(messages, *memoryMsg) + } messages = append(messages, req.Messages...) messages = sanitizeMessages(messages) - skills := dedup(append(historySkills, req.Skills...)) + skills := dedup(req.Skills) containerID := r.resolveContainerID(ctx, req.BotID, req.ContainerID) var usableSkills []gatewaySkill @@ -277,21 +289,19 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex CurrentChannel: req.CurrentChannel, AllowedActions: req.AllowedActions, MCPConnections: mcpConnections, - Messages: nonNilMessages(messages), + Messages: nonNilModelMessages(messages), Skills: nonNilStrings(skills), UsableSkills: usableSkills, Query: req.Query, Identity: gatewayIdentity{ - BotID: req.BotID, - SessionID: req.SessionID, - ContainerID: containerID, - ContactID: firstNonEmpty(req.ContactID, req.UserID, req.BotID), - ContactName: firstNonEmpty(req.ContactName, "User"), - ContactAlias: req.ContactAlias, - UserID: req.UserID, - CurrentPlatform: req.CurrentChannel, - ReplyTarget: req.ReplyTarget, - SessionToken: req.SessionToken, + BotID: req.BotID, + SessionID: req.ChatID, + ContainerID: containerID, + ChannelIdentityID: firstNonEmpty(req.ChannelIdentityID, req.BotID), + DisplayName: firstNonEmpty(req.DisplayName, "User"), + CurrentPlatform: req.CurrentChannel, + ReplyTarget: "", + SessionToken: req.ChatToken, }, Attachments: []any{}, } @@ -311,7 +321,7 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err if err != nil { return ChatResponse{}, err } - if err := r.storeRound(ctx, req.BotID, req.SessionID, req.Query, resp.Messages, resp.Skills); err != nil { + if err := r.storeRound(ctx, req, resp.Messages); err != nil { return ChatResponse{}, err } return ChatResponse{ @@ -333,13 +343,16 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc return fmt.Errorf("schedule command is required") } - sessionID := "schedule:" + payload.ID + chatID := payload.ChatID + if strings.TrimSpace(chatID) == "" { + chatID = "schedule-" + payload.ID + } req := ChatRequest{ - BotID: botID, - SessionID: sessionID, - Query: payload.Command, - UserID: payload.OwnerUserID, - Token: token, + BotID: botID, + ChatID: chatID, + Query: payload.Command, + ChannelIdentityID: payload.OwnerUserID, + Token: token, } rc, err := r.resolve(ctx, req) if err != nil { @@ -357,12 +370,11 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc Skills: rc.payload.Skills, UsableSkills: rc.payload.UsableSkills, Identity: gatewayIdentity{ - BotID: rc.payload.Identity.BotID, - SessionID: rc.payload.Identity.SessionID, - ContainerID: rc.payload.Identity.ContainerID, - ContactID: botID, - ContactName: "Scheduler", - UserID: payload.OwnerUserID, + BotID: rc.payload.Identity.BotID, + SessionID: rc.payload.Identity.SessionID, + ContainerID: rc.payload.Identity.ContainerID, + ChannelIdentityID: firstNonEmpty(payload.OwnerUserID, botID), + DisplayName: "Scheduler", }, Attachments: rc.payload.Attachments, Schedule: gatewaySchedule{ @@ -379,7 +391,7 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc if err != nil { return err } - return r.storeRound(ctx, botID, sessionID, payload.Command, resp.Messages, resp.Skills) + return r.storeRound(ctx, req, resp.Messages) } // --- StreamChat --- @@ -390,7 +402,7 @@ func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan Stre errCh := make(chan error, 1) r.logger.Info("gateway stream start", slog.String("bot_id", req.BotID), - slog.String("session_id", req.SessionID), + slog.String("chat_id", req.ChatID), ) go func() { @@ -401,16 +413,16 @@ func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan Stre if err != nil { r.logger.Error("gateway stream resolve failed", slog.String("bot_id", req.BotID), - slog.String("session_id", req.SessionID), + slog.String("chat_id", req.ChatID), slog.Any("error", err), ) errCh <- err return } - if err := r.streamChat(ctx, rc.payload, req.BotID, req.SessionID, req.Query, req.Token, chunkCh); err != nil { + if err := r.streamChat(ctx, rc.payload, req, chunkCh); err != nil { r.logger.Error("gateway stream request failed", slog.String("bot_id", req.BotID), - slog.String("session_id", req.SessionID), + slog.String("chat_id", req.ChatID), slog.Any("error", err), ) errCh <- err @@ -502,7 +514,7 @@ func (r *Resolver) postTriggerSchedule(ctx context.Context, payload triggerSched return parsed, nil } -func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, botID, sessionID, query, token string, chunkCh chan<- StreamChunk) error { +func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req ChatRequest, chunkCh chan<- StreamChunk) error { body, err := json.Marshal(payload) if err != nil { return err @@ -515,8 +527,8 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, botID } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Accept", "text/event-stream") - if strings.TrimSpace(token) != "" { - httpReq.Header.Set("Authorization", token) + if strings.TrimSpace(req.Token) != "" { + httpReq.Header.Set("Authorization", req.Token) } resp, err := r.streamingClient.Do(httpReq) @@ -558,7 +570,7 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, botID if stored { continue } - if handled, storeErr := r.tryStoreStream(ctx, botID, sessionID, query, currentEvent, data); storeErr != nil { + if handled, storeErr := r.tryStoreStream(ctx, req, currentEvent, data); storeErr != nil { return storeErr } else if handled { stored = true @@ -568,16 +580,16 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, botID } // tryStoreStream attempts to extract final messages from a stream event and persist them. -func (r *Resolver) tryStoreStream(ctx context.Context, botID, sessionID, query, eventType, data string) (bool, error) { +func (r *Resolver) tryStoreStream(ctx context.Context, req ChatRequest, eventType, data string) (bool, error) { // event: done + data: {messages: [...]} if eventType == "done" { var resp gatewayResponse if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 { - return true, r.storeRound(ctx, botID, sessionID, query, resp.Messages, resp.Skills) + return true, r.storeRound(ctx, req, resp.Messages) } } - // data: {"type":"agent_end"|"done", ...} + // data: {"type":"text_delta"|"agent_end"|"done", ...} var envelope struct { Type string `json:"type"` Data json.RawMessage `json:"data"` @@ -585,13 +597,13 @@ func (r *Resolver) tryStoreStream(ctx context.Context, botID, sessionID, query, Skills []string `json:"skills"` } if err := json.Unmarshal([]byte(data), &envelope); err == nil { - if envelope.Type == "agent_end" && len(envelope.Messages) > 0 { - return true, r.storeRound(ctx, botID, sessionID, query, envelope.Messages, envelope.Skills) + if (envelope.Type == "agent_end" || envelope.Type == "done") && len(envelope.Messages) > 0 { + return true, r.storeRound(ctx, req, envelope.Messages) } if envelope.Type == "done" && len(envelope.Data) > 0 { var resp gatewayResponse if err := json.Unmarshal(envelope.Data, &resp); err == nil && len(resp.Messages) > 0 { - return true, r.storeRound(ctx, botID, sessionID, query, resp.Messages, resp.Skills) + return true, r.storeRound(ctx, req, resp.Messages) } } } @@ -599,7 +611,7 @@ func (r *Resolver) tryStoreStream(ctx context.Context, botID, sessionID, query, // fallback: data: {messages: [...]} var resp gatewayResponse if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 { - return true, r.storeRound(ctx, botID, sessionID, query, resp.Messages, resp.Skills) + return true, r.storeRound(ctx, req, resp.Messages) } return false, nil } @@ -622,103 +634,206 @@ func (r *Resolver) resolveContainerID(ctx context.Context, botID, explicit strin return "mcp-" + botID } -// --- history helpers --- +// --- message loading --- -func (r *Resolver) loadHistoryMessages(ctx context.Context, botID, sessionID string, maxContextMinutes int) ([]ModelMessage, error) { - if r.historyService == nil { - return nil, fmt.Errorf("history service not configured") - } +func (r *Resolver) loadChatMessages(ctx context.Context, chatID string, maxContextMinutes int) ([]ModelMessage, error) { since := time.Now().UTC().Add(-time.Duration(maxContextMinutes) * time.Minute) - records, err := r.historyService.ListBySessionSince(ctx, botID, sessionID, since) + msgs, err := r.chatService.ListMessagesSince(ctx, chatID, since) if err != nil { return nil, err } - var messages []ModelMessage - for _, record := range records { - msgs, err := recordToMessages(record) + var result []ModelMessage + for _, m := range msgs { + var mm ModelMessage + if err := json.Unmarshal(m.Content, &mm); err != nil { + // Fallback: treat content as text string. + mm = ModelMessage{Role: m.Role, Content: m.Content} + } else { + mm.Role = m.Role + } + result = append(result, mm) + } + return result, nil +} + +type memoryContextItem struct { + Namespace string + Item memory.MemoryItem +} + +func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest, settings Settings) *ModelMessage { + if r.memoryService == nil { + return nil + } + if strings.TrimSpace(req.Query) == "" || strings.TrimSpace(req.BotID) == "" || strings.TrimSpace(req.ChatID) == "" { + return nil + } + type memoryScope struct { + Namespace string + ScopeID string + } + var scopes []memoryScope + if settings.EnableChatMemory { + scopes = append(scopes, memoryScope{Namespace: "chat", ScopeID: req.ChatID}) + } + if settings.EnablePrivateMemory && strings.TrimSpace(req.ChannelIdentityID) != "" { + scopes = append(scopes, memoryScope{Namespace: "private", ScopeID: req.ChannelIdentityID}) + } + if settings.EnablePublicMemory { + scopes = append(scopes, memoryScope{Namespace: "public", ScopeID: req.BotID}) + } + if len(scopes) == 0 { + return nil + } + + results := make([]memoryContextItem, 0, len(scopes)*memoryContextLimitPerScope) + seen := map[string]struct{}{} + for _, scope := range scopes { + resp, err := r.memoryService.Search(ctx, memory.SearchRequest{ + Query: req.Query, + BotID: req.BotID, + Limit: memoryContextLimitPerScope, + Filters: map[string]any{ + "namespace": scope.Namespace, + "scopeId": scope.ScopeID, + "botId": req.BotID, + }, + }) if err != nil { - r.logger.Warn("skip malformed history record", slog.String("record_id", record.ID), slog.Any("error", err)) + r.logger.Warn("memory search for context failed", + slog.String("namespace", scope.Namespace), + slog.Any("error", err), + ) continue } - messages = append(messages, msgs...) + for _, item := range resp.Results { + key := strings.TrimSpace(item.ID) + if key == "" { + key = scope.Namespace + ":" + strings.TrimSpace(item.Memory) + } + if key == "" { + continue + } + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + results = append(results, memoryContextItem{Namespace: scope.Namespace, Item: item}) + } + } + if len(results) == 0 { + return nil } - return messages, nil -} -func (r *Resolver) loadHistorySkills(ctx context.Context, botID, sessionID string, maxContextMinutes int) ([]string, error) { - if r.historyService == nil { - return nil, fmt.Errorf("history service not configured") + sort.Slice(results, func(i, j int) bool { + return results[i].Item.Score > results[j].Item.Score + }) + if len(results) > memoryContextMaxItems { + results = results[:memoryContextMaxItems] } - since := time.Now().UTC().Add(-time.Duration(maxContextMinutes) * time.Minute) - records, err := r.historyService.ListBySessionSince(ctx, botID, sessionID, since) - if err != nil { - return nil, err - } - var combined []string - for _, record := range records { - combined = append(combined, record.Skills...) - } - return dedup(combined), nil -} -// recordToMessages converts a history record (stored as []map[string]any) to typed ModelMessages. -func recordToMessages(record history.Record) ([]ModelMessage, error) { - if len(record.Messages) == 0 { - return nil, nil + var sb strings.Builder + sb.WriteString("Relevant memory context (use when helpful):\n") + for _, entry := range results { + text := strings.TrimSpace(entry.Item.Memory) + if text == "" { + continue + } + sb.WriteString("- [") + sb.WriteString(entry.Namespace) + sb.WriteString("] ") + sb.WriteString(truncateMemorySnippet(text, memoryContextItemMaxChars)) + sb.WriteString("\n") } - raw, err := json.Marshal(record.Messages) - if err != nil { - return nil, err + payload := strings.TrimSpace(sb.String()) + if payload == "" { + return nil } - var msgs []ModelMessage - if err := json.Unmarshal(raw, &msgs); err != nil { - return nil, err + msg := ModelMessage{ + Role: "system", + Content: NewTextContent(payload), } - return msgs, nil + return &msg } // --- store helpers --- -func (r *Resolver) storeRound(ctx context.Context, botID, sessionID, query string, messages []ModelMessage, skills []string) error { - if err := r.storeHistory(ctx, botID, sessionID, query, messages, skills); err != nil { - return err +func (r *Resolver) storeRound(ctx context.Context, req ChatRequest, messages []ModelMessage) error { + // Add user query as the first message if not already present in the round. + // This ensures the user's prompt is persisted alongside the assistant's response. + fullRound := make([]ModelMessage, 0, len(messages)+1) + hasUserQuery := false + for _, m := range messages { + if m.Role == "user" && m.TextContent() == req.Query { + hasUserQuery = true + break + } } - r.storeMemory(ctx, botID, sessionID, query, messages) + if !hasUserQuery && strings.TrimSpace(req.Query) != "" { + fullRound = append(fullRound, ModelMessage{ + Role: "user", + Content: NewTextContent(req.Query), + }) + } + fullRound = append(fullRound, messages...) + + r.storeMessages(ctx, req, fullRound) + r.storeMemory(ctx, req.BotID, req.ChatID, req.ChannelIdentityID, req.Query, fullRound) return nil } -func (r *Resolver) storeHistory(ctx context.Context, botID, sessionID, query string, messages []ModelMessage, skills []string) error { - if r.historyService == nil { - return fmt.Errorf("history service not configured") +func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages []ModelMessage) { + if r.chatService == nil { + return } - if strings.TrimSpace(botID) == "" || strings.TrimSpace(sessionID) == "" { - return fmt.Errorf("bot id and session id are required") + if strings.TrimSpace(req.BotID) == "" || strings.TrimSpace(req.ChatID) == "" { + return } - if strings.TrimSpace(query) == "" && len(messages) == 0 { - return nil + // Build route-level metadata for traceability. + var meta map[string]any + if strings.TrimSpace(req.RouteID) != "" || strings.TrimSpace(req.CurrentChannel) != "" { + meta = map[string]any{} + if strings.TrimSpace(req.RouteID) != "" { + meta["route_id"] = req.RouteID + } + if strings.TrimSpace(req.CurrentChannel) != "" { + meta["platform"] = req.CurrentChannel + } } - // Convert typed messages to []map[string]any for the history service. - raw, err := json.Marshal(messages) - if err != nil { - return err + for _, msg := range messages { + content, err := json.Marshal(msg) + if err != nil { + continue + } + senderID := "" + externalMessageID := "" + if msg.Role == "user" { + senderID = req.ChannelIdentityID + externalMessageID = req.ExternalMessageID + } + if _, err := r.chatService.PersistMessage( + ctx, + req.ChatID, + req.BotID, + req.RouteID, + "", + senderID, + req.CurrentChannel, + externalMessageID, + msg.Role, + content, + meta, + ); err != nil { + r.logger.Warn("persist message failed", slog.Any("error", err)) + } } - var rows []map[string]any - if err := json.Unmarshal(raw, &rows); err != nil { - return err - } - _, err = r.historyService.Create(ctx, botID, strings.TrimSpace(sessionID), history.CreateRequest{ - Messages: rows, - Metadata: map[string]any{"query": strings.TrimSpace(query)}, - Skills: skills, - }) - return err } -func (r *Resolver) storeMemory(ctx context.Context, botID, sessionID, query string, messages []ModelMessage) { +func (r *Resolver) storeMemory(ctx context.Context, botID, chatID, channelIdentityID, query string, messages []ModelMessage) { if r.memoryService == nil { return } - if strings.TrimSpace(botID) == "" || strings.TrimSpace(sessionID) == "" { + if strings.TrimSpace(botID) == "" || strings.TrimSpace(chatID) == "" { return } memMsgs := make([]memory.Message, 0, len(messages)) @@ -736,27 +851,68 @@ func (r *Resolver) storeMemory(ctx context.Context, botID, sessionID, query stri if len(memMsgs) == 0 { return } + + // Load chat settings to determine which namespaces to write to. + var cs Settings + if r.chatService != nil { + settings, err := r.chatService.GetSettings(ctx, chatID) + if err != nil { + r.logger.Warn("load chat settings for memory write failed", slog.Any("error", err)) + } else { + cs = settings + r.enforceGroupMemoryPolicy(ctx, chatID, &cs) + } + } + + // Always write to chat namespace if enabled (default true). + if cs.EnableChatMemory { + r.addMemory(ctx, botID, memMsgs, "chat", chatID) + } + + // Write to private namespace if enabled and channel identity is known. + if cs.EnablePrivateMemory && strings.TrimSpace(channelIdentityID) != "" { + r.addMemory(ctx, botID, memMsgs, "private", channelIdentityID) + } + + // Write to public namespace if enabled. + if cs.EnablePublicMemory { + r.addMemory(ctx, botID, memMsgs, "public", botID) + } +} + +func (r *Resolver) addMemory(ctx context.Context, botID string, msgs []memory.Message, namespace, scopeID string) { + filters := map[string]any{ + "namespace": namespace, + "scopeId": scopeID, + "botId": botID, + } if _, err := r.memoryService.Add(ctx, memory.AddRequest{ - Messages: memMsgs, - BotID: botID, - SessionID: strings.TrimSpace(sessionID), + Messages: msgs, + BotID: botID, + Filters: filters, }); err != nil { - r.logger.Warn("store memory failed", slog.Any("error", err)) + r.logger.Warn("store memory failed", + slog.String("namespace", namespace), + slog.String("scope_id", scopeID), + slog.Any("error", err), + ) } } // --- model selection --- -func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, botSettings settings.Settings, us resolvedUserSettings) (models.GetResponse, sqlc.LlmProvider, error) { +func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, botSettings settings.Settings, us resolvedUserSettings, cs Settings) (models.GetResponse, sqlc.LlmProvider, error) { if r.modelsService == nil { return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured") } modelID := strings.TrimSpace(req.Model) providerFilter := strings.TrimSpace(req.Provider) - // Priority: request model > bot settings > user settings. + // Priority: request model > chat settings > bot settings > user settings. if modelID == "" && providerFilter == "" { - if value := strings.TrimSpace(botSettings.ChatModelID); value != "" { + if value := strings.TrimSpace(cs.ModelID); value != "" { + modelID = value + } else if value := strings.TrimSpace(botSettings.ChatModelID); value != "" { modelID = value } else if value := strings.TrimSpace(us.ChatModelID); value != "" { modelID = value @@ -922,7 +1078,7 @@ func nonNilStrings(s []string) []string { return s } -func nonNilMessages(m []ModelMessage) []ModelMessage { +func nonNilModelMessages(m []ModelMessage) []ModelMessage { if m == nil { return []ModelMessage{} } @@ -936,6 +1092,14 @@ func truncate(s string, n int) string { return s[:n] + "..." } +func truncateMemorySnippet(s string, n int) string { + trimmed := strings.TrimSpace(s) + if len(trimmed) <= n { + return trimmed + } + return strings.TrimSpace(trimmed[:n]) + "..." +} + func parseUUID(id string) (pgtype.UUID, error) { trimmed := strings.TrimSpace(id) if trimmed == "" { @@ -947,3 +1111,16 @@ func parseUUID(id string) (pgtype.UUID, error) { } return pgID, nil } + +func (r *Resolver) enforceGroupMemoryPolicy(ctx context.Context, chatID string, settings *Settings) { + if r == nil || r.chatService == nil || settings == nil { + return + } + chatObj, err := r.chatService.Get(ctx, chatID) + if err != nil { + return + } + if strings.EqualFold(strings.TrimSpace(chatObj.Kind), KindGroup) { + settings.EnablePrivateMemory = false + } +} diff --git a/internal/chat/resolver_test.go b/internal/chat/resolver_test.go index 6fc40224..74d8a329 100644 --- a/internal/chat/resolver_test.go +++ b/internal/chat/resolver_test.go @@ -47,12 +47,11 @@ func TestPostTriggerSchedule_Endpoint(t *testing.T) { Messages: []ModelMessage{}, Skills: []string{}, Identity: gatewayIdentity{ - BotID: "bot-123", - SessionID: "schedule:sched-1", - ContainerID: "mcp-bot-123", - ContactID: "bot-123", - ContactName: "Scheduler", - UserID: "owner-user-1", + BotID: "bot-123", + SessionID: "schedule:sched-1", + ContainerID: "mcp-bot-123", + ChannelIdentityID: "owner-user-1", + DisplayName: "Scheduler", }, Attachments: []any{}, Schedule: gatewaySchedule{ diff --git a/internal/chat/service.go b/internal/chat/service.go new file mode 100644 index 00000000..11ebbf8b --- /dev/null +++ b/internal/chat/service.go @@ -0,0 +1,864 @@ +package chat + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/db/sqlc" +) + +var ( + ErrChatNotFound = errors.New("chat not found") + ErrNotParticipant = errors.New("not a participant") + ErrPermissionDenied = errors.New("permission denied") +) + +// Service manages chat lifecycle, participants, settings, and routes. +type Service struct { + queries *sqlc.Queries + logger *slog.Logger +} + +// NewService creates a chat service. +func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { + if log == nil { + log = slog.Default() + } + return &Service{ + queries: queries, + logger: log.With(slog.String("service", "chat")), + } +} + +// --- Chat CRUD --- + +// Create creates a new chat and adds the creator as owner. +func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, req CreateChatRequest) (Chat, error) { + kind := strings.TrimSpace(req.Kind) + if kind == "" { + kind = KindDirect + } + if kind != KindDirect && kind != KindGroup && kind != KindThread { + return Chat{}, fmt.Errorf("invalid chat kind: %s", kind) + } + + pgBotID, err := parseUUID(botID) + if err != nil { + return Chat{}, fmt.Errorf("invalid bot id: %w", err) + } + pgChannelIdentityID := pgtype.UUID{} + if strings.TrimSpace(channelIdentityID) != "" { + pgChannelIdentityID, err = parseUUID(channelIdentityID) + if err != nil { + return Chat{}, fmt.Errorf("invalid user id: %w", err) + } + } + + var pgParent pgtype.UUID + if kind == KindThread && strings.TrimSpace(req.ParentChatID) != "" { + pgParent, err = parseUUID(req.ParentChatID) + if err != nil { + return Chat{}, fmt.Errorf("invalid parent chat id: %w", err) + } + } + + metadata, _ := json.Marshal(nonNilMap(req.Metadata)) + + row, err := s.queries.CreateChat(ctx, sqlc.CreateChatParams{ + BotID: pgBotID, + Kind: kind, + ParentChatID: pgParent, + Title: toPgText(req.Title), + CreatedByUserID: pgChannelIdentityID, + Metadata: metadata, + }) + if err != nil { + return Chat{}, fmt.Errorf("create chat: %w", err) + } + + // Add creator as owner when user identity is available. + if pgChannelIdentityID.Valid { + if _, err := s.queries.AddChatParticipant(ctx, sqlc.AddChatParticipantParams{ + ChatID: row.ID, + UserID: pgChannelIdentityID, + Role: RoleOwner, + }); err != nil { + return Chat{}, fmt.Errorf("add owner participant: %w", err) + } + } + + // Create default settings based on kind. + enablePrivate := kind != KindGroup + if _, err := s.queries.UpsertChatSettings(ctx, sqlc.UpsertChatSettingsParams{ + ID: row.ID, + EnableChatMemory: true, + EnablePrivateMemory: enablePrivate, + EnablePublicMemory: false, + SettingsMetadata: []byte("{}"), + }); err != nil { + return Chat{}, fmt.Errorf("create default settings: %w", err) + } + + // For threads, copy participants from parent. + if kind == KindThread && pgParent.Valid { + if err := s.queries.CopyParticipantsToChat(ctx, sqlc.CopyParticipantsToChatParams{ + ChatID: pgParent, + ChatID_2: row.ID, + }); err != nil { + s.logger.Warn("copy parent participants failed", slog.Any("error", err)) + } + } + + return toChat(row), nil +} + +// Get returns a chat by ID. +func (s *Service) Get(ctx context.Context, chatID string) (Chat, error) { + pgID, err := parseUUID(chatID) + if err != nil { + return Chat{}, ErrChatNotFound + } + row, err := s.queries.GetChatByID(ctx, pgID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return Chat{}, ErrChatNotFound + } + return Chat{}, err + } + return toChat(row), nil +} + +// GetReadAccess resolves whether a user can read a chat. +func (s *Service) GetReadAccess(ctx context.Context, chatID, channelIdentityID string) (ChatReadAccess, error) { + pgChatID, err := parseUUID(chatID) + if err != nil { + return ChatReadAccess{}, ErrPermissionDenied + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return ChatReadAccess{}, ErrPermissionDenied + } + row, err := s.queries.GetChatReadAccessByUser(ctx, sqlc.GetChatReadAccessByUserParams{ + ChatID: pgChatID, + UserID: pgChannelIdentityID, + }) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ChatReadAccess{}, ErrPermissionDenied + } + return ChatReadAccess{}, err + } + return ChatReadAccess{ + AccessMode: row.AccessMode, + ParticipantRole: strings.TrimSpace(row.ParticipantRole), + LastObservedAt: pgTimePtr(row.LastObservedAt), + }, nil +} + +// ListByBotAndChannelIdentity returns all chats visible to the user for a bot. +func (s *Service) ListByBotAndChannelIdentity(ctx context.Context, botID, channelIdentityID string) ([]ChatListItem, error) { + pgBotID, err := parseUUID(botID) + if err != nil { + return nil, err + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListVisibleChatsByBotAndUser(ctx, sqlc.ListVisibleChatsByBotAndUserParams{ + BotID: pgBotID, + UserID: pgChannelIdentityID, + }) + if err != nil { + return nil, err + } + chats := make([]ChatListItem, 0, len(rows)) + for _, row := range rows { + chats = append(chats, toChatListItem(row)) + } + return chats, nil +} + +// ListThreads returns threads for a parent chat. +func (s *Service) ListThreads(ctx context.Context, parentChatID string) ([]Chat, error) { + pgID, err := parseUUID(parentChatID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListThreadsByParent(ctx, pgID) + if err != nil { + return nil, err + } + chats := make([]Chat, 0, len(rows)) + for _, row := range rows { + chats = append(chats, toChat(row)) + } + return chats, nil +} + +// Delete deletes a chat (cascade deletes messages, routes, participants, settings). +func (s *Service) Delete(ctx context.Context, chatID string) error { + pgID, err := parseUUID(chatID) + if err != nil { + return ErrChatNotFound + } + return s.queries.DeleteChat(ctx, pgID) +} + +// --- Participants --- + +// AddParticipant adds a user identity to a chat. +func (s *Service) AddParticipant(ctx context.Context, chatID, channelIdentityID, role string) (Participant, error) { + pgChatID, err := parseUUID(chatID) + if err != nil { + return Participant{}, err + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return Participant{}, err + } + if role == "" { + role = RoleMember + } + row, err := s.queries.AddChatParticipant(ctx, sqlc.AddChatParticipantParams{ + ChatID: pgChatID, + UserID: pgChannelIdentityID, + Role: role, + }) + if err != nil { + return Participant{}, err + } + return toParticipant(row), nil +} + +// GetParticipant returns a participant record. +func (s *Service) GetParticipant(ctx context.Context, chatID, channelIdentityID string) (Participant, error) { + pgChatID, err := parseUUID(chatID) + if err != nil { + return Participant{}, ErrNotParticipant + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return Participant{}, ErrNotParticipant + } + row, err := s.queries.GetChatParticipant(ctx, sqlc.GetChatParticipantParams{ + ChatID: pgChatID, + UserID: pgChannelIdentityID, + }) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return Participant{}, ErrNotParticipant + } + return Participant{}, err + } + return toParticipant(row), nil +} + +// IsParticipant checks whether a user identity is a participant in a chat. +func (s *Service) IsParticipant(ctx context.Context, chatID, channelIdentityID string) (bool, error) { + _, err := s.GetParticipant(ctx, chatID, channelIdentityID) + if errors.Is(err, ErrNotParticipant) { + return false, nil + } + return err == nil, err +} + +// ListParticipants returns all participants for a chat. +func (s *Service) ListParticipants(ctx context.Context, chatID string) ([]Participant, error) { + pgID, err := parseUUID(chatID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListChatParticipants(ctx, pgID) + if err != nil { + return nil, err + } + participants := make([]Participant, 0, len(rows)) + for _, row := range rows { + participants = append(participants, toParticipant(row)) + } + return participants, nil +} + +// RemoveParticipant removes a user identity from a chat. +func (s *Service) RemoveParticipant(ctx context.Context, chatID, channelIdentityID string) error { + pgChatID, err := parseUUID(chatID) + if err != nil { + return err + } + pgChannelIdentityID, err := parseUUID(channelIdentityID) + if err != nil { + return err + } + return s.queries.RemoveChatParticipant(ctx, sqlc.RemoveChatParticipantParams{ + ChatID: pgChatID, + UserID: pgChannelIdentityID, + }) +} + +// --- Settings --- + +// GetSettings returns settings for a chat. Returns defaults if not found. +func (s *Service) GetSettings(ctx context.Context, chatID string) (Settings, error) { + pgID, err := parseUUID(chatID) + var current Settings + if err != nil { + current = defaultSettings(chatID) + return current, nil + } + row, err := s.queries.GetChatSettings(ctx, pgID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + current = defaultSettings(chatID) + if s.isGroupChat(ctx, chatID) { + current.EnablePrivateMemory = false + } + return current, nil + } + return Settings{}, err + } + current = toSettingsFromRead(row) + if s.isGroupChat(ctx, chatID) { + current.EnablePrivateMemory = false + } + return current, nil +} + +// UpdateSettings updates chat settings. +func (s *Service) UpdateSettings(ctx context.Context, chatID string, req UpdateSettingsRequest) (Settings, error) { + current, err := s.GetSettings(ctx, chatID) + if err != nil { + return Settings{}, err + } + isGroup := s.isGroupChat(ctx, chatID) + if req.EnableChatMemory != nil { + current.EnableChatMemory = *req.EnableChatMemory + } + if req.EnablePrivateMemory != nil { + current.EnablePrivateMemory = *req.EnablePrivateMemory + } + if req.EnablePublicMemory != nil { + current.EnablePublicMemory = *req.EnablePublicMemory + } + if req.ModelID != nil { + current.ModelID = *req.ModelID + } + if isGroup { + // Group chats are shared contexts, so private memory stays disabled. + current.EnablePrivateMemory = false + } + + pgID, err := parseUUID(chatID) + if err != nil { + return Settings{}, err + } + row, err := s.queries.UpsertChatSettings(ctx, sqlc.UpsertChatSettingsParams{ + ID: pgID, + EnableChatMemory: current.EnableChatMemory, + EnablePrivateMemory: current.EnablePrivateMemory, + EnablePublicMemory: current.EnablePublicMemory, + ModelID: toPgText(current.ModelID), + SettingsMetadata: []byte("{}"), + }) + if err != nil { + return Settings{}, err + } + return toSettingsFromUpsert(row), nil +} + +// --- Routes --- + +// CreateRoute creates a new chat route. +func (s *Service) CreateRoute(ctx context.Context, chatID string, r Route) (Route, error) { + pgChatID, err := parseUUID(chatID) + if err != nil { + return Route{}, err + } + pgBotID, err := parseUUID(r.BotID) + if err != nil { + return Route{}, err + } + var pgConfigID pgtype.UUID + if strings.TrimSpace(r.ChannelConfigID) != "" { + pgConfigID, err = parseUUID(r.ChannelConfigID) + if err != nil { + return Route{}, err + } + } + metadata, _ := json.Marshal(nonNilMap(r.Metadata)) + row, err := s.queries.CreateChatRoute(ctx, sqlc.CreateChatRouteParams{ + ChatID: pgChatID, + BotID: pgBotID, + Platform: r.Platform, + ChannelConfigID: pgConfigID, + ConversationID: r.ConversationID, + ThreadID: toPgText(r.ThreadID), + ReplyTarget: toPgText(r.ReplyTarget), + Metadata: metadata, + }) + if err != nil { + return Route{}, fmt.Errorf("create route: %w", err) + } + return toRoute(row), nil +} + +// FindRoute looks up a route by (bot_id, platform, conversation_id, thread_id). +func (s *Service) FindRoute(ctx context.Context, botID, platform, conversationID, threadID string) (Route, error) { + pgBotID, err := parseUUID(botID) + if err != nil { + return Route{}, err + } + row, err := s.queries.FindChatRoute(ctx, sqlc.FindChatRouteParams{ + BotID: pgBotID, + Platform: platform, + ConversationID: conversationID, + ThreadID: toPgText(threadID), + }) + if err != nil { + return Route{}, err + } + return toRoute(row), nil +} + +// GetRouteByID returns a single route by its ID. +func (s *Service) GetRouteByID(ctx context.Context, routeID string) (Route, error) { + pgID, err := parseUUID(routeID) + if err != nil { + return Route{}, err + } + row, err := s.queries.GetChatRouteByID(ctx, pgID) + if err != nil { + return Route{}, err + } + return toRoute(row), nil +} + +// ListRoutes lists all routes for a chat. +func (s *Service) ListRoutes(ctx context.Context, chatID string) ([]Route, error) { + pgID, err := parseUUID(chatID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListChatRoutes(ctx, pgID) + if err != nil { + return nil, err + } + routes := make([]Route, 0, len(rows)) + for _, row := range rows { + routes = append(routes, toRoute(row)) + } + return routes, nil +} + +// DeleteRoute deletes a route. +func (s *Service) DeleteRoute(ctx context.Context, routeID string) error { + pgID, err := parseUUID(routeID) + if err != nil { + return err + } + return s.queries.DeleteChatRoute(ctx, pgID) +} + +// UpdateRouteReplyTarget updates the reply target for a route. +func (s *Service) UpdateRouteReplyTarget(ctx context.Context, routeID, replyTarget string) error { + pgID, err := parseUUID(routeID) + if err != nil { + return err + } + return s.queries.UpdateChatRouteReplyTarget(ctx, sqlc.UpdateChatRouteReplyTargetParams{ + ID: pgID, + ReplyTarget: toPgText(replyTarget), + }) +} + +// --- ResolveChat --- + +// ResolveChat finds or creates a chat for a channel inbound message. +func (s *Service) ResolveChat(ctx context.Context, botID, platform, conversationID, threadID, conversationType, channelIdentityID, channelConfigID, replyTarget string) (ResolveChatResult, error) { + // Look up existing route. + route, err := s.FindRoute(ctx, botID, platform, conversationID, threadID) + if err == nil { + // Route found, ensure the sender identity is a participant. + if strings.TrimSpace(channelIdentityID) != "" { + ok, _ := s.IsParticipant(ctx, route.ChatID, channelIdentityID) + if !ok { + if _, err := s.AddParticipant(ctx, route.ChatID, channelIdentityID, RoleMember); err != nil { + s.logger.Warn("auto-add participant failed", slog.Any("error", err)) + } + } + } + // Update reply target if changed. + if strings.TrimSpace(replyTarget) != "" && replyTarget != route.ReplyTarget { + _ = s.UpdateRouteReplyTarget(ctx, route.ID, replyTarget) + } + _ = s.queries.TouchChat(ctx, mustParseUUID(route.ChatID)) + return ResolveChatResult{ChatID: route.ChatID, RouteID: route.ID, Created: false}, nil + } + + // Route not found, create chat + route + participant. + kind := determineChatKind(threadID, conversationType) + creatorChannelIdentityID := s.resolveChatCreatorChannelIdentityID(ctx, botID, channelIdentityID, kind) + + var parentChatID string + if kind == KindThread { + parentRoute, parentErr := s.FindRoute(ctx, botID, platform, conversationID, "") + if parentErr == nil { + parentChatID = parentRoute.ChatID + } + } + + c, err := s.Create(ctx, botID, creatorChannelIdentityID, CreateChatRequest{ + Kind: kind, + ParentChatID: parentChatID, + }) + if err != nil { + return ResolveChatResult{}, fmt.Errorf("create chat: %w", err) + } + if strings.TrimSpace(channelIdentityID) != "" && strings.TrimSpace(channelIdentityID) != strings.TrimSpace(creatorChannelIdentityID) { + if _, err := s.AddParticipant(ctx, c.ID, channelIdentityID, RoleMember); err != nil { + s.logger.Warn("auto-add creator participant failed", slog.Any("error", err)) + } + } + + newRoute, err := s.CreateRoute(ctx, c.ID, Route{ + BotID: botID, + Platform: platform, + ChannelConfigID: channelConfigID, + ConversationID: conversationID, + ThreadID: threadID, + ReplyTarget: replyTarget, + }) + if err != nil { + return ResolveChatResult{}, fmt.Errorf("create route: %w", err) + } + + return ResolveChatResult{ChatID: c.ID, RouteID: newRoute.ID, Created: true}, nil +} + +// --- Messages --- + +// PersistMessage writes a single message to chat_messages. +func (s *Service) PersistMessage(ctx context.Context, chatID, botID, routeID, senderChannelIdentityID, senderUserID, platform, externalMessageID, role string, content json.RawMessage, metadata map[string]any) (Message, error) { + pgChatID, err := parseUUID(chatID) + if err != nil { + return Message{}, err + } + pgBotID, err := parseUUID(botID) + if err != nil { + return Message{}, err + } + var pgRouteID pgtype.UUID + if strings.TrimSpace(routeID) != "" { + pgRouteID, err = parseUUID(routeID) + if err != nil { + return Message{}, err + } + } + var pgSender pgtype.UUID + if strings.TrimSpace(senderChannelIdentityID) != "" { + pgSender, _ = parseUUID(senderChannelIdentityID) + } + var pgSenderUser pgtype.UUID + if strings.TrimSpace(senderUserID) != "" { + pgSenderUser, _ = parseUUID(senderUserID) + } + metaBytes, _ := json.Marshal(nonNilMap(metadata)) + if len(content) == 0 { + content = []byte("{}") + } + + row, err := s.queries.CreateChatMessage(ctx, sqlc.CreateChatMessageParams{ + ChatID: pgChatID, + BotID: pgBotID, + RouteID: pgRouteID, + SenderChannelIdentityID: pgSender, + SenderUserID: pgSenderUser, + Platform: toPgText(platform), + ExternalMessageID: toPgText(externalMessageID), + Role: role, + Content: content, + Metadata: metaBytes, + }) + if err != nil { + return Message{}, err + } + if pgSender.Valid { + if err := s.queries.UpsertChatChannelIdentityPresence(ctx, sqlc.UpsertChatChannelIdentityPresenceParams{ + ChatID: pgChatID, + ChannelIdentityID: pgSender, + }); err != nil && s.logger != nil { + // Presence is a derived cache. Keep message persistence successful even if cache update fails. + s.logger.Warn("upsert chat channel identity presence failed", slog.Any("error", err)) + } + } + return toMessage(row), nil +} + +// ListMessages returns all messages for a chat. +func (s *Service) ListMessages(ctx context.Context, chatID string) ([]Message, error) { + pgID, err := parseUUID(chatID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListChatMessages(ctx, pgID) + if err != nil { + return nil, err + } + return toMessages(rows), nil +} + +// ListMessagesSince returns messages since a given time. +func (s *Service) ListMessagesSince(ctx context.Context, chatID string, since time.Time) ([]Message, error) { + pgID, err := parseUUID(chatID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListChatMessagesSince(ctx, sqlc.ListChatMessagesSinceParams{ + ChatID: pgID, + CreatedAt: pgtype.Timestamptz{Time: since, Valid: true}, + }) + if err != nil { + return nil, err + } + return toMessages(rows), nil +} + +// ListMessagesLatest returns the latest N messages (most recent first). +func (s *Service) ListMessagesLatest(ctx context.Context, chatID string, limit int32) ([]Message, error) { + pgID, err := parseUUID(chatID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListChatMessagesLatest(ctx, sqlc.ListChatMessagesLatestParams{ + ChatID: pgID, + Limit: limit, + }) + if err != nil { + return nil, err + } + return toMessages(rows), nil +} + +// DeleteMessages deletes all messages for a chat. +func (s *Service) DeleteMessages(ctx context.Context, chatID string) error { + pgID, err := parseUUID(chatID) + if err != nil { + return err + } + return s.queries.DeleteChatMessagesByChat(ctx, pgID) +} + +// --- conversion helpers --- + +func toChat(row sqlc.Chat) Chat { + return Chat{ + ID: uuidString(row.ID), + BotID: uuidString(row.BotID), + Kind: row.Kind, + ParentChatID: uuidString(row.ParentChatID), + Title: pgTextString(row.Title), + CreatedBy: uuidString(row.CreatedByUserID), + Metadata: parseJSONMap(row.Metadata), + CreatedAt: row.CreatedAt.Time, + UpdatedAt: row.UpdatedAt.Time, + } +} + +func toChatListItem(row sqlc.ListVisibleChatsByBotAndUserRow) ChatListItem { + return ChatListItem{ + ID: uuidString(row.ID), + BotID: uuidString(row.BotID), + Kind: row.Kind, + ParentChatID: uuidString(row.ParentChatID), + Title: pgTextString(row.Title), + CreatedBy: uuidString(row.CreatedByUserID), + Metadata: parseJSONMap(row.Metadata), + CreatedAt: row.CreatedAt.Time, + UpdatedAt: row.UpdatedAt.Time, + AccessMode: row.AccessMode, + ParticipantRole: strings.TrimSpace(row.ParticipantRole), + LastObservedAt: pgTimePtr(row.LastObservedAt), + } +} + +func toParticipant(row sqlc.ChatParticipant) Participant { + return Participant{ + ChatID: uuidString(row.ChatID), + UserID: uuidString(row.UserID), + Role: row.Role, + JoinedAt: row.JoinedAt.Time, + } +} + +func toSettingsFromRead(row sqlc.GetChatSettingsRow) Settings { + return Settings{ + ChatID: uuidString(row.ChatID), + EnableChatMemory: row.EnableChatMemory, + EnablePrivateMemory: row.EnablePrivateMemory, + EnablePublicMemory: row.EnablePublicMemory, + ModelID: pgTextString(row.ModelID), + Metadata: parseJSONMap(row.Metadata), + } +} + +func toSettingsFromUpsert(row sqlc.UpsertChatSettingsRow) Settings { + return Settings{ + ChatID: uuidString(row.ChatID), + EnableChatMemory: row.EnableChatMemory, + EnablePrivateMemory: row.EnablePrivateMemory, + EnablePublicMemory: row.EnablePublicMemory, + ModelID: pgTextString(row.ModelID), + Metadata: parseJSONMap(row.Metadata), + } +} + +func toRoute(row sqlc.ChatRoute) Route { + return Route{ + ID: uuidString(row.ID), + ChatID: uuidString(row.ChatID), + BotID: uuidString(row.BotID), + Platform: row.Platform, + ChannelConfigID: uuidString(row.ChannelConfigID), + ConversationID: row.ConversationID, + ThreadID: pgTextString(row.ThreadID), + ReplyTarget: pgTextString(row.ReplyTarget), + Metadata: parseJSONMap(row.Metadata), + CreatedAt: row.CreatedAt.Time, + UpdatedAt: row.UpdatedAt.Time, + } +} + +func toMessage(row sqlc.ChatMessage) Message { + return Message{ + ID: uuidString(row.ID), + ChatID: uuidString(row.ChatID), + BotID: uuidString(row.BotID), + RouteID: uuidString(row.RouteID), + SenderChannelIdentityID: uuidString(row.SenderChannelIdentityID), + SenderUserID: uuidString(row.SenderUserID), + Platform: pgTextString(row.Platform), + ExternalMessageID: pgTextString(row.ExternalMessageID), + Role: row.Role, + Content: json.RawMessage(row.Content), + Metadata: parseJSONMap(row.Metadata), + CreatedAt: row.CreatedAt.Time, + } +} + +func toMessages(rows []sqlc.ChatMessage) []Message { + msgs := make([]Message, 0, len(rows)) + for _, row := range rows { + msgs = append(msgs, toMessage(row)) + } + return msgs +} + +func defaultSettings(chatID string) Settings { + return Settings{ + ChatID: chatID, + EnableChatMemory: true, + EnablePrivateMemory: true, + EnablePublicMemory: false, + } +} + +func determineChatKind(threadID, conversationType string) string { + if strings.TrimSpace(threadID) != "" { + return KindThread + } + ct := strings.ToLower(strings.TrimSpace(conversationType)) + if ct == "p2p" || ct == "private" || ct == "" { + return KindDirect + } + return KindGroup +} + +func uuidString(id pgtype.UUID) string { + if !id.Valid { + return "" + } + b := id.Bytes + return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", + b[0:4], b[4:6], b[6:8], b[8:10], b[10:16]) +} + +func pgTextString(t pgtype.Text) string { + if !t.Valid { + return "" + } + return t.String +} + +func toPgText(s string) pgtype.Text { + s = strings.TrimSpace(s) + if s == "" { + return pgtype.Text{} + } + return pgtype.Text{String: s, Valid: true} +} + +func pgTimePtr(ts pgtype.Timestamptz) *time.Time { + if !ts.Valid { + return nil + } + value := ts.Time + return &value +} + +func mustParseUUID(id string) pgtype.UUID { + pgID, _ := parseUUID(id) + return pgID +} + +func nonNilMap(m map[string]any) map[string]any { + if m == nil { + return map[string]any{} + } + return m +} + +func parseJSONMap(data []byte) map[string]any { + if len(data) == 0 { + return nil + } + var m map[string]any + _ = json.Unmarshal(data, &m) + return m +} + +func (s *Service) resolveChatCreatorChannelIdentityID(ctx context.Context, botID, fallbackChannelIdentityID, kind string) string { + fallback := strings.TrimSpace(fallbackChannelIdentityID) + if kind != KindGroup || s.queries == nil { + return fallback + } + pgBotID, err := parseUUID(botID) + if err != nil { + return fallback + } + row, err := s.queries.GetBotByID(ctx, pgBotID) + if err != nil { + s.logger.Warn("resolve bot owner for group chat failed", slog.Any("error", err)) + return fallback + } + ownerChannelIdentityID := uuidString(row.OwnerUserID) + if strings.TrimSpace(ownerChannelIdentityID) == "" { + return fallback + } + return ownerChannelIdentityID +} + +func (s *Service) isGroupChat(ctx context.Context, chatID string) bool { + chatObj, err := s.Get(ctx, chatID) + if err != nil { + return false + } + return strings.EqualFold(strings.TrimSpace(chatObj.Kind), KindGroup) +} diff --git a/internal/chat/service_presence_integration_test.go b/internal/chat/service_presence_integration_test.go new file mode 100644 index 00000000..c8856b54 --- /dev/null +++ b/internal/chat/service_presence_integration_test.go @@ -0,0 +1,269 @@ +package chat_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "os" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/memohai/memoh/internal/channelidentities" + "github.com/memohai/memoh/internal/chat" + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" +) + +type chatPresenceFixture struct { + chatSvc *chat.Service + channelIdentitySvc *channelidentities.Service + queries *sqlc.Queries + cleanup func() +} + +func setupChatPresenceIntegrationTest(t *testing.T) chatPresenceFixture { + t.Helper() + + dsn := os.Getenv("TEST_POSTGRES_DSN") + if dsn == "" { + t.Skip("skip integration test: TEST_POSTGRES_DSN is not set") + } + + ctx := context.Background() + pool, err := pgxpool.New(ctx, dsn) + if err != nil { + t.Skipf("skip integration test: cannot connect to database: %v", err) + } + if err := pool.Ping(ctx); err != nil { + pool.Close() + t.Skipf("skip integration test: database ping failed: %v", err) + } + + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + queries := sqlc.New(pool) + + return chatPresenceFixture{ + chatSvc: chat.NewService(logger, queries), + channelIdentitySvc: channelidentities.NewService(logger, queries), + queries: queries, + cleanup: func() { pool.Close() }, + } +} + +func isLegacyChatPresenceSchemaError(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "relation \"chat_channelIdentity_presence\" does not exist") || + strings.Contains(msg, "column \"user_id\" of relation \"channelidentities\" does not exist") || + strings.Contains(msg, "column \"sender_user_id\" of relation \"chat_messages\" does not exist") || + strings.Contains(msg, "column \"created_by_user_id\" of relation \"chats\" does not exist") +} + +func createUserForChatPresence(ctx context.Context, queries *sqlc.Queries) (string, error) { + row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{ + IsActive: true, + Metadata: []byte("{}"), + }) + if err != nil { + return "", err + } + return db.UUIDToString(row.ID), nil +} + +func createBotForChatPresence(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) { + pgOwnerID, err := db.ParseUUID(ownerUserID) + if err != nil { + return "", err + } + meta, _ := json.Marshal(map[string]any{"source": "chat-presence-integration-test"}) + row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{ + OwnerUserID: pgOwnerID, + Type: "personal", + DisplayName: pgtype.Text{String: "presence-test-bot", Valid: true}, + IsActive: true, + Metadata: meta, + }) + if err != nil { + return "", err + } + return db.UUIDToString(row.ID), nil +} + +func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, string, string, string) { + t.Helper() + + fixture := setupChatPresenceIntegrationTest(t) + ctx := context.Background() + + ownerUserID, err := createUserForChatPresence(ctx, fixture.queries) + if err != nil { + if isLegacyChatPresenceSchemaError(err) { + fixture.cleanup() + t.Skipf("skip integration test on legacy schema: %v", err) + } + fixture.cleanup() + t.Fatalf("create owner user failed: %v", err) + } + observerUserID, err := createUserForChatPresence(ctx, fixture.queries) + if err != nil { + fixture.cleanup() + t.Fatalf("create observer user failed: %v", err) + } + botID, err := createBotForChatPresence(ctx, fixture.queries, ownerUserID) + if err != nil { + fixture.cleanup() + t.Fatalf("create bot failed: %v", err) + } + + createdChat, err := fixture.chatSvc.Create(ctx, botID, ownerUserID, chat.CreateChatRequest{ + Kind: chat.KindGroup, + Title: "presence-observed", + }) + if err != nil { + if isLegacyChatPresenceSchemaError(err) { + fixture.cleanup() + t.Skipf("skip integration test on legacy schema: %v", err) + } + fixture.cleanup() + t.Fatalf("create chat failed: %v", err) + } + + observedChannelIdentity, err := fixture.channelIdentitySvc.ResolveByChannelIdentity( + ctx, + "feishu", + fmt.Sprintf("presence-channelIdentity-%d", time.Now().UnixNano()), + "presence-observer", + ) + if err != nil { + if isLegacyChatPresenceSchemaError(err) { + fixture.cleanup() + t.Skipf("skip integration test on legacy schema: %v", err) + } + fixture.cleanup() + t.Fatalf("resolve channelIdentity failed: %v", err) + } + + _, err = fixture.chatSvc.PersistMessage( + ctx, + createdChat.ID, + botID, + "", + observedChannelIdentity.ID, + "", + "feishu", + fmt.Sprintf("ext-msg-%d", time.Now().UnixNano()), + "user", + []byte(`{"content":"hello from observed channelIdentity"}`), + nil, + ) + if err != nil { + if isLegacyChatPresenceSchemaError(err) { + fixture.cleanup() + t.Skipf("skip integration test on legacy schema: %v", err) + } + fixture.cleanup() + t.Fatalf("persist message failed: %v", err) + } + + return fixture, botID, createdChat.ID, observerUserID, observedChannelIdentity.ID +} + +func TestObservedChatVisibleAfterBindWithoutBackfill(t *testing.T) { + fixture, botID, chatID, observerUserID, observedChannelIdentityID := setupObservedChatScenario(t) + defer fixture.cleanup() + + ctx := context.Background() + beforeBind, err := fixture.chatSvc.ListByBotAndChannelIdentity(ctx, botID, observerUserID) + if err != nil { + t.Fatalf("list chats before bind failed: %v", err) + } + if len(beforeBind) != 0 { + t.Fatalf("expected no visible chats before bind, got %d", len(beforeBind)) + } + + if err := fixture.channelIdentitySvc.LinkChannelIdentityToUser(ctx, observedChannelIdentityID, observerUserID); err != nil { + t.Fatalf("link channelIdentity to user failed: %v", err) + } + + afterBind, err := fixture.chatSvc.ListByBotAndChannelIdentity(ctx, botID, observerUserID) + if err != nil { + t.Fatalf("list chats after bind failed: %v", err) + } + if len(afterBind) == 0 { + t.Fatalf("expected observed chat visible after bind, got %d chats", len(afterBind)) + } + + var target *chat.ChatListItem + for i := range afterBind { + if afterBind[i].ID == chatID { + target = &afterBind[i] + break + } + } + if target == nil { + t.Fatalf("expected chat %s in visible list after bind", chatID) + } + if target.AccessMode != chat.AccessModeChannelIdentityObserved { + t.Fatalf("expected access_mode=%s, got %s", chat.AccessModeChannelIdentityObserved, target.AccessMode) + } + if target.ParticipantRole != "" { + t.Fatalf("expected empty participant_role for observed chat, got %s", target.ParticipantRole) + } + if target.LastObservedAt == nil { + t.Fatal("expected last_observed_at to be set for observed chat") + } +} + +func TestObservedAccessReadableButNotParticipant(t *testing.T) { + fixture, botID, chatID, observerUserID, observedChannelIdentityID := setupObservedChatScenario(t) + defer fixture.cleanup() + + ctx := context.Background() + if err := fixture.channelIdentitySvc.LinkChannelIdentityToUser(ctx, observedChannelIdentityID, observerUserID); err != nil { + t.Fatalf("link channelIdentity to user failed: %v", err) + } + + access, err := fixture.chatSvc.GetReadAccess(ctx, chatID, observerUserID) + if err != nil { + t.Fatalf("get read access failed: %v", err) + } + if access.AccessMode != chat.AccessModeChannelIdentityObserved { + t.Fatalf("expected read access %s, got %s", chat.AccessModeChannelIdentityObserved, access.AccessMode) + } + + messages, err := fixture.chatSvc.ListMessages(ctx, chatID) + if err != nil { + t.Fatalf("list messages failed: %v", err) + } + if len(messages) == 0 { + t.Fatal("expected observed user can read chat messages") + } + + _, err = fixture.chatSvc.GetParticipant(ctx, chatID, observerUserID) + if !errors.Is(err, chat.ErrNotParticipant) { + t.Fatalf("expected ErrNotParticipant for observed user, got %v", err) + } + ok, err := fixture.chatSvc.IsParticipant(ctx, chatID, observerUserID) + if err != nil { + t.Fatalf("check participant failed: %v", err) + } + if ok { + t.Fatal("expected observed user to remain non-participant") + } + + visibleChats, err := fixture.chatSvc.ListByBotAndChannelIdentity(ctx, botID, observerUserID) + if err != nil { + t.Fatalf("list visible chats failed: %v", err) + } + if len(visibleChats) == 0 || visibleChats[0].AccessMode != chat.AccessModeChannelIdentityObserved { + t.Fatal("expected observed list entry with channel_identity_observed access mode") + } +} diff --git a/internal/chat/types.go b/internal/chat/types.go index 51bb60fe..77cda9be 100644 --- a/internal/chat/types.go +++ b/internal/chat/types.go @@ -1,12 +1,141 @@ // Package chat orchestrates conversations with the agent gateway, including -// synchronous and streaming chat, scheduled triggers, history, and memory storage. +// synchronous and streaming chat, scheduled triggers, messages, and memory storage. package chat import ( "encoding/json" "strings" + "time" ) +// Chat kind constants. +const ( + KindDirect = "direct" + KindGroup = "group" + KindThread = "thread" +) + +// Participant role constants. +const ( + RoleOwner = "owner" + RoleAdmin = "admin" + RoleMember = "member" +) + +// Chat list access mode constants. +const ( + AccessModeParticipant = "participant" + AccessModeChannelIdentityObserved = "channel_identity_observed" +) + +// Chat is the first-class conversation container. +type Chat struct { + ID string `json:"id"` + BotID string `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID string `json:"parent_chat_id,omitempty"` + Title string `json:"title,omitempty"` + CreatedBy string `json:"created_by"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// ChatListItem is a chat entry with access context for list rendering. +type ChatListItem struct { + ID string `json:"id"` + BotID string `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID string `json:"parent_chat_id,omitempty"` + Title string `json:"title,omitempty"` + CreatedBy string `json:"created_by"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + AccessMode string `json:"access_mode"` + ParticipantRole string `json:"participant_role,omitempty"` + LastObservedAt *time.Time `json:"last_observed_at,omitempty"` +} + +// ChatReadAccess is the resolved access context for reading chat content. +type ChatReadAccess struct { + AccessMode string + ParticipantRole string + LastObservedAt *time.Time +} + +// Participant represents a chat member. +type Participant struct { + ChatID string `json:"chat_id"` + UserID string `json:"user_id"` + Role string `json:"role"` + JoinedAt time.Time `json:"joined_at"` +} + +// Settings holds per-chat configuration. +type Settings struct { + ChatID string `json:"chat_id"` + EnableChatMemory bool `json:"enable_chat_memory"` + EnablePrivateMemory bool `json:"enable_private_memory"` + EnablePublicMemory bool `json:"enable_public_memory"` + ModelID string `json:"model_id,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// Route maps external channel conversations to a chat. +type Route struct { + ID string `json:"id"` + ChatID string `json:"chat_id"` + BotID string `json:"bot_id"` + Platform string `json:"platform"` + ChannelConfigID string `json:"channel_config_id,omitempty"` + ConversationID string `json:"conversation_id"` + ThreadID string `json:"thread_id,omitempty"` + ReplyTarget string `json:"reply_target,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// Message represents a single persisted chat message. +type Message struct { + ID string `json:"id"` + ChatID string `json:"chat_id"` + BotID string `json:"bot_id"` + RouteID string `json:"route_id,omitempty"` + SenderChannelIdentityID string `json:"sender_channel_identity_id,omitempty"` + SenderUserID string `json:"sender_user_id,omitempty"` + Platform string `json:"platform,omitempty"` + ExternalMessageID string `json:"external_message_id,omitempty"` + Role string `json:"role"` + Content json.RawMessage `json:"content"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +// CreateChatRequest is the input for creating a chat. +type CreateChatRequest struct { + Kind string `json:"kind"` + Title string `json:"title,omitempty"` + ParentChatID string `json:"parent_chat_id,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// UpdateSettingsRequest is the input for updating chat settings. +type UpdateSettingsRequest struct { + EnableChatMemory *bool `json:"enable_chat_memory,omitempty"` + EnablePrivateMemory *bool `json:"enable_private_memory,omitempty"` + EnablePublicMemory *bool `json:"enable_public_memory,omitempty"` + ModelID *string `json:"model_id,omitempty"` +} + +// ResolveChatResult is returned by ResolveChat. +type ResolveChatResult struct { + ChatID string + RouteID string + Created bool +} + // ModelMessage is the canonical message format exchanged with the agent gateway. // Aligned with Vercel AI SDK ModelMessage structure. type ModelMessage struct { @@ -73,14 +202,14 @@ func NewTextContent(text string) json.RawMessage { // ContentPart represents one element of a multi-part message content. type ContentPart struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - URL string `json:"url,omitempty"` - Styles []string `json:"styles,omitempty"` - Language string `json:"language,omitempty"` - UserID string `json:"user_id,omitempty"` - Emoji string `json:"emoji,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + URL string `json:"url,omitempty"` + Styles []string `json:"styles,omitempty"` + Language string `json:"language,omitempty"` + ChannelIdentityID string `json:"channel_identity_id,omitempty"` + Emoji string `json:"emoji,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } // HasValue reports whether the content part carries a meaningful value. @@ -105,16 +234,15 @@ type ToolCallFunction struct { // ChatRequest is the input for Chat and StreamChat. type ChatRequest struct { - BotID string `json:"-"` - SessionID string `json:"-"` - Token string `json:"-"` - UserID string `json:"-"` - ContainerID string `json:"-"` - ContactID string `json:"-"` - ContactName string `json:"-"` - ContactAlias string `json:"-"` - ReplyTarget string `json:"-"` - SessionToken string `json:"-"` + BotID string `json:"-"` + ChatID string `json:"-"` + Token string `json:"-"` + ChannelIdentityID string `json:"-"` + ContainerID string `json:"-"` + DisplayName string `json:"-"` + RouteID string `json:"-"` + ChatToken string `json:"-"` + ExternalMessageID string `json:"-"` Query string `json:"query"` Model string `json:"model,omitempty"` diff --git a/internal/contacts/service.go b/internal/contacts/service.go deleted file mode 100644 index 1a8a40c3..00000000 --- a/internal/contacts/service.go +++ /dev/null @@ -1,410 +0,0 @@ -package contacts - -import ( - "context" - "encoding/json" - "fmt" - "strings" - "time" - - "github.com/google/uuid" - "github.com/jackc/pgx/v5/pgtype" - - "github.com/memohai/memoh/internal/db/sqlc" -) - -type Service struct { - queries *sqlc.Queries -} - -func NewService(queries *sqlc.Queries) *Service { - return &Service{queries: queries} -} - -func (s *Service) GetByID(ctx context.Context, contactID string) (Contact, error) { - if s.queries == nil { - return Contact{}, fmt.Errorf("contacts queries not configured") - } - pgID, err := parseUUID(contactID) - if err != nil { - return Contact{}, err - } - row, err := s.queries.GetContactByID(ctx, pgID) - if err != nil { - return Contact{}, err - } - return normalizeContact(row) -} - -func (s *Service) GetByUserID(ctx context.Context, botID, userID string) (Contact, error) { - if s.queries == nil { - return Contact{}, fmt.Errorf("contacts queries not configured") - } - pgBotID, err := parseUUID(botID) - if err != nil { - return Contact{}, err - } - pgUserID, err := parseUUID(userID) - if err != nil { - return Contact{}, err - } - row, err := s.queries.GetContactByUserID(ctx, sqlc.GetContactByUserIDParams{ - BotID: pgBotID, - UserID: pgUserID, - }) - if err != nil { - return Contact{}, err - } - return normalizeContact(row) -} - -func (s *Service) GetByChannelIdentity(ctx context.Context, botID, platform, externalID string) (ContactChannel, error) { - if s.queries == nil { - return ContactChannel{}, fmt.Errorf("contacts queries not configured") - } - pgBotID, err := parseUUID(botID) - if err != nil { - return ContactChannel{}, err - } - row, err := s.queries.GetContactChannelByIdentity(ctx, sqlc.GetContactChannelByIdentityParams{ - BotID: pgBotID, - Platform: platform, - ExternalID: externalID, - }) - if err != nil { - return ContactChannel{}, err - } - return normalizeContactChannel(row) -} - -func (s *Service) ListChannelsByContact(ctx context.Context, contactID string) ([]ContactChannel, error) { - if s.queries == nil { - return nil, fmt.Errorf("contacts queries not configured") - } - pgContactID, err := parseUUID(contactID) - if err != nil { - return nil, err - } - rows, err := s.queries.ListContactChannelsByContact(ctx, pgContactID) - if err != nil { - return nil, err - } - items := make([]ContactChannel, 0, len(rows)) - for _, row := range rows { - item, err := normalizeContactChannel(row) - if err != nil { - return nil, err - } - items = append(items, item) - } - return items, nil -} - -func (s *Service) ListByBot(ctx context.Context, botID string) ([]Contact, error) { - if s.queries == nil { - return nil, fmt.Errorf("contacts queries not configured") - } - pgBotID, err := parseUUID(botID) - if err != nil { - return nil, err - } - rows, err := s.queries.ListContactsByBot(ctx, pgBotID) - if err != nil { - return nil, err - } - items := make([]Contact, 0, len(rows)) - for _, row := range rows { - contact, err := normalizeContact(row) - if err != nil { - return nil, err - } - items = append(items, contact) - } - return items, nil -} - -func (s *Service) Search(ctx context.Context, botID, query string) ([]Contact, error) { - if s.queries == nil { - return nil, fmt.Errorf("contacts queries not configured") - } - trimmed := strings.TrimSpace(query) - if trimmed == "" { - return s.ListByBot(ctx, botID) - } - pgBotID, err := parseUUID(botID) - if err != nil { - return nil, err - } - search := "%" + trimmed + "%" - rows, err := s.queries.SearchContacts(ctx, sqlc.SearchContactsParams{ - BotID: pgBotID, - Query: pgtype.Text{String: search, Valid: true}, - }) - if err != nil { - return nil, err - } - items := make([]Contact, 0, len(rows)) - for _, row := range rows { - contact, err := normalizeContact(row) - if err != nil { - return nil, err - } - items = append(items, contact) - } - return items, nil -} - -func (s *Service) Create(ctx context.Context, req CreateRequest) (Contact, error) { - if s.queries == nil { - return Contact{}, fmt.Errorf("contacts queries not configured") - } - pgBotID, err := parseUUID(req.BotID) - if err != nil { - return Contact{}, err - } - pgUserID := pgtype.UUID{Valid: false} - if strings.TrimSpace(req.UserID) != "" { - parsed, err := parseUUID(req.UserID) - if err != nil { - return Contact{}, err - } - pgUserID = parsed - } - payload, err := json.Marshal(defaultMetadata(req.Metadata)) - if err != nil { - return Contact{}, err - } - row, err := s.queries.CreateContact(ctx, sqlc.CreateContactParams{ - BotID: pgBotID, - UserID: pgUserID, - DisplayName: pgtype.Text{String: strings.TrimSpace(req.DisplayName), Valid: strings.TrimSpace(req.DisplayName) != ""}, - Alias: pgtype.Text{String: strings.TrimSpace(req.Alias), Valid: strings.TrimSpace(req.Alias) != ""}, - Tags: normalizeTags(req.Tags), - Status: normalizeStatus(req.Status), - Metadata: payload, - }) - if err != nil { - return Contact{}, err - } - return normalizeContact(row) -} - -func (s *Service) CreateGuest(ctx context.Context, botID, displayName string) (Contact, error) { - return s.Create(ctx, CreateRequest{ - BotID: botID, - DisplayName: displayName, - Status: "active", - }) -} - -func (s *Service) Update(ctx context.Context, contactID string, req UpdateRequest) (Contact, error) { - if s.queries == nil { - return Contact{}, fmt.Errorf("contacts queries not configured") - } - pgID, err := parseUUID(contactID) - if err != nil { - return Contact{}, err - } - var displayName pgtype.Text - if req.DisplayName != nil { - displayName = pgtype.Text{String: strings.TrimSpace(*req.DisplayName), Valid: strings.TrimSpace(*req.DisplayName) != ""} - } - var alias pgtype.Text - if req.Alias != nil { - alias = pgtype.Text{String: strings.TrimSpace(*req.Alias), Valid: strings.TrimSpace(*req.Alias) != ""} - } - var tags []string - if req.Tags != nil { - tags = normalizeTags(*req.Tags) - } - status := "" - if req.Status != nil { - status = normalizeStatus(*req.Status) - } - var metadata []byte - if req.Metadata != nil { - encoded, err := json.Marshal(defaultMetadata(req.Metadata)) - if err != nil { - return Contact{}, err - } - metadata = encoded - } - row, err := s.queries.UpdateContact(ctx, sqlc.UpdateContactParams{ - ID: pgID, - DisplayName: displayName, - Alias: alias, - Tags: tags, - Status: status, - Metadata: metadata, - }) - if err != nil { - return Contact{}, err - } - return normalizeContact(row) -} - -func (s *Service) BindUser(ctx context.Context, contactID, userID string) (Contact, error) { - if s.queries == nil { - return Contact{}, fmt.Errorf("contacts queries not configured") - } - pgContactID, err := parseUUID(contactID) - if err != nil { - return Contact{}, err - } - pgUserID, err := parseUUID(userID) - if err != nil { - return Contact{}, err - } - row, err := s.queries.UpdateContactUser(ctx, sqlc.UpdateContactUserParams{ - ID: pgContactID, - UserID: pgUserID, - }) - if err != nil { - return Contact{}, err - } - return normalizeContact(row) -} - -func (s *Service) UpsertChannel(ctx context.Context, botID, contactID, platform, externalID string, metadata map[string]any) (ContactChannel, error) { - if s.queries == nil { - return ContactChannel{}, fmt.Errorf("contacts queries not configured") - } - pgBotID, err := parseUUID(botID) - if err != nil { - return ContactChannel{}, err - } - pgContactID, err := parseUUID(contactID) - if err != nil { - return ContactChannel{}, err - } - payload, err := json.Marshal(defaultMetadata(metadata)) - if err != nil { - return ContactChannel{}, err - } - row, err := s.queries.UpsertContactChannel(ctx, sqlc.UpsertContactChannelParams{ - BotID: pgBotID, - ContactID: pgContactID, - Platform: strings.TrimSpace(platform), - ExternalID: strings.TrimSpace(externalID), - Metadata: payload, - }) - if err != nil { - return ContactChannel{}, err - } - return normalizeContactChannel(row) -} - -func normalizeContact(row sqlc.Contact) (Contact, error) { - metadata, err := decodeMetadata(row.Metadata) - if err != nil { - return Contact{}, err - } - return Contact{ - ID: toUUIDString(row.ID), - BotID: toUUIDString(row.BotID), - UserID: toUUIDString(row.UserID), - DisplayName: strings.TrimSpace(row.DisplayName.String), - Alias: strings.TrimSpace(row.Alias.String), - Tags: normalizeTags(row.Tags), - Status: strings.TrimSpace(row.Status), - Metadata: metadata, - CreatedAt: timeFromPg(row.CreatedAt), - UpdatedAt: timeFromPg(row.UpdatedAt), - }, nil -} - -func normalizeContactChannel(row sqlc.ContactChannel) (ContactChannel, error) { - metadata, err := decodeMetadata(row.Metadata) - if err != nil { - return ContactChannel{}, err - } - return ContactChannel{ - ID: toUUIDString(row.ID), - BotID: toUUIDString(row.BotID), - ContactID: toUUIDString(row.ContactID), - Platform: strings.TrimSpace(row.Platform), - ExternalID: strings.TrimSpace(row.ExternalID), - Metadata: metadata, - CreatedAt: timeFromPg(row.CreatedAt), - UpdatedAt: timeFromPg(row.UpdatedAt), - }, nil -} - -func decodeMetadata(raw []byte) (map[string]any, error) { - if len(raw) == 0 { - return map[string]any{}, nil - } - var payload map[string]any - if err := json.Unmarshal(raw, &payload); err != nil { - return nil, err - } - if payload == nil { - payload = map[string]any{} - } - return payload, nil -} - -func defaultMetadata(value map[string]any) map[string]any { - if value == nil { - return map[string]any{} - } - return value -} - -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} - -func toUUIDString(value pgtype.UUID) string { - if !value.Valid { - return "" - } - parsed, err := uuid.FromBytes(value.Bytes[:]) - if err != nil { - return "" - } - return parsed.String() -} - -func timeFromPg(value pgtype.Timestamptz) time.Time { - if value.Valid { - return value.Time - } - return time.Time{} -} - -func normalizeTags(tags []string) []string { - seen := map[string]struct{}{} - normalized := make([]string, 0, len(tags)) - for _, tag := range tags { - trimmed := strings.TrimSpace(tag) - if trimmed == "" { - continue - } - if _, ok := seen[trimmed]; ok { - continue - } - seen[trimmed] = struct{}{} - normalized = append(normalized, trimmed) - } - return normalized -} - -func normalizeStatus(status string) string { - trimmed := strings.ToLower(strings.TrimSpace(status)) - switch trimmed { - case "active", "blocked", "pending": - return trimmed - case "": - return "active" - default: - return "active" - } -} diff --git a/internal/contacts/types.go b/internal/contacts/types.go deleted file mode 100644 index f39ff1e3..00000000 --- a/internal/contacts/types.go +++ /dev/null @@ -1,45 +0,0 @@ -package contacts - -import "time" - -type Contact struct { - ID string - BotID string - UserID string - DisplayName string - Alias string - Tags []string - Status string - Metadata map[string]any - CreatedAt time.Time - UpdatedAt time.Time -} - -type ContactChannel struct { - ID string - BotID string - ContactID string - Platform string - ExternalID string - Metadata map[string]any - CreatedAt time.Time - UpdatedAt time.Time -} - -type CreateRequest struct { - BotID string - UserID string - DisplayName string - Alias string - Tags []string - Status string - Metadata map[string]any -} - -type UpdateRequest struct { - DisplayName *string - Alias *string - Tags *[]string - Status *string - Metadata map[string]any -} diff --git a/internal/db/sqlc/bind.sql.go b/internal/db/sqlc/bind.sql.go new file mode 100644 index 00000000..c4df9b24 --- /dev/null +++ b/internal/db/sqlc/bind.sql.go @@ -0,0 +1,120 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: bind.sql + +package sqlc + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createBindCode = `-- name: CreateBindCode :one +INSERT INTO channel_identity_bind_codes (token, issued_by_user_id, platform, expires_at) +VALUES ($1, $2, $3, $4) +RETURNING id, token, issued_by_user_id, platform, expires_at, used_at, used_by_channel_identity_id, created_at +` + +type CreateBindCodeParams struct { + Token string `json:"token"` + IssuedByUserID pgtype.UUID `json:"issued_by_user_id"` + Platform pgtype.Text `json:"platform"` + ExpiresAt pgtype.Timestamptz `json:"expires_at"` +} + +func (q *Queries) CreateBindCode(ctx context.Context, arg CreateBindCodeParams) (ChannelIdentityBindCode, error) { + row := q.db.QueryRow(ctx, createBindCode, + arg.Token, + arg.IssuedByUserID, + arg.Platform, + arg.ExpiresAt, + ) + var i ChannelIdentityBindCode + err := row.Scan( + &i.ID, + &i.Token, + &i.IssuedByUserID, + &i.Platform, + &i.ExpiresAt, + &i.UsedAt, + &i.UsedByChannelIdentityID, + &i.CreatedAt, + ) + return i, err +} + +const getBindCode = `-- name: GetBindCode :one +SELECT id, token, issued_by_user_id, platform, expires_at, used_at, used_by_channel_identity_id, created_at +FROM channel_identity_bind_codes +WHERE token = $1 +` + +func (q *Queries) GetBindCode(ctx context.Context, token string) (ChannelIdentityBindCode, error) { + row := q.db.QueryRow(ctx, getBindCode, token) + var i ChannelIdentityBindCode + err := row.Scan( + &i.ID, + &i.Token, + &i.IssuedByUserID, + &i.Platform, + &i.ExpiresAt, + &i.UsedAt, + &i.UsedByChannelIdentityID, + &i.CreatedAt, + ) + return i, err +} + +const getBindCodeForUpdate = `-- name: GetBindCodeForUpdate :one +SELECT id, token, issued_by_user_id, platform, expires_at, used_at, used_by_channel_identity_id, created_at +FROM channel_identity_bind_codes +WHERE token = $1 +FOR UPDATE +` + +func (q *Queries) GetBindCodeForUpdate(ctx context.Context, token string) (ChannelIdentityBindCode, error) { + row := q.db.QueryRow(ctx, getBindCodeForUpdate, token) + var i ChannelIdentityBindCode + err := row.Scan( + &i.ID, + &i.Token, + &i.IssuedByUserID, + &i.Platform, + &i.ExpiresAt, + &i.UsedAt, + &i.UsedByChannelIdentityID, + &i.CreatedAt, + ) + return i, err +} + +const markBindCodeUsed = `-- name: MarkBindCodeUsed :one +UPDATE channel_identity_bind_codes +SET used_at = now(), used_by_channel_identity_id = $2 +WHERE id = $1 + AND used_at IS NULL +RETURNING id, token, issued_by_user_id, platform, expires_at, used_at, used_by_channel_identity_id, created_at +` + +type MarkBindCodeUsedParams struct { + ID pgtype.UUID `json:"id"` + UsedByChannelIdentityID pgtype.UUID `json:"used_by_channel_identity_id"` +} + +func (q *Queries) MarkBindCodeUsed(ctx context.Context, arg MarkBindCodeUsedParams) (ChannelIdentityBindCode, error) { + row := q.db.QueryRow(ctx, markBindCodeUsed, arg.ID, arg.UsedByChannelIdentityID) + var i ChannelIdentityBindCode + err := row.Scan( + &i.ID, + &i.Token, + &i.IssuedByUserID, + &i.Platform, + &i.ExpiresAt, + &i.UsedAt, + &i.UsedByChannelIdentityID, + &i.CreatedAt, + ) + return i, err +} diff --git a/internal/db/sqlc/bots.sql.go b/internal/db/sqlc/bots.sql.go index 69f4b4ed..3fafd6d9 100644 --- a/internal/db/sqlc/bots.sql.go +++ b/internal/db/sqlc/bots.sql.go @@ -14,7 +14,7 @@ import ( const createBot = `-- name: CreateBot :one INSERT INTO bots (owner_user_id, type, display_name, avatar_url, is_active, metadata) VALUES ($1, $2, $3, $4, $5, $6) -RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at +RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at ` type CreateBotParams struct { @@ -43,6 +43,12 @@ func (q *Queries) CreateBot(ctx context.Context, arg CreateBotParams) (Bot, erro &i.DisplayName, &i.AvatarUrl, &i.IsActive, + &i.MaxContextLoadTime, + &i.Language, + &i.AllowGuest, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, &i.Metadata, &i.CreatedAt, &i.UpdatedAt, @@ -74,7 +80,7 @@ func (q *Queries) DeleteBotMember(ctx context.Context, arg DeleteBotMemberParams } const getBotByID = `-- name: GetBotByID :one -SELECT id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at +SELECT id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at FROM bots WHERE id = $1 ` @@ -89,6 +95,12 @@ func (q *Queries) GetBotByID(ctx context.Context, id pgtype.UUID) (Bot, error) { &i.DisplayName, &i.AvatarUrl, &i.IsActive, + &i.MaxContextLoadTime, + &i.Language, + &i.AllowGuest, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, &i.Metadata, &i.CreatedAt, &i.UpdatedAt, @@ -153,7 +165,7 @@ func (q *Queries) ListBotMembers(ctx context.Context, botID pgtype.UUID) ([]BotM } const listBotsByMember = `-- name: ListBotsByMember :many -SELECT b.id, b.owner_user_id, b.type, b.display_name, b.avatar_url, b.is_active, b.metadata, b.created_at, b.updated_at +SELECT b.id, b.owner_user_id, b.type, b.display_name, b.avatar_url, b.is_active, b.max_context_load_time, b.language, b.allow_guest, b.chat_model_id, b.memory_model_id, b.embedding_model_id, b.metadata, b.created_at, b.updated_at FROM bots b JOIN bot_members m ON m.bot_id = b.id WHERE m.user_id = $1 @@ -176,6 +188,12 @@ func (q *Queries) ListBotsByMember(ctx context.Context, userID pgtype.UUID) ([]B &i.DisplayName, &i.AvatarUrl, &i.IsActive, + &i.MaxContextLoadTime, + &i.Language, + &i.AllowGuest, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, &i.Metadata, &i.CreatedAt, &i.UpdatedAt, @@ -191,7 +209,7 @@ func (q *Queries) ListBotsByMember(ctx context.Context, userID pgtype.UUID) ([]B } const listBotsByOwner = `-- name: ListBotsByOwner :many -SELECT id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at +SELECT id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at FROM bots WHERE owner_user_id = $1 ORDER BY created_at DESC @@ -213,6 +231,12 @@ func (q *Queries) ListBotsByOwner(ctx context.Context, ownerUserID pgtype.UUID) &i.DisplayName, &i.AvatarUrl, &i.IsActive, + &i.MaxContextLoadTime, + &i.Language, + &i.AllowGuest, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, &i.Metadata, &i.CreatedAt, &i.UpdatedAt, @@ -232,7 +256,7 @@ UPDATE bots SET owner_user_id = $2, updated_at = now() WHERE id = $1 -RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at +RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at ` type UpdateBotOwnerParams struct { @@ -250,6 +274,12 @@ func (q *Queries) UpdateBotOwner(ctx context.Context, arg UpdateBotOwnerParams) &i.DisplayName, &i.AvatarUrl, &i.IsActive, + &i.MaxContextLoadTime, + &i.Language, + &i.AllowGuest, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, &i.Metadata, &i.CreatedAt, &i.UpdatedAt, @@ -265,7 +295,7 @@ SET display_name = $2, metadata = $5, updated_at = now() WHERE id = $1 -RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, metadata, created_at, updated_at +RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at ` type UpdateBotProfileParams struct { @@ -292,6 +322,12 @@ func (q *Queries) UpdateBotProfile(ctx context.Context, arg UpdateBotProfilePara &i.DisplayName, &i.AvatarUrl, &i.IsActive, + &i.MaxContextLoadTime, + &i.Language, + &i.AllowGuest, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, &i.Metadata, &i.CreatedAt, &i.UpdatedAt, diff --git a/internal/db/sqlc/channel_identities.sql.go b/internal/db/sqlc/channel_identities.sql.go new file mode 100644 index 00000000..280706b2 --- /dev/null +++ b/internal/db/sqlc/channel_identities.sql.go @@ -0,0 +1,249 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: channel_identities.sql + +package sqlc + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const clearChannelIdentityLinkedUser = `-- name: ClearChannelIdentityLinkedUser :one +UPDATE channel_identities +SET user_id = NULL, updated_at = now() +WHERE id = $1 +RETURNING id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +` + +func (q *Queries) ClearChannelIdentityLinkedUser(ctx context.Context, id pgtype.UUID) (ChannelIdentity, error) { + row := q.db.QueryRow(ctx, clearChannelIdentityLinkedUser, id) + var i ChannelIdentity + err := row.Scan( + &i.ID, + &i.UserID, + &i.Channel, + &i.ChannelSubjectID, + &i.DisplayName, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const createChannelIdentity = `-- name: CreateChannelIdentity :one +INSERT INTO channel_identities (user_id, channel, channel_subject_id, display_name, metadata) +VALUES ($1, $2, $3, $4, $5) +RETURNING id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +` + +type CreateChannelIdentityParams struct { + UserID pgtype.UUID `json:"user_id"` + Channel string `json:"channel"` + ChannelSubjectID string `json:"channel_subject_id"` + DisplayName pgtype.Text `json:"display_name"` + Metadata []byte `json:"metadata"` +} + +func (q *Queries) CreateChannelIdentity(ctx context.Context, arg CreateChannelIdentityParams) (ChannelIdentity, error) { + row := q.db.QueryRow(ctx, createChannelIdentity, + arg.UserID, + arg.Channel, + arg.ChannelSubjectID, + arg.DisplayName, + arg.Metadata, + ) + var i ChannelIdentity + err := row.Scan( + &i.ID, + &i.UserID, + &i.Channel, + &i.ChannelSubjectID, + &i.DisplayName, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getChannelIdentityByChannelSubject = `-- name: GetChannelIdentityByChannelSubject :one +SELECT id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +FROM channel_identities +WHERE channel = $1 AND channel_subject_id = $2 +` + +type GetChannelIdentityByChannelSubjectParams struct { + Channel string `json:"channel"` + ChannelSubjectID string `json:"channel_subject_id"` +} + +func (q *Queries) GetChannelIdentityByChannelSubject(ctx context.Context, arg GetChannelIdentityByChannelSubjectParams) (ChannelIdentity, error) { + row := q.db.QueryRow(ctx, getChannelIdentityByChannelSubject, arg.Channel, arg.ChannelSubjectID) + var i ChannelIdentity + err := row.Scan( + &i.ID, + &i.UserID, + &i.Channel, + &i.ChannelSubjectID, + &i.DisplayName, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getChannelIdentityByID = `-- name: GetChannelIdentityByID :one +SELECT id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +FROM channel_identities +WHERE id = $1 +` + +func (q *Queries) GetChannelIdentityByID(ctx context.Context, id pgtype.UUID) (ChannelIdentity, error) { + row := q.db.QueryRow(ctx, getChannelIdentityByID, id) + var i ChannelIdentity + err := row.Scan( + &i.ID, + &i.UserID, + &i.Channel, + &i.ChannelSubjectID, + &i.DisplayName, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getChannelIdentityByIDForUpdate = `-- name: GetChannelIdentityByIDForUpdate :one +SELECT id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +FROM channel_identities +WHERE id = $1 +FOR UPDATE +` + +func (q *Queries) GetChannelIdentityByIDForUpdate(ctx context.Context, id pgtype.UUID) (ChannelIdentity, error) { + row := q.db.QueryRow(ctx, getChannelIdentityByIDForUpdate, id) + var i ChannelIdentity + err := row.Scan( + &i.ID, + &i.UserID, + &i.Channel, + &i.ChannelSubjectID, + &i.DisplayName, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const listChannelIdentitiesByUserID = `-- name: ListChannelIdentitiesByUserID :many +SELECT id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +FROM channel_identities +WHERE user_id = $1 +ORDER BY created_at DESC +` + +func (q *Queries) ListChannelIdentitiesByUserID(ctx context.Context, userID pgtype.UUID) ([]ChannelIdentity, error) { + rows, err := q.db.Query(ctx, listChannelIdentitiesByUserID, userID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ChannelIdentity + for rows.Next() { + var i ChannelIdentity + if err := rows.Scan( + &i.ID, + &i.UserID, + &i.Channel, + &i.ChannelSubjectID, + &i.DisplayName, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const setChannelIdentityLinkedUser = `-- name: SetChannelIdentityLinkedUser :one +UPDATE channel_identities +SET user_id = $2, updated_at = now() +WHERE id = $1 +RETURNING id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +` + +type SetChannelIdentityLinkedUserParams struct { + ID pgtype.UUID `json:"id"` + UserID pgtype.UUID `json:"user_id"` +} + +func (q *Queries) SetChannelIdentityLinkedUser(ctx context.Context, arg SetChannelIdentityLinkedUserParams) (ChannelIdentity, error) { + row := q.db.QueryRow(ctx, setChannelIdentityLinkedUser, arg.ID, arg.UserID) + var i ChannelIdentity + err := row.Scan( + &i.ID, + &i.UserID, + &i.Channel, + &i.ChannelSubjectID, + &i.DisplayName, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const upsertChannelIdentityByChannelSubject = `-- name: UpsertChannelIdentityByChannelSubject :one +INSERT INTO channel_identities (user_id, channel, channel_subject_id, display_name, metadata) +VALUES ($1, $2, $3, $4, $5) +ON CONFLICT (channel, channel_subject_id) +DO UPDATE SET + display_name = EXCLUDED.display_name, + metadata = EXCLUDED.metadata, + user_id = COALESCE(channel_identities.user_id, EXCLUDED.user_id), + updated_at = now() +RETURNING id, user_id, channel, channel_subject_id, display_name, metadata, created_at, updated_at +` + +type UpsertChannelIdentityByChannelSubjectParams struct { + UserID pgtype.UUID `json:"user_id"` + Channel string `json:"channel"` + ChannelSubjectID string `json:"channel_subject_id"` + DisplayName pgtype.Text `json:"display_name"` + Metadata []byte `json:"metadata"` +} + +func (q *Queries) UpsertChannelIdentityByChannelSubject(ctx context.Context, arg UpsertChannelIdentityByChannelSubjectParams) (ChannelIdentity, error) { + row := q.db.QueryRow(ctx, upsertChannelIdentityByChannelSubject, + arg.UserID, + arg.Channel, + arg.ChannelSubjectID, + arg.DisplayName, + arg.Metadata, + ) + var i ChannelIdentity + err := row.Scan( + &i.ID, + &i.UserID, + &i.Channel, + &i.ChannelSubjectID, + &i.DisplayName, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/internal/db/sqlc/channels.sql.go b/internal/db/sqlc/channels.sql.go index c38738d8..2bd0a488 100644 --- a/internal/db/sqlc/channels.sql.go +++ b/internal/db/sqlc/channels.sql.go @@ -11,16 +11,6 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -const deleteChannelSession = `-- name: DeleteChannelSession :exec -DELETE FROM channel_sessions -WHERE session_id = $1 -` - -func (q *Queries) DeleteChannelSession(ctx context.Context, sessionID string) error { - _, err := q.db.Exec(ctx, deleteChannelSession, sessionID) - return err -} - const getBotChannelConfig = `-- name: GetBotChannelConfig :one SELECT id, bot_id, channel_type, credentials, external_identity, self_identity, routing, capabilities, status, verified_at, created_at, updated_at FROM bot_channel_configs @@ -85,51 +75,25 @@ func (q *Queries) GetBotChannelConfigByExternalIdentity(ctx context.Context, arg return i, err } -const getChannelSessionByID = `-- name: GetChannelSessionByID :one -SELECT session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata, created_at, updated_at -FROM channel_sessions -WHERE session_id = $1 -LIMIT 1 -` - -func (q *Queries) GetChannelSessionByID(ctx context.Context, sessionID string) (ChannelSession, error) { - row := q.db.QueryRow(ctx, getChannelSessionByID, sessionID) - var i ChannelSession - err := row.Scan( - &i.SessionID, - &i.BotID, - &i.ChannelConfigID, - &i.UserID, - &i.ContactID, - &i.Platform, - &i.ReplyTarget, - &i.ThreadID, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - const getUserChannelBinding = `-- name: GetUserChannelBinding :one -SELECT id, user_id, channel_type, config, created_at, updated_at +SELECT id, user_id, platform, config, created_at, updated_at FROM user_channel_bindings -WHERE user_id = $1 AND channel_type = $2 +WHERE user_id = $1 AND platform = $2 LIMIT 1 ` type GetUserChannelBindingParams struct { - UserID pgtype.UUID `json:"user_id"` - ChannelType string `json:"channel_type"` + UserID pgtype.UUID `json:"user_id"` + Platform string `json:"platform"` } func (q *Queries) GetUserChannelBinding(ctx context.Context, arg GetUserChannelBindingParams) (UserChannelBinding, error) { - row := q.db.QueryRow(ctx, getUserChannelBinding, arg.UserID, arg.ChannelType) + row := q.db.QueryRow(ctx, getUserChannelBinding, arg.UserID, arg.Platform) var i UserChannelBinding err := row.Scan( &i.ID, &i.UserID, - &i.ChannelType, + &i.Platform, &i.Config, &i.CreatedAt, &i.UpdatedAt, @@ -177,59 +141,15 @@ func (q *Queries) ListBotChannelConfigsByType(ctx context.Context, channelType s return items, nil } -const listChannelSessionsByBotPlatform = `-- name: ListChannelSessionsByBotPlatform :many -SELECT session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata, created_at, updated_at -FROM channel_sessions -WHERE bot_id = $1 AND platform = $2 -ORDER BY updated_at DESC -` - -type ListChannelSessionsByBotPlatformParams struct { - BotID pgtype.UUID `json:"bot_id"` - Platform string `json:"platform"` -} - -func (q *Queries) ListChannelSessionsByBotPlatform(ctx context.Context, arg ListChannelSessionsByBotPlatformParams) ([]ChannelSession, error) { - rows, err := q.db.Query(ctx, listChannelSessionsByBotPlatform, arg.BotID, arg.Platform) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ChannelSession - for rows.Next() { - var i ChannelSession - if err := rows.Scan( - &i.SessionID, - &i.BotID, - &i.ChannelConfigID, - &i.UserID, - &i.ContactID, - &i.Platform, - &i.ReplyTarget, - &i.ThreadID, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const listUserChannelBindingsByType = `-- name: ListUserChannelBindingsByType :many -SELECT id, user_id, channel_type, config, created_at, updated_at +const listUserChannelBindingsByPlatform = `-- name: ListUserChannelBindingsByPlatform :many +SELECT id, user_id, platform, config, created_at, updated_at FROM user_channel_bindings -WHERE channel_type = $1 +WHERE platform = $1 ORDER BY created_at DESC ` -func (q *Queries) ListUserChannelBindingsByType(ctx context.Context, channelType string) ([]UserChannelBinding, error) { - rows, err := q.db.Query(ctx, listUserChannelBindingsByType, channelType) +func (q *Queries) ListUserChannelBindingsByPlatform(ctx context.Context, platform string) ([]UserChannelBinding, error) { + rows, err := q.db.Query(ctx, listUserChannelBindingsByPlatform, platform) if err != nil { return nil, err } @@ -240,7 +160,7 @@ func (q *Queries) ListUserChannelBindingsByType(ctx context.Context, channelType if err := rows.Scan( &i.ID, &i.UserID, - &i.ChannelType, + &i.Platform, &i.Config, &i.CreatedAt, &i.UpdatedAt, @@ -315,87 +235,29 @@ func (q *Queries) UpsertBotChannelConfig(ctx context.Context, arg UpsertBotChann return i, err } -const upsertChannelSession = `-- name: UpsertChannelSession :one -INSERT INTO channel_sessions (session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) -ON CONFLICT (session_id) -DO UPDATE SET - bot_id = EXCLUDED.bot_id, - channel_config_id = EXCLUDED.channel_config_id, - user_id = EXCLUDED.user_id, - contact_id = EXCLUDED.contact_id, - platform = EXCLUDED.platform, - reply_target = EXCLUDED.reply_target, - thread_id = EXCLUDED.thread_id, - metadata = EXCLUDED.metadata, - updated_at = now() -RETURNING session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata, created_at, updated_at -` - -type UpsertChannelSessionParams struct { - SessionID string `json:"session_id"` - BotID pgtype.UUID `json:"bot_id"` - ChannelConfigID pgtype.UUID `json:"channel_config_id"` - UserID pgtype.UUID `json:"user_id"` - ContactID pgtype.UUID `json:"contact_id"` - Platform string `json:"platform"` - ReplyTarget pgtype.Text `json:"reply_target"` - ThreadID pgtype.Text `json:"thread_id"` - Metadata []byte `json:"metadata"` -} - -func (q *Queries) UpsertChannelSession(ctx context.Context, arg UpsertChannelSessionParams) (ChannelSession, error) { - row := q.db.QueryRow(ctx, upsertChannelSession, - arg.SessionID, - arg.BotID, - arg.ChannelConfigID, - arg.UserID, - arg.ContactID, - arg.Platform, - arg.ReplyTarget, - arg.ThreadID, - arg.Metadata, - ) - var i ChannelSession - err := row.Scan( - &i.SessionID, - &i.BotID, - &i.ChannelConfigID, - &i.UserID, - &i.ContactID, - &i.Platform, - &i.ReplyTarget, - &i.ThreadID, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - const upsertUserChannelBinding = `-- name: UpsertUserChannelBinding :one -INSERT INTO user_channel_bindings (user_id, channel_type, config) +INSERT INTO user_channel_bindings (user_id, platform, config) VALUES ($1, $2, $3) -ON CONFLICT (user_id, channel_type) +ON CONFLICT (user_id, platform) DO UPDATE SET config = EXCLUDED.config, updated_at = now() -RETURNING id, user_id, channel_type, config, created_at, updated_at +RETURNING id, user_id, platform, config, created_at, updated_at ` type UpsertUserChannelBindingParams struct { - UserID pgtype.UUID `json:"user_id"` - ChannelType string `json:"channel_type"` - Config []byte `json:"config"` + UserID pgtype.UUID `json:"user_id"` + Platform string `json:"platform"` + Config []byte `json:"config"` } func (q *Queries) UpsertUserChannelBinding(ctx context.Context, arg UpsertUserChannelBindingParams) (UserChannelBinding, error) { - row := q.db.QueryRow(ctx, upsertUserChannelBinding, arg.UserID, arg.ChannelType, arg.Config) + row := q.db.QueryRow(ctx, upsertUserChannelBinding, arg.UserID, arg.Platform, arg.Config) var i UserChannelBinding err := row.Scan( &i.ID, &i.UserID, - &i.ChannelType, + &i.Platform, &i.Config, &i.CreatedAt, &i.UpdatedAt, diff --git a/internal/db/sqlc/chats.sql.go b/internal/db/sqlc/chats.sql.go new file mode 100644 index 00000000..0dce7be4 --- /dev/null +++ b/internal/db/sqlc/chats.sql.go @@ -0,0 +1,988 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: chats.sql + +package sqlc + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const addChatParticipant = `-- name: AddChatParticipant :one + +INSERT INTO chat_participants (chat_id, user_id, role) +VALUES ($1, $2, $3) +ON CONFLICT (chat_id, user_id) DO UPDATE SET role = EXCLUDED.role +RETURNING chat_id, user_id, role, joined_at +` + +type AddChatParticipantParams struct { + ChatID pgtype.UUID `json:"chat_id"` + UserID pgtype.UUID `json:"user_id"` + Role string `json:"role"` +} + +// chat_participants +func (q *Queries) AddChatParticipant(ctx context.Context, arg AddChatParticipantParams) (ChatParticipant, error) { + row := q.db.QueryRow(ctx, addChatParticipant, arg.ChatID, arg.UserID, arg.Role) + var i ChatParticipant + err := row.Scan( + &i.ChatID, + &i.UserID, + &i.Role, + &i.JoinedAt, + ) + return i, err +} + +const copyParticipantsToChat = `-- name: CopyParticipantsToChat :exec +INSERT INTO chat_participants (chat_id, user_id, role) +SELECT $2, cp.user_id, cp.role FROM chat_participants cp WHERE cp.chat_id = $1 +ON CONFLICT (chat_id, user_id) DO NOTHING +` + +type CopyParticipantsToChatParams struct { + ChatID pgtype.UUID `json:"chat_id"` + ChatID_2 pgtype.UUID `json:"chat_id_2"` +} + +func (q *Queries) CopyParticipantsToChat(ctx context.Context, arg CopyParticipantsToChatParams) error { + _, err := q.db.Exec(ctx, copyParticipantsToChat, arg.ChatID, arg.ChatID_2) + return err +} + +const createChat = `-- name: CreateChat :one +INSERT INTO chats (bot_id, kind, parent_chat_id, title, created_by_user_id, metadata) +VALUES ($1, $2, $3, $4, $5, $6) +RETURNING id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata, created_at, updated_at +` + +type CreateChatParams struct { + BotID pgtype.UUID `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID pgtype.UUID `json:"parent_chat_id"` + Title pgtype.Text `json:"title"` + CreatedByUserID pgtype.UUID `json:"created_by_user_id"` + Metadata []byte `json:"metadata"` +} + +func (q *Queries) CreateChat(ctx context.Context, arg CreateChatParams) (Chat, error) { + row := q.db.QueryRow(ctx, createChat, + arg.BotID, + arg.Kind, + arg.ParentChatID, + arg.Title, + arg.CreatedByUserID, + arg.Metadata, + ) + var i Chat + err := row.Scan( + &i.ID, + &i.BotID, + &i.Kind, + &i.ParentChatID, + &i.Title, + &i.CreatedByUserID, + &i.Metadata, + &i.EnableChatMemory, + &i.EnablePrivateMemory, + &i.EnablePublicMemory, + &i.ModelID, + &i.SettingsMetadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const createChatMessage = `-- name: CreateChatMessage :one + +INSERT INTO chat_messages (chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) +RETURNING id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at +` + +type CreateChatMessageParams struct { + ChatID pgtype.UUID `json:"chat_id"` + BotID pgtype.UUID `json:"bot_id"` + RouteID pgtype.UUID `json:"route_id"` + SenderChannelIdentityID pgtype.UUID `json:"sender_channel_identity_id"` + SenderUserID pgtype.UUID `json:"sender_user_id"` + Platform pgtype.Text `json:"platform"` + ExternalMessageID pgtype.Text `json:"external_message_id"` + Role string `json:"role"` + Content []byte `json:"content"` + Metadata []byte `json:"metadata"` +} + +// chat_messages +func (q *Queries) CreateChatMessage(ctx context.Context, arg CreateChatMessageParams) (ChatMessage, error) { + row := q.db.QueryRow(ctx, createChatMessage, + arg.ChatID, + arg.BotID, + arg.RouteID, + arg.SenderChannelIdentityID, + arg.SenderUserID, + arg.Platform, + arg.ExternalMessageID, + arg.Role, + arg.Content, + arg.Metadata, + ) + var i ChatMessage + err := row.Scan( + &i.ID, + &i.ChatID, + &i.BotID, + &i.RouteID, + &i.SenderChannelIdentityID, + &i.SenderUserID, + &i.Platform, + &i.ExternalMessageID, + &i.Role, + &i.Content, + &i.Metadata, + &i.CreatedAt, + ) + return i, err +} + +const createChatRoute = `-- name: CreateChatRoute :one + +INSERT INTO chat_routes (chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8) +RETURNING id, chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata, created_at, updated_at +` + +type CreateChatRouteParams struct { + ChatID pgtype.UUID `json:"chat_id"` + BotID pgtype.UUID `json:"bot_id"` + Platform string `json:"platform"` + ChannelConfigID pgtype.UUID `json:"channel_config_id"` + ConversationID string `json:"conversation_id"` + ThreadID pgtype.Text `json:"thread_id"` + ReplyTarget pgtype.Text `json:"reply_target"` + Metadata []byte `json:"metadata"` +} + +// chat_routes +func (q *Queries) CreateChatRoute(ctx context.Context, arg CreateChatRouteParams) (ChatRoute, error) { + row := q.db.QueryRow(ctx, createChatRoute, + arg.ChatID, + arg.BotID, + arg.Platform, + arg.ChannelConfigID, + arg.ConversationID, + arg.ThreadID, + arg.ReplyTarget, + arg.Metadata, + ) + var i ChatRoute + err := row.Scan( + &i.ID, + &i.ChatID, + &i.BotID, + &i.Platform, + &i.ChannelConfigID, + &i.ConversationID, + &i.ThreadID, + &i.ReplyTarget, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const deleteChat = `-- name: DeleteChat :exec +DELETE FROM chats WHERE id = $1 +` + +func (q *Queries) DeleteChat(ctx context.Context, id pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteChat, id) + return err +} + +const deleteChatMessagesByChat = `-- name: DeleteChatMessagesByChat :exec +DELETE FROM chat_messages WHERE chat_id = $1 +` + +func (q *Queries) DeleteChatMessagesByChat(ctx context.Context, chatID pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteChatMessagesByChat, chatID) + return err +} + +const deleteChatRoute = `-- name: DeleteChatRoute :exec +DELETE FROM chat_routes WHERE id = $1 +` + +func (q *Queries) DeleteChatRoute(ctx context.Context, id pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteChatRoute, id) + return err +} + +const findChatRoute = `-- name: FindChatRoute :one +SELECT id, chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata, created_at, updated_at +FROM chat_routes +WHERE bot_id = $1 AND platform = $2 AND conversation_id = $3 + AND COALESCE(thread_id, '') = COALESCE($4, '') +LIMIT 1 +` + +type FindChatRouteParams struct { + BotID pgtype.UUID `json:"bot_id"` + Platform string `json:"platform"` + ConversationID string `json:"conversation_id"` + ThreadID pgtype.Text `json:"thread_id"` +} + +func (q *Queries) FindChatRoute(ctx context.Context, arg FindChatRouteParams) (ChatRoute, error) { + row := q.db.QueryRow(ctx, findChatRoute, + arg.BotID, + arg.Platform, + arg.ConversationID, + arg.ThreadID, + ) + var i ChatRoute + err := row.Scan( + &i.ID, + &i.ChatID, + &i.BotID, + &i.Platform, + &i.ChannelConfigID, + &i.ConversationID, + &i.ThreadID, + &i.ReplyTarget, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getChatByID = `-- name: GetChatByID :one +SELECT id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata, created_at, updated_at +FROM chats +WHERE id = $1 +` + +func (q *Queries) GetChatByID(ctx context.Context, id pgtype.UUID) (Chat, error) { + row := q.db.QueryRow(ctx, getChatByID, id) + var i Chat + err := row.Scan( + &i.ID, + &i.BotID, + &i.Kind, + &i.ParentChatID, + &i.Title, + &i.CreatedByUserID, + &i.Metadata, + &i.EnableChatMemory, + &i.EnablePrivateMemory, + &i.EnablePublicMemory, + &i.ModelID, + &i.SettingsMetadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getChatParticipant = `-- name: GetChatParticipant :one +SELECT chat_id, user_id, role, joined_at +FROM chat_participants +WHERE chat_id = $1 AND user_id = $2 +` + +type GetChatParticipantParams struct { + ChatID pgtype.UUID `json:"chat_id"` + UserID pgtype.UUID `json:"user_id"` +} + +func (q *Queries) GetChatParticipant(ctx context.Context, arg GetChatParticipantParams) (ChatParticipant, error) { + row := q.db.QueryRow(ctx, getChatParticipant, arg.ChatID, arg.UserID) + var i ChatParticipant + err := row.Scan( + &i.ChatID, + &i.UserID, + &i.Role, + &i.JoinedAt, + ) + return i, err +} + +const getChatReadAccessByUser = `-- name: GetChatReadAccessByUser :one +WITH participant_access AS ( + SELECT 'participant'::text AS access_mode, + cp.role AS participant_role, + NULL::timestamptz AS last_observed_at + FROM chat_participants cp + WHERE cp.chat_id = $1 AND cp.user_id = $2 +), +observed_access AS ( + SELECT 'channel_identity_observed'::text AS access_mode, + ''::text AS participant_role, + MAX(cap.last_seen_at) AS last_observed_at + FROM chat_channel_identity_presence cap + JOIN channel_identities ci ON ci.id = cap.channel_identity_id + WHERE cap.chat_id = $1 AND ci.user_id = $2 + GROUP BY cap.chat_id +), +all_access AS ( + SELECT access_mode, participant_role, last_observed_at FROM participant_access + UNION ALL + SELECT access_mode, participant_role, last_observed_at FROM observed_access +) +SELECT access_mode, participant_role, last_observed_at +FROM all_access +ORDER BY CASE WHEN access_mode = 'participant' THEN 0 ELSE 1 END, last_observed_at DESC NULLS LAST +LIMIT 1 +` + +type GetChatReadAccessByUserParams struct { + ChatID pgtype.UUID `json:"chat_id"` + UserID pgtype.UUID `json:"user_id"` +} + +type GetChatReadAccessByUserRow struct { + AccessMode string `json:"access_mode"` + ParticipantRole string `json:"participant_role"` + LastObservedAt pgtype.Timestamptz `json:"last_observed_at"` +} + +func (q *Queries) GetChatReadAccessByUser(ctx context.Context, arg GetChatReadAccessByUserParams) (GetChatReadAccessByUserRow, error) { + row := q.db.QueryRow(ctx, getChatReadAccessByUser, arg.ChatID, arg.UserID) + var i GetChatReadAccessByUserRow + err := row.Scan(&i.AccessMode, &i.ParticipantRole, &i.LastObservedAt) + return i, err +} + +const getChatRouteByID = `-- name: GetChatRouteByID :one +SELECT id, chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata, created_at, updated_at +FROM chat_routes +WHERE id = $1 +` + +func (q *Queries) GetChatRouteByID(ctx context.Context, id pgtype.UUID) (ChatRoute, error) { + row := q.db.QueryRow(ctx, getChatRouteByID, id) + var i ChatRoute + err := row.Scan( + &i.ID, + &i.ChatID, + &i.BotID, + &i.Platform, + &i.ChannelConfigID, + &i.ConversationID, + &i.ThreadID, + &i.ReplyTarget, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getChatSettings = `-- name: GetChatSettings :one +SELECT id AS chat_id, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata AS metadata, updated_at +FROM chats +WHERE id = $1 +` + +type GetChatSettingsRow struct { + ChatID pgtype.UUID `json:"chat_id"` + EnableChatMemory bool `json:"enable_chat_memory"` + EnablePrivateMemory bool `json:"enable_private_memory"` + EnablePublicMemory bool `json:"enable_public_memory"` + ModelID pgtype.Text `json:"model_id"` + Metadata []byte `json:"metadata"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) GetChatSettings(ctx context.Context, id pgtype.UUID) (GetChatSettingsRow, error) { + row := q.db.QueryRow(ctx, getChatSettings, id) + var i GetChatSettingsRow + err := row.Scan( + &i.ChatID, + &i.EnableChatMemory, + &i.EnablePrivateMemory, + &i.EnablePublicMemory, + &i.ModelID, + &i.Metadata, + &i.UpdatedAt, + ) + return i, err +} + +const listChatMessages = `-- name: ListChatMessages :many +SELECT id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at +FROM chat_messages +WHERE chat_id = $1 +ORDER BY created_at ASC +` + +func (q *Queries) ListChatMessages(ctx context.Context, chatID pgtype.UUID) ([]ChatMessage, error) { + rows, err := q.db.Query(ctx, listChatMessages, chatID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ChatMessage + for rows.Next() { + var i ChatMessage + if err := rows.Scan( + &i.ID, + &i.ChatID, + &i.BotID, + &i.RouteID, + &i.SenderChannelIdentityID, + &i.SenderUserID, + &i.Platform, + &i.ExternalMessageID, + &i.Role, + &i.Content, + &i.Metadata, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listChatMessagesBefore = `-- name: ListChatMessagesBefore :many +SELECT id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at +FROM chat_messages +WHERE chat_id = $1 AND created_at < $2 +ORDER BY created_at DESC +LIMIT $3 +` + +type ListChatMessagesBeforeParams struct { + ChatID pgtype.UUID `json:"chat_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + Limit int32 `json:"limit"` +} + +func (q *Queries) ListChatMessagesBefore(ctx context.Context, arg ListChatMessagesBeforeParams) ([]ChatMessage, error) { + rows, err := q.db.Query(ctx, listChatMessagesBefore, arg.ChatID, arg.CreatedAt, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ChatMessage + for rows.Next() { + var i ChatMessage + if err := rows.Scan( + &i.ID, + &i.ChatID, + &i.BotID, + &i.RouteID, + &i.SenderChannelIdentityID, + &i.SenderUserID, + &i.Platform, + &i.ExternalMessageID, + &i.Role, + &i.Content, + &i.Metadata, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listChatMessagesLatest = `-- name: ListChatMessagesLatest :many +SELECT id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at +FROM chat_messages +WHERE chat_id = $1 +ORDER BY created_at DESC +LIMIT $2 +` + +type ListChatMessagesLatestParams struct { + ChatID pgtype.UUID `json:"chat_id"` + Limit int32 `json:"limit"` +} + +func (q *Queries) ListChatMessagesLatest(ctx context.Context, arg ListChatMessagesLatestParams) ([]ChatMessage, error) { + rows, err := q.db.Query(ctx, listChatMessagesLatest, arg.ChatID, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ChatMessage + for rows.Next() { + var i ChatMessage + if err := rows.Scan( + &i.ID, + &i.ChatID, + &i.BotID, + &i.RouteID, + &i.SenderChannelIdentityID, + &i.SenderUserID, + &i.Platform, + &i.ExternalMessageID, + &i.Role, + &i.Content, + &i.Metadata, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listChatMessagesSince = `-- name: ListChatMessagesSince :many +SELECT id, chat_id, bot_id, route_id, sender_channel_identity_id, sender_user_id, platform, external_message_id, role, content, metadata, created_at +FROM chat_messages +WHERE chat_id = $1 AND created_at >= $2 +ORDER BY created_at ASC +` + +type ListChatMessagesSinceParams struct { + ChatID pgtype.UUID `json:"chat_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + +func (q *Queries) ListChatMessagesSince(ctx context.Context, arg ListChatMessagesSinceParams) ([]ChatMessage, error) { + rows, err := q.db.Query(ctx, listChatMessagesSince, arg.ChatID, arg.CreatedAt) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ChatMessage + for rows.Next() { + var i ChatMessage + if err := rows.Scan( + &i.ID, + &i.ChatID, + &i.BotID, + &i.RouteID, + &i.SenderChannelIdentityID, + &i.SenderUserID, + &i.Platform, + &i.ExternalMessageID, + &i.Role, + &i.Content, + &i.Metadata, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listChatParticipants = `-- name: ListChatParticipants :many +SELECT chat_id, user_id, role, joined_at +FROM chat_participants +WHERE chat_id = $1 +ORDER BY joined_at ASC +` + +func (q *Queries) ListChatParticipants(ctx context.Context, chatID pgtype.UUID) ([]ChatParticipant, error) { + rows, err := q.db.Query(ctx, listChatParticipants, chatID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ChatParticipant + for rows.Next() { + var i ChatParticipant + if err := rows.Scan( + &i.ChatID, + &i.UserID, + &i.Role, + &i.JoinedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listChatRoutes = `-- name: ListChatRoutes :many +SELECT id, chat_id, bot_id, platform, channel_config_id, conversation_id, thread_id, reply_target, metadata, created_at, updated_at +FROM chat_routes +WHERE chat_id = $1 +ORDER BY created_at ASC +` + +func (q *Queries) ListChatRoutes(ctx context.Context, chatID pgtype.UUID) ([]ChatRoute, error) { + rows, err := q.db.Query(ctx, listChatRoutes, chatID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ChatRoute + for rows.Next() { + var i ChatRoute + if err := rows.Scan( + &i.ID, + &i.ChatID, + &i.BotID, + &i.Platform, + &i.ChannelConfigID, + &i.ConversationID, + &i.ThreadID, + &i.ReplyTarget, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listChatsByBotAndUser = `-- name: ListChatsByBotAndUser :many +SELECT c.id, c.bot_id, c.kind, c.parent_chat_id, c.title, c.created_by_user_id, c.metadata, c.enable_chat_memory, c.enable_private_memory, c.enable_public_memory, c.model_id, c.settings_metadata, c.created_at, c.updated_at +FROM chats c +JOIN chat_participants cp ON cp.chat_id = c.id +WHERE c.bot_id = $1 AND cp.user_id = $2 +ORDER BY c.updated_at DESC +` + +type ListChatsByBotAndUserParams struct { + BotID pgtype.UUID `json:"bot_id"` + UserID pgtype.UUID `json:"user_id"` +} + +func (q *Queries) ListChatsByBotAndUser(ctx context.Context, arg ListChatsByBotAndUserParams) ([]Chat, error) { + rows, err := q.db.Query(ctx, listChatsByBotAndUser, arg.BotID, arg.UserID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Chat + for rows.Next() { + var i Chat + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.Kind, + &i.ParentChatID, + &i.Title, + &i.CreatedByUserID, + &i.Metadata, + &i.EnableChatMemory, + &i.EnablePrivateMemory, + &i.EnablePublicMemory, + &i.ModelID, + &i.SettingsMetadata, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listThreadsByParent = `-- name: ListThreadsByParent :many +SELECT id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata, created_at, updated_at +FROM chats +WHERE parent_chat_id = $1 AND kind = 'thread' +ORDER BY created_at DESC +` + +func (q *Queries) ListThreadsByParent(ctx context.Context, parentChatID pgtype.UUID) ([]Chat, error) { + rows, err := q.db.Query(ctx, listThreadsByParent, parentChatID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Chat + for rows.Next() { + var i Chat + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.Kind, + &i.ParentChatID, + &i.Title, + &i.CreatedByUserID, + &i.Metadata, + &i.EnableChatMemory, + &i.EnablePrivateMemory, + &i.EnablePublicMemory, + &i.ModelID, + &i.SettingsMetadata, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listVisibleChatsByBotAndUser = `-- name: ListVisibleChatsByBotAndUser :many +WITH participant_chats AS ( + SELECT c.id, c.bot_id, c.kind, c.parent_chat_id, c.title, c.created_by_user_id, c.metadata, c.created_at, c.updated_at, + 'participant'::text AS access_mode, + cp.role AS participant_role, + NULL::timestamptz AS last_observed_at + FROM chats c + JOIN chat_participants cp ON cp.chat_id = c.id + WHERE c.bot_id = $1 AND cp.user_id = $2 +), +observed_chats AS ( + SELECT c.id, c.bot_id, c.kind, c.parent_chat_id, c.title, c.created_by_user_id, c.metadata, c.created_at, c.updated_at, + 'channel_identity_observed'::text AS access_mode, + ''::text AS participant_role, + MAX(cap.last_seen_at) AS last_observed_at + FROM chats c + JOIN chat_channel_identity_presence cap ON cap.chat_id = c.id + JOIN channel_identities ci ON ci.id = cap.channel_identity_id + WHERE c.bot_id = $1 + AND ci.user_id = $2 + AND NOT EXISTS ( + SELECT 1 FROM chat_participants cp + WHERE cp.chat_id = c.id AND cp.user_id = $2 + ) + GROUP BY c.id, c.bot_id, c.kind, c.parent_chat_id, c.title, c.created_by_user_id, c.metadata, c.created_at, c.updated_at +) +SELECT id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, created_at, updated_at, + access_mode, participant_role, last_observed_at +FROM ( + SELECT id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, created_at, updated_at, access_mode, participant_role, last_observed_at FROM participant_chats + UNION ALL + SELECT id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, created_at, updated_at, access_mode, participant_role, last_observed_at FROM observed_chats +) v +ORDER BY v.updated_at DESC, v.last_observed_at DESC NULLS LAST +` + +type ListVisibleChatsByBotAndUserParams struct { + BotID pgtype.UUID `json:"bot_id"` + UserID pgtype.UUID `json:"user_id"` +} + +type ListVisibleChatsByBotAndUserRow struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID pgtype.UUID `json:"parent_chat_id"` + Title pgtype.Text `json:"title"` + CreatedByUserID pgtype.UUID `json:"created_by_user_id"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` + AccessMode string `json:"access_mode"` + ParticipantRole string `json:"participant_role"` + LastObservedAt pgtype.Timestamptz `json:"last_observed_at"` +} + +func (q *Queries) ListVisibleChatsByBotAndUser(ctx context.Context, arg ListVisibleChatsByBotAndUserParams) ([]ListVisibleChatsByBotAndUserRow, error) { + rows, err := q.db.Query(ctx, listVisibleChatsByBotAndUser, arg.BotID, arg.UserID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListVisibleChatsByBotAndUserRow + for rows.Next() { + var i ListVisibleChatsByBotAndUserRow + if err := rows.Scan( + &i.ID, + &i.BotID, + &i.Kind, + &i.ParentChatID, + &i.Title, + &i.CreatedByUserID, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + &i.AccessMode, + &i.ParticipantRole, + &i.LastObservedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const removeChatParticipant = `-- name: RemoveChatParticipant :exec +DELETE FROM chat_participants WHERE chat_id = $1 AND user_id = $2 +` + +type RemoveChatParticipantParams struct { + ChatID pgtype.UUID `json:"chat_id"` + UserID pgtype.UUID `json:"user_id"` +} + +func (q *Queries) RemoveChatParticipant(ctx context.Context, arg RemoveChatParticipantParams) error { + _, err := q.db.Exec(ctx, removeChatParticipant, arg.ChatID, arg.UserID) + return err +} + +const touchChat = `-- name: TouchChat :exec +UPDATE chats SET updated_at = now() WHERE id = $1 +` + +func (q *Queries) TouchChat(ctx context.Context, id pgtype.UUID) error { + _, err := q.db.Exec(ctx, touchChat, id) + return err +} + +const updateChatRouteReplyTarget = `-- name: UpdateChatRouteReplyTarget :exec +UPDATE chat_routes SET reply_target = $2, updated_at = now() WHERE id = $1 +` + +type UpdateChatRouteReplyTargetParams struct { + ID pgtype.UUID `json:"id"` + ReplyTarget pgtype.Text `json:"reply_target"` +} + +func (q *Queries) UpdateChatRouteReplyTarget(ctx context.Context, arg UpdateChatRouteReplyTargetParams) error { + _, err := q.db.Exec(ctx, updateChatRouteReplyTarget, arg.ID, arg.ReplyTarget) + return err +} + +const updateChatTitle = `-- name: UpdateChatTitle :one +UPDATE chats SET title = $2, updated_at = now() +WHERE id = $1 +RETURNING id, bot_id, kind, parent_chat_id, title, created_by_user_id, metadata, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata, created_at, updated_at +` + +type UpdateChatTitleParams struct { + ID pgtype.UUID `json:"id"` + Title pgtype.Text `json:"title"` +} + +func (q *Queries) UpdateChatTitle(ctx context.Context, arg UpdateChatTitleParams) (Chat, error) { + row := q.db.QueryRow(ctx, updateChatTitle, arg.ID, arg.Title) + var i Chat + err := row.Scan( + &i.ID, + &i.BotID, + &i.Kind, + &i.ParentChatID, + &i.Title, + &i.CreatedByUserID, + &i.Metadata, + &i.EnableChatMemory, + &i.EnablePrivateMemory, + &i.EnablePublicMemory, + &i.ModelID, + &i.SettingsMetadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const upsertChatChannelIdentityPresence = `-- name: UpsertChatChannelIdentityPresence :exec +INSERT INTO chat_channel_identity_presence (chat_id, channel_identity_id, first_seen_at, last_seen_at, message_count) +VALUES ($1, $2, now(), now(), 1) +ON CONFLICT (chat_id, channel_identity_id) +DO UPDATE SET + last_seen_at = now(), + message_count = chat_channel_identity_presence.message_count + 1 +` + +type UpsertChatChannelIdentityPresenceParams struct { + ChatID pgtype.UUID `json:"chat_id"` + ChannelIdentityID pgtype.UUID `json:"channel_identity_id"` +} + +func (q *Queries) UpsertChatChannelIdentityPresence(ctx context.Context, arg UpsertChatChannelIdentityPresenceParams) error { + _, err := q.db.Exec(ctx, upsertChatChannelIdentityPresence, arg.ChatID, arg.ChannelIdentityID) + return err +} + +const upsertChatSettings = `-- name: UpsertChatSettings :one + +UPDATE chats +SET enable_chat_memory = $2, + enable_private_memory = $3, + enable_public_memory = $4, + model_id = $5, + settings_metadata = $6 +WHERE id = $1 +RETURNING id AS chat_id, enable_chat_memory, enable_private_memory, enable_public_memory, model_id, settings_metadata AS metadata, updated_at +` + +type UpsertChatSettingsParams struct { + ID pgtype.UUID `json:"id"` + EnableChatMemory bool `json:"enable_chat_memory"` + EnablePrivateMemory bool `json:"enable_private_memory"` + EnablePublicMemory bool `json:"enable_public_memory"` + ModelID pgtype.Text `json:"model_id"` + SettingsMetadata []byte `json:"settings_metadata"` +} + +type UpsertChatSettingsRow struct { + ChatID pgtype.UUID `json:"chat_id"` + EnableChatMemory bool `json:"enable_chat_memory"` + EnablePrivateMemory bool `json:"enable_private_memory"` + EnablePublicMemory bool `json:"enable_public_memory"` + ModelID pgtype.Text `json:"model_id"` + Metadata []byte `json:"metadata"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +// chat_settings +func (q *Queries) UpsertChatSettings(ctx context.Context, arg UpsertChatSettingsParams) (UpsertChatSettingsRow, error) { + row := q.db.QueryRow(ctx, upsertChatSettings, + arg.ID, + arg.EnableChatMemory, + arg.EnablePrivateMemory, + arg.EnablePublicMemory, + arg.ModelID, + arg.SettingsMetadata, + ) + var i UpsertChatSettingsRow + err := row.Scan( + &i.ChatID, + &i.EnableChatMemory, + &i.EnablePrivateMemory, + &i.EnablePublicMemory, + &i.ModelID, + &i.Metadata, + &i.UpdatedAt, + ) + return i, err +} diff --git a/internal/db/sqlc/contacts.sql.go b/internal/db/sqlc/contacts.sql.go deleted file mode 100644 index 3cf19028..00000000 --- a/internal/db/sqlc/contacts.sql.go +++ /dev/null @@ -1,380 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. -// versions: -// sqlc v1.30.0 -// source: contacts.sql - -package sqlc - -import ( - "context" - - "github.com/jackc/pgx/v5/pgtype" -) - -const createContact = `-- name: CreateContact :one -INSERT INTO contacts (bot_id, user_id, display_name, alias, tags, status, metadata) -VALUES ($1, $2, $3, $4, $5, $6, $7) -RETURNING id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -` - -type CreateContactParams struct { - BotID pgtype.UUID `json:"bot_id"` - UserID pgtype.UUID `json:"user_id"` - DisplayName pgtype.Text `json:"display_name"` - Alias pgtype.Text `json:"alias"` - Tags []string `json:"tags"` - Status string `json:"status"` - Metadata []byte `json:"metadata"` -} - -func (q *Queries) CreateContact(ctx context.Context, arg CreateContactParams) (Contact, error) { - row := q.db.QueryRow(ctx, createContact, - arg.BotID, - arg.UserID, - arg.DisplayName, - arg.Alias, - arg.Tags, - arg.Status, - arg.Metadata, - ) - var i Contact - err := row.Scan( - &i.ID, - &i.BotID, - &i.UserID, - &i.DisplayName, - &i.Alias, - &i.Tags, - &i.Status, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const getContactByID = `-- name: GetContactByID :one -SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -FROM contacts -WHERE id = $1 -LIMIT 1 -` - -func (q *Queries) GetContactByID(ctx context.Context, id pgtype.UUID) (Contact, error) { - row := q.db.QueryRow(ctx, getContactByID, id) - var i Contact - err := row.Scan( - &i.ID, - &i.BotID, - &i.UserID, - &i.DisplayName, - &i.Alias, - &i.Tags, - &i.Status, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const getContactByUserID = `-- name: GetContactByUserID :one -SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -FROM contacts -WHERE bot_id = $1 AND user_id = $2 -LIMIT 1 -` - -type GetContactByUserIDParams struct { - BotID pgtype.UUID `json:"bot_id"` - UserID pgtype.UUID `json:"user_id"` -} - -func (q *Queries) GetContactByUserID(ctx context.Context, arg GetContactByUserIDParams) (Contact, error) { - row := q.db.QueryRow(ctx, getContactByUserID, arg.BotID, arg.UserID) - var i Contact - err := row.Scan( - &i.ID, - &i.BotID, - &i.UserID, - &i.DisplayName, - &i.Alias, - &i.Tags, - &i.Status, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const getContactChannelByIdentity = `-- name: GetContactChannelByIdentity :one -SELECT id, bot_id, contact_id, platform, external_id, metadata, created_at, updated_at -FROM contact_channels -WHERE bot_id = $1 AND platform = $2 AND external_id = $3 -LIMIT 1 -` - -type GetContactChannelByIdentityParams struct { - BotID pgtype.UUID `json:"bot_id"` - Platform string `json:"platform"` - ExternalID string `json:"external_id"` -} - -func (q *Queries) GetContactChannelByIdentity(ctx context.Context, arg GetContactChannelByIdentityParams) (ContactChannel, error) { - row := q.db.QueryRow(ctx, getContactChannelByIdentity, arg.BotID, arg.Platform, arg.ExternalID) - var i ContactChannel - err := row.Scan( - &i.ID, - &i.BotID, - &i.ContactID, - &i.Platform, - &i.ExternalID, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const listContactChannelsByContact = `-- name: ListContactChannelsByContact :many -SELECT id, bot_id, contact_id, platform, external_id, metadata, created_at, updated_at -FROM contact_channels -WHERE contact_id = $1 -ORDER BY created_at DESC -` - -func (q *Queries) ListContactChannelsByContact(ctx context.Context, contactID pgtype.UUID) ([]ContactChannel, error) { - rows, err := q.db.Query(ctx, listContactChannelsByContact, contactID) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ContactChannel - for rows.Next() { - var i ContactChannel - if err := rows.Scan( - &i.ID, - &i.BotID, - &i.ContactID, - &i.Platform, - &i.ExternalID, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const listContactsByBot = `-- name: ListContactsByBot :many -SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -FROM contacts -WHERE bot_id = $1 -ORDER BY created_at DESC -` - -func (q *Queries) ListContactsByBot(ctx context.Context, botID pgtype.UUID) ([]Contact, error) { - rows, err := q.db.Query(ctx, listContactsByBot, botID) - if err != nil { - return nil, err - } - defer rows.Close() - var items []Contact - for rows.Next() { - var i Contact - if err := rows.Scan( - &i.ID, - &i.BotID, - &i.UserID, - &i.DisplayName, - &i.Alias, - &i.Tags, - &i.Status, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const searchContacts = `-- name: SearchContacts :many -SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -FROM contacts -WHERE bot_id = $1 - AND ( - display_name ILIKE $2 - OR alias ILIKE $2 - OR EXISTS ( - SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE $2 - ) - ) -ORDER BY created_at DESC -` - -type SearchContactsParams struct { - BotID pgtype.UUID `json:"bot_id"` - Query pgtype.Text `json:"query"` -} - -func (q *Queries) SearchContacts(ctx context.Context, arg SearchContactsParams) ([]Contact, error) { - rows, err := q.db.Query(ctx, searchContacts, arg.BotID, arg.Query) - if err != nil { - return nil, err - } - defer rows.Close() - var items []Contact - for rows.Next() { - var i Contact - if err := rows.Scan( - &i.ID, - &i.BotID, - &i.UserID, - &i.DisplayName, - &i.Alias, - &i.Tags, - &i.Status, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const updateContact = `-- name: UpdateContact :one -UPDATE contacts -SET display_name = COALESCE($1, display_name), - alias = COALESCE($2, alias), - tags = COALESCE($3, tags), - status = COALESCE(NULLIF($4::text, ''), status), - metadata = COALESCE($5, metadata), - updated_at = now() -WHERE id = $6 -RETURNING id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -` - -type UpdateContactParams struct { - DisplayName pgtype.Text `json:"display_name"` - Alias pgtype.Text `json:"alias"` - Tags []string `json:"tags"` - Status string `json:"status"` - Metadata []byte `json:"metadata"` - ID pgtype.UUID `json:"id"` -} - -func (q *Queries) UpdateContact(ctx context.Context, arg UpdateContactParams) (Contact, error) { - row := q.db.QueryRow(ctx, updateContact, - arg.DisplayName, - arg.Alias, - arg.Tags, - arg.Status, - arg.Metadata, - arg.ID, - ) - var i Contact - err := row.Scan( - &i.ID, - &i.BotID, - &i.UserID, - &i.DisplayName, - &i.Alias, - &i.Tags, - &i.Status, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const updateContactUser = `-- name: UpdateContactUser :one -UPDATE contacts -SET user_id = $2, - updated_at = now() -WHERE id = $1 -RETURNING id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at -` - -type UpdateContactUserParams struct { - ID pgtype.UUID `json:"id"` - UserID pgtype.UUID `json:"user_id"` -} - -func (q *Queries) UpdateContactUser(ctx context.Context, arg UpdateContactUserParams) (Contact, error) { - row := q.db.QueryRow(ctx, updateContactUser, arg.ID, arg.UserID) - var i Contact - err := row.Scan( - &i.ID, - &i.BotID, - &i.UserID, - &i.DisplayName, - &i.Alias, - &i.Tags, - &i.Status, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const upsertContactChannel = `-- name: UpsertContactChannel :one -INSERT INTO contact_channels (bot_id, contact_id, platform, external_id, metadata) -VALUES ($1, $2, $3, $4, $5) -ON CONFLICT (bot_id, platform, external_id) -DO UPDATE SET - contact_id = EXCLUDED.contact_id, - metadata = EXCLUDED.metadata, - updated_at = now() -RETURNING id, bot_id, contact_id, platform, external_id, metadata, created_at, updated_at -` - -type UpsertContactChannelParams struct { - BotID pgtype.UUID `json:"bot_id"` - ContactID pgtype.UUID `json:"contact_id"` - Platform string `json:"platform"` - ExternalID string `json:"external_id"` - Metadata []byte `json:"metadata"` -} - -func (q *Queries) UpsertContactChannel(ctx context.Context, arg UpsertContactChannelParams) (ContactChannel, error) { - row := q.db.QueryRow(ctx, upsertContactChannel, - arg.BotID, - arg.ContactID, - arg.Platform, - arg.ExternalID, - arg.Metadata, - ) - var i ContactChannel - err := row.Scan( - &i.ID, - &i.BotID, - &i.ContactID, - &i.Platform, - &i.ExternalID, - &i.Metadata, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} diff --git a/internal/db/sqlc/history.sql.go b/internal/db/sqlc/history.sql.go deleted file mode 100644 index 0fc2033c..00000000 --- a/internal/db/sqlc/history.sql.go +++ /dev/null @@ -1,178 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. -// versions: -// sqlc v1.30.0 -// source: history.sql - -package sqlc - -import ( - "context" - - "github.com/jackc/pgx/v5/pgtype" -) - -const createHistory = `-- name: CreateHistory :one -INSERT INTO history (bot_id, session_id, messages, metadata, skills, timestamp) -VALUES ($1, $2, $3, $4, $5, $6) -RETURNING id, bot_id, session_id, messages, metadata, skills, timestamp -` - -type CreateHistoryParams struct { - BotID pgtype.UUID `json:"bot_id"` - SessionID string `json:"session_id"` - Messages []byte `json:"messages"` - Metadata []byte `json:"metadata"` - Skills []string `json:"skills"` - Timestamp pgtype.Timestamptz `json:"timestamp"` -} - -func (q *Queries) CreateHistory(ctx context.Context, arg CreateHistoryParams) (History, error) { - row := q.db.QueryRow(ctx, createHistory, - arg.BotID, - arg.SessionID, - arg.Messages, - arg.Metadata, - arg.Skills, - arg.Timestamp, - ) - var i History - err := row.Scan( - &i.ID, - &i.BotID, - &i.SessionID, - &i.Messages, - &i.Metadata, - &i.Skills, - &i.Timestamp, - ) - return i, err -} - -const deleteHistoryByBotSession = `-- name: DeleteHistoryByBotSession :exec -DELETE FROM history -WHERE bot_id = $1 AND session_id = $2 -` - -type DeleteHistoryByBotSessionParams struct { - BotID pgtype.UUID `json:"bot_id"` - SessionID string `json:"session_id"` -} - -func (q *Queries) DeleteHistoryByBotSession(ctx context.Context, arg DeleteHistoryByBotSessionParams) error { - _, err := q.db.Exec(ctx, deleteHistoryByBotSession, arg.BotID, arg.SessionID) - return err -} - -const deleteHistoryByID = `-- name: DeleteHistoryByID :exec -DELETE FROM history -WHERE id = $1 -` - -func (q *Queries) DeleteHistoryByID(ctx context.Context, id pgtype.UUID) error { - _, err := q.db.Exec(ctx, deleteHistoryByID, id) - return err -} - -const getHistoryByID = `-- name: GetHistoryByID :one -SELECT id, bot_id, session_id, messages, metadata, skills, timestamp -FROM history -WHERE id = $1 -` - -func (q *Queries) GetHistoryByID(ctx context.Context, id pgtype.UUID) (History, error) { - row := q.db.QueryRow(ctx, getHistoryByID, id) - var i History - err := row.Scan( - &i.ID, - &i.BotID, - &i.SessionID, - &i.Messages, - &i.Metadata, - &i.Skills, - &i.Timestamp, - ) - return i, err -} - -const listHistoryByBotSession = `-- name: ListHistoryByBotSession :many -SELECT id, bot_id, session_id, messages, metadata, skills, timestamp -FROM history -WHERE bot_id = $1 AND session_id = $2 -ORDER BY timestamp DESC -LIMIT $3 -` - -type ListHistoryByBotSessionParams struct { - BotID pgtype.UUID `json:"bot_id"` - SessionID string `json:"session_id"` - Limit int32 `json:"limit"` -} - -func (q *Queries) ListHistoryByBotSession(ctx context.Context, arg ListHistoryByBotSessionParams) ([]History, error) { - rows, err := q.db.Query(ctx, listHistoryByBotSession, arg.BotID, arg.SessionID, arg.Limit) - if err != nil { - return nil, err - } - defer rows.Close() - var items []History - for rows.Next() { - var i History - if err := rows.Scan( - &i.ID, - &i.BotID, - &i.SessionID, - &i.Messages, - &i.Metadata, - &i.Skills, - &i.Timestamp, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const listHistoryByBotSessionSince = `-- name: ListHistoryByBotSessionSince :many -SELECT id, bot_id, session_id, messages, metadata, skills, timestamp -FROM history -WHERE bot_id = $1 AND session_id = $2 AND timestamp >= $3 -ORDER BY timestamp ASC -` - -type ListHistoryByBotSessionSinceParams struct { - BotID pgtype.UUID `json:"bot_id"` - SessionID string `json:"session_id"` - Timestamp pgtype.Timestamptz `json:"timestamp"` -} - -func (q *Queries) ListHistoryByBotSessionSince(ctx context.Context, arg ListHistoryByBotSessionSinceParams) ([]History, error) { - rows, err := q.db.Query(ctx, listHistoryByBotSessionSince, arg.BotID, arg.SessionID, arg.Timestamp) - if err != nil { - return nil, err - } - defer rows.Close() - var items []History - for rows.Next() { - var i History - if err := rows.Scan( - &i.ID, - &i.BotID, - &i.SessionID, - &i.Messages, - &i.Metadata, - &i.Skills, - &i.Timestamp, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} diff --git a/internal/db/sqlc/models.go b/internal/db/sqlc/models.go index 0341bf03..3b43891a 100644 --- a/internal/db/sqlc/models.go +++ b/internal/db/sqlc/models.go @@ -9,15 +9,21 @@ import ( ) type Bot struct { - ID pgtype.UUID `json:"id"` - OwnerUserID pgtype.UUID `json:"owner_user_id"` - Type string `json:"type"` - DisplayName pgtype.Text `json:"display_name"` - AvatarUrl pgtype.Text `json:"avatar_url"` - IsActive bool `json:"is_active"` - Metadata []byte `json:"metadata"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` + ID pgtype.UUID `json:"id"` + OwnerUserID pgtype.UUID `json:"owner_user_id"` + Type string `json:"type"` + DisplayName pgtype.Text `json:"display_name"` + AvatarUrl pgtype.Text `json:"avatar_url"` + IsActive bool `json:"is_active"` + MaxContextLoadTime int32 `json:"max_context_load_time"` + Language string `json:"language"` + AllowGuest bool `json:"allow_guest"` + ChatModelID pgtype.UUID `json:"chat_model_id"` + MemoryModelID pgtype.UUID `json:"memory_model_id"` + EmbeddingModelID pgtype.UUID `json:"embedding_model_id"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` } type BotChannelConfig struct { @@ -42,13 +48,6 @@ type BotMember struct { CreatedAt pgtype.Timestamptz `json:"created_at"` } -type BotModelConfig struct { - BotID pgtype.UUID `json:"bot_id"` - ChatModelID pgtype.UUID `json:"chat_model_id"` - EmbeddingModelID pgtype.UUID `json:"embedding_model_id"` - MemoryModelID pgtype.UUID `json:"memory_model_id"` -} - type BotPreauthKey struct { ID pgtype.UUID `json:"id"` BotID pgtype.UUID `json:"bot_id"` @@ -59,51 +58,89 @@ type BotPreauthKey struct { CreatedAt pgtype.Timestamptz `json:"created_at"` } -type BotSetting struct { - BotID pgtype.UUID `json:"bot_id"` - MaxContextLoadTime int32 `json:"max_context_load_time"` - Language string `json:"language"` - AllowGuest bool `json:"allow_guest"` +type ChannelIdentity struct { + ID pgtype.UUID `json:"id"` + UserID pgtype.UUID `json:"user_id"` + Channel string `json:"channel"` + ChannelSubjectID string `json:"channel_subject_id"` + DisplayName pgtype.Text `json:"display_name"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` } -type ChannelSession struct { - SessionID string `json:"session_id"` +type ChannelIdentityBindCode struct { + ID pgtype.UUID `json:"id"` + Token string `json:"token"` + IssuedByUserID pgtype.UUID `json:"issued_by_user_id"` + Platform pgtype.Text `json:"platform"` + ExpiresAt pgtype.Timestamptz `json:"expires_at"` + UsedAt pgtype.Timestamptz `json:"used_at"` + UsedByChannelIdentityID pgtype.UUID `json:"used_by_channel_identity_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + +type Chat struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + Kind string `json:"kind"` + ParentChatID pgtype.UUID `json:"parent_chat_id"` + Title pgtype.Text `json:"title"` + CreatedByUserID pgtype.UUID `json:"created_by_user_id"` + Metadata []byte `json:"metadata"` + EnableChatMemory bool `json:"enable_chat_memory"` + EnablePrivateMemory bool `json:"enable_private_memory"` + EnablePublicMemory bool `json:"enable_public_memory"` + ModelID pgtype.Text `json:"model_id"` + SettingsMetadata []byte `json:"settings_metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +type ChatChannelIdentityPresence struct { + ChatID pgtype.UUID `json:"chat_id"` + ChannelIdentityID pgtype.UUID `json:"channel_identity_id"` + FirstSeenAt pgtype.Timestamptz `json:"first_seen_at"` + LastSeenAt pgtype.Timestamptz `json:"last_seen_at"` + MessageCount int64 `json:"message_count"` +} + +type ChatMessage struct { + ID pgtype.UUID `json:"id"` + ChatID pgtype.UUID `json:"chat_id"` + BotID pgtype.UUID `json:"bot_id"` + RouteID pgtype.UUID `json:"route_id"` + SenderChannelIdentityID pgtype.UUID `json:"sender_channel_identity_id"` + SenderUserID pgtype.UUID `json:"sender_user_id"` + Platform pgtype.Text `json:"platform"` + ExternalMessageID pgtype.Text `json:"external_message_id"` + Role string `json:"role"` + Content []byte `json:"content"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + +type ChatParticipant struct { + ChatID pgtype.UUID `json:"chat_id"` + UserID pgtype.UUID `json:"user_id"` + Role string `json:"role"` + JoinedAt pgtype.Timestamptz `json:"joined_at"` +} + +type ChatRoute struct { + ID pgtype.UUID `json:"id"` + ChatID pgtype.UUID `json:"chat_id"` BotID pgtype.UUID `json:"bot_id"` - ChannelConfigID pgtype.UUID `json:"channel_config_id"` - UserID pgtype.UUID `json:"user_id"` - ContactID pgtype.UUID `json:"contact_id"` Platform string `json:"platform"` - ReplyTarget pgtype.Text `json:"reply_target"` + ChannelConfigID pgtype.UUID `json:"channel_config_id"` + ConversationID string `json:"conversation_id"` ThreadID pgtype.Text `json:"thread_id"` + ReplyTarget pgtype.Text `json:"reply_target"` Metadata []byte `json:"metadata"` CreatedAt pgtype.Timestamptz `json:"created_at"` UpdatedAt pgtype.Timestamptz `json:"updated_at"` } -type Contact struct { - ID pgtype.UUID `json:"id"` - BotID pgtype.UUID `json:"bot_id"` - UserID pgtype.UUID `json:"user_id"` - DisplayName pgtype.Text `json:"display_name"` - Alias pgtype.Text `json:"alias"` - Tags []string `json:"tags"` - Status string `json:"status"` - Metadata []byte `json:"metadata"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` -} - -type ContactChannel struct { - ID pgtype.UUID `json:"id"` - BotID pgtype.UUID `json:"bot_id"` - ContactID pgtype.UUID `json:"contact_id"` - Platform string `json:"platform"` - ExternalID string `json:"external_id"` - Metadata []byte `json:"metadata"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` -} - type Container struct { ID pgtype.UUID `json:"id"` BotID pgtype.UUID `json:"bot_id"` @@ -129,27 +166,6 @@ type ContainerVersion struct { CreatedAt pgtype.Timestamptz `json:"created_at"` } -type Conversation struct { - ID pgtype.UUID `json:"id"` - BotID pgtype.UUID `json:"bot_id"` - SessionID string `json:"session_id"` - ChannelType string `json:"channel_type"` - ChatID pgtype.Text `json:"chat_id"` - SenderID pgtype.Text `json:"sender_id"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` -} - -type History struct { - ID pgtype.UUID `json:"id"` - BotID pgtype.UUID `json:"bot_id"` - SessionID string `json:"session_id"` - Messages []byte `json:"messages"` - Metadata []byte `json:"metadata"` - Skills []string `json:"skills"` - Timestamp pgtype.Timestamptz `json:"timestamp"` -} - type LifecycleEvent struct { ID string `json:"id"` ContainerID string `json:"container_id"` @@ -240,34 +256,31 @@ type Subagent struct { } type User struct { - ID pgtype.UUID `json:"id"` - Username string `json:"username"` - Email pgtype.Text `json:"email"` - PasswordHash string `json:"password_hash"` - Role string `json:"role"` - DisplayName pgtype.Text `json:"display_name"` - AvatarUrl pgtype.Text `json:"avatar_url"` - IsActive bool `json:"is_active"` - DataRoot pgtype.Text `json:"data_root"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` - LastLoginAt pgtype.Timestamptz `json:"last_login_at"` + ID pgtype.UUID `json:"id"` + Username pgtype.Text `json:"username"` + Email pgtype.Text `json:"email"` + PasswordHash pgtype.Text `json:"password_hash"` + Role string `json:"role"` + DisplayName pgtype.Text `json:"display_name"` + AvatarUrl pgtype.Text `json:"avatar_url"` + DataRoot pgtype.Text `json:"data_root"` + LastLoginAt pgtype.Timestamptz `json:"last_login_at"` + ChatModelID pgtype.Text `json:"chat_model_id"` + MemoryModelID pgtype.Text `json:"memory_model_id"` + EmbeddingModelID pgtype.Text `json:"embedding_model_id"` + MaxContextLoadTime int32 `json:"max_context_load_time"` + Language string `json:"language"` + IsActive bool `json:"is_active"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` } type UserChannelBinding struct { - ID pgtype.UUID `json:"id"` - UserID pgtype.UUID `json:"user_id"` - ChannelType string `json:"channel_type"` - Config []byte `json:"config"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` -} - -type UserSetting struct { - UserID pgtype.UUID `json:"user_id"` - ChatModelID pgtype.Text `json:"chat_model_id"` - MemoryModelID pgtype.Text `json:"memory_model_id"` - EmbeddingModelID pgtype.Text `json:"embedding_model_id"` - MaxContextLoadTime int32 `json:"max_context_load_time"` - Language string `json:"language"` + ID pgtype.UUID `json:"id"` + UserID pgtype.UUID `json:"user_id"` + Platform string `json:"platform"` + Config []byte `json:"config"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` } diff --git a/internal/db/sqlc/settings.sql.go b/internal/db/sqlc/settings.sql.go index 25d96c07..f931844d 100644 --- a/internal/db/sqlc/settings.sql.go +++ b/internal/db/sqlc/settings.sql.go @@ -12,173 +12,67 @@ import ( ) const deleteSettingsByBotID = `-- name: DeleteSettingsByBotID :exec -DELETE FROM bot_settings -WHERE bot_id = $1 +UPDATE bots +SET max_context_load_time = 1440, + language = 'auto', + allow_guest = false, + updated_at = now() +WHERE id = $1 ` -func (q *Queries) DeleteSettingsByBotID(ctx context.Context, botID pgtype.UUID) error { - _, err := q.db.Exec(ctx, deleteSettingsByBotID, botID) +func (q *Queries) DeleteSettingsByBotID(ctx context.Context, id pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteSettingsByBotID, id) return err } -const getBotModelConfigByBotID = `-- name: GetBotModelConfigByBotID :one +const getSettingsByBotID = `-- name: GetSettingsByBotID :one SELECT - bot_model_configs.bot_id, + bots.id AS bot_id, + bots.max_context_load_time, + bots.language, + bots.allow_guest, chat_models.model_id AS chat_model_id, memory_models.model_id AS memory_model_id, embedding_models.model_id AS embedding_model_id -FROM bot_model_configs -LEFT JOIN models AS chat_models ON chat_models.id = bot_model_configs.chat_model_id -LEFT JOIN models AS memory_models ON memory_models.id = bot_model_configs.memory_model_id -LEFT JOIN models AS embedding_models ON embedding_models.id = bot_model_configs.embedding_model_id -WHERE bot_model_configs.bot_id = $1 +FROM bots +LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id +LEFT JOIN models AS memory_models ON memory_models.id = bots.memory_model_id +LEFT JOIN models AS embedding_models ON embedding_models.id = bots.embedding_model_id +WHERE bots.id = $1 ` -type GetBotModelConfigByBotIDRow struct { - BotID pgtype.UUID `json:"bot_id"` - ChatModelID pgtype.Text `json:"chat_model_id"` - MemoryModelID pgtype.Text `json:"memory_model_id"` - EmbeddingModelID pgtype.Text `json:"embedding_model_id"` +type GetSettingsByBotIDRow struct { + BotID pgtype.UUID `json:"bot_id"` + MaxContextLoadTime int32 `json:"max_context_load_time"` + Language string `json:"language"` + AllowGuest bool `json:"allow_guest"` + ChatModelID pgtype.Text `json:"chat_model_id"` + MemoryModelID pgtype.Text `json:"memory_model_id"` + EmbeddingModelID pgtype.Text `json:"embedding_model_id"` } -func (q *Queries) GetBotModelConfigByBotID(ctx context.Context, botID pgtype.UUID) (GetBotModelConfigByBotIDRow, error) { - row := q.db.QueryRow(ctx, getBotModelConfigByBotID, botID) - var i GetBotModelConfigByBotIDRow - err := row.Scan( - &i.BotID, - &i.ChatModelID, - &i.MemoryModelID, - &i.EmbeddingModelID, - ) - return i, err -} - -const getSettingsByBotID = `-- name: GetSettingsByBotID :one -SELECT bot_id, max_context_load_time, language, allow_guest -FROM bot_settings -WHERE bot_id = $1 -` - -func (q *Queries) GetSettingsByBotID(ctx context.Context, botID pgtype.UUID) (BotSetting, error) { - row := q.db.QueryRow(ctx, getSettingsByBotID, botID) - var i BotSetting +func (q *Queries) GetSettingsByBotID(ctx context.Context, id pgtype.UUID) (GetSettingsByBotIDRow, error) { + row := q.db.QueryRow(ctx, getSettingsByBotID, id) + var i GetSettingsByBotIDRow err := row.Scan( &i.BotID, &i.MaxContextLoadTime, &i.Language, &i.AllowGuest, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, ) return i, err } const getSettingsByUserID = `-- name: GetSettingsByUserID :one -SELECT user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language -FROM user_settings -WHERE user_id = $1 +SELECT id AS user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language +FROM users +WHERE id = $1 ` -func (q *Queries) GetSettingsByUserID(ctx context.Context, userID pgtype.UUID) (UserSetting, error) { - row := q.db.QueryRow(ctx, getSettingsByUserID, userID) - var i UserSetting - err := row.Scan( - &i.UserID, - &i.ChatModelID, - &i.MemoryModelID, - &i.EmbeddingModelID, - &i.MaxContextLoadTime, - &i.Language, - ) - return i, err -} - -const upsertBotModelConfig = `-- name: UpsertBotModelConfig :one -INSERT INTO bot_model_configs (bot_id, chat_model_id, memory_model_id, embedding_model_id) -VALUES ($1, $2, $3, $4) -ON CONFLICT (bot_id) DO UPDATE SET - chat_model_id = COALESCE(EXCLUDED.chat_model_id, bot_model_configs.chat_model_id), - memory_model_id = COALESCE(EXCLUDED.memory_model_id, bot_model_configs.memory_model_id), - embedding_model_id = COALESCE(EXCLUDED.embedding_model_id, bot_model_configs.embedding_model_id) -RETURNING bot_id, chat_model_id, memory_model_id, embedding_model_id -` - -type UpsertBotModelConfigParams struct { - BotID pgtype.UUID `json:"bot_id"` - ChatModelID pgtype.UUID `json:"chat_model_id"` - MemoryModelID pgtype.UUID `json:"memory_model_id"` - EmbeddingModelID pgtype.UUID `json:"embedding_model_id"` -} - -type UpsertBotModelConfigRow struct { - BotID pgtype.UUID `json:"bot_id"` - ChatModelID pgtype.UUID `json:"chat_model_id"` - MemoryModelID pgtype.UUID `json:"memory_model_id"` - EmbeddingModelID pgtype.UUID `json:"embedding_model_id"` -} - -func (q *Queries) UpsertBotModelConfig(ctx context.Context, arg UpsertBotModelConfigParams) (UpsertBotModelConfigRow, error) { - row := q.db.QueryRow(ctx, upsertBotModelConfig, - arg.BotID, - arg.ChatModelID, - arg.MemoryModelID, - arg.EmbeddingModelID, - ) - var i UpsertBotModelConfigRow - err := row.Scan( - &i.BotID, - &i.ChatModelID, - &i.MemoryModelID, - &i.EmbeddingModelID, - ) - return i, err -} - -const upsertBotSettings = `-- name: UpsertBotSettings :one -INSERT INTO bot_settings (bot_id, max_context_load_time, language, allow_guest) -VALUES ($1, $2, $3, $4) -ON CONFLICT (bot_id) DO UPDATE SET - max_context_load_time = EXCLUDED.max_context_load_time, - language = EXCLUDED.language, - allow_guest = EXCLUDED.allow_guest -RETURNING bot_id, max_context_load_time, language, allow_guest -` - -type UpsertBotSettingsParams struct { - BotID pgtype.UUID `json:"bot_id"` - MaxContextLoadTime int32 `json:"max_context_load_time"` - Language string `json:"language"` - AllowGuest bool `json:"allow_guest"` -} - -func (q *Queries) UpsertBotSettings(ctx context.Context, arg UpsertBotSettingsParams) (BotSetting, error) { - row := q.db.QueryRow(ctx, upsertBotSettings, - arg.BotID, - arg.MaxContextLoadTime, - arg.Language, - arg.AllowGuest, - ) - var i BotSetting - err := row.Scan( - &i.BotID, - &i.MaxContextLoadTime, - &i.Language, - &i.AllowGuest, - ) - return i, err -} - -const upsertUserSettings = `-- name: UpsertUserSettings :one -INSERT INTO user_settings (user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language) -VALUES ($1, $2, $3, $4, $5, $6) -ON CONFLICT (user_id) DO UPDATE SET - chat_model_id = EXCLUDED.chat_model_id, - memory_model_id = EXCLUDED.memory_model_id, - embedding_model_id = EXCLUDED.embedding_model_id, - max_context_load_time = EXCLUDED.max_context_load_time, - language = EXCLUDED.language -RETURNING user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language -` - -type UpsertUserSettingsParams struct { +type GetSettingsByUserIDRow struct { UserID pgtype.UUID `json:"user_id"` ChatModelID pgtype.Text `json:"chat_model_id"` MemoryModelID pgtype.Text `json:"memory_model_id"` @@ -187,16 +81,130 @@ type UpsertUserSettingsParams struct { Language string `json:"language"` } -func (q *Queries) UpsertUserSettings(ctx context.Context, arg UpsertUserSettingsParams) (UserSetting, error) { - row := q.db.QueryRow(ctx, upsertUserSettings, - arg.UserID, - arg.ChatModelID, - arg.MemoryModelID, - arg.EmbeddingModelID, - arg.MaxContextLoadTime, - arg.Language, - ) - var i UserSetting +func (q *Queries) GetSettingsByUserID(ctx context.Context, id pgtype.UUID) (GetSettingsByUserIDRow, error) { + row := q.db.QueryRow(ctx, getSettingsByUserID, id) + var i GetSettingsByUserIDRow + err := row.Scan( + &i.UserID, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + ) + return i, err +} + +const upsertBotSettings = `-- name: UpsertBotSettings :one +WITH updated AS ( + UPDATE bots + SET max_context_load_time = $1, + language = $2, + allow_guest = $3, + chat_model_id = COALESCE($4::uuid, bots.chat_model_id), + memory_model_id = COALESCE($5::uuid, bots.memory_model_id), + embedding_model_id = COALESCE($6::uuid, bots.embedding_model_id), + updated_at = now() + WHERE bots.id = $7 + RETURNING bots.id, bots.max_context_load_time, bots.language, bots.allow_guest, bots.chat_model_id, bots.memory_model_id, bots.embedding_model_id +) +SELECT + updated.id AS bot_id, + updated.max_context_load_time, + updated.language, + updated.allow_guest, + chat_models.model_id AS chat_model_id, + memory_models.model_id AS memory_model_id, + embedding_models.model_id AS embedding_model_id +FROM updated +LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id +LEFT JOIN models AS memory_models ON memory_models.id = updated.memory_model_id +LEFT JOIN models AS embedding_models ON embedding_models.id = updated.embedding_model_id +` + +type UpsertBotSettingsParams struct { + MaxContextLoadTime int32 `json:"max_context_load_time"` + Language string `json:"language"` + AllowGuest bool `json:"allow_guest"` + ChatModelID pgtype.UUID `json:"chat_model_id"` + MemoryModelID pgtype.UUID `json:"memory_model_id"` + EmbeddingModelID pgtype.UUID `json:"embedding_model_id"` + ID pgtype.UUID `json:"id"` +} + +type UpsertBotSettingsRow struct { + BotID pgtype.UUID `json:"bot_id"` + MaxContextLoadTime int32 `json:"max_context_load_time"` + Language string `json:"language"` + AllowGuest bool `json:"allow_guest"` + ChatModelID pgtype.Text `json:"chat_model_id"` + MemoryModelID pgtype.Text `json:"memory_model_id"` + EmbeddingModelID pgtype.Text `json:"embedding_model_id"` +} + +func (q *Queries) UpsertBotSettings(ctx context.Context, arg UpsertBotSettingsParams) (UpsertBotSettingsRow, error) { + row := q.db.QueryRow(ctx, upsertBotSettings, + arg.MaxContextLoadTime, + arg.Language, + arg.AllowGuest, + arg.ChatModelID, + arg.MemoryModelID, + arg.EmbeddingModelID, + arg.ID, + ) + var i UpsertBotSettingsRow + err := row.Scan( + &i.BotID, + &i.MaxContextLoadTime, + &i.Language, + &i.AllowGuest, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + ) + return i, err +} + +const upsertUserSettings = `-- name: UpsertUserSettings :one +UPDATE users +SET chat_model_id = $2, + memory_model_id = $3, + embedding_model_id = $4, + max_context_load_time = $5, + language = $6, + updated_at = now() +WHERE id = $1 +RETURNING id AS user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language +` + +type UpsertUserSettingsParams struct { + ID pgtype.UUID `json:"id"` + ChatModelID pgtype.Text `json:"chat_model_id"` + MemoryModelID pgtype.Text `json:"memory_model_id"` + EmbeddingModelID pgtype.Text `json:"embedding_model_id"` + MaxContextLoadTime int32 `json:"max_context_load_time"` + Language string `json:"language"` +} + +type UpsertUserSettingsRow struct { + UserID pgtype.UUID `json:"user_id"` + ChatModelID pgtype.Text `json:"chat_model_id"` + MemoryModelID pgtype.Text `json:"memory_model_id"` + EmbeddingModelID pgtype.Text `json:"embedding_model_id"` + MaxContextLoadTime int32 `json:"max_context_load_time"` + Language string `json:"language"` +} + +func (q *Queries) UpsertUserSettings(ctx context.Context, arg UpsertUserSettingsParams) (UpsertUserSettingsRow, error) { + row := q.db.QueryRow(ctx, upsertUserSettings, + arg.ID, + arg.ChatModelID, + arg.MemoryModelID, + arg.EmbeddingModelID, + arg.MaxContextLoadTime, + arg.Language, + ) + var i UpsertUserSettingsRow err := row.Scan( &i.UserID, &i.ChatModelID, diff --git a/internal/db/sqlc/users.sql.go b/internal/db/sqlc/users.sql.go index d421dbd3..60c2aa4a 100644 --- a/internal/db/sqlc/users.sql.go +++ b/internal/db/sqlc/users.sql.go @@ -11,45 +11,49 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -const countUsers = `-- name: CountUsers :one -SELECT COUNT(*)::bigint AS count FROM users +const countAccounts = `-- name: CountAccounts :one +SELECT COUNT(*)::bigint AS count +FROM users +WHERE username IS NOT NULL + AND password_hash IS NOT NULL ` -func (q *Queries) CountUsers(ctx context.Context) (int64, error) { - row := q.db.QueryRow(ctx, countUsers) +func (q *Queries) CountAccounts(ctx context.Context) (int64, error) { + row := q.db.QueryRow(ctx, countAccounts) var count int64 err := row.Scan(&count) return count, err } -const createUser = `-- name: CreateUser :one -INSERT INTO users (username, email, password_hash, role, display_name, avatar_url, is_active, data_root) -VALUES ( - $1, - $2, - $3, - $4::user_role, - $5, - $6, - $7, - $8 -) -RETURNING id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at +const createAccount = `-- name: CreateAccount :one +UPDATE users +SET username = $1, + email = $2, + password_hash = $3, + role = $4::user_role, + display_name = $5, + avatar_url = $6, + is_active = $7, + data_root = $8, + updated_at = now() +WHERE id = $9 +RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at ` -type CreateUserParams struct { - Username string `json:"username"` +type CreateAccountParams struct { + Username pgtype.Text `json:"username"` Email pgtype.Text `json:"email"` - PasswordHash string `json:"password_hash"` + PasswordHash pgtype.Text `json:"password_hash"` Role string `json:"role"` DisplayName pgtype.Text `json:"display_name"` AvatarUrl pgtype.Text `json:"avatar_url"` IsActive bool `json:"is_active"` DataRoot pgtype.Text `json:"data_root"` + UserID pgtype.UUID `json:"user_id"` } -func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { - row := q.db.QueryRow(ctx, createUser, +func (q *Queries) CreateAccount(ctx context.Context, arg CreateAccountParams) (User, error) { + row := q.db.QueryRow(ctx, createAccount, arg.Username, arg.Email, arg.PasswordHash, @@ -58,6 +62,7 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e arg.AvatarUrl, arg.IsActive, arg.DataRoot, + arg.UserID, ) var i User err := row.Scan( @@ -68,55 +73,34 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e &i.Role, &i.DisplayName, &i.AvatarUrl, - &i.IsActive, &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, - &i.LastLoginAt, ) return i, err } -const createUserWithID = `-- name: CreateUserWithID :one -INSERT INTO users (id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root) -VALUES ( - $1, - $2, - $3, - $4, - $5::user_role, - $6, - $7, - $8, - $9 -) -RETURNING id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at +const createUser = `-- name: CreateUser :one +INSERT INTO users (is_active, metadata) +VALUES ($1, $2) +RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at ` -type CreateUserWithIDParams struct { - ID pgtype.UUID `json:"id"` - Username string `json:"username"` - Email pgtype.Text `json:"email"` - PasswordHash string `json:"password_hash"` - Role string `json:"role"` - DisplayName pgtype.Text `json:"display_name"` - AvatarUrl pgtype.Text `json:"avatar_url"` - IsActive bool `json:"is_active"` - DataRoot pgtype.Text `json:"data_root"` +type CreateUserParams struct { + IsActive bool `json:"is_active"` + Metadata []byte `json:"metadata"` } -func (q *Queries) CreateUserWithID(ctx context.Context, arg CreateUserWithIDParams) (User, error) { - row := q.db.QueryRow(ctx, createUserWithID, - arg.ID, - arg.Username, - arg.Email, - arg.PasswordHash, - arg.Role, - arg.DisplayName, - arg.AvatarUrl, - arg.IsActive, - arg.DataRoot, - ) +func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { + row := q.db.QueryRow(ctx, createUser, arg.IsActive, arg.Metadata) var i User err := row.Scan( &i.ID, @@ -126,17 +110,115 @@ func (q *Queries) CreateUserWithID(ctx context.Context, arg CreateUserWithIDPara &i.Role, &i.DisplayName, &i.AvatarUrl, - &i.IsActive, &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, + ) + return i, err +} + +const getAccountByIdentity = `-- name: GetAccountByIdentity :one +SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at FROM users WHERE username = $1 OR email = $1 +` + +func (q *Queries) GetAccountByIdentity(ctx context.Context, identity pgtype.Text) (User, error) { + row := q.db.QueryRow(ctx, getAccountByIdentity, identity) + var i User + err := row.Scan( + &i.ID, + &i.Username, + &i.Email, + &i.PasswordHash, + &i.Role, + &i.DisplayName, + &i.AvatarUrl, + &i.DataRoot, &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getAccountByUserID = `-- name: GetAccountByUserID :one +SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at FROM users WHERE id = $1 +` + +func (q *Queries) GetAccountByUserID(ctx context.Context, userID pgtype.UUID) (User, error) { + row := q.db.QueryRow(ctx, getAccountByUserID, userID) + var i User + err := row.Scan( + &i.ID, + &i.Username, + &i.Email, + &i.PasswordHash, + &i.Role, + &i.DisplayName, + &i.AvatarUrl, + &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getAccountByUsername = `-- name: GetAccountByUsername :one +SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at FROM users WHERE username = $1 +` + +func (q *Queries) GetAccountByUsername(ctx context.Context, username pgtype.Text) (User, error) { + row := q.db.QueryRow(ctx, getAccountByUsername, username) + var i User + err := row.Scan( + &i.ID, + &i.Username, + &i.Email, + &i.PasswordHash, + &i.Role, + &i.DisplayName, + &i.AvatarUrl, + &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, ) return i, err } const getUserByID = `-- name: GetUserByID :one -SELECT id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at FROM users WHERE id = $1 +SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at +FROM users +WHERE id = $1 ` func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error) { @@ -150,70 +232,29 @@ func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error) &i.Role, &i.DisplayName, &i.AvatarUrl, - &i.IsActive, &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, - &i.LastLoginAt, ) return i, err } -const getUserByIdentity = `-- name: GetUserByIdentity :one -SELECT id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at FROM users WHERE username = $1 OR email = $1 -` - -func (q *Queries) GetUserByIdentity(ctx context.Context, identity string) (User, error) { - row := q.db.QueryRow(ctx, getUserByIdentity, identity) - var i User - err := row.Scan( - &i.ID, - &i.Username, - &i.Email, - &i.PasswordHash, - &i.Role, - &i.DisplayName, - &i.AvatarUrl, - &i.IsActive, - &i.DataRoot, - &i.CreatedAt, - &i.UpdatedAt, - &i.LastLoginAt, - ) - return i, err -} - -const getUserByUsername = `-- name: GetUserByUsername :one -SELECT id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at FROM users WHERE username = $1 -` - -func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User, error) { - row := q.db.QueryRow(ctx, getUserByUsername, username) - var i User - err := row.Scan( - &i.ID, - &i.Username, - &i.Email, - &i.PasswordHash, - &i.Role, - &i.DisplayName, - &i.AvatarUrl, - &i.IsActive, - &i.DataRoot, - &i.CreatedAt, - &i.UpdatedAt, - &i.LastLoginAt, - ) - return i, err -} - -const listUsers = `-- name: ListUsers :many -SELECT id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at FROM users +const listAccounts = `-- name: ListAccounts :many +SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at FROM users +WHERE username IS NOT NULL ORDER BY created_at DESC ` -func (q *Queries) ListUsers(ctx context.Context) ([]User, error) { - rows, err := q.db.Query(ctx, listUsers) +func (q *Queries) ListAccounts(ctx context.Context) ([]User, error) { + rows, err := q.db.Query(ctx, listAccounts) if err != nil { return nil, err } @@ -229,11 +270,17 @@ func (q *Queries) ListUsers(ctx context.Context) ([]User, error) { &i.Role, &i.DisplayName, &i.AvatarUrl, - &i.IsActive, &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, - &i.LastLoginAt, ); err != nil { return nil, err } @@ -245,7 +292,7 @@ func (q *Queries) ListUsers(ctx context.Context) ([]User, error) { return items, nil } -const updateUserAdmin = `-- name: UpdateUserAdmin :one +const updateAccountAdmin = `-- name: UpdateAccountAdmin :one UPDATE users SET role = $1::user_role, display_name = $2, @@ -253,24 +300,24 @@ SET role = $1::user_role, is_active = $4, updated_at = now() WHERE id = $5 -RETURNING id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at +RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at ` -type UpdateUserAdminParams struct { +type UpdateAccountAdminParams struct { Role string `json:"role"` DisplayName pgtype.Text `json:"display_name"` AvatarUrl pgtype.Text `json:"avatar_url"` IsActive bool `json:"is_active"` - ID pgtype.UUID `json:"id"` + UserID pgtype.UUID `json:"user_id"` } -func (q *Queries) UpdateUserAdmin(ctx context.Context, arg UpdateUserAdminParams) (User, error) { - row := q.db.QueryRow(ctx, updateUserAdmin, +func (q *Queries) UpdateAccountAdmin(ctx context.Context, arg UpdateAccountAdminParams) (User, error) { + row := q.db.QueryRow(ctx, updateAccountAdmin, arg.Role, arg.DisplayName, arg.AvatarUrl, arg.IsActive, - arg.ID, + arg.UserID, ) var i User err := row.Scan( @@ -281,25 +328,31 @@ func (q *Queries) UpdateUserAdmin(ctx context.Context, arg UpdateUserAdminParams &i.Role, &i.DisplayName, &i.AvatarUrl, - &i.IsActive, &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, - &i.LastLoginAt, ) return i, err } -const updateUserLastLogin = `-- name: UpdateUserLastLogin :one +const updateAccountLastLogin = `-- name: UpdateAccountLastLogin :one UPDATE users SET last_login_at = now(), updated_at = now() WHERE id = $1 -RETURNING id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at +RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at ` -func (q *Queries) UpdateUserLastLogin(ctx context.Context, id pgtype.UUID) (User, error) { - row := q.db.QueryRow(ctx, updateUserLastLogin, id) +func (q *Queries) UpdateAccountLastLogin(ctx context.Context, id pgtype.UUID) (User, error) { + row := q.db.QueryRow(ctx, updateAccountLastLogin, id) var i User err := row.Scan( &i.ID, @@ -309,30 +362,36 @@ func (q *Queries) UpdateUserLastLogin(ctx context.Context, id pgtype.UUID) (User &i.Role, &i.DisplayName, &i.AvatarUrl, - &i.IsActive, &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, - &i.LastLoginAt, ) return i, err } -const updateUserPassword = `-- name: UpdateUserPassword :one +const updateAccountPassword = `-- name: UpdateAccountPassword :one UPDATE users SET password_hash = $2, updated_at = now() WHERE id = $1 -RETURNING id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at +RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at ` -type UpdateUserPasswordParams struct { +type UpdateAccountPasswordParams struct { ID pgtype.UUID `json:"id"` - PasswordHash string `json:"password_hash"` + PasswordHash pgtype.Text `json:"password_hash"` } -func (q *Queries) UpdateUserPassword(ctx context.Context, arg UpdateUserPasswordParams) (User, error) { - row := q.db.QueryRow(ctx, updateUserPassword, arg.ID, arg.PasswordHash) +func (q *Queries) UpdateAccountPassword(ctx context.Context, arg UpdateAccountPasswordParams) (User, error) { + row := q.db.QueryRow(ctx, updateAccountPassword, arg.ID, arg.PasswordHash) var i User err := row.Scan( &i.ID, @@ -342,34 +401,40 @@ func (q *Queries) UpdateUserPassword(ctx context.Context, arg UpdateUserPassword &i.Role, &i.DisplayName, &i.AvatarUrl, - &i.IsActive, &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, - &i.LastLoginAt, ) return i, err } -const updateUserProfile = `-- name: UpdateUserProfile :one +const updateAccountProfile = `-- name: UpdateAccountProfile :one UPDATE users SET display_name = $2, avatar_url = $3, is_active = $4, updated_at = now() WHERE id = $1 -RETURNING id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at +RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at ` -type UpdateUserProfileParams struct { +type UpdateAccountProfileParams struct { ID pgtype.UUID `json:"id"` DisplayName pgtype.Text `json:"display_name"` AvatarUrl pgtype.Text `json:"avatar_url"` IsActive bool `json:"is_active"` } -func (q *Queries) UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error) { - row := q.db.QueryRow(ctx, updateUserProfile, +func (q *Queries) UpdateAccountProfile(ctx context.Context, arg UpdateAccountProfileParams) (User, error) { + row := q.db.QueryRow(ctx, updateAccountProfile, arg.ID, arg.DisplayName, arg.AvatarUrl, @@ -384,26 +449,73 @@ func (q *Queries) UpdateUserProfile(ctx context.Context, arg UpdateUserProfilePa &i.Role, &i.DisplayName, &i.AvatarUrl, - &i.IsActive, &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, - &i.LastLoginAt, ) return i, err } -const upsertUserByUsername = `-- name: UpsertUserByUsername :one -INSERT INTO users (username, email, password_hash, role, display_name, avatar_url, is_active, data_root) +const updateUserStatus = `-- name: UpdateUserStatus :one +UPDATE users +SET is_active = $2, + updated_at = now() +WHERE id = $1 +RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at +` + +type UpdateUserStatusParams struct { + ID pgtype.UUID `json:"id"` + IsActive bool `json:"is_active"` +} + +func (q *Queries) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error) { + row := q.db.QueryRow(ctx, updateUserStatus, arg.ID, arg.IsActive) + var i User + err := row.Scan( + &i.ID, + &i.Username, + &i.Email, + &i.PasswordHash, + &i.Role, + &i.DisplayName, + &i.AvatarUrl, + &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const upsertAccountByUsername = `-- name: UpsertAccountByUsername :one +INSERT INTO users (id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, metadata) VALUES ( $1, $2, $3, - $4::user_role, - $5, + $4, + $5::user_role, $6, $7, - $8 + $8, + $9, + '{}'::jsonb ) ON CONFLICT (username) DO UPDATE SET email = EXCLUDED.email, @@ -414,13 +526,14 @@ ON CONFLICT (username) DO UPDATE SET is_active = EXCLUDED.is_active, data_root = EXCLUDED.data_root, updated_at = now() -RETURNING id, username, email, password_hash, role, display_name, avatar_url, is_active, data_root, created_at, updated_at, last_login_at +RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at ` -type UpsertUserByUsernameParams struct { - Username string `json:"username"` +type UpsertAccountByUsernameParams struct { + UserID pgtype.UUID `json:"user_id"` + Username pgtype.Text `json:"username"` Email pgtype.Text `json:"email"` - PasswordHash string `json:"password_hash"` + PasswordHash pgtype.Text `json:"password_hash"` Role string `json:"role"` DisplayName pgtype.Text `json:"display_name"` AvatarUrl pgtype.Text `json:"avatar_url"` @@ -428,8 +541,9 @@ type UpsertUserByUsernameParams struct { DataRoot pgtype.Text `json:"data_root"` } -func (q *Queries) UpsertUserByUsername(ctx context.Context, arg UpsertUserByUsernameParams) (User, error) { - row := q.db.QueryRow(ctx, upsertUserByUsername, +func (q *Queries) UpsertAccountByUsername(ctx context.Context, arg UpsertAccountByUsernameParams) (User, error) { + row := q.db.QueryRow(ctx, upsertAccountByUsername, + arg.UserID, arg.Username, arg.Email, arg.PasswordHash, @@ -448,11 +562,17 @@ func (q *Queries) UpsertUserByUsername(ctx context.Context, arg UpsertUserByUser &i.Role, &i.DisplayName, &i.AvatarUrl, - &i.IsActive, &i.DataRoot, + &i.LastLoginAt, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + &i.IsActive, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, - &i.LastLoginAt, ) return i, err } diff --git a/internal/directory/service.go b/internal/directory/service.go deleted file mode 100644 index 158039e1..00000000 --- a/internal/directory/service.go +++ /dev/null @@ -1,226 +0,0 @@ -package directory - -import ( - "context" - "errors" - "fmt" - "log/slog" - "strings" - - "github.com/memohai/memoh/internal/channel" - "github.com/memohai/memoh/internal/contacts" -) - -var ( - ErrNotFound = errors.New("directory entry not found") - ErrAmbiguous = errors.New("directory entry ambiguous") - ErrUnsupported = errors.New("directory operation unsupported") -) - -type ContactReader interface { - Search(ctx context.Context, botID, query string) ([]contacts.Contact, error) - ListByBot(ctx context.Context, botID string) ([]contacts.Contact, error) - ListChannelsByContact(ctx context.Context, contactID string) ([]contacts.ContactChannel, error) -} - -type ChannelSessionStore interface { - ListSessionsByBotPlatform(ctx context.Context, botID, platform string) ([]channel.ChannelSession, error) -} - -type LocalService struct { - contacts ContactReader - sessions ChannelSessionStore - logger *slog.Logger -} - -func NewLocalService(log *slog.Logger, contacts ContactReader, sessions ChannelSessionStore) *LocalService { - if log == nil { - log = slog.Default() - } - return &LocalService{ - contacts: contacts, - sessions: sessions, - logger: log.With(slog.String("service", "directory")), - } -} - -func (s *LocalService) ListPeers(ctx context.Context, botID, platform, query string, limit int) ([]channel.DirectoryEntry, error) { - if s.contacts == nil { - return nil, fmt.Errorf("contacts service not configured") - } - trimmed := strings.TrimSpace(query) - var items []contacts.Contact - var err error - if trimmed == "" { - items, err = s.contacts.ListByBot(ctx, botID) - } else { - items, err = s.contacts.Search(ctx, botID, trimmed) - } - if err != nil { - return nil, err - } - results := make([]channel.DirectoryEntry, 0, len(items)) - for _, contact := range items { - channels, err := s.contacts.ListChannelsByContact(ctx, contact.ID) - if err != nil { - if s.logger != nil { - s.logger.Warn("list contact channels failed", slog.String("contact_id", contact.ID), slog.Any("error", err)) - } - continue - } - for _, ch := range channels { - if platform != "" && ch.Platform != platform { - continue - } - entry := channel.DirectoryEntry{ - Kind: channel.DirectoryEntryUser, - ID: strings.TrimSpace(ch.ExternalID), - Name: chooseContactName(contact, ch), - Handle: strings.TrimSpace(contact.Alias), - Metadata: map[string]any{}, - } - if entry.ID == "" { - continue - } - entry.Metadata["contact_id"] = contact.ID - if contact.UserID != "" { - entry.Metadata["user_id"] = contact.UserID - } - entry.Metadata["platform"] = ch.Platform - results = append(results, entry) - if limit > 0 && len(results) >= limit { - return results, nil - } - } - } - return results, nil -} - -func (s *LocalService) ListGroups(ctx context.Context, botID, platform, query string, limit int) ([]channel.DirectoryEntry, error) { - if s.sessions == nil { - return nil, fmt.Errorf("channel session store not configured") - } - platform = strings.TrimSpace(platform) - if platform == "" { - return nil, fmt.Errorf("platform is required") - } - sessions, err := s.sessions.ListSessionsByBotPlatform(ctx, botID, platform) - if err != nil { - return nil, err - } - trimmed := strings.TrimSpace(query) - results := make([]channel.DirectoryEntry, 0, len(sessions)) - for _, session := range sessions { - if !isGroupSession(session) { - continue - } - name := channel.ReadString(session.Metadata, "conversation_name", "name") - entryID := strings.TrimSpace(session.ReplyTarget) - if entryID == "" { - entryID = strings.TrimSpace(session.SessionID) - } - if entryID == "" { - continue - } - if trimmed != "" && !matchesQuery(trimmed, entryID, name) { - continue - } - results = append(results, channel.DirectoryEntry{ - Kind: channel.DirectoryEntryGroup, - ID: entryID, - Name: strings.TrimSpace(name), - Metadata: session.Metadata, - }) - if limit > 0 && len(results) >= limit { - return results, nil - } - } - return results, nil -} - -func (s *LocalService) ListGroupMembers(ctx context.Context, botID, platform, groupID string, limit int) ([]channel.DirectoryEntry, error) { - return nil, ErrUnsupported -} - -func (s *LocalService) ResolveTarget(ctx context.Context, botID, platform, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { - trimmed := strings.TrimSpace(input) - if trimmed == "" { - return channel.DirectoryEntry{}, ErrNotFound - } - switch kind { - case channel.DirectoryEntryGroup: - items, err := s.ListGroups(ctx, botID, platform, trimmed, 5) - if err != nil { - return channel.DirectoryEntry{}, err - } - return pickSingleMatch(items, trimmed) - default: - items, err := s.ListPeers(ctx, botID, platform, trimmed, 5) - if err != nil { - return channel.DirectoryEntry{}, err - } - return pickSingleMatch(items, trimmed) - } -} - -func pickSingleMatch(items []channel.DirectoryEntry, input string) (channel.DirectoryEntry, error) { - if len(items) == 0 { - return channel.DirectoryEntry{}, ErrNotFound - } - if len(items) == 1 { - return items[0], nil - } - lower := strings.ToLower(strings.TrimSpace(input)) - var exact *channel.DirectoryEntry - for i := range items { - if strings.ToLower(strings.TrimSpace(items[i].ID)) == lower { - exact = &items[i] - break - } - if strings.ToLower(strings.TrimSpace(items[i].Name)) == lower { - exact = &items[i] - break - } - } - if exact != nil { - return *exact, nil - } - return channel.DirectoryEntry{}, ErrAmbiguous -} - -func chooseContactName(contact contacts.Contact, ch contacts.ContactChannel) string { - if strings.TrimSpace(contact.DisplayName) != "" { - return strings.TrimSpace(contact.DisplayName) - } - if strings.TrimSpace(contact.Alias) != "" { - return strings.TrimSpace(contact.Alias) - } - if strings.TrimSpace(ch.ExternalID) != "" { - return strings.TrimSpace(ch.ExternalID) - } - return "" -} - -func isGroupSession(session channel.ChannelSession) bool { - value := strings.ToLower(strings.TrimSpace(channel.ReadString(session.Metadata, "conversation_type", "chat_type", "type"))) - if value == "" { - return false - } - if strings.Contains(value, "group") { - return true - } - return false -} - -func matchesQuery(query string, fields ...string) bool { - needle := strings.ToLower(strings.TrimSpace(query)) - if needle == "" { - return true - } - for _, field := range fields { - if strings.Contains(strings.ToLower(strings.TrimSpace(field)), needle) { - return true - } - } - return false -} diff --git a/internal/embeddings/dashscope.go b/internal/embeddings/dashscope.go index c15fc797..f16dc424 100644 --- a/internal/embeddings/dashscope.go +++ b/internal/embeddings/dashscope.go @@ -33,7 +33,7 @@ type DashScopeUsage struct { } type dashScopeRequest struct { - Model string `json:"model"` + Model string `json:"model"` Input dashScopeRequestInput `json:"input"` } diff --git a/internal/embeddings/resolver.go b/internal/embeddings/resolver.go index 48d9d98c..eb8c123b 100644 --- a/internal/embeddings/resolver.go +++ b/internal/embeddings/resolver.go @@ -26,12 +26,12 @@ const ( ) type Request struct { - Type string - Provider string - Model string - Dimensions int - Input Input - UserID string + Type string + Provider string + Model string + Dimensions int + Input Input + ChannelIdentityID string } type Input struct { @@ -180,8 +180,8 @@ func (r *Resolver) selectEmbeddingModel(ctx context.Context, req Request) (model } // If no model specified and no provider specified, try to get per-user embedding model. - if req.Model == "" && req.Provider == "" && strings.TrimSpace(req.UserID) != "" { - modelID, err := r.loadUserEmbeddingModelID(ctx, req.UserID) + if req.Model == "" && req.Provider == "" && strings.TrimSpace(req.ChannelIdentityID) != "" { + modelID, err := r.loadChannelIdentityEmbeddingModelID(ctx, req.ChannelIdentityID) if err != nil { return models.GetResponse{}, err } @@ -257,15 +257,15 @@ func (r *Resolver) fetchProvider(ctx context.Context, providerID string) (sqlc.L return r.queries.GetLlmProviderByID(ctx, pgID) } -func (r *Resolver) loadUserEmbeddingModelID(ctx context.Context, userID string) (string, error) { +func (r *Resolver) loadChannelIdentityEmbeddingModelID(ctx context.Context, channelIdentityID string) (string, error) { if r.queries == nil { return "", nil } - pgUserID, err := parseUUID(userID) + pgChannelIdentityID, err := parseUUID(channelIdentityID) if err != nil { return "", err } - row, err := r.queries.GetSettingsByUserID(ctx, pgUserID) + row, err := r.queries.GetSettingsByUserID(ctx, pgChannelIdentityID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return "", nil diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go index b5f2c439..dfb4503d 100644 --- a/internal/handlers/auth.go +++ b/internal/handlers/auth.go @@ -9,15 +9,15 @@ import ( "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" - "github.com/memohai/memoh/internal/users" ) type AuthHandler struct { - userService *users.Service - jwtSecret string - expiresIn time.Duration - logger *slog.Logger + accountService *accounts.Service + jwtSecret string + expiresIn time.Duration + logger *slog.Logger } type LoginRequest struct { @@ -35,12 +35,12 @@ type LoginResponse struct { Username string `json:"username"` } -func NewAuthHandler(log *slog.Logger, userService *users.Service, jwtSecret string, expiresIn time.Duration) *AuthHandler { +func NewAuthHandler(log *slog.Logger, accountService *accounts.Service, jwtSecret string, expiresIn time.Duration) *AuthHandler { return &AuthHandler{ - userService: userService, - jwtSecret: jwtSecret, - expiresIn: expiresIn, - logger: log.With(slog.String("handler", "auth")), + accountService: accountService, + jwtSecret: jwtSecret, + expiresIn: expiresIn, + logger: log.With(slog.String("handler", "auth")), } } @@ -59,7 +59,7 @@ func (h *AuthHandler) Register(e *echo.Echo) { // @Failure 500 {object} ErrorResponse // @Router /auth/login [post] func (h *AuthHandler) Login(c echo.Context) error { - if h.userService == nil { + if h.accountService == nil { return echo.NewHTTPError(http.StatusInternalServerError, "user service not configured") } if strings.TrimSpace(h.jwtSecret) == "" { @@ -78,17 +78,17 @@ func (h *AuthHandler) Login(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, "username and password are required") } - user, err := h.userService.Login(c.Request().Context(), req.Username, req.Password) + account, err := h.accountService.Login(c.Request().Context(), req.Username, req.Password) if err != nil { - if errors.Is(err, users.ErrInvalidCredentials) { + if errors.Is(err, accounts.ErrInvalidCredentials) { return echo.NewHTTPError(http.StatusUnauthorized, "invalid credentials") } - if errors.Is(err, users.ErrInactiveUser) { + if errors.Is(err, accounts.ErrInactiveAccount) { return echo.NewHTTPError(http.StatusUnauthorized, "user is inactive") } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - token, expiresAt, err := auth.GenerateToken(user.ID, h.jwtSecret, h.expiresIn) + token, expiresAt, err := auth.GenerateToken(account.ID, h.jwtSecret, h.expiresIn) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -97,9 +97,9 @@ func (h *AuthHandler) Login(c echo.Context) error { AccessToken: token, TokenType: "Bearer", ExpiresAt: expiresAt.Format(time.RFC3339), - UserID: user.ID, - Username: user.Username, - Role: user.Role, - DisplayName: user.DisplayName, + UserID: account.ID, + Username: account.Username, + Role: account.Role, + DisplayName: account.DisplayName, }) } diff --git a/internal/handlers/bind.go b/internal/handlers/bind.go new file mode 100644 index 00000000..00fc9933 --- /dev/null +++ b/internal/handlers/bind.go @@ -0,0 +1,91 @@ +package handlers + +import ( + "errors" + "io" + "log/slog" + "net/http" + "strings" + "time" + + "github.com/labstack/echo/v4" + + "github.com/memohai/memoh/internal/auth" + "github.com/memohai/memoh/internal/bind" + "github.com/memohai/memoh/internal/identity" +) + +// BindHandler manages channel identity bind code issuance via REST API. +type BindHandler struct { + service *bind.Service + logger *slog.Logger +} + +// NewBindHandler creates a BindHandler. +func NewBindHandler(log *slog.Logger, service *bind.Service) *BindHandler { + if log == nil { + log = slog.Default() + } + return &BindHandler{ + service: service, + logger: log.With(slog.String("handler", "bind")), + } +} + +// Register registers bind code routes. +func (h *BindHandler) Register(e *echo.Echo) { + e.POST("/users/me/bind_codes", h.Issue) +} + +type bindIssueRequest struct { + Platform string `json:"platform,omitempty"` + TTLSeconds int `json:"ttl_seconds,omitempty"` +} + +type bindIssueResponse struct { + Token string `json:"token"` + Platform string `json:"platform,omitempty"` + ExpiresAt time.Time `json:"expires_at"` +} + +// Issue creates a new bind code for the current user. +func (h *BindHandler) Issue(c echo.Context) error { + if h.service == nil { + return echo.NewHTTPError(http.StatusServiceUnavailable, "bind service not available") + } + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + + var req bindIssueRequest + if err := c.Bind(&req); err != nil && !errors.Is(err, io.EOF) { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + ttl := 24 * time.Hour + if req.TTLSeconds > 0 { + ttl = time.Duration(req.TTLSeconds) * time.Second + } + + code, err := h.service.Issue(c.Request().Context(), channelIdentityID, strings.TrimSpace(req.Platform), ttl) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, bindIssueResponse{ + Token: code.Token, + Platform: code.Platform, + ExpiresAt: code.ExpiresAt, + }) +} + +func (h *BindHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) + if err != nil { + return "", err + } + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { + return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + return channelIdentityID, nil +} diff --git a/internal/handlers/channel.go b/internal/handlers/channel.go index 3dd64f00..74655dd1 100644 --- a/internal/handlers/channel.go +++ b/internal/handlers/channel.go @@ -23,26 +23,26 @@ func NewChannelHandler(service *channel.Service, registry *channel.Registry) *Ch func (h *ChannelHandler) Register(e *echo.Echo) { group := e.Group("/users/me/channels") - group.GET("/:platform", h.GetUserConfig) - group.PUT("/:platform", h.UpsertUserConfig) + group.GET("/:platform", h.GetChannelIdentityConfig) + group.PUT("/:platform", h.UpsertChannelIdentityConfig) metaGroup := e.Group("/channels") metaGroup.GET("", h.ListChannels) metaGroup.GET("/:platform", h.GetChannel) } -// GetUserConfig godoc +// GetChannelIdentityConfig godoc // @Summary Get channel user config // @Description Get channel binding configuration for current user // @Tags channel // @Param platform path string true "Channel platform" -// @Success 200 {object} channel.ChannelUserBinding +// @Success 200 {object} channel.ChannelIdentityBinding // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /users/me/channels/{platform} [get] -func (h *ChannelHandler) GetUserConfig(c echo.Context) error { - userID, err := h.requireUserID(c) +func (h *ChannelHandler) GetChannelIdentityConfig(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -50,7 +50,7 @@ func (h *ChannelHandler) GetUserConfig(c echo.Context) error { if err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - resp, err := h.service.GetUserConfig(c.Request().Context(), userID, channelType) + resp, err := h.service.GetChannelIdentityConfig(c.Request().Context(), channelIdentityID, channelType) if err != nil { if strings.Contains(err.Error(), "not found") { return echo.NewHTTPError(http.StatusNotFound, err.Error()) @@ -60,18 +60,18 @@ func (h *ChannelHandler) GetUserConfig(c echo.Context) error { return c.JSON(http.StatusOK, resp) } -// UpsertUserConfig godoc +// UpsertChannelIdentityConfig godoc // @Summary Update channel user config // @Description Update channel binding configuration for current user // @Tags channel // @Param platform path string true "Channel platform" -// @Param payload body channel.UpsertUserConfigRequest true "Channel user config payload" -// @Success 200 {object} channel.ChannelUserBinding +// @Param payload body channel.UpsertChannelIdentityConfigRequest true "Channel user config payload" +// @Success 200 {object} channel.ChannelIdentityBinding // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /users/me/channels/{platform} [put] -func (h *ChannelHandler) UpsertUserConfig(c echo.Context) error { - userID, err := h.requireUserID(c) +func (h *ChannelHandler) UpsertChannelIdentityConfig(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -79,14 +79,14 @@ func (h *ChannelHandler) UpsertUserConfig(c echo.Context) error { if err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - var req channel.UpsertUserConfigRequest + var req channel.UpsertChannelIdentityConfigRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } if req.Config == nil { req.Config = map[string]any{} } - resp, err := h.service.UpsertUserConfig(c.Request().Context(), userID, channelType, req) + resp, err := h.service.UpsertChannelIdentityConfig(c.Request().Context(), channelIdentityID, channelType, req) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -160,13 +160,13 @@ func (h *ChannelHandler) GetChannel(c echo.Context) error { return c.JSON(http.StatusOK, resp) } -func (h *ChannelHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) +func (h *ChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return userID, nil + return channelIdentityID, nil } diff --git a/internal/handlers/chat.go b/internal/handlers/chat.go index b43e7f40..b8fc91aa 100644 --- a/internal/handlers/chat.go +++ b/internal/handlers/chat.go @@ -12,48 +12,73 @@ import ( "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/chat" "github.com/memohai/memoh/internal/identity" - "github.com/memohai/memoh/internal/users" ) +// ChatHandler handles chat CRUD, messaging, participants, settings, and routes. type ChatHandler struct { - resolver *chat.Resolver - botService *bots.Service - userService *users.Service - logger *slog.Logger + resolver *chat.Resolver + chatService *chat.Service + botService *bots.Service + accountService *accounts.Service + logger *slog.Logger } -func NewChatHandler(log *slog.Logger, resolver *chat.Resolver, botService *bots.Service, userService *users.Service) *ChatHandler { +// NewChatHandler creates a ChatHandler. +func NewChatHandler(log *slog.Logger, resolver *chat.Resolver, chatService *chat.Service, botService *bots.Service, accountService *accounts.Service) *ChatHandler { return &ChatHandler{ - resolver: resolver, - botService: botService, - userService: userService, - logger: log.With(slog.String("handler", "chat")), + resolver: resolver, + chatService: chatService, + botService: botService, + accountService: accountService, + logger: log.With(slog.String("handler", "chat")), } } +// Register registers all chat routes. func (h *ChatHandler) Register(e *echo.Echo) { - group := e.Group("/bots/:bot_id/chat") - group.POST("", h.Chat) - group.POST("/stream", h.StreamChat) + // Chat lifecycle (under bot). + botGroup := e.Group("/bots/:bot_id/chats") + botGroup.POST("", h.CreateChat) + botGroup.GET("", h.ListChats) + + // Chat operations. + chatGroup := e.Group("/chats/:chat_id") + chatGroup.GET("", h.GetChat) + chatGroup.DELETE("", h.DeleteChat) + + // Messages. + chatGroup.POST("/messages", h.SendMessage) + chatGroup.POST("/messages/stream", h.StreamMessage) + chatGroup.GET("/messages", h.ListMessages) + + // Participants. + chatGroup.GET("/participants", h.ListParticipants) + chatGroup.POST("/participants", h.AddParticipant) + chatGroup.DELETE("/participants/:user_id", h.RemoveParticipant) + + // Settings. + chatGroup.GET("/settings", h.GetSettings) + chatGroup.PUT("/settings", h.UpdateSettings) + + // Routes. + chatGroup.GET("/routes", h.ListRoutes) + chatGroup.POST("/routes", h.CreateRoute) + chatGroup.DELETE("/routes/:route_id", h.DeleteRoute) + + // Threads. + chatGroup.GET("/threads", h.ListThreads) } -// Chat godoc -// @Summary Chat with AI -// @Description Send a chat message and get a response. The system will automatically select an appropriate chat model from the database. -// @Tags chat -// @Accept json -// @Produce json -// @Param request body chat.ChatRequest true "Chat request" -// @Success 200 {object} chat.ChatResponse -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/chat [post] -func (h *ChatHandler) Chat(c echo.Context) error { - userID, err := h.requireUserID(c) +// --- Chat Lifecycle --- + +// CreateChat creates a new chat for a bot. +func (h *ChatHandler) CreateChat(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -61,137 +86,177 @@ func (h *ChatHandler) Chat(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } + var req chat.CreateChatRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + result, err := h.chatService.Create(c.Request().Context(), botID, channelIdentityID, req) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusCreated, result) +} + +// ListChats lists chats for a bot where the user has participant or observed access. +func (h *ChatHandler) ListChats(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { + return err + } + + chats, err := h.chatService.ListByBotAndChannelIdentity(c.Request().Context(), botID, channelIdentityID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, map[string]any{"items": chats}) +} + +// GetChat returns a chat by ID. +func (h *ChatHandler) GetChat(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + chatID := strings.TrimSpace(c.Param("chat_id")) + if err := h.requireReadable(c.Request().Context(), chatID, channelIdentityID); err != nil { + return err + } + + result, err := h.chatService.Get(c.Request().Context(), chatID) + if err != nil { + if errors.Is(err, chat.ErrChatNotFound) { + return echo.NewHTTPError(http.StatusNotFound, "chat not found") + } + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, result) +} + +// DeleteChat deletes a chat (owner only). +func (h *ChatHandler) DeleteChat(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + chatID := strings.TrimSpace(c.Param("chat_id")) + if err := h.requireRole(c.Request().Context(), chatID, channelIdentityID, chat.RoleOwner); err != nil { + return err + } + + if err := h.chatService.Delete(c.Request().Context(), chatID); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.NoContent(http.StatusNoContent) +} + +// --- Messages --- + +// SendMessage sends a synchronous chat message. +func (h *ChatHandler) SendMessage(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + chatID := strings.TrimSpace(c.Param("chat_id")) + if err := h.requireParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { + return err + } + + chatObj, err := h.chatService.Get(c.Request().Context(), chatID) + if err != nil { + return echo.NewHTTPError(http.StatusNotFound, "chat not found") + } + var req chat.ChatRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if req.Query == "" { return echo.NewHTTPError(http.StatusBadRequest, "query is required") } - req.BotID = botID - req.SessionID = sessionID + req.BotID = chatObj.BotID + req.ChatID = chatID req.Token = c.Request().Header.Get("Authorization") - req.UserID = userID - if strings.TrimSpace(req.ContactID) == "" { - req.ContactID = userID - } - if strings.TrimSpace(req.ContactName) == "" { - req.ContactName = "User" - } + req.ChannelIdentityID = channelIdentityID resp, err := h.resolver.Chat(c.Request().Context(), req) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - return c.JSON(http.StatusOK, resp) } -// StreamChat godoc -// @Summary Stream chat with AI -// @Description Send a chat message and get a streaming response. The system will automatically select an appropriate chat model from the database. -// @Tags chat -// @Accept json -// @Produce text/event-stream -// @Param request body chat.ChatRequest true "Chat request" -// @Success 200 {string} string -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/chat/stream [post] -func (h *ChatHandler) StreamChat(c echo.Context) error { - userID, err := h.requireUserID(c) +// StreamMessage sends a streaming chat message. +func (h *ChatHandler) StreamMessage(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - botID := strings.TrimSpace(c.Param("bot_id")) - h.logger.Info("chat stream request received", - slog.String("bot_id", botID), - slog.String("session_id", c.QueryParam("session_id")), - slog.String("user_id", userID), - ) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + chatID := strings.TrimSpace(c.Param("chat_id")) + if err := h.requireParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { return err } + chatObj, err := h.chatService.Get(c.Request().Context(), chatID) + if err != nil { + return echo.NewHTTPError(http.StatusNotFound, "chat not found") + } + var req chat.ChatRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if req.Query == "" { return echo.NewHTTPError(http.StatusBadRequest, "query is required") } - req.BotID = botID - req.SessionID = sessionID + req.BotID = chatObj.BotID + req.ChatID = chatID req.Token = c.Request().Header.Get("Authorization") - req.UserID = userID - if strings.TrimSpace(req.ContactID) == "" { - req.ContactID = userID - } - if strings.TrimSpace(req.ContactName) == "" { - req.ContactName = "User" - } + req.ChannelIdentityID = channelIdentityID - // Set headers for SSE c.Response().Header().Set(echo.HeaderContentType, "text/event-stream") c.Response().Header().Set(echo.HeaderCacheControl, "no-cache") c.Response().Header().Set(echo.HeaderConnection, "keep-alive") c.Response().WriteHeader(http.StatusOK) - // Get streaming channels chunkChan, errChan := h.resolver.StreamChat(c.Request().Context(), req) - - // Create a flusher flusher, ok := c.Response().Writer.(http.Flusher) if !ok { return echo.NewHTTPError(http.StatusInternalServerError, "streaming not supported") } - writer := bufio.NewWriter(c.Response().Writer) - // Stream chunks for { select { case chunk, ok := <-chunkChan: if !ok { - // Channel closed, send done message writer.WriteString("data: [DONE]\n\n") writer.Flush() flusher.Flush() return nil } - - // Marshal chunk to JSON data, err := json.Marshal(chunk) if err != nil { continue } - - // Write SSE format writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data))) writer.Flush() flusher.Flush() - case err := <-errChan: if err != nil { h.logger.Error("chat stream failed", slog.Any("error", err)) - // Send error as SSE event errData := map[string]string{"error": err.Error()} data, _ := json.Marshal(errData) writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data))) @@ -203,26 +268,257 @@ func (h *ChatHandler) StreamChat(c echo.Context) error { } } -func (h *ChatHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) +// ListMessages lists messages for a chat. +func (h *ChatHandler) ListMessages(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + chatID := strings.TrimSpace(c.Param("chat_id")) + if err := h.requireReadable(c.Request().Context(), chatID, channelIdentityID); err != nil { + return err + } + + messages, err := h.chatService.ListMessages(c.Request().Context(), chatID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, map[string]any{"items": messages}) +} + +// --- Participants --- + +// ListParticipants lists participants for a chat. +func (h *ChatHandler) ListParticipants(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + chatID := strings.TrimSpace(c.Param("chat_id")) + if err := h.requireParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { + return err + } + + participants, err := h.chatService.ListParticipants(c.Request().Context(), chatID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, map[string]any{"items": participants}) +} + +// AddParticipant adds a participant to a chat (owner/admin only). +func (h *ChatHandler) AddParticipant(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + chatID := strings.TrimSpace(c.Param("chat_id")) + if err := h.requireRole(c.Request().Context(), chatID, channelIdentityID, chat.RoleAdmin); err != nil { + return err + } + + var body struct { + UserID string `json:"user_id"` + Role string `json:"role"` + } + if err := c.Bind(&body); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + if strings.TrimSpace(body.UserID) == "" { + return echo.NewHTTPError(http.StatusBadRequest, "user_id is required") + } + + p, err := h.chatService.AddParticipant(c.Request().Context(), chatID, body.UserID, body.Role) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, p) +} + +// RemoveParticipant removes a participant from a chat. +func (h *ChatHandler) RemoveParticipant(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + chatID := strings.TrimSpace(c.Param("chat_id")) + if err := h.requireRole(c.Request().Context(), chatID, channelIdentityID, chat.RoleAdmin); err != nil { + return err + } + + targetUserID := strings.TrimSpace(c.Param("user_id")) + if targetUserID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "user_id is required") + } + + if err := h.chatService.RemoveParticipant(c.Request().Context(), chatID, targetUserID); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.NoContent(http.StatusNoContent) +} + +// --- Settings --- + +// GetSettings returns settings for a chat. +func (h *ChatHandler) GetSettings(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + chatID := strings.TrimSpace(c.Param("chat_id")) + if err := h.requireParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { + return err + } + + settings, err := h.chatService.GetSettings(c.Request().Context(), chatID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, settings) +} + +// UpdateSettings updates settings for a chat (owner/admin only). +func (h *ChatHandler) UpdateSettings(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + chatID := strings.TrimSpace(c.Param("chat_id")) + chatObj, err := h.chatService.Get(c.Request().Context(), chatID) + if err != nil { + if errors.Is(err, chat.ErrChatNotFound) { + return echo.NewHTTPError(http.StatusNotFound, "chat not found") + } + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + if chatObj.Kind == chat.KindGroup { + if _, err := h.authorizeBotManage(c.Request().Context(), channelIdentityID, chatObj.BotID); err != nil { + return err + } + } else { + if err := h.requireRole(c.Request().Context(), chatID, channelIdentityID, chat.RoleAdmin); err != nil { + return err + } + } + + var req chat.UpdateSettingsRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + settings, err := h.chatService.UpdateSettings(c.Request().Context(), chatID, req) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, settings) +} + +// --- Routes --- + +// ListRoutes lists routes for a chat. +func (h *ChatHandler) ListRoutes(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + chatID := strings.TrimSpace(c.Param("chat_id")) + if err := h.requireParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { + return err + } + + routes, err := h.chatService.ListRoutes(c.Request().Context(), chatID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, map[string]any{"items": routes}) +} + +// CreateRoute creates a new route for a chat (cross-channel). +func (h *ChatHandler) CreateRoute(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + chatID := strings.TrimSpace(c.Param("chat_id")) + if err := h.requireRole(c.Request().Context(), chatID, channelIdentityID, chat.RoleAdmin); err != nil { + return err + } + + var route chat.Route + if err := c.Bind(&route); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + result, err := h.chatService.CreateRoute(c.Request().Context(), chatID, route) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusCreated, result) +} + +// DeleteRoute deletes a route. +func (h *ChatHandler) DeleteRoute(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + chatID := strings.TrimSpace(c.Param("chat_id")) + if err := h.requireRole(c.Request().Context(), chatID, channelIdentityID, chat.RoleAdmin); err != nil { + return err + } + + routeID := strings.TrimSpace(c.Param("route_id")) + if routeID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "route_id is required") + } + if err := h.chatService.DeleteRoute(c.Request().Context(), routeID); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.NoContent(http.StatusNoContent) +} + +// --- Threads --- + +// ListThreads lists threads for a parent chat. +func (h *ChatHandler) ListThreads(c echo.Context) error { + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + chatID := strings.TrimSpace(c.Param("chat_id")) + if err := h.requireParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { + return err + } + + threads, err := h.chatService.ListThreads(c.Request().Context(), chatID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, map[string]any{"items": threads}) +} + +// --- helpers --- + +func (h *ChatHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return userID, nil + return channelIdentityID, nil } -func (h *ChatHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { +func (h *ChatHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: true}) + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: true}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") @@ -234,3 +530,99 @@ func (h *ChatHandler) authorizeBotAccess(ctx context.Context, actorID, botID str } return bot, nil } + +func (h *ChatHandler) authorizeBotManage(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") + } + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) + if err != nil { + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + if err != nil { + if errors.Is(err, bots.ErrBotNotFound) { + return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") + } + if errors.Is(err, bots.ErrBotAccessDenied) { + return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot management access denied") + } + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return bot, nil +} + +func (h *ChatHandler) requireParticipant(ctx context.Context, chatID, channelIdentityID string) error { + if h.chatService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured") + } + // Admin bypass. + if h.accountService != nil { + isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID) + if isAdmin { + return nil + } + } + ok, err := h.chatService.IsParticipant(ctx, chatID, channelIdentityID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + if !ok { + return echo.NewHTTPError(http.StatusForbidden, "not a participant") + } + return nil +} + +func (h *ChatHandler) requireReadable(ctx context.Context, chatID, channelIdentityID string) error { + if h.chatService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured") + } + // Admin bypass. + if h.accountService != nil { + isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID) + if isAdmin { + return nil + } + } + _, err := h.chatService.GetReadAccess(ctx, chatID, channelIdentityID) + if err != nil { + if errors.Is(err, chat.ErrPermissionDenied) { + return echo.NewHTTPError(http.StatusForbidden, "not allowed to read chat") + } + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return nil +} + +func (h *ChatHandler) requireRole(ctx context.Context, chatID, channelIdentityID, minRole string) error { + if h.chatService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured") + } + // Admin bypass. + if h.accountService != nil { + isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID) + if isAdmin { + return nil + } + } + p, err := h.chatService.GetParticipant(ctx, chatID, channelIdentityID) + if err != nil { + if errors.Is(err, chat.ErrNotParticipant) { + return echo.NewHTTPError(http.StatusForbidden, "not a participant") + } + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + if !roleAtLeast(p.Role, minRole) { + return echo.NewHTTPError(http.StatusForbidden, "insufficient permissions") + } + return nil +} + +func roleAtLeast(actual, required string) bool { + roleLevel := map[string]int{ + chat.RoleOwner: 3, + chat.RoleAdmin: 2, + chat.RoleMember: 1, + } + return roleLevel[actual] >= roleLevel[required] +} diff --git a/internal/handlers/contacts.go b/internal/handlers/contacts.go deleted file mode 100644 index bee73a31..00000000 --- a/internal/handlers/contacts.go +++ /dev/null @@ -1,183 +0,0 @@ -package handlers - -import ( - "context" - "errors" - "net/http" - "strings" - - "github.com/labstack/echo/v4" - - "github.com/memohai/memoh/internal/auth" - "github.com/memohai/memoh/internal/bots" - "github.com/memohai/memoh/internal/contacts" - "github.com/memohai/memoh/internal/identity" - "github.com/memohai/memoh/internal/users" -) - -type ContactsHandler struct { - service *contacts.Service - botService *bots.Service - userService *users.Service -} - -func NewContactsHandler(service *contacts.Service, botService *bots.Service, userService *users.Service) *ContactsHandler { - return &ContactsHandler{ - service: service, - botService: botService, - userService: userService, - } -} - -func (h *ContactsHandler) Register(e *echo.Echo) { - group := e.Group("/bots/:bot_id/contacts") - group.GET("", h.List) - group.GET("/:id", h.Get) - group.POST("", h.Create) - group.PATCH("/:id", h.Update) -} - -func (h *ContactsHandler) List(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - query := strings.TrimSpace(c.QueryParam("q")) - items, err := h.service.Search(c.Request().Context(), botID, query) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, map[string]any{"items": items}) -} - -func (h *ContactsHandler) Get(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - id := strings.TrimSpace(c.Param("id")) - if id == "" { - return echo.NewHTTPError(http.StatusBadRequest, "contact id is required") - } - item, err := h.service.GetByID(c.Request().Context(), id) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, item) -} - -func (h *ContactsHandler) Create(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - var req contacts.CreateRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - req.BotID = botID - item, err := h.service.Create(c.Request().Context(), req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, item) -} - -func (h *ContactsHandler) Update(c echo.Context) error { - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - id := strings.TrimSpace(c.Param("id")) - if id == "" { - return echo.NewHTTPError(http.StatusBadRequest, "contact id is required") - } - var req contacts.UpdateRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - - userID, err := h.requireUserID(c) - if err == nil { - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - item, err := h.service.Update(c.Request().Context(), id, req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, item) - } - - sessionToken, tokenErr := auth.SessionTokenFromContext(c) - if tokenErr != nil { - return err - } - if sessionToken.BotID != botID { - return echo.NewHTTPError(http.StatusForbidden, "session token mismatch") - } - if strings.TrimSpace(sessionToken.ContactID) == "" || sessionToken.ContactID != id { - return echo.NewHTTPError(http.StatusForbidden, "contact mismatch") - } - if req.Tags != nil || req.Status != nil { - return echo.NewHTTPError(http.StatusForbidden, "session token cannot update tags or status") - } - item, err := h.service.Update(c.Request().Context(), id, req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, item) -} - -func (h *ContactsHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) - if err != nil { - return "", err - } - if err := identity.ValidateUserID(userID); err != nil { - return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - return userID, nil -} - -func (h *ContactsHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") - } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) - if err != nil { - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) - if err != nil { - if errors.Is(err, bots.ErrBotNotFound) { - return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") - } - if errors.Is(err, bots.ErrBotAccessDenied) { - return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied") - } - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return bot, nil -} diff --git a/internal/handlers/containerd.go b/internal/handlers/containerd.go index bd2aaef7..603efa13 100644 --- a/internal/handlers/containerd.go +++ b/internal/handlers/containerd.go @@ -22,6 +22,7 @@ import ( "github.com/labstack/echo/v4" "github.com/opencontainers/runtime-spec/specs-go" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/config" @@ -29,21 +30,20 @@ import ( dbsqlc "github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/identity" "github.com/memohai/memoh/internal/mcp" - "github.com/memohai/memoh/internal/users" ) type ContainerdHandler struct { - service ctr.Service - cfg config.MCPConfig - namespace string - logger *slog.Logger - mcpMu sync.Mutex - mcpSess map[string]*mcpSession - mcpStdioMu sync.Mutex - mcpStdioSess map[string]*mcpStdioSession - botService *bots.Service - userService *users.Service - queries *dbsqlc.Queries + service ctr.Service + cfg config.MCPConfig + namespace string + logger *slog.Logger + mcpMu sync.Mutex + mcpSess map[string]*mcpSession + mcpStdioMu sync.Mutex + mcpStdioSess map[string]*mcpStdioSession + botService *bots.Service + accountService *accounts.Service + queries *dbsqlc.Queries } type CreateContainerRequest struct { @@ -95,17 +95,17 @@ type ListSnapshotsResponse struct { Snapshots []SnapshotInfo `json:"snapshots"` } -func NewContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string, botService *bots.Service, userService *users.Service, queries *dbsqlc.Queries) *ContainerdHandler { +func NewContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string, botService *bots.Service, accountService *accounts.Service, queries *dbsqlc.Queries) *ContainerdHandler { return &ContainerdHandler{ - service: service, - cfg: cfg, - namespace: namespace, - logger: log.With(slog.String("handler", "containerd")), - mcpSess: make(map[string]*mcpSession), - mcpStdioSess: make(map[string]*mcpStdioSession), - botService: botService, - userService: userService, - queries: queries, + service: service, + cfg: cfg, + namespace: namespace, + logger: log.With(slog.String("handler", "containerd")), + mcpSess: make(map[string]*mcpSession), + mcpStdioSess: make(map[string]*mcpStdioSession), + botService: botService, + accountService: accountService, + queries: queries, } } @@ -612,7 +612,7 @@ func (h *ContainerdHandler) ListSnapshots(c echo.Context) error { // requireBotAccess extracts bot_id from path, validates user auth, and authorizes bot access. func (h *ContainerdHandler) requireBotAccess(c echo.Context) (string, error) { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return "", err } @@ -620,32 +620,32 @@ func (h *ContainerdHandler) requireBotAccess(c echo.Context) (string, error) { if botID == "" { return "", echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return "", err } return botID, nil } -func (h *ContainerdHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) +func (h *ContainerdHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return userID, nil + return channelIdentityID, nil } -func (h *ContainerdHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { +func (h *ContainerdHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") @@ -780,8 +780,13 @@ func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string) // CleanupBotContainer removes the containerd container and DB record for a bot. func (h *ContainerdHandler) CleanupBotContainer(ctx context.Context, botID string) error { + h.logger.Info("CleanupBotContainer starting", slog.String("bot_id", botID)) containerID, err := h.botContainerID(ctx, botID) if err != nil { + h.logger.Warn("CleanupBotContainer: container not found for bot, cleaning up DB only", + slog.String("bot_id", botID), + slog.Any("error", err), + ) if h.queries != nil { if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { _ = h.queries.DeleteContainerByBotID(ctx, pgBotID) @@ -790,26 +795,41 @@ func (h *ContainerdHandler) CleanupBotContainer(ctx context.Context, botID strin return nil } + h.logger.Info("CleanupBotContainer: found container", + slog.String("bot_id", botID), + slog.String("container_id", containerID), + ) + if task, taskErr := h.service.GetTask(ctx, containerID); taskErr == nil { + h.logger.Info("CleanupBotContainer: removing network", slog.String("container_id", containerID)) _ = ctr.RemoveNetwork(ctx, task, containerID) } + h.logger.Info("CleanupBotContainer: stopping task", slog.String("container_id", containerID)) _ = h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{ Timeout: 5 * time.Second, Force: true, }) + h.logger.Info("CleanupBotContainer: deleting task", slog.String("container_id", containerID)) _ = h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true}) + h.logger.Info("CleanupBotContainer: deleting container", slog.String("container_id", containerID)) if err := h.service.DeleteContainer(ctx, containerID, &ctr.DeleteContainerOptions{ CleanupSnapshot: true, }); err != nil && !errdefs.IsNotFound(err) { + h.logger.Error("CleanupBotContainer: failed to delete container", + slog.String("container_id", containerID), + slog.Any("error", err), + ) return err } if h.queries != nil { + h.logger.Info("CleanupBotContainer: deleting container record from DB", slog.String("bot_id", botID)) if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { _ = h.queries.DeleteContainerByBotID(ctx, pgBotID) } } + h.logger.Info("CleanupBotContainer finished", slog.String("bot_id", botID)) return nil } diff --git a/internal/handlers/embeddings.go b/internal/handlers/embeddings.go index 4f4d411c..ab3ea517 100644 --- a/internal/handlers/embeddings.go +++ b/internal/handlers/embeddings.go @@ -85,10 +85,10 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error { req.Input.ImageURL = strings.TrimSpace(req.Input.ImageURL) req.Input.VideoURL = strings.TrimSpace(req.Input.VideoURL) - userID := "" + channelIdentityID := "" if c.Get("user") != nil { - if value, err := auth.UserIDFromContext(c); err == nil { - userID = value + if value, err := auth.ChannelIdentityIDFromContext(c); err == nil { + channelIdentityID = value } } result, err := h.resolver.Embed(c.Request().Context(), embeddings.Request{ @@ -101,7 +101,7 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error { ImageURL: req.Input.ImageURL, VideoURL: req.Input.VideoURL, }, - UserID: userID, + ChannelIdentityID: channelIdentityID, }) if err != nil { message := err.Error() diff --git a/internal/handlers/history.go b/internal/handlers/history.go deleted file mode 100644 index 5d72c145..00000000 --- a/internal/handlers/history.go +++ /dev/null @@ -1,254 +0,0 @@ -package handlers - -import ( - "context" - "errors" - "fmt" - "log/slog" - "net/http" - "strings" - - "github.com/labstack/echo/v4" - - "github.com/memohai/memoh/internal/auth" - "github.com/memohai/memoh/internal/bots" - "github.com/memohai/memoh/internal/history" - "github.com/memohai/memoh/internal/identity" - "github.com/memohai/memoh/internal/users" -) - -type HistoryHandler struct { - service *history.Service - botService *bots.Service - userService *users.Service - logger *slog.Logger -} - -func NewHistoryHandler(log *slog.Logger, service *history.Service, botService *bots.Service, userService *users.Service) *HistoryHandler { - return &HistoryHandler{ - service: service, - botService: botService, - userService: userService, - logger: log.With(slog.String("handler", "history")), - } -} - -func (h *HistoryHandler) Register(e *echo.Echo) { - group := e.Group("/bots/:bot_id/history") - group.POST("", h.Create) - group.GET("", h.List) - group.GET("/:id", h.Get) - group.DELETE("/:id", h.Delete) - group.DELETE("", h.DeleteAll) -} - -// Create godoc -// @Summary Create history record -// @Description Create a history record for current user -// @Tags history -// @Param payload body history.CreateRequest true "History payload" -// @Success 201 {object} history.Record -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/history [post] -func (h *HistoryHandler) Create(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - var req history.CreateRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - resp, err := h.service.Create(c.Request().Context(), botID, sessionID, req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusCreated, resp) -} - -// Get godoc -// @Summary Get history record -// @Description Get a history record by ID (must belong to current user) -// @Tags history -// @Param id path string true "History ID" -// @Success 200 {object} history.Record -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/history/{id} [get] -func (h *HistoryHandler) Get(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - id := c.Param("id") - if id == "" { - return echo.NewHTTPError(http.StatusBadRequest, "id is required") - } - record, err := h.service.Get(c.Request().Context(), id) - if err != nil { - return echo.NewHTTPError(http.StatusNotFound, err.Error()) - } - if record.BotID != botID { - return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") - } - return c.JSON(http.StatusOK, record) -} - -// List godoc -// @Summary List history records -// @Description List history records for current user -// @Tags history -// @Param limit query int false "Limit" -// @Success 200 {object} history.ListResponse -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/history [get] -func (h *HistoryHandler) List(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - limit := 0 - if raw := c.QueryParam("limit"); raw != "" { - if _, err := fmt.Sscanf(raw, "%d", &limit); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "invalid limit") - } - } - items, err := h.service.List(c.Request().Context(), botID, sessionID, limit) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, history.ListResponse{Items: items}) -} - -// Delete godoc -// @Summary Delete history record -// @Description Delete a history record by ID (must belong to current user) -// @Tags history -// @Param id path string true "History ID" -// @Success 204 "No Content" -// @Failure 400 {object} ErrorResponse -// @Failure 403 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/history/{id} [delete] -func (h *HistoryHandler) Delete(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - id := c.Param("id") - if id == "" { - return echo.NewHTTPError(http.StatusBadRequest, "id is required") - } - record, err := h.service.Get(c.Request().Context(), id) - if err != nil { - return echo.NewHTTPError(http.StatusNotFound, err.Error()) - } - if record.BotID != botID { - return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") - } - if err := h.service.Delete(c.Request().Context(), id); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.NoContent(http.StatusNoContent) -} - -// DeleteAll godoc -// @Summary Delete all history records -// @Description Delete all history records for current user -// @Tags history -// @Success 204 "No Content" -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/history [delete] -func (h *HistoryHandler) DeleteAll(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - if err := h.service.DeleteBySession(c.Request().Context(), botID, sessionID); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.NoContent(http.StatusNoContent) -} - -func (h *HistoryHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) - if err != nil { - return "", err - } - if err := identity.ValidateUserID(userID); err != nil { - return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - return userID, nil -} - -func (h *HistoryHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") - } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) - if err != nil { - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) - if err != nil { - if errors.Is(err, bots.ErrBotNotFound) { - return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") - } - if errors.Is(err, bots.ErrBotAccessDenied) { - return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied") - } - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return bot, nil -} \ No newline at end of file diff --git a/internal/handlers/local_channel.go b/internal/handlers/local_channel.go index 9100a014..d80a5ae9 100644 --- a/internal/handlers/local_channel.go +++ b/internal/handlers/local_channel.go @@ -10,37 +10,42 @@ import ( "strings" "time" - "github.com/google/uuid" "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/channel" "github.com/memohai/memoh/internal/channel/adapters/local" + "github.com/memohai/memoh/internal/chat" "github.com/memohai/memoh/internal/identity" - "github.com/memohai/memoh/internal/users" ) +// LocalChannelHandler handles local channel (CLI/Web) sessions backed by chats. type LocalChannelHandler struct { channelType channel.ChannelType channelManager *channel.Manager channelService *channel.Service + chatService *chat.Service sessionHub *local.SessionHub botService *bots.Service - userService *users.Service + accountService *accounts.Service } -func NewLocalChannelHandler(channelType channel.ChannelType, channelManager *channel.Manager, channelService *channel.Service, sessionHub *local.SessionHub, botService *bots.Service, userService *users.Service) *LocalChannelHandler { +// NewLocalChannelHandler creates a local channel handler. +func NewLocalChannelHandler(channelType channel.ChannelType, channelManager *channel.Manager, channelService *channel.Service, chatService *chat.Service, sessionHub *local.SessionHub, botService *bots.Service, accountService *accounts.Service) *LocalChannelHandler { return &LocalChannelHandler{ channelType: channelType, channelManager: channelManager, channelService: channelService, + chatService: chatService, sessionHub: sessionHub, botService: botService, - userService: userService, + accountService: accountService, } } +// Register registers the local channel routes. func (h *LocalChannelHandler) Register(e *echo.Echo) { prefix := fmt.Sprintf("/bots/:bot_id/%s", h.channelType.String()) group := e.Group(prefix) @@ -51,11 +56,13 @@ func (h *LocalChannelHandler) Register(e *echo.Echo) { type localSessionResponse struct { SessionID string `json:"session_id"` + ChatID string `json:"chat_id"` StreamURL string `json:"stream_url"` } +// CreateSession creates a new local chat session. func (h *LocalChannelHandler) CreateSession(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -63,22 +70,27 @@ func (h *LocalChannelHandler) CreateSession(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } - if h.channelService == nil { - return echo.NewHTTPError(http.StatusInternalServerError, "channel service not configured") - } - sessionID := fmt.Sprintf("%s:%s", h.channelType.String(), uuid.NewString()) - if err := h.channelService.UpsertChannelSession(c.Request().Context(), sessionID, botID, "", userID, "", h.channelType.String(), sessionID, "", nil); err != nil { + + // Create a chat as the underlying container. + chatObj, err := h.chatService.Create(c.Request().Context(), botID, channelIdentityID, chat.CreateChatRequest{ + Kind: chat.KindDirect, + }) + if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } + + // Use chat_id as the session_id for the local hub. + sessionID := chatObj.ID streamURL := fmt.Sprintf("/bots/%s/%s/sessions/%s/stream", botID, h.channelType.String(), sessionID) - return c.JSON(http.StatusOK, localSessionResponse{SessionID: sessionID, StreamURL: streamURL}) + return c.JSON(http.StatusOK, localSessionResponse{SessionID: sessionID, ChatID: chatObj.ID, StreamURL: streamURL}) } +// StreamSession streams responses for a local session. func (h *LocalChannelHandler) StreamSession(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -90,10 +102,10 @@ func (h *LocalChannelHandler) StreamSession(c echo.Context) error { if sessionID == "" { return echo.NewHTTPError(http.StatusBadRequest, "session id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } - if err := h.ensureSessionOwner(c.Request().Context(), botID, sessionID, userID); err != nil { + if err := h.ensureChatParticipant(c.Request().Context(), sessionID, channelIdentityID); err != nil { return err } if h.sessionHub == nil { @@ -141,8 +153,9 @@ type localMessageRequest struct { Message channel.Message `json:"message"` } +// PostMessage sends a message through the local channel. func (h *LocalChannelHandler) PostMessage(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -154,10 +167,10 @@ func (h *LocalChannelHandler) PostMessage(c echo.Context) error { if sessionID == "" { return echo.NewHTTPError(http.StatusBadRequest, "session id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } - if err := h.ensureSessionOwner(c.Request().Context(), botID, sessionID, userID); err != nil { + if err := h.ensureChatParticipant(c.Request().Context(), sessionID, channelIdentityID); err != nil { return err } if h.channelManager == nil || h.channelService == nil { @@ -182,9 +195,9 @@ func (h *LocalChannelHandler) PostMessage(c echo.Context) error { ReplyTarget: sessionID, SessionKey: sessionID, Sender: channel.Identity{ - ExternalID: userID, + SubjectID: channelIdentityID, Attributes: map[string]string{ - "user_id": userID, + "user_id": channelIdentityID, }, }, Conversation: channel.Conversation{ @@ -200,46 +213,40 @@ func (h *LocalChannelHandler) PostMessage(c echo.Context) error { return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) } -func (h *LocalChannelHandler) ensureSessionOwner(ctx context.Context, botID, sessionID, userID string) error { - if h.channelService == nil { - return echo.NewHTTPError(http.StatusInternalServerError, "channel service not configured") +func (h *LocalChannelHandler) ensureChatParticipant(ctx context.Context, chatID, channelIdentityID string) error { + if h.chatService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured") } - session, err := h.channelService.GetChannelSession(ctx, sessionID) + ok, err := h.chatService.IsParticipant(ctx, chatID, channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - if strings.TrimSpace(session.SessionID) == "" { - return echo.NewHTTPError(http.StatusNotFound, "session not found") - } - if session.BotID != botID { - return echo.NewHTTPError(http.StatusForbidden, "session access denied") - } - if session.UserID != userID { - return echo.NewHTTPError(http.StatusForbidden, "session access denied") + if !ok { + return echo.NewHTTPError(http.StatusForbidden, "chat access denied") } return nil } -func (h *LocalChannelHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) +func (h *LocalChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return userID, nil + return channelIdentityID, nil } -func (h *LocalChannelHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { +func (h *LocalChannelHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: true}) + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: true}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") diff --git a/internal/handlers/mcp.go b/internal/handlers/mcp.go index bfc5d11d..0a512db2 100644 --- a/internal/handlers/mcp.go +++ b/internal/handlers/mcp.go @@ -10,26 +10,26 @@ import ( "github.com/jackc/pgx/v5" "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/identity" "github.com/memohai/memoh/internal/mcp" - "github.com/memohai/memoh/internal/users" ) type MCPHandler struct { - service *mcp.ConnectionService - botService *bots.Service - userService *users.Service - logger *slog.Logger + service *mcp.ConnectionService + botService *bots.Service + accountService *accounts.Service + logger *slog.Logger } -func NewMCPHandler(log *slog.Logger, service *mcp.ConnectionService, botService *bots.Service, userService *users.Service) *MCPHandler { +func NewMCPHandler(log *slog.Logger, service *mcp.ConnectionService, botService *bots.Service, accountService *accounts.Service) *MCPHandler { return &MCPHandler{ - service: service, - botService: botService, - userService: userService, - logger: log.With(slog.String("handler", "mcp")), + service: service, + botService: botService, + accountService: accountService, + logger: log.With(slog.String("handler", "mcp")), } } @@ -53,7 +53,7 @@ func (h *MCPHandler) Register(e *echo.Echo) { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/mcp [get] func (h *MCPHandler) List(c echo.Context) error { - userID, err := h.requireUserID(c) + userID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -83,7 +83,7 @@ func (h *MCPHandler) List(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/mcp [post] func (h *MCPHandler) Create(c echo.Context) error { - userID, err := h.requireUserID(c) + userID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -117,7 +117,7 @@ func (h *MCPHandler) Create(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/mcp/{id} [get] func (h *MCPHandler) Get(c echo.Context) error { - userID, err := h.requireUserID(c) + userID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -155,7 +155,7 @@ func (h *MCPHandler) Get(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/mcp/{id} [put] func (h *MCPHandler) Update(c echo.Context) error { - userID, err := h.requireUserID(c) + userID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -196,7 +196,7 @@ func (h *MCPHandler) Update(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/mcp/{id} [delete] func (h *MCPHandler) Delete(c echo.Context) error { - userID, err := h.requireUserID(c) + userID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -217,26 +217,26 @@ func (h *MCPHandler) Delete(c echo.Context) error { return c.NoContent(http.StatusNoContent) } -func (h *MCPHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) +func (h *MCPHandler) requireChannelIdentityID(c echo.Context) (string, error) { + userID, err := auth.ChannelIdentityIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(userID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } return userID, nil } -func (h *MCPHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { +func (h *MCPHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") diff --git a/internal/handlers/memory.go b/internal/handlers/memory.go index ffa731cc..b63cbebb 100644 --- a/internal/handlers/memory.go +++ b/internal/handlers/memory.go @@ -2,31 +2,32 @@ package handlers import ( "context" - "errors" - "fmt" "log/slog" "net/http" + "sort" "strings" "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" - "github.com/memohai/memoh/internal/bots" + "github.com/memohai/memoh/internal/chat" "github.com/memohai/memoh/internal/identity" "github.com/memohai/memoh/internal/memory" - "github.com/memohai/memoh/internal/users" ) +// MemoryHandler handles memory CRUD operations scoped by chat. type MemoryHandler struct { - service *memory.Service - botService *bots.Service - userService *users.Service - logger *slog.Logger + service *memory.Service + chatService *chat.Service + accountService *accounts.Service + logger *slog.Logger } type memoryAddPayload struct { Message string `json:"message,omitempty"` Messages []memory.Message `json:"messages,omitempty"` + Namespace string `json:"namespace,omitempty"` RunID string `json:"run_id,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` Filters map[string]any `json:"filters,omitempty"` @@ -43,40 +44,29 @@ type memorySearchPayload struct { EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` } -type memoryEmbedUpsertPayload struct { - Type string `json:"type"` - Provider string `json:"provider,omitempty"` - Model string `json:"model,omitempty"` - Input memory.EmbedInput `json:"input"` - Source string `json:"source,omitempty"` - RunID string `json:"run_id,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` - Filters map[string]any `json:"filters,omitempty"` +// namespaceScope holds namespace + scopeId for a single memory scope. +type namespaceScope struct { + Namespace string + ScopeID string } -type memoryDeleteAllPayload struct { - RunID string `json:"run_id,omitempty"` -} - -func NewMemoryHandler(log *slog.Logger, service *memory.Service, botService *bots.Service, userService *users.Service) *MemoryHandler { +// NewMemoryHandler creates a MemoryHandler. +func NewMemoryHandler(log *slog.Logger, service *memory.Service, chatService *chat.Service, accountService *accounts.Service) *MemoryHandler { return &MemoryHandler{ - service: service, - botService: botService, - userService: userService, - logger: log.With(slog.String("handler", "memory")), + service: service, + chatService: chatService, + accountService: accountService, + logger: log.With(slog.String("handler", "memory")), } } +// Register registers chat-level memory routes. func (h *MemoryHandler) Register(e *echo.Echo) { - group := e.Group("/bots/:bot_id/memory") - group.POST("/add", h.Add) - group.POST("/embed", h.EmbedUpsert) - group.POST("/search", h.Search) - group.POST("/update", h.Update) - group.GET("/memories/:memoryId", h.Get) - group.GET("/memories", h.GetAll) - group.DELETE("/memories/:memoryId", h.Delete) - group.DELETE("/memories", h.DeleteAll) + chatGroup := e.Group("/chats/:chat_id/memory") + chatGroup.POST("", h.ChatAdd) + chatGroup.POST("/search", h.ChatSearch) + chatGroup.GET("", h.ChatGetAll) + chatGroup.DELETE("", h.ChatDeleteAll) } func (h *MemoryHandler) checkService() error { @@ -86,106 +76,52 @@ func (h *MemoryHandler) checkService() error { return nil } -// EmbedUpsert godoc -// @Summary Embed and upsert memory -// @Description Embed text or multimodal input and upsert into memory store. Auth: Bearer JWT determines user_id (sub or user_id). -// @Tags memory -// @Param payload body memoryEmbedUpsertPayload true "Embed upsert request" -// @Success 200 {object} memory.EmbedUpsertResponse -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/embed [post] -func (h *MemoryHandler) EmbedUpsert(c echo.Context) error { +// --- Chat-level memory endpoints --- + +// ChatAdd adds memory to a specific namespace (validated against chat_settings). +func (h *MemoryHandler) ChatAdd(c echo.Context) error { if err := h.checkService(); err != nil { return err } - - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + chatID := strings.TrimSpace(c.Param("chat_id")) + if chatID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "chat_id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if err := h.requireChatParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { return err } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") - } - - var payload memoryEmbedUpsertPayload - if err := c.Bind(&payload); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - req := memory.EmbedUpsertRequest{ - Type: payload.Type, - Provider: payload.Provider, - Model: payload.Model, - Input: payload.Input, - Source: payload.Source, - BotID: botID, - SessionID: sessionID, - RunID: payload.RunID, - Metadata: payload.Metadata, - Filters: payload.Filters, - } - - resp, err := h.service.EmbedUpsert(c.Request().Context(), req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, resp) -} - -// Add godoc -// @Summary Add memory -// @Description Add memory for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). -// @Tags memory -// @Param payload body memoryAddPayload true "Add request" -// @Success 200 {object} memory.SearchResponse -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/add [post] -func (h *MemoryHandler) Add(c echo.Context) error { - if err := h.checkService(); err != nil { - return err - } - - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") - } var payload memoryAddPayload if err := c.Bind(&payload); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } + + namespace := strings.TrimSpace(payload.Namespace) + if namespace == "" { + namespace = "chat" + } + + // Resolve correct scopeId/botId and validate namespace is enabled. + scopeID, botID, err := h.resolveWriteScope(c.Request().Context(), chatID, channelIdentityID, namespace) + if err != nil { + return err + } + + filters := buildNamespaceFilters(namespace, scopeID, payload.Filters) req := memory.AddRequest{ Message: payload.Message, Messages: payload.Messages, BotID: botID, - SessionID: sessionID, RunID: payload.RunID, Metadata: payload.Metadata, - Filters: payload.Filters, + Filters: filters, Infer: payload.Infer, EmbeddingEnabled: payload.EmbeddingEnabled, } - resp, err := h.service.Add(c.Request().Context(), req) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) @@ -193,318 +129,263 @@ func (h *MemoryHandler) Add(c echo.Context) error { return c.JSON(http.StatusOK, resp) } -// Search godoc -// @Summary Search memories -// @Description Search memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). -// @Tags memory -// @Param payload body memorySearchPayload true "Search request" -// @Success 200 {object} memory.SearchResponse -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/search [post] -func (h *MemoryHandler) Search(c echo.Context) error { +// ChatSearch searches memory across all enabled namespaces per chat_settings. +func (h *MemoryHandler) ChatSearch(c echo.Context) error { if err := h.checkService(); err != nil { return err } - - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + chatID := strings.TrimSpace(c.Param("chat_id")) + if chatID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "chat_id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if err := h.requireChatParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { return err } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") - } var payload memorySearchPayload if err := c.Bind(&payload); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - req := memory.SearchRequest{ - Query: payload.Query, - BotID: botID, - SessionID: sessionID, - RunID: payload.RunID, - Limit: payload.Limit, - Filters: payload.Filters, - Sources: payload.Sources, - EmbeddingEnabled: payload.EmbeddingEnabled, - } - resp, err := h.service.Search(c.Request().Context(), req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, resp) -} - -// Update godoc -// @Summary Update memory -// @Description Update a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). -// @Tags memory -// @Param payload body memory.UpdateRequest true "Update request" -// @Success 200 {object} memory.MemoryItem -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/update [post] -func (h *MemoryHandler) Update(c echo.Context) error { - if err := h.checkService(); err != nil { - return err - } - - userID, err := h.requireUserID(c) + scopes, err := h.resolveEnabledScopes(c.Request().Context(), chatID, channelIdentityID) if err != nil { return err } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - var req memory.UpdateRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - if req.MemoryID != "" { - existing, err := h.service.Get(c.Request().Context(), req.MemoryID) + // Search across all enabled namespaces and merge results. + var allResults []memory.MemoryItem + for _, scope := range scopes { + filters := buildNamespaceFilters(scope.Namespace, scope.ScopeID, payload.Filters) + req := memory.SearchRequest{ + Query: payload.Query, + RunID: payload.RunID, + Limit: payload.Limit, + Filters: filters, + Sources: payload.Sources, + EmbeddingEnabled: payload.EmbeddingEnabled, + } + resp, err := h.service.Search(c.Request().Context(), req) if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + h.logger.Warn("search namespace failed", slog.String("namespace", scope.Namespace), slog.Any("error", err)) + continue } - if existing.BotID != "" && existing.BotID != botID { - return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") + allResults = append(allResults, resp.Results...) + } + + // Deduplicate by ID and sort by score descending. + allResults = deduplicateMemoryItems(allResults) + sort.Slice(allResults, func(i, j int) bool { + return allResults[i].Score > allResults[j].Score + }) + if payload.Limit > 0 && len(allResults) > payload.Limit { + allResults = allResults[:payload.Limit] + } + + return c.JSON(http.StatusOK, memory.SearchResponse{Results: allResults}) +} + +// ChatGetAll lists all memories across enabled namespaces. +func (h *MemoryHandler) ChatGetAll(c echo.Context) error { + if err := h.checkService(); err != nil { + return err + } + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + chatID := strings.TrimSpace(c.Param("chat_id")) + if chatID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "chat_id is required") + } + if err := h.requireChatParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { + return err + } + + scopes, err := h.resolveEnabledScopes(c.Request().Context(), chatID, channelIdentityID) + if err != nil { + return err + } + + var allResults []memory.MemoryItem + for _, scope := range scopes { + req := memory.GetAllRequest{ + Filters: buildNamespaceFilters(scope.Namespace, scope.ScopeID, nil), + } + resp, err := h.service.GetAll(c.Request().Context(), req) + if err != nil { + h.logger.Warn("getall namespace failed", slog.String("namespace", scope.Namespace), slog.Any("error", err)) + continue + } + allResults = append(allResults, resp.Results...) + } + allResults = deduplicateMemoryItems(allResults) + + return c.JSON(http.StatusOK, memory.SearchResponse{Results: allResults}) +} + +// ChatDeleteAll deletes all memories across enabled namespaces. +func (h *MemoryHandler) ChatDeleteAll(c echo.Context) error { + if err := h.checkService(); err != nil { + return err + } + channelIdentityID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + chatID := strings.TrimSpace(c.Param("chat_id")) + if chatID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "chat_id is required") + } + if err := h.requireChatParticipant(c.Request().Context(), chatID, channelIdentityID); err != nil { + return err + } + + scopes, err := h.resolveEnabledScopes(c.Request().Context(), chatID, channelIdentityID) + if err != nil { + return err + } + + for _, scope := range scopes { + req := memory.DeleteAllRequest{ + Filters: buildNamespaceFilters(scope.Namespace, scope.ScopeID, nil), + } + if _, err := h.service.DeleteAll(c.Request().Context(), req); err != nil { + h.logger.Warn("deleteall namespace failed", slog.String("namespace", scope.Namespace), slog.Any("error", err)) } } - - resp, err := h.service.Update(c.Request().Context(), req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, resp) + return c.JSON(http.StatusOK, memory.DeleteResponse{Message: "Memory deleted successfully!"}) } -// Get godoc -// @Summary Get memory -// @Description Get a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). -// @Tags memory -// @Param memoryId path string true "Memory ID" -// @Success 200 {object} memory.MemoryItem -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/memories/{memoryId} [get] -func (h *MemoryHandler) Get(c echo.Context) error { - if err := h.checkService(); err != nil { - return err - } +// --- helpers --- - userID, err := h.requireUserID(c) +// resolveEnabledScopes returns all namespace scopes enabled by chat_settings. +func (h *MemoryHandler) resolveEnabledScopes(ctx context.Context, chatID, channelIdentityID string) ([]namespaceScope, error) { + if h.chatService == nil { + return nil, echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured") + } + chatObj, err := h.chatService.Get(ctx, chatID) if err != nil { - return err + return nil, echo.NewHTTPError(http.StatusNotFound, "chat not found") } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - - memoryID := c.Param("memoryId") - if memoryID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "memory ID required") - } - - resp, err := h.service.Get(c.Request().Context(), memoryID) + settings, err := h.chatService.GetSettings(ctx, chatID) if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - if resp.BotID != "" && resp.BotID != botID { - return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") + + var scopes []namespaceScope + if settings.EnableChatMemory { + scopes = append(scopes, namespaceScope{Namespace: "chat", ScopeID: chatID}) } - return c.JSON(http.StatusOK, resp) + if settings.EnablePrivateMemory && strings.TrimSpace(channelIdentityID) != "" { + scopes = append(scopes, namespaceScope{Namespace: "private", ScopeID: channelIdentityID}) + } + if settings.EnablePublicMemory { + scopes = append(scopes, namespaceScope{Namespace: "public", ScopeID: chatObj.BotID}) + } + if len(scopes) == 0 { + scopes = append(scopes, namespaceScope{Namespace: "chat", ScopeID: chatID}) + } + return scopes, nil } -// GetAll godoc -// @Summary List memories -// @Description List memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). -// @Tags memory -// @Param run_id query string false "Run ID" -// @Param limit query int false "Limit" -// @Success 200 {object} memory.SearchResponse -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/memories [get] -func (h *MemoryHandler) GetAll(c echo.Context) error { - if err := h.checkService(); err != nil { - return err +// resolveWriteScope validates namespace and returns (scopeId, botId). +func (h *MemoryHandler) resolveWriteScope(ctx context.Context, chatID, channelIdentityID, namespace string) (string, string, error) { + if h.chatService == nil { + return "", "", echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured") } - - userID, err := h.requireUserID(c) + chatObj, err := h.chatService.Get(ctx, chatID) if err != nil { - return err + return "", "", echo.NewHTTPError(http.StatusNotFound, "chat not found") } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") + settings, err := h.chatService.GetSettings(ctx, chatID) + if err != nil { + return "", "", echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - req := memory.GetAllRequest{ - BotID: botID, - SessionID: sessionID, - AgentID: c.QueryParam("agent_id"), - RunID: c.QueryParam("run_id"), + switch namespace { + case "chat": + if !settings.EnableChatMemory { + return "", "", echo.NewHTTPError(http.StatusForbidden, "chat memory is disabled for this chat") + } + return chatID, chatObj.BotID, nil + case "private": + if !settings.EnablePrivateMemory { + return "", "", echo.NewHTTPError(http.StatusForbidden, "private memory is disabled for this chat") + } + if strings.TrimSpace(channelIdentityID) == "" { + return "", "", echo.NewHTTPError(http.StatusBadRequest, "channel_identity_id required for private namespace") + } + return channelIdentityID, chatObj.BotID, nil + case "public": + if !settings.EnablePublicMemory { + return "", "", echo.NewHTTPError(http.StatusForbidden, "public memory is disabled for this chat") + } + return chatObj.BotID, chatObj.BotID, nil + default: + return "", "", echo.NewHTTPError(http.StatusBadRequest, "invalid namespace: "+namespace) } - if limit := c.QueryParam("limit"); limit != "" { - var parsed int - if _, err := fmt.Sscanf(limit, "%d", &parsed); err == nil { - req.Limit = parsed +} + +func buildNamespaceFilters(namespace, scopeID string, extra map[string]any) map[string]any { + filters := map[string]any{ + "namespace": namespace, + "scopeId": scopeID, + } + for k, v := range extra { + if k != "namespace" && k != "scopeId" { + filters[k] = v } } - - resp, err := h.service.GetAll(c.Request().Context(), req) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, resp) + return filters } -// Delete godoc -// @Summary Delete memory -// @Description Delete a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). -// @Tags memory -// @Param memoryId path string true "Memory ID" -// @Success 200 {object} memory.DeleteResponse -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/memories/{memoryId} [delete] -func (h *MemoryHandler) Delete(c echo.Context) error { - if err := h.checkService(); err != nil { - return err +func deduplicateMemoryItems(items []memory.MemoryItem) []memory.MemoryItem { + if len(items) == 0 { + return items } - - userID, err := h.requireUserID(c) - if err != nil { - return err + seen := make(map[string]struct{}, len(items)) + result := make([]memory.MemoryItem, 0, len(items)) + for _, item := range items { + if _, ok := seen[item.ID]; ok { + continue + } + seen[item.ID] = struct{}{} + result = append(result, item) } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - - memoryID := c.Param("memoryId") - if memoryID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "memory ID required") - } - - existing, err := h.service.Get(c.Request().Context(), memoryID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if existing.BotID != "" && existing.BotID != botID { - return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") - } - - resp, err := h.service.Delete(c.Request().Context(), memoryID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, resp) + return result } -// DeleteAll godoc -// @Summary Delete memories -// @Description Delete all memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). -// @Tags memory -// @Param payload body memoryDeleteAllPayload true "Delete all request" -// @Success 200 {object} memory.DeleteResponse -// @Failure 400 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/memories [delete] -func (h *MemoryHandler) DeleteAll(c echo.Context) error { - if err := h.checkService(); err != nil { - return err +func (h *MemoryHandler) requireChatParticipant(ctx context.Context, chatID, channelIdentityID string) error { + if h.chatService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured") } - - userID, err := h.requireUserID(c) - if err != nil { - return err + if h.accountService != nil { + isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID) + if isAdmin { + return nil + } } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - sessionID := strings.TrimSpace(c.QueryParam("session_id")) - if sessionID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "session_id is required") - } - - var payload memoryDeleteAllPayload - if err := c.Bind(&payload); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - req := memory.DeleteAllRequest{ - BotID: botID, - SessionID: sessionID, - RunID: payload.RunID, - } - - resp, err := h.service.DeleteAll(c.Request().Context(), req) + ok, err := h.chatService.IsParticipant(ctx, chatID, channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - return c.JSON(http.StatusOK, resp) + if !ok { + return echo.NewHTTPError(http.StatusForbidden, "not a chat participant") + } + return nil } -func (h *MemoryHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) +func (h *MemoryHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return userID, nil -} - -func (h *MemoryHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") - } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) - if err != nil { - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) - if err != nil { - if errors.Is(err, bots.ErrBotNotFound) { - return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") - } - if errors.Is(err, bots.ErrBotAccessDenied) { - return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied") - } - return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return bot, nil + return channelIdentityID, nil } diff --git a/internal/handlers/models.go b/internal/handlers/models.go index 0dbd8af4..357845f1 100644 --- a/internal/handlers/models.go +++ b/internal/handlers/models.go @@ -164,7 +164,7 @@ func (h *ModelsHandler) Enable(c echo.Context) error { if h.settingsService == nil { return echo.NewHTTPError(http.StatusInternalServerError, "settings service not configured") } - userID, err := auth.UserIDFromContext(c) + userID, err := auth.ChannelIdentityIDFromContext(c) if err != nil { return err } diff --git a/internal/handlers/preauth.go b/internal/handlers/preauth.go index 4b0c965b..5213859f 100644 --- a/internal/handlers/preauth.go +++ b/internal/handlers/preauth.go @@ -9,24 +9,24 @@ import ( "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/identity" "github.com/memohai/memoh/internal/preauth" - "github.com/memohai/memoh/internal/users" ) type PreauthHandler struct { - service *preauth.Service - botService *bots.Service - userService *users.Service + service *preauth.Service + botService *bots.Service + accountService *accounts.Service } -func NewPreauthHandler(service *preauth.Service, botService *bots.Service, userService *users.Service) *PreauthHandler { +func NewPreauthHandler(service *preauth.Service, botService *bots.Service, accountService *accounts.Service) *PreauthHandler { return &PreauthHandler{ - service: service, - botService: botService, - userService: userService, + service: service, + botService: botService, + accountService: accountService, } } @@ -40,7 +40,7 @@ type preauthIssueRequest struct { } func (h *PreauthHandler) Issue(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -48,7 +48,7 @@ func (h *PreauthHandler) Issue(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } var req preauthIssueRequest @@ -59,33 +59,33 @@ func (h *PreauthHandler) Issue(c echo.Context) error { if req.TTLSeconds > 0 { ttl = time.Duration(req.TTLSeconds) * time.Second } - key, err := h.service.Issue(c.Request().Context(), botID, userID, ttl) + key, err := h.service.Issue(c.Request().Context(), botID, channelIdentityID, ttl) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.JSON(http.StatusOK, key) } -func (h *PreauthHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) +func (h *PreauthHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return userID, nil + return channelIdentityID, nil } -func (h *PreauthHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { +func (h *PreauthHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") diff --git a/internal/handlers/schedule.go b/internal/handlers/schedule.go index fa56f520..4784346a 100644 --- a/internal/handlers/schedule.go +++ b/internal/handlers/schedule.go @@ -9,26 +9,26 @@ import ( "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/identity" "github.com/memohai/memoh/internal/schedule" - "github.com/memohai/memoh/internal/users" ) type ScheduleHandler struct { - service *schedule.Service - botService *bots.Service - userService *users.Service - logger *slog.Logger + service *schedule.Service + botService *bots.Service + accountService *accounts.Service + logger *slog.Logger } -func NewScheduleHandler(log *slog.Logger, service *schedule.Service, botService *bots.Service, userService *users.Service) *ScheduleHandler { +func NewScheduleHandler(log *slog.Logger, service *schedule.Service, botService *bots.Service, accountService *accounts.Service) *ScheduleHandler { return &ScheduleHandler{ - service: service, - botService: botService, - userService: userService, - logger: log.With(slog.String("handler", "schedule")), + service: service, + botService: botService, + accountService: accountService, + logger: log.With(slog.String("handler", "schedule")), } } @@ -51,7 +51,7 @@ func (h *ScheduleHandler) Register(e *echo.Echo) { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/schedule [post] func (h *ScheduleHandler) Create(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -59,7 +59,7 @@ func (h *ScheduleHandler) Create(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } var req schedule.CreateRequest @@ -82,7 +82,7 @@ func (h *ScheduleHandler) Create(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/schedule [get] func (h *ScheduleHandler) List(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -90,7 +90,7 @@ func (h *ScheduleHandler) List(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } items, err := h.service.List(c.Request().Context(), botID) @@ -111,7 +111,7 @@ func (h *ScheduleHandler) List(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/schedule/{id} [get] func (h *ScheduleHandler) Get(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -130,7 +130,7 @@ func (h *ScheduleHandler) Get(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } return c.JSON(http.StatusOK, item) @@ -147,7 +147,7 @@ func (h *ScheduleHandler) Get(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/schedule/{id} [put] func (h *ScheduleHandler) Update(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -170,7 +170,7 @@ func (h *ScheduleHandler) Update(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } resp, err := h.service.Update(c.Request().Context(), id, req) @@ -190,7 +190,7 @@ func (h *ScheduleHandler) Update(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/schedule/{id} [delete] func (h *ScheduleHandler) Delete(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -209,7 +209,7 @@ func (h *ScheduleHandler) Delete(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } if err := h.service.Delete(c.Request().Context(), id); err != nil { @@ -218,26 +218,26 @@ func (h *ScheduleHandler) Delete(c echo.Context) error { return c.NoContent(http.StatusNoContent) } -func (h *ScheduleHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) +func (h *ScheduleHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return userID, nil + return channelIdentityID, nil } -func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { +func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") @@ -248,4 +248,4 @@ func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, actorID, botID return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return bot, nil -} \ No newline at end of file +} diff --git a/internal/handlers/settings.go b/internal/handlers/settings.go index ad902319..78e7572c 100644 --- a/internal/handlers/settings.go +++ b/internal/handlers/settings.go @@ -9,26 +9,26 @@ import ( "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/identity" "github.com/memohai/memoh/internal/settings" - "github.com/memohai/memoh/internal/users" ) type SettingsHandler struct { - service *settings.Service - botService *bots.Service - userService *users.Service - logger *slog.Logger + service *settings.Service + botService *bots.Service + accountService *accounts.Service + logger *slog.Logger } -func NewSettingsHandler(log *slog.Logger, service *settings.Service, botService *bots.Service, userService *users.Service) *SettingsHandler { +func NewSettingsHandler(log *slog.Logger, service *settings.Service, botService *bots.Service, accountService *accounts.Service) *SettingsHandler { return &SettingsHandler{ - service: service, - botService: botService, - userService: userService, - logger: log.With(slog.String("handler", "settings")), + service: service, + botService: botService, + accountService: accountService, + logger: log.With(slog.String("handler", "settings")), } } @@ -49,7 +49,7 @@ func (h *SettingsHandler) Register(e *echo.Echo) { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/settings [get] func (h *SettingsHandler) Get(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -57,7 +57,7 @@ func (h *SettingsHandler) Get(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } resp, err := h.service.GetBot(c.Request().Context(), botID) @@ -78,7 +78,7 @@ func (h *SettingsHandler) Get(c echo.Context) error { // @Router /bots/{bot_id}/settings [put] // @Router /bots/{bot_id}/settings [post] func (h *SettingsHandler) Upsert(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -86,7 +86,7 @@ func (h *SettingsHandler) Upsert(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } var req settings.UpsertRequest @@ -95,6 +95,9 @@ func (h *SettingsHandler) Upsert(c echo.Context) error { } resp, err := h.service.UpsertBot(c.Request().Context(), botID, req) if err != nil { + if errors.Is(err, settings.ErrPersonalBotGuestAccessUnsupported) { + return echo.NewHTTPError(http.StatusBadRequest, "personal bot does not support guest access") + } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.JSON(http.StatusOK, resp) @@ -109,7 +112,7 @@ func (h *SettingsHandler) Upsert(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/settings [delete] func (h *SettingsHandler) Delete(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -117,7 +120,7 @@ func (h *SettingsHandler) Delete(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } if err := h.service.Delete(c.Request().Context(), botID); err != nil { @@ -126,26 +129,26 @@ func (h *SettingsHandler) Delete(c echo.Context) error { return c.NoContent(http.StatusNoContent) } -func (h *SettingsHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) +func (h *SettingsHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return userID, nil + return channelIdentityID, nil } -func (h *SettingsHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { +func (h *SettingsHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") @@ -156,4 +159,4 @@ func (h *SettingsHandler) authorizeBotAccess(ctx context.Context, actorID, botID return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return bot, nil -} \ No newline at end of file +} diff --git a/internal/handlers/subagent.go b/internal/handlers/subagent.go index 67f95467..0d67ad95 100644 --- a/internal/handlers/subagent.go +++ b/internal/handlers/subagent.go @@ -9,26 +9,26 @@ import ( "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/identity" "github.com/memohai/memoh/internal/subagent" - "github.com/memohai/memoh/internal/users" ) type SubagentHandler struct { - service *subagent.Service - botService *bots.Service - userService *users.Service - logger *slog.Logger + service *subagent.Service + botService *bots.Service + accountService *accounts.Service + logger *slog.Logger } -func NewSubagentHandler(log *slog.Logger, service *subagent.Service, botService *bots.Service, userService *users.Service) *SubagentHandler { +func NewSubagentHandler(log *slog.Logger, service *subagent.Service, botService *bots.Service, accountService *accounts.Service) *SubagentHandler { return &SubagentHandler{ - service: service, - botService: botService, - userService: userService, - logger: log.With(slog.String("handler", "subagent")), + service: service, + botService: botService, + accountService: accountService, + logger: log.With(slog.String("handler", "subagent")), } } @@ -56,7 +56,7 @@ func (h *SubagentHandler) Register(e *echo.Echo) { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents [post] func (h *SubagentHandler) Create(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -64,7 +64,7 @@ func (h *SubagentHandler) Create(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } var req subagent.CreateRequest @@ -87,7 +87,7 @@ func (h *SubagentHandler) Create(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents [get] func (h *SubagentHandler) List(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -95,7 +95,7 @@ func (h *SubagentHandler) List(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } items, err := h.service.List(c.Request().Context(), botID) @@ -116,7 +116,7 @@ func (h *SubagentHandler) List(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents/{id} [get] func (h *SubagentHandler) Get(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -135,7 +135,7 @@ func (h *SubagentHandler) Get(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } return c.JSON(http.StatusOK, item) @@ -153,7 +153,7 @@ func (h *SubagentHandler) Get(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents/{id} [put] func (h *SubagentHandler) Update(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -176,7 +176,7 @@ func (h *SubagentHandler) Update(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } resp, err := h.service.Update(c.Request().Context(), id, req) @@ -197,7 +197,7 @@ func (h *SubagentHandler) Update(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents/{id} [delete] func (h *SubagentHandler) Delete(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -216,7 +216,7 @@ func (h *SubagentHandler) Delete(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } if err := h.service.Delete(c.Request().Context(), id); err != nil { @@ -236,7 +236,7 @@ func (h *SubagentHandler) Delete(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents/{id}/context [get] func (h *SubagentHandler) GetContext(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -255,7 +255,7 @@ func (h *SubagentHandler) GetContext(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } return c.JSON(http.StatusOK, subagent.ContextResponse{Messages: item.Messages}) @@ -273,7 +273,7 @@ func (h *SubagentHandler) GetContext(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents/{id}/context [put] func (h *SubagentHandler) UpdateContext(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -296,7 +296,7 @@ func (h *SubagentHandler) UpdateContext(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } updated, err := h.service.UpdateContext(c.Request().Context(), id, req) @@ -317,7 +317,7 @@ func (h *SubagentHandler) UpdateContext(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents/{id}/skills [get] func (h *SubagentHandler) GetSkills(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -336,7 +336,7 @@ func (h *SubagentHandler) GetSkills(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } return c.JSON(http.StatusOK, subagent.SkillsResponse{Skills: item.Skills}) @@ -354,7 +354,7 @@ func (h *SubagentHandler) GetSkills(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents/{id}/skills [put] func (h *SubagentHandler) UpdateSkills(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -377,7 +377,7 @@ func (h *SubagentHandler) UpdateSkills(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } updated, err := h.service.UpdateSkills(c.Request().Context(), id, req) @@ -399,7 +399,7 @@ func (h *SubagentHandler) UpdateSkills(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/subagents/{id}/skills [post] func (h *SubagentHandler) AddSkills(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -422,7 +422,7 @@ func (h *SubagentHandler) AddSkills(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "user mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } updated, err := h.service.AddSkills(c.Request().Context(), id, req) @@ -432,26 +432,26 @@ func (h *SubagentHandler) AddSkills(c echo.Context) error { return c.JSON(http.StatusOK, subagent.SkillsResponse{Skills: updated.Skills}) } -func (h *SubagentHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) +func (h *SubagentHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return userID, nil + return channelIdentityID, nil } -func (h *SubagentHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - if h.botService == nil || h.userService == nil { +func (h *SubagentHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + if h.botService == nil || h.accountService == nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") } - isAdmin, err := h.userService.IsAdmin(ctx, actorID) + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") @@ -462,4 +462,4 @@ func (h *SubagentHandler) authorizeBotAccess(ctx context.Context, actorID, botID return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return bot, nil -} \ No newline at end of file +} diff --git a/internal/handlers/users.go b/internal/handlers/users.go index f07d69e8..973931b7 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -10,29 +10,42 @@ import ( "github.com/jackc/pgx/v5" "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/channel" + "github.com/memohai/memoh/internal/channelidentities" + "github.com/memohai/memoh/internal/chat" "github.com/memohai/memoh/internal/identity" - "github.com/memohai/memoh/internal/users" ) +// UsersHandler manages user/account CRUD and bot operations via REST API. type UsersHandler struct { - service *users.Service + service *accounts.Service + channelIdentityService *channelidentities.Service botService *bots.Service + chatService *chat.Service channelService *channel.Service channelManager *channel.Manager registry *channel.Registry logger *slog.Logger } -func NewUsersHandler(log *slog.Logger, service *users.Service, botService *bots.Service, channelService *channel.Service, channelManager *channel.Manager, registry *channel.Registry) *UsersHandler { +type listMyIdentitiesResponse struct { + UserID string `json:"user_id"` + Items []channelidentities.ChannelIdentity `json:"items"` +} + +// NewUsersHandler creates a UsersHandler with channel identity support. +func NewUsersHandler(log *slog.Logger, service *accounts.Service, channelIdentityService *channelidentities.Service, botService *bots.Service, chatService *chat.Service, channelService *channel.Service, channelManager *channel.Manager, registry *channel.Registry) *UsersHandler { if log == nil { log = slog.Default() } return &UsersHandler{ service: service, + channelIdentityService: channelIdentityService, botService: botService, + chatService: chatService, channelService: channelService, channelManager: channelManager, registry: registry, @@ -43,6 +56,7 @@ func NewUsersHandler(log *slog.Logger, service *users.Service, botService *bots. func (h *UsersHandler) Register(e *echo.Echo) { userGroup := e.Group("/users") userGroup.GET("/me", h.GetMe) + userGroup.GET("/me/identities", h.ListMyIdentities) userGroup.PUT("/me", h.UpdateMe) userGroup.PUT("/me/password", h.UpdateMyPassword) userGroup.GET("", h.ListUsers) @@ -64,48 +78,75 @@ func (h *UsersHandler) Register(e *echo.Echo) { botGroup.GET("/:id/channel/:platform", h.GetBotChannelConfig) botGroup.PUT("/:id/channel/:platform", h.UpsertBotChannelConfig) botGroup.POST("/:id/channel/:platform/send", h.SendBotMessage) - botGroup.POST("/:id/channel/:platform/send_session", h.SendBotMessageSession) + botGroup.POST("/:id/channel/:platform/send_chat", h.SendBotMessageSession) } // GetMe godoc // @Summary Get current user // @Description Get current user profile // @Tags users -// @Success 200 {object} users.User +// @Success 200 {object} accounts.Account // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /users/me [get] func (h *UsersHandler) GetMe(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - resp, err := h.service.Get(c.Request().Context(), userID) + resp, err := h.service.Get(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.JSON(http.StatusOK, resp) } +// ListMyIdentities godoc +// @Summary List current user's channel identities +// @Description List all channel identities linked to current user +// @Tags users +// @Success 200 {object} listMyIdentitiesResponse +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /users/me/identities [get] +func (h *UsersHandler) ListMyIdentities(c echo.Context) error { + userID, err := h.requireChannelIdentityID(c) + if err != nil { + return err + } + if h.channelIdentityService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "channel identity service not configured") + } + items, err := h.channelIdentityService.ListUserChannelIdentities(c.Request().Context(), userID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, listMyIdentitiesResponse{ + UserID: userID, + Items: items, + }) +} + // UpdateMe godoc // @Summary Update current user profile // @Description Update current user display name or avatar // @Tags users -// @Param payload body users.UpdateProfileRequest true "Profile payload" -// @Success 200 {object} users.User +// @Param payload body accounts.UpdateProfileRequest true "Profile payload" +// @Success 200 {object} accounts.Account // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /users/me [put] func (h *UsersHandler) UpdateMe(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - var req users.UpdateProfileRequest + var req accounts.UpdateProfileRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - resp, err := h.service.UpdateProfile(c.Request().Context(), userID, req) + resp, err := h.service.UpdateProfile(c.Request().Context(), channelIdentityID, req) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -116,22 +157,22 @@ func (h *UsersHandler) UpdateMe(c echo.Context) error { // @Summary Update current user password // @Description Update current user password with current password check // @Tags users -// @Param payload body users.UpdatePasswordRequest true "Password payload" +// @Param payload body accounts.UpdatePasswordRequest true "Password payload" // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /users/me/password [put] func (h *UsersHandler) UpdateMyPassword(c echo.Context) error { - userID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - var req users.UpdatePasswordRequest + var req accounts.UpdatePasswordRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if err := h.service.UpdatePassword(c.Request().Context(), userID, req.CurrentPassword, req.NewPassword); err != nil { - if errors.Is(err, users.ErrInvalidPassword) { + if err := h.service.UpdatePassword(c.Request().Context(), channelIdentityID, req.CurrentPassword, req.NewPassword); err != nil { + if errors.Is(err, accounts.ErrInvalidPassword) { return echo.NewHTTPError(http.StatusBadRequest, "current password mismatch") } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) @@ -143,17 +184,17 @@ func (h *UsersHandler) UpdateMyPassword(c echo.Context) error { // @Summary List users (admin only) // @Description List users // @Tags users -// @Success 200 {object} users.ListUsersResponse +// @Success 200 {object} accounts.ListAccountsResponse // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /users [get] func (h *UsersHandler) ListUsers(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - isAdmin, err := h.service.IsAdmin(c.Request().Context(), actorID) + isAdmin, err := h.service.IsAdmin(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -163,11 +204,11 @@ func (h *UsersHandler) ListUsers(c echo.Context) error { if strings.TrimSpace(c.QueryParam("user_type")) != "" || strings.TrimSpace(c.QueryParam("owner_id")) != "" { return echo.NewHTTPError(http.StatusBadRequest, "user_type and owner_id are not supported") } - items, err := h.service.ListUsers(c.Request().Context()) + items, err := h.service.ListAccounts(c.Request().Context()) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - return c.JSON(http.StatusOK, users.ListUsersResponse{Items: items}) + return c.JSON(http.StatusOK, accounts.ListAccountsResponse{Items: items}) } // GetUser godoc @@ -175,14 +216,14 @@ func (h *UsersHandler) ListUsers(c echo.Context) error { // @Description Get user details (self or admin only) // @Tags users // @Param id path string true "User ID" -// @Success 200 {object} users.User +// @Success 200 {object} accounts.Account // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /users/{id} [get] func (h *UsersHandler) GetUser(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -190,8 +231,8 @@ func (h *UsersHandler) GetUser(c echo.Context) error { if targetID == "" { return echo.NewHTTPError(http.StatusBadRequest, "user id is required") } - if targetID != actorID { - isAdmin, err := h.service.IsAdmin(c.Request().Context(), actorID) + if targetID != channelIdentityID { + isAdmin, err := h.service.IsAdmin(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -214,19 +255,19 @@ func (h *UsersHandler) GetUser(c echo.Context) error { // @Description Update user profile and status // @Tags users // @Param id path string true "User ID" -// @Param payload body users.UpdateUserRequest true "User update payload" -// @Success 200 {object} users.User +// @Param payload body accounts.UpdateAccountRequest true "User update payload" +// @Success 200 {object} accounts.Account // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /users/{id} [put] func (h *UsersHandler) UpdateUser(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - isAdmin, err := h.service.IsAdmin(c.Request().Context(), actorID) + isAdmin, err := h.service.IsAdmin(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -244,11 +285,11 @@ func (h *UsersHandler) UpdateUser(c echo.Context) error { } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - var req users.UpdateUserRequest + var req accounts.UpdateAccountRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - resp, err := h.service.UpdateUserAdmin(c.Request().Context(), targetID, req) + resp, err := h.service.UpdateAdmin(c.Request().Context(), targetID, req) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -260,7 +301,7 @@ func (h *UsersHandler) UpdateUser(c echo.Context) error { // @Description Reset a user password // @Tags users // @Param id path string true "User ID" -// @Param payload body users.ResetPasswordRequest true "Password payload" +// @Param payload body accounts.ResetPasswordRequest true "Password payload" // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse @@ -268,11 +309,11 @@ func (h *UsersHandler) UpdateUser(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /users/{id}/password [put] func (h *UsersHandler) ResetUserPassword(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - isAdmin, err := h.service.IsAdmin(c.Request().Context(), actorID) + isAdmin, err := h.service.IsAdmin(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -289,7 +330,7 @@ func (h *UsersHandler) ResetUserPassword(c echo.Context) error { } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - var req users.ResetPasswordRequest + var req accounts.ResetPasswordRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } @@ -303,29 +344,29 @@ func (h *UsersHandler) ResetUserPassword(c echo.Context) error { // @Summary Create human user (admin only) // @Description Create a new human user account // @Tags users -// @Param payload body users.CreateUserRequest true "User payload" -// @Success 201 {object} users.User +// @Param payload body accounts.CreateAccountRequest true "User payload" +// @Success 201 {object} accounts.Account // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /users [post] func (h *UsersHandler) CreateUser(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - isAdmin, err := h.service.IsAdmin(c.Request().Context(), actorID) + isAdmin, err := h.service.IsAdmin(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } if !isAdmin { return echo.NewHTTPError(http.StatusForbidden, "admin role required") } - var req users.CreateUserRequest + var req accounts.CreateAccountRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - resp, err := h.service.CreateHuman(c.Request().Context(), req) + resp, err := h.service.CreateHuman(c.Request().Context(), "", req) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -343,7 +384,7 @@ func (h *UsersHandler) CreateUser(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots [post] func (h *UsersHandler) CreateBot(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -351,19 +392,36 @@ func (h *UsersHandler) CreateBot(c echo.Context) error { if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - ownerID := actorID + ownerID := channelIdentityID + ownerFromToken := true if raw := strings.TrimSpace(c.QueryParam("owner_id")); raw != "" { - isAdmin, err := h.service.IsAdmin(c.Request().Context(), actorID) + isAdmin, err := h.service.IsAdmin(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } if !isAdmin { return echo.NewHTTPError(http.StatusForbidden, "admin role required for owner override") } + if err := identity.ValidateChannelIdentityID(raw); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } ownerID = raw + ownerFromToken = false + } + if ownerFromToken && h.channelIdentityService != nil { + linkedUserID, err := h.channelIdentityService.GetLinkedUserID(c.Request().Context(), ownerID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + if strings.TrimSpace(linkedUserID) != "" { + ownerID = linkedUserID + } } resp, err := h.botService.Create(c.Request().Context(), ownerID, req) if err != nil { + if errors.Is(err, bots.ErrOwnerUserNotFound) { + return echo.NewHTTPError(http.StatusBadRequest, "owner user not found") + } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.JSON(http.StatusCreated, resp) @@ -380,13 +438,13 @@ func (h *UsersHandler) CreateBot(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots [get] func (h *UsersHandler) ListBots(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } ownerID := strings.TrimSpace(c.QueryParam("owner_id")) if ownerID != "" { - isAdmin, err := h.service.IsAdmin(c.Request().Context(), actorID) + isAdmin, err := h.service.IsAdmin(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -399,7 +457,7 @@ func (h *UsersHandler) ListBots(c echo.Context) error { } return c.JSON(http.StatusOK, bots.ListBotsResponse{Items: items}) } - items, err := h.botService.ListAccessible(c.Request().Context(), actorID) + items, err := h.botService.ListAccessible(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -418,7 +476,7 @@ func (h *UsersHandler) ListBots(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id} [get] func (h *UsersHandler) GetBot(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -426,7 +484,7 @@ func (h *UsersHandler) GetBot(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - bot, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID) + bot, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID) if err != nil { return err } @@ -446,7 +504,7 @@ func (h *UsersHandler) GetBot(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id} [put] func (h *UsersHandler) UpdateBot(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -454,7 +512,7 @@ func (h *UsersHandler) UpdateBot(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } var req bots.UpdateBotRequest @@ -481,11 +539,11 @@ func (h *UsersHandler) UpdateBot(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id}/owner [put] func (h *UsersHandler) TransferBotOwner(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } - isAdmin, err := h.service.IsAdmin(c.Request().Context(), actorID) + isAdmin, err := h.service.IsAdmin(c.Request().Context(), channelIdentityID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -505,6 +563,9 @@ func (h *UsersHandler) TransferBotOwner(c echo.Context) error { if errors.Is(err, pgx.ErrNoRows) { return echo.NewHTTPError(http.StatusNotFound, "bot not found") } + if errors.Is(err, bots.ErrOwnerUserNotFound) { + return echo.NewHTTPError(http.StatusBadRequest, "owner user not found") + } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.JSON(http.StatusOK, resp) @@ -522,7 +583,7 @@ func (h *UsersHandler) TransferBotOwner(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id} [delete] func (h *UsersHandler) DeleteBot(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -530,7 +591,7 @@ func (h *UsersHandler) DeleteBot(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } if err := h.botService.Delete(c.Request().Context(), botID); err != nil { @@ -554,7 +615,7 @@ func (h *UsersHandler) DeleteBot(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id}/members [get] func (h *UsersHandler) ListBotMembers(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -562,7 +623,7 @@ func (h *UsersHandler) ListBotMembers(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } items, err := h.botService.ListMembers(c.Request().Context(), botID) @@ -585,7 +646,7 @@ func (h *UsersHandler) ListBotMembers(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id}/members [put] func (h *UsersHandler) UpsertBotMember(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -593,7 +654,7 @@ func (h *UsersHandler) UpsertBotMember(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } var req bots.UpsertMemberRequest @@ -624,7 +685,7 @@ func (h *UsersHandler) UpsertBotMember(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id}/members/{user_id} [delete] func (h *UsersHandler) DeleteBotMember(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -632,14 +693,14 @@ func (h *UsersHandler) DeleteBotMember(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - userID := strings.TrimSpace(c.Param("user_id")) - if userID == "" { + memberUserID := strings.TrimSpace(c.Param("user_id")) + if memberUserID == "" { return echo.NewHTTPError(http.StatusBadRequest, "user id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } - if err := h.botService.DeleteMember(c.Request().Context(), botID, userID); err != nil { + if err := h.botService.DeleteMember(c.Request().Context(), botID, memberUserID); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.NoContent(http.StatusNoContent) @@ -658,7 +719,7 @@ func (h *UsersHandler) DeleteBotMember(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id}/channel/{platform} [get] func (h *UsersHandler) GetBotChannelConfig(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -666,7 +727,7 @@ func (h *UsersHandler) GetBotChannelConfig(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } channelType, err := h.registry.ParseChannelType(c.Param("platform")) @@ -697,7 +758,7 @@ func (h *UsersHandler) GetBotChannelConfig(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id}/channel/{platform} [put] func (h *UsersHandler) UpsertBotChannelConfig(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -705,7 +766,7 @@ func (h *UsersHandler) UpsertBotChannelConfig(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } channelType, err := h.registry.ParseChannelType(c.Param("platform")) @@ -740,7 +801,7 @@ func (h *UsersHandler) UpsertBotChannelConfig(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{id}/channel/{platform}/send [post] func (h *UsersHandler) SendBotMessage(c echo.Context) error { - actorID, err := h.requireUserID(c) + channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { return err } @@ -748,7 +809,7 @@ func (h *UsersHandler) SendBotMessage(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { return err } if h.channelManager == nil { @@ -783,9 +844,9 @@ func (h *UsersHandler) SendBotMessage(c echo.Context) error { // @Failure 401 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/channel/{platform}/send_session [post] +// @Router /bots/{id}/channel/{platform}/send_chat [post] func (h *UsersHandler) SendBotMessageSession(c echo.Context) error { - sessionToken, err := auth.SessionTokenFromContext(c) + chatToken, err := auth.ChatTokenFromContext(c) if err != nil { return err } @@ -793,16 +854,24 @@ func (h *UsersHandler) SendBotMessageSession(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - channelType, err := h.registry.ParseChannelType(c.Param("platform")) + if chatToken.BotID != botID { + return echo.NewHTTPError(http.StatusForbidden, "token bot mismatch") + } + if h.channelManager == nil || h.chatService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "services not configured") + } + route, err := h.chatService.GetRouteByID(c.Request().Context(), chatToken.RouteID) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "route not found") + } + if strings.TrimSpace(route.ReplyTarget) == "" { + return echo.NewHTTPError(http.StatusBadRequest, "reply target missing in route") + } + channelType, err := h.registry.ParseChannelType(route.Platform) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if sessionToken.BotID != botID || sessionToken.Platform != channelType.String() { - return echo.NewHTTPError(http.StatusForbidden, "session token mismatch") - } - if h.channelManager == nil { - return echo.NewHTTPError(http.StatusInternalServerError, "channel manager not configured") - } + var req channel.SendRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) @@ -810,11 +879,8 @@ func (h *UsersHandler) SendBotMessageSession(c echo.Context) error { if req.Message.IsEmpty() { return echo.NewHTTPError(http.StatusBadRequest, "message is required") } - if strings.TrimSpace(sessionToken.ReplyTarget) == "" { - return echo.NewHTTPError(http.StatusBadRequest, "reply target missing") - } if err := h.channelManager.Send(c.Request().Context(), botID, channelType, channel.SendRequest{ - Target: sessionToken.ReplyTarget, + Target: route.ReplyTarget, Message: req.Message, }); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) @@ -822,12 +888,12 @@ func (h *UsersHandler) SendBotMessageSession(c echo.Context) error { return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) } -func (h *UsersHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { - isAdmin, err := h.service.IsAdmin(ctx, actorID) +func (h *UsersHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { + isAdmin, err := h.service.IsAdmin(ctx, channelIdentityID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") @@ -840,13 +906,13 @@ func (h *UsersHandler) authorizeBotAccess(ctx context.Context, actorID, botID st return bot, nil } -func (h *UsersHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) +func (h *UsersHandler) requireChannelIdentityID(c echo.Context) (string, error) { + channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateUserID(userID); err != nil { + if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return userID, nil + return channelIdentityID, nil } diff --git a/internal/history/service.go b/internal/history/service.go deleted file mode 100644 index a407557c..00000000 --- a/internal/history/service.go +++ /dev/null @@ -1,237 +0,0 @@ -package history - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "log/slog" - "strings" - "time" - - "github.com/google/uuid" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - - "github.com/memohai/memoh/internal/db/sqlc" -) - -const defaultListLimit = 50 - -type Service struct { - queries *sqlc.Queries - logger *slog.Logger -} - -func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { - return &Service{ - queries: queries, - logger: log.With(slog.String("service", "history")), - } -} - -func (s *Service) Create(ctx context.Context, botID, sessionID string, req CreateRequest) (Record, error) { - if len(req.Messages) == 0 { - return Record{}, fmt.Errorf("messages are required") - } - botUUID, err := parseUUID(botID) - if err != nil { - return Record{}, err - } - trimmedSession := strings.TrimSpace(sessionID) - if trimmedSession == "" { - return Record{}, fmt.Errorf("session id is required") - } - payload, err := json.Marshal(req.Messages) - if err != nil { - return Record{}, err - } - meta := req.Metadata - if meta == nil { - meta = map[string]any{} - } - metaPayload, err := json.Marshal(meta) - if err != nil { - return Record{}, err - } - row, err := s.queries.CreateHistory(ctx, sqlc.CreateHistoryParams{ - BotID: botUUID, - SessionID: trimmedSession, - Messages: payload, - Metadata: metaPayload, - Skills: normalizeSkills(req.Skills), - Timestamp: pgtype.Timestamptz{ - Time: time.Now().UTC(), - Valid: true, - }, - }) - if err != nil { - return Record{}, err - } - return toRecord(row) -} - -func (s *Service) Get(ctx context.Context, id string) (Record, error) { - pgID, err := parseUUID(id) - if err != nil { - return Record{}, err - } - row, err := s.queries.GetHistoryByID(ctx, pgID) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return Record{}, fmt.Errorf("history not found") - } - return Record{}, err - } - return toRecord(row) -} - -func (s *Service) List(ctx context.Context, botID, sessionID string, limit int) ([]Record, error) { - botUUID, err := parseUUID(botID) - if err != nil { - return nil, err - } - trimmedSession := strings.TrimSpace(sessionID) - if trimmedSession == "" { - return nil, fmt.Errorf("session id is required") - } - if limit <= 0 { - limit = defaultListLimit - } - rows, err := s.queries.ListHistoryByBotSession(ctx, sqlc.ListHistoryByBotSessionParams{ - BotID: botUUID, - SessionID: trimmedSession, - Limit: int32(limit), - }) - if err != nil { - return nil, err - } - items := make([]Record, 0, len(rows)) - for _, row := range rows { - record, err := toRecord(row) - if err != nil { - return nil, err - } - items = append(items, record) - } - return items, nil -} - -func (s *Service) ListBySessionSince(ctx context.Context, botID, sessionID string, since time.Time) ([]Record, error) { - botUUID, err := parseUUID(botID) - if err != nil { - return nil, err - } - trimmedSession := strings.TrimSpace(sessionID) - if trimmedSession == "" { - return nil, fmt.Errorf("session id is required") - } - rows, err := s.queries.ListHistoryByBotSessionSince(ctx, sqlc.ListHistoryByBotSessionSinceParams{ - BotID: botUUID, - SessionID: trimmedSession, - Timestamp: pgtype.Timestamptz{ - Time: since, - Valid: true, - }, - }) - if err != nil { - return nil, err - } - items := make([]Record, 0, len(rows)) - for _, row := range rows { - record, err := toRecord(row) - if err != nil { - return nil, err - } - items = append(items, record) - } - return items, nil -} - -func (s *Service) Delete(ctx context.Context, id string) error { - pgID, err := parseUUID(id) - if err != nil { - return err - } - return s.queries.DeleteHistoryByID(ctx, pgID) -} - -func (s *Service) DeleteBySession(ctx context.Context, botID, sessionID string) error { - botUUID, err := parseUUID(botID) - if err != nil { - return err - } - trimmedSession := strings.TrimSpace(sessionID) - if trimmedSession == "" { - return fmt.Errorf("session id is required") - } - return s.queries.DeleteHistoryByBotSession(ctx, sqlc.DeleteHistoryByBotSessionParams{ - BotID: botUUID, - SessionID: trimmedSession, - }) -} - -func toRecord(row sqlc.History) (Record, error) { - var messages []map[string]any - if len(row.Messages) > 0 { - if err := json.Unmarshal(row.Messages, &messages); err != nil { - return Record{}, err - } - } - var metadata map[string]any - if len(row.Metadata) > 0 { - if err := json.Unmarshal(row.Metadata, &metadata); err != nil { - return Record{}, err - } - } - record := Record{ - Messages: messages, - Metadata: metadata, - Skills: normalizeSkills(row.Skills), - } - if row.Timestamp.Valid { - record.Timestamp = row.Timestamp.Time - } - if row.ID.Valid { - id, err := uuid.FromBytes(row.ID.Bytes[:]) - if err == nil { - record.ID = id.String() - } - } - if row.BotID.Valid { - uid, err := uuid.FromBytes(row.BotID.Bytes[:]) - if err == nil { - record.BotID = uid.String() - } - } - record.SessionID = row.SessionID - return record, nil -} - -func normalizeSkills(skills []string) []string { - seen := map[string]struct{}{} - normalized := make([]string, 0, len(skills)) - for _, skill := range skills { - trimmed := strings.TrimSpace(skill) - if trimmed == "" { - continue - } - if _, ok := seen[trimmed]; ok { - continue - } - seen[trimmed] = struct{}{} - normalized = append(normalized, trimmed) - } - return normalized -} - -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} diff --git a/internal/history/types.go b/internal/history/types.go deleted file mode 100644 index 088d8cfb..00000000 --- a/internal/history/types.go +++ /dev/null @@ -1,23 +0,0 @@ -package history - -import "time" - -type Record struct { - ID string `json:"id"` - Messages []map[string]any `json:"messages"` - Metadata map[string]any `json:"metadata,omitempty"` - Skills []string `json:"skills"` - Timestamp time.Time `json:"timestamp"` - BotID string `json:"bot_id"` - SessionID string `json:"session_id"` -} - -type CreateRequest struct { - Messages []map[string]any `json:"messages"` - Metadata map[string]any `json:"metadata,omitempty"` - Skills []string `json:"skills,omitempty"` -} - -type ListResponse struct { - Items []Record `json:"items"` -} diff --git a/internal/identity/types.go b/internal/identity/types.go index 32f65acb..125e325a 100644 --- a/internal/identity/types.go +++ b/internal/identity/types.go @@ -3,10 +3,11 @@ package identity import "strings" const ( - UserTypeHuman = "human" - UserTypeBot = "bot" + IdentityTypeHuman = "human" + IdentityTypeBot = "bot" ) -func IsBotUserType(userType string) bool { - return strings.EqualFold(strings.TrimSpace(userType), UserTypeBot) +// IsBotIdentityType checks if the identity type is a bot. +func IsBotIdentityType(identityType string) bool { + return strings.EqualFold(strings.TrimSpace(identityType), IdentityTypeBot) } diff --git a/internal/identity/user.go b/internal/identity/user.go index 3f210c43..6e5b9d41 100644 --- a/internal/identity/user.go +++ b/internal/identity/user.go @@ -6,14 +6,14 @@ import ( ctr "github.com/memohai/memoh/internal/containerd" ) -// ValidateUserID enforces a conservative ID charset for isolation. -func ValidateUserID(userID string) error { - if userID == "" { - return fmt.Errorf("%w: user id required", ctr.ErrInvalidArgument) +// ValidateChannelIdentityID enforces a conservative ID charset for isolation. +func ValidateChannelIdentityID(channelIdentityID string) error { + if channelIdentityID == "" { + return fmt.Errorf("%w: channel identity id required", ctr.ErrInvalidArgument) } - for _, r := range userID { + for _, r := range channelIdentityID { if !(r == '-' || r == '_' || (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9')) { - return fmt.Errorf("%w: invalid user id", ctr.ErrInvalidArgument) + return fmt.Errorf("%w: invalid channel identity id", ctr.ErrInvalidArgument) } } return nil diff --git a/internal/logger/logger_test.go b/internal/logger/logger_test.go index 676cd358..f485ac1c 100644 --- a/internal/logger/logger_test.go +++ b/internal/logger/logger_test.go @@ -9,7 +9,7 @@ import ( func TestInitAndLogging(t *testing.T) { // 测试 JSON 格式 Init("debug", "json") - + if L.Enabled(context.Background(), slog.LevelDebug) != true { t.Error("expected debug level to be enabled") } @@ -20,15 +20,15 @@ func TestInitAndLogging(t *testing.T) { func TestContextLogger(t *testing.T) { Init("info", "text") - + // 创建一个带特定属性的 logger expectedKey := "request_id" expectedValue := "12345" customLogger := L.With(expectedKey, expectedValue) - + ctx := WithContext(context.Background(), customLogger) extracted := FromContext(ctx) - + // 这里简单验证提取出来的是否是同一个(或者功能一致) if extracted == nil { t.Fatal("extracted logger should not be nil") diff --git a/internal/mcp/manager.go b/internal/mcp/manager.go index e2a46218..0877de2e 100644 --- a/internal/mcp/manager.go +++ b/internal/mcp/manager.go @@ -281,5 +281,5 @@ func (m *Manager) imageRef() string { } func validateBotID(botID string) error { - return identity.ValidateUserID(botID) + return identity.ValidateChannelIdentityID(botID) } diff --git a/internal/memory/service.go b/internal/memory/service.go index 343ee1b3..3ff076e2 100644 --- a/internal/memory/service.go +++ b/internal/memory/service.go @@ -409,12 +409,12 @@ func (s *Service) Get(ctx context.Context, memoryID string) (MemoryItem, error) func (s *Service) GetAll(ctx context.Context, req GetAllRequest) (SearchResponse, error) { filters := map[string]any{} + for k, v := range req.Filters { + filters[k] = v + } if req.BotID != "" { filters["botId"] = req.BotID } - if req.SessionID != "" { - filters["sessionId"] = req.SessionID - } if req.RunID != "" { filters["runId"] = req.RunID } @@ -445,12 +445,12 @@ func (s *Service) Delete(ctx context.Context, memoryID string) (DeleteResponse, func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteResponse, error) { filters := map[string]any{} + for k, v := range req.Filters { + filters[k] = v + } if req.BotID != "" { filters["botId"] = req.BotID } - if req.SessionID != "" { - filters["sessionId"] = req.SessionID - } if req.RunID != "" { filters["runId"] = req.RunID } @@ -756,9 +756,6 @@ func buildFilters(req AddRequest) map[string]any { if req.BotID != "" { filters["botId"] = req.BotID } - if req.SessionID != "" { - filters["sessionId"] = req.SessionID - } if req.RunID != "" { filters["runId"] = req.RunID } @@ -773,9 +770,6 @@ func buildSearchFilters(req SearchRequest) map[string]any { if req.BotID != "" { filters["botId"] = req.BotID } - if req.SessionID != "" { - filters["sessionId"] = req.SessionID - } if req.RunID != "" { filters["runId"] = req.RunID } @@ -790,9 +784,6 @@ func buildEmbedFilters(req EmbedUpsertRequest) map[string]any { if req.BotID != "" { filters["botId"] = req.BotID } - if req.SessionID != "" { - filters["sessionId"] = req.SessionID - } if req.RunID != "" { filters["runId"] = req.RunID } @@ -883,9 +874,6 @@ func payloadToMemoryItem(id string, payload map[string]any) MemoryItem { if v, ok := payload["botId"].(string); ok { item.BotID = v } - if v, ok := payload["sessionId"].(string); ok { - item.SessionID = v - } if v, ok := payload["runId"].(string); ok { item.RunID = v } diff --git a/internal/memory/types.go b/internal/memory/types.go index 22299457..606f102a 100644 --- a/internal/memory/types.go +++ b/internal/memory/types.go @@ -15,28 +15,26 @@ type Message struct { } type AddRequest struct { - Message string `json:"message,omitempty"` - Messages []Message `json:"messages,omitempty"` - BotID string `json:"bot_id,omitempty"` - SessionID string `json:"session_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - RunID string `json:"run_id,omitempty"` + Message string `json:"message,omitempty"` + Messages []Message `json:"messages,omitempty"` + BotID string `json:"bot_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + RunID string `json:"run_id,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` Filters map[string]any `json:"filters,omitempty"` - Infer *bool `json:"infer,omitempty"` - EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` + Infer *bool `json:"infer,omitempty"` + EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` } type SearchRequest struct { - Query string `json:"query"` - BotID string `json:"bot_id,omitempty"` - SessionID string `json:"session_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - RunID string `json:"run_id,omitempty"` - Limit int `json:"limit,omitempty"` + Query string `json:"query"` + BotID string `json:"bot_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + RunID string `json:"run_id,omitempty"` + Limit int `json:"limit,omitempty"` Filters map[string]any `json:"filters,omitempty"` - Sources []string `json:"sources,omitempty"` - EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` + Sources []string `json:"sources,omitempty"` + EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` } type UpdateRequest struct { @@ -46,18 +44,18 @@ type UpdateRequest struct { } type GetAllRequest struct { - BotID string `json:"bot_id,omitempty"` - SessionID string `json:"session_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - RunID string `json:"run_id,omitempty"` - Limit int `json:"limit,omitempty"` + BotID string `json:"bot_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + RunID string `json:"run_id,omitempty"` + Limit int `json:"limit,omitempty"` + Filters map[string]any `json:"filters,omitempty"` } type DeleteAllRequest struct { - BotID string `json:"bot_id,omitempty"` - SessionID string `json:"session_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - RunID string `json:"run_id,omitempty"` + BotID string `json:"bot_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + RunID string `json:"run_id,omitempty"` + Filters map[string]any `json:"filters,omitempty"` } type EmbedInput struct { @@ -67,17 +65,16 @@ type EmbedInput struct { } type EmbedUpsertRequest struct { - Type string `json:"type"` - Provider string `json:"provider,omitempty"` - Model string `json:"model,omitempty"` - Input EmbedInput `json:"input"` - Source string `json:"source,omitempty"` - BotID string `json:"bot_id,omitempty"` - SessionID string `json:"session_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - RunID string `json:"run_id,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` - Filters map[string]any `json:"filters,omitempty"` + Type string `json:"type"` + Provider string `json:"provider,omitempty"` + Model string `json:"model,omitempty"` + Input EmbedInput `json:"input"` + Source string `json:"source,omitempty"` + BotID string `json:"bot_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + RunID string `json:"run_id,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + Filters map[string]any `json:"filters,omitempty"` } type EmbedUpsertResponse struct { @@ -88,17 +85,16 @@ type EmbedUpsertResponse struct { } type MemoryItem struct { - ID string `json:"id"` - Memory string `json:"memory"` - Hash string `json:"hash,omitempty"` - CreatedAt string `json:"createdAt,omitempty"` - UpdatedAt string `json:"updatedAt,omitempty"` - Score float64 `json:"score,omitempty"` + ID string `json:"id"` + Memory string `json:"memory"` + Hash string `json:"hash,omitempty"` + CreatedAt string `json:"createdAt,omitempty"` + UpdatedAt string `json:"updatedAt,omitempty"` + Score float64 `json:"score,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` - BotID string `json:"botId,omitempty"` - SessionID string `json:"sessionId,omitempty"` - AgentID string `json:"agentId,omitempty"` - RunID string `json:"runId,omitempty"` + BotID string `json:"botId,omitempty"` + AgentID string `json:"agentId,omitempty"` + RunID string `json:"runId,omitempty"` } type SearchResponse struct { @@ -111,7 +107,7 @@ type DeleteResponse struct { } type ExtractRequest struct { - Messages []Message `json:"messages"` + Messages []Message `json:"messages"` Filters map[string]any `json:"filters,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } @@ -121,16 +117,16 @@ type ExtractResponse struct { } type CandidateMemory struct { - ID string `json:"id"` - Memory string `json:"memory"` + ID string `json:"id"` + Memory string `json:"memory"` Metadata map[string]any `json:"metadata,omitempty"` } type DecideRequest struct { - Facts []string `json:"facts"` - Candidates []CandidateMemory `json:"candidates"` - Filters map[string]any `json:"filters,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Facts []string `json:"facts"` + Candidates []CandidateMemory `json:"candidates"` + Filters map[string]any `json:"filters,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } type DecisionAction struct { diff --git a/internal/models/models_test.go b/internal/models/models_test.go index 42573164..35fe4b31 100644 --- a/internal/models/models_test.go +++ b/internal/models/models_test.go @@ -13,7 +13,7 @@ import ( func ExampleService_Create() { // Example usage - in real code, you would initialize with actual database connection // service := models.NewService(queries) - + // ctx := context.Background() // req := models.AddRequest{ // ModelID: "gpt-4", @@ -21,7 +21,7 @@ func ExampleService_Create() { // LlmProviderID: "11111111-1111-1111-1111-111111111111", // Type: models.ModelTypeChat, // } - + // resp, err := service.Create(ctx, req) // if err != nil { // // handle error @@ -32,7 +32,7 @@ func ExampleService_Create() { func ExampleService_GetByModelID() { // Example usage // service := models.NewService(queries) - + // ctx := context.Background() // resp, err := service.GetByModelID(ctx, "gpt-4") // if err != nil { @@ -44,7 +44,7 @@ func ExampleService_GetByModelID() { func ExampleService_List() { // Example usage // service := models.NewService(queries) - + // ctx := context.Background() // models, err := service.List(ctx) // if err != nil { @@ -58,7 +58,7 @@ func ExampleService_List() { func ExampleService_ListByType() { // Example usage // service := models.NewService(queries) - + // ctx := context.Background() // chatModels, err := service.ListByType(ctx, models.ModelTypeChat) // if err != nil { @@ -70,7 +70,7 @@ func ExampleService_ListByType() { func ExampleService_UpdateByModelID() { // Example usage // service := models.NewService(queries) - + // ctx := context.Background() // req := models.UpdateRequest{ // ModelID: "gpt-4", @@ -78,7 +78,7 @@ func ExampleService_UpdateByModelID() { // LlmProviderID: "11111111-1111-1111-1111-111111111111", // Type: models.ModelTypeChat, // } - + // resp, err := service.UpdateByModelID(ctx, "gpt-4", req) // if err != nil { // // handle error @@ -89,7 +89,7 @@ func ExampleService_UpdateByModelID() { func ExampleService_DeleteByModelID() { // Example usage // service := models.NewService(queries) - + // ctx := context.Background() // err := service.DeleteByModelID(ctx, "gpt-4") // if err != nil { @@ -208,7 +208,7 @@ func TestModelTypes(t *testing.T) { // } // // ctx := context.Background() -// +// // // Setup database connection // pool, err := db.Open(ctx, config.PostgresConfig{ // Host: "localhost", @@ -271,4 +271,3 @@ func TestModelTypes(t *testing.T) { // err = service.DeleteByModelID(ctx, "test-gpt-4") // require.NoError(t, err) // } - diff --git a/internal/policy/service.go b/internal/policy/service.go index 518e21bb..2c476d1e 100644 --- a/internal/policy/service.go +++ b/internal/policy/service.go @@ -33,6 +33,7 @@ func NewService(log *slog.Logger, botsService *bots.Service, settingsService *se } } +// Resolve evaluates the full access policy for a bot. func (s *Service) Resolve(ctx context.Context, botID string) (Decision, error) { if s == nil || s.bots == nil || s.settings == nil { return Decision{}, fmt.Errorf("policy service not configured") @@ -59,3 +60,33 @@ func (s *Service) Resolve(ctx context.Context, botID string) (Decision, error) { } return decision, nil } + +// AllowGuest checks if the bot allows guest access. Implements router.PolicyService. +func (s *Service) AllowGuest(ctx context.Context, botID string) (bool, error) { + decision, err := s.Resolve(ctx, botID) + if err != nil { + return false, err + } + return decision.AllowGuest, nil +} + +// BotType returns the normalized bot type. Implements router.PolicyService. +func (s *Service) BotType(ctx context.Context, botID string) (string, error) { + decision, err := s.Resolve(ctx, botID) + if err != nil { + return "", err + } + return decision.BotType, nil +} + +// BotOwnerUserID returns bot owner's user id. Implements router.PolicyService. +func (s *Service) BotOwnerUserID(ctx context.Context, botID string) (string, error) { + if s == nil || s.bots == nil { + return "", fmt.Errorf("policy service not configured") + } + bot, err := s.bots.Get(ctx, strings.TrimSpace(botID)) + if err != nil { + return "", err + } + return strings.TrimSpace(bot.OwnerUserID), nil +} diff --git a/internal/preauth/service.go b/internal/preauth/service.go index 7aa8f5b9..e267521b 100644 --- a/internal/preauth/service.go +++ b/internal/preauth/service.go @@ -24,6 +24,7 @@ func NewService(queries *sqlc.Queries) *Service { return &Service{queries: queries} } +// Issue creates a new preauth key for the given bot. func (s *Service) Issue(ctx context.Context, botID, issuedByUserID string, ttl time.Duration) (Key, error) { if s.queries == nil { return Key{}, fmt.Errorf("preauth queries not configured") @@ -88,13 +89,13 @@ func (s *Service) MarkUsed(ctx context.Context, id string) (Key, error) { func normalizeKey(row sqlc.BotPreauthKey) Key { return Key{ - ID: toUUIDString(row.ID), - BotID: toUUIDString(row.BotID), - Token: strings.TrimSpace(row.Token), - IssuedByUserID: toUUIDString(row.IssuedByUserID), - ExpiresAt: timeFromPg(row.ExpiresAt), - UsedAt: timeFromPg(row.UsedAt), - CreatedAt: timeFromPg(row.CreatedAt), + ID: toUUIDString(row.ID), + BotID: toUUIDString(row.BotID), + Token: strings.TrimSpace(row.Token), + IssuedByChannelIdentityID: toUUIDString(row.IssuedByUserID), + ExpiresAt: timeFromPg(row.ExpiresAt), + UsedAt: timeFromPg(row.UsedAt), + CreatedAt: timeFromPg(row.CreatedAt), } } diff --git a/internal/preauth/types.go b/internal/preauth/types.go index cd26b086..2527f38d 100644 --- a/internal/preauth/types.go +++ b/internal/preauth/types.go @@ -2,12 +2,13 @@ package preauth import "time" +// Key represents a bot pre-authorization key. type Key struct { - ID string - BotID string - Token string - IssuedByUserID string - ExpiresAt time.Time - UsedAt time.Time - CreatedAt time.Time + ID string + BotID string + Token string + IssuedByChannelIdentityID string + ExpiresAt time.Time + UsedAt time.Time + CreatedAt time.Time } diff --git a/internal/router/channel.go b/internal/router/channel.go index 3fadd484..e86474ed 100644 --- a/internal/router/channel.go +++ b/internal/router/channel.go @@ -13,23 +13,13 @@ import ( "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/channel" "github.com/memohai/memoh/internal/chat" - "github.com/memohai/memoh/internal/contacts" ) -// ChatGateway 抽象聊天能力,避免路由层直接依赖具体实现。 +// ChatGateway abstracts the chat capability to avoid direct coupling in the router. type ChatGateway interface { Chat(ctx context.Context, req chat.ChatRequest) (chat.ChatResponse, error) } -type ContactService interface { - GetByID(ctx context.Context, contactID string) (contacts.Contact, error) - GetByUserID(ctx context.Context, botID, userID string) (contacts.Contact, error) - GetByChannelIdentity(ctx context.Context, botID, platform, externalID string) (contacts.ContactChannel, error) - Create(ctx context.Context, req contacts.CreateRequest) (contacts.Contact, error) - CreateGuest(ctx context.Context, botID, displayName string) (contacts.Contact, error) - UpsertChannel(ctx context.Context, botID, contactID, platform, externalID string, metadata map[string]any) (contacts.ContactChannel, error) -} - const ( silentReplyToken = "NO_REPLY" minDuplicateTextLength = 10 @@ -39,34 +29,56 @@ var ( whitespacePattern = regexp.MustCompile(`\s+`) ) -// ChannelInboundProcessor 将 channel 入站消息路由到 chat,并返回可发送的回复。 -type ChannelInboundProcessor struct { - chat ChatGateway - registry *channel.Registry - logger *slog.Logger - jwtSecret string - tokenTTL time.Duration - identity *IdentityResolver +// ChatService resolves and manages chats. +type ChatService interface { + ResolveChat(ctx context.Context, botID, platform, conversationID, threadID, conversationType, userID, channelConfigID, replyTarget string) (chat.ResolveChatResult, error) + PersistMessage(ctx context.Context, chatID, botID, routeID, senderChannelIdentityID, senderUserID, platform, externalMessageID, role string, content json.RawMessage, metadata map[string]any) (chat.Message, error) } -func NewChannelInboundProcessor(log *slog.Logger, registry *channel.Registry, store channel.ConfigStore, chatGateway ChatGateway, contactService ContactService, policyService PolicyService, preauthService PreauthService, jwtSecret string, tokenTTL time.Duration) *ChannelInboundProcessor { +// ChannelInboundProcessor routes channel inbound messages to the chat gateway. +type ChannelInboundProcessor struct { + chat ChatGateway + chatService ChatService + registry *channel.Registry + logger *slog.Logger + jwtSecret string + tokenTTL time.Duration + identity *IdentityResolver +} + +// NewChannelInboundProcessor creates a processor with channel identity-based resolution. +func NewChannelInboundProcessor( + log *slog.Logger, + registry *channel.Registry, + chatService ChatService, + chatGateway ChatGateway, + channelIdentityService ChannelIdentityService, + memberService BotMemberService, + policyService PolicyService, + preauthService PreauthService, + bindService BindService, + jwtSecret string, + tokenTTL time.Duration, +) *ChannelInboundProcessor { if log == nil { log = slog.Default() } if tokenTTL <= 0 { tokenTTL = 5 * time.Minute } - identityResolver := NewIdentityResolver(log, registry, store, contactService, policyService, preauthService, "", "") + identityResolver := NewIdentityResolver(log, registry, channelIdentityService, memberService, policyService, preauthService, bindService, "", "") return &ChannelInboundProcessor{ - chat: chatGateway, - registry: registry, - logger: log.With(slog.String("component", "channel_router")), - jwtSecret: strings.TrimSpace(jwtSecret), - tokenTTL: tokenTTL, - identity: identityResolver, + chat: chatGateway, + chatService: chatService, + registry: registry, + logger: log.With(slog.String("component", "channel_router")), + jwtSecret: strings.TrimSpace(jwtSecret), + tokenTTL: tokenTTL, + identity: identityResolver, } } +// IdentityMiddleware returns the identity resolution middleware. func (p *ChannelInboundProcessor) IdentityMiddleware() channel.Middleware { if p == nil || p.identity == nil { return nil @@ -74,6 +86,7 @@ func (p *ChannelInboundProcessor) IdentityMiddleware() channel.Middleware { return p.identity.Middleware() } +// HandleInbound processes an inbound channel message through identity resolution and chat gateway. func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, sender channel.ReplySender) error { if p.chat == nil { return fmt.Errorf("channel inbound processor not configured") @@ -101,24 +114,42 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel identity := state.Identity - sessionToken := "" + // Resolve or create the chat via chat_routes. + if p.chatService == nil { + return fmt.Errorf("chat service not configured") + } + resolved, err := p.chatService.ResolveChat(ctx, identity.BotID, + msg.Channel.String(), msg.Conversation.ID, extractThreadID(msg), + msg.Conversation.Type, identity.UserID, identity.ChannelConfigID, + strings.TrimSpace(msg.ReplyTarget)) + if err != nil { + return fmt.Errorf("resolve chat: %w", err) + } + if !shouldTriggerAssistantResponse(msg) && !identity.ForceReply { + p.persistInboundOnly(ctx, resolved, identity, msg, text) + return nil + } + + // Issue chat token for reply routing. + chatToken := "" if p.jwtSecret != "" && strings.TrimSpace(msg.ReplyTarget) != "" { - signed, _, err := auth.GenerateSessionToken(auth.SessionToken{ - BotID: identity.BotID, - Platform: msg.Channel.String(), - ReplyTarget: strings.TrimSpace(msg.ReplyTarget), - SessionID: identity.SessionID, - ContactID: identity.ContactID, + signed, _, err := auth.GenerateChatToken(auth.ChatToken{ + BotID: identity.BotID, + ChatID: resolved.ChatID, + RouteID: resolved.RouteID, + UserID: identity.UserID, + ChannelIdentityID: identity.ChannelIdentityID, }, p.jwtSecret, p.tokenTTL) if err != nil { if p.logger != nil { - p.logger.Warn("issue session token failed", slog.Any("error", err)) + p.logger.Warn("issue chat token failed", slog.Any("error", err)) } } else { - sessionToken = signed + chatToken = signed } } + // Issue user JWT for downstream calls. token := "" if identity.UserID != "" && p.jwtSecret != "" { signed, _, err := auth.GenerateToken(identity.UserID, p.jwtSecret, p.tokenTTL) @@ -130,27 +161,33 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel token = "Bearer " + signed } } + var desc channel.Descriptor if p.registry != nil { desc, _ = p.registry.GetDescriptor(msg.Channel) } resp, err := p.chat.Chat(ctx, chat.ChatRequest{ - BotID: identity.BotID, - SessionID: identity.SessionID, - Token: token, - UserID: identity.UserID, - ContactID: identity.ContactID, - ContactName: strings.TrimSpace(identity.Contact.DisplayName), - ContactAlias: strings.TrimSpace(identity.Contact.Alias), - ReplyTarget: strings.TrimSpace(msg.ReplyTarget), - SessionToken: sessionToken, - Query: text, - CurrentChannel: msg.Channel.String(), - Channels: []string{msg.Channel.String()}, + BotID: identity.BotID, + ChatID: resolved.ChatID, + Token: token, + ChannelIdentityID: identity.UserID, + DisplayName: identity.DisplayName, + RouteID: resolved.RouteID, + ChatToken: chatToken, + ExternalMessageID: strings.TrimSpace(msg.Message.ID), + Query: text, + CurrentChannel: msg.Channel.String(), + Channels: []string{msg.Channel.String()}, }) if err != nil { if p.logger != nil { - p.logger.Error("chat gateway failed", slog.String("channel", msg.Channel.String()), slog.String("user_id", identity.UserID), slog.Any("error", err)) + p.logger.Error( + "chat gateway failed", + slog.String("channel", msg.Channel.String()), + slog.String("channel_identity_id", identity.ChannelIdentityID), + slog.String("user_id", identity.UserID), + slog.Any("error", err), + ) } return err } @@ -188,6 +225,141 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel return nil } +func shouldTriggerAssistantResponse(msg channel.InboundMessage) bool { + if isDirectConversationType(msg.Conversation.Type) { + return true + } + if metadataBool(msg.Metadata, "is_mentioned") { + return true + } + if metadataBool(msg.Metadata, "is_reply_to_bot") { + return true + } + return hasCommandPrefix(msg.Message.PlainText(), msg.Metadata) +} + +func isDirectConversationType(conversationType string) bool { + ct := strings.ToLower(strings.TrimSpace(conversationType)) + return ct == "" || ct == "p2p" || ct == "private" || ct == "direct" +} + +func hasCommandPrefix(text string, metadata map[string]any) bool { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return false + } + prefixes := []string{"/"} + if metadata != nil { + if raw, ok := metadata["command_prefix"]; ok { + if value := strings.TrimSpace(fmt.Sprint(raw)); value != "" { + prefixes = []string{value} + } + } + if raw, ok := metadata["command_prefixes"]; ok { + if parsed := parseCommandPrefixes(raw); len(parsed) > 0 { + prefixes = parsed + } + } + } + for _, prefix := range prefixes { + if strings.HasPrefix(trimmed, prefix) { + return true + } + } + return false +} + +func parseCommandPrefixes(raw any) []string { + if items, ok := raw.([]string); ok { + result := make([]string, 0, len(items)) + for _, item := range items { + value := strings.TrimSpace(item) + if value == "" { + continue + } + result = append(result, value) + } + return result + } + items, ok := raw.([]any) + if !ok { + return nil + } + result := make([]string, 0, len(items)) + for _, item := range items { + value := strings.TrimSpace(fmt.Sprint(item)) + if value == "" { + continue + } + result = append(result, value) + } + return result +} + +func metadataBool(metadata map[string]any, key string) bool { + if metadata == nil { + return false + } + raw, ok := metadata[key] + if !ok { + return false + } + switch value := raw.(type) { + case bool: + return value + case string: + switch strings.ToLower(strings.TrimSpace(value)) { + case "1", "true", "yes", "on": + return true + default: + return false + } + default: + return false + } +} + +func (p *ChannelInboundProcessor) persistInboundOnly(ctx context.Context, resolved chat.ResolveChatResult, identity InboundIdentity, msg channel.InboundMessage, query string) { + if p.chatService == nil { + return + } + chatID := strings.TrimSpace(resolved.ChatID) + botID := strings.TrimSpace(identity.BotID) + if chatID == "" || botID == "" { + return + } + payload, err := json.Marshal(chat.ModelMessage{ + Role: "user", + Content: chat.NewTextContent(query), + }) + if err != nil { + if p.logger != nil { + p.logger.Warn("marshal passive inbound failed", slog.Any("error", err)) + } + return + } + meta := map[string]any{ + "route_id": resolved.RouteID, + "platform": msg.Channel.String(), + "trigger_mode": "passive_sync", + } + if _, err := p.chatService.PersistMessage( + ctx, + chatID, + botID, + strings.TrimSpace(resolved.RouteID), + strings.TrimSpace(identity.ChannelIdentityID), + strings.TrimSpace(identity.UserID), + msg.Channel.String(), + strings.TrimSpace(msg.Message.ID), + "user", + payload, + meta, + ); err != nil && p.logger != nil { + p.logger.Warn("persist passive inbound failed", slog.Any("error", err)) + } +} + func buildChannelMessage(output chat.AssistantOutput, capabilities channel.ChannelCapabilities) channel.Message { msg := channel.Message{} if strings.TrimSpace(output.Content) != "" { @@ -207,13 +379,13 @@ func buildChannelMessage(output chat.AssistantOutput, capabilities channel.Chann } partType := normalizeContentPartType(part.Type) parts = append(parts, channel.MessagePart{ - Type: partType, - Text: part.Text, - URL: part.URL, - Styles: normalizeContentPartStyles(part.Styles), - Language: part.Language, - UserID: part.UserID, - Emoji: part.Emoji, + Type: partType, + Text: part.Text, + URL: part.URL, + Styles: normalizeContentPartStyles(part.Styles), + Language: part.Language, + ChannelIdentityID: part.ChannelIdentityID, + Emoji: part.Emoji, }) } if len(parts) > 0 { @@ -350,11 +522,11 @@ func normalizeContentPartStyles(styles []string) []channel.MessageTextStyle { } type sendMessageToolArgs struct { - Platform string `json:"platform"` - Target string `json:"target"` - UserID string `json:"user_id"` - Text string `json:"text"` - Message *channel.Message `json:"message"` + Platform string `json:"platform"` + Target string `json:"target"` + ChannelIdentityID string `json:"channel_identity_id"` + Text string `json:"text"` + Message *channel.Message `json:"message"` } func collectMessageToolContext(registry *channel.Registry, messages []chat.ModelMessage, channelType channel.ChannelType, replyTarget string) ([]string, bool) { @@ -419,7 +591,7 @@ func shouldSuppressForToolCall(registry *channel.Registry, args sendMessageToolA return false } target := strings.TrimSpace(args.Target) - if target == "" && strings.TrimSpace(args.UserID) == "" { + if target == "" && strings.TrimSpace(args.ChannelIdentityID) == "" { target = replyTarget } if strings.TrimSpace(target) == "" || strings.TrimSpace(replyTarget) == "" { diff --git a/internal/router/channel_test.go b/internal/router/channel_test.go index 1dc061d6..6d8c130b 100644 --- a/internal/router/channel_test.go +++ b/internal/router/channel_test.go @@ -2,60 +2,16 @@ package router import ( "context" - "fmt" + "encoding/json" "log/slog" "strings" "testing" "github.com/memohai/memoh/internal/channel" + "github.com/memohai/memoh/internal/channelidentities" "github.com/memohai/memoh/internal/chat" - "github.com/memohai/memoh/internal/contacts" - "github.com/memohai/memoh/internal/policy" ) -type fakeConfigStore struct { - session channel.ChannelSession - boundUserID string -} - -func (f *fakeConfigStore) ResolveEffectiveConfig(ctx context.Context, botID string, channelType channel.ChannelType) (channel.ChannelConfig, error) { - return channel.ChannelConfig{}, nil -} - -func (f *fakeConfigStore) GetUserConfig(ctx context.Context, actorUserID string, channelType channel.ChannelType) (channel.ChannelUserBinding, error) { - return channel.ChannelUserBinding{}, fmt.Errorf("not implemented") -} - -func (f *fakeConfigStore) UpsertUserConfig(ctx context.Context, actorUserID string, channelType channel.ChannelType, req channel.UpsertUserConfigRequest) (channel.ChannelUserBinding, error) { - return channel.ChannelUserBinding{}, nil -} - -func (f *fakeConfigStore) ListConfigsByType(ctx context.Context, channelType channel.ChannelType) ([]channel.ChannelConfig, error) { - return nil, nil -} - -func (f *fakeConfigStore) ResolveUserBinding(ctx context.Context, channelType channel.ChannelType, criteria channel.BindingCriteria) (string, error) { - if f.boundUserID == "" { - return "", fmt.Errorf("channel user binding not found") - } - return f.boundUserID, nil -} - -func (f *fakeConfigStore) ListSessionsByBotPlatform(ctx context.Context, botID, platform string) ([]channel.ChannelSession, error) { - return nil, nil -} - -func (f *fakeConfigStore) GetChannelSession(ctx context.Context, sessionID string) (channel.ChannelSession, error) { - if f.session.SessionID == sessionID { - return f.session, nil - } - return channel.ChannelSession{}, nil -} - -func (f *fakeConfigStore) UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string, replyTarget string, threadID string, metadata map[string]any) error { - return nil -} - type fakeChatGateway struct { resp chat.ChatResponse err error @@ -67,50 +23,6 @@ func (f *fakeChatGateway) Chat(ctx context.Context, req chat.ChatRequest) (chat. return f.resp, f.err } -type fakeContactService struct { - contactID string -} - -func (f *fakeContactService) GetByID(ctx context.Context, contactID string) (contacts.Contact, error) { - return contacts.Contact{}, fmt.Errorf("not found") -} - -func (f *fakeContactService) GetByUserID(ctx context.Context, botID, userID string) (contacts.Contact, error) { - return contacts.Contact{}, fmt.Errorf("not found") -} - -func (f *fakeContactService) GetByChannelIdentity(ctx context.Context, botID, platform, externalID string) (contacts.ContactChannel, error) { - return contacts.ContactChannel{}, fmt.Errorf("not found") -} - -func (f *fakeContactService) Create(ctx context.Context, req contacts.CreateRequest) (contacts.Contact, error) { - return contacts.Contact{ID: "contact-1", BotID: req.BotID, UserID: req.UserID}, nil -} - -func (f *fakeContactService) CreateGuest(ctx context.Context, botID, displayName string) (contacts.Contact, error) { - return contacts.Contact{ID: "contact-guest", BotID: botID}, nil -} - -func (f *fakeContactService) UpsertChannel(ctx context.Context, botID, contactID, platform, externalID string, metadata map[string]any) (contacts.ContactChannel, error) { - return contacts.ContactChannel{ID: "channel-1", ContactID: contactID}, nil -} - -type fakePolicyService struct { - decision policy.Decision - err error -} - -func (f *fakePolicyService) Resolve(ctx context.Context, botID string) (policy.Decision, error) { - if f.err != nil { - return policy.Decision{}, f.err - } - decision := f.decision - if decision.BotID == "" { - decision.BotID = botID - } - return decision, nil -} - type fakeReplySender struct { sent []channel.OutboundMessage } @@ -120,28 +32,58 @@ func (s *fakeReplySender) Send(ctx context.Context, msg channel.OutboundMessage) return nil } -func TestChannelInboundProcessorBoundUser(t *testing.T) { - store := &fakeConfigStore{ - session: channel.ChannelSession{ - SessionID: "feishu:bot-1:chat-1", - UserID: "user-123", - }, +type fakeChatService struct { + resolveResult chat.ResolveChatResult + resolveErr error + persisted []chat.Message +} + +func (f *fakeChatService) ResolveChat(ctx context.Context, botID, platform, conversationID, threadID, conversationType, userID, channelConfigID, replyTarget string) (chat.ResolveChatResult, error) { + if f.resolveErr != nil { + return chat.ResolveChatResult{}, f.resolveErr } + return f.resolveResult, nil +} + +func (f *fakeChatService) PersistMessage(ctx context.Context, chatID, botID, routeID, senderChannelIdentityID, senderUserID, platform, externalMessageID, role string, content json.RawMessage, metadata map[string]any) (chat.Message, error) { + msg := chat.Message{ + ChatID: chatID, + BotID: botID, + RouteID: routeID, + SenderChannelIdentityID: senderChannelIdentityID, + SenderUserID: senderUserID, + Platform: platform, + ExternalMessageID: externalMessageID, + Role: role, + Content: content, + Metadata: metadata, + } + f.persisted = append(f.persisted, msg) + return msg, nil +} + +func TestChannelInboundProcessorWithIdentity(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-1"}} + memberSvc := &fakeMemberService{isMember: true} + policySvc := &fakePolicyService{allow: false} + chatSvc := &fakeChatService{resolveResult: chat.ResolveChatResult{ChatID: "chat-1", RouteID: "route-1"}} gateway := &fakeChatGateway{ resp: chat.ChatResponse{ Messages: []chat.ModelMessage{ - {Role: "assistant", Content: chat.NewTextContent("AI回复内容")}, + {Role: "assistant", Content: chat.NewTextContent("AI reply")}, }, }, } - processor := NewChannelInboundProcessor(slog.Default(), nil, store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} msg := channel.InboundMessage{ + BotID: "bot-1", Channel: channel.ChannelType("feishu"), - Message: channel.Message{Text: "你好"}, + Message: channel.Message{Text: "hello"}, ReplyTarget: "target-id", + Sender: channel.Identity{SubjectID: "ext-1", DisplayName: "User1"}, Conversation: channel.Conversation{ ID: "chat-1", Type: "p2p", @@ -150,48 +92,59 @@ func TestChannelInboundProcessorBoundUser(t *testing.T) { err := processor.HandleInbound(context.Background(), cfg, msg, sender) if err != nil { - t.Fatalf("不应报错: %v", err) + t.Fatalf("unexpected error: %v", err) } - if gateway.gotReq.Query != "你好" { - t.Errorf("Chat 请求 Query 错误: %s", gateway.gotReq.Query) + if gateway.gotReq.Query != "hello" { + t.Errorf("expected query 'hello', got: %s", gateway.gotReq.Query) } - if gateway.gotReq.SessionID != "feishu:bot-1:chat-1" { - t.Errorf("SessionID 传递错误: %s", gateway.gotReq.SessionID) + if gateway.gotReq.ChannelIdentityID != "channelIdentity-1" { + t.Errorf("expected channel_identity_id 'channelIdentity-1', got: %s", gateway.gotReq.ChannelIdentityID) } - if len(sender.sent) != 1 || sender.sent[0].Message.PlainText() != "AI回复内容" { - t.Fatalf("应发送 AI 回复,实际: %+v", sender.sent) + if gateway.gotReq.ChatID != "chat-1" { + t.Errorf("expected chat_id 'chat-1', got: %s", gateway.gotReq.ChatID) + } + if len(sender.sent) != 1 || sender.sent[0].Message.PlainText() != "AI reply" { + t.Fatalf("expected AI reply, got: %+v", sender.sent) } } -func TestChannelInboundProcessorUnboundUser(t *testing.T) { - store := &fakeConfigStore{} +func TestChannelInboundProcessorDenied(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-2"}} + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: false} + chatSvc := &fakeChatService{} gateway := &fakeChatGateway{} - processor := NewChannelInboundProcessor(slog.Default(), nil, store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} msg := channel.InboundMessage{ + BotID: "bot-1", Channel: channel.ChannelType("feishu"), - Message: channel.Message{Text: "你好"}, + Message: channel.Message{Text: "hello"}, ReplyTarget: "target-id", + Sender: channel.Identity{SubjectID: "stranger-1"}, } err := processor.HandleInbound(context.Background(), cfg, msg, sender) if err != nil { - t.Fatalf("不应报错: %v", err) + t.Fatalf("unexpected error: %v", err) } - if len(sender.sent) != 1 || !strings.Contains(sender.sent[0].Message.PlainText(), "陌生人") { - t.Fatalf("应发送绑定提示,实际: %+v", sender.sent) + if len(sender.sent) != 1 || !strings.Contains(sender.sent[0].Message.PlainText(), "denied") { + t.Fatalf("expected access denied reply, got: %+v", sender.sent) } if gateway.gotReq.Query != "" { - t.Error("未绑定用户不应触发 Chat 调用") + t.Error("denied user should not trigger chat call") } } func TestChannelInboundProcessorIgnoreEmpty(t *testing.T) { - store := &fakeConfigStore{} + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-3"}} + memberSvc := &fakeMemberService{isMember: true} + policySvc := &fakePolicyService{allow: false} + chatSvc := &fakeChatService{} gateway := &fakeChatGateway{} - processor := NewChannelInboundProcessor(slog.Default(), nil, store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1"} @@ -199,23 +152,20 @@ func TestChannelInboundProcessorIgnoreEmpty(t *testing.T) { err := processor.HandleInbound(context.Background(), cfg, msg, sender) if err != nil { - t.Fatalf("空消息不应报错: %v", err) + t.Fatalf("empty message should not error: %v", err) } if len(sender.sent) != 0 { - t.Fatalf("空消息不应发送回复: %+v", sender.sent) + t.Fatalf("empty message should not produce reply: %+v", sender.sent) } if gateway.gotReq.Query != "" { - t.Error("空消息不应触发 Chat 调用") + t.Error("empty message should not trigger chat call") } } func TestChannelInboundProcessorSilentReply(t *testing.T) { - store := &fakeConfigStore{ - session: channel.ChannelSession{ - SessionID: "feishu:bot-1:chat-1", - UserID: "user-123", - }, - } + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-4"}} + memberSvc := &fakeMemberService{isMember: true} + chatSvc := &fakeChatService{resolveResult: chat.ResolveChatResult{ChatID: "chat-4", RouteID: "route-4"}} gateway := &fakeChatGateway{ resp: chat.ChatResponse{ Messages: []chat.ModelMessage{ @@ -223,123 +173,200 @@ func TestChannelInboundProcessorSilentReply(t *testing.T) { }, }, } - processor := NewChannelInboundProcessor(slog.Default(), nil, store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} msg := channel.InboundMessage{ - Channel: channel.ChannelType("feishu"), - Message: channel.Message{Text: "你好"}, - ReplyTarget: "target-id", + BotID: "bot-1", + Channel: channel.ChannelType("telegram"), + Message: channel.Message{Text: "test"}, + ReplyTarget: "chat-123", + Sender: channel.Identity{SubjectID: "user-1"}, Conversation: channel.Conversation{ - ID: "chat-1", + ID: "conv-1", Type: "p2p", }, } err := processor.HandleInbound(context.Background(), cfg, msg, sender) if err != nil { - t.Fatalf("不应报错: %v", err) + t.Fatalf("unexpected error: %v", err) } if len(sender.sent) != 0 { - t.Fatalf("NO_REPLY 不应发送回复,实际: %+v", sender.sent) + t.Fatalf("NO_REPLY should suppress output: %+v", sender.sent) } } -func TestChannelInboundProcessorSuppressOnToolSend(t *testing.T) { - store := &fakeConfigStore{ - session: channel.ChannelSession{ - SessionID: "feishu:bot-1:chat-1", - UserID: "user-123", - }, - } +func TestChannelInboundProcessorGroupPassiveSync(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-5"}} + memberSvc := &fakeMemberService{isMember: true} + chatSvc := &fakeChatService{resolveResult: chat.ResolveChatResult{ChatID: "chat-5", RouteID: "route-5"}} gateway := &fakeChatGateway{ resp: chat.ChatResponse{ Messages: []chat.ModelMessage{ - { - Role: "assistant", - ToolCalls: []chat.ToolCall{ - { - Type: "function", - Function: chat.ToolCallFunction{ - Name: "send_message", - Arguments: `{"platform":"feishu","target":"target-id","message":{"text":"AI回复内容"}}`, - }, - }, - }, - }, - {Role: "assistant", Content: chat.NewTextContent("AI回复内容")}, + {Role: "assistant", Content: chat.NewTextContent("AI reply")}, }, }, } - processor := NewChannelInboundProcessor(slog.Default(), nil, store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} msg := channel.InboundMessage{ + BotID: "bot-1", Channel: channel.ChannelType("feishu"), - Message: channel.Message{Text: "你好"}, - ReplyTarget: "target-id", + Message: channel.Message{ID: "msg-1", Text: "hello everyone"}, + ReplyTarget: "chat_id:oc_123", + Sender: channel.Identity{SubjectID: "user-1"}, Conversation: channel.Conversation{ - ID: "chat-1", - Type: "p2p", + ID: "oc_123", + Type: "group", }, } err := processor.HandleInbound(context.Background(), cfg, msg, sender) if err != nil { - t.Fatalf("不应报错: %v", err) + t.Fatalf("unexpected error: %v", err) + } + if gateway.gotReq.Query != "" { + t.Fatalf("group passive sync should not trigger chat call") } if len(sender.sent) != 0 { - t.Fatalf("工具已发送当前会话消息,应抑制普通回复,实际: %+v", sender.sent) + t.Fatalf("group passive sync should not send reply: %+v", sender.sent) + } + if len(chatSvc.persisted) != 1 { + t.Fatalf("expected 1 passive persisted message, got: %d", len(chatSvc.persisted)) + } + if chatSvc.persisted[0].Role != "user" { + t.Fatalf("expected persisted role user, got: %s", chatSvc.persisted[0].Role) } } -func TestChannelInboundProcessorDedupeWithToolSend(t *testing.T) { - store := &fakeConfigStore{ - session: channel.ChannelSession{ - SessionID: "feishu:bot-1:chat-1", - UserID: "user-123", - }, - } +func TestChannelInboundProcessorGroupMentionTriggersReply(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-6"}} + memberSvc := &fakeMemberService{isMember: true} + chatSvc := &fakeChatService{resolveResult: chat.ResolveChatResult{ChatID: "chat-6", RouteID: "route-6"}} gateway := &fakeChatGateway{ resp: chat.ChatResponse{ Messages: []chat.ModelMessage{ - { - Role: "assistant", - ToolCalls: []chat.ToolCall{ - { - Type: "function", - Function: chat.ToolCallFunction{ - Name: "send_message", - Arguments: `{"platform":"feishu","target":"other-target","message":{"text":"AI回复内容"}}`, - }, - }, - }, - }, - {Role: "assistant", Content: chat.NewTextContent("AI回复内容")}, + {Role: "assistant", Content: chat.NewTextContent("AI reply")}, }, }, } - processor := NewChannelInboundProcessor(slog.Default(), nil, store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} msg := channel.InboundMessage{ + BotID: "bot-1", Channel: channel.ChannelType("feishu"), - Message: channel.Message{Text: "你好"}, - ReplyTarget: "target-id", + Message: channel.Message{ID: "msg-2", Text: "@bot ping"}, + ReplyTarget: "chat_id:oc_123", + Sender: channel.Identity{SubjectID: "user-1"}, Conversation: channel.Conversation{ - ID: "chat-1", - Type: "p2p", + ID: "oc_123", + Type: "group", + }, + Metadata: map[string]any{ + "is_mentioned": true, }, } err := processor.HandleInbound(context.Background(), cfg, msg, sender) if err != nil { - t.Fatalf("不应报错: %v", err) + t.Fatalf("unexpected error: %v", err) } - if len(sender.sent) != 0 { - t.Fatalf("工具发送文本与普通回复重复,应去重,实际: %+v", sender.sent) + if gateway.gotReq.Query == "" { + t.Fatalf("group mention should trigger chat call") + } + if len(sender.sent) != 1 { + t.Fatalf("expected one outbound reply, got %d", len(sender.sent)) + } + if len(chatSvc.persisted) != 0 { + t.Fatalf("triggered group message should not use passive persistence") + } +} + +func TestChannelInboundProcessorPersonalGroupNonOwnerIgnored(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-member"}} + memberSvc := &fakeMemberService{isMember: true} + policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner"} + chatSvc := &fakeChatService{resolveResult: chat.ResolveChatResult{ChatID: "chat-personal-1", RouteID: "route-personal-1"}} + gateway := &fakeChatGateway{ + resp: chat.ChatResponse{ + Messages: []chat.ModelMessage{ + {Role: "assistant", Content: chat.NewTextContent("AI reply")}, + }, + }, + } + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) + sender := &fakeReplySender{} + + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{ID: "msg-personal-1", Text: "hello"}, + ReplyTarget: "chat_id:oc_personal", + Sender: channel.Identity{SubjectID: "ext-member-1"}, + Conversation: channel.Conversation{ + ID: "oc_personal", + Type: "group", + }, + } + + err := processor.HandleInbound(context.Background(), cfg, msg, sender) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gateway.gotReq.Query != "" { + t.Fatalf("non-owner should not trigger chat call") + } + if len(sender.sent) != 0 { + t.Fatalf("non-owner should be ignored silently: %+v", sender.sent) + } + if len(chatSvc.persisted) != 0 { + t.Fatalf("ignored message should not persist in passive mode") + } +} + +func TestChannelInboundProcessorPersonalGroupOwnerForceReply(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-owner"}} + memberSvc := &fakeMemberService{isMember: true} + policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner"} + chatSvc := &fakeChatService{resolveResult: chat.ResolveChatResult{ChatID: "chat-personal-2", RouteID: "route-personal-2"}} + gateway := &fakeChatGateway{ + resp: chat.ChatResponse{ + Messages: []chat.ModelMessage{ + {Role: "assistant", Content: chat.NewTextContent("AI reply")}, + }, + }, + } + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) + sender := &fakeReplySender{} + + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{ID: "msg-personal-2", Text: "owner says hi"}, + ReplyTarget: "chat_id:oc_personal", + Sender: channel.Identity{SubjectID: "ext-owner-1"}, + Conversation: channel.Conversation{ + ID: "oc_personal", + Type: "group", + }, + } + + err := processor.HandleInbound(context.Background(), cfg, msg, sender) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gateway.gotReq.Query == "" { + t.Fatalf("owner should trigger chat call in personal group") + } + if len(sender.sent) != 1 { + t.Fatalf("expected one owner reply, got %d", len(sender.sent)) } } diff --git a/internal/router/identity.go b/internal/router/identity.go index 863a7036..f9e3fcca 100644 --- a/internal/router/identity.go +++ b/internal/router/identity.go @@ -8,27 +8,30 @@ import ( "strings" "time" + "github.com/memohai/memoh/internal/bind" "github.com/memohai/memoh/internal/channel" - "github.com/memohai/memoh/internal/contacts" - "github.com/memohai/memoh/internal/policy" + "github.com/memohai/memoh/internal/channelidentities" "github.com/memohai/memoh/internal/preauth" ) +// IdentityDecision indicates whether the inbound message should be stopped with an optional reply. type IdentityDecision struct { Stop bool Reply channel.Message } +// InboundIdentity carries the resolved channel identity for an inbound message. type InboundIdentity struct { - BotID string - SessionID string - ChannelConfigID string - ExternalID string - UserID string - ContactID string - Contact contacts.Contact + BotID string + ChannelConfigID string + SubjectID string + ChannelIdentityID string + UserID string + DisplayName string + ForceReply bool } +// IdentityState bundles resolved identity with an optional early-exit decision. type IdentityState struct { Identity InboundIdentity Decision *IdentityDecision @@ -36,10 +39,12 @@ type IdentityState struct { type identityContextKey struct{} +// WithIdentityState stores IdentityState in the context. func WithIdentityState(ctx context.Context, state IdentityState) context.Context { return context.WithValue(ctx, identityContextKey{}, state) } +// IdentityStateFromContext retrieves IdentityState from the context. func IdentityStateFromContext(ctx context.Context) (IdentityState, bool) { if ctx == nil { return IdentityState{}, false @@ -52,54 +57,88 @@ func IdentityStateFromContext(ctx context.Context) (IdentityState, bool) { return state, ok } -// IdentityStore is the minimal persistence interface required by IdentityResolver. -type IdentityStore interface { - channel.BindingStore - channel.SessionStore +// ChannelIdentityService is the minimal interface for channel identity resolution. +type ChannelIdentityService interface { + ResolveByChannelIdentity(ctx context.Context, channel, channelSubjectID, displayName string) (channelidentities.ChannelIdentity, error) + Canonicalize(ctx context.Context, channelIdentityID string) (string, error) + GetLinkedUserID(ctx context.Context, channelIdentityID string) (string, error) + LinkChannelIdentityToUser(ctx context.Context, channelIdentityID, userID string) error } -type IdentityResolver struct { - registry *channel.Registry - store IdentityStore - contacts ContactService - policy PolicyService - preauth PreauthService - logger *slog.Logger - unboundReply string - preauthReply string +// BotMemberService checks and manages bot membership. +type BotMemberService interface { + IsMember(ctx context.Context, botID, channelIdentityID string) (bool, error) + UpsertMemberSimple(ctx context.Context, botID, channelIdentityID, role string) error } +// PolicyService resolves access policy for a bot. type PolicyService interface { - Resolve(ctx context.Context, botID string) (policy.Decision, error) + AllowGuest(ctx context.Context, botID string) (bool, error) + BotType(ctx context.Context, botID string) (string, error) + BotOwnerUserID(ctx context.Context, botID string) (string, error) } +// PreauthService handles preauth key validation. type PreauthService interface { Get(ctx context.Context, token string) (preauth.Key, error) MarkUsed(ctx context.Context, id string) (preauth.Key, error) } -func NewIdentityResolver(log *slog.Logger, registry *channel.Registry, store IdentityStore, contacts ContactService, policyService PolicyService, preauthService PreauthService, unboundReply, preauthReply string) *IdentityResolver { +// BindService handles channel identity bind code validation and consumption. +type BindService interface { + Get(ctx context.Context, token string) (bind.Code, error) + Consume(ctx context.Context, code bind.Code, channelIdentityID string) error +} + +// IdentityResolver implements identity resolution with bind code, preauth, and guest fallback. +type IdentityResolver struct { + registry *channel.Registry + channelIdentities ChannelIdentityService + members BotMemberService + policy PolicyService + preauth PreauthService + bind BindService + logger *slog.Logger + unboundReply string + preauthReply string + bindReply string +} + +// NewIdentityResolver creates an IdentityResolver. +func NewIdentityResolver( + log *slog.Logger, + registry *channel.Registry, + channelIdentityService ChannelIdentityService, + memberService BotMemberService, + policyService PolicyService, + preauthService PreauthService, + bindService BindService, + unboundReply, preauthReply string, +) *IdentityResolver { if log == nil { log = slog.Default() } if strings.TrimSpace(unboundReply) == "" { - unboundReply = "当前不允许陌生人访问,请联系管理员。" + unboundReply = "Access denied. Please contact the administrator." } if strings.TrimSpace(preauthReply) == "" { - preauthReply = "授权成功,请继续使用。" + preauthReply = "Authorization successful." } return &IdentityResolver{ - registry: registry, - store: store, - contacts: contacts, - policy: policyService, - preauth: preauthService, - logger: log.With(slog.String("component", "channel_identity")), - unboundReply: unboundReply, - preauthReply: preauthReply, + registry: registry, + channelIdentities: channelIdentityService, + members: memberService, + policy: policyService, + preauth: preauthService, + bind: bindService, + logger: log.With(slog.String("component", "channel_identity")), + unboundReply: unboundReply, + preauthReply: preauthReply, + bindReply: "Binding successful! Your identity has been linked.", } } +// Middleware returns a channel middleware that resolves identity before processing. func (r *IdentityResolver) Middleware() channel.Middleware { return func(next channel.InboundHandler) channel.InboundHandler { return func(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) error { @@ -112,8 +151,11 @@ func (r *IdentityResolver) Middleware() channel.Middleware { } } +// Resolve performs two-phase identity resolution: +// 1. Global identity: (channel, channel_subject_id) -> channel_identity_id (unconditional) +// 2. Authorization: bot membership check with guest/preauth fallback func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) (IdentityState, error) { - if r.store == nil || r.contacts == nil || r.policy == nil { + if r.channelIdentities == nil { return IdentityState{}, fmt.Errorf("identity resolver not configured") } @@ -121,111 +163,120 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi if botID == "" { botID = cfg.BotID } - normalizedMsg := msg - normalizedMsg.BotID = botID - sessionID := normalizedMsg.SessionID() channelConfigID := cfg.ID if r.registry != nil && r.registry.IsConfigless(msg.Channel) { channelConfigID = "" } - externalID := extractExternalIdentity(msg) + subjectID := extractSubjectIdentity(msg) + displayName := extractDisplayName(msg) state := IdentityState{ Identity: InboundIdentity{ BotID: botID, - SessionID: sessionID, ChannelConfigID: channelConfigID, - ExternalID: externalID, + SubjectID: subjectID, }, } - session, err := r.store.GetChannelSession(ctx, sessionID) - if err != nil && r.logger != nil { - r.logger.Error("get user by session failed", slog.String("session_id", sessionID), slog.Any("error", err)) - } - userID := strings.TrimSpace(session.UserID) - contactID := strings.TrimSpace(session.ContactID) - - if userID == "" { - userID, err = r.store.ResolveUserBinding(ctx, msg.Channel, channel.BindingCriteriaFromIdentity(msg.Sender)) - if err == nil && userID != "" { - _ = r.store.UpsertChannelSession(ctx, sessionID, botID, channelConfigID, userID, contactID, string(msg.Channel), strings.TrimSpace(msg.ReplyTarget), extractThreadID(msg), buildSessionMetadata(msg)) - } + // Phase 1: Global identity resolution (unconditional). + if subjectID == "" { + return state, fmt.Errorf("cannot resolve identity: no channel_subject_id") } - var contact contacts.Contact - if contactID == "" && userID != "" { - contact, err = r.contacts.GetByUserID(ctx, botID, userID) - if err != nil { - displayName := extractDisplayName(msg) - contact, err = r.contacts.Create(ctx, contacts.CreateRequest{ - BotID: botID, - UserID: userID, - DisplayName: displayName, - Status: "active", - }) - } - if err == nil { - contactID = contact.ID - if externalID != "" { - _, _ = r.contacts.UpsertChannel(ctx, botID, contactID, msg.Channel.String(), externalID, nil) - } - } + channelIdentity, err := r.channelIdentities.ResolveByChannelIdentity(ctx, msg.Channel.String(), subjectID, displayName) + if err != nil { + return state, fmt.Errorf("resolve channel identity: %w", err) } - if contactID == "" && externalID != "" { - binding, err := r.contacts.GetByChannelIdentity(ctx, botID, msg.Channel.String(), externalID) - if err == nil { - contactID = binding.ContactID - } + channelIdentityID := strings.TrimSpace(channelIdentity.ID) + state.Identity.ChannelIdentityID = channelIdentityID + linkedUserID, err := r.channelIdentities.GetLinkedUserID(ctx, channelIdentityID) + if err != nil { + return state, err } + state.Identity.UserID = strings.TrimSpace(linkedUserID) + if strings.TrimSpace(state.Identity.UserID) == "" { + state.Identity.UserID = r.tryLinkConfiglessChannelIdentityToUser(ctx, msg, channelIdentityID) + } + state.Identity.DisplayName = displayName - if contactID == "" { - decision, err := r.policy.Resolve(ctx, botID) + // Bind code check runs before membership/guest checks so linking is always reachable. + if handled, decision, newUserID, err := r.tryHandleBindCode(ctx, msg, channelIdentityID, subjectID); handled { + if strings.TrimSpace(newUserID) != "" { + state.Identity.UserID = strings.TrimSpace(newUserID) + } + state.Decision = &decision + return state, err + } + if r.policy != nil && isGroupConversationType(msg.Conversation.Type) { + botType, err := r.policy.BotType(ctx, botID) if err != nil { return state, err } - if decision.AllowGuest { - displayName := extractDisplayName(msg) - contact, err = r.contacts.CreateGuest(ctx, botID, displayName) - if err == nil { - contactID = contact.ID - if externalID != "" { - _, _ = r.contacts.UpsertChannel(ctx, botID, contactID, msg.Channel.String(), externalID, nil) - } - } - } else { - if handled, decision, err := r.tryHandlePreauthKey(ctx, normalizedMsg, externalID); handled { - state.Decision = &decision + if strings.EqualFold(strings.TrimSpace(botType), "personal") { + ownerUserID, err := r.policy.BotOwnerUserID(ctx, botID) + if err != nil { return state, err } - state.Decision = &IdentityDecision{ - Stop: true, - Reply: channel.Message{Text: r.unboundReply}, + if strings.TrimSpace(state.Identity.UserID) == "" || strings.TrimSpace(ownerUserID) != strings.TrimSpace(state.Identity.UserID) { + // Personal bots in group chats only answer owner messages. + state.Decision = &IdentityDecision{Stop: true} + return state, nil + } + // Owner can chat normally in group for personal bots. + state.Identity.ForceReply = true + } + } + + // Phase 2: Authorization (bot membership check). + if r.members != nil { + if strings.TrimSpace(state.Identity.UserID) != "" { + isMember, _ := r.members.IsMember(ctx, botID, state.Identity.UserID) + if isMember { + return state, nil + } + } + } + if r.policy != nil && strings.TrimSpace(state.Identity.UserID) != "" { + ownerUserID, err := r.policy.BotOwnerUserID(ctx, botID) + if err != nil { + return state, err + } + // Bot owner should not depend on bot_members linkage. + if strings.TrimSpace(ownerUserID) == strings.TrimSpace(state.Identity.UserID) { + return state, nil + } + } + + // Guest policy check. + if r.policy != nil { + allowed, err := r.policy.AllowGuest(ctx, botID) + if err != nil { + return state, err + } + if allowed { + if r.members != nil && strings.TrimSpace(state.Identity.UserID) != "" { + _ = r.members.UpsertMemberSimple(ctx, botID, state.Identity.UserID, "member") } return state, nil } } - if contactID != "" && contact.ID == "" { - loaded, err := r.contacts.GetByID(ctx, contactID) - if err == nil { - contact = loaded - } + // Preauth key check. + if handled, decision, err := r.tryHandlePreauthKey(ctx, msg, botID, state.Identity.UserID, subjectID); handled { + state.Decision = &decision + return state, err } - if contactID != "" { - _ = r.store.UpsertChannelSession(ctx, sessionID, botID, channelConfigID, userID, contactID, string(msg.Channel), strings.TrimSpace(msg.ReplyTarget), extractThreadID(msg), buildSessionMetadata(msg)) + state.Decision = &IdentityDecision{ + Stop: true, + Reply: channel.Message{Text: r.unboundReply}, } - - state.Identity.UserID = userID - state.Identity.ContactID = contactID - state.Identity.Contact = contact return state, nil } -func (r *IdentityResolver) tryHandlePreauthKey(ctx context.Context, msg channel.InboundMessage, externalID string) (bool, IdentityDecision, error) { +func (r *IdentityResolver) tryHandlePreauthKey(ctx context.Context, msg channel.InboundMessage, botID, userID, subjectID string) (bool, IdentityDecision, error) { tokenText := strings.TrimSpace(msg.Message.PlainText()) if tokenText == "" || r.preauth == nil { return false, IdentityDecision{}, nil @@ -244,33 +295,87 @@ func (r *IdentityResolver) tryHandlePreauthKey(ctx context.Context, msg channel. } } if !key.UsedAt.IsZero() { - return true, reply("预授权码已使用。"), nil + return true, reply("Preauth key already used."), nil } if !key.ExpiresAt.IsZero() && time.Now().UTC().After(key.ExpiresAt) { - return true, reply("预授权码已过期,请重新获取。"), nil + return true, reply("Preauth key expired."), nil } - if key.BotID != msg.BotID { - return true, reply("预授权码不匹配。"), nil + if key.BotID != botID { + return true, reply("Preauth key mismatch."), nil } - if externalID == "" { - return true, reply("无法识别当前账号,授权失败。"), nil + if subjectID == "" { + return true, reply("Cannot identify current account."), nil } - displayName := extractDisplayName(msg) - contact, err := r.contacts.CreateGuest(ctx, msg.BotID, displayName) - if err != nil { - return true, reply("授权失败,请稍后重试。"), nil + + // Grant membership via preauth. + if strings.TrimSpace(userID) == "" { + return true, reply("Current channel account is not linked to a user."), nil } - if _, err := r.contacts.UpsertChannel(ctx, msg.BotID, contact.ID, msg.Channel.String(), externalID, nil); err != nil { - return true, reply("授权失败,请稍后重试。"), nil + if r.members != nil { + _ = r.members.UpsertMemberSimple(ctx, botID, userID, "member") } - _ = r.store.UpsertChannelSession(ctx, msg.SessionID(), msg.BotID, "", "", contact.ID, msg.Channel.String(), strings.TrimSpace(msg.ReplyTarget), extractThreadID(msg), buildSessionMetadata(msg)) _, _ = r.preauth.MarkUsed(ctx, key.ID) return true, reply(r.preauthReply), nil } -func extractExternalIdentity(msg channel.InboundMessage) string { - if strings.TrimSpace(msg.Sender.ExternalID) != "" { - return strings.TrimSpace(msg.Sender.ExternalID) +func (r *IdentityResolver) tryHandleBindCode(ctx context.Context, msg channel.InboundMessage, channelIdentityID, subjectID string) (bool, IdentityDecision, string, error) { + tokenText := strings.TrimSpace(msg.Message.PlainText()) + if tokenText == "" || r.bind == nil { + return false, IdentityDecision{}, "", nil + } + code, err := r.bind.Get(ctx, tokenText) + if err != nil { + if errors.Is(err, bind.ErrCodeNotFound) { + return false, IdentityDecision{}, "", nil + } + return true, IdentityDecision{}, "", err + } + reply := func(text string) IdentityDecision { + return IdentityDecision{Stop: true, Reply: channel.Message{Text: text}} + } + if !code.UsedAt.IsZero() { + return true, reply("Bind code already used."), "", nil + } + if !code.ExpiresAt.IsZero() && time.Now().UTC().After(code.ExpiresAt) { + return true, reply("Bind code expired."), "", nil + } + if strings.TrimSpace(code.Platform) != "" && !strings.EqualFold(strings.TrimSpace(code.Platform), msg.Channel.String()) { + return true, reply("Bind code mismatch."), "", nil + } + if subjectID == "" { + return true, reply("Cannot identify current account."), "", nil + } + + // Consume: mark used + link source channel identity to issuer user. + if err := r.bind.Consume(ctx, code, channelIdentityID); err != nil { + switch { + case errors.Is(err, bind.ErrCodeUsed): + return true, reply("Bind code already used."), "", nil + case errors.Is(err, bind.ErrCodeExpired): + return true, reply("Bind code expired."), "", nil + case errors.Is(err, bind.ErrCodeMismatch): + return true, reply("Bind code mismatch."), "", nil + case errors.Is(err, bind.ErrLinkConflict): + return true, reply("Current identity has already been linked to another account."), "", nil + default: + return true, IdentityDecision{}, "", fmt.Errorf("consume bind code: %w", err) + } + } + + // Resolve linked user after binding. + newUserID := code.IssuedByUserID + if r.channelIdentities != nil { + if linkedUserID, err := r.channelIdentities.GetLinkedUserID(ctx, channelIdentityID); err == nil && strings.TrimSpace(linkedUserID) != "" { + newUserID = linkedUserID + } + } + + return true, reply(r.bindReply), newUserID, nil +} + +func extractSubjectIdentity(msg channel.InboundMessage) string { + if strings.TrimSpace(msg.Sender.SubjectID) != "" { + return strings.TrimSpace(msg.Sender.SubjectID) } if value := strings.TrimSpace(msg.Sender.Attribute("open_id")); value != "" { return value @@ -288,8 +393,8 @@ func extractDisplayName(msg channel.InboundMessage) string { if strings.TrimSpace(msg.Sender.DisplayName) != "" { return strings.TrimSpace(msg.Sender.DisplayName) } - if strings.TrimSpace(msg.Sender.ExternalID) != "" { - return strings.TrimSpace(msg.Sender.ExternalID) + if strings.TrimSpace(msg.Sender.SubjectID) != "" { + return strings.TrimSpace(msg.Sender.SubjectID) } if value := strings.TrimSpace(msg.Sender.Attribute("username")); value != "" { return value @@ -313,6 +418,39 @@ func extractThreadID(msg channel.InboundMessage) string { return "" } +func isGroupConversationType(conversationType string) bool { + ct := strings.ToLower(strings.TrimSpace(conversationType)) + if ct == "" { + return false + } + return ct != "p2p" && ct != "private" && ct != "direct" +} + +func (r *IdentityResolver) tryLinkConfiglessChannelIdentityToUser(ctx context.Context, msg channel.InboundMessage, channelIdentityID string) string { + if r.registry == nil || !r.registry.IsConfigless(msg.Channel) { + return "" + } + if r.channelIdentities == nil { + return "" + } + candidateUserID := strings.TrimSpace(msg.Sender.Attribute("user_id")) + if candidateUserID == "" { + return "" + } + if err := r.channelIdentities.LinkChannelIdentityToUser(ctx, channelIdentityID, candidateUserID); err != nil { + if r.logger != nil { + r.logger.Warn("auto link configless channel identity failed", + slog.String("channel", msg.Channel.String()), + slog.String("channel_identity_id", channelIdentityID), + slog.String("user_id", candidateUserID), + slog.Any("error", err), + ) + } + return "" + } + return candidateUserID +} + func buildSessionMetadata(msg channel.InboundMessage) map[string]any { metadata := map[string]any{} if strings.TrimSpace(msg.Source) != "" { diff --git a/internal/router/identity_test.go b/internal/router/identity_test.go index 90c87bd8..1e2b8d17 100644 --- a/internal/router/identity_test.go +++ b/internal/router/identity_test.go @@ -2,105 +2,109 @@ package router import ( "context" - "fmt" "log/slog" "testing" "time" + "github.com/memohai/memoh/internal/bind" "github.com/memohai/memoh/internal/channel" - "github.com/memohai/memoh/internal/contacts" - "github.com/memohai/memoh/internal/policy" + "github.com/memohai/memoh/internal/channelidentities" "github.com/memohai/memoh/internal/preauth" ) -type fakePolicyServiceIdentity struct { - decision policy.Decision - err error +type fakeChannelIdentityService struct { + channelIdentity channelidentities.ChannelIdentity + err error + canonical map[string]string + linked map[string]string + calls int } -func (f *fakePolicyServiceIdentity) Resolve(ctx context.Context, botID string) (policy.Decision, error) { +func (f *fakeChannelIdentityService) ResolveByChannelIdentity(ctx context.Context, platform, externalID, displayName string) (channelidentities.ChannelIdentity, error) { + f.calls++ if f.err != nil { - return policy.Decision{}, f.err + return channelidentities.ChannelIdentity{}, f.err } - decision := f.decision - if decision.BotID == "" { - decision.BotID = botID + return f.channelIdentity, nil +} + +func (f *fakeChannelIdentityService) Canonicalize(ctx context.Context, channelIdentityID string) (string, error) { + if f.canonical != nil { + if value, ok := f.canonical[channelIdentityID]; ok { + return value, nil + } } - return decision, nil + return channelIdentityID, nil } -type fakeIdentityConfigStore struct{} - -func (f *fakeIdentityConfigStore) ResolveEffectiveConfig(ctx context.Context, botID string, channelType channel.ChannelType) (channel.ChannelConfig, error) { - return channel.ChannelConfig{}, nil +func (f *fakeChannelIdentityService) GetLinkedUserID(ctx context.Context, channelIdentityID string) (string, error) { + if f.linked != nil { + if value, ok := f.linked[channelIdentityID]; ok { + return value, nil + } + return "", nil + } + // Default to one-to-one mapping for tests that do not set explicit links. + return channelIdentityID, nil } -func (f *fakeIdentityConfigStore) GetUserConfig(ctx context.Context, actorUserID string, channelType channel.ChannelType) (channel.ChannelUserBinding, error) { - return channel.ChannelUserBinding{}, fmt.Errorf("not implemented") -} - -func (f *fakeIdentityConfigStore) UpsertUserConfig(ctx context.Context, actorUserID string, channelType channel.ChannelType, req channel.UpsertUserConfigRequest) (channel.ChannelUserBinding, error) { - return channel.ChannelUserBinding{}, nil -} - -func (f *fakeIdentityConfigStore) ListConfigsByType(ctx context.Context, channelType channel.ChannelType) ([]channel.ChannelConfig, error) { - return nil, nil -} - -func (f *fakeIdentityConfigStore) ResolveUserBinding(ctx context.Context, channelType channel.ChannelType, criteria channel.BindingCriteria) (string, error) { - return "", fmt.Errorf("channel user binding not found") -} - -func (f *fakeIdentityConfigStore) ListSessionsByBotPlatform(ctx context.Context, botID string, platform string) ([]channel.ChannelSession, error) { - return nil, nil -} - -func (f *fakeIdentityConfigStore) GetChannelSession(ctx context.Context, sessionID string) (channel.ChannelSession, error) { - return channel.ChannelSession{}, nil -} - -func (f *fakeIdentityConfigStore) UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string, replyTarget string, threadID string, metadata map[string]any) error { +func (f *fakeChannelIdentityService) LinkChannelIdentityToUser(ctx context.Context, channelIdentityID, userID string) error { + if f.linked == nil { + f.linked = map[string]string{} + } + f.linked[channelIdentityID] = userID return nil } -type fakeIdentityContactService struct { - createGuestCalled bool - upsertCalled bool +type fakeMemberService struct { + isMember bool + upsertCalled bool } -func (f *fakeIdentityContactService) GetByID(ctx context.Context, contactID string) (contacts.Contact, error) { - return contacts.Contact{}, fmt.Errorf("not found") +func (f *fakeMemberService) IsMember(ctx context.Context, botID, channelIdentityID string) (bool, error) { + return f.isMember, nil } -func (f *fakeIdentityContactService) GetByUserID(ctx context.Context, botID, userID string) (contacts.Contact, error) { - return contacts.Contact{}, fmt.Errorf("not found") -} - -func (f *fakeIdentityContactService) GetByChannelIdentity(ctx context.Context, botID, platform, externalID string) (contacts.ContactChannel, error) { - return contacts.ContactChannel{}, fmt.Errorf("not found") -} - -func (f *fakeIdentityContactService) Create(ctx context.Context, req contacts.CreateRequest) (contacts.Contact, error) { - return contacts.Contact{ID: "contact-1", BotID: req.BotID}, nil -} - -func (f *fakeIdentityContactService) CreateGuest(ctx context.Context, botID, displayName string) (contacts.Contact, error) { - f.createGuestCalled = true - return contacts.Contact{ID: "contact-guest", BotID: botID}, nil -} - -func (f *fakeIdentityContactService) UpsertChannel(ctx context.Context, botID, contactID, platform, externalID string, metadata map[string]any) (contacts.ContactChannel, error) { +func (f *fakeMemberService) UpsertMemberSimple(ctx context.Context, botID, channelIdentityID, role string) error { f.upsertCalled = true - return contacts.ContactChannel{ID: "channel-1", ContactID: contactID}, nil + return nil } -type fakePreauthService struct { +type fakePolicyService struct { + allow bool + botType string + ownerUserID string + err error +} + +func (f *fakePolicyService) AllowGuest(ctx context.Context, botID string) (bool, error) { + if f.err != nil { + return false, f.err + } + return f.allow, nil +} + +func (f *fakePolicyService) BotType(ctx context.Context, botID string) (string, error) { + if f.err != nil { + return "", f.err + } + return f.botType, nil +} + +func (f *fakePolicyService) BotOwnerUserID(ctx context.Context, botID string) (string, error) { + if f.err != nil { + return "", f.err + } + return f.ownerUserID, nil +} + +type fakePreauthServiceIdentity struct { key preauth.Key err error markUsed bool } -func (f *fakePreauthService) Get(ctx context.Context, token string) (preauth.Key, error) { +func (f *fakePreauthServiceIdentity) Get(ctx context.Context, token string) (preauth.Key, error) { if f.err != nil { return preauth.Key{}, f.err } @@ -110,41 +114,92 @@ func (f *fakePreauthService) Get(ctx context.Context, token string) (preauth.Key return f.key, nil } -func (f *fakePreauthService) MarkUsed(ctx context.Context, id string) (preauth.Key, error) { +func (f *fakePreauthServiceIdentity) MarkUsed(ctx context.Context, id string) (preauth.Key, error) { f.markUsed = true return f.key, nil } -func TestIdentityResolverAllowGuestCreatesContact(t *testing.T) { - store := &fakeIdentityConfigStore{} - contactsService := &fakeIdentityContactService{} - policyService := &fakePolicyServiceIdentity{decision: policy.Decision{AllowGuest: true}} - resolver := NewIdentityResolver(slog.Default(), nil, store, contactsService, policyService, nil, "禁止访问", "授权成功") +type fakeBindService struct { + code bind.Code + getErr error + consumeErr error + consumeCalled bool + onConsume func(channelChannelIdentityID string) +} + +func (f *fakeBindService) Get(ctx context.Context, token string) (bind.Code, error) { + if f.getErr != nil { + return bind.Code{}, f.getErr + } + if f.code.Token == "" || f.code.Token != token { + return bind.Code{}, bind.ErrCodeNotFound + } + return f.code, nil +} + +func (f *fakeBindService) Consume(ctx context.Context, code bind.Code, channelChannelIdentityID string) error { + f.consumeCalled = true + if f.onConsume != nil { + f.onConsume(channelChannelIdentityID) + } + return f.consumeErr +} + +func TestIdentityResolverAllowGuestUpsertsMember(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-1"}} + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: true, botType: "public"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") msg := channel.InboundMessage{ BotID: "bot-1", Channel: channel.ChannelType("feishu"), Message: channel.Message{Text: "hello"}, ReplyTarget: "target-id", - Sender: channel.Identity{ExternalID: "user-1", DisplayName: "访客"}, + Sender: channel.Identity{SubjectID: "ext-1", DisplayName: "Guest"}, } state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) if err != nil { - t.Fatalf("不应报错: %v", err) + t.Fatalf("unexpected error: %v", err) } - if state.Identity.ContactID != "contact-guest" { - t.Fatalf("应创建访客联系人,实际: %s", state.Identity.ContactID) + if state.Identity.ChannelIdentityID != "channelIdentity-1" { + t.Fatalf("expected channelIdentity-1, got: %s", state.Identity.ChannelIdentityID) } - if !contactsService.createGuestCalled { - t.Fatalf("应调用 CreateGuest") + if !memberSvc.upsertCalled { + t.Fatal("expected UpsertMemberSimple to be called") + } + if state.Decision != nil { + t.Fatal("expected no decision for allowed guest") } } -func TestIdentityResolverPreauthKeyAllowsGuest(t *testing.T) { - store := &fakeIdentityConfigStore{} - contactsService := &fakeIdentityContactService{} - policyService := &fakePolicyServiceIdentity{} - preauthService := &fakePreauthService{ +func TestIdentityResolverExistingMemberPasses(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-2"}} + memberSvc := &fakeMemberService{isMember: true} + policySvc := &fakePolicyService{allow: false, botType: "public"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("telegram"), + Message: channel.Message{Text: "hello"}, + ReplyTarget: "chat-123", + Sender: channel.Identity{SubjectID: "tg-user-1"}, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision != nil { + t.Fatal("existing member should pass without decision") + } +} + +func TestIdentityResolverPreauthKey(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-3"}} + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: false, botType: "public"} + preauthSvc := &fakePreauthServiceIdentity{ key: preauth.Key{ ID: "key-1", BotID: "bot-1", @@ -152,35 +207,35 @@ func TestIdentityResolverPreauthKeyAllowsGuest(t *testing.T) { ExpiresAt: time.Now().UTC().Add(1 * time.Hour), }, } - resolver := NewIdentityResolver(slog.Default(), nil, store, contactsService, policyService, preauthService, "禁止访问", "授权成功") + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, preauthSvc, nil, "", "") msg := channel.InboundMessage{ BotID: "bot-1", Channel: channel.ChannelType("feishu"), Message: channel.Message{Text: "PREAUTH123"}, ReplyTarget: "target-id", - Sender: channel.Identity{ExternalID: "user-1"}, + Sender: channel.Identity{SubjectID: "ext-1"}, } state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) if err != nil { - t.Fatalf("不应报错: %v", err) + t.Fatalf("unexpected error: %v", err) } if state.Decision == nil || !state.Decision.Stop { - t.Fatalf("应返回授权确认") + t.Fatal("preauth key should return stop decision") } - if !contactsService.upsertCalled { - t.Fatalf("应执行联系人绑定") + if !preauthSvc.markUsed { + t.Fatal("preauth key should be marked used") } - if !preauthService.markUsed { - t.Fatalf("应标记预授权码已使用") + if !memberSvc.upsertCalled { + t.Fatal("membership should be upserted via preauth") } } func TestIdentityResolverPreauthKeyExpired(t *testing.T) { - store := &fakeIdentityConfigStore{} - contactsService := &fakeIdentityContactService{} - policyService := &fakePolicyServiceIdentity{} - preauthService := &fakePreauthService{ + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-4"}} + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: false, botType: "public"} + preauthSvc := &fakePreauthServiceIdentity{ key: preauth.Key{ ID: "key-1", BotID: "bot-1", @@ -188,23 +243,291 @@ func TestIdentityResolverPreauthKeyExpired(t *testing.T) { ExpiresAt: time.Now().UTC().Add(-1 * time.Hour), }, } - resolver := NewIdentityResolver(slog.Default(), nil, store, contactsService, policyService, preauthService, "禁止访问", "授权成功") + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, preauthSvc, nil, "", "") msg := channel.InboundMessage{ BotID: "bot-1", Channel: channel.ChannelType("feishu"), Message: channel.Message{Text: "PREAUTH123"}, ReplyTarget: "target-id", - Sender: channel.Identity{ExternalID: "user-1"}, + Sender: channel.Identity{SubjectID: "ext-1"}, } state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) if err != nil { - t.Fatalf("不应报错: %v", err) + t.Fatalf("unexpected error: %v", err) } if state.Decision == nil || !state.Decision.Stop { - t.Fatalf("过期预授权码应被拒绝") + t.Fatal("expired preauth key should be rejected") } - if preauthService.markUsed { - t.Fatalf("过期预授权码不应被使用") + if preauthSvc.markUsed { + t.Fatal("expired preauth key should not be marked used") + } +} + +func TestIdentityResolverDenied(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-5"}} + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: false, botType: "public"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "Access denied.", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("telegram"), + Message: channel.Message{Text: "hello"}, + ReplyTarget: "chat-123", + Sender: channel.Identity{SubjectID: "stranger-1"}, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatal("stranger without guest access should be denied") + } +} + +func TestIdentityResolverPersonalBotRejectsGroupMessages(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-group"}} + memberSvc := &fakeMemberService{isMember: true} + policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello"}, + Sender: channel.Identity{SubjectID: "ext-group-1"}, + Conversation: channel.Conversation{ + ID: "group-1", + Type: "group", + }, + } + + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatal("personal bot should reject group messages") + } + if channelIdentitySvc.calls != 1 { + t.Fatalf("expected channelIdentity resolution once before owner check, got %d", channelIdentitySvc.calls) + } + if !state.Decision.Reply.IsEmpty() { + t.Fatal("non-owner group message should be silently ignored") + } +} + +func TestIdentityResolverPersonalBotAllowsOwnerInGroup(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-owner"}} + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello from owner"}, + Sender: channel.Identity{SubjectID: "ext-owner-1"}, + Conversation: channel.Conversation{ + ID: "group-1", + Type: "group", + }, + } + + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision != nil { + t.Fatal("owner group message should pass") + } + if !state.Identity.ForceReply { + t.Fatal("owner group message should force reply") + } +} + +func TestIdentityResolverPersonalBotAllowsOwnerDirectWithoutMembership(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-owner-direct"}} + memberSvc := &fakeMemberService{isMember: false} + policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner-direct"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello from owner"}, + Sender: channel.Identity{SubjectID: "ext-owner-direct"}, + Conversation: channel.Conversation{ + ID: "p2p-1", + Type: "p2p", + }, + } + + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision != nil { + t.Fatal("owner direct message should pass") + } + if state.Identity.ForceReply { + t.Fatal("owner direct message should not force reply") + } +} + +func TestIdentityResolverBindRunsBeforeMembershipCheck(t *testing.T) { + shadowID := "channelIdentity-shadow" + humanID := "channelIdentity-human" + channelIdentitySvc := &fakeChannelIdentityService{ + channelIdentity: channelidentities.ChannelIdentity{ID: shadowID}, + linked: map[string]string{ + shadowID: shadowID, + }, + } + memberSvc := &fakeMemberService{isMember: true} + bindSvc := &fakeBindService{ + code: bind.Code{ + ID: "code-1", + Platform: "feishu", + Token: "BIND123", + IssuedByUserID: humanID, + ExpiresAt: time.Now().UTC().Add(1 * time.Hour), + }, + onConsume: func(channelChannelIdentityID string) { + channelIdentitySvc.linked[channelChannelIdentityID] = humanID + }, + } + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, nil, nil, bindSvc, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "BIND123"}, + ReplyTarget: "target-id", + Sender: channel.Identity{SubjectID: "ext-bind-1"}, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !bindSvc.consumeCalled { + t.Fatal("expected bind consume to run before membership shortcut") + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatal("bind flow should return stop decision") + } + if state.Identity.UserID != humanID { + t.Fatalf("expected linked user to switch to %s, got %s", humanID, state.Identity.UserID) + } + if memberSvc.upsertCalled { + t.Fatal("bind should not upsert bot membership") + } +} + +func TestIdentityResolverBindConsumeErrorHandledAsDecision(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-shadow"}} + bindSvc := &fakeBindService{ + code: bind.Code{ + ID: "code-2", + Platform: "telegram", + Token: "BINDUSED", + IssuedByUserID: "channelIdentity-human", + ExpiresAt: time.Now().UTC().Add(1 * time.Hour), + }, + consumeErr: bind.ErrCodeUsed, + } + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, &fakeMemberService{}, nil, nil, bindSvc, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("telegram"), + Message: channel.Message{Text: "BINDUSED"}, + ReplyTarget: "chat-123", + Sender: channel.Identity{SubjectID: "ext-bind-2"}, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatal("bind consume errors should be converted into stop decision") + } +} + +func TestIdentityResolverBindCodeNotScopedToCurrentBot(t *testing.T) { + shadowID := "channelIdentity-shadow-any-bot" + humanID := "channelIdentity-human-any-bot" + channelIdentitySvc := &fakeChannelIdentityService{ + channelIdentity: channelidentities.ChannelIdentity{ID: shadowID}, + linked: map[string]string{ + shadowID: shadowID, + }, + } + bindSvc := &fakeBindService{ + code: bind.Code{ + ID: "code-any-bot", + Platform: "feishu", + Token: "BINDANYBOT", + IssuedByUserID: humanID, + ExpiresAt: time.Now().UTC().Add(1 * time.Hour), + }, + onConsume: func(channelChannelIdentityID string) { + channelIdentitySvc.linked[channelChannelIdentityID] = humanID + }, + } + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, &fakeMemberService{}, nil, nil, bindSvc, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-2", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "BINDANYBOT"}, + ReplyTarget: "target-id", + Sender: channel.Identity{SubjectID: "ext-bind-any-bot"}, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !bindSvc.consumeCalled { + t.Fatal("bind consume should run even when message bot differs") + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatal("bind flow should return stop decision") + } + if state.Identity.UserID != humanID { + t.Fatalf("expected linked user to switch to %s, got %s", humanID, state.Identity.UserID) + } +} + +func TestIdentityResolverBindCodePlatformMismatch(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-platform-mismatch"}} + bindSvc := &fakeBindService{ + code: bind.Code{ + ID: "code-platform", + Platform: "telegram", + Token: "BINDPLATFORM", + IssuedByUserID: "channelIdentity-human-platform", + ExpiresAt: time.Now().UTC().Add(1 * time.Hour), + }, + } + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, &fakeMemberService{}, nil, nil, bindSvc, "", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "BINDPLATFORM"}, + ReplyTarget: "target-id", + Sender: channel.Identity{SubjectID: "ext-bind-platform"}, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if bindSvc.consumeCalled { + t.Fatal("bind consume should not run when platform mismatches") + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatal("platform mismatch should return stop decision") } } diff --git a/internal/schedule/service_test.go b/internal/schedule/service_test.go index 15ec87fc..aca684b6 100644 --- a/internal/schedule/service_test.go +++ b/internal/schedule/service_test.go @@ -55,8 +55,8 @@ func TestGenerateTriggerToken(t *testing.T) { if sub, _ := claims["sub"].(string); sub != userID { t.Errorf("expected sub=%s, got=%s", userID, sub) } - if uid, _ := claims["user_id"].(string); uid != userID { - t.Errorf("expected user_id=%s, got=%s", userID, uid) + if uid, _ := claims["channel_identity_id"].(string); uid != userID { + t.Errorf("expected channel_identity_id=%s, got=%s", userID, uid) } exp, _ := claims["exp"].(float64) if exp == 0 { diff --git a/internal/schedule/trigger.go b/internal/schedule/trigger.go index e2d3b5ab..1cc01b39 100644 --- a/internal/schedule/trigger.go +++ b/internal/schedule/trigger.go @@ -11,6 +11,7 @@ type TriggerPayload struct { MaxCalls *int Command string OwnerUserID string + ChatID string } // Triggerer 负责触发与聊天相关的调度执行。 diff --git a/internal/schedule/types.go b/internal/schedule/types.go index ffa0cbc7..7c95b1f6 100644 --- a/internal/schedule/types.go +++ b/internal/schedule/types.go @@ -50,21 +50,21 @@ func (n *NullableInt) UnmarshalJSON(data []byte) error { } type CreateRequest struct { - Name string `json:"name"` - Description string `json:"description"` - Pattern string `json:"pattern"` + Name string `json:"name"` + Description string `json:"description"` + Pattern string `json:"pattern"` MaxCalls NullableInt `json:"max_calls,omitempty"` - Command string `json:"command"` - Enabled *bool `json:"enabled,omitempty"` + Command string `json:"command"` + Enabled *bool `json:"enabled,omitempty"` } type UpdateRequest struct { - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Pattern *string `json:"pattern,omitempty"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Pattern *string `json:"pattern,omitempty"` MaxCalls NullableInt `json:"max_calls,omitempty"` - Command *string `json:"command,omitempty"` - Enabled *bool `json:"enabled,omitempty"` + Command *string `json:"command,omitempty"` + Enabled *bool `json:"enabled,omitempty"` } type ListResponse struct { diff --git a/internal/server/server.go b/internal/server/server.go index b451fa90..c2463643 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -17,7 +17,7 @@ type Server struct { logger *slog.Logger } -func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, chatHandler *handlers.ChatHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, settingsHandler *handlers.SettingsHandler, historyHandler *handlers.HistoryHandler, contactsHandler *handlers.ContactsHandler, preauthHandler *handlers.PreauthHandler, scheduleHandler *handlers.ScheduleHandler, subagentHandler *handlers.SubagentHandler, containerdHandler *handlers.ContainerdHandler, channelHandler *handlers.ChannelHandler, usersHandler *handlers.UsersHandler, mcpHandler *handlers.MCPHandler, cliHandler *handlers.LocalChannelHandler, webHandler *handlers.LocalChannelHandler) *Server { +func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, chatHandler *handlers.ChatHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, settingsHandler *handlers.SettingsHandler, preauthHandler *handlers.PreauthHandler, bindHandler *handlers.BindHandler, scheduleHandler *handlers.ScheduleHandler, subagentHandler *handlers.SubagentHandler, containerdHandler *handlers.ContainerdHandler, channelHandler *handlers.ChannelHandler, usersHandler *handlers.UsersHandler, mcpHandler *handlers.MCPHandler, cliHandler *handlers.LocalChannelHandler, webHandler *handlers.LocalChannelHandler) *Server { if addr == "" { addr = ":8080" } @@ -72,15 +72,12 @@ func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *han if settingsHandler != nil { settingsHandler.Register(e) } - if historyHandler != nil { - historyHandler.Register(e) - } - if contactsHandler != nil { - contactsHandler.Register(e) - } if preauthHandler != nil { preauthHandler.Register(e) } + if bindHandler != nil { + bindHandler.Register(e) + } if scheduleHandler != nil { scheduleHandler.Register(e) } diff --git a/internal/settings/service.go b/internal/settings/service.go index 2607c807..47d25e7a 100644 --- a/internal/settings/service.go +++ b/internal/settings/service.go @@ -19,6 +19,8 @@ type Service struct { logger *slog.Logger } +var ErrPersonalBotGuestAccessUnsupported = errors.New("personal bots do not support guest access") + func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { return &Service{ queries: queries, @@ -26,6 +28,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { } } +// Get returns user-level settings. func (s *Service) Get(ctx context.Context, userID string) (Settings, error) { pgID, err := parseUUID(userID) if err != nil { @@ -47,6 +50,7 @@ func (s *Service) Get(ctx context.Context, userID string) (Settings, error) { return normalizeUserSetting(row), nil } +// Upsert creates or updates user-level settings. func (s *Service) Upsert(ctx context.Context, userID string, req UpsertRequest) (Settings, error) { if s.queries == nil { return Settings{}, fmt.Errorf("settings queries not configured") @@ -88,7 +92,7 @@ func (s *Service) Upsert(ctx context.Context, userID string, req UpsertRequest) } _, err = s.queries.UpsertUserSettings(ctx, sqlc.UpsertUserSettingsParams{ - UserID: pgID, + ID: pgID, ChatModelID: pgtype.Text{String: current.ChatModelID, Valid: current.ChatModelID != ""}, MemoryModelID: pgtype.Text{String: current.MemoryModelID, Valid: current.MemoryModelID != ""}, EmbeddingModelID: pgtype.Text{String: current.EmbeddingModelID, Valid: current.EmbeddingModelID != ""}, @@ -108,24 +112,9 @@ func (s *Service) GetBot(ctx context.Context, botID string) (Settings, error) { } row, err := s.queries.GetSettingsByBotID(ctx, pgID) if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - settings := Settings{ - MaxContextLoadTime: DefaultMaxContextLoadTime, - Language: DefaultLanguage, - AllowGuest: false, - } - if err := s.attachBotModelConfig(ctx, pgID, &settings); err != nil { - return Settings{}, err - } - return settings, nil - } return Settings{}, err } - settings := normalizeBotSetting(row) - if err := s.attachBotModelConfig(ctx, pgID, &settings); err != nil { - return Settings{}, err - } - return settings, nil + return normalizeBotSettingsReadRow(row), nil } func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest) (Settings, error) { @@ -136,45 +125,66 @@ func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest if err != nil { return Settings{}, err } - - current := Settings{ - MaxContextLoadTime: DefaultMaxContextLoadTime, - Language: DefaultLanguage, - AllowGuest: false, - } - existing, err := s.queries.GetSettingsByBotID(ctx, pgID) - if err != nil && !errors.Is(err, pgx.ErrNoRows) { + botRow, err := s.queries.GetBotByID(ctx, pgID) + if err != nil { return Settings{}, err } - if err == nil { - current = normalizeBotSetting(existing) - } + isPersonalBot := strings.EqualFold(strings.TrimSpace(botRow.Type), "personal") + + current := normalizeBotSetting(botRow.MaxContextLoadTime, botRow.Language, botRow.AllowGuest) if req.MaxContextLoadTime != nil && *req.MaxContextLoadTime > 0 { current.MaxContextLoadTime = *req.MaxContextLoadTime } if strings.TrimSpace(req.Language) != "" { current.Language = strings.TrimSpace(req.Language) } - if req.AllowGuest != nil { + if isPersonalBot { + if req.AllowGuest != nil && *req.AllowGuest { + return Settings{}, ErrPersonalBotGuestAccessUnsupported + } + current.AllowGuest = false + } else if req.AllowGuest != nil { current.AllowGuest = *req.AllowGuest } - _, err = s.queries.UpsertBotSettings(ctx, sqlc.UpsertBotSettingsParams{ - BotID: pgID, + chatModelUUID := pgtype.UUID{} + if value := strings.TrimSpace(req.ChatModelID); value != "" { + modelID, err := s.resolveModelUUID(ctx, value) + if err != nil { + return Settings{}, err + } + chatModelUUID = modelID + } + memoryModelUUID := pgtype.UUID{} + if value := strings.TrimSpace(req.MemoryModelID); value != "" { + modelID, err := s.resolveModelUUID(ctx, value) + if err != nil { + return Settings{}, err + } + memoryModelUUID = modelID + } + embeddingModelUUID := pgtype.UUID{} + if value := strings.TrimSpace(req.EmbeddingModelID); value != "" { + modelID, err := s.resolveModelUUID(ctx, value) + if err != nil { + return Settings{}, err + } + embeddingModelUUID = modelID + } + + updated, err := s.queries.UpsertBotSettings(ctx, sqlc.UpsertBotSettingsParams{ + ID: pgID, MaxContextLoadTime: int32(current.MaxContextLoadTime), Language: current.Language, AllowGuest: current.AllowGuest, + ChatModelID: chatModelUUID, + MemoryModelID: memoryModelUUID, + EmbeddingModelID: embeddingModelUUID, }) if err != nil { return Settings{}, err } - if err := s.upsertBotModelConfig(ctx, pgID, req); err != nil { - return Settings{}, err - } - if err := s.attachBotModelConfig(ctx, pgID, ¤t); err != nil { - return Settings{}, err - } - return current, nil + return normalizeBotSettingsWriteRow(updated), nil } func (s *Service) Delete(ctx context.Context, botID string) error { @@ -188,7 +198,7 @@ func (s *Service) Delete(ctx context.Context, botID string) error { return s.queries.DeleteSettingsByBotID(ctx, pgID) } -func normalizeUserSetting(row sqlc.UserSetting) Settings { +func normalizeUserSetting(row sqlc.GetSettingsByUserIDRow) Settings { settings := Settings{ ChatModelID: strings.TrimSpace(row.ChatModelID.String), MemoryModelID: strings.TrimSpace(row.MemoryModelID.String), @@ -205,11 +215,11 @@ func normalizeUserSetting(row sqlc.UserSetting) Settings { return settings } -func normalizeBotSetting(row sqlc.BotSetting) Settings { +func normalizeBotSetting(maxContextLoadTime int32, language string, allowGuest bool) Settings { settings := Settings{ - MaxContextLoadTime: int(row.MaxContextLoadTime), - Language: strings.TrimSpace(row.Language), - AllowGuest: row.AllowGuest, + MaxContextLoadTime: int(maxContextLoadTime), + Language: strings.TrimSpace(language), + AllowGuest: allowGuest, } if settings.MaxContextLoadTime <= 0 { settings.MaxContextLoadTime = DefaultMaxContextLoadTime @@ -220,60 +230,41 @@ func normalizeBotSetting(row sqlc.BotSetting) Settings { return settings } -func (s *Service) attachBotModelConfig(ctx context.Context, botID pgtype.UUID, target *Settings) error { - if s.queries == nil || target == nil { - return nil - } - row, err := s.queries.GetBotModelConfigByBotID(ctx, botID) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return nil - } - return err - } - target.ChatModelID = strings.TrimSpace(row.ChatModelID.String) - target.MemoryModelID = strings.TrimSpace(row.MemoryModelID.String) - target.EmbeddingModelID = strings.TrimSpace(row.EmbeddingModelID.String) - return nil +func normalizeBotSettingsReadRow(row sqlc.GetSettingsByBotIDRow) Settings { + return normalizeBotSettingsFields( + row.MaxContextLoadTime, + row.Language, + row.AllowGuest, + row.ChatModelID, + row.MemoryModelID, + row.EmbeddingModelID, + ) } -func (s *Service) upsertBotModelConfig(ctx context.Context, botID pgtype.UUID, req UpsertRequest) error { - if s.queries == nil { - return fmt.Errorf("settings queries not configured") - } - params := sqlc.UpsertBotModelConfigParams{ - BotID: botID, - } - hasUpdate := false - if value := strings.TrimSpace(req.ChatModelID); value != "" { - modelID, err := s.resolveModelUUID(ctx, value) - if err != nil { - return err - } - params.ChatModelID = modelID - hasUpdate = true - } - if value := strings.TrimSpace(req.MemoryModelID); value != "" { - modelID, err := s.resolveModelUUID(ctx, value) - if err != nil { - return err - } - params.MemoryModelID = modelID - hasUpdate = true - } - if value := strings.TrimSpace(req.EmbeddingModelID); value != "" { - modelID, err := s.resolveModelUUID(ctx, value) - if err != nil { - return err - } - params.EmbeddingModelID = modelID - hasUpdate = true - } - if !hasUpdate { - return nil - } - _, err := s.queries.UpsertBotModelConfig(ctx, params) - return err +func normalizeBotSettingsWriteRow(row sqlc.UpsertBotSettingsRow) Settings { + return normalizeBotSettingsFields( + row.MaxContextLoadTime, + row.Language, + row.AllowGuest, + row.ChatModelID, + row.MemoryModelID, + row.EmbeddingModelID, + ) +} + +func normalizeBotSettingsFields( + maxContextLoadTime int32, + language string, + allowGuest bool, + chatModelID pgtype.Text, + memoryModelID pgtype.Text, + embeddingModelID pgtype.Text, +) Settings { + settings := normalizeBotSetting(maxContextLoadTime, language, allowGuest) + settings.ChatModelID = strings.TrimSpace(chatModelID.String) + settings.MemoryModelID = strings.TrimSpace(memoryModelID.String) + settings.EmbeddingModelID = strings.TrimSpace(embeddingModelID.String) + return settings } func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype.UUID, error) { diff --git a/internal/subagent/types.go b/internal/subagent/types.go index 77498a12..38207e0d 100644 --- a/internal/subagent/types.go +++ b/internal/subagent/types.go @@ -3,30 +3,30 @@ package subagent import "time" type Subagent struct { - ID string `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - BotID string `json:"bot_id"` + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + BotID string `json:"bot_id"` Messages []map[string]any `json:"messages"` Metadata map[string]any `json:"metadata"` - Skills []string `json:"skills"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - Deleted bool `json:"deleted"` - DeletedAt *time.Time `json:"deleted_at,omitempty"` + Skills []string `json:"skills"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Deleted bool `json:"deleted"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` } type CreateRequest struct { - Name string `json:"name"` - Description string `json:"description"` + Name string `json:"name"` + Description string `json:"description"` Messages []map[string]any `json:"messages,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` - Skills []string `json:"skills,omitempty"` + Skills []string `json:"skills,omitempty"` } type UpdateRequest struct { - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } diff --git a/packages/web/src/components/Sidebar/index.vue b/packages/web/src/components/Sidebar/index.vue index cc2687b7..ca4b3229 100644 --- a/packages/web/src/components/Sidebar/index.vue +++ b/packages/web/src/components/Sidebar/index.vue @@ -1,100 +1,123 @@ + \ No newline at end of file +const settingsRouteNames = new Set(['settings', 'settings-user', 'bots', 'bot-detail', 'models', 'mcp']) +const sidebarListRegistry: Record = { + chat: ChatListMenu, + settings: SettingsListMenu, +} + +const currentListKey = computed(() => ( + settingsRouteNames.has(String(route.name ?? '')) ? 'settings' : 'chat' +)) +const currentListComponent = computed(() => sidebarListRegistry[currentListKey.value]) +const isInSettingsRoute = computed(() => currentListKey.value === 'settings') + +const { userInfo } = useUserStore() +const displayNameLabel = computed(() => userInfo.displayName || userInfo.username || userInfo.id || '-') +const displayTitle = computed(() => userInfo.displayName || userInfo.username || userInfo.id || 'User') +const avatarFallback = computed(() => displayTitle.value.slice(0, 2).toUpperCase() || 'U') + +function onLogoClick() { + if (route.name === 'chat') { + return + } + void router.push({ name: 'chat' }).catch(() => undefined) +} + +function onActionButtonClick() { + if (isInSettingsRoute.value) { + void openChat() + return + } + void openUserSettings() +} + +async function openChat() { + if (route.name === 'chat') { + return + } + await router.push({ name: 'chat' }).catch(() => undefined) +} + +async function openUserSettings() { + if (route.name === 'settings-user') { + return + } + await router.push({ name: 'settings-user' }).catch(() => undefined) +} + diff --git a/packages/web/src/components/Sidebar/lists/chat-list-menu.vue b/packages/web/src/components/Sidebar/lists/chat-list-menu.vue new file mode 100644 index 00000000..9d62cd0b --- /dev/null +++ b/packages/web/src/components/Sidebar/lists/chat-list-menu.vue @@ -0,0 +1,215 @@ + + + + diff --git a/packages/web/src/components/Sidebar/lists/settings-list-menu.vue b/packages/web/src/components/Sidebar/lists/settings-list-menu.vue new file mode 100644 index 00000000..209ce4e7 --- /dev/null +++ b/packages/web/src/components/Sidebar/lists/settings-list-menu.vue @@ -0,0 +1,94 @@ + + + + diff --git a/packages/web/src/components/Sidebar/lists/types.ts b/packages/web/src/components/Sidebar/lists/types.ts new file mode 100644 index 00000000..7612130b --- /dev/null +++ b/packages/web/src/components/Sidebar/lists/types.ts @@ -0,0 +1,4 @@ +export interface SidebarListProps { + collapsible?: boolean +} + diff --git a/packages/web/src/components/add-platform/index.vue b/packages/web/src/components/add-platform/index.vue deleted file mode 100644 index dfb38ac6..00000000 --- a/packages/web/src/components/add-platform/index.vue +++ /dev/null @@ -1,172 +0,0 @@ - - - diff --git a/packages/web/src/components/chat-list/robot-chat/index.vue b/packages/web/src/components/chat-list/robot-chat/index.vue index 4f9f35a3..16c35e21 100644 --- a/packages/web/src/components/chat-list/robot-chat/index.vue +++ b/packages/web/src/components/chat-list/robot-chat/index.vue @@ -1,11 +1,18 @@