From 5a35ef34acb82da5d4f65478e5a771459e759260 Mon Sep 17 00:00:00 2001 From: BBQ Date: Fri, 6 Feb 2026 14:41:54 +0800 Subject: [PATCH] feat: channel gateway implementation and multi-bot refactor - Refactor channel manager with support for Sender/Receiver interfaces and hot-swappable adapters. - Implement identity routing and pre-authentication logic for inbound messages. - Update database schema to support bot pre-auth keys and extended channel session metadata. - Add Telegram and Feishu channel configuration and adapter enhancements. - Update Swagger documentation and internal handlers for channel management. Co-authored-by: Cursor --- agent/src/model.ts | 29 +- cmd/agent/main.go | 20 +- db/migrations/0001_init.down.sql | 2 +- db/migrations/0001_init.up.sql | 42 +- db/queries/channels.sql | 87 ++ db/queries/contacts.sql | 76 ++ db/queries/containers.sql | 3 + db/queries/history.sql | 12 +- db/queries/preauth.sql | 16 + docs/docs.go | 917 +++++++++++++----- docs/swagger.json | 917 +++++++++++++----- docs/swagger.yaml | 631 ++++++++---- internal/bots/service.go | 10 +- internal/bots/types.go | 36 +- internal/channel/adapter.go | 120 ++- internal/channel/adapters/feishu/config.go | 142 +++ .../channel/adapters/feishu/config_test.go | 81 ++ .../channel/adapters/feishu/descriptor.go | 48 +- internal/channel/adapters/feishu/feishu.go | 333 +++++-- .../feishu/feishu_integration_test.go | 27 +- .../channel/adapters/feishu/feishu_logger.go | 10 +- .../channel/adapters/feishu/feishu_test.go | 16 +- internal/channel/adapters/local/cli.go | 11 +- internal/channel/adapters/local/descriptor.go | 55 +- internal/channel/adapters/local/web.go | 11 +- internal/channel/adapters/telegram/config.go | 158 +++ .../channel/adapters/telegram/config_test.go | 88 ++ .../channel/adapters/telegram/descriptor.go | 44 +- .../channel/adapters/telegram/telegram.go | 491 ++++++++-- .../adapters/telegram/telegram_test.go | 17 +- internal/channel/capabilities.go | 22 + internal/channel/config.go | 187 +--- internal/channel/config_test.go | 165 ++-- internal/channel/directory.go | 32 + internal/channel/helpers_test.go | 90 +- internal/channel/manager.go | 534 +++++++--- internal/channel/manager_core_test.go | 56 +- internal/channel/manager_integration_test.go | 115 ++- internal/channel/manager_test.go | 46 +- internal/channel/outbound.go | 165 ++++ internal/channel/processor.go | 4 +- internal/channel/registry.go | 68 +- internal/channel/registry_test.go | 28 - internal/channel/schema.go | 27 + internal/channel/service.go | 166 ++-- internal/channel/target.go | 25 + internal/channel/types.go | 254 ++++- internal/chat/assistant_output.go | 52 + internal/chat/normalize.go | 109 ++- internal/chat/resolver.go | 666 +++++++------ internal/chat/types.go | 46 +- internal/contacts/service.go | 118 +-- internal/contacts/types.go | 21 +- internal/db/sqlc/channels.sql.go | 97 +- internal/db/sqlc/contacts.sql.go | 92 -- internal/db/sqlc/containers.sql.go | 26 + internal/db/sqlc/history.sql.go | 18 +- internal/db/sqlc/models.go | 29 +- internal/db/sqlc/preauth.sql.go | 89 ++ internal/directory/service.go | 226 +++++ internal/handlers/channel.go | 74 +- internal/handlers/chat.go | 12 + internal/handlers/contacts.go | 189 +--- internal/handlers/containerd.go | 183 +++- internal/handlers/fs.go | 33 +- internal/handlers/history.go | 2 +- internal/handlers/local_channel.go | 42 +- internal/handlers/memory.go | 42 +- internal/handlers/preauth.go | 99 ++ internal/handlers/schedule.go | 2 +- internal/handlers/settings.go | 2 +- internal/handlers/skills.go | 18 +- internal/handlers/subagent.go | 2 +- internal/handlers/users.go | 8 +- internal/history/service.go | 18 +- internal/history/types.go | 18 +- internal/mcp/manager.go | 57 +- internal/mcp/versioning.go | 12 +- internal/memory/llm_client.go | 30 +- internal/memory/prompts.go | 2 +- internal/memory/qdrant_store.go | 38 +- internal/memory/qdrant_store_test.go | 4 +- internal/memory/service.go | 60 +- internal/memory/service_test.go | 4 +- internal/memory/types.go | 22 +- internal/models/models.go | 9 + internal/models/types.go | 6 + internal/policy/service.go | 61 ++ internal/preauth/service.go | 128 +++ internal/preauth/types.go | 13 + internal/providers/service.go | 3 +- internal/providers/types.go | 37 +- internal/router/channel.go | 812 ++++++++++------ internal/router/channel_test.go | 219 ++++- internal/router/identity.go | 326 +++++++ internal/router/identity_test.go | 210 ++++ internal/server/server.go | 5 +- internal/settings/types.go | 2 +- internal/subagent/service.go | 24 +- internal/subagent/types.go | 14 +- internal/users/service.go | 8 +- packages/shared/src/model.ts | 1 + packages/ui/src/index.ts | 78 +- scripts/db-drop.sh | 0 scripts/db-up.sh | 0 sqlc.yaml | 2 +- 106 files changed, 7910 insertions(+), 3044 deletions(-) create mode 100644 db/queries/channels.sql create mode 100644 db/queries/contacts.sql create mode 100644 db/queries/preauth.sql create mode 100644 internal/channel/adapters/feishu/config.go create mode 100644 internal/channel/adapters/feishu/config_test.go create mode 100644 internal/channel/adapters/telegram/config.go create mode 100644 internal/channel/adapters/telegram/config_test.go create mode 100644 internal/channel/capabilities.go create mode 100644 internal/channel/directory.go create mode 100644 internal/channel/outbound.go delete mode 100644 internal/channel/registry_test.go create mode 100644 internal/channel/schema.go create mode 100644 internal/channel/target.go create mode 100644 internal/chat/assistant_output.go create mode 100644 internal/db/sqlc/preauth.sql.go create mode 100644 internal/directory/service.go create mode 100644 internal/handlers/preauth.go create mode 100644 internal/policy/service.go create mode 100644 internal/preauth/service.go create mode 100644 internal/preauth/types.go create mode 100644 internal/router/identity.go create mode 100644 internal/router/identity_test.go mode change 100755 => 100644 scripts/db-drop.sh mode change 100755 => 100644 scripts/db-up.sh diff --git a/agent/src/model.ts b/agent/src/model.ts index 0be14f7d..38ebf601 100644 --- a/agent/src/model.ts +++ b/agent/src/model.ts @@ -5,17 +5,22 @@ import { createGoogleGenerativeAI } from '@ai-sdk/google' import { ClientType, ModelConfig } from './types' export const createModel = (model: ModelConfig) => { - const apiKey = model.apiKey.toLowerCase().trim() - const baseURL = model.baseUrl.toLowerCase().trim() - const modelId = model.modelId.toLowerCase().trim() - const clients = { - [ClientType.OpenAI]: createOpenAI, - [ClientType.OpenAICompatible]: createOpenAI, - [ClientType.Anthropic]: createAnthropic, - [ClientType.Google]: createGoogleGenerativeAI, + const apiKey = model.apiKey.trim() + const baseURL = model.baseUrl.trim() + const modelId = model.modelId.trim() + + switch (model.clientType) { + case ClientType.OpenAI: + case ClientType.OpenAICompatible: { + const provider = createOpenAI({ apiKey, baseURL }) + // Use .chat() to call /chat/completions (not /responses which only OpenAI supports) + return provider.chat(modelId) + } + case ClientType.Anthropic: + return createAnthropic({ apiKey, baseURL })(modelId) + case ClientType.Google: + return createGoogleGenerativeAI({ apiKey, baseURL })(modelId) + default: + return createAiGateway({ apiKey, baseURL })(modelId) } - return (clients[model.clientType] ?? createAiGateway)({ - apiKey, - baseURL, - })(modelId) } \ No newline at end of file diff --git a/cmd/agent/main.go b/cmd/agent/main.go index c963d683..adaef571 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -26,6 +26,8 @@ import ( "github.com/memohai/memoh/internal/mcp" "github.com/memohai/memoh/internal/memory" "github.com/memohai/memoh/internal/models" + "github.com/memohai/memoh/internal/policy" + "github.com/memohai/memoh/internal/preauth" "github.com/memohai/memoh/internal/providers" "github.com/memohai/memoh/internal/router" "github.com/memohai/memoh/internal/schedule" @@ -82,7 +84,7 @@ func main() { manager := mcp.NewManager(logger.L, service, cfg.MCP) pingHandler := handlers.NewPingHandler(logger.L) - containerdHandler := handlers.NewContainerdHandler(logger.L, service, cfg.MCP, cfg.Containerd.Namespace) + // containerdHandler is created later after DB services are initialized conn, err := db.Open(ctx, cfg.Postgres) if err != nil { @@ -96,6 +98,8 @@ func main() { botService := bots.NewService(logger.L, queries) usersService := users.NewService(logger.L, queries) + containerdHandler := handlers.NewContainerdHandler(logger.L, service, cfg.MCP, cfg.Containerd.Namespace, botService, usersService, queries) + if err := ensureAdminUser(ctx, logger.L, queries, cfg); err != nil { logger.Error("ensure admin user", slog.Any("error", err)) os.Exit(1) @@ -142,18 +146,24 @@ func main() { settingsService := settings.NewService(logger.L, queries) settingsHandler := handlers.NewSettingsHandler(logger.L, settingsService, botService, usersService) modelsHandler := handlers.NewModelsHandler(logger.L, modelsService, settingsService) + policyService := policy.NewService(logger.L, botService, settingsService) historyService := history.NewService(logger.L, queries) historyHandler := handlers.NewHistoryHandler(logger.L, historyService, botService, usersService) contactsService := contacts.NewService(queries) contactsHandler := handlers.NewContactsHandler(contactsService, botService, usersService) + preauthService := preauth.NewService(queries) + preauthHandler := handlers.NewPreauthHandler(preauthService, botService, usersService) chatResolver = chat.NewResolver(logger.L, modelsService, queries, memoryService, historyService, settingsService, cfg.AgentGateway.BaseURL(), 120*time.Second) embeddingsHandler := handlers.NewEmbeddingsHandler(logger.L, modelsService, queries) swaggerHandler := handlers.NewSwaggerHandler(logger.L) chatHandler := handlers.NewChatHandler(logger.L, chatResolver, botService, usersService) channelService := channel.NewService(queries) - channelRouter := router.NewChannelInboundProcessor(logger.L, channelService, chatResolver, contactsService, settingsService, cfg.Auth.JWTSecret, 5*time.Minute) + channelRouter := router.NewChannelInboundProcessor(logger.L, channelService, chatResolver, contactsService, policyService, preauthService, cfg.Auth.JWTSecret, 5*time.Minute) channelManager := channel.NewManager(logger.L, channelService, channelRouter) + if mw := channelRouter.IdentityMiddleware(); mw != nil { + channelManager.Use(mw) + } sessionHub := channel.NewSessionHub() channelManager.RegisterAdapter(telegram.NewTelegramAdapter(logger.L)) channelManager.RegisterAdapter(feishu.NewFeishuAdapter(logger.L)) @@ -162,8 +172,8 @@ func main() { channelManager.Start(ctx) channelHandler := handlers.NewChannelHandler(channelService) usersHandler := handlers.NewUsersHandler(logger.L, usersService, botService, channelService, channelManager) - cliHandler := handlers.NewLocalChannelHandler(channel.ChannelCLI, channelManager, channelService, sessionHub, botService, usersService) - webHandler := handlers.NewLocalChannelHandler(channel.ChannelWeb, channelManager, channelService, sessionHub, botService, usersService) + cliHandler := handlers.NewLocalChannelHandler(local.CLIType, channelManager, channelService, sessionHub, botService, usersService) + webHandler := handlers.NewLocalChannelHandler(local.WebType, channelManager, channelService, sessionHub, botService, usersService) scheduleGateway := chat.NewScheduleGateway(chatResolver) scheduleService := schedule.NewService(logger.L, queries, scheduleGateway, cfg.Auth.JWTSecret) if err := scheduleService.Bootstrap(ctx); err != nil { @@ -173,7 +183,7 @@ func main() { scheduleHandler := handlers.NewScheduleHandler(logger.L, scheduleService, botService, usersService) subagentService := subagent.NewService(logger.L, queries) subagentHandler := handlers.NewSubagentHandler(logger.L, subagentService, botService, usersService) - srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, historyHandler, contactsHandler, scheduleHandler, subagentHandler, containerdHandler, channelHandler, usersHandler, cliHandler, webHandler) + srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, historyHandler, contactsHandler, preauthHandler, scheduleHandler, subagentHandler, containerdHandler, channelHandler, usersHandler, cliHandler, webHandler) if err := srv.Start(); err != nil { logger.Error("server failed", slog.Any("error", err)) diff --git a/db/migrations/0001_init.down.sql b/db/migrations/0001_init.down.sql index 742a2d74..7ce16393 100644 --- a/db/migrations/0001_init.down.sql +++ b/db/migrations/0001_init.down.sql @@ -6,8 +6,8 @@ DROP TABLE IF EXISTS container_versions; DROP TABLE IF EXISTS snapshots; DROP TABLE IF EXISTS containers; DROP TABLE IF EXISTS channel_sessions; -DROP TABLE IF EXISTS contact_bind_tokens; DROP TABLE IF EXISTS contact_channels; +DROP TABLE IF EXISTS bot_preauth_keys; DROP TABLE IF EXISTS contacts; DROP TABLE IF EXISTS bot_channel_configs; DROP TABLE IF EXISTS user_channel_bindings; diff --git a/db/migrations/0001_init.up.sql b/db/migrations/0001_init.up.sql index 6efab809..8e0b7b08 100644 --- a/db/migrations/0001_init.up.sql +++ b/db/migrations/0001_init.up.sql @@ -95,7 +95,7 @@ CREATE INDEX IF NOT EXISTS idx_bot_members_user_id ON bot_members(user_id); CREATE TABLE IF NOT EXISTS bot_settings ( bot_id UUID PRIMARY KEY REFERENCES bots(id) ON DELETE CASCADE, max_context_load_time INTEGER NOT NULL DEFAULT 1440, - language TEXT NOT NULL DEFAULT 'Same as user input', + language TEXT NOT NULL DEFAULT 'auto', allow_guest BOOLEAN NOT NULL DEFAULT false ); @@ -125,6 +125,7 @@ CREATE TABLE IF NOT EXISTS history ( bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, session_id TEXT NOT NULL, messages JSONB NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, skills TEXT[] NOT NULL DEFAULT '{}'::text[], timestamp TIMESTAMPTZ NOT NULL ); @@ -187,6 +188,20 @@ CREATE UNIQUE INDEX IF NOT EXISTS idx_contacts_bot_user_unique CREATE INDEX IF NOT EXISTS idx_contacts_bot_id ON contacts(bot_id); +CREATE TABLE IF NOT EXISTS bot_preauth_keys ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, + token TEXT NOT NULL, + issued_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL, + expires_at TIMESTAMPTZ, + used_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT bot_preauth_keys_unique UNIQUE (token) +); + +CREATE INDEX IF NOT EXISTS idx_bot_preauth_keys_bot_id ON bot_preauth_keys(bot_id); +CREATE INDEX IF NOT EXISTS idx_bot_preauth_keys_expires ON bot_preauth_keys(expires_at); + CREATE TABLE IF NOT EXISTS contact_channels ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, @@ -202,23 +217,6 @@ CREATE TABLE IF NOT EXISTS contact_channels ( CREATE INDEX IF NOT EXISTS idx_contact_channels_contact_id ON contact_channels(contact_id); CREATE INDEX IF NOT EXISTS idx_contact_channels_platform_external ON contact_channels(platform, external_id); -CREATE TABLE IF NOT EXISTS contact_bind_tokens ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, - contact_id UUID NOT NULL REFERENCES contacts(id) ON DELETE CASCADE, - token TEXT NOT NULL, - target_platform TEXT, - target_external_id TEXT, - issued_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL, - expires_at TIMESTAMPTZ NOT NULL, - used_at TIMESTAMPTZ, - created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - CONSTRAINT contact_bind_tokens_unique UNIQUE (token) -); - -CREATE INDEX IF NOT EXISTS idx_contact_bind_tokens_contact_id ON contact_bind_tokens(contact_id); -CREATE INDEX IF NOT EXISTS idx_contact_bind_tokens_expires ON contact_bind_tokens(expires_at); - CREATE TABLE IF NOT EXISTS channel_sessions ( session_id TEXT PRIMARY KEY, bot_id UUID NOT NULL REFERENCES bots(id) ON DELETE CASCADE, @@ -226,6 +224,9 @@ CREATE TABLE IF NOT EXISTS channel_sessions ( user_id UUID REFERENCES users(id) ON DELETE CASCADE, contact_id UUID REFERENCES contacts(id) ON DELETE SET NULL, platform TEXT NOT NULL, + reply_target TEXT, + thread_id TEXT, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now() ); @@ -325,6 +326,9 @@ CREATE INDEX IF NOT EXISTS idx_subagents_deleted ON subagents(deleted); CREATE TABLE IF NOT EXISTS user_settings ( user_id UUID PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE, + chat_model_id TEXT, + memory_model_id TEXT, + embedding_model_id TEXT, max_context_load_time INTEGER NOT NULL DEFAULT 1440, - language TEXT NOT NULL DEFAULT 'Same as user input' + language TEXT NOT NULL DEFAULT 'auto' ); diff --git a/db/queries/channels.sql b/db/queries/channels.sql new file mode 100644 index 00000000..9323a26f --- /dev/null +++ b/db/queries/channels.sql @@ -0,0 +1,87 @@ +-- name: GetBotChannelConfig :one +SELECT id, bot_id, channel_type, credentials, external_identity, self_identity, routing, capabilities, status, verified_at, created_at, updated_at +FROM bot_channel_configs +WHERE bot_id = $1 AND channel_type = $2 +LIMIT 1; + +-- name: GetBotChannelConfigByExternalIdentity :one +SELECT id, bot_id, channel_type, credentials, external_identity, self_identity, routing, capabilities, status, verified_at, created_at, updated_at +FROM bot_channel_configs +WHERE channel_type = $1 AND external_identity = $2 +LIMIT 1; + +-- name: UpsertBotChannelConfig :one +INSERT INTO bot_channel_configs ( + bot_id, channel_type, credentials, external_identity, self_identity, routing, capabilities, status, verified_at +) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) +ON CONFLICT (bot_id, channel_type) +DO UPDATE SET + credentials = EXCLUDED.credentials, + external_identity = EXCLUDED.external_identity, + self_identity = EXCLUDED.self_identity, + routing = EXCLUDED.routing, + capabilities = EXCLUDED.capabilities, + status = EXCLUDED.status, + verified_at = EXCLUDED.verified_at, + updated_at = now() +RETURNING id, bot_id, channel_type, credentials, external_identity, self_identity, routing, capabilities, status, verified_at, created_at, updated_at; + +-- name: ListBotChannelConfigsByType :many +SELECT id, bot_id, channel_type, credentials, external_identity, self_identity, routing, capabilities, status, verified_at, created_at, updated_at +FROM bot_channel_configs +WHERE channel_type = $1 +ORDER BY created_at DESC; + +-- name: GetUserChannelBinding :one +SELECT id, user_id, channel_type, config, created_at, updated_at +FROM user_channel_bindings +WHERE user_id = $1 AND channel_type = $2 +LIMIT 1; + +-- name: UpsertUserChannelBinding :one +INSERT INTO user_channel_bindings (user_id, channel_type, config) +VALUES ($1, $2, $3) +ON CONFLICT (user_id, channel_type) +DO UPDATE SET + config = EXCLUDED.config, + updated_at = now() +RETURNING id, user_id, channel_type, config, created_at, updated_at; + +-- name: ListUserChannelBindingsByType :many +SELECT id, user_id, channel_type, config, created_at, updated_at +FROM user_channel_bindings +WHERE channel_type = $1 +ORDER BY created_at DESC; + +-- name: GetChannelSessionByID :one +SELECT session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata, created_at, updated_at +FROM channel_sessions +WHERE session_id = $1 +LIMIT 1; + +-- name: ListChannelSessionsByBotPlatform :many +SELECT session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata, created_at, updated_at +FROM channel_sessions +WHERE bot_id = $1 AND platform = $2 +ORDER BY updated_at DESC; + +-- name: UpsertChannelSession :one +INSERT INTO channel_sessions (session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) +ON CONFLICT (session_id) +DO UPDATE SET + bot_id = EXCLUDED.bot_id, + channel_config_id = EXCLUDED.channel_config_id, + user_id = EXCLUDED.user_id, + contact_id = EXCLUDED.contact_id, + platform = EXCLUDED.platform, + reply_target = EXCLUDED.reply_target, + thread_id = EXCLUDED.thread_id, + metadata = EXCLUDED.metadata, + updated_at = now() +RETURNING session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata, created_at, updated_at; + +-- name: DeleteChannelSession :exec +DELETE FROM channel_sessions +WHERE session_id = $1; diff --git a/db/queries/contacts.sql b/db/queries/contacts.sql new file mode 100644 index 00000000..7f5d9fe8 --- /dev/null +++ b/db/queries/contacts.sql @@ -0,0 +1,76 @@ +-- name: CreateContact :one +INSERT INTO contacts (bot_id, user_id, display_name, alias, tags, status, metadata) +VALUES ($1, $2, $3, $4, $5, $6, $7) +RETURNING id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at; + +-- name: GetContactByID :one +SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at +FROM contacts +WHERE id = $1 +LIMIT 1; + +-- name: GetContactByUserID :one +SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at +FROM contacts +WHERE bot_id = $1 AND user_id = $2 +LIMIT 1; + +-- name: ListContactsByBot :many +SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at +FROM contacts +WHERE bot_id = $1 +ORDER BY created_at DESC; + +-- name: SearchContacts :many +SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at +FROM contacts +WHERE bot_id = $1 + AND ( + display_name ILIKE sqlc.arg(query) + OR alias ILIKE sqlc.arg(query) + OR EXISTS ( + SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE sqlc.arg(query) + ) + ) +ORDER BY created_at DESC; + +-- name: UpdateContact :one +UPDATE contacts +SET display_name = COALESCE(sqlc.narg(display_name), display_name), + alias = COALESCE(sqlc.narg(alias), alias), + tags = COALESCE(sqlc.narg(tags), tags), + status = COALESCE(NULLIF(sqlc.arg(status)::text, ''), status), + metadata = COALESCE(sqlc.narg(metadata), metadata), + updated_at = now() +WHERE id = sqlc.arg(id) +RETURNING id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at; + +-- name: UpdateContactUser :one +UPDATE contacts +SET user_id = $2, + updated_at = now() +WHERE id = $1 +RETURNING id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at; + +-- name: UpsertContactChannel :one +INSERT INTO contact_channels (bot_id, contact_id, platform, external_id, metadata) +VALUES ($1, $2, $3, $4, $5) +ON CONFLICT (bot_id, platform, external_id) +DO UPDATE SET + contact_id = EXCLUDED.contact_id, + metadata = EXCLUDED.metadata, + updated_at = now() +RETURNING id, bot_id, contact_id, platform, external_id, metadata, created_at, updated_at; + +-- name: GetContactChannelByIdentity :one +SELECT id, bot_id, contact_id, platform, external_id, metadata, created_at, updated_at +FROM contact_channels +WHERE bot_id = $1 AND platform = $2 AND external_id = $3 +LIMIT 1; + +-- name: ListContactChannelsByContact :many +SELECT id, bot_id, contact_id, platform, external_id, metadata, created_at, updated_at +FROM contact_channels +WHERE contact_id = $1 +ORDER BY created_at DESC; + diff --git a/db/queries/containers.sql b/db/queries/containers.sql index 60173d9f..8d446b9f 100644 --- a/db/queries/containers.sql +++ b/db/queries/containers.sql @@ -31,3 +31,6 @@ ON CONFLICT (container_id) DO UPDATE SET -- name: GetContainerByContainerID :one SELECT * FROM containers WHERE container_id = sqlc.arg(container_id); + +-- name: GetContainerByBotID :one +SELECT * FROM containers WHERE bot_id = sqlc.arg(bot_id) ORDER BY updated_at DESC LIMIT 1; diff --git a/db/queries/history.sql b/db/queries/history.sql index 0b4cab3f..7e95576c 100644 --- a/db/queries/history.sql +++ b/db/queries/history.sql @@ -1,21 +1,21 @@ -- name: CreateHistory :one -INSERT INTO history (bot_id, session_id, messages, skills, timestamp) -VALUES ($1, $2, $3, $4, $5) -RETURNING id, bot_id, session_id, messages, skills, timestamp; +INSERT INTO history (bot_id, session_id, messages, metadata, skills, timestamp) +VALUES ($1, $2, $3, $4, $5, $6) +RETURNING id, bot_id, session_id, messages, metadata, skills, timestamp; -- name: ListHistoryByBotSessionSince :many -SELECT id, bot_id, session_id, messages, skills, timestamp +SELECT id, bot_id, session_id, messages, metadata, skills, timestamp FROM history WHERE bot_id = $1 AND session_id = $2 AND timestamp >= $3 ORDER BY timestamp ASC; -- name: GetHistoryByID :one -SELECT id, bot_id, session_id, messages, skills, timestamp +SELECT id, bot_id, session_id, messages, metadata, skills, timestamp FROM history WHERE id = $1; -- name: ListHistoryByBotSession :many -SELECT id, bot_id, session_id, messages, skills, timestamp +SELECT id, bot_id, session_id, messages, metadata, skills, timestamp FROM history WHERE bot_id = $1 AND session_id = $2 ORDER BY timestamp DESC diff --git a/db/queries/preauth.sql b/db/queries/preauth.sql new file mode 100644 index 00000000..86aa3ffe --- /dev/null +++ b/db/queries/preauth.sql @@ -0,0 +1,16 @@ +-- name: CreateBotPreauthKey :one +INSERT INTO bot_preauth_keys (bot_id, token, issued_by_user_id, expires_at) +VALUES ($1, $2, $3, $4) +RETURNING id, bot_id, token, issued_by_user_id, expires_at, used_at, created_at; + +-- name: GetBotPreauthKey :one +SELECT id, bot_id, token, issued_by_user_id, expires_at, used_at, created_at +FROM bot_preauth_keys +WHERE token = $1 +LIMIT 1; + +-- name: MarkBotPreauthKeyUsed :one +UPDATE bot_preauth_keys +SET used_at = now() +WHERE id = $1 +RETURNING id, bot_id, token, issued_by_user_id, expires_at, used_at, created_at; diff --git a/docs/docs.go b/docs/docs.go index feea2bb5..e47b06dc 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -240,6 +240,193 @@ const docTemplate = `{ } } }, + "/bots/{bot_id}/container": { + "post": { + "tags": [ + "containerd" + ], + "summary": "Create and start MCP container for bot", + "parameters": [ + { + "type": "string", + "description": "Bot ID", + "name": "bot_id", + "in": "path", + "required": true + }, + { + "description": "Create container payload", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/handlers.CreateContainerRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/handlers.CreateContainerResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/bots/{bot_id}/container/list": { + "get": { + "tags": [ + "containerd" + ], + "summary": "List containers for bot", + "parameters": [ + { + "type": "string", + "description": "Bot ID", + "name": "bot_id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/handlers.ListContainersResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/bots/{bot_id}/container/snapshots": { + "get": { + "tags": [ + "containerd" + ], + "summary": "List snapshots", + "parameters": [ + { + "type": "string", + "description": "Bot ID", + "name": "bot_id", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "Snapshotter name", + "name": "snapshotter", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/handlers.ListSnapshotsResponse" + } + } + } + }, + "post": { + "tags": [ + "containerd" + ], + "summary": "Create container snapshot", + "parameters": [ + { + "type": "string", + "description": "Bot ID", + "name": "bot_id", + "in": "path", + "required": true + }, + { + "description": "Create snapshot payload", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/handlers.CreateSnapshotRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/handlers.CreateSnapshotResponse" + } + } + } + } + }, + "/bots/{bot_id}/container/{id}": { + "delete": { + "tags": [ + "containerd" + ], + "summary": "Delete MCP container", + "parameters": [ + { + "type": "string", + "description": "Bot ID", + "name": "bot_id", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "Container ID", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "204": { + "description": "No Content" + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, "/bots/{bot_id}/history": { "get": { "description": "List history records for current user", @@ -2128,28 +2315,53 @@ const docTemplate = `{ } } }, - "/container": { - "post": { + "/channels": { + "get": { + "description": "List channel meta information including capabilities and schemas", "tags": [ - "containerd" + "channel" ], - "summary": "Create and start MCP container", + "summary": "List channel capabilities and schemas", + "responses": { + "200": { + "description": "OK", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/handlers.ChannelMeta" + } + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/channels/{platform}": { + "get": { + "description": "Get channel meta information including capabilities and schemas", + "tags": [ + "channel" + ], + "summary": "Get channel capabilities and schemas", "parameters": [ { - "description": "Create container payload", - "name": "payload", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/handlers.CreateContainerRequest" - } + "type": "string", + "description": "Channel platform", + "name": "platform", + "in": "path", + "required": true } ], "responses": { "200": { "description": "OK", "schema": { - "$ref": "#/definitions/handlers.CreateContainerResponse" + "$ref": "#/definitions/handlers.ChannelMeta" } }, "400": { @@ -2158,8 +2370,8 @@ const docTemplate = `{ "$ref": "#/definitions/handlers.ErrorResponse" } }, - "500": { - "description": "Internal Server Error", + "404": { + "description": "Not Found", "schema": { "$ref": "#/definitions/handlers.ErrorResponse" } @@ -2227,28 +2439,6 @@ const docTemplate = `{ } } }, - "/container/list": { - "get": { - "tags": [ - "containerd" - ], - "summary": "List containers", - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/handlers.ListContainersResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, "/container/skills": { "get": { "tags": [ @@ -2369,119 +2559,6 @@ const docTemplate = `{ } } }, - "/container/snapshots": { - "get": { - "tags": [ - "containerd" - ], - "summary": "List snapshots", - "parameters": [ - { - "type": "string", - "description": "Snapshotter name", - "name": "snapshotter", - "in": "query" - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/handlers.ListSnapshotsResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - }, - "post": { - "tags": [ - "containerd" - ], - "summary": "Create container snapshot", - "parameters": [ - { - "description": "Create snapshot payload", - "name": "payload", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/handlers.CreateSnapshotRequest" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/handlers.CreateSnapshotResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/container/{id}": { - "delete": { - "tags": [ - "containerd" - ], - "summary": "Delete MCP container", - "parameters": [ - { - "type": "string", - "description": "Container ID", - "name": "id", - "in": "path", - "required": true - } - ], - "responses": { - "204": { - "description": "No Content" - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, "/embeddings": { "post": { "description": "Create text or multimodal embeddings", @@ -3808,7 +3885,7 @@ const docTemplate = `{ }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "owner_user_id": { "type": "string" @@ -3852,7 +3929,7 @@ const docTemplate = `{ }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "type": { "type": "string" @@ -3903,7 +3980,7 @@ const docTemplate = `{ }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} } } }, @@ -3918,6 +3995,137 @@ const docTemplate = `{ } } }, + "channel.Action": { + "type": "object", + "properties": { + "label": { + "type": "string" + }, + "type": { + "type": "string" + }, + "url": { + "type": "string" + }, + "value": { + "type": "string" + } + } + }, + "channel.Attachment": { + "type": "object", + "properties": { + "caption": { + "type": "string" + }, + "duration_ms": { + "type": "integer" + }, + "height": { + "type": "integer" + }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, + "mime": { + "type": "string" + }, + "name": { + "type": "string" + }, + "size": { + "type": "integer" + }, + "thumbnail_url": { + "type": "string" + }, + "type": { + "$ref": "#/definitions/channel.AttachmentType" + }, + "url": { + "type": "string" + }, + "width": { + "type": "integer" + } + } + }, + "channel.AttachmentType": { + "type": "string", + "enum": [ + "image", + "audio", + "video", + "voice", + "file", + "gif" + ], + "x-enum-varnames": [ + "AttachmentImage", + "AttachmentAudio", + "AttachmentVideo", + "AttachmentVoice", + "AttachmentFile", + "AttachmentGIF" + ] + }, + "channel.ChannelCapabilities": { + "type": "object", + "properties": { + "attachments": { + "type": "boolean" + }, + "block_streaming": { + "type": "boolean" + }, + "buttons": { + "type": "boolean" + }, + "chat_types": { + "type": "array", + "items": { + "type": "string" + } + }, + "edit": { + "type": "boolean" + }, + "markdown": { + "type": "boolean" + }, + "media": { + "type": "boolean" + }, + "native_commands": { + "type": "boolean" + }, + "polls": { + "type": "boolean" + }, + "reactions": { + "type": "boolean" + }, + "reply": { + "type": "boolean" + }, + "rich_text": { + "type": "boolean" + }, + "streaming": { + "type": "boolean" + }, + "text": { + "type": "boolean" + }, + "threads": { + "type": "boolean" + }, + "unsend": { + "type": "boolean" + } + } + }, "channel.ChannelConfig": { "type": "object", "properties": { @@ -3926,17 +4134,17 @@ const docTemplate = `{ }, "capabilities": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "channelType": { - "$ref": "#/definitions/channel.ChannelType" + "type": "string" }, "createdAt": { "type": "string" }, "credentials": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "externalIdentity": { "type": "string" @@ -3946,11 +4154,11 @@ const docTemplate = `{ }, "routing": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "selfIdentity": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "status": { "type": "string" @@ -3963,30 +4171,15 @@ const docTemplate = `{ } } }, - "channel.ChannelType": { - "type": "string", - "enum": [ - "telegram", - "feishu", - "cli", - "web" - ], - "x-enum-varnames": [ - "ChannelTelegram", - "ChannelFeishu", - "ChannelCLI", - "ChannelWeb" - ] - }, "channel.ChannelUserBinding": { "type": "object", "properties": { "channelType": { - "$ref": "#/definitions/channel.ChannelType" + "type": "string" }, "config": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "createdAt": { "type": "string" @@ -4002,16 +4195,235 @@ const docTemplate = `{ } } }, + "channel.ConfigSchema": { + "type": "object", + "properties": { + "fields": { + "type": "object", + "additionalProperties": { + "$ref": "#/definitions/channel.FieldSchema" + } + }, + "version": { + "type": "integer" + } + } + }, + "channel.FieldSchema": { + "type": "object", + "properties": { + "description": { + "type": "string" + }, + "enum": { + "type": "array", + "items": { + "type": "string" + } + }, + "example": {}, + "required": { + "type": "boolean" + }, + "title": { + "type": "string" + }, + "type": { + "$ref": "#/definitions/channel.FieldType" + } + } + }, + "channel.FieldType": { + "type": "string", + "enum": [ + "string", + "secret", + "bool", + "number", + "enum" + ], + "x-enum-varnames": [ + "FieldString", + "FieldSecret", + "FieldBool", + "FieldNumber", + "FieldEnum" + ] + }, + "channel.Message": { + "type": "object", + "properties": { + "actions": { + "type": "array", + "items": { + "$ref": "#/definitions/channel.Action" + } + }, + "attachments": { + "type": "array", + "items": { + "$ref": "#/definitions/channel.Attachment" + } + }, + "format": { + "$ref": "#/definitions/channel.MessageFormat" + }, + "id": { + "type": "string" + }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, + "parts": { + "type": "array", + "items": { + "$ref": "#/definitions/channel.MessagePart" + } + }, + "reply": { + "$ref": "#/definitions/channel.ReplyRef" + }, + "text": { + "type": "string" + }, + "thread": { + "$ref": "#/definitions/channel.ThreadRef" + } + } + }, + "channel.MessageFormat": { + "type": "string", + "enum": [ + "plain", + "markdown", + "rich" + ], + "x-enum-varnames": [ + "MessageFormatPlain", + "MessageFormatMarkdown", + "MessageFormatRich" + ] + }, + "channel.MessagePart": { + "type": "object", + "properties": { + "emoji": { + "type": "string" + }, + "language": { + "type": "string" + }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, + "styles": { + "type": "array", + "items": { + "$ref": "#/definitions/channel.MessageTextStyle" + } + }, + "text": { + "type": "string" + }, + "type": { + "$ref": "#/definitions/channel.MessagePartType" + }, + "url": { + "type": "string" + }, + "user_id": { + "type": "string" + } + } + }, + "channel.MessagePartType": { + "type": "string", + "enum": [ + "text", + "link", + "code_block", + "mention", + "emoji" + ], + "x-enum-varnames": [ + "MessagePartText", + "MessagePartLink", + "MessagePartCodeBlock", + "MessagePartMention", + "MessagePartEmoji" + ] + }, + "channel.MessageTextStyle": { + "type": "string", + "enum": [ + "bold", + "italic", + "strikethrough", + "code" + ], + "x-enum-varnames": [ + "MessageStyleBold", + "MessageStyleItalic", + "MessageStyleStrikethrough", + "MessageStyleCode" + ] + }, + "channel.ReplyRef": { + "type": "object", + "properties": { + "message_id": { + "type": "string" + }, + "target": { + "type": "string" + } + } + }, "channel.SendRequest": { "type": "object", "properties": { "message": { + "$ref": "#/definitions/channel.Message" + }, + "target": { "type": "string" }, - "to": { + "user_id": { + "type": "string" + } + } + }, + "channel.TargetHint": { + "type": "object", + "properties": { + "example": { "type": "string" }, - "to_user_id": { + "label": { + "type": "string" + } + } + }, + "channel.TargetSpec": { + "type": "object", + "properties": { + "format": { + "type": "string" + }, + "hints": { + "type": "array", + "items": { + "$ref": "#/definitions/channel.TargetHint" + } + } + } + }, + "channel.ThreadRef": { + "type": "object", + "properties": { + "id": { "type": "string" } } @@ -4021,22 +4433,22 @@ const docTemplate = `{ "properties": { "capabilities": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "credentials": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "external_identity": { "type": "string" }, "routing": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "self_identity": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "status": { "type": "string" @@ -4051,42 +4463,28 @@ const docTemplate = `{ "properties": { "config": { "type": "object", - "additionalProperties": true - } - } - }, - "chat.AgentSkill": { - "type": "object", - "properties": { - "content": { - "type": "string" - }, - "description": { - "type": "string" - }, - "name": { - "type": "string" + "additionalProperties": {} } } }, "chat.ChatRequest": { "type": "object", "properties": { + "allowed_actions": { + "type": "array", + "items": { + "type": "string" + } + }, "current_platform": { "type": "string" }, "language": { "type": "string" }, - "locale": { - "type": "string" - }, "max_context_load_time": { "type": "integer" }, - "max_steps": { - "type": "integer" - }, "messages": { "type": "array", "items": { @@ -4109,19 +4507,6 @@ const docTemplate = `{ "type": "string" }, "skills": { - "type": "array", - "items": { - "$ref": "#/definitions/chat.AgentSkill" - } - }, - "toolChoice": { - "type": "object", - "additionalProperties": {} - }, - "toolContext": { - "$ref": "#/definitions/chat.ToolContext" - }, - "use_skills": { "type": "array", "items": { "type": "string" @@ -4154,37 +4539,31 @@ const docTemplate = `{ }, "chat.GatewayMessage": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, - "chat.ToolContext": { + "handlers.ChannelMeta": { "type": "object", "properties": { - "botId": { + "capabilities": { + "$ref": "#/definitions/channel.ChannelCapabilities" + }, + "config_schema": { + "$ref": "#/definitions/channel.ConfigSchema" + }, + "configless": { + "type": "boolean" + }, + "display_name": { "type": "string" }, - "contactAlias": { + "target_spec": { + "$ref": "#/definitions/channel.TargetSpec" + }, + "type": { "type": "string" }, - "contactId": { - "type": "string" - }, - "contactName": { - "type": "string" - }, - "currentPlatform": { - "type": "string" - }, - "replyTarget": { - "type": "string" - }, - "sessionId": { - "type": "string" - }, - "sessionToken": { - "type": "string" - }, - "userId": { - "type": "string" + "user_config_schema": { + "$ref": "#/definitions/channel.ConfigSchema" } } }, @@ -4515,7 +4894,7 @@ const docTemplate = `{ }, "filters": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "infer": { "type": "boolean" @@ -4531,7 +4910,7 @@ const docTemplate = `{ }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "run_id": { "type": "string" @@ -4551,14 +4930,14 @@ const docTemplate = `{ "properties": { "filters": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "input": { "$ref": "#/definitions/memory.EmbedInput" }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "model": { "type": "string" @@ -4585,7 +4964,7 @@ const docTemplate = `{ }, "filters": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "limit": { "type": "integer" @@ -4619,9 +4998,13 @@ const docTemplate = `{ "type": "array", "items": { "type": "object", - "additionalProperties": true + "additionalProperties": {} } }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, "skills": { "type": "array", "items": { @@ -4654,9 +5037,13 @@ const docTemplate = `{ "type": "array", "items": { "type": "object", - "additionalProperties": true + "additionalProperties": {} } }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, "session_id": { "type": "string" }, @@ -4733,7 +5120,7 @@ const docTemplate = `{ }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "runId": { "type": "string" @@ -4795,6 +5182,12 @@ const docTemplate = `{ "dimensions": { "type": "integer" }, + "input": { + "type": "array", + "items": { + "type": "string" + } + }, "is_multimodal": { "type": "boolean" }, @@ -4837,6 +5230,12 @@ const docTemplate = `{ "dimensions": { "type": "integer" }, + "input": { + "type": "array", + "items": { + "type": "string" + } + }, "is_multimodal": { "type": "boolean" }, @@ -4871,6 +5270,12 @@ const docTemplate = `{ "dimensions": { "type": "integer" }, + "input": { + "type": "array", + "items": { + "type": "string" + } + }, "is_multimodal": { "type": "boolean" }, @@ -4932,7 +5337,7 @@ const docTemplate = `{ }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "name": { "type": "string" @@ -4960,7 +5365,7 @@ const docTemplate = `{ }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "name": { "type": "string" @@ -4984,7 +5389,7 @@ const docTemplate = `{ }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "name": { "type": "string" @@ -5161,7 +5566,7 @@ const docTemplate = `{ "type": "array", "items": { "type": "object", - "additionalProperties": true + "additionalProperties": {} } } } @@ -5176,12 +5581,12 @@ const docTemplate = `{ "type": "array", "items": { "type": "object", - "additionalProperties": true + "additionalProperties": {} } }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "name": { "type": "string" @@ -5241,12 +5646,12 @@ const docTemplate = `{ "type": "array", "items": { "type": "object", - "additionalProperties": true + "additionalProperties": {} } }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "name": { "type": "string" @@ -5269,7 +5674,7 @@ const docTemplate = `{ "type": "array", "items": { "type": "object", - "additionalProperties": true + "additionalProperties": {} } } } @@ -5282,7 +5687,7 @@ const docTemplate = `{ }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "name": { "type": "string" diff --git a/docs/swagger.json b/docs/swagger.json index 9dd101da..c513dacc 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -231,6 +231,193 @@ } } }, + "/bots/{bot_id}/container": { + "post": { + "tags": [ + "containerd" + ], + "summary": "Create and start MCP container for bot", + "parameters": [ + { + "type": "string", + "description": "Bot ID", + "name": "bot_id", + "in": "path", + "required": true + }, + { + "description": "Create container payload", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/handlers.CreateContainerRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/handlers.CreateContainerResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/bots/{bot_id}/container/list": { + "get": { + "tags": [ + "containerd" + ], + "summary": "List containers for bot", + "parameters": [ + { + "type": "string", + "description": "Bot ID", + "name": "bot_id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/handlers.ListContainersResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/bots/{bot_id}/container/snapshots": { + "get": { + "tags": [ + "containerd" + ], + "summary": "List snapshots", + "parameters": [ + { + "type": "string", + "description": "Bot ID", + "name": "bot_id", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "Snapshotter name", + "name": "snapshotter", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/handlers.ListSnapshotsResponse" + } + } + } + }, + "post": { + "tags": [ + "containerd" + ], + "summary": "Create container snapshot", + "parameters": [ + { + "type": "string", + "description": "Bot ID", + "name": "bot_id", + "in": "path", + "required": true + }, + { + "description": "Create snapshot payload", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/handlers.CreateSnapshotRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/handlers.CreateSnapshotResponse" + } + } + } + } + }, + "/bots/{bot_id}/container/{id}": { + "delete": { + "tags": [ + "containerd" + ], + "summary": "Delete MCP container", + "parameters": [ + { + "type": "string", + "description": "Bot ID", + "name": "bot_id", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "Container ID", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "204": { + "description": "No Content" + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, "/bots/{bot_id}/history": { "get": { "description": "List history records for current user", @@ -2119,28 +2306,53 @@ } } }, - "/container": { - "post": { + "/channels": { + "get": { + "description": "List channel meta information including capabilities and schemas", "tags": [ - "containerd" + "channel" ], - "summary": "Create and start MCP container", + "summary": "List channel capabilities and schemas", + "responses": { + "200": { + "description": "OK", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/handlers.ChannelMeta" + } + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/channels/{platform}": { + "get": { + "description": "Get channel meta information including capabilities and schemas", + "tags": [ + "channel" + ], + "summary": "Get channel capabilities and schemas", "parameters": [ { - "description": "Create container payload", - "name": "payload", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/handlers.CreateContainerRequest" - } + "type": "string", + "description": "Channel platform", + "name": "platform", + "in": "path", + "required": true } ], "responses": { "200": { "description": "OK", "schema": { - "$ref": "#/definitions/handlers.CreateContainerResponse" + "$ref": "#/definitions/handlers.ChannelMeta" } }, "400": { @@ -2149,8 +2361,8 @@ "$ref": "#/definitions/handlers.ErrorResponse" } }, - "500": { - "description": "Internal Server Error", + "404": { + "description": "Not Found", "schema": { "$ref": "#/definitions/handlers.ErrorResponse" } @@ -2218,28 +2430,6 @@ } } }, - "/container/list": { - "get": { - "tags": [ - "containerd" - ], - "summary": "List containers", - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/handlers.ListContainersResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, "/container/skills": { "get": { "tags": [ @@ -2360,119 +2550,6 @@ } } }, - "/container/snapshots": { - "get": { - "tags": [ - "containerd" - ], - "summary": "List snapshots", - "parameters": [ - { - "type": "string", - "description": "Snapshotter name", - "name": "snapshotter", - "in": "query" - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/handlers.ListSnapshotsResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - }, - "post": { - "tags": [ - "containerd" - ], - "summary": "Create container snapshot", - "parameters": [ - { - "description": "Create snapshot payload", - "name": "payload", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/handlers.CreateSnapshotRequest" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/handlers.CreateSnapshotResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/container/{id}": { - "delete": { - "tags": [ - "containerd" - ], - "summary": "Delete MCP container", - "parameters": [ - { - "type": "string", - "description": "Container ID", - "name": "id", - "in": "path", - "required": true - } - ], - "responses": { - "204": { - "description": "No Content" - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, "/embeddings": { "post": { "description": "Create text or multimodal embeddings", @@ -3799,7 +3876,7 @@ }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "owner_user_id": { "type": "string" @@ -3843,7 +3920,7 @@ }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "type": { "type": "string" @@ -3894,7 +3971,7 @@ }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} } } }, @@ -3909,6 +3986,137 @@ } } }, + "channel.Action": { + "type": "object", + "properties": { + "label": { + "type": "string" + }, + "type": { + "type": "string" + }, + "url": { + "type": "string" + }, + "value": { + "type": "string" + } + } + }, + "channel.Attachment": { + "type": "object", + "properties": { + "caption": { + "type": "string" + }, + "duration_ms": { + "type": "integer" + }, + "height": { + "type": "integer" + }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, + "mime": { + "type": "string" + }, + "name": { + "type": "string" + }, + "size": { + "type": "integer" + }, + "thumbnail_url": { + "type": "string" + }, + "type": { + "$ref": "#/definitions/channel.AttachmentType" + }, + "url": { + "type": "string" + }, + "width": { + "type": "integer" + } + } + }, + "channel.AttachmentType": { + "type": "string", + "enum": [ + "image", + "audio", + "video", + "voice", + "file", + "gif" + ], + "x-enum-varnames": [ + "AttachmentImage", + "AttachmentAudio", + "AttachmentVideo", + "AttachmentVoice", + "AttachmentFile", + "AttachmentGIF" + ] + }, + "channel.ChannelCapabilities": { + "type": "object", + "properties": { + "attachments": { + "type": "boolean" + }, + "block_streaming": { + "type": "boolean" + }, + "buttons": { + "type": "boolean" + }, + "chat_types": { + "type": "array", + "items": { + "type": "string" + } + }, + "edit": { + "type": "boolean" + }, + "markdown": { + "type": "boolean" + }, + "media": { + "type": "boolean" + }, + "native_commands": { + "type": "boolean" + }, + "polls": { + "type": "boolean" + }, + "reactions": { + "type": "boolean" + }, + "reply": { + "type": "boolean" + }, + "rich_text": { + "type": "boolean" + }, + "streaming": { + "type": "boolean" + }, + "text": { + "type": "boolean" + }, + "threads": { + "type": "boolean" + }, + "unsend": { + "type": "boolean" + } + } + }, "channel.ChannelConfig": { "type": "object", "properties": { @@ -3917,17 +4125,17 @@ }, "capabilities": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "channelType": { - "$ref": "#/definitions/channel.ChannelType" + "type": "string" }, "createdAt": { "type": "string" }, "credentials": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "externalIdentity": { "type": "string" @@ -3937,11 +4145,11 @@ }, "routing": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "selfIdentity": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "status": { "type": "string" @@ -3954,30 +4162,15 @@ } } }, - "channel.ChannelType": { - "type": "string", - "enum": [ - "telegram", - "feishu", - "cli", - "web" - ], - "x-enum-varnames": [ - "ChannelTelegram", - "ChannelFeishu", - "ChannelCLI", - "ChannelWeb" - ] - }, "channel.ChannelUserBinding": { "type": "object", "properties": { "channelType": { - "$ref": "#/definitions/channel.ChannelType" + "type": "string" }, "config": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "createdAt": { "type": "string" @@ -3993,16 +4186,235 @@ } } }, + "channel.ConfigSchema": { + "type": "object", + "properties": { + "fields": { + "type": "object", + "additionalProperties": { + "$ref": "#/definitions/channel.FieldSchema" + } + }, + "version": { + "type": "integer" + } + } + }, + "channel.FieldSchema": { + "type": "object", + "properties": { + "description": { + "type": "string" + }, + "enum": { + "type": "array", + "items": { + "type": "string" + } + }, + "example": {}, + "required": { + "type": "boolean" + }, + "title": { + "type": "string" + }, + "type": { + "$ref": "#/definitions/channel.FieldType" + } + } + }, + "channel.FieldType": { + "type": "string", + "enum": [ + "string", + "secret", + "bool", + "number", + "enum" + ], + "x-enum-varnames": [ + "FieldString", + "FieldSecret", + "FieldBool", + "FieldNumber", + "FieldEnum" + ] + }, + "channel.Message": { + "type": "object", + "properties": { + "actions": { + "type": "array", + "items": { + "$ref": "#/definitions/channel.Action" + } + }, + "attachments": { + "type": "array", + "items": { + "$ref": "#/definitions/channel.Attachment" + } + }, + "format": { + "$ref": "#/definitions/channel.MessageFormat" + }, + "id": { + "type": "string" + }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, + "parts": { + "type": "array", + "items": { + "$ref": "#/definitions/channel.MessagePart" + } + }, + "reply": { + "$ref": "#/definitions/channel.ReplyRef" + }, + "text": { + "type": "string" + }, + "thread": { + "$ref": "#/definitions/channel.ThreadRef" + } + } + }, + "channel.MessageFormat": { + "type": "string", + "enum": [ + "plain", + "markdown", + "rich" + ], + "x-enum-varnames": [ + "MessageFormatPlain", + "MessageFormatMarkdown", + "MessageFormatRich" + ] + }, + "channel.MessagePart": { + "type": "object", + "properties": { + "emoji": { + "type": "string" + }, + "language": { + "type": "string" + }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, + "styles": { + "type": "array", + "items": { + "$ref": "#/definitions/channel.MessageTextStyle" + } + }, + "text": { + "type": "string" + }, + "type": { + "$ref": "#/definitions/channel.MessagePartType" + }, + "url": { + "type": "string" + }, + "user_id": { + "type": "string" + } + } + }, + "channel.MessagePartType": { + "type": "string", + "enum": [ + "text", + "link", + "code_block", + "mention", + "emoji" + ], + "x-enum-varnames": [ + "MessagePartText", + "MessagePartLink", + "MessagePartCodeBlock", + "MessagePartMention", + "MessagePartEmoji" + ] + }, + "channel.MessageTextStyle": { + "type": "string", + "enum": [ + "bold", + "italic", + "strikethrough", + "code" + ], + "x-enum-varnames": [ + "MessageStyleBold", + "MessageStyleItalic", + "MessageStyleStrikethrough", + "MessageStyleCode" + ] + }, + "channel.ReplyRef": { + "type": "object", + "properties": { + "message_id": { + "type": "string" + }, + "target": { + "type": "string" + } + } + }, "channel.SendRequest": { "type": "object", "properties": { "message": { + "$ref": "#/definitions/channel.Message" + }, + "target": { "type": "string" }, - "to": { + "user_id": { + "type": "string" + } + } + }, + "channel.TargetHint": { + "type": "object", + "properties": { + "example": { "type": "string" }, - "to_user_id": { + "label": { + "type": "string" + } + } + }, + "channel.TargetSpec": { + "type": "object", + "properties": { + "format": { + "type": "string" + }, + "hints": { + "type": "array", + "items": { + "$ref": "#/definitions/channel.TargetHint" + } + } + } + }, + "channel.ThreadRef": { + "type": "object", + "properties": { + "id": { "type": "string" } } @@ -4012,22 +4424,22 @@ "properties": { "capabilities": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "credentials": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "external_identity": { "type": "string" }, "routing": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "self_identity": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "status": { "type": "string" @@ -4042,42 +4454,28 @@ "properties": { "config": { "type": "object", - "additionalProperties": true - } - } - }, - "chat.AgentSkill": { - "type": "object", - "properties": { - "content": { - "type": "string" - }, - "description": { - "type": "string" - }, - "name": { - "type": "string" + "additionalProperties": {} } } }, "chat.ChatRequest": { "type": "object", "properties": { + "allowed_actions": { + "type": "array", + "items": { + "type": "string" + } + }, "current_platform": { "type": "string" }, "language": { "type": "string" }, - "locale": { - "type": "string" - }, "max_context_load_time": { "type": "integer" }, - "max_steps": { - "type": "integer" - }, "messages": { "type": "array", "items": { @@ -4100,19 +4498,6 @@ "type": "string" }, "skills": { - "type": "array", - "items": { - "$ref": "#/definitions/chat.AgentSkill" - } - }, - "toolChoice": { - "type": "object", - "additionalProperties": {} - }, - "toolContext": { - "$ref": "#/definitions/chat.ToolContext" - }, - "use_skills": { "type": "array", "items": { "type": "string" @@ -4145,37 +4530,31 @@ }, "chat.GatewayMessage": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, - "chat.ToolContext": { + "handlers.ChannelMeta": { "type": "object", "properties": { - "botId": { + "capabilities": { + "$ref": "#/definitions/channel.ChannelCapabilities" + }, + "config_schema": { + "$ref": "#/definitions/channel.ConfigSchema" + }, + "configless": { + "type": "boolean" + }, + "display_name": { "type": "string" }, - "contactAlias": { + "target_spec": { + "$ref": "#/definitions/channel.TargetSpec" + }, + "type": { "type": "string" }, - "contactId": { - "type": "string" - }, - "contactName": { - "type": "string" - }, - "currentPlatform": { - "type": "string" - }, - "replyTarget": { - "type": "string" - }, - "sessionId": { - "type": "string" - }, - "sessionToken": { - "type": "string" - }, - "userId": { - "type": "string" + "user_config_schema": { + "$ref": "#/definitions/channel.ConfigSchema" } } }, @@ -4506,7 +4885,7 @@ }, "filters": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "infer": { "type": "boolean" @@ -4522,7 +4901,7 @@ }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "run_id": { "type": "string" @@ -4542,14 +4921,14 @@ "properties": { "filters": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "input": { "$ref": "#/definitions/memory.EmbedInput" }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "model": { "type": "string" @@ -4576,7 +4955,7 @@ }, "filters": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "limit": { "type": "integer" @@ -4610,9 +4989,13 @@ "type": "array", "items": { "type": "object", - "additionalProperties": true + "additionalProperties": {} } }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, "skills": { "type": "array", "items": { @@ -4645,9 +5028,13 @@ "type": "array", "items": { "type": "object", - "additionalProperties": true + "additionalProperties": {} } }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, "session_id": { "type": "string" }, @@ -4724,7 +5111,7 @@ }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "runId": { "type": "string" @@ -4786,6 +5173,12 @@ "dimensions": { "type": "integer" }, + "input": { + "type": "array", + "items": { + "type": "string" + } + }, "is_multimodal": { "type": "boolean" }, @@ -4828,6 +5221,12 @@ "dimensions": { "type": "integer" }, + "input": { + "type": "array", + "items": { + "type": "string" + } + }, "is_multimodal": { "type": "boolean" }, @@ -4862,6 +5261,12 @@ "dimensions": { "type": "integer" }, + "input": { + "type": "array", + "items": { + "type": "string" + } + }, "is_multimodal": { "type": "boolean" }, @@ -4923,7 +5328,7 @@ }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "name": { "type": "string" @@ -4951,7 +5356,7 @@ }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "name": { "type": "string" @@ -4975,7 +5380,7 @@ }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "name": { "type": "string" @@ -5152,7 +5557,7 @@ "type": "array", "items": { "type": "object", - "additionalProperties": true + "additionalProperties": {} } } } @@ -5167,12 +5572,12 @@ "type": "array", "items": { "type": "object", - "additionalProperties": true + "additionalProperties": {} } }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "name": { "type": "string" @@ -5232,12 +5637,12 @@ "type": "array", "items": { "type": "object", - "additionalProperties": true + "additionalProperties": {} } }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "name": { "type": "string" @@ -5260,7 +5665,7 @@ "type": "array", "items": { "type": "object", - "additionalProperties": true + "additionalProperties": {} } } } @@ -5273,7 +5678,7 @@ }, "metadata": { "type": "object", - "additionalProperties": true + "additionalProperties": {} }, "name": { "type": "string" diff --git a/docs/swagger.yaml b/docs/swagger.yaml index d55634f3..35f0c8f6 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -12,7 +12,7 @@ definitions: is_active: type: boolean metadata: - additionalProperties: true + additionalProperties: {} type: object owner_user_id: type: string @@ -41,7 +41,7 @@ definitions: is_active: type: boolean metadata: - additionalProperties: true + additionalProperties: {} type: object type: type: string @@ -74,7 +74,7 @@ definitions: is_active: type: boolean metadata: - additionalProperties: true + additionalProperties: {} type: object type: object bots.UpsertMemberRequest: @@ -84,29 +84,119 @@ definitions: user_id: type: string type: object + channel.Action: + properties: + label: + type: string + type: + type: string + url: + type: string + value: + type: string + type: object + channel.Attachment: + properties: + caption: + type: string + duration_ms: + type: integer + height: + type: integer + metadata: + additionalProperties: {} + type: object + mime: + type: string + name: + type: string + size: + type: integer + thumbnail_url: + type: string + type: + $ref: '#/definitions/channel.AttachmentType' + url: + type: string + width: + type: integer + type: object + channel.AttachmentType: + enum: + - image + - audio + - video + - voice + - file + - gif + type: string + x-enum-varnames: + - AttachmentImage + - AttachmentAudio + - AttachmentVideo + - AttachmentVoice + - AttachmentFile + - AttachmentGIF + channel.ChannelCapabilities: + properties: + attachments: + type: boolean + block_streaming: + type: boolean + buttons: + type: boolean + chat_types: + items: + type: string + type: array + edit: + type: boolean + markdown: + type: boolean + media: + type: boolean + native_commands: + type: boolean + polls: + type: boolean + reactions: + type: boolean + reply: + type: boolean + rich_text: + type: boolean + streaming: + type: boolean + text: + type: boolean + threads: + type: boolean + unsend: + type: boolean + type: object channel.ChannelConfig: properties: botID: type: string capabilities: - additionalProperties: true + additionalProperties: {} type: object channelType: - $ref: '#/definitions/channel.ChannelType' + type: string createdAt: type: string credentials: - additionalProperties: true + additionalProperties: {} type: object externalIdentity: type: string id: type: string routing: - additionalProperties: true + additionalProperties: {} type: object selfIdentity: - additionalProperties: true + additionalProperties: {} type: object status: type: string @@ -115,24 +205,12 @@ definitions: verifiedAt: type: string type: object - channel.ChannelType: - enum: - - telegram - - feishu - - cli - - web - type: string - x-enum-varnames: - - ChannelTelegram - - ChannelFeishu - - ChannelCLI - - ChannelWeb channel.ChannelUserBinding: properties: channelType: - $ref: '#/definitions/channel.ChannelType' + type: string config: - additionalProperties: true + additionalProperties: {} type: object createdAt: type: string @@ -143,30 +221,183 @@ definitions: userID: type: string type: object + channel.ConfigSchema: + properties: + fields: + additionalProperties: + $ref: '#/definitions/channel.FieldSchema' + type: object + version: + type: integer + type: object + channel.FieldSchema: + properties: + description: + type: string + enum: + items: + type: string + type: array + example: {} + required: + type: boolean + title: + type: string + type: + $ref: '#/definitions/channel.FieldType' + type: object + channel.FieldType: + enum: + - string + - secret + - bool + - number + - enum + type: string + x-enum-varnames: + - FieldString + - FieldSecret + - FieldBool + - FieldNumber + - FieldEnum + channel.Message: + properties: + actions: + items: + $ref: '#/definitions/channel.Action' + type: array + attachments: + items: + $ref: '#/definitions/channel.Attachment' + type: array + format: + $ref: '#/definitions/channel.MessageFormat' + id: + type: string + metadata: + additionalProperties: {} + type: object + parts: + items: + $ref: '#/definitions/channel.MessagePart' + type: array + reply: + $ref: '#/definitions/channel.ReplyRef' + text: + type: string + thread: + $ref: '#/definitions/channel.ThreadRef' + type: object + channel.MessageFormat: + enum: + - plain + - markdown + - rich + type: string + x-enum-varnames: + - MessageFormatPlain + - MessageFormatMarkdown + - MessageFormatRich + channel.MessagePart: + properties: + emoji: + type: string + language: + type: string + metadata: + additionalProperties: {} + type: object + styles: + items: + $ref: '#/definitions/channel.MessageTextStyle' + type: array + text: + type: string + type: + $ref: '#/definitions/channel.MessagePartType' + url: + type: string + user_id: + type: string + type: object + channel.MessagePartType: + enum: + - text + - link + - code_block + - mention + - emoji + type: string + x-enum-varnames: + - MessagePartText + - MessagePartLink + - MessagePartCodeBlock + - MessagePartMention + - MessagePartEmoji + channel.MessageTextStyle: + enum: + - bold + - italic + - strikethrough + - code + type: string + x-enum-varnames: + - MessageStyleBold + - MessageStyleItalic + - MessageStyleStrikethrough + - MessageStyleCode + channel.ReplyRef: + properties: + message_id: + type: string + target: + type: string + type: object channel.SendRequest: properties: message: + $ref: '#/definitions/channel.Message' + target: type: string - to: + user_id: type: string - to_user_id: + type: object + channel.TargetHint: + properties: + example: + type: string + label: + type: string + type: object + channel.TargetSpec: + properties: + format: + type: string + hints: + items: + $ref: '#/definitions/channel.TargetHint' + type: array + type: object + channel.ThreadRef: + properties: + id: type: string type: object channel.UpsertConfigRequest: properties: capabilities: - additionalProperties: true + additionalProperties: {} type: object credentials: - additionalProperties: true + additionalProperties: {} type: object external_identity: type: string routing: - additionalProperties: true + additionalProperties: {} type: object self_identity: - additionalProperties: true + additionalProperties: {} type: object status: type: string @@ -176,30 +407,21 @@ definitions: channel.UpsertUserConfigRequest: properties: config: - additionalProperties: true + additionalProperties: {} type: object type: object - chat.AgentSkill: - properties: - content: - type: string - description: - type: string - name: - type: string - type: object chat.ChatRequest: properties: + allowed_actions: + items: + type: string + type: array current_platform: type: string language: type: string - locale: - type: string max_context_load_time: type: integer - max_steps: - type: integer messages: items: $ref: '#/definitions/chat.GatewayMessage' @@ -215,15 +437,6 @@ definitions: query: type: string skills: - items: - $ref: '#/definitions/chat.AgentSkill' - type: array - toolChoice: - additionalProperties: {} - type: object - toolContext: - $ref: '#/definitions/chat.ToolContext' - use_skills: items: type: string type: array @@ -244,28 +457,24 @@ definitions: type: array type: object chat.GatewayMessage: - additionalProperties: true + additionalProperties: {} type: object - chat.ToolContext: + handlers.ChannelMeta: properties: - botId: + capabilities: + $ref: '#/definitions/channel.ChannelCapabilities' + config_schema: + $ref: '#/definitions/channel.ConfigSchema' + configless: + type: boolean + display_name: type: string - contactAlias: - type: string - contactId: - type: string - contactName: - type: string - currentPlatform: - type: string - replyTarget: - type: string - sessionId: - type: string - sessionToken: - type: string - userId: + target_spec: + $ref: '#/definitions/channel.TargetSpec' + type: type: string + user_config_schema: + $ref: '#/definitions/channel.ConfigSchema' type: object handlers.ContainerInfo: properties: @@ -478,7 +687,7 @@ definitions: embedding_enabled: type: boolean filters: - additionalProperties: true + additionalProperties: {} type: object infer: type: boolean @@ -489,7 +698,7 @@ definitions: $ref: '#/definitions/memory.Message' type: array metadata: - additionalProperties: true + additionalProperties: {} type: object run_id: type: string @@ -502,12 +711,12 @@ definitions: handlers.memoryEmbedUpsertPayload: properties: filters: - additionalProperties: true + additionalProperties: {} type: object input: $ref: '#/definitions/memory.EmbedInput' metadata: - additionalProperties: true + additionalProperties: {} type: object model: type: string @@ -525,7 +734,7 @@ definitions: embedding_enabled: type: boolean filters: - additionalProperties: true + additionalProperties: {} type: object limit: type: integer @@ -547,9 +756,12 @@ definitions: properties: messages: items: - additionalProperties: true + additionalProperties: {} type: object type: array + metadata: + additionalProperties: {} + type: object skills: items: type: string @@ -570,9 +782,12 @@ definitions: type: string messages: items: - additionalProperties: true + additionalProperties: {} type: object type: array + metadata: + additionalProperties: {} + type: object session_id: type: string skills: @@ -622,7 +837,7 @@ definitions: memory: type: string metadata: - additionalProperties: true + additionalProperties: {} type: object runId: type: string @@ -663,6 +878,10 @@ definitions: properties: dimensions: type: integer + input: + items: + type: string + type: array is_multimodal: type: boolean llm_provider_id: @@ -690,6 +909,10 @@ definitions: properties: dimensions: type: integer + input: + items: + type: string + type: array is_multimodal: type: boolean llm_provider_id: @@ -713,6 +936,10 @@ definitions: properties: dimensions: type: integer + input: + items: + type: string + type: array is_multimodal: type: boolean llm_provider_id: @@ -752,7 +979,7 @@ definitions: client_type: $ref: '#/definitions/providers.ClientType' metadata: - additionalProperties: true + additionalProperties: {} type: object name: type: string @@ -775,7 +1002,7 @@ definitions: id: type: string metadata: - additionalProperties: true + additionalProperties: {} type: object name: type: string @@ -791,7 +1018,7 @@ definitions: client_type: $ref: '#/definitions/providers.ClientType' metadata: - additionalProperties: true + additionalProperties: {} type: object name: type: string @@ -906,7 +1133,7 @@ definitions: properties: messages: items: - additionalProperties: true + additionalProperties: {} type: object type: array type: object @@ -916,11 +1143,11 @@ definitions: type: string messages: items: - additionalProperties: true + additionalProperties: {} type: object type: array metadata: - additionalProperties: true + additionalProperties: {} type: object name: type: string @@ -959,11 +1186,11 @@ definitions: type: string messages: items: - additionalProperties: true + additionalProperties: {} type: object type: array metadata: - additionalProperties: true + additionalProperties: {} type: object name: type: string @@ -978,7 +1205,7 @@ definitions: properties: messages: items: - additionalProperties: true + additionalProperties: {} type: object type: array type: object @@ -987,7 +1214,7 @@ definitions: description: type: string metadata: - additionalProperties: true + additionalProperties: {} type: object name: type: string @@ -1230,6 +1457,128 @@ paths: summary: Stream chat with AI tags: - chat + /bots/{bot_id}/container: + post: + parameters: + - description: Bot ID + in: path + name: bot_id + required: true + type: string + - description: Create container payload + in: body + name: payload + required: true + schema: + $ref: '#/definitions/handlers.CreateContainerRequest' + responses: + "200": + description: OK + schema: + $ref: '#/definitions/handlers.CreateContainerResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Create and start MCP container for bot + tags: + - containerd + /bots/{bot_id}/container/{id}: + delete: + parameters: + - description: Bot ID + in: path + name: bot_id + required: true + type: string + - description: Container ID + in: path + name: id + required: true + type: string + responses: + "204": + description: No Content + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Delete MCP container + tags: + - containerd + /bots/{bot_id}/container/list: + get: + parameters: + - description: Bot ID + in: path + name: bot_id + required: true + type: string + responses: + "200": + description: OK + schema: + $ref: '#/definitions/handlers.ListContainersResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: List containers for bot + tags: + - containerd + /bots/{bot_id}/container/snapshots: + get: + parameters: + - description: Bot ID + in: path + name: bot_id + required: true + type: string + - description: Snapshotter name + in: query + name: snapshotter + type: string + responses: + "200": + description: OK + schema: + $ref: '#/definitions/handlers.ListSnapshotsResponse' + summary: List snapshots + tags: + - containerd + post: + parameters: + - description: Bot ID + in: path + name: bot_id + required: true + type: string + - description: Create snapshot payload + in: body + name: payload + required: true + schema: + $ref: '#/definitions/handlers.CreateSnapshotRequest' + responses: + "200": + description: OK + schema: + $ref: '#/definitions/handlers.CreateSnapshotResponse' + summary: Create container snapshot + tags: + - containerd /bots/{bot_id}/history: delete: description: Delete all history records for current user @@ -2487,42 +2836,37 @@ paths: summary: Transfer bot owner (admin only) tags: - bots - /container: - post: - parameters: - - description: Create container payload - in: body - name: payload - required: true - schema: - $ref: '#/definitions/handlers.CreateContainerRequest' + /channels: + get: + description: List channel meta information including capabilities and schemas responses: "200": description: OK schema: - $ref: '#/definitions/handlers.CreateContainerResponse' - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' + items: + $ref: '#/definitions/handlers.ChannelMeta' + type: array "500": description: Internal Server Error schema: $ref: '#/definitions/handlers.ErrorResponse' - summary: Create and start MCP container + summary: List channel capabilities and schemas tags: - - containerd - /container/{id}: - delete: + - channel + /channels/{platform}: + get: + description: Get channel meta information including capabilities and schemas parameters: - - description: Container ID + - description: Channel platform in: path - name: id + name: platform required: true type: string responses: - "204": - description: No Content + "200": + description: OK + schema: + $ref: '#/definitions/handlers.ChannelMeta' "400": description: Bad Request schema: @@ -2531,13 +2875,9 @@ paths: description: Not Found schema: $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Delete MCP container + summary: Get channel capabilities and schemas tags: - - containerd + - channel /container/fs/{id}: post: description: |- @@ -2591,20 +2931,6 @@ paths: summary: MCP filesystem tools (JSON-RPC) tags: - containerd - /container/list: - get: - responses: - "200": - description: OK - schema: - $ref: '#/definitions/handlers.ListContainersResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: List containers - tags: - - containerd /container/skills: delete: parameters: @@ -2683,53 +3009,6 @@ paths: summary: Upload skills into container tags: - containerd - /container/snapshots: - get: - parameters: - - description: Snapshotter name - in: query - name: snapshotter - type: string - responses: - "200": - description: OK - schema: - $ref: '#/definitions/handlers.ListSnapshotsResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: List snapshots - tags: - - containerd - post: - parameters: - - description: Create snapshot payload - in: body - name: payload - required: true - schema: - $ref: '#/definitions/handlers.CreateSnapshotRequest' - responses: - "200": - description: OK - schema: - $ref: '#/definitions/handlers.CreateSnapshotResponse' - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "404": - description: Not Found - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Create container snapshot - tags: - - containerd /embeddings: post: description: Create text or multimodal embeddings diff --git a/internal/bots/service.go b/internal/bots/service.go index 94082699..d5526aaa 100644 --- a/internal/bots/service.go +++ b/internal/bots/service.go @@ -89,7 +89,7 @@ func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotR } metadata := req.Metadata if metadata == nil { - metadata = map[string]interface{}{} + metadata = map[string]any{} } payload, err := json.Marshal(metadata) if err != nil { @@ -438,16 +438,16 @@ func toBotMember(row sqlc.BotMember) BotMember { } } -func decodeMetadata(payload []byte) (map[string]interface{}, error) { +func decodeMetadata(payload []byte) (map[string]any, error) { if len(payload) == 0 { - return map[string]interface{}{}, nil + return map[string]any{}, nil } - var data map[string]interface{} + var data map[string]any if err := json.Unmarshal(payload, &data); err != nil { return nil, err } if data == nil { - data = map[string]interface{}{} + data = map[string]any{} } return data, nil } diff --git a/internal/bots/types.go b/internal/bots/types.go index b9e02e30..288e7f64 100644 --- a/internal/bots/types.go +++ b/internal/bots/types.go @@ -3,15 +3,15 @@ package bots import "time" type Bot struct { - ID string `json:"id"` - OwnerUserID string `json:"owner_user_id"` - Type string `json:"type"` - DisplayName string `json:"display_name"` - AvatarURL string `json:"avatar_url,omitempty"` - IsActive bool `json:"is_active"` - Metadata map[string]interface{} `json:"metadata,omitempty"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID string `json:"id"` + OwnerUserID string `json:"owner_user_id"` + Type string `json:"type"` + DisplayName string `json:"display_name"` + AvatarURL string `json:"avatar_url,omitempty"` + IsActive bool `json:"is_active"` + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } type BotMember struct { @@ -22,18 +22,18 @@ type BotMember struct { } type CreateBotRequest struct { - Type string `json:"type"` - DisplayName string `json:"display_name,omitempty"` - AvatarURL string `json:"avatar_url,omitempty"` - IsActive *bool `json:"is_active,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` + Type string `json:"type"` + DisplayName string `json:"display_name,omitempty"` + AvatarURL string `json:"avatar_url,omitempty"` + IsActive *bool `json:"is_active,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } type UpdateBotRequest struct { - DisplayName *string `json:"display_name,omitempty"` - AvatarURL *string `json:"avatar_url,omitempty"` - IsActive *bool `json:"is_active,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` + DisplayName *string `json:"display_name,omitempty"` + AvatarURL *string `json:"avatar_url,omitempty"` + IsActive *bool `json:"is_active,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } type TransferBotRequest struct { diff --git a/internal/channel/adapter.go b/internal/channel/adapter.go index a09cf9c3..4d566b42 100644 --- a/internal/channel/adapter.go +++ b/internal/channel/adapter.go @@ -2,64 +2,80 @@ package channel import ( "context" - "strings" + "errors" + "sync/atomic" ) -type InboundMessage struct { - Channel ChannelType - Text string - Username string - UserID string - OpenID string - ChatID string - ChatType string - ReplyTo string - BotID string // 增加 BotID 以支持多 Bot 隔离 - SessionKey string -} - -// SessionID 结构: platform:bot_id:chat_id[:sender_id] -func (m InboundMessage) SessionID() string { - if strings.TrimSpace(m.SessionKey) != "" { - return strings.TrimSpace(m.SessionKey) - } - return GenerateSessionID(string(m.Channel), m.BotID, m.ChatID, m.ChatType, m.OpenID, m.UserID, m.Username) -} - -// GenerateSessionID 统一生成 SessionID 的逻辑 -func GenerateSessionID(platform, botID, chatID, chatType, openID, userID, username string) string { - parts := []string{platform, botID, chatID} - // 如果是群聊,增加发送者 ID 以支持个人上下文 - ct := strings.ToLower(strings.TrimSpace(chatType)) - if ct != "" && ct != "p2p" && ct != "private" { - senderID := strings.TrimSpace(openID) - if senderID == "" { - senderID = strings.TrimSpace(userID) - } - if senderID == "" { - senderID = strings.TrimSpace(username) - } - if senderID != "" { - parts = append(parts, senderID) - } - } - return strings.Join(parts, ":") -} - -type OutboundMessage struct { - To string - Text string -} - -type AdapterRunner struct { - Stop func() - SupportsStop bool -} +var ErrStopNotSupported = errors.New("channel connection stop not supported") type InboundHandler func(ctx context.Context, cfg ChannelConfig, msg InboundMessage) error +type ReplySender interface { + Send(ctx context.Context, msg OutboundMessage) error +} + type Adapter interface { Type() ChannelType - Start(ctx context.Context, cfg ChannelConfig, handler InboundHandler) (AdapterRunner, error) +} + +type Sender interface { Send(ctx context.Context, cfg ChannelConfig, msg OutboundMessage) error } + +type Receiver interface { + Connect(ctx context.Context, cfg ChannelConfig, handler InboundHandler) (Connection, error) +} + +type Connection interface { + ConfigID() string + BotID() string + ChannelType() ChannelType + Stop(ctx context.Context) error + Running() bool +} + +type BaseConnection struct { + configID string + botID string + channelType ChannelType + stop func(ctx context.Context) error + running atomic.Bool +} + +func NewConnection(cfg ChannelConfig, stop func(ctx context.Context) error) *BaseConnection { + conn := &BaseConnection{ + configID: cfg.ID, + botID: cfg.BotID, + channelType: cfg.ChannelType, + stop: stop, + } + conn.running.Store(true) + return conn +} + +func (c *BaseConnection) ConfigID() string { + return c.configID +} + +func (c *BaseConnection) BotID() string { + return c.botID +} + +func (c *BaseConnection) ChannelType() ChannelType { + return c.channelType +} + +func (c *BaseConnection) Stop(ctx context.Context) error { + if c.stop == nil { + return ErrStopNotSupported + } + err := c.stop(ctx) + if err == nil { + c.running.Store(false) + } + return err +} + +func (c *BaseConnection) Running() bool { + return c.running.Load() +} diff --git a/internal/channel/adapters/feishu/config.go b/internal/channel/adapters/feishu/config.go new file mode 100644 index 00000000..da302857 --- /dev/null +++ b/internal/channel/adapters/feishu/config.go @@ -0,0 +1,142 @@ +package feishu + +import ( + "fmt" + "strings" + + "github.com/memohai/memoh/internal/channel" +) + +type Config struct { + AppID string + AppSecret string + EncryptKey string + VerificationToken string +} + +type UserConfig struct { + OpenID string + UserID string +} + +func NormalizeConfig(raw map[string]any) (map[string]any, error) { + cfg, err := parseConfig(raw) + if err != nil { + return nil, err + } + result := map[string]any{ + "appId": cfg.AppID, + "appSecret": cfg.AppSecret, + } + if cfg.EncryptKey != "" { + result["encryptKey"] = cfg.EncryptKey + } + if cfg.VerificationToken != "" { + result["verificationToken"] = cfg.VerificationToken + } + return result, nil +} + +func NormalizeUserConfig(raw map[string]any) (map[string]any, error) { + cfg, err := parseUserConfig(raw) + if err != nil { + return nil, err + } + result := map[string]any{} + if cfg.OpenID != "" { + result["open_id"] = cfg.OpenID + } + if cfg.UserID != "" { + result["user_id"] = cfg.UserID + } + return result, nil +} + +func ResolveTarget(raw map[string]any) (string, error) { + cfg, err := parseUserConfig(raw) + if err != nil { + return "", err + } + if cfg.OpenID != "" { + return "open_id:" + cfg.OpenID, nil + } + if cfg.UserID != "" { + return "user_id:" + cfg.UserID, nil + } + return "", fmt.Errorf("feishu binding is incomplete") +} + +func MatchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { + cfg, err := parseUserConfig(raw) + if err != nil { + return false + } + if value := strings.TrimSpace(criteria.Attribute("open_id")); value != "" && value == cfg.OpenID { + return true + } + if value := strings.TrimSpace(criteria.Attribute("user_id")); value != "" && value == cfg.UserID { + return true + } + if criteria.ExternalID != "" { + if criteria.ExternalID == cfg.OpenID || criteria.ExternalID == cfg.UserID { + return true + } + } + return false +} + +func BuildUserConfig(identity channel.Identity) map[string]any { + result := map[string]any{} + if value := strings.TrimSpace(identity.Attribute("open_id")); value != "" { + result["open_id"] = value + } + if value := strings.TrimSpace(identity.Attribute("user_id")); value != "" { + result["user_id"] = value + } + return result +} + +func parseConfig(raw map[string]any) (Config, error) { + appID := strings.TrimSpace(channel.ReadString(raw, "appId", "app_id")) + appSecret := strings.TrimSpace(channel.ReadString(raw, "appSecret", "app_secret")) + encryptKey := strings.TrimSpace(channel.ReadString(raw, "encryptKey", "encrypt_key")) + verificationToken := strings.TrimSpace(channel.ReadString(raw, "verificationToken", "verification_token")) + if appID == "" || appSecret == "" { + return Config{}, fmt.Errorf("feishu appId and appSecret are required") + } + return Config{ + AppID: appID, + AppSecret: appSecret, + EncryptKey: encryptKey, + VerificationToken: verificationToken, + }, nil +} + +func parseUserConfig(raw map[string]any) (UserConfig, error) { + openID := strings.TrimSpace(channel.ReadString(raw, "openId", "open_id")) + userID := strings.TrimSpace(channel.ReadString(raw, "userId", "user_id")) + if openID == "" && userID == "" { + return UserConfig{}, fmt.Errorf("feishu user config requires open_id or user_id") + } + return UserConfig{OpenID: openID, UserID: userID}, nil +} + +func normalizeTarget(raw string) string { + value := strings.TrimSpace(raw) + if value == "" { + return "" + } + if strings.HasPrefix(value, "open_id:") || strings.HasPrefix(value, "user_id:") || strings.HasPrefix(value, "chat_id:") { + return value + } + if strings.HasPrefix(value, "ou_") { + return "open_id:" + value + } + if strings.HasPrefix(value, "oc_") { + return "chat_id:" + value + } + if strings.HasPrefix(value, "user_id:") { + return value + } + return "open_id:" + value +} diff --git a/internal/channel/adapters/feishu/config_test.go b/internal/channel/adapters/feishu/config_test.go new file mode 100644 index 00000000..501a0886 --- /dev/null +++ b/internal/channel/adapters/feishu/config_test.go @@ -0,0 +1,81 @@ +package feishu + +import "testing" + +func TestNormalizeConfig(t *testing.T) { + t.Parallel() + + got, err := NormalizeConfig(map[string]any{ + "app_id": "app", + "app_secret": "secret", + "encrypt_key": "enc", + "verification_token": "verify", + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if got["appId"] != "app" || got["appSecret"] != "secret" { + t.Fatalf("unexpected feishu config: %#v", got) + } + if got["encryptKey"] != "enc" || got["verificationToken"] != "verify" { + t.Fatalf("unexpected feishu security config: %#v", got) + } +} + +func TestNormalizeConfigRequiresApp(t *testing.T) { + t.Parallel() + + _, err := NormalizeConfig(map[string]any{}) + if err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestNormalizeUserConfig(t *testing.T) { + t.Parallel() + + got, err := NormalizeUserConfig(map[string]any{ + "open_id": "ou_123", + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if got["open_id"] != "ou_123" { + t.Fatalf("unexpected open_id: %#v", got["open_id"]) + } +} + +func TestNormalizeUserConfigRequiresBinding(t *testing.T) { + t.Parallel() + + _, err := NormalizeUserConfig(map[string]any{}) + if err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestResolveTarget(t *testing.T) { + t.Parallel() + + target, err := ResolveTarget(map[string]any{ + "open_id": "ou_123", + "user_id": "u_123", + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if target != "open_id:ou_123" { + t.Fatalf("unexpected target: %s", target) + } +} + +func TestNormalizeTarget(t *testing.T) { + t.Parallel() + + if got := normalizeTarget("ou_123"); got != "open_id:ou_123" { + t.Fatalf("unexpected normalized target: %s", got) + } + if got := normalizeTarget("chat_id:oc_123"); got != "chat_id:oc_123" { + t.Fatalf("unexpected normalized target: %s", got) + } +} diff --git a/internal/channel/adapters/feishu/descriptor.go b/internal/channel/adapters/feishu/descriptor.go index dd9c887a..fdb9a5e5 100644 --- a/internal/channel/adapters/feishu/descriptor.go +++ b/internal/channel/adapters/feishu/descriptor.go @@ -2,11 +2,53 @@ package feishu import "github.com/memohai/memoh/internal/channel" +const Type channel.ChannelType = "feishu" + func init() { channel.MustRegisterChannel(channel.ChannelDescriptor{ - Type: channel.ChannelFeishu, + Type: Type, DisplayName: "Feishu", - NormalizeConfig: channel.NormalizeFeishuConfig, - NormalizeUserConfig: channel.NormalizeFeishuUserConfig, + NormalizeConfig: NormalizeConfig, + NormalizeUserConfig: NormalizeUserConfig, + ResolveTarget: ResolveTarget, + MatchBinding: MatchBinding, + BuildUserConfig: BuildUserConfig, + TargetSpec: channel.TargetSpec{ + Format: "open_id:xxx | user_id:xxx | chat_id:xxx", + Hints: []channel.TargetHint{ + {Label: "Open ID", Example: "open_id:ou_xxx"}, + {Label: "User ID", Example: "user_id:ou_xxx"}, + {Label: "Chat ID", Example: "chat_id:oc_xxx"}, + }, + }, + NormalizeTarget: normalizeTarget, + Capabilities: channel.ChannelCapabilities{ + Text: true, + RichText: true, + Attachments: true, + Reply: true, + }, + ConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "appId": {Type: channel.FieldString, Required: true, Title: "App ID"}, + "appSecret": {Type: channel.FieldSecret, Required: true, Title: "App Secret"}, + "encryptKey": { + Type: channel.FieldSecret, + Title: "Encrypt Key", + }, + "verificationToken": { + Type: channel.FieldSecret, + Title: "Verification Token", + }, + }, + }, + UserConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "open_id": {Type: channel.FieldString}, + "user_id": {Type: channel.FieldString}, + }, + }, }) } diff --git a/internal/channel/adapters/feishu/feishu.go b/internal/channel/adapters/feishu/feishu.go index 4d1e7a9d..dddfcb3d 100644 --- a/internal/channel/adapters/feishu/feishu.go +++ b/internal/channel/adapters/feishu/feishu.go @@ -5,7 +5,9 @@ import ( "encoding/json" "fmt" "log/slog" + "net/http" "strings" + "time" "github.com/google/uuid" lark "github.com/larksuite/oapi-sdk-go/v3" @@ -32,27 +34,29 @@ func NewFeishuAdapter(log *slog.Logger) *FeishuAdapter { } func (a *FeishuAdapter) Type() channel.ChannelType { - return channel.ChannelFeishu + return Type } -func (a *FeishuAdapter) Start(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.AdapterRunner, error) { +func (a *FeishuAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.Connection, error) { if a.logger != nil { a.logger.Info("start", slog.String("config_id", cfg.ID)) } - feishuCfg, err := decodeFeishuConfig(cfg.Credentials) + feishuCfg, err := parseConfig(cfg.Credentials) if err != nil { if a.logger != nil { a.logger.Error("decode config failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) } - return channel.AdapterRunner{}, err + return nil, err } + connCtx, cancel := context.WithCancel(ctx) eventDispatcher := dispatcher.NewEventDispatcher( feishuCfg.VerificationToken, feishuCfg.EncryptKey, ) eventDispatcher.OnP2MessageReceiveV1(func(_ context.Context, event *larkim.P2MessageReceiveV1) error { msg := extractFeishuInbound(event) - if msg.Text == "" { + text := msg.Message.PlainText() + if text == "" && len(msg.Message.Attachments) == 0 { return nil } msg.BotID = cfg.BotID @@ -61,12 +65,12 @@ func (a *FeishuAdapter) Start(ctx context.Context, cfg channel.ChannelConfig, ha "inbound received", slog.String("config_id", cfg.ID), slog.String("session_id", msg.SessionID()), - slog.String("chat_type", msg.ChatType), - slog.String("text", common.SummarizeText(msg.Text)), + slog.String("chat_type", msg.Conversation.Type), + slog.String("text", common.SummarizeText(text)), ) } go func() { - if err := handler(ctx, cfg, msg); err != nil && a.logger != nil { + if err := handler(connCtx, cfg, msg); err != nil && a.logger != nil { a.logger.Error("handle inbound failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) } }() @@ -85,87 +89,281 @@ func (a *FeishuAdapter) Start(ctx context.Context, cfg channel.ChannelConfig, ha ) go func() { - if err := client.Start(ctx); err != nil && a.logger != nil { + if err := client.Start(connCtx); err != nil && a.logger != nil { a.logger.Error("client start failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) } }() - return channel.AdapterRunner{ - Stop: func() {}, - SupportsStop: false, - }, nil + stop := func(context.Context) error { + cancel() + return nil + } + return channel.NewConnection(cfg, stop), nil } func (a *FeishuAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { - feishuCfg, err := decodeFeishuConfig(cfg.Credentials) + feishuCfg, err := parseConfig(cfg.Credentials) if err != nil { if a.logger != nil { a.logger.Error("decode config failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) } return err } - text := strings.TrimSpace(msg.Text) - if text == "" { - return fmt.Errorf("message is required") - } - receiveID, receiveType, err := resolveFeishuReceiveID(strings.TrimSpace(msg.To)) - if err != nil { - return err - } - contentPayload, err := json.Marshal(map[string]string{"text": text}) + + receiveID, receiveType, err := resolveFeishuReceiveID(strings.TrimSpace(msg.Target)) if err != nil { return err } + client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret) - body := larkim.NewCreateMessageReqBodyBuilder(). + + // 1. 处理附件 + if len(msg.Message.Attachments) > 0 { + for _, att := range msg.Message.Attachments { + if err := a.sendAttachment(ctx, client, receiveID, receiveType, att, msg.Message.Text); err != nil { + return err + } + } + return nil + } + + // 2. 处理富文本或普通文本 + var msgType string + var content string + + if len(msg.Message.Parts) > 1 { + msgType = larkim.MsgTypePost + content, err = a.buildPostContent(msg.Message) + } else { + msgType = larkim.MsgTypeText + text := strings.TrimSpace(msg.Message.PlainText()) + if text == "" { + return fmt.Errorf("message is required") + } + payload, _ := json.Marshal(map[string]string{"text": text}) + content = string(payload) + } + + if err != nil { + return err + } + + reqBuilder := larkim.NewCreateMessageReqBodyBuilder(). ReceiveId(receiveID). - MsgType(larkim.MsgTypeText). - Content(string(contentPayload)). - Uuid(uuid.NewString()). - Build() + MsgType(msgType). + Content(content). + Uuid(uuid.NewString()) + req := larkim.NewCreateMessageReqBuilder(). ReceiveIdType(receiveType). - Body(body). + Body(reqBuilder.Build()). Build() + + // 处理回复 + if msg.Message.Reply != nil && msg.Message.Reply.MessageID != "" { + replyReq := larkim.NewReplyMessageReqBuilder(). + MessageId(msg.Message.Reply.MessageID). + Body(larkim.NewReplyMessageReqBodyBuilder(). + Content(content). + MsgType(msgType). + Uuid(uuid.NewString()). + Build()). + Build() + resp, err := client.Im.V1.Message.Reply(ctx, replyReq) + return a.handleReplyResponse(cfg.ID, resp, err) + } + resp, err := client.Im.V1.Message.Create(ctx, req) + return a.handleResponse(cfg.ID, resp, err) +} + +func (a *FeishuAdapter) handleReplyResponse(configID string, resp *larkim.ReplyMessageResp, err error) error { if err != nil { if a.logger != nil { - a.logger.Error("send failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) + a.logger.Error("reply failed", slog.String("config_id", configID), slog.Any("error", err)) } return err } if resp == nil || !resp.Success() { - if a.logger != nil { - code := 0 - msg := "" - if resp != nil { - code = resp.Code - msg = resp.Msg - } - a.logger.Error("send failed", slog.String("config_id", cfg.ID), slog.Int("code", code), slog.String("msg", msg)) + code := 0 + msg := "" + if resp != nil { + code = resp.Code + msg = resp.Msg } - return fmt.Errorf("feishu send failed") + if a.logger != nil { + a.logger.Error("reply failed", slog.String("config_id", configID), slog.Int("code", code), slog.String("msg", msg)) + } + return fmt.Errorf("feishu reply failed: %s (code: %d)", msg, code) } if a.logger != nil { - a.logger.Info("send success", slog.String("config_id", cfg.ID)) + a.logger.Info("reply success", slog.String("config_id", configID)) } return nil } +func (a *FeishuAdapter) handleResponse(configID string, resp *larkim.CreateMessageResp, err error) error { + if err != nil { + if a.logger != nil { + a.logger.Error("send failed", slog.String("config_id", configID), slog.Any("error", err)) + } + return err + } + if resp == nil || !resp.Success() { + code := 0 + msg := "" + if resp != nil { + code = resp.Code + msg = resp.Msg + } + if a.logger != nil { + a.logger.Error("send failed", slog.String("config_id", configID), slog.Int("code", code), slog.String("msg", msg)) + } + return fmt.Errorf("feishu send failed: %s (code: %d)", msg, code) + } + if a.logger != nil { + a.logger.Info("send success", slog.String("config_id", configID)) + } + return nil +} + +func (a *FeishuAdapter) sendAttachment(ctx context.Context, client *lark.Client, receiveID, receiveType string, att channel.Attachment, text string) error { + // 下载文件 + resp, err := http.Get(att.URL) + if err != nil { + return fmt.Errorf("failed to download attachment: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to download attachment, status: %d", resp.StatusCode) + } + + var msgType string + var contentMap map[string]string + + if strings.HasPrefix(att.Mime, "image/") || att.Type == channel.AttachmentImage { + // 上传图片 + uploadReq := larkim.NewCreateImageReqBuilder(). + Body(larkim.NewCreateImageReqBodyBuilder(). + ImageType(larkim.ImageTypeMessage). + Image(resp.Body). + Build()). + Build() + uploadResp, err := client.Im.V1.Image.Create(ctx, uploadReq) + if err != nil || !uploadResp.Success() { + return fmt.Errorf("failed to upload image: %w", err) + } + msgType = larkim.MsgTypeImage + contentMap = map[string]string{"image_key": *uploadResp.Data.ImageKey} + } else { + // 上传文件 + uploadReq := larkim.NewCreateFileReqBuilder(). + Body(larkim.NewCreateFileReqBodyBuilder(). + FileType(larkim.FileTypePdf). // 默认为 pdf,飞书支持 mp4, doc, xls, ppt, pdf, zip + FileName(att.Name). + File(resp.Body). + Build()). + Build() + uploadResp, err := client.Im.V1.File.Create(ctx, uploadReq) + if err != nil || !uploadResp.Success() { + return fmt.Errorf("failed to upload file: %w", err) + } + msgType = larkim.MsgTypeFile + contentMap = map[string]string{"file_key": *uploadResp.Data.FileKey} + } + + content, _ := json.Marshal(contentMap) + req := larkim.NewCreateMessageReqBuilder(). + ReceiveIdType(receiveType). + Body(larkim.NewCreateMessageReqBodyBuilder(). + ReceiveId(receiveID). + MsgType(msgType). + Content(string(content)). + Uuid(uuid.NewString()). + Build()). + Build() + + _, err = client.Im.V1.Message.Create(ctx, req) + return err +} + +func (a *FeishuAdapter) buildPostContent(msg channel.Message) (string, error) { + // 简单的 Post 构建逻辑 + type postContent struct { + ZhCn struct { + Title string `json:"title"` + Content [][]any `json:"content"` + } `json:"zh_cn"` + } + + pc := postContent{} + pc.ZhCn.Title = "" // 暂时不设标题 + + line := []any{} + for _, part := range msg.Parts { + if part.Type == channel.MessagePartText { + line = append(line, map[string]any{ + "tag": "text", + "text": part.Text, + }) + } + } + pc.ZhCn.Content = [][]any{line} + + payload, err := json.Marshal(pc) + return string(payload), err +} + func extractFeishuInbound(event *larkim.P2MessageReceiveV1) channel.InboundMessage { if event == nil || event.Event == nil || event.Event.Message == nil { - return channel.InboundMessage{Channel: channel.ChannelFeishu} + return channel.InboundMessage{Channel: Type} } message := event.Event.Message - if message.MessageType == nil || *message.MessageType != larkim.MsgTypeText { - return channel.InboundMessage{Channel: channel.ChannelFeishu} - } - var payload struct { - Text string `json:"text"` + + var msg channel.Message + if message.MessageId != nil { + msg.ID = *message.MessageId } + + // 解析内容 + var contentMap map[string]any if message.Content != nil { - _ = json.Unmarshal([]byte(*message.Content), &payload) + _ = json.Unmarshal([]byte(*message.Content), &contentMap) } + + if message.MessageType != nil { + switch *message.MessageType { + case larkim.MsgTypeText: + if txt, ok := contentMap["text"].(string); ok { + msg.Text = txt + } + case larkim.MsgTypeImage: + if key, ok := contentMap["image_key"].(string); ok { + msg.Attachments = append(msg.Attachments, channel.Attachment{ + Type: channel.AttachmentImage, + URL: key, // 飞书内部 key,上层需注意 + }) + } + case larkim.MsgTypeFile, larkim.MsgTypeAudio: + if key, ok := contentMap["file_key"].(string); ok { + name, _ := contentMap["file_name"].(string) + msg.Attachments = append(msg.Attachments, channel.Attachment{ + Type: channel.AttachmentType(*message.MessageType), + URL: key, + Name: name, + }) + } + } + } + + // 处理回复引用 + if message.ParentId != nil && *message.ParentId != "" { + msg.Reply = &channel.ReplyRef{ + MessageID: *message.ParentId, + } + } + senderID, senderOpenID := "", "" if event.Event.Sender != nil && event.Event.Sender.SenderId != nil { if event.Event.Sender.SenderId.UserId != nil { @@ -190,14 +388,33 @@ func extractFeishuInbound(event *larkim.P2MessageReceiveV1) channel.InboundMessa if chatType != "" && chatType != "p2p" && chatID != "" { replyTo = "chat_id:" + chatID } + attrs := map[string]string{} + if senderID != "" { + attrs["user_id"] = senderID + } + if senderOpenID != "" { + attrs["open_id"] = senderOpenID + } + externalID := senderOpenID + if externalID == "" { + externalID = senderID + } + return channel.InboundMessage{ - Channel: channel.ChannelFeishu, - Text: strings.TrimSpace(payload.Text), - UserID: senderID, - OpenID: senderOpenID, - ChatID: chatID, - ChatType: chatType, - ReplyTo: replyTo, + Channel: Type, + Message: msg, + ReplyTarget: replyTo, + Sender: channel.Identity{ + ExternalID: externalID, + DisplayName: senderOpenID, + Attributes: attrs, + }, + Conversation: channel.Conversation{ + ID: chatID, + Type: chatType, + }, + ReceivedAt: time.Now().UTC(), + Source: "feishu", } } @@ -216,11 +433,3 @@ func resolveFeishuReceiveID(raw string) (string, string, error) { } return raw, larkim.ReceiveIdTypeOpenId, nil } - -func decodeFeishuConfig(raw map[string]interface{}) (channel.FeishuConfig, error) { - payload, err := json.Marshal(raw) - if err != nil { - return channel.FeishuConfig{}, err - } - return channel.DecodeFeishuConfig(payload) -} diff --git a/internal/channel/adapters/feishu/feishu_integration_test.go b/internal/channel/adapters/feishu/feishu_integration_test.go index da477341..80daea33 100644 --- a/internal/channel/adapters/feishu/feishu_integration_test.go +++ b/internal/channel/adapters/feishu/feishu_integration_test.go @@ -37,7 +37,7 @@ func TestFeishuGateway_Integration(t *testing.T) { // 构造测试配置 cfg := channel.ChannelConfig{ ID: "integration-test-bot", - Credentials: map[string]interface{}{ + Credentials: map[string]any{ "app_id": appID, "app_secret": appSecret, "encrypt_key": encryptKey, @@ -54,9 +54,10 @@ func TestFeishuGateway_Integration(t *testing.T) { // 模拟 InboundHandler handler := func(ctx context.Context, c channel.ChannelConfig, msg channel.InboundMessage) error { + plainText := msg.Message.PlainText() logger.Info("测试收到消息", - slog.String("text", msg.Text), - slog.String("user_id", msg.UserID), + slog.String("text", plainText), + slog.String("user_id", msg.Sender.Attribute("user_id")), slog.String("session_id", msg.SessionID())) // 将消息放入通道,供主测试逻辑验证 @@ -67,8 +68,10 @@ func TestFeishuGateway_Integration(t *testing.T) { // 自动回复测试 (验证下行链路) reply := channel.OutboundMessage{ - To: msg.ReplyTo, - Text: fmt.Sprintf("【Memoh 集成测试】已收到消息: %s\n测试时间: %s", msg.Text, time.Now().Format("15:04:05")), + Target: msg.ReplyTarget, + Message: channel.Message{ + Text: fmt.Sprintf("【Memoh 集成测试】已收到消息: %s\n测试时间: %s", plainText, time.Now().Format("15:04:05")), + }, } if err := adapter.Send(ctx, c, reply); err != nil { @@ -79,8 +82,10 @@ func TestFeishuGateway_Integration(t *testing.T) { go func() { time.Sleep(1 * time.Second) pushMsg := channel.OutboundMessage{ - To: msg.ReplyTo, - Text: "【Memoh 集成测试】主动推送验证成功。", + Target: msg.ReplyTarget, + Message: channel.Message{ + Text: "【Memoh 集成测试】主动推送验证成功。", + }, } _ = adapter.Send(context.Background(), c, pushMsg) }() @@ -90,11 +95,13 @@ func TestFeishuGateway_Integration(t *testing.T) { // 启动适配器 logger.Info("正在启动飞书适配器...", slog.String("app_id", appID)) - runner, err := adapter.Start(ctx, cfg, handler) + runner, err := adapter.Connect(ctx, cfg, handler) if err != nil { t.Fatalf("适配器启动失败: %v", err) } - defer runner.Stop() + defer func() { + _ = runner.Stop(context.Background()) + }() fmt.Println("==================================================================") fmt.Println("🚀 飞书集成测试已就绪!") @@ -105,7 +112,7 @@ func TestFeishuGateway_Integration(t *testing.T) { // 等待测试结果 select { case msg := <-receivedChan: - logger.Info("集成测试验证成功!", slog.String("received_text", msg.Text)) + logger.Info("集成测试验证成功!", slog.String("received_text", msg.Message.PlainText())) // 给一点时间让异步推送完成 time.Sleep(2 * time.Second) case <-ctx.Done(): diff --git a/internal/channel/adapters/feishu/feishu_logger.go b/internal/channel/adapters/feishu/feishu_logger.go index 97ec9bd4..0143f2ee 100644 --- a/internal/channel/adapters/feishu/feishu_logger.go +++ b/internal/channel/adapters/feishu/feishu_logger.go @@ -19,23 +19,23 @@ func newLarkSlogLogger(logger *slog.Logger) larkcore.Logger { return &larkSlogLogger{logger: logger} } -func (l *larkSlogLogger) Debug(ctx context.Context, args ...interface{}) { +func (l *larkSlogLogger) Debug(ctx context.Context, args ...any) { l.log(ctx, slog.LevelDebug, args...) } -func (l *larkSlogLogger) Info(ctx context.Context, args ...interface{}) { +func (l *larkSlogLogger) Info(ctx context.Context, args ...any) { l.log(ctx, slog.LevelInfo, args...) } -func (l *larkSlogLogger) Warn(ctx context.Context, args ...interface{}) { +func (l *larkSlogLogger) Warn(ctx context.Context, args ...any) { l.log(ctx, slog.LevelWarn, args...) } -func (l *larkSlogLogger) Error(ctx context.Context, args ...interface{}) { +func (l *larkSlogLogger) Error(ctx context.Context, args ...any) { l.log(ctx, slog.LevelError, args...) } -func (l *larkSlogLogger) log(ctx context.Context, level slog.Level, args ...interface{}) { +func (l *larkSlogLogger) log(ctx context.Context, level slog.Level, args ...any) { if l.logger == nil { return } diff --git a/internal/channel/adapters/feishu/feishu_test.go b/internal/channel/adapters/feishu/feishu_test.go index 2190dbb6..5c6c47dc 100644 --- a/internal/channel/adapters/feishu/feishu_test.go +++ b/internal/channel/adapters/feishu/feishu_test.go @@ -64,11 +64,11 @@ func TestExtractFeishuInboundP2P(t *testing.T) { }, } got := extractFeishuInbound(event) - if got.Text != "hi" { - t.Fatalf("unexpected text: %s", got.Text) + if got.Message.PlainText() != "hi" { + t.Fatalf("unexpected text: %s", got.Message.PlainText()) } - if got.ReplyTo != "ou_1" { - t.Fatalf("unexpected reply target: %s", got.ReplyTo) + if got.ReplyTarget != "ou_1" { + t.Fatalf("unexpected reply target: %s", got.ReplyTarget) } } @@ -98,8 +98,8 @@ func TestExtractFeishuInboundGroup(t *testing.T) { }, } got := extractFeishuInbound(event) - if got.ReplyTo != "chat_id:oc_2" { - t.Fatalf("unexpected reply target: %s", got.ReplyTo) + if got.ReplyTarget != "chat_id:oc_2" { + t.Fatalf("unexpected reply target: %s", got.ReplyTarget) } } @@ -115,7 +115,7 @@ func TestExtractFeishuInboundNonText(t *testing.T) { }, } got := extractFeishuInbound(event) - if got.Text != "" { - t.Fatalf("expected empty text, got %s", got.Text) + if got.Message.PlainText() != "" { + t.Fatalf("expected empty text, got %s", got.Message.PlainText()) } } diff --git a/internal/channel/adapters/local/cli.go b/internal/channel/adapters/local/cli.go index 3e3744e1..e2675207 100644 --- a/internal/channel/adapters/local/cli.go +++ b/internal/channel/adapters/local/cli.go @@ -17,23 +17,18 @@ func NewCLIAdapter(hub *channel.SessionHub) *CLIAdapter { } func (a *CLIAdapter) Type() channel.ChannelType { - return channel.ChannelCLI -} - -func (a *CLIAdapter) Start(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.AdapterRunner, error) { - return channel.AdapterRunner{SupportsStop: false}, nil + return CLIType } func (a *CLIAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { if a.hub == nil { return fmt.Errorf("cli hub not configured") } - target := strings.TrimSpace(msg.To) + target := strings.TrimSpace(msg.Target) if target == "" { return fmt.Errorf("cli target is required") } - text := strings.TrimSpace(msg.Text) - if text == "" { + if msg.Message.IsEmpty() { return fmt.Errorf("message is required") } a.hub.Publish(target, msg) diff --git a/internal/channel/adapters/local/descriptor.go b/internal/channel/adapters/local/descriptor.go index 20c1f83b..abf051d5 100644 --- a/internal/channel/adapters/local/descriptor.go +++ b/internal/channel/adapters/local/descriptor.go @@ -1,22 +1,67 @@ package local -import "github.com/memohai/memoh/internal/channel" +import ( + "strings" + + "github.com/memohai/memoh/internal/channel" +) + +const ( + CLIType channel.ChannelType = "cli" + WebType channel.ChannelType = "web" +) func init() { channel.MustRegisterChannel(channel.ChannelDescriptor{ - Type: channel.ChannelCLI, + Type: CLIType, DisplayName: "CLI", NormalizeConfig: normalizeEmpty, NormalizeUserConfig: normalizeEmpty, + BuildUserConfig: buildEmpty, + Configless: true, + TargetSpec: channel.TargetSpec{ + Format: "session_id", + Hints: []channel.TargetHint{ + {Label: "Session ID", Example: "cli:uuid"}, + }, + }, + NormalizeTarget: normalizeTarget, + Capabilities: channel.ChannelCapabilities{ + Text: true, + Reply: true, + Attachments: true, + }, }) channel.MustRegisterChannel(channel.ChannelDescriptor{ - Type: channel.ChannelWeb, + Type: WebType, DisplayName: "Web", NormalizeConfig: normalizeEmpty, NormalizeUserConfig: normalizeEmpty, + BuildUserConfig: buildEmpty, + Configless: true, + TargetSpec: channel.TargetSpec{ + Format: "session_id", + Hints: []channel.TargetHint{ + {Label: "Session ID", Example: "web:uuid"}, + }, + }, + NormalizeTarget: normalizeTarget, + Capabilities: channel.ChannelCapabilities{ + Text: true, + Reply: true, + Attachments: true, + }, }) } -func normalizeEmpty(map[string]interface{}) (map[string]interface{}, error) { - return map[string]interface{}{}, nil +func normalizeTarget(raw string) string { + return strings.TrimSpace(raw) +} + +func normalizeEmpty(map[string]any) (map[string]any, error) { + return map[string]any{}, nil +} + +func buildEmpty(channel.Identity) map[string]any { + return map[string]any{} } diff --git a/internal/channel/adapters/local/web.go b/internal/channel/adapters/local/web.go index 6db4dea8..d0b24682 100644 --- a/internal/channel/adapters/local/web.go +++ b/internal/channel/adapters/local/web.go @@ -17,23 +17,18 @@ func NewWebAdapter(hub *channel.SessionHub) *WebAdapter { } func (a *WebAdapter) Type() channel.ChannelType { - return channel.ChannelWeb -} - -func (a *WebAdapter) Start(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.AdapterRunner, error) { - return channel.AdapterRunner{SupportsStop: false}, nil + return WebType } func (a *WebAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { if a.hub == nil { return fmt.Errorf("web hub not configured") } - target := strings.TrimSpace(msg.To) + target := strings.TrimSpace(msg.Target) if target == "" { return fmt.Errorf("web target is required") } - text := strings.TrimSpace(msg.Text) - if text == "" { + if msg.Message.IsEmpty() { return fmt.Errorf("message is required") } a.hub.Publish(target, msg) diff --git a/internal/channel/adapters/telegram/config.go b/internal/channel/adapters/telegram/config.go new file mode 100644 index 00000000..dd972838 --- /dev/null +++ b/internal/channel/adapters/telegram/config.go @@ -0,0 +1,158 @@ +package telegram + +import ( + "fmt" + "strings" + + "github.com/memohai/memoh/internal/channel" +) + +type Config struct { + BotToken string +} + +type UserConfig struct { + Username string + UserID string + ChatID string +} + +func NormalizeConfig(raw map[string]any) (map[string]any, error) { + cfg, err := parseConfig(raw) + if err != nil { + return nil, err + } + return map[string]any{ + "botToken": cfg.BotToken, + }, nil +} + +func NormalizeUserConfig(raw map[string]any) (map[string]any, error) { + cfg, err := parseUserConfig(raw) + if err != nil { + return nil, err + } + result := map[string]any{} + if cfg.Username != "" { + result["username"] = cfg.Username + } + if cfg.UserID != "" { + result["user_id"] = cfg.UserID + } + if cfg.ChatID != "" { + result["chat_id"] = cfg.ChatID + } + return result, nil +} + +func ResolveTarget(raw map[string]any) (string, error) { + cfg, err := parseUserConfig(raw) + if err != nil { + return "", err + } + if cfg.ChatID != "" { + return cfg.ChatID, nil + } + if cfg.UserID != "" { + return cfg.UserID, nil + } + if cfg.Username != "" { + name := cfg.Username + if !strings.HasPrefix(name, "@") { + name = "@" + name + } + return name, nil + } + return "", fmt.Errorf("telegram binding is incomplete") +} + +func MatchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { + cfg, err := parseUserConfig(raw) + if err != nil { + return false + } + if value := strings.TrimSpace(criteria.Attribute("chat_id")); value != "" && value == cfg.ChatID { + return true + } + if value := strings.TrimSpace(criteria.Attribute("user_id")); value != "" && value == cfg.UserID { + return true + } + if value := strings.TrimSpace(criteria.Attribute("username")); value != "" && strings.EqualFold(value, cfg.Username) { + return true + } + if criteria.ExternalID != "" { + if criteria.ExternalID == cfg.ChatID || criteria.ExternalID == cfg.UserID || strings.EqualFold(criteria.ExternalID, cfg.Username) { + return true + } + } + return false +} + +func BuildUserConfig(identity channel.Identity) map[string]any { + result := map[string]any{} + if value := strings.TrimSpace(identity.Attribute("username")); value != "" { + result["username"] = value + } + if value := strings.TrimSpace(identity.Attribute("user_id")); value != "" { + result["user_id"] = value + } + if value := strings.TrimSpace(identity.Attribute("chat_id")); value != "" { + result["chat_id"] = value + } + return result +} + +func parseConfig(raw map[string]any) (Config, error) { + token := strings.TrimSpace(channel.ReadString(raw, "botToken", "bot_token")) + if token == "" { + return Config{}, fmt.Errorf("telegram botToken is required") + } + return Config{BotToken: token}, nil +} + +func parseUserConfig(raw map[string]any) (UserConfig, error) { + username := strings.TrimSpace(channel.ReadString(raw, "username")) + userID := strings.TrimSpace(channel.ReadString(raw, "userId", "user_id")) + chatID := strings.TrimSpace(channel.ReadString(raw, "chatId", "chat_id")) + if username == "" && userID == "" && chatID == "" { + return UserConfig{}, fmt.Errorf("telegram user config requires username, user_id, or chat_id") + } + return UserConfig{ + Username: username, + UserID: userID, + ChatID: chatID, + }, nil +} + +func normalizeTarget(raw string) string { + value := strings.TrimSpace(raw) + if value == "" { + return "" + } + if strings.HasPrefix(value, "@") { + return value + } + value = strings.TrimPrefix(value, "tg:") + value = strings.TrimPrefix(value, "telegram:") + value = strings.TrimPrefix(value, "t.me/") + value = strings.TrimPrefix(value, "https://t.me/") + value = strings.TrimPrefix(value, "http://t.me/") + value = strings.TrimSpace(value) + if value == "" { + return "" + } + if strings.HasPrefix(value, "@") { + return value + } + isNumeric := true + for _, r := range value { + if r < '0' || r > '9' { + isNumeric = false + break + } + } + if isNumeric { + return value + } + return "@" + value +} diff --git a/internal/channel/adapters/telegram/config_test.go b/internal/channel/adapters/telegram/config_test.go new file mode 100644 index 00000000..c3d28165 --- /dev/null +++ b/internal/channel/adapters/telegram/config_test.go @@ -0,0 +1,88 @@ +package telegram + +import "testing" + +func TestNormalizeConfig(t *testing.T) { + t.Parallel() + + got, err := NormalizeConfig(map[string]any{ + "bot_token": "token-123", + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if got["botToken"] != "token-123" { + t.Fatalf("unexpected botToken: %#v", got["botToken"]) + } +} + +func TestNormalizeConfigRequiresToken(t *testing.T) { + t.Parallel() + + _, err := NormalizeConfig(map[string]any{}) + if err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestNormalizeUserConfig(t *testing.T) { + t.Parallel() + + got, err := NormalizeUserConfig(map[string]any{ + "username": "alice", + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if got["username"] != "alice" { + t.Fatalf("unexpected username: %#v", got["username"]) + } +} + +func TestNormalizeUserConfigRequiresBinding(t *testing.T) { + t.Parallel() + + _, err := NormalizeUserConfig(map[string]any{}) + if err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestResolveTarget(t *testing.T) { + t.Parallel() + + target, err := ResolveTarget(map[string]any{ + "chat_id": "123", + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if target != "123" { + t.Fatalf("unexpected target: %s", target) + } +} + +func TestResolveTargetUsername(t *testing.T) { + t.Parallel() + + target, err := ResolveTarget(map[string]any{ + "username": "alice", + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if target != "@alice" { + t.Fatalf("unexpected target: %s", target) + } +} + +func TestNormalizeTarget(t *testing.T) { + t.Parallel() + + if got := normalizeTarget("https://t.me/alice"); got != "@alice" { + t.Fatalf("unexpected normalized target: %s", got) + } + if got := normalizeTarget("@alice"); got != "@alice" { + t.Fatalf("unexpected normalized target: %s", got) + } +} diff --git a/internal/channel/adapters/telegram/descriptor.go b/internal/channel/adapters/telegram/descriptor.go index 8cbf014b..2d8dd4bb 100644 --- a/internal/channel/adapters/telegram/descriptor.go +++ b/internal/channel/adapters/telegram/descriptor.go @@ -2,11 +2,49 @@ package telegram import "github.com/memohai/memoh/internal/channel" +const Type channel.ChannelType = "telegram" + func init() { channel.MustRegisterChannel(channel.ChannelDescriptor{ - Type: channel.ChannelTelegram, + Type: Type, DisplayName: "Telegram", - NormalizeConfig: channel.NormalizeTelegramConfig, - NormalizeUserConfig: channel.NormalizeTelegramUserConfig, + NormalizeConfig: NormalizeConfig, + NormalizeUserConfig: NormalizeUserConfig, + ResolveTarget: ResolveTarget, + MatchBinding: MatchBinding, + BuildUserConfig: BuildUserConfig, + TargetSpec: channel.TargetSpec{ + Format: "chat_id | @username", + Hints: []channel.TargetHint{ + {Label: "Chat ID", Example: "123456789"}, + {Label: "Username", Example: "@alice"}, + }, + }, + NormalizeTarget: normalizeTarget, + Capabilities: channel.ChannelCapabilities{ + Text: true, + Markdown: true, + Reply: true, + Attachments: true, + Media: true, + }, + ConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "botToken": { + Type: channel.FieldSecret, + Required: true, + Title: "Bot Token", + }, + }, + }, + UserConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "username": {Type: channel.FieldString}, + "user_id": {Type: channel.FieldString}, + "chat_id": {Type: channel.FieldString}, + }, + }, }) } diff --git a/internal/channel/adapters/telegram/telegram.go b/internal/channel/adapters/telegram/telegram.go index ae059685..9595bd0f 100644 --- a/internal/channel/adapters/telegram/telegram.go +++ b/internal/channel/adapters/telegram/telegram.go @@ -2,11 +2,11 @@ package telegram import ( "context" - "encoding/json" "fmt" "log/slog" "strconv" "strings" + "time" tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" @@ -28,35 +28,36 @@ func NewTelegramAdapter(log *slog.Logger) *TelegramAdapter { } func (a *TelegramAdapter) Type() channel.ChannelType { - return channel.ChannelTelegram + return Type } -func (a *TelegramAdapter) Start(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.AdapterRunner, error) { +func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.Connection, error) { if a.logger != nil { a.logger.Info("start", slog.String("config_id", cfg.ID)) } - telegramCfg, err := decodeTelegramConfig(cfg.Credentials) + telegramCfg, err := parseConfig(cfg.Credentials) if err != nil { if a.logger != nil { a.logger.Error("decode config failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) } - return channel.AdapterRunner{}, err + return nil, err } bot, err := tgbotapi.NewBotAPI(telegramCfg.BotToken) if err != nil { if a.logger != nil { a.logger.Error("create bot failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) } - return channel.AdapterRunner{}, err + return nil, err } updateConfig := tgbotapi.NewUpdate(0) updateConfig.Timeout = 30 updates := bot.GetUpdatesChan(updateConfig) + connCtx, cancel := context.WithCancel(ctx) go func() { for { select { - case <-ctx.Done(): + case <-connCtx.Done(): if a.logger != nil { a.logger.Info("stop", slog.String("config_id", cfg.ID)) } @@ -73,34 +74,61 @@ func (a *TelegramAdapter) Start(ctx context.Context, cfg channel.ChannelConfig, continue } text := strings.TrimSpace(update.Message.Text) - if text == "" { + caption := strings.TrimSpace(update.Message.Caption) + if text == "" && caption != "" { + text = caption + } + attachments := a.collectTelegramAttachments(bot, update.Message) + if text == "" && len(attachments) == 0 { continue } - userID, username := resolveTelegramSender(update.Message.From) - chatID := strconv.FormatInt(update.Message.Chat.ID, 10) + externalID, displayName, attrs := resolveTelegramSender(update.Message) + chatID := "" + chatType := "" + chatName := "" + if update.Message.Chat != nil { + chatID = strconv.FormatInt(update.Message.Chat.ID, 10) + chatType = strings.TrimSpace(update.Message.Chat.Type) + chatName = strings.TrimSpace(update.Message.Chat.Title) + } + replyRef := buildTelegramReplyRef(update.Message, chatID) msg := channel.InboundMessage{ - Channel: channel.ChannelTelegram, - Text: text, - Username: username, - UserID: userID, - ChatID: chatID, - ChatType: update.Message.Chat.Type, - ReplyTo: chatID, - BotID: cfg.BotID, + Channel: Type, + Message: channel.Message{ + ID: strconv.Itoa(update.Message.MessageID), + Format: channel.MessageFormatPlain, + Text: text, + Attachments: attachments, + Reply: replyRef, + }, + BotID: cfg.BotID, + ReplyTarget: chatID, + Sender: channel.Identity{ + ExternalID: externalID, + DisplayName: displayName, + Attributes: attrs, + }, + Conversation: channel.Conversation{ + ID: chatID, + Type: chatType, + Name: chatName, + }, + ReceivedAt: time.Unix(int64(update.Message.Date), 0).UTC(), + Source: "telegram", } if a.logger != nil { a.logger.Info( "inbound received", slog.String("config_id", cfg.ID), - slog.String("chat_type", msg.ChatType), - slog.String("chat_id", msg.ChatID), - slog.String("user_id", msg.UserID), - slog.String("username", msg.Username), - slog.String("text", common.SummarizeText(msg.Text)), + slog.String("chat_type", msg.Conversation.Type), + slog.String("chat_id", msg.Conversation.ID), + slog.String("user_id", attrs["user_id"]), + slog.String("username", attrs["username"]), + slog.String("text", common.SummarizeText(text)), ) } go func() { - if err := handler(ctx, cfg, msg); err != nil && a.logger != nil { + if err := handler(connCtx, cfg, msg); err != nil && a.logger != nil { a.logger.Error("handle inbound failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) } }() @@ -108,26 +136,26 @@ func (a *TelegramAdapter) Start(ctx context.Context, cfg channel.ChannelConfig, } }() - return channel.AdapterRunner{ - Stop: func() { - if a.logger != nil { - a.logger.Info("stop", slog.String("config_id", cfg.ID)) - } - bot.StopReceivingUpdates() - }, - SupportsStop: true, - }, nil + stop := func(context.Context) error { + if a.logger != nil { + a.logger.Info("stop", slog.String("config_id", cfg.ID)) + } + cancel() + bot.StopReceivingUpdates() + return nil + } + return channel.NewConnection(cfg, stop), nil } func (a *TelegramAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { - telegramCfg, err := decodeTelegramConfig(cfg.Credentials) + telegramCfg, err := parseConfig(cfg.Credentials) if err != nil { if a.logger != nil { a.logger.Error("decode config failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) } return err } - to := strings.TrimSpace(msg.To) + to := strings.TrimSpace(msg.Target) if to == "" { return fmt.Errorf("telegram target is required") } @@ -138,41 +166,394 @@ func (a *TelegramAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, m } return err } - text := strings.TrimSpace(msg.Text) - if text == "" { + if msg.Message.IsEmpty() { return fmt.Errorf("message is required") } - if strings.HasPrefix(to, "@") { - message := tgbotapi.NewMessageToChannel(to, text) - _, err = bot.Send(message) - if err != nil && a.logger != nil { - a.logger.Error("send failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) + text := strings.TrimSpace(msg.Message.PlainText()) + parseMode := resolveTelegramParseMode(msg.Message.Format) + replyTo := parseReplyToMessageID(msg.Message.Reply) + if len(msg.Message.Attachments) > 0 { + usedCaption := false + for i, att := range msg.Message.Attachments { + caption := "" + if !usedCaption && text != "" { + caption = text + usedCaption = true + } + applyReply := replyTo + if i > 0 { + applyReply = 0 + } + if err := sendTelegramAttachment(bot, to, att, caption, applyReply, parseMode); err != nil { + if a.logger != nil { + a.logger.Error("send attachment failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) + } + return err + } } + if text != "" && !usedCaption { + return sendTelegramText(bot, to, text, replyTo, parseMode) + } + return nil + } + return sendTelegramText(bot, to, text, replyTo, parseMode) +} + +func resolveTelegramSender(msg *tgbotapi.Message) (string, string, map[string]string) { + attrs := map[string]string{} + if msg == nil { + return "", "", attrs + } + if msg.Chat != nil { + attrs["chat_id"] = strconv.FormatInt(msg.Chat.ID, 10) + } + if msg.From != nil { + userID := strconv.FormatInt(msg.From.ID, 10) + username := strings.TrimSpace(msg.From.UserName) + if userID != "" { + attrs["user_id"] = userID + } + if username != "" { + attrs["username"] = username + } + displayName := strings.TrimSpace(msg.From.UserName) + if displayName == "" { + displayName = strings.TrimSpace(strings.TrimSpace(msg.From.FirstName + " " + msg.From.LastName)) + } + externalID := userID + if externalID == "" { + externalID = username + } + return externalID, displayName, attrs + } + if msg.SenderChat != nil { + senderChatID := strconv.FormatInt(msg.SenderChat.ID, 10) + if senderChatID != "" { + attrs["sender_chat_id"] = senderChatID + } + if msg.SenderChat.UserName != "" { + attrs["sender_chat_username"] = strings.TrimSpace(msg.SenderChat.UserName) + } + if msg.SenderChat.Title != "" { + attrs["sender_chat_title"] = strings.TrimSpace(msg.SenderChat.Title) + } + displayName := strings.TrimSpace(msg.SenderChat.Title) + if displayName == "" { + displayName = strings.TrimSpace(msg.SenderChat.UserName) + } + externalID := senderChatID + if externalID == "" { + externalID = attrs["sender_chat_username"] + } + if externalID == "" { + externalID = attrs["chat_id"] + } + return externalID, displayName, attrs + } + return "", "", attrs +} + +func parseReplyToMessageID(reply *channel.ReplyRef) int { + if reply == nil { + return 0 + } + raw := strings.TrimSpace(reply.MessageID) + if raw == "" { + return 0 + } + value, err := strconv.Atoi(raw) + if err != nil { + return 0 + } + return value +} + +func sendTelegramText(bot *tgbotapi.BotAPI, target string, text string, replyTo int, parseMode string) error { + if strings.HasPrefix(target, "@") { + message := tgbotapi.NewMessageToChannel(target, text) + message.ParseMode = parseMode + if replyTo > 0 { + message.ReplyToMessageID = replyTo + } + _, err := bot.Send(message) return err } - chatID, err := strconv.ParseInt(to, 10, 64) + chatID, err := strconv.ParseInt(target, 10, 64) if err != nil { return fmt.Errorf("telegram target must be @username or chat_id") } message := tgbotapi.NewMessage(chatID, text) - _, err = bot.Send(message) - if err != nil && a.logger != nil { - a.logger.Error("send failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) + message.ParseMode = parseMode + if replyTo > 0 { + message.ReplyToMessageID = replyTo } + _, err = bot.Send(message) return err } -func resolveTelegramSender(user *tgbotapi.User) (string, string) { - if user == nil { - return "", "" +func sendTelegramAttachment(bot *tgbotapi.BotAPI, target string, att channel.Attachment, caption string, replyTo int, parseMode string) error { + if strings.TrimSpace(att.URL) == "" { + return fmt.Errorf("attachment url is required") + } + if strings.TrimSpace(caption) == "" && strings.TrimSpace(att.Caption) != "" { + caption = strings.TrimSpace(att.Caption) + } + file := tgbotapi.FileURL(att.URL) + isChannel := strings.HasPrefix(target, "@") + switch att.Type { + case channel.AttachmentImage: + var photo tgbotapi.PhotoConfig + if isChannel { + photo = tgbotapi.NewPhotoToChannel(target, file) + } else { + chatID, err := strconv.ParseInt(target, 10, 64) + if err != nil { + return fmt.Errorf("telegram target must be @username or chat_id") + } + photo = tgbotapi.NewPhoto(chatID, file) + } + photo.Caption = caption + photo.ParseMode = parseMode + if replyTo > 0 { + photo.ReplyToMessageID = replyTo + } + _, err := bot.Send(photo) + return err + case channel.AttachmentFile, "": + chatID, err := strconv.ParseInt(target, 10, 64) + if err != nil && !isChannel { + return fmt.Errorf("telegram target must be @username or chat_id") + } + document := tgbotapi.NewDocument(chatID, file) + if isChannel { + document.ChatID = 0 + document.ChannelUsername = target + } + document.Caption = caption + document.ParseMode = parseMode + if replyTo > 0 { + document.ReplyToMessageID = replyTo + } + _, err = bot.Send(document) + return err + case channel.AttachmentAudio: + audio, err := buildTelegramAudio(target, file) + if err != nil { + return err + } + audio.Caption = caption + audio.ParseMode = parseMode + if replyTo > 0 { + audio.ReplyToMessageID = replyTo + } + _, err = bot.Send(audio) + return err + case channel.AttachmentVoice: + voice, err := buildTelegramVoice(target, file) + if err != nil { + return err + } + voice.Caption = caption + voice.ParseMode = parseMode + if replyTo > 0 { + voice.ReplyToMessageID = replyTo + } + _, err = bot.Send(voice) + return err + case channel.AttachmentVideo: + video, err := buildTelegramVideo(target, file) + if err != nil { + return err + } + video.Caption = caption + video.ParseMode = parseMode + if replyTo > 0 { + video.ReplyToMessageID = replyTo + } + _, err = bot.Send(video) + return err + case channel.AttachmentGIF: + animation, err := buildTelegramAnimation(target, file) + if err != nil { + return err + } + animation.Caption = caption + animation.ParseMode = parseMode + if replyTo > 0 { + animation.ReplyToMessageID = replyTo + } + _, err = bot.Send(animation) + return err + default: + return fmt.Errorf("unsupported attachment type: %s", att.Type) } - return strconv.FormatInt(user.ID, 10), strings.TrimSpace(user.UserName) } -func decodeTelegramConfig(raw map[string]interface{}) (channel.TelegramConfig, error) { - payload, err := json.Marshal(raw) - if err != nil { - return channel.TelegramConfig{}, err +func buildTelegramReplyRef(msg *tgbotapi.Message, chatID string) *channel.ReplyRef { + if msg == nil || msg.ReplyToMessage == nil { + return nil + } + return &channel.ReplyRef{ + MessageID: strconv.Itoa(msg.ReplyToMessage.MessageID), + Target: strings.TrimSpace(chatID), } - return channel.DecodeTelegramConfig(payload) +} + +func buildTelegramAudio(target string, file tgbotapi.RequestFileData) (tgbotapi.AudioConfig, error) { + if strings.HasPrefix(target, "@") { + audio := tgbotapi.NewAudio(0, file) + audio.ChannelUsername = target + return audio, nil + } + chatID, err := strconv.ParseInt(target, 10, 64) + if err != nil { + return tgbotapi.AudioConfig{}, fmt.Errorf("telegram target must be @username or chat_id") + } + return tgbotapi.NewAudio(chatID, file), nil +} + +func buildTelegramVoice(target string, file tgbotapi.RequestFileData) (tgbotapi.VoiceConfig, error) { + if strings.HasPrefix(target, "@") { + voice := tgbotapi.NewVoice(0, file) + voice.ChannelUsername = target + return voice, nil + } + chatID, err := strconv.ParseInt(target, 10, 64) + if err != nil { + return tgbotapi.VoiceConfig{}, fmt.Errorf("telegram target must be @username or chat_id") + } + return tgbotapi.NewVoice(chatID, file), nil +} + +func buildTelegramVideo(target string, file tgbotapi.RequestFileData) (tgbotapi.VideoConfig, error) { + if strings.HasPrefix(target, "@") { + video := tgbotapi.NewVideo(0, file) + video.ChannelUsername = target + return video, nil + } + chatID, err := strconv.ParseInt(target, 10, 64) + if err != nil { + return tgbotapi.VideoConfig{}, fmt.Errorf("telegram target must be @username or chat_id") + } + return tgbotapi.NewVideo(chatID, file), nil +} + +func buildTelegramAnimation(target string, file tgbotapi.RequestFileData) (tgbotapi.AnimationConfig, error) { + if strings.HasPrefix(target, "@") { + animation := tgbotapi.NewAnimation(0, file) + animation.ChannelUsername = target + return animation, nil + } + chatID, err := strconv.ParseInt(target, 10, 64) + if err != nil { + return tgbotapi.AnimationConfig{}, fmt.Errorf("telegram target must be @username or chat_id") + } + return tgbotapi.NewAnimation(chatID, file), nil +} + +func resolveTelegramParseMode(format channel.MessageFormat) string { + switch format { + case channel.MessageFormatMarkdown: + return tgbotapi.ModeMarkdown + default: + return "" + } +} + +func (a *TelegramAdapter) collectTelegramAttachments(bot *tgbotapi.BotAPI, msg *tgbotapi.Message) []channel.Attachment { + if msg == nil { + return nil + } + attachments := make([]channel.Attachment, 0, 1) + if len(msg.Photo) > 0 { + photo := pickTelegramPhoto(msg.Photo) + att := a.buildTelegramAttachment(bot, channel.AttachmentImage, photo.FileID, "", "", int64(photo.FileSize)) + att.Width = photo.Width + att.Height = photo.Height + attachments = append(attachments, att) + } + if msg.Document != nil { + att := a.buildTelegramAttachment(bot, channel.AttachmentFile, msg.Document.FileID, msg.Document.FileName, msg.Document.MimeType, int64(msg.Document.FileSize)) + attachments = append(attachments, att) + } + if msg.Audio != nil { + att := a.buildTelegramAttachment(bot, channel.AttachmentAudio, msg.Audio.FileID, msg.Audio.FileName, msg.Audio.MimeType, int64(msg.Audio.FileSize)) + att.DurationMs = int64(msg.Audio.Duration) * 1000 + attachments = append(attachments, att) + } + if msg.Voice != nil { + att := a.buildTelegramAttachment(bot, channel.AttachmentVoice, msg.Voice.FileID, "", msg.Voice.MimeType, int64(msg.Voice.FileSize)) + att.DurationMs = int64(msg.Voice.Duration) * 1000 + attachments = append(attachments, att) + } + if msg.Video != nil { + att := a.buildTelegramAttachment(bot, channel.AttachmentVideo, msg.Video.FileID, msg.Video.FileName, msg.Video.MimeType, int64(msg.Video.FileSize)) + att.Width = msg.Video.Width + att.Height = msg.Video.Height + att.DurationMs = int64(msg.Video.Duration) * 1000 + attachments = append(attachments, att) + } + if msg.Animation != nil { + att := a.buildTelegramAttachment(bot, channel.AttachmentGIF, msg.Animation.FileID, msg.Animation.FileName, msg.Animation.MimeType, int64(msg.Animation.FileSize)) + att.Width = msg.Animation.Width + att.Height = msg.Animation.Height + att.DurationMs = int64(msg.Animation.Duration) * 1000 + attachments = append(attachments, att) + } + if msg.Sticker != nil { + att := a.buildTelegramAttachment(bot, channel.AttachmentImage, msg.Sticker.FileID, "", "", int64(msg.Sticker.FileSize)) + att.Width = msg.Sticker.Width + att.Height = msg.Sticker.Height + attachments = append(attachments, att) + } + caption := strings.TrimSpace(msg.Caption) + if caption != "" { + for i := range attachments { + attachments[i].Caption = caption + } + } + return attachments +} + +func (a *TelegramAdapter) buildTelegramAttachment(bot *tgbotapi.BotAPI, attType channel.AttachmentType, fileID, name, mime string, size int64) channel.Attachment { + url := "" + if bot != nil && strings.TrimSpace(fileID) != "" { + value, err := bot.GetFileDirectURL(fileID) + if err != nil { + if a.logger != nil { + a.logger.Warn("resolve file url failed", slog.Any("error", err)) + } + } else { + url = value + } + } + att := channel.Attachment{ + Type: attType, + URL: strings.TrimSpace(url), + Name: strings.TrimSpace(name), + Mime: strings.TrimSpace(mime), + Size: size, + Metadata: map[string]any{}, + } + if fileID != "" { + att.Metadata["file_id"] = fileID + } + return att +} + +func pickTelegramPhoto(items []tgbotapi.PhotoSize) tgbotapi.PhotoSize { + if len(items) == 0 { + return tgbotapi.PhotoSize{} + } + best := items[0] + for _, item := range items[1:] { + if item.FileSize > best.FileSize { + best = item + continue + } + if item.Width*item.Height > best.Width*best.Height { + best = item + } + } + return best } diff --git a/internal/channel/adapters/telegram/telegram_test.go b/internal/channel/adapters/telegram/telegram_test.go index f848275b..6d3a5834 100644 --- a/internal/channel/adapters/telegram/telegram_test.go +++ b/internal/channel/adapters/telegram/telegram_test.go @@ -9,13 +9,18 @@ import ( func TestResolveTelegramSender(t *testing.T) { t.Parallel() - id, name := resolveTelegramSender(nil) - if id != "" || name != "" { + externalID, displayName, attrs := resolveTelegramSender(nil) + if externalID != "" || displayName != "" || len(attrs) != 0 { t.Fatalf("expected empty sender") } - user := &tgbotapi.User{ID: 123, UserName: "alice"} - id, name = resolveTelegramSender(user) - if id != "123" || name != "alice" { - t.Fatalf("unexpected sender: %s %s", id, name) + msg := &tgbotapi.Message{ + From: &tgbotapi.User{ID: 123, UserName: "alice"}, + } + externalID, displayName, attrs = resolveTelegramSender(msg) + if externalID != "123" || displayName != "alice" { + t.Fatalf("unexpected sender: %s %s", externalID, displayName) + } + if attrs["user_id"] != "123" || attrs["username"] != "alice" { + t.Fatalf("unexpected attrs: %#v", attrs) } } diff --git a/internal/channel/capabilities.go b/internal/channel/capabilities.go new file mode 100644 index 00000000..b72a7af2 --- /dev/null +++ b/internal/channel/capabilities.go @@ -0,0 +1,22 @@ +package channel + +// ChannelCapabilities 描述通道在功能层面的能力矩阵。 +// 该结构用于上层自适应逻辑,不依赖具体适配器实现。 +type ChannelCapabilities struct { + Text bool `json:"text"` + Markdown bool `json:"markdown"` + RichText bool `json:"rich_text"` + Attachments bool `json:"attachments"` + Media bool `json:"media"` + Reactions bool `json:"reactions"` + Buttons bool `json:"buttons"` + Reply bool `json:"reply"` + Threads bool `json:"threads"` + Streaming bool `json:"streaming"` + Polls bool `json:"polls"` + Edit bool `json:"edit"` + Unsend bool `json:"unsend"` + NativeCommands bool `json:"native_commands"` + BlockStreaming bool `json:"block_streaming"` + ChatTypes []string `json:"chat_types,omitempty"` +} diff --git a/internal/channel/config.go b/internal/channel/config.go index 759d199d..9f8014bc 100644 --- a/internal/channel/config.go +++ b/internal/channel/config.go @@ -6,31 +6,9 @@ import ( "strings" ) -type TelegramConfig struct { - BotToken string -} - -type TelegramUserConfig struct { - Username string - UserID string - ChatID string -} - -type FeishuConfig struct { - AppID string - AppSecret string - EncryptKey string - VerificationToken string -} - -type FeishuUserConfig struct { - OpenID string - UserID string -} - -func NormalizeChannelConfig(channelType ChannelType, raw map[string]interface{}) (map[string]interface{}, error) { +func NormalizeChannelConfig(channelType ChannelType, raw map[string]any) (map[string]any, error) { if raw == nil { - raw = map[string]interface{}{} + raw = map[string]any{} } desc, ok := GetChannelDescriptor(channelType) if !ok { @@ -42,9 +20,9 @@ func NormalizeChannelConfig(channelType ChannelType, raw map[string]interface{}) return desc.NormalizeConfig(raw) } -func NormalizeChannelUserConfig(channelType ChannelType, raw map[string]interface{}) (map[string]interface{}, error) { +func NormalizeChannelUserConfig(channelType ChannelType, raw map[string]any) (map[string]any, error) { if raw == nil { - raw = map[string]interface{}{} + raw = map[string]any{} } desc, ok := GetChannelDescriptor(channelType) if !ok { @@ -56,162 +34,45 @@ func NormalizeChannelUserConfig(channelType ChannelType, raw map[string]interfac return desc.NormalizeUserConfig(raw) } -func NormalizeTelegramConfig(raw map[string]interface{}) (map[string]interface{}, error) { - cfg, err := parseTelegramConfig(raw) - if err != nil { - return nil, err +func ResolveTargetFromUserConfig(channelType ChannelType, config map[string]any) (string, error) { + desc, ok := GetChannelDescriptor(channelType) + if !ok || desc.ResolveTarget == nil { + return "", fmt.Errorf("unsupported channel type: %s", channelType) } - return map[string]interface{}{ - "botToken": cfg.BotToken, - }, nil + return desc.ResolveTarget(config) } -func NormalizeTelegramUserConfig(raw map[string]interface{}) (map[string]interface{}, error) { - cfg, err := parseTelegramUserConfig(raw) - if err != nil { - return nil, err +func MatchUserBinding(channelType ChannelType, config map[string]any, criteria BindingCriteria) bool { + desc, ok := GetChannelDescriptor(channelType) + if !ok || desc.MatchBinding == nil { + return false } - result := map[string]interface{}{} - if cfg.Username != "" { - result["username"] = cfg.Username - } - if cfg.UserID != "" { - result["user_id"] = cfg.UserID - } - if cfg.ChatID != "" { - result["chat_id"] = cfg.ChatID - } - return result, nil + return desc.MatchBinding(config, criteria) } -func NormalizeFeishuConfig(raw map[string]interface{}) (map[string]interface{}, error) { - cfg, err := parseFeishuConfig(raw) - if err != nil { - return nil, err +func BuildUserBindingConfig(channelType ChannelType, identity Identity) map[string]any { + desc, ok := GetChannelDescriptor(channelType) + if !ok || desc.BuildUserConfig == nil { + return map[string]any{} } - result := map[string]interface{}{ - "appId": cfg.AppID, - "appSecret": cfg.AppSecret, - } - if cfg.EncryptKey != "" { - result["encryptKey"] = cfg.EncryptKey - } - if cfg.VerificationToken != "" { - result["verificationToken"] = cfg.VerificationToken - } - return result, nil + return desc.BuildUserConfig(identity) } -func NormalizeFeishuUserConfig(raw map[string]interface{}) (map[string]interface{}, error) { - cfg, err := parseFeishuUserConfig(raw) - if err != nil { - return nil, err - } - result := map[string]interface{}{} - if cfg.OpenID != "" { - result["open_id"] = cfg.OpenID - } - if cfg.UserID != "" { - result["user_id"] = cfg.UserID - } - return result, nil -} - -func DecodeTelegramConfig(raw []byte) (TelegramConfig, error) { - payload, err := decodeConfigMap(raw) - if err != nil { - return TelegramConfig{}, err - } - return parseTelegramConfig(payload) -} - -func DecodeTelegramUserConfig(raw []byte) (TelegramUserConfig, error) { - payload, err := decodeConfigMap(raw) - if err != nil { - return TelegramUserConfig{}, err - } - return parseTelegramUserConfig(payload) -} - -func DecodeFeishuConfig(raw []byte) (FeishuConfig, error) { - payload, err := decodeConfigMap(raw) - if err != nil { - return FeishuConfig{}, err - } - return parseFeishuConfig(payload) -} - -func DecodeFeishuUserConfig(raw []byte) (FeishuUserConfig, error) { - payload, err := decodeConfigMap(raw) - if err != nil { - return FeishuUserConfig{}, err - } - return parseFeishuUserConfig(payload) -} - -func parseTelegramConfig(raw map[string]interface{}) (TelegramConfig, error) { - token := readString(raw, "botToken", "bot_token") - token = strings.TrimSpace(token) - if token == "" { - return TelegramConfig{}, fmt.Errorf("telegram botToken is required") - } - return TelegramConfig{BotToken: token}, nil -} - -func parseTelegramUserConfig(raw map[string]interface{}) (TelegramUserConfig, error) { - username := strings.TrimSpace(readString(raw, "username")) - userID := strings.TrimSpace(readString(raw, "userId", "user_id")) - chatID := strings.TrimSpace(readString(raw, "chatId", "chat_id")) - if username == "" && userID == "" && chatID == "" { - return TelegramUserConfig{}, fmt.Errorf("telegram user config requires username, user_id, or chat_id") - } - return TelegramUserConfig{ - Username: username, - UserID: userID, - ChatID: chatID, - }, nil -} - -func parseFeishuConfig(raw map[string]interface{}) (FeishuConfig, error) { - appID := strings.TrimSpace(readString(raw, "appId", "app_id")) - appSecret := strings.TrimSpace(readString(raw, "appSecret", "app_secret")) - encryptKey := strings.TrimSpace(readString(raw, "encryptKey", "encrypt_key")) - verificationToken := strings.TrimSpace(readString(raw, "verificationToken", "verification_token")) - if appID == "" || appSecret == "" { - return FeishuConfig{}, fmt.Errorf("feishu appId and appSecret are required") - } - return FeishuConfig{ - AppID: appID, - AppSecret: appSecret, - EncryptKey: encryptKey, - VerificationToken: verificationToken, - }, nil -} - -func parseFeishuUserConfig(raw map[string]interface{}) (FeishuUserConfig, error) { - openID := strings.TrimSpace(readString(raw, "openId", "open_id")) - userID := strings.TrimSpace(readString(raw, "userId", "user_id")) - if openID == "" && userID == "" { - return FeishuUserConfig{}, fmt.Errorf("feishu user config requires open_id or user_id") - } - return FeishuUserConfig{OpenID: openID, UserID: userID}, nil -} - -func decodeConfigMap(raw []byte) (map[string]interface{}, error) { +func DecodeConfigMap(raw []byte) (map[string]any, error) { if len(raw) == 0 { - return map[string]interface{}{}, nil + return map[string]any{}, nil } - var payload map[string]interface{} + var payload map[string]any if err := json.Unmarshal(raw, &payload); err != nil { return nil, err } if payload == nil { - payload = map[string]interface{}{} + payload = map[string]any{} } return payload, nil } -func readString(raw map[string]interface{}, keys ...string) string { +func ReadString(raw map[string]any, keys ...string) string { for _, key := range keys { if value, ok := raw[key]; ok { switch v := value.(type) { diff --git a/internal/channel/config_test.go b/internal/channel/config_test.go index 7775b6b3..1fae6855 100644 --- a/internal/channel/config_test.go +++ b/internal/channel/config_test.go @@ -1,91 +1,134 @@ -package channel +package channel_test -import "testing" +import ( + "fmt" + "sync" + "testing" -func TestNormalizeChannelConfigTelegram(t *testing.T) { - t.Parallel() + "github.com/memohai/memoh/internal/channel" +) - got, err := NormalizeChannelConfig(ChannelTelegram, map[string]interface{}{ - "bot_token": "token-123", +const testChannelType = channel.ChannelType("test-config") + +var registerTestChannelOnce sync.Once + +func registerTestChannel() { + registerTestChannelOnce.Do(func() { + if _, ok := channel.GetChannelDescriptor(testChannelType); ok { + return + } + _ = channel.RegisterChannel(channel.ChannelDescriptor{ + Type: testChannelType, + DisplayName: "Test", + NormalizeConfig: normalizeTestConfig, + NormalizeUserConfig: normalizeTestUserConfig, + ResolveTarget: resolveTestTarget, + MatchBinding: matchTestBinding, + Capabilities: channel.ChannelCapabilities{ + Text: true, + }, + ConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "value": {Type: channel.FieldString, Required: true}, + }, + }, + UserConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "user": {Type: channel.FieldString, Required: true}, + }, + }, + }) }) +} + +func normalizeTestConfig(raw map[string]any) (map[string]any, error) { + value := channel.ReadString(raw, "value") + if value == "" { + return nil, fmt.Errorf("value is required") + } + return map[string]any{"value": value}, nil +} + +func normalizeTestUserConfig(raw map[string]any) (map[string]any, error) { + value := channel.ReadString(raw, "user") + if value == "" { + return nil, fmt.Errorf("user is required") + } + return map[string]any{"user": value}, nil +} + +func resolveTestTarget(raw map[string]any) (string, error) { + value := channel.ReadString(raw, "target") + if value == "" { + return "", fmt.Errorf("target is required") + } + return "resolved:" + value, nil +} + +func matchTestBinding(raw map[string]any, criteria channel.BindingCriteria) bool { + value := channel.ReadString(raw, "user") + return value != "" && value == criteria.ExternalID +} + +func TestParseChannelType(t *testing.T) { + t.Parallel() + registerTestChannel() + + got, err := channel.ParseChannelType(" test-config ") if err != nil { t.Fatalf("expected no error, got %v", err) } - if got["botToken"] != "token-123" { - t.Fatalf("unexpected botToken: %#v", got["botToken"]) + if got != testChannelType { + t.Fatalf("unexpected channel type: %s", got) + } + if _, err := channel.ParseChannelType("unknown"); err == nil { + t.Fatalf("expected error, got nil") } } -func TestNormalizeChannelConfigTelegramRequiresToken(t *testing.T) { +func TestNormalizeChannelConfig(t *testing.T) { t.Parallel() + registerTestChannel() - _, err := NormalizeChannelConfig(ChannelTelegram, map[string]interface{}{}) + got, err := channel.NormalizeChannelConfig(testChannelType, map[string]any{"value": "ok"}) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if got["value"] != "ok" { + t.Fatalf("unexpected value: %#v", got["value"]) + } +} + +func TestNormalizeChannelConfigRequiresValue(t *testing.T) { + t.Parallel() + registerTestChannel() + + _, err := channel.NormalizeChannelConfig(testChannelType, map[string]any{}) if err == nil { t.Fatalf("expected error, got nil") } } -func TestNormalizeChannelConfigFeishu(t *testing.T) { +func TestNormalizeChannelUserConfig(t *testing.T) { t.Parallel() + registerTestChannel() - got, err := NormalizeChannelConfig(ChannelFeishu, map[string]interface{}{ - "app_id": "app", - "app_secret": "secret", - "encrypt_key": "enc", - "verification_token": "verify", - }) + got, err := channel.NormalizeChannelUserConfig(testChannelType, map[string]any{"user": "alice"}) if err != nil { t.Fatalf("expected no error, got %v", err) } - if got["appId"] != "app" || got["appSecret"] != "secret" { - t.Fatalf("unexpected feishu config: %#v", got) - } - if got["encryptKey"] != "enc" || got["verificationToken"] != "verify" { - t.Fatalf("unexpected feishu security config: %#v", got) + if got["user"] != "alice" { + t.Fatalf("unexpected user: %#v", got["user"]) } } -func TestNormalizeChannelUserConfigTelegram(t *testing.T) { +func TestNormalizeChannelUserConfigRequiresUser(t *testing.T) { t.Parallel() + registerTestChannel() - got, err := NormalizeChannelUserConfig(ChannelTelegram, map[string]interface{}{ - "username": "alice", - }) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if got["username"] != "alice" { - t.Fatalf("unexpected username: %#v", got["username"]) - } -} - -func TestNormalizeChannelUserConfigTelegramRequiresBinding(t *testing.T) { - t.Parallel() - - _, err := NormalizeChannelUserConfig(ChannelTelegram, map[string]interface{}{}) - if err == nil { - t.Fatalf("expected error, got nil") - } -} - -func TestNormalizeChannelUserConfigFeishu(t *testing.T) { - t.Parallel() - - got, err := NormalizeChannelUserConfig(ChannelFeishu, map[string]interface{}{ - "open_id": "ou_123", - }) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if got["open_id"] != "ou_123" { - t.Fatalf("unexpected open_id: %#v", got["open_id"]) - } -} - -func TestNormalizeChannelUserConfigFeishuRequiresBinding(t *testing.T) { - t.Parallel() - - _, err := NormalizeChannelUserConfig(ChannelFeishu, map[string]interface{}{}) + _, err := channel.NormalizeChannelUserConfig(testChannelType, map[string]any{}) if err == nil { t.Fatalf("expected error, got nil") } diff --git a/internal/channel/directory.go b/internal/channel/directory.go new file mode 100644 index 00000000..dc629dbd --- /dev/null +++ b/internal/channel/directory.go @@ -0,0 +1,32 @@ +package channel + +import "context" + +type DirectoryEntryKind string + +const ( + DirectoryEntryUser DirectoryEntryKind = "user" + DirectoryEntryGroup DirectoryEntryKind = "group" +) + +type DirectoryEntry struct { + Kind DirectoryEntryKind `json:"kind"` + ID string `json:"id"` + Name string `json:"name,omitempty"` + Handle string `json:"handle,omitempty"` + AvatarURL string `json:"avatar_url,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type DirectoryQuery struct { + Query string `json:"query,omitempty"` + Limit int `json:"limit,omitempty"` + Kind DirectoryEntryKind `json:"kind,omitempty"` +} + +type ChannelDirectoryAdapter interface { + ListPeers(ctx context.Context, cfg ChannelConfig, query DirectoryQuery) ([]DirectoryEntry, error) + ListGroups(ctx context.Context, cfg ChannelConfig, query DirectoryQuery) ([]DirectoryEntry, error) + ListGroupMembers(ctx context.Context, cfg ChannelConfig, groupID string, query DirectoryQuery) ([]DirectoryEntry, error) + ResolveTarget(ctx context.Context, cfg ChannelConfig, input string, kind DirectoryEntryKind) (DirectoryEntry, error) +} diff --git a/internal/channel/helpers_test.go b/internal/channel/helpers_test.go index 61dcad6c..05730c2d 100644 --- a/internal/channel/helpers_test.go +++ b/internal/channel/helpers_test.go @@ -6,73 +6,17 @@ import ( "github.com/google/uuid" ) -func TestParseChannelType(t *testing.T) { - t.Parallel() - - got, err := ParseChannelType(" Telegram ") - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if got != ChannelTelegram { - t.Fatalf("unexpected channel type: %s", got) - } - - if _, err := ParseChannelType("unknown"); err == nil { - t.Fatalf("expected error, got nil") - } -} - -func TestMatchTelegramBinding(t *testing.T) { - t.Parallel() - - cfg := TelegramUserConfig{ - Username: "Alice", - UserID: "u1", - ChatID: "c1", - } - if !matchTelegramBinding(cfg, BindingCriteria{ChatID: "c1"}) { - t.Fatalf("expected chat id match") - } - if !matchTelegramBinding(cfg, BindingCriteria{UserID: "u1"}) { - t.Fatalf("expected user id match") - } - if !matchTelegramBinding(cfg, BindingCriteria{Username: "alice"}) { - t.Fatalf("expected username match") - } - if matchTelegramBinding(cfg, BindingCriteria{Username: "bob"}) { - t.Fatalf("expected no match") - } -} - -func TestMatchFeishuBinding(t *testing.T) { - t.Parallel() - - cfg := FeishuUserConfig{ - OpenID: "ou_1", - UserID: "u_1", - } - if !matchFeishuBinding(cfg, BindingCriteria{OpenID: "ou_1"}) { - t.Fatalf("expected open_id match") - } - if !matchFeishuBinding(cfg, BindingCriteria{UserID: "u_1"}) { - t.Fatalf("expected user_id match") - } - if matchFeishuBinding(cfg, BindingCriteria{UserID: "u_2"}) { - t.Fatalf("expected no match") - } -} - func TestDecodeConfigMap(t *testing.T) { t.Parallel() - cfg, err := decodeConfigMap([]byte(`{"a":1}`)) + cfg, err := DecodeConfigMap([]byte(`{"a":1}`)) if err != nil { t.Fatalf("expected no error, got %v", err) } if cfg["a"] == nil { t.Fatalf("expected key in map") } - cfg, err = decodeConfigMap([]byte(`null`)) + cfg, err = DecodeConfigMap([]byte(`null`)) if err != nil { t.Fatalf("expected no error, got %v", err) } @@ -84,10 +28,10 @@ func TestDecodeConfigMap(t *testing.T) { func TestReadString(t *testing.T) { t.Parallel() - raw := map[string]interface{}{ + raw := map[string]any{ "bot_token": 123, } - got := readString(raw, "bot_token") + got := ReadString(raw, "bot_token") if got != "123" { t.Fatalf("unexpected value: %s", got) } @@ -105,27 +49,17 @@ func TestParseUUID(t *testing.T) { } } -func TestParseTelegramUserConfigTrims(t *testing.T) { +func TestBindingCriteriaFromIdentity(t *testing.T) { t.Parallel() - cfg, err := parseTelegramUserConfig(map[string]interface{}{ - "username": " alice ", + criteria := BindingCriteriaFromIdentity(Identity{ + ExternalID: "u1", + Attributes: map[string]string{"username": "alice"}, }) - if err != nil { - t.Fatalf("expected no error, got %v", err) + if criteria.ExternalID != "u1" { + t.Fatalf("unexpected external id: %s", criteria.ExternalID) } - if cfg.Username != "alice" { - t.Fatalf("unexpected username: %s", cfg.Username) - } -} - -func TestResolveTargetFromUserConfigMissing(t *testing.T) { - t.Parallel() - - if _, err := resolveTargetFromUserConfig(ChannelTelegram, map[string]interface{}{}); err == nil { - t.Fatalf("expected error, got nil") - } - if _, err := resolveTargetFromUserConfig(ChannelFeishu, map[string]interface{}{}); err == nil { - t.Fatalf("expected error, got nil") + if criteria.Attribute("username") != "alice" { + t.Fatalf("unexpected username: %s", criteria.Attribute("username")) } } diff --git a/internal/channel/manager.go b/internal/channel/manager.go index e24d4a90..3f88263d 100644 --- a/internal/channel/manager.go +++ b/internal/channel/manager.go @@ -2,6 +2,7 @@ package channel import ( "context" + "errors" "fmt" "log/slog" "strings" @@ -15,8 +16,9 @@ type ConfigStore interface { UpsertUserConfig(ctx context.Context, actorUserID string, channelType ChannelType, req UpsertUserConfigRequest) (ChannelUserBinding, error) ListConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelConfig, error) ResolveUserBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) + ListSessionsByBotPlatform(ctx context.Context, botID string, platform string) ([]ChannelSession, error) GetChannelSession(ctx context.Context, sessionID string) (ChannelSession, error) - UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string) error + UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string, replyTarget string, threadID string, metadata map[string]any) error } // Middleware 消息处理中间件定义 @@ -26,19 +28,25 @@ type Manager struct { service ConfigStore processor InboundProcessor adapters map[ChannelType]Adapter + senders map[ChannelType]Sender + receivers map[ChannelType]Receiver refreshInterval time.Duration logger *slog.Logger middlewares []Middleware - mu sync.Mutex - runners map[string]*runningAdapter + inboundQueue chan inboundTask + inboundWorkers int + inboundOnce sync.Once + inboundCtx context.Context + inboundCancel context.CancelFunc + adapterMu sync.RWMutex + mu sync.Mutex + connections map[string]*connectionEntry } -type runningAdapter struct { - adapter Adapter - config ChannelConfig - stop func() - supportsStop bool +type connectionEntry struct { + config ChannelConfig + connection Connection } func NewManager(log *slog.Logger, service ConfigStore, processor InboundProcessor) *Manager { @@ -49,10 +57,14 @@ func NewManager(log *slog.Logger, service ConfigStore, processor InboundProcesso service: service, processor: processor, adapters: map[ChannelType]Adapter{}, + senders: map[ChannelType]Sender{}, + receivers: map[ChannelType]Receiver{}, refreshInterval: 30 * time.Second, - runners: map[string]*runningAdapter{}, + connections: map[string]*connectionEntry{}, logger: log.With(slog.String("component", "channel")), middlewares: []Middleware{}, + inboundQueue: make(chan inboundTask, 256), + inboundWorkers: 4, } } @@ -65,16 +77,60 @@ func (m *Manager) RegisterAdapter(adapter Adapter) { if adapter == nil { return } + m.adapterMu.Lock() m.adapters[adapter.Type()] = adapter + if sender, ok := adapter.(Sender); ok { + m.senders[adapter.Type()] = sender + } + if receiver, ok := adapter.(Receiver); ok { + m.receivers[adapter.Type()] = receiver + } + m.adapterMu.Unlock() if m.logger != nil { m.logger.Info("adapter registered", slog.String("channel", adapter.Type().String())) } } +// AddAdapter 注册适配器并触发一次刷新(便于热插拔)。 +func (m *Manager) AddAdapter(ctx context.Context, adapter Adapter) { + m.RegisterAdapter(adapter) + if ctx != nil { + m.refresh(ctx) + } +} + +// RemoveAdapter 移除适配器并停止其连接(便于热插拔)。 +func (m *Manager) RemoveAdapter(ctx context.Context, channelType ChannelType) { + if ctx == nil { + ctx = context.Background() + } + normalized := normalizeChannelType(channelType.String()) + if normalized == "" { + return + } + m.mu.Lock() + for id, entry := range m.connections { + if entry != nil && entry.config.ChannelType == normalized { + if entry.connection != nil { + _ = entry.connection.Stop(ctx) + } + delete(m.connections, id) + } + } + m.mu.Unlock() + + m.adapterMu.Lock() + delete(m.adapters, normalized) + delete(m.senders, normalized) + delete(m.receivers, normalized) + m.adapterMu.Unlock() +} + func (m *Manager) Start(ctx context.Context) { if m.logger != nil { m.logger.Info("manager start") } + m.startInboundWorkers(ctx) go func() { m.refresh(ctx) ticker := time.NewTicker(m.refreshInterval) @@ -85,7 +141,7 @@ func (m *Manager) Start(ctx context.Context) { if m.logger != nil { m.logger.Info("manager stop") } - m.stopAll() + m.stopAll(ctx) return case <-ticker.C: m.refresh(ctx) @@ -98,19 +154,21 @@ func (m *Manager) Send(ctx context.Context, botID string, channelType ChannelTyp if m.service == nil { return fmt.Errorf("channel manager not configured") } - adapter := m.adapters[channelType] - if adapter == nil { + m.adapterMu.RLock() + sender := m.senders[channelType] + m.adapterMu.RUnlock() + if sender == nil { return fmt.Errorf("unsupported channel type: %s", channelType) } config, err := m.service.ResolveEffectiveConfig(ctx, botID, channelType) if err != nil { return err } - target := strings.TrimSpace(req.To) + target := strings.TrimSpace(req.Target) if target == "" { - targetUserID := strings.TrimSpace(req.ToUserID) + targetUserID := strings.TrimSpace(req.UserID) if targetUserID == "" { - return fmt.Errorf("target user_id is required") + return fmt.Errorf("target or user_id is required") } userCfg, err := m.service.GetUserConfig(ctx, targetUserID, channelType) if err != nil { @@ -119,30 +177,62 @@ func (m *Manager) Send(ctx context.Context, botID string, channelType ChannelTyp } return fmt.Errorf("channel binding required") } - target, err = resolveTargetFromUserConfig(channelType, userCfg.Config) + target, err = ResolveTargetFromUserConfig(channelType, userCfg.Config) if err != nil { return err } } - text := strings.TrimSpace(req.Message) - if text == "" { + if normalized, ok := NormalizeTarget(channelType, target); ok { + target = normalized + } + if req.Message.IsEmpty() { return fmt.Errorf("message is required") } if m.logger != nil { m.logger.Info("send outbound", slog.String("channel", channelType.String()), slog.String("bot_id", botID)) } - err = adapter.Send(ctx, config, OutboundMessage{ - To: target, - Text: text, - }) - if err != nil && m.logger != nil { - m.logger.Error("send outbound failed", slog.String("channel", channelType.String()), slog.String("bot_id", botID), slog.Any("error", err)) + policy := m.resolveOutboundPolicy(channelType) + outbound, err := buildOutboundMessages(OutboundMessage{ + Target: target, + Message: req.Message, + }, policy) + if err != nil { + return err } - return err + for _, item := range outbound { + if err := m.sendWithConfig(ctx, sender, config, item, policy); err != nil { + if m.logger != nil { + m.logger.Error("send outbound failed", slog.String("channel", channelType.String()), slog.String("bot_id", botID), slog.Any("error", err)) + } + return err + } + } + return nil } func (m *Manager) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage) error { - return m.handleInbound(ctx, cfg, msg) + if m.processor == nil { + return fmt.Errorf("inbound processor not configured") + } + m.startInboundWorkers(ctx) + if m.inboundCtx != nil && m.inboundCtx.Err() != nil { + return fmt.Errorf("inbound dispatcher stopped") + } + taskCtx := ctx + if ctx != nil { + taskCtx = context.WithoutCancel(ctx) + } + task := inboundTask{ + ctx: taskCtx, + cfg: cfg, + msg: msg, + } + select { + case m.inboundQueue <- task: + return nil + default: + return fmt.Errorf("inbound queue full") + } } func (m *Manager) refresh(ctx context.Context) { @@ -150,7 +240,8 @@ func (m *Manager) refresh(ctx context.Context) { return } configs := make([]ChannelConfig, 0) - for channelType := range m.adapters { + channelTypes := m.listAdapterTypes() + for _, channelType := range channelTypes { items, err := m.service.ListConfigsByType(ctx, channelType) if err != nil { if m.logger != nil { @@ -174,7 +265,7 @@ func (m *Manager) reconcile(ctx context.Context, configs []ChannelConfig) { continue } active[cfg.ID] = cfg - if err := m.ensureRunner(ctx, cfg); err != nil { + if err := m.ensureConnection(ctx, cfg); err != nil { if m.logger != nil { m.logger.Error("adapter start failed", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID), slog.Any("error", err)) } @@ -183,48 +274,51 @@ func (m *Manager) reconcile(ctx context.Context, configs []ChannelConfig) { m.mu.Lock() defer m.mu.Unlock() - for id, runner := range m.runners { + for id, entry := range m.connections { if _, ok := active[id]; ok { continue } - if runner.supportsStop && runner.stop != nil { + if entry != nil && entry.connection != nil { if m.logger != nil { - m.logger.Info("adapter stop", slog.String("channel", runner.config.ChannelType.String()), slog.String("config_id", id)) + m.logger.Info("adapter stop", slog.String("channel", entry.config.ChannelType.String()), slog.String("config_id", id)) } - runner.stop() + _ = entry.connection.Stop(ctx) } - delete(m.runners, id) + delete(m.connections, id) } } -func (m *Manager) ensureRunner(ctx context.Context, cfg ChannelConfig) error { +func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error { + m.adapterMu.RLock() + receiver := m.receivers[cfg.ChannelType] + m.adapterMu.RUnlock() + if receiver == nil { + return nil + } m.mu.Lock() - runner := m.runners[cfg.ID] + entry := m.connections[cfg.ID] m.mu.Unlock() - if runner != nil { - if runner.config.UpdatedAt.Equal(cfg.UpdatedAt) { - return nil - } - if !runner.supportsStop || runner.stop == nil { - if m.logger != nil { - m.logger.Warn("adapter restart skipped", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) - } + if entry != nil { + if entry.config.UpdatedAt.Equal(cfg.UpdatedAt) { return nil } if m.logger != nil { m.logger.Info("adapter restart", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) } - runner.stop() + if err := entry.connection.Stop(ctx); err != nil { + if errors.Is(err, ErrStopNotSupported) { + if m.logger != nil { + m.logger.Warn("adapter restart skipped", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) + } + return nil + } + return err + } m.mu.Lock() - delete(m.runners, cfg.ID) + delete(m.connections, cfg.ID) m.mu.Unlock() } - - adapter := m.adapters[cfg.ChannelType] - if adapter == nil { - return fmt.Errorf("unsupported channel type: %s", cfg.ChannelType) - } if m.logger != nil { m.logger.Info("adapter start", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) } @@ -235,33 +329,30 @@ func (m *Manager) ensureRunner(ctx context.Context, cfg ChannelConfig) error { handler = m.middlewares[i](handler) } - started, err := adapter.Start(ctx, cfg, handler) + conn, err := receiver.Connect(ctx, cfg, handler) if err != nil { return err } - entry := &runningAdapter{ - adapter: adapter, - config: cfg, - stop: started.Stop, - supportsStop: started.SupportsStop, - } m.mu.Lock() - m.runners[cfg.ID] = entry + m.connections[cfg.ID] = &connectionEntry{ + config: cfg, + connection: conn, + } m.mu.Unlock() return nil } -func (m *Manager) stopAll() { +func (m *Manager) stopAll(ctx context.Context) { m.mu.Lock() defer m.mu.Unlock() - for id, runner := range m.runners { - if runner.supportsStop && runner.stop != nil { + for id, entry := range m.connections { + if entry != nil && entry.connection != nil { if m.logger != nil { - m.logger.Info("adapter stop", slog.String("channel", runner.config.ChannelType.String()), slog.String("config_id", id)) + m.logger.Info("adapter stop", slog.String("channel", entry.config.ChannelType.String()), slog.String("config_id", id)) } - runner.stop() + _ = entry.connection.Stop(ctx) } - delete(m.runners, id) + delete(m.connections, id) } } @@ -269,85 +360,288 @@ func (m *Manager) handleInbound(ctx context.Context, cfg ChannelConfig, msg Inbo if m.processor == nil { return fmt.Errorf("inbound processor not configured") } - reply, err := m.processor.HandleInbound(ctx, cfg, msg) - if err != nil { + sender := m.newReplySender(cfg, msg.Channel) + if err := m.processor.HandleInbound(ctx, cfg, msg, sender); err != nil { if m.logger != nil { m.logger.Error("inbound processing failed", slog.String("channel", msg.Channel.String()), slog.Any("error", err)) } return err } - if reply == nil || strings.TrimSpace(reply.Text) == "" { + return nil +} + +func (m *Manager) Stop(ctx context.Context, configID string) error { + configID = strings.TrimSpace(configID) + if configID == "" { + return fmt.Errorf("config id is required") + } + m.mu.Lock() + entry := m.connections[configID] + m.mu.Unlock() + if entry == nil || entry.connection == nil { return nil } - adapter := m.adapters[msg.Channel] - if adapter == nil { - return fmt.Errorf("unsupported channel type: %s", msg.Channel) + return entry.connection.Stop(ctx) +} + +func (m *Manager) StopByBot(ctx context.Context, botID string) error { + botID = strings.TrimSpace(botID) + if botID == "" { + return fmt.Errorf("bot id is required") } - target := strings.TrimSpace(reply.To) - if target == "" { - return fmt.Errorf("reply target missing") + m.mu.Lock() + defer m.mu.Unlock() + for id, entry := range m.connections { + if entry != nil && entry.config.BotID == botID { + if entry.connection != nil { + _ = entry.connection.Stop(ctx) + } + delete(m.connections, id) + } } - if m.logger != nil { - m.logger.Info("send reply", slog.String("channel", msg.Channel.String())) + return nil +} + +func (m *Manager) Shutdown(ctx context.Context) error { + if m.inboundCancel != nil { + m.inboundCancel() + } + m.stopAll(ctx) + return nil +} + +func (m *Manager) newReplySender(cfg ChannelConfig, channelType ChannelType) ReplySender { + m.adapterMu.RLock() + sender := m.senders[channelType] + m.adapterMu.RUnlock() + return &managerReplySender{ + manager: m, + sender: sender, + channelType: channelType, + config: cfg, + } +} + +func (m *Manager) listAdapterTypes() []ChannelType { + m.adapterMu.RLock() + defer m.adapterMu.RUnlock() + items := make([]ChannelType, 0, len(m.adapters)) + for channelType := range m.adapters { + items = append(items, channelType) + } + return items +} + +type inboundTask struct { + ctx context.Context + cfg ChannelConfig + msg InboundMessage +} + +func (m *Manager) startInboundWorkers(ctx context.Context) { + m.inboundOnce.Do(func() { + workerCtx := ctx + if workerCtx == nil { + workerCtx = context.Background() + } + m.inboundCtx, m.inboundCancel = context.WithCancel(workerCtx) + for i := 0; i < m.inboundWorkers; i++ { + go m.runInboundWorker(m.inboundCtx) + } + }) +} + +func (m *Manager) runInboundWorker(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case task := <-m.inboundQueue: + if err := m.handleInbound(task.ctx, task.cfg, task.msg); err != nil { + if m.logger != nil { + m.logger.Error("inbound processing failed", slog.String("channel", task.msg.Channel.String()), slog.Any("error", err)) + } + } + } + } +} + +func (m *Manager) resolveOutboundPolicy(channelType ChannelType) OutboundPolicy { + policy, ok := GetChannelOutboundPolicy(channelType) + if !ok { + policy = OutboundPolicy{} + } + return NormalizeOutboundPolicy(policy) +} + +func buildOutboundMessages(msg OutboundMessage, policy OutboundPolicy) ([]OutboundMessage, error) { + policy = NormalizeOutboundPolicy(policy) + if msg.Message.IsEmpty() { + return nil, fmt.Errorf("message is required") + } + normalized := normalizeOutboundMessage(msg.Message) + chunker := policy.Chunker + if normalized.Format == MessageFormatMarkdown { + chunker = ChunkMarkdownText + } + base := normalized + base.Attachments = nil + textMessages := make([]OutboundMessage, 0) + shouldChunk := policy.TextChunkLimit > 0 && strings.TrimSpace(base.Text) != "" && len(base.Parts) == 0 + if shouldChunk { + chunks := chunker(base.Text, policy.TextChunkLimit) + for idx, chunk := range chunks { + chunk = strings.TrimSpace(chunk) + if chunk == "" { + continue + } + actions := base.Actions + if len(chunks) > 1 && idx < len(chunks)-1 { + actions = nil + } + item := OutboundMessage{ + Target: msg.Target, + Message: Message{ + ID: base.ID, + Format: base.Format, + Text: chunk, + Parts: base.Parts, + Attachments: nil, + Actions: actions, + Thread: base.Thread, + Reply: base.Reply, + Metadata: base.Metadata, + }, + } + textMessages = append(textMessages, item) + } + } else if !base.IsEmpty() { + textMessages = append(textMessages, OutboundMessage{Target: msg.Target, Message: base}) } - // 增加简单的重试逻辑 + attachments := normalized.Attachments + attachmentMessages := make([]OutboundMessage, 0) + if len(attachments) > 0 { + media := normalized + media.Format = "" + media.Text = "" + media.Parts = nil + media.Actions = nil + media.Attachments = attachments + attachmentMessages = append(attachmentMessages, OutboundMessage{Target: msg.Target, Message: media}) + } + + if len(textMessages) == 0 && len(attachmentMessages) == 0 { + return nil, fmt.Errorf("message is required") + } + if policy.MediaOrder == OutboundOrderTextFirst { + return append(textMessages, attachmentMessages...), nil + } + return append(attachmentMessages, textMessages...), nil +} + +func normalizeOutboundMessage(msg Message) Message { + if msg.Format == "" { + if len(msg.Parts) > 0 { + msg.Format = MessageFormatRich + } else if strings.TrimSpace(msg.Text) != "" { + msg.Format = MessageFormatPlain + } + } + return msg +} + +func (m *Manager) sendWithConfig(ctx context.Context, sender Sender, cfg ChannelConfig, msg OutboundMessage, policy OutboundPolicy) error { + if sender == nil { + return fmt.Errorf("unsupported channel type: %s", cfg.ChannelType) + } + target := strings.TrimSpace(msg.Target) + if target == "" { + return fmt.Errorf("target is required") + } + if msg.Message.IsEmpty() { + return fmt.Errorf("message is required") + } + if caps, ok := GetChannelCapabilities(cfg.ChannelType); ok { + if msg.Message.Format == MessageFormatPlain && !caps.Text { + return fmt.Errorf("channel does not support plain text") + } + if msg.Message.Format == MessageFormatMarkdown && !(caps.Markdown || caps.RichText) { + return fmt.Errorf("channel does not support markdown") + } + if msg.Message.Format == MessageFormatRich && !caps.RichText { + return fmt.Errorf("channel does not support rich text") + } + if len(msg.Message.Parts) > 0 && !caps.RichText { + return fmt.Errorf("channel does not support rich text") + } + if len(msg.Message.Attachments) > 0 && !caps.Attachments { + return fmt.Errorf("channel does not support attachments") + } + if len(msg.Message.Attachments) > 0 && requiresMedia(msg.Message.Attachments) && !caps.Media { + return fmt.Errorf("channel does not support media") + } + if len(msg.Message.Actions) > 0 && !caps.Buttons { + return fmt.Errorf("channel does not support actions") + } + if msg.Message.Thread != nil && !caps.Threads { + return fmt.Errorf("channel does not support threads") + } + if msg.Message.Reply != nil && !caps.Reply { + return fmt.Errorf("channel does not support reply") + } + } + policy = NormalizeOutboundPolicy(policy) var lastErr error - for i := 0; i < 3; i++ { - err = adapter.Send(ctx, cfg, OutboundMessage{ - To: target, - Text: reply.Text, - }) + for i := 0; i < policy.RetryMax; i++ { + err := sender.Send(ctx, cfg, OutboundMessage{Target: target, Message: msg.Message}) if err == nil { return nil } lastErr = err if m.logger != nil { - m.logger.Warn("send reply retry", - slog.String("channel", msg.Channel.String()), + m.logger.Warn("send outbound retry", + slog.String("channel", cfg.ChannelType.String()), slog.Int("attempt", i+1), slog.Any("error", err)) } - time.Sleep(time.Duration(i+1) * 500 * time.Millisecond) // 指数退避 + time.Sleep(time.Duration(i+1) * time.Duration(policy.RetryBackoffMs) * time.Millisecond) } - - return fmt.Errorf("send reply failed after retries: %w", lastErr) + return fmt.Errorf("send outbound failed after retries: %w", lastErr) } -func resolveTargetFromUserConfig(channelType ChannelType, config map[string]interface{}) (string, error) { - switch channelType { - case ChannelTelegram: - userCfg, err := parseTelegramUserConfig(config) - if err != nil { - return "", err +func requiresMedia(attachments []Attachment) bool { + for _, att := range attachments { + switch att.Type { + case AttachmentAudio, AttachmentVideo, AttachmentVoice, AttachmentGIF: + return true + default: + continue } - if userCfg.ChatID != "" { - return userCfg.ChatID, nil - } - if userCfg.UserID != "" { - return userCfg.UserID, nil - } - if userCfg.Username != "" { - name := userCfg.Username - if !strings.HasPrefix(name, "@") { - name = "@" + name - } - return name, nil - } - return "", fmt.Errorf("telegram binding is incomplete") - case ChannelFeishu: - userCfg, err := parseFeishuUserConfig(config) - if err != nil { - return "", err - } - if userCfg.OpenID != "" { - return "open_id:" + userCfg.OpenID, nil - } - if userCfg.UserID != "" { - return "user_id:" + userCfg.UserID, nil - } - return "", fmt.Errorf("feishu binding is incomplete") - default: - return "", fmt.Errorf("unsupported channel type: %s", channelType) } + return false +} + +type managerReplySender struct { + manager *Manager + sender Sender + channelType ChannelType + config ChannelConfig +} + +func (s *managerReplySender) Send(ctx context.Context, msg OutboundMessage) error { + if s.manager == nil { + return fmt.Errorf("channel manager not configured") + } + policy := s.manager.resolveOutboundPolicy(s.channelType) + outbound, err := buildOutboundMessages(msg, policy) + if err != nil { + return err + } + for _, item := range outbound { + if err := s.manager.sendWithConfig(ctx, s.sender, s.config, item, policy); err != nil { + return err + } + } + return nil } diff --git a/internal/channel/manager_core_test.go b/internal/channel/manager_core_test.go index 5e4b37f9..766b1354 100644 --- a/internal/channel/manager_core_test.go +++ b/internal/channel/manager_core_test.go @@ -2,6 +2,7 @@ package channel import ( "context" + "fmt" "log/slog" "testing" ) @@ -11,10 +12,7 @@ type mockAdapter struct { sentMessages []OutboundMessage } -func (m *mockAdapter) Type() ChannelType { return ChannelFeishu } -func (m *mockAdapter) Start(ctx context.Context, cfg ChannelConfig, handler InboundHandler) (AdapterRunner, error) { - return AdapterRunner{}, nil -} +func (m *mockAdapter) Type() ChannelType { return ChannelType("test") } func (m *mockAdapter) Send(ctx context.Context, cfg ChannelConfig, msg OutboundMessage) error { m.sentMessages = append(m.sentMessages, msg) return nil @@ -27,10 +25,19 @@ type fakeInboundProcessor struct { gotMsg InboundMessage } -func (f *fakeInboundProcessor) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage) (*OutboundMessage, error) { +func (f *fakeInboundProcessor) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender ReplySender) error { f.gotCfg = cfg f.gotMsg = msg - return f.resp, f.err + if f.err != nil { + return f.err + } + if f.resp == nil { + return nil + } + if sender == nil { + return fmt.Errorf("sender missing") + } + return sender.Send(ctx, *f.resp) } func TestManager_HandleInbound_CoreLogic(t *testing.T) { @@ -39,8 +46,10 @@ func TestManager_HandleInbound_CoreLogic(t *testing.T) { t.Run("返回回复_发送成功", func(t *testing.T) { processor := &fakeInboundProcessor{ resp: &OutboundMessage{ - To: "target-id", - Text: "AI回复内容", + Target: "target-id", + Message: Message{ + Text: "AI回复内容", + }, }, } @@ -48,12 +57,15 @@ func TestManager_HandleInbound_CoreLogic(t *testing.T) { adapter := &mockAdapter{} m.RegisterAdapter(adapter) - cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelFeishu} + cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelType("test")} msg := InboundMessage{ - Channel: ChannelFeishu, - Text: "你好", - ChatID: "chat-1", - ReplyTo: "target-id", + Channel: ChannelType("test"), + Message: Message{Text: "你好"}, + ReplyTarget: "target-id", + Conversation: Conversation{ + ID: "chat-1", + Type: "p2p", + }, } err := m.handleInbound(context.Background(), cfg, msg) @@ -65,11 +77,11 @@ func TestManager_HandleInbound_CoreLogic(t *testing.T) { if len(adapter.sentMessages) != 1 { t.Fatalf("应该发送 1 条回复,实际发送: %d", len(adapter.sentMessages)) } - if adapter.sentMessages[0].Text != "AI回复内容" { - t.Errorf("回复内容错误: %s", adapter.sentMessages[0].Text) + if adapter.sentMessages[0].Message.PlainText() != "AI回复内容" { + t.Errorf("回复内容错误: %s", adapter.sentMessages[0].Message.PlainText()) } - if adapter.sentMessages[0].To != "target-id" { - t.Errorf("回复目标错误: %s", adapter.sentMessages[0].To) + if adapter.sentMessages[0].Target != "target-id" { + t.Errorf("回复目标错误: %s", adapter.sentMessages[0].Target) } }) @@ -79,11 +91,11 @@ func TestManager_HandleInbound_CoreLogic(t *testing.T) { adapter := &mockAdapter{} m.RegisterAdapter(adapter) - cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelFeishu} + cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelType("test")} msg := InboundMessage{ - Channel: ChannelFeishu, - Text: "你好", - ReplyTo: "target-id", + Channel: ChannelType("test"), + Message: Message{Text: "你好"}, + ReplyTarget: "target-id", } err := m.handleInbound(context.Background(), cfg, msg) @@ -100,7 +112,7 @@ func TestManager_HandleInbound_CoreLogic(t *testing.T) { processor := &fakeInboundProcessor{err: context.Canceled} m := NewManager(logger, &fakeConfigStore{}, processor) cfg := ChannelConfig{ID: "bot-1"} - msg := InboundMessage{Text: " "} // 空格消息 + msg := InboundMessage{Message: Message{Text: " "}} // 空格消息 err := m.handleInbound(context.Background(), cfg, msg) if err == nil { diff --git a/internal/channel/manager_integration_test.go b/internal/channel/manager_integration_test.go index a797f8cb..0212d3af 100644 --- a/internal/channel/manager_integration_test.go +++ b/internal/channel/manager_integration_test.go @@ -5,11 +5,37 @@ import ( "fmt" "io" "log/slog" + "strings" "sync" "testing" "time" ) +func init() { + _ = RegisterChannel(ChannelDescriptor{ + Type: ChannelType("test"), + DisplayName: "Test", + NormalizeConfig: normalizeEmpty, + NormalizeUserConfig: normalizeEmpty, + ResolveTarget: resolveTestTarget, + Capabilities: ChannelCapabilities{ + Text: true, + }, + }) +} + +func normalizeEmpty(map[string]any) (map[string]any, error) { + return map[string]any{}, nil +} + +func resolveTestTarget(config map[string]any) (string, error) { + value := strings.TrimSpace(ReadString(config, "target")) + if value == "" { + return "", fmt.Errorf("missing target") + } + return "resolved:" + value, nil +} + type fakeConfigStore struct { effectiveConfig ChannelConfig userConfig ChannelUserBinding @@ -47,6 +73,10 @@ func (f *fakeConfigStore) ResolveUserBinding(ctx context.Context, channelType Ch return f.boundUserID, nil } +func (f *fakeConfigStore) ListSessionsByBotPlatform(ctx context.Context, botID, platform string) ([]ChannelSession, error) { + return nil, nil +} + func (f *fakeConfigStore) GetChannelSession(ctx context.Context, sessionID string) (ChannelSession, error) { if f.session.SessionID == sessionID { return f.session, nil @@ -54,7 +84,7 @@ func (f *fakeConfigStore) GetChannelSession(ctx context.Context, sessionID strin return ChannelSession{}, nil } -func (f *fakeConfigStore) UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string) error { +func (f *fakeConfigStore) UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string, replyTarget string, threadID string, metadata map[string]any) error { return nil } @@ -65,10 +95,19 @@ type fakeInboundProcessorIntegration struct { gotMsg InboundMessage } -func (f *fakeInboundProcessorIntegration) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage) (*OutboundMessage, error) { +func (f *fakeInboundProcessorIntegration) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender ReplySender) error { f.gotCfg = cfg f.gotMsg = msg - return f.resp, f.err + if f.err != nil { + return f.err + } + if f.resp == nil { + return nil + } + if sender == nil { + return fmt.Errorf("sender missing") + } + return sender.Send(ctx, *f.resp) } type fakeAdapter struct { @@ -83,18 +122,17 @@ func (f *fakeAdapter) Type() ChannelType { return f.channelType } -func (f *fakeAdapter) Start(ctx context.Context, cfg ChannelConfig, handler InboundHandler) (AdapterRunner, error) { +func (f *fakeAdapter) Connect(ctx context.Context, cfg ChannelConfig, handler InboundHandler) (Connection, error) { f.mu.Lock() f.started = append(f.started, cfg) f.mu.Unlock() - return AdapterRunner{ - Stop: func() { - f.mu.Lock() - f.stops++ - f.mu.Unlock() - }, - SupportsStop: true, - }, nil + stop := func(context.Context) error { + f.mu.Lock() + f.stops++ + f.mu.Unlock() + return nil + } + return NewConnection(cfg, stop), nil } func (f *fakeAdapter) Send(ctx context.Context, cfg ChannelConfig, msg OutboundMessage) error { @@ -117,33 +155,38 @@ func TestManagerHandleInboundIntegratesAdapter(t *testing.T) { } processor := &fakeInboundProcessorIntegration{ resp: &OutboundMessage{ - To: "123", - Text: "ok", + Target: "123", + Message: Message{ + Text: "ok", + }, }, } - adapter := &fakeAdapter{channelType: ChannelTelegram} + adapter := &fakeAdapter{channelType: ChannelType("test")} manager := NewManager(log, store, processor) manager.RegisterAdapter(adapter) cfg := ChannelConfig{ ID: "cfg-1", BotID: "bot-1", - ChannelType: ChannelTelegram, - Credentials: map[string]interface{}{"botToken": "token"}, + ChannelType: ChannelType("test"), + Credentials: map[string]any{"botToken": "token"}, UpdatedAt: time.Now(), } err := manager.handleInbound(context.Background(), cfg, InboundMessage{ - Channel: ChannelTelegram, - Text: "hi", - ChatID: "chat-1", - BotID: "bot-1", - ReplyTo: "123", + Channel: ChannelType("test"), + Message: Message{Text: "hi"}, + BotID: "bot-1", + ReplyTarget: "123", + Conversation: Conversation{ + ID: "chat-1", + Type: "p2p", + }, }) if err != nil { t.Fatalf("expected no error, got %v", err) } - if processor.gotMsg.ChatID != "chat-1" || processor.gotMsg.Text != "hi" || processor.gotMsg.BotID != "bot-1" { + if processor.gotMsg.Conversation.ID != "chat-1" || processor.gotMsg.Message.PlainText() != "hi" || processor.gotMsg.BotID != "bot-1" { t.Fatalf("unexpected inbound message: %+v", processor.gotMsg) } @@ -152,7 +195,7 @@ func TestManagerHandleInboundIntegratesAdapter(t *testing.T) { if len(adapter.sent) != 1 { t.Fatalf("expected 1 send, got %d", len(adapter.sent)) } - if adapter.sent[0].To != "123" || adapter.sent[0].Text != "ok" { + if adapter.sent[0].Target != "123" || adapter.sent[0].Message.PlainText() != "ok" { t.Fatalf("unexpected outbound message: %+v", adapter.sent[0]) } } @@ -165,22 +208,24 @@ func TestManagerSendUsesBinding(t *testing.T) { effectiveConfig: ChannelConfig{ ID: "cfg-1", BotID: "bot-1", - ChannelType: ChannelTelegram, - Credentials: map[string]interface{}{"botToken": "token"}, + ChannelType: ChannelType("test"), + Credentials: map[string]any{"botToken": "token"}, UpdatedAt: time.Now(), }, userConfig: ChannelUserBinding{ ID: "binding-1", - Config: map[string]interface{}{"username": "alice"}, + Config: map[string]any{"target": "alice"}, }, } - adapter := &fakeAdapter{channelType: ChannelTelegram} + adapter := &fakeAdapter{channelType: ChannelType("test")} manager := NewManager(log, store, &fakeInboundProcessorIntegration{}) manager.RegisterAdapter(adapter) - err := manager.Send(context.Background(), "bot-1", ChannelTelegram, SendRequest{ - ToUserID: "user-1", - Message: "hello", + err := manager.Send(context.Background(), "bot-1", ChannelType("test"), SendRequest{ + UserID: "user-1", + Message: Message{ + Text: "hello", + }, }) if err != nil { t.Fatalf("expected no error, got %v", err) @@ -191,7 +236,7 @@ func TestManagerSendUsesBinding(t *testing.T) { if len(adapter.sent) != 1 { t.Fatalf("expected 1 send, got %d", len(adapter.sent)) } - if adapter.sent[0].To != "@alice" || adapter.sent[0].Text != "hello" { + if adapter.sent[0].Target != "resolved:alice" || adapter.sent[0].Message.PlainText() != "hello" { t.Fatalf("unexpected outbound message: %+v", adapter.sent[0]) } } @@ -201,15 +246,15 @@ func TestManagerReconcileStartsAndStops(t *testing.T) { log := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) store := &fakeConfigStore{} - adapter := &fakeAdapter{channelType: ChannelTelegram} + adapter := &fakeAdapter{channelType: ChannelType("test")} manager := NewManager(log, store, &fakeInboundProcessorIntegration{}) manager.RegisterAdapter(adapter) cfg := ChannelConfig{ ID: "cfg-1", BotID: "bot-1", - ChannelType: ChannelTelegram, - Credentials: map[string]interface{}{"botToken": "token"}, + ChannelType: ChannelType("test"), + Credentials: map[string]any{"botToken": "token"}, UpdatedAt: time.Now(), } manager.reconcile(context.Background(), []ChannelConfig{cfg}) diff --git a/internal/channel/manager_test.go b/internal/channel/manager_test.go index 6701b82d..a35c45c2 100644 --- a/internal/channel/manager_test.go +++ b/internal/channel/manager_test.go @@ -1,50 +1,22 @@ -package channel +package channel_test import ( "testing" + + "github.com/memohai/memoh/internal/channel" ) -func TestResolveTargetFromUserConfigTelegram(t *testing.T) { +func TestResolveTargetFromUserConfig(t *testing.T) { t.Parallel() + registerTestChannel() - target, err := resolveTargetFromUserConfig(ChannelTelegram, map[string]interface{}{ - "chat_id": "123", - "user_id": "456", - "username": "alice", + target, err := channel.ResolveTargetFromUserConfig(testChannelType, map[string]any{ + "target": "alice", }) if err != nil { t.Fatalf("expected no error, got %v", err) } - if target != "123" { - t.Fatalf("unexpected target: %s", target) - } -} - -func TestResolveTargetFromUserConfigTelegramUsername(t *testing.T) { - t.Parallel() - - target, err := resolveTargetFromUserConfig(ChannelTelegram, map[string]interface{}{ - "username": "alice", - }) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if target != "@alice" { - t.Fatalf("unexpected target: %s", target) - } -} - -func TestResolveTargetFromUserConfigFeishu(t *testing.T) { - t.Parallel() - - target, err := resolveTargetFromUserConfig(ChannelFeishu, map[string]interface{}{ - "open_id": "ou_123", - "user_id": "u_123", - }) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if target != "open_id:ou_123" { + if target != "resolved:alice" { t.Fatalf("unexpected target: %s", target) } } @@ -52,7 +24,7 @@ func TestResolveTargetFromUserConfigFeishu(t *testing.T) { func TestResolveTargetFromUserConfigUnsupported(t *testing.T) { t.Parallel() - _, err := resolveTargetFromUserConfig("unknown", map[string]interface{}{}) + _, err := channel.ResolveTargetFromUserConfig("unknown", map[string]any{}) if err == nil { t.Fatalf("expected error, got nil") } diff --git a/internal/channel/outbound.go b/internal/channel/outbound.go new file mode 100644 index 00000000..79653593 --- /dev/null +++ b/internal/channel/outbound.go @@ -0,0 +1,165 @@ +package channel + +import "strings" + +type ChunkerMode string + +const ( + ChunkerModeText ChunkerMode = "text" + ChunkerModeMarkdown ChunkerMode = "markdown" +) + +type OutboundOrder string + +const ( + OutboundOrderMediaFirst OutboundOrder = "media_first" + OutboundOrderTextFirst OutboundOrder = "text_first" +) + +type Chunker func(text string, limit int) []string + +type OutboundPolicy struct { + TextChunkLimit int `json:"text_chunk_limit,omitempty"` + ChunkerMode ChunkerMode `json:"chunker_mode,omitempty"` + Chunker Chunker `json:"-"` + MediaOrder OutboundOrder `json:"media_order,omitempty"` + RetryMax int `json:"retry_max,omitempty"` + RetryBackoffMs int `json:"retry_backoff_ms,omitempty"` +} + +func NormalizeOutboundPolicy(policy OutboundPolicy) OutboundPolicy { + if policy.TextChunkLimit <= 0 { + policy.TextChunkLimit = 2000 + } + if policy.MediaOrder == "" { + policy.MediaOrder = OutboundOrderMediaFirst + } + if policy.ChunkerMode == "" { + policy.ChunkerMode = ChunkerModeText + } + if policy.RetryMax <= 0 { + policy.RetryMax = 3 + } + if policy.RetryBackoffMs <= 0 { + policy.RetryBackoffMs = 500 + } + if policy.Chunker == nil { + policy.Chunker = DefaultChunker(policy.ChunkerMode) + } + return policy +} + +func DefaultChunker(mode ChunkerMode) Chunker { + switch mode { + case ChunkerModeMarkdown: + return ChunkMarkdownText + default: + return ChunkText + } +} + +func ChunkText(text string, limit int) []string { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return nil + } + if limit <= 0 || runeLen(trimmed) <= limit { + return []string{trimmed} + } + lines := strings.Split(trimmed, "\n") + chunks := make([]string, 0) + buf := make([]string, 0, len(lines)) + bufLen := 0 + for _, line := range lines { + lineLen := runeLen(line) + sepLen := 0 + if len(buf) > 0 { + sepLen = 1 + } + if bufLen+sepLen+lineLen <= limit { + buf = append(buf, line) + bufLen += sepLen + lineLen + continue + } + if len(buf) > 0 { + chunks = append(chunks, strings.Join(buf, "\n")) + buf = buf[:0] + bufLen = 0 + } + if lineLen <= limit { + buf = append(buf, line) + bufLen = lineLen + continue + } + chunks = append(chunks, splitLongLine(line, limit)...) + } + if len(buf) > 0 { + chunks = append(chunks, strings.Join(buf, "\n")) + } + return chunks +} + +func ChunkMarkdownText(text string, limit int) []string { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return nil + } + if limit <= 0 || runeLen(trimmed) <= limit { + return []string{trimmed} + } + paragraphs := strings.Split(trimmed, "\n\n") + chunks := make([]string, 0) + buf := make([]string, 0, len(paragraphs)) + bufLen := 0 + for _, para := range paragraphs { + paraLen := runeLen(para) + sepLen := 0 + if len(buf) > 0 { + sepLen = 2 + } + if bufLen+sepLen+paraLen <= limit { + buf = append(buf, para) + bufLen += sepLen + paraLen + continue + } + if len(buf) > 0 { + chunks = append(chunks, strings.Join(buf, "\n\n")) + buf = buf[:0] + bufLen = 0 + } + if paraLen <= limit { + buf = append(buf, para) + bufLen = paraLen + continue + } + chunks = append(chunks, ChunkText(para, limit)...) + } + if len(buf) > 0 { + chunks = append(chunks, strings.Join(buf, "\n\n")) + } + return chunks +} + +func runeLen(value string) int { + return len([]rune(value)) +} + +func splitLongLine(line string, limit int) []string { + if limit <= 0 { + return []string{line} + } + runes := []rune(line) + chunks := make([]string, 0) + for start := 0; start < len(runes); start += limit { + end := start + limit + if end > len(runes) { + end = len(runes) + } + segment := strings.TrimSpace(string(runes[start:end])) + if segment == "" { + continue + } + chunks = append(chunks, segment) + } + return chunks +} diff --git a/internal/channel/processor.go b/internal/channel/processor.go index e9f0b905..9d6be6e7 100644 --- a/internal/channel/processor.go +++ b/internal/channel/processor.go @@ -2,7 +2,7 @@ package channel import "context" -// InboundProcessor 负责处理入站消息并产出可发送的响应。 +// InboundProcessor 负责处理入站消息并通过 sender 回传响应。 type InboundProcessor interface { - HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage) (*OutboundMessage, error) + HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender ReplySender) error } diff --git a/internal/channel/registry.go b/internal/channel/registry.go index ae949bea..6826092e 100644 --- a/internal/channel/registry.go +++ b/internal/channel/registry.go @@ -9,8 +9,18 @@ import ( type ChannelDescriptor struct { Type ChannelType DisplayName string - NormalizeConfig func(map[string]interface{}) (map[string]interface{}, error) - NormalizeUserConfig func(map[string]interface{}) (map[string]interface{}, error) + NormalizeConfig func(map[string]any) (map[string]any, error) + NormalizeUserConfig func(map[string]any) (map[string]any, error) + ResolveTarget func(map[string]any) (string, error) + MatchBinding func(map[string]any, BindingCriteria) bool + BuildUserConfig func(Identity) map[string]any + Configless bool + Capabilities ChannelCapabilities + OutboundPolicy OutboundPolicy + ConfigSchema ConfigSchema + UserConfigSchema ConfigSchema + TargetSpec TargetSpec + NormalizeTarget func(string) string } type channelRegistry struct { @@ -46,6 +56,20 @@ func MustRegisterChannel(desc ChannelDescriptor) { } } +func UnregisterChannel(channelType ChannelType) bool { + normalized := normalizeChannelType(channelType.String()) + if normalized == "" { + return false + } + registry.mu.Lock() + defer registry.mu.Unlock() + if _, exists := registry.items[normalized]; !exists { + return false + } + delete(registry.items, normalized) + return true +} + func GetChannelDescriptor(channelType ChannelType) (ChannelDescriptor, bool) { normalized := normalizeChannelType(channelType.String()) registry.mu.RLock() @@ -64,6 +88,46 @@ func ListChannelDescriptors() []ChannelDescriptor { return items } +func GetChannelCapabilities(channelType ChannelType) (ChannelCapabilities, bool) { + desc, ok := GetChannelDescriptor(channelType) + if !ok { + return ChannelCapabilities{}, false + } + return desc.Capabilities, true +} + +func GetChannelOutboundPolicy(channelType ChannelType) (OutboundPolicy, bool) { + desc, ok := GetChannelDescriptor(channelType) + if !ok { + return OutboundPolicy{}, false + } + return desc.OutboundPolicy, true +} + +func GetChannelConfigSchema(channelType ChannelType) (ConfigSchema, bool) { + desc, ok := GetChannelDescriptor(channelType) + if !ok { + return ConfigSchema{}, false + } + return desc.ConfigSchema, true +} + +func GetChannelUserConfigSchema(channelType ChannelType) (ConfigSchema, bool) { + desc, ok := GetChannelDescriptor(channelType) + if !ok { + return ConfigSchema{}, false + } + return desc.UserConfigSchema, true +} + +func IsConfigless(channelType ChannelType) bool { + desc, ok := GetChannelDescriptor(channelType) + if !ok { + return false + } + return desc.Configless +} + func normalizeChannelType(raw string) ChannelType { normalized := strings.TrimSpace(strings.ToLower(raw)) if normalized == "" { diff --git a/internal/channel/registry_test.go b/internal/channel/registry_test.go deleted file mode 100644 index 6305c31f..00000000 --- a/internal/channel/registry_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package channel - -func init() { - MustRegisterChannel(ChannelDescriptor{ - Type: ChannelTelegram, - DisplayName: "Telegram", - NormalizeConfig: NormalizeTelegramConfig, - NormalizeUserConfig: NormalizeTelegramUserConfig, - }) - MustRegisterChannel(ChannelDescriptor{ - Type: ChannelFeishu, - DisplayName: "Feishu", - NormalizeConfig: NormalizeFeishuConfig, - NormalizeUserConfig: NormalizeFeishuUserConfig, - }) - MustRegisterChannel(ChannelDescriptor{ - Type: ChannelCLI, - DisplayName: "CLI", - NormalizeConfig: func(map[string]interface{}) (map[string]interface{}, error) { return map[string]interface{}{}, nil }, - NormalizeUserConfig: func(map[string]interface{}) (map[string]interface{}, error) { return map[string]interface{}{}, nil }, - }) - MustRegisterChannel(ChannelDescriptor{ - Type: ChannelWeb, - DisplayName: "Web", - NormalizeConfig: func(map[string]interface{}) (map[string]interface{}, error) { return map[string]interface{}{}, nil }, - NormalizeUserConfig: func(map[string]interface{}) (map[string]interface{}, error) { return map[string]interface{}{}, nil }, - }) -} diff --git a/internal/channel/schema.go b/internal/channel/schema.go new file mode 100644 index 00000000..cb9e7db7 --- /dev/null +++ b/internal/channel/schema.go @@ -0,0 +1,27 @@ +package channel + +type FieldType string + +const ( + FieldString FieldType = "string" + FieldSecret FieldType = "secret" + FieldBool FieldType = "bool" + FieldNumber FieldType = "number" + FieldEnum FieldType = "enum" +) + +// FieldSchema 定义单个配置字段的结构化描述。 +type FieldSchema struct { + Type FieldType `json:"type"` + Required bool `json:"required"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + Enum []string `json:"enum,omitempty"` + Example any `json:"example,omitempty"` +} + +// ConfigSchema 描述通道配置或用户绑定的结构。 +type ConfigSchema struct { + Version int `json:"version"` + Fields map[string]FieldSchema `json:"fields"` +} diff --git a/internal/channel/service.go b/internal/channel/service.go index 25b6e88c..f3cdd25e 100644 --- a/internal/channel/service.go +++ b/internal/channel/service.go @@ -44,7 +44,7 @@ func (s *Service) UpsertConfig(ctx context.Context, botID string, channelType Ch } selfIdentity := req.SelfIdentity if selfIdentity == nil { - selfIdentity = map[string]interface{}{} + selfIdentity = map[string]any{} } selfPayload, err := json.Marshal(selfIdentity) if err != nil { @@ -52,7 +52,7 @@ func (s *Service) UpsertConfig(ctx context.Context, botID string, channelType Ch } routing := req.Routing if routing == nil { - routing = map[string]interface{}{} + routing = map[string]any{} } routingPayload, err := json.Marshal(routing) if err != nil { @@ -60,7 +60,7 @@ func (s *Service) UpsertConfig(ctx context.Context, botID string, channelType Ch } capabilities := req.Capabilities if capabilities == nil { - capabilities = map[string]interface{}{} + capabilities = map[string]any{} } capabilitiesPayload, err := json.Marshal(capabilities) if err != nil { @@ -132,7 +132,7 @@ func (s *Service) ResolveEffectiveConfig(ctx context.Context, botID string, chan if channelType == "" { return ChannelConfig{}, fmt.Errorf("channel type is required") } - if channelType == ChannelCLI || channelType == ChannelWeb { + if IsConfigless(channelType) { return ChannelConfig{ ID: channelType.String() + ":" + strings.TrimSpace(botID), BotID: strings.TrimSpace(botID), @@ -160,7 +160,7 @@ func (s *Service) ListConfigsByType(ctx context.Context, channelType ChannelType if s.queries == nil { return nil, fmt.Errorf("channel queries not configured") } - if channelType == ChannelCLI || channelType == ChannelWeb { + if IsConfigless(channelType) { return []ChannelConfig{}, nil } rows, err := s.queries.ListBotChannelConfigsByType(ctx, channelType.String()) @@ -199,7 +199,7 @@ func (s *Service) GetUserConfig(ctx context.Context, actorUserID string, channel } return ChannelUserBinding{}, err } - config, err := decodeConfigMap(row.Config) + config, err := DecodeConfigMap(row.Config) if err != nil { return ChannelUserBinding{}, err } @@ -243,19 +243,44 @@ func (s *Service) GetChannelSession(ctx context.Context, sessionID string) (Chan } return ChannelSession{}, err } - return ChannelSession{ - SessionID: row.SessionID, - BotID: toUUIDString(row.BotID), - ChannelConfigID: toUUIDString(row.ChannelConfigID), - UserID: toUUIDString(row.UserID), - ContactID: toUUIDString(row.ContactID), - Platform: row.Platform, - CreatedAt: timeFromPg(row.CreatedAt), - UpdatedAt: timeFromPg(row.UpdatedAt), - }, nil + return normalizeChannelSession(row) } -func (s *Service) UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string) error { +func (s *Service) ListSessionsByBotPlatform(ctx context.Context, botID, platform string) ([]ChannelSession, error) { + if s.queries == nil { + return nil, fmt.Errorf("channel queries not configured") + } + botID = strings.TrimSpace(botID) + platform = strings.TrimSpace(platform) + if botID == "" { + return nil, fmt.Errorf("bot id is required") + } + if platform == "" { + return nil, fmt.Errorf("platform is required") + } + pgBotID, err := parseUUID(botID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListChannelSessionsByBotPlatform(ctx, sqlc.ListChannelSessionsByBotPlatformParams{ + BotID: pgBotID, + Platform: platform, + }) + if err != nil { + return nil, err + } + items := make([]ChannelSession, 0, len(rows)) + for _, row := range rows { + item, err := normalizeChannelSession(row) + if err != nil { + return nil, err + } + items = append(items, item) + } + return items, nil +} + +func (s *Service) UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string, replyTarget string, threadID string, metadata map[string]any) error { if s.queries == nil { return fmt.Errorf("channel queries not configured") } @@ -286,6 +311,14 @@ func (s *Service) UpsertChannelSession(ctx context.Context, sessionID string, bo } pgContactID = parsed } + payload := metadata + if payload == nil { + payload = map[string]any{} + } + metaBytes, err := json.Marshal(payload) + if err != nil { + return err + } _, err = s.queries.UpsertChannelSession(ctx, sqlc.UpsertChannelSessionParams{ SessionID: sessionID, BotID: botUUID, @@ -293,6 +326,15 @@ func (s *Service) UpsertChannelSession(ctx context.Context, sessionID string, bo UserID: pgUserID, ContactID: pgContactID, Platform: platform, + ReplyTarget: pgtype.Text{ + String: strings.TrimSpace(replyTarget), + Valid: strings.TrimSpace(replyTarget) != "", + }, + ThreadID: pgtype.Text{ + String: strings.TrimSpace(threadID), + Valid: strings.TrimSpace(threadID) != "", + }, + Metadata: metaBytes, }) return err } @@ -302,77 +344,31 @@ func (s *Service) ResolveUserBinding(ctx context.Context, channelType ChannelTyp if err != nil { return "", err } - switch channelType { - case ChannelTelegram: - for _, row := range rows { - cfg, err := parseTelegramUserConfig(row.Config) - if err != nil { - continue - } - if matchTelegramBinding(cfg, criteria) { - return row.UserID, nil - } - } - case ChannelFeishu: - for _, row := range rows { - cfg, err := parseFeishuUserConfig(row.Config) - if err != nil { - continue - } - if matchFeishuBinding(cfg, criteria) { - return row.UserID, nil - } - } - default: + if _, ok := GetChannelDescriptor(channelType); !ok { return "", fmt.Errorf("unsupported channel type: %s", channelType) } + for _, row := range rows { + if MatchUserBinding(channelType, row.Config, criteria) { + return row.UserID, nil + } + } return "", fmt.Errorf("channel user binding not found") } -type BindingCriteria struct { - Username string - UserID string - ChatID string - OpenID string -} - -func matchTelegramBinding(cfg TelegramUserConfig, criteria BindingCriteria) bool { - if criteria.ChatID != "" && cfg.ChatID != "" && criteria.ChatID == cfg.ChatID { - return true - } - if criteria.UserID != "" && cfg.UserID != "" && criteria.UserID == cfg.UserID { - return true - } - if criteria.Username != "" && cfg.Username != "" && strings.EqualFold(criteria.Username, cfg.Username) { - return true - } - return false -} - -func matchFeishuBinding(cfg FeishuUserConfig, criteria BindingCriteria) bool { - if criteria.OpenID != "" && cfg.OpenID != "" && criteria.OpenID == cfg.OpenID { - return true - } - if criteria.UserID != "" && cfg.UserID != "" && criteria.UserID == cfg.UserID { - return true - } - return false -} - func normalizeChannelConfig(row sqlc.BotChannelConfig) (ChannelConfig, error) { - credentials, err := decodeConfigMap(row.Credentials) + credentials, err := DecodeConfigMap(row.Credentials) if err != nil { return ChannelConfig{}, err } - selfIdentity, err := decodeConfigMap(row.SelfIdentity) + selfIdentity, err := DecodeConfigMap(row.SelfIdentity) if err != nil { return ChannelConfig{}, err } - routing, err := decodeConfigMap(row.Routing) + routing, err := DecodeConfigMap(row.Routing) if err != nil { return ChannelConfig{}, err } - capabilities, err := decodeConfigMap(row.Capabilities) + capabilities, err := DecodeConfigMap(row.Capabilities) if err != nil { return ChannelConfig{}, err } @@ -401,7 +397,7 @@ func normalizeChannelConfig(row sqlc.BotChannelConfig) (ChannelConfig, error) { } func normalizeChannelUserBindingRow(row sqlc.UserChannelBinding) (ChannelUserBinding, error) { - config, err := decodeConfigMap(row.Config) + config, err := DecodeConfigMap(row.Config) if err != nil { return ChannelUserBinding{}, err } @@ -416,7 +412,7 @@ func normalizeChannelUserBindingRow(row sqlc.UserChannelBinding) (ChannelUserBin } func normalizeChannelUserBindingListRow(row sqlc.UserChannelBinding) (ChannelUserBinding, error) { - config, err := decodeConfigMap(row.Config) + config, err := DecodeConfigMap(row.Config) if err != nil { return ChannelUserBinding{}, err } @@ -430,6 +426,26 @@ func normalizeChannelUserBindingListRow(row sqlc.UserChannelBinding) (ChannelUse }, nil } +func normalizeChannelSession(row sqlc.ChannelSession) (ChannelSession, error) { + metadata, err := DecodeConfigMap(row.Metadata) + if err != nil { + return ChannelSession{}, err + } + return ChannelSession{ + SessionID: row.SessionID, + BotID: toUUIDString(row.BotID), + ChannelConfigID: toUUIDString(row.ChannelConfigID), + UserID: toUUIDString(row.UserID), + ContactID: toUUIDString(row.ContactID), + Platform: row.Platform, + ReplyTarget: strings.TrimSpace(row.ReplyTarget.String), + ThreadID: strings.TrimSpace(row.ThreadID.String), + Metadata: metadata, + CreatedAt: timeFromPg(row.CreatedAt), + UpdatedAt: timeFromPg(row.UpdatedAt), + }, nil +} + func parseUUID(id string) (pgtype.UUID, error) { parsed, err := uuid.Parse(strings.TrimSpace(id)) if err != nil { diff --git a/internal/channel/target.go b/internal/channel/target.go new file mode 100644 index 00000000..67a88548 --- /dev/null +++ b/internal/channel/target.go @@ -0,0 +1,25 @@ +package channel + +import "strings" + +type TargetHint struct { + Example string `json:"example,omitempty"` + Label string `json:"label,omitempty"` +} + +type TargetSpec struct { + Format string `json:"format"` + Hints []TargetHint `json:"hints,omitempty"` +} + +func NormalizeTarget(channelType ChannelType, raw string) (string, bool) { + desc, ok := GetChannelDescriptor(channelType) + if !ok || desc.NormalizeTarget == nil { + return strings.TrimSpace(raw), false + } + normalized := strings.TrimSpace(desc.NormalizeTarget(raw)) + if normalized == "" { + return "", false + } + return normalized, true +} diff --git a/internal/channel/types.go b/internal/channel/types.go index 0669e86d..3b66934c 100644 --- a/internal/channel/types.go +++ b/internal/channel/types.go @@ -2,18 +2,12 @@ package channel import ( "fmt" + "strings" "time" ) type ChannelType string -const ( - ChannelTelegram ChannelType = "telegram" - ChannelFeishu ChannelType = "feishu" - ChannelCLI ChannelType = "cli" - ChannelWeb ChannelType = "web" -) - func ParseChannelType(raw string) (ChannelType, error) { normalized := normalizeChannelType(raw) if normalized == "" { @@ -25,15 +19,226 @@ func ParseChannelType(raw string) (ChannelType, error) { return normalized, nil } +type Identity struct { + ExternalID string + DisplayName string + Attributes map[string]string +} + +func (i Identity) Attribute(key string) string { + if i.Attributes == nil { + return "" + } + return strings.TrimSpace(i.Attributes[key]) +} + +type Conversation struct { + ID string + Type string + Name string + ThreadID string + Metadata map[string]any +} + +type InboundMessage struct { + Channel ChannelType + Message Message + BotID string + ReplyTarget string + SessionKey string + Sender Identity + Conversation Conversation + ReceivedAt time.Time + Source string + Metadata map[string]any +} + +// SessionID 结构: platform:bot_id:conversation_id[:sender_id] +func (m InboundMessage) SessionID() string { + if strings.TrimSpace(m.SessionKey) != "" { + return strings.TrimSpace(m.SessionKey) + } + senderID := strings.TrimSpace(m.Sender.ExternalID) + if senderID == "" { + senderID = strings.TrimSpace(m.Sender.DisplayName) + } + return GenerateSessionID(string(m.Channel), m.BotID, m.Conversation.ID, m.Conversation.Type, senderID) +} + +// GenerateSessionID 统一生成 SessionID 的逻辑 +func GenerateSessionID(platform, botID, conversationID, conversationType, senderID string) string { + parts := []string{platform, botID, conversationID} + // 如果是群聊,增加发送者 ID 以支持个人上下文 + ct := strings.ToLower(strings.TrimSpace(conversationType)) + if ct != "" && ct != "p2p" && ct != "private" { + senderID = strings.TrimSpace(senderID) + if senderID != "" { + parts = append(parts, senderID) + } + } + return strings.Join(parts, ":") +} + +type OutboundMessage struct { + Target string `json:"target"` + Message Message `json:"message"` +} + +type MessageFormat string + +const ( + MessageFormatPlain MessageFormat = "plain" + MessageFormatMarkdown MessageFormat = "markdown" + MessageFormatRich MessageFormat = "rich" +) + +type MessagePartType string + +const ( + MessagePartText MessagePartType = "text" + MessagePartLink MessagePartType = "link" + MessagePartCodeBlock MessagePartType = "code_block" + MessagePartMention MessagePartType = "mention" + MessagePartEmoji MessagePartType = "emoji" +) + +type MessageTextStyle string + +const ( + MessageStyleBold MessageTextStyle = "bold" + MessageStyleItalic MessageTextStyle = "italic" + MessageStyleStrikethrough MessageTextStyle = "strikethrough" + MessageStyleCode MessageTextStyle = "code" +) + +type MessagePart struct { + Type MessagePartType `json:"type"` + Text string `json:"text,omitempty"` + URL string `json:"url,omitempty"` + Styles []MessageTextStyle `json:"styles,omitempty"` + Language string `json:"language,omitempty"` + UserID string `json:"user_id,omitempty"` + Emoji string `json:"emoji,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type AttachmentType string + +const ( + AttachmentImage AttachmentType = "image" + AttachmentAudio AttachmentType = "audio" + AttachmentVideo AttachmentType = "video" + AttachmentVoice AttachmentType = "voice" + AttachmentFile AttachmentType = "file" + AttachmentGIF AttachmentType = "gif" +) + +type Attachment struct { + Type AttachmentType `json:"type"` + URL string `json:"url,omitempty"` + Name string `json:"name,omitempty"` + Size int64 `json:"size,omitempty"` + Mime string `json:"mime,omitempty"` + DurationMs int64 `json:"duration_ms,omitempty"` + Width int `json:"width,omitempty"` + Height int `json:"height,omitempty"` + ThumbnailURL string `json:"thumbnail_url,omitempty"` + Caption string `json:"caption,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type Action struct { + Type string `json:"type"` + Label string `json:"label,omitempty"` + Value string `json:"value,omitempty"` + URL string `json:"url,omitempty"` +} + +type ThreadRef struct { + ID string `json:"id"` +} + +type ReplyRef struct { + Target string `json:"target,omitempty"` + MessageID string `json:"message_id,omitempty"` +} + +type Message struct { + ID string `json:"id,omitempty"` + Format MessageFormat `json:"format,omitempty"` + Text string `json:"text,omitempty"` + Parts []MessagePart `json:"parts,omitempty"` + Attachments []Attachment `json:"attachments,omitempty"` + Actions []Action `json:"actions,omitempty"` + Thread *ThreadRef `json:"thread,omitempty"` + Reply *ReplyRef `json:"reply,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +func (m Message) IsEmpty() bool { + return strings.TrimSpace(m.Text) == "" && + len(m.Parts) == 0 && + len(m.Attachments) == 0 && + len(m.Actions) == 0 +} + +func (m Message) PlainText() string { + if strings.TrimSpace(m.Text) != "" { + return strings.TrimSpace(m.Text) + } + if len(m.Parts) == 0 { + return "" + } + lines := make([]string, 0, len(m.Parts)) + for _, part := range m.Parts { + switch part.Type { + case MessagePartText, MessagePartLink, MessagePartCodeBlock, MessagePartMention, MessagePartEmoji: + value := strings.TrimSpace(part.Text) + if value == "" && part.Type == MessagePartLink { + value = strings.TrimSpace(part.URL) + } + if value == "" && part.Type == MessagePartEmoji { + value = strings.TrimSpace(part.Emoji) + } + if value == "" { + continue + } + lines = append(lines, value) + default: + continue + } + } + return strings.Join(lines, "\n") +} + +type BindingCriteria struct { + ExternalID string + Attributes map[string]string +} + +func (c BindingCriteria) Attribute(key string) string { + if c.Attributes == nil { + return "" + } + return strings.TrimSpace(c.Attributes[key]) +} + +func BindingCriteriaFromIdentity(identity Identity) BindingCriteria { + return BindingCriteria{ + ExternalID: strings.TrimSpace(identity.ExternalID), + Attributes: identity.Attributes, + } +} + type ChannelConfig struct { ID string BotID string ChannelType ChannelType - Credentials map[string]interface{} + Credentials map[string]any ExternalIdentity string - SelfIdentity map[string]interface{} - Routing map[string]interface{} - Capabilities map[string]interface{} + SelfIdentity map[string]any + Routing map[string]any + Capabilities map[string]any Status string VerifiedAt time.Time CreatedAt time.Time @@ -44,23 +249,23 @@ type ChannelUserBinding struct { ID string ChannelType ChannelType UserID string - Config map[string]interface{} + Config map[string]any CreatedAt time.Time UpdatedAt time.Time } type UpsertConfigRequest struct { - Credentials map[string]interface{} `json:"credentials"` - ExternalIdentity string `json:"external_identity,omitempty"` - SelfIdentity map[string]interface{} `json:"self_identity,omitempty"` - Routing map[string]interface{} `json:"routing,omitempty"` - Capabilities map[string]interface{} `json:"capabilities,omitempty"` - Status string `json:"status,omitempty"` - VerifiedAt *time.Time `json:"verified_at,omitempty"` + Credentials map[string]any `json:"credentials"` + ExternalIdentity string `json:"external_identity,omitempty"` + SelfIdentity map[string]any `json:"self_identity,omitempty"` + Routing map[string]any `json:"routing,omitempty"` + Capabilities map[string]any `json:"capabilities,omitempty"` + Status string `json:"status,omitempty"` + VerifiedAt *time.Time `json:"verified_at,omitempty"` } type UpsertUserConfigRequest struct { - Config map[string]interface{} `json:"config"` + Config map[string]any `json:"config"` } type ChannelSession struct { @@ -70,12 +275,15 @@ type ChannelSession struct { UserID string ContactID string Platform string + ReplyTarget string + ThreadID string + Metadata map[string]any CreatedAt time.Time UpdatedAt time.Time } type SendRequest struct { - To string `json:"to"` - ToUserID string `json:"to_user_id"` - Message string `json:"message"` + Target string `json:"target,omitempty"` + UserID string `json:"user_id,omitempty"` + Message Message `json:"message"` } diff --git a/internal/chat/assistant_output.go b/internal/chat/assistant_output.go new file mode 100644 index 00000000..a36f15cb --- /dev/null +++ b/internal/chat/assistant_output.go @@ -0,0 +1,52 @@ +package chat + +import "strings" + +type AssistantOutput struct { + Content string + Parts []ContentPart +} + +func ExtractAssistantOutputs(messages []GatewayMessage) []AssistantOutput { + if len(messages) == 0 { + return nil + } + outputs := make([]AssistantOutput, 0, len(messages)) + for _, msg := range messages { + normalized := normalizeGatewayMessage(msg) + for _, item := range normalized { + if item.Role != "assistant" { + continue + } + content := strings.TrimSpace(item.Content) + parts := make([]ContentPart, 0, len(item.Parts)) + for _, part := range item.Parts { + if !hasContentPartValue(part) { + continue + } + parts = append(parts, part) + } + if content == "" && len(parts) == 0 { + continue + } + outputs = append(outputs, AssistantOutput{ + Content: content, + Parts: parts, + }) + } + } + return outputs +} + +func hasContentPartValue(part ContentPart) bool { + if strings.TrimSpace(part.Text) != "" { + return true + } + if strings.TrimSpace(part.URL) != "" { + return true + } + if strings.TrimSpace(part.Emoji) != "" { + return true + } + return false +} diff --git a/internal/chat/normalize.go b/internal/chat/normalize.go index 958d9c8c..8f3962d9 100644 --- a/internal/chat/normalize.go +++ b/internal/chat/normalize.go @@ -32,7 +32,7 @@ func normalizeGatewayMessage(msg GatewayMessage) []NormalizedMessage { var textParts []ContentPart var toolResults []toolResult - if rawCalls, ok := msg["tool_calls"].([]interface{}); ok { + if rawCalls, ok := msg["tool_calls"].([]any); ok { for _, raw := range rawCalls { if call := normalizeToolCall(raw); call.Function.Name != "" { toolCalls = append(toolCalls, call) @@ -52,16 +52,16 @@ func normalizeGatewayMessage(msg GatewayMessage) []NormalizedMessage { } return appendToolResults([]NormalizedMessage{normalized}, toolResults) } - case []interface{}: + case []any: for _, part := range content { switch p := part.(type) { case string: if strings.TrimSpace(p) != "" { textParts = append(textParts, ContentPart{Type: "text", Text: p}) } - case map[string]interface{}: - if text := normalizeTextPart(p); text != "" { - textParts = append(textParts, ContentPart{Type: "text", Text: text}) + case map[string]any: + if contentPart, ok := normalizeContentPart(p); ok { + textParts = append(textParts, contentPart) continue } if call := normalizeToolCall(p); call.Function.Name != "" { @@ -81,9 +81,9 @@ func normalizeGatewayMessage(msg GatewayMessage) []NormalizedMessage { } } } - case map[string]interface{}: - if text := normalizeTextPart(content); text != "" { - textParts = append(textParts, ContentPart{Type: "text", Text: text}) + case map[string]any: + if contentPart, ok := normalizeContentPart(content); ok { + textParts = append(textParts, contentPart) } else if encoded := toJSONString(content); encoded != "" { textParts = append(textParts, ContentPart{Type: "text", Text: encoded}) } @@ -126,7 +126,7 @@ func appendToolResults(messages []NormalizedMessage, results []toolResult) []Nor return messages } -func normalizeTextPart(part map[string]interface{}) string { +func normalizeTextPart(part map[string]any) string { if part == nil { return "" } @@ -141,9 +141,60 @@ func normalizeTextPart(part map[string]interface{}) string { return "" } -func normalizeToolCall(part interface{}) ToolCall { +func normalizeContentPart(part map[string]any) (ContentPart, bool) { + if part == nil { + return ContentPart{}, false + } + partType := getString(part["type"]) + if partType == "" { + partType = "text" + } + if partType == "tool_use" || partType == "tool-call" || partType == "function_call" || partType == "tool_result" || partType == "tool-result" { + return ContentPart{}, false + } + text := normalizeTextPart(part) + url := getString(part["url"]) + emoji := getString(part["emoji"]) + if strings.TrimSpace(text) == "" && strings.TrimSpace(url) == "" && strings.TrimSpace(emoji) == "" { + return ContentPart{}, false + } + styles := normalizeStringSlice(part["styles"]) + metadata := map[string]any{} + if raw, ok := part["metadata"].(map[string]any); ok && raw != nil { + metadata = raw + } + return ContentPart{ + Type: partType, + Text: text, + URL: url, + Styles: styles, + Language: getString(part["language"]), + UserID: getString(part["user_id"]), + Emoji: emoji, + Metadata: metadata, + }, true +} + +func normalizeStringSlice(raw any) []string { + switch value := raw.(type) { + case []string: + return value + case []any: + items := make([]string, 0, len(value)) + for _, entry := range value { + if str, ok := entry.(string); ok && strings.TrimSpace(str) != "" { + items = append(items, strings.TrimSpace(str)) + } + } + return items + default: + return nil + } +} + +func normalizeToolCall(part any) ToolCall { switch value := part.(type) { - case map[string]interface{}: + case map[string]any: if valueType, _ := value["type"].(string); valueType == "tool_use" || valueType == "tool-call" || valueType == "function_call" { return ToolCall{ ID: getString(value["id"]), @@ -154,7 +205,7 @@ func normalizeToolCall(part interface{}) ToolCall { }, } } - if fc, ok := value["function_call"].(map[string]interface{}); ok { + if fc, ok := value["function_call"].(map[string]any); ok { return ToolCall{ ID: getString(value["id"]), Type: "function", @@ -164,7 +215,7 @@ func normalizeToolCall(part interface{}) ToolCall { }, } } - if fc, ok := value["functionCall"].(map[string]interface{}); ok { + if fc, ok := value["functionCall"].(map[string]any); ok { return ToolCall{ ID: getString(value["id"]), Type: "function", @@ -174,7 +225,7 @@ func normalizeToolCall(part interface{}) ToolCall { }, } } - if fn, ok := value["function"].(map[string]interface{}); ok { + if fn, ok := value["function"].(map[string]any); ok { return ToolCall{ ID: getString(value["id"]), Type: "function", @@ -188,7 +239,7 @@ func normalizeToolCall(part interface{}) ToolCall { return ToolCall{} } -func normalizeToolResult(part map[string]interface{}) toolResult { +func normalizeToolResult(part map[string]any) toolResult { if part == nil { return toolResult{} } @@ -198,13 +249,13 @@ func normalizeToolResult(part map[string]interface{}) toolResult { Content: normalizeToolResultContent(part["content"], part["result"], part["output"]), } } - if raw, ok := part["toolResult"].(map[string]interface{}); ok { + if raw, ok := part["toolResult"].(map[string]any); ok { return toolResult{ ToolCallID: firstString(raw["toolUseId"], raw["tool_call_id"], raw["id"]), Content: normalizeToolResultContent(raw["content"], raw["output"], raw["result"]), } } - if raw, ok := part["functionResponse"].(map[string]interface{}); ok { + if raw, ok := part["functionResponse"].(map[string]any); ok { return toolResult{ ToolCallID: firstString(raw["id"]), Content: normalizeToolResultContent(raw["response"], raw["output"], raw["result"]), @@ -213,7 +264,7 @@ func normalizeToolResult(part map[string]interface{}) toolResult { return toolResult{} } -func normalizeToolResultContent(values ...interface{}) string { +func normalizeToolResultContent(values ...any) string { for _, value := range values { if value == nil { continue @@ -223,7 +274,7 @@ func normalizeToolResultContent(values ...interface{}) string { if strings.TrimSpace(v) != "" { return v } - case []interface{}: + case []any: parts := make([]string, 0, len(v)) for _, item := range v { switch itemValue := item.(type) { @@ -231,7 +282,7 @@ func normalizeToolResultContent(values ...interface{}) string { if strings.TrimSpace(itemValue) != "" { parts = append(parts, itemValue) } - case map[string]interface{}: + case map[string]any: if text := normalizeTextPart(itemValue); text != "" { parts = append(parts, text) } else if encoded := toJSONString(itemValue); encoded != "" { @@ -246,7 +297,7 @@ func normalizeToolResultContent(values ...interface{}) string { if len(parts) > 0 { return strings.Join(parts, "\n") } - case map[string]interface{}: + case map[string]any: if text := normalizeTextPart(v); text != "" { return text } @@ -271,9 +322,9 @@ func toGatewayMessages(messages []NormalizedMessage) []GatewayMessage { if strings.TrimSpace(msg.Content) != "" { item["content"] = msg.Content } else if len(msg.Parts) > 0 { - parts := make([]map[string]interface{}, 0, len(msg.Parts)) + parts := make([]map[string]any, 0, len(msg.Parts)) for _, part := range msg.Parts { - entry := map[string]interface{}{ + entry := map[string]any{ "type": part.Type, } if strings.TrimSpace(part.Text) != "" { @@ -284,14 +335,14 @@ func toGatewayMessages(messages []NormalizedMessage) []GatewayMessage { item["content"] = parts } if len(msg.ToolCalls) > 0 { - payload := make([]map[string]interface{}, 0, len(msg.ToolCalls)) + payload := make([]map[string]any, 0, len(msg.ToolCalls)) for _, call := range msg.ToolCalls { if strings.TrimSpace(call.Function.Name) == "" { continue } - entry := map[string]interface{}{ + entry := map[string]any{ "type": "function", - "function": map[string]interface{}{ + "function": map[string]any{ "name": call.Function.Name, "arguments": call.Function.Arguments, }, @@ -316,14 +367,14 @@ func toGatewayMessages(messages []NormalizedMessage) []GatewayMessage { return converted } -func getString(value interface{}) string { +func getString(value any) string { if raw, ok := value.(string); ok { return raw } return "" } -func firstString(values ...interface{}) string { +func firstString(values ...any) string { for _, value := range values { if raw, ok := value.(string); ok && strings.TrimSpace(raw) != "" { return raw @@ -332,7 +383,7 @@ func firstString(values ...interface{}) string { return "" } -func toJSONString(values ...interface{}) string { +func toJSONString(values ...any) string { for _, value := range values { if value == nil { continue diff --git a/internal/chat/resolver.go b/internal/chat/resolver.go index f1fd77e6..b262398b 100644 --- a/internal/chat/resolver.go +++ b/internal/chat/resolver.go @@ -12,7 +12,6 @@ import ( "strings" "time" - "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" "github.com/memohai/memoh/internal/db/sqlc" @@ -24,6 +23,7 @@ import ( const defaultMaxContextMinutes = 24 * 60 +// Resolver orchestrates chat with the agent gateway. type Resolver struct { modelsService *models.Service queries *sqlc.Queries @@ -69,6 +69,48 @@ func NewResolver(log *slog.Logger, modelsService *models.Service, queries *sqlc. } } +// ---------- gateway payload types ---------- + +type gatewayModelConfig struct { + ModelID string `json:"modelId"` + ClientType string `json:"clientType"` + Input []string `json:"input"` + APIKey string `json:"apiKey"` + BaseURL string `json:"baseUrl"` +} + +type gatewayIdentity struct { + BotID string `json:"botId"` + SessionID string `json:"sessionId"` + ContainerID string `json:"containerId"` + ContactID string `json:"contactId"` + ContactName string `json:"contactName"` + ContactAlias string `json:"contactAlias,omitempty"` + UserID string `json:"userId,omitempty"` + CurrentPlatform string `json:"currentPlatform,omitempty"` + ReplyTarget string `json:"replyTarget,omitempty"` + SessionToken string `json:"sessionToken,omitempty"` +} + +type agentGatewayRequest struct { + Model gatewayModelConfig `json:"model"` + ActiveContextTime int `json:"activeContextTime"` + Platforms []string `json:"platforms"` + CurrentPlatform string `json:"currentPlatform"` + AllowedActions []string `json:"allowedActions,omitempty"` + Messages []GatewayMessage `json:"messages"` + Skills []string `json:"skills"` + Query string `json:"query"` + Identity gatewayIdentity `json:"identity"` +} + +type agentGatewayResponse struct { + Messages []GatewayMessage `json:"messages"` + Skills []string `json:"skills"` +} + +// ---------- Chat ---------- + func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, error) { if strings.TrimSpace(req.Query) == "" { return ChatResponse{}, fmt.Errorf("query is required") @@ -122,27 +164,39 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err } messages = sanitizeGatewayMessages(messages) messages = normalizeGatewayMessagesForModel(messages) - useSkills := normalizeSkills(append(historySkills, req.UseSkills...)) + skills := normalizeSkills(append(historySkills, req.Skills...)) + + containerID := r.resolveContainerID(ctx, req.BotID, req.ContainerID) payload := agentGatewayRequest{ - APIKey: provider.ApiKey, - BaseURL: provider.BaseUrl, - Model: chatModel.ModelID, - ClientType: clientType, - Locale: req.Locale, - Language: req.Language, - MaxSteps: req.MaxSteps, - MaxContextLoadTime: normalizeMaxContextLoad(maxContextLoadTime), - Platforms: req.Platforms, - CurrentPlatform: req.CurrentPlatform, - Messages: messages, - Query: req.Query, - Skills: req.Skills, - UseSkills: useSkills, - ToolContext: req.ToolContext, - ToolChoice: req.ToolChoice, + Model: gatewayModelConfig{ + ModelID: chatModel.ModelID, + ClientType: clientType, + Input: chatModel.Input, + APIKey: provider.ApiKey, + BaseURL: provider.BaseUrl, + }, + ActiveContextTime: normalizeMaxContextLoad(maxContextLoadTime), + Platforms: req.Platforms, + CurrentPlatform: req.CurrentPlatform, + AllowedActions: req.AllowedActions, + Messages: messages, + Skills: skills, + Query: req.Query, + Identity: gatewayIdentity{ + BotID: req.BotID, + SessionID: req.SessionID, + ContainerID: containerID, + ContactID: defaultString(req.ContactID, req.UserID, req.BotID), + ContactName: defaultString(req.ContactName, "User"), + ContactAlias: req.ContactAlias, + UserID: req.UserID, + CurrentPlatform: req.CurrentPlatform, + ReplyTarget: req.ReplyTarget, + SessionToken: req.SessionToken, + }, } - payload.Language = language + _ = language // language is embedded in system prompt by the gateway resp, err := r.postChat(ctx, payload, req.Token) if err != nil { @@ -165,6 +219,8 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err }, nil } +// ---------- TriggerSchedule ---------- + func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, schedule SchedulePayload, token string) error { if strings.TrimSpace(botID) == "" { return fmt.Errorf("bot id is required") @@ -177,8 +233,6 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, schedule S BotID: botID, SessionID: "schedule:" + schedule.ID, Query: schedule.Command, - Locale: "", - Language: "", } settings, err := r.loadUserSettings(ctx, "") if err != nil { @@ -193,7 +247,7 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, schedule S return err } - maxContextLoadTime, language, err := r.loadBotSettings(ctx, botID) + maxContextLoadTime, _, err := r.loadBotSettings(ctx, botID) if err != nil { return err } @@ -206,26 +260,31 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, schedule S if err != nil { return err } - useSkills := normalizeSkills(historySkills) + skills := normalizeSkills(historySkills) + containerID := r.resolveContainerID(ctx, botID, "") - payload := agentGatewayScheduleRequest{ - APIKey: provider.ApiKey, - BaseURL: provider.BaseUrl, - Model: chatModel.ModelID, - ClientType: clientType, - Locale: "", - Language: language, - MaxSteps: 0, - MaxContextLoadTime: normalizeMaxContextLoad(maxContextLoadTime), - Platforms: nil, - CurrentPlatform: "", - Messages: messages, - Query: schedule.Command, - Schedule: schedule, - UseSkills: useSkills, + payload := agentGatewayRequest{ + Model: gatewayModelConfig{ + ModelID: chatModel.ModelID, + ClientType: clientType, + Input: chatModel.Input, + APIKey: provider.ApiKey, + BaseURL: provider.BaseUrl, + }, + ActiveContextTime: normalizeMaxContextLoad(maxContextLoadTime), + Messages: messages, + Skills: skills, + Query: schedule.Command, + Identity: gatewayIdentity{ + BotID: botID, + SessionID: req.SessionID, + ContainerID: containerID, + ContactID: botID, + ContactName: "Scheduler", + }, } - resp, err := r.postSchedule(ctx, payload, token) + resp, err := r.postChat(ctx, payload, token) if err != nil { return err } @@ -239,6 +298,8 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, schedule S return nil } +// ---------- StreamChat ---------- + func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan StreamChunk, <-chan error) { chunkChan := make(chan StreamChunk) errChan := make(chan error, 1) @@ -308,27 +369,38 @@ func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan Stre } messages = sanitizeGatewayMessages(messages) messages = normalizeGatewayMessagesForModel(messages) - useSkills := normalizeSkills(append(historySkills, req.UseSkills...)) + skills := normalizeSkills(append(historySkills, req.Skills...)) + containerID := r.resolveContainerID(ctx, req.BotID, req.ContainerID) payload := agentGatewayRequest{ - APIKey: provider.ApiKey, - BaseURL: provider.BaseUrl, - Model: chatModel.ModelID, - ClientType: clientType, - Locale: req.Locale, - Language: req.Language, - MaxSteps: req.MaxSteps, - MaxContextLoadTime: normalizeMaxContextLoad(maxContextLoadTime), - Platforms: req.Platforms, - CurrentPlatform: req.CurrentPlatform, - Messages: messages, - Query: req.Query, - Skills: req.Skills, - UseSkills: useSkills, - ToolContext: req.ToolContext, - ToolChoice: req.ToolChoice, + Model: gatewayModelConfig{ + ModelID: chatModel.ModelID, + ClientType: clientType, + Input: chatModel.Input, + APIKey: provider.ApiKey, + BaseURL: provider.BaseUrl, + }, + ActiveContextTime: normalizeMaxContextLoad(maxContextLoadTime), + Platforms: req.Platforms, + CurrentPlatform: req.CurrentPlatform, + AllowedActions: req.AllowedActions, + Messages: messages, + Skills: skills, + Query: req.Query, + Identity: gatewayIdentity{ + BotID: req.BotID, + SessionID: req.SessionID, + ContainerID: containerID, + ContactID: defaultString(req.ContactID, req.UserID, req.BotID), + ContactName: defaultString(req.ContactName, "User"), + ContactAlias: req.ContactAlias, + UserID: req.UserID, + CurrentPlatform: req.CurrentPlatform, + ReplyTarget: req.ReplyTarget, + SessionToken: req.SessionToken, + }, } - payload.Language = language + _ = language if err := r.streamChat(ctx, payload, req.BotID, req.SessionID, req.Query, req.Token, chunkChan); err != nil { errChan <- err @@ -339,54 +411,15 @@ func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan Stre return chunkChan, errChan } -type agentGatewayRequest struct { - APIKey string `json:"apiKey"` - BaseURL string `json:"baseUrl"` - Model string `json:"model"` - ClientType string `json:"clientType"` - Locale string `json:"locale,omitempty"` - Language string `json:"language,omitempty"` - MaxSteps int `json:"maxSteps,omitempty"` - MaxContextLoadTime int `json:"maxContextLoadTime"` - Platforms []string `json:"platforms,omitempty"` - CurrentPlatform string `json:"currentPlatform,omitempty"` - Messages []GatewayMessage `json:"messages"` - Query string `json:"query"` - Skills []AgentSkill `json:"skills,omitempty"` - UseSkills []string `json:"useSkills,omitempty"` - ToolContext *ToolContext `json:"toolContext,omitempty"` - ToolChoice interface{} `json:"toolChoice,omitempty"` -} - -type agentGatewayScheduleRequest struct { - APIKey string `json:"apiKey"` - BaseURL string `json:"baseUrl"` - Model string `json:"model"` - ClientType string `json:"clientType"` - Locale string `json:"locale,omitempty"` - Language string `json:"language,omitempty"` - MaxSteps int `json:"maxSteps,omitempty"` - MaxContextLoadTime int `json:"maxContextLoadTime"` - Platforms []string `json:"platforms,omitempty"` - CurrentPlatform string `json:"currentPlatform,omitempty"` - Messages []GatewayMessage `json:"messages"` - Query string `json:"query"` - Schedule SchedulePayload `json:"schedule"` - Skills []AgentSkill `json:"skills,omitempty"` - UseSkills []string `json:"useSkills,omitempty"` -} - -type agentGatewayResponse struct { - Messages []GatewayMessage `json:"messages"` - Skills []string `json:"skills"` -} +// ---------- HTTP helpers ---------- func (r *Resolver) postChat(ctx context.Context, payload agentGatewayRequest, token string) (agentGatewayResponse, error) { body, err := json.Marshal(payload) if err != nil { return agentGatewayResponse{}, err } - url := r.gatewayBaseURL + "/chat" + url := r.gatewayBaseURL + "/chat/" + r.logger.Info("gateway request", slog.String("url", url), slog.String("body_prefix", truncate(string(body), 200))) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { return agentGatewayResponse{}, err @@ -408,6 +441,11 @@ func (r *Resolver) postChat(ctx context.Context, payload agentGatewayRequest, to } if resp.StatusCode < 200 || resp.StatusCode >= 300 { + r.logger.Error("gateway request failed", + slog.String("url", url), + slog.Int("status", resp.StatusCode), + slog.String("body_prefix", truncate(string(respBody), 300)), + ) return agentGatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody))) } @@ -419,44 +457,6 @@ func (r *Resolver) postChat(ctx context.Context, payload agentGatewayRequest, to return parsed, nil } -func (r *Resolver) postSchedule(ctx context.Context, payload agentGatewayScheduleRequest, token string) (agentGatewayResponse, error) { - body, err := json.Marshal(payload) - if err != nil { - return agentGatewayResponse{}, err - } - url := r.gatewayBaseURL + "/chat/schedule" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) - if err != nil { - return agentGatewayResponse{}, err - } - req.Header.Set("Content-Type", "application/json") - if strings.TrimSpace(token) != "" { - req.Header.Set("Authorization", token) - } - - resp, err := r.httpClient.Do(req) - if err != nil { - return agentGatewayResponse{}, err - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return agentGatewayResponse{}, err - } - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return agentGatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody))) - } - - parsed, err := parseAgentGatewayResponse(respBody) - if err != nil { - r.logger.Error("failed to parse schedule gateway response", slog.String("body", string(respBody)), slog.Any("error", err)) - return agentGatewayResponse{}, err - } - return parsed, nil -} - func (r *Resolver) streamChat(ctx context.Context, payload agentGatewayRequest, botID, sessionID, query, token string, chunkChan chan<- StreamChunk) error { body, err := json.Marshal(payload) if err != nil { @@ -524,6 +524,26 @@ func (r *Resolver) streamChat(ctx context.Context, payload agentGatewayRequest, return nil } +// ---------- container resolution ---------- + +func (r *Resolver) resolveContainerID(ctx context.Context, botID, explicit string) string { + if strings.TrimSpace(explicit) != "" { + return explicit + } + if r.queries != nil { + pgBotID, err := parseUUID(botID) + if err == nil { + row, err := r.queries.GetContainerByBotID(ctx, pgBotID) + if err == nil && strings.TrimSpace(row.ContainerID) != "" { + return row.ContainerID + } + } + } + return "mcp-" + botID +} + +// ---------- history helpers ---------- + func (r *Resolver) loadHistoryMessages(ctx context.Context, botID, sessionID string, maxContextLoadTime int) ([]GatewayMessage, error) { if r.historyService == nil { return nil, fmt.Errorf("history service not configured") @@ -567,6 +587,8 @@ func (r *Resolver) loadHistorySkills(ctx context.Context, botID, sessionID strin return normalizeSkills(combined), nil } +// ---------- store helpers ---------- + func (r *Resolver) storeHistory(ctx context.Context, botID, sessionID, query string, responseMessages []GatewayMessage, skills []string) error { if r.historyService == nil { return fmt.Errorf("history service not configured") @@ -581,15 +603,19 @@ func (r *Resolver) storeHistory(ctx context.Context, botID, sessionID, query str if strings.TrimSpace(query) == "" && len(responseMessages) == 0 { return nil } - messages := make([]map[string]interface{}, 0, len(responseMessages)) + messages := make([]map[string]any, 0, len(responseMessages)) for _, msg := range responseMessages { if msg == nil { continue } - messages = append(messages, map[string]interface{}(msg)) + messages = append(messages, map[string]any(msg)) + } + metadata := map[string]any{ + "query": strings.TrimSpace(query), } _, err := r.historyService.Create(ctx, botID, trimmedSession, history.CreateRequest{ Messages: messages, + Metadata: metadata, Skills: skills, }) return err @@ -642,12 +668,23 @@ func (r *Resolver) tryStoreFromStreamPayload(ctx context.Context, botID, session } } - // Case 2: data: {"type":"done","data":{messages:[...]}} + // Case 2: data: {"type":"agent_end","messages":[...],"skills":[...]} var envelope struct { - Type string `json:"type"` - Data json.RawMessage `json:"data"` + Type string `json:"type"` + Data json.RawMessage `json:"data"` + Messages json.RawMessage `json:"messages"` + Skills []string `json:"skills"` } if err := json.Unmarshal([]byte(data), &envelope); err == nil { + if envelope.Type == "agent_end" { + // agent_end with inline messages + if len(envelope.Messages) > 0 { + if parsed, ok := parseGatewayResponseFromRaw(envelope.Messages, envelope.Skills); ok { + parsed.Messages = normalizeGatewayMessages(parsed.Messages) + return r.storeRound(ctx, botID, sessionID, query, parsed.Messages, parsed.Skills) + } + } + } if envelope.Type == "done" && len(envelope.Data) > 0 { if parsed, ok := parseGatewayResponse(envelope.Data); ok { parsed.Messages = normalizeGatewayMessages(parsed.Messages) @@ -675,10 +712,27 @@ func parseGatewayResponse(payload []byte) (agentGatewayResponse, bool) { return parsed, true } +func parseGatewayResponseFromRaw(messagesRaw json.RawMessage, skills []string) (agentGatewayResponse, bool) { + var rawMessages []json.RawMessage + if err := json.Unmarshal(messagesRaw, &rawMessages); err != nil { + return agentGatewayResponse{}, false + } + messages := make([]GatewayMessage, 0, len(rawMessages)) + for _, rawMsg := range rawMessages { + var msg map[string]any + if err := json.Unmarshal(rawMsg, &msg); err != nil { + continue + } + messages = append(messages, GatewayMessage(msg)) + } + if len(messages) == 0 { + return agentGatewayResponse{}, false + } + return agentGatewayResponse{Messages: messages, Skills: skills}, true +} + // parseAgentGatewayResponse parses the agent gateway response with flexible message handling. -// It can handle various message formats from different AI SDK versions. func parseAgentGatewayResponse(payload []byte) (agentGatewayResponse, error) { - // Use json.RawMessage to handle flexible message formats var raw struct { Messages []json.RawMessage `json:"messages"` Skills []string `json:"skills"` @@ -689,20 +743,17 @@ func parseAgentGatewayResponse(payload []byte) (agentGatewayResponse, error) { messages := make([]GatewayMessage, 0, len(raw.Messages)) for _, rawMsg := range raw.Messages { - // Try parsing as object - var msg map[string]interface{} + var msg map[string]any if err := json.Unmarshal(rawMsg, &msg); err != nil { - // If it's an array, try to extract messages from it - var arr []interface{} + var arr []any if err := json.Unmarshal(rawMsg, &arr); err == nil { for _, item := range arr { - if m, ok := item.(map[string]interface{}); ok { + if m, ok := item.(map[string]any); ok { messages = append(messages, GatewayMessage(m)) } } continue } - // Skip unparseable messages continue } messages = append(messages, GatewayMessage(msg)) @@ -724,6 +775,154 @@ func (r *Resolver) storeRound(ctx context.Context, botID, sessionID, query strin return true, nil } +// ---------- model selection ---------- + +func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, settings userSettings) (models.GetResponse, sqlc.LlmProvider, error) { + if r.modelsService == nil { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured") + } + modelID := strings.TrimSpace(req.Model) + providerFilter := strings.TrimSpace(req.Provider) + + if modelID != "" && providerFilter == "" { + model, err := r.modelsService.GetByModelID(ctx, modelID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + if model.Type != models.ModelTypeChat { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model") + } + provider, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + return model, provider, nil + } + + if providerFilter == "" && modelID == "" && strings.TrimSpace(settings.ChatModelID) != "" { + selected, err := r.modelsService.GetByModelID(ctx, settings.ChatModelID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("settings chat model not found: %w", err) + } + if selected.Type != models.ModelTypeChat { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("settings chat model is not a chat model") + } + provider, err := models.FetchProviderByID(ctx, r.queries, selected.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + return selected, provider, nil + } + + var candidates []models.GetResponse + var err error + if providerFilter != "" { + candidates, err = r.modelsService.ListByClientType(ctx, models.ClientType(providerFilter)) + } else { + candidates, err = r.modelsService.ListByType(ctx, models.ModelTypeChat) + } + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + + filtered := make([]models.GetResponse, 0, len(candidates)) + for _, model := range candidates { + if model.Type != models.ModelTypeChat { + continue + } + filtered = append(filtered, model) + } + if len(filtered) == 0 { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available") + } + + if modelID != "" { + for _, model := range filtered { + if model.ModelID == modelID { + provider, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + return model, provider, nil + } + } + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not found") + } + + selected := filtered[0] + provider, err := models.FetchProviderByID(ctx, r.queries, selected.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + return selected, provider, nil +} + +// ---------- settings helpers ---------- + +func normalizeMaxContextLoad(value int) int { + if value <= 0 { + return defaultMaxContextMinutes + } + return value +} + +func (r *Resolver) loadUserSettings(ctx context.Context, userID string) (userSettings, error) { + defaults := userSettings{ + MaxContextLoadTime: defaultMaxContextMinutes, + Language: settings.DefaultLanguage, + } + if r.settingsService == nil || strings.TrimSpace(userID) == "" { + return defaults, nil + } + settingsRow, err := r.settingsService.Get(ctx, userID) + if err != nil { + return userSettings{}, err + } + maxLoad := settingsRow.MaxContextLoadTime + if maxLoad <= 0 { + maxLoad = defaultMaxContextMinutes + } + language := strings.TrimSpace(settingsRow.Language) + if language == "" || language == "auto" { + language = settings.DefaultLanguage + } + return userSettings{ + ChatModelID: strings.TrimSpace(settingsRow.ChatModelID), + MemoryModelID: strings.TrimSpace(settingsRow.MemoryModelID), + EmbeddingModelID: strings.TrimSpace(settingsRow.EmbeddingModelID), + MaxContextLoadTime: maxLoad, + Language: language, + }, nil +} + +func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (int, string, error) { + if r.settingsService == nil { + return settings.DefaultMaxContextLoadTime, settings.DefaultLanguage, nil + } + settingsRow, err := r.settingsService.GetBot(ctx, botID) + if err != nil { + return 0, "", err + } + return settingsRow.MaxContextLoadTime, settingsRow.Language, nil +} + +// ---------- utility ---------- + +func normalizeClientType(clientType string) (string, error) { + switch strings.ToLower(strings.TrimSpace(clientType)) { + case "openai": + return "openai", nil + case "openai-compat": + return "openai", nil + case "anthropic": + return "anthropic", nil + case "google": + return "google", nil + default: + return "", fmt.Errorf("unsupported agent gateway client type: %s", clientType) + } +} + func normalizeSkills(skills []string) []string { seen := map[string]struct{}{} normalized := make([]string, 0, len(skills)) @@ -832,13 +1031,13 @@ func isMeaningfulGatewayMessage(msg GatewayMessage) bool { return false } -func isEmptyValue(value interface{}) bool { +func isEmptyValue(value any) bool { switch v := value.(type) { case nil: return true case string: return strings.TrimSpace(v) == "" - case []interface{}: + case []any: if len(v) == 0 { return true } @@ -848,7 +1047,7 @@ func isEmptyValue(value interface{}) bool { } } return true - case map[string]interface{}: + case map[string]any: if len(v) == 0 { return true } @@ -863,155 +1062,38 @@ func isEmptyValue(value interface{}) bool { } } -func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, settings userSettings) (models.GetResponse, sqlc.LlmProvider, error) { - if r.modelsService == nil { - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured") +func truncate(s string, n int) string { + if len(s) <= n { + return s } - modelID := strings.TrimSpace(req.Model) - providerFilter := strings.TrimSpace(req.Provider) - - if modelID != "" && providerFilter == "" { - model, err := r.modelsService.GetByModelID(ctx, modelID) - if err != nil { - return models.GetResponse{}, sqlc.LlmProvider{}, err - } - if model.Type != models.ModelTypeChat { - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model") - } - provider, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID) - if err != nil { - return models.GetResponse{}, sqlc.LlmProvider{}, err - } - return model, provider, nil - } - - if providerFilter == "" && modelID == "" && strings.TrimSpace(settings.ChatModelID) != "" { - selected, err := r.modelsService.GetByModelID(ctx, settings.ChatModelID) - if err != nil { - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("settings chat model not found: %w", err) - } - if selected.Type != models.ModelTypeChat { - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("settings chat model is not a chat model") - } - provider, err := models.FetchProviderByID(ctx, r.queries, selected.LlmProviderID) - if err != nil { - return models.GetResponse{}, sqlc.LlmProvider{}, err - } - return selected, provider, nil - } - - var candidates []models.GetResponse - var err error - if providerFilter != "" { - candidates, err = r.modelsService.ListByClientType(ctx, models.ClientType(providerFilter)) - } else { - candidates, err = r.modelsService.ListByType(ctx, models.ModelTypeChat) - } - if err != nil { - return models.GetResponse{}, sqlc.LlmProvider{}, err - } - - filtered := make([]models.GetResponse, 0, len(candidates)) - for _, model := range candidates { - if model.Type != models.ModelTypeChat { - continue - } - filtered = append(filtered, model) - } - if len(filtered) == 0 { - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available") - } - - if modelID != "" { - for _, model := range filtered { - if model.ModelID == modelID { - provider, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID) - if err != nil { - return models.GetResponse{}, sqlc.LlmProvider{}, err - } - return model, provider, nil - } - } - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not found") - } - - selected := filtered[0] - provider, err := models.FetchProviderByID(ctx, r.queries, selected.LlmProviderID) - if err != nil { - return models.GetResponse{}, sqlc.LlmProvider{}, err - } - return selected, provider, nil + return s[:n] + "..." } -func normalizeMaxContextLoad(value int) int { - if value <= 0 { - return defaultMaxContextMinutes - } - return value -} - -func (r *Resolver) loadUserSettings(ctx context.Context, userID string) (userSettings, error) { - defaults := userSettings{ - MaxContextLoadTime: defaultMaxContextMinutes, - Language: "Same as user input", - } - if r.settingsService == nil || strings.TrimSpace(userID) == "" { - return defaults, nil - } - settingsRow, err := r.settingsService.Get(ctx, userID) - if err != nil { - return userSettings{}, err - } - maxLoad := settingsRow.MaxContextLoadTime - if maxLoad <= 0 { - maxLoad = defaultMaxContextMinutes - } - language := strings.TrimSpace(settingsRow.Language) - if language == "" { - language = "Same as user input" - } - return userSettings{ - ChatModelID: strings.TrimSpace(settingsRow.ChatModelID), - MemoryModelID: strings.TrimSpace(settingsRow.MemoryModelID), - EmbeddingModelID: strings.TrimSpace(settingsRow.EmbeddingModelID), - MaxContextLoadTime: maxLoad, - Language: language, - }, nil -} - -func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (int, string, error) { - if r.settingsService == nil { - return settings.DefaultMaxContextLoadTime, settings.DefaultLanguage, nil - } - settingsRow, err := r.settingsService.GetBot(ctx, botID) - if err != nil { - return 0, "", err - } - return settingsRow.MaxContextLoadTime, settingsRow.Language, nil -} - -func normalizeClientType(clientType string) (string, error) { - switch strings.ToLower(strings.TrimSpace(clientType)) { - case "openai": - return "openai", nil - case "openai-compat": - return "openai", nil - case "anthropic": - return "anthropic", nil - case "google": - return "google", nil - default: - return "", fmt.Errorf("unsupported agent gateway client type: %s", clientType) +func defaultString(values ...string) string { + for _, v := range values { + if strings.TrimSpace(v) != "" { + return v + } } + return "" } func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(id) + parsed, err := parseUUIDHelper(id) if err != nil { return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) } + return parsed, nil +} + +func parseUUIDHelper(id string) (pgtype.UUID, error) { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + return pgtype.UUID{}, fmt.Errorf("empty id") + } var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) + if err := pgID.Scan(trimmed); err != nil { + return pgtype.UUID{}, err + } return pgID, nil } diff --git a/internal/chat/types.go b/internal/chat/types.go index 505fa3f2..295cbf6f 100644 --- a/internal/chat/types.go +++ b/internal/chat/types.go @@ -7,33 +7,29 @@ type Message struct { Content string `json:"content"` } -type GatewayMessage map[string]interface{} - -type AgentSkill struct { - Name string `json:"name"` - Description string `json:"description"` - Content string `json:"content"` -} +type GatewayMessage map[string]any type ChatRequest struct { BotID string `json:"-"` SessionID string `json:"-"` Token string `json:"-"` UserID string `json:"-"` + ContainerID string `json:"-"` + ContactID string `json:"-"` + ContactName string `json:"-"` + ContactAlias string `json:"-"` + ReplyTarget string `json:"-"` + SessionToken string `json:"-"` Query string `json:"query"` Model string `json:"model,omitempty"` Provider string `json:"provider,omitempty"` MaxContextLoadTime int `json:"max_context_load_time,omitempty"` - Locale string `json:"locale,omitempty"` Language string `json:"language,omitempty"` - MaxSteps int `json:"max_steps,omitempty"` Platforms []string `json:"platforms,omitempty"` CurrentPlatform string `json:"current_platform,omitempty"` Messages []GatewayMessage `json:"messages,omitempty"` - Skills []AgentSkill `json:"skills,omitempty"` - UseSkills []string `json:"use_skills,omitempty"` - ToolContext *ToolContext `json:"toolContext,omitempty"` - ToolChoice map[string]any `json:"toolChoice,omitempty"` + Skills []string `json:"skills,omitempty"` + AllowedActions []string `json:"allowed_actions,omitempty"` } type ChatResponse struct { @@ -54,19 +50,7 @@ type SchedulePayload struct { Command string `json:"command"` } -type ToolContext struct { - BotID string `json:"botId,omitempty"` - SessionID string `json:"sessionId,omitempty"` - CurrentPlatform string `json:"currentPlatform,omitempty"` - ReplyTarget string `json:"replyTarget,omitempty"` - SessionToken string `json:"sessionToken,omitempty"` - ContactID string `json:"contactId,omitempty"` - ContactName string `json:"contactName,omitempty"` - ContactAlias string `json:"contactAlias,omitempty"` - UserID string `json:"userId,omitempty"` -} - -// NormalizedMessage 是内部统一后的消息结构,屏蔽厂商差异。 +// NormalizedMessage is the internal unified message structure. type NormalizedMessage struct { Role string `json:"role"` Content string `json:"content,omitempty"` @@ -77,8 +61,14 @@ type NormalizedMessage struct { } type ContentPart struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + URL string `json:"url,omitempty"` + Styles []string `json:"styles,omitempty"` + Language string `json:"language,omitempty"` + UserID string `json:"user_id,omitempty"` + Emoji string `json:"emoji,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } type ToolCall struct { diff --git a/internal/contacts/service.go b/internal/contacts/service.go index e7391a81..1a8a40c3 100644 --- a/internal/contacts/service.go +++ b/internal/contacts/service.go @@ -77,6 +77,29 @@ func (s *Service) GetByChannelIdentity(ctx context.Context, botID, platform, ext return normalizeContactChannel(row) } +func (s *Service) ListChannelsByContact(ctx context.Context, contactID string) ([]ContactChannel, error) { + if s.queries == nil { + return nil, fmt.Errorf("contacts queries not configured") + } + pgContactID, err := parseUUID(contactID) + if err != nil { + return nil, err + } + rows, err := s.queries.ListContactChannelsByContact(ctx, pgContactID) + if err != nil { + return nil, err + } + items := make([]ContactChannel, 0, len(rows)) + for _, row := range rows { + item, err := normalizeContactChannel(row) + if err != nil { + return nil, err + } + items = append(items, item) + } + return items, nil +} + func (s *Service) ListByBot(ctx context.Context, botID string) ([]Contact, error) { if s.queries == nil { return nil, fmt.Errorf("contacts queries not configured") @@ -242,7 +265,7 @@ func (s *Service) BindUser(ctx context.Context, contactID, userID string) (Conta return normalizeContact(row) } -func (s *Service) UpsertChannel(ctx context.Context, botID, contactID, platform, externalID string, metadata map[string]interface{}) (ContactChannel, error) { +func (s *Service) UpsertChannel(ctx context.Context, botID, contactID, platform, externalID string, metadata map[string]any) (ContactChannel, error) { if s.queries == nil { return ContactChannel{}, fmt.Errorf("contacts queries not configured") } @@ -271,72 +294,6 @@ func (s *Service) UpsertChannel(ctx context.Context, botID, contactID, platform, return normalizeContactChannel(row) } -func (s *Service) CreateBindToken(ctx context.Context, botID, contactID, targetPlatform, targetExternalID, issuedByUserID string, ttl time.Duration) (BindToken, error) { - if s.queries == nil { - return BindToken{}, fmt.Errorf("contacts queries not configured") - } - if ttl <= 0 { - ttl = 10 * time.Minute - } - pgBotID, err := parseUUID(botID) - if err != nil { - return BindToken{}, err - } - pgContactID, err := parseUUID(contactID) - if err != nil { - return BindToken{}, err - } - pgIssuedBy := pgtype.UUID{Valid: false} - if strings.TrimSpace(issuedByUserID) != "" { - parsed, err := parseUUID(issuedByUserID) - if err != nil { - return BindToken{}, err - } - pgIssuedBy = parsed - } - token := strings.ReplaceAll(uuid.NewString(), "-", "")[:8] - expiresAt := time.Now().UTC().Add(ttl) - row, err := s.queries.CreateContactBindToken(ctx, sqlc.CreateContactBindTokenParams{ - BotID: pgBotID, - ContactID: pgContactID, - Token: token, - TargetPlatform: pgtype.Text{String: strings.TrimSpace(targetPlatform), Valid: strings.TrimSpace(targetPlatform) != ""}, - TargetExternalID: pgtype.Text{String: strings.TrimSpace(targetExternalID), Valid: strings.TrimSpace(targetExternalID) != ""}, - IssuedByUserID: pgIssuedBy, - ExpiresAt: pgtype.Timestamptz{Time: expiresAt, Valid: true}, - }) - if err != nil { - return BindToken{}, err - } - return normalizeBindToken(row) -} - -func (s *Service) GetBindToken(ctx context.Context, token string) (BindToken, error) { - if s.queries == nil { - return BindToken{}, fmt.Errorf("contacts queries not configured") - } - row, err := s.queries.GetContactBindToken(ctx, strings.TrimSpace(token)) - if err != nil { - return BindToken{}, err - } - return normalizeBindToken(row) -} - -func (s *Service) MarkBindTokenUsed(ctx context.Context, id string) (BindToken, error) { - if s.queries == nil { - return BindToken{}, fmt.Errorf("contacts queries not configured") - } - pgID, err := parseUUID(id) - if err != nil { - return BindToken{}, err - } - row, err := s.queries.MarkContactBindTokenUsed(ctx, pgID) - if err != nil { - return BindToken{}, err - } - return normalizeBindToken(row) -} - func normalizeContact(row sqlc.Contact) (Contact, error) { metadata, err := decodeMetadata(row.Metadata) if err != nil { @@ -373,38 +330,23 @@ func normalizeContactChannel(row sqlc.ContactChannel) (ContactChannel, error) { }, nil } -func normalizeBindToken(row sqlc.ContactBindToken) (BindToken, error) { - return BindToken{ - ID: toUUIDString(row.ID), - BotID: toUUIDString(row.BotID), - ContactID: toUUIDString(row.ContactID), - Token: strings.TrimSpace(row.Token), - TargetPlatform: strings.TrimSpace(row.TargetPlatform.String), - TargetExternalID: strings.TrimSpace(row.TargetExternalID.String), - IssuedByUserID: toUUIDString(row.IssuedByUserID), - ExpiresAt: timeFromPg(row.ExpiresAt), - UsedAt: timeFromPg(row.UsedAt), - CreatedAt: timeFromPg(row.CreatedAt), - }, nil -} - -func decodeMetadata(raw []byte) (map[string]interface{}, error) { +func decodeMetadata(raw []byte) (map[string]any, error) { if len(raw) == 0 { - return map[string]interface{}{}, nil + return map[string]any{}, nil } - var payload map[string]interface{} + var payload map[string]any if err := json.Unmarshal(raw, &payload); err != nil { return nil, err } if payload == nil { - payload = map[string]interface{}{} + payload = map[string]any{} } return payload, nil } -func defaultMetadata(value map[string]interface{}) map[string]interface{} { +func defaultMetadata(value map[string]any) map[string]any { if value == nil { - return map[string]interface{}{} + return map[string]any{} } return value } diff --git a/internal/contacts/types.go b/internal/contacts/types.go index 2d071acf..f39ff1e3 100644 --- a/internal/contacts/types.go +++ b/internal/contacts/types.go @@ -10,7 +10,7 @@ type Contact struct { Alias string Tags []string Status string - Metadata map[string]interface{} + Metadata map[string]any CreatedAt time.Time UpdatedAt time.Time } @@ -21,24 +21,11 @@ type ContactChannel struct { ContactID string Platform string ExternalID string - Metadata map[string]interface{} + Metadata map[string]any CreatedAt time.Time UpdatedAt time.Time } -type BindToken struct { - ID string - BotID string - ContactID string - Token string - TargetPlatform string - TargetExternalID string - IssuedByUserID string - ExpiresAt time.Time - UsedAt time.Time - CreatedAt time.Time -} - type CreateRequest struct { BotID string UserID string @@ -46,7 +33,7 @@ type CreateRequest struct { Alias string Tags []string Status string - Metadata map[string]interface{} + Metadata map[string]any } type UpdateRequest struct { @@ -54,5 +41,5 @@ type UpdateRequest struct { Alias *string Tags *[]string Status *string - Metadata map[string]interface{} + Metadata map[string]any } diff --git a/internal/db/sqlc/channels.sql.go b/internal/db/sqlc/channels.sql.go index b04aad37..c38738d8 100644 --- a/internal/db/sqlc/channels.sql.go +++ b/internal/db/sqlc/channels.sql.go @@ -86,26 +86,15 @@ func (q *Queries) GetBotChannelConfigByExternalIdentity(ctx context.Context, arg } const getChannelSessionByID = `-- name: GetChannelSessionByID :one -SELECT session_id, bot_id, channel_config_id, user_id, contact_id, platform, created_at, updated_at +SELECT session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata, created_at, updated_at FROM channel_sessions WHERE session_id = $1 LIMIT 1 ` -type GetChannelSessionByIDRow struct { - SessionID string `json:"session_id"` - BotID pgtype.UUID `json:"bot_id"` - ChannelConfigID pgtype.UUID `json:"channel_config_id"` - UserID pgtype.UUID `json:"user_id"` - ContactID pgtype.UUID `json:"contact_id"` - Platform string `json:"platform"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` -} - -func (q *Queries) GetChannelSessionByID(ctx context.Context, sessionID string) (GetChannelSessionByIDRow, error) { +func (q *Queries) GetChannelSessionByID(ctx context.Context, sessionID string) (ChannelSession, error) { row := q.db.QueryRow(ctx, getChannelSessionByID, sessionID) - var i GetChannelSessionByIDRow + var i ChannelSession err := row.Scan( &i.SessionID, &i.BotID, @@ -113,6 +102,9 @@ func (q *Queries) GetChannelSessionByID(ctx context.Context, sessionID string) ( &i.UserID, &i.ContactID, &i.Platform, + &i.ReplyTarget, + &i.ThreadID, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, ) @@ -185,6 +177,50 @@ func (q *Queries) ListBotChannelConfigsByType(ctx context.Context, channelType s return items, nil } +const listChannelSessionsByBotPlatform = `-- name: ListChannelSessionsByBotPlatform :many +SELECT session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata, created_at, updated_at +FROM channel_sessions +WHERE bot_id = $1 AND platform = $2 +ORDER BY updated_at DESC +` + +type ListChannelSessionsByBotPlatformParams struct { + BotID pgtype.UUID `json:"bot_id"` + Platform string `json:"platform"` +} + +func (q *Queries) ListChannelSessionsByBotPlatform(ctx context.Context, arg ListChannelSessionsByBotPlatformParams) ([]ChannelSession, error) { + rows, err := q.db.Query(ctx, listChannelSessionsByBotPlatform, arg.BotID, arg.Platform) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ChannelSession + for rows.Next() { + var i ChannelSession + if err := rows.Scan( + &i.SessionID, + &i.BotID, + &i.ChannelConfigID, + &i.UserID, + &i.ContactID, + &i.Platform, + &i.ReplyTarget, + &i.ThreadID, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listUserChannelBindingsByType = `-- name: ListUserChannelBindingsByType :many SELECT id, user_id, channel_type, config, created_at, updated_at FROM user_channel_bindings @@ -280,8 +316,8 @@ func (q *Queries) UpsertBotChannelConfig(ctx context.Context, arg UpsertBotChann } const upsertChannelSession = `-- name: UpsertChannelSession :one -INSERT INTO channel_sessions (session_id, bot_id, channel_config_id, user_id, contact_id, platform) -VALUES ($1, $2, $3, $4, $5, $6) +INSERT INTO channel_sessions (session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) ON CONFLICT (session_id) DO UPDATE SET bot_id = EXCLUDED.bot_id, @@ -289,8 +325,11 @@ DO UPDATE SET user_id = EXCLUDED.user_id, contact_id = EXCLUDED.contact_id, platform = EXCLUDED.platform, + reply_target = EXCLUDED.reply_target, + thread_id = EXCLUDED.thread_id, + metadata = EXCLUDED.metadata, updated_at = now() -RETURNING session_id, bot_id, channel_config_id, user_id, contact_id, platform, created_at, updated_at +RETURNING session_id, bot_id, channel_config_id, user_id, contact_id, platform, reply_target, thread_id, metadata, created_at, updated_at ` type UpsertChannelSessionParams struct { @@ -300,20 +339,12 @@ type UpsertChannelSessionParams struct { UserID pgtype.UUID `json:"user_id"` ContactID pgtype.UUID `json:"contact_id"` Platform string `json:"platform"` + ReplyTarget pgtype.Text `json:"reply_target"` + ThreadID pgtype.Text `json:"thread_id"` + Metadata []byte `json:"metadata"` } -type UpsertChannelSessionRow struct { - SessionID string `json:"session_id"` - BotID pgtype.UUID `json:"bot_id"` - ChannelConfigID pgtype.UUID `json:"channel_config_id"` - UserID pgtype.UUID `json:"user_id"` - ContactID pgtype.UUID `json:"contact_id"` - Platform string `json:"platform"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` -} - -func (q *Queries) UpsertChannelSession(ctx context.Context, arg UpsertChannelSessionParams) (UpsertChannelSessionRow, error) { +func (q *Queries) UpsertChannelSession(ctx context.Context, arg UpsertChannelSessionParams) (ChannelSession, error) { row := q.db.QueryRow(ctx, upsertChannelSession, arg.SessionID, arg.BotID, @@ -321,8 +352,11 @@ func (q *Queries) UpsertChannelSession(ctx context.Context, arg UpsertChannelSes arg.UserID, arg.ContactID, arg.Platform, + arg.ReplyTarget, + arg.ThreadID, + arg.Metadata, ) - var i UpsertChannelSessionRow + var i ChannelSession err := row.Scan( &i.SessionID, &i.BotID, @@ -330,6 +364,9 @@ func (q *Queries) UpsertChannelSession(ctx context.Context, arg UpsertChannelSes &i.UserID, &i.ContactID, &i.Platform, + &i.ReplyTarget, + &i.ThreadID, + &i.Metadata, &i.CreatedAt, &i.UpdatedAt, ) diff --git a/internal/db/sqlc/contacts.sql.go b/internal/db/sqlc/contacts.sql.go index 64d9739d..3cf19028 100644 --- a/internal/db/sqlc/contacts.sql.go +++ b/internal/db/sqlc/contacts.sql.go @@ -53,73 +53,6 @@ func (q *Queries) CreateContact(ctx context.Context, arg CreateContactParams) (C return i, err } -const createContactBindToken = `-- name: CreateContactBindToken :one -INSERT INTO contact_bind_tokens (bot_id, contact_id, token, target_platform, target_external_id, issued_by_user_id, expires_at) -VALUES ($1, $2, $3, $4, $5, $6, $7) -RETURNING id, bot_id, contact_id, token, target_platform, target_external_id, issued_by_user_id, expires_at, used_at, created_at -` - -type CreateContactBindTokenParams struct { - BotID pgtype.UUID `json:"bot_id"` - ContactID pgtype.UUID `json:"contact_id"` - Token string `json:"token"` - TargetPlatform pgtype.Text `json:"target_platform"` - TargetExternalID pgtype.Text `json:"target_external_id"` - IssuedByUserID pgtype.UUID `json:"issued_by_user_id"` - ExpiresAt pgtype.Timestamptz `json:"expires_at"` -} - -func (q *Queries) CreateContactBindToken(ctx context.Context, arg CreateContactBindTokenParams) (ContactBindToken, error) { - row := q.db.QueryRow(ctx, createContactBindToken, - arg.BotID, - arg.ContactID, - arg.Token, - arg.TargetPlatform, - arg.TargetExternalID, - arg.IssuedByUserID, - arg.ExpiresAt, - ) - var i ContactBindToken - err := row.Scan( - &i.ID, - &i.BotID, - &i.ContactID, - &i.Token, - &i.TargetPlatform, - &i.TargetExternalID, - &i.IssuedByUserID, - &i.ExpiresAt, - &i.UsedAt, - &i.CreatedAt, - ) - return i, err -} - -const getContactBindToken = `-- name: GetContactBindToken :one -SELECT id, bot_id, contact_id, token, target_platform, target_external_id, issued_by_user_id, expires_at, used_at, created_at -FROM contact_bind_tokens -WHERE token = $1 -LIMIT 1 -` - -func (q *Queries) GetContactBindToken(ctx context.Context, token string) (ContactBindToken, error) { - row := q.db.QueryRow(ctx, getContactBindToken, token) - var i ContactBindToken - err := row.Scan( - &i.ID, - &i.BotID, - &i.ContactID, - &i.Token, - &i.TargetPlatform, - &i.TargetExternalID, - &i.IssuedByUserID, - &i.ExpiresAt, - &i.UsedAt, - &i.CreatedAt, - ) - return i, err -} - const getContactByID = `-- name: GetContactByID :one SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at FROM contacts @@ -278,31 +211,6 @@ func (q *Queries) ListContactsByBot(ctx context.Context, botID pgtype.UUID) ([]C return items, nil } -const markContactBindTokenUsed = `-- name: MarkContactBindTokenUsed :one -UPDATE contact_bind_tokens -SET used_at = now() -WHERE id = $1 -RETURNING id, bot_id, contact_id, token, target_platform, target_external_id, issued_by_user_id, expires_at, used_at, created_at -` - -func (q *Queries) MarkContactBindTokenUsed(ctx context.Context, id pgtype.UUID) (ContactBindToken, error) { - row := q.db.QueryRow(ctx, markContactBindTokenUsed, id) - var i ContactBindToken - err := row.Scan( - &i.ID, - &i.BotID, - &i.ContactID, - &i.Token, - &i.TargetPlatform, - &i.TargetExternalID, - &i.IssuedByUserID, - &i.ExpiresAt, - &i.UsedAt, - &i.CreatedAt, - ) - return i, err -} - const searchContacts = `-- name: SearchContacts :many SELECT id, bot_id, user_id, display_name, alias, tags, status, metadata, created_at, updated_at FROM contacts diff --git a/internal/db/sqlc/containers.sql.go b/internal/db/sqlc/containers.sql.go index 3dddad7f..7771f34a 100644 --- a/internal/db/sqlc/containers.sql.go +++ b/internal/db/sqlc/containers.sql.go @@ -11,6 +11,32 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) +const getContainerByBotID = `-- name: GetContainerByBotID :one +SELECT id, bot_id, container_id, container_name, image, status, namespace, auto_start, host_path, container_path, created_at, updated_at, last_started_at, last_stopped_at FROM containers WHERE bot_id = $1 ORDER BY updated_at DESC LIMIT 1 +` + +func (q *Queries) GetContainerByBotID(ctx context.Context, botID pgtype.UUID) (Container, error) { + row := q.db.QueryRow(ctx, getContainerByBotID, botID) + var i Container + err := row.Scan( + &i.ID, + &i.BotID, + &i.ContainerID, + &i.ContainerName, + &i.Image, + &i.Status, + &i.Namespace, + &i.AutoStart, + &i.HostPath, + &i.ContainerPath, + &i.CreatedAt, + &i.UpdatedAt, + &i.LastStartedAt, + &i.LastStoppedAt, + ) + return i, err +} + const getContainerByContainerID = `-- name: GetContainerByContainerID :one SELECT id, bot_id, container_id, container_name, image, status, namespace, auto_start, host_path, container_path, created_at, updated_at, last_started_at, last_stopped_at FROM containers WHERE container_id = $1 ` diff --git a/internal/db/sqlc/history.sql.go b/internal/db/sqlc/history.sql.go index b8f2b038..0fc2033c 100644 --- a/internal/db/sqlc/history.sql.go +++ b/internal/db/sqlc/history.sql.go @@ -12,15 +12,16 @@ import ( ) const createHistory = `-- name: CreateHistory :one -INSERT INTO history (bot_id, session_id, messages, skills, timestamp) -VALUES ($1, $2, $3, $4, $5) -RETURNING id, bot_id, session_id, messages, skills, timestamp +INSERT INTO history (bot_id, session_id, messages, metadata, skills, timestamp) +VALUES ($1, $2, $3, $4, $5, $6) +RETURNING id, bot_id, session_id, messages, metadata, skills, timestamp ` type CreateHistoryParams struct { BotID pgtype.UUID `json:"bot_id"` SessionID string `json:"session_id"` Messages []byte `json:"messages"` + Metadata []byte `json:"metadata"` Skills []string `json:"skills"` Timestamp pgtype.Timestamptz `json:"timestamp"` } @@ -30,6 +31,7 @@ func (q *Queries) CreateHistory(ctx context.Context, arg CreateHistoryParams) (H arg.BotID, arg.SessionID, arg.Messages, + arg.Metadata, arg.Skills, arg.Timestamp, ) @@ -39,6 +41,7 @@ func (q *Queries) CreateHistory(ctx context.Context, arg CreateHistoryParams) (H &i.BotID, &i.SessionID, &i.Messages, + &i.Metadata, &i.Skills, &i.Timestamp, ) @@ -71,7 +74,7 @@ func (q *Queries) DeleteHistoryByID(ctx context.Context, id pgtype.UUID) error { } const getHistoryByID = `-- name: GetHistoryByID :one -SELECT id, bot_id, session_id, messages, skills, timestamp +SELECT id, bot_id, session_id, messages, metadata, skills, timestamp FROM history WHERE id = $1 ` @@ -84,6 +87,7 @@ func (q *Queries) GetHistoryByID(ctx context.Context, id pgtype.UUID) (History, &i.BotID, &i.SessionID, &i.Messages, + &i.Metadata, &i.Skills, &i.Timestamp, ) @@ -91,7 +95,7 @@ func (q *Queries) GetHistoryByID(ctx context.Context, id pgtype.UUID) (History, } const listHistoryByBotSession = `-- name: ListHistoryByBotSession :many -SELECT id, bot_id, session_id, messages, skills, timestamp +SELECT id, bot_id, session_id, messages, metadata, skills, timestamp FROM history WHERE bot_id = $1 AND session_id = $2 ORDER BY timestamp DESC @@ -118,6 +122,7 @@ func (q *Queries) ListHistoryByBotSession(ctx context.Context, arg ListHistoryBy &i.BotID, &i.SessionID, &i.Messages, + &i.Metadata, &i.Skills, &i.Timestamp, ); err != nil { @@ -132,7 +137,7 @@ func (q *Queries) ListHistoryByBotSession(ctx context.Context, arg ListHistoryBy } const listHistoryByBotSessionSince = `-- name: ListHistoryByBotSessionSince :many -SELECT id, bot_id, session_id, messages, skills, timestamp +SELECT id, bot_id, session_id, messages, metadata, skills, timestamp FROM history WHERE bot_id = $1 AND session_id = $2 AND timestamp >= $3 ORDER BY timestamp ASC @@ -158,6 +163,7 @@ func (q *Queries) ListHistoryByBotSessionSince(ctx context.Context, arg ListHist &i.BotID, &i.SessionID, &i.Messages, + &i.Metadata, &i.Skills, &i.Timestamp, ); err != nil { diff --git a/internal/db/sqlc/models.go b/internal/db/sqlc/models.go index cc53d80a..44b30a0a 100644 --- a/internal/db/sqlc/models.go +++ b/internal/db/sqlc/models.go @@ -49,6 +49,16 @@ type BotModelConfig struct { MemoryModelID pgtype.UUID `json:"memory_model_id"` } +type BotPreauthKey struct { + ID pgtype.UUID `json:"id"` + BotID pgtype.UUID `json:"bot_id"` + Token string `json:"token"` + IssuedByUserID pgtype.UUID `json:"issued_by_user_id"` + ExpiresAt pgtype.Timestamptz `json:"expires_at"` + UsedAt pgtype.Timestamptz `json:"used_at"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + type BotSetting struct { BotID pgtype.UUID `json:"bot_id"` MaxContextLoadTime int32 `json:"max_context_load_time"` @@ -61,10 +71,13 @@ type ChannelSession struct { BotID pgtype.UUID `json:"bot_id"` ChannelConfigID pgtype.UUID `json:"channel_config_id"` UserID pgtype.UUID `json:"user_id"` + ContactID pgtype.UUID `json:"contact_id"` Platform string `json:"platform"` + ReplyTarget pgtype.Text `json:"reply_target"` + ThreadID pgtype.Text `json:"thread_id"` + Metadata []byte `json:"metadata"` CreatedAt pgtype.Timestamptz `json:"created_at"` UpdatedAt pgtype.Timestamptz `json:"updated_at"` - ContactID pgtype.UUID `json:"contact_id"` } type Contact struct { @@ -80,19 +93,6 @@ type Contact struct { UpdatedAt pgtype.Timestamptz `json:"updated_at"` } -type ContactBindToken struct { - ID pgtype.UUID `json:"id"` - BotID pgtype.UUID `json:"bot_id"` - ContactID pgtype.UUID `json:"contact_id"` - Token string `json:"token"` - TargetPlatform pgtype.Text `json:"target_platform"` - TargetExternalID pgtype.Text `json:"target_external_id"` - IssuedByUserID pgtype.UUID `json:"issued_by_user_id"` - ExpiresAt pgtype.Timestamptz `json:"expires_at"` - UsedAt pgtype.Timestamptz `json:"used_at"` - CreatedAt pgtype.Timestamptz `json:"created_at"` -} - type ContactChannel struct { ID pgtype.UUID `json:"id"` BotID pgtype.UUID `json:"bot_id"` @@ -145,6 +145,7 @@ type History struct { BotID pgtype.UUID `json:"bot_id"` SessionID string `json:"session_id"` Messages []byte `json:"messages"` + Metadata []byte `json:"metadata"` Skills []string `json:"skills"` Timestamp pgtype.Timestamptz `json:"timestamp"` } diff --git a/internal/db/sqlc/preauth.sql.go b/internal/db/sqlc/preauth.sql.go new file mode 100644 index 00000000..c6e8c92b --- /dev/null +++ b/internal/db/sqlc/preauth.sql.go @@ -0,0 +1,89 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: preauth.sql + +package sqlc + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createBotPreauthKey = `-- name: CreateBotPreauthKey :one +INSERT INTO bot_preauth_keys (bot_id, token, issued_by_user_id, expires_at) +VALUES ($1, $2, $3, $4) +RETURNING id, bot_id, token, issued_by_user_id, expires_at, used_at, created_at +` + +type CreateBotPreauthKeyParams struct { + BotID pgtype.UUID `json:"bot_id"` + Token string `json:"token"` + IssuedByUserID pgtype.UUID `json:"issued_by_user_id"` + ExpiresAt pgtype.Timestamptz `json:"expires_at"` +} + +func (q *Queries) CreateBotPreauthKey(ctx context.Context, arg CreateBotPreauthKeyParams) (BotPreauthKey, error) { + row := q.db.QueryRow(ctx, createBotPreauthKey, + arg.BotID, + arg.Token, + arg.IssuedByUserID, + arg.ExpiresAt, + ) + var i BotPreauthKey + err := row.Scan( + &i.ID, + &i.BotID, + &i.Token, + &i.IssuedByUserID, + &i.ExpiresAt, + &i.UsedAt, + &i.CreatedAt, + ) + return i, err +} + +const getBotPreauthKey = `-- name: GetBotPreauthKey :one +SELECT id, bot_id, token, issued_by_user_id, expires_at, used_at, created_at +FROM bot_preauth_keys +WHERE token = $1 +LIMIT 1 +` + +func (q *Queries) GetBotPreauthKey(ctx context.Context, token string) (BotPreauthKey, error) { + row := q.db.QueryRow(ctx, getBotPreauthKey, token) + var i BotPreauthKey + err := row.Scan( + &i.ID, + &i.BotID, + &i.Token, + &i.IssuedByUserID, + &i.ExpiresAt, + &i.UsedAt, + &i.CreatedAt, + ) + return i, err +} + +const markBotPreauthKeyUsed = `-- name: MarkBotPreauthKeyUsed :one +UPDATE bot_preauth_keys +SET used_at = now() +WHERE id = $1 +RETURNING id, bot_id, token, issued_by_user_id, expires_at, used_at, created_at +` + +func (q *Queries) MarkBotPreauthKeyUsed(ctx context.Context, id pgtype.UUID) (BotPreauthKey, error) { + row := q.db.QueryRow(ctx, markBotPreauthKeyUsed, id) + var i BotPreauthKey + err := row.Scan( + &i.ID, + &i.BotID, + &i.Token, + &i.IssuedByUserID, + &i.ExpiresAt, + &i.UsedAt, + &i.CreatedAt, + ) + return i, err +} diff --git a/internal/directory/service.go b/internal/directory/service.go new file mode 100644 index 00000000..158039e1 --- /dev/null +++ b/internal/directory/service.go @@ -0,0 +1,226 @@ +package directory + +import ( + "context" + "errors" + "fmt" + "log/slog" + "strings" + + "github.com/memohai/memoh/internal/channel" + "github.com/memohai/memoh/internal/contacts" +) + +var ( + ErrNotFound = errors.New("directory entry not found") + ErrAmbiguous = errors.New("directory entry ambiguous") + ErrUnsupported = errors.New("directory operation unsupported") +) + +type ContactReader interface { + Search(ctx context.Context, botID, query string) ([]contacts.Contact, error) + ListByBot(ctx context.Context, botID string) ([]contacts.Contact, error) + ListChannelsByContact(ctx context.Context, contactID string) ([]contacts.ContactChannel, error) +} + +type ChannelSessionStore interface { + ListSessionsByBotPlatform(ctx context.Context, botID, platform string) ([]channel.ChannelSession, error) +} + +type LocalService struct { + contacts ContactReader + sessions ChannelSessionStore + logger *slog.Logger +} + +func NewLocalService(log *slog.Logger, contacts ContactReader, sessions ChannelSessionStore) *LocalService { + if log == nil { + log = slog.Default() + } + return &LocalService{ + contacts: contacts, + sessions: sessions, + logger: log.With(slog.String("service", "directory")), + } +} + +func (s *LocalService) ListPeers(ctx context.Context, botID, platform, query string, limit int) ([]channel.DirectoryEntry, error) { + if s.contacts == nil { + return nil, fmt.Errorf("contacts service not configured") + } + trimmed := strings.TrimSpace(query) + var items []contacts.Contact + var err error + if trimmed == "" { + items, err = s.contacts.ListByBot(ctx, botID) + } else { + items, err = s.contacts.Search(ctx, botID, trimmed) + } + if err != nil { + return nil, err + } + results := make([]channel.DirectoryEntry, 0, len(items)) + for _, contact := range items { + channels, err := s.contacts.ListChannelsByContact(ctx, contact.ID) + if err != nil { + if s.logger != nil { + s.logger.Warn("list contact channels failed", slog.String("contact_id", contact.ID), slog.Any("error", err)) + } + continue + } + for _, ch := range channels { + if platform != "" && ch.Platform != platform { + continue + } + entry := channel.DirectoryEntry{ + Kind: channel.DirectoryEntryUser, + ID: strings.TrimSpace(ch.ExternalID), + Name: chooseContactName(contact, ch), + Handle: strings.TrimSpace(contact.Alias), + Metadata: map[string]any{}, + } + if entry.ID == "" { + continue + } + entry.Metadata["contact_id"] = contact.ID + if contact.UserID != "" { + entry.Metadata["user_id"] = contact.UserID + } + entry.Metadata["platform"] = ch.Platform + results = append(results, entry) + if limit > 0 && len(results) >= limit { + return results, nil + } + } + } + return results, nil +} + +func (s *LocalService) ListGroups(ctx context.Context, botID, platform, query string, limit int) ([]channel.DirectoryEntry, error) { + if s.sessions == nil { + return nil, fmt.Errorf("channel session store not configured") + } + platform = strings.TrimSpace(platform) + if platform == "" { + return nil, fmt.Errorf("platform is required") + } + sessions, err := s.sessions.ListSessionsByBotPlatform(ctx, botID, platform) + if err != nil { + return nil, err + } + trimmed := strings.TrimSpace(query) + results := make([]channel.DirectoryEntry, 0, len(sessions)) + for _, session := range sessions { + if !isGroupSession(session) { + continue + } + name := channel.ReadString(session.Metadata, "conversation_name", "name") + entryID := strings.TrimSpace(session.ReplyTarget) + if entryID == "" { + entryID = strings.TrimSpace(session.SessionID) + } + if entryID == "" { + continue + } + if trimmed != "" && !matchesQuery(trimmed, entryID, name) { + continue + } + results = append(results, channel.DirectoryEntry{ + Kind: channel.DirectoryEntryGroup, + ID: entryID, + Name: strings.TrimSpace(name), + Metadata: session.Metadata, + }) + if limit > 0 && len(results) >= limit { + return results, nil + } + } + return results, nil +} + +func (s *LocalService) ListGroupMembers(ctx context.Context, botID, platform, groupID string, limit int) ([]channel.DirectoryEntry, error) { + return nil, ErrUnsupported +} + +func (s *LocalService) ResolveTarget(ctx context.Context, botID, platform, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + trimmed := strings.TrimSpace(input) + if trimmed == "" { + return channel.DirectoryEntry{}, ErrNotFound + } + switch kind { + case channel.DirectoryEntryGroup: + items, err := s.ListGroups(ctx, botID, platform, trimmed, 5) + if err != nil { + return channel.DirectoryEntry{}, err + } + return pickSingleMatch(items, trimmed) + default: + items, err := s.ListPeers(ctx, botID, platform, trimmed, 5) + if err != nil { + return channel.DirectoryEntry{}, err + } + return pickSingleMatch(items, trimmed) + } +} + +func pickSingleMatch(items []channel.DirectoryEntry, input string) (channel.DirectoryEntry, error) { + if len(items) == 0 { + return channel.DirectoryEntry{}, ErrNotFound + } + if len(items) == 1 { + return items[0], nil + } + lower := strings.ToLower(strings.TrimSpace(input)) + var exact *channel.DirectoryEntry + for i := range items { + if strings.ToLower(strings.TrimSpace(items[i].ID)) == lower { + exact = &items[i] + break + } + if strings.ToLower(strings.TrimSpace(items[i].Name)) == lower { + exact = &items[i] + break + } + } + if exact != nil { + return *exact, nil + } + return channel.DirectoryEntry{}, ErrAmbiguous +} + +func chooseContactName(contact contacts.Contact, ch contacts.ContactChannel) string { + if strings.TrimSpace(contact.DisplayName) != "" { + return strings.TrimSpace(contact.DisplayName) + } + if strings.TrimSpace(contact.Alias) != "" { + return strings.TrimSpace(contact.Alias) + } + if strings.TrimSpace(ch.ExternalID) != "" { + return strings.TrimSpace(ch.ExternalID) + } + return "" +} + +func isGroupSession(session channel.ChannelSession) bool { + value := strings.ToLower(strings.TrimSpace(channel.ReadString(session.Metadata, "conversation_type", "chat_type", "type"))) + if value == "" { + return false + } + if strings.Contains(value, "group") { + return true + } + return false +} + +func matchesQuery(query string, fields ...string) bool { + needle := strings.ToLower(strings.TrimSpace(query)) + if needle == "" { + return true + } + for _, field := range fields { + if strings.Contains(strings.ToLower(strings.TrimSpace(field)), needle) { + return true + } + } + return false +} diff --git a/internal/handlers/channel.go b/internal/handlers/channel.go index 03d9eab8..dc96a4ea 100644 --- a/internal/handlers/channel.go +++ b/internal/handlers/channel.go @@ -2,6 +2,7 @@ package handlers import ( "net/http" + "sort" "strings" "github.com/labstack/echo/v4" @@ -23,6 +24,10 @@ func (h *ChannelHandler) Register(e *echo.Echo) { group := e.Group("/users/me/channels") group.GET("/:platform", h.GetUserConfig) group.PUT("/:platform", h.UpsertUserConfig) + + metaGroup := e.Group("/channels") + metaGroup.GET("", h.ListChannels) + metaGroup.GET("/:platform", h.GetChannel) } // GetUserConfig godoc @@ -78,7 +83,7 @@ func (h *ChannelHandler) UpsertUserConfig(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } if req.Config == nil { - req.Config = map[string]interface{}{} + req.Config = map[string]any{} } resp, err := h.service.UpsertUserConfig(c.Request().Context(), userID, channelType, req) if err != nil { @@ -87,6 +92,73 @@ func (h *ChannelHandler) UpsertUserConfig(c echo.Context) error { return c.JSON(http.StatusOK, resp) } +type ChannelMeta struct { + Type string `json:"type"` + DisplayName string `json:"display_name"` + Configless bool `json:"configless"` + Capabilities channel.ChannelCapabilities `json:"capabilities"` + ConfigSchema channel.ConfigSchema `json:"config_schema"` + UserConfigSchema channel.ConfigSchema `json:"user_config_schema"` + TargetSpec channel.TargetSpec `json:"target_spec"` +} + +// ListChannels godoc +// @Summary List channel capabilities and schemas +// @Description List channel meta information including capabilities and schemas +// @Tags channel +// @Success 200 {array} ChannelMeta +// @Failure 500 {object} ErrorResponse +// @Router /channels [get] +func (h *ChannelHandler) ListChannels(c echo.Context) error { + descs := channel.ListChannelDescriptors() + items := make([]ChannelMeta, 0, len(descs)) + for _, desc := range descs { + items = append(items, ChannelMeta{ + Type: desc.Type.String(), + DisplayName: desc.DisplayName, + Configless: desc.Configless, + Capabilities: desc.Capabilities, + ConfigSchema: desc.ConfigSchema, + UserConfigSchema: desc.UserConfigSchema, + TargetSpec: desc.TargetSpec, + }) + } + sort.Slice(items, func(i, j int) bool { + return items[i].Type < items[j].Type + }) + return c.JSON(http.StatusOK, items) +} + +// GetChannel godoc +// @Summary Get channel capabilities and schemas +// @Description Get channel meta information including capabilities and schemas +// @Tags channel +// @Param platform path string true "Channel platform" +// @Success 200 {object} ChannelMeta +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Router /channels/{platform} [get] +func (h *ChannelHandler) GetChannel(c echo.Context) error { + channelType, err := channel.ParseChannelType(c.Param("platform")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + desc, ok := channel.GetChannelDescriptor(channelType) + if !ok { + return echo.NewHTTPError(http.StatusNotFound, "channel not found") + } + resp := ChannelMeta{ + Type: desc.Type.String(), + DisplayName: desc.DisplayName, + Configless: desc.Configless, + Capabilities: desc.Capabilities, + ConfigSchema: desc.ConfigSchema, + UserConfigSchema: desc.UserConfigSchema, + TargetSpec: desc.TargetSpec, + } + return c.JSON(http.StatusOK, resp) +} + func (h *ChannelHandler) requireUserID(c echo.Context) (string, error) { userID, err := auth.UserIDFromContext(c) if err != nil { diff --git a/internal/handlers/chat.go b/internal/handlers/chat.go index 3d185868..e93408a2 100644 --- a/internal/handlers/chat.go +++ b/internal/handlers/chat.go @@ -81,6 +81,12 @@ func (h *ChatHandler) Chat(c echo.Context) error { req.SessionID = sessionID req.Token = c.Request().Header.Get("Authorization") req.UserID = userID + if strings.TrimSpace(req.ContactID) == "" { + req.ContactID = userID + } + if strings.TrimSpace(req.ContactName) == "" { + req.ContactName = "User" + } resp, err := h.resolver.Chat(c.Request().Context(), req) if err != nil { @@ -130,6 +136,12 @@ func (h *ChatHandler) StreamChat(c echo.Context) error { req.SessionID = sessionID req.Token = c.Request().Header.Get("Authorization") req.UserID = userID + if strings.TrimSpace(req.ContactID) == "" { + req.ContactID = userID + } + if strings.TrimSpace(req.ContactName) == "" { + req.ContactName = "User" + } // Set headers for SSE c.Response().Header().Set(echo.HeaderContentType, "text/event-stream") diff --git a/internal/handlers/contacts.go b/internal/handlers/contacts.go index 286ce89a..bee73a31 100644 --- a/internal/handlers/contacts.go +++ b/internal/handlers/contacts.go @@ -5,7 +5,6 @@ import ( "errors" "net/http" "strings" - "time" "github.com/labstack/echo/v4" @@ -36,25 +35,6 @@ func (h *ContactsHandler) Register(e *echo.Echo) { group.GET("/:id", h.Get) group.POST("", h.Create) group.PATCH("/:id", h.Update) - group.POST("/:id/bind", h.Bind) - group.POST("/:id/bind_token", h.IssueBindToken) - group.POST("/bind_confirm", h.ConfirmBind) -} - -type contactBindRequest struct { - Platform string `json:"platform"` - ExternalID string `json:"external_id"` - BindToken string `json:"bind_token"` -} - -type contactBindTokenRequest struct { - TargetPlatform string `json:"target_platform"` - TargetExternalID string `json:"target_external_id"` - TTLSeconds int `json:"ttl_seconds"` -} - -type contactBindConfirmRequest struct { - Token string `json:"token"` } func (h *ContactsHandler) List(c echo.Context) error { @@ -74,7 +54,7 @@ func (h *ContactsHandler) List(c echo.Context) error { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - return c.JSON(http.StatusOK, map[string]interface{}{"items": items}) + return c.JSON(http.StatusOK, map[string]any{"items": items}) } func (h *ContactsHandler) Get(c echo.Context) error { @@ -125,17 +105,10 @@ func (h *ContactsHandler) Create(c echo.Context) error { } func (h *ContactsHandler) Update(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } botID := strings.TrimSpace(c.Param("bot_id")) if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } id := strings.TrimSpace(c.Param("id")) if id == "" { return echo.NewHTTPError(http.StatusBadRequest, "contact id is required") @@ -144,6 +117,32 @@ func (h *ContactsHandler) Update(c echo.Context) error { if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } + + userID, err := h.requireUserID(c) + if err == nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + return err + } + item, err := h.service.Update(c.Request().Context(), id, req) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, item) + } + + sessionToken, tokenErr := auth.SessionTokenFromContext(c) + if tokenErr != nil { + return err + } + if sessionToken.BotID != botID { + return echo.NewHTTPError(http.StatusForbidden, "session token mismatch") + } + if strings.TrimSpace(sessionToken.ContactID) == "" || sessionToken.ContactID != id { + return echo.NewHTTPError(http.StatusForbidden, "contact mismatch") + } + if req.Tags != nil || req.Status != nil { + return echo.NewHTTPError(http.StatusForbidden, "session token cannot update tags or status") + } item, err := h.service.Update(c.Request().Context(), id, req) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) @@ -151,140 +150,6 @@ func (h *ContactsHandler) Update(c echo.Context) error { return c.JSON(http.StatusOK, item) } -func (h *ContactsHandler) Bind(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - id := strings.TrimSpace(c.Param("id")) - if id == "" { - return echo.NewHTTPError(http.StatusBadRequest, "contact id is required") - } - var req contactBindRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - if strings.TrimSpace(req.BindToken) == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bind_token is required") - } - token, err := h.service.GetBindToken(c.Request().Context(), req.BindToken) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "invalid bind token") - } - if token.UsedAt.IsZero() == false { - return echo.NewHTTPError(http.StatusBadRequest, "bind token already used") - } - if time.Now().UTC().After(token.ExpiresAt) { - return echo.NewHTTPError(http.StatusBadRequest, "bind token expired") - } - if token.BotID != botID || token.ContactID != id { - return echo.NewHTTPError(http.StatusBadRequest, "bind token mismatch") - } - platform := strings.TrimSpace(req.Platform) - externalID := strings.TrimSpace(req.ExternalID) - if platform == "" || externalID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "platform and external_id are required") - } - if token.TargetPlatform != "" && token.TargetPlatform != platform { - return echo.NewHTTPError(http.StatusBadRequest, "bind token platform mismatch") - } - if token.TargetExternalID != "" && token.TargetExternalID != externalID { - return echo.NewHTTPError(http.StatusBadRequest, "bind token external_id mismatch") - } - bound, err := h.service.UpsertChannel(c.Request().Context(), botID, id, platform, externalID, nil) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - _, _ = h.service.MarkBindTokenUsed(c.Request().Context(), token.ID) - return c.JSON(http.StatusOK, bound) -} - -func (h *ContactsHandler) IssueBindToken(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - id := strings.TrimSpace(c.Param("id")) - if id == "" { - return echo.NewHTTPError(http.StatusBadRequest, "contact id is required") - } - var req contactBindTokenRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - ttl := 10 * time.Minute - if req.TTLSeconds > 0 { - ttl = time.Duration(req.TTLSeconds) * time.Second - } - token, err := h.service.CreateBindToken(c.Request().Context(), botID, id, req.TargetPlatform, req.TargetExternalID, userID, ttl) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return c.JSON(http.StatusOK, token) -} - -func (h *ContactsHandler) ConfirmBind(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - botID := strings.TrimSpace(c.Param("bot_id")) - if botID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { - return err - } - var req contactBindConfirmRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - token, err := h.service.GetBindToken(c.Request().Context(), req.Token) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "invalid bind token") - } - if token.UsedAt.IsZero() == false { - return echo.NewHTTPError(http.StatusBadRequest, "bind token already used") - } - if time.Now().UTC().After(token.ExpiresAt) { - return echo.NewHTTPError(http.StatusBadRequest, "bind token expired") - } - if token.BotID != botID { - return echo.NewHTTPError(http.StatusBadRequest, "bind token mismatch") - } - if token.IssuedByUserID != "" && token.IssuedByUserID != userID { - return echo.NewHTTPError(http.StatusBadRequest, "bind token not issued for current user") - } - contact, err := h.service.GetByID(c.Request().Context(), token.ContactID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if contact.UserID != "" && contact.UserID != userID { - return echo.NewHTTPError(http.StatusBadRequest, "contact already bound to another user") - } - if contact.UserID == "" { - if _, err := h.service.BindUser(c.Request().Context(), contact.ID, userID); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - } - _, _ = h.service.MarkBindTokenUsed(c.Request().Context(), token.ID) - return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) -} - func (h *ContactsHandler) requireUserID(c echo.Context) (string, error) { userID, err := auth.UserIDFromContext(c) if err != nil { diff --git a/internal/handlers/containerd.go b/internal/handlers/containerd.go index 0dd759bb..a9e7977b 100644 --- a/internal/handlers/containerd.go +++ b/internal/handlers/containerd.go @@ -2,6 +2,7 @@ package handlers import ( "context" + "errors" "log/slog" "net/http" "os" @@ -16,21 +17,30 @@ import ( "github.com/containerd/containerd/v2/pkg/oci" "github.com/containerd/errdefs" "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" "github.com/labstack/echo/v4" "github.com/opencontainers/runtime-spec/specs-go" + "github.com/memohai/memoh/internal/auth" + "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/config" ctr "github.com/memohai/memoh/internal/containerd" + dbsqlc "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/identity" "github.com/memohai/memoh/internal/mcp" + "github.com/memohai/memoh/internal/users" ) type ContainerdHandler struct { - service ctr.Service - cfg config.MCPConfig - namespace string - logger *slog.Logger - mcpMu sync.Mutex - mcpSess map[string]*mcpSession + service ctr.Service + cfg config.MCPConfig + namespace string + logger *slog.Logger + mcpMu sync.Mutex + mcpSess map[string]*mcpSession + botService *bots.Service + userService *users.Service + queries *dbsqlc.Queries } type CreateContainerRequest struct { @@ -86,18 +96,21 @@ type ListSnapshotsResponse struct { Snapshots []SnapshotInfo `json:"snapshots"` } -func NewContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string) *ContainerdHandler { +func NewContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string, botService *bots.Service, userService *users.Service, queries *dbsqlc.Queries) *ContainerdHandler { return &ContainerdHandler{ - service: service, - cfg: cfg, - namespace: namespace, - logger: log.With(slog.String("handler", "containerd")), - mcpSess: make(map[string]*mcpSession), + service: service, + cfg: cfg, + namespace: namespace, + logger: log.With(slog.String("handler", "containerd")), + mcpSess: make(map[string]*mcpSession), + botService: botService, + userService: userService, + queries: queries, } } func (h *ContainerdHandler) Register(e *echo.Echo) { - group := e.Group("/container") + group := e.Group("/bots/:bot_id/container") group.POST("", h.CreateContainer) group.GET("/list", h.ListContainers) group.DELETE("/:id", h.DeleteContainer) @@ -110,15 +123,16 @@ func (h *ContainerdHandler) Register(e *echo.Echo) { } // CreateContainer godoc -// @Summary Create and start MCP container +// @Summary Create and start MCP container for bot // @Tags containerd +// @Param bot_id path string true "Bot ID" // @Param payload body CreateContainerRequest true "Create container payload" // @Success 200 {object} CreateContainerResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /container [post] +// @Router /bots/{bot_id}/container [post] func (h *ContainerdHandler) CreateContainer(c echo.Context) error { - userID, err := h.requireUserID(c) + botID, err := h.requireBotAccess(c) if err != nil { return err } @@ -129,7 +143,7 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { } req.ContainerID = strings.TrimSpace(req.ContainerID) if req.ContainerID == "" { - req.ContainerID = uuid.NewString() + req.ContainerID = "mcp-" + botID } image := strings.TrimSpace(req.Image) @@ -156,7 +170,7 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { if dataMount == "" { dataMount = config.DefaultDataMount } - dataDir := filepath.Join(dataRoot, "bots", userID) + dataDir := filepath.Join(dataRoot, "bots", botID) if err := os.MkdirAll(dataDir, 0o755); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -187,7 +201,7 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { ImageRef: image, Snapshotter: snapshotter, Labels: map[string]string{ - mcp.BotLabelKey: userID, + mcp.BotLabelKey: botID, }, SpecOpts: specOpts, }) @@ -195,6 +209,28 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { return echo.NewHTTPError(http.StatusInternalServerError, "snapshotter="+snapshotter+" image="+image+" err="+err.Error()) } + // Persist container record in database + if h.queries != nil { + pgBotID, parseErr := parsePgUUID(botID) + if parseErr == nil { + ns := strings.TrimSpace(h.namespace) + if ns == "" { + ns = "default" + } + _ = h.queries.UpsertContainer(c.Request().Context(), dbsqlc.UpsertContainerParams{ + BotID: pgBotID, + ContainerID: req.ContainerID, + ContainerName: req.ContainerID, + Image: image, + Status: "created", + Namespace: ns, + AutoStart: true, + HostPath: pgtype.Text{String: dataDir, Valid: true}, + ContainerPath: dataMount, + }) + } + } + started := false fifoDir, err := h.taskFIFODir() if err != nil { @@ -260,8 +296,19 @@ func (h *ContainerdHandler) ensureTaskRunning(ctx context.Context, containerID s return err } -func (h *ContainerdHandler) userContainerID(ctx context.Context, userID string) (string, error) { - containers, err := h.service.ListContainersByLabel(ctx, mcp.BotLabelKey, userID) +// botContainerID resolves container_id for a bot from the database. +func (h *ContainerdHandler) botContainerID(ctx context.Context, botID string) (string, error) { + if h.queries != nil { + pgBotID, err := parsePgUUID(botID) + if err == nil { + row, err := h.queries.GetContainerByBotID(ctx, pgBotID) + if err == nil && strings.TrimSpace(row.ContainerID) != "" { + return row.ContainerID, nil + } + } + } + // Fallback: search by containerd label + containers, err := h.service.ListContainersByLabel(ctx, mcp.BotLabelKey, botID) if err != nil { return "", err } @@ -288,14 +335,19 @@ func (h *ContainerdHandler) userContainerID(ctx context.Context, userID string) } // ListContainers godoc -// @Summary List containers +// @Summary List containers for bot // @Tags containerd +// @Param bot_id path string true "Bot ID" // @Success 200 {object} ListContainersResponse // @Failure 500 {object} ErrorResponse -// @Router /container/list [get] +// @Router /bots/{bot_id}/container/list [get] func (h *ContainerdHandler) ListContainers(c echo.Context) error { + botID, err := h.requireBotAccess(c) + if err != nil { + return err + } ctx := c.Request().Context() - containers, err := h.service.ListContainers(ctx) + containers, err := h.service.ListContainersByLabel(ctx, mcp.BotLabelKey, botID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -328,13 +380,17 @@ func (h *ContainerdHandler) ListContainers(c echo.Context) error { // DeleteContainer godoc // @Summary Delete MCP container // @Tags containerd +// @Param bot_id path string true "Bot ID" // @Param id path string true "Container ID" // @Success 204 // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /container/{id} [delete] +// @Router /bots/{bot_id}/container/{id} [delete] func (h *ContainerdHandler) DeleteContainer(c echo.Context) error { + if _, err := h.requireBotAccess(c); err != nil { + return err + } containerID := strings.TrimSpace(c.Param("id")) if containerID == "" { return echo.NewHTTPError(http.StatusBadRequest, "container id is required") @@ -354,13 +410,14 @@ func (h *ContainerdHandler) DeleteContainer(c echo.Context) error { // CreateSnapshot godoc // @Summary Create container snapshot // @Tags containerd +// @Param bot_id path string true "Bot ID" // @Param payload body CreateSnapshotRequest true "Create snapshot payload" // @Success 200 {object} CreateSnapshotResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /container/snapshots [post] +// @Router /bots/{bot_id}/container/snapshots [post] func (h *ContainerdHandler) CreateSnapshot(c echo.Context) error { + if _, err := h.requireBotAccess(c); err != nil { + return err + } var req CreateSnapshotRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) @@ -401,11 +458,14 @@ func (h *ContainerdHandler) CreateSnapshot(c echo.Context) error { // ListSnapshots godoc // @Summary List snapshots // @Tags containerd +// @Param bot_id path string true "Bot ID" // @Param snapshotter query string false "Snapshotter name" // @Success 200 {object} ListSnapshotsResponse -// @Failure 500 {object} ErrorResponse -// @Router /container/snapshots [get] +// @Router /bots/{bot_id}/container/snapshots [get] func (h *ContainerdHandler) ListSnapshots(c echo.Context) error { + if _, err := h.requireBotAccess(c); err != nil { + return err + } snapshotter := strings.TrimSpace(c.QueryParam("snapshotter")) if snapshotter == "" { snapshotter = strings.TrimSpace(h.cfg.Snapshotter) @@ -440,3 +500,64 @@ func (h *ContainerdHandler) ListSnapshots(c echo.Context) error { Snapshots: items, }) } + +// ---------- auth helpers ---------- + +// requireBotAccess extracts bot_id from path, validates user auth, and authorizes bot access. +func (h *ContainerdHandler) requireBotAccess(c echo.Context) (string, error) { + userID, err := h.requireUserID(c) + if err != nil { + return "", err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return "", echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + return "", err + } + return botID, nil +} + +func (h *ContainerdHandler) requireUserID(c echo.Context) (string, error) { + userID, err := auth.UserIDFromContext(c) + if err != nil { + return "", err + } + if err := identity.ValidateUserID(userID); err != nil { + return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + return userID, nil +} + +func (h *ContainerdHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { + if h.botService == nil || h.userService == nil { + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") + } + isAdmin, err := h.userService.IsAdmin(ctx, actorID) + if err != nil { + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + if err != nil { + if errors.Is(err, bots.ErrBotNotFound) { + return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") + } + if errors.Is(err, bots.ErrBotAccessDenied) { + return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied") + } + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return bot, nil +} + +func parsePgUUID(id string) (pgtype.UUID, error) { + parsed, err := uuid.Parse(strings.TrimSpace(id)) + if err != nil { + return pgtype.UUID{}, err + } + var pgID pgtype.UUID + pgID.Valid = true + copy(pgID.Bytes[:], parsed[:]) + return pgID, nil +} diff --git a/internal/handlers/fs.go b/internal/handlers/fs.go index 0a49db70..94a0dcb5 100644 --- a/internal/handlers/fs.go +++ b/internal/handlers/fs.go @@ -17,9 +17,7 @@ import ( "github.com/containerd/errdefs" "github.com/labstack/echo/v4" - "github.com/memohai/memoh/internal/auth" ctr "github.com/memohai/memoh/internal/containerd" - "github.com/memohai/memoh/internal/identity" mcptools "github.com/memohai/memoh/internal/mcp" ) @@ -48,6 +46,10 @@ import ( // @Failure 500 {object} ErrorResponse // @Router /container/fs/{id} [post] func (h *ContainerdHandler) HandleMCPFS(c echo.Context) error { + botID, err := h.requireBotAccess(c) + if err != nil { + return err + } containerID := strings.TrimSpace(c.Param("id")) if containerID == "" { return echo.NewHTTPError(http.StatusBadRequest, "container id is required") @@ -65,11 +67,7 @@ func (h *ContainerdHandler) HandleMCPFS(c echo.Context) error { }) } - userID, err := h.requireUserID(c) - if err != nil { - return err - } - if err := h.validateMCPContainer(c.Request().Context(), containerID, userID); err != nil { + if err := h.validateMCPContainer(c.Request().Context(), containerID, botID); err != nil { return err } if err := h.ensureTaskRunning(c.Request().Context(), containerID); err != nil { @@ -98,9 +96,9 @@ func (h *ContainerdHandler) HandleMCPFS(c echo.Context) error { } } -func (h *ContainerdHandler) validateMCPContainer(ctx context.Context, containerID, userID string) error { - if strings.TrimSpace(userID) == "" { - return echo.NewHTTPError(http.StatusUnauthorized, "invalid token") +func (h *ContainerdHandler) validateMCPContainer(ctx context.Context, containerID, botID string) error { + if strings.TrimSpace(botID) == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } container, err := h.service.GetContainer(ctx, containerID) if err != nil { @@ -118,24 +116,13 @@ func (h *ContainerdHandler) validateMCPContainer(ctx context.Context, containerI if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - labelUserID := strings.TrimSpace(info.Labels[mcptools.BotLabelKey]) - if labelUserID != "" && labelUserID != userID { + labelBotID := strings.TrimSpace(info.Labels[mcptools.BotLabelKey]) + if labelBotID != "" && labelBotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } return nil } -func (h *ContainerdHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) - if err != nil { - return "", err - } - if err := identity.ValidateUserID(userID); err != nil { - return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - return userID, nil -} - func (h *ContainerdHandler) callMCPServer(ctx context.Context, containerID string, req mcptools.JSONRPCRequest) (map[string]any, error) { session, err := h.getMCPSession(ctx, containerID) if err != nil { diff --git a/internal/handlers/history.go b/internal/handlers/history.go index 7c7f905b..5d72c145 100644 --- a/internal/handlers/history.go +++ b/internal/handlers/history.go @@ -251,4 +251,4 @@ func (h *HistoryHandler) authorizeBotAccess(ctx context.Context, actorID, botID return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return bot, nil -} +} \ No newline at end of file diff --git a/internal/handlers/local_channel.go b/internal/handlers/local_channel.go index 98e13a29..204761be 100644 --- a/internal/handlers/local_channel.go +++ b/internal/handlers/local_channel.go @@ -8,6 +8,7 @@ import ( "fmt" "net/http" "strings" + "time" "github.com/google/uuid" "github.com/labstack/echo/v4" @@ -68,7 +69,7 @@ func (h *LocalChannelHandler) CreateSession(c echo.Context) error { return echo.NewHTTPError(http.StatusInternalServerError, "channel service not configured") } sessionID := fmt.Sprintf("%s:%s", h.channelType.String(), uuid.NewString()) - if err := h.channelService.UpsertChannelSession(c.Request().Context(), sessionID, botID, "", userID, "", h.channelType.String()); err != nil { + if err := h.channelService.UpsertChannelSession(c.Request().Context(), sessionID, botID, "", userID, "", h.channelType.String(), sessionID, "", nil); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } streamURL := fmt.Sprintf("/bots/%s/%s/sessions/%s/stream", botID, h.channelType.String(), sessionID) @@ -121,8 +122,8 @@ func (h *LocalChannelHandler) StreamSession(c echo.Context) error { return nil } payload := map[string]any{ - "text": msg.Text, - "to": msg.To, + "target": msg.Target, + "message": msg.Message, } data, err := json.Marshal(payload) if err != nil { @@ -136,8 +137,7 @@ func (h *LocalChannelHandler) StreamSession(c echo.Context) error { } type localMessageRequest struct { - Text string `json:"text"` - Message string `json:"message"` + Message channel.Message `json:"message"` } func (h *LocalChannelHandler) PostMessage(c echo.Context) error { @@ -166,26 +166,32 @@ func (h *LocalChannelHandler) PostMessage(c echo.Context) error { if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - text := strings.TrimSpace(req.Text) + text := strings.TrimSpace(req.Message.PlainText()) if text == "" { - text = strings.TrimSpace(req.Message) - } - if text == "" { - return echo.NewHTTPError(http.StatusBadRequest, "text is required") + return echo.NewHTTPError(http.StatusBadRequest, "message is required") } cfg, err := h.channelService.ResolveEffectiveConfig(c.Request().Context(), botID, h.channelType) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } msg := channel.InboundMessage{ - Channel: h.channelType, - Text: text, - ChatID: sessionID, - ChatType: "p2p", - ReplyTo: sessionID, - BotID: botID, - UserID: userID, - SessionKey: sessionID, + Channel: h.channelType, + Message: req.Message, + BotID: botID, + ReplyTarget: sessionID, + SessionKey: sessionID, + Sender: channel.Identity{ + ExternalID: userID, + Attributes: map[string]string{ + "user_id": userID, + }, + }, + Conversation: channel.Conversation{ + ID: sessionID, + Type: "p2p", + }, + ReceivedAt: time.Now().UTC(), + Source: "local", } if err := h.channelManager.HandleInbound(c.Request().Context(), cfg, msg); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) diff --git a/internal/handlers/memory.go b/internal/handlers/memory.go index 35643219..ffa731cc 100644 --- a/internal/handlers/memory.go +++ b/internal/handlers/memory.go @@ -25,33 +25,33 @@ type MemoryHandler struct { } type memoryAddPayload struct { - Message string `json:"message,omitempty"` - Messages []memory.Message `json:"messages,omitempty"` - RunID string `json:"run_id,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` - Filters map[string]interface{} `json:"filters,omitempty"` - Infer *bool `json:"infer,omitempty"` - EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` + Message string `json:"message,omitempty"` + Messages []memory.Message `json:"messages,omitempty"` + RunID string `json:"run_id,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + Filters map[string]any `json:"filters,omitempty"` + Infer *bool `json:"infer,omitempty"` + EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` } type memorySearchPayload struct { - Query string `json:"query"` - RunID string `json:"run_id,omitempty"` - Limit int `json:"limit,omitempty"` - Filters map[string]interface{} `json:"filters,omitempty"` - Sources []string `json:"sources,omitempty"` - EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` + Query string `json:"query"` + RunID string `json:"run_id,omitempty"` + Limit int `json:"limit,omitempty"` + Filters map[string]any `json:"filters,omitempty"` + Sources []string `json:"sources,omitempty"` + EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` } type memoryEmbedUpsertPayload struct { - Type string `json:"type"` - Provider string `json:"provider,omitempty"` - Model string `json:"model,omitempty"` - Input memory.EmbedInput `json:"input"` - Source string `json:"source,omitempty"` - RunID string `json:"run_id,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` - Filters map[string]interface{} `json:"filters,omitempty"` + Type string `json:"type"` + Provider string `json:"provider,omitempty"` + Model string `json:"model,omitempty"` + Input memory.EmbedInput `json:"input"` + Source string `json:"source,omitempty"` + RunID string `json:"run_id,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + Filters map[string]any `json:"filters,omitempty"` } type memoryDeleteAllPayload struct { diff --git a/internal/handlers/preauth.go b/internal/handlers/preauth.go new file mode 100644 index 00000000..4b0c965b --- /dev/null +++ b/internal/handlers/preauth.go @@ -0,0 +1,99 @@ +package handlers + +import ( + "context" + "errors" + "net/http" + "strings" + "time" + + "github.com/labstack/echo/v4" + + "github.com/memohai/memoh/internal/auth" + "github.com/memohai/memoh/internal/bots" + "github.com/memohai/memoh/internal/identity" + "github.com/memohai/memoh/internal/preauth" + "github.com/memohai/memoh/internal/users" +) + +type PreauthHandler struct { + service *preauth.Service + botService *bots.Service + userService *users.Service +} + +func NewPreauthHandler(service *preauth.Service, botService *bots.Service, userService *users.Service) *PreauthHandler { + return &PreauthHandler{ + service: service, + botService: botService, + userService: userService, + } +} + +func (h *PreauthHandler) Register(e *echo.Echo) { + group := e.Group("/bots/:bot_id/preauth_keys") + group.POST("", h.Issue) +} + +type preauthIssueRequest struct { + TTLSeconds int `json:"ttl_seconds"` +} + +func (h *PreauthHandler) Issue(c echo.Context) error { + userID, err := h.requireUserID(c) + if err != nil { + return err + } + botID := strings.TrimSpace(c.Param("bot_id")) + if botID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") + } + if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { + return err + } + var req preauthIssueRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + ttl := 24 * time.Hour + if req.TTLSeconds > 0 { + ttl = time.Duration(req.TTLSeconds) * time.Second + } + key, err := h.service.Issue(c.Request().Context(), botID, userID, ttl) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, key) +} + +func (h *PreauthHandler) requireUserID(c echo.Context) (string, error) { + userID, err := auth.UserIDFromContext(c) + if err != nil { + return "", err + } + if err := identity.ValidateUserID(userID); err != nil { + return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + return userID, nil +} + +func (h *PreauthHandler) authorizeBotAccess(ctx context.Context, actorID, botID string) (bots.Bot, error) { + if h.botService == nil || h.userService == nil { + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") + } + isAdmin, err := h.userService.IsAdmin(ctx, actorID) + if err != nil { + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + bot, err := h.botService.AuthorizeAccess(ctx, actorID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + if err != nil { + if errors.Is(err, bots.ErrBotNotFound) { + return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") + } + if errors.Is(err, bots.ErrBotAccessDenied) { + return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied") + } + return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return bot, nil +} diff --git a/internal/handlers/schedule.go b/internal/handlers/schedule.go index a07c4a87..fa56f520 100644 --- a/internal/handlers/schedule.go +++ b/internal/handlers/schedule.go @@ -248,4 +248,4 @@ func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, actorID, botID return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return bot, nil -} +} \ No newline at end of file diff --git a/internal/handlers/settings.go b/internal/handlers/settings.go index 711cd9aa..ad902319 100644 --- a/internal/handlers/settings.go +++ b/internal/handlers/settings.go @@ -156,4 +156,4 @@ func (h *SettingsHandler) authorizeBotAccess(ctx context.Context, actorID, botID return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return bot, nil -} +} \ No newline at end of file diff --git a/internal/handlers/skills.go b/internal/handlers/skills.go index 6f321bcb..1c65f65c 100644 --- a/internal/handlers/skills.go +++ b/internal/handlers/skills.go @@ -47,19 +47,19 @@ type skillsOpResponse struct { // @Failure 500 {object} ErrorResponse // @Router /container/skills [get] func (h *ContainerdHandler) ListSkills(c echo.Context) error { - userID, err := h.requireUserID(c) + botID, err := h.requireBotAccess(c) if err != nil { return err } ctx := c.Request().Context() - containerID, err := h.userContainerID(ctx, userID) + containerID, err := h.botContainerID(ctx, botID) if err != nil { return err } if err := h.ensureTaskRunning(ctx, containerID); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - if err := h.ensureSkillsDirHost(userID); err != nil { + if err := h.ensureSkillsDirHost(botID); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -105,7 +105,7 @@ func (h *ContainerdHandler) ListSkills(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /container/skills [post] func (h *ContainerdHandler) UpsertSkills(c echo.Context) error { - userID, err := h.requireUserID(c) + botID, err := h.requireBotAccess(c) if err != nil { return err } @@ -118,7 +118,7 @@ func (h *ContainerdHandler) UpsertSkills(c echo.Context) error { } ctx := c.Request().Context() - containerID, err := h.userContainerID(ctx, userID) + containerID, err := h.botContainerID(ctx, botID) if err != nil { return err } @@ -156,7 +156,7 @@ func (h *ContainerdHandler) UpsertSkills(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /container/skills [delete] func (h *ContainerdHandler) DeleteSkills(c echo.Context) error { - userID, err := h.requireUserID(c) + botID, err := h.requireBotAccess(c) if err != nil { return err } @@ -169,7 +169,7 @@ func (h *ContainerdHandler) DeleteSkills(c echo.Context) error { } ctx := c.Request().Context() - containerID, err := h.userContainerID(ctx, userID) + containerID, err := h.botContainerID(ctx, botID) if err != nil { return err } @@ -193,12 +193,12 @@ func (h *ContainerdHandler) DeleteSkills(c echo.Context) error { return c.JSON(http.StatusOK, skillsOpResponse{OK: true}) } -func (h *ContainerdHandler) ensureSkillsDirHost(userID string) error { +func (h *ContainerdHandler) ensureSkillsDirHost(botID string) error { dataRoot := strings.TrimSpace(h.cfg.DataRoot) if dataRoot == "" { dataRoot = config.DefaultDataRoot } - skillsDir := path.Join(dataRoot, "bots", userID, ".skills") + skillsDir := path.Join(dataRoot, "bots", botID, ".skills") return os.MkdirAll(skillsDir, 0o755) } diff --git a/internal/handlers/subagent.go b/internal/handlers/subagent.go index e8f70376..67f95467 100644 --- a/internal/handlers/subagent.go +++ b/internal/handlers/subagent.go @@ -462,4 +462,4 @@ func (h *SubagentHandler) authorizeBotAccess(ctx context.Context, actorID, botID return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return bot, nil -} +} \ No newline at end of file diff --git a/internal/handlers/users.go b/internal/handlers/users.go index d7dc8af7..c007b1ca 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -715,7 +715,7 @@ func (h *UsersHandler) UpsertBotChannelConfig(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } if req.Credentials == nil { - req.Credentials = map[string]interface{}{} + req.Credentials = map[string]any{} } resp, err := h.channelService.UpsertConfig(c.Request().Context(), botID, channelType, req) if err != nil { @@ -760,7 +760,7 @@ func (h *UsersHandler) SendBotMessage(c echo.Context) error { if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if strings.TrimSpace(req.Message) == "" { + if req.Message.IsEmpty() { return echo.NewHTTPError(http.StatusBadRequest, "message is required") } if err := h.channelManager.Send(c.Request().Context(), botID, channelType, req); err != nil { @@ -805,14 +805,14 @@ func (h *UsersHandler) SendBotMessageSession(c echo.Context) error { if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if strings.TrimSpace(req.Message) == "" { + if req.Message.IsEmpty() { return echo.NewHTTPError(http.StatusBadRequest, "message is required") } if strings.TrimSpace(sessionToken.ReplyTarget) == "" { return echo.NewHTTPError(http.StatusBadRequest, "reply target missing") } if err := h.channelManager.Send(c.Request().Context(), botID, channelType, channel.SendRequest{ - To: sessionToken.ReplyTarget, + Target: sessionToken.ReplyTarget, Message: req.Message, }); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) diff --git a/internal/history/service.go b/internal/history/service.go index 8d68d447..a407557c 100644 --- a/internal/history/service.go +++ b/internal/history/service.go @@ -46,10 +46,19 @@ func (s *Service) Create(ctx context.Context, botID, sessionID string, req Creat if err != nil { return Record{}, err } + meta := req.Metadata + if meta == nil { + meta = map[string]any{} + } + metaPayload, err := json.Marshal(meta) + if err != nil { + return Record{}, err + } row, err := s.queries.CreateHistory(ctx, sqlc.CreateHistoryParams{ BotID: botUUID, SessionID: trimmedSession, Messages: payload, + Metadata: metaPayload, Skills: normalizeSkills(req.Skills), Timestamp: pgtype.Timestamptz{ Time: time.Now().UTC(), @@ -163,14 +172,21 @@ func (s *Service) DeleteBySession(ctx context.Context, botID, sessionID string) } func toRecord(row sqlc.History) (Record, error) { - var messages []map[string]interface{} + var messages []map[string]any if len(row.Messages) > 0 { if err := json.Unmarshal(row.Messages, &messages); err != nil { return Record{}, err } } + var metadata map[string]any + if len(row.Metadata) > 0 { + if err := json.Unmarshal(row.Metadata, &metadata); err != nil { + return Record{}, err + } + } record := Record{ Messages: messages, + Metadata: metadata, Skills: normalizeSkills(row.Skills), } if row.Timestamp.Valid { diff --git a/internal/history/types.go b/internal/history/types.go index b1c6c347..088d8cfb 100644 --- a/internal/history/types.go +++ b/internal/history/types.go @@ -3,17 +3,19 @@ package history import "time" type Record struct { - ID string `json:"id"` - Messages []map[string]interface{} `json:"messages"` - Skills []string `json:"skills"` - Timestamp time.Time `json:"timestamp"` - BotID string `json:"bot_id"` - SessionID string `json:"session_id"` + ID string `json:"id"` + Messages []map[string]any `json:"messages"` + Metadata map[string]any `json:"metadata,omitempty"` + Skills []string `json:"skills"` + Timestamp time.Time `json:"timestamp"` + BotID string `json:"bot_id"` + SessionID string `json:"session_id"` } type CreateRequest struct { - Messages []map[string]interface{} `json:"messages"` - Skills []string `json:"skills,omitempty"` + Messages []map[string]any `json:"messages"` + Metadata map[string]any `json:"metadata,omitempty"` + Skills []string `json:"skills,omitempty"` } type ListResponse struct { diff --git a/internal/mcp/manager.go b/internal/mcp/manager.go index 59f0f08a..06e7a925 100644 --- a/internal/mcp/manager.go +++ b/internal/mcp/manager.go @@ -77,12 +77,13 @@ func (m *Manager) Init(ctx context.Context) error { return err } -func (m *Manager) EnsureUser(ctx context.Context, userID string) error { - if err := validateUserID(userID); err != nil { +// EnsureBot creates the MCP container for a bot if it does not exist. +func (m *Manager) EnsureBot(ctx context.Context, botID string) error { + if err := validateBotID(botID); err != nil { return err } - dataDir, err := m.ensureUserDir(userID) + dataDir, err := m.ensureBotDir(botID) if err != nil { return err } @@ -115,11 +116,11 @@ func (m *Manager) EnsureUser(ctx context.Context, userID string) error { } _, err = m.service.CreateContainer(ctx, ctr.CreateContainerRequest{ - ID: m.containerID(userID), + ID: m.containerID(botID), ImageRef: image, Snapshotter: m.cfg.Snapshotter, Labels: map[string]string{ - BotLabelKey: userID, + BotLabelKey: botID, }, SpecOpts: specOpts, }) @@ -134,13 +135,14 @@ func (m *Manager) EnsureUser(ctx context.Context, userID string) error { return nil } -func (m *Manager) ListUsers(ctx context.Context) ([]string, error) { +// ListBots returns the bot IDs that have MCP containers. +func (m *Manager) ListBots(ctx context.Context) ([]string, error) { containers, err := m.service.ListContainers(ctx) if err != nil { return nil, err } - users := make([]string, 0, len(containers)) + botIDs := make([]string, 0, len(containers)) for _, container := range containers { info, err := container.Info(ctx) if err != nil { @@ -148,47 +150,47 @@ func (m *Manager) ListUsers(ctx context.Context) ([]string, error) { } if strings.HasPrefix(info.ID, ContainerPrefix) { if botID, ok := info.Labels[BotLabelKey]; ok { - users = append(users, botID) + botIDs = append(botIDs, botID) } } } - return users, nil + return botIDs, nil } -func (m *Manager) Start(ctx context.Context, userID string) error { - if err := m.EnsureUser(ctx, userID); err != nil { +func (m *Manager) Start(ctx context.Context, botID string) error { + if err := m.EnsureBot(ctx, botID); err != nil { return err } - _, err := m.service.StartTask(ctx, m.containerID(userID), &ctr.StartTaskOptions{ + _, err := m.service.StartTask(ctx, m.containerID(botID), &ctr.StartTaskOptions{ UseStdio: false, }) return err } -func (m *Manager) Stop(ctx context.Context, userID string, timeout time.Duration) error { - if err := validateUserID(userID); err != nil { +func (m *Manager) Stop(ctx context.Context, botID string, timeout time.Duration) error { + if err := validateBotID(botID); err != nil { return err } - return m.service.StopTask(ctx, m.containerID(userID), &ctr.StopTaskOptions{ + return m.service.StopTask(ctx, m.containerID(botID), &ctr.StopTaskOptions{ Timeout: timeout, Force: true, }) } -func (m *Manager) Delete(ctx context.Context, userID string) error { - if err := validateUserID(userID); err != nil { +func (m *Manager) Delete(ctx context.Context, botID string) error { + if err := validateBotID(botID); err != nil { return err } - _ = m.service.DeleteTask(ctx, m.containerID(userID), &ctr.DeleteTaskOptions{Force: true}) - return m.service.DeleteContainer(ctx, m.containerID(userID), &ctr.DeleteContainerOptions{ + _ = m.service.DeleteTask(ctx, m.containerID(botID), &ctr.DeleteTaskOptions{Force: true}) + return m.service.DeleteContainer(ctx, m.containerID(botID), &ctr.DeleteContainerOptions{ CleanupSnapshot: true, }) } func (m *Manager) Exec(ctx context.Context, req ExecRequest) (*ExecResult, error) { - if err := validateUserID(req.BotID); err != nil { + if err := validateBotID(req.BotID); err != nil { return nil, err } if len(req.Command) == 0 { @@ -227,8 +229,9 @@ func (m *Manager) Exec(ctx context.Context, req ExecRequest) (*ExecResult, error return &ExecResult{ExitCode: result.ExitCode}, nil } -func (m *Manager) DataDir(userID string) (string, error) { - if err := validateUserID(userID); err != nil { +// DataDir returns the host data directory for a bot. +func (m *Manager) DataDir(botID string) (string, error) { + if err := validateBotID(botID); err != nil { return "", err } @@ -236,21 +239,21 @@ func (m *Manager) DataDir(userID string) (string, error) { if root == "" { root = config.DefaultDataRoot } - return filepath.Join(root, "bots", userID), nil + return filepath.Join(root, "bots", botID), nil } -func (m *Manager) ensureUserDir(userID string) (string, error) { +func (m *Manager) ensureBotDir(botID string) (string, error) { root := m.cfg.DataRoot if root == "" { root = config.DefaultDataRoot } - dir := filepath.Join(root, "bots", userID) + dir := filepath.Join(root, "bots", botID) if err := os.MkdirAll(dir, 0o755); err != nil { return "", err } return dir, nil } -func validateUserID(userID string) error { - return identity.ValidateUserID(userID) +func validateBotID(botID string) error { + return identity.ValidateUserID(botID) } diff --git a/internal/mcp/versioning.go b/internal/mcp/versioning.go index 4d2ce99d..a4f40ca5 100644 --- a/internal/mcp/versioning.go +++ b/internal/mcp/versioning.go @@ -30,7 +30,7 @@ func (m *Manager) CreateVersion(ctx context.Context, userID string) (*VersionInf if m.db == nil || m.queries == nil { return nil, fmt.Errorf("db is not configured") } - if err := validateUserID(userID); err != nil { + if err := validateBotID(userID); err != nil { return nil, err } @@ -67,7 +67,7 @@ func (m *Manager) CreateVersion(ctx context.Context, userID string) (*VersionInf return nil, err } - dataDir, err := m.ensureUserDir(userID) + dataDir, err := m.ensureBotDir(userID) if err != nil { return nil, err } @@ -129,7 +129,7 @@ func (m *Manager) ListVersions(ctx context.Context, userID string) ([]VersionInf if m.db == nil || m.queries == nil { return nil, fmt.Errorf("db is not configured") } - if err := validateUserID(userID); err != nil { + if err := validateBotID(userID); err != nil { return nil, err } @@ -159,7 +159,7 @@ func (m *Manager) RollbackVersion(ctx context.Context, userID string, version in if m.db == nil || m.queries == nil { return fmt.Errorf("db is not configured") } - if err := validateUserID(userID); err != nil { + if err := validateBotID(userID); err != nil { return err } @@ -194,7 +194,7 @@ func (m *Manager) RollbackVersion(ctx context.Context, userID string, version in return err } - dataDir, err := m.ensureUserDir(userID) + dataDir, err := m.ensureBotDir(userID) if err != nil { return err } @@ -241,7 +241,7 @@ func (m *Manager) VersionSnapshotID(ctx context.Context, userID string, version if m.db == nil || m.queries == nil { return "", fmt.Errorf("db is not configured") } - if err := validateUserID(userID); err != nil { + if err := validateBotID(userID); err != nil { return "", err } diff --git a/internal/memory/llm_client.go b/internal/memory/llm_client.go index cf04af58..2eb76d2f 100644 --- a/internal/memory/llm_client.go +++ b/internal/memory/llm_client.go @@ -86,15 +86,15 @@ func (c *LLMClient) Decide(ctx context.Context, req DecideRequest) (DecideRespon } cleaned := removeCodeBlocks(content) - var memoryItems []map[string]interface{} + var memoryItems []map[string]any // Try parsing as object first - var raw map[string]interface{} + var raw map[string]any if err := json.Unmarshal([]byte(cleaned), &raw); err == nil { memoryItems = normalizeMemoryItems(raw["memory"]) } else { // If object parsing fails, try parsing as array directly - var arr []interface{} + var arr []any if err := json.Unmarshal([]byte(cleaned), &arr); err != nil { return DecideResponse{}, fmt.Errorf("failed to parse LLM response: %w", err) } @@ -222,7 +222,7 @@ func formatMessages(messages []Message) []string { return formatted } -func asString(value interface{}) string { +func asString(value any) string { switch typed := value.(type) { case string: return typed @@ -240,7 +240,7 @@ func asString(value interface{}) string { } } -func normalizeID(value interface{}) string { +func normalizeID(value any) string { id := asString(value) if id == "" { return "" @@ -248,31 +248,31 @@ func normalizeID(value interface{}) string { return id } -func normalizeMemoryItems(value interface{}) []map[string]interface{} { +func normalizeMemoryItems(value any) []map[string]any { switch typed := value.(type) { - case []interface{}: - items := make([]map[string]interface{}, 0, len(typed)) + case []any: + items := make([]map[string]any, 0, len(typed)) for _, item := range typed { - if m, ok := item.(map[string]interface{}); ok { + if m, ok := item.(map[string]any); ok { items = append(items, m) } } return items - case map[string]interface{}: + case map[string]any: // If this map looks like a single item, wrap it. if _, hasText := typed["text"]; hasText { - return []map[string]interface{}{typed} + return []map[string]any{typed} } if _, hasFact := typed["fact"]; hasFact { - return []map[string]interface{}{typed} + return []map[string]any{typed} } if _, hasEvent := typed["event"]; hasEvent { - return []map[string]interface{}{typed} + return []map[string]any{typed} } // Otherwise treat as map of items. - items := make([]map[string]interface{}, 0, len(typed)) + items := make([]map[string]any, 0, len(typed)) for _, item := range typed { - if m, ok := item.(map[string]interface{}); ok { + if m, ok := item.(map[string]any); ok { items = append(items, m) } } diff --git a/internal/memory/prompts.go b/internal/memory/prompts.go index b312e673..d3240ddd 100644 --- a/internal/memory/prompts.go +++ b/internal/memory/prompts.go @@ -128,7 +128,7 @@ func removeCodeBlocks(text string) string { return strings.ReplaceAll(strings.ReplaceAll(text, "```json", ""), "```", "") } -func toJSON(value interface{}) string { +func toJSON(value any) string { data, err := json.Marshal(value) if err != nil { return "[]" diff --git a/internal/memory/qdrant_store.go b/internal/memory/qdrant_store.go index ab856ee2..f9c3d376 100644 --- a/internal/memory/qdrant_store.go +++ b/internal/memory/qdrant_store.go @@ -38,7 +38,7 @@ type qdrantPoint struct { SparseIndices []uint32 `json:"sparse_indices,omitempty"` SparseValues []float32 `json:"sparse_values,omitempty"` SparseVectorName string `json:"sparse_vector_name,omitempty"` - Payload map[string]interface{} `json:"payload,omitempty"` + Payload map[string]any `json:"payload,omitempty"` } func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimension int, sparseVectorName string, timeout time.Duration) (*QdrantStore, error) { @@ -189,7 +189,7 @@ func (s *QdrantStore) Upsert(ctx context.Context, points []qdrantPoint) error { return err } -func (s *QdrantStore) Search(ctx context.Context, vector []float32, limit int, filters map[string]interface{}, vectorName string) ([]qdrantPoint, []float64, error) { +func (s *QdrantStore) Search(ctx context.Context, vector []float32, limit int, filters map[string]any, vectorName string) ([]qdrantPoint, []float64, error) { if limit <= 0 { limit = 10 } @@ -222,7 +222,7 @@ func (s *QdrantStore) Search(ctx context.Context, vector []float32, limit int, f return points, scores, nil } -func (s *QdrantStore) SearchSparse(ctx context.Context, indices []uint32, values []float32, limit int, filters map[string]interface{}) ([]qdrantPoint, []float64, error) { +func (s *QdrantStore) SearchSparse(ctx context.Context, indices []uint32, values []float32, limit int, filters map[string]any) ([]qdrantPoint, []float64, error) { if limit <= 0 { limit = 10 } @@ -257,7 +257,7 @@ func (s *QdrantStore) SearchSparse(ctx context.Context, indices []uint32, values return points, scores, nil } -func (s *QdrantStore) SearchBySources(ctx context.Context, vector []float32, limit int, filters map[string]interface{}, sources []string, vectorName string) (map[string][]qdrantPoint, map[string][]float64, error) { +func (s *QdrantStore) SearchBySources(ctx context.Context, vector []float32, limit int, filters map[string]any, sources []string, vectorName string) (map[string][]qdrantPoint, map[string][]float64, error) { pointsBySource := make(map[string][]qdrantPoint, len(sources)) scoresBySource := make(map[string][]float64, len(sources)) if len(sources) == 0 { @@ -278,7 +278,7 @@ func (s *QdrantStore) SearchBySources(ctx context.Context, vector []float32, lim return pointsBySource, scoresBySource, nil } -func (s *QdrantStore) SearchSparseBySources(ctx context.Context, indices []uint32, values []float32, limit int, filters map[string]interface{}, sources []string) (map[string][]qdrantPoint, map[string][]float64, error) { +func (s *QdrantStore) SearchSparseBySources(ctx context.Context, indices []uint32, values []float32, limit int, filters map[string]any, sources []string) (map[string][]qdrantPoint, map[string][]float64, error) { pointsBySource := make(map[string][]qdrantPoint, len(sources)) scoresBySource := make(map[string][]float64, len(sources)) if len(sources) == 0 { @@ -327,7 +327,7 @@ func (s *QdrantStore) Delete(ctx context.Context, id string) error { return err } -func (s *QdrantStore) List(ctx context.Context, limit int, filters map[string]interface{}) ([]qdrantPoint, error) { +func (s *QdrantStore) List(ctx context.Context, limit int, filters map[string]any) ([]qdrantPoint, error) { if limit <= 0 { limit = 100 } @@ -352,7 +352,7 @@ func (s *QdrantStore) List(ctx context.Context, limit int, filters map[string]in return result, nil } -func (s *QdrantStore) Scroll(ctx context.Context, limit int, filters map[string]interface{}, offset *qdrant.PointId) ([]qdrantPoint, *qdrant.PointId, error) { +func (s *QdrantStore) Scroll(ctx context.Context, limit int, filters map[string]any, offset *qdrant.PointId) ([]qdrantPoint, *qdrant.PointId, error) { if limit <= 0 { limit = 100 } @@ -377,7 +377,7 @@ func (s *QdrantStore) Scroll(ctx context.Context, limit int, filters map[string] return result, nextOffset, nil } -func (s *QdrantStore) DeleteAll(ctx context.Context, filters map[string]interface{}) error { +func (s *QdrantStore) DeleteAll(ctx context.Context, filters map[string]any) error { filter := buildQdrantFilter(filters) if filter == nil { return fmt.Errorf("delete all requires filters") @@ -542,7 +542,7 @@ func timeoutOrDefault(timeout time.Duration) time.Duration { return timeout } -func buildQdrantFilter(filters map[string]interface{}) *qdrant.Filter { +func buildQdrantFilter(filters map[string]any) *qdrant.Filter { if len(filters) == 0 { return nil } @@ -560,18 +560,18 @@ func buildQdrantFilter(filters map[string]interface{}) *qdrant.Filter { } } -func cloneFilters(filters map[string]interface{}) map[string]interface{} { +func cloneFilters(filters map[string]any) map[string]any { if len(filters) == 0 { - return map[string]interface{}{} + return map[string]any{} } - clone := make(map[string]interface{}, len(filters)) + clone := make(map[string]any, len(filters)) for key, value := range filters { clone[key] = value } return clone } -func buildQdrantCondition(key string, value interface{}) *qdrant.Condition { +func buildQdrantCondition(key string, value any) *qdrant.Condition { switch typed := value.(type) { case string: return qdrant.NewMatch(key, typed) @@ -586,7 +586,7 @@ func buildQdrantCondition(key string, value interface{}) *qdrant.Condition { return qdrant.NewRange(key, &qdrant.Range{Gte: &v, Lte: &v}) case float64: return qdrant.NewRange(key, &qdrant.Range{Gte: &typed, Lte: &typed}) - case map[string]interface{}: + case map[string]any: rangeValue := &qdrant.Range{} for _, op := range []string{"gte", "gt", "lte", "lt"} { if raw, ok := typed[op]; ok { @@ -613,7 +613,7 @@ func buildQdrantCondition(key string, value interface{}) *qdrant.Condition { return qdrant.NewMatch(key, fmt.Sprint(value)) } -func toFloat(value interface{}) (float64, bool) { +func toFloat(value any) (float64, bool) { switch typed := value.(type) { case float32: return float64(typed), true @@ -641,15 +641,15 @@ func pointIDToString(id *qdrant.PointId) string { return "" } -func valueMapToInterface(values map[string]*qdrant.Value) map[string]interface{} { - result := make(map[string]interface{}, len(values)) +func valueMapToInterface(values map[string]*qdrant.Value) map[string]any { + result := make(map[string]any, len(values)) for key, value := range values { result[key] = valueToInterface(value) } return result } -func valueToInterface(value *qdrant.Value) interface{} { +func valueToInterface(value *qdrant.Value) any { if value == nil { return nil } @@ -667,7 +667,7 @@ func valueToInterface(value *qdrant.Value) interface{} { case *qdrant.Value_StructValue: return valueMapToInterface(kind.StructValue.GetFields()) case *qdrant.Value_ListValue: - items := make([]interface{}, 0, len(kind.ListValue.GetValues())) + items := make([]any, 0, len(kind.ListValue.GetValues())) for _, item := range kind.ListValue.GetValues() { items = append(items, valueToInterface(item)) } diff --git a/internal/memory/qdrant_store_test.go b/internal/memory/qdrant_store_test.go index 80a2726a..7753bf25 100644 --- a/internal/memory/qdrant_store_test.go +++ b/internal/memory/qdrant_store_test.go @@ -5,9 +5,9 @@ import "testing" func TestBuildQdrantFilter(t *testing.T) { t.Parallel() - filter := buildQdrantFilter(map[string]interface{}{ + filter := buildQdrantFilter(map[string]any{ "userId": "u1", - "score": map[string]interface{}{"gte": 0.5}, + "score": map[string]any{"gte": 0.5}, }) if filter == nil { t.Fatalf("expected filter") diff --git a/internal/memory/service.go b/internal/memory/service.go index 6e13a08b..343ee1b3 100644 --- a/internal/memory/service.go +++ b/internal/memory/service.go @@ -103,7 +103,7 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro if err != nil { return SearchResponse{}, err } - item.Metadata = mergeMetadata(item.Metadata, map[string]interface{}{ + item.Metadata = mergeMetadata(item.Metadata, map[string]any{ "event": "ADD", }) results = append(results, item) @@ -112,7 +112,7 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro if err != nil { return SearchResponse{}, err } - item.Metadata = mergeMetadata(item.Metadata, map[string]interface{}{ + item.Metadata = mergeMetadata(item.Metadata, map[string]any{ "event": "UPDATE", "previous_memory": action.OldMemory, }) @@ -122,7 +122,7 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro if err != nil { return SearchResponse{}, err } - item.Metadata = mergeMetadata(item.Metadata, map[string]interface{}{ + item.Metadata = mergeMetadata(item.Metadata, map[string]any{ "event": "DELETE", }) results = append(results, item) @@ -294,7 +294,7 @@ func (s *Service) EmbedUpsert(ctx context.Context, req EmbedUpsertRequest) (Embe id := uuid.NewString() filters := buildEmbedFilters(req) payload := buildEmbeddingPayload(req, filters) - if metadata, ok := payload["metadata"].(map[string]interface{}); ok && result.Model != "" { + if metadata, ok := payload["metadata"].(map[string]any); ok && result.Model != "" { metadata["model_id"] = result.Model } if err := s.store.Upsert(ctx, []qdrantPoint{{ @@ -408,7 +408,7 @@ func (s *Service) Get(ctx context.Context, memoryID string) (MemoryItem, error) } func (s *Service) GetAll(ctx context.Context, req GetAllRequest) (SearchResponse, error) { - filters := map[string]interface{}{} + filters := map[string]any{} if req.BotID != "" { filters["botId"] = req.BotID } @@ -444,7 +444,7 @@ func (s *Service) Delete(ctx context.Context, memoryID string) (DeleteResponse, } func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteResponse, error) { - filters := map[string]interface{}{} + filters := map[string]any{} if req.BotID != "" { filters["botId"] = req.BotID } @@ -499,14 +499,14 @@ func (s *Service) WarmupBM25(ctx context.Context, batchSize int) error { return nil } -func (s *Service) addRawMessages(ctx context.Context, messages []Message, filters map[string]interface{}, metadata map[string]interface{}, embeddingEnabled bool) (SearchResponse, error) { +func (s *Service) addRawMessages(ctx context.Context, messages []Message, filters map[string]any, metadata map[string]any, embeddingEnabled bool) (SearchResponse, error) { results := make([]MemoryItem, 0, len(messages)) for _, message := range messages { item, err := s.applyAdd(ctx, message.Content, filters, metadata, embeddingEnabled) if err != nil { return SearchResponse{}, err } - item.Metadata = mergeMetadata(item.Metadata, map[string]interface{}{ + item.Metadata = mergeMetadata(item.Metadata, map[string]any{ "event": "ADD", }) results = append(results, item) @@ -514,7 +514,7 @@ func (s *Service) addRawMessages(ctx context.Context, messages []Message, filter return SearchResponse{Results: results}, nil } -func (s *Service) collectCandidates(ctx context.Context, facts []string, filters map[string]interface{}) ([]CandidateMemory, error) { +func (s *Service) collectCandidates(ctx context.Context, facts []string, filters map[string]any) ([]CandidateMemory, error) { unique := map[string]CandidateMemory{} for _, fact := range facts { if s.bm25 == nil { @@ -550,7 +550,7 @@ func (s *Service) collectCandidates(ctx context.Context, facts []string, filters return candidates, nil } -func (s *Service) applyAdd(ctx context.Context, text string, filters map[string]interface{}, metadata map[string]interface{}, embeddingEnabled bool) (MemoryItem, error) { +func (s *Service) applyAdd(ctx context.Context, text string, filters map[string]any, metadata map[string]any, embeddingEnabled bool) (MemoryItem, error) { if s.store == nil { return MemoryItem{}, fmt.Errorf("qdrant store not configured") } @@ -593,7 +593,7 @@ func (s *Service) applyAdd(ctx context.Context, text string, filters map[string] return payloadToMemoryItem(id, payload), nil } -func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[string]interface{}, metadata map[string]interface{}, embeddingEnabled bool) (MemoryItem, error) { +func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[string]any, metadata map[string]any, embeddingEnabled bool) (MemoryItem, error) { if strings.TrimSpace(id) == "" { return MemoryItem{}, fmt.Errorf("update action missing id") } @@ -748,8 +748,8 @@ func isCJKRune(r rune) bool { return false } -func buildFilters(req AddRequest) map[string]interface{} { - filters := map[string]interface{}{} +func buildFilters(req AddRequest) map[string]any { + filters := map[string]any{} for key, value := range req.Filters { filters[key] = value } @@ -765,8 +765,8 @@ func buildFilters(req AddRequest) map[string]interface{} { return filters } -func buildSearchFilters(req SearchRequest) map[string]interface{} { - filters := map[string]interface{}{} +func buildSearchFilters(req SearchRequest) map[string]any { + filters := map[string]any{} for key, value := range req.Filters { filters[key] = value } @@ -782,8 +782,8 @@ func buildSearchFilters(req SearchRequest) map[string]interface{} { return filters } -func buildEmbedFilters(req EmbedUpsertRequest) map[string]interface{} { - filters := map[string]interface{}{} +func buildEmbedFilters(req EmbedUpsertRequest) map[string]any { + filters := map[string]any{} for key, value := range req.Filters { filters[key] = value } @@ -799,7 +799,7 @@ func buildEmbedFilters(req EmbedUpsertRequest) map[string]interface{} { return filters } -func buildEmbeddingPayload(req EmbedUpsertRequest, filters map[string]interface{}) map[string]interface{} { +func buildEmbeddingPayload(req EmbedUpsertRequest, filters map[string]any) map[string]any { text := req.Input.Text payload := buildPayload(text, filters, req.Metadata, "") payload["hash"] = hashEmbeddingInput(req.Input.Text, req.Input.ImageURL, req.Input.VideoURL) @@ -813,9 +813,9 @@ func buildEmbeddingPayload(req EmbedUpsertRequest, filters map[string]interface{ payload["modality"] = modality if payload["metadata"] == nil { - payload["metadata"] = map[string]interface{}{} + payload["metadata"] = map[string]any{} } - if metadata, ok := payload["metadata"].(map[string]interface{}); ok { + if metadata, ok := payload["metadata"].(map[string]any); ok { if req.Source != "" { metadata["source"] = req.Source } @@ -844,11 +844,11 @@ func (s *Service) vectorNameForMultimodal() string { return strings.TrimSpace(s.defaultMultimodalModelID) } -func buildPayload(text string, filters map[string]interface{}, metadata map[string]interface{}, createdAt string) map[string]interface{} { +func buildPayload(text string, filters map[string]any, metadata map[string]any, createdAt string) map[string]any { if createdAt == "" { createdAt = time.Now().UTC().Format(time.RFC3339) } - payload := map[string]interface{}{ + payload := map[string]any{ "data": text, "hash": hashMemory(text), "createdAt": createdAt, @@ -860,13 +860,13 @@ func buildPayload(text string, filters map[string]interface{}, metadata map[stri return payload } -func applyFiltersToPayload(payload map[string]interface{}, filters map[string]interface{}) { +func applyFiltersToPayload(payload map[string]any, filters map[string]any) { for key, value := range filters { payload[key] = value } } -func payloadToMemoryItem(id string, payload map[string]interface{}) MemoryItem { +func payloadToMemoryItem(id string, payload map[string]any) MemoryItem { item := MemoryItem{ ID: id, Memory: fmt.Sprint(payload["data"]), @@ -889,10 +889,10 @@ func payloadToMemoryItem(id string, payload map[string]interface{}) MemoryItem { if v, ok := payload["runId"].(string); ok { item.RunID = v } - if meta, ok := payload["metadata"].(map[string]interface{}); ok { + if meta, ok := payload["metadata"].(map[string]any); ok { item.Metadata = meta } else if payload["metadata"] == nil { - item.Metadata = map[string]interface{}{} + item.Metadata = map[string]any{} } if item.Metadata != nil { if source, ok := payload["source"].(string); ok && source != "" { @@ -920,9 +920,9 @@ func hashEmbeddingInput(text, imageURL, videoURL string) string { return hex.EncodeToString(sum[:]) } -func mergeMetadata(base interface{}, extra map[string]interface{}) map[string]interface{} { - merged := map[string]interface{}{} - if baseMap, ok := base.(map[string]interface{}); ok { +func mergeMetadata(base any, extra map[string]any) map[string]any { + merged := map[string]any{} + if baseMap, ok := base.(map[string]any); ok { for k, v := range baseMap { merged[k] = v } @@ -935,7 +935,7 @@ func mergeMetadata(base interface{}, extra map[string]interface{}) map[string]in type rerankCandidate struct { ID string - Payload map[string]interface{} + Payload map[string]any Score float64 Source string Rank int diff --git a/internal/memory/service_test.go b/internal/memory/service_test.go index 05027951..92db7d63 100644 --- a/internal/memory/service_test.go +++ b/internal/memory/service_test.go @@ -117,8 +117,8 @@ func TestRankFusion_Logic(t *testing.T) { // 测试 RRF (Reciprocal Rank Fusion) 逻辑 // 验证不同来源的结果是否能被正确合并和排序 - p1 := qdrantPoint{ID: "1", Payload: map[string]interface{}{"data": "result 1"}} - p2 := qdrantPoint{ID: "2", Payload: map[string]interface{}{"data": "result 2"}} + p1 := qdrantPoint{ID: "1", Payload: map[string]any{"data": "result 1"}} + p2 := qdrantPoint{ID: "2", Payload: map[string]any{"data": "result 2"}} // 来源 A: 1 号排第一,2 号排第二 // 来源 B: 2 号排第一,1 号排第二 diff --git a/internal/memory/types.go b/internal/memory/types.go index 1b7f8477..22299457 100644 --- a/internal/memory/types.go +++ b/internal/memory/types.go @@ -21,8 +21,8 @@ type AddRequest struct { SessionID string `json:"session_id,omitempty"` AgentID string `json:"agent_id,omitempty"` RunID string `json:"run_id,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` - Filters map[string]interface{} `json:"filters,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + Filters map[string]any `json:"filters,omitempty"` Infer *bool `json:"infer,omitempty"` EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` } @@ -34,7 +34,7 @@ type SearchRequest struct { AgentID string `json:"agent_id,omitempty"` RunID string `json:"run_id,omitempty"` Limit int `json:"limit,omitempty"` - Filters map[string]interface{} `json:"filters,omitempty"` + Filters map[string]any `json:"filters,omitempty"` Sources []string `json:"sources,omitempty"` EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` } @@ -76,8 +76,8 @@ type EmbedUpsertRequest struct { SessionID string `json:"session_id,omitempty"` AgentID string `json:"agent_id,omitempty"` RunID string `json:"run_id,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` - Filters map[string]interface{} `json:"filters,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + Filters map[string]any `json:"filters,omitempty"` } type EmbedUpsertResponse struct { @@ -94,7 +94,7 @@ type MemoryItem struct { CreatedAt string `json:"createdAt,omitempty"` UpdatedAt string `json:"updatedAt,omitempty"` Score float64 `json:"score,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` BotID string `json:"botId,omitempty"` SessionID string `json:"sessionId,omitempty"` AgentID string `json:"agentId,omitempty"` @@ -112,8 +112,8 @@ type DeleteResponse struct { type ExtractRequest struct { Messages []Message `json:"messages"` - Filters map[string]interface{} `json:"filters,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` + Filters map[string]any `json:"filters,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } type ExtractResponse struct { @@ -123,14 +123,14 @@ type ExtractResponse struct { type CandidateMemory struct { ID string `json:"id"` Memory string `json:"memory"` - Metadata map[string]interface{} `json:"metadata,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } type DecideRequest struct { Facts []string `json:"facts"` Candidates []CandidateMemory `json:"candidates"` - Filters map[string]interface{} `json:"filters,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` + Filters map[string]any `json:"filters,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } type DecisionAction struct { diff --git a/internal/models/models.go b/internal/models/models.go index 7d9f4fa8..45993883 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -330,6 +330,7 @@ func convertToGetResponse(dbModel sqlc.Model) GetResponse { Model: Model{ ModelID: dbModel.ModelID, IsMultimodal: dbModel.IsMultimodal, + Input: modelInputFromMultimodal(dbModel.IsMultimodal), Type: ModelType(dbModel.Type), }, } @@ -357,6 +358,14 @@ func convertToGetResponseList(dbModels []sqlc.Model) []GetResponse { return responses } +// modelInputFromMultimodal builds the input list based on multimodal support. +func modelInputFromMultimodal(isMultimodal bool) []string { + if isMultimodal { + return []string{ModelInputText, ModelInputImage} + } + return []string{ModelInputText} +} + func isValidClientType(clientType ClientType) bool { switch clientType { case ClientTypeOpenAI, diff --git a/internal/models/types.go b/internal/models/types.go index cdd5a66f..c0ef4df2 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -13,6 +13,11 @@ const ( ModelTypeEmbedding ModelType = "embedding" ) +const ( + ModelInputText = "text" + ModelInputImage = "image" +) + type ClientType string const ( @@ -31,6 +36,7 @@ type Model struct { Name string `json:"name"` LlmProviderID string `json:"llm_provider_id"` IsMultimodal bool `json:"is_multimodal"` + Input []string `json:"input"` Type ModelType `json:"type"` Dimensions int `json:"dimensions"` } diff --git a/internal/policy/service.go b/internal/policy/service.go new file mode 100644 index 00000000..518e21bb --- /dev/null +++ b/internal/policy/service.go @@ -0,0 +1,61 @@ +package policy + +import ( + "context" + "fmt" + "log/slog" + "strings" + + "github.com/memohai/memoh/internal/bots" + "github.com/memohai/memoh/internal/settings" +) + +type Decision struct { + BotID string + BotType string + AllowGuest bool +} + +type Service struct { + bots *bots.Service + settings *settings.Service + logger *slog.Logger +} + +func NewService(log *slog.Logger, botsService *bots.Service, settingsService *settings.Service) *Service { + if log == nil { + log = slog.Default() + } + return &Service{ + bots: botsService, + settings: settingsService, + logger: log.With(slog.String("service", "policy")), + } +} + +func (s *Service) Resolve(ctx context.Context, botID string) (Decision, error) { + if s == nil || s.bots == nil || s.settings == nil { + return Decision{}, fmt.Errorf("policy service not configured") + } + botID = strings.TrimSpace(botID) + if botID == "" { + return Decision{}, fmt.Errorf("bot id is required") + } + bot, err := s.bots.Get(ctx, botID) + if err != nil { + return Decision{}, err + } + botSettings, err := s.settings.GetBot(ctx, botID) + if err != nil { + return Decision{}, err + } + decision := Decision{ + BotID: botID, + BotType: strings.TrimSpace(bot.Type), + AllowGuest: botSettings.AllowGuest, + } + if decision.BotType == bots.BotTypePersonal { + decision.AllowGuest = false + } + return decision, nil +} diff --git a/internal/preauth/service.go b/internal/preauth/service.go new file mode 100644 index 00000000..7aa8f5b9 --- /dev/null +++ b/internal/preauth/service.go @@ -0,0 +1,128 @@ +package preauth + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/db/sqlc" +) + +var ErrKeyNotFound = errors.New("preauth key not found") + +type Service struct { + queries *sqlc.Queries +} + +func NewService(queries *sqlc.Queries) *Service { + return &Service{queries: queries} +} + +func (s *Service) Issue(ctx context.Context, botID, issuedByUserID string, ttl time.Duration) (Key, error) { + if s.queries == nil { + return Key{}, fmt.Errorf("preauth queries not configured") + } + if ttl <= 0 { + ttl = 24 * time.Hour + } + pgBotID, err := parseUUID(botID) + if err != nil { + return Key{}, err + } + pgIssuedBy := pgtype.UUID{Valid: false} + if strings.TrimSpace(issuedByUserID) != "" { + parsed, err := parseUUID(issuedByUserID) + if err != nil { + return Key{}, err + } + pgIssuedBy = parsed + } + token := strings.ReplaceAll(uuid.NewString(), "-", "")[:8] + expiresAt := time.Now().UTC().Add(ttl) + row, err := s.queries.CreateBotPreauthKey(ctx, sqlc.CreateBotPreauthKeyParams{ + BotID: pgBotID, + Token: token, + IssuedByUserID: pgIssuedBy, + ExpiresAt: pgtype.Timestamptz{Time: expiresAt, Valid: true}, + }) + if err != nil { + return Key{}, err + } + return normalizeKey(row), nil +} + +func (s *Service) Get(ctx context.Context, token string) (Key, error) { + if s.queries == nil { + return Key{}, fmt.Errorf("preauth queries not configured") + } + row, err := s.queries.GetBotPreauthKey(ctx, strings.TrimSpace(token)) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return Key{}, ErrKeyNotFound + } + return Key{}, err + } + return normalizeKey(row), nil +} + +func (s *Service) MarkUsed(ctx context.Context, id string) (Key, error) { + if s.queries == nil { + return Key{}, fmt.Errorf("preauth queries not configured") + } + pgID, err := parseUUID(id) + if err != nil { + return Key{}, err + } + row, err := s.queries.MarkBotPreauthKeyUsed(ctx, pgID) + if err != nil { + return Key{}, err + } + return normalizeKey(row), nil +} + +func normalizeKey(row sqlc.BotPreauthKey) Key { + return Key{ + ID: toUUIDString(row.ID), + BotID: toUUIDString(row.BotID), + Token: strings.TrimSpace(row.Token), + IssuedByUserID: toUUIDString(row.IssuedByUserID), + ExpiresAt: timeFromPg(row.ExpiresAt), + UsedAt: timeFromPg(row.UsedAt), + CreatedAt: timeFromPg(row.CreatedAt), + } +} + +func parseUUID(id string) (pgtype.UUID, error) { + parsed, err := uuid.Parse(strings.TrimSpace(id)) + if err != nil { + return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) + } + var pgID pgtype.UUID + pgID.Valid = true + copy(pgID.Bytes[:], parsed[:]) + return pgID, nil +} + +func toUUIDString(value pgtype.UUID) string { + if !value.Valid { + return "" + } + parsed, err := uuid.FromBytes(value.Bytes[:]) + if err != nil { + return "" + } + return parsed.String() +} + +func timeFromPg(value pgtype.Timestamptz) time.Time { + if value.Valid { + return value.Time + } + return time.Time{} +} diff --git a/internal/preauth/types.go b/internal/preauth/types.go new file mode 100644 index 00000000..cd26b086 --- /dev/null +++ b/internal/preauth/types.go @@ -0,0 +1,13 @@ +package preauth + +import "time" + +type Key struct { + ID string + BotID string + Token string + IssuedByUserID string + ExpiresAt time.Time + UsedAt time.Time + CreatedAt time.Time +} diff --git a/internal/providers/service.go b/internal/providers/service.go index 2c842e37..d2118f5f 100644 --- a/internal/providers/service.go +++ b/internal/providers/service.go @@ -211,7 +211,7 @@ func (s *Service) CountByClientType(ctx context.Context, clientType ClientType) // toGetResponse converts a database provider to a response func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse { - var metadata map[string]interface{} + var metadata map[string]any if len(provider.Metadata) > 0 { _ = json.Unmarshal(provider.Metadata, &metadata) } @@ -268,4 +268,3 @@ func maskAPIKey(apiKey string) string { } return apiKey[:8] + strings.Repeat("*", len(apiKey)-8) } - diff --git a/internal/providers/types.go b/internal/providers/types.go index 1a8ddd36..4eec664e 100644 --- a/internal/providers/types.go +++ b/internal/providers/types.go @@ -15,32 +15,32 @@ const ( // CreateRequest represents a request to create a new LLM provider type CreateRequest struct { - Name string `json:"name" validate:"required"` - ClientType ClientType `json:"client_type" validate:"required"` - BaseURL string `json:"base_url" validate:"required,url"` - APIKey string `json:"api_key"` - Metadata map[string]interface{} `json:"metadata,omitempty"` + Name string `json:"name" validate:"required"` + ClientType ClientType `json:"client_type" validate:"required"` + BaseURL string `json:"base_url" validate:"required,url"` + APIKey string `json:"api_key"` + Metadata map[string]any `json:"metadata,omitempty"` } // UpdateRequest represents a request to update an existing LLM provider type UpdateRequest struct { - Name *string `json:"name,omitempty"` - ClientType *ClientType `json:"client_type,omitempty"` - BaseURL *string `json:"base_url,omitempty"` - APIKey *string `json:"api_key,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` + Name *string `json:"name,omitempty"` + ClientType *ClientType `json:"client_type,omitempty"` + BaseURL *string `json:"base_url,omitempty"` + APIKey *string `json:"api_key,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } // GetResponse represents the response for getting a provider type GetResponse struct { - ID string `json:"id"` - Name string `json:"name"` - ClientType string `json:"client_type"` - BaseURL string `json:"base_url"` - APIKey string `json:"api_key,omitempty"` // masked in response - Metadata map[string]interface{} `json:"metadata,omitempty"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID string `json:"id"` + Name string `json:"name"` + ClientType string `json:"client_type"` + BaseURL string `json:"base_url"` + APIKey string `json:"api_key,omitempty"` // masked in response + Metadata map[string]any `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // ListResponse represents the response for listing providers @@ -68,4 +68,3 @@ type TestResponse struct { Message string `json:"message,omitempty"` Latency int64 `json:"latency_ms,omitempty"` // latency in milliseconds } - diff --git a/internal/router/channel.go b/internal/router/channel.go index 4886c4b1..d146441d 100644 --- a/internal/router/channel.go +++ b/internal/router/channel.go @@ -2,16 +2,18 @@ package router import ( "context" + "encoding/json" "fmt" "log/slog" + "regexp" "strings" "time" + "unicode" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/channel" "github.com/memohai/memoh/internal/chat" "github.com/memohai/memoh/internal/contacts" - "github.com/memohai/memoh/internal/settings" ) // ChatGateway 抽象聊天能力,避免路由层直接依赖具体实现。 @@ -25,158 +27,86 @@ type ContactService interface { GetByChannelIdentity(ctx context.Context, botID, platform, externalID string) (contacts.ContactChannel, error) Create(ctx context.Context, req contacts.CreateRequest) (contacts.Contact, error) CreateGuest(ctx context.Context, botID, displayName string) (contacts.Contact, error) - UpsertChannel(ctx context.Context, botID, contactID, platform, externalID string, metadata map[string]interface{}) (contacts.ContactChannel, error) - GetBindToken(ctx context.Context, token string) (contacts.BindToken, error) - MarkBindTokenUsed(ctx context.Context, id string) (contacts.BindToken, error) - BindUser(ctx context.Context, contactID, userID string) (contacts.Contact, error) + UpsertChannel(ctx context.Context, botID, contactID, platform, externalID string, metadata map[string]any) (contacts.ContactChannel, error) } -type SettingsService interface { - GetBot(ctx context.Context, botID string) (settings.Settings, error) -} +const ( + silentReplyToken = "NO_REPLY" + minDuplicateTextLength = 10 +) + +var ( + whitespacePattern = regexp.MustCompile(`\s+`) +) // ChannelInboundProcessor 将 channel 入站消息路由到 chat,并返回可发送的回复。 type ChannelInboundProcessor struct { - store channel.ConfigStore - chat ChatGateway - contacts ContactService - settings SettingsService - logger *slog.Logger - unboundReply string - bindSuccessReply string - jwtSecret string - tokenTTL time.Duration + chat ChatGateway + logger *slog.Logger + jwtSecret string + tokenTTL time.Duration + identity *IdentityResolver } -func NewChannelInboundProcessor(log *slog.Logger, store channel.ConfigStore, chatGateway ChatGateway, contactService ContactService, settingsService SettingsService, jwtSecret string, tokenTTL time.Duration) *ChannelInboundProcessor { +func NewChannelInboundProcessor(log *slog.Logger, store channel.ConfigStore, chatGateway ChatGateway, contactService ContactService, policyService PolicyService, preauthService PreauthService, jwtSecret string, tokenTTL time.Duration) *ChannelInboundProcessor { if log == nil { log = slog.Default() } if tokenTTL <= 0 { tokenTTL = 5 * time.Minute } + identityResolver := NewIdentityResolver(log, store, contactService, policyService, preauthService, "", "") return &ChannelInboundProcessor{ - store: store, - chat: chatGateway, - contacts: contactService, - settings: settingsService, - logger: log.With(slog.String("component", "channel_router")), - unboundReply: "当前不允许陌生人访问,请联系管理员。", - bindSuccessReply: "绑定成功,感谢确认。", - jwtSecret: strings.TrimSpace(jwtSecret), - tokenTTL: tokenTTL, + chat: chatGateway, + logger: log.With(slog.String("component", "channel_router")), + jwtSecret: strings.TrimSpace(jwtSecret), + tokenTTL: tokenTTL, + identity: identityResolver, } } -func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) (*channel.OutboundMessage, error) { - if p.store == nil || p.chat == nil || p.contacts == nil { - return nil, fmt.Errorf("channel inbound processor not configured") - } - if strings.TrimSpace(msg.Text) == "" { - return nil, nil - } - if strings.TrimSpace(msg.BotID) == "" { - msg.BotID = cfg.BotID +func (p *ChannelInboundProcessor) IdentityMiddleware() channel.Middleware { + if p == nil || p.identity == nil { + return nil } + return p.identity.Middleware() +} - sessionID := msg.SessionID() - channelConfigID := cfg.ID - if msg.Channel == channel.ChannelCLI || msg.Channel == channel.ChannelWeb { - channelConfigID = "" +func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, sender channel.ReplySender) error { + if p.chat == nil { + return fmt.Errorf("channel inbound processor not configured") } - - session, err := p.store.GetChannelSession(ctx, sessionID) - if err != nil && p.logger != nil { - p.logger.Error("get user by session failed", slog.String("session_id", sessionID), slog.Any("error", err)) + if sender == nil { + return fmt.Errorf("reply sender not configured") } - userID := strings.TrimSpace(session.UserID) - contactID := strings.TrimSpace(session.ContactID) - externalID := extractExternalIdentity(msg) - - if bindReply, handled := p.tryHandleBindToken(ctx, msg, externalID); handled { - return bindReply, nil + text := buildInboundQuery(msg.Message) + if strings.TrimSpace(text) == "" { + return nil } - - if userID == "" { - userID, err = p.store.ResolveUserBinding(ctx, msg.Channel, channel.BindingCriteria{ - Username: msg.Username, - UserID: msg.UserID, - ChatID: msg.ChatID, - OpenID: msg.OpenID, - }) - if err == nil && userID != "" { - _ = p.store.UpsertChannelSession(ctx, sessionID, msg.BotID, channelConfigID, userID, contactID, string(msg.Channel)) - } + state, err := p.requireIdentity(ctx, cfg, msg) + if err != nil { + return err } - - var contact contacts.Contact - if contactID == "" && userID != "" { - contact, err = p.contacts.GetByUserID(ctx, msg.BotID, userID) - if err != nil { - displayName := extractDisplayName(msg) - contact, err = p.contacts.Create(ctx, contacts.CreateRequest{ - BotID: msg.BotID, - UserID: userID, - DisplayName: displayName, - Status: "active", + if state.Decision != nil && state.Decision.Stop { + if !state.Decision.Reply.IsEmpty() { + return sender.Send(ctx, channel.OutboundMessage{ + Target: strings.TrimSpace(msg.ReplyTarget), + Message: state.Decision.Reply, }) } - if err == nil { - contactID = contact.ID - if externalID != "" { - _, _ = p.contacts.UpsertChannel(ctx, msg.BotID, contactID, msg.Channel.String(), externalID, nil) - } - } + return nil } - if contactID == "" && externalID != "" { - binding, err := p.contacts.GetByChannelIdentity(ctx, msg.BotID, msg.Channel.String(), externalID) - if err == nil { - contactID = binding.ContactID - } - } - - if contactID == "" { - allowGuest := false - if p.settings != nil { - botSettings, err := p.settings.GetBot(ctx, msg.BotID) - if err == nil { - allowGuest = botSettings.AllowGuest - } - } - if allowGuest { - displayName := extractDisplayName(msg) - contact, err = p.contacts.CreateGuest(ctx, msg.BotID, displayName) - if err == nil { - contactID = contact.ID - if externalID != "" { - _, _ = p.contacts.UpsertChannel(ctx, msg.BotID, contactID, msg.Channel.String(), externalID, nil) - } - } - } else { - return p.buildUnboundReply(msg) - } - } - - if contactID != "" && contact.ID == "" { - loaded, err := p.contacts.GetByID(ctx, contactID) - if err == nil { - contact = loaded - } - } - - if contactID != "" { - _ = p.store.UpsertChannelSession(ctx, sessionID, msg.BotID, channelConfigID, userID, contactID, string(msg.Channel)) - } + identity := state.Identity sessionToken := "" - if p.jwtSecret != "" && strings.TrimSpace(msg.ReplyTo) != "" { + if p.jwtSecret != "" && strings.TrimSpace(msg.ReplyTarget) != "" { signed, _, err := auth.GenerateSessionToken(auth.SessionToken{ - BotID: msg.BotID, + BotID: identity.BotID, Platform: msg.Channel.String(), - ReplyTarget: strings.TrimSpace(msg.ReplyTo), - SessionID: sessionID, - ContactID: contactID, + ReplyTarget: strings.TrimSpace(msg.ReplyTarget), + SessionID: identity.SessionID, + ContactID: identity.ContactID, }, p.jwtSecret, p.tokenTTL) if err != nil { if p.logger != nil { @@ -188,8 +118,8 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel } token := "" - if userID != "" && p.jwtSecret != "" { - signed, _, err := auth.GenerateToken(userID, p.jwtSecret, p.tokenTTL) + if identity.UserID != "" && p.jwtSecret != "" { + signed, _, err := auth.GenerateToken(identity.UserID, p.jwtSecret, p.tokenTTL) if err != nil { if p.logger != nil { p.logger.Warn("issue channel token failed", slog.Any("error", err)) @@ -198,211 +128,469 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel token = "Bearer " + signed } } + desc, _ := channel.GetChannelDescriptor(msg.Channel) resp, err := p.chat.Chat(ctx, chat.ChatRequest{ - BotID: msg.BotID, - SessionID: sessionID, + BotID: identity.BotID, + SessionID: identity.SessionID, Token: token, - UserID: userID, - Query: msg.Text, + UserID: identity.UserID, + ContactID: identity.ContactID, + ContactName: strings.TrimSpace(identity.Contact.DisplayName), + ContactAlias: strings.TrimSpace(identity.Contact.Alias), + ReplyTarget: strings.TrimSpace(msg.ReplyTarget), + SessionToken: sessionToken, + Query: text, CurrentPlatform: msg.Channel.String(), Platforms: []string{msg.Channel.String()}, - ToolContext: &chat.ToolContext{ - BotID: msg.BotID, - SessionID: sessionID, - CurrentPlatform: msg.Channel.String(), - ReplyTarget: strings.TrimSpace(msg.ReplyTo), - SessionToken: sessionToken, - ContactID: contactID, - ContactAlias: strings.TrimSpace(contact.Alias), - ContactName: strings.TrimSpace(contact.DisplayName), - }, }) if err != nil { if p.logger != nil { - p.logger.Error("chat gateway failed", slog.String("channel", msg.Channel.String()), slog.String("user_id", userID), slog.Any("error", err)) + p.logger.Error("chat gateway failed", slog.String("channel", msg.Channel.String()), slog.String("user_id", identity.UserID), slog.Any("error", err)) } - return nil, err + return err } - if len(resp.Messages) == 0 { - return nil, nil + outputs := chat.ExtractAssistantOutputs(resp.Messages) + if len(outputs) == 0 { + return nil } - // Extract assistant text as reply - if reply := extractAssistantReply(resp.Messages); strings.TrimSpace(reply) != "" { - target := strings.TrimSpace(msg.ReplyTo) - if target == "" { - return nil, fmt.Errorf("reply target missing") - } - return &channel.OutboundMessage{ - To: target, - Text: reply, - }, nil + target := strings.TrimSpace(msg.ReplyTarget) + if target == "" { + return fmt.Errorf("reply target missing") } - return nil, nil -} - -// extractAssistantReply extracts text content from the last assistant message with actual text. -// Skips assistant messages that only contain tool_calls without text content. -func extractAssistantReply(messages []chat.GatewayMessage) string { - if len(messages) == 0 { - return "" + sentTexts, suppressReplies := collectMessageToolContext(resp.Messages, msg.Channel, target) + if suppressReplies { + return nil } - reply := "" - for _, msg := range messages { - role, _ := msg["role"].(string) - if role != "" && role != "assistant" { + for _, output := range outputs { + outMessage := buildChannelMessage(output, desc.Capabilities) + if outMessage.IsEmpty() { continue } - // Skip if this message only has tool_calls without text content - if _, hasToolCalls := msg["tool_calls"]; hasToolCalls { - // Check if there's also text content - if msg["content"] == nil { + plainText := strings.TrimSpace(outMessage.PlainText()) + if isSilentReplyText(plainText) { + continue + } + if isMessagingToolDuplicate(plainText, sentTexts) { + continue + } + if err := sender.Send(ctx, channel.OutboundMessage{ + Target: target, + Message: outMessage, + }); err != nil { + return err + } + } + return nil +} + +func buildChannelMessage(output chat.AssistantOutput, capabilities channel.ChannelCapabilities) channel.Message { + msg := channel.Message{} + if strings.TrimSpace(output.Content) != "" { + msg.Text = strings.TrimSpace(output.Content) + if containsMarkdown(msg.Text) && (capabilities.Markdown || capabilities.RichText) { + msg.Format = channel.MessageFormatMarkdown + } + } + if len(output.Parts) == 0 { + return msg + } + if capabilities.RichText { + parts := make([]channel.MessagePart, 0, len(output.Parts)) + for _, part := range output.Parts { + if !contentPartHasValue(part) { continue } - } - if content, ok := msg["content"].(string); ok && strings.TrimSpace(content) != "" { - reply = content - continue - } - parts, ok := msg["content"].([]interface{}) - if !ok { - continue - } - texts := make([]string, 0, len(parts)) - for _, part := range parts { - switch value := part.(type) { - case string: - if strings.TrimSpace(value) != "" { - texts = append(texts, value) - } - case map[string]interface{}: - if text, ok := value["text"].(string); ok && strings.TrimSpace(text) != "" { - texts = append(texts, text) - } - } - } - if len(texts) > 0 { - reply = strings.Join(texts, "\n") - } - } - return reply -} - -func (p *ChannelInboundProcessor) buildUnboundReply(msg channel.InboundMessage) (*channel.OutboundMessage, error) { - target := strings.TrimSpace(msg.ReplyTo) - if target == "" { - return nil, fmt.Errorf("reply target missing") - } - return &channel.OutboundMessage{ - To: target, - Text: p.unboundReply, - }, nil -} - -func extractExternalIdentity(msg channel.InboundMessage) string { - if strings.TrimSpace(msg.OpenID) != "" { - return strings.TrimSpace(msg.OpenID) - } - if strings.TrimSpace(msg.UserID) != "" { - return strings.TrimSpace(msg.UserID) - } - if strings.TrimSpace(msg.Username) != "" { - return strings.TrimSpace(msg.Username) - } - if strings.TrimSpace(msg.ChatID) != "" { - return strings.TrimSpace(msg.ChatID) - } - return "" -} - -func extractDisplayName(msg channel.InboundMessage) string { - if strings.TrimSpace(msg.Username) != "" { - return strings.TrimSpace(msg.Username) - } - if strings.TrimSpace(msg.UserID) != "" { - return strings.TrimSpace(msg.UserID) - } - if strings.TrimSpace(msg.OpenID) != "" { - return strings.TrimSpace(msg.OpenID) - } - if strings.TrimSpace(msg.ChatID) != "" { - return strings.TrimSpace(msg.ChatID) - } - return "" -} - -func buildUserBindingConfig(msg channel.InboundMessage) map[string]interface{} { - config := map[string]interface{}{} - switch msg.Channel { - case channel.ChannelFeishu: - if strings.TrimSpace(msg.OpenID) != "" { - config["open_id"] = strings.TrimSpace(msg.OpenID) - } - if strings.TrimSpace(msg.UserID) != "" { - config["user_id"] = strings.TrimSpace(msg.UserID) - } - case channel.ChannelTelegram: - if strings.TrimSpace(msg.Username) != "" { - config["username"] = strings.TrimSpace(msg.Username) - } - if strings.TrimSpace(msg.UserID) != "" { - config["user_id"] = strings.TrimSpace(msg.UserID) - } - if strings.TrimSpace(msg.ChatID) != "" { - config["chat_id"] = strings.TrimSpace(msg.ChatID) - } - } - return config -} - -func (p *ChannelInboundProcessor) tryHandleBindToken(ctx context.Context, msg channel.InboundMessage, externalID string) (*channel.OutboundMessage, bool) { - tokenText := strings.TrimSpace(msg.Text) - if tokenText == "" { - return nil, false - } - token, err := p.contacts.GetBindToken(ctx, tokenText) - if err != nil { - return nil, false - } - replyTarget := strings.TrimSpace(msg.ReplyTo) - if replyTarget == "" { - return nil, true - } - now := time.Now().UTC() - if !token.UsedAt.IsZero() { - return &channel.OutboundMessage{To: replyTarget, Text: "绑定码已被使用。"}, true - } - if now.After(token.ExpiresAt) { - return &channel.OutboundMessage{To: replyTarget, Text: "绑定码已过期,请重新获取。"}, true - } - if token.BotID != msg.BotID { - return &channel.OutboundMessage{To: replyTarget, Text: "绑定码不匹配。"}, true - } - if token.TargetPlatform != "" && token.TargetPlatform != msg.Channel.String() { - return &channel.OutboundMessage{To: replyTarget, Text: "绑定码平台不匹配。"}, true - } - if token.TargetExternalID != "" && token.TargetExternalID != externalID { - return &channel.OutboundMessage{To: replyTarget, Text: "绑定码目标不匹配。"}, true - } - if externalID == "" { - return &channel.OutboundMessage{To: replyTarget, Text: "无法识别当前账号,绑定失败。"}, true - } - if _, err := p.contacts.UpsertChannel(ctx, msg.BotID, token.ContactID, msg.Channel.String(), externalID, nil); err != nil { - return &channel.OutboundMessage{To: replyTarget, Text: "绑定失败,请稍后重试。"}, true - } - if strings.TrimSpace(token.IssuedByUserID) != "" { - if boundContact, err := p.contacts.GetByID(ctx, token.ContactID); err == nil { - if strings.TrimSpace(boundContact.UserID) != "" && boundContact.UserID != token.IssuedByUserID { - return &channel.OutboundMessage{To: replyTarget, Text: "该绑定码已关联其他账号。"}, true - } - } - _, _ = p.contacts.BindUser(ctx, token.ContactID, token.IssuedByUserID) - if config := buildUserBindingConfig(msg); len(config) > 0 { - _, _ = p.store.UpsertUserConfig(ctx, token.IssuedByUserID, msg.Channel, channel.UpsertUserConfigRequest{ - Config: config, + partType := normalizeContentPartType(part.Type) + parts = append(parts, channel.MessagePart{ + Type: partType, + Text: part.Text, + URL: part.URL, + Styles: normalizeContentPartStyles(part.Styles), + Language: part.Language, + UserID: part.UserID, + Emoji: part.Emoji, }) } - _ = p.store.UpsertChannelSession(ctx, msg.SessionID(), msg.BotID, "", token.IssuedByUserID, token.ContactID, msg.Channel.String()) + if len(parts) > 0 { + msg.Parts = parts + msg.Format = channel.MessageFormatRich + } + return msg } - _, _ = p.contacts.MarkBindTokenUsed(ctx, token.ID) - return &channel.OutboundMessage{To: replyTarget, Text: p.bindSuccessReply}, true + textParts := make([]string, 0, len(output.Parts)) + for _, part := range output.Parts { + if !contentPartHasValue(part) { + continue + } + textParts = append(textParts, strings.TrimSpace(contentPartText(part))) + } + if len(textParts) > 0 { + msg.Text = strings.Join(textParts, "\n") + if msg.Format == "" && containsMarkdown(msg.Text) && (capabilities.Markdown || capabilities.RichText) { + msg.Format = channel.MessageFormatMarkdown + } + } + return msg +} + +func containsMarkdown(text string) bool { + if strings.TrimSpace(text) == "" { + return false + } + patterns := []string{ + `\\*\\*[^*]+\\*\\*`, + `\\*[^*]+\\*`, + `~~[^~]+~~`, + "`[^`]+`", + "```[\\s\\S]*```", + `\\[.+\\]\\(.+\\)`, + `(?m)^#{1,6}\\s`, + `(?m)^[-*]\\s`, + `(?m)^\\d+\\.\\s`, + } + for _, pattern := range patterns { + if matched, _ := regexp.MatchString(pattern, text); matched { + return true + } + } + return false +} + +func contentPartHasValue(part chat.ContentPart) bool { + if strings.TrimSpace(part.Text) != "" { + return true + } + if strings.TrimSpace(part.URL) != "" { + return true + } + if strings.TrimSpace(part.Emoji) != "" { + return true + } + return false +} + +func contentPartText(part chat.ContentPart) string { + if strings.TrimSpace(part.Text) != "" { + return part.Text + } + if strings.TrimSpace(part.URL) != "" { + return part.URL + } + if strings.TrimSpace(part.Emoji) != "" { + return part.Emoji + } + return "" +} + +func buildInboundQuery(message channel.Message) string { + text := strings.TrimSpace(message.PlainText()) + if len(message.Attachments) == 0 { + return text + } + lines := make([]string, 0, len(message.Attachments)+1) + if text != "" { + lines = append(lines, text) + } + for _, att := range message.Attachments { + label := strings.TrimSpace(att.Name) + if label == "" { + label = strings.TrimSpace(att.URL) + } + if label == "" { + label = "unknown" + } + lines = append(lines, fmt.Sprintf("[attachment:%s] %s", att.Type, label)) + } + return strings.Join(lines, "\n") +} + +func normalizeContentPartType(raw string) channel.MessagePartType { + switch strings.TrimSpace(strings.ToLower(raw)) { + case "link": + return channel.MessagePartLink + case "code_block": + return channel.MessagePartCodeBlock + case "mention": + return channel.MessagePartMention + case "emoji": + return channel.MessagePartEmoji + default: + return channel.MessagePartText + } +} + +func normalizeContentPartStyles(styles []string) []channel.MessageTextStyle { + if len(styles) == 0 { + return nil + } + result := make([]channel.MessageTextStyle, 0, len(styles)) + for _, style := range styles { + switch strings.TrimSpace(strings.ToLower(style)) { + case "bold": + result = append(result, channel.MessageStyleBold) + case "italic": + result = append(result, channel.MessageStyleItalic) + case "strikethrough", "lineThrough": + result = append(result, channel.MessageStyleStrikethrough) + case "code": + result = append(result, channel.MessageStyleCode) + default: + continue + } + } + if len(result) == 0 { + return nil + } + return result +} + +type sendMessageToolArgs struct { + Platform string `json:"platform"` + Target string `json:"target"` + UserID string `json:"user_id"` + Text string `json:"text"` + Message *channel.Message `json:"message"` +} + +type toolCall struct { + Name string + Arguments string +} + +func collectMessageToolContext(messages []chat.GatewayMessage, channelType channel.ChannelType, replyTarget string) ([]string, bool) { + if len(messages) == 0 { + return nil, false + } + sentTexts := make([]string, 0) + suppressReplies := false + for _, msg := range messages { + for _, call := range extractToolCalls(msg) { + if call.Name != "send_message" { + continue + } + var args sendMessageToolArgs + if !parseToolArguments(call.Arguments, &args) { + continue + } + messageText := strings.TrimSpace(extractSendMessageText(args)) + if messageText != "" { + sentTexts = append(sentTexts, messageText) + } + if shouldSuppressForToolCall(args, channelType, replyTarget) { + suppressReplies = true + } + } + } + return sentTexts, suppressReplies +} + +func extractToolCalls(msg chat.GatewayMessage) []toolCall { + calls := make([]toolCall, 0) + if msg == nil { + return calls + } + if rawCalls, ok := msg["tool_calls"].([]any); ok { + for _, raw := range rawCalls { + call, ok := raw.(map[string]any) + if !ok { + continue + } + name, args := parseToolCall(call) + if name == "" { + continue + } + calls = append(calls, toolCall{Name: name, Arguments: args}) + } + } + if fn, ok := msg["function_call"].(map[string]any); ok { + name := readString(fn["name"]) + args := readString(fn["arguments"]) + if name != "" { + calls = append(calls, toolCall{Name: name, Arguments: args}) + } + } + if fn, ok := msg["functionCall"].(map[string]any); ok { + name := readString(fn["name"]) + args := readString(fn["arguments"]) + if name != "" { + calls = append(calls, toolCall{Name: name, Arguments: args}) + } + } + return calls +} + +func parseToolCall(call map[string]any) (string, string) { + if call == nil { + return "", "" + } + name := "" + args := "" + if fn, ok := call["function"].(map[string]any); ok { + name = readString(fn["name"]) + args = readString(fn["arguments"]) + } + if name == "" { + name = readString(call["name"]) + } + if args == "" { + args = readString(call["arguments"]) + } + return name, args +} + +func parseToolArguments(raw string, out any) bool { + if strings.TrimSpace(raw) == "" { + return false + } + if err := json.Unmarshal([]byte(raw), out); err == nil { + return true + } + var decoded string + if err := json.Unmarshal([]byte(raw), &decoded); err != nil { + return false + } + if strings.TrimSpace(decoded) == "" { + return false + } + return json.Unmarshal([]byte(decoded), out) == nil +} + +func extractSendMessageText(args sendMessageToolArgs) string { + if strings.TrimSpace(args.Text) != "" { + return strings.TrimSpace(args.Text) + } + if args.Message == nil { + return "" + } + return strings.TrimSpace(args.Message.PlainText()) +} + +func shouldSuppressForToolCall(args sendMessageToolArgs, channelType channel.ChannelType, replyTarget string) bool { + platform := strings.TrimSpace(args.Platform) + if platform == "" { + platform = string(channelType) + } + if !strings.EqualFold(platform, string(channelType)) { + return false + } + target := strings.TrimSpace(args.Target) + if target == "" && strings.TrimSpace(args.UserID) == "" { + target = replyTarget + } + if strings.TrimSpace(target) == "" || strings.TrimSpace(replyTarget) == "" { + return false + } + normalizedTarget := normalizeReplyTarget(channelType, target) + normalizedReply := normalizeReplyTarget(channelType, replyTarget) + if normalizedTarget == "" || normalizedReply == "" { + return false + } + return normalizedTarget == normalizedReply +} + +func normalizeReplyTarget(channelType channel.ChannelType, target string) string { + normalized, ok := channel.NormalizeTarget(channelType, target) + if ok && strings.TrimSpace(normalized) != "" { + return strings.TrimSpace(normalized) + } + return strings.TrimSpace(target) +} + +func isSilentReplyText(text string) bool { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return false + } + token := []rune(silentReplyToken) + value := []rune(trimmed) + if len(value) < len(token) { + return false + } + if hasTokenPrefix(value, token) { + return true + } + if hasTokenSuffix(value, token) { + return true + } + return false +} + +func hasTokenPrefix(value []rune, token []rune) bool { + if len(value) < len(token) { + return false + } + for i := range token { + if value[i] != token[i] { + return false + } + } + if len(value) == len(token) { + return true + } + return !isWordChar(value[len(token)]) +} + +func hasTokenSuffix(value []rune, token []rune) bool { + if len(value) < len(token) { + return false + } + start := len(value) - len(token) + for i := range token { + if value[start+i] != token[i] { + return false + } + } + if start == 0 { + return true + } + return !isWordChar(value[start-1]) +} + +func isWordChar(value rune) bool { + return value == '_' || unicode.IsLetter(value) || unicode.IsDigit(value) +} + +func normalizeTextForComparison(text string) string { + trimmed := strings.TrimSpace(strings.ToLower(text)) + if trimmed == "" { + return "" + } + return strings.TrimSpace(whitespacePattern.ReplaceAllString(trimmed, " ")) +} + +func isMessagingToolDuplicate(text string, sentTexts []string) bool { + if len(sentTexts) == 0 { + return false + } + normalized := normalizeTextForComparison(text) + if len(normalized) < minDuplicateTextLength { + return false + } + for _, sent := range sentTexts { + sentNormalized := normalizeTextForComparison(sent) + if len(sentNormalized) < minDuplicateTextLength { + continue + } + if strings.Contains(normalized, sentNormalized) || strings.Contains(sentNormalized, normalized) { + return true + } + } + return false +} + +func readString(value any) string { + if raw, ok := value.(string); ok { + return raw + } + return "" +} + +func (p *ChannelInboundProcessor) requireIdentity(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) (IdentityState, error) { + if state, ok := IdentityStateFromContext(ctx); ok { + return state, nil + } + if p.identity == nil { + return IdentityState{}, fmt.Errorf("identity resolver not configured") + } + return p.identity.Resolve(ctx, cfg, msg) } diff --git a/internal/router/channel_test.go b/internal/router/channel_test.go index b5fcfc23..b65daf9a 100644 --- a/internal/router/channel_test.go +++ b/internal/router/channel_test.go @@ -10,6 +10,7 @@ import ( "github.com/memohai/memoh/internal/channel" "github.com/memohai/memoh/internal/chat" "github.com/memohai/memoh/internal/contacts" + "github.com/memohai/memoh/internal/policy" ) type fakeConfigStore struct { @@ -40,6 +41,10 @@ func (f *fakeConfigStore) ResolveUserBinding(ctx context.Context, channelType ch return f.boundUserID, nil } +func (f *fakeConfigStore) ListSessionsByBotPlatform(ctx context.Context, botID, platform string) ([]channel.ChannelSession, error) { + return nil, nil +} + func (f *fakeConfigStore) GetChannelSession(ctx context.Context, sessionID string) (channel.ChannelSession, error) { if f.session.SessionID == sessionID { return f.session, nil @@ -47,7 +52,7 @@ func (f *fakeConfigStore) GetChannelSession(ctx context.Context, sessionID strin return channel.ChannelSession{}, nil } -func (f *fakeConfigStore) UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string) error { +func (f *fakeConfigStore) UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string, replyTarget string, threadID string, metadata map[string]any) error { return nil } @@ -86,20 +91,33 @@ func (f *fakeContactService) CreateGuest(ctx context.Context, botID, displayName return contacts.Contact{ID: "contact-guest", BotID: botID}, nil } -func (f *fakeContactService) UpsertChannel(ctx context.Context, botID, contactID, platform, externalID string, metadata map[string]interface{}) (contacts.ContactChannel, error) { +func (f *fakeContactService) UpsertChannel(ctx context.Context, botID, contactID, platform, externalID string, metadata map[string]any) (contacts.ContactChannel, error) { return contacts.ContactChannel{ID: "channel-1", ContactID: contactID}, nil } -func (f *fakeContactService) GetBindToken(ctx context.Context, token string) (contacts.BindToken, error) { - return contacts.BindToken{}, fmt.Errorf("not found") +type fakePolicyService struct { + decision policy.Decision + err error } -func (f *fakeContactService) MarkBindTokenUsed(ctx context.Context, id string) (contacts.BindToken, error) { - return contacts.BindToken{}, nil +func (f *fakePolicyService) Resolve(ctx context.Context, botID string) (policy.Decision, error) { + if f.err != nil { + return policy.Decision{}, f.err + } + decision := f.decision + if decision.BotID == "" { + decision.BotID = botID + } + return decision, nil } -func (f *fakeContactService) BindUser(ctx context.Context, contactID, userID string) (contacts.Contact, error) { - return contacts.Contact{}, nil +type fakeReplySender struct { + sent []channel.OutboundMessage +} + +func (s *fakeReplySender) Send(ctx context.Context, msg channel.OutboundMessage) error { + s.sent = append(s.sent, msg) + return nil } func TestChannelInboundProcessorBoundUser(t *testing.T) { @@ -116,17 +134,21 @@ func TestChannelInboundProcessorBoundUser(t *testing.T) { }, }, } - processor := NewChannelInboundProcessor(slog.Default(), store, gateway, &fakeContactService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelFeishu} + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} msg := channel.InboundMessage{ - Channel: channel.ChannelFeishu, - Text: "你好", - ChatID: "chat-1", - ReplyTo: "target-id", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "你好"}, + ReplyTarget: "target-id", + Conversation: channel.Conversation{ + ID: "chat-1", + Type: "p2p", + }, } - out, err := processor.HandleInbound(context.Background(), cfg, msg) + err := processor.HandleInbound(context.Background(), cfg, msg, sender) if err != nil { t.Fatalf("不应报错: %v", err) } @@ -136,29 +158,30 @@ func TestChannelInboundProcessorBoundUser(t *testing.T) { if gateway.gotReq.SessionID != "feishu:bot-1:chat-1" { t.Errorf("SessionID 传递错误: %s", gateway.gotReq.SessionID) } - if out != nil { - t.Fatalf("不应直接返回回复: %+v", out) + if len(sender.sent) != 1 || sender.sent[0].Message.PlainText() != "AI回复内容" { + t.Fatalf("应发送 AI 回复,实际: %+v", sender.sent) } } func TestChannelInboundProcessorUnboundUser(t *testing.T) { store := &fakeConfigStore{} gateway := &fakeChatGateway{} - processor := NewChannelInboundProcessor(slog.Default(), store, gateway, &fakeContactService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelFeishu} + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} msg := channel.InboundMessage{ - Channel: channel.ChannelFeishu, - Text: "你好", - ReplyTo: "target-id", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "你好"}, + ReplyTarget: "target-id", } - out, err := processor.HandleInbound(context.Background(), cfg, msg) + err := processor.HandleInbound(context.Background(), cfg, msg, sender) if err != nil { t.Fatalf("不应报错: %v", err) } - if out == nil || !strings.Contains(out.Text, "尚未绑定") { - t.Fatalf("应返回绑定提示,实际返回: %+v", out) + if len(sender.sent) != 1 || !strings.Contains(sender.sent[0].Message.PlainText(), "陌生人") { + t.Fatalf("应发送绑定提示,实际: %+v", sender.sent) } if gateway.gotReq.Query != "" { t.Error("未绑定用户不应触发 Chat 调用") @@ -168,19 +191,155 @@ func TestChannelInboundProcessorUnboundUser(t *testing.T) { func TestChannelInboundProcessorIgnoreEmpty(t *testing.T) { store := &fakeConfigStore{} gateway := &fakeChatGateway{} - processor := NewChannelInboundProcessor(slog.Default(), store, gateway, &fakeContactService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1"} - msg := channel.InboundMessage{Text: " "} + msg := channel.InboundMessage{Message: channel.Message{Text: " "}} - out, err := processor.HandleInbound(context.Background(), cfg, msg) + err := processor.HandleInbound(context.Background(), cfg, msg, sender) if err != nil { t.Fatalf("空消息不应报错: %v", err) } - if out != nil { - t.Fatalf("空消息不应返回回复: %+v", out) + if len(sender.sent) != 0 { + t.Fatalf("空消息不应发送回复: %+v", sender.sent) } if gateway.gotReq.Query != "" { t.Error("空消息不应触发 Chat 调用") } } + +func TestChannelInboundProcessorSilentReply(t *testing.T) { + store := &fakeConfigStore{ + session: channel.ChannelSession{ + SessionID: "feishu:bot-1:chat-1", + UserID: "user-123", + }, + } + gateway := &fakeChatGateway{ + resp: chat.ChatResponse{ + Messages: []chat.GatewayMessage{ + {"role": "assistant", "content": "NO_REPLY"}, + }, + }, + } + processor := NewChannelInboundProcessor(slog.Default(), store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + sender := &fakeReplySender{} + + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + msg := channel.InboundMessage{ + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "你好"}, + ReplyTarget: "target-id", + Conversation: channel.Conversation{ + ID: "chat-1", + Type: "p2p", + }, + } + + err := processor.HandleInbound(context.Background(), cfg, msg, sender) + if err != nil { + t.Fatalf("不应报错: %v", err) + } + if len(sender.sent) != 0 { + t.Fatalf("NO_REPLY 不应发送回复,实际: %+v", sender.sent) + } +} + +func TestChannelInboundProcessorSuppressOnToolSend(t *testing.T) { + store := &fakeConfigStore{ + session: channel.ChannelSession{ + SessionID: "feishu:bot-1:chat-1", + UserID: "user-123", + }, + } + gateway := &fakeChatGateway{ + resp: chat.ChatResponse{ + Messages: []chat.GatewayMessage{ + { + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "send_message", + "arguments": `{"platform":"feishu","target":"target-id","message":{"text":"AI回复内容"}}`, + }, + }, + }, + }, + {"role": "assistant", "content": "AI回复内容"}, + }, + }, + } + processor := NewChannelInboundProcessor(slog.Default(), store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + sender := &fakeReplySender{} + + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + msg := channel.InboundMessage{ + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "你好"}, + ReplyTarget: "target-id", + Conversation: channel.Conversation{ + ID: "chat-1", + Type: "p2p", + }, + } + + err := processor.HandleInbound(context.Background(), cfg, msg, sender) + if err != nil { + t.Fatalf("不应报错: %v", err) + } + if len(sender.sent) != 0 { + t.Fatalf("工具已发送当前会话消息,应抑制普通回复,实际: %+v", sender.sent) + } +} + +func TestChannelInboundProcessorDedupeWithToolSend(t *testing.T) { + store := &fakeConfigStore{ + session: channel.ChannelSession{ + SessionID: "feishu:bot-1:chat-1", + UserID: "user-123", + }, + } + gateway := &fakeChatGateway{ + resp: chat.ChatResponse{ + Messages: []chat.GatewayMessage{ + { + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "send_message", + "arguments": `{"platform":"feishu","target":"other-target","message":{"text":"AI回复内容"}}`, + }, + }, + }, + }, + {"role": "assistant", "content": "AI回复内容"}, + }, + }, + } + processor := NewChannelInboundProcessor(slog.Default(), store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + sender := &fakeReplySender{} + + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + msg := channel.InboundMessage{ + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "你好"}, + ReplyTarget: "target-id", + Conversation: channel.Conversation{ + ID: "chat-1", + Type: "p2p", + }, + } + + err := processor.HandleInbound(context.Background(), cfg, msg, sender) + if err != nil { + t.Fatalf("不应报错: %v", err) + } + if len(sender.sent) != 0 { + t.Fatalf("工具发送文本与普通回复重复,应去重,实际: %+v", sender.sent) + } +} diff --git a/internal/router/identity.go b/internal/router/identity.go new file mode 100644 index 00000000..b0ba5211 --- /dev/null +++ b/internal/router/identity.go @@ -0,0 +1,326 @@ +package router + +import ( + "context" + "errors" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/memohai/memoh/internal/channel" + "github.com/memohai/memoh/internal/contacts" + "github.com/memohai/memoh/internal/policy" + "github.com/memohai/memoh/internal/preauth" +) + +type IdentityDecision struct { + Stop bool + Reply channel.Message +} + +type InboundIdentity struct { + BotID string + SessionID string + ChannelConfigID string + ExternalID string + UserID string + ContactID string + Contact contacts.Contact +} + +type IdentityState struct { + Identity InboundIdentity + Decision *IdentityDecision +} + +type identityContextKey struct{} + +func WithIdentityState(ctx context.Context, state IdentityState) context.Context { + return context.WithValue(ctx, identityContextKey{}, state) +} + +func IdentityStateFromContext(ctx context.Context) (IdentityState, bool) { + if ctx == nil { + return IdentityState{}, false + } + raw := ctx.Value(identityContextKey{}) + if raw == nil { + return IdentityState{}, false + } + state, ok := raw.(IdentityState) + return state, ok +} + +type IdentityResolver struct { + store channel.ConfigStore + contacts ContactService + policy PolicyService + preauth PreauthService + logger *slog.Logger + unboundReply string + preauthReply string +} + +type PolicyService interface { + Resolve(ctx context.Context, botID string) (policy.Decision, error) +} + +type PreauthService interface { + Get(ctx context.Context, token string) (preauth.Key, error) + MarkUsed(ctx context.Context, id string) (preauth.Key, error) +} + +func NewIdentityResolver(log *slog.Logger, store channel.ConfigStore, contacts ContactService, policyService PolicyService, preauthService PreauthService, unboundReply, preauthReply string) *IdentityResolver { + if log == nil { + log = slog.Default() + } + if strings.TrimSpace(unboundReply) == "" { + unboundReply = "当前不允许陌生人访问,请联系管理员。" + } + if strings.TrimSpace(preauthReply) == "" { + preauthReply = "授权成功,请继续使用。" + } + return &IdentityResolver{ + store: store, + contacts: contacts, + policy: policyService, + preauth: preauthService, + logger: log.With(slog.String("component", "channel_identity")), + unboundReply: unboundReply, + preauthReply: preauthReply, + } +} + +func (r *IdentityResolver) Middleware() channel.Middleware { + return func(next channel.InboundHandler) channel.InboundHandler { + return func(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) error { + state, err := r.Resolve(ctx, cfg, msg) + if err != nil { + return err + } + return next(WithIdentityState(ctx, state), cfg, msg) + } + } +} + +func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) (IdentityState, error) { + if r.store == nil || r.contacts == nil || r.policy == nil { + return IdentityState{}, fmt.Errorf("identity resolver not configured") + } + + botID := strings.TrimSpace(msg.BotID) + if botID == "" { + botID = cfg.BotID + } + normalizedMsg := msg + normalizedMsg.BotID = botID + + sessionID := normalizedMsg.SessionID() + channelConfigID := cfg.ID + if channel.IsConfigless(msg.Channel) { + channelConfigID = "" + } + externalID := extractExternalIdentity(msg) + + state := IdentityState{ + Identity: InboundIdentity{ + BotID: botID, + SessionID: sessionID, + ChannelConfigID: channelConfigID, + ExternalID: externalID, + }, + } + + session, err := r.store.GetChannelSession(ctx, sessionID) + if err != nil && r.logger != nil { + r.logger.Error("get user by session failed", slog.String("session_id", sessionID), slog.Any("error", err)) + } + userID := strings.TrimSpace(session.UserID) + contactID := strings.TrimSpace(session.ContactID) + + if userID == "" { + userID, err = r.store.ResolveUserBinding(ctx, msg.Channel, channel.BindingCriteriaFromIdentity(msg.Sender)) + if err == nil && userID != "" { + _ = r.store.UpsertChannelSession(ctx, sessionID, botID, channelConfigID, userID, contactID, string(msg.Channel), strings.TrimSpace(msg.ReplyTarget), extractThreadID(msg), buildSessionMetadata(msg)) + } + } + + var contact contacts.Contact + if contactID == "" && userID != "" { + contact, err = r.contacts.GetByUserID(ctx, botID, userID) + if err != nil { + displayName := extractDisplayName(msg) + contact, err = r.contacts.Create(ctx, contacts.CreateRequest{ + BotID: botID, + UserID: userID, + DisplayName: displayName, + Status: "active", + }) + } + if err == nil { + contactID = contact.ID + if externalID != "" { + _, _ = r.contacts.UpsertChannel(ctx, botID, contactID, msg.Channel.String(), externalID, nil) + } + } + } + + if contactID == "" && externalID != "" { + binding, err := r.contacts.GetByChannelIdentity(ctx, botID, msg.Channel.String(), externalID) + if err == nil { + contactID = binding.ContactID + } + } + + if contactID == "" { + decision, err := r.policy.Resolve(ctx, botID) + if err != nil { + return state, err + } + if decision.AllowGuest { + displayName := extractDisplayName(msg) + contact, err = r.contacts.CreateGuest(ctx, botID, displayName) + if err == nil { + contactID = contact.ID + if externalID != "" { + _, _ = r.contacts.UpsertChannel(ctx, botID, contactID, msg.Channel.String(), externalID, nil) + } + } + } else { + if handled, decision, err := r.tryHandlePreauthKey(ctx, normalizedMsg, externalID); handled { + state.Decision = &decision + return state, err + } + state.Decision = &IdentityDecision{ + Stop: true, + Reply: channel.Message{Text: r.unboundReply}, + } + return state, nil + } + } + + if contactID != "" && contact.ID == "" { + loaded, err := r.contacts.GetByID(ctx, contactID) + if err == nil { + contact = loaded + } + } + + if contactID != "" { + _ = r.store.UpsertChannelSession(ctx, sessionID, botID, channelConfigID, userID, contactID, string(msg.Channel), strings.TrimSpace(msg.ReplyTarget), extractThreadID(msg), buildSessionMetadata(msg)) + } + + state.Identity.UserID = userID + state.Identity.ContactID = contactID + state.Identity.Contact = contact + return state, nil +} + +func (r *IdentityResolver) tryHandlePreauthKey(ctx context.Context, msg channel.InboundMessage, externalID string) (bool, IdentityDecision, error) { + tokenText := strings.TrimSpace(msg.Message.PlainText()) + if tokenText == "" || r.preauth == nil { + return false, IdentityDecision{}, nil + } + key, err := r.preauth.Get(ctx, tokenText) + if err != nil { + if errors.Is(err, preauth.ErrKeyNotFound) { + return false, IdentityDecision{}, nil + } + return true, IdentityDecision{}, err + } + reply := func(text string) IdentityDecision { + return IdentityDecision{ + Stop: true, + Reply: channel.Message{Text: text}, + } + } + if !key.UsedAt.IsZero() { + return true, reply("预授权码已使用。"), nil + } + if !key.ExpiresAt.IsZero() && time.Now().UTC().After(key.ExpiresAt) { + return true, reply("预授权码已过期,请重新获取。"), nil + } + if key.BotID != msg.BotID { + return true, reply("预授权码不匹配。"), nil + } + if externalID == "" { + return true, reply("无法识别当前账号,授权失败。"), nil + } + displayName := extractDisplayName(msg) + contact, err := r.contacts.CreateGuest(ctx, msg.BotID, displayName) + if err != nil { + return true, reply("授权失败,请稍后重试。"), nil + } + if _, err := r.contacts.UpsertChannel(ctx, msg.BotID, contact.ID, msg.Channel.String(), externalID, nil); err != nil { + return true, reply("授权失败,请稍后重试。"), nil + } + _ = r.store.UpsertChannelSession(ctx, msg.SessionID(), msg.BotID, "", "", contact.ID, msg.Channel.String(), strings.TrimSpace(msg.ReplyTarget), extractThreadID(msg), buildSessionMetadata(msg)) + _, _ = r.preauth.MarkUsed(ctx, key.ID) + return true, reply(r.preauthReply), nil +} + +func extractExternalIdentity(msg channel.InboundMessage) string { + if strings.TrimSpace(msg.Sender.ExternalID) != "" { + return strings.TrimSpace(msg.Sender.ExternalID) + } + if value := strings.TrimSpace(msg.Sender.Attribute("open_id")); value != "" { + return value + } + if value := strings.TrimSpace(msg.Sender.Attribute("user_id")); value != "" { + return value + } + if value := strings.TrimSpace(msg.Sender.Attribute("username")); value != "" { + return value + } + return strings.TrimSpace(msg.Sender.DisplayName) +} + +func extractDisplayName(msg channel.InboundMessage) string { + if strings.TrimSpace(msg.Sender.DisplayName) != "" { + return strings.TrimSpace(msg.Sender.DisplayName) + } + if strings.TrimSpace(msg.Sender.ExternalID) != "" { + return strings.TrimSpace(msg.Sender.ExternalID) + } + if value := strings.TrimSpace(msg.Sender.Attribute("username")); value != "" { + return value + } + if value := strings.TrimSpace(msg.Sender.Attribute("user_id")); value != "" { + return value + } + if value := strings.TrimSpace(msg.Sender.Attribute("open_id")); value != "" { + return value + } + return "" +} + +func extractThreadID(msg channel.InboundMessage) string { + if msg.Message.Thread != nil && strings.TrimSpace(msg.Message.Thread.ID) != "" { + return strings.TrimSpace(msg.Message.Thread.ID) + } + if strings.TrimSpace(msg.Conversation.ThreadID) != "" { + return strings.TrimSpace(msg.Conversation.ThreadID) + } + return "" +} + +func buildSessionMetadata(msg channel.InboundMessage) map[string]any { + metadata := map[string]any{} + if strings.TrimSpace(msg.Source) != "" { + metadata["source"] = strings.TrimSpace(msg.Source) + } + if strings.TrimSpace(msg.Message.ID) != "" { + metadata["message_id"] = strings.TrimSpace(msg.Message.ID) + } + if strings.TrimSpace(msg.Conversation.Type) != "" { + metadata["conversation_type"] = strings.TrimSpace(msg.Conversation.Type) + } + if strings.TrimSpace(msg.Conversation.Name) != "" { + metadata["conversation_name"] = strings.TrimSpace(msg.Conversation.Name) + } + if !msg.ReceivedAt.IsZero() { + metadata["received_at"] = msg.ReceivedAt.UTC().Format(time.RFC3339Nano) + } + return metadata +} diff --git a/internal/router/identity_test.go b/internal/router/identity_test.go new file mode 100644 index 00000000..3948f16c --- /dev/null +++ b/internal/router/identity_test.go @@ -0,0 +1,210 @@ +package router + +import ( + "context" + "fmt" + "log/slog" + "testing" + "time" + + "github.com/memohai/memoh/internal/channel" + "github.com/memohai/memoh/internal/contacts" + "github.com/memohai/memoh/internal/policy" + "github.com/memohai/memoh/internal/preauth" +) + +type fakePolicyServiceIdentity struct { + decision policy.Decision + err error +} + +func (f *fakePolicyServiceIdentity) Resolve(ctx context.Context, botID string) (policy.Decision, error) { + if f.err != nil { + return policy.Decision{}, f.err + } + decision := f.decision + if decision.BotID == "" { + decision.BotID = botID + } + return decision, nil +} + +type fakeIdentityConfigStore struct{} + +func (f *fakeIdentityConfigStore) ResolveEffectiveConfig(ctx context.Context, botID string, channelType channel.ChannelType) (channel.ChannelConfig, error) { + return channel.ChannelConfig{}, nil +} + +func (f *fakeIdentityConfigStore) GetUserConfig(ctx context.Context, actorUserID string, channelType channel.ChannelType) (channel.ChannelUserBinding, error) { + return channel.ChannelUserBinding{}, fmt.Errorf("not implemented") +} + +func (f *fakeIdentityConfigStore) UpsertUserConfig(ctx context.Context, actorUserID string, channelType channel.ChannelType, req channel.UpsertUserConfigRequest) (channel.ChannelUserBinding, error) { + return channel.ChannelUserBinding{}, nil +} + +func (f *fakeIdentityConfigStore) ListConfigsByType(ctx context.Context, channelType channel.ChannelType) ([]channel.ChannelConfig, error) { + return nil, nil +} + +func (f *fakeIdentityConfigStore) ResolveUserBinding(ctx context.Context, channelType channel.ChannelType, criteria channel.BindingCriteria) (string, error) { + return "", fmt.Errorf("channel user binding not found") +} + +func (f *fakeIdentityConfigStore) ListSessionsByBotPlatform(ctx context.Context, botID string, platform string) ([]channel.ChannelSession, error) { + return nil, nil +} + +func (f *fakeIdentityConfigStore) GetChannelSession(ctx context.Context, sessionID string) (channel.ChannelSession, error) { + return channel.ChannelSession{}, nil +} + +func (f *fakeIdentityConfigStore) UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string, replyTarget string, threadID string, metadata map[string]any) error { + return nil +} + +type fakeIdentityContactService struct { + createGuestCalled bool + upsertCalled bool +} + +func (f *fakeIdentityContactService) GetByID(ctx context.Context, contactID string) (contacts.Contact, error) { + return contacts.Contact{}, fmt.Errorf("not found") +} + +func (f *fakeIdentityContactService) GetByUserID(ctx context.Context, botID, userID string) (contacts.Contact, error) { + return contacts.Contact{}, fmt.Errorf("not found") +} + +func (f *fakeIdentityContactService) GetByChannelIdentity(ctx context.Context, botID, platform, externalID string) (contacts.ContactChannel, error) { + return contacts.ContactChannel{}, fmt.Errorf("not found") +} + +func (f *fakeIdentityContactService) Create(ctx context.Context, req contacts.CreateRequest) (contacts.Contact, error) { + return contacts.Contact{ID: "contact-1", BotID: req.BotID}, nil +} + +func (f *fakeIdentityContactService) CreateGuest(ctx context.Context, botID, displayName string) (contacts.Contact, error) { + f.createGuestCalled = true + return contacts.Contact{ID: "contact-guest", BotID: botID}, nil +} + +func (f *fakeIdentityContactService) UpsertChannel(ctx context.Context, botID, contactID, platform, externalID string, metadata map[string]any) (contacts.ContactChannel, error) { + f.upsertCalled = true + return contacts.ContactChannel{ID: "channel-1", ContactID: contactID}, nil +} + +type fakePreauthService struct { + key preauth.Key + err error + markUsed bool +} + +func (f *fakePreauthService) Get(ctx context.Context, token string) (preauth.Key, error) { + if f.err != nil { + return preauth.Key{}, f.err + } + if f.key.Token == "" || f.key.Token != token { + return preauth.Key{}, preauth.ErrKeyNotFound + } + return f.key, nil +} + +func (f *fakePreauthService) MarkUsed(ctx context.Context, id string) (preauth.Key, error) { + f.markUsed = true + return f.key, nil +} + +func TestIdentityResolverAllowGuestCreatesContact(t *testing.T) { + store := &fakeIdentityConfigStore{} + contactsService := &fakeIdentityContactService{} + policyService := &fakePolicyServiceIdentity{decision: policy.Decision{AllowGuest: true}} + resolver := NewIdentityResolver(slog.Default(), store, contactsService, policyService, nil, "禁止访问", "授权成功") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello"}, + ReplyTarget: "target-id", + Sender: channel.Identity{ExternalID: "user-1", DisplayName: "访客"}, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("不应报错: %v", err) + } + if state.Identity.ContactID != "contact-guest" { + t.Fatalf("应创建访客联系人,实际: %s", state.Identity.ContactID) + } + if !contactsService.createGuestCalled { + t.Fatalf("应调用 CreateGuest") + } +} + +func TestIdentityResolverPreauthKeyAllowsGuest(t *testing.T) { + store := &fakeIdentityConfigStore{} + contactsService := &fakeIdentityContactService{} + policyService := &fakePolicyServiceIdentity{} + preauthService := &fakePreauthService{ + key: preauth.Key{ + ID: "key-1", + BotID: "bot-1", + Token: "PREAUTH123", + ExpiresAt: time.Now().UTC().Add(1 * time.Hour), + }, + } + resolver := NewIdentityResolver(slog.Default(), store, contactsService, policyService, preauthService, "禁止访问", "授权成功") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "PREAUTH123"}, + ReplyTarget: "target-id", + Sender: channel.Identity{ExternalID: "user-1"}, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("不应报错: %v", err) + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatalf("应返回授权确认") + } + if !contactsService.upsertCalled { + t.Fatalf("应执行联系人绑定") + } + if !preauthService.markUsed { + t.Fatalf("应标记预授权码已使用") + } +} + +func TestIdentityResolverPreauthKeyExpired(t *testing.T) { + store := &fakeIdentityConfigStore{} + contactsService := &fakeIdentityContactService{} + policyService := &fakePolicyServiceIdentity{} + preauthService := &fakePreauthService{ + key: preauth.Key{ + ID: "key-1", + BotID: "bot-1", + Token: "PREAUTH123", + ExpiresAt: time.Now().UTC().Add(-1 * time.Hour), + }, + } + resolver := NewIdentityResolver(slog.Default(), store, contactsService, policyService, preauthService, "禁止访问", "授权成功") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "PREAUTH123"}, + ReplyTarget: "target-id", + Sender: channel.Identity{ExternalID: "user-1"}, + } + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("不应报错: %v", err) + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatalf("过期预授权码应被拒绝") + } + if preauthService.markUsed { + t.Fatalf("过期预授权码不应被使用") + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 819d08c8..639cc97d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -17,7 +17,7 @@ type Server struct { logger *slog.Logger } -func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, chatHandler *handlers.ChatHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, settingsHandler *handlers.SettingsHandler, historyHandler *handlers.HistoryHandler, contactsHandler *handlers.ContactsHandler, scheduleHandler *handlers.ScheduleHandler, subagentHandler *handlers.SubagentHandler, containerdHandler *handlers.ContainerdHandler, channelHandler *handlers.ChannelHandler, usersHandler *handlers.UsersHandler, cliHandler *handlers.LocalChannelHandler, webHandler *handlers.LocalChannelHandler) *Server { +func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, chatHandler *handlers.ChatHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, settingsHandler *handlers.SettingsHandler, historyHandler *handlers.HistoryHandler, contactsHandler *handlers.ContactsHandler, preauthHandler *handlers.PreauthHandler, scheduleHandler *handlers.ScheduleHandler, subagentHandler *handlers.SubagentHandler, containerdHandler *handlers.ContainerdHandler, channelHandler *handlers.ChannelHandler, usersHandler *handlers.UsersHandler, cliHandler *handlers.LocalChannelHandler, webHandler *handlers.LocalChannelHandler) *Server { if addr == "" { addr = ":8080" } @@ -78,6 +78,9 @@ func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *han if contactsHandler != nil { contactsHandler.Register(e) } + if preauthHandler != nil { + preauthHandler.Register(e) + } if scheduleHandler != nil { scheduleHandler.Register(e) } diff --git a/internal/settings/types.go b/internal/settings/types.go index 4db7dce3..f750f8ff 100644 --- a/internal/settings/types.go +++ b/internal/settings/types.go @@ -2,7 +2,7 @@ package settings const ( DefaultMaxContextLoadTime = 24 * 60 - DefaultLanguage = "Same as user input" + DefaultLanguage = "auto" ) type Settings struct { diff --git a/internal/subagent/service.go b/internal/subagent/service.go index 89e5bf64..c864a4db 100644 --- a/internal/subagent/service.go +++ b/internal/subagent/service.go @@ -253,44 +253,44 @@ func toSubagent(row sqlc.Subagent) (Subagent, error) { return item, nil } -func marshalMessages(messages []map[string]interface{}) ([]byte, error) { +func marshalMessages(messages []map[string]any) ([]byte, error) { if messages == nil { - messages = []map[string]interface{}{} + messages = []map[string]any{} } return json.Marshal(messages) } -func unmarshalMessages(payload []byte) ([]map[string]interface{}, error) { +func unmarshalMessages(payload []byte) ([]map[string]any, error) { if len(payload) == 0 { - return []map[string]interface{}{}, nil + return []map[string]any{}, nil } - var messages []map[string]interface{} + var messages []map[string]any if err := json.Unmarshal(payload, &messages); err != nil { return nil, err } if messages == nil { - messages = []map[string]interface{}{} + messages = []map[string]any{} } return messages, nil } -func marshalMetadata(metadata map[string]interface{}) ([]byte, error) { +func marshalMetadata(metadata map[string]any) ([]byte, error) { if metadata == nil { - metadata = map[string]interface{}{} + metadata = map[string]any{} } return json.Marshal(metadata) } -func unmarshalMetadata(payload []byte) (map[string]interface{}, error) { +func unmarshalMetadata(payload []byte) (map[string]any, error) { if len(payload) == 0 { - return map[string]interface{}{}, nil + return map[string]any{}, nil } - var metadata map[string]interface{} + var metadata map[string]any if err := json.Unmarshal(payload, &metadata); err != nil { return nil, err } if metadata == nil { - metadata = map[string]interface{}{} + metadata = map[string]any{} } return metadata, nil } diff --git a/internal/subagent/types.go b/internal/subagent/types.go index cf412772..77498a12 100644 --- a/internal/subagent/types.go +++ b/internal/subagent/types.go @@ -7,8 +7,8 @@ type Subagent struct { Name string `json:"name"` Description string `json:"description"` BotID string `json:"bot_id"` - Messages []map[string]interface{} `json:"messages"` - Metadata map[string]interface{} `json:"metadata"` + Messages []map[string]any `json:"messages"` + Metadata map[string]any `json:"metadata"` Skills []string `json:"skills"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` @@ -19,19 +19,19 @@ type Subagent struct { type CreateRequest struct { Name string `json:"name"` Description string `json:"description"` - Messages []map[string]interface{} `json:"messages,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` + Messages []map[string]any `json:"messages,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` Skills []string `json:"skills,omitempty"` } type UpdateRequest struct { Name *string `json:"name,omitempty"` Description *string `json:"description,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } type UpdateContextRequest struct { - Messages []map[string]interface{} `json:"messages"` + Messages []map[string]any `json:"messages"` } type UpdateSkillsRequest struct { @@ -47,7 +47,7 @@ type ListResponse struct { } type ContextResponse struct { - Messages []map[string]interface{} `json:"messages"` + Messages []map[string]any `json:"messages"` } type SkillsResponse struct { diff --git a/internal/users/service.go b/internal/users/service.go index d93a1214..24c0c405 100644 --- a/internal/users/service.go +++ b/internal/users/service.go @@ -22,9 +22,9 @@ type Service struct { } var ( - ErrInvalidPassword = errors.New("invalid password") - ErrInvalidCredentials = errors.New("invalid credentials") - ErrInactiveUser = errors.New("user is inactive") + ErrInvalidPassword = errors.New("invalid password") + ErrInvalidCredentials = errors.New("invalid credentials") + ErrInactiveUser = errors.New("user is inactive") ) func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { @@ -325,7 +325,7 @@ func normalizeRole(raw string) (string, error) { return role, nil } -func isAdminRole(role interface{}) bool { +func isAdminRole(role any) bool { if role == nil { return false } diff --git a/packages/shared/src/model.ts b/packages/shared/src/model.ts index 83aced69..56b0c6fd 100644 --- a/packages/shared/src/model.ts +++ b/packages/shared/src/model.ts @@ -90,6 +90,7 @@ export interface ProviderInfo{ export interface ModelInfo{ dimensions:number is_multimodal:boolean + input?: string[] llm_provider_id:string model_id:string name:string diff --git a/packages/ui/src/index.ts b/packages/ui/src/index.ts index cb74b21f..1920d05b 100644 --- a/packages/ui/src/index.ts +++ b/packages/ui/src/index.ts @@ -1,40 +1,40 @@ -export * from './components/alert/index' -export * from './components/avatar/index' -export * from './components/badge/index' -export * from './components/breadcrumb/index' -export * from './components/button/index' -export * from './components/button-group/index' -export * from './components/card/index' -export * from './components/checkbox/index' -export * from './components/collapsible/index' -export * from './components/combobox/index' -export * from './components/context-menu/index' -export * from './components/dialog/index' -export * from './components/dropdown-menu/index' -export * from './components/empty/index' -export * from './components/form/index' -export * from './components/input/index' -export * from './components/input-group/index' -export * from './components/item/index' -export * from './components/kbd/index' -export * from './components/label/index' -export * from './components/native-select/index' -export * from './components/pagination/index' -export * from './components/popover/index' -export * from './components/radio-group/index' -export * from './components/scroll-area/index' -export * from './components/select/index' -export * from './components/separator/index' -export * from './components/sheet/index' -export * from './components/sidebar/index' -export * from './components/skeleton/index' -export * from './components/slider/index' -export * from './components/sonner/index' -export * from './components/spinner/index' -export * from './components/switch/index' -export * from './components/table/index' -export * from './components/tabs/index' -export * from './components/tags-input/index' -export * from './components/textarea/index' -export * from './components/toggle/index' +export * from './components/alert/index' +export * from './components/avatar/index' +export * from './components/badge/index' +export * from './components/breadcrumb/index' +export * from './components/button/index' +export * from './components/button-group/index' +export * from './components/card/index' +export * from './components/checkbox/index' +export * from './components/collapsible/index' +export * from './components/combobox/index' +export * from './components/context-menu/index' +export * from './components/dialog/index' +export * from './components/dropdown-menu/index' +export * from './components/empty/index' +export * from './components/form/index' +export * from './components/input/index' +export * from './components/input-group/index' +export * from './components/item/index' +export * from './components/kbd/index' +export * from './components/label/index' +export * from './components/native-select/index' +export * from './components/pagination/index' +export * from './components/popover/index' +export * from './components/radio-group/index' +export * from './components/scroll-area/index' +export * from './components/select/index' +export * from './components/separator/index' +export * from './components/sheet/index' +export * from './components/sidebar/index' +export * from './components/skeleton/index' +export * from './components/slider/index' +export * from './components/sonner/index' +export * from './components/spinner/index' +export * from './components/switch/index' +export * from './components/table/index' +export * from './components/tabs/index' +export * from './components/tags-input/index' +export * from './components/textarea/index' +export * from './components/toggle/index' export * from './components/tooltip/index' \ No newline at end of file diff --git a/scripts/db-drop.sh b/scripts/db-drop.sh old mode 100755 new mode 100644 diff --git a/scripts/db-up.sh b/scripts/db-up.sh old mode 100755 new mode 100644 diff --git a/sqlc.yaml b/sqlc.yaml index 6990102c..c24809b6 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -11,4 +11,4 @@ sql: emit_json_tags: true overrides: - db_type: "user_role" - go_type: "string" + go_type: "string" \ No newline at end of file