diff --git a/cmd/agent/main.go b/cmd/agent/main.go index d740ddfd..7ebf8dea 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -159,20 +159,21 @@ func main() { embeddingsHandler := handlers.NewEmbeddingsHandler(logger.L, modelsService, queries) swaggerHandler := handlers.NewSwaggerHandler(logger.L) chatHandler := handlers.NewChatHandler(logger.L, chatResolver, botService, usersService) - channelService := channel.NewService(queries) - channelRouter := router.NewChannelInboundProcessor(logger.L, channelService, chatResolver, contactsService, policyService, preauthService, cfg.Auth.JWTSecret, 5*time.Minute) - channelManager := channel.NewManager(logger.L, channelService, channelRouter) + channelRegistry := channel.NewRegistry() + sessionHub := local.NewSessionHub() + channelRegistry.MustRegister(telegram.NewTelegramAdapter(logger.L)) + channelRegistry.MustRegister(feishu.NewFeishuAdapter(logger.L)) + channelRegistry.MustRegister(local.NewCLIAdapter(sessionHub)) + channelRegistry.MustRegister(local.NewWebAdapter(sessionHub)) + channelService := channel.NewService(queries, channelRegistry) + channelRouter := router.NewChannelInboundProcessor(logger.L, channelRegistry, channelService, chatResolver, contactsService, policyService, preauthService, cfg.Auth.JWTSecret, 5*time.Minute) + channelManager := channel.NewManager(logger.L, channelRegistry, channelService, channelRouter) if mw := channelRouter.IdentityMiddleware(); mw != nil { channelManager.Use(mw) } - sessionHub := channel.NewSessionHub() - channelManager.RegisterAdapter(telegram.NewTelegramAdapter(logger.L)) - channelManager.RegisterAdapter(feishu.NewFeishuAdapter(logger.L)) - channelManager.RegisterAdapter(local.NewCLIAdapter(sessionHub)) - channelManager.RegisterAdapter(local.NewWebAdapter(sessionHub)) channelManager.Start(ctx) - channelHandler := handlers.NewChannelHandler(channelService) - usersHandler := handlers.NewUsersHandler(logger.L, usersService, botService, channelService, channelManager) + channelHandler := handlers.NewChannelHandler(channelService, channelRegistry) + usersHandler := handlers.NewUsersHandler(logger.L, usersService, botService, channelService, channelManager, channelRegistry) cliHandler := handlers.NewLocalChannelHandler(local.CLIType, channelManager, channelService, sessionHub, botService, usersService) webHandler := handlers.NewLocalChannelHandler(local.WebType, channelManager, channelService, sessionHub, botService, usersService) scheduleGateway := chat.NewScheduleGateway(chatResolver) diff --git a/docs/docs.go b/docs/docs.go index 947e88b9..0904ecaf 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -4207,10 +4207,6 @@ const docTemplate = `{ "botID": { "type": "string" }, - "capabilities": { - "type": "object", - "additionalProperties": {} - }, "channelType": { "type": "string" }, @@ -4506,10 +4502,6 @@ const docTemplate = `{ "channel.UpsertConfigRequest": { "type": "object", "properties": { - "capabilities": { - "type": "object", - "additionalProperties": {} - }, "credentials": { "type": "object", "additionalProperties": {} diff --git a/docs/swagger.json b/docs/swagger.json index 88734826..64046fb6 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -4198,10 +4198,6 @@ "botID": { "type": "string" }, - "capabilities": { - "type": "object", - "additionalProperties": {} - }, "channelType": { "type": "string" }, @@ -4497,10 +4493,6 @@ "channel.UpsertConfigRequest": { "type": "object", "properties": { - "capabilities": { - "type": "object", - "additionalProperties": {} - }, "credentials": { "type": "object", "additionalProperties": {} diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 5eddb616..a73f3bec 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -178,9 +178,6 @@ definitions: properties: botID: type: string - capabilities: - additionalProperties: {} - type: object channelType: type: string createdAt: @@ -385,9 +382,6 @@ definitions: type: object channel.UpsertConfigRequest: properties: - capabilities: - additionalProperties: {} - type: object credentials: additionalProperties: {} type: object diff --git a/internal/channel/adapter.go b/internal/channel/adapter.go index 0cb9c0df..9cb5f6b7 100644 --- a/internal/channel/adapter.go +++ b/internal/channel/adapter.go @@ -20,6 +20,38 @@ type ReplySender interface { // Adapter is the base interface every channel adapter must implement. type Adapter interface { Type() ChannelType + Descriptor() Descriptor +} + +// Descriptor holds read-only metadata for a registered channel type. +// It contains no behavior — all behavior is expressed through optional interfaces. +type Descriptor struct { + Type ChannelType + DisplayName string + Configless bool + Capabilities ChannelCapabilities + OutboundPolicy OutboundPolicy + ConfigSchema ConfigSchema + UserConfigSchema ConfigSchema + TargetSpec TargetSpec +} + +// ConfigNormalizer validates and normalizes channel and user-binding configurations. +type ConfigNormalizer interface { + NormalizeConfig(raw map[string]any) (map[string]any, error) + NormalizeUserConfig(raw map[string]any) (map[string]any, error) +} + +// TargetResolver handles delivery target normalization and resolution from user bindings. +type TargetResolver interface { + NormalizeTarget(raw string) string + ResolveTarget(userConfig map[string]any) (string, error) +} + +// BindingMatcher matches user-channel bindings and constructs binding configs from identities. +type BindingMatcher interface { + MatchBinding(config map[string]any, criteria BindingCriteria) bool + BuildUserConfig(identity Identity) map[string]any } // Sender is an adapter capable of sending outbound messages. diff --git a/internal/channel/adapters/feishu/config.go b/internal/channel/adapters/feishu/config.go index fcd9dc1e..d7def25a 100644 --- a/internal/channel/adapters/feishu/config.go +++ b/internal/channel/adapters/feishu/config.go @@ -21,8 +21,7 @@ type UserConfig struct { UserID string } -// NormalizeConfig validates and normalizes a Feishu channel configuration map. -func NormalizeConfig(raw map[string]any) (map[string]any, error) { +func normalizeConfig(raw map[string]any) (map[string]any, error) { cfg, err := parseConfig(raw) if err != nil { return nil, err @@ -40,8 +39,7 @@ func NormalizeConfig(raw map[string]any) (map[string]any, error) { return result, nil } -// NormalizeUserConfig validates and normalizes a Feishu user-binding configuration map. -func NormalizeUserConfig(raw map[string]any) (map[string]any, error) { +func normalizeUserConfig(raw map[string]any) (map[string]any, error) { cfg, err := parseUserConfig(raw) if err != nil { return nil, err @@ -56,8 +54,7 @@ func NormalizeUserConfig(raw map[string]any) (map[string]any, error) { return result, nil } -// ResolveTarget derives a Feishu delivery target from a user-binding configuration. -func ResolveTarget(raw map[string]any) (string, error) { +func resolveTarget(raw map[string]any) (string, error) { cfg, err := parseUserConfig(raw) if err != nil { return "", err @@ -71,8 +68,7 @@ func ResolveTarget(raw map[string]any) (string, error) { return "", fmt.Errorf("feishu binding is incomplete") } -// MatchBinding reports whether a Feishu user binding matches the given criteria. -func MatchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { +func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { cfg, err := parseUserConfig(raw) if err != nil { return false @@ -91,8 +87,7 @@ func MatchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { return false } -// BuildUserConfig constructs a Feishu user-binding config from an Identity. -func BuildUserConfig(identity channel.Identity) map[string]any { +func buildUserConfig(identity channel.Identity) map[string]any { result := map[string]any{} if value := strings.TrimSpace(identity.Attribute("open_id")); value != "" { result["open_id"] = value diff --git a/internal/channel/adapters/feishu/config_test.go b/internal/channel/adapters/feishu/config_test.go index 501a0886..1b1feb60 100644 --- a/internal/channel/adapters/feishu/config_test.go +++ b/internal/channel/adapters/feishu/config_test.go @@ -5,7 +5,7 @@ import "testing" func TestNormalizeConfig(t *testing.T) { t.Parallel() - got, err := NormalizeConfig(map[string]any{ + got, err := normalizeConfig(map[string]any{ "app_id": "app", "app_secret": "secret", "encrypt_key": "enc", @@ -25,7 +25,7 @@ func TestNormalizeConfig(t *testing.T) { func TestNormalizeConfigRequiresApp(t *testing.T) { t.Parallel() - _, err := NormalizeConfig(map[string]any{}) + _, err := normalizeConfig(map[string]any{}) if err == nil { t.Fatalf("expected error, got nil") } @@ -34,7 +34,7 @@ func TestNormalizeConfigRequiresApp(t *testing.T) { func TestNormalizeUserConfig(t *testing.T) { t.Parallel() - got, err := NormalizeUserConfig(map[string]any{ + got, err := normalizeUserConfig(map[string]any{ "open_id": "ou_123", }) if err != nil { @@ -48,7 +48,7 @@ func TestNormalizeUserConfig(t *testing.T) { func TestNormalizeUserConfigRequiresBinding(t *testing.T) { t.Parallel() - _, err := NormalizeUserConfig(map[string]any{}) + _, err := normalizeUserConfig(map[string]any{}) if err == nil { t.Fatalf("expected error, got nil") } @@ -57,7 +57,7 @@ func TestNormalizeUserConfigRequiresBinding(t *testing.T) { func TestResolveTarget(t *testing.T) { t.Parallel() - target, err := ResolveTarget(map[string]any{ + target, err := resolveTarget(map[string]any{ "open_id": "ou_123", "user_id": "u_123", }) diff --git a/internal/channel/adapters/feishu/descriptor.go b/internal/channel/adapters/feishu/descriptor.go index b97be3c8..b1fd4129 100644 --- a/internal/channel/adapters/feishu/descriptor.go +++ b/internal/channel/adapters/feishu/descriptor.go @@ -5,52 +5,3 @@ import "github.com/memohai/memoh/internal/channel" // Type is the registered ChannelType identifier for Feishu. const Type channel.ChannelType = "feishu" - -func init() { - channel.MustRegisterChannel(channel.ChannelDescriptor{ - Type: Type, - DisplayName: "Feishu", - NormalizeConfig: NormalizeConfig, - NormalizeUserConfig: NormalizeUserConfig, - ResolveTarget: ResolveTarget, - MatchBinding: MatchBinding, - BuildUserConfig: BuildUserConfig, - TargetSpec: channel.TargetSpec{ - Format: "open_id:xxx | user_id:xxx | chat_id:xxx", - Hints: []channel.TargetHint{ - {Label: "Open ID", Example: "open_id:ou_xxx"}, - {Label: "User ID", Example: "user_id:ou_xxx"}, - {Label: "Chat ID", Example: "chat_id:oc_xxx"}, - }, - }, - NormalizeTarget: normalizeTarget, - Capabilities: channel.ChannelCapabilities{ - Text: true, - RichText: true, - Attachments: true, - Reply: true, - }, - ConfigSchema: channel.ConfigSchema{ - Version: 1, - Fields: map[string]channel.FieldSchema{ - "appId": {Type: channel.FieldString, Required: true, Title: "App ID"}, - "appSecret": {Type: channel.FieldSecret, Required: true, Title: "App Secret"}, - "encryptKey": { - Type: channel.FieldSecret, - Title: "Encrypt Key", - }, - "verificationToken": { - Type: channel.FieldSecret, - Title: "Verification Token", - }, - }, - }, - UserConfigSchema: channel.ConfigSchema{ - Version: 1, - Fields: map[string]channel.FieldSchema{ - "open_id": {Type: channel.FieldString}, - "user_id": {Type: channel.FieldString}, - }, - }, - }) -} diff --git a/internal/channel/adapters/feishu/feishu.go b/internal/channel/adapters/feishu/feishu.go index 6d9f1d9c..e1cc6d7f 100644 --- a/internal/channel/adapters/feishu/feishu.go +++ b/internal/channel/adapters/feishu/feishu.go @@ -40,6 +40,80 @@ func (a *FeishuAdapter) Type() channel.ChannelType { return Type } +// Descriptor returns the Feishu channel metadata. +func (a *FeishuAdapter) Descriptor() channel.Descriptor { + return channel.Descriptor{ + Type: Type, + DisplayName: "Feishu", + Capabilities: channel.ChannelCapabilities{ + Text: true, + RichText: true, + Attachments: true, + Reply: true, + }, + ConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "appId": {Type: channel.FieldString, Required: true, Title: "App ID"}, + "appSecret": {Type: channel.FieldSecret, Required: true, Title: "App Secret"}, + "encryptKey": { + Type: channel.FieldSecret, + Title: "Encrypt Key", + }, + "verificationToken": { + Type: channel.FieldSecret, + Title: "Verification Token", + }, + }, + }, + UserConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "open_id": {Type: channel.FieldString}, + "user_id": {Type: channel.FieldString}, + }, + }, + TargetSpec: channel.TargetSpec{ + Format: "open_id:xxx | user_id:xxx | chat_id:xxx", + Hints: []channel.TargetHint{ + {Label: "Open ID", Example: "open_id:ou_xxx"}, + {Label: "User ID", Example: "user_id:ou_xxx"}, + {Label: "Chat ID", Example: "chat_id:oc_xxx"}, + }, + }, + } +} + +// NormalizeConfig validates and normalizes a Feishu channel configuration map. +func (a *FeishuAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { + return normalizeConfig(raw) +} + +// NormalizeUserConfig validates and normalizes a Feishu user-binding configuration map. +func (a *FeishuAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { + return normalizeUserConfig(raw) +} + +// NormalizeTarget normalizes a Feishu delivery target string. +func (a *FeishuAdapter) NormalizeTarget(raw string) string { + return normalizeTarget(raw) +} + +// ResolveTarget derives a delivery target from a Feishu user-binding configuration. +func (a *FeishuAdapter) ResolveTarget(userConfig map[string]any) (string, error) { + return resolveTarget(userConfig) +} + +// MatchBinding reports whether a Feishu user binding matches the given criteria. +func (a *FeishuAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool { + return matchBinding(config, criteria) +} + +// BuildUserConfig constructs a Feishu user-binding config from an Identity. +func (a *FeishuAdapter) BuildUserConfig(identity channel.Identity) map[string]any { + return buildUserConfig(identity) +} + // Connect establishes a WebSocket connection to Feishu and forwards inbound messages to the handler. func (a *FeishuAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.Connection, error) { if a.logger != nil { diff --git a/internal/channel/adapters/local/cli.go b/internal/channel/adapters/local/cli.go index 42a6bdb7..9241b207 100644 --- a/internal/channel/adapters/local/cli.go +++ b/internal/channel/adapters/local/cli.go @@ -10,11 +10,11 @@ import ( // CLIAdapter implements channel.Sender for the local CLI channel. type CLIAdapter struct { - hub *channel.SessionHub + hub *SessionHub } // NewCLIAdapter creates a CLIAdapter backed by the given session hub. -func NewCLIAdapter(hub *channel.SessionHub) *CLIAdapter { +func NewCLIAdapter(hub *SessionHub) *CLIAdapter { return &CLIAdapter{hub: hub} } @@ -23,6 +23,26 @@ func (a *CLIAdapter) Type() channel.ChannelType { return CLIType } +// Descriptor returns the CLI channel metadata. +func (a *CLIAdapter) Descriptor() channel.Descriptor { + return channel.Descriptor{ + Type: CLIType, + DisplayName: "CLI", + Configless: true, + Capabilities: channel.ChannelCapabilities{ + Text: true, + Reply: true, + Attachments: true, + }, + TargetSpec: channel.TargetSpec{ + Format: "session_id", + Hints: []channel.TargetHint{ + {Label: "Session ID", Example: "cli:uuid"}, + }, + }, + } +} + // Send publishes an outbound message to the CLI session hub. func (a *CLIAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { if a.hub == nil { diff --git a/internal/channel/adapters/local/descriptor.go b/internal/channel/adapters/local/descriptor.go index bfc99c90..fe262a71 100644 --- a/internal/channel/adapters/local/descriptor.go +++ b/internal/channel/adapters/local/descriptor.go @@ -1,11 +1,7 @@ // Package local implements the CLI and Web channel adapters for local development. package local -import ( - "strings" - - "github.com/memohai/memoh/internal/channel" -) +import "github.com/memohai/memoh/internal/channel" const ( // CLIType is the registered ChannelType for the CLI adapter. @@ -13,58 +9,3 @@ const ( // WebType is the registered ChannelType for the Web adapter. WebType channel.ChannelType = "web" ) - -func init() { - channel.MustRegisterChannel(channel.ChannelDescriptor{ - Type: CLIType, - DisplayName: "CLI", - NormalizeConfig: normalizeEmpty, - NormalizeUserConfig: normalizeEmpty, - BuildUserConfig: buildEmpty, - Configless: true, - TargetSpec: channel.TargetSpec{ - Format: "session_id", - Hints: []channel.TargetHint{ - {Label: "Session ID", Example: "cli:uuid"}, - }, - }, - NormalizeTarget: normalizeTarget, - Capabilities: channel.ChannelCapabilities{ - Text: true, - Reply: true, - Attachments: true, - }, - }) - channel.MustRegisterChannel(channel.ChannelDescriptor{ - Type: WebType, - DisplayName: "Web", - NormalizeConfig: normalizeEmpty, - NormalizeUserConfig: normalizeEmpty, - BuildUserConfig: buildEmpty, - Configless: true, - TargetSpec: channel.TargetSpec{ - Format: "session_id", - Hints: []channel.TargetHint{ - {Label: "Session ID", Example: "web:uuid"}, - }, - }, - NormalizeTarget: normalizeTarget, - Capabilities: channel.ChannelCapabilities{ - Text: true, - Reply: true, - Attachments: true, - }, - }) -} - -func normalizeTarget(raw string) string { - return strings.TrimSpace(raw) -} - -func normalizeEmpty(map[string]any) (map[string]any, error) { - return map[string]any{}, nil -} - -func buildEmpty(channel.Identity) map[string]any { - return map[string]any{} -} diff --git a/internal/channel/cli_hub.go b/internal/channel/adapters/local/hub.go similarity index 76% rename from internal/channel/cli_hub.go rename to internal/channel/adapters/local/hub.go index 4599ccb8..ddc4b4bd 100644 --- a/internal/channel/cli_hub.go +++ b/internal/channel/adapters/local/hub.go @@ -1,34 +1,36 @@ -package channel +package local import ( "sync" "github.com/google/uuid" + + "github.com/memohai/memoh/internal/channel" ) // SessionHub is a pub/sub hub that routes outbound messages to CLI/Web session subscribers. type SessionHub struct { mu sync.RWMutex - sessions map[string]map[string]chan OutboundMessage + sessions map[string]map[string]chan channel.OutboundMessage } // NewSessionHub creates an empty SessionHub. func NewSessionHub() *SessionHub { return &SessionHub{ - sessions: map[string]map[string]chan OutboundMessage{}, + sessions: map[string]map[string]chan channel.OutboundMessage{}, } } // Subscribe registers a new stream for the given session and returns a stream ID, // a read-only channel for messages, and a cancel function to unsubscribe. -func (h *SessionHub) Subscribe(sessionID string) (string, <-chan OutboundMessage, func()) { +func (h *SessionHub) Subscribe(sessionID string) (string, <-chan channel.OutboundMessage, func()) { streamID := uuid.NewString() - ch := make(chan OutboundMessage, 32) + ch := make(chan channel.OutboundMessage, 32) h.mu.Lock() streams, ok := h.sessions[sessionID] if !ok { - streams = map[string]chan OutboundMessage{} + streams = map[string]chan channel.OutboundMessage{} h.sessions[sessionID] = streams } streams[streamID] = ch @@ -54,7 +56,7 @@ func (h *SessionHub) Subscribe(sessionID string) (string, <-chan OutboundMessage // Publish delivers a message to all subscribers of the given session. // Slow receivers are silently dropped. -func (h *SessionHub) Publish(sessionID string, msg OutboundMessage) { +func (h *SessionHub) Publish(sessionID string, msg channel.OutboundMessage) { h.mu.RLock() defer h.mu.RUnlock() for _, ch := range h.sessions[sessionID] { diff --git a/internal/channel/adapters/local/web.go b/internal/channel/adapters/local/web.go index 37ab217a..1490604b 100644 --- a/internal/channel/adapters/local/web.go +++ b/internal/channel/adapters/local/web.go @@ -10,11 +10,11 @@ import ( // WebAdapter implements channel.Sender for the local Web channel. type WebAdapter struct { - hub *channel.SessionHub + hub *SessionHub } // NewWebAdapter creates a WebAdapter backed by the given session hub. -func NewWebAdapter(hub *channel.SessionHub) *WebAdapter { +func NewWebAdapter(hub *SessionHub) *WebAdapter { return &WebAdapter{hub: hub} } @@ -23,6 +23,26 @@ func (a *WebAdapter) Type() channel.ChannelType { return WebType } +// Descriptor returns the Web channel metadata. +func (a *WebAdapter) Descriptor() channel.Descriptor { + return channel.Descriptor{ + Type: WebType, + DisplayName: "Web", + Configless: true, + Capabilities: channel.ChannelCapabilities{ + Text: true, + Reply: true, + Attachments: true, + }, + TargetSpec: channel.TargetSpec{ + Format: "session_id", + Hints: []channel.TargetHint{ + {Label: "Session ID", Example: "web:uuid"}, + }, + }, + } +} + // Send publishes an outbound message to the Web session hub. func (a *WebAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { if a.hub == nil { diff --git a/internal/channel/adapters/telegram/config.go b/internal/channel/adapters/telegram/config.go index 3bfdc99d..51d2d3a6 100644 --- a/internal/channel/adapters/telegram/config.go +++ b/internal/channel/adapters/telegram/config.go @@ -19,8 +19,7 @@ type UserConfig struct { ChatID string } -// NormalizeConfig validates and normalizes a Telegram channel configuration map. -func NormalizeConfig(raw map[string]any) (map[string]any, error) { +func normalizeConfig(raw map[string]any) (map[string]any, error) { cfg, err := parseConfig(raw) if err != nil { return nil, err @@ -30,8 +29,7 @@ func NormalizeConfig(raw map[string]any) (map[string]any, error) { }, nil } -// NormalizeUserConfig validates and normalizes a Telegram user-binding configuration map. -func NormalizeUserConfig(raw map[string]any) (map[string]any, error) { +func normalizeUserConfig(raw map[string]any) (map[string]any, error) { cfg, err := parseUserConfig(raw) if err != nil { return nil, err @@ -49,8 +47,7 @@ func NormalizeUserConfig(raw map[string]any) (map[string]any, error) { return result, nil } -// ResolveTarget derives a Telegram delivery target from a user-binding configuration. -func ResolveTarget(raw map[string]any) (string, error) { +func resolveTarget(raw map[string]any) (string, error) { cfg, err := parseUserConfig(raw) if err != nil { return "", err @@ -71,8 +68,7 @@ func ResolveTarget(raw map[string]any) (string, error) { return "", fmt.Errorf("telegram binding is incomplete") } -// MatchBinding reports whether a Telegram user binding matches the given criteria. -func MatchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { +func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { cfg, err := parseUserConfig(raw) if err != nil { return false @@ -94,8 +90,7 @@ func MatchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { return false } -// BuildUserConfig constructs a Telegram user-binding config from an Identity. -func BuildUserConfig(identity channel.Identity) map[string]any { +func buildUserConfig(identity channel.Identity) map[string]any { result := map[string]any{} if value := strings.TrimSpace(identity.Attribute("username")); value != "" { result["username"] = value diff --git a/internal/channel/adapters/telegram/config_test.go b/internal/channel/adapters/telegram/config_test.go index c3d28165..3a314592 100644 --- a/internal/channel/adapters/telegram/config_test.go +++ b/internal/channel/adapters/telegram/config_test.go @@ -5,7 +5,7 @@ import "testing" func TestNormalizeConfig(t *testing.T) { t.Parallel() - got, err := NormalizeConfig(map[string]any{ + got, err := normalizeConfig(map[string]any{ "bot_token": "token-123", }) if err != nil { @@ -19,7 +19,7 @@ func TestNormalizeConfig(t *testing.T) { func TestNormalizeConfigRequiresToken(t *testing.T) { t.Parallel() - _, err := NormalizeConfig(map[string]any{}) + _, err := normalizeConfig(map[string]any{}) if err == nil { t.Fatalf("expected error, got nil") } @@ -28,7 +28,7 @@ func TestNormalizeConfigRequiresToken(t *testing.T) { func TestNormalizeUserConfig(t *testing.T) { t.Parallel() - got, err := NormalizeUserConfig(map[string]any{ + got, err := normalizeUserConfig(map[string]any{ "username": "alice", }) if err != nil { @@ -42,7 +42,7 @@ func TestNormalizeUserConfig(t *testing.T) { func TestNormalizeUserConfigRequiresBinding(t *testing.T) { t.Parallel() - _, err := NormalizeUserConfig(map[string]any{}) + _, err := normalizeUserConfig(map[string]any{}) if err == nil { t.Fatalf("expected error, got nil") } @@ -51,7 +51,7 @@ func TestNormalizeUserConfigRequiresBinding(t *testing.T) { func TestResolveTarget(t *testing.T) { t.Parallel() - target, err := ResolveTarget(map[string]any{ + target, err := resolveTarget(map[string]any{ "chat_id": "123", }) if err != nil { @@ -65,7 +65,7 @@ func TestResolveTarget(t *testing.T) { func TestResolveTargetUsername(t *testing.T) { t.Parallel() - target, err := ResolveTarget(map[string]any{ + target, err := resolveTarget(map[string]any{ "username": "alice", }) if err != nil { diff --git a/internal/channel/adapters/telegram/descriptor.go b/internal/channel/adapters/telegram/descriptor.go index 268bc9d8..4638e9aa 100644 --- a/internal/channel/adapters/telegram/descriptor.go +++ b/internal/channel/adapters/telegram/descriptor.go @@ -5,48 +5,3 @@ import "github.com/memohai/memoh/internal/channel" // Type is the registered ChannelType identifier for Telegram. const Type channel.ChannelType = "telegram" - -func init() { - channel.MustRegisterChannel(channel.ChannelDescriptor{ - Type: Type, - DisplayName: "Telegram", - NormalizeConfig: NormalizeConfig, - NormalizeUserConfig: NormalizeUserConfig, - ResolveTarget: ResolveTarget, - MatchBinding: MatchBinding, - BuildUserConfig: BuildUserConfig, - TargetSpec: channel.TargetSpec{ - Format: "chat_id | @username", - Hints: []channel.TargetHint{ - {Label: "Chat ID", Example: "123456789"}, - {Label: "Username", Example: "@alice"}, - }, - }, - NormalizeTarget: normalizeTarget, - Capabilities: channel.ChannelCapabilities{ - Text: true, - Markdown: true, - Reply: true, - Attachments: true, - Media: true, - }, - ConfigSchema: channel.ConfigSchema{ - Version: 1, - Fields: map[string]channel.FieldSchema{ - "botToken": { - Type: channel.FieldSecret, - Required: true, - Title: "Bot Token", - }, - }, - }, - UserConfigSchema: channel.ConfigSchema{ - Version: 1, - Fields: map[string]channel.FieldSchema{ - "username": {Type: channel.FieldString}, - "user_id": {Type: channel.FieldString}, - "chat_id": {Type: channel.FieldString}, - }, - }, - }) -} diff --git a/internal/channel/adapters/telegram/telegram.go b/internal/channel/adapters/telegram/telegram.go index d838b139..90e46062 100644 --- a/internal/channel/adapters/telegram/telegram.go +++ b/internal/channel/adapters/telegram/telegram.go @@ -61,6 +61,76 @@ func (a *TelegramAdapter) Type() channel.ChannelType { return Type } +// Descriptor returns the Telegram channel metadata. +func (a *TelegramAdapter) Descriptor() channel.Descriptor { + return channel.Descriptor{ + Type: Type, + DisplayName: "Telegram", + Capabilities: channel.ChannelCapabilities{ + Text: true, + Markdown: true, + Reply: true, + Attachments: true, + Media: true, + }, + ConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "botToken": { + Type: channel.FieldSecret, + Required: true, + Title: "Bot Token", + }, + }, + }, + UserConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "username": {Type: channel.FieldString}, + "user_id": {Type: channel.FieldString}, + "chat_id": {Type: channel.FieldString}, + }, + }, + TargetSpec: channel.TargetSpec{ + Format: "chat_id | @username", + Hints: []channel.TargetHint{ + {Label: "Chat ID", Example: "123456789"}, + {Label: "Username", Example: "@alice"}, + }, + }, + } +} + +// NormalizeConfig validates and normalizes a Telegram channel configuration map. +func (a *TelegramAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { + return normalizeConfig(raw) +} + +// NormalizeUserConfig validates and normalizes a Telegram user-binding configuration map. +func (a *TelegramAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { + return normalizeUserConfig(raw) +} + +// NormalizeTarget normalizes a Telegram delivery target string. +func (a *TelegramAdapter) NormalizeTarget(raw string) string { + return normalizeTarget(raw) +} + +// ResolveTarget derives a delivery target from a Telegram user-binding configuration. +func (a *TelegramAdapter) ResolveTarget(userConfig map[string]any) (string, error) { + return resolveTarget(userConfig) +} + +// MatchBinding reports whether a Telegram user binding matches the given criteria. +func (a *TelegramAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool { + return matchBinding(config, criteria) +} + +// BuildUserConfig constructs a Telegram user-binding config from an Identity. +func (a *TelegramAdapter) BuildUserConfig(identity channel.Identity) map[string]any { + return buildUserConfig(identity) +} + // Connect starts long-polling for Telegram updates and forwards messages to the handler. func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.Connection, error) { if a.logger != nil { diff --git a/internal/channel/config.go b/internal/channel/config.go index 6067cd39..9ca278d3 100644 --- a/internal/channel/config.go +++ b/internal/channel/config.go @@ -6,64 +6,6 @@ import ( "strconv" ) -// NormalizeChannelConfig validates and normalizes a channel configuration map -// using the registered descriptor for the given channel type. -func NormalizeChannelConfig(channelType ChannelType, raw map[string]any) (map[string]any, error) { - if raw == nil { - raw = map[string]any{} - } - desc, ok := GetChannelDescriptor(channelType) - if !ok { - return nil, fmt.Errorf("unsupported channel type: %s", channelType) - } - if desc.NormalizeConfig == nil { - return raw, nil - } - return desc.NormalizeConfig(raw) -} - -// NormalizeChannelUserConfig validates and normalizes a user-channel binding configuration. -func NormalizeChannelUserConfig(channelType ChannelType, raw map[string]any) (map[string]any, error) { - if raw == nil { - raw = map[string]any{} - } - desc, ok := GetChannelDescriptor(channelType) - if !ok { - return nil, fmt.Errorf("unsupported channel type: %s", channelType) - } - if desc.NormalizeUserConfig == nil { - return raw, nil - } - return desc.NormalizeUserConfig(raw) -} - -// ResolveTargetFromUserConfig derives a delivery target string from a user-channel binding. -func ResolveTargetFromUserConfig(channelType ChannelType, config map[string]any) (string, error) { - desc, ok := GetChannelDescriptor(channelType) - if !ok || desc.ResolveTarget == nil { - return "", fmt.Errorf("unsupported channel type: %s", channelType) - } - return desc.ResolveTarget(config) -} - -// MatchUserBinding reports whether the given binding config matches the criteria. -func MatchUserBinding(channelType ChannelType, config map[string]any, criteria BindingCriteria) bool { - desc, ok := GetChannelDescriptor(channelType) - if !ok || desc.MatchBinding == nil { - return false - } - return desc.MatchBinding(config, criteria) -} - -// BuildUserBindingConfig constructs a user-channel binding config from an Identity. -func BuildUserBindingConfig(channelType ChannelType, identity Identity) map[string]any { - desc, ok := GetChannelDescriptor(channelType) - if !ok || desc.BuildUserConfig == nil { - return map[string]any{} - } - return desc.BuildUserConfig(identity) -} - // DecodeConfigMap unmarshals a JSON byte slice into a string-keyed map. func DecodeConfigMap(raw []byte) (map[string]any, error) { if len(raw) == 0 { diff --git a/internal/channel/config_test.go b/internal/channel/config_test.go index 1fae6855..837b3b53 100644 --- a/internal/channel/config_test.go +++ b/internal/channel/config_test.go @@ -2,7 +2,6 @@ package channel_test import ( "fmt" - "sync" "testing" "github.com/memohai/memoh/internal/channel" @@ -10,40 +9,33 @@ import ( const testChannelType = channel.ChannelType("test-config") -var registerTestChannelOnce sync.Once +// testConfigAdapter implements Adapter, ConfigNormalizer, TargetResolver, BindingMatcher for tests. +type testConfigAdapter struct{} -func registerTestChannel() { - registerTestChannelOnce.Do(func() { - if _, ok := channel.GetChannelDescriptor(testChannelType); ok { - return - } - _ = channel.RegisterChannel(channel.ChannelDescriptor{ - Type: testChannelType, - DisplayName: "Test", - NormalizeConfig: normalizeTestConfig, - NormalizeUserConfig: normalizeTestUserConfig, - ResolveTarget: resolveTestTarget, - MatchBinding: matchTestBinding, - Capabilities: channel.ChannelCapabilities{ - Text: true, +func (a *testConfigAdapter) Type() channel.ChannelType { return testChannelType } +func (a *testConfigAdapter) Descriptor() channel.Descriptor { + return channel.Descriptor{ + Type: testChannelType, + DisplayName: "Test", + Capabilities: channel.ChannelCapabilities{ + Text: true, + }, + ConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "value": {Type: channel.FieldString, Required: true}, }, - ConfigSchema: channel.ConfigSchema{ - Version: 1, - Fields: map[string]channel.FieldSchema{ - "value": {Type: channel.FieldString, Required: true}, - }, + }, + UserConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "user": {Type: channel.FieldString, Required: true}, }, - UserConfigSchema: channel.ConfigSchema{ - Version: 1, - Fields: map[string]channel.FieldSchema{ - "user": {Type: channel.FieldString, Required: true}, - }, - }, - }) - }) + }, + } } -func normalizeTestConfig(raw map[string]any) (map[string]any, error) { +func (a *testConfigAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { value := channel.ReadString(raw, "value") if value == "" { return nil, fmt.Errorf("value is required") @@ -51,7 +43,7 @@ func normalizeTestConfig(raw map[string]any) (map[string]any, error) { return map[string]any{"value": value}, nil } -func normalizeTestUserConfig(raw map[string]any) (map[string]any, error) { +func (a *testConfigAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { value := channel.ReadString(raw, "user") if value == "" { return nil, fmt.Errorf("user is required") @@ -59,7 +51,9 @@ func normalizeTestUserConfig(raw map[string]any) (map[string]any, error) { return map[string]any{"user": value}, nil } -func resolveTestTarget(raw map[string]any) (string, error) { +func (a *testConfigAdapter) NormalizeTarget(raw string) string { return raw } + +func (a *testConfigAdapter) ResolveTarget(raw map[string]any) (string, error) { value := channel.ReadString(raw, "target") if value == "" { return "", fmt.Errorf("target is required") @@ -67,32 +61,42 @@ func resolveTestTarget(raw map[string]any) (string, error) { return "resolved:" + value, nil } -func matchTestBinding(raw map[string]any, criteria channel.BindingCriteria) bool { +func (a *testConfigAdapter) MatchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { value := channel.ReadString(raw, "user") return value != "" && value == criteria.ExternalID } +func (a *testConfigAdapter) BuildUserConfig(identity channel.Identity) map[string]any { + return map[string]any{} +} + +func newTestConfigRegistry() *channel.Registry { + reg := channel.NewRegistry() + reg.MustRegister(&testConfigAdapter{}) + return reg +} + func TestParseChannelType(t *testing.T) { t.Parallel() - registerTestChannel() + reg := newTestConfigRegistry() - got, err := channel.ParseChannelType(" test-config ") + got, err := reg.ParseChannelType(" test-config ") if err != nil { t.Fatalf("expected no error, got %v", err) } if got != testChannelType { t.Fatalf("unexpected channel type: %s", got) } - if _, err := channel.ParseChannelType("unknown"); err == nil { + if _, err := reg.ParseChannelType("unknown"); err == nil { t.Fatalf("expected error, got nil") } } func TestNormalizeChannelConfig(t *testing.T) { t.Parallel() - registerTestChannel() + reg := newTestConfigRegistry() - got, err := channel.NormalizeChannelConfig(testChannelType, map[string]any{"value": "ok"}) + got, err := reg.NormalizeConfig(testChannelType, map[string]any{"value": "ok"}) if err != nil { t.Fatalf("expected no error, got %v", err) } @@ -103,9 +107,9 @@ func TestNormalizeChannelConfig(t *testing.T) { func TestNormalizeChannelConfigRequiresValue(t *testing.T) { t.Parallel() - registerTestChannel() + reg := newTestConfigRegistry() - _, err := channel.NormalizeChannelConfig(testChannelType, map[string]any{}) + _, err := reg.NormalizeConfig(testChannelType, map[string]any{}) if err == nil { t.Fatalf("expected error, got nil") } @@ -113,9 +117,9 @@ func TestNormalizeChannelConfigRequiresValue(t *testing.T) { func TestNormalizeChannelUserConfig(t *testing.T) { t.Parallel() - registerTestChannel() + reg := newTestConfigRegistry() - got, err := channel.NormalizeChannelUserConfig(testChannelType, map[string]any{"user": "alice"}) + got, err := reg.NormalizeUserConfig(testChannelType, map[string]any{"user": "alice"}) if err != nil { t.Fatalf("expected no error, got %v", err) } @@ -126,9 +130,9 @@ func TestNormalizeChannelUserConfig(t *testing.T) { func TestNormalizeChannelUserConfigRequiresUser(t *testing.T) { t.Parallel() - registerTestChannel() + reg := newTestConfigRegistry() - _, err := channel.NormalizeChannelUserConfig(testChannelType, map[string]any{}) + _, err := reg.NormalizeUserConfig(testChannelType, map[string]any{}) if err == nil { t.Fatalf("expected error, got nil") } diff --git a/internal/channel/connection.go b/internal/channel/connection.go new file mode 100644 index 00000000..ffcf75c1 --- /dev/null +++ b/internal/channel/connection.go @@ -0,0 +1,176 @@ +package channel + +import ( + "context" + "errors" + "fmt" + "log/slog" + "strings" +) + +type connectionEntry struct { + config ChannelConfig + connection Connection +} + +func (m *Manager) refresh(ctx context.Context) { + if m.service == nil { + return + } + configs := make([]ChannelConfig, 0) + for _, channelType := range m.registry.Types() { + items, err := m.service.ListConfigsByType(ctx, channelType) + if err != nil { + if m.logger != nil { + m.logger.Error("list configs failed", slog.String("channel", channelType.String()), slog.Any("error", err)) + } + continue + } + configs = append(configs, items...) + } + m.reconcile(ctx, configs) +} + +func (m *Manager) reconcile(ctx context.Context, configs []ChannelConfig) { + active := map[string]ChannelConfig{} + for _, cfg := range configs { + if cfg.ID == "" { + continue + } + status := strings.ToLower(strings.TrimSpace(cfg.Status)) + if status != "" && status != "active" && status != "verified" { + continue + } + active[cfg.ID] = cfg + if err := m.ensureConnection(ctx, cfg); err != nil { + if m.logger != nil { + m.logger.Error("adapter start failed", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID), slog.Any("error", err)) + } + } + } + + m.mu.Lock() + defer m.mu.Unlock() + for id, entry := range m.connections { + if _, ok := active[id]; ok { + continue + } + if entry != nil && entry.connection != nil { + if m.logger != nil { + m.logger.Info("adapter stop", slog.String("channel", entry.config.ChannelType.String()), slog.String("config_id", id)) + } + if err := entry.connection.Stop(ctx); err != nil && !errors.Is(err, ErrStopNotSupported) && m.logger != nil { + m.logger.Warn("adapter stop failed", slog.String("config_id", id), slog.Any("error", err)) + } + } + delete(m.connections, id) + } +} + +func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error { + _, ok := m.registry.GetReceiver(cfg.ChannelType) + if !ok { + return nil + } + + m.mu.Lock() + entry := m.connections[cfg.ID] + if entry != nil && !entry.config.UpdatedAt.Before(cfg.UpdatedAt) { + m.mu.Unlock() + return nil + } + if entry != nil { + m.mu.Unlock() + if m.logger != nil { + m.logger.Info("adapter restart", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) + } + if err := entry.connection.Stop(ctx); err != nil { + if errors.Is(err, ErrStopNotSupported) { + if m.logger != nil { + m.logger.Warn("adapter restart skipped", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) + } + return nil + } + return err + } + m.mu.Lock() + delete(m.connections, cfg.ID) + m.mu.Unlock() + } else { + m.mu.Unlock() + } + + receiver, ok := m.registry.GetReceiver(cfg.ChannelType) + if !ok { + return nil + } + + if m.logger != nil { + m.logger.Info("adapter start", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) + } + handler := m.handleInbound + for i := len(m.middlewares) - 1; i >= 0; i-- { + handler = m.middlewares[i](handler) + } + conn, err := receiver.Connect(ctx, cfg, handler) + if err != nil { + return err + } + m.mu.Lock() + m.connections[cfg.ID] = &connectionEntry{ + config: cfg, + connection: conn, + } + m.mu.Unlock() + return nil +} + +func (m *Manager) stopAll(ctx context.Context) { + m.mu.Lock() + defer m.mu.Unlock() + for id, entry := range m.connections { + if entry != nil && entry.connection != nil { + if m.logger != nil { + m.logger.Info("adapter stop", slog.String("channel", entry.config.ChannelType.String()), slog.String("config_id", id)) + } + if err := entry.connection.Stop(ctx); err != nil && !errors.Is(err, ErrStopNotSupported) && m.logger != nil { + m.logger.Warn("adapter stop failed", slog.String("config_id", id), slog.Any("error", err)) + } + } + delete(m.connections, id) + } +} + +// Stop terminates the connection identified by the given config ID. +func (m *Manager) Stop(ctx context.Context, configID string) error { + configID = strings.TrimSpace(configID) + if configID == "" { + return fmt.Errorf("config id is required") + } + m.mu.Lock() + entry := m.connections[configID] + m.mu.Unlock() + if entry == nil || entry.connection == nil { + return nil + } + return entry.connection.Stop(ctx) +} + +// StopByBot terminates all connections belonging to the given bot. +func (m *Manager) StopByBot(ctx context.Context, botID string) error { + botID = strings.TrimSpace(botID) + if botID == "" { + return fmt.Errorf("bot id is required") + } + m.mu.Lock() + defer m.mu.Unlock() + for id, entry := range m.connections { + if entry != nil && entry.config.BotID == botID { + if entry.connection != nil { + _ = entry.connection.Stop(ctx) + } + delete(m.connections, id) + } + } + return nil +} diff --git a/internal/channel/helpers_test.go b/internal/channel/helpers_test.go index 05730c2d..de50b163 100644 --- a/internal/channel/helpers_test.go +++ b/internal/channel/helpers_test.go @@ -4,6 +4,8 @@ import ( "testing" "github.com/google/uuid" + + "github.com/memohai/memoh/internal/db" ) func TestDecodeConfigMap(t *testing.T) { @@ -41,10 +43,10 @@ func TestParseUUID(t *testing.T) { t.Parallel() id := uuid.NewString() - if _, err := parseUUID(id); err != nil { + if _, err := db.ParseUUID(id); err != nil { t.Fatalf("expected no error, got %v", err) } - if _, err := parseUUID("invalid"); err == nil { + if _, err := db.ParseUUID("invalid"); err == nil { t.Fatalf("expected error, got nil") } } diff --git a/internal/channel/inbound.go b/internal/channel/inbound.go new file mode 100644 index 00000000..630d7aef --- /dev/null +++ b/internal/channel/inbound.go @@ -0,0 +1,80 @@ +package channel + +import ( + "context" + "fmt" + "log/slog" +) + +type inboundTask struct { + ctx context.Context + cfg ChannelConfig + msg InboundMessage +} + +// HandleInbound enqueues an inbound message for asynchronous processing by the worker pool. +func (m *Manager) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage) error { + if m.processor == nil { + return fmt.Errorf("inbound processor not configured") + } + if ctx == nil { + ctx = context.Background() + } + m.startInboundWorkers(ctx) + if m.inboundCtx != nil && m.inboundCtx.Err() != nil { + return fmt.Errorf("inbound dispatcher stopped") + } + task := inboundTask{ + ctx: context.WithoutCancel(ctx), + cfg: cfg, + msg: msg, + } + select { + case m.inboundQueue <- task: + return nil + default: + return fmt.Errorf("inbound queue full") + } +} + +func (m *Manager) handleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage) error { + if m.processor == nil { + return fmt.Errorf("inbound processor not configured") + } + sender := m.newReplySender(cfg, msg.Channel) + if err := m.processor.HandleInbound(ctx, cfg, msg, sender); err != nil { + if m.logger != nil { + m.logger.Error("inbound processing failed", slog.String("channel", msg.Channel.String()), slog.Any("error", err)) + } + return err + } + return nil +} + +func (m *Manager) startInboundWorkers(ctx context.Context) { + m.inboundOnce.Do(func() { + workerCtx := ctx + if workerCtx == nil { + workerCtx = context.Background() + } + m.inboundCtx, m.inboundCancel = context.WithCancel(workerCtx) + for i := 0; i < m.inboundWorkers; i++ { + go m.runInboundWorker(m.inboundCtx) + } + }) +} + +func (m *Manager) runInboundWorker(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case task := <-m.inboundQueue: + if err := m.handleInbound(task.ctx, task.cfg, task.msg); err != nil { + if m.logger != nil { + m.logger.Error("inbound processing failed", slog.String("channel", task.msg.Channel.String()), slog.Any("error", err)) + } + } + } + } +} diff --git a/internal/channel/manager.go b/internal/channel/manager.go index d46a61f3..feead996 100644 --- a/internal/channel/manager.go +++ b/internal/channel/manager.go @@ -10,28 +10,55 @@ import ( "time" ) -// ConfigStore abstracts the persistence layer used by the Manager. -type ConfigStore interface { +// ConfigLister lists channel configs for periodic refresh. Used by connection lifecycle. +type ConfigLister interface { + ListConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelConfig, error) +} + +// ConfigResolver resolves effective configs and user bindings. Used for outbound sending. +type ConfigResolver interface { ResolveEffectiveConfig(ctx context.Context, botID string, channelType ChannelType) (ChannelConfig, error) GetUserConfig(ctx context.Context, actorUserID string, channelType ChannelType) (ChannelUserBinding, error) - UpsertUserConfig(ctx context.Context, actorUserID string, channelType ChannelType, req UpsertUserConfigRequest) (ChannelUserBinding, error) - ListConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelConfig, error) +} + +// BindingStore resolves user-channel bindings. Used by identity resolution. +type BindingStore interface { ResolveUserBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) - ListSessionsByBotPlatform(ctx context.Context, botID string, platform string) ([]ChannelSession, error) +} + +// SessionStore manages channel session lifecycle. Used by identity resolution. +type SessionStore interface { GetChannelSession(ctx context.Context, sessionID string) (ChannelSession, error) UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string, replyTarget string, threadID string, metadata map[string]any) error + ListSessionsByBotPlatform(ctx context.Context, botID string, platform string) ([]ChannelSession, error) +} + +// ConfigStore is the full persistence interface. Components should depend on smaller +// interfaces above; ConfigStore exists as a convenience for wiring. +type ConfigStore interface { + ConfigLister + ConfigResolver + BindingStore + SessionStore + UpsertUserConfig(ctx context.Context, actorUserID string, channelType ChannelType, req UpsertUserConfigRequest) (ChannelUserBinding, error) } // Middleware wraps an InboundHandler to add cross-cutting behavior. type Middleware func(next InboundHandler) InboundHandler +// ManagerStore is the minimal persistence interface required by Manager. +type ManagerStore interface { + ConfigLister + ConfigResolver +} + // Manager coordinates channel adapters, connection lifecycle, and message dispatch. +// Connection lifecycle lives in connection.go, inbound dispatch in inbound.go, +// and outbound pipeline in outbound.go. type Manager struct { - service ConfigStore + registry *Registry + service ManagerStore processor InboundProcessor - adapters map[ChannelType]Adapter - senders map[ChannelType]Sender - receivers map[ChannelType]Receiver refreshInterval time.Duration logger *slog.Logger middlewares []Middleware @@ -41,27 +68,22 @@ type Manager struct { inboundOnce sync.Once inboundCtx context.Context inboundCancel context.CancelFunc - adapterMu sync.RWMutex mu sync.Mutex connections map[string]*connectionEntry } -type connectionEntry struct { - config ChannelConfig - connection Connection -} - -// NewManager creates a Manager with the given logger, config store, and inbound processor. -func NewManager(log *slog.Logger, service ConfigStore, processor InboundProcessor) *Manager { +// NewManager creates a Manager with the given logger, registry, config store, and inbound processor. +func NewManager(log *slog.Logger, registry *Registry, service ManagerStore, processor InboundProcessor) *Manager { if log == nil { log = slog.Default() } + if registry == nil { + registry = NewRegistry() + } return &Manager{ + registry: registry, service: service, processor: processor, - adapters: map[ChannelType]Adapter{}, - senders: map[ChannelType]Sender{}, - receivers: map[ChannelType]Receiver{}, refreshInterval: 30 * time.Second, connections: map[string]*connectionEntry{}, logger: log.With(slog.String("component", "channel")), @@ -71,25 +93,27 @@ func NewManager(log *slog.Logger, service ConfigStore, processor InboundProcesso } } +// Registry returns the adapter registry used by this manager. +func (m *Manager) Registry() *Registry { + return m.registry +} + // Use appends middleware to the inbound processing chain. func (m *Manager) Use(mw ...Middleware) { m.middlewares = append(m.middlewares, mw...) } -// RegisterAdapter adds an adapter and indexes its Sender/Receiver capabilities. +// RegisterAdapter adds an adapter to the registry and logs the registration. func (m *Manager) RegisterAdapter(adapter Adapter) { if adapter == nil { return } - m.adapterMu.Lock() - m.adapters[adapter.Type()] = adapter - if sender, ok := adapter.(Sender); ok { - m.senders[adapter.Type()] = sender + if err := m.registry.Register(adapter); err != nil { + if m.logger != nil { + m.logger.Warn("adapter registration failed", slog.String("channel", adapter.Type().String()), slog.Any("error", err)) + } + return } - if receiver, ok := adapter.(Receiver); ok { - m.receivers[adapter.Type()] = receiver - } - m.adapterMu.Unlock() if m.logger != nil { m.logger.Info("adapter registered", slog.String("channel", adapter.Type().String())) } @@ -108,13 +132,9 @@ func (m *Manager) RemoveAdapter(ctx context.Context, channelType ChannelType) { if ctx == nil { ctx = context.Background() } - normalized := normalizeChannelType(channelType.String()) - if normalized == "" { - return - } m.mu.Lock() for id, entry := range m.connections { - if entry != nil && entry.config.ChannelType == normalized { + if entry != nil && entry.config.ChannelType == channelType { if entry.connection != nil { if err := entry.connection.Stop(ctx); err != nil && !errors.Is(err, ErrStopNotSupported) && m.logger != nil { m.logger.Warn("adapter stop failed", slog.String("config_id", id), slog.Any("error", err)) @@ -124,12 +144,7 @@ func (m *Manager) RemoveAdapter(ctx context.Context, channelType ChannelType) { } } m.mu.Unlock() - - m.adapterMu.Lock() - delete(m.adapters, normalized) - delete(m.senders, normalized) - delete(m.receivers, normalized) - m.adapterMu.Unlock() + m.registry.Unregister(channelType) } // Start begins the periodic config refresh loop and inbound worker pool. @@ -162,10 +177,8 @@ func (m *Manager) Send(ctx context.Context, botID string, channelType ChannelTyp if m.service == nil { return fmt.Errorf("channel manager not configured") } - m.adapterMu.RLock() - sender := m.senders[channelType] - m.adapterMu.RUnlock() - if sender == nil { + sender, ok := m.registry.GetSender(channelType) + if !ok { return fmt.Errorf("unsupported channel type: %s", channelType) } config, err := m.service.ResolveEffectiveConfig(ctx, botID, channelType) @@ -185,12 +198,12 @@ func (m *Manager) Send(ctx context.Context, botID string, channelType ChannelTyp } return fmt.Errorf("channel binding required") } - target, err = ResolveTargetFromUserConfig(channelType, userCfg.Config) + target, err = m.registry.ResolveTargetFromUserConfig(channelType, userCfg.Config) if err != nil { return err } } - if normalized, ok := NormalizeTarget(channelType, target); ok { + if normalized, ok := m.registry.NormalizeTarget(channelType, target); ok { target = normalized } if req.Message.IsEmpty() { @@ -218,205 +231,6 @@ func (m *Manager) Send(ctx context.Context, botID string, channelType ChannelTyp return nil } -// HandleInbound enqueues an inbound message for asynchronous processing by the worker pool. -func (m *Manager) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage) error { - if m.processor == nil { - return fmt.Errorf("inbound processor not configured") - } - if ctx == nil { - ctx = context.Background() - } - m.startInboundWorkers(ctx) - if m.inboundCtx != nil && m.inboundCtx.Err() != nil { - return fmt.Errorf("inbound dispatcher stopped") - } - task := inboundTask{ - ctx: context.WithoutCancel(ctx), - cfg: cfg, - msg: msg, - } - select { - case m.inboundQueue <- task: - return nil - default: - return fmt.Errorf("inbound queue full") - } -} - -func (m *Manager) refresh(ctx context.Context) { - if m.service == nil { - return - } - configs := make([]ChannelConfig, 0) - channelTypes := m.listAdapterTypes() - for _, channelType := range channelTypes { - items, err := m.service.ListConfigsByType(ctx, channelType) - if err != nil { - if m.logger != nil { - m.logger.Error("list configs failed", slog.String("channel", channelType.String()), slog.Any("error", err)) - } - continue - } - configs = append(configs, items...) - } - m.reconcile(ctx, configs) -} - -func (m *Manager) reconcile(ctx context.Context, configs []ChannelConfig) { - active := map[string]ChannelConfig{} - for _, cfg := range configs { - if cfg.ID == "" { - continue - } - status := strings.ToLower(strings.TrimSpace(cfg.Status)) - if status != "" && status != "active" && status != "verified" { - continue - } - active[cfg.ID] = cfg - if err := m.ensureConnection(ctx, cfg); err != nil { - if m.logger != nil { - m.logger.Error("adapter start failed", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID), slog.Any("error", err)) - } - } - } - - m.mu.Lock() - defer m.mu.Unlock() - for id, entry := range m.connections { - if _, ok := active[id]; ok { - continue - } - if entry != nil && entry.connection != nil { - if m.logger != nil { - m.logger.Info("adapter stop", slog.String("channel", entry.config.ChannelType.String()), slog.String("config_id", id)) - } - if err := entry.connection.Stop(ctx); err != nil && !errors.Is(err, ErrStopNotSupported) && m.logger != nil { - m.logger.Warn("adapter stop failed", slog.String("config_id", id), slog.Any("error", err)) - } - } - delete(m.connections, id) - } -} - -func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error { - m.adapterMu.RLock() - receiver := m.receivers[cfg.ChannelType] - m.adapterMu.RUnlock() - if receiver == nil { - return nil - } - - m.mu.Lock() - entry := m.connections[cfg.ID] - if entry != nil && !entry.config.UpdatedAt.Before(cfg.UpdatedAt) { - m.mu.Unlock() - return nil - } - if entry != nil { - m.mu.Unlock() - if m.logger != nil { - m.logger.Info("adapter restart", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) - } - if err := entry.connection.Stop(ctx); err != nil { - if errors.Is(err, ErrStopNotSupported) { - if m.logger != nil { - m.logger.Warn("adapter restart skipped", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) - } - return nil - } - return err - } - m.mu.Lock() - delete(m.connections, cfg.ID) - m.mu.Unlock() - } else { - m.mu.Unlock() - } - - if m.logger != nil { - m.logger.Info("adapter start", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) - } - handler := m.handleInbound - for i := len(m.middlewares) - 1; i >= 0; i-- { - handler = m.middlewares[i](handler) - } - conn, err := receiver.Connect(ctx, cfg, handler) - if err != nil { - return err - } - m.mu.Lock() - m.connections[cfg.ID] = &connectionEntry{ - config: cfg, - connection: conn, - } - m.mu.Unlock() - return nil -} - -func (m *Manager) stopAll(ctx context.Context) { - m.mu.Lock() - defer m.mu.Unlock() - for id, entry := range m.connections { - if entry != nil && entry.connection != nil { - if m.logger != nil { - m.logger.Info("adapter stop", slog.String("channel", entry.config.ChannelType.String()), slog.String("config_id", id)) - } - if err := entry.connection.Stop(ctx); err != nil && !errors.Is(err, ErrStopNotSupported) && m.logger != nil { - m.logger.Warn("adapter stop failed", slog.String("config_id", id), slog.Any("error", err)) - } - } - delete(m.connections, id) - } -} - -func (m *Manager) handleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage) error { - if m.processor == nil { - return fmt.Errorf("inbound processor not configured") - } - sender := m.newReplySender(cfg, msg.Channel) - if err := m.processor.HandleInbound(ctx, cfg, msg, sender); err != nil { - if m.logger != nil { - m.logger.Error("inbound processing failed", slog.String("channel", msg.Channel.String()), slog.Any("error", err)) - } - return err - } - return nil -} - -// Stop terminates the connection identified by the given config ID. -func (m *Manager) Stop(ctx context.Context, configID string) error { - configID = strings.TrimSpace(configID) - if configID == "" { - return fmt.Errorf("config id is required") - } - m.mu.Lock() - entry := m.connections[configID] - m.mu.Unlock() - if entry == nil || entry.connection == nil { - return nil - } - return entry.connection.Stop(ctx) -} - -// StopByBot terminates all connections belonging to the given bot. -func (m *Manager) StopByBot(ctx context.Context, botID string) error { - botID = strings.TrimSpace(botID) - if botID == "" { - return fmt.Errorf("bot id is required") - } - m.mu.Lock() - defer m.mu.Unlock() - for id, entry := range m.connections { - if entry != nil && entry.config.BotID == botID { - if entry.connection != nil { - _ = entry.connection.Stop(ctx) - } - delete(m.connections, id) - } - } - return nil -} - // Shutdown cancels the inbound worker pool and stops all active connections. func (m *Manager) Shutdown(ctx context.Context) error { if m.inboundCancel != nil { @@ -425,253 +239,3 @@ func (m *Manager) Shutdown(ctx context.Context) error { m.stopAll(ctx) return nil } - -func (m *Manager) newReplySender(cfg ChannelConfig, channelType ChannelType) ReplySender { - m.adapterMu.RLock() - sender := m.senders[channelType] - m.adapterMu.RUnlock() - return &managerReplySender{ - manager: m, - sender: sender, - channelType: channelType, - config: cfg, - } -} - -func (m *Manager) listAdapterTypes() []ChannelType { - m.adapterMu.RLock() - defer m.adapterMu.RUnlock() - items := make([]ChannelType, 0, len(m.adapters)) - for channelType := range m.adapters { - items = append(items, channelType) - } - return items -} - -type inboundTask struct { - ctx context.Context - cfg ChannelConfig - msg InboundMessage -} - -func (m *Manager) startInboundWorkers(ctx context.Context) { - m.inboundOnce.Do(func() { - workerCtx := ctx - if workerCtx == nil { - workerCtx = context.Background() - } - m.inboundCtx, m.inboundCancel = context.WithCancel(workerCtx) - for i := 0; i < m.inboundWorkers; i++ { - go m.runInboundWorker(m.inboundCtx) - } - }) -} - -func (m *Manager) runInboundWorker(ctx context.Context) { - for { - select { - case <-ctx.Done(): - return - case task := <-m.inboundQueue: - if err := m.handleInbound(task.ctx, task.cfg, task.msg); err != nil { - if m.logger != nil { - m.logger.Error("inbound processing failed", slog.String("channel", task.msg.Channel.String()), slog.Any("error", err)) - } - } - } - } -} - -func (m *Manager) resolveOutboundPolicy(channelType ChannelType) OutboundPolicy { - policy, ok := GetChannelOutboundPolicy(channelType) - if !ok { - policy = OutboundPolicy{} - } - return NormalizeOutboundPolicy(policy) -} - -// buildOutboundMessages splits an outbound message into multiple messages based on the policy. -// The caller must pass an already-normalized policy. -func buildOutboundMessages(msg OutboundMessage, policy OutboundPolicy) ([]OutboundMessage, error) { - if msg.Message.IsEmpty() { - return nil, fmt.Errorf("message is required") - } - normalized := normalizeOutboundMessage(msg.Message) - chunker := policy.Chunker - if normalized.Format == MessageFormatMarkdown { - chunker = ChunkMarkdownText - } - base := normalized - base.Attachments = nil - textMessages := make([]OutboundMessage, 0) - shouldChunk := policy.TextChunkLimit > 0 && strings.TrimSpace(base.Text) != "" && len(base.Parts) == 0 - if shouldChunk { - chunks := chunker(base.Text, policy.TextChunkLimit) - for idx, chunk := range chunks { - chunk = strings.TrimSpace(chunk) - if chunk == "" { - continue - } - actions := base.Actions - if len(chunks) > 1 && idx < len(chunks)-1 { - actions = nil - } - item := OutboundMessage{ - Target: msg.Target, - Message: Message{ - ID: base.ID, - Format: base.Format, - Text: chunk, - Parts: base.Parts, - Attachments: nil, - Actions: actions, - Thread: base.Thread, - Reply: base.Reply, - Metadata: base.Metadata, - }, - } - textMessages = append(textMessages, item) - } - } else if !base.IsEmpty() { - textMessages = append(textMessages, OutboundMessage{Target: msg.Target, Message: base}) - } - - attachments := normalized.Attachments - attachmentMessages := make([]OutboundMessage, 0) - if len(attachments) > 0 { - media := normalized - media.Format = "" - media.Text = "" - media.Parts = nil - media.Actions = nil - media.Attachments = attachments - attachmentMessages = append(attachmentMessages, OutboundMessage{Target: msg.Target, Message: media}) - } - - if len(textMessages) == 0 && len(attachmentMessages) == 0 { - return nil, fmt.Errorf("message is required") - } - if policy.MediaOrder == OutboundOrderTextFirst { - return append(textMessages, attachmentMessages...), nil - } - return append(attachmentMessages, textMessages...), nil -} - -func normalizeOutboundMessage(msg Message) Message { - if msg.Format == "" { - if len(msg.Parts) > 0 { - msg.Format = MessageFormatRich - } else if strings.TrimSpace(msg.Text) != "" { - msg.Format = MessageFormatPlain - } - } - return msg -} - -func validateMessageCapabilities(channelType ChannelType, msg Message) error { - caps, ok := GetChannelCapabilities(channelType) - if !ok { - return nil - } - switch msg.Format { - case MessageFormatPlain: - if !caps.Text { - return fmt.Errorf("channel does not support plain text") - } - case MessageFormatMarkdown: - if !caps.Markdown && !caps.RichText { - return fmt.Errorf("channel does not support markdown") - } - case MessageFormatRich: - if !caps.RichText { - return fmt.Errorf("channel does not support rich text") - } - } - if len(msg.Parts) > 0 && !caps.RichText { - return fmt.Errorf("channel does not support rich text") - } - if len(msg.Attachments) > 0 && !caps.Attachments { - return fmt.Errorf("channel does not support attachments") - } - if len(msg.Attachments) > 0 && requiresMedia(msg.Attachments) && !caps.Media { - return fmt.Errorf("channel does not support media") - } - if len(msg.Actions) > 0 && !caps.Buttons { - return fmt.Errorf("channel does not support actions") - } - if msg.Thread != nil && !caps.Threads { - return fmt.Errorf("channel does not support threads") - } - if msg.Reply != nil && !caps.Reply { - return fmt.Errorf("channel does not support reply") - } - return nil -} - -func (m *Manager) sendWithConfig(ctx context.Context, sender Sender, cfg ChannelConfig, msg OutboundMessage, policy OutboundPolicy) error { - if sender == nil { - return fmt.Errorf("unsupported channel type: %s", cfg.ChannelType) - } - target := strings.TrimSpace(msg.Target) - if target == "" { - return fmt.Errorf("target is required") - } - if msg.Message.IsEmpty() { - return fmt.Errorf("message is required") - } - if err := validateMessageCapabilities(cfg.ChannelType, msg.Message); err != nil { - return err - } - var lastErr error - for i := 0; i < policy.RetryMax; i++ { - err := sender.Send(ctx, cfg, OutboundMessage{Target: target, Message: msg.Message}) - if err == nil { - return nil - } - lastErr = err - if m.logger != nil { - m.logger.Warn("send outbound retry", - slog.String("channel", cfg.ChannelType.String()), - slog.Int("attempt", i+1), - slog.Any("error", err)) - } - time.Sleep(time.Duration(i+1) * time.Duration(policy.RetryBackoffMs) * time.Millisecond) - } - return fmt.Errorf("send outbound failed after retries: %w", lastErr) -} - -func requiresMedia(attachments []Attachment) bool { - for _, att := range attachments { - switch att.Type { - case AttachmentAudio, AttachmentVideo, AttachmentVoice, AttachmentGIF: - return true - default: - continue - } - } - return false -} - -type managerReplySender struct { - manager *Manager - sender Sender - channelType ChannelType - config ChannelConfig -} - -func (s *managerReplySender) Send(ctx context.Context, msg OutboundMessage) error { - if s.manager == nil { - return fmt.Errorf("channel manager not configured") - } - policy := s.manager.resolveOutboundPolicy(s.channelType) - outbound, err := buildOutboundMessages(msg, policy) - if err != nil { - return err - } - for _, item := range outbound { - if err := s.manager.sendWithConfig(ctx, s.sender, s.config, item, policy); err != nil { - return err - } - } - return nil -} diff --git a/internal/channel/manager_core_test.go b/internal/channel/manager_core_test.go index 766b1354..9283bcd6 100644 --- a/internal/channel/manager_core_test.go +++ b/internal/channel/manager_core_test.go @@ -13,6 +13,9 @@ type mockAdapter struct { } func (m *mockAdapter) Type() ChannelType { return ChannelType("test") } +func (m *mockAdapter) Descriptor() Descriptor { + return Descriptor{Type: ChannelType("test"), DisplayName: "Test", Capabilities: ChannelCapabilities{Text: true}} +} func (m *mockAdapter) Send(ctx context.Context, cfg ChannelConfig, msg OutboundMessage) error { m.sentMessages = append(m.sentMessages, msg) return nil @@ -53,7 +56,8 @@ func TestManager_HandleInbound_CoreLogic(t *testing.T) { }, } - m := NewManager(logger, &fakeConfigStore{}, processor) + reg := NewRegistry() + m := NewManager(logger, reg, &fakeConfigStore{}, processor) adapter := &mockAdapter{} m.RegisterAdapter(adapter) @@ -87,7 +91,8 @@ func TestManager_HandleInbound_CoreLogic(t *testing.T) { t.Run("无回复_不发送", func(t *testing.T) { processor := &fakeInboundProcessor{resp: nil} - m := NewManager(logger, &fakeConfigStore{}, processor) + reg := NewRegistry() + m := NewManager(logger, reg, &fakeConfigStore{}, processor) adapter := &mockAdapter{} m.RegisterAdapter(adapter) @@ -110,7 +115,8 @@ func TestManager_HandleInbound_CoreLogic(t *testing.T) { t.Run("处理失败_返回错误", func(t *testing.T) { processor := &fakeInboundProcessor{err: context.Canceled} - m := NewManager(logger, &fakeConfigStore{}, processor) + reg := NewRegistry() + m := NewManager(logger, reg, &fakeConfigStore{}, processor) cfg := ChannelConfig{ID: "bot-1"} msg := InboundMessage{Message: Message{Text: " "}} // 空格消息 diff --git a/internal/channel/manager_integration_test.go b/internal/channel/manager_integration_test.go index 0212d3af..82b910c0 100644 --- a/internal/channel/manager_integration_test.go +++ b/internal/channel/manager_integration_test.go @@ -11,31 +11,6 @@ import ( "time" ) -func init() { - _ = RegisterChannel(ChannelDescriptor{ - Type: ChannelType("test"), - DisplayName: "Test", - NormalizeConfig: normalizeEmpty, - NormalizeUserConfig: normalizeEmpty, - ResolveTarget: resolveTestTarget, - Capabilities: ChannelCapabilities{ - Text: true, - }, - }) -} - -func normalizeEmpty(map[string]any) (map[string]any, error) { - return map[string]any{}, nil -} - -func resolveTestTarget(config map[string]any) (string, error) { - value := strings.TrimSpace(ReadString(config, "target")) - if value == "" { - return "", fmt.Errorf("missing target") - } - return "resolved:" + value, nil -} - type fakeConfigStore struct { effectiveConfig ChannelConfig userConfig ChannelUserBinding @@ -122,6 +97,20 @@ func (f *fakeAdapter) Type() ChannelType { return f.channelType } +func (f *fakeAdapter) Descriptor() Descriptor { + return Descriptor{Type: f.channelType, DisplayName: "Fake", Capabilities: ChannelCapabilities{Text: true}} +} + +func (f *fakeAdapter) ResolveTarget(userConfig map[string]any) (string, error) { + value := strings.TrimSpace(ReadString(userConfig, "target")) + if value == "" { + return "", fmt.Errorf("missing target") + } + return "resolved:" + value, nil +} + +func (f *fakeAdapter) NormalizeTarget(raw string) string { return strings.TrimSpace(raw) } + func (f *fakeAdapter) Connect(ctx context.Context, cfg ChannelConfig, handler InboundHandler) (Connection, error) { f.mu.Lock() f.started = append(f.started, cfg) @@ -161,8 +150,9 @@ func TestManagerHandleInboundIntegratesAdapter(t *testing.T) { }, }, } + reg := NewRegistry() adapter := &fakeAdapter{channelType: ChannelType("test")} - manager := NewManager(log, store, processor) + manager := NewManager(log, reg, store, processor) manager.RegisterAdapter(adapter) cfg := ChannelConfig{ @@ -217,8 +207,9 @@ func TestManagerSendUsesBinding(t *testing.T) { Config: map[string]any{"target": "alice"}, }, } + reg := NewRegistry() adapter := &fakeAdapter{channelType: ChannelType("test")} - manager := NewManager(log, store, &fakeInboundProcessorIntegration{}) + manager := NewManager(log, reg, store, &fakeInboundProcessorIntegration{}) manager.RegisterAdapter(adapter) err := manager.Send(context.Background(), "bot-1", ChannelType("test"), SendRequest{ @@ -246,8 +237,9 @@ func TestManagerReconcileStartsAndStops(t *testing.T) { log := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) store := &fakeConfigStore{} + reg := NewRegistry() adapter := &fakeAdapter{channelType: ChannelType("test")} - manager := NewManager(log, store, &fakeInboundProcessorIntegration{}) + manager := NewManager(log, reg, store, &fakeInboundProcessorIntegration{}) manager.RegisterAdapter(adapter) cfg := ChannelConfig{ diff --git a/internal/channel/manager_test.go b/internal/channel/manager_test.go index a35c45c2..8c9cfb59 100644 --- a/internal/channel/manager_test.go +++ b/internal/channel/manager_test.go @@ -8,9 +8,9 @@ import ( func TestResolveTargetFromUserConfig(t *testing.T) { t.Parallel() - registerTestChannel() + reg := newTestConfigRegistry() - target, err := channel.ResolveTargetFromUserConfig(testChannelType, map[string]any{ + target, err := reg.ResolveTargetFromUserConfig(testChannelType, map[string]any{ "target": "alice", }) if err != nil { @@ -23,8 +23,9 @@ func TestResolveTargetFromUserConfig(t *testing.T) { func TestResolveTargetFromUserConfigUnsupported(t *testing.T) { t.Parallel() + reg := channel.NewRegistry() - _, err := channel.ResolveTargetFromUserConfig("unknown", map[string]any{}) + _, err := reg.ResolveTargetFromUserConfig("unknown", map[string]any{}) if err == nil { t.Fatalf("expected error, got nil") } diff --git a/internal/channel/outbound.go b/internal/channel/outbound.go index b16644dd..a8d1482d 100644 --- a/internal/channel/outbound.go +++ b/internal/channel/outbound.go @@ -1,6 +1,12 @@ package channel -import "strings" +import ( + "context" + "fmt" + "log/slog" + "strings" + "time" +) // ChunkerMode selects the text chunking strategy. type ChunkerMode string @@ -171,3 +177,208 @@ func splitLongLine(line string, limit int) []string { } return chunks } + +// --- Outbound pipeline methods (used by Manager) --- + +func (m *Manager) resolveOutboundPolicy(channelType ChannelType) OutboundPolicy { + policy, ok := m.registry.GetOutboundPolicy(channelType) + if !ok { + policy = OutboundPolicy{} + } + return NormalizeOutboundPolicy(policy) +} + +// buildOutboundMessages splits an outbound message into multiple messages based on the policy. +func buildOutboundMessages(msg OutboundMessage, policy OutboundPolicy) ([]OutboundMessage, error) { + if msg.Message.IsEmpty() { + return nil, fmt.Errorf("message is required") + } + normalized := normalizeOutboundMessage(msg.Message) + chunker := policy.Chunker + if normalized.Format == MessageFormatMarkdown { + chunker = ChunkMarkdownText + } + base := normalized + base.Attachments = nil + textMessages := make([]OutboundMessage, 0) + shouldChunk := policy.TextChunkLimit > 0 && strings.TrimSpace(base.Text) != "" && len(base.Parts) == 0 + if shouldChunk { + chunks := chunker(base.Text, policy.TextChunkLimit) + for idx, chunk := range chunks { + chunk = strings.TrimSpace(chunk) + if chunk == "" { + continue + } + actions := base.Actions + if len(chunks) > 1 && idx < len(chunks)-1 { + actions = nil + } + item := OutboundMessage{ + Target: msg.Target, + Message: Message{ + ID: base.ID, + Format: base.Format, + Text: chunk, + Parts: base.Parts, + Attachments: nil, + Actions: actions, + Thread: base.Thread, + Reply: base.Reply, + Metadata: base.Metadata, + }, + } + textMessages = append(textMessages, item) + } + } else if !base.IsEmpty() { + textMessages = append(textMessages, OutboundMessage{Target: msg.Target, Message: base}) + } + + attachments := normalized.Attachments + attachmentMessages := make([]OutboundMessage, 0) + if len(attachments) > 0 { + media := normalized + media.Format = "" + media.Text = "" + media.Parts = nil + media.Actions = nil + media.Attachments = attachments + attachmentMessages = append(attachmentMessages, OutboundMessage{Target: msg.Target, Message: media}) + } + + if len(textMessages) == 0 && len(attachmentMessages) == 0 { + return nil, fmt.Errorf("message is required") + } + if policy.MediaOrder == OutboundOrderTextFirst { + return append(textMessages, attachmentMessages...), nil + } + return append(attachmentMessages, textMessages...), nil +} + +func normalizeOutboundMessage(msg Message) Message { + if msg.Format == "" { + if len(msg.Parts) > 0 { + msg.Format = MessageFormatRich + } else if strings.TrimSpace(msg.Text) != "" { + msg.Format = MessageFormatPlain + } + } + return msg +} + +func validateMessageCapabilities(registry *Registry, channelType ChannelType, msg Message) error { + caps, ok := registry.GetCapabilities(channelType) + if !ok { + return nil + } + switch msg.Format { + case MessageFormatPlain: + if !caps.Text { + return fmt.Errorf("channel does not support plain text") + } + case MessageFormatMarkdown: + if !caps.Markdown && !caps.RichText { + return fmt.Errorf("channel does not support markdown") + } + case MessageFormatRich: + if !caps.RichText { + return fmt.Errorf("channel does not support rich text") + } + } + if len(msg.Parts) > 0 && !caps.RichText { + return fmt.Errorf("channel does not support rich text") + } + if len(msg.Attachments) > 0 && !caps.Attachments { + return fmt.Errorf("channel does not support attachments") + } + if len(msg.Attachments) > 0 && requiresMedia(msg.Attachments) && !caps.Media { + return fmt.Errorf("channel does not support media") + } + if len(msg.Actions) > 0 && !caps.Buttons { + return fmt.Errorf("channel does not support actions") + } + if msg.Thread != nil && !caps.Threads { + return fmt.Errorf("channel does not support threads") + } + if msg.Reply != nil && !caps.Reply { + return fmt.Errorf("channel does not support reply") + } + return nil +} + +func (m *Manager) sendWithConfig(ctx context.Context, sender Sender, cfg ChannelConfig, msg OutboundMessage, policy OutboundPolicy) error { + if sender == nil { + return fmt.Errorf("unsupported channel type: %s", cfg.ChannelType) + } + target := strings.TrimSpace(msg.Target) + if target == "" { + return fmt.Errorf("target is required") + } + if msg.Message.IsEmpty() { + return fmt.Errorf("message is required") + } + if err := validateMessageCapabilities(m.registry, cfg.ChannelType, msg.Message); err != nil { + return err + } + var lastErr error + for i := 0; i < policy.RetryMax; i++ { + err := sender.Send(ctx, cfg, OutboundMessage{Target: target, Message: msg.Message}) + if err == nil { + return nil + } + lastErr = err + if m.logger != nil { + m.logger.Warn("send outbound retry", + slog.String("channel", cfg.ChannelType.String()), + slog.Int("attempt", i+1), + slog.Any("error", err)) + } + time.Sleep(time.Duration(i+1) * time.Duration(policy.RetryBackoffMs) * time.Millisecond) + } + return fmt.Errorf("send outbound failed after retries: %w", lastErr) +} + +func requiresMedia(attachments []Attachment) bool { + for _, att := range attachments { + switch att.Type { + case AttachmentAudio, AttachmentVideo, AttachmentVoice, AttachmentGIF: + return true + default: + continue + } + } + return false +} + +func (m *Manager) newReplySender(cfg ChannelConfig, channelType ChannelType) ReplySender { + sender, _ := m.registry.GetSender(channelType) + return &managerReplySender{ + manager: m, + sender: sender, + channelType: channelType, + config: cfg, + } +} + +type managerReplySender struct { + manager *Manager + sender Sender + channelType ChannelType + config ChannelConfig +} + +func (s *managerReplySender) Send(ctx context.Context, msg OutboundMessage) error { + if s.manager == nil { + return fmt.Errorf("channel manager not configured") + } + policy := s.manager.resolveOutboundPolicy(s.channelType) + outbound, err := buildOutboundMessages(msg, policy) + if err != nil { + return err + } + for _, item := range outbound { + if err := s.manager.sendWithConfig(ctx, s.sender, s.config, item, policy); err != nil { + return err + } + } + return nil +} diff --git a/internal/channel/registry.go b/internal/channel/registry.go index f70ed507..b9a629ce 100644 --- a/internal/channel/registry.go +++ b/internal/channel/registry.go @@ -6,124 +6,158 @@ import ( "sync" ) -// ChannelDescriptor holds all metadata and hooks for a registered channel type. -type ChannelDescriptor struct { - Type ChannelType - DisplayName string - NormalizeConfig func(map[string]any) (map[string]any, error) - NormalizeUserConfig func(map[string]any) (map[string]any, error) - ResolveTarget func(map[string]any) (string, error) - MatchBinding func(map[string]any, BindingCriteria) bool - BuildUserConfig func(Identity) map[string]any - Configless bool - Capabilities ChannelCapabilities - OutboundPolicy OutboundPolicy - ConfigSchema ConfigSchema - UserConfigSchema ConfigSchema - TargetSpec TargetSpec - NormalizeTarget func(string) string +// Registry holds all registered channel adapters and provides dispatch methods +// for configuration normalization, target resolution, and binding operations. +// It replaces the former global registry, and must be created via NewRegistry +// and passed explicitly to components that need it. +type Registry struct { + mu sync.RWMutex + adapters map[ChannelType]Adapter } -type channelRegistry struct { - mu sync.RWMutex - items map[ChannelType]ChannelDescriptor +// NewRegistry creates an empty Registry. +func NewRegistry() *Registry { + return &Registry{ + adapters: map[ChannelType]Adapter{}, + } } -var registry = &channelRegistry{ - items: map[ChannelType]ChannelDescriptor{}, -} - -// RegisterChannel adds a channel descriptor to the global registry. -func RegisterChannel(desc ChannelDescriptor) error { - normalized := normalizeChannelType(string(desc.Type)) - if normalized == "" { +// Register adds an adapter to the registry. +func (r *Registry) Register(adapter Adapter) error { + if adapter == nil { + return fmt.Errorf("adapter is nil") + } + ct := normalizeChannelType(adapter.Type().String()) + if ct == "" { return fmt.Errorf("channel type is required") } - desc.Type = normalized - if strings.TrimSpace(desc.DisplayName) == "" { - desc.DisplayName = normalized.String() + r.mu.Lock() + defer r.mu.Unlock() + if _, exists := r.adapters[ct]; exists { + return fmt.Errorf("channel type already registered: %s", ct) } - registry.mu.Lock() - defer registry.mu.Unlock() - if _, exists := registry.items[desc.Type]; exists { - return fmt.Errorf("channel type already registered: %s", desc.Type) - } - registry.items[desc.Type] = desc + r.adapters[ct] = adapter return nil } -// MustRegisterChannel calls RegisterChannel and panics on error. -func MustRegisterChannel(desc ChannelDescriptor) { - if err := RegisterChannel(desc); err != nil { +// MustRegister calls Register and panics on error. +func (r *Registry) MustRegister(adapter Adapter) { + if err := r.Register(adapter); err != nil { panic(err) } } -// UnregisterChannel removes a channel type from the global registry. -func UnregisterChannel(channelType ChannelType) bool { - normalized := normalizeChannelType(channelType.String()) - if normalized == "" { +// Unregister removes a channel type from the registry. +func (r *Registry) Unregister(channelType ChannelType) bool { + ct := normalizeChannelType(channelType.String()) + if ct == "" { return false } - registry.mu.Lock() - defer registry.mu.Unlock() - if _, exists := registry.items[normalized]; !exists { + r.mu.Lock() + defer r.mu.Unlock() + if _, exists := r.adapters[ct]; !exists { return false } - delete(registry.items, normalized) + delete(r.adapters, ct) return true } -// GetChannelDescriptor returns the descriptor for the given channel type. -func GetChannelDescriptor(channelType ChannelType) (ChannelDescriptor, bool) { - normalized := normalizeChannelType(channelType.String()) - registry.mu.RLock() - defer registry.mu.RUnlock() - desc, ok := registry.items[normalized] - return desc, ok +// Get returns the adapter for the given channel type. +func (r *Registry) Get(channelType ChannelType) (Adapter, bool) { + ct := normalizeChannelType(channelType.String()) + r.mu.RLock() + defer r.mu.RUnlock() + adapter, ok := r.adapters[ct] + return adapter, ok } -// ListChannelDescriptors returns all registered channel descriptors. -func ListChannelDescriptors() []ChannelDescriptor { - registry.mu.RLock() - defer registry.mu.RUnlock() - items := make([]ChannelDescriptor, 0, len(registry.items)) - for _, item := range registry.items { - items = append(items, item) +// List returns all registered adapters. +func (r *Registry) List() []Adapter { + r.mu.RLock() + defer r.mu.RUnlock() + items := make([]Adapter, 0, len(r.adapters)) + for _, a := range r.adapters { + items = append(items, a) } return items } -// GetChannelCapabilities returns the capability matrix for the given channel type. -func GetChannelCapabilities(channelType ChannelType) (ChannelCapabilities, bool) { - desc, ok := GetChannelDescriptor(channelType) +// Types returns all registered channel types. +func (r *Registry) Types() []ChannelType { + r.mu.RLock() + defer r.mu.RUnlock() + items := make([]ChannelType, 0, len(r.adapters)) + for ct := range r.adapters { + items = append(items, ct) + } + return items +} + +// --- Descriptor accessors --- + +// GetDescriptor returns the descriptor for the given channel type. +func (r *Registry) GetDescriptor(channelType ChannelType) (Descriptor, bool) { + adapter, ok := r.Get(channelType) + if !ok { + return Descriptor{}, false + } + return adapter.Descriptor(), true +} + +// ListDescriptors returns descriptors for all registered channel types. +func (r *Registry) ListDescriptors() []Descriptor { + adapters := r.List() + items := make([]Descriptor, 0, len(adapters)) + for _, a := range adapters { + items = append(items, a.Descriptor()) + } + return items +} + +// ParseChannelType validates and normalizes a raw string into a registered ChannelType. +func (r *Registry) ParseChannelType(raw string) (ChannelType, error) { + ct := normalizeChannelType(raw) + if ct == "" { + return "", fmt.Errorf("unsupported channel type: %s", raw) + } + if _, ok := r.Get(ct); !ok { + return "", fmt.Errorf("unsupported channel type: %s", raw) + } + return ct, nil +} + +// --- Capability accessors --- + +// GetCapabilities returns the capability matrix for the given channel type. +func (r *Registry) GetCapabilities(channelType ChannelType) (ChannelCapabilities, bool) { + desc, ok := r.GetDescriptor(channelType) if !ok { return ChannelCapabilities{}, false } return desc.Capabilities, true } -// GetChannelOutboundPolicy returns the outbound policy for the given channel type. -func GetChannelOutboundPolicy(channelType ChannelType) (OutboundPolicy, bool) { - desc, ok := GetChannelDescriptor(channelType) +// GetOutboundPolicy returns the outbound policy for the given channel type. +func (r *Registry) GetOutboundPolicy(channelType ChannelType) (OutboundPolicy, bool) { + desc, ok := r.GetDescriptor(channelType) if !ok { return OutboundPolicy{}, false } return desc.OutboundPolicy, true } -// GetChannelConfigSchema returns the configuration schema for the given channel type. -func GetChannelConfigSchema(channelType ChannelType) (ConfigSchema, bool) { - desc, ok := GetChannelDescriptor(channelType) +// GetConfigSchema returns the configuration schema for the given channel type. +func (r *Registry) GetConfigSchema(channelType ChannelType) (ConfigSchema, bool) { + desc, ok := r.GetDescriptor(channelType) if !ok { return ConfigSchema{}, false } return desc.ConfigSchema, true } -// GetChannelUserConfigSchema returns the user-binding configuration schema for the given channel type. -func GetChannelUserConfigSchema(channelType ChannelType) (ConfigSchema, bool) { - desc, ok := GetChannelDescriptor(channelType) +// GetUserConfigSchema returns the user-binding configuration schema. +func (r *Registry) GetUserConfigSchema(channelType ChannelType) (ConfigSchema, bool) { + desc, ok := r.GetDescriptor(channelType) if !ok { return ConfigSchema{}, false } @@ -131,14 +165,120 @@ func GetChannelUserConfigSchema(channelType ChannelType) (ConfigSchema, bool) { } // IsConfigless reports whether the channel type operates without per-bot configuration. -func IsConfigless(channelType ChannelType) bool { - desc, ok := GetChannelDescriptor(channelType) +func (r *Registry) IsConfigless(channelType ChannelType) bool { + desc, ok := r.GetDescriptor(channelType) if !ok { return false } return desc.Configless } +// --- Sender / Receiver accessors --- + +// GetSender returns the Sender for the given channel type, or nil if unsupported. +func (r *Registry) GetSender(channelType ChannelType) (Sender, bool) { + adapter, ok := r.Get(channelType) + if !ok { + return nil, false + } + sender, ok := adapter.(Sender) + return sender, ok +} + +// GetReceiver returns the Receiver for the given channel type, or nil if unsupported. +func (r *Registry) GetReceiver(channelType ChannelType) (Receiver, bool) { + adapter, ok := r.Get(channelType) + if !ok { + return nil, false + } + receiver, ok := adapter.(Receiver) + return receiver, ok +} + +// --- Dispatch methods (replace former global functions in config.go / target.go) --- + +// NormalizeConfig validates and normalizes a channel configuration map. +func (r *Registry) NormalizeConfig(channelType ChannelType, raw map[string]any) (map[string]any, error) { + if raw == nil { + raw = map[string]any{} + } + adapter, ok := r.Get(channelType) + if !ok { + return nil, fmt.Errorf("unsupported channel type: %s", channelType) + } + if normalizer, ok := adapter.(ConfigNormalizer); ok { + return normalizer.NormalizeConfig(raw) + } + return raw, nil +} + +// NormalizeUserConfig validates and normalizes a user-channel binding configuration. +func (r *Registry) NormalizeUserConfig(channelType ChannelType, raw map[string]any) (map[string]any, error) { + if raw == nil { + raw = map[string]any{} + } + adapter, ok := r.Get(channelType) + if !ok { + return nil, fmt.Errorf("unsupported channel type: %s", channelType) + } + if normalizer, ok := adapter.(ConfigNormalizer); ok { + return normalizer.NormalizeUserConfig(raw) + } + return raw, nil +} + +// ResolveTargetFromUserConfig derives a delivery target from a user-channel binding. +func (r *Registry) ResolveTargetFromUserConfig(channelType ChannelType, config map[string]any) (string, error) { + adapter, ok := r.Get(channelType) + if !ok { + return "", fmt.Errorf("unsupported channel type: %s", channelType) + } + if resolver, ok := adapter.(TargetResolver); ok { + return resolver.ResolveTarget(config) + } + return "", fmt.Errorf("channel type %s does not support target resolution", channelType) +} + +// NormalizeTarget applies the channel-specific target normalization. +func (r *Registry) NormalizeTarget(channelType ChannelType, raw string) (string, bool) { + adapter, ok := r.Get(channelType) + if !ok { + return strings.TrimSpace(raw), false + } + if resolver, ok := adapter.(TargetResolver); ok { + normalized := strings.TrimSpace(resolver.NormalizeTarget(raw)) + if normalized == "" { + return "", false + } + return normalized, true + } + return strings.TrimSpace(raw), false +} + +// MatchUserBinding reports whether the given binding config matches the criteria. +func (r *Registry) MatchUserBinding(channelType ChannelType, config map[string]any, criteria BindingCriteria) bool { + adapter, ok := r.Get(channelType) + if !ok { + return false + } + if matcher, ok := adapter.(BindingMatcher); ok { + return matcher.MatchBinding(config, criteria) + } + return false +} + +// BuildUserBindingConfig constructs a user-channel binding config from an Identity. +func (r *Registry) BuildUserBindingConfig(channelType ChannelType, identity Identity) map[string]any { + adapter, ok := r.Get(channelType) + if !ok { + return map[string]any{} + } + if matcher, ok := adapter.(BindingMatcher); ok { + return matcher.BuildUserConfig(identity) + } + return map[string]any{} +} + func normalizeChannelType(raw string) ChannelType { normalized := strings.TrimSpace(strings.ToLower(raw)) if normalized == "" { diff --git a/internal/channel/service.go b/internal/channel/service.go index 5a79ae90..d11437a2 100644 --- a/internal/channel/service.go +++ b/internal/channel/service.go @@ -8,21 +8,25 @@ import ( "strings" "time" - "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) // Service provides CRUD operations for channel configurations, user bindings, and sessions. type Service struct { - queries *sqlc.Queries + queries *sqlc.Queries + registry *Registry } -// NewService creates a Service backed by the given database queries. -func NewService(queries *sqlc.Queries) *Service { - return &Service{queries: queries} +// NewService creates a Service backed by the given database queries and adapter registry. +func NewService(queries *sqlc.Queries, registry *Registry) *Service { + if registry == nil { + registry = NewRegistry() + } + return &Service{queries: queries, registry: registry} } // UpsertConfig creates or updates a bot's channel configuration. @@ -33,7 +37,7 @@ func (s *Service) UpsertConfig(ctx context.Context, botID string, channelType Ch if channelType == "" { return ChannelConfig{}, fmt.Errorf("channel type is required") } - normalized, err := NormalizeChannelConfig(channelType, req.Credentials) + normalized, err := s.registry.NormalizeConfig(channelType, req.Credentials) if err != nil { return ChannelConfig{}, err } @@ -41,7 +45,7 @@ func (s *Service) UpsertConfig(ctx context.Context, botID string, channelType Ch if err != nil { return ChannelConfig{}, err } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return ChannelConfig{}, err } @@ -98,7 +102,7 @@ func (s *Service) UpsertUserConfig(ctx context.Context, actorUserID string, chan if channelType == "" { return ChannelUserBinding{}, fmt.Errorf("channel type is required") } - normalized, err := NormalizeChannelUserConfig(channelType, req.Config) + normalized, err := s.registry.NormalizeUserConfig(channelType, req.Config) if err != nil { return ChannelUserBinding{}, err } @@ -106,7 +110,7 @@ func (s *Service) UpsertUserConfig(ctx context.Context, actorUserID string, chan if err != nil { return ChannelUserBinding{}, err } - pgUserID, err := parseUUID(actorUserID) + pgUserID, err := db.ParseUUID(actorUserID) if err != nil { return ChannelUserBinding{}, err } @@ -130,14 +134,14 @@ func (s *Service) ResolveEffectiveConfig(ctx context.Context, botID string, chan if channelType == "" { return ChannelConfig{}, fmt.Errorf("channel type is required") } - if IsConfigless(channelType) { + if s.registry.IsConfigless(channelType) { return ChannelConfig{ ID: channelType.String() + ":" + strings.TrimSpace(botID), BotID: strings.TrimSpace(botID), ChannelType: channelType, }, nil } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return ChannelConfig{}, err } @@ -159,7 +163,7 @@ func (s *Service) ListConfigsByType(ctx context.Context, channelType ChannelType if s.queries == nil { return nil, fmt.Errorf("channel queries not configured") } - if IsConfigless(channelType) { + if s.registry.IsConfigless(channelType) { return []ChannelConfig{}, nil } rows, err := s.queries.ListBotChannelConfigsByType(ctx, channelType.String()) @@ -185,7 +189,7 @@ func (s *Service) GetUserConfig(ctx context.Context, actorUserID string, channel if channelType == "" { return ChannelUserBinding{}, fmt.Errorf("channel type is required") } - pgUserID, err := parseUUID(actorUserID) + pgUserID, err := db.ParseUUID(actorUserID) if err != nil { return ChannelUserBinding{}, err } @@ -204,12 +208,12 @@ func (s *Service) GetUserConfig(ctx context.Context, actorUserID string, channel return ChannelUserBinding{}, err } return ChannelUserBinding{ - ID: toUUIDString(row.ID), + ID: db.UUIDToString(row.ID), ChannelType: ChannelType(row.ChannelType), - UserID: toUUIDString(row.UserID), + UserID: db.UUIDToString(row.UserID), Config: config, - CreatedAt: timeFromPg(row.CreatedAt), - UpdatedAt: timeFromPg(row.UpdatedAt), + CreatedAt: db.TimeFromPg(row.CreatedAt), + UpdatedAt: db.TimeFromPg(row.UpdatedAt), }, nil } @@ -261,7 +265,7 @@ func (s *Service) ListSessionsByBotPlatform(ctx context.Context, botID, platform if platform == "" { return nil, fmt.Errorf("platform is required") } - pgBotID, err := parseUUID(botID) + pgBotID, err := db.ParseUUID(botID) if err != nil { return nil, err } @@ -290,26 +294,26 @@ func (s *Service) UpsertChannelSession(ctx context.Context, sessionID string, bo } pgUserID := pgtype.UUID{Valid: false} if strings.TrimSpace(userID) != "" { - parsed, err := parseUUID(userID) + parsed, err := db.ParseUUID(userID) if err != nil { return err } pgUserID = parsed } - botUUID, err := parseUUID(botID) + botUUID, err := db.ParseUUID(botID) if err != nil { return err } var channelUUID pgtype.UUID if strings.TrimSpace(channelConfigID) != "" { - channelUUID, err = parseUUID(channelConfigID) + channelUUID, err = db.ParseUUID(channelConfigID) if err != nil { return err } } pgContactID := pgtype.UUID{Valid: false} if strings.TrimSpace(contactID) != "" { - parsed, err := parseUUID(contactID) + parsed, err := db.ParseUUID(contactID) if err != nil { return err } @@ -349,11 +353,11 @@ func (s *Service) ResolveUserBinding(ctx context.Context, channelType ChannelTyp if err != nil { return "", err } - if _, ok := GetChannelDescriptor(channelType); !ok { + if _, ok := s.registry.Get(channelType); !ok { return "", fmt.Errorf("unsupported channel type: %s", channelType) } for _, row := range rows { - if MatchUserBinding(channelType, row.Config, criteria) { + if s.registry.MatchUserBinding(channelType, row.Config, criteria) { return row.UserID, nil } } @@ -382,8 +386,8 @@ func normalizeChannelConfig(row sqlc.BotChannelConfig) (ChannelConfig, error) { externalIdentity = strings.TrimSpace(row.ExternalIdentity.String) } return ChannelConfig{ - ID: toUUIDString(row.ID), - BotID: toUUIDString(row.BotID), + ID: db.UUIDToString(row.ID), + BotID: db.UUIDToString(row.BotID), ChannelType: ChannelType(row.ChannelType), Credentials: credentials, ExternalIdentity: externalIdentity, @@ -391,8 +395,8 @@ func normalizeChannelConfig(row sqlc.BotChannelConfig) (ChannelConfig, error) { Routing: routing, Status: strings.TrimSpace(row.Status), VerifiedAt: verifiedAt, - CreatedAt: timeFromPg(row.CreatedAt), - UpdatedAt: timeFromPg(row.UpdatedAt), + CreatedAt: db.TimeFromPg(row.CreatedAt), + UpdatedAt: db.TimeFromPg(row.UpdatedAt), }, nil } @@ -402,12 +406,12 @@ func normalizeChannelUserBindingRow(row sqlc.UserChannelBinding) (ChannelUserBin return ChannelUserBinding{}, err } return ChannelUserBinding{ - ID: toUUIDString(row.ID), + ID: db.UUIDToString(row.ID), ChannelType: ChannelType(row.ChannelType), - UserID: toUUIDString(row.UserID), + UserID: db.UUIDToString(row.UserID), Config: config, - CreatedAt: timeFromPg(row.CreatedAt), - UpdatedAt: timeFromPg(row.UpdatedAt), + CreatedAt: db.TimeFromPg(row.CreatedAt), + UpdatedAt: db.TimeFromPg(row.UpdatedAt), }, nil } @@ -418,49 +422,18 @@ func normalizeChannelSession(row sqlc.ChannelSession) (ChannelSession, error) { } return ChannelSession{ SessionID: row.SessionID, - BotID: toUUIDString(row.BotID), - ChannelConfigID: toUUIDString(row.ChannelConfigID), - UserID: toUUIDString(row.UserID), - ContactID: toUUIDString(row.ContactID), + BotID: db.UUIDToString(row.BotID), + ChannelConfigID: db.UUIDToString(row.ChannelConfigID), + UserID: db.UUIDToString(row.UserID), + ContactID: db.UUIDToString(row.ContactID), Platform: row.Platform, ReplyTarget: strings.TrimSpace(row.ReplyTarget.String), ThreadID: strings.TrimSpace(row.ThreadID.String), Metadata: metadata, - CreatedAt: timeFromPg(row.CreatedAt), - UpdatedAt: timeFromPg(row.UpdatedAt), + CreatedAt: db.TimeFromPg(row.CreatedAt), + UpdatedAt: db.TimeFromPg(row.UpdatedAt), }, nil } -func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := uuid.Parse(strings.TrimSpace(id)) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - var pgID pgtype.UUID - pgID.Valid = true - copy(pgID.Bytes[:], parsed[:]) - return pgID, nil -} -func toUUIDString(value pgtype.UUID) string { - if !value.Valid { - return "" - } - parsed, err := uuid.FromBytes(value.Bytes[:]) - if err != nil { - return "" - } - return parsed.String() -} -func timeFromPg(value pgtype.Timestamptz) time.Time { - if value.Valid { - return value.Time - } - return time.Time{} -} - -// String returns the channel type as a plain string. -func (c ChannelType) String() string { - return string(c) -} diff --git a/internal/channel/target.go b/internal/channel/target.go index 28bf465b..93469912 100644 --- a/internal/channel/target.go +++ b/internal/channel/target.go @@ -1,7 +1,5 @@ package channel -import "strings" - // TargetHint provides a display label and example for a target format. type TargetHint struct { Example string `json:"example,omitempty"` @@ -13,17 +11,3 @@ type TargetSpec struct { Format string `json:"format"` Hints []TargetHint `json:"hints,omitempty"` } - -// NormalizeTarget applies the channel-specific target normalization function. -// It returns the normalized string and true if a normalizer was found, otherwise the trimmed input and false. -func NormalizeTarget(channelType ChannelType, raw string) (string, bool) { - desc, ok := GetChannelDescriptor(channelType) - if !ok || desc.NormalizeTarget == nil { - return strings.TrimSpace(raw), false - } - normalized := strings.TrimSpace(desc.NormalizeTarget(raw)) - if normalized == "" { - return "", false - } - return normalized, true -} diff --git a/internal/channel/types.go b/internal/channel/types.go index 79ab3b53..0bb2147e 100644 --- a/internal/channel/types.go +++ b/internal/channel/types.go @@ -3,7 +3,6 @@ package channel import ( - "fmt" "strings" "time" ) @@ -11,16 +10,9 @@ import ( // ChannelType identifies a messaging platform (e.g., "telegram", "feishu"). type ChannelType string -// ParseChannelType validates and normalizes a raw string into a registered ChannelType. -func ParseChannelType(raw string) (ChannelType, error) { - normalized := normalizeChannelType(raw) - if normalized == "" { - return "", fmt.Errorf("unsupported channel type: %s", raw) - } - if _, ok := GetChannelDescriptor(normalized); !ok { - return "", fmt.Errorf("unsupported channel type: %s", raw) - } - return normalized, nil +// String returns the channel type as a plain string. +func (c ChannelType) String() string { + return string(c) } // Identity represents a sender's identity on a channel. diff --git a/internal/db/uuid.go b/internal/db/uuid.go new file mode 100644 index 00000000..a0b3dfd3 --- /dev/null +++ b/internal/db/uuid.go @@ -0,0 +1,42 @@ +package db + +import ( + "fmt" + "strings" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" +) + +// ParseUUID converts a string UUID to pgtype.UUID. +func ParseUUID(id string) (pgtype.UUID, error) { + parsed, err := uuid.Parse(strings.TrimSpace(id)) + if err != nil { + return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) + } + var pgID pgtype.UUID + pgID.Valid = true + copy(pgID.Bytes[:], parsed[:]) + return pgID, nil +} + +// UUIDToString converts a pgtype.UUID to its string representation. +func UUIDToString(value pgtype.UUID) string { + if !value.Valid { + return "" + } + parsed, err := uuid.FromBytes(value.Bytes[:]) + if err != nil { + return "" + } + return parsed.String() +} + +// TimeFromPg converts a pgtype.Timestamptz to time.Time. +func TimeFromPg(value pgtype.Timestamptz) time.Time { + if value.Valid { + return value.Time + } + return time.Time{} +} diff --git a/internal/handlers/channel.go b/internal/handlers/channel.go index dc96a4ea..3dd64f00 100644 --- a/internal/handlers/channel.go +++ b/internal/handlers/channel.go @@ -13,11 +13,12 @@ import ( ) type ChannelHandler struct { - service *channel.Service + service *channel.Service + registry *channel.Registry } -func NewChannelHandler(service *channel.Service) *ChannelHandler { - return &ChannelHandler{service: service} +func NewChannelHandler(service *channel.Service, registry *channel.Registry) *ChannelHandler { + return &ChannelHandler{service: service, registry: registry} } func (h *ChannelHandler) Register(e *echo.Echo) { @@ -45,7 +46,7 @@ func (h *ChannelHandler) GetUserConfig(c echo.Context) error { if err != nil { return err } - channelType, err := channel.ParseChannelType(c.Param("platform")) + channelType, err := h.registry.ParseChannelType(c.Param("platform")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } @@ -74,7 +75,7 @@ func (h *ChannelHandler) UpsertUserConfig(c echo.Context) error { if err != nil { return err } - channelType, err := channel.ParseChannelType(c.Param("platform")) + channelType, err := h.registry.ParseChannelType(c.Param("platform")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } @@ -110,7 +111,7 @@ type ChannelMeta struct { // @Failure 500 {object} ErrorResponse // @Router /channels [get] func (h *ChannelHandler) ListChannels(c echo.Context) error { - descs := channel.ListChannelDescriptors() + descs := h.registry.ListDescriptors() items := make([]ChannelMeta, 0, len(descs)) for _, desc := range descs { items = append(items, ChannelMeta{ @@ -139,11 +140,11 @@ func (h *ChannelHandler) ListChannels(c echo.Context) error { // @Failure 404 {object} ErrorResponse // @Router /channels/{platform} [get] func (h *ChannelHandler) GetChannel(c echo.Context) error { - channelType, err := channel.ParseChannelType(c.Param("platform")) + channelType, err := h.registry.ParseChannelType(c.Param("platform")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - desc, ok := channel.GetChannelDescriptor(channelType) + desc, ok := h.registry.GetDescriptor(channelType) if !ok { return echo.NewHTTPError(http.StatusNotFound, "channel not found") } diff --git a/internal/handlers/local_channel.go b/internal/handlers/local_channel.go index 204761be..9100a014 100644 --- a/internal/handlers/local_channel.go +++ b/internal/handlers/local_channel.go @@ -16,6 +16,7 @@ import ( "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/channel" + "github.com/memohai/memoh/internal/channel/adapters/local" "github.com/memohai/memoh/internal/identity" "github.com/memohai/memoh/internal/users" ) @@ -24,12 +25,12 @@ type LocalChannelHandler struct { channelType channel.ChannelType channelManager *channel.Manager channelService *channel.Service - sessionHub *channel.SessionHub + sessionHub *local.SessionHub botService *bots.Service userService *users.Service } -func NewLocalChannelHandler(channelType channel.ChannelType, channelManager *channel.Manager, channelService *channel.Service, sessionHub *channel.SessionHub, botService *bots.Service, userService *users.Service) *LocalChannelHandler { +func NewLocalChannelHandler(channelType channel.ChannelType, channelManager *channel.Manager, channelService *channel.Service, sessionHub *local.SessionHub, botService *bots.Service, userService *users.Service) *LocalChannelHandler { return &LocalChannelHandler{ channelType: channelType, channelManager: channelManager, diff --git a/internal/handlers/users.go b/internal/handlers/users.go index c007b1ca..f07d69e8 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -22,10 +22,11 @@ type UsersHandler struct { botService *bots.Service channelService *channel.Service channelManager *channel.Manager + registry *channel.Registry logger *slog.Logger } -func NewUsersHandler(log *slog.Logger, service *users.Service, botService *bots.Service, channelService *channel.Service, channelManager *channel.Manager) *UsersHandler { +func NewUsersHandler(log *slog.Logger, service *users.Service, botService *bots.Service, channelService *channel.Service, channelManager *channel.Manager, registry *channel.Registry) *UsersHandler { if log == nil { log = slog.Default() } @@ -34,6 +35,7 @@ func NewUsersHandler(log *slog.Logger, service *users.Service, botService *bots. botService: botService, channelService: channelService, channelManager: channelManager, + registry: registry, logger: log.With(slog.String("handler", "users")), } } @@ -667,7 +669,7 @@ func (h *UsersHandler) GetBotChannelConfig(c echo.Context) error { if _, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID); err != nil { return err } - channelType, err := channel.ParseChannelType(c.Param("platform")) + channelType, err := h.registry.ParseChannelType(c.Param("platform")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } @@ -706,7 +708,7 @@ func (h *UsersHandler) UpsertBotChannelConfig(c echo.Context) error { if _, err := h.authorizeBotAccess(c.Request().Context(), actorID, botID); err != nil { return err } - channelType, err := channel.ParseChannelType(c.Param("platform")) + channelType, err := h.registry.ParseChannelType(c.Param("platform")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } @@ -752,7 +754,7 @@ func (h *UsersHandler) SendBotMessage(c echo.Context) error { if h.channelManager == nil { return echo.NewHTTPError(http.StatusInternalServerError, "channel manager not configured") } - channelType, err := channel.ParseChannelType(c.Param("platform")) + channelType, err := h.registry.ParseChannelType(c.Param("platform")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } @@ -791,7 +793,7 @@ func (h *UsersHandler) SendBotMessageSession(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - channelType, err := channel.ParseChannelType(c.Param("platform")) + channelType, err := h.registry.ParseChannelType(c.Param("platform")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } diff --git a/internal/router/channel.go b/internal/router/channel.go index 74c97127..c652c05c 100644 --- a/internal/router/channel.go +++ b/internal/router/channel.go @@ -42,22 +42,24 @@ var ( // ChannelInboundProcessor 将 channel 入站消息路由到 chat,并返回可发送的回复。 type ChannelInboundProcessor struct { chat ChatGateway + registry *channel.Registry logger *slog.Logger jwtSecret string tokenTTL time.Duration identity *IdentityResolver } -func NewChannelInboundProcessor(log *slog.Logger, store channel.ConfigStore, chatGateway ChatGateway, contactService ContactService, policyService PolicyService, preauthService PreauthService, jwtSecret string, tokenTTL time.Duration) *ChannelInboundProcessor { +func NewChannelInboundProcessor(log *slog.Logger, registry *channel.Registry, store channel.ConfigStore, chatGateway ChatGateway, contactService ContactService, policyService PolicyService, preauthService PreauthService, jwtSecret string, tokenTTL time.Duration) *ChannelInboundProcessor { if log == nil { log = slog.Default() } if tokenTTL <= 0 { tokenTTL = 5 * time.Minute } - identityResolver := NewIdentityResolver(log, store, contactService, policyService, preauthService, "", "") + identityResolver := NewIdentityResolver(log, registry, store, contactService, policyService, preauthService, "", "") return &ChannelInboundProcessor{ chat: chatGateway, + registry: registry, logger: log.With(slog.String("component", "channel_router")), jwtSecret: strings.TrimSpace(jwtSecret), tokenTTL: tokenTTL, @@ -128,7 +130,10 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel token = "Bearer " + signed } } - desc, _ := channel.GetChannelDescriptor(msg.Channel) + var desc channel.Descriptor + if p.registry != nil { + desc, _ = p.registry.GetDescriptor(msg.Channel) + } resp, err := p.chat.Chat(ctx, chat.ChatRequest{ BotID: identity.BotID, SessionID: identity.SessionID, @@ -157,7 +162,7 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel if target == "" { return fmt.Errorf("reply target missing") } - sentTexts, suppressReplies := collectMessageToolContext(resp.Messages, msg.Channel, target) + sentTexts, suppressReplies := collectMessageToolContext(p.registry, resp.Messages, msg.Channel, target) if suppressReplies { return nil } @@ -352,7 +357,7 @@ type sendMessageToolArgs struct { Message *channel.Message `json:"message"` } -func collectMessageToolContext(messages []chat.ModelMessage, channelType channel.ChannelType, replyTarget string) ([]string, bool) { +func collectMessageToolContext(registry *channel.Registry, messages []chat.ModelMessage, channelType channel.ChannelType, replyTarget string) ([]string, bool) { if len(messages) == 0 { return nil, false } @@ -370,7 +375,7 @@ func collectMessageToolContext(messages []chat.ModelMessage, channelType channel if text := strings.TrimSpace(extractSendMessageText(args)); text != "" { sentTexts = append(sentTexts, text) } - if shouldSuppressForToolCall(args, channelType, replyTarget) { + if shouldSuppressForToolCall(registry, args, channelType, replyTarget) { suppressReplies = true } } @@ -405,7 +410,7 @@ func extractSendMessageText(args sendMessageToolArgs) string { return strings.TrimSpace(args.Message.PlainText()) } -func shouldSuppressForToolCall(args sendMessageToolArgs, channelType channel.ChannelType, replyTarget string) bool { +func shouldSuppressForToolCall(registry *channel.Registry, args sendMessageToolArgs, channelType channel.ChannelType, replyTarget string) bool { platform := strings.TrimSpace(args.Platform) if platform == "" { platform = string(channelType) @@ -420,16 +425,19 @@ func shouldSuppressForToolCall(args sendMessageToolArgs, channelType channel.Cha if strings.TrimSpace(target) == "" || strings.TrimSpace(replyTarget) == "" { return false } - normalizedTarget := normalizeReplyTarget(channelType, target) - normalizedReply := normalizeReplyTarget(channelType, replyTarget) + normalizedTarget := normalizeReplyTarget(registry, channelType, target) + normalizedReply := normalizeReplyTarget(registry, channelType, replyTarget) if normalizedTarget == "" || normalizedReply == "" { return false } return normalizedTarget == normalizedReply } -func normalizeReplyTarget(channelType channel.ChannelType, target string) string { - normalized, ok := channel.NormalizeTarget(channelType, target) +func normalizeReplyTarget(registry *channel.Registry, channelType channel.ChannelType, target string) string { + if registry == nil { + return strings.TrimSpace(target) + } + normalized, ok := registry.NormalizeTarget(channelType, target) if ok && strings.TrimSpace(normalized) != "" { return strings.TrimSpace(normalized) } diff --git a/internal/router/channel_test.go b/internal/router/channel_test.go index 28c35631..1dc061d6 100644 --- a/internal/router/channel_test.go +++ b/internal/router/channel_test.go @@ -134,7 +134,7 @@ func TestChannelInboundProcessorBoundUser(t *testing.T) { }, }, } - processor := NewChannelInboundProcessor(slog.Default(), store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} @@ -166,7 +166,7 @@ func TestChannelInboundProcessorBoundUser(t *testing.T) { func TestChannelInboundProcessorUnboundUser(t *testing.T) { store := &fakeConfigStore{} gateway := &fakeChatGateway{} - processor := NewChannelInboundProcessor(slog.Default(), store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} @@ -191,7 +191,7 @@ func TestChannelInboundProcessorUnboundUser(t *testing.T) { func TestChannelInboundProcessorIgnoreEmpty(t *testing.T) { store := &fakeConfigStore{} gateway := &fakeChatGateway{} - processor := NewChannelInboundProcessor(slog.Default(), store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1"} @@ -223,7 +223,7 @@ func TestChannelInboundProcessorSilentReply(t *testing.T) { }, }, } - processor := NewChannelInboundProcessor(slog.Default(), store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} @@ -272,7 +272,7 @@ func TestChannelInboundProcessorSuppressOnToolSend(t *testing.T) { }, }, } - processor := NewChannelInboundProcessor(slog.Default(), store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} @@ -321,7 +321,7 @@ func TestChannelInboundProcessorDedupeWithToolSend(t *testing.T) { }, }, } - processor := NewChannelInboundProcessor(slog.Default(), store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) + processor := NewChannelInboundProcessor(slog.Default(), nil, store, gateway, &fakeContactService{}, &fakePolicyService{}, nil, "", 0) sender := &fakeReplySender{} cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} diff --git a/internal/router/identity.go b/internal/router/identity.go index b0ba5211..863a7036 100644 --- a/internal/router/identity.go +++ b/internal/router/identity.go @@ -52,8 +52,15 @@ func IdentityStateFromContext(ctx context.Context) (IdentityState, bool) { return state, ok } +// IdentityStore is the minimal persistence interface required by IdentityResolver. +type IdentityStore interface { + channel.BindingStore + channel.SessionStore +} + type IdentityResolver struct { - store channel.ConfigStore + registry *channel.Registry + store IdentityStore contacts ContactService policy PolicyService preauth PreauthService @@ -71,7 +78,7 @@ type PreauthService interface { MarkUsed(ctx context.Context, id string) (preauth.Key, error) } -func NewIdentityResolver(log *slog.Logger, store channel.ConfigStore, contacts ContactService, policyService PolicyService, preauthService PreauthService, unboundReply, preauthReply string) *IdentityResolver { +func NewIdentityResolver(log *slog.Logger, registry *channel.Registry, store IdentityStore, contacts ContactService, policyService PolicyService, preauthService PreauthService, unboundReply, preauthReply string) *IdentityResolver { if log == nil { log = slog.Default() } @@ -82,6 +89,7 @@ func NewIdentityResolver(log *slog.Logger, store channel.ConfigStore, contacts C preauthReply = "授权成功,请继续使用。" } return &IdentityResolver{ + registry: registry, store: store, contacts: contacts, policy: policyService, @@ -118,7 +126,7 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi sessionID := normalizedMsg.SessionID() channelConfigID := cfg.ID - if channel.IsConfigless(msg.Channel) { + if r.registry != nil && r.registry.IsConfigless(msg.Channel) { channelConfigID = "" } externalID := extractExternalIdentity(msg) diff --git a/internal/router/identity_test.go b/internal/router/identity_test.go index 3948f16c..90c87bd8 100644 --- a/internal/router/identity_test.go +++ b/internal/router/identity_test.go @@ -119,7 +119,7 @@ func TestIdentityResolverAllowGuestCreatesContact(t *testing.T) { store := &fakeIdentityConfigStore{} contactsService := &fakeIdentityContactService{} policyService := &fakePolicyServiceIdentity{decision: policy.Decision{AllowGuest: true}} - resolver := NewIdentityResolver(slog.Default(), store, contactsService, policyService, nil, "禁止访问", "授权成功") + resolver := NewIdentityResolver(slog.Default(), nil, store, contactsService, policyService, nil, "禁止访问", "授权成功") msg := channel.InboundMessage{ BotID: "bot-1", @@ -152,7 +152,7 @@ func TestIdentityResolverPreauthKeyAllowsGuest(t *testing.T) { ExpiresAt: time.Now().UTC().Add(1 * time.Hour), }, } - resolver := NewIdentityResolver(slog.Default(), store, contactsService, policyService, preauthService, "禁止访问", "授权成功") + resolver := NewIdentityResolver(slog.Default(), nil, store, contactsService, policyService, preauthService, "禁止访问", "授权成功") msg := channel.InboundMessage{ BotID: "bot-1", @@ -188,7 +188,7 @@ func TestIdentityResolverPreauthKeyExpired(t *testing.T) { ExpiresAt: time.Now().UTC().Add(-1 * time.Hour), }, } - resolver := NewIdentityResolver(slog.Default(), store, contactsService, policyService, preauthService, "禁止访问", "授权成功") + resolver := NewIdentityResolver(slog.Default(), nil, store, contactsService, policyService, preauthService, "禁止访问", "授权成功") msg := channel.InboundMessage{ BotID: "bot-1",