Files
Memoh/internal/router/channel.go
T
BBQ 83b6ee608c refactor: bind container lifecycle to bot and improve schedule trigger flow
- Add SetupBotContainer to ContainerLifecycle interface so containers
  are automatically created when a bot is created, matching the existing
  cleanup-on-delete behavior.
- Refactor schedule tools to use bot-scoped API paths and pass identity
  context for proper authorization.
- Introduce dedicated trigger-schedule endpoint in chat resolver with
  explicit schedule payload instead of reusing the generic chat path.
- Generate short-lived JWT tokens for schedule trigger callbacks with
  resolved bot owner identity.
- Validate required parameters in NewLLMClient and NewOpenAIEmbedder
  constructors, returning errors instead of falling back to defaults.
- Add unit tests for schedule token generation and chat resolver.
2026-02-07 12:04:37 +08:00

538 lines
14 KiB
Go

package router
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"regexp"
"strings"
"time"
"unicode"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/channel"
"github.com/memohai/memoh/internal/chat"
"github.com/memohai/memoh/internal/contacts"
)
// ChatGateway 抽象聊天能力,避免路由层直接依赖具体实现。
type ChatGateway interface {
Chat(ctx context.Context, req chat.ChatRequest) (chat.ChatResponse, error)
}
type ContactService interface {
GetByID(ctx context.Context, contactID string) (contacts.Contact, error)
GetByUserID(ctx context.Context, botID, userID string) (contacts.Contact, error)
GetByChannelIdentity(ctx context.Context, botID, platform, externalID string) (contacts.ContactChannel, error)
Create(ctx context.Context, req contacts.CreateRequest) (contacts.Contact, error)
CreateGuest(ctx context.Context, botID, displayName string) (contacts.Contact, error)
UpsertChannel(ctx context.Context, botID, contactID, platform, externalID string, metadata map[string]any) (contacts.ContactChannel, error)
}
const (
silentReplyToken = "NO_REPLY"
minDuplicateTextLength = 10
)
var (
whitespacePattern = regexp.MustCompile(`\s+`)
)
// ChannelInboundProcessor 将 channel 入站消息路由到 chat,并返回可发送的回复。
type ChannelInboundProcessor struct {
chat ChatGateway
registry *channel.Registry
logger *slog.Logger
jwtSecret string
tokenTTL time.Duration
identity *IdentityResolver
}
func NewChannelInboundProcessor(log *slog.Logger, registry *channel.Registry, store channel.ConfigStore, chatGateway ChatGateway, contactService ContactService, policyService PolicyService, preauthService PreauthService, jwtSecret string, tokenTTL time.Duration) *ChannelInboundProcessor {
if log == nil {
log = slog.Default()
}
if tokenTTL <= 0 {
tokenTTL = 5 * time.Minute
}
identityResolver := NewIdentityResolver(log, registry, store, contactService, policyService, preauthService, "", "")
return &ChannelInboundProcessor{
chat: chatGateway,
registry: registry,
logger: log.With(slog.String("component", "channel_router")),
jwtSecret: strings.TrimSpace(jwtSecret),
tokenTTL: tokenTTL,
identity: identityResolver,
}
}
func (p *ChannelInboundProcessor) IdentityMiddleware() channel.Middleware {
if p == nil || p.identity == nil {
return nil
}
return p.identity.Middleware()
}
func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, sender channel.ReplySender) error {
if p.chat == nil {
return fmt.Errorf("channel inbound processor not configured")
}
if sender == nil {
return fmt.Errorf("reply sender not configured")
}
text := buildInboundQuery(msg.Message)
if strings.TrimSpace(text) == "" {
return nil
}
state, err := p.requireIdentity(ctx, cfg, msg)
if err != nil {
return err
}
if state.Decision != nil && state.Decision.Stop {
if !state.Decision.Reply.IsEmpty() {
return sender.Send(ctx, channel.OutboundMessage{
Target: strings.TrimSpace(msg.ReplyTarget),
Message: state.Decision.Reply,
})
}
return nil
}
identity := state.Identity
sessionToken := ""
if p.jwtSecret != "" && strings.TrimSpace(msg.ReplyTarget) != "" {
signed, _, err := auth.GenerateSessionToken(auth.SessionToken{
BotID: identity.BotID,
Platform: msg.Channel.String(),
ReplyTarget: strings.TrimSpace(msg.ReplyTarget),
SessionID: identity.SessionID,
ContactID: identity.ContactID,
}, p.jwtSecret, p.tokenTTL)
if err != nil {
if p.logger != nil {
p.logger.Warn("issue session token failed", slog.Any("error", err))
}
} else {
sessionToken = signed
}
}
token := ""
if identity.UserID != "" && p.jwtSecret != "" {
signed, _, err := auth.GenerateToken(identity.UserID, p.jwtSecret, p.tokenTTL)
if err != nil {
if p.logger != nil {
p.logger.Warn("issue channel token failed", slog.Any("error", err))
}
} else {
token = "Bearer " + signed
}
}
var desc channel.Descriptor
if p.registry != nil {
desc, _ = p.registry.GetDescriptor(msg.Channel)
}
resp, err := p.chat.Chat(ctx, chat.ChatRequest{
BotID: identity.BotID,
SessionID: identity.SessionID,
Token: token,
UserID: identity.UserID,
ContactID: identity.ContactID,
ContactName: strings.TrimSpace(identity.Contact.DisplayName),
ContactAlias: strings.TrimSpace(identity.Contact.Alias),
ReplyTarget: strings.TrimSpace(msg.ReplyTarget),
SessionToken: sessionToken,
Query: text,
CurrentChannel: msg.Channel.String(),
Channels: []string{msg.Channel.String()},
})
if err != nil {
if p.logger != nil {
p.logger.Error("chat gateway failed", slog.String("channel", msg.Channel.String()), slog.String("user_id", identity.UserID), slog.Any("error", err))
}
return err
}
outputs := chat.ExtractAssistantOutputs(resp.Messages)
if len(outputs) == 0 {
return nil
}
target := strings.TrimSpace(msg.ReplyTarget)
if target == "" {
return fmt.Errorf("reply target missing")
}
sentTexts, suppressReplies := collectMessageToolContext(p.registry, resp.Messages, msg.Channel, target)
if suppressReplies {
return nil
}
for _, output := range outputs {
outMessage := buildChannelMessage(output, desc.Capabilities)
if outMessage.IsEmpty() {
continue
}
plainText := strings.TrimSpace(outMessage.PlainText())
if isSilentReplyText(plainText) {
continue
}
if isMessagingToolDuplicate(plainText, sentTexts) {
continue
}
if err := sender.Send(ctx, channel.OutboundMessage{
Target: target,
Message: outMessage,
}); err != nil {
return err
}
}
return nil
}
func buildChannelMessage(output chat.AssistantOutput, capabilities channel.ChannelCapabilities) channel.Message {
msg := channel.Message{}
if strings.TrimSpace(output.Content) != "" {
msg.Text = strings.TrimSpace(output.Content)
if containsMarkdown(msg.Text) && (capabilities.Markdown || capabilities.RichText) {
msg.Format = channel.MessageFormatMarkdown
}
}
if len(output.Parts) == 0 {
return msg
}
if capabilities.RichText {
parts := make([]channel.MessagePart, 0, len(output.Parts))
for _, part := range output.Parts {
if !contentPartHasValue(part) {
continue
}
partType := normalizeContentPartType(part.Type)
parts = append(parts, channel.MessagePart{
Type: partType,
Text: part.Text,
URL: part.URL,
Styles: normalizeContentPartStyles(part.Styles),
Language: part.Language,
UserID: part.UserID,
Emoji: part.Emoji,
})
}
if len(parts) > 0 {
msg.Parts = parts
msg.Format = channel.MessageFormatRich
}
return msg
}
textParts := make([]string, 0, len(output.Parts))
for _, part := range output.Parts {
if !contentPartHasValue(part) {
continue
}
textParts = append(textParts, strings.TrimSpace(contentPartText(part)))
}
if len(textParts) > 0 {
msg.Text = strings.Join(textParts, "\n")
if msg.Format == "" && containsMarkdown(msg.Text) && (capabilities.Markdown || capabilities.RichText) {
msg.Format = channel.MessageFormatMarkdown
}
}
return msg
}
func containsMarkdown(text string) bool {
if strings.TrimSpace(text) == "" {
return false
}
patterns := []string{
`\\*\\*[^*]+\\*\\*`,
`\\*[^*]+\\*`,
`~~[^~]+~~`,
"`[^`]+`",
"```[\\s\\S]*```",
`\\[.+\\]\\(.+\\)`,
`(?m)^#{1,6}\\s`,
`(?m)^[-*]\\s`,
`(?m)^\\d+\\.\\s`,
}
for _, pattern := range patterns {
if matched, _ := regexp.MatchString(pattern, text); matched {
return true
}
}
return false
}
func contentPartHasValue(part chat.ContentPart) bool {
if strings.TrimSpace(part.Text) != "" {
return true
}
if strings.TrimSpace(part.URL) != "" {
return true
}
if strings.TrimSpace(part.Emoji) != "" {
return true
}
return false
}
func contentPartText(part chat.ContentPart) string {
if strings.TrimSpace(part.Text) != "" {
return part.Text
}
if strings.TrimSpace(part.URL) != "" {
return part.URL
}
if strings.TrimSpace(part.Emoji) != "" {
return part.Emoji
}
return ""
}
func buildInboundQuery(message channel.Message) string {
text := strings.TrimSpace(message.PlainText())
if len(message.Attachments) == 0 {
return text
}
lines := make([]string, 0, len(message.Attachments)+1)
if text != "" {
lines = append(lines, text)
}
for _, att := range message.Attachments {
label := strings.TrimSpace(att.Name)
if label == "" {
label = strings.TrimSpace(att.URL)
}
if label == "" {
label = "unknown"
}
lines = append(lines, fmt.Sprintf("[attachment:%s] %s", att.Type, label))
}
return strings.Join(lines, "\n")
}
func normalizeContentPartType(raw string) channel.MessagePartType {
switch strings.TrimSpace(strings.ToLower(raw)) {
case "link":
return channel.MessagePartLink
case "code_block":
return channel.MessagePartCodeBlock
case "mention":
return channel.MessagePartMention
case "emoji":
return channel.MessagePartEmoji
default:
return channel.MessagePartText
}
}
func normalizeContentPartStyles(styles []string) []channel.MessageTextStyle {
if len(styles) == 0 {
return nil
}
result := make([]channel.MessageTextStyle, 0, len(styles))
for _, style := range styles {
switch strings.TrimSpace(strings.ToLower(style)) {
case "bold":
result = append(result, channel.MessageStyleBold)
case "italic":
result = append(result, channel.MessageStyleItalic)
case "strikethrough", "lineThrough":
result = append(result, channel.MessageStyleStrikethrough)
case "code":
result = append(result, channel.MessageStyleCode)
default:
continue
}
}
if len(result) == 0 {
return nil
}
return result
}
type sendMessageToolArgs struct {
Platform string `json:"platform"`
Target string `json:"target"`
UserID string `json:"user_id"`
Text string `json:"text"`
Message *channel.Message `json:"message"`
}
func collectMessageToolContext(registry *channel.Registry, messages []chat.ModelMessage, channelType channel.ChannelType, replyTarget string) ([]string, bool) {
if len(messages) == 0 {
return nil, false
}
var sentTexts []string
suppressReplies := false
for _, msg := range messages {
for _, tc := range msg.ToolCalls {
if tc.Function.Name != "send_message" {
continue
}
var args sendMessageToolArgs
if !parseToolArguments(tc.Function.Arguments, &args) {
continue
}
if text := strings.TrimSpace(extractSendMessageText(args)); text != "" {
sentTexts = append(sentTexts, text)
}
if shouldSuppressForToolCall(registry, args, channelType, replyTarget) {
suppressReplies = true
}
}
}
return sentTexts, suppressReplies
}
func parseToolArguments(raw string, out any) bool {
if strings.TrimSpace(raw) == "" {
return false
}
if err := json.Unmarshal([]byte(raw), out); err == nil {
return true
}
var decoded string
if err := json.Unmarshal([]byte(raw), &decoded); err != nil {
return false
}
if strings.TrimSpace(decoded) == "" {
return false
}
return json.Unmarshal([]byte(decoded), out) == nil
}
func extractSendMessageText(args sendMessageToolArgs) string {
if strings.TrimSpace(args.Text) != "" {
return strings.TrimSpace(args.Text)
}
if args.Message == nil {
return ""
}
return strings.TrimSpace(args.Message.PlainText())
}
func shouldSuppressForToolCall(registry *channel.Registry, args sendMessageToolArgs, channelType channel.ChannelType, replyTarget string) bool {
platform := strings.TrimSpace(args.Platform)
if platform == "" {
platform = string(channelType)
}
if !strings.EqualFold(platform, string(channelType)) {
return false
}
target := strings.TrimSpace(args.Target)
if target == "" && strings.TrimSpace(args.UserID) == "" {
target = replyTarget
}
if strings.TrimSpace(target) == "" || strings.TrimSpace(replyTarget) == "" {
return false
}
normalizedTarget := normalizeReplyTarget(registry, channelType, target)
normalizedReply := normalizeReplyTarget(registry, channelType, replyTarget)
if normalizedTarget == "" || normalizedReply == "" {
return false
}
return normalizedTarget == normalizedReply
}
func normalizeReplyTarget(registry *channel.Registry, channelType channel.ChannelType, target string) string {
if registry == nil {
return strings.TrimSpace(target)
}
normalized, ok := registry.NormalizeTarget(channelType, target)
if ok && strings.TrimSpace(normalized) != "" {
return strings.TrimSpace(normalized)
}
return strings.TrimSpace(target)
}
func isSilentReplyText(text string) bool {
trimmed := strings.TrimSpace(text)
if trimmed == "" {
return false
}
token := []rune(silentReplyToken)
value := []rune(trimmed)
if len(value) < len(token) {
return false
}
if hasTokenPrefix(value, token) {
return true
}
if hasTokenSuffix(value, token) {
return true
}
return false
}
func hasTokenPrefix(value []rune, token []rune) bool {
if len(value) < len(token) {
return false
}
for i := range token {
if value[i] != token[i] {
return false
}
}
if len(value) == len(token) {
return true
}
return !isWordChar(value[len(token)])
}
func hasTokenSuffix(value []rune, token []rune) bool {
if len(value) < len(token) {
return false
}
start := len(value) - len(token)
for i := range token {
if value[start+i] != token[i] {
return false
}
}
if start == 0 {
return true
}
return !isWordChar(value[start-1])
}
func isWordChar(value rune) bool {
return value == '_' || unicode.IsLetter(value) || unicode.IsDigit(value)
}
func normalizeTextForComparison(text string) string {
trimmed := strings.TrimSpace(strings.ToLower(text))
if trimmed == "" {
return ""
}
return strings.TrimSpace(whitespacePattern.ReplaceAllString(trimmed, " "))
}
func isMessagingToolDuplicate(text string, sentTexts []string) bool {
if len(sentTexts) == 0 {
return false
}
normalized := normalizeTextForComparison(text)
if len(normalized) < minDuplicateTextLength {
return false
}
for _, sent := range sentTexts {
sentNormalized := normalizeTextForComparison(sent)
if len(sentNormalized) < minDuplicateTextLength {
continue
}
if strings.Contains(normalized, sentNormalized) || strings.Contains(sentNormalized, normalized) {
return true
}
}
return false
}
func (p *ChannelInboundProcessor) requireIdentity(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) (IdentityState, error) {
if state, ok := IdentityStateFromContext(ctx); ok {
return state, nil
}
if p.identity == nil {
return IdentityState{}, fmt.Errorf("identity resolver not configured")
}
return p.identity.Resolve(ctx, cfg, msg)
}