mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
fix(channel): consistent markdown rendering across all Telegram paths (#210)
- Extract ContainsMarkdown to shared channel package - Auto-detect markdown in normalizeOutboundMessage and MCP send tool - Apply markdown-to-HTML conversion during streaming deltas, not just on the final message - Remove resolveTelegramParseMode which incorrectly returned Telegram's native "Markdown" mode instead of converting to HTML - Fix all 14 Telegram send/edit paths for consistent parse mode handling - Reset parseMode for plain-text error messages to avoid HTML corruption
This commit is contained in:
@@ -231,7 +231,7 @@ func (s *telegramOutboundStream) sendDraft(ctx context.Context, text string) err
|
||||
return err
|
||||
}
|
||||
|
||||
draftErr := sendTelegramDraft(bot, s.streamChatID, s.draftID, text, "")
|
||||
draftErr := sendTelegramDraft(bot, s.streamChatID, s.draftID, text, s.parseMode)
|
||||
if draftErr != nil {
|
||||
if isTelegramTooManyRequests(draftErr) {
|
||||
d := getTelegramRetryAfter(draftErr)
|
||||
@@ -295,10 +295,13 @@ func (s *telegramOutboundStream) pushToolCallStart(ctx context.Context) error {
|
||||
bufText := strings.TrimSpace(s.buf.String())
|
||||
hasMsg := s.streamMsgID != 0
|
||||
s.mu.Unlock()
|
||||
if bufText != "" {
|
||||
bufText = s.formatStreamContent(bufText)
|
||||
}
|
||||
if s.isPrivateChat {
|
||||
// In draft mode, send buffered text as a permanent message before tool execution.
|
||||
if bufText != "" {
|
||||
if err := s.sendPermanentMessage(ctx, bufText, ""); err != nil {
|
||||
if err := s.sendPermanentMessage(ctx, bufText, s.parseMode); err != nil {
|
||||
if s.adapter != nil && s.adapter.logger != nil {
|
||||
s.adapter.logger.Warn("telegram: draft permanent message failed", slog.Any("error", err))
|
||||
}
|
||||
@@ -346,6 +349,7 @@ func (s *telegramOutboundStream) pushPhaseEnd(ctx context.Context, event channel
|
||||
finalText := strings.TrimSpace(s.buf.String())
|
||||
s.mu.Unlock()
|
||||
if finalText != "" {
|
||||
finalText = s.formatStreamContent(finalText)
|
||||
if err := s.ensureStreamMessage(ctx, finalText); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -362,6 +366,7 @@ func (s *telegramOutboundStream) pushDelta(ctx context.Context, event channel.St
|
||||
s.buf.WriteString(event.Delta)
|
||||
content := s.buf.String()
|
||||
s.mu.Unlock()
|
||||
content = s.formatStreamContent(content)
|
||||
if s.isPrivateChat {
|
||||
return s.sendDraft(ctx, content)
|
||||
}
|
||||
@@ -384,7 +389,8 @@ func (s *telegramOutboundStream) pushFinal(ctx context.Context, event channel.St
|
||||
|
||||
if event.Final == nil || event.Final.Message.IsEmpty() {
|
||||
if bufText != "" {
|
||||
if err := s.deliverFinalText(ctx, bufText, ""); err != nil {
|
||||
bufText = s.formatStreamContent(bufText)
|
||||
if err := s.deliverFinalText(ctx, bufText, s.parseMode); err != nil {
|
||||
if s.adapter != nil && s.adapter.logger != nil {
|
||||
s.adapter.logger.Warn("telegram: deliver buffered final text failed", slog.Any("error", err))
|
||||
}
|
||||
@@ -418,7 +424,7 @@ func (s *telegramOutboundStream) pushFinal(ctx context.Context, event channel.St
|
||||
return err
|
||||
}
|
||||
replyTo := parseReplyToMessageID(s.reply)
|
||||
parseMode := resolveTelegramParseMode(msg.Format)
|
||||
parseMode := s.parseMode
|
||||
for i, att := range msg.Attachments {
|
||||
to := replyTo
|
||||
if i > 0 {
|
||||
@@ -438,6 +444,11 @@ func (s *telegramOutboundStream) pushError(ctx context.Context, event channel.St
|
||||
return nil
|
||||
}
|
||||
display := "Error: " + errText
|
||||
// Error messages are plain text; reset parseMode so HTML-mode
|
||||
// left over from earlier deltas does not corrupt the output.
|
||||
s.mu.Lock()
|
||||
s.parseMode = ""
|
||||
s.mu.Unlock()
|
||||
if s.isPrivateChat {
|
||||
return s.sendPermanentMessage(ctx, display, "")
|
||||
}
|
||||
@@ -480,6 +491,22 @@ func (s *telegramOutboundStream) Push(ctx context.Context, event channel.StreamE
|
||||
}
|
||||
}
|
||||
|
||||
// formatStreamContent applies markdown-to-HTML conversion for the accumulated
|
||||
// stream buffer text and updates parseMode accordingly. Safe for incomplete
|
||||
// markdown — unclosed constructs are left as plain text.
|
||||
func (s *telegramOutboundStream) formatStreamContent(text string) string {
|
||||
if channel.ContainsMarkdown(text) {
|
||||
formatted, pm := formatTelegramOutput(text, channel.MessageFormatMarkdown)
|
||||
if pm != "" {
|
||||
s.mu.Lock()
|
||||
s.parseMode = pm
|
||||
s.mu.Unlock()
|
||||
return formatted
|
||||
}
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func (s *telegramOutboundStream) Close(ctx context.Context) error {
|
||||
if s == nil {
|
||||
return nil
|
||||
|
||||
@@ -1200,15 +1200,6 @@ func buildTelegramAnimation(target string, file tgbotapi.RequestFileData) (tgbot
|
||||
return animation, nil
|
||||
}
|
||||
|
||||
func resolveTelegramParseMode(format channel.MessageFormat) string {
|
||||
switch format {
|
||||
case channel.MessageFormatMarkdown:
|
||||
return tgbotapi.ModeMarkdown
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// extractTelegramMentionParts extracts structured mention parts from Telegram message entities.
|
||||
func extractTelegramMentionParts(msg *tgbotapi.Message) []channel.MessagePart {
|
||||
if msg == nil {
|
||||
|
||||
@@ -143,20 +143,6 @@ func TestParseReplyToMessageID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveTelegramParseMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := resolveTelegramParseMode(channel.MessageFormatMarkdown); got != tgbotapi.ModeMarkdown {
|
||||
t.Fatalf("markdown should return ModeMarkdown: %s", got)
|
||||
}
|
||||
if got := resolveTelegramParseMode(channel.MessageFormatPlain); got != "" {
|
||||
t.Fatalf("plain should return empty: %s", got)
|
||||
}
|
||||
if got := resolveTelegramParseMode(channel.MessageFormatRich); got != "" {
|
||||
t.Fatalf("rich should return empty: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildTelegramReplyRef(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ContainsMarkdown returns true if the text contains common Markdown constructs.
|
||||
func ContainsMarkdown(text string) bool {
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return false
|
||||
}
|
||||
patterns := []string{
|
||||
`\*\*[^*]+\*\*`,
|
||||
`\*[^*]+\*`,
|
||||
`~~[^~]+~~`,
|
||||
"`[^`]+`",
|
||||
"```[\\s\\S]*```",
|
||||
`\[.+\]\(.+\)`,
|
||||
`(?m)^#{1,6}\s`,
|
||||
`(?m)^[-*]\s`,
|
||||
`(?m)^\d+\.\s`,
|
||||
}
|
||||
for _, pattern := range patterns {
|
||||
if matched, _ := regexp.MatchString(pattern, text); matched {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package channel
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestContainsMarkdown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
text string
|
||||
want bool
|
||||
}{
|
||||
{"empty", "", false},
|
||||
{"plain", "hello world", false},
|
||||
{"bold", "this is **bold** text", true},
|
||||
{"italic", "this is *italic* text", true},
|
||||
{"code", "use `fmt.Println`", true},
|
||||
{"fenced_code", "```go\nfmt.Println()\n```", true},
|
||||
{"heading", "# Title", true},
|
||||
{"link", "[click](https://example.com)", true},
|
||||
{"unordered_list", "- item one", true},
|
||||
{"ordered_list", "1. first item", true},
|
||||
{"strikethrough", "this is ~~deleted~~ text", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := ContainsMarkdown(tt.text)
|
||||
if got != tt.want {
|
||||
t.Errorf("ContainsMarkdown(%q) = %v, want %v", tt.text, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -792,7 +792,7 @@ func buildChannelMessage(output conversation.AssistantOutput, capabilities chann
|
||||
msg := channel.Message{}
|
||||
if strings.TrimSpace(output.Content) != "" {
|
||||
msg.Text = strings.TrimSpace(output.Content)
|
||||
if containsMarkdown(msg.Text) && (capabilities.Markdown || capabilities.RichText) {
|
||||
if channel.ContainsMarkdown(msg.Text) && (capabilities.Markdown || capabilities.RichText) {
|
||||
msg.Format = channel.MessageFormatMarkdown
|
||||
}
|
||||
}
|
||||
@@ -831,35 +831,13 @@ func buildChannelMessage(output conversation.AssistantOutput, capabilities chann
|
||||
}
|
||||
if len(textParts) > 0 {
|
||||
msg.Text = strings.Join(textParts, "\n")
|
||||
if msg.Format == "" && containsMarkdown(msg.Text) && (capabilities.Markdown || capabilities.RichText) {
|
||||
if msg.Format == "" && channel.ContainsMarkdown(msg.Text) && (capabilities.Markdown || capabilities.RichText) {
|
||||
msg.Format = channel.MessageFormatMarkdown
|
||||
}
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func containsMarkdown(text string) bool {
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return false
|
||||
}
|
||||
patterns := []string{
|
||||
`\\*\\*[^*]+\\*\\*`,
|
||||
`\\*[^*]+\\*`,
|
||||
`~~[^~]+~~`,
|
||||
"`[^`]+`",
|
||||
"```[\\s\\S]*```",
|
||||
`\\[.+\\]\\(.+\\)`,
|
||||
`(?m)^#{1,6}\\s`,
|
||||
`(?m)^[-*]\\s`,
|
||||
`(?m)^\\d+\\.\\s`,
|
||||
}
|
||||
for _, pattern := range patterns {
|
||||
if matched, _ := regexp.MatchString(pattern, text); matched {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func contentPartHasValue(part conversation.ContentPart) bool {
|
||||
if strings.TrimSpace(part.Text) != "" {
|
||||
|
||||
@@ -287,7 +287,11 @@ func normalizeOutboundMessage(msg Message) Message {
|
||||
if len(msg.Parts) > 0 {
|
||||
msg.Format = MessageFormatRich
|
||||
} else if strings.TrimSpace(msg.Text) != "" {
|
||||
msg.Format = MessageFormatPlain
|
||||
if ContainsMarkdown(msg.Text) {
|
||||
msg.Format = MessageFormatMarkdown
|
||||
} else {
|
||||
msg.Format = MessageFormatPlain
|
||||
}
|
||||
}
|
||||
}
|
||||
return msg
|
||||
|
||||
@@ -1038,3 +1038,35 @@ func TestIsNaturalBreakPoint(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOutboundMessage_MarkdownDetected(t *testing.T) {
|
||||
t.Parallel()
|
||||
msg := normalizeOutboundMessage(Message{Text: "Hello **world**"})
|
||||
if msg.Format != MessageFormatMarkdown {
|
||||
t.Errorf("expected %q, got %q", MessageFormatMarkdown, msg.Format)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOutboundMessage_PlainText(t *testing.T) {
|
||||
t.Parallel()
|
||||
msg := normalizeOutboundMessage(Message{Text: "Hello world"})
|
||||
if msg.Format != MessageFormatPlain {
|
||||
t.Errorf("expected %q, got %q", MessageFormatPlain, msg.Format)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOutboundMessage_ExplicitFormatPreserved(t *testing.T) {
|
||||
t.Parallel()
|
||||
msg := normalizeOutboundMessage(Message{Text: "Hello **world**", Format: MessageFormatPlain})
|
||||
if msg.Format != MessageFormatPlain {
|
||||
t.Errorf("expected explicit format %q preserved, got %q", MessageFormatPlain, msg.Format)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOutboundMessage_RichParts(t *testing.T) {
|
||||
t.Parallel()
|
||||
msg := normalizeOutboundMessage(Message{Parts: []MessagePart{{Type: "text", Text: "hello"}}})
|
||||
if msg.Format != MessageFormatRich {
|
||||
t.Errorf("expected %q, got %q", MessageFormatRich, msg.Format)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,6 +214,11 @@ func (p *Executor) callSend(ctx context.Context, session mcpgw.ToolSessionContex
|
||||
outboundMessage.Reply = &channel.ReplyRef{MessageID: replyTo}
|
||||
}
|
||||
|
||||
// Auto-detect markdown when format is not explicitly set.
|
||||
if outboundMessage.Format == "" && channel.ContainsMarkdown(outboundMessage.Text) {
|
||||
outboundMessage.Format = channel.MessageFormatMarkdown
|
||||
}
|
||||
|
||||
target := mcpgw.FirstStringArg(arguments, "target")
|
||||
if target == "" {
|
||||
target = strings.TrimSpace(session.ReplyTarget)
|
||||
|
||||
Reference in New Issue
Block a user