mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
470 lines
15 KiB
Go
470 lines
15 KiB
Go
package handlers
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
|
|
"github.com/memohai/memoh/internal/accounts"
|
|
"github.com/memohai/memoh/internal/bots"
|
|
"github.com/memohai/memoh/internal/conversation"
|
|
"github.com/memohai/memoh/internal/media"
|
|
messagepkg "github.com/memohai/memoh/internal/message"
|
|
messageevent "github.com/memohai/memoh/internal/message/event"
|
|
)
|
|
|
|
// MessageHandler handles bot-scoped messaging endpoints.
|
|
type MessageHandler struct {
|
|
conversationService conversation.Accessor
|
|
messageService messagepkg.Service
|
|
messageEvents messageevent.Subscriber
|
|
mediaService *media.Service
|
|
botService *bots.Service
|
|
accountService *accounts.Service
|
|
logger *slog.Logger
|
|
}
|
|
|
|
// NewMessageHandler creates a MessageHandler.
|
|
func NewMessageHandler(log *slog.Logger, conversationService conversation.Accessor, messageService messagepkg.Service, botService *bots.Service, accountService *accounts.Service, eventSubscribers ...messageevent.Subscriber) *MessageHandler {
|
|
var messageEvents messageevent.Subscriber
|
|
if len(eventSubscribers) > 0 {
|
|
messageEvents = eventSubscribers[0]
|
|
}
|
|
return &MessageHandler{
|
|
conversationService: conversationService,
|
|
messageService: messageService,
|
|
messageEvents: messageEvents,
|
|
botService: botService,
|
|
accountService: accountService,
|
|
logger: log.With(slog.String("handler", "conversation")),
|
|
}
|
|
}
|
|
|
|
// SetMediaService sets the optional media service for asset serving.
|
|
func (h *MessageHandler) SetMediaService(svc *media.Service) {
|
|
h.mediaService = svc
|
|
}
|
|
|
|
// Register registers all conversation routes.
|
|
func (h *MessageHandler) Register(e *echo.Echo) {
|
|
// Bot-scoped message container (single shared history per bot).
|
|
botGroup := e.Group("/bots/:bot_id")
|
|
botGroup.GET("/messages", h.ListMessages)
|
|
botGroup.GET("/messages/events", h.StreamMessageEvents)
|
|
botGroup.DELETE("/messages", h.DeleteMessages)
|
|
botGroup.GET("/media/:content_hash", h.ServeMedia)
|
|
}
|
|
|
|
// --- Messages ---
|
|
|
|
func writeSSEData(writer *bufio.Writer, flusher http.Flusher, payload string) error {
|
|
if _, err := fmt.Fprintf(writer, "data: %s\n\n", payload); err != nil {
|
|
return err
|
|
}
|
|
if err := writer.Flush(); err != nil {
|
|
return err
|
|
}
|
|
flusher.Flush()
|
|
return nil
|
|
}
|
|
|
|
func writeSSEJSON(writer *bufio.Writer, flusher http.Flusher, payload any) error {
|
|
data, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return writeSSEData(writer, flusher, string(data))
|
|
}
|
|
|
|
func parseSinceParam(raw string) (time.Time, bool, error) {
|
|
trimmed := strings.TrimSpace(raw)
|
|
if trimmed == "" {
|
|
return time.Time{}, false, nil
|
|
}
|
|
layouts := []string{time.RFC3339Nano, time.RFC3339}
|
|
for _, layout := range layouts {
|
|
parsed, err := time.Parse(layout, trimmed)
|
|
if err == nil {
|
|
return parsed.UTC(), true, nil
|
|
}
|
|
}
|
|
if epochMillis, err := strconv.ParseInt(trimmed, 10, 64); err == nil {
|
|
return time.UnixMilli(epochMillis).UTC(), true, nil
|
|
}
|
|
return time.Time{}, false, errors.New("invalid since parameter")
|
|
}
|
|
|
|
// ListMessages godoc
|
|
// @Summary List bot history messages
|
|
// @Description List messages for a bot history with optional pagination
|
|
// @Tags messages
|
|
// @Produce json
|
|
// @Param bot_id path string true "Bot ID"
|
|
// @Param limit query int false "Limit"
|
|
// @Param before query string false "Before"
|
|
// @Success 200 {object} map[string][]messagepkg.Message
|
|
// @Failure 400 {object} ErrorResponse
|
|
// @Failure 403 {object} ErrorResponse
|
|
// @Failure 500 {object} ErrorResponse
|
|
// @Router /bots/{bot_id}/messages [get].
|
|
func (h *MessageHandler) ListMessages(c echo.Context) error {
|
|
channelIdentityID, err := h.requireChannelIdentityID(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
botID := strings.TrimSpace(c.Param("bot_id"))
|
|
if botID == "" {
|
|
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
|
}
|
|
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
|
return err
|
|
}
|
|
if err := h.requireReadable(c.Request().Context(), botID, channelIdentityID); err != nil {
|
|
return err
|
|
}
|
|
|
|
if h.messageService == nil {
|
|
return echo.NewHTTPError(http.StatusInternalServerError, "message service not configured")
|
|
}
|
|
|
|
limit := int32(30)
|
|
if s := strings.TrimSpace(c.QueryParam("limit")); s != "" {
|
|
if n, err := strconv.ParseInt(s, 10, 32); err == nil && n > 0 && n <= 100 {
|
|
limit = int32(n)
|
|
}
|
|
}
|
|
|
|
before, hasBefore := parseBeforeParam(c.QueryParam("before"))
|
|
format := strings.ToLower(strings.TrimSpace(c.QueryParam("format")))
|
|
|
|
sessionID := strings.TrimSpace(c.QueryParam("session_id"))
|
|
|
|
var messages []messagepkg.Message
|
|
if sessionID != "" {
|
|
if hasBefore {
|
|
messages, err = h.messageService.ListBeforeBySession(c.Request().Context(), sessionID, before, limit)
|
|
} else {
|
|
messages, err = h.messageService.ListLatestBySession(c.Request().Context(), sessionID, limit)
|
|
if err == nil {
|
|
reverseMessages(messages)
|
|
}
|
|
}
|
|
} else {
|
|
if hasBefore {
|
|
messages, err = h.messageService.ListBefore(c.Request().Context(), botID, before, limit)
|
|
} else {
|
|
messages, err = h.messageService.ListLatest(c.Request().Context(), botID, limit)
|
|
if err == nil {
|
|
reverseMessages(messages)
|
|
}
|
|
}
|
|
}
|
|
if err != nil {
|
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
|
}
|
|
h.fillAssetMimeFromStorage(c.Request().Context(), botID, messages)
|
|
if format == "ui" {
|
|
return c.JSON(http.StatusOK, map[string]any{
|
|
"items": conversation.ConvertMessagesToUITurns(messages),
|
|
})
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]any{"items": messages})
|
|
}
|
|
|
|
// fillAssetMimeFromStorage fills mime, storage_key, size_bytes from storage (soft link: DB only has content_hash).
|
|
func (h *MessageHandler) fillAssetMimeFromStorage(ctx context.Context, botID string, messages []messagepkg.Message) {
|
|
if h.mediaService == nil {
|
|
return
|
|
}
|
|
for i := range messages {
|
|
for j := range messages[i].Assets {
|
|
a := &messages[i].Assets[j] //nolint:gosec // G602: j is bounded by range loop
|
|
if strings.TrimSpace(a.ContentHash) == "" {
|
|
continue
|
|
}
|
|
asset, err := h.mediaService.Resolve(ctx, botID, a.ContentHash)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
a.Mime = asset.Mime
|
|
a.StorageKey = asset.StorageKey
|
|
a.SizeBytes = asset.SizeBytes
|
|
}
|
|
}
|
|
}
|
|
|
|
func parseBeforeParam(s string) (time.Time, bool) {
|
|
trimmed := strings.TrimSpace(s)
|
|
if trimmed == "" {
|
|
return time.Time{}, false
|
|
}
|
|
if t, err := time.Parse(time.RFC3339Nano, trimmed); err == nil {
|
|
return t.UTC(), true
|
|
}
|
|
if t, err := time.Parse(time.RFC3339, trimmed); err == nil {
|
|
return t.UTC(), true
|
|
}
|
|
if epochMillis, err := strconv.ParseInt(trimmed, 10, 64); err == nil {
|
|
return time.UnixMilli(epochMillis).UTC(), true
|
|
}
|
|
return time.Time{}, false
|
|
}
|
|
|
|
func reverseMessages(m []messagepkg.Message) {
|
|
for i, j := 0, len(m)-1; i < j; i, j = i+1, j-1 {
|
|
m[i], m[j] = m[j], m[i]
|
|
}
|
|
}
|
|
|
|
// StreamMessageEvents streams bot-scoped message events to clients.
|
|
func (h *MessageHandler) StreamMessageEvents(c echo.Context) error {
|
|
channelIdentityID, err := h.requireChannelIdentityID(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
botID := strings.TrimSpace(c.Param("bot_id"))
|
|
if botID == "" {
|
|
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
|
}
|
|
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
|
return err
|
|
}
|
|
if err := h.requireReadable(c.Request().Context(), botID, channelIdentityID); err != nil {
|
|
return err
|
|
}
|
|
if h.messageService == nil {
|
|
return echo.NewHTTPError(http.StatusInternalServerError, "message service not configured")
|
|
}
|
|
if h.messageEvents == nil {
|
|
return echo.NewHTTPError(http.StatusInternalServerError, "message events not configured")
|
|
}
|
|
|
|
since, hasSince, err := parseSinceParam(c.QueryParam("since"))
|
|
if err != nil {
|
|
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
|
}
|
|
|
|
c.Response().Header().Set(echo.HeaderContentType, "text/event-stream")
|
|
c.Response().Header().Set(echo.HeaderCacheControl, "no-cache")
|
|
c.Response().Header().Set(echo.HeaderConnection, "keep-alive")
|
|
c.Response().WriteHeader(http.StatusOK)
|
|
|
|
flusher, ok := c.Response().Writer.(http.Flusher)
|
|
if !ok {
|
|
return echo.NewHTTPError(http.StatusInternalServerError, "streaming not supported")
|
|
}
|
|
writer := bufio.NewWriter(c.Response().Writer)
|
|
|
|
sentMessageIDs := map[string]struct{}{}
|
|
writeCreatedEvent := func(message messagepkg.Message) error {
|
|
msgID := strings.TrimSpace(message.ID)
|
|
if msgID != "" {
|
|
if _, exists := sentMessageIDs[msgID]; exists {
|
|
return nil
|
|
}
|
|
sentMessageIDs[msgID] = struct{}{}
|
|
}
|
|
return writeSSEJSON(writer, flusher, map[string]any{
|
|
"type": string(messageevent.EventTypeMessageCreated),
|
|
"bot_id": botID,
|
|
"message": message,
|
|
})
|
|
}
|
|
|
|
_, stream, cancel := h.messageEvents.Subscribe(botID, 128)
|
|
defer cancel()
|
|
|
|
if hasSince {
|
|
backlog, err := h.messageService.ListSince(c.Request().Context(), botID, since)
|
|
if err != nil {
|
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
|
}
|
|
h.fillAssetMimeFromStorage(c.Request().Context(), botID, backlog)
|
|
for _, message := range backlog {
|
|
if err := writeCreatedEvent(message); err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
heartbeatTicker := time.NewTicker(20 * time.Second)
|
|
defer heartbeatTicker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-c.Request().Context().Done():
|
|
return nil
|
|
case <-heartbeatTicker.C:
|
|
if err := writeSSEJSON(writer, flusher, map[string]any{"type": "ping"}); err != nil {
|
|
return nil
|
|
}
|
|
case event, ok := <-stream:
|
|
if !ok {
|
|
return nil
|
|
}
|
|
if strings.TrimSpace(event.BotID) != botID {
|
|
continue
|
|
}
|
|
if len(event.Data) == 0 {
|
|
continue
|
|
}
|
|
switch event.Type {
|
|
case messageevent.EventTypeMessageCreated:
|
|
var message messagepkg.Message
|
|
if err := json.Unmarshal(event.Data, &message); err != nil {
|
|
h.logger.Warn("decode message event failed", slog.Any("error", err))
|
|
continue
|
|
}
|
|
h.fillAssetMimeFromStorage(c.Request().Context(), botID, []messagepkg.Message{message})
|
|
if err := writeCreatedEvent(message); err != nil {
|
|
return nil
|
|
}
|
|
case messageevent.EventTypeSessionTitleUpdated:
|
|
var payload map[string]string
|
|
if err := json.Unmarshal(event.Data, &payload); err != nil {
|
|
continue
|
|
}
|
|
if err := writeSSEJSON(writer, flusher, map[string]any{
|
|
"type": string(messageevent.EventTypeSessionTitleUpdated),
|
|
"bot_id": botID,
|
|
"session_id": payload["session_id"],
|
|
"title": payload["title"],
|
|
}); err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// DeleteMessages godoc
|
|
// @Summary Delete all bot history messages
|
|
// @Description Clear all persisted bot-level history messages
|
|
// @Tags messages
|
|
// @Produce json
|
|
// @Param bot_id path string true "Bot ID"
|
|
// @Success 204 "No Content"
|
|
// @Failure 400 {object} ErrorResponse
|
|
// @Failure 403 {object} ErrorResponse
|
|
// @Failure 500 {object} ErrorResponse
|
|
// @Router /bots/{bot_id}/messages [delete].
|
|
func (h *MessageHandler) DeleteMessages(c echo.Context) error {
|
|
channelIdentityID, err := h.requireChannelIdentityID(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
botID := strings.TrimSpace(c.Param("bot_id"))
|
|
if botID == "" {
|
|
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
|
}
|
|
if _, err := h.authorizeBotManage(c.Request().Context(), channelIdentityID, botID); err != nil {
|
|
return err
|
|
}
|
|
if h.messageService == nil {
|
|
return echo.NewHTTPError(http.StatusInternalServerError, "message service not configured")
|
|
}
|
|
sessionID := strings.TrimSpace(c.QueryParam("session_id"))
|
|
if sessionID != "" {
|
|
if err := h.messageService.DeleteBySession(c.Request().Context(), sessionID); err != nil {
|
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
|
}
|
|
} else {
|
|
if err := h.messageService.DeleteByBot(c.Request().Context(), botID); err != nil {
|
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
|
}
|
|
}
|
|
return c.NoContent(http.StatusNoContent)
|
|
}
|
|
|
|
// --- helpers ---
|
|
|
|
func (*MessageHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
|
return RequireChannelIdentityID(c)
|
|
}
|
|
|
|
func (h *MessageHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
|
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID)
|
|
}
|
|
|
|
func (h *MessageHandler) authorizeBotManage(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
|
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID)
|
|
}
|
|
|
|
func (h *MessageHandler) requireReadable(ctx context.Context, conversationID, channelIdentityID string) error {
|
|
if h.conversationService == nil {
|
|
return echo.NewHTTPError(http.StatusInternalServerError, "conversation service not configured")
|
|
}
|
|
// Admin bypass.
|
|
if h.accountService != nil {
|
|
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
|
if err != nil {
|
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
|
}
|
|
if isAdmin {
|
|
return nil
|
|
}
|
|
}
|
|
_, err := h.conversationService.GetReadAccess(ctx, conversationID, channelIdentityID)
|
|
if err != nil {
|
|
if errors.Is(err, conversation.ErrPermissionDenied) {
|
|
return echo.NewHTTPError(http.StatusForbidden, "not allowed to read conversation")
|
|
}
|
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ServeMedia streams a media asset by bot_id + content_hash with read-access authorization.
|
|
func (h *MessageHandler) ServeMedia(c echo.Context) error {
|
|
channelIdentityID, err := h.requireChannelIdentityID(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
botID := strings.TrimSpace(c.Param("bot_id"))
|
|
if botID == "" {
|
|
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
|
}
|
|
contentHash := strings.TrimSpace(c.Param("content_hash"))
|
|
if contentHash == "" {
|
|
return echo.NewHTTPError(http.StatusBadRequest, "content hash is required")
|
|
}
|
|
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
|
return err
|
|
}
|
|
if err := h.requireReadable(c.Request().Context(), botID, channelIdentityID); err != nil {
|
|
return err
|
|
}
|
|
if h.mediaService == nil {
|
|
return echo.NewHTTPError(http.StatusInternalServerError, "media service not configured")
|
|
}
|
|
reader, asset, err := h.mediaService.Open(c.Request().Context(), botID, contentHash)
|
|
if err != nil {
|
|
if errors.Is(err, media.ErrAssetNotFound) {
|
|
return echo.NewHTTPError(http.StatusNotFound, "asset not found")
|
|
}
|
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
|
}
|
|
defer func() { _ = reader.Close() }()
|
|
contentType := asset.Mime
|
|
if contentType == "" {
|
|
contentType = "application/octet-stream"
|
|
}
|
|
c.Response().Header().Set("Content-Type", contentType)
|
|
c.Response().Header().Set("Cache-Control", "private, max-age=86400")
|
|
c.Response().WriteHeader(http.StatusOK)
|
|
if _, err := io.Copy(c.Response().Writer, reader); err != nil {
|
|
h.logger.Warn("serve media stream failed", slog.Any("error", err))
|
|
}
|
|
return nil
|
|
}
|