feat(platforms): add discord channel support

This commit is contained in:
Fodesu
2026-02-22 14:25:18 +08:00
committed by 晨苒
parent da9d64e508
commit 77ff24c6fd
7 changed files with 720 additions and 0 deletions
+131
View File
@@ -0,0 +1,131 @@
package discord
import (
"fmt"
"strings"
"github.com/memohai/memoh/internal/channel"
)
type Config struct {
BotToken string
}
type UserConfig struct {
UserID string
ChannelID string
GuildID string
Username string
}
func normalizeConfig(raw map[string]any) (map[string]any, error) {
cfg, err := parseConfig(raw)
if err != nil {
return nil, err
}
return map[string]any{"botToken": cfg.BotToken}, nil
}
func normalizeUserConfig(raw map[string]any) (map[string]any, error) {
cfg, err := parseUserConfig(raw)
if err != nil {
return nil, err
}
result := map[string]any{}
if cfg.UserID != "" {
result["user_id"] = cfg.UserID
}
if cfg.ChannelID != "" {
result["channel_id"] = cfg.ChannelID
}
if cfg.GuildID != "" {
result["guild_id"] = cfg.GuildID
}
if cfg.Username != "" {
result["username"] = cfg.Username
}
return result, nil
}
func resolveTarget(raw map[string]any) (string, error) {
cfg, err := parseUserConfig(raw)
if err != nil {
return "", err
}
if cfg.ChannelID != "" {
return cfg.ChannelID, nil
}
if cfg.UserID != "" {
return cfg.UserID, nil
}
return "", fmt.Errorf("discord binding is incomplete")
}
func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool {
cfg, err := parseUserConfig(raw)
if err != nil {
return false
}
if value := strings.TrimSpace(criteria.Attribute("user_id")); value != "" && value == cfg.UserID {
return true
}
if value := strings.TrimSpace(criteria.Attribute("channel_id")); value != "" && value == cfg.ChannelID {
return true
}
if value := strings.TrimSpace(criteria.Attribute("username")); value != "" && strings.EqualFold(value, cfg.Username) {
return true
}
if criteria.SubjectID != "" {
if criteria.SubjectID == cfg.UserID || criteria.SubjectID == cfg.ChannelID {
return true
}
}
return false
}
func buildUserConfig(identity channel.Identity) map[string]any {
result := map[string]any{}
if value := strings.TrimSpace(identity.Attribute("user_id")); value != "" {
result["user_id"] = value
}
if value := strings.TrimSpace(identity.Attribute("channel_id")); value != "" {
result["channel_id"] = value
}
if value := strings.TrimSpace(identity.Attribute("guild_id")); value != "" {
result["guild_id"] = value
}
if value := strings.TrimSpace(identity.Attribute("username")); value != "" {
result["username"] = value
}
return result
}
func parseConfig(raw map[string]any) (Config, error) {
token := strings.TrimSpace(channel.ReadString(raw, "botToken", "bot_token"))
if token == "" {
return Config{}, fmt.Errorf("discord botToken is required")
}
return Config{BotToken: token}, nil
}
func parseUserConfig(raw map[string]any) (UserConfig, error) {
userID := strings.TrimSpace(channel.ReadString(raw,"userId", "user_id"))
channelID := strings.TrimSpace(channel.ReadString(raw, "channelId", "channel_id"))
guildID := strings.TrimSpace(channel.ReadString(raw, "guildId", "guild_id"))
username := strings.TrimSpace(channel.ReadString(raw, "username"))
if userID == "" && channelID == "" {
return UserConfig{}, fmt.Errorf("discord user config requires user_id or channel_id")
}
return UserConfig{
UserID: userID,
ChannelID: channelID,
GuildID: guildID,
Username: username,
}, nil
}
func normalizeTarget(raw string) string {
return strings.TrimSpace(raw)
}
@@ -0,0 +1,5 @@
package discord
import "github.com/memohai/memoh/internal/channel"
const Type channel.ChannelType = "discord"
@@ -0,0 +1,400 @@
package discord
import (
"context"
"fmt"
"log/slog"
"strings"
"sync"
"time"
"github.com/bwmarrin/discordgo"
"github.com/memohai/memoh/internal/channel"
"github.com/memohai/memoh/internal/channel/adapters/common"
)
type DiscordAdapter struct {
logger *slog.Logger
mu sync.RWMutex
sessions map[string]*discordgo.Session // keyed by bot token
}
func NewDiscordAdapter(log *slog.Logger) *DiscordAdapter {
if log == nil {
log = slog.Default()
}
return &DiscordAdapter{
logger: log.With(slog.String("adapter", "discord")),
sessions: make(map[string]*discordgo.Session),
}
}
func (a *DiscordAdapter) Type() channel.ChannelType {
return Type
}
func (a *DiscordAdapter) Descriptor() channel.Descriptor {
return channel.Descriptor{
Type: Type,
DisplayName: "Discord",
Capabilities: channel.ChannelCapabilities{
Text: true,
Markdown: true,
Reply: true,
Attachments: true,
Media: true,
Streaming: true,
BlockStreaming: true,
Reactions: true,
},
ConfigSchema: channel.ConfigSchema{
Version: 1,
Fields: map[string]channel.FieldSchema{
"botToken": {
Type: channel.FieldSecret,
Required: true,
Title: "Bot Token",
},
},
},
UserConfigSchema: channel.ConfigSchema{
Version: 1,
Fields: map[string]channel.FieldSchema{
"user_id": {Type: channel.FieldString},
"channel_id": {Type: channel.FieldString},
"guild_id": {Type: channel.FieldString},
"username": {Type: channel.FieldString},
},
},
TargetSpec: channel.TargetSpec{
Format: "channel_id | user_id",
Hints: []channel.TargetHint{
{Label: "Channel ID", Example: "1234567890123456789"},
{Label: "User ID", Example: "1234567890123456789"},
},
},
}
}
func (a *DiscordAdapter) getOrCreateSession(token, configID string) (*discordgo.Session, error) {
a.mu.RLock()
session, ok := a.sessions[token]
a.mu.RUnlock()
if ok {
return session, nil
}
a.mu.Lock()
defer a.mu.Unlock()
if s, ok := a.sessions[token]; ok {
return s, nil
}
session, err := discordgo.New("Bot " + token)
if err != nil {
a.logger.Error("create session failed", slog.String("config_id", configID), slog.Any("error", err))
return nil, err
}
session.Identify.Intents = discordgo.IntentsAll
a.sessions[token] = session
return session, nil
}
func (a *DiscordAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.Connection, error) {
if a.logger != nil {
a.logger.Info("start", slog.String("config_id", cfg.ID))
}
discordCfg, err := parseConfig(cfg.Credentials)
if err != nil {
return nil, err
}
session, err := a.getOrCreateSession(discordCfg.BotToken, cfg.ID)
if err != nil {
return nil, err
}
session.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) {
if m.Author != nil && m.Author.Bot {
return
}
if ctx.Err() != nil {
return
}
text := strings.TrimSpace(m.Content)
botId := s.State.User.ID
if text == "" && len(m.Attachments) == 0 {
return
}
attachments := a.collectAttachments(m.Message)
chatType := "direct"
if m.GuildID != "" {
chatType = "guild"
}
isMentioned := a.isBotMentioned(m.Message, botId)
isReplyToBot := m.ReferencedMessage != nil &&
m.ReferencedMessage.Author != nil &&
m.ReferencedMessage.Author.ID == botId
msg := channel.InboundMessage{
Channel: Type,
Message: channel.Message{
ID: m.ID,
Format: channel.MessageFormatPlain,
Text: text,
Attachments: attachments,
},
BotID: cfg.BotID,
ReplyTarget: m.ChannelID,
Sender: channel.Identity{
SubjectID: m.Author.ID,
DisplayName: m.Author.Username,
Attributes: map[string]string{
"user_id": m.Author.ID,
"username": m.Author.Username,
},
},
Conversation: channel.Conversation{
ID: m.ChannelID,
Type: chatType,
},
ReceivedAt: time.Now().UTC(),
Source: "discord",
Metadata: map[string]any{
"guild_id": m.GuildID,
"is_mentioned": isMentioned,
"is_reply_to_bot": isReplyToBot,
},
}
if a.logger != nil {
a.logger.Info("inbound received",
slog.String("config_id", cfg.ID),
slog.String("chat_type", chatType),
slog.String("user_id", m.Author.ID),
slog.String("username", m.Author.Username),
slog.String("text", common.SummarizeText(text)),
)
}
go func() {
if err := handler(ctx, cfg, msg); err != nil && a.logger != nil {
a.logger.Error("handle inbound failed", slog.String("config_id", cfg.ID), slog.Any("error", err))
}
}()
})
if err := session.Open(); err != nil {
return nil, fmt.Errorf("discord open connection: %w", err)
}
stop := func(stopCtx context.Context) error {
if a.logger != nil {
a.logger.Info("stop", slog.String("config_id", cfg.ID))
}
return session.Close()
}
return channel.NewConnection(cfg, stop), nil
}
func (a *DiscordAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error {
discordCfg, err := parseConfig(cfg.Credentials)
if err != nil {
return err
}
session, err := a.getOrCreateSession(discordCfg.BotToken, cfg.ID)
if err != nil {
return err
}
channelID := strings.TrimSpace(msg.Target)
if channelID == "" {
return fmt.Errorf("discord target is required")
}
text := strings.TrimSpace(msg.Message.PlainText())
if text == "" && len(msg.Message.Attachments) == 0 {
return fmt.Errorf("message is required")
}
// Discord limit: 2000 characters
const discordMaxLength = 2000
if len(text) > discordMaxLength {
text = text[:discordMaxLength-3] + "..."
}
_, err = session.ChannelMessageSend(channelID, text)
return err
}
func (a *DiscordAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) {
target = strings.TrimSpace(target)
if target == "" {
return nil, fmt.Errorf("discord target is required")
}
discordCfg, err := parseConfig(cfg.Credentials)
if err != nil {
return nil, err
}
session, err := a.getOrCreateSession(discordCfg.BotToken, cfg.ID)
if err != nil {
return nil, err
}
return &discordOutboundStream{
adapter: a,
cfg: cfg,
target: target,
reply: opts.Reply,
session: session,
}, nil
}
func (a *DiscordAdapter) ProcessingStarted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) {
chatID := strings.TrimSpace(info.ReplyTarget)
if chatID == "" {
return channel.ProcessingStatusHandle{}, nil
}
discordCfg, err := parseConfig(cfg.Credentials)
if err != nil {
return channel.ProcessingStatusHandle{}, err
}
session, err := a.getOrCreateSession(discordCfg.BotToken, cfg.ID)
if err != nil {
return channel.ProcessingStatusHandle{}, err
}
// Discord typing indicator
err = session.ChannelTyping(chatID)
return channel.ProcessingStatusHandle{}, err
}
func (a *DiscordAdapter) ProcessingCompleted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error {
return nil
}
func (a *DiscordAdapter) ProcessingFailed(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error {
return nil
}
func (a *DiscordAdapter) React(ctx context.Context, cfg channel.ChannelConfig, target string, messageID string, emoji string) error {
discordCfg, err := parseConfig(cfg.Credentials)
if err != nil {
return err
}
session, err := a.getOrCreateSession(discordCfg.BotToken, cfg.ID)
if err != nil {
return err
}
return session.MessageReactionAdd(target, messageID, emoji)
}
func (a *DiscordAdapter) Unreact(ctx context.Context, cfg channel.ChannelConfig, target string, messageID string, emoji string) error {
discordCfg, err := parseConfig(cfg.Credentials)
if err != nil {
return err
}
session, err := a.getOrCreateSession(discordCfg.BotToken, cfg.ID)
if err != nil {
return err
}
return session.MessageReactionRemove(target, messageID, emoji, "@me")
}
func (a *DiscordAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) {
return normalizeConfig(raw)
}
func (a *DiscordAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) {
return normalizeUserConfig(raw)
}
func (a *DiscordAdapter) NormalizeTarget(raw string) string {
return normalizeTarget(raw)
}
func (a *DiscordAdapter) ResolveTarget(userConfig map[string]any) (string, error) {
return resolveTarget(userConfig)
}
func (a *DiscordAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool {
return matchBinding(config, criteria)
}
func (a *DiscordAdapter) BuildUserConfig(identity channel.Identity) map[string]any {
return buildUserConfig(identity)
}
func (a *DiscordAdapter) collectAttachments(msg *discordgo.Message) []channel.Attachment {
if msg == nil || len(msg.Attachments) == 0 {
return nil
}
attachments := make([]channel.Attachment, 0, len(msg.Attachments))
for _, att := range msg.Attachments {
attachment := channel.Attachment{
Type: channel.AttachmentFile,
URL: att.URL,
PlatformKey: att.ID,
SourcePlatform: Type.String(),
Name: att.Filename,
Size: int64(att.Size),
}
if att.ContentType != "" {
switch {
case strings.HasPrefix(att.ContentType, "image/"):
attachment.Type = channel.AttachmentImage
attachment.Width = att.Width
attachment.Height = att.Height
case strings.HasPrefix(att.ContentType, "video/"):
attachment.Type = channel.AttachmentVideo
case strings.HasPrefix(att.ContentType, "audio/"):
attachment.Type = channel.AttachmentAudio
}
}
attachments = append(attachments, attachment)
}
return attachments
}
func (a *DiscordAdapter) isBotMentioned(msg *discordgo.Message, botID string) bool {
if msg == nil {
return false
}
for _, mention := range msg.Mentions {
if mention != nil && mention.ID == botID {
return true
}
}
if msg.MentionEveryone {
return true
}
botMention := "<@" + botID + ">"
botNickMention := "<@!" + botID + ">"
content := strings.ToLower(msg.Content)
return strings.Contains(content, strings.ToLower(botMention)) ||
strings.Contains(content, strings.ToLower(botNickMention))
}
+174
View File
@@ -0,0 +1,174 @@
package discord
import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/bwmarrin/discordgo"
"github.com/memohai/memoh/internal/channel"
)
type discordOutboundStream struct {
adapter *DiscordAdapter
cfg channel.ChannelConfig
target string
reply *channel.ReplyRef
session *discordgo.Session
closed atomic.Bool
mu sync.Mutex
msgID string
buffer strings.Builder
lastUpdate time.Time
}
func (s *discordOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error {
if s == nil || s.adapter == nil {
return fmt.Errorf("discord stream not configured")
}
if s.closed.Load() {
return fmt.Errorf("discord stream is closed")
}
select {
case <-ctx.Done():
return ctx.Err()
default:
}
switch event.Type {
case channel.StreamEventStatus:
if event.Status == channel.StreamStatusStarted {
return s.ensureMessage(ctx, "Thinking...")
}
return nil
case channel.StreamEventDelta:
if event.Delta == "" {
return nil
}
s.mu.Lock()
s.buffer.WriteString(event.Delta)
s.mu.Unlock()
// Discord has strict rate limits, only update periodically
if time.Since(s.lastUpdate) > 2*time.Second {
return s.updateMessage(ctx)
}
return nil
case channel.StreamEventFinal:
if event.Final != nil && !event.Final.Message.IsEmpty() {
finalText := strings.TrimSpace(event.Final.Message.PlainText())
if finalText != "" {
return s.finalizeMessage(ctx, finalText)
}
}
s.mu.Lock()
finalText := strings.TrimSpace(s.buffer.String())
s.mu.Unlock()
if finalText != "" {
return s.finalizeMessage(ctx, finalText)
}
return nil
case channel.StreamEventError:
errText := strings.TrimSpace(event.Error)
if errText == "" {
return nil
}
return s.finalizeMessage(ctx, "Error: "+errText)
case channel.StreamEventAgentStart, channel.StreamEventAgentEnd, channel.StreamEventPhaseStart, channel.StreamEventPhaseEnd, channel.StreamEventProcessingStarted, channel.StreamEventProcessingCompleted, channel.StreamEventProcessingFailed, channel.StreamEventToolCallStart, channel.StreamEventToolCallEnd:
// Status events - no action needed for Discord
return nil
default:
return fmt.Errorf("unsupported stream event type: %s", event.Type)
}
}
func (s *discordOutboundStream) Close(ctx context.Context) error {
if s == nil {
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
default:
}
s.closed.Store(true)
return nil
}
func (s *discordOutboundStream) ensureMessage(ctx context.Context, text string) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.msgID != "" {
return nil
}
// Discord limit: 2000 characters
content := text
if len(content) > 2000 {
content = content[:1997] + "..."
}
msg, err := s.session.ChannelMessageSend(s.target, content)
if err != nil {
return err
}
s.msgID = msg.ID
s.lastUpdate = time.Now()
return nil
}
func (s *discordOutboundStream) updateMessage(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.msgID == "" {
return nil
}
content := s.buffer.String()
if content == "" {
return nil
}
// Discord limit
if len(content) > 2000 {
content = content[:1997] + "..."
}
_, err := s.session.ChannelMessageEdit(s.target, s.msgID, content)
if err != nil {
return err
}
s.lastUpdate = time.Now()
return nil
}
func (s *discordOutboundStream) finalizeMessage(ctx context.Context, text string) error {
s.mu.Lock()
defer s.mu.Unlock()
// Discord limit
if len(text) > 2000 {
text = text[:1997] + "..."
}
if s.msgID == "" {
_, err := s.session.ChannelMessageSend(s.target, text)
return err
}
_, err := s.session.ChannelMessageEdit(s.target, s.msgID, text)
return err
}