diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 06c7143d..c441f08d 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -25,6 +25,7 @@ import ( "github.com/memohai/memoh/internal/boot" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/channel" + "github.com/memohai/memoh/internal/channel/adapters/discord" "github.com/memohai/memoh/internal/channel/adapters/feishu" "github.com/memohai/memoh/internal/channel/adapters/local" "github.com/memohai/memoh/internal/channel/adapters/telegram" @@ -394,6 +395,7 @@ func provideChannelRegistry(log *slog.Logger, hub *local.RouteHub, mediaService tgAdapter := telegram.NewTelegramAdapter(log) tgAdapter.SetAssetOpener(mediaService) registry.MustRegister(tgAdapter) + registry.MustRegister(discord.NewDiscordAdapter(log)) registry.MustRegister(feishu.NewFeishuAdapter(log)) registry.MustRegister(local.NewCLIAdapter(hub)) registry.MustRegister(local.NewWebAdapter(hub)) diff --git a/go.mod b/go.mod index bf7ed440..837070a3 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( github.com/blevesearch/snowballstem v0.9.0 // indirect github.com/blevesearch/stempel v0.2.0 // indirect github.com/blevesearch/upsidedown_store_api v1.0.2 // indirect + github.com/bwmarrin/discordgo v0.29.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/containerd/cgroups/v3 v3.1.2 // indirect github.com/containerd/continuity v0.4.5 // indirect diff --git a/go.sum b/go.sum index cae2737f..6092496a 100644 --- a/go.sum +++ b/go.sum @@ -32,6 +32,8 @@ github.com/blevesearch/stempel v0.2.0 h1:CYzVPaScODMvgE9o+kf6D4RJ/VRomyi9uHF+PtB github.com/blevesearch/stempel v0.2.0/go.mod h1:wjeTHqQv+nQdbPuJ/YcvOjTInA2EIc6Ks1FoSUzSLvc= github.com/blevesearch/upsidedown_store_api v1.0.2 h1:U53Q6YoWEARVLd1OYNc9kvhBMGZzVrdmaozG2MfoB+A= github.com/blevesearch/upsidedown_store_api v1.0.2/go.mod h1:M01mh3Gpfy56Ps/UXHjEO/knbqyQ1Oamg8If49gRwrQ= +github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno= +github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -161,6 +163,7 @@ github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8/go.mod h1:K1liHPHnj73 github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 h1:JeSE6pjso5THxAzdVpqr6/geYxZytqFMBCOtn/ujyeo= github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674/go.mod h1:r4w70xmWCQKmi1ONH4KIaBptdivuRPyosB9RmPlGEwA= @@ -311,6 +314,7 @@ go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -332,6 +336,7 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -348,9 +353,11 @@ golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= diff --git a/internal/channel/adapters/discord/config.go b/internal/channel/adapters/discord/config.go new file mode 100644 index 00000000..1d0c4b2f --- /dev/null +++ b/internal/channel/adapters/discord/config.go @@ -0,0 +1,131 @@ +package discord + +import ( + "fmt" + "strings" + + "github.com/memohai/memoh/internal/channel" +) + +type Config struct { + BotToken string +} + +type UserConfig struct { + UserID string + ChannelID string + GuildID string + Username string +} + +func normalizeConfig(raw map[string]any) (map[string]any, error) { + cfg, err := parseConfig(raw) + if err != nil { + return nil, err + } + return map[string]any{"botToken": cfg.BotToken}, nil +} + +func normalizeUserConfig(raw map[string]any) (map[string]any, error) { + cfg, err := parseUserConfig(raw) + if err != nil { + return nil, err + } + result := map[string]any{} + if cfg.UserID != "" { + result["user_id"] = cfg.UserID + } + if cfg.ChannelID != "" { + result["channel_id"] = cfg.ChannelID + } + if cfg.GuildID != "" { + result["guild_id"] = cfg.GuildID + } + if cfg.Username != "" { + result["username"] = cfg.Username + } + return result, nil +} + +func resolveTarget(raw map[string]any) (string, error) { + cfg, err := parseUserConfig(raw) + if err != nil { + return "", err + } + if cfg.ChannelID != "" { + return cfg.ChannelID, nil + } + if cfg.UserID != "" { + return cfg.UserID, nil + } + return "", fmt.Errorf("discord binding is incomplete") +} + +func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { + cfg, err := parseUserConfig(raw) + if err != nil { + return false + } + if value := strings.TrimSpace(criteria.Attribute("user_id")); value != "" && value == cfg.UserID { + return true + } + if value := strings.TrimSpace(criteria.Attribute("channel_id")); value != "" && value == cfg.ChannelID { + return true + } + if value := strings.TrimSpace(criteria.Attribute("username")); value != "" && strings.EqualFold(value, cfg.Username) { + return true + } + if criteria.SubjectID != "" { + if criteria.SubjectID == cfg.UserID || criteria.SubjectID == cfg.ChannelID { + return true + } + } + return false +} + +func buildUserConfig(identity channel.Identity) map[string]any { + result := map[string]any{} + if value := strings.TrimSpace(identity.Attribute("user_id")); value != "" { + result["user_id"] = value + } + if value := strings.TrimSpace(identity.Attribute("channel_id")); value != "" { + result["channel_id"] = value + } + if value := strings.TrimSpace(identity.Attribute("guild_id")); value != "" { + result["guild_id"] = value + } + if value := strings.TrimSpace(identity.Attribute("username")); value != "" { + result["username"] = value + } + return result +} + +func parseConfig(raw map[string]any) (Config, error) { + token := strings.TrimSpace(channel.ReadString(raw, "botToken", "bot_token")) + if token == "" { + return Config{}, fmt.Errorf("discord botToken is required") + } + return Config{BotToken: token}, nil +} + +func parseUserConfig(raw map[string]any) (UserConfig, error) { + userID := strings.TrimSpace(channel.ReadString(raw,"userId", "user_id")) + channelID := strings.TrimSpace(channel.ReadString(raw, "channelId", "channel_id")) + guildID := strings.TrimSpace(channel.ReadString(raw, "guildId", "guild_id")) + username := strings.TrimSpace(channel.ReadString(raw, "username")) + + if userID == "" && channelID == "" { + return UserConfig{}, fmt.Errorf("discord user config requires user_id or channel_id") + } + + return UserConfig{ + UserID: userID, + ChannelID: channelID, + GuildID: guildID, + Username: username, + }, nil +} + +func normalizeTarget(raw string) string { + return strings.TrimSpace(raw) +} \ No newline at end of file diff --git a/internal/channel/adapters/discord/descriptor.go b/internal/channel/adapters/discord/descriptor.go new file mode 100644 index 00000000..5e44e12a --- /dev/null +++ b/internal/channel/adapters/discord/descriptor.go @@ -0,0 +1,5 @@ +package discord + +import "github.com/memohai/memoh/internal/channel" + +const Type channel.ChannelType = "discord" \ No newline at end of file diff --git a/internal/channel/adapters/discord/discord.go b/internal/channel/adapters/discord/discord.go new file mode 100644 index 00000000..d5720d32 --- /dev/null +++ b/internal/channel/adapters/discord/discord.go @@ -0,0 +1,400 @@ +package discord + +import ( + "context" + "fmt" + "log/slog" + "strings" + "sync" + "time" + + "github.com/bwmarrin/discordgo" + "github.com/memohai/memoh/internal/channel" + "github.com/memohai/memoh/internal/channel/adapters/common" +) + +type DiscordAdapter struct { + logger *slog.Logger + mu sync.RWMutex + sessions map[string]*discordgo.Session // keyed by bot token +} + +func NewDiscordAdapter(log *slog.Logger) *DiscordAdapter { + if log == nil { + log = slog.Default() + } + return &DiscordAdapter{ + logger: log.With(slog.String("adapter", "discord")), + sessions: make(map[string]*discordgo.Session), + } +} + +func (a *DiscordAdapter) Type() channel.ChannelType { + return Type +} + +func (a *DiscordAdapter) Descriptor() channel.Descriptor { + return channel.Descriptor{ + Type: Type, + DisplayName: "Discord", + Capabilities: channel.ChannelCapabilities{ + Text: true, + Markdown: true, + Reply: true, + Attachments: true, + Media: true, + Streaming: true, + BlockStreaming: true, + Reactions: true, + }, + ConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "botToken": { + Type: channel.FieldSecret, + Required: true, + Title: "Bot Token", + }, + }, + }, + UserConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "user_id": {Type: channel.FieldString}, + "channel_id": {Type: channel.FieldString}, + "guild_id": {Type: channel.FieldString}, + "username": {Type: channel.FieldString}, + }, + }, + TargetSpec: channel.TargetSpec{ + Format: "channel_id | user_id", + Hints: []channel.TargetHint{ + {Label: "Channel ID", Example: "1234567890123456789"}, + {Label: "User ID", Example: "1234567890123456789"}, + }, + }, + } +} + +func (a *DiscordAdapter) getOrCreateSession(token, configID string) (*discordgo.Session, error) { + a.mu.RLock() + session, ok := a.sessions[token] + a.mu.RUnlock() + if ok { + return session, nil + } + + a.mu.Lock() + defer a.mu.Unlock() + if s, ok := a.sessions[token]; ok { + return s, nil + } + + session, err := discordgo.New("Bot " + token) + if err != nil { + a.logger.Error("create session failed", slog.String("config_id", configID), slog.Any("error", err)) + return nil, err + } + + session.Identify.Intents = discordgo.IntentsAll + + a.sessions[token] = session + return session, nil +} + +func (a *DiscordAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.Connection, error) { + if a.logger != nil { + a.logger.Info("start", slog.String("config_id", cfg.ID)) + } + + discordCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, err + } + + session, err := a.getOrCreateSession(discordCfg.BotToken, cfg.ID) + if err != nil { + return nil, err + } + + session.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) { + if m.Author != nil && m.Author.Bot { + return + } + + if ctx.Err() != nil { + return + } + + text := strings.TrimSpace(m.Content) + botId := s.State.User.ID + if text == "" && len(m.Attachments) == 0 { + return + } + + attachments := a.collectAttachments(m.Message) + chatType := "direct" + if m.GuildID != "" { + chatType = "guild" + } + + isMentioned := a.isBotMentioned(m.Message, botId) + isReplyToBot := m.ReferencedMessage != nil && + m.ReferencedMessage.Author != nil && + m.ReferencedMessage.Author.ID == botId + + msg := channel.InboundMessage{ + Channel: Type, + Message: channel.Message{ + ID: m.ID, + Format: channel.MessageFormatPlain, + Text: text, + Attachments: attachments, + }, + BotID: cfg.BotID, + ReplyTarget: m.ChannelID, + Sender: channel.Identity{ + SubjectID: m.Author.ID, + DisplayName: m.Author.Username, + Attributes: map[string]string{ + "user_id": m.Author.ID, + "username": m.Author.Username, + }, + }, + Conversation: channel.Conversation{ + ID: m.ChannelID, + Type: chatType, + }, + ReceivedAt: time.Now().UTC(), + Source: "discord", + Metadata: map[string]any{ + "guild_id": m.GuildID, + "is_mentioned": isMentioned, + "is_reply_to_bot": isReplyToBot, + }, + } + + if a.logger != nil { + a.logger.Info("inbound received", + slog.String("config_id", cfg.ID), + slog.String("chat_type", chatType), + slog.String("user_id", m.Author.ID), + slog.String("username", m.Author.Username), + slog.String("text", common.SummarizeText(text)), + ) + } + + go func() { + if err := handler(ctx, cfg, msg); err != nil && a.logger != nil { + a.logger.Error("handle inbound failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) + } + }() + }) + + if err := session.Open(); err != nil { + return nil, fmt.Errorf("discord open connection: %w", err) + } + + stop := func(stopCtx context.Context) error { + if a.logger != nil { + a.logger.Info("stop", slog.String("config_id", cfg.ID)) + } + return session.Close() + } + + return channel.NewConnection(cfg, stop), nil +} + +func (a *DiscordAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { + discordCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return err + } + + session, err := a.getOrCreateSession(discordCfg.BotToken, cfg.ID) + if err != nil { + return err + } + + channelID := strings.TrimSpace(msg.Target) + if channelID == "" { + return fmt.Errorf("discord target is required") + } + + text := strings.TrimSpace(msg.Message.PlainText()) + if text == "" && len(msg.Message.Attachments) == 0 { + return fmt.Errorf("message is required") + } + + // Discord limit: 2000 characters + const discordMaxLength = 2000 + if len(text) > discordMaxLength { + text = text[:discordMaxLength-3] + "..." + } + + _, err = session.ChannelMessageSend(channelID, text) + return err +} + +func (a *DiscordAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { + target = strings.TrimSpace(target) + if target == "" { + return nil, fmt.Errorf("discord target is required") + } + + discordCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, err + } + + session, err := a.getOrCreateSession(discordCfg.BotToken, cfg.ID) + if err != nil { + return nil, err + } + + return &discordOutboundStream{ + adapter: a, + cfg: cfg, + target: target, + reply: opts.Reply, + session: session, + }, nil +} + +func (a *DiscordAdapter) ProcessingStarted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { + chatID := strings.TrimSpace(info.ReplyTarget) + if chatID == "" { + return channel.ProcessingStatusHandle{}, nil + } + + discordCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return channel.ProcessingStatusHandle{}, err + } + + session, err := a.getOrCreateSession(discordCfg.BotToken, cfg.ID) + if err != nil { + return channel.ProcessingStatusHandle{}, err + } + + // Discord typing indicator + err = session.ChannelTyping(chatID) + return channel.ProcessingStatusHandle{}, err +} + +func (a *DiscordAdapter) ProcessingCompleted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { + return nil +} + +func (a *DiscordAdapter) ProcessingFailed(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error { + return nil +} + +func (a *DiscordAdapter) React(ctx context.Context, cfg channel.ChannelConfig, target string, messageID string, emoji string) error { + discordCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return err + } + + session, err := a.getOrCreateSession(discordCfg.BotToken, cfg.ID) + if err != nil { + return err + } + + return session.MessageReactionAdd(target, messageID, emoji) +} + +func (a *DiscordAdapter) Unreact(ctx context.Context, cfg channel.ChannelConfig, target string, messageID string, emoji string) error { + discordCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return err + } + + session, err := a.getOrCreateSession(discordCfg.BotToken, cfg.ID) + if err != nil { + return err + } + + return session.MessageReactionRemove(target, messageID, emoji, "@me") +} + +func (a *DiscordAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { + return normalizeConfig(raw) +} + +func (a *DiscordAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { + return normalizeUserConfig(raw) +} + +func (a *DiscordAdapter) NormalizeTarget(raw string) string { + return normalizeTarget(raw) +} + +func (a *DiscordAdapter) ResolveTarget(userConfig map[string]any) (string, error) { + return resolveTarget(userConfig) +} + +func (a *DiscordAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool { + return matchBinding(config, criteria) +} + +func (a *DiscordAdapter) BuildUserConfig(identity channel.Identity) map[string]any { + return buildUserConfig(identity) +} + +func (a *DiscordAdapter) collectAttachments(msg *discordgo.Message) []channel.Attachment { + if msg == nil || len(msg.Attachments) == 0 { + return nil + } + + attachments := make([]channel.Attachment, 0, len(msg.Attachments)) + for _, att := range msg.Attachments { + attachment := channel.Attachment{ + Type: channel.AttachmentFile, + URL: att.URL, + PlatformKey: att.ID, + SourcePlatform: Type.String(), + Name: att.Filename, + Size: int64(att.Size), + } + + if att.ContentType != "" { + switch { + case strings.HasPrefix(att.ContentType, "image/"): + attachment.Type = channel.AttachmentImage + attachment.Width = att.Width + attachment.Height = att.Height + case strings.HasPrefix(att.ContentType, "video/"): + attachment.Type = channel.AttachmentVideo + case strings.HasPrefix(att.ContentType, "audio/"): + attachment.Type = channel.AttachmentAudio + } + } + + attachments = append(attachments, attachment) + } + + return attachments +} + +func (a *DiscordAdapter) isBotMentioned(msg *discordgo.Message, botID string) bool { + if msg == nil { + return false + } + + for _, mention := range msg.Mentions { + if mention != nil && mention.ID == botID { + return true + } + } + + if msg.MentionEveryone { + return true + } + + botMention := "<@" + botID + ">" + botNickMention := "<@!" + botID + ">" + content := strings.ToLower(msg.Content) + return strings.Contains(content, strings.ToLower(botMention)) || + strings.Contains(content, strings.ToLower(botNickMention)) +} \ No newline at end of file diff --git a/internal/channel/adapters/discord/stream.go b/internal/channel/adapters/discord/stream.go new file mode 100644 index 00000000..5db2775c --- /dev/null +++ b/internal/channel/adapters/discord/stream.go @@ -0,0 +1,174 @@ +package discord + +import ( + "context" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/bwmarrin/discordgo" + "github.com/memohai/memoh/internal/channel" +) + +type discordOutboundStream struct { + adapter *DiscordAdapter + cfg channel.ChannelConfig + target string + reply *channel.ReplyRef + session *discordgo.Session + closed atomic.Bool + mu sync.Mutex + msgID string + buffer strings.Builder + lastUpdate time.Time +} + +func (s *discordOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { + if s == nil || s.adapter == nil { + return fmt.Errorf("discord stream not configured") + } + if s.closed.Load() { + return fmt.Errorf("discord stream is closed") + } + + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + switch event.Type { + case channel.StreamEventStatus: + if event.Status == channel.StreamStatusStarted { + return s.ensureMessage(ctx, "Thinking...") + } + return nil + + case channel.StreamEventDelta: + if event.Delta == "" { + return nil + } + s.mu.Lock() + s.buffer.WriteString(event.Delta) + s.mu.Unlock() + + // Discord has strict rate limits, only update periodically + if time.Since(s.lastUpdate) > 2*time.Second { + return s.updateMessage(ctx) + } + return nil + + case channel.StreamEventFinal: + if event.Final != nil && !event.Final.Message.IsEmpty() { + finalText := strings.TrimSpace(event.Final.Message.PlainText()) + if finalText != "" { + return s.finalizeMessage(ctx, finalText) + } + } + s.mu.Lock() + finalText := strings.TrimSpace(s.buffer.String()) + s.mu.Unlock() + if finalText != "" { + return s.finalizeMessage(ctx, finalText) + } + return nil + + case channel.StreamEventError: + errText := strings.TrimSpace(event.Error) + if errText == "" { + return nil + } + return s.finalizeMessage(ctx, "Error: "+errText) + + case channel.StreamEventAgentStart, channel.StreamEventAgentEnd, channel.StreamEventPhaseStart, channel.StreamEventPhaseEnd, channel.StreamEventProcessingStarted, channel.StreamEventProcessingCompleted, channel.StreamEventProcessingFailed, channel.StreamEventToolCallStart, channel.StreamEventToolCallEnd: + // Status events - no action needed for Discord + return nil + + default: + return fmt.Errorf("unsupported stream event type: %s", event.Type) + } +} + +func (s *discordOutboundStream) Close(ctx context.Context) error { + if s == nil { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + s.closed.Store(true) + return nil +} + +func (s *discordOutboundStream) ensureMessage(ctx context.Context, text string) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.msgID != "" { + return nil + } + + // Discord limit: 2000 characters + content := text + if len(content) > 2000 { + content = content[:1997] + "..." + } + + msg, err := s.session.ChannelMessageSend(s.target, content) + if err != nil { + return err + } + + s.msgID = msg.ID + s.lastUpdate = time.Now() + return nil +} + +func (s *discordOutboundStream) updateMessage(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.msgID == "" { + return nil + } + + content := s.buffer.String() + if content == "" { + return nil + } + + // Discord limit + if len(content) > 2000 { + content = content[:1997] + "..." + } + + _, err := s.session.ChannelMessageEdit(s.target, s.msgID, content) + if err != nil { + return err + } + + s.lastUpdate = time.Now() + return nil +} + +func (s *discordOutboundStream) finalizeMessage(ctx context.Context, text string) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Discord limit + if len(text) > 2000 { + text = text[:1997] + "..." + } + + if s.msgID == "" { + _, err := s.session.ChannelMessageSend(s.target, text) + return err + } + + _, err := s.session.ChannelMessageEdit(s.target, s.msgID, text) + return err +} \ No newline at end of file