feat: message abort and web socket support (#222)

* feat: message abort and web socket support

* fix(web): chat end

* fix: lint

* fix: lint
This commit is contained in:
Acbox Liu
2026-03-09 23:27:50 +08:00
committed by GitHub
parent 36d50738b5
commit 23d49a1c7b
21 changed files with 1050 additions and 110 deletions
+131 -56
View File
@@ -14,6 +14,7 @@ import (
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
@@ -648,18 +649,9 @@ func (r *Resolver) StreamChat(ctx context.Context, req conversation.ChatRequest)
return
}
streamReq.Query = rc.payload.Query
if !streamReq.UserMessagePersisted {
if err := r.persistUserMessage(ctx, streamReq); err != nil {
r.logger.Error("gateway stream persist user message failed",
slog.String("bot_id", streamReq.BotID),
slog.String("chat_id", streamReq.ChatID),
slog.Any("error", err),
)
errCh <- err
return
}
streamReq.UserMessagePersisted = true
}
// User message persistence is deferred to storeRound so that user +
// assistant messages are written atomically. This prevents duplicate
// user messages when concurrent requests hit the same bot.
if err := r.streamChat(ctx, rc.payload, streamReq, chunkCh, rc.model.ID); err != nil {
r.logger.Error("gateway stream request failed",
slog.String("bot_id", streamReq.BotID),
@@ -674,6 +666,119 @@ func (r *Resolver) StreamChat(ctx context.Context, req conversation.ChatRequest)
return chunkCh, errCh
}
// --- WebSocket streaming ---
// WSStreamEvent represents a raw JSON event forwarded from the agent gateway
// WebSocket connection to the Go server's client WebSocket.
type WSStreamEvent = json.RawMessage
// StreamChatWS resolves the agent context and streams agent events from the
// gateway WebSocket endpoint. Events are sent on eventCh. When abortCh is
// closed or receives a value, an abort message is forwarded to the gateway.
// Terminal events (agent_end, agent_abort) trigger message persistence before
// being forwarded.
func (r *Resolver) StreamChatWS(
ctx context.Context,
req conversation.ChatRequest,
eventCh chan<- WSStreamEvent,
abortCh <-chan struct{},
) error {
rc, err := r.resolve(ctx, req)
if err != nil {
return fmt.Errorf("resolve: %w", err)
}
req.Query = rc.payload.Query
wsURL := strings.Replace(r.gatewayBaseURL, "http://", "ws://", 1)
wsURL = strings.Replace(wsURL, "https://", "wss://", 1)
wsURL += "/chat/ws"
r.logger.Info("gateway ws connect",
slog.String("url", wsURL),
slog.String("bot_id", req.BotID),
)
dialer := websocket.Dialer{
HandshakeTimeout: r.timeout,
}
conn, resp, err := dialer.DialContext(ctx, wsURL, nil)
if resp != nil {
defer func() { _ = resp.Body.Close() }()
}
if err != nil {
return fmt.Errorf("gateway ws dial: %w", err)
}
defer func() { _ = conn.Close() }()
// The gateway WS handler uses the bearer field directly (not as an HTTP
// header), so strip the "Bearer " prefix that the Token field carries.
rawToken := strings.TrimSpace(req.Token)
rawToken = strings.TrimPrefix(rawToken, "Bearer ")
rawToken = strings.TrimPrefix(rawToken, "bearer ")
startPayload := struct {
Type string `json:"type"`
Bearer string `json:"bearer,omitempty"`
gatewayRequest
}{
Type: "start",
Bearer: rawToken,
gatewayRequest: rc.payload,
}
if err := conn.WriteJSON(startPayload); err != nil {
return fmt.Errorf("gateway ws write start: %w", err)
}
// Forward abort signal to gateway.
abortDone := make(chan struct{})
go func() {
defer close(abortDone)
select {
case <-abortCh:
_ = conn.WriteJSON(map[string]string{"type": "abort"})
case <-ctx.Done():
}
}()
defer func() { <-abortDone }()
modelID := rc.model.ID
stored := false
for {
_, msgData, err := conn.ReadMessage()
if err != nil {
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
break
}
if ctx.Err() != nil {
break
}
return fmt.Errorf("gateway ws read: %w", err)
}
if !stored {
var envelope struct {
Type string `json:"type"`
}
if json.Unmarshal(msgData, &envelope) == nil && isTerminalStreamEvent(envelope.Type) {
if _, storeErr := r.tryStoreStream(ctx, req, msgData, modelID); storeErr != nil {
r.logger.Error("ws persist failed", slog.Any("error", storeErr))
} else {
stored = true
}
}
}
select {
case eventCh <- json.RawMessage(msgData):
case <-ctx.Done():
return ctx.Err()
}
}
r.markInboxRead(ctx, req.BotID, rc.inboxItemIDs)
return nil
}
// --- HTTP helpers ---
func (r *Resolver) postChat(ctx context.Context, payload gatewayRequest, token string) (gatewayResponse, error) {
@@ -895,9 +1000,14 @@ func newJSONRequestWithContext(ctx context.Context, method, url string, payload
return req, nil
}
// isTerminalStreamEvent returns true for event types that carry the final
// message round (agent_end, agent_abort, done).
func isTerminalStreamEvent(eventType string) bool {
return eventType == "agent_end" || eventType == "agent_abort" || eventType == "done"
}
// tryStoreStream attempts to extract final messages from a stream event and persist them.
func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, data []byte, modelID string) (bool, error) {
// data: {"type":"text_delta"|"agent_end"|"done", ...}
var envelope struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
@@ -906,7 +1016,7 @@ func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequ
Usages []json.RawMessage `json:"usages,omitempty"`
}
if err := json.Unmarshal(data, &envelope); err == nil {
if (envelope.Type == "agent_end" || envelope.Type == "done") && len(envelope.Messages) > 0 {
if isTerminalStreamEvent(envelope.Type) && len(envelope.Messages) > 0 {
return true, r.storeRound(ctx, req, envelope.Messages, envelope.Usage, envelope.Usages, modelID)
}
if envelope.Type == "done" && len(envelope.Data) > 0 {
@@ -1381,52 +1491,17 @@ func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req conversatio
// --- store helpers ---
func (r *Resolver) persistUserMessage(ctx context.Context, req conversation.ChatRequest) error {
if r.messageService == nil {
return nil
}
if strings.TrimSpace(req.BotID) == "" {
return errors.New("bot id is required for persistence")
}
text := strings.TrimSpace(req.Query)
if text == "" && len(req.Attachments) == 0 {
return nil
}
message := conversation.ModelMessage{
Role: "user",
Content: conversation.NewTextContent(text),
}
content, err := json.Marshal(message)
if err != nil {
return err
}
senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req)
meta := buildRouteMetadata(req)
if meta == nil {
meta = map[string]any{}
}
meta["trigger_mode"] = "active_chat"
_, err = r.messageService.Persist(ctx, messagepkg.PersistInput{
BotID: req.BotID,
RouteID: req.RouteID,
SenderChannelIdentityID: senderChannelIdentityID,
SenderUserID: senderUserID,
Platform: req.CurrentChannel,
ExternalMessageID: req.ExternalMessageID,
Role: "user",
Content: content,
Metadata: meta,
Assets: chatAttachmentsToAssetRefs(req.Attachments),
})
return err
}
func (r *Resolver) storeRound(ctx context.Context, req conversation.ChatRequest, messages []conversation.ModelMessage, usage json.RawMessage, usages []json.RawMessage, modelID string) error {
fullRound := make([]conversation.ModelMessage, 0, len(messages))
roundUsages := make([]json.RawMessage, 0, len(usages))
// When the user message was already persisted by a channel adapter, skip
// the duplicate from the round. Otherwise keep it so that user + assistant
// messages are written atomically (deferred persistence).
skipUserQuery := req.UserMessagePersisted
for i, m := range messages {
if req.UserMessagePersisted && m.Role == "user" && strings.TrimSpace(m.TextContent()) == strings.TrimSpace(req.Query) {
if skipUserQuery && m.Role == "user" && strings.TrimSpace(m.TextContent()) == strings.TrimSpace(req.Query) {
skipUserQuery = false // only skip the first matching user message
continue
}
fullRound = append(fullRound, m)