= {
qq: ['fab', 'qq'],
telegram: ['fab', 'telegram'],
+ matrix: ['fas', 'hashtag'],
feishu: ['fas', 'comment-dots'],
web: ['fas', 'globe'],
slack: ['fab', 'slack'],
diff --git a/cmd/agent/main.go b/cmd/agent/main.go
index 767431ec..5068b92e 100644
--- a/cmd/agent/main.go
+++ b/cmd/agent/main.go
@@ -32,6 +32,7 @@ import (
"github.com/memohai/memoh/internal/channel/adapters/discord"
"github.com/memohai/memoh/internal/channel/adapters/feishu"
"github.com/memohai/memoh/internal/channel/adapters/local"
+ "github.com/memohai/memoh/internal/channel/adapters/matrix"
"github.com/memohai/memoh/internal/channel/adapters/qq"
"github.com/memohai/memoh/internal/channel/adapters/telegram"
"github.com/memohai/memoh/internal/channel/adapters/wecom"
@@ -473,6 +474,9 @@ func provideChannelRegistry(log *slog.Logger, hub *local.RouteHub, mediaService
qqAdapter := qq.NewQQAdapter(log)
qqAdapter.SetAssetOpener(mediaService)
registry.MustRegister(qqAdapter)
+ matrixAdapter := matrix.NewMatrixAdapter(log)
+ matrixAdapter.SetAssetOpener(mediaService)
+ registry.MustRegister(matrixAdapter)
feishuAdapter := feishu.NewFeishuAdapter(log)
feishuAdapter.SetAssetOpener(mediaService)
@@ -553,6 +557,11 @@ func provideChannelRouter(
}
func provideChannelManager(log *slog.Logger, registry *channel.Registry, channelStore *channel.Store, channelRouter *inbound.ChannelInboundProcessor) *channel.Manager {
+ if adapter, ok := registry.Get(matrix.Type); ok {
+ if matrixAdapter, ok := adapter.(*matrix.MatrixAdapter); ok {
+ matrixAdapter.SetSyncStateSaver(channelStore.SaveMatrixSyncSinceToken)
+ }
+ }
mgr := channel.NewManager(log, registry, channelStore, channelRouter)
if mw := channelRouter.IdentityMiddleware(); mw != nil {
mgr.Use(mw)
diff --git a/cmd/memoh/serve.go b/cmd/memoh/serve.go
index 8bb2024c..595c2379 100644
--- a/cmd/memoh/serve.go
+++ b/cmd/memoh/serve.go
@@ -33,6 +33,7 @@ import (
"github.com/memohai/memoh/internal/channel/adapters/discord"
"github.com/memohai/memoh/internal/channel/adapters/feishu"
"github.com/memohai/memoh/internal/channel/adapters/local"
+ "github.com/memohai/memoh/internal/channel/adapters/matrix"
"github.com/memohai/memoh/internal/channel/adapters/qq"
"github.com/memohai/memoh/internal/channel/adapters/telegram"
"github.com/memohai/memoh/internal/channel/adapters/wecom"
@@ -388,6 +389,9 @@ func provideChannelRegistry(log *slog.Logger, hub *local.RouteHub, mediaService
qqAdapter := qq.NewQQAdapter(log)
qqAdapter.SetAssetOpener(mediaService)
registry.MustRegister(qqAdapter)
+ matrixAdapter := matrix.NewMatrixAdapter(log)
+ matrixAdapter.SetAssetOpener(mediaService)
+ registry.MustRegister(matrixAdapter)
feishuAdapter := feishu.NewFeishuAdapter(log)
feishuAdapter.SetAssetOpener(mediaService)
registry.MustRegister(feishuAdapter)
@@ -436,6 +440,11 @@ func provideChannelRouter(log *slog.Logger, registry *channel.Registry, hub *loc
}
func provideChannelManager(log *slog.Logger, registry *channel.Registry, channelStore *channel.Store, channelRouter *inbound.ChannelInboundProcessor) *channel.Manager {
+ if adapter, ok := registry.Get(matrix.Type); ok {
+ if matrixAdapter, ok := adapter.(*matrix.MatrixAdapter); ok {
+ matrixAdapter.SetSyncStateSaver(channelStore.SaveMatrixSyncSinceToken)
+ }
+ }
mgr := channel.NewManager(log, registry, channelStore, channelRouter)
if mw := channelRouter.IdentityMiddleware(); mw != nil {
mgr.Use(mw)
diff --git a/db/queries/channels.sql b/db/queries/channels.sql
index d0634aa5..8269370e 100644
--- a/db/queries/channels.sql
+++ b/db/queries/channels.sql
@@ -39,6 +39,14 @@ SET
WHERE bot_id = $1 AND channel_type = $2
RETURNING id, bot_id, channel_type, credentials, external_identity, self_identity, routing, capabilities, disabled, verified_at, created_at, updated_at;
+-- name: SaveMatrixSyncSinceToken :execrows
+UPDATE bot_channel_configs
+SET routing = COALESCE(routing, '{}'::jsonb) || jsonb_build_object(
+ '_matrix',
+ COALESCE(routing->'_matrix', '{}'::jsonb) || jsonb_build_object('since_token', sqlc.arg(since_token)::text)
+)
+WHERE id = $1;
+
-- name: ListBotChannelConfigsByType :many
SELECT id, bot_id, channel_type, credentials, external_identity, self_identity, routing, capabilities, disabled, verified_at, created_at, updated_at
FROM bot_channel_configs
@@ -65,4 +73,3 @@ SELECT id, user_id, channel_type, config, created_at, updated_at
FROM user_channel_bindings
WHERE channel_type = $1
ORDER BY created_at DESC;
-
diff --git a/go.mod b/go.mod
index e33e9de3..faa51901 100644
--- a/go.mod
+++ b/go.mod
@@ -123,6 +123,7 @@ require (
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.2 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
+ github.com/yuin/goldmark v1.7.13 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 // indirect
diff --git a/internal/channel/adapters/matrix/config.go b/internal/channel/adapters/matrix/config.go
new file mode 100644
index 00000000..43777c4f
--- /dev/null
+++ b/internal/channel/adapters/matrix/config.go
@@ -0,0 +1,213 @@
+package matrix
+
+import (
+ "errors"
+ "strconv"
+ "strings"
+
+ "github.com/memohai/memoh/internal/channel"
+)
+
+type Config struct {
+ HomeserverURL string
+ AccessToken string //nolint:gosec // intentional: operator-supplied Matrix access token in channel config
+ UserID string
+ SyncTimeoutSeconds int
+ AutoJoinInvites bool
+}
+
+type UserConfig struct {
+ RoomID string
+ UserID string
+}
+
+func normalizeConfig(raw map[string]any) (map[string]any, error) {
+ cfg, err := parseConfig(raw)
+ if err != nil {
+ return nil, err
+ }
+ out := map[string]any{
+ "homeserverUrl": cfg.HomeserverURL,
+ "accessToken": cfg.AccessToken,
+ "userId": cfg.UserID,
+ "syncTimeoutSeconds": cfg.SyncTimeoutSeconds,
+ "autoJoinInvites": cfg.AutoJoinInvites,
+ }
+ return out, nil
+}
+
+func normalizeUserConfig(raw map[string]any) (map[string]any, error) {
+ cfg, err := parseUserConfig(raw)
+ if err != nil {
+ return nil, err
+ }
+ out := map[string]any{}
+ if cfg.RoomID != "" {
+ out["room_id"] = cfg.RoomID
+ }
+ if cfg.UserID != "" {
+ out["user_id"] = cfg.UserID
+ }
+ return out, nil
+}
+
+func resolveTarget(raw map[string]any) (string, error) {
+ cfg, err := parseUserConfig(raw)
+ if err != nil {
+ return "", err
+ }
+ if cfg.RoomID != "" {
+ return cfg.RoomID, nil
+ }
+ if cfg.UserID != "" {
+ return cfg.UserID, nil
+ }
+ return "", errors.New("matrix user config requires room_id or user_id")
+}
+
+func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool {
+ cfg, err := parseUserConfig(raw)
+ if err != nil {
+ return false
+ }
+ if cfg.UserID != "" && strings.EqualFold(strings.TrimSpace(criteria.SubjectID), cfg.UserID) {
+ return true
+ }
+ return false
+}
+
+func buildUserConfig(identity channel.Identity) map[string]any {
+ userID := strings.TrimSpace(identity.Attribute("user_id"))
+ if userID == "" {
+ userID = strings.TrimSpace(identity.SubjectID)
+ }
+ if userID == "" {
+ return map[string]any{}
+ }
+ return map[string]any{"user_id": userID}
+}
+
+func parseConfig(raw map[string]any) (Config, error) {
+ homeserverURL := normalizeHomeserverURL(channel.ReadString(raw, "homeserverUrl", "homeserver_url", "homeserver"))
+ accessToken := strings.TrimSpace(channel.ReadString(raw, "accessToken", "access_token"))
+ userID := strings.TrimSpace(channel.ReadString(raw, "userId", "user_id"))
+ if homeserverURL == "" {
+ return Config{}, errors.New("matrix homeserverUrl is required")
+ }
+ if accessToken == "" {
+ return Config{}, errors.New("matrix accessToken is required")
+ }
+ if userID == "" {
+ return Config{}, errors.New("matrix userId is required")
+ }
+ timeout := readInt(raw, 30, "syncTimeoutSeconds", "sync_timeout_seconds")
+ if timeout < 0 {
+ timeout = 0
+ }
+ autoJoinInvites := readBool(raw, true, "autoJoinInvites", "auto_join_invites")
+ return Config{
+ HomeserverURL: homeserverURL,
+ AccessToken: accessToken,
+ UserID: userID,
+ SyncTimeoutSeconds: timeout,
+ AutoJoinInvites: autoJoinInvites,
+ }, nil
+}
+
+func parseUserConfig(raw map[string]any) (UserConfig, error) {
+ roomID := normalizeTarget(channel.ReadString(raw, "roomId", "room_id"))
+ userID := normalizeTarget(channel.ReadString(raw, "userId", "user_id"))
+ if roomID == "" && userID == "" {
+ return UserConfig{}, errors.New("matrix user config requires room_id or user_id")
+ }
+ if roomID != "" && !strings.HasPrefix(roomID, "!") && !strings.HasPrefix(roomID, "#") {
+ return UserConfig{}, errors.New("matrix room_id must start with ! or #")
+ }
+ if userID != "" && !strings.HasPrefix(userID, "@") {
+ return UserConfig{}, errors.New("matrix user_id must start with @")
+ }
+ return UserConfig{RoomID: roomID, UserID: userID}, nil
+}
+
+func normalizeTarget(raw string) string {
+ value := strings.TrimSpace(raw)
+ if value == "" {
+ return ""
+ }
+ for _, prefix := range []string{"matrix:", "room:", "user:"} {
+ if strings.HasPrefix(strings.ToLower(value), prefix) {
+ value = strings.TrimSpace(value[len(prefix):])
+ break
+ }
+ }
+ return value
+}
+
+func normalizeHomeserverURL(raw string) string {
+ value := strings.TrimSpace(raw)
+ return strings.TrimRight(value, "/")
+}
+
+func readInt(raw map[string]any, fallback int, keys ...string) int {
+ for _, key := range keys {
+ value, ok := raw[key]
+ if !ok {
+ continue
+ }
+ switch v := value.(type) {
+ case int:
+ return v
+ case int64:
+ return int(v)
+ case float64:
+ return int(v)
+ case string:
+ parsed, err := strconv.Atoi(strings.TrimSpace(v))
+ if err == nil {
+ return parsed
+ }
+ }
+ }
+ return fallback
+}
+
+func readBool(raw map[string]any, fallback bool, keys ...string) bool {
+ for _, key := range keys {
+ value, ok := raw[key]
+ if !ok {
+ continue
+ }
+ switch v := value.(type) {
+ case bool:
+ return v
+ case string:
+ switch strings.ToLower(strings.TrimSpace(v)) {
+ case "true", "1", "yes", "on":
+ return true
+ case "false", "0", "no", "off":
+ return false
+ }
+ }
+ }
+ return fallback
+}
+
+func targetKind(target string) string {
+ value := normalizeTarget(target)
+ switch {
+ case strings.HasPrefix(value, "!") || strings.HasPrefix(value, "#"):
+ return "room"
+ case strings.HasPrefix(value, "@"):
+ return "user"
+ default:
+ return ""
+ }
+}
+
+func validateTarget(target string) error {
+ kind := targetKind(target)
+ if kind == "" {
+ return errors.New("matrix target must be a room id/alias or user id")
+ }
+ return nil
+}
diff --git a/internal/channel/adapters/matrix/config_test.go b/internal/channel/adapters/matrix/config_test.go
new file mode 100644
index 00000000..1c64f2f5
--- /dev/null
+++ b/internal/channel/adapters/matrix/config_test.go
@@ -0,0 +1,61 @@
+package matrix
+
+import "testing"
+
+func TestParseConfig(t *testing.T) {
+ cfg, err := parseConfig(map[string]any{
+ "homeserverUrl": "https://matrix.example.com/",
+ "accessToken": "tok",
+ "userId": "@memoh:example.com",
+ "syncTimeoutSeconds": 15,
+ "autoJoinInvites": false,
+ })
+ if err != nil {
+ t.Fatalf("parseConfig returned error: %v", err)
+ }
+ if cfg.HomeserverURL != "https://matrix.example.com" {
+ t.Fatalf("unexpected homeserver url: %q", cfg.HomeserverURL)
+ }
+ if cfg.UserID != "@memoh:example.com" {
+ t.Fatalf("unexpected user id: %q", cfg.UserID)
+ }
+ if cfg.SyncTimeoutSeconds != 15 {
+ t.Fatalf("unexpected sync timeout: %d", cfg.SyncTimeoutSeconds)
+ }
+ if cfg.AutoJoinInvites {
+ t.Fatal("expected autoJoinInvites to be false")
+ }
+}
+
+func TestParseConfigDefaultsAutoJoinInvites(t *testing.T) {
+ cfg, err := parseConfig(map[string]any{
+ "homeserverUrl": "https://matrix.example.com",
+ "accessToken": "tok",
+ "userId": "@memoh:example.com",
+ })
+ if err != nil {
+ t.Fatalf("parseConfig returned error: %v", err)
+ }
+ if !cfg.AutoJoinInvites {
+ t.Fatal("expected autoJoinInvites default to true")
+ }
+}
+
+func TestParseUserConfigRequiresTarget(t *testing.T) {
+ if _, err := parseUserConfig(map[string]any{}); err == nil {
+ t.Fatal("expected parseUserConfig to fail")
+ }
+}
+
+func TestResolveTargetPrefersRoomID(t *testing.T) {
+ target, err := resolveTarget(map[string]any{
+ "room_id": "!room:example.com",
+ "user_id": "@alice:example.com",
+ })
+ if err != nil {
+ t.Fatalf("resolveTarget returned error: %v", err)
+ }
+ if target != "!room:example.com" {
+ t.Fatalf("unexpected target: %q", target)
+ }
+}
diff --git a/internal/channel/adapters/matrix/markdown.go b/internal/channel/adapters/matrix/markdown.go
new file mode 100644
index 00000000..1e780021
--- /dev/null
+++ b/internal/channel/adapters/matrix/markdown.go
@@ -0,0 +1,147 @@
+package matrix
+
+import (
+ "bytes"
+ "regexp"
+ "strings"
+
+ "github.com/yuin/goldmark"
+ "github.com/yuin/goldmark/extension"
+ "github.com/yuin/goldmark/renderer/html"
+
+ "github.com/memohai/memoh/internal/channel"
+)
+
+const matrixHTMLFormat = "org.matrix.custom.html"
+
+var matrixMarkdownRenderer = goldmark.New(
+ goldmark.WithExtensions(extension.GFM),
+ goldmark.WithRendererOptions(
+ html.WithHardWraps(),
+ ),
+)
+
+type matrixFormattedMessage struct {
+ Body string
+ FormattedBody string
+ HasHTML bool
+}
+
+var (
+ matrixTaskListPattern = regexp.MustCompile(`^(\s*(?:[-*+]\s+|\d+\.\s+))\[( |x|X)\]\s+(.*)$`)
+ matrixTableAlignCell = regexp.MustCompile(`^:?-{3,}:?$`)
+)
+
+func formatMatrixMessage(msg channel.Message) matrixFormattedMessage {
+ body := strings.TrimSpace(msg.PlainText())
+ formatted := matrixFormattedMessage{Body: body}
+ if msg.Format != channel.MessageFormatMarkdown || body == "" {
+ return formatted
+ }
+ body = normalizeMatrixMarkdown(body)
+ formatted.Body = body
+ htmlBody, err := renderMatrixMarkdown(body)
+ if err != nil || strings.TrimSpace(htmlBody) == "" {
+ return formatted
+ }
+ formatted.FormattedBody = htmlBody
+ formatted.HasHTML = true
+ return formatted
+}
+
+func renderMatrixMarkdown(text string) (string, error) {
+ text = strings.TrimSpace(text)
+ if text == "" {
+ return "", nil
+ }
+ var buf bytes.Buffer
+ if err := matrixMarkdownRenderer.Convert([]byte(text), &buf); err != nil {
+ return "", err
+ }
+ return strings.TrimSpace(buf.String()), nil
+}
+
+func normalizeMatrixMarkdown(text string) string {
+ text = strings.TrimSpace(text)
+ if text == "" {
+ return ""
+ }
+ lines := strings.Split(text, "\n")
+ result := make([]string, 0, len(lines))
+ inFence := false
+ for i := 0; i < len(lines); i++ {
+ line := lines[i]
+ trimmed := strings.TrimSpace(line)
+ if isFenceLine(trimmed) {
+ inFence = !inFence
+ result = append(result, line)
+ continue
+ }
+ if !inFence && i+1 < len(lines) && isMarkdownTableHeader(line, lines[i+1]) {
+ block := []string{line, lines[i+1]}
+ i += 2
+ for i < len(lines) && isMarkdownTableRow(lines[i]) {
+ block = append(block, lines[i])
+ i++
+ }
+ i--
+ result = append(result, "```text")
+ result = append(result, block...)
+ result = append(result, "```")
+ continue
+ }
+ if !inFence {
+ line = normalizeMatrixTaskListLine(line)
+ }
+ result = append(result, line)
+ }
+ return strings.TrimSpace(strings.Join(result, "\n"))
+}
+
+func normalizeMatrixTaskListLine(line string) string {
+ matches := matrixTaskListPattern.FindStringSubmatch(line)
+ if len(matches) != 4 {
+ return line
+ }
+ box := "☐"
+ if strings.EqualFold(matches[2], "x") {
+ box = "☑"
+ }
+ return matches[1] + box + " " + matches[3]
+}
+
+func isFenceLine(line string) bool {
+ return strings.HasPrefix(line, "```") || strings.HasPrefix(line, "~~~")
+}
+
+func isMarkdownTableHeader(headerLine, delimiterLine string) bool {
+ if !strings.Contains(headerLine, "|") {
+ return false
+ }
+ return isMarkdownTableDelimiter(delimiterLine)
+}
+
+func isMarkdownTableDelimiter(line string) bool {
+ trimmed := strings.TrimSpace(line)
+ if !strings.Contains(trimmed, "|") {
+ return false
+ }
+ parts := strings.Split(trimmed, "|")
+ validCells := 0
+ for _, part := range parts {
+ cell := strings.TrimSpace(part)
+ if cell == "" {
+ continue
+ }
+ if !matrixTableAlignCell.MatchString(cell) {
+ return false
+ }
+ validCells++
+ }
+ return validCells >= 1
+}
+
+func isMarkdownTableRow(line string) bool {
+ trimmed := strings.TrimSpace(line)
+ return trimmed != "" && strings.Contains(trimmed, "|")
+}
diff --git a/internal/channel/adapters/matrix/markdown_test.go b/internal/channel/adapters/matrix/markdown_test.go
new file mode 100644
index 00000000..9e90255b
--- /dev/null
+++ b/internal/channel/adapters/matrix/markdown_test.go
@@ -0,0 +1,40 @@
+package matrix
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/memohai/memoh/internal/channel"
+)
+
+func TestNormalizeMatrixMarkdownTaskList(t *testing.T) {
+ input := "- [ ] todo\n- [x] done"
+ got := normalizeMatrixMarkdown(input)
+ if got != "- ☐ todo\n- ☑ done" {
+ t.Fatalf("unexpected normalized markdown: %q", got)
+ }
+}
+
+func TestNormalizeMatrixMarkdownTablesBecomeCodeBlocks(t *testing.T) {
+ input := "| A | B |\n| --- | --- |\n| 1 | 2 |"
+ got := normalizeMatrixMarkdown(input)
+ if !strings.HasPrefix(got, "```text\n") || !strings.Contains(got, "| 1 | 2 |") || !strings.HasSuffix(got, "\n```") {
+ t.Fatalf("unexpected normalized markdown: %q", got)
+ }
+}
+
+func TestFormatMatrixMessageMarkdownUsesNormalizedBody(t *testing.T) {
+ formatted := formatMatrixMessage(channel.Message{
+ Text: "- [x] done\n\n| A |\n| --- |\n| 1 |",
+ Format: channel.MessageFormatMarkdown,
+ })
+ if !strings.Contains(formatted.Body, "☑ done") {
+ t.Fatalf("expected task list checkbox in body, got %q", formatted.Body)
+ }
+ if !strings.Contains(formatted.FormattedBody, "= http.StatusMultipleChoices {
+ return fmt.Errorf("matrix homeserver check failed: %s", matrixHTTPErrorSummary(statusCode, data))
+ }
+ var resp matrixVersionsResponse
+ if err := json.Unmarshal(data, &resp); err != nil {
+ return fmt.Errorf("matrix homeserver check failed: invalid /versions response: %w", err)
+ }
+ if len(resp.Versions) == 0 {
+ return errors.New("matrix homeserver check failed: /_matrix/client/versions returned no supported versions")
+ }
+ return nil
+}
+
+func (a *MatrixAdapter) validateAccessToken(ctx context.Context, cfg Config) (matrixWhoAmIResponse, error) {
+ data, _, statusCode, err := a.performRequest(ctx, http.MethodGet, cfg.HomeserverURL+"/_matrix/client/v3/account/whoami", nil, "", cfg.AccessToken)
+ if err != nil {
+ return matrixWhoAmIResponse{}, fmt.Errorf("matrix access token check failed: %w", err)
+ }
+ if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
+ return matrixWhoAmIResponse{}, fmt.Errorf("matrix access token check failed: %s", matrixHTTPErrorSummary(statusCode, data))
+ }
+ var resp matrixWhoAmIResponse
+ if err := json.Unmarshal(data, &resp); err != nil {
+ return matrixWhoAmIResponse{}, fmt.Errorf("matrix access token check failed: invalid /account/whoami response: %w", err)
+ }
+ return resp, nil
+}
+
+func matrixHTTPErrorSummary(statusCode int, data []byte) string {
+ var resp matrixErrorResponse
+ if err := json.Unmarshal(data, &resp); err == nil {
+ message := strings.TrimSpace(resp.Error)
+ errCode := strings.TrimSpace(resp.ErrCode)
+ switch {
+ case message != "" && errCode != "":
+ return fmt.Sprintf("%s (%s, HTTP %d)", message, errCode, statusCode)
+ case message != "":
+ return fmt.Sprintf("%s (HTTP %d)", message, statusCode)
+ case errCode != "":
+ return fmt.Sprintf("%s (HTTP %d)", errCode, statusCode)
+ }
+ }
+ message := strings.TrimSpace(string(data))
+ if message == "" {
+ return fmt.Sprintf("HTTP %d", statusCode)
+ }
+ return fmt.Sprintf("%s (HTTP %d)", textutil.TruncateRunes(message, 300), statusCode)
+}
+
+func (a *MatrixAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error {
+ if msg.Message.IsEmpty() {
+ return errors.New("message is required")
+ }
+ parsed, err := parseConfig(cfg.Credentials)
+ if err != nil {
+ return err
+ }
+ roomID, err := a.resolveRoomTarget(ctx, parsed, msg.Target)
+ if err != nil {
+ return err
+ }
+ text := strings.TrimSpace(msg.Message.PlainText())
+ if text != "" {
+ textMsg := msg.Message
+ textMsg.Attachments = nil
+ textMsg.Text = text
+ textMsg.Parts = nil
+ if _, err := a.sendTextEvent(ctx, parsed, roomID, buildMatrixMessageContent(textMsg, false, "")); err != nil {
+ return err
+ }
+ }
+ for i, att := range msg.Message.Attachments {
+ mediaMsg := channel.Message{}
+ if text == "" && i == 0 {
+ mediaMsg.Reply = msg.Message.Reply
+ }
+ if err := a.sendMediaAttachment(ctx, parsed, roomID, cfg.BotID, mediaMsg, att); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (a *MatrixAdapter) OpenStream(_ context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) {
+ if err := validateTarget(target); err != nil {
+ return nil, err
+ }
+ parsed, err := parseConfig(cfg.Credentials)
+ if err != nil {
+ return nil, err
+ }
+ reply := opts.Reply
+ if reply == nil && strings.TrimSpace(opts.SourceMessageID) != "" {
+ reply = &channel.ReplyRef{Target: normalizeTarget(target), MessageID: strings.TrimSpace(opts.SourceMessageID)}
+ }
+ return &matrixOutboundStream{
+ adapter: a,
+ cfg: parsed,
+ target: normalizeTarget(target),
+ reply: reply,
+ }, nil
+}
+
+func (a *MatrixAdapter) Update(ctx context.Context, cfg channel.ChannelConfig, target string, messageID string, msg channel.Message) error {
+ parsed, err := parseConfig(cfg.Credentials)
+ if err != nil {
+ return err
+ }
+ roomID, err := a.resolveRoomTarget(ctx, parsed, target)
+ if err != nil {
+ return err
+ }
+ _, err = a.sendTextEvent(ctx, parsed, roomID, buildMatrixMessageContent(msg, true, strings.TrimSpace(messageID)))
+ return err
+}
+
+func (*MatrixAdapter) Unsend(context.Context, channel.ChannelConfig, string, string) error {
+ return errors.New("matrix unsend not supported")
+}
+
+func (a *MatrixAdapter) runSyncLoop(ctx context.Context, cfg channel.ChannelConfig, parsed Config, handler channel.InboundHandler) {
+ backoffs := []time.Duration{time.Second, 2 * time.Second, 5 * time.Second, 10 * time.Second, 20 * time.Second}
+ attempt := 0
+ since := matrixSinceTokenFromRouting(cfg.Routing)
+ persistedSince := since
+ if strings.TrimSpace(since) == "" {
+ bootstrapSince, err := a.bootstrapSinceToken(ctx, cfg, parsed)
+ if err != nil {
+ if a.logger != nil {
+ a.logger.Warn("matrix sync bootstrap failed", slog.String("config_id", cfg.ID), slog.Any("error", err))
+ }
+ } else if bootstrapSince != "" {
+ since = bootstrapSince
+ persistedSince = bootstrapSince
+ }
+ }
+ for ctx.Err() == nil {
+ nextSince, healthy, err := a.syncOnce(ctx, cfg, parsed, since, handler)
+ if strings.TrimSpace(nextSince) != "" {
+ since = nextSince
+ }
+ if err == nil && strings.TrimSpace(since) != "" && since != persistedSince {
+ if saveErr := a.persistSinceToken(ctx, cfg.ID, since); saveErr != nil {
+ if a.logger != nil {
+ a.logger.Warn("matrix sync cursor persist failed", slog.String("config_id", cfg.ID), slog.Bool("healthy", healthy), slog.Any("error", saveErr))
+ }
+ } else {
+ persistedSince = since
+ }
+ }
+ if err == nil || ctx.Err() != nil {
+ attempt = 0
+ continue
+ }
+ if a.logger != nil {
+ a.logger.Warn("matrix sync reconnect", slog.String("config_id", cfg.ID), slog.Any("error", err))
+ }
+ delay, nextAttempt := nextReconnectDelay(backoffs, attempt, healthy)
+ attempt = nextAttempt
+ if !sleepContext(ctx, delay) {
+ return
+ }
+ }
+}
+
+func (a *MatrixAdapter) bootstrapSinceToken(ctx context.Context, cfg channel.ChannelConfig, parsed Config) (string, error) {
+ var resp matrixSyncResponse
+ if err := a.doJSON(ctx, parsed, http.MethodGet, "/_matrix/client/v3/sync?timeout=0", nil, &resp); err != nil {
+ return "", err
+ }
+ if _, err := a.handleInvites(ctx, cfg, parsed, resp); err != nil {
+ return "", err
+ }
+ a.rememberSyncResponseRoomTypes(cfg.ID, parsed, resp)
+ a.rememberSyncResponseEvents(cfg.ID, resp)
+ since := strings.TrimSpace(resp.NextBatch)
+ if since == "" {
+ return "", nil
+ }
+ if err := a.persistSinceToken(ctx, cfg.ID, since); err != nil {
+ return "", err
+ }
+ if a.logger != nil {
+ a.logger.Info("matrix sync cursor bootstrapped", slog.String("config_id", cfg.ID))
+ }
+ return since, nil
+}
+
+func (a *MatrixAdapter) rememberSyncResponseEvents(configID string, resp matrixSyncResponse) {
+ configID = strings.TrimSpace(configID)
+ if configID == "" {
+ return
+ }
+ for _, joined := range resp.Rooms.Join {
+ for _, evt := range joined.Timeline.Events {
+ a.seenEvent(configID, evt.EventID)
+ }
+ }
+}
+
+func (a *MatrixAdapter) persistSinceToken(ctx context.Context, configID string, since string) error {
+ if a == nil || a.saveSince == nil {
+ return nil
+ }
+ configID = strings.TrimSpace(configID)
+ since = strings.TrimSpace(since)
+ if configID == "" || since == "" {
+ return nil
+ }
+ return a.saveSince(ctx, configID, since)
+}
+
+func (a *MatrixAdapter) syncOnce(ctx context.Context, cfg channel.ChannelConfig, parsed Config, since string, handler channel.InboundHandler) (string, bool, error) {
+ query := url.Values{}
+ query.Set("timeout", strconv.Itoa(parsed.SyncTimeoutSeconds*1000))
+ if strings.TrimSpace(since) != "" {
+ query.Set("since", since)
+ }
+ var resp matrixSyncResponse
+ if err := a.doJSON(ctx, parsed, http.MethodGet, "/_matrix/client/v3/sync?"+query.Encode(), nil, &resp); err != nil {
+ return since, false, err
+ }
+ a.rememberSyncResponseRoomTypes(cfg.ID, parsed, resp)
+ healthy := false
+ joinedInvite, err := a.handleInvites(ctx, cfg, parsed, resp)
+ if err != nil {
+ return resp.NextBatch, healthy, err
+ }
+ healthy = healthy || joinedInvite
+ for roomID, joined := range resp.Rooms.Join {
+ for _, evt := range joined.Timeline.Events {
+ evt.RoomID = roomID
+ delivered, err := a.handleEvent(ctx, cfg, parsed, evt, handler)
+ if err != nil {
+ return resp.NextBatch, healthy, err
+ }
+ healthy = healthy || delivered
+ }
+ }
+ return resp.NextBatch, healthy, nil
+}
+
+func (a *MatrixAdapter) handleInvites(ctx context.Context, cfg channel.ChannelConfig, parsed Config, resp matrixSyncResponse) (bool, error) {
+ joinedAny := false
+ for roomID := range resp.Rooms.Invite {
+ roomID = strings.TrimSpace(roomID)
+ if roomID == "" {
+ continue
+ }
+ if !parsed.AutoJoinInvites {
+ if a.logger != nil {
+ a.logger.Info("matrix invite skipped",
+ slog.String("config_id", cfg.ID),
+ slog.String("room_id", roomID),
+ slog.String("reason", "auto_join_disabled"),
+ )
+ }
+ continue
+ }
+ if err := a.joinRoom(ctx, parsed, roomID); err != nil {
+ return joinedAny, err
+ }
+ joinedAny = true
+ if a.logger != nil {
+ a.logger.Info("matrix room auto-joined",
+ slog.String("config_id", cfg.ID),
+ slog.String("room_id", roomID),
+ )
+ }
+ }
+ return joinedAny, nil
+}
+
+func (a *MatrixAdapter) handleEvent(ctx context.Context, cfg channel.ChannelConfig, parsed Config, evt matrixEvent, handler channel.InboundHandler) (bool, error) {
+ if evt.Type != "m.room.message" {
+ return false, nil
+ }
+ if strings.TrimSpace(evt.Sender) == "" || strings.EqualFold(strings.TrimSpace(evt.Sender), parsed.UserID) {
+ return false, nil
+ }
+ if a.seenEvent(cfg.ID, evt.EventID) {
+ return false, nil
+ }
+ if isMatrixEditEvent(evt.Content) {
+ return false, nil
+ }
+ body, attachments := extractMatrixInboundContent(evt.Content)
+ if body == "" && len(attachments) == 0 {
+ return false, nil
+ }
+ isMentioned := isMatrixBotMentioned(parsed.UserID, evt.Content)
+ replyTo := readReplyToEventID(evt.Content)
+ if replyTo != "" {
+ body = stripMatrixReplyFallback(body)
+ }
+ rawText := body
+ isReplyToBot := false
+ if replyTo != "" {
+ repliedEvent, err := a.fetchRoomEvent(ctx, parsed, evt.RoomID, replyTo)
+ if err != nil {
+ if a.logger != nil {
+ a.logger.Warn("failed to fetch matrix replied event",
+ slog.String("config_id", cfg.ID),
+ slog.String("room_id", evt.RoomID),
+ slog.String("reply_to", replyTo),
+ slog.Any("error", err),
+ )
+ }
+ } else {
+ if quotedText := buildMatrixQuotedText(repliedEvent); quotedText != "" {
+ if body != "" {
+ body = quotedText + "\n" + body
+ } else {
+ body = quotedText
+ }
+ }
+ if quotedAttachments := matrixQuotedAttachments(repliedEvent); len(quotedAttachments) > 0 {
+ attachments = append(attachments, quotedAttachments...)
+ }
+ isReplyToBot = strings.EqualFold(strings.TrimSpace(repliedEvent.Sender), parsed.UserID)
+ }
+ }
+ conversationType := a.resolveConversationType(ctx, cfg.ID, parsed, evt.RoomID)
+ msg := channel.InboundMessage{
+ Channel: Type,
+ BotID: cfg.BotID,
+ ReplyTarget: evt.RoomID,
+ Message: channel.Message{
+ ID: strings.TrimSpace(evt.EventID),
+ Format: channel.MessageFormatPlain,
+ Text: body,
+ Attachments: attachments,
+ },
+ Sender: channel.Identity{
+ SubjectID: strings.TrimSpace(evt.Sender),
+ DisplayName: matrixDisplayName(evt),
+ Attributes: map[string]string{
+ "user_id": strings.TrimSpace(evt.Sender),
+ "room_id": strings.TrimSpace(evt.RoomID),
+ },
+ },
+ Conversation: channel.Conversation{
+ ID: strings.TrimSpace(evt.RoomID),
+ Type: conversationType,
+ Metadata: map[string]any{
+ "room_id": strings.TrimSpace(evt.RoomID),
+ },
+ },
+ ReceivedAt: matrixEventTime(evt.OriginServerTS),
+ Source: "matrix",
+ Metadata: map[string]any{
+ "room_id": strings.TrimSpace(evt.RoomID),
+ "event_id": strings.TrimSpace(evt.EventID),
+ "sender": strings.TrimSpace(evt.Sender),
+ "msgtype": channel.ReadString(evt.Content, "msgtype"),
+ "raw_text": rawText,
+ "attachments": len(attachments),
+ "is_mentioned": isMentioned,
+ "is_reply_to_bot": isReplyToBot,
+ },
+ }
+ if replyTo != "" {
+ msg.Message.Reply = &channel.ReplyRef{Target: evt.RoomID, MessageID: replyTo}
+ }
+ if a.logger != nil {
+ a.logger.Info("inbound received",
+ slog.String("config_id", cfg.ID),
+ slog.String("room_id", evt.RoomID),
+ slog.String("sender", evt.Sender),
+ slog.Bool("is_mentioned", isMentioned),
+ slog.String("text", common.SummarizeText(body)),
+ )
+ }
+ return true, handler(ctx, cfg, msg)
+}
+
+func (a *MatrixAdapter) fetchRoomEvent(ctx context.Context, cfg Config, roomID, eventID string) (matrixEvent, error) {
+ path := fmt.Sprintf("/_matrix/client/v3/rooms/%s/event/%s", url.PathEscape(strings.TrimSpace(roomID)), url.PathEscape(strings.TrimSpace(eventID)))
+ var evt matrixEvent
+ if err := a.doJSON(ctx, cfg, http.MethodGet, path, nil, &evt); err != nil {
+ return matrixEvent{}, err
+ }
+ evt.RoomID = strings.TrimSpace(roomID)
+ return evt, nil
+}
+
+func (a *MatrixAdapter) resolveConversationType(ctx context.Context, configID string, cfg Config, roomID string) string {
+ if conversationType, ok := a.cachedRoomConversationType(configID, roomID); ok {
+ return conversationType
+ }
+ isDirect, err := a.isDirectRoom(ctx, cfg, roomID)
+ if err != nil {
+ if a.logger != nil {
+ a.logger.Warn("failed to resolve matrix room type",
+ slog.String("config_id", configID),
+ slog.String("room_id", strings.TrimSpace(roomID)),
+ slog.Any("error", err),
+ )
+ }
+ return "group"
+ }
+ conversationType := "group"
+ if isDirect {
+ conversationType = "direct"
+ }
+ a.rememberRoomConversationType(configID, roomID, conversationType)
+ return conversationType
+}
+
+func (a *MatrixAdapter) isDirectRoom(ctx context.Context, cfg Config, roomID string) (bool, error) {
+ path := fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", url.PathEscape(strings.TrimSpace(roomID)))
+ var resp matrixJoinedMembersResponse
+ if err := a.doJSON(ctx, cfg, http.MethodGet, path, nil, &resp); err != nil {
+ return false, err
+ }
+ return len(resp.Joined) == 2, nil
+}
+
+func (a *MatrixAdapter) rememberSyncResponseRoomTypes(configID string, cfg Config, resp matrixSyncResponse) {
+ a.rememberSyncDirectRooms(cfg, resp)
+ configID = strings.TrimSpace(configID)
+ if configID == "" {
+ return
+ }
+ directRooms := extractMatrixDirectRoomIDs(resp)
+ for roomID, joined := range resp.Rooms.Join {
+ roomID = strings.TrimSpace(roomID)
+ if roomID == "" {
+ continue
+ }
+ if _, ok := directRooms[roomID]; ok {
+ a.rememberRoomConversationType(configID, roomID, "direct")
+ continue
+ }
+ if conversationType := matrixConversationTypeFromSummary(joined.Summary); conversationType != "" {
+ a.rememberRoomConversationType(configID, roomID, conversationType)
+ }
+ }
+}
+
+func extractMatrixDirectRooms(resp matrixSyncResponse) map[string]string {
+ directRooms := make(map[string]string)
+ for _, evt := range resp.AccountData.Events {
+ if strings.TrimSpace(evt.Type) != "m.direct" {
+ continue
+ }
+ for userID, rawRoomIDs := range evt.Content {
+ userID = strings.TrimSpace(userID)
+ if userID == "" {
+ continue
+ }
+ for _, roomID := range matrixStringList(rawRoomIDs) {
+ roomID = strings.TrimSpace(roomID)
+ if roomID == "" {
+ continue
+ }
+ directRooms[userID] = roomID
+ break
+ }
+ }
+ }
+ return directRooms
+}
+
+func (a *MatrixAdapter) rememberSyncDirectRooms(cfg Config, resp matrixSyncResponse) {
+ for userID, roomID := range extractMatrixDirectRooms(resp) {
+ a.rememberDirectRoomForConfig(cfg, userID, roomID)
+ }
+}
+
+func extractMatrixDirectRoomIDs(resp matrixSyncResponse) map[string]struct{} {
+ directRooms := make(map[string]struct{})
+ for _, evt := range resp.AccountData.Events {
+ if strings.TrimSpace(evt.Type) != "m.direct" {
+ continue
+ }
+ for _, rawRoomIDs := range evt.Content {
+ for _, roomID := range matrixStringList(rawRoomIDs) {
+ roomID = strings.TrimSpace(roomID)
+ if roomID == "" {
+ continue
+ }
+ directRooms[roomID] = struct{}{}
+ }
+ }
+ }
+ return directRooms
+}
+
+func matrixConversationTypeFromSummary(summary matrixRoomSummary) string {
+ totalMembers := summary.JoinedMemberCount + summary.InvitedMemberCount
+ switch {
+ case totalMembers == 2:
+ return "direct"
+ case totalMembers > 2:
+ return "group"
+ default:
+ return ""
+ }
+}
+
+func matrixStringList(raw any) []string {
+ switch value := raw.(type) {
+ case []string:
+ result := make([]string, 0, len(value))
+ for _, item := range value {
+ trimmed := strings.TrimSpace(item)
+ if trimmed != "" {
+ result = append(result, trimmed)
+ }
+ }
+ return result
+ case []any:
+ result := make([]string, 0, len(value))
+ for _, item := range value {
+ text, ok := item.(string)
+ if !ok {
+ continue
+ }
+ trimmed := strings.TrimSpace(text)
+ if trimmed != "" {
+ result = append(result, trimmed)
+ }
+ }
+ return result
+ default:
+ return nil
+ }
+}
+
+func (a *MatrixAdapter) cachedRoomConversationType(configID, roomID string) (string, bool) {
+ a.roomTypeMu.Lock()
+ defer a.roomTypeMu.Unlock()
+ rooms, ok := a.roomTypes[strings.TrimSpace(configID)]
+ if !ok {
+ return "", false
+ }
+ conversationType, ok := rooms[strings.TrimSpace(roomID)]
+ if !ok || strings.TrimSpace(conversationType) == "" {
+ return "", false
+ }
+ return conversationType, true
+}
+
+func (a *MatrixAdapter) rememberRoomConversationType(configID, roomID, conversationType string) {
+ configID = strings.TrimSpace(configID)
+ roomID = strings.TrimSpace(roomID)
+ conversationType = strings.TrimSpace(conversationType)
+ if configID == "" || roomID == "" || conversationType == "" {
+ return
+ }
+ a.roomTypeMu.Lock()
+ defer a.roomTypeMu.Unlock()
+ rooms, ok := a.roomTypes[configID]
+ if !ok {
+ rooms = make(map[string]string)
+ a.roomTypes[configID] = rooms
+ }
+ rooms[roomID] = conversationType
+}
+
+func buildMatrixMessageContent(msg channel.Message, edit bool, originalEventID string) map[string]any {
+ formatted := formatMatrixMessage(msg)
+ body := formatted.Body
+ content := map[string]any{
+ "msgtype": "m.notice",
+ "body": body,
+ }
+ if formatted.HasHTML {
+ content["format"] = matrixHTMLFormat
+ content["formatted_body"] = formatted.FormattedBody
+ }
+ if msg.Reply != nil && strings.TrimSpace(msg.Reply.MessageID) != "" && !edit {
+ content["m.relates_to"] = map[string]any{
+ "m.in_reply_to": map[string]any{
+ "event_id": strings.TrimSpace(msg.Reply.MessageID),
+ },
+ }
+ }
+ if edit && strings.TrimSpace(originalEventID) != "" {
+ newContent := map[string]any{
+ "msgtype": "m.notice",
+ "body": body,
+ }
+ if formatted.HasHTML {
+ newContent["format"] = matrixHTMLFormat
+ newContent["formatted_body"] = formatted.FormattedBody
+ }
+ content["m.new_content"] = newContent
+ content["m.relates_to"] = map[string]any{
+ "rel_type": "m.replace",
+ "event_id": strings.TrimSpace(originalEventID),
+ }
+ content["body"] = "* " + body
+ }
+ return content
+}
+
+func buildMatrixMediaContent(msg channel.Message, att channel.Attachment, contentURI string) map[string]any {
+ body := matrixAttachmentBody(att)
+ content := map[string]any{
+ "msgtype": matrixAttachmentMsgType(att.Type),
+ "body": body,
+ "url": strings.TrimSpace(contentURI),
+ }
+ if filename := strings.TrimSpace(att.Name); filename != "" {
+ content["filename"] = filename
+ }
+ info := matrixAttachmentInfo(att)
+ if len(info) > 0 {
+ content["info"] = info
+ }
+ if msg.Reply != nil && strings.TrimSpace(msg.Reply.MessageID) != "" {
+ content["m.relates_to"] = map[string]any{
+ "m.in_reply_to": map[string]any{
+ "event_id": strings.TrimSpace(msg.Reply.MessageID),
+ },
+ }
+ }
+ return content
+}
+
+func isMatrixEditEvent(content map[string]any) bool {
+ if _, ok := content["m.new_content"]; ok {
+ return true
+ }
+ relatesTo, ok := content["m.relates_to"].(map[string]any)
+ if !ok {
+ return false
+ }
+ return strings.EqualFold(strings.TrimSpace(channel.ReadString(relatesTo, "rel_type")), "m.replace")
+}
+
+func readReplyToEventID(content map[string]any) string {
+ relatesTo, ok := content["m.relates_to"].(map[string]any)
+ if !ok {
+ return ""
+ }
+ inReplyTo, ok := relatesTo["m.in_reply_to"].(map[string]any)
+ if !ok {
+ return ""
+ }
+ return strings.TrimSpace(channel.ReadString(inReplyTo, "event_id"))
+}
+
+func extractMatrixInboundContent(content map[string]any) (string, []channel.Attachment) {
+ msgType := strings.TrimSpace(channel.ReadString(content, "msgtype"))
+ if !isMatrixAttachmentMsgType(msgType) {
+ return strings.TrimSpace(channel.ReadString(content, "body")), nil
+ }
+ att, ok := matrixAttachmentFromContent(content, msgType)
+ if !ok {
+ return strings.TrimSpace(channel.ReadString(content, "body")), nil
+ }
+ return strings.TrimSpace(att.Caption), []channel.Attachment{att}
+}
+
+func matrixAttachmentFromContent(content map[string]any, msgType string) (channel.Attachment, bool) {
+ contentURI := strings.TrimSpace(channel.ReadString(content, "url"))
+ if contentURI == "" {
+ return channel.Attachment{}, false
+ }
+ info, _ := content["info"].(map[string]any)
+ body := strings.TrimSpace(channel.ReadString(content, "body"))
+ name := strings.TrimSpace(channel.ReadString(content, "filename"))
+ caption := ""
+ if name == "" {
+ name = body
+ } else if body != "" && !strings.EqualFold(body, name) {
+ caption = body
+ }
+ att := channel.Attachment{
+ Type: matrixAttachmentType(msgType),
+ PlatformKey: contentURI,
+ SourcePlatform: Type.String(),
+ Name: name,
+ Caption: caption,
+ Mime: strings.TrimSpace(channel.ReadString(info, "mimetype")),
+ Size: matrixMapInt64(info, "size"),
+ Width: matrixMapInt(info, "w"),
+ Height: matrixMapInt(info, "h"),
+ DurationMs: matrixMapInt64(info, "duration"),
+ }
+ return channel.NormalizeInboundChannelAttachment(att), true
+}
+
+func isMatrixAttachmentMsgType(msgType string) bool {
+ switch strings.TrimSpace(msgType) {
+ case "m.image", "m.file", "m.video", "m.audio":
+ return true
+ default:
+ return false
+ }
+}
+
+func matrixAttachmentType(msgType string) channel.AttachmentType {
+ switch strings.TrimSpace(msgType) {
+ case "m.image":
+ return channel.AttachmentImage
+ case "m.video":
+ return channel.AttachmentVideo
+ case "m.audio":
+ return channel.AttachmentAudio
+ default:
+ return channel.AttachmentFile
+ }
+}
+
+func buildMatrixQuotedText(replyTo matrixEvent) string {
+ senderName := matrixDisplayName(replyTo)
+ text, attachments := extractMatrixInboundContent(replyTo.Content)
+ text = strings.TrimSpace(text)
+ if text == "" && len(attachments) > 0 {
+ types := make([]string, 0, len(attachments))
+ for _, att := range attachments {
+ types = append(types, string(att.Type))
+ }
+ text = "[" + strings.Join(types, ", ") + "]"
+ }
+ if text == "" {
+ text = strings.TrimSpace(channel.ReadString(replyTo.Content, "body"))
+ }
+ if text == "" {
+ return ""
+ }
+ if len([]rune(text)) > matrixQuotedTextMaxLen {
+ text = string([]rune(text)[:matrixQuotedTextMaxLen]) + "..."
+ }
+ if senderName != "" {
+ return fmt.Sprintf("[Reply to %s: %s]", senderName, text)
+ }
+ return fmt.Sprintf("[Reply to: %s]", text)
+}
+
+func matrixQuotedAttachments(replyTo matrixEvent) []channel.Attachment {
+ _, attachments := extractMatrixInboundContent(replyTo.Content)
+ if len(attachments) == 0 {
+ return nil
+ }
+ return attachments
+}
+
+func stripMatrixReplyFallback(body string) string {
+ trimmed := strings.TrimSpace(body)
+ if trimmed == "" {
+ return ""
+ }
+ lines := strings.Split(strings.ReplaceAll(trimmed, "\r\n", "\n"), "\n")
+ idx := 0
+ sawQuote := false
+ for idx < len(lines) {
+ line := lines[idx]
+ if strings.HasPrefix(line, ">") {
+ sawQuote = true
+ idx++
+ continue
+ }
+ if sawQuote && strings.TrimSpace(line) == "" {
+ idx++
+ continue
+ }
+ break
+ }
+ if !sawQuote {
+ return trimmed
+ }
+ return strings.TrimSpace(strings.Join(lines[idx:], "\n"))
+}
+
+func matrixSinceTokenFromRouting(routing map[string]any) string {
+ if len(routing) == 0 {
+ return ""
+ }
+ state, ok := routing[matrixRoutingStateKey]
+ if !ok || state == nil {
+ return strings.TrimSpace(channel.ReadString(routing, "matrix_since_token", "since_token"))
+ }
+ switch value := state.(type) {
+ case map[string]any:
+ return strings.TrimSpace(channel.ReadString(value, "since_token", "sinceToken"))
+ case map[string]string:
+ return strings.TrimSpace(value["since_token"])
+ default:
+ return ""
+ }
+}
+
+func isMatrixBotMentioned(botUserID string, content map[string]any) bool {
+ botUserID = strings.TrimSpace(botUserID)
+ if botUserID == "" {
+ return false
+ }
+ if mentions, ok := content["m.mentions"].(map[string]any); ok {
+ if userIDs, ok := mentions["user_ids"].([]any); ok {
+ for _, item := range userIDs {
+ if strings.EqualFold(strings.TrimSpace(fmt.Sprint(item)), botUserID) {
+ return true
+ }
+ }
+ }
+ }
+ formatted := strings.TrimSpace(channel.ReadString(content, "formatted_body", "formattedBody"))
+ if formatted != "" {
+ matches := matrixMentionHrefPattern.FindAllStringSubmatch(formatted, -1)
+ for _, match := range matches {
+ if len(match) > 1 && strings.EqualFold(strings.TrimSpace(match[1]), botUserID) {
+ return true
+ }
+ }
+ }
+ body := strings.TrimSpace(channel.ReadString(content, "body"))
+ if body == "" {
+ return false
+ }
+ localpart := botUserID
+ if idx := strings.Index(localpart, ":"); idx > 0 {
+ localpart = localpart[:idx]
+ }
+ for _, candidate := range []string{botUserID, localpart} {
+ if matrixHasExactMentionToken(body, candidate) {
+ return true
+ }
+ }
+ return false
+}
+
+func matrixHasExactMentionToken(body, candidate string) bool {
+ body = strings.TrimSpace(body)
+ candidate = strings.TrimSpace(candidate)
+ if body == "" || candidate == "" {
+ return false
+ }
+ lowerBody := strings.ToLower(body)
+ lowerCandidate := strings.ToLower(candidate)
+ searchFrom := 0
+ for searchFrom < len(lowerBody) {
+ idx := strings.Index(lowerBody[searchFrom:], lowerCandidate)
+ if idx < 0 {
+ return false
+ }
+ start := searchFrom + idx
+ end := start + len(lowerCandidate)
+ if matrixMentionBoundaryBefore(body, start) && matrixMentionBoundaryAfter(body, end) {
+ return true
+ }
+ searchFrom = start + len(lowerCandidate)
+ }
+ return false
+}
+
+func matrixMentionBoundaryBefore(body string, idx int) bool {
+ if idx <= 0 {
+ return true
+ }
+ r, _ := utf8.DecodeLastRuneInString(body[:idx])
+ return matrixMentionBoundaryRune(r, true)
+}
+
+func matrixMentionBoundaryAfter(body string, idx int) bool {
+ if idx >= len(body) {
+ return true
+ }
+ r, _ := utf8.DecodeRuneInString(body[idx:])
+ return matrixMentionBoundaryRune(r, false)
+}
+
+func matrixMentionBoundaryRune(r rune, before bool) bool {
+ if unicode.IsSpace(r) {
+ return true
+ }
+ switch r {
+ case '(', '[', '{', '<', '>', ',', ';', '.', '!', '?', '\'', '"', '`':
+ return true
+ case ')', ']', '}':
+ return !before
+ default:
+ return false
+ }
+}
+
+func matrixAttachmentMsgType(attType channel.AttachmentType) string {
+ switch attType {
+ case channel.AttachmentImage, channel.AttachmentGIF:
+ return "m.image"
+ case channel.AttachmentVideo:
+ return "m.video"
+ case channel.AttachmentAudio, channel.AttachmentVoice:
+ return "m.audio"
+ default:
+ return "m.file"
+ }
+}
+
+func matrixAttachmentBody(att channel.Attachment) string {
+ if caption := strings.TrimSpace(att.Caption); caption != "" {
+ return caption
+ }
+ if name := strings.TrimSpace(att.Name); name != "" {
+ return name
+ }
+ switch att.Type {
+ case channel.AttachmentImage, channel.AttachmentGIF:
+ return "image"
+ case channel.AttachmentVideo:
+ return "video"
+ case channel.AttachmentAudio, channel.AttachmentVoice:
+ return "audio"
+ default:
+ return "file"
+ }
+}
+
+func matrixAttachmentInfo(att channel.Attachment) map[string]any {
+ info := map[string]any{}
+ if mime := strings.TrimSpace(att.Mime); mime != "" {
+ info["mimetype"] = mime
+ }
+ if att.Size > 0 {
+ info["size"] = att.Size
+ }
+ if att.Width > 0 {
+ info["w"] = att.Width
+ }
+ if att.Height > 0 {
+ info["h"] = att.Height
+ }
+ if att.DurationMs > 0 {
+ info["duration"] = att.DurationMs
+ }
+ return info
+}
+
+func matrixMapInt64(raw map[string]any, key string) int64 {
+ if raw == nil {
+ return 0
+ }
+ value, ok := raw[key]
+ if !ok {
+ return 0
+ }
+ switch v := value.(type) {
+ case int:
+ return int64(v)
+ case int32:
+ return int64(v)
+ case int64:
+ return v
+ case float64:
+ return int64(v)
+ case json.Number:
+ parsed, err := v.Int64()
+ if err == nil {
+ return parsed
+ }
+ }
+ return 0
+}
+
+func matrixMapInt(raw map[string]any, key string) int {
+ return int(matrixMapInt64(raw, key))
+}
+
+func (a *MatrixAdapter) sendTextEvent(ctx context.Context, cfg Config, roomID string, content map[string]any) (string, error) {
+ txnID := a.nextTxnID()
+ path := fmt.Sprintf("/_matrix/client/v3/rooms/%s/send/m.room.message/%s", url.PathEscape(roomID), url.PathEscape(txnID))
+ var resp matrixSendResponse
+ if err := a.doJSON(ctx, cfg, http.MethodPut, path, content, &resp); err != nil {
+ return "", err
+ }
+ return strings.TrimSpace(resp.EventID), nil
+}
+
+func (a *MatrixAdapter) sendMediaAttachment(ctx context.Context, cfg Config, roomID string, fallbackBotID string, msg channel.Message, att channel.Attachment) error {
+ contentURI, resolved, err := a.resolveMatrixContentURI(ctx, cfg, fallbackBotID, att)
+ if err != nil {
+ return err
+ }
+ _, err = a.sendTextEvent(ctx, cfg, roomID, buildMatrixMediaContent(msg, resolved, contentURI))
+ return err
+}
+
+func (a *MatrixAdapter) resolveMatrixContentURI(ctx context.Context, cfg Config, fallbackBotID string, att channel.Attachment) (string, channel.Attachment, error) {
+ if ref := strings.TrimSpace(att.PlatformKey); isMatrixContentURI(ref) {
+ resolved := att
+ if resolved.SourcePlatform == "" {
+ resolved.SourcePlatform = Type.String()
+ }
+ return ref, resolved, nil
+ }
+ if ref := strings.TrimSpace(att.URL); isMatrixContentURI(ref) {
+ resolved := att
+ if resolved.SourcePlatform == "" {
+ resolved.SourcePlatform = Type.String()
+ }
+ return ref, resolved, nil
+ }
+ payload, resolved, err := a.prepareMatrixUpload(ctx, fallbackBotID, att)
+ if err != nil {
+ return "", channel.Attachment{}, err
+ }
+ contentURI, err := a.uploadMatrixMedia(ctx, cfg, payload.data, payload.mime, payload.name)
+ if err != nil {
+ return "", channel.Attachment{}, err
+ }
+ resolved.PlatformKey = contentURI
+ resolved.SourcePlatform = Type.String()
+ if resolved.Size <= 0 {
+ resolved.Size = int64(len(payload.data))
+ }
+ return contentURI, resolved, nil
+}
+
+type matrixUploadPayload struct {
+ data []byte
+ mime string
+ name string
+}
+
+func (a *MatrixAdapter) prepareMatrixUpload(ctx context.Context, fallbackBotID string, att channel.Attachment) (matrixUploadPayload, channel.Attachment, error) {
+ resolved := att
+ assetID := strings.TrimSpace(att.ContentHash)
+ botID := strings.TrimSpace(fallbackBotID)
+ if att.Metadata != nil {
+ if value, ok := att.Metadata["bot_id"].(string); ok && strings.TrimSpace(value) != "" {
+ botID = strings.TrimSpace(value)
+ }
+ }
+ if assetID != "" && a.assets != nil && botID != "" {
+ reader, asset, err := a.assets.Open(ctx, botID, assetID)
+ if err == nil {
+ defer func() { _ = reader.Close() }()
+ data, readErr := media.ReadAllWithLimit(reader, media.MaxAssetBytes)
+ if readErr != nil {
+ return matrixUploadPayload{}, channel.Attachment{}, readErr
+ }
+ if strings.TrimSpace(resolved.Mime) == "" {
+ resolved.Mime = strings.TrimSpace(asset.Mime)
+ }
+ if resolved.Size <= 0 {
+ resolved.Size = asset.SizeBytes
+ }
+ name := deriveMatrixUploadName(resolved, resolved.Mime, "")
+ return matrixUploadPayload{data: data, mime: strings.TrimSpace(resolved.Mime), name: name}, resolved, nil
+ }
+ }
+
+ rawBase64 := strings.TrimSpace(att.Base64)
+ refURL := strings.TrimSpace(att.URL)
+ if rawBase64 == "" && strings.HasPrefix(strings.ToLower(refURL), "data:") {
+ rawBase64 = refURL
+ }
+ if rawBase64 != "" {
+ decoded, err := attachmentpkg.DecodeBase64(rawBase64, media.MaxAssetBytes)
+ if err != nil {
+ return matrixUploadPayload{}, channel.Attachment{}, fmt.Errorf("decode matrix attachment base64: %w", err)
+ }
+ data, err := media.ReadAllWithLimit(decoded, media.MaxAssetBytes)
+ if err != nil {
+ return matrixUploadPayload{}, channel.Attachment{}, fmt.Errorf("read matrix attachment base64: %w", err)
+ }
+ if strings.TrimSpace(resolved.Mime) == "" {
+ resolved.Mime = strings.TrimSpace(attachmentpkg.MimeFromDataURL(rawBase64))
+ }
+ if resolved.Size <= 0 {
+ resolved.Size = int64(len(data))
+ }
+ name := deriveMatrixUploadName(resolved, resolved.Mime, "")
+ return matrixUploadPayload{data: data, mime: strings.TrimSpace(resolved.Mime), name: name}, resolved, nil
+ }
+
+ if refURL == "" {
+ return matrixUploadPayload{}, channel.Attachment{}, errors.New("matrix attachment requires content_hash, base64, mxc url, or http(s) url")
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, refURL, nil)
+ if err != nil {
+ return matrixUploadPayload{}, channel.Attachment{}, fmt.Errorf("build matrix attachment download request: %w", err)
+ }
+ resp, err := (&http.Client{Timeout: 60 * time.Second}).Do(req) //nolint:gosec // URL is a user-provided or cross-platform attachment reference.
+ if err != nil {
+ return matrixUploadPayload{}, channel.Attachment{}, fmt.Errorf("download matrix attachment: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+ if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
+ return matrixUploadPayload{}, channel.Attachment{}, fmt.Errorf("download matrix attachment status: %d", resp.StatusCode)
+ }
+ if resp.ContentLength > media.MaxAssetBytes {
+ return matrixUploadPayload{}, channel.Attachment{}, fmt.Errorf("%w: max %d bytes", media.ErrAssetTooLarge, media.MaxAssetBytes)
+ }
+ data, err := media.ReadAllWithLimit(resp.Body, media.MaxAssetBytes)
+ if err != nil {
+ return matrixUploadPayload{}, channel.Attachment{}, err
+ }
+ if strings.TrimSpace(resolved.Mime) == "" {
+ resolved.Mime = strings.TrimSpace(resp.Header.Get("Content-Type"))
+ resolved.Mime = attachmentpkg.NormalizeMime(resolved.Mime)
+ }
+ if resolved.Size <= 0 {
+ if resp.ContentLength > 0 {
+ resolved.Size = resp.ContentLength
+ } else {
+ resolved.Size = int64(len(data))
+ }
+ }
+ name := deriveMatrixUploadName(resolved, resolved.Mime, refURL)
+ return matrixUploadPayload{data: data, mime: strings.TrimSpace(resolved.Mime), name: name}, resolved, nil
+}
+
+func deriveMatrixUploadName(att channel.Attachment, mime, refURL string) string {
+ if name := strings.TrimSpace(att.Name); name != "" {
+ return name
+ }
+ if refURL != "" {
+ if parsed, err := url.Parse(refURL); err == nil {
+ if base := strings.TrimSpace(pathpkg.Base(parsed.Path)); base != "" && base != "." && base != "/" {
+ return base
+ }
+ }
+ }
+ return matrixAttachmentBody(channel.Attachment{Type: att.Type, Mime: mime, Caption: att.Caption})
+}
+
+func (a *MatrixAdapter) uploadMatrixMedia(ctx context.Context, cfg Config, data []byte, mime, filename string) (string, error) {
+ query := url.Values{}
+ if strings.TrimSpace(filename) != "" {
+ query.Set("filename", strings.TrimSpace(filename))
+ }
+ path := "/_matrix/media/v3/upload"
+ if encoded := query.Encode(); encoded != "" {
+ path += "?" + encoded
+ }
+ body := bytes.NewReader(data)
+ payload, _, err := a.doRequest(ctx, cfg, http.MethodPost, path, body, firstNonEmpty(strings.TrimSpace(mime), "application/octet-stream"))
+ if err != nil {
+ return "", err
+ }
+ var resp matrixUploadResponse
+ if err := json.Unmarshal(payload, &resp); err != nil {
+ return "", err
+ }
+ contentURI := strings.TrimSpace(resp.ContentURI)
+ if contentURI == "" {
+ return "", errors.New("matrix upload returned empty content_uri")
+ }
+ return contentURI, nil
+}
+
+func firstNonEmpty(values ...string) string {
+ for _, value := range values {
+ if strings.TrimSpace(value) != "" {
+ return strings.TrimSpace(value)
+ }
+ }
+ return ""
+}
+
+func isMatrixContentURI(ref string) bool {
+ return strings.HasPrefix(strings.ToLower(strings.TrimSpace(ref)), "mxc://")
+}
+
+func parseMatrixContentURI(ref string) (string, string, bool) {
+ trimmed := strings.TrimSpace(ref)
+ if !isMatrixContentURI(trimmed) {
+ return "", "", false
+ }
+ withoutScheme := strings.TrimPrefix(trimmed, "mxc://")
+ server, mediaID, ok := strings.Cut(withoutScheme, "/")
+ if !ok || strings.TrimSpace(server) == "" || strings.TrimSpace(mediaID) == "" {
+ return "", "", false
+ }
+ return strings.TrimSpace(server), strings.TrimSpace(mediaID), true
+}
+
+func (a *MatrixAdapter) resolveRoomTarget(ctx context.Context, cfg Config, target string) (string, error) {
+ target = normalizeTarget(target)
+ if err := validateTarget(target); err != nil {
+ return "", err
+ }
+ if strings.HasPrefix(target, "@") {
+ return a.ensureDirectRoom(ctx, cfg, target)
+ }
+ if strings.HasPrefix(target, "#") {
+ return a.resolveRoomAlias(ctx, cfg, target)
+ }
+ return target, nil
+}
+
+func (a *MatrixAdapter) resolveRoomAlias(ctx context.Context, cfg Config, roomAlias string) (string, error) {
+ path := fmt.Sprintf("/_matrix/client/v3/directory/room/%s", url.PathEscape(strings.TrimSpace(roomAlias)))
+ var resp matrixRoomAliasResponse
+ if err := a.doJSON(ctx, cfg, http.MethodGet, path, nil, &resp); err != nil {
+ return "", err
+ }
+ if strings.TrimSpace(resp.RoomID) == "" {
+ return "", fmt.Errorf("matrix room alias lookup returned empty room_id: %s", roomAlias)
+ }
+ return strings.TrimSpace(resp.RoomID), nil
+}
+
+func (a *MatrixAdapter) ensureDirectRoom(ctx context.Context, cfg Config, userID string) (string, error) {
+ userID = strings.TrimSpace(userID)
+ if roomID, ok := a.cachedDirectRoom(cfg, userID); ok {
+ return roomID, nil
+ }
+ if roomID, err := a.findExistingDirectRoom(ctx, cfg, userID); err == nil {
+ if roomID != "" {
+ a.rememberDirectRoomForConfig(cfg, userID, roomID)
+ return roomID, nil
+ }
+ } else if a.logger != nil {
+ a.logger.Warn("matrix direct room lookup failed",
+ slog.String("user_id", userID),
+ slog.Any("error", err),
+ )
+ }
+ req := matrixCreateRoomRequest{
+ Invite: []string{userID},
+ IsDirect: true,
+ Preset: "trusted_private_chat",
+ }
+ var resp matrixCreateRoomResponse
+ if err := a.doJSON(ctx, cfg, http.MethodPost, "/_matrix/client/v3/createRoom", req, &resp); err != nil {
+ return "", err
+ }
+ if strings.TrimSpace(resp.RoomID) == "" {
+ return "", errors.New("matrix createRoom returned empty room_id")
+ }
+ roomID := strings.TrimSpace(resp.RoomID)
+ a.rememberDirectRoomForConfig(cfg, userID, roomID)
+ return roomID, nil
+}
+
+func (a *MatrixAdapter) findExistingDirectRoom(ctx context.Context, cfg Config, userID string) (string, error) {
+ var resp matrixJoinedRoomsResponse
+ if err := a.doJSON(ctx, cfg, http.MethodGet, "/_matrix/client/v3/joined_rooms", nil, &resp); err != nil {
+ return "", err
+ }
+ for _, roomID := range resp.JoinedRooms {
+ matched, err := a.isDirectRoomForUser(ctx, cfg, roomID, userID)
+ if err != nil {
+ if a.logger != nil {
+ a.logger.Warn("matrix direct room candidate lookup failed",
+ slog.String("room_id", strings.TrimSpace(roomID)),
+ slog.String("user_id", strings.TrimSpace(userID)),
+ slog.Any("error", err),
+ )
+ }
+ continue
+ }
+ if matched {
+ return strings.TrimSpace(roomID), nil
+ }
+ }
+ return "", nil
+}
+
+func (a *MatrixAdapter) isDirectRoomForUser(ctx context.Context, cfg Config, roomID string, userID string) (bool, error) {
+ path := fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", url.PathEscape(strings.TrimSpace(roomID)))
+ var resp matrixJoinedMembersResponse
+ if err := a.doJSON(ctx, cfg, http.MethodGet, path, nil, &resp); err != nil {
+ return false, err
+ }
+ if len(resp.Joined) != 2 {
+ return false, nil
+ }
+ if _, ok := resp.Joined[strings.TrimSpace(userID)]; !ok {
+ return false, nil
+ }
+ if _, ok := resp.Joined[strings.TrimSpace(cfg.UserID)]; !ok {
+ return false, nil
+ }
+ return true, nil
+}
+
+func directRoomCacheKey(cfg Config) string {
+ return strings.TrimSpace(cfg.HomeserverURL) + "|" + strings.TrimSpace(cfg.UserID)
+}
+
+func (a *MatrixAdapter) cachedDirectRoom(cfg Config, userID string) (string, bool) {
+ if a == nil {
+ return "", false
+ }
+ cacheKey := directRoomCacheKey(cfg)
+ userID = strings.TrimSpace(userID)
+ if cacheKey == "" || userID == "" {
+ return "", false
+ }
+ a.directRoomMu.Lock()
+ defer a.directRoomMu.Unlock()
+ rooms, ok := a.directRooms[cacheKey]
+ if !ok {
+ return "", false
+ }
+ roomID, ok := rooms[userID]
+ if !ok || strings.TrimSpace(roomID) == "" {
+ return "", false
+ }
+ return roomID, true
+}
+
+func (a *MatrixAdapter) rememberDirectRoomForConfig(cfg Config, userID, roomID string) {
+ a.rememberDirectRoom(directRoomCacheKey(cfg), userID, roomID)
+}
+
+func (a *MatrixAdapter) rememberDirectRoom(cacheKey, userID, roomID string) {
+ if a == nil {
+ return
+ }
+ cacheKey = strings.TrimSpace(cacheKey)
+ userID = strings.TrimSpace(userID)
+ roomID = strings.TrimSpace(roomID)
+ if cacheKey == "" || userID == "" || roomID == "" {
+ return
+ }
+ a.directRoomMu.Lock()
+ defer a.directRoomMu.Unlock()
+ rooms, ok := a.directRooms[cacheKey]
+ if !ok {
+ rooms = make(map[string]string)
+ a.directRooms[cacheKey] = rooms
+ }
+ rooms[userID] = roomID
+}
+
+func (a *MatrixAdapter) joinRoom(ctx context.Context, cfg Config, roomID string) error {
+ path := fmt.Sprintf("/_matrix/client/v3/join/%s", url.PathEscape(strings.TrimSpace(roomID)))
+ return a.doJSON(ctx, cfg, http.MethodPost, path, nil, nil)
+}
+
+func (a *MatrixAdapter) ResolveAttachment(ctx context.Context, cfg channel.ChannelConfig, attachment channel.Attachment) (channel.AttachmentPayload, error) {
+ contentURI := strings.TrimSpace(attachment.PlatformKey)
+ if contentURI == "" {
+ contentURI = strings.TrimSpace(attachment.URL)
+ }
+ if contentURI == "" {
+ return channel.AttachmentPayload{}, errors.New("matrix attachment requires platform_key or url")
+ }
+ if !isMatrixContentURI(contentURI) {
+ return channel.AttachmentPayload{}, errors.New("matrix attachment reference must be mxc://")
+ }
+ parsed, err := parseConfig(cfg.Credentials)
+ if err != nil {
+ return channel.AttachmentPayload{}, err
+ }
+ serverName, mediaID, ok := parseMatrixContentURI(contentURI)
+ if !ok {
+ return channel.AttachmentPayload{}, errors.New("invalid matrix content uri")
+ }
+ body, header, contentLength, err := a.downloadMatrixMedia(ctx, parsed, serverName, mediaID, strings.TrimSpace(attachment.Name))
+ if err != nil {
+ return channel.AttachmentPayload{}, err
+ }
+ mime := strings.TrimSpace(attachment.Mime)
+ if mime == "" {
+ mime = attachmentpkg.NormalizeMime(header.Get("Content-Type"))
+ }
+ size := attachment.Size
+ if size <= 0 && contentLength > 0 {
+ size = contentLength
+ }
+ return channel.AttachmentPayload{
+ Reader: body,
+ Mime: mime,
+ Name: strings.TrimSpace(attachment.Name),
+ Size: size,
+ }, nil
+}
+
+func (a *MatrixAdapter) downloadMatrixMedia(ctx context.Context, cfg Config, serverName, mediaID, fileName string) (io.ReadCloser, http.Header, int64, error) {
+ paths := make([]string, 0, 3)
+ serverName = url.PathEscape(strings.TrimSpace(serverName))
+ mediaID = url.PathEscape(strings.TrimSpace(mediaID))
+ trimmedFileName := strings.TrimSpace(fileName)
+ if trimmedFileName != "" {
+ paths = append(paths, fmt.Sprintf("/_matrix/client/v1/media/download/%s/%s/%s", serverName, mediaID, url.PathEscape(trimmedFileName)))
+ }
+ paths = append(paths,
+ fmt.Sprintf("/_matrix/client/v1/media/download/%s/%s", serverName, mediaID),
+ fmt.Sprintf("/_matrix/media/v3/download/%s/%s", serverName, mediaID),
+ )
+
+ var lastErr error
+ for _, path := range paths {
+ request, err := http.NewRequestWithContext(ctx, http.MethodGet, cfg.HomeserverURL+path, nil)
+ if err != nil {
+ return nil, nil, 0, err
+ }
+ request.Header.Set("Authorization", "Bearer "+cfg.AccessToken)
+ resp, err := a.httpClient.Do(request) //nolint:gosec // G704: URL is derived from operator-configured Matrix homeserver
+ if err != nil {
+ lastErr = fmt.Errorf("download matrix attachment: %w", err)
+ continue
+ }
+ if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices {
+ return resp.Body, resp.Header.Clone(), resp.ContentLength, nil
+ }
+ data, _ := io.ReadAll(resp.Body)
+ _ = resp.Body.Close()
+ message := strings.TrimSpace(string(data))
+ if message == "" {
+ message = resp.Status
+ }
+ lastErr = fmt.Errorf("download matrix attachment failed: %s", textutil.TruncateRunes(message, 300))
+ if resp.StatusCode != http.StatusNotFound {
+ return nil, nil, 0, lastErr
+ }
+ }
+ if lastErr == nil {
+ lastErr = errors.New("download matrix attachment failed")
+ }
+ return nil, nil, 0, lastErr
+}
+
+func (a *MatrixAdapter) doJSON(ctx context.Context, cfg Config, method, path string, reqBody any, respBody any) error {
+ var body io.Reader
+ contentType := ""
+ if reqBody != nil {
+ payload, err := json.Marshal(reqBody)
+ if err != nil {
+ return err
+ }
+ body = bytes.NewReader(payload)
+ contentType = "application/json"
+ }
+ data, _, err := a.doRequest(ctx, cfg, method, path, body, contentType)
+ if err != nil {
+ return err
+ }
+ if respBody == nil || len(data) == 0 {
+ return nil
+ }
+ return json.Unmarshal(data, respBody)
+}
+
+func (a *MatrixAdapter) doRequest(ctx context.Context, cfg Config, method, path string, body io.Reader, contentType string) ([]byte, http.Header, error) {
+ data, header, statusCode, err := a.performRequest(ctx, method, cfg.HomeserverURL+path, body, contentType, cfg.AccessToken)
+ if err != nil {
+ return nil, nil, err
+ }
+ if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
+ return nil, header, fmt.Errorf("matrix %s %s failed: %s", method, path, matrixHTTPErrorSummary(statusCode, data))
+ }
+ return data, header, nil
+}
+
+func (a *MatrixAdapter) performRequest(ctx context.Context, method string, requestURL string, body io.Reader, contentType string, accessToken string) ([]byte, http.Header, int, error) {
+ request, err := http.NewRequestWithContext(ctx, method, requestURL, body)
+ if err != nil {
+ return nil, nil, 0, err
+ }
+ if strings.TrimSpace(accessToken) != "" {
+ request.Header.Set("Authorization", "Bearer "+strings.TrimSpace(accessToken))
+ }
+ if strings.TrimSpace(contentType) != "" {
+ request.Header.Set("Content-Type", strings.TrimSpace(contentType))
+ }
+ resp, err := a.httpClient.Do(request) //nolint:gosec // G704: URL is derived from operator-configured Matrix homeserver
+ if err != nil {
+ return nil, nil, 0, err
+ }
+ defer func() { _ = resp.Body.Close() }()
+ data, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, resp.Header.Clone(), resp.StatusCode, err
+ }
+ return data, resp.Header.Clone(), resp.StatusCode, nil
+}
+
+func (a *MatrixAdapter) nextTxnID() string {
+ a.txnMu.Lock()
+ defer a.txnMu.Unlock()
+ a.txnID++
+ rnd, err := cryptorand.Int(cryptorand.Reader, big.NewInt(10000))
+ if err != nil {
+ return fmt.Sprintf("memoh-%d-%d", time.Now().UnixMilli(), a.txnID)
+ }
+ return fmt.Sprintf("memoh-%d-%d-%04d", time.Now().UnixMilli(), a.txnID, rnd.Int64())
+}
+
+func (a *MatrixAdapter) seenEvent(configID, eventID string) bool {
+ configID = strings.TrimSpace(configID)
+ eventID = strings.TrimSpace(eventID)
+ if configID == "" || eventID == "" {
+ return false
+ }
+ now := time.Now()
+ a.seenMu.Lock()
+ defer a.seenMu.Unlock()
+ byConfig := a.seen[configID]
+ if byConfig == nil {
+ byConfig = make(map[string]time.Time)
+ a.seen[configID] = byConfig
+ }
+ for id, seenAt := range byConfig {
+ if now.Sub(seenAt) > 10*time.Minute {
+ delete(byConfig, id)
+ }
+ }
+ if _, ok := byConfig[eventID]; ok {
+ return true
+ }
+ byConfig[eventID] = now
+ return false
+}
+
+func matrixDisplayName(evt matrixEvent) string {
+ unsignedSender, ok := evt.Unsigned["m.relations"].(map[string]any)
+ if ok {
+ _ = unsignedSender
+ }
+ if displayName := strings.TrimSpace(channel.ReadString(evt.Unsigned, "displayname", "sender_display_name")); displayName != "" {
+ return displayName
+ }
+ return strings.TrimSpace(evt.Sender)
+}
+
+func matrixEventTime(ts int64) time.Time {
+ if ts <= 0 {
+ return time.Now().UTC()
+ }
+ return time.UnixMilli(ts).UTC()
+}
+
+func sleepContext(ctx context.Context, delay time.Duration) bool {
+ if delay <= 0 {
+ return ctx.Err() == nil
+ }
+ timer := time.NewTimer(delay)
+ defer timer.Stop()
+ select {
+ case <-ctx.Done():
+ return false
+ case <-timer.C:
+ return true
+ }
+}
+
+func nextReconnectDelay(backoffs []time.Duration, attempt int, healthySession bool) (time.Duration, int) {
+ if healthySession {
+ attempt = 0
+ }
+ if len(backoffs) == 0 {
+ return time.Second, attempt + 1
+ }
+ if attempt < 0 {
+ attempt = 0
+ }
+ if attempt >= len(backoffs) {
+ attempt = len(backoffs) - 1
+ }
+ delay := backoffs[attempt]
+ if attempt < len(backoffs)-1 {
+ attempt++
+ }
+ return delay, attempt
+}
diff --git a/internal/channel/adapters/matrix/matrix_test.go b/internal/channel/adapters/matrix/matrix_test.go
new file mode 100644
index 00000000..76060e65
--- /dev/null
+++ b/internal/channel/adapters/matrix/matrix_test.go
@@ -0,0 +1,1018 @@
+package matrix
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "testing"
+
+ "github.com/memohai/memoh/internal/channel"
+)
+
+func TestIsMatrixBotMentionedByMentionsMetadata(t *testing.T) {
+ content := map[string]any{
+ "body": "hi bot",
+ "m.mentions": map[string]any{
+ "user_ids": []any{"@memoh:example.com"},
+ },
+ }
+ if !isMatrixBotMentioned("@memoh:example.com", content) {
+ t.Fatal("expected mention metadata to be detected")
+ }
+}
+
+func TestIsMatrixBotMentionedByFormattedBody(t *testing.T) {
+ content := map[string]any{
+ "body": "hello Memoh",
+ "formatted_body": `Memoh hello`,
+ }
+ if !isMatrixBotMentioned("@memoh:example.com", content) {
+ t.Fatal("expected formatted body mention to be detected")
+ }
+}
+
+func TestIsMatrixBotMentionedByBodyFallback(t *testing.T) {
+ content := map[string]any{
+ "body": "@memoh:example.com ping",
+ }
+ if !isMatrixBotMentioned("@memoh:example.com", content) {
+ t.Fatal("expected body fallback mention to be detected")
+ }
+}
+
+func TestIsMatrixBotMentionedByLocalpartBodyFallback(t *testing.T) {
+ content := map[string]any{
+ "body": "@memoh ping",
+ }
+ if !isMatrixBotMentioned("@memoh:example.com", content) {
+ t.Fatal("expected localpart body fallback mention to be detected")
+ }
+}
+
+func TestIsMatrixBotMentionedDoesNotMatchSubstring(t *testing.T) {
+ content := map[string]any{
+ "body": "@memoh-helper:example.com ping",
+ }
+ if isMatrixBotMentioned("@memoh:example.com", content) {
+ t.Fatal("expected substring match not to count as mention")
+ }
+}
+
+func TestIsMatrixBotMentionedDoesNotMatchPlainMatrixURL(t *testing.T) {
+ content := map[string]any{
+ "body": "see https://matrix.to/#/@memoh:example.com",
+ }
+ if isMatrixBotMentioned("@memoh:example.com", content) {
+ t.Fatal("expected plain Matrix URL not to count as mention")
+ }
+}
+
+func TestMatrixSinceTokenFromRouting(t *testing.T) {
+ routing := map[string]any{
+ matrixRoutingStateKey: map[string]any{"since_token": "s123"},
+ }
+ if got := matrixSinceTokenFromRouting(routing); got != "s123" {
+ t.Fatalf("unexpected since token: %q", got)
+ }
+}
+
+func TestPersistSinceTokenUsesConfiguredSaver(t *testing.T) {
+ var gotConfigID string
+ var gotSince string
+ adapter := NewMatrixAdapter(nil)
+ adapter.SetSyncStateSaver(func(_ context.Context, configID string, since string) error {
+ gotConfigID = configID
+ gotSince = since
+ return nil
+ })
+ if err := adapter.persistSinceToken(context.Background(), "cfg-1", "token-1"); err != nil {
+ t.Fatalf("persistSinceToken returned error: %v", err)
+ }
+ if gotConfigID != "cfg-1" || gotSince != "token-1" {
+ t.Fatalf("unexpected saver args: %q %q", gotConfigID, gotSince)
+ }
+}
+
+func TestBootstrapSinceTokenPersistsLatestCursor(t *testing.T) {
+ adapter := NewMatrixAdapter(nil)
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(_ *http.Request) (*http.Response, error) {
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"next_batch":"s123","rooms":{"join":{"!room:example.com":{"timeline":{"events":[{"event_id":"$evt1"}]}}}}}`)),
+ Header: make(http.Header),
+ }, nil
+ })}
+ var gotConfigID string
+ var gotSince string
+ adapter.SetSyncStateSaver(func(_ context.Context, configID string, since string) error {
+ gotConfigID = configID
+ gotSince = since
+ return nil
+ })
+
+ since, err := adapter.bootstrapSinceToken(context.Background(), channel.ChannelConfig{ID: "cfg-1"}, Config{
+ HomeserverURL: "https://matrix.example.com",
+ AccessToken: "tok",
+ AutoJoinInvites: true,
+ })
+ if err != nil {
+ t.Fatalf("bootstrapSinceToken returned error: %v", err)
+ }
+ if since != "s123" {
+ t.Fatalf("unexpected since token: %q", since)
+ }
+ if gotConfigID != "cfg-1" || gotSince != "s123" {
+ t.Fatalf("unexpected persisted cursor: %q %q", gotConfigID, gotSince)
+ }
+ if !adapter.seenEvent("cfg-1", "$evt1") {
+ t.Fatal("expected bootstrap event to be remembered as seen")
+ }
+}
+
+func TestBootstrapSinceTokenAutoJoinsInvitedRooms(t *testing.T) {
+ joinRequests := 0
+ adapter := NewMatrixAdapter(nil)
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
+ switch req.URL.Path {
+ case "/_matrix/client/v3/sync":
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"next_batch":"s123","rooms":{"invite":{"!room:example.com":{"invite_state":{"events":[{"type":"m.room.member"}]}}}}}`)),
+ Header: make(http.Header),
+ }, nil
+ case "/_matrix/client/v3/join/!room:example.com":
+ joinRequests++
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{}`)),
+ Header: make(http.Header),
+ }, nil
+ default:
+ t.Fatalf("unexpected request path: %s", req.URL.Path)
+ return nil, nil
+ }
+ })}
+
+ since, err := adapter.bootstrapSinceToken(context.Background(), channel.ChannelConfig{ID: "cfg-1"}, Config{
+ HomeserverURL: "https://matrix.example.com",
+ AccessToken: "tok",
+ AutoJoinInvites: true,
+ })
+ if err != nil {
+ t.Fatalf("bootstrapSinceToken returned error: %v", err)
+ }
+ if since != "s123" {
+ t.Fatalf("unexpected since token: %q", since)
+ }
+ if joinRequests != 1 {
+ t.Fatalf("expected invited room to be auto-joined once, got %d", joinRequests)
+ }
+}
+
+func TestValidateConnectionChecksHomeserverVersions(t *testing.T) {
+ adapter := NewMatrixAdapter(nil)
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
+ if req.URL.Path != "/_matrix/client/versions" {
+ t.Fatalf("unexpected request path: %s", req.URL.Path)
+ }
+ return &http.Response{
+ StatusCode: http.StatusNotFound,
+ Body: io.NopCloser(strings.NewReader("not found")),
+ Header: make(http.Header),
+ }, nil
+ })}
+
+ err := adapter.validateConnection(context.Background(), Config{
+ HomeserverURL: "https://matrix.example.com",
+ AccessToken: "tok",
+ UserID: "@memoh:example.com",
+ })
+ if err == nil {
+ t.Fatal("expected homeserver validation to fail")
+ }
+ if !strings.Contains(err.Error(), "homeserver check failed") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+func TestValidateConnectionRejectsTokenUserMismatch(t *testing.T) {
+ requests := make([]string, 0, 2)
+ adapter := NewMatrixAdapter(nil)
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
+ requests = append(requests, req.URL.Path)
+ switch req.URL.Path {
+ case "/_matrix/client/versions":
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"versions":["v1.11"]}`)),
+ Header: make(http.Header),
+ }, nil
+ case "/_matrix/client/v3/account/whoami":
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"user_id":"@alice:example.com"}`)),
+ Header: make(http.Header),
+ }, nil
+ default:
+ t.Fatalf("unexpected request path: %s", req.URL.Path)
+ return nil, nil
+ }
+ })}
+
+ err := adapter.validateConnection(context.Background(), Config{
+ HomeserverURL: "https://matrix.example.com",
+ AccessToken: "tok",
+ UserID: "@memoh:example.com",
+ })
+ if err == nil {
+ t.Fatal("expected token mismatch validation to fail")
+ }
+ if !strings.Contains(err.Error(), "token belongs to @alice:example.com, expected @memoh:example.com") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(requests) != 2 {
+ t.Fatalf("expected homeserver and whoami checks, got %d requests", len(requests))
+ }
+}
+
+func TestValidateConnectionSkipsSyncProbe(t *testing.T) {
+ requests := make([]string, 0, 3)
+ adapter := NewMatrixAdapter(nil)
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
+ requests = append(requests, req.URL.RequestURI())
+ switch req.URL.Path {
+ case "/_matrix/client/versions":
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"versions":["v1.11"]}`)),
+ Header: make(http.Header),
+ }, nil
+ case "/_matrix/client/v3/account/whoami":
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"user_id":"@memoh:example.com"}`)),
+ Header: make(http.Header),
+ }, nil
+ case "/_matrix/client/v3/sync":
+ t.Fatal("did not expect /sync probe during connection validation")
+ return nil, nil
+ default:
+ t.Fatalf("unexpected request path: %s", req.URL.Path)
+ return nil, nil
+ }
+ })}
+
+ err := adapter.validateConnection(context.Background(), Config{
+ HomeserverURL: "https://matrix.example.com",
+ AccessToken: "tok",
+ UserID: "@memoh:example.com",
+ })
+ if err != nil {
+ t.Fatalf("validateConnection returned error: %v", err)
+ }
+ if len(requests) != 2 {
+ t.Fatalf("expected homeserver and whoami checks only, got %d requests", len(requests))
+ }
+}
+
+func TestHandleInvitesSkipsWhenAutoJoinDisabled(t *testing.T) {
+ joinRequests := 0
+ adapter := NewMatrixAdapter(nil)
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
+ if req.URL.Path == "/_matrix/client/v3/join/!room:example.com" {
+ joinRequests++
+ }
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{}`)),
+ Header: make(http.Header),
+ }, nil
+ })}
+
+ joined, err := adapter.handleInvites(
+ context.Background(),
+ channel.ChannelConfig{ID: "cfg-1"},
+ Config{HomeserverURL: "https://matrix.example.com", AccessToken: "tok", AutoJoinInvites: false},
+ matrixSyncResponse{Rooms: struct {
+ Join map[string]matrixSyncJoinedRoom `json:"join"`
+ Invite map[string]matrixSyncInvitedRoom `json:"invite"`
+ }{Invite: map[string]matrixSyncInvitedRoom{"!room:example.com": {}}}},
+ )
+ if err != nil {
+ t.Fatalf("handleInvites returned error: %v", err)
+ }
+ if joined {
+ t.Fatal("expected no room to be joined")
+ }
+ if joinRequests != 0 {
+ t.Fatalf("expected no join requests, got %d", joinRequests)
+ }
+}
+
+func TestBuildMatrixMessageContentIncludesFormattedHTMLForMarkdown(t *testing.T) {
+ content := buildMatrixMessageContent(channel.Message{
+ Text: "**bold**\n\n- item",
+ Format: channel.MessageFormatMarkdown,
+ }, false, "")
+
+ if got := content["body"]; got != "**bold**\n\n- item" {
+ t.Fatalf("unexpected body: %#v", got)
+ }
+ if got := content["format"]; got != matrixHTMLFormat {
+ t.Fatalf("unexpected format: %#v", got)
+ }
+ html, ok := content["formatted_body"].(string)
+ if !ok || !strings.Contains(html, "bold") || !strings.Contains(html, "") {
+ t.Fatalf("unexpected formatted body: %#v", content["formatted_body"])
+ }
+}
+
+func TestBuildMatrixMessageContentAddsFormattedHTMLToEdits(t *testing.T) {
+ content := buildMatrixMessageContent(channel.Message{
+ Text: "`code`",
+ Format: channel.MessageFormatMarkdown,
+ }, true, "$evt1")
+
+ newContent, ok := content["m.new_content"].(map[string]any)
+ if !ok {
+ t.Fatalf("expected m.new_content map, got %#v", content["m.new_content"])
+ }
+ if got := newContent["format"]; got != matrixHTMLFormat {
+ t.Fatalf("unexpected edit format: %#v", got)
+ }
+ html, ok := newContent["formatted_body"].(string)
+ if !ok || !strings.Contains(html, "code") {
+ t.Fatalf("unexpected edit formatted body: %#v", newContent["formatted_body"])
+ }
+}
+
+func TestStripMatrixReplyFallback(t *testing.T) {
+ body := "> <@memoh:example.com> This looks like Antelope Canyon\n>\nWhere is Antelope Canyon?"
+ if got := stripMatrixReplyFallback(body); got != "Where is Antelope Canyon?" {
+ t.Fatalf("unexpected stripped body: %q", got)
+ }
+}
+
+func TestMatrixHandleEventExpandsRepliedImageContext(t *testing.T) {
+ adapter := NewMatrixAdapter(nil)
+ adapter.rememberRoomConversationType("cfg-1", "!room:example.com", "group")
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
+ if !strings.Contains(req.URL.Path, "/rooms/!room:example.com/event/$img1") {
+ t.Fatalf("unexpected request path: %s", req.URL.Path)
+ }
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{
+ "event_id":"$img1",
+ "type":"m.room.message",
+ "sender":"@memoh:example.com",
+ "unsigned":{"displayname":"Memoh"},
+ "content":{
+ "msgtype":"m.image",
+ "body":"canyon.jpg",
+ "url":"mxc://matrix.example.com/media123",
+ "info":{"mimetype":"image/jpeg","w":640,"h":480}
+ }
+ }`)),
+ Header: make(http.Header),
+ }, nil
+ })}
+
+ var captured channel.InboundMessage
+ delivered, err := adapter.handleEvent(
+ context.Background(),
+ channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"},
+ Config{HomeserverURL: "https://matrix.example.com", AccessToken: "tok", UserID: "@memoh:example.com"},
+ matrixEvent{
+ EventID: "$evt2",
+ Type: "m.room.message",
+ Sender: "@alex:example.com",
+ RoomID: "!room:example.com",
+ Content: map[string]any{
+ "msgtype": "m.text",
+ "body": "> <@memoh:example.com> photo\n>\nWhere is Antelope Canyon?",
+ "m.relates_to": map[string]any{
+ "m.in_reply_to": map[string]any{"event_id": "$img1"},
+ },
+ },
+ },
+ func(_ context.Context, _ channel.ChannelConfig, msg channel.InboundMessage) error {
+ captured = msg
+ return nil
+ },
+ )
+ if err != nil {
+ t.Fatalf("handleEvent returned error: %v", err)
+ }
+ if !delivered {
+ t.Fatal("expected event to be delivered")
+ }
+ if got := captured.Message.Text; got != "[Reply to Memoh: [image]]\nWhere is Antelope Canyon?" {
+ t.Fatalf("unexpected message text: %q", got)
+ }
+ if len(captured.Message.Attachments) != 1 {
+ t.Fatalf("expected one quoted attachment, got %d", len(captured.Message.Attachments))
+ }
+ if captured.Message.Attachments[0].PlatformKey != "mxc://matrix.example.com/media123" {
+ t.Fatalf("unexpected quoted attachment: %#v", captured.Message.Attachments[0])
+ }
+ isReplyToBot, _ := captured.Metadata["is_reply_to_bot"].(bool)
+ if !isReplyToBot {
+ t.Fatalf("expected is_reply_to_bot metadata to be true")
+ }
+ if rawText, _ := captured.Metadata["raw_text"].(string); rawText != "Where is Antelope Canyon?" {
+ t.Fatalf("unexpected raw_text metadata: %q", rawText)
+ }
+}
+
+func TestMatrixHandleEventUsesImageCaptionAsMessageText(t *testing.T) {
+ adapter := NewMatrixAdapter(nil)
+ adapter.rememberRoomConversationType("cfg-1", "!room:example.com", "group")
+
+ var captured channel.InboundMessage
+ delivered, err := adapter.handleEvent(
+ context.Background(),
+ channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"},
+ Config{HomeserverURL: "https://matrix.example.com", AccessToken: "tok", UserID: "@memoh:example.com"},
+ matrixEvent{
+ EventID: "$evt2",
+ Type: "m.room.message",
+ Sender: "@alex:example.com",
+ RoomID: "!room:example.com",
+ Content: map[string]any{
+ "msgtype": "m.image",
+ "body": "A hand-drawn system architecture diagram",
+ "filename": "diagram.png",
+ "url": "mxc://matrix.example.com/media123",
+ "info": map[string]any{
+ "mimetype": "image/png",
+ },
+ },
+ },
+ func(_ context.Context, _ channel.ChannelConfig, msg channel.InboundMessage) error {
+ captured = msg
+ return nil
+ },
+ )
+ if err != nil {
+ t.Fatalf("handleEvent returned error: %v", err)
+ }
+ if !delivered {
+ t.Fatal("expected event to be delivered")
+ }
+ if got := captured.Message.Text; got != "A hand-drawn system architecture diagram" {
+ t.Fatalf("unexpected message text: %q", got)
+ }
+ if len(captured.Message.Attachments) != 1 {
+ t.Fatalf("expected one attachment, got %d", len(captured.Message.Attachments))
+ }
+ att := captured.Message.Attachments[0]
+ if att.Name != "diagram.png" || att.Caption != "A hand-drawn system architecture diagram" {
+ t.Fatalf("unexpected attachment metadata: %#v", att)
+ }
+ if rawText, _ := captured.Metadata["raw_text"].(string); rawText != "A hand-drawn system architecture diagram" {
+ t.Fatalf("unexpected raw_text metadata: %q", rawText)
+ }
+}
+
+func TestMatrixHandleEventMarksDirectConversationFromJoinedMembers(t *testing.T) {
+ joinedMembersRequests := 0
+ adapter := NewMatrixAdapter(nil)
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
+ switch req.URL.Path {
+ case "/_matrix/client/v3/rooms/!room:example.com/joined_members":
+ joinedMembersRequests++
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{
+ "joined": {
+ "@alex:example.com": {"display_name": "Alex"},
+ "@memoh:example.com": {"display_name": "Memoh"}
+ }
+ }`)),
+ Header: make(http.Header),
+ }, nil
+ default:
+ t.Fatalf("unexpected request path: %s", req.URL.Path)
+ return nil, nil
+ }
+ })}
+
+ var captured []channel.InboundMessage
+ for i := 0; i < 2; i++ {
+ delivered, err := adapter.handleEvent(
+ context.Background(),
+ channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"},
+ Config{HomeserverURL: "https://matrix.example.com", AccessToken: "tok", UserID: "@memoh:example.com"},
+ matrixEvent{
+ EventID: fmt.Sprintf("$evt%d", i+1),
+ Type: "m.room.message",
+ Sender: "@alex:example.com",
+ RoomID: "!room:example.com",
+ Content: map[string]any{
+ "msgtype": "m.text",
+ "body": "ping",
+ },
+ },
+ func(_ context.Context, _ channel.ChannelConfig, msg channel.InboundMessage) error {
+ captured = append(captured, msg)
+ return nil
+ },
+ )
+ if err != nil {
+ t.Fatalf("handleEvent returned error: %v", err)
+ }
+ if !delivered {
+ t.Fatal("expected event to be delivered")
+ }
+ }
+
+ if len(captured) != 2 {
+ t.Fatalf("expected two captured messages, got %d", len(captured))
+ }
+ if captured[0].Conversation.Type != "direct" {
+ t.Fatalf("expected direct conversation type, got %q", captured[0].Conversation.Type)
+ }
+ if joinedMembersRequests != 1 {
+ t.Fatalf("expected joined_members lookup to be cached, got %d requests", joinedMembersRequests)
+ }
+}
+
+func TestMatrixSyncOnceAutoJoinsInvitedRooms(t *testing.T) {
+ joinRequests := 0
+ adapter := NewMatrixAdapter(nil)
+ adapter.rememberRoomConversationType("cfg-1", "!joined:example.com", "group")
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
+ switch req.URL.Path {
+ case "/_matrix/client/v3/sync":
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{
+ "next_batch":"s124",
+ "rooms":{
+ "invite":{"!invite:example.com":{"invite_state":{"events":[{"type":"m.room.member"}]}}},
+ "join":{"!joined:example.com":{"timeline":{"events":[{"event_id":"$evt1","type":"m.room.message","sender":"@alex:example.com","content":{"msgtype":"m.text","body":"ping"}}]}}}
+ }
+ }`)),
+ Header: make(http.Header),
+ }, nil
+ case "/_matrix/client/v3/join/!invite:example.com":
+ joinRequests++
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{}`)),
+ Header: make(http.Header),
+ }, nil
+ default:
+ t.Fatalf("unexpected request path: %s", req.URL.Path)
+ return nil, nil
+ }
+ })}
+
+ var captured channel.InboundMessage
+ nextSince, healthy, err := adapter.syncOnce(
+ context.Background(),
+ channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"},
+ Config{HomeserverURL: "https://matrix.example.com", AccessToken: "tok", UserID: "@memoh:example.com", SyncTimeoutSeconds: 30, AutoJoinInvites: true},
+ "s123",
+ func(_ context.Context, _ channel.ChannelConfig, msg channel.InboundMessage) error {
+ captured = msg
+ return nil
+ },
+ )
+ if err != nil {
+ t.Fatalf("syncOnce returned error: %v", err)
+ }
+ if nextSince != "s124" {
+ t.Fatalf("unexpected next since token: %q", nextSince)
+ }
+ if !healthy {
+ t.Fatal("expected sync session to be marked healthy")
+ }
+ if joinRequests != 1 {
+ t.Fatalf("expected invited room to be auto-joined once, got %d", joinRequests)
+ }
+ if captured.ReplyTarget != "!joined:example.com" || captured.Message.Text != "ping" {
+ t.Fatalf("unexpected captured message: %#v", captured)
+ }
+}
+
+func TestExtractMatrixDirectRoomIDs(t *testing.T) {
+ roomIDs := extractMatrixDirectRoomIDs(matrixSyncResponse{
+ AccountData: struct {
+ Events []matrixSyncEvent `json:"events"`
+ }{
+ Events: []matrixSyncEvent{{
+ Type: "m.direct",
+ Content: map[string]any{
+ "@alice:example.com": []any{"!dm:example.com", " !dm2:example.com "},
+ },
+ }},
+ },
+ })
+
+ if _, ok := roomIDs["!dm:example.com"]; !ok {
+ t.Fatal("expected first direct room id to be extracted")
+ }
+ if _, ok := roomIDs["!dm2:example.com"]; !ok {
+ t.Fatal("expected second direct room id to be extracted")
+ }
+}
+
+func TestExtractMatrixDirectRooms(t *testing.T) {
+ directRooms := extractMatrixDirectRooms(matrixSyncResponse{
+ AccountData: struct {
+ Events []matrixSyncEvent `json:"events"`
+ }{
+ Events: []matrixSyncEvent{{
+ Type: "m.direct",
+ Content: map[string]any{
+ "@alice:example.com": []any{"!dm:example.com", "!ignored:example.com"},
+ "@bob:example.com": []any{" !bob:example.com "},
+ },
+ }},
+ },
+ })
+
+ if got := directRooms["@alice:example.com"]; got != "!dm:example.com" {
+ t.Fatalf("unexpected Alice direct room: %q", got)
+ }
+ if got := directRooms["@bob:example.com"]; got != "!bob:example.com" {
+ t.Fatalf("unexpected Bob direct room: %q", got)
+ }
+}
+
+func TestEnsureDirectRoomReusesExistingRoom(t *testing.T) {
+ joinedRoomsRequests := 0
+ joinedMembersRequests := 0
+ createRoomRequests := 0
+ adapter := NewMatrixAdapter(nil)
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
+ switch req.URL.Path {
+ case "/_matrix/client/v3/joined_rooms":
+ joinedRoomsRequests++
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"joined_rooms":["!dm:example.com"]}`)),
+ Header: make(http.Header),
+ }, nil
+ case "/_matrix/client/v3/rooms/!dm:example.com/joined_members":
+ joinedMembersRequests++
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"joined":{"@memoh:example.com":{},"@alice:example.com":{}}}`)),
+ Header: make(http.Header),
+ }, nil
+ case "/_matrix/client/v3/createRoom":
+ createRoomRequests++
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"room_id":"!new:example.com"}`)),
+ Header: make(http.Header),
+ }, nil
+ default:
+ t.Fatalf("unexpected request path: %s", req.URL.Path)
+ return nil, nil
+ }
+ })}
+
+ cfg := Config{
+ HomeserverURL: "https://matrix.example.com",
+ AccessToken: "tok",
+ UserID: "@memoh:example.com",
+ }
+
+ roomID, err := adapter.ensureDirectRoom(context.Background(), cfg, "@alice:example.com")
+ if err != nil {
+ t.Fatalf("ensureDirectRoom returned error: %v", err)
+ }
+ if roomID != "!dm:example.com" {
+ t.Fatalf("unexpected room id: %q", roomID)
+ }
+ roomID, err = adapter.ensureDirectRoom(context.Background(), cfg, "@alice:example.com")
+ if err != nil {
+ t.Fatalf("ensureDirectRoom second call returned error: %v", err)
+ }
+ if roomID != "!dm:example.com" {
+ t.Fatalf("unexpected cached room id: %q", roomID)
+ }
+ if joinedRoomsRequests != 1 {
+ t.Fatalf("expected joined room lookup once, got %d", joinedRoomsRequests)
+ }
+ if joinedMembersRequests != 1 {
+ t.Fatalf("expected joined members lookup once, got %d", joinedMembersRequests)
+ }
+ if createRoomRequests != 0 {
+ t.Fatalf("expected no createRoom requests, got %d", createRoomRequests)
+ }
+}
+
+func TestEnsureDirectRoomCachesCreatedRoom(t *testing.T) {
+ joinedRoomsRequests := 0
+ createRoomRequests := 0
+ adapter := NewMatrixAdapter(nil)
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
+ switch req.URL.Path {
+ case "/_matrix/client/v3/joined_rooms":
+ joinedRoomsRequests++
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"joined_rooms":[]}`)),
+ Header: make(http.Header),
+ }, nil
+ case "/_matrix/client/v3/createRoom":
+ createRoomRequests++
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"room_id":"!new:example.com"}`)),
+ Header: make(http.Header),
+ }, nil
+ default:
+ t.Fatalf("unexpected request path: %s", req.URL.Path)
+ return nil, nil
+ }
+ })}
+
+ cfg := Config{
+ HomeserverURL: "https://matrix.example.com",
+ AccessToken: "tok",
+ UserID: "@memoh:example.com",
+ }
+
+ roomID, err := adapter.ensureDirectRoom(context.Background(), cfg, "@alice:example.com")
+ if err != nil {
+ t.Fatalf("ensureDirectRoom returned error: %v", err)
+ }
+ if roomID != "!new:example.com" {
+ t.Fatalf("unexpected room id: %q", roomID)
+ }
+ roomID, err = adapter.ensureDirectRoom(context.Background(), cfg, "@alice:example.com")
+ if err != nil {
+ t.Fatalf("ensureDirectRoom second call returned error: %v", err)
+ }
+ if roomID != "!new:example.com" {
+ t.Fatalf("unexpected cached room id: %q", roomID)
+ }
+ if joinedRoomsRequests != 1 {
+ t.Fatalf("expected joined room lookup once, got %d", joinedRoomsRequests)
+ }
+ if createRoomRequests != 1 {
+ t.Fatalf("expected createRoom once, got %d", createRoomRequests)
+ }
+}
+
+func TestExtractMatrixInboundContentParsesImageAttachment(t *testing.T) {
+ text, attachments := extractMatrixInboundContent(map[string]any{
+ "msgtype": "m.image",
+ "body": "diagram.png",
+ "url": "mxc://matrix.example.com/media123",
+ "info": map[string]any{
+ "mimetype": "image/png",
+ "size": 42,
+ "w": 640,
+ "h": 480,
+ },
+ })
+ if text != "" {
+ t.Fatalf("expected empty text for attachment message, got %q", text)
+ }
+ if len(attachments) != 1 {
+ t.Fatalf("expected 1 attachment, got %d", len(attachments))
+ }
+ att := attachments[0]
+ if att.Type != channel.AttachmentImage {
+ t.Fatalf("unexpected attachment type: %s", att.Type)
+ }
+ if att.PlatformKey != "mxc://matrix.example.com/media123" {
+ t.Fatalf("unexpected platform key: %q", att.PlatformKey)
+ }
+ if att.Name != "diagram.png" || att.Mime != "image/png" {
+ t.Fatalf("unexpected attachment metadata: %#v", att)
+ }
+ if att.Width != 640 || att.Height != 480 || att.Size != 42 {
+ t.Fatalf("unexpected attachment dimensions: %#v", att)
+ }
+ if att.Caption != "" {
+ t.Fatalf("expected empty caption, got %#v", att)
+ }
+}
+
+func TestExtractMatrixInboundContentParsesImageCaption(t *testing.T) {
+ text, attachments := extractMatrixInboundContent(map[string]any{
+ "msgtype": "m.image",
+ "body": "System architecture diagram",
+ "filename": "diagram.png",
+ "url": "mxc://matrix.example.com/media123",
+ "info": map[string]any{
+ "mimetype": "image/png",
+ },
+ })
+ if text != "System architecture diagram" {
+ t.Fatalf("expected caption text, got %q", text)
+ }
+ if len(attachments) != 1 {
+ t.Fatalf("expected 1 attachment, got %d", len(attachments))
+ }
+ att := attachments[0]
+ if att.Name != "diagram.png" {
+ t.Fatalf("unexpected attachment name: %#v", att)
+ }
+ if att.Caption != "System architecture diagram" {
+ t.Fatalf("unexpected attachment caption: %#v", att)
+ }
+}
+
+func TestMatrixSendUploadsBase64AttachmentAndSendsMediaEvent(t *testing.T) {
+ requests := make([]string, 0, 2)
+ uploadedContentTypes := make([]string, 0, 1)
+ adapter := NewMatrixAdapter(nil)
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
+ requests = append(requests, req.URL.Path)
+ if strings.Contains(req.URL.Path, "/_matrix/media/v3/upload") {
+ uploadedContentTypes = append(uploadedContentTypes, req.Header.Get("Content-Type"))
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"content_uri":"mxc://matrix.example.com/uploaded1"}`)),
+ Header: make(http.Header),
+ }, nil
+ }
+ payload, err := io.ReadAll(req.Body)
+ if err != nil {
+ return nil, err
+ }
+ var content map[string]any
+ if err := json.Unmarshal(payload, &content); err != nil {
+ return nil, err
+ }
+ if got := content["msgtype"]; got != "m.image" {
+ t.Fatalf("unexpected msgtype: %#v", got)
+ }
+ if got := content["url"]; got != "mxc://matrix.example.com/uploaded1" {
+ t.Fatalf("unexpected uploaded uri: %#v", got)
+ }
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"event_id":"$evt1"}`)),
+ Header: make(http.Header),
+ }, nil
+ })}
+
+ err := adapter.Send(context.Background(), channel.ChannelConfig{
+ BotID: "bot-1",
+ Credentials: map[string]any{
+ "homeserverUrl": "https://matrix.example.com",
+ "userId": "@memoh:example.com",
+ "accessToken": "tok",
+ },
+ }, channel.OutboundMessage{
+ Target: "!room:example.com",
+ Message: channel.Message{
+ Attachments: []channel.Attachment{{
+ Type: channel.AttachmentImage,
+ Name: "chart.png",
+ Mime: "image/png",
+ Base64: "data:image/png;base64,aGVsbG8=",
+ }},
+ },
+ })
+ if err != nil {
+ t.Fatalf("send returned error: %v", err)
+ }
+ if len(requests) != 2 {
+ t.Fatalf("expected upload and send requests, got %d", len(requests))
+ }
+ if len(uploadedContentTypes) != 1 || uploadedContentTypes[0] != "image/png" {
+ t.Fatalf("unexpected upload content type: %#v", uploadedContentTypes)
+ }
+}
+
+func TestMatrixSendResolvesRoomAlias(t *testing.T) {
+ requests := make([]string, 0, 2)
+ adapter := NewMatrixAdapter(nil)
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
+ requests = append(requests, req.URL.Path)
+ switch req.URL.Path {
+ case "/_matrix/client/v3/directory/room/#ops:example.com":
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"room_id":"!resolved:example.com"}`)),
+ Header: make(http.Header),
+ }, nil
+ default:
+ if !strings.Contains(req.URL.Path, "/_matrix/client/v3/rooms/!resolved:example.com/send/m.room.message/") {
+ t.Fatalf("unexpected request path: %s", req.URL.Path)
+ }
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"event_id":"$evt1"}`)),
+ Header: make(http.Header),
+ }, nil
+ }
+ })}
+
+ err := adapter.Send(context.Background(), channel.ChannelConfig{
+ Credentials: map[string]any{
+ "homeserverUrl": "https://matrix.example.com",
+ "userId": "@memoh:example.com",
+ "accessToken": "tok",
+ },
+ }, channel.OutboundMessage{
+ Target: "#ops:example.com",
+ Message: channel.Message{
+ Text: "ping",
+ },
+ })
+ if err != nil {
+ t.Fatalf("send returned error: %v", err)
+ }
+ if len(requests) != 2 {
+ t.Fatalf("expected alias lookup and send requests, got %d", len(requests))
+ }
+}
+
+func TestMatrixResolveAttachmentDownloadsMXC(t *testing.T) {
+ adapter := NewMatrixAdapter(nil)
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
+ if !strings.Contains(req.URL.Path, "/_matrix/client/v1/media/download/matrix.example.com/media123/image.png") {
+ t.Fatalf("unexpected download path: %s", req.URL.Path)
+ }
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader("file-bytes")),
+ Header: make(http.Header),
+ }
+ resp.Header.Set("Content-Type", "image/png")
+ resp.ContentLength = int64(len("file-bytes"))
+ return resp, nil
+ })}
+
+ payload, err := adapter.ResolveAttachment(context.Background(), channel.ChannelConfig{
+ Credentials: map[string]any{
+ "homeserverUrl": "https://matrix.example.com",
+ "userId": "@memoh:example.com",
+ "accessToken": "tok",
+ },
+ }, channel.Attachment{
+ PlatformKey: "mxc://matrix.example.com/media123",
+ Name: "image.png",
+ })
+ if err != nil {
+ t.Fatalf("ResolveAttachment returned error: %v", err)
+ }
+ defer func() { _ = payload.Reader.Close() }()
+ data, err := io.ReadAll(payload.Reader)
+ if err != nil {
+ t.Fatalf("read payload: %v", err)
+ }
+ if string(data) != "file-bytes" {
+ t.Fatalf("unexpected payload: %q", string(data))
+ }
+ if payload.Mime != "image/png" || payload.Name != "image.png" || payload.Size != int64(len("file-bytes")) {
+ t.Fatalf("unexpected payload metadata: %#v", payload)
+ }
+}
+
+func TestMatrixResolveAttachmentFallsBackToLegacyMediaDownload(t *testing.T) {
+ paths := make([]string, 0, 2)
+ adapter := NewMatrixAdapter(nil)
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
+ paths = append(paths, req.URL.Path)
+ if strings.Contains(req.URL.Path, "/_matrix/client/v1/media/download/") {
+ return &http.Response{
+ StatusCode: http.StatusNotFound,
+ Body: io.NopCloser(strings.NewReader(`{"errcode":"M_NOT_FOUND"}`)),
+ Header: make(http.Header),
+ }, nil
+ }
+ if !strings.Contains(req.URL.Path, "/_matrix/media/v3/download/matrix.example.com/media123") {
+ t.Fatalf("unexpected fallback path: %s", req.URL.Path)
+ }
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader("legacy-file")),
+ Header: make(http.Header),
+ }
+ resp.Header.Set("Content-Type", "application/octet-stream")
+ return resp, nil
+ })}
+
+ payload, err := adapter.ResolveAttachment(context.Background(), channel.ChannelConfig{
+ Credentials: map[string]any{
+ "homeserverUrl": "https://matrix.example.com",
+ "userId": "@memoh:example.com",
+ "accessToken": "tok",
+ },
+ }, channel.Attachment{
+ PlatformKey: "mxc://matrix.example.com/media123",
+ })
+ if err != nil {
+ t.Fatalf("ResolveAttachment returned error: %v", err)
+ }
+ defer func() { _ = payload.Reader.Close() }()
+ if len(paths) != 2 {
+ t.Fatalf("expected authenticated and legacy download attempts, got %d", len(paths))
+ }
+}
diff --git a/internal/channel/adapters/matrix/stream.go b/internal/channel/adapters/matrix/stream.go
new file mode 100644
index 00000000..f090066d
--- /dev/null
+++ b/internal/channel/adapters/matrix/stream.go
@@ -0,0 +1,231 @@
+package matrix
+
+import (
+ "context"
+ "errors"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/memohai/memoh/internal/channel"
+)
+
+type matrixOutboundStream struct {
+ adapter *MatrixAdapter
+ cfg Config
+ target string
+ reply *channel.ReplyRef
+
+ closed atomic.Bool
+ mu sync.Mutex
+
+ roomID string
+ originalEventID string
+ rawBuffer strings.Builder
+ lastText string
+ lastFormat channel.MessageFormat
+ lastEditedAt time.Time
+}
+
+func (s *matrixOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error {
+ if s == nil || s.adapter == nil {
+ return errors.New("matrix stream not configured")
+ }
+ if s.closed.Load() {
+ return errors.New("matrix stream is closed")
+ }
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
+
+ switch event.Type {
+ case channel.StreamEventStatus,
+ channel.StreamEventPhaseStart,
+ channel.StreamEventToolCallEnd,
+ channel.StreamEventAgentStart,
+ channel.StreamEventAgentEnd,
+ channel.StreamEventProcessingStarted,
+ channel.StreamEventProcessingCompleted,
+ channel.StreamEventProcessingFailed:
+ return nil
+ case channel.StreamEventPhaseEnd:
+ if event.Phase != channel.StreamPhaseText {
+ return nil
+ }
+ s.mu.Lock()
+ text := strings.TrimSpace(s.rawBuffer.String())
+ s.mu.Unlock()
+ return s.upsertText(ctx, text, channel.MessageFormatPlain, true)
+ case channel.StreamEventToolCallStart:
+ s.resetMessageState()
+ return nil
+ case channel.StreamEventDelta:
+ if event.Phase == channel.StreamPhaseReasoning || event.Delta == "" {
+ return nil
+ }
+ s.mu.Lock()
+ s.rawBuffer.WriteString(event.Delta)
+ s.mu.Unlock()
+ return nil
+ case channel.StreamEventError:
+ errText := strings.TrimSpace(event.Error)
+ if errText == "" {
+ return nil
+ }
+ return s.upsertText(ctx, "Error: "+errText, channel.MessageFormatPlain, true)
+ case channel.StreamEventAttachment:
+ return s.pushAttachments(ctx, event.Attachments)
+ case channel.StreamEventFinal:
+ if event.Final == nil {
+ return errors.New("matrix stream final payload is required")
+ }
+ text := strings.TrimSpace(event.Final.Message.PlainText())
+ format := event.Final.Message.Format
+ if format == "" {
+ format = channel.MessageFormatPlain
+ }
+ if text == "" {
+ s.mu.Lock()
+ text = strings.TrimSpace(s.rawBuffer.String())
+ s.mu.Unlock()
+ }
+ if err := s.upsertText(ctx, text, format, true); err != nil {
+ return err
+ }
+ if err := s.pushAttachments(ctx, event.Final.Message.Attachments); err != nil {
+ return err
+ }
+ s.resetMessageState()
+ return nil
+ default:
+ return nil
+ }
+}
+
+func (s *matrixOutboundStream) Close(ctx context.Context) error {
+ if s == nil {
+ return nil
+ }
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
+ s.closed.Store(true)
+ return nil
+}
+
+func (s *matrixOutboundStream) upsertText(ctx context.Context, text string, format channel.MessageFormat, force bool) error {
+ text = strings.TrimSpace(text)
+ if text == "" {
+ return nil
+ }
+ if format == "" {
+ format = channel.MessageFormatPlain
+ }
+
+ s.mu.Lock()
+ roomID := s.roomID
+ originalEventID := s.originalEventID
+ lastText := s.lastText
+ lastFormat := s.lastFormat
+ lastEditedAt := s.lastEditedAt
+ reply := s.reply
+ s.mu.Unlock()
+
+ if roomID == "" {
+ resolvedRoomID, err := s.adapter.resolveRoomTarget(ctx, s.cfg, s.target)
+ if err != nil {
+ return err
+ }
+ roomID = resolvedRoomID
+ s.mu.Lock()
+ s.roomID = resolvedRoomID
+ s.mu.Unlock()
+ }
+
+ if originalEventID == "" {
+ eventID, err := s.adapter.sendTextEvent(ctx, s.cfg, roomID, buildMatrixMessageContent(channel.Message{
+ Text: text,
+ Format: format,
+ Reply: reply,
+ }, false, ""))
+ if err != nil {
+ return err
+ }
+ s.mu.Lock()
+ s.originalEventID = eventID
+ s.lastText = text
+ s.lastFormat = format
+ s.lastEditedAt = time.Now()
+ s.mu.Unlock()
+ return nil
+ }
+
+ if text == lastText && format == lastFormat {
+ return nil
+ }
+ if !force && time.Since(lastEditedAt) < matrixEditThrottle {
+ return nil
+ }
+ _, err := s.adapter.sendTextEvent(ctx, s.cfg, roomID, buildMatrixMessageContent(channel.Message{
+ Text: text,
+ Format: format,
+ }, true, originalEventID))
+ if err != nil {
+ return err
+ }
+ s.mu.Lock()
+ s.lastText = text
+ s.lastFormat = format
+ s.lastEditedAt = time.Now()
+ s.mu.Unlock()
+ return nil
+}
+
+func (s *matrixOutboundStream) resetMessageState() {
+ s.mu.Lock()
+ s.originalEventID = ""
+ s.rawBuffer.Reset()
+ s.lastText = ""
+ s.lastFormat = ""
+ s.lastEditedAt = time.Time{}
+ s.mu.Unlock()
+}
+
+func (s *matrixOutboundStream) pushAttachments(ctx context.Context, attachments []channel.Attachment) error {
+ if len(attachments) == 0 {
+ return nil
+ }
+
+ s.mu.Lock()
+ roomID := s.roomID
+ originalEventID := s.originalEventID
+ reply := s.reply
+ s.mu.Unlock()
+
+ if roomID == "" {
+ resolvedRoomID, err := s.adapter.resolveRoomTarget(ctx, s.cfg, s.target)
+ if err != nil {
+ return err
+ }
+ roomID = resolvedRoomID
+ s.mu.Lock()
+ s.roomID = resolvedRoomID
+ s.mu.Unlock()
+ }
+
+ for idx, att := range attachments {
+ mediaMsg := channel.Message{}
+ if idx == 0 && originalEventID == "" {
+ mediaMsg.Reply = reply
+ }
+ if err := s.adapter.sendMediaAttachment(ctx, s.cfg, roomID, "", mediaMsg, att); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/internal/channel/adapters/matrix/stream_test.go b/internal/channel/adapters/matrix/stream_test.go
new file mode 100644
index 00000000..0d8247a2
--- /dev/null
+++ b/internal/channel/adapters/matrix/stream_test.go
@@ -0,0 +1,189 @@
+package matrix
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "strings"
+ "testing"
+
+ "github.com/memohai/memoh/internal/channel"
+)
+
+type roundTripFunc func(*http.Request) (*http.Response, error)
+
+func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
+ return f(req)
+}
+
+func TestMatrixStreamDoesNotSendDeltaBeforeTextPhaseEnds(t *testing.T) {
+ requests := 0
+ adapter := NewMatrixAdapter(nil)
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(_ *http.Request) (*http.Response, error) {
+ requests++
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"event_id":"$evt1"}`)),
+ Header: make(http.Header),
+ }, nil
+ })}
+
+ stream := &matrixOutboundStream{
+ adapter: adapter,
+ cfg: Config{
+ HomeserverURL: "https://matrix.example.com",
+ AccessToken: "tok",
+ },
+ target: "!room:example.com",
+ }
+
+ ctx := context.Background()
+ if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "draft", Phase: channel.StreamPhaseText}); err != nil {
+ t.Fatalf("push delta: %v", err)
+ }
+ if requests != 0 {
+ t.Fatalf("expected no request before text phase ends, got %d", requests)
+ }
+ if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventPhaseEnd, Phase: channel.StreamPhaseText}); err != nil {
+ t.Fatalf("push phase end: %v", err)
+ }
+ if requests != 1 {
+ t.Fatalf("expected one request after text phase end, got %d", requests)
+ }
+}
+
+func TestMatrixStreamDropsBufferedTextWhenToolStarts(t *testing.T) {
+ requests := 0
+ adapter := NewMatrixAdapter(nil)
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(_ *http.Request) (*http.Response, error) {
+ requests++
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"event_id":"$evt1"}`)),
+ Header: make(http.Header),
+ }, nil
+ })}
+
+ stream := &matrixOutboundStream{
+ adapter: adapter,
+ cfg: Config{
+ HomeserverURL: "https://matrix.example.com",
+ AccessToken: "tok",
+ },
+ target: "!room:example.com",
+ }
+
+ ctx := context.Background()
+ if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "I will inspect first", Phase: channel.StreamPhaseText}); err != nil {
+ t.Fatalf("push delta: %v", err)
+ }
+ if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventToolCallStart}); err != nil {
+ t.Fatalf("push tool call start: %v", err)
+ }
+ if requests != 0 {
+ t.Fatalf("expected no request for discarded pre-tool text, got %d", requests)
+ }
+ if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "Final answer", Phase: channel.StreamPhaseText}); err != nil {
+ t.Fatalf("push final delta: %v", err)
+ }
+ if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventFinal, Final: &channel.StreamFinalizePayload{Message: channel.Message{Text: "Final answer"}}}); err != nil {
+ t.Fatalf("push final: %v", err)
+ }
+ if requests != 1 {
+ t.Fatalf("expected only final visible message to be sent, got %d", requests)
+ }
+}
+
+func TestMatrixStreamFinalMarkdownUpdatesFormattedContent(t *testing.T) {
+ bodies := make([]string, 0, 2)
+ adapter := NewMatrixAdapter(nil)
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
+ payload, err := io.ReadAll(req.Body)
+ if err != nil {
+ return nil, err
+ }
+ bodies = append(bodies, string(payload))
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"event_id":"$evt1"}`)),
+ Header: make(http.Header),
+ }, nil
+ })}
+
+ stream := &matrixOutboundStream{
+ adapter: adapter,
+ cfg: Config{
+ HomeserverURL: "https://matrix.example.com",
+ AccessToken: "tok",
+ },
+ target: "!room:example.com",
+ }
+
+ ctx := context.Background()
+ if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "**bold**", Phase: channel.StreamPhaseText}); err != nil {
+ t.Fatalf("push delta: %v", err)
+ }
+ if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventPhaseEnd, Phase: channel.StreamPhaseText}); err != nil {
+ t.Fatalf("push phase end: %v", err)
+ }
+ if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventFinal, Final: &channel.StreamFinalizePayload{Message: channel.Message{Text: "**bold**", Format: channel.MessageFormatMarkdown}}}); err != nil {
+ t.Fatalf("push final: %v", err)
+ }
+ if len(bodies) != 2 {
+ t.Fatalf("expected two sends, got %d", len(bodies))
+ }
+ if strings.Contains(bodies[0], "formatted_body") {
+ t.Fatalf("expected plain interim send, got %s", bodies[0])
+ }
+ if !strings.Contains(bodies[1], "formatted_body") || !strings.Contains(bodies[1], "org.matrix.custom.html") {
+ t.Fatalf("expected markdown final edit, got %s", bodies[1])
+ }
+}
+
+func TestMatrixStreamFinalSendsAttachments(t *testing.T) {
+ bodies := make([]string, 0, 2)
+ adapter := NewMatrixAdapter(nil)
+ adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
+ payload, err := io.ReadAll(req.Body)
+ if err != nil {
+ return nil, err
+ }
+ bodies = append(bodies, string(payload))
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{"event_id":"$evt1"}`)),
+ Header: make(http.Header),
+ }, nil
+ })}
+
+ stream := &matrixOutboundStream{
+ adapter: adapter,
+ cfg: Config{
+ HomeserverURL: "https://matrix.example.com",
+ AccessToken: "tok",
+ },
+ target: "!room:example.com",
+ }
+
+ ctx := context.Background()
+ if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventFinal, Final: &channel.StreamFinalizePayload{Message: channel.Message{
+ Text: "done",
+ Attachments: []channel.Attachment{{
+ Type: channel.AttachmentImage,
+ PlatformKey: "mxc://matrix.example.com/media123",
+ Name: "image.png",
+ SourcePlatform: Type.String(),
+ }},
+ }}}); err != nil {
+ t.Fatalf("push final: %v", err)
+ }
+ if len(bodies) != 2 {
+ t.Fatalf("expected text and attachment sends, got %d", len(bodies))
+ }
+ if !strings.Contains(bodies[0], `"msgtype":"m.notice"`) {
+ t.Fatalf("expected first payload to be text, got %s", bodies[0])
+ }
+ if !strings.Contains(bodies[1], `"msgtype":"m.image"`) || !strings.Contains(bodies[1], `mxc://matrix.example.com/media123`) {
+ t.Fatalf("expected second payload to be attachment, got %s", bodies[1])
+ }
+}
diff --git a/internal/channel/route/service_test.go b/internal/channel/route/service_test.go
new file mode 100644
index 00000000..9ab17d03
--- /dev/null
+++ b/internal/channel/route/service_test.go
@@ -0,0 +1,13 @@
+package route
+
+import (
+ "testing"
+
+ "github.com/memohai/memoh/internal/conversation"
+)
+
+func TestDetermineConversationKindTreatsDirectAsDirect(t *testing.T) {
+ if got := determineConversationKind("", "direct"); got != conversation.KindDirect {
+ t.Fatalf("unexpected conversation kind: %q", got)
+ }
+}
diff --git a/internal/channel/service.go b/internal/channel/service.go
index c8d298b3..1ea8248f 100644
--- a/internal/channel/service.go
+++ b/internal/channel/service.go
@@ -151,6 +151,28 @@ func (s *Store) UpdateConfigDisabled(ctx context.Context, botID string, channelT
return normalizeChannelConfigFromRow(row)
}
+// SaveMatrixSyncSinceToken persists the Matrix /sync cursor without mutating channel config updated_at.
+func (s *Store) SaveMatrixSyncSinceToken(ctx context.Context, configID string, since string) error {
+ if s.queries == nil {
+ return errors.New("channel queries not configured")
+ }
+ pgConfigID, err := db.ParseUUID(configID)
+ if err != nil {
+ return err
+ }
+ rows, err := s.queries.SaveMatrixSyncSinceToken(ctx, sqlc.SaveMatrixSyncSinceTokenParams{
+ ID: pgConfigID,
+ SinceToken: strings.TrimSpace(since),
+ })
+ if err != nil {
+ return err
+ }
+ if rows == 0 {
+ return fmt.Errorf("%w", ErrChannelConfigNotFound)
+ }
+ return nil
+}
+
// UpsertChannelIdentityConfig creates or updates a channel identity's channel binding.
func (s *Store) UpsertChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType, req UpsertChannelIdentityConfigRequest) (ChannelIdentityBinding, error) {
if s.queries == nil {
diff --git a/internal/channel/types.go b/internal/channel/types.go
index d129485f..f2a368ea 100644
--- a/internal/channel/types.go
+++ b/internal/channel/types.go
@@ -49,7 +49,7 @@ const (
// channel abstraction domain: private/group/thread.
func NormalizeConversationType(raw string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
- case ConversationTypePrivate:
+ case "", "p2p", "direct", ConversationTypePrivate:
return ConversationTypePrivate
case ConversationTypeThread:
return ConversationTypeThread
diff --git a/internal/channel/types_test.go b/internal/channel/types_test.go
new file mode 100644
index 00000000..40ac2f7a
--- /dev/null
+++ b/internal/channel/types_test.go
@@ -0,0 +1,10 @@
+package channel
+
+import "testing"
+
+func TestGenerateRoutingKeyTreatsDirectConversationAsSharedRoute(t *testing.T) {
+ got := GenerateRoutingKey("matrix", "bot-1", "!room:example.com", "direct", "@alex:example.com")
+ if got != "matrix:bot-1:!room:example.com" {
+ t.Fatalf("unexpected routing key: %q", got)
+ }
+}
diff --git a/internal/conversation/flow/assistant_output.go b/internal/conversation/flow/assistant_output.go
index ed18ef18..cce974d2 100644
--- a/internal/conversation/flow/assistant_output.go
+++ b/internal/conversation/flow/assistant_output.go
@@ -16,12 +16,15 @@ func ExtractAssistantOutputs(messages []conversation.ModelMessage) []conversatio
if msg.Role != "assistant" {
continue
}
- // skip tool call tips content.
if hasToolCallContent(msg) {
continue
}
- content := strings.TrimSpace(msg.TextContent())
- parts := filterContentParts(msg.ContentParts())
+ rawParts := msg.ContentParts()
+ parts := filterVisibleContentParts(rawParts)
+ content := visibleContentText(parts)
+ if len(rawParts) == 0 {
+ content = strings.TrimSpace(msg.TextContent())
+ }
if content == "" && len(parts) == 0 {
continue
}
@@ -42,19 +45,55 @@ func hasToolCallContent(msg conversation.ModelMessage) bool {
return false
}
-func filterContentParts(parts []conversation.ContentPart) []conversation.ContentPart {
+func filterVisibleContentParts(parts []conversation.ContentPart) []conversation.ContentPart {
if len(parts) == 0 {
return nil
}
filtered := make([]conversation.ContentPart, 0, len(parts))
for _, p := range parts {
- // Ignore Reasoning parts
- if p.Type == "reasoning" {
- continue
- }
- if p.HasValue() {
+ if isVisibleContentPart(p) {
filtered = append(filtered, p)
}
}
return filtered
}
+
+func isVisibleContentPart(part conversation.ContentPart) bool {
+ if !part.HasValue() {
+ return false
+ }
+ switch strings.ToLower(strings.TrimSpace(part.Type)) {
+ case "reasoning", "tool-call", "tool-result":
+ return false
+ default:
+ return true
+ }
+}
+
+func visibleContentText(parts []conversation.ContentPart) string {
+ if len(parts) == 0 {
+ return ""
+ }
+ texts := make([]string, 0, len(parts))
+ for _, part := range parts {
+ text := strings.TrimSpace(visibleContentPartText(part))
+ if text == "" {
+ continue
+ }
+ texts = append(texts, text)
+ }
+ return strings.TrimSpace(strings.Join(texts, "\n"))
+}
+
+func visibleContentPartText(part conversation.ContentPart) string {
+ if strings.TrimSpace(part.Text) != "" {
+ return part.Text
+ }
+ if strings.TrimSpace(part.URL) != "" {
+ return part.URL
+ }
+ if strings.TrimSpace(part.Emoji) != "" {
+ return part.Emoji
+ }
+ return ""
+}
diff --git a/internal/conversation/flow/assistant_output_test.go b/internal/conversation/flow/assistant_output_test.go
new file mode 100644
index 00000000..fb436f48
--- /dev/null
+++ b/internal/conversation/flow/assistant_output_test.go
@@ -0,0 +1,98 @@
+package flow
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/memohai/memoh/internal/conversation"
+)
+
+func TestExtractAssistantOutputsSkipsAssistantToolCallMessages(t *testing.T) {
+ outputs := ExtractAssistantOutputs([]conversation.ModelMessage{
+ {
+ Role: "assistant",
+ Content: conversation.NewTextContent("I will inspect the file first."),
+ ToolCalls: []conversation.ToolCall{{
+ Type: "function",
+ Function: conversation.ToolCallFunction{
+ Name: "read_file",
+ Arguments: `{"path":"/tmp/a.txt"}`,
+ },
+ }},
+ },
+ {
+ Role: "assistant",
+ Content: conversation.NewTextContent("Done. Here is the final answer."),
+ },
+ })
+
+ if len(outputs) != 1 {
+ t.Fatalf("expected one assistant output, got %d", len(outputs))
+ }
+ if outputs[0].Content != "Done. Here is the final answer." {
+ t.Fatalf("unexpected assistant output: %q", outputs[0].Content)
+ }
+}
+
+func TestExtractAssistantOutputsExcludesReasoningParts(t *testing.T) {
+ content, err := json.Marshal([]conversation.ContentPart{
+ {Type: "reasoning", Text: "I should inspect the file first."},
+ {Type: "text", Text: "Here is the file summary."},
+ })
+ if err != nil {
+ t.Fatalf("marshal content: %v", err)
+ }
+
+ outputs := ExtractAssistantOutputs([]conversation.ModelMessage{{
+ Role: "assistant",
+ Content: content,
+ }})
+
+ if len(outputs) != 1 {
+ t.Fatalf("expected one assistant output, got %d", len(outputs))
+ }
+ if outputs[0].Content != "Here is the file summary." {
+ t.Fatalf("unexpected visible assistant output: %q", outputs[0].Content)
+ }
+ if len(outputs[0].Parts) != 1 || outputs[0].Parts[0].Type != "text" {
+ t.Fatalf("unexpected visible parts: %#v", outputs[0].Parts)
+ }
+}
+
+func TestExtractAssistantOutputsSkipsReasoningOnlyStructuredMessage(t *testing.T) {
+ content, err := json.Marshal([]map[string]any{
+ {"type": "reasoning", "text": "I should inspect the file first."},
+ {"type": "tool-call", "toolName": "read", "toolCallId": "call_1", "input": map[string]any{"path": "/tmp/a.txt"}},
+ })
+ if err != nil {
+ t.Fatalf("marshal content: %v", err)
+ }
+
+ outputs := ExtractAssistantOutputs([]conversation.ModelMessage{{
+ Role: "assistant",
+ Content: content,
+ }})
+
+ if len(outputs) != 0 {
+ t.Fatalf("expected no visible assistant outputs, got %#v", outputs)
+ }
+}
+
+func TestExtractAssistantOutputsSkipsStructuredToolCallMessageWithVisibleText(t *testing.T) {
+ content, err := json.Marshal([]map[string]any{
+ {"type": "text", "text": "I will inspect the file first."},
+ {"type": "tool-call", "toolName": "read", "toolCallId": "call_1", "input": map[string]any{"path": "/tmp/a.txt"}},
+ })
+ if err != nil {
+ t.Fatalf("marshal content: %v", err)
+ }
+
+ outputs := ExtractAssistantOutputs([]conversation.ModelMessage{{
+ Role: "assistant",
+ Content: content,
+ }})
+
+ if len(outputs) != 0 {
+ t.Fatalf("expected no visible assistant outputs, got %#v", outputs)
+ }
+}
diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go
index 72012f29..b12c1b18 100644
--- a/internal/conversation/flow/resolver.go
+++ b/internal/conversation/flow/resolver.go
@@ -2,9 +2,14 @@ package flow
import (
"context"
+ "encoding/base64"
+ "encoding/json"
"errors"
"io"
"log/slog"
+ "math"
+ "sort"
+ "strconv"
"strings"
"time"
@@ -118,8 +123,6 @@ type usageInfo struct {
OutputTokens *int `json:"outputTokens"`
}
-// --- resolved context (shared by Chat / StreamChat / TriggerSchedule) ---
-
type resolvedContext struct {
runConfig agentpkg.RunConfig
model models.GetResponse
@@ -146,7 +149,6 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
}
loopDetectionEnabled := r.loadBotLoopDetectionEnabled(ctx, req.BotID)
- // Check chat-level model override.
var chatSettings conversation.Settings
if r.conversationSvc != nil {
chatSettings, err = r.conversationSvc.GetSettings(ctx, req.ChatID)
@@ -164,8 +166,6 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
maxCtx := coalescePositiveInt(req.MaxContextLoadTime, botSettings.MaxContextLoadTime, defaultMaxContextMinutes)
maxTokens := botSettings.MaxContextTokens
- // Build non-history parts first so we can reserve their token cost before
- // trimming history messages.
memoryMsg := r.loadMemoryContextMessage(ctx, req)
reqMessages := pruneMessagesForGateway(nonNilModelMessages(req.Messages))
if memoryMsg != nil {
@@ -179,8 +179,6 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
for _, m := range reqMessages {
overhead += estimateMessageTokens(m)
}
- // Reserve space for the system prompt built by the agent gateway
- // (IDENTITY.md, SOUL.md, TOOLS.md, skills, boilerplate, user prompt, etc.).
const systemPromptReserve = 4096
overhead += systemPromptReserve
@@ -228,15 +226,11 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
} else {
agentSkills = make([]agentpkg.SkillEntry, 0, len(entries))
for _, e := range entries {
- if strings.TrimSpace(e.Name) == "" {
+ skill, ok := normalizeGatewaySkill(e)
+ if !ok {
continue
}
- agentSkills = append(agentSkills, agentpkg.SkillEntry{
- Name: e.Name,
- Description: e.Description,
- Content: e.Content,
- Metadata: e.Metadata,
- })
+ agentSkills = append(agentSkills, skill)
}
}
}
@@ -245,7 +239,6 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
}
displayName := r.resolveDisplayName(ctx, req)
-
headerifiedQuery := FormatUserHeader(
strings.TrimSpace(req.ExternalMessageID),
strings.TrimSpace(req.SourceChannelIdentityID),
@@ -253,7 +246,7 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
req.CurrentChannel,
strings.TrimSpace(req.ConversationType),
strings.TrimSpace(req.ConversationName),
- nil, // attachments paths handled separately
+ extractFileRefPaths(r.routeAndMergeAttachments(ctx, chatModel, req)),
req.Query,
)
@@ -360,3 +353,171 @@ func (r *Resolver) prepareRunConfig(ctx context.Context, cfg agentpkg.RunConfig)
return cfg
}
+
+func normalizeGatewaySkill(entry SkillEntry) (agentpkg.SkillEntry, bool) {
+ name := strings.TrimSpace(entry.Name)
+ if name == "" {
+ return agentpkg.SkillEntry{}, false
+ }
+ description := strings.TrimSpace(entry.Description)
+ if description == "" {
+ description = name
+ }
+ content := strings.TrimSpace(entry.Content)
+ if content == "" {
+ content = description
+ }
+ return agentpkg.SkillEntry{
+ Name: name,
+ Description: description,
+ Content: content,
+ Metadata: entry.Metadata,
+ }, true
+}
+
+func normalizeUserMessageContent(msg conversation.ModelMessage) conversation.ModelMessage {
+ if !strings.EqualFold(strings.TrimSpace(msg.Role), "user") {
+ return msg
+ }
+ normalized, changed := normalizeUserContentParts(msg.Content)
+ if !changed {
+ return msg
+ }
+ msg.Content = normalized
+ return msg
+}
+
+func normalizeUserContentParts(content json.RawMessage) (json.RawMessage, bool) {
+ if len(content) == 0 {
+ return nil, false
+ }
+ var parts []map[string]any
+ if err := json.Unmarshal(content, &parts); err != nil || len(parts) == 0 {
+ return nil, false
+ }
+
+ changed := false
+ rebuilt := make([]map[string]any, 0, len(parts))
+ for _, part := range parts {
+ partType := strings.TrimSpace(strings.ToLower(readAnyString(part["type"])))
+ switch partType {
+ case "image":
+ normalized, ok, didChange := normalizeUserImagePart(part)
+ if didChange {
+ changed = true
+ }
+ if ok {
+ rebuilt = append(rebuilt, normalized)
+ }
+ default:
+ rebuilt = append(rebuilt, part)
+ }
+ }
+ if !changed {
+ return nil, false
+ }
+ if len(rebuilt) == 0 {
+ rebuilt = append(rebuilt, map[string]any{
+ "type": "text",
+ "text": "[User sent an attachment]",
+ })
+ }
+ data, err := json.Marshal(rebuilt)
+ if err != nil {
+ return nil, false
+ }
+ return data, true
+}
+
+func normalizeUserImagePart(part map[string]any) (map[string]any, bool, bool) {
+ raw, ok := part["image"]
+ if !ok {
+ return nil, false, true
+ }
+ if image, ok := raw.(string); ok && strings.TrimSpace(image) != "" {
+ return part, true, false
+ }
+ bytes, ok := anyIndexedByteObject(raw)
+ if !ok {
+ return nil, false, true
+ }
+ cloned := cloneAnyMap(part)
+ mediaType := strings.TrimSpace(readAnyString(cloned["mediaType"]))
+ encoded := base64.StdEncoding.EncodeToString(bytes)
+ if mediaType != "" {
+ cloned["image"] = "data:" + mediaType + ";base64," + encoded
+ } else {
+ cloned["image"] = encoded
+ }
+ return cloned, true, true
+}
+
+func cloneAnyMap(input map[string]any) map[string]any {
+ cloned := make(map[string]any, len(input))
+ for key, value := range input {
+ cloned[key] = value
+ }
+ return cloned
+}
+
+func readAnyString(value any) string {
+ text, _ := value.(string)
+ return text
+}
+
+func anyIndexedByteObject(value any) ([]byte, bool) {
+ obj, ok := value.(map[string]any)
+ if !ok || len(obj) == 0 {
+ return nil, false
+ }
+ indexes := make([]int, 0, len(obj))
+ values := make(map[int]byte, len(obj))
+ for key, raw := range obj {
+ idx, err := strconv.Atoi(strings.TrimSpace(key))
+ if err != nil || idx < 0 {
+ return nil, false
+ }
+ byteValue, ok := anyNumberToByte(raw)
+ if !ok {
+ return nil, false
+ }
+ indexes = append(indexes, idx)
+ values[idx] = byteValue
+ }
+ sort.Ints(indexes)
+ if indexes[len(indexes)-1]+1 != len(indexes) {
+ return nil, false
+ }
+ bytes := make([]byte, len(indexes))
+ for _, idx := range indexes {
+ bytes[idx] = values[idx]
+ }
+ return bytes, true
+}
+
+func anyNumberToByte(value any) (byte, bool) {
+ floatValue, ok := value.(float64)
+ if !ok || math.IsNaN(floatValue) || math.IsInf(floatValue, 0) {
+ return 0, false
+ }
+ if floatValue < 0 || floatValue > 255 || math.Trunc(floatValue) != floatValue {
+ return 0, false
+ }
+ parsed, err := strconv.ParseUint(strconv.FormatFloat(floatValue, 'f', 0, 64), 10, 8)
+ if err != nil {
+ return 0, false
+ }
+ return byte(parsed), true
+}
+
+// extractFileRefPaths collects container file paths from gateway attachments
+// that use the tool_file_ref transport.
+func extractFileRefPaths(attachments []any) []string {
+ var paths []string
+ for _, att := range attachments {
+ if ga, ok := att.(gatewayAttachment); ok && ga.Transport == gatewayTransportToolFileRef && strings.TrimSpace(ga.Payload) != "" {
+ paths = append(paths, ga.Payload)
+ }
+ }
+ return paths
+}
diff --git a/internal/conversation/flow/resolver_store.go b/internal/conversation/flow/resolver_store.go
index 124d551a..12067a54 100644
--- a/internal/conversation/flow/resolver_store.go
+++ b/internal/conversation/flow/resolver_store.go
@@ -60,6 +60,7 @@ func (r *Resolver) storeMessages(ctx context.Context, req conversation.ChatReque
}
for i, msg := range messages {
+ msg = normalizeUserMessageContent(msg)
content, err := json.Marshal(msg)
if err != nil {
r.logger.Warn("storeMessages: marshal failed", slog.Any("error", err))
diff --git a/internal/conversation/flow/resolver_test.go b/internal/conversation/flow/resolver_test.go
index f9c2cb06..41aaa69c 100644
--- a/internal/conversation/flow/resolver_test.go
+++ b/internal/conversation/flow/resolver_test.go
@@ -274,6 +274,64 @@ func TestOutboundAssetRefsToMessageRefs_Empty(t *testing.T) {
}
}
+func TestSanitizeMessagesNormalizesUserMultipartImageBytes(t *testing.T) {
+ t.Parallel()
+ content, err := json.Marshal([]map[string]any{
+ {"type": "text", "text": "> quoted reply\n\nWhere is Antelope Canyon?"},
+ {"type": "image", "image": map[string]any{"0": 137, "1": 80}, "mediaType": "image/png"},
+ })
+ if err != nil {
+ t.Fatalf("marshal content: %v", err)
+ }
+
+ cleaned := sanitizeMessages([]conversation.ModelMessage{{
+ Role: "user",
+ Content: content,
+ }})
+ if len(cleaned) != 1 {
+ t.Fatalf("expected 1 message, got %d", len(cleaned))
+ }
+ if bytes.Equal(cleaned[0].Content, content) {
+ t.Fatalf("expected user multipart content to be normalized")
+ }
+ var parts []map[string]any
+ if err := json.Unmarshal(cleaned[0].Content, &parts); err != nil {
+ t.Fatalf("unmarshal normalized content: %v", err)
+ }
+ if len(parts) != 2 {
+ t.Fatalf("expected 2 parts after normalization, got %d", len(parts))
+ }
+ if got := parts[0]["text"]; got != "> quoted reply\n\nWhere is Antelope Canyon?" {
+ t.Fatalf("unexpected text part: %#v", got)
+ }
+ image, _ := parts[1]["image"].(string)
+ if !strings.HasPrefix(image, "data:image/png;base64,") {
+ t.Fatalf("expected data URL image payload, got %#v", parts[1]["image"])
+ }
+}
+
+func TestSanitizeMessagesKeepsAssistantMultipartMessages(t *testing.T) {
+ t.Parallel()
+ content, err := json.Marshal([]map[string]any{
+ {"type": "text", "text": "answer"},
+ {"type": "image", "image": "data:image/png;base64,aGVsbG8="},
+ })
+ if err != nil {
+ t.Fatalf("marshal content: %v", err)
+ }
+
+ cleaned := sanitizeMessages([]conversation.ModelMessage{{
+ Role: "assistant",
+ Content: content,
+ }})
+ if len(cleaned) != 1 {
+ t.Fatalf("expected 1 message, got %d", len(cleaned))
+ }
+ if !bytes.Equal(cleaned[0].Content, content) {
+ t.Fatalf("assistant multipart content should remain unchanged")
+ }
+}
+
func TestNormalizeImagePartsToDataURL_ConvertsIndexedObject(t *testing.T) {
msg := conversation.ModelMessage{
Role: "user",
diff --git a/internal/conversation/flow/resolver_util.go b/internal/conversation/flow/resolver_util.go
index fd5f74f9..314a4d60 100644
--- a/internal/conversation/flow/resolver_util.go
+++ b/internal/conversation/flow/resolver_util.go
@@ -17,6 +17,7 @@ import (
func sanitizeMessages(messages []conversation.ModelMessage) []conversation.ModelMessage {
cleaned := make([]conversation.ModelMessage, 0, len(messages))
for _, msg := range messages {
+ msg = normalizeUserMessageContent(msg)
if normalized, ok := normalizeImagePartsToDataURL(msg); ok {
msg = normalized
}
diff --git a/internal/db/sqlc/channels.sql.go b/internal/db/sqlc/channels.sql.go
index 7728e640..45a8836f 100644
--- a/internal/db/sqlc/channels.sql.go
+++ b/internal/db/sqlc/channels.sql.go
@@ -190,6 +190,28 @@ func (q *Queries) ListUserChannelBindingsByPlatform(ctx context.Context, channel
return items, nil
}
+const saveMatrixSyncSinceToken = `-- name: SaveMatrixSyncSinceToken :execrows
+UPDATE bot_channel_configs
+SET routing = COALESCE(routing, '{}'::jsonb) || jsonb_build_object(
+ '_matrix',
+ COALESCE(routing->'_matrix', '{}'::jsonb) || jsonb_build_object('since_token', $2::text)
+)
+WHERE id = $1
+`
+
+type SaveMatrixSyncSinceTokenParams struct {
+ ID pgtype.UUID `json:"id"`
+ SinceToken string `json:"since_token"`
+}
+
+func (q *Queries) SaveMatrixSyncSinceToken(ctx context.Context, arg SaveMatrixSyncSinceTokenParams) (int64, error) {
+ result, err := q.db.Exec(ctx, saveMatrixSyncSinceToken, arg.ID, arg.SinceToken)
+ if err != nil {
+ return 0, err
+ }
+ return result.RowsAffected(), nil
+}
+
const updateBotChannelConfigDisabled = `-- name: UpdateBotChannelConfigDisabled :one
UPDATE bot_channel_configs
SET
diff --git a/internal/handlers/filemanager.go b/internal/handlers/filemanager.go
index b40cdbe7..687c1b15 100644
--- a/internal/handlers/filemanager.go
+++ b/internal/handlers/filemanager.go
@@ -15,6 +15,8 @@ import (
"github.com/memohai/memoh/internal/workspace/bridge"
)
+const mediaContainerRoot = "/data/media"
+
// ---------- request / response types ----------
type FSFileInfo struct {
@@ -83,6 +85,11 @@ func resolveContainerPath(rawPath string) (string, error) {
return cleaned, nil
}
+func isContainerMediaPath(containerPath string) bool {
+ cleaned := filepath.Clean("/" + strings.TrimSpace(containerPath))
+ return cleaned == mediaContainerRoot || strings.HasPrefix(cleaned, mediaContainerRoot+"/")
+}
+
// getGRPCClient returns the gRPC client for the bot's container.
func (h *ContainerdHandler) getGRPCClient(ctx context.Context, botID string) (*bridge.Client, error) {
return h.manager.MCPClient(ctx, botID)
@@ -287,10 +294,6 @@ func (h *ContainerdHandler) FSRead(c echo.Context) error {
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/container/fs/download [get].
func (h *ContainerdHandler) FSDownload(c echo.Context) error {
- botID, err := h.requireBotAccess(c)
- if err != nil {
- return err
- }
rawPath := c.QueryParam("path")
if strings.TrimSpace(rawPath) == "" {
return echo.NewHTTPError(http.StatusBadRequest, "path is required")
@@ -301,6 +304,15 @@ func (h *ContainerdHandler) FSDownload(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
+ requireAccess := h.requireBotAccess
+ if isContainerMediaPath(containerPath) {
+ requireAccess = h.requireBotAccessWithGuest
+ }
+ botID, err := requireAccess(c)
+ if err != nil {
+ return err
+ }
+
ctx := c.Request().Context()
client, err := h.getGRPCClient(ctx, botID)
if err != nil {
diff --git a/internal/handlers/filemanager_test.go b/internal/handlers/filemanager_test.go
new file mode 100644
index 00000000..62ed64db
--- /dev/null
+++ b/internal/handlers/filemanager_test.go
@@ -0,0 +1,22 @@
+package handlers
+
+import "testing"
+
+func TestIsContainerMediaPath(t *testing.T) {
+ tests := []struct {
+ path string
+ want bool
+ }{
+ {path: "/data/media", want: true},
+ {path: "/data/media/0f/demo.jpg", want: true},
+ {path: "data/media/0f/demo.jpg", want: true},
+ {path: "/data/mediakit/demo.jpg", want: false},
+ {path: "/etc/passwd", want: false},
+ }
+
+ for _, tt := range tests {
+ if got := isContainerMediaPath(tt.path); got != tt.want {
+ t.Fatalf("isContainerMediaPath(%q) = %v, want %v", tt.path, got, tt.want)
+ }
+ }
+}