Files
Memoh/internal/channel/route/service.go
T
2026-03-14 17:15:41 +08:00

380 lines
12 KiB
Go

package route
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"strings"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/channel"
"github.com/memohai/memoh/internal/conversation"
dbpkg "github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
)
// ConversationService contains the minimal conversation behavior required by route resolution.
type ConversationService interface {
Create(ctx context.Context, botID, channelIdentityID string, req conversation.CreateRequest) (conversation.Conversation, error)
}
// DBService manages channel routes and route-to-conversation resolution.
type DBService struct {
queries *sqlc.Queries
conversation ConversationService
logger *slog.Logger
}
// NewService creates a channel route service.
func NewService(log *slog.Logger, queries *sqlc.Queries, conversationService ConversationService) *DBService {
if log == nil {
log = slog.Default()
}
return &DBService{
queries: queries,
conversation: conversationService,
logger: log.With(slog.String("service", "channel/route")),
}
}
// Create creates a route.
func (s *DBService) Create(ctx context.Context, input CreateInput) (Route, error) {
pgConversationID, err := dbpkg.ParseUUID(input.ChatID)
if err != nil {
return Route{}, err
}
pgBotID, err := dbpkg.ParseUUID(input.BotID)
if err != nil {
return Route{}, err
}
var pgConfigID pgtype.UUID
if strings.TrimSpace(input.ChannelConfigID) != "" {
pgConfigID, err = dbpkg.ParseUUID(input.ChannelConfigID)
if err != nil {
return Route{}, err
}
}
metadata, err := json.Marshal(nonNilMap(input.Metadata))
if err != nil {
return Route{}, fmt.Errorf("marshal route metadata: %w", err)
}
row, err := s.queries.CreateChatRoute(ctx, sqlc.CreateChatRouteParams{
ChatID: pgConversationID,
BotID: pgBotID,
Platform: input.Platform,
ChannelConfigID: pgConfigID,
ConversationID: input.ConversationID,
ThreadID: toPgText(input.ThreadID),
ConversationType: toPgText(input.ConversationType),
ReplyTarget: toPgText(input.ReplyTarget),
Metadata: metadata,
})
if err != nil {
return Route{}, fmt.Errorf("create route: %w", err)
}
return toRouteFromCreate(row), nil
}
// Find finds a route by bot/platform/external-conversation/thread.
func (s *DBService) Find(ctx context.Context, botID, platform, conversationID, threadID string) (Route, error) {
pgBotID, err := dbpkg.ParseUUID(botID)
if err != nil {
return Route{}, err
}
row, err := s.queries.FindChatRoute(ctx, sqlc.FindChatRouteParams{
BotID: pgBotID,
Platform: platform,
ConversationID: conversationID,
ThreadID: toPgText(threadID),
})
if err != nil {
return Route{}, err
}
return toRouteFromFind(row), nil
}
// GetByID gets a route by ID.
func (s *DBService) GetByID(ctx context.Context, routeID string) (Route, error) {
pgID, err := dbpkg.ParseUUID(routeID)
if err != nil {
return Route{}, err
}
row, err := s.queries.GetChatRouteByID(ctx, pgID)
if err != nil {
return Route{}, err
}
return toRouteFromGet(row), nil
}
// List lists all routes for a conversation.
func (s *DBService) List(ctx context.Context, conversationID string) ([]Route, error) {
pgID, err := dbpkg.ParseUUID(conversationID)
if err != nil {
return nil, err
}
rows, err := s.queries.ListChatRoutes(ctx, pgID)
if err != nil {
return nil, err
}
routes := make([]Route, 0, len(rows))
for _, row := range rows {
routes = append(routes, toRouteFromList(row))
}
return routes, nil
}
// Delete deletes a route by ID.
func (s *DBService) Delete(ctx context.Context, routeID string) error {
pgID, err := dbpkg.ParseUUID(routeID)
if err != nil {
return err
}
return s.queries.DeleteChatRoute(ctx, pgID)
}
// UpdateReplyTarget updates default reply target.
func (s *DBService) UpdateReplyTarget(ctx context.Context, routeID, replyTarget string) error {
pgID, err := dbpkg.ParseUUID(routeID)
if err != nil {
return err
}
return s.queries.UpdateChatRouteReplyTarget(ctx, sqlc.UpdateChatRouteReplyTargetParams{
ID: pgID,
ReplyTarget: toPgText(replyTarget),
})
}
// UpdateMetadata replaces the route metadata.
func (s *DBService) UpdateMetadata(ctx context.Context, routeID string, metadata map[string]any) error {
pgID, err := dbpkg.ParseUUID(routeID)
if err != nil {
return err
}
data, err := json.Marshal(nonNilMap(metadata))
if err != nil {
return fmt.Errorf("marshal route metadata: %w", err)
}
return s.queries.UpdateChatRouteMetadata(ctx, sqlc.UpdateChatRouteMetadataParams{
ID: pgID,
Metadata: data,
})
}
// ResolveConversation finds or creates a conversation route for an inbound message.
func (s *DBService) ResolveConversation(ctx context.Context, input ResolveInput) (ResolveConversationResult, error) {
route, err := s.Find(ctx, input.BotID, input.Platform, input.ConversationID, input.ThreadID)
if err == nil {
if strings.TrimSpace(input.ReplyTarget) != "" && input.ReplyTarget != route.ReplyTarget {
if updateErr := s.UpdateReplyTarget(ctx, route.ID, input.ReplyTarget); updateErr != nil && s.logger != nil {
s.logger.Warn("update route reply target failed", slog.Any("error", updateErr))
}
}
if len(input.Metadata) > 0 && metadataChanged(route.Metadata, input.Metadata) {
merged := mergeMetadata(route.Metadata, input.Metadata)
if updateErr := s.UpdateMetadata(ctx, route.ID, merged); updateErr != nil && s.logger != nil {
s.logger.Warn("update route metadata failed", slog.Any("error", updateErr))
}
}
pgConversationID, parseErr := dbpkg.ParseUUID(route.ChatID)
if parseErr != nil {
return ResolveConversationResult{}, fmt.Errorf("parse route conversation id: %w", parseErr)
}
if touchErr := s.queries.TouchChat(ctx, pgConversationID); touchErr != nil && s.logger != nil {
s.logger.Warn("touch conversation failed", slog.Any("error", touchErr))
}
return ResolveConversationResult{ChatID: route.ChatID, RouteID: route.ID, Created: false}, nil
}
if s.conversation == nil {
return ResolveConversationResult{}, errors.New("conversation service not configured")
}
kind := determineConversationKind(input.ThreadID, input.ConversationType)
creatorChannelIdentityID := s.resolveConversationCreatorChannelIdentityID(ctx, input.BotID, input.ChannelIdentityID, kind)
var parentConversationID string
if kind == conversation.KindThread {
parentRoute, parentErr := s.Find(ctx, input.BotID, input.Platform, input.ConversationID, "")
if parentErr == nil {
parentConversationID = parentRoute.ChatID
}
}
createdConversation, err := s.conversation.Create(ctx, input.BotID, creatorChannelIdentityID, conversation.CreateRequest{
Kind: kind,
ParentChatID: parentConversationID,
})
if err != nil {
return ResolveConversationResult{}, fmt.Errorf("create conversation: %w", err)
}
newRoute, err := s.Create(ctx, CreateInput{
ChatID: createdConversation.ID,
BotID: input.BotID,
Platform: input.Platform,
ChannelConfigID: input.ChannelConfigID,
ConversationID: input.ConversationID,
ThreadID: input.ThreadID,
ConversationType: input.ConversationType,
ReplyTarget: input.ReplyTarget,
Metadata: input.Metadata,
})
if err != nil {
// Concurrent insert race: another goroutine created the same route between
// our Find and Create calls. Fall back to Find the winning row.
if dbpkg.IsUniqueViolation(err) {
existing, findErr := s.Find(ctx, input.BotID, input.Platform, input.ConversationID, input.ThreadID)
if findErr == nil {
return ResolveConversationResult{ChatID: existing.ChatID, RouteID: existing.ID, Created: false}, nil
}
}
return ResolveConversationResult{}, fmt.Errorf("create route: %w", err)
}
return ResolveConversationResult{ChatID: createdConversation.ID, RouteID: newRoute.ID, Created: true}, nil
}
func determineConversationKind(threadID, conversationType string) string {
if strings.TrimSpace(threadID) != "" {
return conversation.KindThread
}
switch channel.NormalizeConversationType(conversationType) {
case channel.ConversationTypeThread:
return conversation.KindThread
case channel.ConversationTypePrivate:
return conversation.KindDirect
default:
return conversation.KindGroup
}
}
func (s *DBService) resolveConversationCreatorChannelIdentityID(ctx context.Context, botID, fallbackChannelIdentityID, kind string) string {
fallback := strings.TrimSpace(fallbackChannelIdentityID)
if kind != conversation.KindGroup || s.queries == nil {
return fallback
}
pgBotID, err := dbpkg.ParseUUID(botID)
if err != nil {
return fallback
}
row, err := s.queries.GetBotByID(ctx, pgBotID)
if err != nil {
if s.logger != nil {
s.logger.Warn("resolve bot owner for group conversation failed", slog.Any("error", err))
}
return fallback
}
// NOTE: OwnerUserID is the bot owner's user ID. Used as fallback creator for group conversations.
ownerUserID := row.OwnerUserID.String()
if strings.TrimSpace(ownerUserID) == "" {
return fallback
}
return ownerUserID
}
func toRouteFromCreate(row sqlc.CreateChatRouteRow) Route {
return toRouteFields(
row.ID, row.ChatID, row.BotID, row.Platform, row.ChannelConfigID,
row.ConversationID, row.ThreadID, row.ConversationType, row.ReplyTarget,
row.Metadata, row.CreatedAt, row.UpdatedAt,
)
}
func toRouteFromFind(row sqlc.FindChatRouteRow) Route {
return toRouteFields(
row.ID, row.ChatID, row.BotID, row.Platform, row.ChannelConfigID,
row.ConversationID, row.ThreadID, row.ConversationType, row.ReplyTarget,
row.Metadata, row.CreatedAt, row.UpdatedAt,
)
}
func toRouteFromGet(row sqlc.GetChatRouteByIDRow) Route {
return toRouteFields(
row.ID, row.ChatID, row.BotID, row.Platform, row.ChannelConfigID,
row.ConversationID, row.ThreadID, row.ConversationType, row.ReplyTarget,
row.Metadata, row.CreatedAt, row.UpdatedAt,
)
}
func toRouteFromList(row sqlc.ListChatRoutesRow) Route {
return toRouteFields(
row.ID, row.ChatID, row.BotID, row.Platform, row.ChannelConfigID,
row.ConversationID, row.ThreadID, row.ConversationType, row.ReplyTarget,
row.Metadata, row.CreatedAt, row.UpdatedAt,
)
}
func toRouteFields(id, conversationID, botID pgtype.UUID, platform string, channelConfigID pgtype.UUID, externalConversationID string, threadID, conversationType, replyTarget pgtype.Text, metadata []byte, createdAt, updatedAt pgtype.Timestamptz) Route {
return Route{
ID: id.String(),
ChatID: conversationID.String(),
BotID: botID.String(),
Platform: platform,
ChannelConfigID: channelConfigID.String(),
ConversationID: externalConversationID,
ThreadID: dbpkg.TextToString(threadID),
ConversationType: dbpkg.TextToString(conversationType),
ReplyTarget: dbpkg.TextToString(replyTarget),
Metadata: parseJSONMap(metadata),
CreatedAt: createdAt.Time,
UpdatedAt: updatedAt.Time,
}
}
func toPgText(value string) pgtype.Text {
value = strings.TrimSpace(value)
if value == "" {
return pgtype.Text{}
}
return pgtype.Text{String: value, Valid: true}
}
func nonNilMap(m map[string]any) map[string]any {
if m == nil {
return map[string]any{}
}
return m
}
func parseJSONMap(data []byte) map[string]any {
if len(data) == 0 {
return nil
}
var m map[string]any
_ = json.Unmarshal(data, &m)
return m
}
// metadataChanged returns true when any key in incoming differs from existing.
func metadataChanged(existing, incoming map[string]any) bool {
for k, v := range incoming {
old, ok := existing[k]
if !ok {
return true
}
oldJSON, _ := json.Marshal(old)
newJSON, _ := json.Marshal(v)
if string(oldJSON) != string(newJSON) {
return true
}
}
return false
}
// mergeMetadata merges incoming keys into existing, preserving keys not in incoming.
func mergeMetadata(existing, incoming map[string]any) map[string]any {
merged := make(map[string]any, len(existing)+len(incoming))
for k, v := range existing {
merged[k] = v
}
for k, v := range incoming {
merged[k] = v
}
return merged
}