feat: matrix support (part 1) (#242)

* feat(channel): add Matrix adapter support

* fix(channel): prevent reasoning leaks in Matrix replies

* fix(channel): persist Matrix sync cursors

* fix(channel): improve Matrix markdown rendering

* fix(channel): support Matrix attachments and multimodal history

* fix(channel): expand Matrix reply media context

* fix(handlers): allow media downloads for chat-access bots

* fix(channel): classify Matrix DMs as direct chats

* fix(channel): auto-join Matrix room invites

* fix(channel): resolve Matrix room aliases for outbound send

* fix(web): use Matrix brand icon in channel badges

Replace the generic Matrix hashtag badge with the official brand asset so channel badges feel recognizable and fit the circular mask cleanly.

* fix(channel): add Matrix room whitelist controls

Let Matrix bots decide whether to auto-join invites and restrict inbound activity to allowed rooms or aliases. Expose the new controls in the web settings UI with line-based whitelist input so access rules stay explicit.

* fix(channel): stabilize Matrix multimodal follow-ups and settings

* fix(flow): avoid gosec panic on byte decoding

* fix: fix golangci-lint

* fix(channel): remove Matrix built-in ACL

* fix(channel): preserve Matrix image captions

* fix(channel): validate Matrix homeserver and sync access

Fail Matrix connections early when the homeserver, access token, or /sync capability is misconfigured so bot health checks surface actionable errors.

* fix(channel): preserve optional toggles and relax Matrix startup validation

* fix(channel): tighten Matrix mention fallback parsing

* fix(flow): skip structured assistant tool-call outputs

* fix(flow): resolve merged resolver duplication

Keep the internal agent resolver implementation after merging main so split helper files do not redeclare flow symbols. Restore user message normalization in sanitize and persistence paths to keep flow tests and command packages building.

* fix(flow): remove unused merged resolver helper

Drop the leftover truncate helper and import from the resolver merge fix so golangci-lint passes again without affecting flow behavior.

---------

Co-authored-by: Acbox Liu <acbox0328@gmail.com>
This commit is contained in:
AlexMa233
2026-03-22 21:55:34 +08:00
committed by GitHub
parent a4473d252a
commit 609ca49cf5
32 changed files with 4394 additions and 51 deletions
+213
View File
@@ -0,0 +1,213 @@
package matrix
import (
"errors"
"strconv"
"strings"
"github.com/memohai/memoh/internal/channel"
)
type Config struct {
HomeserverURL string
AccessToken string //nolint:gosec // intentional: operator-supplied Matrix access token in channel config
UserID string
SyncTimeoutSeconds int
AutoJoinInvites bool
}
type UserConfig struct {
RoomID string
UserID string
}
func normalizeConfig(raw map[string]any) (map[string]any, error) {
cfg, err := parseConfig(raw)
if err != nil {
return nil, err
}
out := map[string]any{
"homeserverUrl": cfg.HomeserverURL,
"accessToken": cfg.AccessToken,
"userId": cfg.UserID,
"syncTimeoutSeconds": cfg.SyncTimeoutSeconds,
"autoJoinInvites": cfg.AutoJoinInvites,
}
return out, nil
}
func normalizeUserConfig(raw map[string]any) (map[string]any, error) {
cfg, err := parseUserConfig(raw)
if err != nil {
return nil, err
}
out := map[string]any{}
if cfg.RoomID != "" {
out["room_id"] = cfg.RoomID
}
if cfg.UserID != "" {
out["user_id"] = cfg.UserID
}
return out, nil
}
func resolveTarget(raw map[string]any) (string, error) {
cfg, err := parseUserConfig(raw)
if err != nil {
return "", err
}
if cfg.RoomID != "" {
return cfg.RoomID, nil
}
if cfg.UserID != "" {
return cfg.UserID, nil
}
return "", errors.New("matrix user config requires room_id or user_id")
}
func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool {
cfg, err := parseUserConfig(raw)
if err != nil {
return false
}
if cfg.UserID != "" && strings.EqualFold(strings.TrimSpace(criteria.SubjectID), cfg.UserID) {
return true
}
return false
}
func buildUserConfig(identity channel.Identity) map[string]any {
userID := strings.TrimSpace(identity.Attribute("user_id"))
if userID == "" {
userID = strings.TrimSpace(identity.SubjectID)
}
if userID == "" {
return map[string]any{}
}
return map[string]any{"user_id": userID}
}
func parseConfig(raw map[string]any) (Config, error) {
homeserverURL := normalizeHomeserverURL(channel.ReadString(raw, "homeserverUrl", "homeserver_url", "homeserver"))
accessToken := strings.TrimSpace(channel.ReadString(raw, "accessToken", "access_token"))
userID := strings.TrimSpace(channel.ReadString(raw, "userId", "user_id"))
if homeserverURL == "" {
return Config{}, errors.New("matrix homeserverUrl is required")
}
if accessToken == "" {
return Config{}, errors.New("matrix accessToken is required")
}
if userID == "" {
return Config{}, errors.New("matrix userId is required")
}
timeout := readInt(raw, 30, "syncTimeoutSeconds", "sync_timeout_seconds")
if timeout < 0 {
timeout = 0
}
autoJoinInvites := readBool(raw, true, "autoJoinInvites", "auto_join_invites")
return Config{
HomeserverURL: homeserverURL,
AccessToken: accessToken,
UserID: userID,
SyncTimeoutSeconds: timeout,
AutoJoinInvites: autoJoinInvites,
}, nil
}
func parseUserConfig(raw map[string]any) (UserConfig, error) {
roomID := normalizeTarget(channel.ReadString(raw, "roomId", "room_id"))
userID := normalizeTarget(channel.ReadString(raw, "userId", "user_id"))
if roomID == "" && userID == "" {
return UserConfig{}, errors.New("matrix user config requires room_id or user_id")
}
if roomID != "" && !strings.HasPrefix(roomID, "!") && !strings.HasPrefix(roomID, "#") {
return UserConfig{}, errors.New("matrix room_id must start with ! or #")
}
if userID != "" && !strings.HasPrefix(userID, "@") {
return UserConfig{}, errors.New("matrix user_id must start with @")
}
return UserConfig{RoomID: roomID, UserID: userID}, nil
}
func normalizeTarget(raw string) string {
value := strings.TrimSpace(raw)
if value == "" {
return ""
}
for _, prefix := range []string{"matrix:", "room:", "user:"} {
if strings.HasPrefix(strings.ToLower(value), prefix) {
value = strings.TrimSpace(value[len(prefix):])
break
}
}
return value
}
func normalizeHomeserverURL(raw string) string {
value := strings.TrimSpace(raw)
return strings.TrimRight(value, "/")
}
func readInt(raw map[string]any, fallback int, keys ...string) int {
for _, key := range keys {
value, ok := raw[key]
if !ok {
continue
}
switch v := value.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case string:
parsed, err := strconv.Atoi(strings.TrimSpace(v))
if err == nil {
return parsed
}
}
}
return fallback
}
func readBool(raw map[string]any, fallback bool, keys ...string) bool {
for _, key := range keys {
value, ok := raw[key]
if !ok {
continue
}
switch v := value.(type) {
case bool:
return v
case string:
switch strings.ToLower(strings.TrimSpace(v)) {
case "true", "1", "yes", "on":
return true
case "false", "0", "no", "off":
return false
}
}
}
return fallback
}
func targetKind(target string) string {
value := normalizeTarget(target)
switch {
case strings.HasPrefix(value, "!") || strings.HasPrefix(value, "#"):
return "room"
case strings.HasPrefix(value, "@"):
return "user"
default:
return ""
}
}
func validateTarget(target string) error {
kind := targetKind(target)
if kind == "" {
return errors.New("matrix target must be a room id/alias or user id")
}
return nil
}
@@ -0,0 +1,61 @@
package matrix
import "testing"
func TestParseConfig(t *testing.T) {
cfg, err := parseConfig(map[string]any{
"homeserverUrl": "https://matrix.example.com/",
"accessToken": "tok",
"userId": "@memoh:example.com",
"syncTimeoutSeconds": 15,
"autoJoinInvites": false,
})
if err != nil {
t.Fatalf("parseConfig returned error: %v", err)
}
if cfg.HomeserverURL != "https://matrix.example.com" {
t.Fatalf("unexpected homeserver url: %q", cfg.HomeserverURL)
}
if cfg.UserID != "@memoh:example.com" {
t.Fatalf("unexpected user id: %q", cfg.UserID)
}
if cfg.SyncTimeoutSeconds != 15 {
t.Fatalf("unexpected sync timeout: %d", cfg.SyncTimeoutSeconds)
}
if cfg.AutoJoinInvites {
t.Fatal("expected autoJoinInvites to be false")
}
}
func TestParseConfigDefaultsAutoJoinInvites(t *testing.T) {
cfg, err := parseConfig(map[string]any{
"homeserverUrl": "https://matrix.example.com",
"accessToken": "tok",
"userId": "@memoh:example.com",
})
if err != nil {
t.Fatalf("parseConfig returned error: %v", err)
}
if !cfg.AutoJoinInvites {
t.Fatal("expected autoJoinInvites default to true")
}
}
func TestParseUserConfigRequiresTarget(t *testing.T) {
if _, err := parseUserConfig(map[string]any{}); err == nil {
t.Fatal("expected parseUserConfig to fail")
}
}
func TestResolveTargetPrefersRoomID(t *testing.T) {
target, err := resolveTarget(map[string]any{
"room_id": "!room:example.com",
"user_id": "@alice:example.com",
})
if err != nil {
t.Fatalf("resolveTarget returned error: %v", err)
}
if target != "!room:example.com" {
t.Fatalf("unexpected target: %q", target)
}
}
@@ -0,0 +1,147 @@
package matrix
import (
"bytes"
"regexp"
"strings"
"github.com/yuin/goldmark"
"github.com/yuin/goldmark/extension"
"github.com/yuin/goldmark/renderer/html"
"github.com/memohai/memoh/internal/channel"
)
const matrixHTMLFormat = "org.matrix.custom.html"
var matrixMarkdownRenderer = goldmark.New(
goldmark.WithExtensions(extension.GFM),
goldmark.WithRendererOptions(
html.WithHardWraps(),
),
)
type matrixFormattedMessage struct {
Body string
FormattedBody string
HasHTML bool
}
var (
matrixTaskListPattern = regexp.MustCompile(`^(\s*(?:[-*+]\s+|\d+\.\s+))\[( |x|X)\]\s+(.*)$`)
matrixTableAlignCell = regexp.MustCompile(`^:?-{3,}:?$`)
)
func formatMatrixMessage(msg channel.Message) matrixFormattedMessage {
body := strings.TrimSpace(msg.PlainText())
formatted := matrixFormattedMessage{Body: body}
if msg.Format != channel.MessageFormatMarkdown || body == "" {
return formatted
}
body = normalizeMatrixMarkdown(body)
formatted.Body = body
htmlBody, err := renderMatrixMarkdown(body)
if err != nil || strings.TrimSpace(htmlBody) == "" {
return formatted
}
formatted.FormattedBody = htmlBody
formatted.HasHTML = true
return formatted
}
func renderMatrixMarkdown(text string) (string, error) {
text = strings.TrimSpace(text)
if text == "" {
return "", nil
}
var buf bytes.Buffer
if err := matrixMarkdownRenderer.Convert([]byte(text), &buf); err != nil {
return "", err
}
return strings.TrimSpace(buf.String()), nil
}
func normalizeMatrixMarkdown(text string) string {
text = strings.TrimSpace(text)
if text == "" {
return ""
}
lines := strings.Split(text, "\n")
result := make([]string, 0, len(lines))
inFence := false
for i := 0; i < len(lines); i++ {
line := lines[i]
trimmed := strings.TrimSpace(line)
if isFenceLine(trimmed) {
inFence = !inFence
result = append(result, line)
continue
}
if !inFence && i+1 < len(lines) && isMarkdownTableHeader(line, lines[i+1]) {
block := []string{line, lines[i+1]}
i += 2
for i < len(lines) && isMarkdownTableRow(lines[i]) {
block = append(block, lines[i])
i++
}
i--
result = append(result, "```text")
result = append(result, block...)
result = append(result, "```")
continue
}
if !inFence {
line = normalizeMatrixTaskListLine(line)
}
result = append(result, line)
}
return strings.TrimSpace(strings.Join(result, "\n"))
}
func normalizeMatrixTaskListLine(line string) string {
matches := matrixTaskListPattern.FindStringSubmatch(line)
if len(matches) != 4 {
return line
}
box := "☐"
if strings.EqualFold(matches[2], "x") {
box = "☑"
}
return matches[1] + box + " " + matches[3]
}
func isFenceLine(line string) bool {
return strings.HasPrefix(line, "```") || strings.HasPrefix(line, "~~~")
}
func isMarkdownTableHeader(headerLine, delimiterLine string) bool {
if !strings.Contains(headerLine, "|") {
return false
}
return isMarkdownTableDelimiter(delimiterLine)
}
func isMarkdownTableDelimiter(line string) bool {
trimmed := strings.TrimSpace(line)
if !strings.Contains(trimmed, "|") {
return false
}
parts := strings.Split(trimmed, "|")
validCells := 0
for _, part := range parts {
cell := strings.TrimSpace(part)
if cell == "" {
continue
}
if !matrixTableAlignCell.MatchString(cell) {
return false
}
validCells++
}
return validCells >= 1
}
func isMarkdownTableRow(line string) bool {
trimmed := strings.TrimSpace(line)
return trimmed != "" && strings.Contains(trimmed, "|")
}
@@ -0,0 +1,40 @@
package matrix
import (
"strings"
"testing"
"github.com/memohai/memoh/internal/channel"
)
func TestNormalizeMatrixMarkdownTaskList(t *testing.T) {
input := "- [ ] todo\n- [x] done"
got := normalizeMatrixMarkdown(input)
if got != "- ☐ todo\n- ☑ done" {
t.Fatalf("unexpected normalized markdown: %q", got)
}
}
func TestNormalizeMatrixMarkdownTablesBecomeCodeBlocks(t *testing.T) {
input := "| A | B |\n| --- | --- |\n| 1 | 2 |"
got := normalizeMatrixMarkdown(input)
if !strings.HasPrefix(got, "```text\n") || !strings.Contains(got, "| 1 | 2 |") || !strings.HasSuffix(got, "\n```") {
t.Fatalf("unexpected normalized markdown: %q", got)
}
}
func TestFormatMatrixMessageMarkdownUsesNormalizedBody(t *testing.T) {
formatted := formatMatrixMessage(channel.Message{
Text: "- [x] done\n\n| A |\n| --- |\n| 1 |",
Format: channel.MessageFormatMarkdown,
})
if !strings.Contains(formatted.Body, "☑ done") {
t.Fatalf("expected task list checkbox in body, got %q", formatted.Body)
}
if !strings.Contains(formatted.FormattedBody, "<pre><code") || !strings.Contains(formatted.FormattedBody, "| A |") {
t.Fatalf("expected table fallback code block in formatted body, got %q", formatted.FormattedBody)
}
if strings.Contains(formatted.FormattedBody, "<table") {
t.Fatalf("expected no html table in formatted body, got %q", formatted.FormattedBody)
}
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+231
View File
@@ -0,0 +1,231 @@
package matrix
import (
"context"
"errors"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/memohai/memoh/internal/channel"
)
type matrixOutboundStream struct {
adapter *MatrixAdapter
cfg Config
target string
reply *channel.ReplyRef
closed atomic.Bool
mu sync.Mutex
roomID string
originalEventID string
rawBuffer strings.Builder
lastText string
lastFormat channel.MessageFormat
lastEditedAt time.Time
}
func (s *matrixOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error {
if s == nil || s.adapter == nil {
return errors.New("matrix stream not configured")
}
if s.closed.Load() {
return errors.New("matrix stream is closed")
}
select {
case <-ctx.Done():
return ctx.Err()
default:
}
switch event.Type {
case channel.StreamEventStatus,
channel.StreamEventPhaseStart,
channel.StreamEventToolCallEnd,
channel.StreamEventAgentStart,
channel.StreamEventAgentEnd,
channel.StreamEventProcessingStarted,
channel.StreamEventProcessingCompleted,
channel.StreamEventProcessingFailed:
return nil
case channel.StreamEventPhaseEnd:
if event.Phase != channel.StreamPhaseText {
return nil
}
s.mu.Lock()
text := strings.TrimSpace(s.rawBuffer.String())
s.mu.Unlock()
return s.upsertText(ctx, text, channel.MessageFormatPlain, true)
case channel.StreamEventToolCallStart:
s.resetMessageState()
return nil
case channel.StreamEventDelta:
if event.Phase == channel.StreamPhaseReasoning || event.Delta == "" {
return nil
}
s.mu.Lock()
s.rawBuffer.WriteString(event.Delta)
s.mu.Unlock()
return nil
case channel.StreamEventError:
errText := strings.TrimSpace(event.Error)
if errText == "" {
return nil
}
return s.upsertText(ctx, "Error: "+errText, channel.MessageFormatPlain, true)
case channel.StreamEventAttachment:
return s.pushAttachments(ctx, event.Attachments)
case channel.StreamEventFinal:
if event.Final == nil {
return errors.New("matrix stream final payload is required")
}
text := strings.TrimSpace(event.Final.Message.PlainText())
format := event.Final.Message.Format
if format == "" {
format = channel.MessageFormatPlain
}
if text == "" {
s.mu.Lock()
text = strings.TrimSpace(s.rawBuffer.String())
s.mu.Unlock()
}
if err := s.upsertText(ctx, text, format, true); err != nil {
return err
}
if err := s.pushAttachments(ctx, event.Final.Message.Attachments); err != nil {
return err
}
s.resetMessageState()
return nil
default:
return nil
}
}
func (s *matrixOutboundStream) Close(ctx context.Context) error {
if s == nil {
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
default:
}
s.closed.Store(true)
return nil
}
func (s *matrixOutboundStream) upsertText(ctx context.Context, text string, format channel.MessageFormat, force bool) error {
text = strings.TrimSpace(text)
if text == "" {
return nil
}
if format == "" {
format = channel.MessageFormatPlain
}
s.mu.Lock()
roomID := s.roomID
originalEventID := s.originalEventID
lastText := s.lastText
lastFormat := s.lastFormat
lastEditedAt := s.lastEditedAt
reply := s.reply
s.mu.Unlock()
if roomID == "" {
resolvedRoomID, err := s.adapter.resolveRoomTarget(ctx, s.cfg, s.target)
if err != nil {
return err
}
roomID = resolvedRoomID
s.mu.Lock()
s.roomID = resolvedRoomID
s.mu.Unlock()
}
if originalEventID == "" {
eventID, err := s.adapter.sendTextEvent(ctx, s.cfg, roomID, buildMatrixMessageContent(channel.Message{
Text: text,
Format: format,
Reply: reply,
}, false, ""))
if err != nil {
return err
}
s.mu.Lock()
s.originalEventID = eventID
s.lastText = text
s.lastFormat = format
s.lastEditedAt = time.Now()
s.mu.Unlock()
return nil
}
if text == lastText && format == lastFormat {
return nil
}
if !force && time.Since(lastEditedAt) < matrixEditThrottle {
return nil
}
_, err := s.adapter.sendTextEvent(ctx, s.cfg, roomID, buildMatrixMessageContent(channel.Message{
Text: text,
Format: format,
}, true, originalEventID))
if err != nil {
return err
}
s.mu.Lock()
s.lastText = text
s.lastFormat = format
s.lastEditedAt = time.Now()
s.mu.Unlock()
return nil
}
func (s *matrixOutboundStream) resetMessageState() {
s.mu.Lock()
s.originalEventID = ""
s.rawBuffer.Reset()
s.lastText = ""
s.lastFormat = ""
s.lastEditedAt = time.Time{}
s.mu.Unlock()
}
func (s *matrixOutboundStream) pushAttachments(ctx context.Context, attachments []channel.Attachment) error {
if len(attachments) == 0 {
return nil
}
s.mu.Lock()
roomID := s.roomID
originalEventID := s.originalEventID
reply := s.reply
s.mu.Unlock()
if roomID == "" {
resolvedRoomID, err := s.adapter.resolveRoomTarget(ctx, s.cfg, s.target)
if err != nil {
return err
}
roomID = resolvedRoomID
s.mu.Lock()
s.roomID = resolvedRoomID
s.mu.Unlock()
}
for idx, att := range attachments {
mediaMsg := channel.Message{}
if idx == 0 && originalEventID == "" {
mediaMsg.Reply = reply
}
if err := s.adapter.sendMediaAttachment(ctx, s.cfg, roomID, "", mediaMsg, att); err != nil {
return err
}
}
return nil
}
@@ -0,0 +1,189 @@
package matrix
import (
"context"
"io"
"net/http"
"strings"
"testing"
"github.com/memohai/memoh/internal/channel"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestMatrixStreamDoesNotSendDeltaBeforeTextPhaseEnds(t *testing.T) {
requests := 0
adapter := NewMatrixAdapter(nil)
adapter.httpClient = &http.Client{Transport: roundTripFunc(func(_ *http.Request) (*http.Response, error) {
requests++
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{"event_id":"$evt1"}`)),
Header: make(http.Header),
}, nil
})}
stream := &matrixOutboundStream{
adapter: adapter,
cfg: Config{
HomeserverURL: "https://matrix.example.com",
AccessToken: "tok",
},
target: "!room:example.com",
}
ctx := context.Background()
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "draft", Phase: channel.StreamPhaseText}); err != nil {
t.Fatalf("push delta: %v", err)
}
if requests != 0 {
t.Fatalf("expected no request before text phase ends, got %d", requests)
}
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventPhaseEnd, Phase: channel.StreamPhaseText}); err != nil {
t.Fatalf("push phase end: %v", err)
}
if requests != 1 {
t.Fatalf("expected one request after text phase end, got %d", requests)
}
}
func TestMatrixStreamDropsBufferedTextWhenToolStarts(t *testing.T) {
requests := 0
adapter := NewMatrixAdapter(nil)
adapter.httpClient = &http.Client{Transport: roundTripFunc(func(_ *http.Request) (*http.Response, error) {
requests++
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{"event_id":"$evt1"}`)),
Header: make(http.Header),
}, nil
})}
stream := &matrixOutboundStream{
adapter: adapter,
cfg: Config{
HomeserverURL: "https://matrix.example.com",
AccessToken: "tok",
},
target: "!room:example.com",
}
ctx := context.Background()
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "I will inspect first", Phase: channel.StreamPhaseText}); err != nil {
t.Fatalf("push delta: %v", err)
}
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventToolCallStart}); err != nil {
t.Fatalf("push tool call start: %v", err)
}
if requests != 0 {
t.Fatalf("expected no request for discarded pre-tool text, got %d", requests)
}
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "Final answer", Phase: channel.StreamPhaseText}); err != nil {
t.Fatalf("push final delta: %v", err)
}
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventFinal, Final: &channel.StreamFinalizePayload{Message: channel.Message{Text: "Final answer"}}}); err != nil {
t.Fatalf("push final: %v", err)
}
if requests != 1 {
t.Fatalf("expected only final visible message to be sent, got %d", requests)
}
}
func TestMatrixStreamFinalMarkdownUpdatesFormattedContent(t *testing.T) {
bodies := make([]string, 0, 2)
adapter := NewMatrixAdapter(nil)
adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
payload, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
bodies = append(bodies, string(payload))
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{"event_id":"$evt1"}`)),
Header: make(http.Header),
}, nil
})}
stream := &matrixOutboundStream{
adapter: adapter,
cfg: Config{
HomeserverURL: "https://matrix.example.com",
AccessToken: "tok",
},
target: "!room:example.com",
}
ctx := context.Background()
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "**bold**", Phase: channel.StreamPhaseText}); err != nil {
t.Fatalf("push delta: %v", err)
}
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventPhaseEnd, Phase: channel.StreamPhaseText}); err != nil {
t.Fatalf("push phase end: %v", err)
}
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventFinal, Final: &channel.StreamFinalizePayload{Message: channel.Message{Text: "**bold**", Format: channel.MessageFormatMarkdown}}}); err != nil {
t.Fatalf("push final: %v", err)
}
if len(bodies) != 2 {
t.Fatalf("expected two sends, got %d", len(bodies))
}
if strings.Contains(bodies[0], "formatted_body") {
t.Fatalf("expected plain interim send, got %s", bodies[0])
}
if !strings.Contains(bodies[1], "formatted_body") || !strings.Contains(bodies[1], "org.matrix.custom.html") {
t.Fatalf("expected markdown final edit, got %s", bodies[1])
}
}
func TestMatrixStreamFinalSendsAttachments(t *testing.T) {
bodies := make([]string, 0, 2)
adapter := NewMatrixAdapter(nil)
adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
payload, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
bodies = append(bodies, string(payload))
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{"event_id":"$evt1"}`)),
Header: make(http.Header),
}, nil
})}
stream := &matrixOutboundStream{
adapter: adapter,
cfg: Config{
HomeserverURL: "https://matrix.example.com",
AccessToken: "tok",
},
target: "!room:example.com",
}
ctx := context.Background()
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventFinal, Final: &channel.StreamFinalizePayload{Message: channel.Message{
Text: "done",
Attachments: []channel.Attachment{{
Type: channel.AttachmentImage,
PlatformKey: "mxc://matrix.example.com/media123",
Name: "image.png",
SourcePlatform: Type.String(),
}},
}}}); err != nil {
t.Fatalf("push final: %v", err)
}
if len(bodies) != 2 {
t.Fatalf("expected text and attachment sends, got %d", len(bodies))
}
if !strings.Contains(bodies[0], `"msgtype":"m.notice"`) {
t.Fatalf("expected first payload to be text, got %s", bodies[0])
}
if !strings.Contains(bodies[1], `"msgtype":"m.image"`) || !strings.Contains(bodies[1], `mxc://matrix.example.com/media123`) {
t.Fatalf("expected second payload to be attachment, got %s", bodies[1])
}
}
+13
View File
@@ -0,0 +1,13 @@
package route
import (
"testing"
"github.com/memohai/memoh/internal/conversation"
)
func TestDetermineConversationKindTreatsDirectAsDirect(t *testing.T) {
if got := determineConversationKind("", "direct"); got != conversation.KindDirect {
t.Fatalf("unexpected conversation kind: %q", got)
}
}
+22
View File
@@ -151,6 +151,28 @@ func (s *Store) UpdateConfigDisabled(ctx context.Context, botID string, channelT
return normalizeChannelConfigFromRow(row)
}
// SaveMatrixSyncSinceToken persists the Matrix /sync cursor without mutating channel config updated_at.
func (s *Store) SaveMatrixSyncSinceToken(ctx context.Context, configID string, since string) error {
if s.queries == nil {
return errors.New("channel queries not configured")
}
pgConfigID, err := db.ParseUUID(configID)
if err != nil {
return err
}
rows, err := s.queries.SaveMatrixSyncSinceToken(ctx, sqlc.SaveMatrixSyncSinceTokenParams{
ID: pgConfigID,
SinceToken: strings.TrimSpace(since),
})
if err != nil {
return err
}
if rows == 0 {
return fmt.Errorf("%w", ErrChannelConfigNotFound)
}
return nil
}
// UpsertChannelIdentityConfig creates or updates a channel identity's channel binding.
func (s *Store) UpsertChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType, req UpsertChannelIdentityConfigRequest) (ChannelIdentityBinding, error) {
if s.queries == nil {
+1 -1
View File
@@ -49,7 +49,7 @@ const (
// channel abstraction domain: private/group/thread.
func NormalizeConversationType(raw string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case ConversationTypePrivate:
case "", "p2p", "direct", ConversationTypePrivate:
return ConversationTypePrivate
case ConversationTypeThread:
return ConversationTypeThread
+10
View File
@@ -0,0 +1,10 @@
package channel
import "testing"
func TestGenerateRoutingKeyTreatsDirectConversationAsSharedRoute(t *testing.T) {
got := GenerateRoutingKey("matrix", "bot-1", "!room:example.com", "direct", "@alex:example.com")
if got != "matrix:bot-1:!room:example.com" {
t.Fatalf("unexpected routing key: %q", got)
}
}