mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat: message abort and web socket support (#222)
* feat: message abort and web socket support * fix(web): chat end * fix: lint * fix: lint
This commit is contained in:
@@ -14,6 +14,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
@@ -648,18 +649,9 @@ func (r *Resolver) StreamChat(ctx context.Context, req conversation.ChatRequest)
|
||||
return
|
||||
}
|
||||
streamReq.Query = rc.payload.Query
|
||||
if !streamReq.UserMessagePersisted {
|
||||
if err := r.persistUserMessage(ctx, streamReq); err != nil {
|
||||
r.logger.Error("gateway stream persist user message failed",
|
||||
slog.String("bot_id", streamReq.BotID),
|
||||
slog.String("chat_id", streamReq.ChatID),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
streamReq.UserMessagePersisted = true
|
||||
}
|
||||
// User message persistence is deferred to storeRound so that user +
|
||||
// assistant messages are written atomically. This prevents duplicate
|
||||
// user messages when concurrent requests hit the same bot.
|
||||
if err := r.streamChat(ctx, rc.payload, streamReq, chunkCh, rc.model.ID); err != nil {
|
||||
r.logger.Error("gateway stream request failed",
|
||||
slog.String("bot_id", streamReq.BotID),
|
||||
@@ -674,6 +666,119 @@ func (r *Resolver) StreamChat(ctx context.Context, req conversation.ChatRequest)
|
||||
return chunkCh, errCh
|
||||
}
|
||||
|
||||
// --- WebSocket streaming ---
|
||||
|
||||
// WSStreamEvent represents a raw JSON event forwarded from the agent gateway
|
||||
// WebSocket connection to the Go server's client WebSocket.
|
||||
type WSStreamEvent = json.RawMessage
|
||||
|
||||
// StreamChatWS resolves the agent context and streams agent events from the
|
||||
// gateway WebSocket endpoint. Events are sent on eventCh. When abortCh is
|
||||
// closed or receives a value, an abort message is forwarded to the gateway.
|
||||
// Terminal events (agent_end, agent_abort) trigger message persistence before
|
||||
// being forwarded.
|
||||
func (r *Resolver) StreamChatWS(
|
||||
ctx context.Context,
|
||||
req conversation.ChatRequest,
|
||||
eventCh chan<- WSStreamEvent,
|
||||
abortCh <-chan struct{},
|
||||
) error {
|
||||
rc, err := r.resolve(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve: %w", err)
|
||||
}
|
||||
req.Query = rc.payload.Query
|
||||
|
||||
wsURL := strings.Replace(r.gatewayBaseURL, "http://", "ws://", 1)
|
||||
wsURL = strings.Replace(wsURL, "https://", "wss://", 1)
|
||||
wsURL += "/chat/ws"
|
||||
|
||||
r.logger.Info("gateway ws connect",
|
||||
slog.String("url", wsURL),
|
||||
slog.String("bot_id", req.BotID),
|
||||
)
|
||||
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: r.timeout,
|
||||
}
|
||||
conn, resp, err := dialer.DialContext(ctx, wsURL, nil)
|
||||
if resp != nil {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("gateway ws dial: %w", err)
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
// The gateway WS handler uses the bearer field directly (not as an HTTP
|
||||
// header), so strip the "Bearer " prefix that the Token field carries.
|
||||
rawToken := strings.TrimSpace(req.Token)
|
||||
rawToken = strings.TrimPrefix(rawToken, "Bearer ")
|
||||
rawToken = strings.TrimPrefix(rawToken, "bearer ")
|
||||
|
||||
startPayload := struct {
|
||||
Type string `json:"type"`
|
||||
Bearer string `json:"bearer,omitempty"`
|
||||
gatewayRequest
|
||||
}{
|
||||
Type: "start",
|
||||
Bearer: rawToken,
|
||||
gatewayRequest: rc.payload,
|
||||
}
|
||||
if err := conn.WriteJSON(startPayload); err != nil {
|
||||
return fmt.Errorf("gateway ws write start: %w", err)
|
||||
}
|
||||
|
||||
// Forward abort signal to gateway.
|
||||
abortDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(abortDone)
|
||||
select {
|
||||
case <-abortCh:
|
||||
_ = conn.WriteJSON(map[string]string{"type": "abort"})
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
defer func() { <-abortDone }()
|
||||
|
||||
modelID := rc.model.ID
|
||||
stored := false
|
||||
for {
|
||||
_, msgData, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
break
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
return fmt.Errorf("gateway ws read: %w", err)
|
||||
}
|
||||
|
||||
if !stored {
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if json.Unmarshal(msgData, &envelope) == nil && isTerminalStreamEvent(envelope.Type) {
|
||||
if _, storeErr := r.tryStoreStream(ctx, req, msgData, modelID); storeErr != nil {
|
||||
r.logger.Error("ws persist failed", slog.Any("error", storeErr))
|
||||
} else {
|
||||
stored = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case eventCh <- json.RawMessage(msgData):
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
r.markInboxRead(ctx, req.BotID, rc.inboxItemIDs)
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- HTTP helpers ---
|
||||
|
||||
func (r *Resolver) postChat(ctx context.Context, payload gatewayRequest, token string) (gatewayResponse, error) {
|
||||
@@ -895,9 +1000,14 @@ func newJSONRequestWithContext(ctx context.Context, method, url string, payload
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// isTerminalStreamEvent returns true for event types that carry the final
|
||||
// message round (agent_end, agent_abort, done).
|
||||
func isTerminalStreamEvent(eventType string) bool {
|
||||
return eventType == "agent_end" || eventType == "agent_abort" || eventType == "done"
|
||||
}
|
||||
|
||||
// tryStoreStream attempts to extract final messages from a stream event and persist them.
|
||||
func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, data []byte, modelID string) (bool, error) {
|
||||
// data: {"type":"text_delta"|"agent_end"|"done", ...}
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
@@ -906,7 +1016,7 @@ func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequ
|
||||
Usages []json.RawMessage `json:"usages,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &envelope); err == nil {
|
||||
if (envelope.Type == "agent_end" || envelope.Type == "done") && len(envelope.Messages) > 0 {
|
||||
if isTerminalStreamEvent(envelope.Type) && len(envelope.Messages) > 0 {
|
||||
return true, r.storeRound(ctx, req, envelope.Messages, envelope.Usage, envelope.Usages, modelID)
|
||||
}
|
||||
if envelope.Type == "done" && len(envelope.Data) > 0 {
|
||||
@@ -1381,52 +1491,17 @@ func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req conversatio
|
||||
|
||||
// --- store helpers ---
|
||||
|
||||
func (r *Resolver) persistUserMessage(ctx context.Context, req conversation.ChatRequest) error {
|
||||
if r.messageService == nil {
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(req.BotID) == "" {
|
||||
return errors.New("bot id is required for persistence")
|
||||
}
|
||||
text := strings.TrimSpace(req.Query)
|
||||
if text == "" && len(req.Attachments) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
message := conversation.ModelMessage{
|
||||
Role: "user",
|
||||
Content: conversation.NewTextContent(text),
|
||||
}
|
||||
content, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req)
|
||||
meta := buildRouteMetadata(req)
|
||||
if meta == nil {
|
||||
meta = map[string]any{}
|
||||
}
|
||||
meta["trigger_mode"] = "active_chat"
|
||||
_, err = r.messageService.Persist(ctx, messagepkg.PersistInput{
|
||||
BotID: req.BotID,
|
||||
RouteID: req.RouteID,
|
||||
SenderChannelIdentityID: senderChannelIdentityID,
|
||||
SenderUserID: senderUserID,
|
||||
Platform: req.CurrentChannel,
|
||||
ExternalMessageID: req.ExternalMessageID,
|
||||
Role: "user",
|
||||
Content: content,
|
||||
Metadata: meta,
|
||||
Assets: chatAttachmentsToAssetRefs(req.Attachments),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Resolver) storeRound(ctx context.Context, req conversation.ChatRequest, messages []conversation.ModelMessage, usage json.RawMessage, usages []json.RawMessage, modelID string) error {
|
||||
fullRound := make([]conversation.ModelMessage, 0, len(messages))
|
||||
roundUsages := make([]json.RawMessage, 0, len(usages))
|
||||
|
||||
// When the user message was already persisted by a channel adapter, skip
|
||||
// the duplicate from the round. Otherwise keep it so that user + assistant
|
||||
// messages are written atomically (deferred persistence).
|
||||
skipUserQuery := req.UserMessagePersisted
|
||||
for i, m := range messages {
|
||||
if req.UserMessagePersisted && m.Role == "user" && strings.TrimSpace(m.TextContent()) == strings.TrimSpace(req.Query) {
|
||||
if skipUserQuery && m.Role == "user" && strings.TrimSpace(m.TextContent()) == strings.TrimSpace(req.Query) {
|
||||
skipUserQuery = false // only skip the first matching user message
|
||||
continue
|
||||
}
|
||||
fullRound = append(fullRound, m)
|
||||
|
||||
Reference in New Issue
Block a user