feat: add per-route message dispatch modes (inject/parallel/queue)

Introduce three inbound message handling modes for channel adapters:

- inject (default, /btw): when a route has an active agent stream,
  inject the new user message into the running stream via the SDK's
  PrepareStep hook between tool rounds. The message is interleaved at
  the correct position in the persisted round.
- parallel (/now): start a new agent stream immediately, running
  concurrently with any existing stream (preserves current behavior).
- queue (/next): enqueue the message and process it after the current
  stream completes.

Key components:
- RouteDispatcher: per-route state management with inject channel,
  task queue, and active-stream tracking.
- PrepareStep integration: drains inject channel between tool rounds,
  records insertion position via InjectedRecorder for correct
  persistence ordering.
- interleaveInjectedMessages: inserts injected user messages at their
  actual injection position within the persisted message round.
- Parallel mode isolation: /now streams do not interact with the
  dispatcher, preventing them from clearing another stream's active
  state.
This commit is contained in:
Acbox
2026-04-02 21:43:13 +08:00
parent 33b57ee345
commit a31995424c
10 changed files with 947 additions and 6 deletions
+41
View File
@@ -98,6 +98,47 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
if readMediaState != nil {
prepareStep = readMediaState.prepareStep
}
initialMsgCount := len(cfg.Messages)
if cfg.InjectCh != nil {
basePrepare := prepareStep
prepareStep = func(p *sdk.GenerateParams) *sdk.GenerateParams {
if basePrepare != nil {
if override := basePrepare(p); override != nil {
p = override
}
}
for {
select {
case injected, ok := <-cfg.InjectCh:
if !ok {
break
}
text := strings.TrimSpace(injected.HeaderifiedText)
if text == "" {
text = strings.TrimSpace(injected.Text)
}
if text != "" {
insertAfter := len(p.Messages) - initialMsgCount
p.Messages = append(p.Messages, sdk.UserMessage(text))
if cfg.InjectedRecorder != nil {
cfg.InjectedRecorder(text, insertAfter)
}
a.logger.Info("injected user message into agent stream",
slog.String("bot_id", cfg.Identity.BotID),
slog.Int("insert_after", insertAfter),
)
}
continue
default:
}
break
}
return p
}
}
opts := a.buildGenerateOptions(cfg, tools, prepareStep)
streamResult, err := a.client.StreamText(ctx, opts...)
+18
View File
@@ -46,6 +46,13 @@ type LoopDetectionConfig struct {
Enabled bool
}
// InjectMessage carries a user message to be injected into a running agent
// stream between tool rounds via the PrepareStep hook.
type InjectMessage struct {
Text string
HeaderifiedText string
}
// RunConfig holds everything needed for a single agent invocation.
type RunConfig struct {
Model *sdk.Model
@@ -60,6 +67,17 @@ type RunConfig struct {
Identity SessionContext
Skills []SkillEntry
LoopDetection LoopDetectionConfig
// InjectCh receives user messages to inject between tool rounds.
// When non-nil, a PrepareStep hook drains this channel and appends
// user messages to the conversation before the next LLM call.
InjectCh <-chan InjectMessage
// InjectedRecorder is called each time a message is injected via
// PrepareStep, recording the headerified text and the number of SDK
// output messages that preceded the injection. Used by the resolver
// to interleave injected messages at the correct position in storeRound.
InjectedRecorder func(headerifiedText string, insertAfter int)
}
// GenerateResult holds the result of a non-streaming agent invocation.
+179 -1
View File
@@ -98,6 +98,7 @@ type ChannelInboundProcessor struct {
tokenTTL time.Duration
identity *IdentityResolver
policy PolicyService
dispatcher *RouteDispatcher
acl chatACL
observer channel.StreamObserver
ttsService ttsSynthesizer
@@ -206,6 +207,14 @@ func (p *ChannelInboundProcessor) SetCommandHandler(handler *command.Handler) {
p.commandHandler = handler
}
// SetDispatcher configures the per-route message dispatcher for inject/queue/parallel modes.
func (p *ChannelInboundProcessor) SetDispatcher(dispatcher *RouteDispatcher) {
if p == nil {
return
}
p.dispatcher = dispatcher
}
// HandleInbound processes an inbound channel message through identity resolution and chat gateway.
func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, sender channel.StreamReplySender) error {
if p.runner == nil {
@@ -270,7 +279,9 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
return p.handleNewSessionCommand(ctx, cfg, msg, sender, identity)
}
if p.commandHandler != nil && p.commandHandler.IsCommand(cmdText) && isDirectedAtBot(msg) {
// Skip generic command handler for mode-prefix commands (/btw, /now, /next)
// so they pass through to mode detection below.
if p.commandHandler != nil && p.commandHandler.IsCommand(cmdText) && !IsModeCommand(cmdText) && isDirectedAtBot(msg) {
reply, err := p.commandHandler.Execute(ctx, strings.TrimSpace(identity.BotID), strings.TrimSpace(identity.ChannelIdentityID), cmdText)
if err != nil {
reply = "Error: " + err.Error()
@@ -284,6 +295,14 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
resolvedAttachments := p.ingestInboundAttachments(ctx, cfg, msg, strings.TrimSpace(identity.BotID), msg.Message.Attachments)
attachments := mapChannelToChatAttachments(resolvedAttachments)
text = buildInboundQuery(msg.Message, attachments)
// Detect inbound mode from message prefix (/btw, /now, /next).
// Only applies to non-local channels; WebUI always uses the default flow.
// Must run after buildInboundQuery so the prefix is stripped from the final text.
inboundMode := ModeInject
if !isLocalChannelType(msg.Channel) {
inboundMode, text = DetectMode(text)
}
threadID := extractThreadID(msg)
// Resolve or create the route via channel_routes.
@@ -374,6 +393,66 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
return nil
}
routeID := strings.TrimSpace(resolved.RouteID)
// --- Dispatcher-based mode handling (inject / queue) ---
// For non-parallel modes, when a route already has an active agent stream,
// short-circuit here instead of starting a new stream.
if p.dispatcher != nil && !isLocalChannelType(msg.Channel) && inboundMode != ModeParallel {
if p.dispatcher.IsActive(routeID) {
headerifiedText := flow.FormatUserHeader(
strings.TrimSpace(msg.Message.ID),
strings.TrimSpace(identity.ChannelIdentityID),
strings.TrimSpace(identity.DisplayName),
msg.Channel.String(),
strings.TrimSpace(msg.Conversation.Type),
strings.TrimSpace(msg.Conversation.Name),
collectAttachmentPaths(attachments),
time.Now().UTC(),
"",
text,
)
switch inboundMode {
case ModeInject:
// Don't persist here — the injected message will be interleaved
// at the correct position within the round by
// interleaveInjectedMessages in storeRound.
injected := p.dispatcher.Inject(routeID, InjectMessage{
Text: text,
Attachments: attachments,
HeaderifiedText: headerifiedText,
})
if injected {
p.sendModeConfirmation(ctx, sender, msg, identity, "inject")
} else {
if p.logger != nil {
p.logger.Warn("inject failed (channel full), falling through to new stream",
slog.String("route_id", routeID))
}
goto startStream
}
return nil
case ModeQueue:
p.persistPassiveMessage(ctx, identity, msg, text, attachments, routeID, sessionID)
p.dispatcher.Enqueue(routeID, QueuedTask{
Ctx: ctx,
Cfg: cfg,
Msg: msg,
Sender: sender,
Ident: identity,
Text: text,
Attachs: attachments,
})
p.sendModeConfirmation(ctx, sender, msg, identity, "queue")
return nil
}
}
}
startStream:
// Issue chat token for reply routing.
chatToken := ""
if p.jwtSecret != "" && strings.TrimSpace(msg.ReplyTarget) != "" {
@@ -516,6 +595,33 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
return result
}
// Mark this route as active in the dispatcher so subsequent messages
// can be injected or queued. Produces the inject channel for this stream.
// Parallel mode (/now) skips the dispatcher entirely — it must not
// interfere with the active flag or drain the queue of another stream.
var injectCh <-chan InjectMessage
if p.dispatcher != nil && !isLocalChannelType(msg.Channel) && inboundMode != ModeParallel {
injectCh = p.dispatcher.MarkActive(routeID)
defer func() {
p.drainQueue(context.WithoutCancel(ctx), routeID)
}()
}
// Convert inbound InjectMessage channel to conversation.InjectMessage channel.
var convInjectCh chan conversation.InjectMessage
if injectCh != nil {
convInjectCh = make(chan conversation.InjectMessage, injectChBuffer)
go func() {
for im := range injectCh {
convInjectCh <- conversation.InjectMessage{
Text: im.Text,
Attachments: im.Attachments,
HeaderifiedText: im.HeaderifiedText,
}
}
close(convInjectCh)
}()
}
chatReq := conversation.ChatRequest{
BotID: identity.BotID,
ChatID: activeChatID,
@@ -537,6 +643,9 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
Attachments: attachments,
OutboundAssetCollector: assetCollector,
}
if convInjectCh != nil {
chatReq.InjectCh = convInjectCh
}
if mid, _ := msg.Metadata["model_id"].(string); strings.TrimSpace(mid) != "" {
chatReq.Model = strings.TrimSpace(mid)
}
@@ -687,6 +796,75 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
return nil
}
// sendModeConfirmation sends a lightweight acknowledgement to the user when
// their message is injected or queued rather than triggering a new stream.
func (p *ChannelInboundProcessor) sendModeConfirmation(
ctx context.Context,
_ channel.StreamReplySender,
msg channel.InboundMessage,
identity InboundIdentity,
mode string,
) {
target := strings.TrimSpace(msg.ReplyTarget)
sourceMessageID := strings.TrimSpace(msg.Message.ID)
if target == "" || sourceMessageID == "" {
return
}
if p.reactor != nil {
emoji := "👀"
if mode == "queue" {
emoji = "📋"
}
_ = p.reactor.React(ctx, strings.TrimSpace(identity.BotID), msg.Channel, channel.ReactRequest{
Target: target,
MessageID: sourceMessageID,
Emoji: emoji,
})
}
}
// drainQueue marks the route as done and processes any queued tasks.
func (p *ChannelInboundProcessor) drainQueue(ctx context.Context, routeID string) {
if p.dispatcher == nil {
return
}
result := p.dispatcher.MarkDone(routeID)
for _, fn := range result.PendingPersists {
fn(ctx)
}
for _, task := range result.QueuedTasks {
if p.logger != nil {
p.logger.Info("processing queued task",
slog.String("route_id", routeID),
slog.String("query", strings.TrimSpace(task.Text)),
)
}
if err := p.HandleInbound(ctx, task.Cfg, task.Msg, task.Sender); err != nil { //nolint:contextcheck // ctx is already WithoutCancel from the defer caller
if p.logger != nil {
p.logger.Error("queued task processing failed",
slog.String("route_id", routeID),
slog.Any("error", err),
)
}
}
}
}
func collectAttachmentPaths(attachments []conversation.ChatAttachment) []string {
if len(attachments) == 0 {
return nil
}
paths := make([]string, 0, len(attachments))
for _, att := range attachments {
if p := strings.TrimSpace(att.Path); p != "" {
paths = append(paths, p)
}
}
return paths
}
func shouldTriggerAssistantResponse(msg channel.InboundMessage) bool {
if isDirectConversationType(msg.Conversation.Type) {
return true
+308
View File
@@ -0,0 +1,308 @@
package inbound
import (
"context"
"log/slog"
"strings"
"sync"
"time"
"github.com/memohai/memoh/internal/channel"
"github.com/memohai/memoh/internal/conversation"
)
// InboundMode determines how a new inbound message is handled when an agent
// stream is already active for the same route.
type InboundMode int
const (
// ModeInject (default, command /btw) injects the message into the active
// agent stream via the PrepareStep hook so the LLM sees it between tool
// rounds. When no stream is active, starts one normally.
ModeInject InboundMode = iota
// ModeParallel (command /now) starts a new agent stream immediately,
// running concurrently with any existing stream.
ModeParallel
// ModeQueue (command /next) queues the message and processes it after the
// current agent stream completes.
ModeQueue
)
// InjectMessage carries a user message to be injected into a running agent stream.
type InjectMessage struct {
Text string
Attachments []conversation.ChatAttachment
// HeaderifiedText is the formatted user header text ready for SDK injection.
HeaderifiedText string
}
// QueuedTask holds everything needed to start an agent stream for a queued message.
type QueuedTask struct {
Ctx context.Context
Cfg channel.ChannelConfig
Msg channel.InboundMessage
Sender channel.StreamReplySender
Ident InboundIdentity
Text string
Attachs []conversation.ChatAttachment
}
// PersistFunc is a deferred persistence closure called after the active stream
// completes (and its storeRound has run), ensuring correct created_at ordering.
type PersistFunc func(ctx context.Context)
// routeState tracks in-flight agent activity for a single route.
type routeState struct {
mu sync.Mutex
active bool
injectCh chan InjectMessage
queue []QueuedTask
pendingPersists []PersistFunc
lastUsed time.Time
}
// RouteDispatcher manages per-route concurrency for inbound message processing.
// It decides whether a new message should be injected into an active stream,
// run in parallel, or be queued.
type RouteDispatcher struct {
mu sync.RWMutex
routes map[string]*routeState
logger *slog.Logger
}
// NewRouteDispatcher creates a dispatcher with background cleanup.
func NewRouteDispatcher(logger *slog.Logger) *RouteDispatcher {
if logger == nil {
logger = slog.Default()
}
return &RouteDispatcher{
routes: make(map[string]*routeState),
logger: logger.With(slog.String("component", "route_dispatcher")),
}
}
const injectChBuffer = 16
func (d *RouteDispatcher) getOrCreate(routeID string) *routeState {
d.mu.RLock()
rs, ok := d.routes[routeID]
d.mu.RUnlock()
if ok {
return rs
}
d.mu.Lock()
defer d.mu.Unlock()
if rs, ok = d.routes[routeID]; ok {
return rs
}
rs = &routeState{
injectCh: make(chan InjectMessage, injectChBuffer),
lastUsed: time.Now(),
}
d.routes[routeID] = rs
return rs
}
// IsActive reports whether the given route has an active agent stream.
func (d *RouteDispatcher) IsActive(routeID string) bool {
routeID = strings.TrimSpace(routeID)
if routeID == "" {
return false
}
rs := d.getOrCreate(routeID)
rs.mu.Lock()
defer rs.mu.Unlock()
return rs.active
}
// MarkActive marks a route as having an active stream and returns the inject
// channel that the agent should drain via PrepareStep.
func (d *RouteDispatcher) MarkActive(routeID string) <-chan InjectMessage {
routeID = strings.TrimSpace(routeID)
if routeID == "" {
return nil
}
rs := d.getOrCreate(routeID)
rs.mu.Lock()
defer rs.mu.Unlock()
rs.active = true
rs.lastUsed = time.Now()
return rs.injectCh
}
// MarkDoneResult holds the data returned when a route transitions from active to idle.
type MarkDoneResult struct {
PendingPersists []PersistFunc
QueuedTasks []QueuedTask
}
// MarkDone marks a route as idle and returns pending persist functions (to be
// called now that storeRound has completed) and any queued tasks.
func (d *RouteDispatcher) MarkDone(routeID string) MarkDoneResult {
routeID = strings.TrimSpace(routeID)
if routeID == "" {
return MarkDoneResult{}
}
rs := d.getOrCreate(routeID)
rs.mu.Lock()
defer rs.mu.Unlock()
rs.active = false
rs.lastUsed = time.Now()
drainInjectCh(rs.injectCh)
var persists []PersistFunc
if len(rs.pendingPersists) > 0 {
persists = rs.pendingPersists
rs.pendingPersists = nil
}
var tasks []QueuedTask
if len(rs.queue) > 0 {
tasks = rs.queue
rs.queue = nil
}
return MarkDoneResult{PendingPersists: persists, QueuedTasks: tasks}
}
// AddPendingPersist records a deferred persist closure to be executed after the
// active stream completes. This ensures injected messages get a created_at
// timestamp after the triggering message's round.
func (d *RouteDispatcher) AddPendingPersist(routeID string, fn PersistFunc) {
routeID = strings.TrimSpace(routeID)
if routeID == "" || fn == nil {
return
}
rs := d.getOrCreate(routeID)
rs.mu.Lock()
defer rs.mu.Unlock()
rs.pendingPersists = append(rs.pendingPersists, fn)
}
// Inject sends a message to the inject channel of an active route.
// Returns true if the message was accepted (route is active and channel not full).
func (d *RouteDispatcher) Inject(routeID string, msg InjectMessage) bool {
routeID = strings.TrimSpace(routeID)
if routeID == "" {
return false
}
rs := d.getOrCreate(routeID)
rs.mu.Lock()
defer rs.mu.Unlock()
if !rs.active {
return false
}
select {
case rs.injectCh <- msg:
if d.logger != nil {
d.logger.Info("message injected into active stream",
slog.String("route_id", routeID),
)
}
return true
default:
if d.logger != nil {
d.logger.Warn("inject channel full, message dropped",
slog.String("route_id", routeID),
)
}
return false
}
}
// Enqueue adds a task to the route's queue for later processing.
func (d *RouteDispatcher) Enqueue(routeID string, task QueuedTask) {
routeID = strings.TrimSpace(routeID)
if routeID == "" {
return
}
rs := d.getOrCreate(routeID)
rs.mu.Lock()
defer rs.mu.Unlock()
rs.queue = append(rs.queue, task)
rs.lastUsed = time.Now()
if d.logger != nil {
d.logger.Info("message queued",
slog.String("route_id", routeID),
slog.Int("queue_size", len(rs.queue)),
)
}
}
// Cleanup removes idle route states older than maxAge.
func (d *RouteDispatcher) Cleanup(maxAge time.Duration) {
d.mu.Lock()
defer d.mu.Unlock()
cutoff := time.Now().Add(-maxAge)
for id, rs := range d.routes {
rs.mu.Lock()
idle := !rs.active && rs.lastUsed.Before(cutoff) && len(rs.queue) == 0
rs.mu.Unlock()
if idle {
delete(d.routes, id)
}
}
}
func drainInjectCh(ch chan InjectMessage) {
for {
select {
case <-ch:
default:
return
}
}
}
// DetectMode parses a message prefix to determine the inbound mode.
// Returns the mode and the text with the prefix stripped.
func DetectMode(text string) (InboundMode, string) {
trimmed := strings.TrimSpace(text)
if trimmed == "" {
return ModeInject, trimmed
}
type modePrefix struct {
prefix string
mode InboundMode
}
prefixes := []modePrefix{
{"/now ", ModeParallel},
{"/next ", ModeQueue},
{"/btw ", ModeInject},
}
lower := strings.ToLower(trimmed)
for _, mp := range prefixes {
if strings.HasPrefix(lower, mp.prefix) {
return mp.mode, strings.TrimSpace(trimmed[len(mp.prefix):])
}
}
// Exact match without trailing text (bare command)
barePrefixes := []modePrefix{
{"/now", ModeParallel},
{"/next", ModeQueue},
{"/btw", ModeInject},
}
for _, mp := range barePrefixes {
if lower == mp.prefix {
return mp.mode, ""
}
}
return ModeInject, trimmed
}
// IsModeCommand reports whether the text is a mode-prefix command
// (/btw, /now, /next), so the generic command handler should skip it.
func IsModeCommand(text string) bool {
trimmed := strings.ToLower(strings.TrimSpace(text))
if trimmed == "" {
return false
}
for _, prefix := range []string{"/now", "/next", "/btw"} {
if trimmed == prefix || strings.HasPrefix(trimmed, prefix+" ") || strings.HasPrefix(trimmed, prefix+"\t") {
return true
}
}
return false
}
+302
View File
@@ -0,0 +1,302 @@
package inbound
import (
"context"
"log/slog"
"sync"
"testing"
"time"
)
func TestDetectMode(t *testing.T) {
tests := []struct {
input string
wantMode InboundMode
wantText string
}{
{"hello world", ModeInject, "hello world"},
{"/btw hello", ModeInject, "hello"},
{"/now hello", ModeParallel, "hello"},
{"/next hello", ModeQueue, "hello"},
{"/BTW hello", ModeInject, "hello"},
{"/NOW hello", ModeParallel, "hello"},
{"/NEXT hello", ModeQueue, "hello"},
{"/now", ModeParallel, ""},
{"/next", ModeQueue, ""},
{"/btw", ModeInject, ""},
{" /now hello ", ModeParallel, "hello"},
{"/unknown hello", ModeInject, "/unknown hello"},
{"", ModeInject, ""},
{"/new session", ModeInject, "/new session"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
mode, text := DetectMode(tt.input)
if mode != tt.wantMode {
t.Errorf("DetectMode(%q) mode = %d, want %d", tt.input, mode, tt.wantMode)
}
if text != tt.wantText {
t.Errorf("DetectMode(%q) text = %q, want %q", tt.input, text, tt.wantText)
}
})
}
}
func TestIsModeCommand(t *testing.T) {
tests := []struct {
input string
want bool
}{
{"/btw hello", true},
{"/now hello", true},
{"/next hello", true},
{"/btw", true},
{"/now", true},
{"/next", true},
{"/new", false},
{"/fs list", false},
{"hello", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := IsModeCommand(tt.input)
if got != tt.want {
t.Errorf("IsModeCommand(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestRouteDispatcher_InjectWhenActive(t *testing.T) {
d := NewRouteDispatcher(slog.Default())
routeID := "route-1"
if d.IsActive(routeID) {
t.Fatal("expected route to be inactive initially")
}
injectCh := d.MarkActive(routeID)
if injectCh == nil {
t.Fatal("expected non-nil inject channel")
}
if !d.IsActive(routeID) {
t.Fatal("expected route to be active after MarkActive")
}
msg := InjectMessage{Text: "hello", HeaderifiedText: "[User] hello"}
if !d.Inject(routeID, msg) {
t.Fatal("expected inject to succeed when route is active")
}
select {
case got := <-injectCh:
if got.Text != "hello" {
t.Errorf("got text %q, want %q", got.Text, "hello")
}
default:
t.Fatal("expected message on inject channel")
}
}
func TestRouteDispatcher_InjectWhenInactive(t *testing.T) {
d := NewRouteDispatcher(slog.Default())
routeID := "route-1"
msg := InjectMessage{Text: "hello"}
if d.Inject(routeID, msg) {
t.Fatal("expected inject to fail when route is inactive")
}
}
func TestRouteDispatcher_QueueAndDrain(t *testing.T) {
d := NewRouteDispatcher(slog.Default())
routeID := "route-1"
d.MarkActive(routeID)
d.Enqueue(routeID, QueuedTask{Text: "task-1"})
d.Enqueue(routeID, QueuedTask{Text: "task-2"})
result := d.MarkDone(routeID)
if len(result.QueuedTasks) != 2 {
t.Fatalf("expected 2 queued tasks, got %d", len(result.QueuedTasks))
}
if result.QueuedTasks[0].Text != "task-1" || result.QueuedTasks[1].Text != "task-2" {
t.Errorf("unexpected task order: %v", result.QueuedTasks)
}
if d.IsActive(routeID) {
t.Fatal("expected route to be inactive after MarkDone")
}
}
func TestRouteDispatcher_MarkDoneReturnsNilWhenEmpty(t *testing.T) {
d := NewRouteDispatcher(slog.Default())
routeID := "route-1"
d.MarkActive(routeID)
result := d.MarkDone(routeID)
if result.QueuedTasks != nil {
t.Fatalf("expected nil queued tasks, got %v", result.QueuedTasks)
}
if result.PendingPersists != nil {
t.Fatalf("expected nil pending persists, got %v", result.PendingPersists)
}
}
func TestRouteDispatcher_ConcurrentInject(t *testing.T) {
d := NewRouteDispatcher(slog.Default())
routeID := "route-1"
injectCh := d.MarkActive(routeID)
const numMessages = 10
var wg sync.WaitGroup
wg.Add(numMessages)
for i := 0; i < numMessages; i++ {
go func() {
defer wg.Done()
d.Inject(routeID, InjectMessage{Text: "msg"})
}()
}
wg.Wait()
count := 0
for {
select {
case <-injectCh:
count++
default:
goto done
}
}
done:
if count != numMessages {
t.Errorf("expected %d messages, got %d", numMessages, count)
}
}
func TestRouteDispatcher_ParallelBypass(t *testing.T) {
d := NewRouteDispatcher(slog.Default())
routeID := "route-1"
d.MarkActive(routeID)
// In parallel mode, the caller does not interact with the dispatcher
// at all — it just starts a new stream directly. Verify the route
// stays active without interference.
if !d.IsActive(routeID) {
t.Fatal("expected route to still be active")
}
d.MarkDone(routeID)
if d.IsActive(routeID) {
t.Fatal("expected route to be inactive after MarkDone")
}
}
func TestRouteDispatcher_Cleanup(t *testing.T) {
d := NewRouteDispatcher(slog.Default())
d.MarkActive("route-1")
d.MarkDone("route-1")
d.MarkActive("route-2")
d.mu.RLock()
initialCount := len(d.routes)
d.mu.RUnlock()
if initialCount != 2 {
t.Fatalf("expected 2 routes, got %d", initialCount)
}
d.Cleanup(0)
d.mu.RLock()
afterCount := len(d.routes)
d.mu.RUnlock()
// route-1 is idle → cleaned up; route-2 is active → kept
if afterCount != 1 {
t.Fatalf("expected 1 route after cleanup, got %d", afterCount)
}
if d.IsActive("route-2") != true {
t.Fatal("expected route-2 to still be active")
}
}
func TestRouteDispatcher_InjectChannelFull(t *testing.T) {
d := NewRouteDispatcher(slog.Default())
routeID := "route-1"
d.MarkActive(routeID)
// Fill the inject channel to capacity
for i := 0; i < injectChBuffer; i++ {
if !d.Inject(routeID, InjectMessage{Text: "fill"}) {
t.Fatalf("expected inject %d to succeed", i)
}
}
// Next inject should fail (channel full)
if d.Inject(routeID, InjectMessage{Text: "overflow"}) {
t.Fatal("expected inject to fail when channel is full")
}
}
func TestRouteDispatcher_QueueWhenInactive(t *testing.T) {
d := NewRouteDispatcher(slog.Default())
routeID := "route-1"
// Enqueue without marking active — still stores in queue
d.Enqueue(routeID, QueuedTask{Text: "task-1"})
// MarkActive then MarkDone should return the queued task
d.MarkActive(routeID)
result := d.MarkDone(routeID)
if len(result.QueuedTasks) != 1 {
t.Fatalf("expected 1 queued task, got %d", len(result.QueuedTasks))
}
}
func TestRouteDispatcher_MultipleMarkActive(t *testing.T) {
d := NewRouteDispatcher(slog.Default())
routeID := "route-1"
ch1 := d.MarkActive(routeID)
ch2 := d.MarkActive(routeID)
if ch1 == nil || ch2 == nil {
t.Fatal("expected non-nil channels")
}
_ = time.Now()
}
func TestRouteDispatcher_PendingPersistOrder(t *testing.T) {
d := NewRouteDispatcher(slog.Default())
routeID := "route-1"
d.MarkActive(routeID)
var order []string
d.AddPendingPersist(routeID, func(_ context.Context) {
order = append(order, "B")
})
d.AddPendingPersist(routeID, func(_ context.Context) {
order = append(order, "C")
})
result := d.MarkDone(routeID)
if len(result.PendingPersists) != 2 {
t.Fatalf("expected 2 pending persists, got %d", len(result.PendingPersists))
}
// Execute persists — they should run in insertion order (B then C)
for _, fn := range result.PendingPersists {
fn(context.Background())
}
if len(order) != 2 || order[0] != "B" || order[1] != "C" {
t.Errorf("expected [B C], got %v", order)
}
}
+40 -5
View File
@@ -12,6 +12,7 @@ import (
"sort"
"strconv"
"strings"
"sync"
"time"
sdk "github.com/memohai/twilight-ai/sdk"
@@ -136,10 +137,11 @@ type usageInfo struct {
}
type resolvedContext struct {
runConfig agentpkg.RunConfig
model models.GetResponse
provider sqlc.LlmProvider
query string // headerified query
runConfig agentpkg.RunConfig
model models.GetResponse
provider sqlc.LlmProvider
query string // headerified query
injectedRecords *[]conversation.InjectedMessageRecord
}
func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (resolvedContext, error) {
@@ -292,7 +294,40 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
LoopDetection: agentpkg.LoopDetectionConfig{Enabled: loopDetectionEnabled},
}
return resolvedContext{runConfig: runCfg, model: chatModel, provider: provider, query: headerifiedQuery}, nil
var injectedRecords *[]conversation.InjectedMessageRecord
if req.InjectCh != nil {
agentInjectCh := make(chan agentpkg.InjectMessage, cap(req.InjectCh))
go func() {
for msg := range req.InjectCh {
agentInjectCh <- agentpkg.InjectMessage{
Text: msg.Text,
HeaderifiedText: msg.HeaderifiedText,
}
}
close(agentInjectCh)
}()
runCfg.InjectCh = agentInjectCh
records := make([]conversation.InjectedMessageRecord, 0)
injectedRecords = &records
var recMu sync.Mutex
runCfg.InjectedRecorder = func(headerifiedText string, insertAfter int) {
recMu.Lock()
*injectedRecords = append(*injectedRecords, conversation.InjectedMessageRecord{
HeaderifiedText: headerifiedText,
InsertAfter: insertAfter,
})
recMu.Unlock()
}
}
return resolvedContext{
runConfig: runCfg,
model: chatModel,
provider: provider,
query: headerifiedQuery,
injectedRecords: injectedRecords,
}, nil
}
// Chat sends a synchronous chat request and stores the result.
@@ -162,6 +162,10 @@ func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequ
outputMessages := sdkMessagesToModelMessages(sdkMsgs)
roundMessages := prependUserMessage(req.Query, outputMessages)
if rc.injectedRecords != nil && len(*rc.injectedRecords) > 0 {
roundMessages = interleaveInjectedMessages(roundMessages, *rc.injectedRecords)
}
if err := r.storeRound(ctx, req, roundMessages, modelID); err != nil {
return false, err
}
@@ -173,6 +177,37 @@ func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequ
return true, nil
}
// interleaveInjectedMessages inserts injected user messages at their correct
// positions within the round. Each record's InsertAfter value indicates how
// many output messages preceded the injection.
//
// round layout: [user_A, output_0, output_1, ..., output_N]
// InsertAfter=K → insert after round[K] (i.e. after the K-th output message).
func interleaveInjectedMessages(round []conversation.ModelMessage, injections []conversation.InjectedMessageRecord) []conversation.ModelMessage {
if len(injections) == 0 {
return round
}
result := make([]conversation.ModelMessage, 0, len(round)+len(injections))
injIdx := 0
for i, msg := range round {
result = append(result, msg)
for injIdx < len(injections) && injections[injIdx].InsertAfter == i {
result = append(result, conversation.ModelMessage{
Role: "user",
Content: conversation.NewTextContent(injections[injIdx].HeaderifiedText),
})
injIdx++
}
}
for ; injIdx < len(injections); injIdx++ {
result = append(result, conversation.ModelMessage{
Role: "user",
Content: conversation.NewTextContent(injections[injIdx].HeaderifiedText),
})
}
return result
}
func extractInputTokensFromUsage(raw json.RawMessage) int {
if len(raw) == 0 {
return 0
+22
View File
@@ -241,6 +241,10 @@ type ChatRequest struct {
// Set by the inbound channel processor; called by the resolver at persist time.
OutboundAssetCollector func() []OutboundAssetRef `json:"-"`
// InjectCh receives user messages to inject into the active agent stream
// between tool rounds via the PrepareStep hook. Nil means no injection.
InjectCh <-chan InjectMessage `json:"-"`
Query string `json:"query"`
Model string `json:"model,omitempty"`
Provider string `json:"provider,omitempty"`
@@ -251,6 +255,24 @@ type ChatRequest struct {
Attachments []ChatAttachment `json:"attachments,omitempty"`
}
// InjectMessage carries a user message to be injected into a running agent
// stream between tool rounds.
type InjectMessage struct {
Text string
Attachments []ChatAttachment
HeaderifiedText string
}
// InjectedMessageRecord records a message that was injected via PrepareStep,
// together with its position in the output message sequence.
type InjectedMessageRecord struct {
HeaderifiedText string
// InsertAfter is the number of SDK output messages that existed before
// this injection. Used to determine the correct insertion position when
// interleaving injected messages into the persisted round.
InsertAfter int
}
// ChatResponse is the output of a non-streaming chat call.
type ChatResponse struct {
Messages []ModelMessage `json:"messages"`