mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat: support discord attacchment file, assetService
This commit is contained in:
+9
-4
@@ -391,13 +391,18 @@ func provideChatResolver(log *slog.Logger, cfg config.Config, modelsService *mod
|
||||
|
||||
func provideChannelRegistry(log *slog.Logger, hub *local.RouteHub, mediaService *media.Service) *channel.Registry {
|
||||
registry := channel.NewRegistry()
|
||||
|
||||
// Telegram
|
||||
tgAdapter := telegram.NewTelegramAdapter(log)
|
||||
tgAdapter.SetAssetOpener(mediaService)
|
||||
registry.MustRegister(tgAdapter)
|
||||
registry.MustRegister(discord.NewDiscordAdapter(log))
|
||||
feishuAdapter := feishu.NewFeishuAdapter(log)
|
||||
feishuAdapter.SetAssetOpener(mediaService)
|
||||
registry.MustRegister(feishuAdapter)
|
||||
|
||||
// Discord
|
||||
discordAdapter := discord.NewDiscordAdapter(log)
|
||||
discordAdapter.SetAssetOpener(mediaService)
|
||||
registry.MustRegister(discordAdapter)
|
||||
|
||||
registry.MustRegister(feishu.NewFeishuAdapter(log))
|
||||
registry.MustRegister(local.NewCLIAdapter(hub))
|
||||
registry.MustRegister(local.NewWebAdapter(hub))
|
||||
return registry
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -11,16 +15,23 @@ import (
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
"github.com/memohai/memoh/internal/channel/adapters/common"
|
||||
"github.com/memohai/memoh/internal/media"
|
||||
)
|
||||
|
||||
const inboundDedupTTL = time.Minute
|
||||
|
||||
// assetOpener reads stored asset bytes by content hash.
|
||||
type assetOpener interface {
|
||||
Open(ctx context.Context, botID, contentHash string) (io.ReadCloser, media.Asset, error)
|
||||
}
|
||||
|
||||
type DiscordAdapter struct {
|
||||
logger *slog.Logger
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*discordgo.Session // keyed by bot token
|
||||
handlerRemovers map[string]func() // keyed by bot token
|
||||
seenMessages map[string]time.Time // keyed by token:messageID
|
||||
assets assetOpener
|
||||
}
|
||||
|
||||
func NewDiscordAdapter(log *slog.Logger) *DiscordAdapter {
|
||||
@@ -35,6 +46,13 @@ func NewDiscordAdapter(log *slog.Logger) *DiscordAdapter {
|
||||
}
|
||||
}
|
||||
|
||||
// SetAssetOpener configures the asset opener for reading stored attachments by content hash.
|
||||
func (a *DiscordAdapter) SetAssetOpener(opener assetOpener) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.assets = opener
|
||||
}
|
||||
|
||||
func (a *DiscordAdapter) Type() channel.ChannelType {
|
||||
return Type
|
||||
}
|
||||
@@ -237,24 +255,54 @@ func (a *DiscordAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, ms
|
||||
return fmt.Errorf("discord target is required")
|
||||
}
|
||||
|
||||
err = sendDiscordText(session, channelID, msg)
|
||||
return err
|
||||
}
|
||||
|
||||
func sendDiscordText(session *discordgo.Session, channelID string, message channel.OutboundMessage) error {
|
||||
textTruncated := truncateDiscordText(message.Message.Text)
|
||||
var err error
|
||||
if message.Message.Reply != nil && message.Message.Reply.MessageID != "" {
|
||||
_, err = session.ChannelMessageSendReply(channelID, textTruncated, &discordgo.MessageReference{
|
||||
ChannelID: channelID,
|
||||
MessageID: message.Message.Reply.MessageID,
|
||||
})
|
||||
} else {
|
||||
_, err = session.ChannelMessageSend(channelID, textTruncated)
|
||||
// Get botID from config metadata if available
|
||||
botID := ""
|
||||
if cfg.BotID != "" {
|
||||
botID = cfg.BotID
|
||||
}
|
||||
|
||||
return err
|
||||
return a.sendDiscordMessage(ctx, session, channelID, botID, msg)
|
||||
}
|
||||
|
||||
func (a *DiscordAdapter) sendDiscordMessage(ctx context.Context, session *discordgo.Session, channelID, botID string, msg channel.OutboundMessage) error {
|
||||
content := truncateDiscordText(msg.Message.Text)
|
||||
|
||||
// Build message send parameters
|
||||
messageSend := &discordgo.MessageSend{
|
||||
Content: content,
|
||||
}
|
||||
|
||||
if msg.Message.Reply != nil && msg.Message.Reply.MessageID != "" {
|
||||
messageSend.Reference = &discordgo.MessageReference{
|
||||
ChannelID: channelID,
|
||||
MessageID: msg.Message.Reply.MessageID,
|
||||
}
|
||||
}
|
||||
|
||||
// Add attachments if present
|
||||
if len(msg.Message.Attachments) > 0 {
|
||||
files := make([]*discordgo.File, 0, len(msg.Message.Attachments))
|
||||
for _, att := range msg.Message.Attachments {
|
||||
file := discordAttachmentToFile(ctx, att, a.assets)
|
||||
if file != nil {
|
||||
files = append(files, file)
|
||||
}
|
||||
}
|
||||
messageSend.Files = files
|
||||
|
||||
// Discord requires non-empty content when sending files only
|
||||
if messageSend.Content == "" && len(messageSend.Files) > 0 {
|
||||
messageSend.Content = "\u200b"
|
||||
}
|
||||
}
|
||||
|
||||
// Validate: must have content or files
|
||||
if messageSend.Content == "" && len(messageSend.Files) == 0 {
|
||||
return fmt.Errorf("cannot send empty message: no content and no valid attachments")
|
||||
}
|
||||
|
||||
_, err := session.ChannelMessageSendComplex(channelID, messageSend)
|
||||
return err
|
||||
}
|
||||
|
||||
func truncateDiscordText(text string) string {
|
||||
@@ -265,6 +313,106 @@ func truncateDiscordText(text string) string {
|
||||
return text
|
||||
}
|
||||
|
||||
// discordAttachmentToFile converts a channel attachment to discordgo.File
|
||||
func discordAttachmentToFile(ctx context.Context, att channel.Attachment, opener assetOpener) *discordgo.File {
|
||||
// Get file name
|
||||
name := att.Name
|
||||
if name == "" {
|
||||
name = "attachment"
|
||||
ext := mimeExtension(att.Mime)
|
||||
if ext != "" {
|
||||
name += ext
|
||||
}
|
||||
}
|
||||
|
||||
var reader io.Reader
|
||||
|
||||
// Prefer bot_id from attachment metadata (allows cross-bot file forwarding)
|
||||
var botID string
|
||||
if att.Metadata != nil {
|
||||
if bid, ok := att.Metadata["bot_id"].(string); ok && bid != "" {
|
||||
botID = bid
|
||||
}
|
||||
}
|
||||
|
||||
// Try asset opener first (for ContentHash from media store)
|
||||
if att.ContentHash != "" && botID != "" && opener != nil {
|
||||
if rc, _, err := opener.Open(ctx, botID, att.ContentHash); err == nil {
|
||||
data, _ := io.ReadAll(rc)
|
||||
rc.Close()
|
||||
if len(data) > 0 {
|
||||
reader = bytes.NewReader(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to Base64
|
||||
if reader == nil && att.Base64 != "" {
|
||||
data, err := base64DataURLToBytes(att.Base64)
|
||||
if err == nil {
|
||||
reader = bytes.NewReader(data)
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to URL
|
||||
if reader == nil && att.URL != "" {
|
||||
resp, err := http.Get(att.URL)
|
||||
if err == nil {
|
||||
defer resp.Body.Close()
|
||||
data, _ := io.ReadAll(resp.Body)
|
||||
reader = bytes.NewReader(data)
|
||||
}
|
||||
}
|
||||
|
||||
if reader == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &discordgo.File{
|
||||
Name: name,
|
||||
Reader: reader,
|
||||
}
|
||||
}
|
||||
|
||||
// base64DataURLToBytes decodes a base64 data URL to bytes
|
||||
func base64DataURLToBytes(dataURL string) ([]byte, error) {
|
||||
parts := strings.SplitN(dataURL, ",", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid data URL")
|
||||
}
|
||||
return base64.StdEncoding.DecodeString(parts[1])
|
||||
}
|
||||
|
||||
// mimeExtension returns file extension for common mime types
|
||||
func mimeExtension(mime string) string {
|
||||
switch mime {
|
||||
case "image/jpeg", "image/jpg":
|
||||
return ".jpg"
|
||||
case "image/png":
|
||||
return ".png"
|
||||
case "image/gif":
|
||||
return ".gif"
|
||||
case "image/webp":
|
||||
return ".webp"
|
||||
case "video/mp4":
|
||||
return ".mp4"
|
||||
case "video/webm":
|
||||
return ".webm"
|
||||
case "audio/mpeg", "audio/mp3":
|
||||
return ".mp3"
|
||||
case "audio/ogg":
|
||||
return ".ogg"
|
||||
case "audio/wav":
|
||||
return ".wav"
|
||||
case "application/pdf":
|
||||
return ".pdf"
|
||||
case "text/plain":
|
||||
return ".txt"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (a *DiscordAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) {
|
||||
target = strings.TrimSpace(target)
|
||||
if target == "" {
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
||||
func TestMimeExtension(t *testing.T) {
|
||||
tests := []struct {
|
||||
mime string
|
||||
want string
|
||||
}{
|
||||
{"image/png", ".png"},
|
||||
{"image/jpeg", ".jpg"},
|
||||
{"image/gif", ".gif"},
|
||||
{"video/mp4", ".mp4"},
|
||||
{"audio/mpeg", ".mp3"},
|
||||
{"application/pdf", ".pdf"},
|
||||
{"unknown/type", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.mime, func(t *testing.T) {
|
||||
got := mimeExtension(tt.mime)
|
||||
if got != tt.want {
|
||||
t.Errorf("mimeExtension(%q) = %q, want %q", tt.mime, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBase64DataURLToBytes(t *testing.T) {
|
||||
// Test valid data URL
|
||||
data, err := base64DataURLToBytes("data:text/plain;base64,SGVsbG8=")
|
||||
if err != nil {
|
||||
t.Errorf("base64DataURLToBytes() error = %v", err)
|
||||
}
|
||||
if string(data) != "Hello" {
|
||||
t.Errorf("base64DataURLToBytes() = %q, want %q", string(data), "Hello")
|
||||
}
|
||||
|
||||
// Test invalid data URL
|
||||
_, err = base64DataURLToBytes("invalid")
|
||||
if err == nil {
|
||||
t.Error("base64DataURLToBytes() expected error for invalid URL")
|
||||
}
|
||||
}
|
||||
@@ -82,6 +82,27 @@ func (s *discordOutboundStream) Push(ctx context.Context, event channel.StreamEv
|
||||
}
|
||||
return s.finalizeMessage("Error: " + errText)
|
||||
|
||||
case channel.StreamEventAttachment:
|
||||
if len(event.Attachments) == 0 {
|
||||
return nil
|
||||
}
|
||||
// Finalize current text message before sending attachments
|
||||
s.mu.Lock()
|
||||
finalText := strings.TrimSpace(s.buffer.String())
|
||||
s.mu.Unlock()
|
||||
if finalText != "" {
|
||||
if err := s.finalizeMessage(finalText); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Send attachments
|
||||
for _, att := range event.Attachments {
|
||||
if err := s.sendAttachment(att); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
case channel.StreamEventAgentStart, channel.StreamEventAgentEnd, channel.StreamEventPhaseStart, channel.StreamEventPhaseEnd, channel.StreamEventProcessingStarted, channel.StreamEventProcessingCompleted, channel.StreamEventProcessingFailed, channel.StreamEventToolCallStart, channel.StreamEventToolCallEnd:
|
||||
// Status events - no action needed for Discord
|
||||
return nil
|
||||
@@ -184,4 +205,27 @@ func (s *discordOutboundStream) finalizeMessage(text string) error {
|
||||
|
||||
_, err := s.session.ChannelMessageEdit(s.target, s.msgID, text)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *discordOutboundStream) sendAttachment(att channel.Attachment) error {
|
||||
ctx := context.Background()
|
||||
file := discordAttachmentToFile(ctx, att, s.adapter.assets)
|
||||
if file == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
messageSend := &discordgo.MessageSend{
|
||||
Files: []*discordgo.File{file},
|
||||
}
|
||||
|
||||
// Add reply reference if this is the first message and we have a reply target
|
||||
if s.reply != nil && s.reply.MessageID != "" {
|
||||
messageSend.Reference = &discordgo.MessageReference{
|
||||
ChannelID: s.target,
|
||||
MessageID: s.reply.MessageID,
|
||||
}
|
||||
}
|
||||
|
||||
_, err := s.session.ChannelMessageSendComplex(s.target, messageSend)
|
||||
return err
|
||||
}
|
||||
Reference in New Issue
Block a user