feat: message abort and web socket support (#222)

* feat: message abort and web socket support

* fix(web): chat end

* fix: lint

* fix: lint
This commit is contained in:
Acbox Liu
2026-03-09 23:27:50 +08:00
committed by GitHub
parent 36d50738b5
commit 23d49a1c7b
21 changed files with 1050 additions and 110 deletions
+96 -27
View File
@@ -1,6 +1,6 @@
import { Elysia } from 'elysia'
import z from 'zod'
import { createAgent, ModelConfig } from '@memoh/agent'
import { createAgent, ModelConfig, type AgentStreamAction } from '@memoh/agent'
import { createAuthFetcher, getBaseUrl } from '../index'
import { bearerMiddleware } from '../middlewares/bearer'
import { AgentSkillModel, AttachmentModel, HeartbeatModel, IdentityContextModel, InboxItemModel, LoopDetectionModel, MCPConnectionModel, ModelConfigModel, ScheduleModel } from '../models'
@@ -21,6 +21,37 @@ const AgentModel = z.object({
loopDetection: LoopDetectionModel,
})
const StreamBodyModel = AgentModel.extend({
query: z.string().optional().default(''),
})
function buildAgentAndStream(body: z.infer<typeof StreamBodyModel>, bearer: string, signal?: AbortSignal) {
const auth = {
bearer,
baseUrl: getBaseUrl(),
}
const authFetcher = createAuthFetcher(auth)
const { stream } = createAgent({
model: body.model as ModelConfig,
activeContextTime: body.activeContextTime,
channels: body.channels,
currentChannel: body.currentChannel,
identity: body.identity,
auth,
skills: body.usableSkills,
mcpConnections: body.mcpConnections,
inbox: body.inbox,
loopDetection: body.loopDetection,
}, authFetcher)
return stream({
query: body.query,
messages: body.messages,
skills: body.skills,
attachments: body.attachments,
signal,
})
}
export const chatModule = new Elysia({ prefix: '/chat' })
.use(bearerMiddleware)
.post('/', async ({ body, bearer }) => {
@@ -55,33 +86,13 @@ export const chatModule = new Elysia({ prefix: '/chat' })
})
.post('/stream', async function* ({ body, bearer }) {
console.log('stream', body)
const abortController = new AbortController()
try {
const auth = {
bearer: bearer!,
baseUrl: getBaseUrl(),
}
const authFetcher = createAuthFetcher(auth)
const { stream } = createAgent({
model: body.model as ModelConfig,
activeContextTime: body.activeContextTime,
channels: body.channels,
currentChannel: body.currentChannel,
identity: body.identity,
auth,
skills: body.usableSkills,
mcpConnections: body.mcpConnections,
inbox: body.inbox,
loopDetection: body.loopDetection,
}, authFetcher)
for await (const action of stream({
query: body.query,
messages: body.messages,
skills: body.skills,
attachments: body.attachments,
})) {
for await (const action of buildAgentAndStream(body, bearer!, abortController.signal)) {
yield sseChunked(JSON.stringify(action))
}
} catch (error) {
if (abortController.signal.aborted) return
console.error(error)
const message = error instanceof Error && error.message.trim()
? error.message
@@ -90,12 +101,70 @@ export const chatModule = new Elysia({ prefix: '/chat' })
type: 'error',
message,
}))
} finally {
abortController.abort()
}
}, {
body: AgentModel.extend({
query: z.string().optional().default(''),
}),
body: StreamBodyModel,
})
.ws('/ws', (() => {
const sessions = new Map<unknown, { abortController: AbortController | null; streaming: boolean }>()
return {
open(ws: { raw: unknown }) {
sessions.set(ws.raw, { abortController: null, streaming: false })
},
async message(ws: { raw: unknown; send: (data: string) => void }, raw: unknown) {
const parsed = typeof raw === 'string' ? JSON.parse(raw) : raw
const session = sessions.get(ws.raw)
if (!session) return
if (parsed.type === 'abort') {
session.abortController?.abort()
return
}
if (parsed.type === 'start') {
if (session.streaming) {
ws.send(JSON.stringify({ type: 'error', message: 'Already streaming' }))
return
}
session.streaming = true
const abortController = new AbortController()
session.abortController = abortController
const bearer = parsed.bearer as string | undefined
if (!bearer) {
ws.send(JSON.stringify({ type: 'error', message: 'Missing bearer token' }))
session.streaming = false
return
}
try {
const body = StreamBodyModel.parse(parsed)
const streamIter = buildAgentAndStream(body, bearer, abortController.signal)
for await (const action of streamIter) {
ws.send(JSON.stringify(action))
}
} catch (error) {
if (!abortController.signal.aborted) {
console.error(error)
const message = error instanceof Error && error.message.trim()
? error.message
: 'Internal server error'
ws.send(JSON.stringify({ type: 'error', message }))
}
} finally {
session.streaming = false
session.abortController = null
}
}
},
close(ws: { raw: unknown }) {
const session = sessions.get(ws.raw)
if (session) {
session.abortController?.abort()
sessions.delete(ws.raw)
}
},
}
})())
.post('/trigger-schedule', async ({ body, bearer }) => {
console.log('trigger-schedule', body)
const auth = {
+1
View File
@@ -2,3 +2,4 @@ export * from './useChat.types'
export * from './useChat.chat-api'
export * from './useChat.message-api'
export * from './useChat.content'
export * from './useChat.ws'
@@ -47,7 +47,7 @@ export interface StreamEvent {
| 'reasoning_start' | 'reasoning_delta' | 'reasoning_end'
| 'tool_call_start' | 'tool_call_end'
| 'attachment_delta'
| 'agent_start' | 'agent_end'
| 'agent_start' | 'agent_end' | 'agent_abort'
| 'processing_started' | 'processing_completed' | 'processing_failed'
| 'error'
delta?: string
+139
View File
@@ -0,0 +1,139 @@
import { client } from '@memoh/sdk/client'
import type { StreamEvent, MessageStreamEvent, ChatAttachment, StreamEventHandler } from './useChat.types'
export interface WSClientMessage {
type: 'message' | 'abort'
text?: string
attachments?: ChatAttachment[]
}
export interface ChatWebSocket {
send: (msg: WSClientMessage) => void
abort: () => void
close: () => void
readonly connected: boolean
onOpen: (() => void) | null
onClose: (() => void) | null
}
function resolveWebSocketUrl(botId: string): string {
const baseUrl = String(client.getConfig().baseUrl || '').trim()
const path = `/bots/${encodeURIComponent(botId)}/web/ws`
if (!baseUrl || baseUrl.startsWith('/')) {
const loc = window.location
const proto = loc.protocol === 'https:' ? 'wss:' : 'ws:'
const base = baseUrl || '/api'
return `${proto}//${loc.host}${base.replace(/\/+$/, '')}${path}`
}
try {
const url = new URL(path, baseUrl)
url.protocol = url.protocol === 'https:' ? 'wss:' : 'ws:'
return url.toString()
} catch {
const loc = window.location
const proto = loc.protocol === 'https:' ? 'wss:' : 'ws:'
return `${proto}//${loc.host}/api${path}`
}
}
export function connectWebSocket(
botId: string,
onStreamEvent: StreamEventHandler,
onMessageEvent?: (event: MessageStreamEvent) => void,
): ChatWebSocket {
const id = botId.trim()
if (!id) throw new Error('bot id is required')
const wsUrl = resolveWebSocketUrl(id)
const token = localStorage.getItem('token') ?? ''
const url = token ? `${wsUrl}?token=${encodeURIComponent(token)}` : wsUrl
let ws: WebSocket | null = null
let isConnected = false
let closed = false
let reconnectTimer: ReturnType<typeof setTimeout> | null = null
let reconnectDelay = 1000
const handle: ChatWebSocket = {
send(msg: WSClientMessage) {
if (ws && ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify(msg))
}
},
abort() {
if (ws && ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ type: 'abort' }))
}
},
close() {
closed = true
if (reconnectTimer) {
clearTimeout(reconnectTimer)
reconnectTimer = null
}
if (ws) {
ws.close()
ws = null
}
isConnected = false
},
get connected() {
return isConnected
},
onOpen: null,
onClose: null,
}
function connect() {
if (closed) return
ws = new WebSocket(url)
ws.onopen = () => {
isConnected = true
reconnectDelay = 1000
handle.onOpen?.()
}
ws.onclose = () => {
isConnected = false
handle.onClose?.()
scheduleReconnect()
}
ws.onerror = () => {
// onerror is always followed by onclose; reconnect handled there.
}
ws.onmessage = (event) => {
if (typeof event.data !== 'string') return
try {
const parsed = JSON.parse(event.data)
if (!parsed || typeof parsed !== 'object') return
const eventType = String(parsed.type ?? '').trim()
if (eventType === 'message_created' && onMessageEvent) {
onMessageEvent(parsed as MessageStreamEvent)
return
}
onStreamEvent(parsed as StreamEvent)
} catch {
// Ignore unparsable messages.
}
}
}
function scheduleReconnect() {
if (closed) return
reconnectTimer = setTimeout(() => {
reconnectTimer = null
connect()
}, reconnectDelay)
reconnectDelay = Math.min(reconnectDelay * 1.5, 10000)
}
connect()
return handle
}
+40 -2
View File
@@ -21,7 +21,9 @@ import {
sendLocalChannelMessage,
streamLocalChannel,
streamMessageEvents,
connectWebSocket,
type ChatAttachment,
type ChatWebSocket,
} from '@/composables/api/useChat'
// ---- Message model (blocks-based, aligned with main branch) ----
@@ -103,6 +105,7 @@ export const useChatStore = defineStore('chat', () => {
let pendingAssistantStream: PendingAssistantStream | null = null
const messageEventsStream = useRetryingStream()
const localStream = useRetryingStream()
let activeWs: ChatWebSocket | null = null
const participantChats = computed(() =>
chats.value.filter((c) => (c.access_mode ?? 'participant') === 'participant'),
@@ -123,6 +126,7 @@ export const useChatStore = defineStore('chat', () => {
} else {
stopMessageEvents()
stopLocalStream()
stopWebSocket()
rejectPendingAssistantStream(new Error('Bot stream stopped'))
messageEventsSince = ''
chats.value = []
@@ -332,6 +336,9 @@ export const useChatStore = defineStore('chat', () => {
// ---- Abort ----
function abort() {
if (activeWs) {
activeWs.abort()
}
abortFn?.()
abortFn = null
for (const msg of messages) {
@@ -356,6 +363,25 @@ export const useChatStore = defineStore('chat', () => {
localStream.stop()
}
function stopWebSocket() {
if (activeWs) {
activeWs.close()
activeWs = null
}
}
function startWebSocket(targetBotId: string) {
const bid = targetBotId.trim()
stopWebSocket()
if (!bid) return
activeWs = connectWebSocket(
bid,
handleLocalStreamEvent,
(e) => handleStreamEvent(bid, e),
)
}
function pushAssistantBlock(session: PendingAssistantStream, block: ContentBlock): number {
session.assistantMsg.blocks.push(block)
return session.assistantMsg.blocks.length - 1
@@ -587,8 +613,11 @@ export const useChatStore = defineStore('chat', () => {
rejectPendingAssistantStream(new Error(message))
break
}
case 'agent_start':
case 'agent_abort':
case 'agent_end':
resolvePendingAssistantStream()
break
case 'agent_start':
default: {
const fallback = extractFallbackText(event)
if (fallback) {
@@ -750,6 +779,7 @@ export const useChatStore = defineStore('chat', () => {
loadingChats.value = true
stopMessageEvents()
stopLocalStream()
stopWebSocket()
try {
const bid = await ensureBot()
if (!bid) {
@@ -772,6 +802,7 @@ export const useChatStore = defineStore('chat', () => {
: visible[0]!.id
chatId.value = activeChatId
await loadMessages(bid, activeChatId)
startWebSocket(bid)
startMessageEvents(bid)
startLocalStream(bid)
} finally {
@@ -900,10 +931,17 @@ export const useChatStore = defineStore('chat', () => {
abortFn = () => {
const abortError = new Error('aborted')
abortError.name = 'AbortError'
if (activeWs) {
activeWs.abort()
}
rejectPendingAssistantStream(abortError)
}
await sendLocalChannelMessage(bid, trimmed, attachments)
if (activeWs?.connected) {
activeWs.send({ type: 'message', text: trimmed, attachments })
} else {
await sendLocalChannelMessage(bid, trimmed, attachments)
}
await completion
assistantMsg.streaming = false
+4 -2
View File
@@ -46,7 +46,8 @@ export default defineConfig(({ command }) => {
'/api': {
target: baseUrl,
changeOrigin: true,
rewrite: (path: string) => path.replace(/^\/api/, '')
rewrite: (path: string) => path.replace(/^\/api/, ''),
ws: true,
}
},
},
@@ -57,7 +58,8 @@ export default defineConfig(({ command }) => {
'/api': {
target: baseUrl,
changeOrigin: true,
rewrite: (path: string) => path.replace(/^\/api/, '')
rewrite: (path: string) => path.replace(/^\/api/, ''),
ws: true,
}
},
allowedHosts: true,
+8 -4
View File
@@ -535,12 +535,16 @@ func provideUsersHandler(log *slog.Logger, accountService *accounts.Service, ide
return handlers.NewUsersHandler(log, accountService, identityService, botService, routeService, channelStore, channelLifecycle, channelManager, registry)
}
func provideCLIHandler(channelManager *channel.Manager, channelStore *channel.Store, chatService *conversation.Service, hub *local.RouteHub, botService *bots.Service, accountService *accounts.Service) *handlers.LocalChannelHandler {
return handlers.NewLocalChannelHandler(local.CLIType, channelManager, channelStore, chatService, hub, botService, accountService)
func provideCLIHandler(channelManager *channel.Manager, channelStore *channel.Store, chatService *conversation.Service, hub *local.RouteHub, botService *bots.Service, accountService *accounts.Service, resolver *flow.Resolver) *handlers.LocalChannelHandler {
h := handlers.NewLocalChannelHandler(local.CLIType, channelManager, channelStore, chatService, hub, botService, accountService)
h.SetResolver(resolver)
return h
}
func provideWebHandler(channelManager *channel.Manager, channelStore *channel.Store, chatService *conversation.Service, hub *local.RouteHub, botService *bots.Service, accountService *accounts.Service) *handlers.LocalChannelHandler {
return handlers.NewLocalChannelHandler(local.WebType, channelManager, channelStore, chatService, hub, botService, accountService)
func provideWebHandler(channelManager *channel.Manager, channelStore *channel.Store, chatService *conversation.Service, hub *local.RouteHub, botService *bots.Service, accountService *accounts.Service, resolver *flow.Resolver) *handlers.LocalChannelHandler {
h := handlers.NewLocalChannelHandler(local.WebType, channelManager, channelStore, chatService, hub, botService, accountService)
h.SetResolver(resolver)
return h
}
// ---------------------------------------------------------------------------
+8 -4
View File
@@ -382,12 +382,16 @@ func provideUsersHandler(log *slog.Logger, accountService *accounts.Service, ide
return handlers.NewUsersHandler(log, accountService, identityService, botService, routeService, channelStore, channelLifecycle, channelManager, registry)
}
func provideCLIHandler(channelManager *channel.Manager, channelStore *channel.Store, chatService *conversation.Service, hub *local.RouteHub, botService *bots.Service, accountService *accounts.Service) *handlers.LocalChannelHandler {
return handlers.NewLocalChannelHandler(local.CLIType, channelManager, channelStore, chatService, hub, botService, accountService)
func provideCLIHandler(channelManager *channel.Manager, channelStore *channel.Store, chatService *conversation.Service, hub *local.RouteHub, botService *bots.Service, accountService *accounts.Service, resolver *flow.Resolver) *handlers.LocalChannelHandler {
h := handlers.NewLocalChannelHandler(local.CLIType, channelManager, channelStore, chatService, hub, botService, accountService)
h.SetResolver(resolver)
return h
}
func provideWebHandler(channelManager *channel.Manager, channelStore *channel.Store, chatService *conversation.Service, hub *local.RouteHub, botService *bots.Service, accountService *accounts.Service) *handlers.LocalChannelHandler {
return handlers.NewLocalChannelHandler(local.WebType, channelManager, channelStore, chatService, hub, botService, accountService)
func provideWebHandler(channelManager *channel.Manager, channelStore *channel.Store, chatService *conversation.Service, hub *local.RouteHub, botService *bots.Service, accountService *accounts.Service, resolver *flow.Resolver) *handlers.LocalChannelHandler {
h := handlers.NewLocalChannelHandler(local.WebType, channelManager, channelStore, chatService, hub, botService, accountService)
h.SetResolver(resolver)
return h
}
type serverParams struct {
+2 -2
View File
@@ -35,8 +35,8 @@ server {
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_connect_timeout 60s;
proxy_send_timeout 60s;
proxy_read_timeout 60s;
proxy_send_timeout 300s;
proxy_read_timeout 300s;
# Swagger 文档(保留 /api 前缀)
location ~ ^/api/(docs(?:/.*)?|swagger\.json)$ {
+131 -56
View File
@@ -14,6 +14,7 @@ import (
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
@@ -648,18 +649,9 @@ func (r *Resolver) StreamChat(ctx context.Context, req conversation.ChatRequest)
return
}
streamReq.Query = rc.payload.Query
if !streamReq.UserMessagePersisted {
if err := r.persistUserMessage(ctx, streamReq); err != nil {
r.logger.Error("gateway stream persist user message failed",
slog.String("bot_id", streamReq.BotID),
slog.String("chat_id", streamReq.ChatID),
slog.Any("error", err),
)
errCh <- err
return
}
streamReq.UserMessagePersisted = true
}
// User message persistence is deferred to storeRound so that user +
// assistant messages are written atomically. This prevents duplicate
// user messages when concurrent requests hit the same bot.
if err := r.streamChat(ctx, rc.payload, streamReq, chunkCh, rc.model.ID); err != nil {
r.logger.Error("gateway stream request failed",
slog.String("bot_id", streamReq.BotID),
@@ -674,6 +666,119 @@ func (r *Resolver) StreamChat(ctx context.Context, req conversation.ChatRequest)
return chunkCh, errCh
}
// --- WebSocket streaming ---
// WSStreamEvent represents a raw JSON event forwarded from the agent gateway
// WebSocket connection to the Go server's client WebSocket.
type WSStreamEvent = json.RawMessage
// StreamChatWS resolves the agent context and streams agent events from the
// gateway WebSocket endpoint. Events are sent on eventCh. When abortCh is
// closed or receives a value, an abort message is forwarded to the gateway.
// Terminal events (agent_end, agent_abort) trigger message persistence before
// being forwarded.
func (r *Resolver) StreamChatWS(
ctx context.Context,
req conversation.ChatRequest,
eventCh chan<- WSStreamEvent,
abortCh <-chan struct{},
) error {
rc, err := r.resolve(ctx, req)
if err != nil {
return fmt.Errorf("resolve: %w", err)
}
req.Query = rc.payload.Query
wsURL := strings.Replace(r.gatewayBaseURL, "http://", "ws://", 1)
wsURL = strings.Replace(wsURL, "https://", "wss://", 1)
wsURL += "/chat/ws"
r.logger.Info("gateway ws connect",
slog.String("url", wsURL),
slog.String("bot_id", req.BotID),
)
dialer := websocket.Dialer{
HandshakeTimeout: r.timeout,
}
conn, resp, err := dialer.DialContext(ctx, wsURL, nil)
if resp != nil {
defer func() { _ = resp.Body.Close() }()
}
if err != nil {
return fmt.Errorf("gateway ws dial: %w", err)
}
defer func() { _ = conn.Close() }()
// The gateway WS handler uses the bearer field directly (not as an HTTP
// header), so strip the "Bearer " prefix that the Token field carries.
rawToken := strings.TrimSpace(req.Token)
rawToken = strings.TrimPrefix(rawToken, "Bearer ")
rawToken = strings.TrimPrefix(rawToken, "bearer ")
startPayload := struct {
Type string `json:"type"`
Bearer string `json:"bearer,omitempty"`
gatewayRequest
}{
Type: "start",
Bearer: rawToken,
gatewayRequest: rc.payload,
}
if err := conn.WriteJSON(startPayload); err != nil {
return fmt.Errorf("gateway ws write start: %w", err)
}
// Forward abort signal to gateway.
abortDone := make(chan struct{})
go func() {
defer close(abortDone)
select {
case <-abortCh:
_ = conn.WriteJSON(map[string]string{"type": "abort"})
case <-ctx.Done():
}
}()
defer func() { <-abortDone }()
modelID := rc.model.ID
stored := false
for {
_, msgData, err := conn.ReadMessage()
if err != nil {
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
break
}
if ctx.Err() != nil {
break
}
return fmt.Errorf("gateway ws read: %w", err)
}
if !stored {
var envelope struct {
Type string `json:"type"`
}
if json.Unmarshal(msgData, &envelope) == nil && isTerminalStreamEvent(envelope.Type) {
if _, storeErr := r.tryStoreStream(ctx, req, msgData, modelID); storeErr != nil {
r.logger.Error("ws persist failed", slog.Any("error", storeErr))
} else {
stored = true
}
}
}
select {
case eventCh <- json.RawMessage(msgData):
case <-ctx.Done():
return ctx.Err()
}
}
r.markInboxRead(ctx, req.BotID, rc.inboxItemIDs)
return nil
}
// --- HTTP helpers ---
func (r *Resolver) postChat(ctx context.Context, payload gatewayRequest, token string) (gatewayResponse, error) {
@@ -895,9 +1000,14 @@ func newJSONRequestWithContext(ctx context.Context, method, url string, payload
return req, nil
}
// isTerminalStreamEvent returns true for event types that carry the final
// message round (agent_end, agent_abort, done).
func isTerminalStreamEvent(eventType string) bool {
return eventType == "agent_end" || eventType == "agent_abort" || eventType == "done"
}
// tryStoreStream attempts to extract final messages from a stream event and persist them.
func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, data []byte, modelID string) (bool, error) {
// data: {"type":"text_delta"|"agent_end"|"done", ...}
var envelope struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
@@ -906,7 +1016,7 @@ func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequ
Usages []json.RawMessage `json:"usages,omitempty"`
}
if err := json.Unmarshal(data, &envelope); err == nil {
if (envelope.Type == "agent_end" || envelope.Type == "done") && len(envelope.Messages) > 0 {
if isTerminalStreamEvent(envelope.Type) && len(envelope.Messages) > 0 {
return true, r.storeRound(ctx, req, envelope.Messages, envelope.Usage, envelope.Usages, modelID)
}
if envelope.Type == "done" && len(envelope.Data) > 0 {
@@ -1381,52 +1491,17 @@ func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req conversatio
// --- store helpers ---
func (r *Resolver) persistUserMessage(ctx context.Context, req conversation.ChatRequest) error {
if r.messageService == nil {
return nil
}
if strings.TrimSpace(req.BotID) == "" {
return errors.New("bot id is required for persistence")
}
text := strings.TrimSpace(req.Query)
if text == "" && len(req.Attachments) == 0 {
return nil
}
message := conversation.ModelMessage{
Role: "user",
Content: conversation.NewTextContent(text),
}
content, err := json.Marshal(message)
if err != nil {
return err
}
senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req)
meta := buildRouteMetadata(req)
if meta == nil {
meta = map[string]any{}
}
meta["trigger_mode"] = "active_chat"
_, err = r.messageService.Persist(ctx, messagepkg.PersistInput{
BotID: req.BotID,
RouteID: req.RouteID,
SenderChannelIdentityID: senderChannelIdentityID,
SenderUserID: senderUserID,
Platform: req.CurrentChannel,
ExternalMessageID: req.ExternalMessageID,
Role: "user",
Content: content,
Metadata: meta,
Assets: chatAttachmentsToAssetRefs(req.Attachments),
})
return err
}
func (r *Resolver) storeRound(ctx context.Context, req conversation.ChatRequest, messages []conversation.ModelMessage, usage json.RawMessage, usages []json.RawMessage, modelID string) error {
fullRound := make([]conversation.ModelMessage, 0, len(messages))
roundUsages := make([]json.RawMessage, 0, len(usages))
// When the user message was already persisted by a channel adapter, skip
// the duplicate from the round. Otherwise keep it so that user + assistant
// messages are written atomically (deferred persistence).
skipUserQuery := req.UserMessagePersisted
for i, m := range messages {
if req.UserMessagePersisted && m.Role == "user" && strings.TrimSpace(m.TextContent()) == strings.TrimSpace(req.Query) {
if skipUserQuery && m.Role == "user" && strings.TrimSpace(m.TextContent()) == strings.TrimSpace(req.Query) {
skipUserQuery = false // only skip the first matching user message
continue
}
fullRound = append(fullRound, m)
+207
View File
@@ -5,10 +5,12 @@ import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/accounts"
@@ -16,6 +18,7 @@ import (
"github.com/memohai/memoh/internal/channel"
"github.com/memohai/memoh/internal/channel/adapters/local"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/conversation/flow"
)
// LocalChannelHandler handles local channel (CLI/Web) routes backed by bot history.
@@ -27,6 +30,8 @@ type LocalChannelHandler struct {
routeHub *local.RouteHub
botService *bots.Service
accountService *accounts.Service
resolver *flow.Resolver
logger *slog.Logger
}
// NewLocalChannelHandler creates a local channel handler.
@@ -39,15 +44,22 @@ func NewLocalChannelHandler(channelType channel.ChannelType, channelManager *cha
routeHub: routeHub,
botService: botService,
accountService: accountService,
logger: slog.Default().With(slog.String("handler", "local_channel")),
}
}
// SetResolver sets the flow resolver for WebSocket streaming.
func (h *LocalChannelHandler) SetResolver(resolver *flow.Resolver) {
h.resolver = resolver
}
// Register registers the local channel routes.
func (h *LocalChannelHandler) Register(e *echo.Echo) {
prefix := fmt.Sprintf("/bots/:bot_id/%s", h.channelType.String())
group := e.Group(prefix)
group.GET("/stream", h.StreamMessages)
group.POST("/messages", h.PostMessage)
group.GET("/ws", h.HandleWebSocket)
}
// StreamMessages godoc
@@ -196,6 +208,201 @@ func (h *LocalChannelHandler) PostMessage(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
}
var wsUpgrader = websocket.Upgrader{
CheckOrigin: func(_ *http.Request) bool { return true },
}
type wsClientMessage struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Attachments []json.RawMessage `json:"attachments,omitempty"`
}
// wsWriter serialises all WebSocket writes through a single goroutine to
// avoid concurrent write panics with gorilla/websocket.
type wsWriter struct {
conn *websocket.Conn
ch chan []byte
done chan struct{}
}
func newWSWriter(conn *websocket.Conn) *wsWriter {
w := &wsWriter{
conn: conn,
ch: make(chan []byte, 128),
done: make(chan struct{}),
}
go w.loop()
return w
}
func (w *wsWriter) loop() {
defer close(w.done)
for data := range w.ch {
_ = w.conn.WriteMessage(websocket.TextMessage, data)
}
}
func (w *wsWriter) Send(data []byte) {
select {
case w.ch <- data:
case <-w.done:
}
}
func (w *wsWriter) SendJSON(v any) {
data, err := json.Marshal(v)
if err != nil {
return
}
w.Send(data)
}
func (w *wsWriter) Close() {
close(w.ch)
<-w.done
}
// extractRawBearerToken returns the raw JWT token suitable for passing to the
// gateway. The gateway WS handler receives the token directly (not as an HTTP
// header), so we must strip the "Bearer " prefix if present.
func extractRawBearerToken(c echo.Context) string {
auth := strings.TrimSpace(c.Request().Header.Get("Authorization"))
if auth != "" {
return strings.TrimPrefix(auth, "Bearer ")
}
return strings.TrimSpace(c.QueryParam("token"))
}
// HandleWebSocket godoc
// @Summary WebSocket chat endpoint
// @Description Upgrade to WebSocket for bidirectional chat streaming with abort support.
// @Tags local-channel
// @Param bot_id path string true "Bot ID"
// @Success 101 {string} string "Switching Protocols"
// @Failure 400 {object} ErrorResponse
// @Failure 403 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/web/ws [get]
// @Router /bots/{bot_id}/cli/ws [get].
func (h *LocalChannelHandler) HandleWebSocket(c echo.Context) error {
channelIdentityID, err := h.requireChannelIdentityID(c)
if err != nil {
return err
}
botID := strings.TrimSpace(c.Param("bot_id"))
if botID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
}
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
return err
}
if err := h.ensureBotParticipant(c.Request().Context(), botID, channelIdentityID); err != nil {
return err
}
if h.resolver == nil {
return echo.NewHTTPError(http.StatusInternalServerError, "resolver not configured")
}
conn, err := wsUpgrader.Upgrade(c.Response(), c.Request(), nil)
if err != nil {
return err
}
defer func() { _ = conn.Close() }()
rawToken := extractRawBearerToken(c)
bearerToken := "Bearer " + rawToken
writer := newWSWriter(conn)
defer writer.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
abortCh := make(chan struct{}, 1)
var activeCancel context.CancelFunc
for {
_, raw, readErr := conn.ReadMessage()
if readErr != nil {
cancel()
break
}
var msg wsClientMessage
if err := json.Unmarshal(raw, &msg); err != nil {
writer.SendJSON(map[string]string{"type": "error", "message": "invalid message format"})
continue
}
switch msg.Type {
case "abort":
select {
case abortCh <- struct{}{}:
default:
}
case "message":
text := strings.TrimSpace(msg.Text)
if text == "" {
writer.SendJSON(map[string]string{"type": "error", "message": "message text is required"})
continue
}
chatAttachments := make([]conversation.ChatAttachment, 0, len(msg.Attachments))
for _, rawAtt := range msg.Attachments {
var att conversation.ChatAttachment
if err := json.Unmarshal(rawAtt, &att); err == nil {
chatAttachments = append(chatAttachments, att)
}
}
// Drain any previous abort signal.
select {
case <-abortCh:
default:
}
streamCtx, streamCancel := context.WithCancel(ctx)
activeCancel = streamCancel
eventCh := make(chan flow.WSStreamEvent, 64)
go func() {
defer streamCancel()
defer close(eventCh)
req := conversation.ChatRequest{
BotID: botID,
ChatID: botID,
Token: bearerToken,
UserID: channelIdentityID,
SourceChannelIdentityID: channelIdentityID,
ConversationType: "p2p",
Query: text,
CurrentChannel: h.channelType.String(),
Channels: []string{h.channelType.String()},
Attachments: chatAttachments,
}
if streamErr := h.resolver.StreamChatWS(streamCtx, req, eventCh, abortCh); streamErr != nil {
if ctx.Err() == nil {
h.logger.Error("ws stream error", slog.Any("error", streamErr))
writer.SendJSON(map[string]string{"type": "error", "message": streamErr.Error()})
}
}
}()
go func() {
for event := range eventCh {
writer.Send(event)
}
}()
default:
writer.SendJSON(map[string]string{"type": "error", "message": "unknown message type: " + msg.Type})
}
}
_ = activeCancel
return nil
}
func (h *LocalChannelHandler) ensureBotParticipant(ctx context.Context, botID, channelIdentityID string) error {
if h.chatService == nil {
return echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured")
+51 -7
View File
@@ -4,6 +4,7 @@ import {
LanguageModelUsage,
ModelMessage,
stepCountIs,
type StepResult,
streamText,
ToolSet,
UserModelMessage,
@@ -85,6 +86,16 @@ export const buildNativeImageParts = (attachments: GatewayInputAttachment[]): Im
.filter((attachment): attachment is ImagePart => attachment != null)
}
const rebuildPartialMessages = (steps: StepResult<ToolSet>[]): ModelMessage[] => {
const messages: ModelMessage[] = []
for (const step of steps) {
if (step.response?.messages) {
messages.push(...(step.response.messages as ModelMessage[]))
}
}
return messages
}
export const createAgent = (
{
model: modelConfig,
@@ -508,7 +519,6 @@ export const createAgent = (
basePrepareStep: () => ({ system: systemPrompt }),
})
const tools = { ...baseTools, ...readMediaTools }
// Stream path needs deferred abort to keep tool_call_start/tool_call_end event pairing.
const guardedTools = buildGuardedTools(tools, (toolCallId) => {
toolLoopAbortCallIds.add(toolCallId)
})
@@ -519,6 +529,17 @@ export const createAgent = (
}
await closePromise
}
const abortController = new AbortController()
if (input.signal) {
if (input.signal.aborted) {
abortController.abort(input.signal.reason)
} else {
input.signal.addEventListener('abort', () => abortController.abort(input.signal!.reason), { once: true })
}
}
const abortedSteps: StepResult<ToolSet>[] = []
let wasAborted = false
let streamError: unknown = null
try {
const { fullStream } = streamText({
@@ -529,6 +550,7 @@ export const createAgent = (
stopWhen: stepCountIs(Infinity),
prepareStep,
tools: guardedTools,
abortSignal: abortController.signal,
onFinish: async ({ usage, reasoning, response, steps }) => {
await closeTools()
result.usage = usage as never
@@ -536,6 +558,10 @@ export const createAgent = (
result.messages = response.messages
result.usages = buildStepUsages(steps)
},
onAbort: ({ steps }) => {
wasAborted = true
abortedSteps.push(...steps)
},
})
yield {
type: 'agent_start',
@@ -593,7 +619,6 @@ export const createAgent = (
break
}
case 'text-end': {
// Flush any remaining buffered content before ending the text stream.
const remainder = attachmentsExtractor.flushRemainder()
if (remainder.visibleText) {
if (textLoopProbeBuffer) {
@@ -620,7 +645,6 @@ export const createAgent = (
break
}
case 'tool-call':
// Flush any remaining buffered content before ending the text stream.
const remainder = attachmentsExtractor.flushRemainder()
if (remainder.visibleText) {
if (textLoopProbeBuffer) {
@@ -649,8 +673,6 @@ export const createAgent = (
}
break
case 'tool-result':
// Always emit the terminal tool event first so downstream reducers
// can close the in-flight tool block before the stream aborts.
const shouldAbortForToolLoop = toolLoopAbortCallIds.delete(chunk.toolCallId)
yield {
type: 'tool_call_end',
@@ -702,8 +724,30 @@ export const createAgent = (
}
} catch (error) {
streamError = error
console.error(error)
throw error
if (wasAborted || abortController.signal.aborted) {
const partialMessages = rebuildPartialMessages(abortedSteps)
const partialUsages = abortedSteps.length > 0
? buildStepUsages(abortedSteps as Parameters<typeof buildStepUsages>[0])
: []
const partialUsage = abortedSteps.length > 0
? abortedSteps[abortedSteps.length - 1].usage
: null
const partialReasoning = abortedSteps.flatMap(
(s) => (s.reasoning ?? []).map((r: { text: string }) => r.text),
)
yield {
type: 'agent_abort',
messages: [userPrompt, ...partialMessages],
usages: [null, ...partialUsages],
reasoning: partialReasoning,
usage: partialUsage as LanguageModelUsage | null,
skills: getEnabledSkills(),
}
} else {
console.error(error)
throw error
}
} finally {
try {
await closeTools()
+10
View File
@@ -67,6 +67,15 @@ export interface AgentEndAction extends BaseAction {
usages: (LanguageModelUsage | null)[]
}
export interface AgentAbortAction extends BaseAction {
type: 'agent_abort'
messages: ModelMessage[]
skills: string[]
reasoning: string[]
usage: LanguageModelUsage | null
usages: (LanguageModelUsage | null)[]
}
export type AgentStreamAction =
| AgentStartAction
| ReasoningStartAction
@@ -79,3 +88,4 @@ export type AgentStreamAction =
| ToolCallStartAction
| ToolCallEndAction
| AgentEndAction
| AgentAbortAction
+1
View File
@@ -50,6 +50,7 @@ export interface AgentInput {
attachments: GatewayInputAttachment[]
skills: string[]
query: string
signal?: AbortSignal
}
export interface AgentSkill {
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
+58
View File
@@ -1562,6 +1562,35 @@ export type GetBotsByBotIdCliStreamResponses = {
export type GetBotsByBotIdCliStreamResponse = GetBotsByBotIdCliStreamResponses[keyof GetBotsByBotIdCliStreamResponses];
export type GetBotsByBotIdCliWsData = {
body?: never;
path: {
/**
* Bot ID
*/
bot_id: string;
};
query?: never;
url: '/bots/{bot_id}/cli/ws';
};
export type GetBotsByBotIdCliWsErrors = {
/**
* Bad Request
*/
400: HandlersErrorResponse;
/**
* Forbidden
*/
403: HandlersErrorResponse;
/**
* Internal Server Error
*/
500: HandlersErrorResponse;
};
export type GetBotsByBotIdCliWsError = GetBotsByBotIdCliWsErrors[keyof GetBotsByBotIdCliWsErrors];
export type DeleteBotsByBotIdContainerData = {
body?: never;
path: {
@@ -4891,6 +4920,35 @@ export type GetBotsByBotIdWebStreamResponses = {
export type GetBotsByBotIdWebStreamResponse = GetBotsByBotIdWebStreamResponses[keyof GetBotsByBotIdWebStreamResponses];
export type GetBotsByBotIdWebWsData = {
body?: never;
path: {
/**
* Bot ID
*/
bot_id: string;
};
query?: never;
url: '/bots/{bot_id}/web/ws';
};
export type GetBotsByBotIdWebWsErrors = {
/**
* Bad Request
*/
400: HandlersErrorResponse;
/**
* Forbidden
*/
403: HandlersErrorResponse;
/**
* Internal Server Error
*/
500: HandlersErrorResponse;
};
export type GetBotsByBotIdWebWsError = GetBotsByBotIdWebWsErrors[keyof GetBotsByBotIdWebWsErrors];
export type DeleteBotsByIdData = {
body?: never;
path: {
+88
View File
@@ -291,6 +291,50 @@ const docTemplate = `{
}
}
},
"/bots/{bot_id}/cli/ws": {
"get": {
"description": "Upgrade to WebSocket for bidirectional chat streaming with abort support.",
"tags": [
"local-channel"
],
"summary": "WebSocket chat endpoint",
"parameters": [
{
"type": "string",
"description": "Bot ID",
"name": "bot_id",
"in": "path",
"required": true
}
],
"responses": {
"101": {
"description": "Switching Protocols",
"schema": {
"type": "string"
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"403": {
"description": "Forbidden",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
}
}
}
},
"/bots/{bot_id}/container": {
"get": {
"tags": [
@@ -4319,6 +4363,50 @@ const docTemplate = `{
}
}
},
"/bots/{bot_id}/web/ws": {
"get": {
"description": "Upgrade to WebSocket for bidirectional chat streaming with abort support.",
"tags": [
"local-channel"
],
"summary": "WebSocket chat endpoint",
"parameters": [
{
"type": "string",
"description": "Bot ID",
"name": "bot_id",
"in": "path",
"required": true
}
],
"responses": {
"101": {
"description": "Switching Protocols",
"schema": {
"type": "string"
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"403": {
"description": "Forbidden",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
}
}
}
},
"/bots/{id}": {
"get": {
"description": "Get a bot by ID (owner/admin only)",
+88
View File
@@ -282,6 +282,50 @@
}
}
},
"/bots/{bot_id}/cli/ws": {
"get": {
"description": "Upgrade to WebSocket for bidirectional chat streaming with abort support.",
"tags": [
"local-channel"
],
"summary": "WebSocket chat endpoint",
"parameters": [
{
"type": "string",
"description": "Bot ID",
"name": "bot_id",
"in": "path",
"required": true
}
],
"responses": {
"101": {
"description": "Switching Protocols",
"schema": {
"type": "string"
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"403": {
"description": "Forbidden",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
}
}
}
},
"/bots/{bot_id}/container": {
"get": {
"tags": [
@@ -4310,6 +4354,50 @@
}
}
},
"/bots/{bot_id}/web/ws": {
"get": {
"description": "Upgrade to WebSocket for bidirectional chat streaming with abort support.",
"tags": [
"local-channel"
],
"summary": "WebSocket chat endpoint",
"parameters": [
{
"type": "string",
"description": "Bot ID",
"name": "bot_id",
"in": "path",
"required": true
}
],
"responses": {
"101": {
"description": "Switching Protocols",
"schema": {
"type": "string"
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"403": {
"description": "Forbidden",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
}
}
}
},
"/bots/{id}": {
"get": {
"description": "Get a bot by ID (owner/admin only)",
+60
View File
@@ -2398,6 +2398,36 @@ paths:
summary: Subscribe to local channel events via SSE
tags:
- local-channel
/bots/{bot_id}/cli/ws:
get:
description: Upgrade to WebSocket for bidirectional chat streaming with abort
support.
parameters:
- description: Bot ID
in: path
name: bot_id
required: true
type: string
responses:
"101":
description: Switching Protocols
schema:
type: string
"400":
description: Bad Request
schema:
$ref: '#/definitions/handlers.ErrorResponse'
"403":
description: Forbidden
schema:
$ref: '#/definitions/handlers.ErrorResponse'
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/handlers.ErrorResponse'
summary: WebSocket chat endpoint
tags:
- local-channel
/bots/{bot_id}/container:
delete:
parameters:
@@ -5081,6 +5111,36 @@ paths:
summary: Subscribe to local channel events via SSE
tags:
- local-channel
/bots/{bot_id}/web/ws:
get:
description: Upgrade to WebSocket for bidirectional chat streaming with abort
support.
parameters:
- description: Bot ID
in: path
name: bot_id
required: true
type: string
responses:
"101":
description: Switching Protocols
schema:
type: string
"400":
description: Bad Request
schema:
$ref: '#/definitions/handlers.ErrorResponse'
"403":
description: Forbidden
schema:
$ref: '#/definitions/handlers.ErrorResponse'
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/handlers.ErrorResponse'
summary: WebSocket chat endpoint
tags:
- local-channel
/bots/{id}:
delete:
description: Delete a bot user (owner/admin only)