Files
Memoh/internal/conversation/flow/resolver_stream.go
T
Acbox a31995424c feat: add per-route message dispatch modes (inject/parallel/queue)
Introduce three inbound message handling modes for channel adapters:

- inject (default, /btw): when a route has an active agent stream,
  inject the new user message into the running stream via the SDK's
  PrepareStep hook between tool rounds. The message is interleaved at
  the correct position in the persisted round.
- parallel (/now): start a new agent stream immediately, running
  concurrently with any existing stream (preserves current behavior).
- queue (/next): enqueue the message and process it after the current
  stream completes.

Key components:
- RouteDispatcher: per-route state management with inject channel,
  task queue, and active-stream tracking.
- PrepareStep integration: drains inject channel between tool rounds,
  records insertion position via InjectedRecorder for correct
  persistence ordering.
- interleaveInjectedMessages: inserts injected user messages at their
  actual injection position within the persisted message round.
- Parallel mode isolation: /now streams do not interact with the
  dispatcher, preventing them from clearing another stream's active
  state.
2026-04-03 01:17:33 +08:00

223 lines
6.0 KiB
Go

package flow
import (
"context"
"encoding/json"
"fmt"
"log/slog"
sdk "github.com/memohai/twilight-ai/sdk"
agentpkg "github.com/memohai/memoh/internal/agent"
"github.com/memohai/memoh/internal/conversation"
)
// WSStreamEvent represents a raw JSON event forwarded from the agent.
type WSStreamEvent = json.RawMessage
// StreamChat runs a streaming chat via the internal agent.
func (r *Resolver) StreamChat(ctx context.Context, req conversation.ChatRequest) (<-chan conversation.StreamChunk, <-chan error) {
chunkCh := make(chan conversation.StreamChunk)
errCh := make(chan error, 1)
r.logger.Info("agent stream start",
slog.String("bot_id", req.BotID),
slog.String("chat_id", req.ChatID),
)
go func() {
defer close(chunkCh)
defer close(errCh)
streamReq := req
rc, err := r.resolve(ctx, streamReq)
if err != nil {
r.logger.Error("agent stream resolve failed",
slog.String("bot_id", streamReq.BotID),
slog.String("chat_id", streamReq.ChatID),
slog.Any("error", err),
)
errCh <- err
return
}
streamReq.Query = rc.query
go r.maybeGenerateSessionTitle(context.WithoutCancel(ctx), streamReq, streamReq.Query)
cfg := rc.runConfig
cfg = r.prepareRunConfig(ctx, cfg)
eventCh := r.agent.Stream(ctx, cfg)
stored := false
for event := range eventCh {
if event.Type == agentpkg.EventError {
r.logger.Error("agent stream error",
slog.String("bot_id", streamReq.BotID),
slog.String("chat_id", streamReq.ChatID),
slog.String("model_id", rc.model.ID),
slog.String("error", event.Error),
)
}
data, err := json.Marshal(event)
if err != nil {
continue
}
if !stored && event.IsTerminal() && len(event.Messages) > 0 {
if _, storeErr := r.tryStoreStream(ctx, streamReq, data, rc.model.ID, rc); storeErr != nil {
r.logger.Error("stream persist failed", slog.Any("error", storeErr))
} else {
stored = true
}
}
chunkCh <- conversation.StreamChunk(data)
}
}()
return chunkCh, errCh
}
// StreamChatWS resolves the agent context and streams agent events.
// Events are sent on eventCh. When abortCh is closed, the context is cancelled.
func (r *Resolver) StreamChatWS(
ctx context.Context,
req conversation.ChatRequest,
eventCh chan<- WSStreamEvent,
abortCh <-chan struct{},
) error {
rc, err := r.resolve(ctx, req)
if err != nil {
return fmt.Errorf("resolve: %w", err)
}
req.Query = rc.query
go r.maybeGenerateSessionTitle(context.WithoutCancel(ctx), req, req.Query)
streamCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-abortCh:
cancel()
case <-streamCtx.Done():
}
}()
cfg := rc.runConfig
cfg = r.prepareRunConfig(streamCtx, cfg)
agentEventCh := r.agent.Stream(streamCtx, cfg)
modelID := rc.model.ID
stored := false
for event := range agentEventCh {
if event.Type == agentpkg.EventError {
r.logger.Error("agent stream error",
slog.String("bot_id", req.BotID),
slog.String("chat_id", req.ChatID),
slog.String("model_id", modelID),
slog.String("error", event.Error),
)
}
data, err := json.Marshal(event)
if err != nil {
continue
}
if !stored && event.IsTerminal() && len(event.Messages) > 0 {
if _, storeErr := r.tryStoreStream(ctx, req, data, modelID, rc); storeErr != nil {
r.logger.Error("ws persist failed", slog.Any("error", storeErr))
} else {
stored = true
}
}
select {
case eventCh <- json.RawMessage(data):
case <-ctx.Done():
return ctx.Err()
}
}
return nil
}
// tryStoreStream attempts to extract final messages from a stream event and persist them.
func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, data []byte, modelID string, rc resolvedContext) (bool, error) {
var envelope struct {
Type string `json:"type"`
Messages json.RawMessage `json:"messages"`
Usage json.RawMessage `json:"usage,omitempty"`
}
if err := json.Unmarshal(data, &envelope); err != nil {
return false, nil
}
if len(envelope.Messages) == 0 {
return false, nil
}
var sdkMsgs []sdk.Message
if err := json.Unmarshal(envelope.Messages, &sdkMsgs); err != nil || len(sdkMsgs) == 0 {
return false, nil
}
outputMessages := sdkMessagesToModelMessages(sdkMsgs)
roundMessages := prependUserMessage(req.Query, outputMessages)
if rc.injectedRecords != nil && len(*rc.injectedRecords) > 0 {
roundMessages = interleaveInjectedMessages(roundMessages, *rc.injectedRecords)
}
if err := r.storeRound(ctx, req, roundMessages, modelID); err != nil {
return false, err
}
if inputTokens := extractInputTokensFromUsage(envelope.Usage); inputTokens > 0 {
go r.maybeCompact(context.WithoutCancel(ctx), req, rc, inputTokens)
}
return true, nil
}
// interleaveInjectedMessages inserts injected user messages at their correct
// positions within the round. Each record's InsertAfter value indicates how
// many output messages preceded the injection.
//
// round layout: [user_A, output_0, output_1, ..., output_N]
// InsertAfter=K → insert after round[K] (i.e. after the K-th output message).
func interleaveInjectedMessages(round []conversation.ModelMessage, injections []conversation.InjectedMessageRecord) []conversation.ModelMessage {
if len(injections) == 0 {
return round
}
result := make([]conversation.ModelMessage, 0, len(round)+len(injections))
injIdx := 0
for i, msg := range round {
result = append(result, msg)
for injIdx < len(injections) && injections[injIdx].InsertAfter == i {
result = append(result, conversation.ModelMessage{
Role: "user",
Content: conversation.NewTextContent(injections[injIdx].HeaderifiedText),
})
injIdx++
}
}
for ; injIdx < len(injections); injIdx++ {
result = append(result, conversation.ModelMessage{
Role: "user",
Content: conversation.NewTextContent(injections[injIdx].HeaderifiedText),
})
}
return result
}
func extractInputTokensFromUsage(raw json.RawMessage) int {
if len(raw) == 0 {
return 0
}
var u struct {
InputTokens int `json:"inputTokens"`
}
if json.Unmarshal(raw, &u) != nil {
return 0
}
return u.InputTokens
}