mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat: matrix support (part 1) (#242)
* feat(channel): add Matrix adapter support * fix(channel): prevent reasoning leaks in Matrix replies * fix(channel): persist Matrix sync cursors * fix(channel): improve Matrix markdown rendering * fix(channel): support Matrix attachments and multimodal history * fix(channel): expand Matrix reply media context * fix(handlers): allow media downloads for chat-access bots * fix(channel): classify Matrix DMs as direct chats * fix(channel): auto-join Matrix room invites * fix(channel): resolve Matrix room aliases for outbound send * fix(web): use Matrix brand icon in channel badges Replace the generic Matrix hashtag badge with the official brand asset so channel badges feel recognizable and fit the circular mask cleanly. * fix(channel): add Matrix room whitelist controls Let Matrix bots decide whether to auto-join invites and restrict inbound activity to allowed rooms or aliases. Expose the new controls in the web settings UI with line-based whitelist input so access rules stay explicit. * fix(channel): stabilize Matrix multimodal follow-ups and settings * fix(flow): avoid gosec panic on byte decoding * fix: fix golangci-lint * fix(channel): remove Matrix built-in ACL * fix(channel): preserve Matrix image captions * fix(channel): validate Matrix homeserver and sync access Fail Matrix connections early when the homeserver, access token, or /sync capability is misconfigured so bot health checks surface actionable errors. * fix(channel): preserve optional toggles and relax Matrix startup validation * fix(channel): tighten Matrix mention fallback parsing * fix(flow): skip structured assistant tool-call outputs * fix(flow): resolve merged resolver duplication Keep the internal agent resolver implementation after merging main so split helper files do not redeclare flow symbols. Restore user message normalization in sanitize and persistence paths to keep flow tests and command packages building. * fix(flow): remove unused merged resolver helper Drop the leftover truncate helper and import from the resolver merge fix so golangci-lint passes again without affecting flow behavior. --------- Co-authored-by: Acbox Liu <acbox0328@gmail.com>
This commit is contained in:
@@ -0,0 +1,213 @@
|
||||
package matrix
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
HomeserverURL string
|
||||
AccessToken string //nolint:gosec // intentional: operator-supplied Matrix access token in channel config
|
||||
UserID string
|
||||
SyncTimeoutSeconds int
|
||||
AutoJoinInvites bool
|
||||
}
|
||||
|
||||
type UserConfig struct {
|
||||
RoomID string
|
||||
UserID string
|
||||
}
|
||||
|
||||
func normalizeConfig(raw map[string]any) (map[string]any, error) {
|
||||
cfg, err := parseConfig(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := map[string]any{
|
||||
"homeserverUrl": cfg.HomeserverURL,
|
||||
"accessToken": cfg.AccessToken,
|
||||
"userId": cfg.UserID,
|
||||
"syncTimeoutSeconds": cfg.SyncTimeoutSeconds,
|
||||
"autoJoinInvites": cfg.AutoJoinInvites,
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func normalizeUserConfig(raw map[string]any) (map[string]any, error) {
|
||||
cfg, err := parseUserConfig(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := map[string]any{}
|
||||
if cfg.RoomID != "" {
|
||||
out["room_id"] = cfg.RoomID
|
||||
}
|
||||
if cfg.UserID != "" {
|
||||
out["user_id"] = cfg.UserID
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func resolveTarget(raw map[string]any) (string, error) {
|
||||
cfg, err := parseUserConfig(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if cfg.RoomID != "" {
|
||||
return cfg.RoomID, nil
|
||||
}
|
||||
if cfg.UserID != "" {
|
||||
return cfg.UserID, nil
|
||||
}
|
||||
return "", errors.New("matrix user config requires room_id or user_id")
|
||||
}
|
||||
|
||||
func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool {
|
||||
cfg, err := parseUserConfig(raw)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if cfg.UserID != "" && strings.EqualFold(strings.TrimSpace(criteria.SubjectID), cfg.UserID) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildUserConfig(identity channel.Identity) map[string]any {
|
||||
userID := strings.TrimSpace(identity.Attribute("user_id"))
|
||||
if userID == "" {
|
||||
userID = strings.TrimSpace(identity.SubjectID)
|
||||
}
|
||||
if userID == "" {
|
||||
return map[string]any{}
|
||||
}
|
||||
return map[string]any{"user_id": userID}
|
||||
}
|
||||
|
||||
func parseConfig(raw map[string]any) (Config, error) {
|
||||
homeserverURL := normalizeHomeserverURL(channel.ReadString(raw, "homeserverUrl", "homeserver_url", "homeserver"))
|
||||
accessToken := strings.TrimSpace(channel.ReadString(raw, "accessToken", "access_token"))
|
||||
userID := strings.TrimSpace(channel.ReadString(raw, "userId", "user_id"))
|
||||
if homeserverURL == "" {
|
||||
return Config{}, errors.New("matrix homeserverUrl is required")
|
||||
}
|
||||
if accessToken == "" {
|
||||
return Config{}, errors.New("matrix accessToken is required")
|
||||
}
|
||||
if userID == "" {
|
||||
return Config{}, errors.New("matrix userId is required")
|
||||
}
|
||||
timeout := readInt(raw, 30, "syncTimeoutSeconds", "sync_timeout_seconds")
|
||||
if timeout < 0 {
|
||||
timeout = 0
|
||||
}
|
||||
autoJoinInvites := readBool(raw, true, "autoJoinInvites", "auto_join_invites")
|
||||
return Config{
|
||||
HomeserverURL: homeserverURL,
|
||||
AccessToken: accessToken,
|
||||
UserID: userID,
|
||||
SyncTimeoutSeconds: timeout,
|
||||
AutoJoinInvites: autoJoinInvites,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseUserConfig(raw map[string]any) (UserConfig, error) {
|
||||
roomID := normalizeTarget(channel.ReadString(raw, "roomId", "room_id"))
|
||||
userID := normalizeTarget(channel.ReadString(raw, "userId", "user_id"))
|
||||
if roomID == "" && userID == "" {
|
||||
return UserConfig{}, errors.New("matrix user config requires room_id or user_id")
|
||||
}
|
||||
if roomID != "" && !strings.HasPrefix(roomID, "!") && !strings.HasPrefix(roomID, "#") {
|
||||
return UserConfig{}, errors.New("matrix room_id must start with ! or #")
|
||||
}
|
||||
if userID != "" && !strings.HasPrefix(userID, "@") {
|
||||
return UserConfig{}, errors.New("matrix user_id must start with @")
|
||||
}
|
||||
return UserConfig{RoomID: roomID, UserID: userID}, nil
|
||||
}
|
||||
|
||||
func normalizeTarget(raw string) string {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
for _, prefix := range []string{"matrix:", "room:", "user:"} {
|
||||
if strings.HasPrefix(strings.ToLower(value), prefix) {
|
||||
value = strings.TrimSpace(value[len(prefix):])
|
||||
break
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func normalizeHomeserverURL(raw string) string {
|
||||
value := strings.TrimSpace(raw)
|
||||
return strings.TrimRight(value, "/")
|
||||
}
|
||||
|
||||
func readInt(raw map[string]any, fallback int, keys ...string) int {
|
||||
for _, key := range keys {
|
||||
value, ok := raw[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case string:
|
||||
parsed, err := strconv.Atoi(strings.TrimSpace(v))
|
||||
if err == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func readBool(raw map[string]any, fallback bool, keys ...string) bool {
|
||||
for _, key := range keys {
|
||||
value, ok := raw[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case bool:
|
||||
return v
|
||||
case string:
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "true", "1", "yes", "on":
|
||||
return true
|
||||
case "false", "0", "no", "off":
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func targetKind(target string) string {
|
||||
value := normalizeTarget(target)
|
||||
switch {
|
||||
case strings.HasPrefix(value, "!") || strings.HasPrefix(value, "#"):
|
||||
return "room"
|
||||
case strings.HasPrefix(value, "@"):
|
||||
return "user"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func validateTarget(target string) error {
|
||||
kind := targetKind(target)
|
||||
if kind == "" {
|
||||
return errors.New("matrix target must be a room id/alias or user id")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package matrix
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
cfg, err := parseConfig(map[string]any{
|
||||
"homeserverUrl": "https://matrix.example.com/",
|
||||
"accessToken": "tok",
|
||||
"userId": "@memoh:example.com",
|
||||
"syncTimeoutSeconds": 15,
|
||||
"autoJoinInvites": false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("parseConfig returned error: %v", err)
|
||||
}
|
||||
if cfg.HomeserverURL != "https://matrix.example.com" {
|
||||
t.Fatalf("unexpected homeserver url: %q", cfg.HomeserverURL)
|
||||
}
|
||||
if cfg.UserID != "@memoh:example.com" {
|
||||
t.Fatalf("unexpected user id: %q", cfg.UserID)
|
||||
}
|
||||
if cfg.SyncTimeoutSeconds != 15 {
|
||||
t.Fatalf("unexpected sync timeout: %d", cfg.SyncTimeoutSeconds)
|
||||
}
|
||||
if cfg.AutoJoinInvites {
|
||||
t.Fatal("expected autoJoinInvites to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigDefaultsAutoJoinInvites(t *testing.T) {
|
||||
cfg, err := parseConfig(map[string]any{
|
||||
"homeserverUrl": "https://matrix.example.com",
|
||||
"accessToken": "tok",
|
||||
"userId": "@memoh:example.com",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("parseConfig returned error: %v", err)
|
||||
}
|
||||
if !cfg.AutoJoinInvites {
|
||||
t.Fatal("expected autoJoinInvites default to true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUserConfigRequiresTarget(t *testing.T) {
|
||||
if _, err := parseUserConfig(map[string]any{}); err == nil {
|
||||
t.Fatal("expected parseUserConfig to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveTargetPrefersRoomID(t *testing.T) {
|
||||
target, err := resolveTarget(map[string]any{
|
||||
"room_id": "!room:example.com",
|
||||
"user_id": "@alice:example.com",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("resolveTarget returned error: %v", err)
|
||||
}
|
||||
if target != "!room:example.com" {
|
||||
t.Fatalf("unexpected target: %q", target)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
package matrix
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/yuin/goldmark"
|
||||
"github.com/yuin/goldmark/extension"
|
||||
"github.com/yuin/goldmark/renderer/html"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
)
|
||||
|
||||
const matrixHTMLFormat = "org.matrix.custom.html"
|
||||
|
||||
var matrixMarkdownRenderer = goldmark.New(
|
||||
goldmark.WithExtensions(extension.GFM),
|
||||
goldmark.WithRendererOptions(
|
||||
html.WithHardWraps(),
|
||||
),
|
||||
)
|
||||
|
||||
type matrixFormattedMessage struct {
|
||||
Body string
|
||||
FormattedBody string
|
||||
HasHTML bool
|
||||
}
|
||||
|
||||
var (
|
||||
matrixTaskListPattern = regexp.MustCompile(`^(\s*(?:[-*+]\s+|\d+\.\s+))\[( |x|X)\]\s+(.*)$`)
|
||||
matrixTableAlignCell = regexp.MustCompile(`^:?-{3,}:?$`)
|
||||
)
|
||||
|
||||
func formatMatrixMessage(msg channel.Message) matrixFormattedMessage {
|
||||
body := strings.TrimSpace(msg.PlainText())
|
||||
formatted := matrixFormattedMessage{Body: body}
|
||||
if msg.Format != channel.MessageFormatMarkdown || body == "" {
|
||||
return formatted
|
||||
}
|
||||
body = normalizeMatrixMarkdown(body)
|
||||
formatted.Body = body
|
||||
htmlBody, err := renderMatrixMarkdown(body)
|
||||
if err != nil || strings.TrimSpace(htmlBody) == "" {
|
||||
return formatted
|
||||
}
|
||||
formatted.FormattedBody = htmlBody
|
||||
formatted.HasHTML = true
|
||||
return formatted
|
||||
}
|
||||
|
||||
func renderMatrixMarkdown(text string) (string, error) {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return "", nil
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if err := matrixMarkdownRenderer.Convert([]byte(text), &buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(buf.String()), nil
|
||||
}
|
||||
|
||||
func normalizeMatrixMarkdown(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
lines := strings.Split(text, "\n")
|
||||
result := make([]string, 0, len(lines))
|
||||
inFence := false
|
||||
for i := 0; i < len(lines); i++ {
|
||||
line := lines[i]
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if isFenceLine(trimmed) {
|
||||
inFence = !inFence
|
||||
result = append(result, line)
|
||||
continue
|
||||
}
|
||||
if !inFence && i+1 < len(lines) && isMarkdownTableHeader(line, lines[i+1]) {
|
||||
block := []string{line, lines[i+1]}
|
||||
i += 2
|
||||
for i < len(lines) && isMarkdownTableRow(lines[i]) {
|
||||
block = append(block, lines[i])
|
||||
i++
|
||||
}
|
||||
i--
|
||||
result = append(result, "```text")
|
||||
result = append(result, block...)
|
||||
result = append(result, "```")
|
||||
continue
|
||||
}
|
||||
if !inFence {
|
||||
line = normalizeMatrixTaskListLine(line)
|
||||
}
|
||||
result = append(result, line)
|
||||
}
|
||||
return strings.TrimSpace(strings.Join(result, "\n"))
|
||||
}
|
||||
|
||||
func normalizeMatrixTaskListLine(line string) string {
|
||||
matches := matrixTaskListPattern.FindStringSubmatch(line)
|
||||
if len(matches) != 4 {
|
||||
return line
|
||||
}
|
||||
box := "☐"
|
||||
if strings.EqualFold(matches[2], "x") {
|
||||
box = "☑"
|
||||
}
|
||||
return matches[1] + box + " " + matches[3]
|
||||
}
|
||||
|
||||
func isFenceLine(line string) bool {
|
||||
return strings.HasPrefix(line, "```") || strings.HasPrefix(line, "~~~")
|
||||
}
|
||||
|
||||
func isMarkdownTableHeader(headerLine, delimiterLine string) bool {
|
||||
if !strings.Contains(headerLine, "|") {
|
||||
return false
|
||||
}
|
||||
return isMarkdownTableDelimiter(delimiterLine)
|
||||
}
|
||||
|
||||
func isMarkdownTableDelimiter(line string) bool {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if !strings.Contains(trimmed, "|") {
|
||||
return false
|
||||
}
|
||||
parts := strings.Split(trimmed, "|")
|
||||
validCells := 0
|
||||
for _, part := range parts {
|
||||
cell := strings.TrimSpace(part)
|
||||
if cell == "" {
|
||||
continue
|
||||
}
|
||||
if !matrixTableAlignCell.MatchString(cell) {
|
||||
return false
|
||||
}
|
||||
validCells++
|
||||
}
|
||||
return validCells >= 1
|
||||
}
|
||||
|
||||
func isMarkdownTableRow(line string) bool {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
return trimmed != "" && strings.Contains(trimmed, "|")
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package matrix
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
)
|
||||
|
||||
func TestNormalizeMatrixMarkdownTaskList(t *testing.T) {
|
||||
input := "- [ ] todo\n- [x] done"
|
||||
got := normalizeMatrixMarkdown(input)
|
||||
if got != "- ☐ todo\n- ☑ done" {
|
||||
t.Fatalf("unexpected normalized markdown: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeMatrixMarkdownTablesBecomeCodeBlocks(t *testing.T) {
|
||||
input := "| A | B |\n| --- | --- |\n| 1 | 2 |"
|
||||
got := normalizeMatrixMarkdown(input)
|
||||
if !strings.HasPrefix(got, "```text\n") || !strings.Contains(got, "| 1 | 2 |") || !strings.HasSuffix(got, "\n```") {
|
||||
t.Fatalf("unexpected normalized markdown: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatMatrixMessageMarkdownUsesNormalizedBody(t *testing.T) {
|
||||
formatted := formatMatrixMessage(channel.Message{
|
||||
Text: "- [x] done\n\n| A |\n| --- |\n| 1 |",
|
||||
Format: channel.MessageFormatMarkdown,
|
||||
})
|
||||
if !strings.Contains(formatted.Body, "☑ done") {
|
||||
t.Fatalf("expected task list checkbox in body, got %q", formatted.Body)
|
||||
}
|
||||
if !strings.Contains(formatted.FormattedBody, "<pre><code") || !strings.Contains(formatted.FormattedBody, "| A |") {
|
||||
t.Fatalf("expected table fallback code block in formatted body, got %q", formatted.FormattedBody)
|
||||
}
|
||||
if strings.Contains(formatted.FormattedBody, "<table") {
|
||||
t.Fatalf("expected no html table in formatted body, got %q", formatted.FormattedBody)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,231 @@
|
||||
package matrix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
)
|
||||
|
||||
type matrixOutboundStream struct {
|
||||
adapter *MatrixAdapter
|
||||
cfg Config
|
||||
target string
|
||||
reply *channel.ReplyRef
|
||||
|
||||
closed atomic.Bool
|
||||
mu sync.Mutex
|
||||
|
||||
roomID string
|
||||
originalEventID string
|
||||
rawBuffer strings.Builder
|
||||
lastText string
|
||||
lastFormat channel.MessageFormat
|
||||
lastEditedAt time.Time
|
||||
}
|
||||
|
||||
func (s *matrixOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error {
|
||||
if s == nil || s.adapter == nil {
|
||||
return errors.New("matrix stream not configured")
|
||||
}
|
||||
if s.closed.Load() {
|
||||
return errors.New("matrix stream is closed")
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case channel.StreamEventStatus,
|
||||
channel.StreamEventPhaseStart,
|
||||
channel.StreamEventToolCallEnd,
|
||||
channel.StreamEventAgentStart,
|
||||
channel.StreamEventAgentEnd,
|
||||
channel.StreamEventProcessingStarted,
|
||||
channel.StreamEventProcessingCompleted,
|
||||
channel.StreamEventProcessingFailed:
|
||||
return nil
|
||||
case channel.StreamEventPhaseEnd:
|
||||
if event.Phase != channel.StreamPhaseText {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
text := strings.TrimSpace(s.rawBuffer.String())
|
||||
s.mu.Unlock()
|
||||
return s.upsertText(ctx, text, channel.MessageFormatPlain, true)
|
||||
case channel.StreamEventToolCallStart:
|
||||
s.resetMessageState()
|
||||
return nil
|
||||
case channel.StreamEventDelta:
|
||||
if event.Phase == channel.StreamPhaseReasoning || event.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.rawBuffer.WriteString(event.Delta)
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
case channel.StreamEventError:
|
||||
errText := strings.TrimSpace(event.Error)
|
||||
if errText == "" {
|
||||
return nil
|
||||
}
|
||||
return s.upsertText(ctx, "Error: "+errText, channel.MessageFormatPlain, true)
|
||||
case channel.StreamEventAttachment:
|
||||
return s.pushAttachments(ctx, event.Attachments)
|
||||
case channel.StreamEventFinal:
|
||||
if event.Final == nil {
|
||||
return errors.New("matrix stream final payload is required")
|
||||
}
|
||||
text := strings.TrimSpace(event.Final.Message.PlainText())
|
||||
format := event.Final.Message.Format
|
||||
if format == "" {
|
||||
format = channel.MessageFormatPlain
|
||||
}
|
||||
if text == "" {
|
||||
s.mu.Lock()
|
||||
text = strings.TrimSpace(s.rawBuffer.String())
|
||||
s.mu.Unlock()
|
||||
}
|
||||
if err := s.upsertText(ctx, text, format, true); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.pushAttachments(ctx, event.Final.Message.Attachments); err != nil {
|
||||
return err
|
||||
}
|
||||
s.resetMessageState()
|
||||
return nil
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *matrixOutboundStream) Close(ctx context.Context) error {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
s.closed.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *matrixOutboundStream) upsertText(ctx context.Context, text string, format channel.MessageFormat, force bool) error {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
if format == "" {
|
||||
format = channel.MessageFormatPlain
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
roomID := s.roomID
|
||||
originalEventID := s.originalEventID
|
||||
lastText := s.lastText
|
||||
lastFormat := s.lastFormat
|
||||
lastEditedAt := s.lastEditedAt
|
||||
reply := s.reply
|
||||
s.mu.Unlock()
|
||||
|
||||
if roomID == "" {
|
||||
resolvedRoomID, err := s.adapter.resolveRoomTarget(ctx, s.cfg, s.target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
roomID = resolvedRoomID
|
||||
s.mu.Lock()
|
||||
s.roomID = resolvedRoomID
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
if originalEventID == "" {
|
||||
eventID, err := s.adapter.sendTextEvent(ctx, s.cfg, roomID, buildMatrixMessageContent(channel.Message{
|
||||
Text: text,
|
||||
Format: format,
|
||||
Reply: reply,
|
||||
}, false, ""))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.originalEventID = eventID
|
||||
s.lastText = text
|
||||
s.lastFormat = format
|
||||
s.lastEditedAt = time.Now()
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
if text == lastText && format == lastFormat {
|
||||
return nil
|
||||
}
|
||||
if !force && time.Since(lastEditedAt) < matrixEditThrottle {
|
||||
return nil
|
||||
}
|
||||
_, err := s.adapter.sendTextEvent(ctx, s.cfg, roomID, buildMatrixMessageContent(channel.Message{
|
||||
Text: text,
|
||||
Format: format,
|
||||
}, true, originalEventID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.lastText = text
|
||||
s.lastFormat = format
|
||||
s.lastEditedAt = time.Now()
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *matrixOutboundStream) resetMessageState() {
|
||||
s.mu.Lock()
|
||||
s.originalEventID = ""
|
||||
s.rawBuffer.Reset()
|
||||
s.lastText = ""
|
||||
s.lastFormat = ""
|
||||
s.lastEditedAt = time.Time{}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *matrixOutboundStream) pushAttachments(ctx context.Context, attachments []channel.Attachment) error {
|
||||
if len(attachments) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
roomID := s.roomID
|
||||
originalEventID := s.originalEventID
|
||||
reply := s.reply
|
||||
s.mu.Unlock()
|
||||
|
||||
if roomID == "" {
|
||||
resolvedRoomID, err := s.adapter.resolveRoomTarget(ctx, s.cfg, s.target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
roomID = resolvedRoomID
|
||||
s.mu.Lock()
|
||||
s.roomID = resolvedRoomID
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
for idx, att := range attachments {
|
||||
mediaMsg := channel.Message{}
|
||||
if idx == 0 && originalEventID == "" {
|
||||
mediaMsg.Reply = reply
|
||||
}
|
||||
if err := s.adapter.sendMediaAttachment(ctx, s.cfg, roomID, "", mediaMsg, att); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
package matrix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestMatrixStreamDoesNotSendDeltaBeforeTextPhaseEnds(t *testing.T) {
|
||||
requests := 0
|
||||
adapter := NewMatrixAdapter(nil)
|
||||
adapter.httpClient = &http.Client{Transport: roundTripFunc(func(_ *http.Request) (*http.Response, error) {
|
||||
requests++
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(`{"event_id":"$evt1"}`)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})}
|
||||
|
||||
stream := &matrixOutboundStream{
|
||||
adapter: adapter,
|
||||
cfg: Config{
|
||||
HomeserverURL: "https://matrix.example.com",
|
||||
AccessToken: "tok",
|
||||
},
|
||||
target: "!room:example.com",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "draft", Phase: channel.StreamPhaseText}); err != nil {
|
||||
t.Fatalf("push delta: %v", err)
|
||||
}
|
||||
if requests != 0 {
|
||||
t.Fatalf("expected no request before text phase ends, got %d", requests)
|
||||
}
|
||||
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventPhaseEnd, Phase: channel.StreamPhaseText}); err != nil {
|
||||
t.Fatalf("push phase end: %v", err)
|
||||
}
|
||||
if requests != 1 {
|
||||
t.Fatalf("expected one request after text phase end, got %d", requests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatrixStreamDropsBufferedTextWhenToolStarts(t *testing.T) {
|
||||
requests := 0
|
||||
adapter := NewMatrixAdapter(nil)
|
||||
adapter.httpClient = &http.Client{Transport: roundTripFunc(func(_ *http.Request) (*http.Response, error) {
|
||||
requests++
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(`{"event_id":"$evt1"}`)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})}
|
||||
|
||||
stream := &matrixOutboundStream{
|
||||
adapter: adapter,
|
||||
cfg: Config{
|
||||
HomeserverURL: "https://matrix.example.com",
|
||||
AccessToken: "tok",
|
||||
},
|
||||
target: "!room:example.com",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "I will inspect first", Phase: channel.StreamPhaseText}); err != nil {
|
||||
t.Fatalf("push delta: %v", err)
|
||||
}
|
||||
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventToolCallStart}); err != nil {
|
||||
t.Fatalf("push tool call start: %v", err)
|
||||
}
|
||||
if requests != 0 {
|
||||
t.Fatalf("expected no request for discarded pre-tool text, got %d", requests)
|
||||
}
|
||||
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "Final answer", Phase: channel.StreamPhaseText}); err != nil {
|
||||
t.Fatalf("push final delta: %v", err)
|
||||
}
|
||||
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventFinal, Final: &channel.StreamFinalizePayload{Message: channel.Message{Text: "Final answer"}}}); err != nil {
|
||||
t.Fatalf("push final: %v", err)
|
||||
}
|
||||
if requests != 1 {
|
||||
t.Fatalf("expected only final visible message to be sent, got %d", requests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatrixStreamFinalMarkdownUpdatesFormattedContent(t *testing.T) {
|
||||
bodies := make([]string, 0, 2)
|
||||
adapter := NewMatrixAdapter(nil)
|
||||
adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
payload, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bodies = append(bodies, string(payload))
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(`{"event_id":"$evt1"}`)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})}
|
||||
|
||||
stream := &matrixOutboundStream{
|
||||
adapter: adapter,
|
||||
cfg: Config{
|
||||
HomeserverURL: "https://matrix.example.com",
|
||||
AccessToken: "tok",
|
||||
},
|
||||
target: "!room:example.com",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "**bold**", Phase: channel.StreamPhaseText}); err != nil {
|
||||
t.Fatalf("push delta: %v", err)
|
||||
}
|
||||
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventPhaseEnd, Phase: channel.StreamPhaseText}); err != nil {
|
||||
t.Fatalf("push phase end: %v", err)
|
||||
}
|
||||
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventFinal, Final: &channel.StreamFinalizePayload{Message: channel.Message{Text: "**bold**", Format: channel.MessageFormatMarkdown}}}); err != nil {
|
||||
t.Fatalf("push final: %v", err)
|
||||
}
|
||||
if len(bodies) != 2 {
|
||||
t.Fatalf("expected two sends, got %d", len(bodies))
|
||||
}
|
||||
if strings.Contains(bodies[0], "formatted_body") {
|
||||
t.Fatalf("expected plain interim send, got %s", bodies[0])
|
||||
}
|
||||
if !strings.Contains(bodies[1], "formatted_body") || !strings.Contains(bodies[1], "org.matrix.custom.html") {
|
||||
t.Fatalf("expected markdown final edit, got %s", bodies[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatrixStreamFinalSendsAttachments(t *testing.T) {
|
||||
bodies := make([]string, 0, 2)
|
||||
adapter := NewMatrixAdapter(nil)
|
||||
adapter.httpClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
payload, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bodies = append(bodies, string(payload))
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(`{"event_id":"$evt1"}`)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})}
|
||||
|
||||
stream := &matrixOutboundStream{
|
||||
adapter: adapter,
|
||||
cfg: Config{
|
||||
HomeserverURL: "https://matrix.example.com",
|
||||
AccessToken: "tok",
|
||||
},
|
||||
target: "!room:example.com",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventFinal, Final: &channel.StreamFinalizePayload{Message: channel.Message{
|
||||
Text: "done",
|
||||
Attachments: []channel.Attachment{{
|
||||
Type: channel.AttachmentImage,
|
||||
PlatformKey: "mxc://matrix.example.com/media123",
|
||||
Name: "image.png",
|
||||
SourcePlatform: Type.String(),
|
||||
}},
|
||||
}}}); err != nil {
|
||||
t.Fatalf("push final: %v", err)
|
||||
}
|
||||
if len(bodies) != 2 {
|
||||
t.Fatalf("expected text and attachment sends, got %d", len(bodies))
|
||||
}
|
||||
if !strings.Contains(bodies[0], `"msgtype":"m.notice"`) {
|
||||
t.Fatalf("expected first payload to be text, got %s", bodies[0])
|
||||
}
|
||||
if !strings.Contains(bodies[1], `"msgtype":"m.image"`) || !strings.Contains(bodies[1], `mxc://matrix.example.com/media123`) {
|
||||
t.Fatalf("expected second payload to be attachment, got %s", bodies[1])
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
)
|
||||
|
||||
func TestDetermineConversationKindTreatsDirectAsDirect(t *testing.T) {
|
||||
if got := determineConversationKind("", "direct"); got != conversation.KindDirect {
|
||||
t.Fatalf("unexpected conversation kind: %q", got)
|
||||
}
|
||||
}
|
||||
@@ -151,6 +151,28 @@ func (s *Store) UpdateConfigDisabled(ctx context.Context, botID string, channelT
|
||||
return normalizeChannelConfigFromRow(row)
|
||||
}
|
||||
|
||||
// SaveMatrixSyncSinceToken persists the Matrix /sync cursor without mutating channel config updated_at.
|
||||
func (s *Store) SaveMatrixSyncSinceToken(ctx context.Context, configID string, since string) error {
|
||||
if s.queries == nil {
|
||||
return errors.New("channel queries not configured")
|
||||
}
|
||||
pgConfigID, err := db.ParseUUID(configID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, err := s.queries.SaveMatrixSyncSinceToken(ctx, sqlc.SaveMatrixSyncSinceTokenParams{
|
||||
ID: pgConfigID,
|
||||
SinceToken: strings.TrimSpace(since),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("%w", ErrChannelConfigNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpsertChannelIdentityConfig creates or updates a channel identity's channel binding.
|
||||
func (s *Store) UpsertChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType, req UpsertChannelIdentityConfigRequest) (ChannelIdentityBinding, error) {
|
||||
if s.queries == nil {
|
||||
|
||||
@@ -49,7 +49,7 @@ const (
|
||||
// channel abstraction domain: private/group/thread.
|
||||
func NormalizeConversationType(raw string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case ConversationTypePrivate:
|
||||
case "", "p2p", "direct", ConversationTypePrivate:
|
||||
return ConversationTypePrivate
|
||||
case ConversationTypeThread:
|
||||
return ConversationTypeThread
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
package channel
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGenerateRoutingKeyTreatsDirectConversationAsSharedRoute(t *testing.T) {
|
||||
got := GenerateRoutingKey("matrix", "bot-1", "!room:example.com", "direct", "@alex:example.com")
|
||||
if got != "matrix:bot-1:!room:example.com" {
|
||||
t.Fatalf("unexpected routing key: %q", got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user