mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat: message abort and web socket support (#222)
* feat: message abort and web socket support * fix(web): chat end * fix: lint * fix: lint
This commit is contained in:
@@ -5,10 +5,12 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/accounts"
|
||||
@@ -16,6 +18,7 @@ import (
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
"github.com/memohai/memoh/internal/channel/adapters/local"
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/conversation/flow"
|
||||
)
|
||||
|
||||
// LocalChannelHandler handles local channel (CLI/Web) routes backed by bot history.
|
||||
@@ -27,6 +30,8 @@ type LocalChannelHandler struct {
|
||||
routeHub *local.RouteHub
|
||||
botService *bots.Service
|
||||
accountService *accounts.Service
|
||||
resolver *flow.Resolver
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewLocalChannelHandler creates a local channel handler.
|
||||
@@ -39,15 +44,22 @@ func NewLocalChannelHandler(channelType channel.ChannelType, channelManager *cha
|
||||
routeHub: routeHub,
|
||||
botService: botService,
|
||||
accountService: accountService,
|
||||
logger: slog.Default().With(slog.String("handler", "local_channel")),
|
||||
}
|
||||
}
|
||||
|
||||
// SetResolver sets the flow resolver for WebSocket streaming.
|
||||
func (h *LocalChannelHandler) SetResolver(resolver *flow.Resolver) {
|
||||
h.resolver = resolver
|
||||
}
|
||||
|
||||
// Register registers the local channel routes.
|
||||
func (h *LocalChannelHandler) Register(e *echo.Echo) {
|
||||
prefix := fmt.Sprintf("/bots/:bot_id/%s", h.channelType.String())
|
||||
group := e.Group(prefix)
|
||||
group.GET("/stream", h.StreamMessages)
|
||||
group.POST("/messages", h.PostMessage)
|
||||
group.GET("/ws", h.HandleWebSocket)
|
||||
}
|
||||
|
||||
// StreamMessages godoc
|
||||
@@ -196,6 +208,201 @@ func (h *LocalChannelHandler) PostMessage(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
var wsUpgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(_ *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
type wsClientMessage struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Attachments []json.RawMessage `json:"attachments,omitempty"`
|
||||
}
|
||||
|
||||
// wsWriter serialises all WebSocket writes through a single goroutine to
|
||||
// avoid concurrent write panics with gorilla/websocket.
|
||||
type wsWriter struct {
|
||||
conn *websocket.Conn
|
||||
ch chan []byte
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newWSWriter(conn *websocket.Conn) *wsWriter {
|
||||
w := &wsWriter{
|
||||
conn: conn,
|
||||
ch: make(chan []byte, 128),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go w.loop()
|
||||
return w
|
||||
}
|
||||
|
||||
func (w *wsWriter) loop() {
|
||||
defer close(w.done)
|
||||
for data := range w.ch {
|
||||
_ = w.conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *wsWriter) Send(data []byte) {
|
||||
select {
|
||||
case w.ch <- data:
|
||||
case <-w.done:
|
||||
}
|
||||
}
|
||||
|
||||
func (w *wsWriter) SendJSON(v any) {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.Send(data)
|
||||
}
|
||||
|
||||
func (w *wsWriter) Close() {
|
||||
close(w.ch)
|
||||
<-w.done
|
||||
}
|
||||
|
||||
// extractRawBearerToken returns the raw JWT token suitable for passing to the
|
||||
// gateway. The gateway WS handler receives the token directly (not as an HTTP
|
||||
// header), so we must strip the "Bearer " prefix if present.
|
||||
func extractRawBearerToken(c echo.Context) string {
|
||||
auth := strings.TrimSpace(c.Request().Header.Get("Authorization"))
|
||||
if auth != "" {
|
||||
return strings.TrimPrefix(auth, "Bearer ")
|
||||
}
|
||||
return strings.TrimSpace(c.QueryParam("token"))
|
||||
}
|
||||
|
||||
// HandleWebSocket godoc
|
||||
// @Summary WebSocket chat endpoint
|
||||
// @Description Upgrade to WebSocket for bidirectional chat streaming with abort support.
|
||||
// @Tags local-channel
|
||||
// @Param bot_id path string true "Bot ID"
|
||||
// @Success 101 {string} string "Switching Protocols"
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 403 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{bot_id}/web/ws [get]
|
||||
// @Router /bots/{bot_id}/cli/ws [get].
|
||||
func (h *LocalChannelHandler) HandleWebSocket(c echo.Context) error {
|
||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
botID := strings.TrimSpace(c.Param("bot_id"))
|
||||
if botID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
||||
}
|
||||
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.ensureBotParticipant(c.Request().Context(), botID, channelIdentityID); err != nil {
|
||||
return err
|
||||
}
|
||||
if h.resolver == nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "resolver not configured")
|
||||
}
|
||||
|
||||
conn, err := wsUpgrader.Upgrade(c.Response(), c.Request(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
rawToken := extractRawBearerToken(c)
|
||||
bearerToken := "Bearer " + rawToken
|
||||
|
||||
writer := newWSWriter(conn)
|
||||
defer writer.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
abortCh := make(chan struct{}, 1)
|
||||
var activeCancel context.CancelFunc
|
||||
|
||||
for {
|
||||
_, raw, readErr := conn.ReadMessage()
|
||||
if readErr != nil {
|
||||
cancel()
|
||||
break
|
||||
}
|
||||
var msg wsClientMessage
|
||||
if err := json.Unmarshal(raw, &msg); err != nil {
|
||||
writer.SendJSON(map[string]string{"type": "error", "message": "invalid message format"})
|
||||
continue
|
||||
}
|
||||
|
||||
switch msg.Type {
|
||||
case "abort":
|
||||
select {
|
||||
case abortCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
case "message":
|
||||
text := strings.TrimSpace(msg.Text)
|
||||
if text == "" {
|
||||
writer.SendJSON(map[string]string{"type": "error", "message": "message text is required"})
|
||||
continue
|
||||
}
|
||||
|
||||
chatAttachments := make([]conversation.ChatAttachment, 0, len(msg.Attachments))
|
||||
for _, rawAtt := range msg.Attachments {
|
||||
var att conversation.ChatAttachment
|
||||
if err := json.Unmarshal(rawAtt, &att); err == nil {
|
||||
chatAttachments = append(chatAttachments, att)
|
||||
}
|
||||
}
|
||||
|
||||
// Drain any previous abort signal.
|
||||
select {
|
||||
case <-abortCh:
|
||||
default:
|
||||
}
|
||||
|
||||
streamCtx, streamCancel := context.WithCancel(ctx)
|
||||
activeCancel = streamCancel
|
||||
eventCh := make(chan flow.WSStreamEvent, 64)
|
||||
|
||||
go func() {
|
||||
defer streamCancel()
|
||||
defer close(eventCh)
|
||||
req := conversation.ChatRequest{
|
||||
BotID: botID,
|
||||
ChatID: botID,
|
||||
Token: bearerToken,
|
||||
UserID: channelIdentityID,
|
||||
SourceChannelIdentityID: channelIdentityID,
|
||||
ConversationType: "p2p",
|
||||
Query: text,
|
||||
CurrentChannel: h.channelType.String(),
|
||||
Channels: []string{h.channelType.String()},
|
||||
Attachments: chatAttachments,
|
||||
}
|
||||
if streamErr := h.resolver.StreamChatWS(streamCtx, req, eventCh, abortCh); streamErr != nil {
|
||||
if ctx.Err() == nil {
|
||||
h.logger.Error("ws stream error", slog.Any("error", streamErr))
|
||||
writer.SendJSON(map[string]string{"type": "error", "message": streamErr.Error()})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for event := range eventCh {
|
||||
writer.Send(event)
|
||||
}
|
||||
}()
|
||||
|
||||
default:
|
||||
writer.SendJSON(map[string]string{"type": "error", "message": "unknown message type: " + msg.Type})
|
||||
}
|
||||
}
|
||||
_ = activeCancel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *LocalChannelHandler) ensureBotParticipant(ctx context.Context, botID, channelIdentityID string) error {
|
||||
if h.chatService == nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured")
|
||||
|
||||
Reference in New Issue
Block a user