diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 9b0fb9ae..f80bda6d 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -553,6 +553,7 @@ func provideChannelRouter( processor.SetACLService(aclService) processor.SetMediaService(mediaService) processor.SetStreamObserver(local.NewRouteHubBroadcaster(hub)) + processor.SetDispatcher(inbound.NewRouteDispatcher(log)) processor.SetTtsService(ttsService, &settingsTtsModelResolver{settings: settingsService}) processor.SetCommandHandler(command.NewHandler( log, diff --git a/cmd/memoh/serve.go b/cmd/memoh/serve.go index a61c3dd0..ce45f56c 100644 --- a/cmd/memoh/serve.go +++ b/cmd/memoh/serve.go @@ -457,6 +457,7 @@ func provideChannelRouter(log *slog.Logger, registry *channel.Registry, hub *loc processor.SetACLService(aclService) processor.SetMediaService(mediaService) processor.SetStreamObserver(local.NewRouteHubBroadcaster(hub)) + processor.SetDispatcher(inbound.NewRouteDispatcher(log)) processor.SetTtsService(ttsService, &settingsTtsModelResolver{settings: settingsService}) processor.SetCommandHandler(command.NewHandler( log, diff --git a/internal/agent/agent.go b/internal/agent/agent.go index f11af9a8..c31a7cf3 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -98,6 +98,47 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv if readMediaState != nil { prepareStep = readMediaState.prepareStep } + + initialMsgCount := len(cfg.Messages) + + if cfg.InjectCh != nil { + basePrepare := prepareStep + prepareStep = func(p *sdk.GenerateParams) *sdk.GenerateParams { + if basePrepare != nil { + if override := basePrepare(p); override != nil { + p = override + } + } + for { + select { + case injected, ok := <-cfg.InjectCh: + if !ok { + break + } + text := strings.TrimSpace(injected.HeaderifiedText) + if text == "" { + text = strings.TrimSpace(injected.Text) + } + if text != "" { + insertAfter := len(p.Messages) - initialMsgCount + p.Messages = append(p.Messages, sdk.UserMessage(text)) + if cfg.InjectedRecorder != nil { + cfg.InjectedRecorder(text, insertAfter) + } + a.logger.Info("injected user message into agent stream", + slog.String("bot_id", cfg.Identity.BotID), + slog.Int("insert_after", insertAfter), + ) + } + continue + default: + } + break + } + return p + } + } + opts := a.buildGenerateOptions(cfg, tools, prepareStep) streamResult, err := a.client.StreamText(ctx, opts...) diff --git a/internal/agent/types.go b/internal/agent/types.go index 3b7956c5..9ba5c52d 100644 --- a/internal/agent/types.go +++ b/internal/agent/types.go @@ -46,6 +46,13 @@ type LoopDetectionConfig struct { Enabled bool } +// InjectMessage carries a user message to be injected into a running agent +// stream between tool rounds via the PrepareStep hook. +type InjectMessage struct { + Text string + HeaderifiedText string +} + // RunConfig holds everything needed for a single agent invocation. type RunConfig struct { Model *sdk.Model @@ -60,6 +67,17 @@ type RunConfig struct { Identity SessionContext Skills []SkillEntry LoopDetection LoopDetectionConfig + + // InjectCh receives user messages to inject between tool rounds. + // When non-nil, a PrepareStep hook drains this channel and appends + // user messages to the conversation before the next LLM call. + InjectCh <-chan InjectMessage + + // InjectedRecorder is called each time a message is injected via + // PrepareStep, recording the headerified text and the number of SDK + // output messages that preceded the injection. Used by the resolver + // to interleave injected messages at the correct position in storeRound. + InjectedRecorder func(headerifiedText string, insertAfter int) } // GenerateResult holds the result of a non-streaming agent invocation. diff --git a/internal/channel/inbound/channel.go b/internal/channel/inbound/channel.go index 04cb2525..68f9aa44 100644 --- a/internal/channel/inbound/channel.go +++ b/internal/channel/inbound/channel.go @@ -98,6 +98,7 @@ type ChannelInboundProcessor struct { tokenTTL time.Duration identity *IdentityResolver policy PolicyService + dispatcher *RouteDispatcher acl chatACL observer channel.StreamObserver ttsService ttsSynthesizer @@ -206,6 +207,14 @@ func (p *ChannelInboundProcessor) SetCommandHandler(handler *command.Handler) { p.commandHandler = handler } +// SetDispatcher configures the per-route message dispatcher for inject/queue/parallel modes. +func (p *ChannelInboundProcessor) SetDispatcher(dispatcher *RouteDispatcher) { + if p == nil { + return + } + p.dispatcher = dispatcher +} + // HandleInbound processes an inbound channel message through identity resolution and chat gateway. func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, sender channel.StreamReplySender) error { if p.runner == nil { @@ -270,7 +279,9 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel return p.handleNewSessionCommand(ctx, cfg, msg, sender, identity) } - if p.commandHandler != nil && p.commandHandler.IsCommand(cmdText) && isDirectedAtBot(msg) { + // Skip generic command handler for mode-prefix commands (/btw, /now, /next) + // so they pass through to mode detection below. + if p.commandHandler != nil && p.commandHandler.IsCommand(cmdText) && !IsModeCommand(cmdText) && isDirectedAtBot(msg) { reply, err := p.commandHandler.Execute(ctx, strings.TrimSpace(identity.BotID), strings.TrimSpace(identity.ChannelIdentityID), cmdText) if err != nil { reply = "Error: " + err.Error() @@ -284,6 +295,14 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel resolvedAttachments := p.ingestInboundAttachments(ctx, cfg, msg, strings.TrimSpace(identity.BotID), msg.Message.Attachments) attachments := mapChannelToChatAttachments(resolvedAttachments) text = buildInboundQuery(msg.Message, attachments) + + // Detect inbound mode from message prefix (/btw, /now, /next). + // Only applies to non-local channels; WebUI always uses the default flow. + // Must run after buildInboundQuery so the prefix is stripped from the final text. + inboundMode := ModeInject + if !isLocalChannelType(msg.Channel) { + inboundMode, text = DetectMode(text) + } threadID := extractThreadID(msg) // Resolve or create the route via channel_routes. @@ -374,6 +393,66 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel return nil } + routeID := strings.TrimSpace(resolved.RouteID) + + // --- Dispatcher-based mode handling (inject / queue) --- + // For non-parallel modes, when a route already has an active agent stream, + // short-circuit here instead of starting a new stream. + if p.dispatcher != nil && !isLocalChannelType(msg.Channel) && inboundMode != ModeParallel { + if p.dispatcher.IsActive(routeID) { + headerifiedText := flow.FormatUserHeader( + strings.TrimSpace(msg.Message.ID), + strings.TrimSpace(identity.ChannelIdentityID), + strings.TrimSpace(identity.DisplayName), + msg.Channel.String(), + strings.TrimSpace(msg.Conversation.Type), + strings.TrimSpace(msg.Conversation.Name), + collectAttachmentPaths(attachments), + time.Now().UTC(), + "", + text, + ) + + switch inboundMode { + case ModeInject: + // Don't persist here — the injected message will be interleaved + // at the correct position within the round by + // interleaveInjectedMessages in storeRound. + injected := p.dispatcher.Inject(routeID, InjectMessage{ + Text: text, + Attachments: attachments, + HeaderifiedText: headerifiedText, + }) + if injected { + p.sendModeConfirmation(ctx, sender, msg, identity, "inject") + } else { + if p.logger != nil { + p.logger.Warn("inject failed (channel full), falling through to new stream", + slog.String("route_id", routeID)) + } + goto startStream + } + return nil + + case ModeQueue: + p.persistPassiveMessage(ctx, identity, msg, text, attachments, routeID, sessionID) + p.dispatcher.Enqueue(routeID, QueuedTask{ + Ctx: ctx, + Cfg: cfg, + Msg: msg, + Sender: sender, + Ident: identity, + Text: text, + Attachs: attachments, + }) + p.sendModeConfirmation(ctx, sender, msg, identity, "queue") + return nil + } + } + } + +startStream: + // Issue chat token for reply routing. chatToken := "" if p.jwtSecret != "" && strings.TrimSpace(msg.ReplyTarget) != "" { @@ -516,6 +595,33 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel return result } + // Mark this route as active in the dispatcher so subsequent messages + // can be injected or queued. Produces the inject channel for this stream. + // Parallel mode (/now) skips the dispatcher entirely — it must not + // interfere with the active flag or drain the queue of another stream. + var injectCh <-chan InjectMessage + if p.dispatcher != nil && !isLocalChannelType(msg.Channel) && inboundMode != ModeParallel { + injectCh = p.dispatcher.MarkActive(routeID) + defer func() { + p.drainQueue(context.WithoutCancel(ctx), routeID) + }() + } + // Convert inbound InjectMessage channel to conversation.InjectMessage channel. + var convInjectCh chan conversation.InjectMessage + if injectCh != nil { + convInjectCh = make(chan conversation.InjectMessage, injectChBuffer) + go func() { + for im := range injectCh { + convInjectCh <- conversation.InjectMessage{ + Text: im.Text, + Attachments: im.Attachments, + HeaderifiedText: im.HeaderifiedText, + } + } + close(convInjectCh) + }() + } + chatReq := conversation.ChatRequest{ BotID: identity.BotID, ChatID: activeChatID, @@ -537,6 +643,9 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel Attachments: attachments, OutboundAssetCollector: assetCollector, } + if convInjectCh != nil { + chatReq.InjectCh = convInjectCh + } if mid, _ := msg.Metadata["model_id"].(string); strings.TrimSpace(mid) != "" { chatReq.Model = strings.TrimSpace(mid) } @@ -687,6 +796,75 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel return nil } +// sendModeConfirmation sends a lightweight acknowledgement to the user when +// their message is injected or queued rather than triggering a new stream. +func (p *ChannelInboundProcessor) sendModeConfirmation( + ctx context.Context, + _ channel.StreamReplySender, + msg channel.InboundMessage, + identity InboundIdentity, + mode string, +) { + target := strings.TrimSpace(msg.ReplyTarget) + sourceMessageID := strings.TrimSpace(msg.Message.ID) + if target == "" || sourceMessageID == "" { + return + } + if p.reactor != nil { + emoji := "👀" + if mode == "queue" { + emoji = "📋" + } + _ = p.reactor.React(ctx, strings.TrimSpace(identity.BotID), msg.Channel, channel.ReactRequest{ + Target: target, + MessageID: sourceMessageID, + Emoji: emoji, + }) + } +} + +// drainQueue marks the route as done and processes any queued tasks. +func (p *ChannelInboundProcessor) drainQueue(ctx context.Context, routeID string) { + if p.dispatcher == nil { + return + } + result := p.dispatcher.MarkDone(routeID) + + for _, fn := range result.PendingPersists { + fn(ctx) + } + + for _, task := range result.QueuedTasks { + if p.logger != nil { + p.logger.Info("processing queued task", + slog.String("route_id", routeID), + slog.String("query", strings.TrimSpace(task.Text)), + ) + } + if err := p.HandleInbound(ctx, task.Cfg, task.Msg, task.Sender); err != nil { //nolint:contextcheck // ctx is already WithoutCancel from the defer caller + if p.logger != nil { + p.logger.Error("queued task processing failed", + slog.String("route_id", routeID), + slog.Any("error", err), + ) + } + } + } +} + +func collectAttachmentPaths(attachments []conversation.ChatAttachment) []string { + if len(attachments) == 0 { + return nil + } + paths := make([]string, 0, len(attachments)) + for _, att := range attachments { + if p := strings.TrimSpace(att.Path); p != "" { + paths = append(paths, p) + } + } + return paths +} + func shouldTriggerAssistantResponse(msg channel.InboundMessage) bool { if isDirectConversationType(msg.Conversation.Type) { return true diff --git a/internal/channel/inbound/dispatcher.go b/internal/channel/inbound/dispatcher.go new file mode 100644 index 00000000..ba46e316 --- /dev/null +++ b/internal/channel/inbound/dispatcher.go @@ -0,0 +1,308 @@ +package inbound + +import ( + "context" + "log/slog" + "strings" + "sync" + "time" + + "github.com/memohai/memoh/internal/channel" + "github.com/memohai/memoh/internal/conversation" +) + +// InboundMode determines how a new inbound message is handled when an agent +// stream is already active for the same route. +type InboundMode int + +const ( + // ModeInject (default, command /btw) injects the message into the active + // agent stream via the PrepareStep hook so the LLM sees it between tool + // rounds. When no stream is active, starts one normally. + ModeInject InboundMode = iota + // ModeParallel (command /now) starts a new agent stream immediately, + // running concurrently with any existing stream. + ModeParallel + // ModeQueue (command /next) queues the message and processes it after the + // current agent stream completes. + ModeQueue +) + +// InjectMessage carries a user message to be injected into a running agent stream. +type InjectMessage struct { + Text string + Attachments []conversation.ChatAttachment + // HeaderifiedText is the formatted user header text ready for SDK injection. + HeaderifiedText string +} + +// QueuedTask holds everything needed to start an agent stream for a queued message. +type QueuedTask struct { + Ctx context.Context + Cfg channel.ChannelConfig + Msg channel.InboundMessage + Sender channel.StreamReplySender + Ident InboundIdentity + Text string + Attachs []conversation.ChatAttachment +} + +// PersistFunc is a deferred persistence closure called after the active stream +// completes (and its storeRound has run), ensuring correct created_at ordering. +type PersistFunc func(ctx context.Context) + +// routeState tracks in-flight agent activity for a single route. +type routeState struct { + mu sync.Mutex + active bool + injectCh chan InjectMessage + queue []QueuedTask + pendingPersists []PersistFunc + lastUsed time.Time +} + +// RouteDispatcher manages per-route concurrency for inbound message processing. +// It decides whether a new message should be injected into an active stream, +// run in parallel, or be queued. +type RouteDispatcher struct { + mu sync.RWMutex + routes map[string]*routeState + logger *slog.Logger +} + +// NewRouteDispatcher creates a dispatcher with background cleanup. +func NewRouteDispatcher(logger *slog.Logger) *RouteDispatcher { + if logger == nil { + logger = slog.Default() + } + return &RouteDispatcher{ + routes: make(map[string]*routeState), + logger: logger.With(slog.String("component", "route_dispatcher")), + } +} + +const injectChBuffer = 16 + +func (d *RouteDispatcher) getOrCreate(routeID string) *routeState { + d.mu.RLock() + rs, ok := d.routes[routeID] + d.mu.RUnlock() + if ok { + return rs + } + d.mu.Lock() + defer d.mu.Unlock() + if rs, ok = d.routes[routeID]; ok { + return rs + } + rs = &routeState{ + injectCh: make(chan InjectMessage, injectChBuffer), + lastUsed: time.Now(), + } + d.routes[routeID] = rs + return rs +} + +// IsActive reports whether the given route has an active agent stream. +func (d *RouteDispatcher) IsActive(routeID string) bool { + routeID = strings.TrimSpace(routeID) + if routeID == "" { + return false + } + rs := d.getOrCreate(routeID) + rs.mu.Lock() + defer rs.mu.Unlock() + return rs.active +} + +// MarkActive marks a route as having an active stream and returns the inject +// channel that the agent should drain via PrepareStep. +func (d *RouteDispatcher) MarkActive(routeID string) <-chan InjectMessage { + routeID = strings.TrimSpace(routeID) + if routeID == "" { + return nil + } + rs := d.getOrCreate(routeID) + rs.mu.Lock() + defer rs.mu.Unlock() + rs.active = true + rs.lastUsed = time.Now() + return rs.injectCh +} + +// MarkDoneResult holds the data returned when a route transitions from active to idle. +type MarkDoneResult struct { + PendingPersists []PersistFunc + QueuedTasks []QueuedTask +} + +// MarkDone marks a route as idle and returns pending persist functions (to be +// called now that storeRound has completed) and any queued tasks. +func (d *RouteDispatcher) MarkDone(routeID string) MarkDoneResult { + routeID = strings.TrimSpace(routeID) + if routeID == "" { + return MarkDoneResult{} + } + rs := d.getOrCreate(routeID) + rs.mu.Lock() + defer rs.mu.Unlock() + rs.active = false + rs.lastUsed = time.Now() + + drainInjectCh(rs.injectCh) + + var persists []PersistFunc + if len(rs.pendingPersists) > 0 { + persists = rs.pendingPersists + rs.pendingPersists = nil + } + + var tasks []QueuedTask + if len(rs.queue) > 0 { + tasks = rs.queue + rs.queue = nil + } + + return MarkDoneResult{PendingPersists: persists, QueuedTasks: tasks} +} + +// AddPendingPersist records a deferred persist closure to be executed after the +// active stream completes. This ensures injected messages get a created_at +// timestamp after the triggering message's round. +func (d *RouteDispatcher) AddPendingPersist(routeID string, fn PersistFunc) { + routeID = strings.TrimSpace(routeID) + if routeID == "" || fn == nil { + return + } + rs := d.getOrCreate(routeID) + rs.mu.Lock() + defer rs.mu.Unlock() + rs.pendingPersists = append(rs.pendingPersists, fn) +} + +// Inject sends a message to the inject channel of an active route. +// Returns true if the message was accepted (route is active and channel not full). +func (d *RouteDispatcher) Inject(routeID string, msg InjectMessage) bool { + routeID = strings.TrimSpace(routeID) + if routeID == "" { + return false + } + rs := d.getOrCreate(routeID) + rs.mu.Lock() + defer rs.mu.Unlock() + if !rs.active { + return false + } + select { + case rs.injectCh <- msg: + if d.logger != nil { + d.logger.Info("message injected into active stream", + slog.String("route_id", routeID), + ) + } + return true + default: + if d.logger != nil { + d.logger.Warn("inject channel full, message dropped", + slog.String("route_id", routeID), + ) + } + return false + } +} + +// Enqueue adds a task to the route's queue for later processing. +func (d *RouteDispatcher) Enqueue(routeID string, task QueuedTask) { + routeID = strings.TrimSpace(routeID) + if routeID == "" { + return + } + rs := d.getOrCreate(routeID) + rs.mu.Lock() + defer rs.mu.Unlock() + rs.queue = append(rs.queue, task) + rs.lastUsed = time.Now() + if d.logger != nil { + d.logger.Info("message queued", + slog.String("route_id", routeID), + slog.Int("queue_size", len(rs.queue)), + ) + } +} + +// Cleanup removes idle route states older than maxAge. +func (d *RouteDispatcher) Cleanup(maxAge time.Duration) { + d.mu.Lock() + defer d.mu.Unlock() + cutoff := time.Now().Add(-maxAge) + for id, rs := range d.routes { + rs.mu.Lock() + idle := !rs.active && rs.lastUsed.Before(cutoff) && len(rs.queue) == 0 + rs.mu.Unlock() + if idle { + delete(d.routes, id) + } + } +} + +func drainInjectCh(ch chan InjectMessage) { + for { + select { + case <-ch: + default: + return + } + } +} + +// DetectMode parses a message prefix to determine the inbound mode. +// Returns the mode and the text with the prefix stripped. +func DetectMode(text string) (InboundMode, string) { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return ModeInject, trimmed + } + + type modePrefix struct { + prefix string + mode InboundMode + } + prefixes := []modePrefix{ + {"/now ", ModeParallel}, + {"/next ", ModeQueue}, + {"/btw ", ModeInject}, + } + lower := strings.ToLower(trimmed) + for _, mp := range prefixes { + if strings.HasPrefix(lower, mp.prefix) { + return mp.mode, strings.TrimSpace(trimmed[len(mp.prefix):]) + } + } + // Exact match without trailing text (bare command) + barePrefixes := []modePrefix{ + {"/now", ModeParallel}, + {"/next", ModeQueue}, + {"/btw", ModeInject}, + } + for _, mp := range barePrefixes { + if lower == mp.prefix { + return mp.mode, "" + } + } + return ModeInject, trimmed +} + +// IsModeCommand reports whether the text is a mode-prefix command +// (/btw, /now, /next), so the generic command handler should skip it. +func IsModeCommand(text string) bool { + trimmed := strings.ToLower(strings.TrimSpace(text)) + if trimmed == "" { + return false + } + for _, prefix := range []string{"/now", "/next", "/btw"} { + if trimmed == prefix || strings.HasPrefix(trimmed, prefix+" ") || strings.HasPrefix(trimmed, prefix+"\t") { + return true + } + } + return false +} diff --git a/internal/channel/inbound/dispatcher_test.go b/internal/channel/inbound/dispatcher_test.go new file mode 100644 index 00000000..ba642d25 --- /dev/null +++ b/internal/channel/inbound/dispatcher_test.go @@ -0,0 +1,302 @@ +package inbound + +import ( + "context" + "log/slog" + "sync" + "testing" + "time" +) + +func TestDetectMode(t *testing.T) { + tests := []struct { + input string + wantMode InboundMode + wantText string + }{ + {"hello world", ModeInject, "hello world"}, + {"/btw hello", ModeInject, "hello"}, + {"/now hello", ModeParallel, "hello"}, + {"/next hello", ModeQueue, "hello"}, + {"/BTW hello", ModeInject, "hello"}, + {"/NOW hello", ModeParallel, "hello"}, + {"/NEXT hello", ModeQueue, "hello"}, + {"/now", ModeParallel, ""}, + {"/next", ModeQueue, ""}, + {"/btw", ModeInject, ""}, + {" /now hello ", ModeParallel, "hello"}, + {"/unknown hello", ModeInject, "/unknown hello"}, + {"", ModeInject, ""}, + {"/new session", ModeInject, "/new session"}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + mode, text := DetectMode(tt.input) + if mode != tt.wantMode { + t.Errorf("DetectMode(%q) mode = %d, want %d", tt.input, mode, tt.wantMode) + } + if text != tt.wantText { + t.Errorf("DetectMode(%q) text = %q, want %q", tt.input, text, tt.wantText) + } + }) + } +} + +func TestIsModeCommand(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"/btw hello", true}, + {"/now hello", true}, + {"/next hello", true}, + {"/btw", true}, + {"/now", true}, + {"/next", true}, + {"/new", false}, + {"/fs list", false}, + {"hello", false}, + {"", false}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := IsModeCommand(tt.input) + if got != tt.want { + t.Errorf("IsModeCommand(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestRouteDispatcher_InjectWhenActive(t *testing.T) { + d := NewRouteDispatcher(slog.Default()) + routeID := "route-1" + + if d.IsActive(routeID) { + t.Fatal("expected route to be inactive initially") + } + + injectCh := d.MarkActive(routeID) + if injectCh == nil { + t.Fatal("expected non-nil inject channel") + } + if !d.IsActive(routeID) { + t.Fatal("expected route to be active after MarkActive") + } + + msg := InjectMessage{Text: "hello", HeaderifiedText: "[User] hello"} + if !d.Inject(routeID, msg) { + t.Fatal("expected inject to succeed when route is active") + } + + select { + case got := <-injectCh: + if got.Text != "hello" { + t.Errorf("got text %q, want %q", got.Text, "hello") + } + default: + t.Fatal("expected message on inject channel") + } +} + +func TestRouteDispatcher_InjectWhenInactive(t *testing.T) { + d := NewRouteDispatcher(slog.Default()) + routeID := "route-1" + + msg := InjectMessage{Text: "hello"} + if d.Inject(routeID, msg) { + t.Fatal("expected inject to fail when route is inactive") + } +} + +func TestRouteDispatcher_QueueAndDrain(t *testing.T) { + d := NewRouteDispatcher(slog.Default()) + routeID := "route-1" + + d.MarkActive(routeID) + + d.Enqueue(routeID, QueuedTask{Text: "task-1"}) + d.Enqueue(routeID, QueuedTask{Text: "task-2"}) + + result := d.MarkDone(routeID) + if len(result.QueuedTasks) != 2 { + t.Fatalf("expected 2 queued tasks, got %d", len(result.QueuedTasks)) + } + if result.QueuedTasks[0].Text != "task-1" || result.QueuedTasks[1].Text != "task-2" { + t.Errorf("unexpected task order: %v", result.QueuedTasks) + } + if d.IsActive(routeID) { + t.Fatal("expected route to be inactive after MarkDone") + } +} + +func TestRouteDispatcher_MarkDoneReturnsNilWhenEmpty(t *testing.T) { + d := NewRouteDispatcher(slog.Default()) + routeID := "route-1" + + d.MarkActive(routeID) + result := d.MarkDone(routeID) + if result.QueuedTasks != nil { + t.Fatalf("expected nil queued tasks, got %v", result.QueuedTasks) + } + if result.PendingPersists != nil { + t.Fatalf("expected nil pending persists, got %v", result.PendingPersists) + } +} + +func TestRouteDispatcher_ConcurrentInject(t *testing.T) { + d := NewRouteDispatcher(slog.Default()) + routeID := "route-1" + + injectCh := d.MarkActive(routeID) + + const numMessages = 10 + var wg sync.WaitGroup + wg.Add(numMessages) + for i := 0; i < numMessages; i++ { + go func() { + defer wg.Done() + d.Inject(routeID, InjectMessage{Text: "msg"}) + }() + } + wg.Wait() + + count := 0 + for { + select { + case <-injectCh: + count++ + default: + goto done + } + } +done: + if count != numMessages { + t.Errorf("expected %d messages, got %d", numMessages, count) + } +} + +func TestRouteDispatcher_ParallelBypass(t *testing.T) { + d := NewRouteDispatcher(slog.Default()) + routeID := "route-1" + + d.MarkActive(routeID) + + // In parallel mode, the caller does not interact with the dispatcher + // at all — it just starts a new stream directly. Verify the route + // stays active without interference. + if !d.IsActive(routeID) { + t.Fatal("expected route to still be active") + } + + d.MarkDone(routeID) + if d.IsActive(routeID) { + t.Fatal("expected route to be inactive after MarkDone") + } +} + +func TestRouteDispatcher_Cleanup(t *testing.T) { + d := NewRouteDispatcher(slog.Default()) + + d.MarkActive("route-1") + d.MarkDone("route-1") + + d.MarkActive("route-2") + + d.mu.RLock() + initialCount := len(d.routes) + d.mu.RUnlock() + if initialCount != 2 { + t.Fatalf("expected 2 routes, got %d", initialCount) + } + + d.Cleanup(0) + + d.mu.RLock() + afterCount := len(d.routes) + d.mu.RUnlock() + + // route-1 is idle → cleaned up; route-2 is active → kept + if afterCount != 1 { + t.Fatalf("expected 1 route after cleanup, got %d", afterCount) + } + if d.IsActive("route-2") != true { + t.Fatal("expected route-2 to still be active") + } +} + +func TestRouteDispatcher_InjectChannelFull(t *testing.T) { + d := NewRouteDispatcher(slog.Default()) + routeID := "route-1" + + d.MarkActive(routeID) + + // Fill the inject channel to capacity + for i := 0; i < injectChBuffer; i++ { + if !d.Inject(routeID, InjectMessage{Text: "fill"}) { + t.Fatalf("expected inject %d to succeed", i) + } + } + + // Next inject should fail (channel full) + if d.Inject(routeID, InjectMessage{Text: "overflow"}) { + t.Fatal("expected inject to fail when channel is full") + } +} + +func TestRouteDispatcher_QueueWhenInactive(t *testing.T) { + d := NewRouteDispatcher(slog.Default()) + routeID := "route-1" + + // Enqueue without marking active — still stores in queue + d.Enqueue(routeID, QueuedTask{Text: "task-1"}) + + // MarkActive then MarkDone should return the queued task + d.MarkActive(routeID) + result := d.MarkDone(routeID) + if len(result.QueuedTasks) != 1 { + t.Fatalf("expected 1 queued task, got %d", len(result.QueuedTasks)) + } +} + +func TestRouteDispatcher_MultipleMarkActive(t *testing.T) { + d := NewRouteDispatcher(slog.Default()) + routeID := "route-1" + + ch1 := d.MarkActive(routeID) + ch2 := d.MarkActive(routeID) + + if ch1 == nil || ch2 == nil { + t.Fatal("expected non-nil channels") + } + + _ = time.Now() +} + +func TestRouteDispatcher_PendingPersistOrder(t *testing.T) { + d := NewRouteDispatcher(slog.Default()) + routeID := "route-1" + + d.MarkActive(routeID) + + var order []string + d.AddPendingPersist(routeID, func(_ context.Context) { + order = append(order, "B") + }) + d.AddPendingPersist(routeID, func(_ context.Context) { + order = append(order, "C") + }) + + result := d.MarkDone(routeID) + if len(result.PendingPersists) != 2 { + t.Fatalf("expected 2 pending persists, got %d", len(result.PendingPersists)) + } + + // Execute persists — they should run in insertion order (B then C) + for _, fn := range result.PendingPersists { + fn(context.Background()) + } + if len(order) != 2 || order[0] != "B" || order[1] != "C" { + t.Errorf("expected [B C], got %v", order) + } +} diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go index edbe2afa..3d5bdb2c 100644 --- a/internal/conversation/flow/resolver.go +++ b/internal/conversation/flow/resolver.go @@ -12,6 +12,7 @@ import ( "sort" "strconv" "strings" + "sync" "time" sdk "github.com/memohai/twilight-ai/sdk" @@ -136,10 +137,11 @@ type usageInfo struct { } type resolvedContext struct { - runConfig agentpkg.RunConfig - model models.GetResponse - provider sqlc.LlmProvider - query string // headerified query + runConfig agentpkg.RunConfig + model models.GetResponse + provider sqlc.LlmProvider + query string // headerified query + injectedRecords *[]conversation.InjectedMessageRecord } func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (resolvedContext, error) { @@ -292,7 +294,40 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r LoopDetection: agentpkg.LoopDetectionConfig{Enabled: loopDetectionEnabled}, } - return resolvedContext{runConfig: runCfg, model: chatModel, provider: provider, query: headerifiedQuery}, nil + var injectedRecords *[]conversation.InjectedMessageRecord + if req.InjectCh != nil { + agentInjectCh := make(chan agentpkg.InjectMessage, cap(req.InjectCh)) + go func() { + for msg := range req.InjectCh { + agentInjectCh <- agentpkg.InjectMessage{ + Text: msg.Text, + HeaderifiedText: msg.HeaderifiedText, + } + } + close(agentInjectCh) + }() + runCfg.InjectCh = agentInjectCh + + records := make([]conversation.InjectedMessageRecord, 0) + injectedRecords = &records + var recMu sync.Mutex + runCfg.InjectedRecorder = func(headerifiedText string, insertAfter int) { + recMu.Lock() + *injectedRecords = append(*injectedRecords, conversation.InjectedMessageRecord{ + HeaderifiedText: headerifiedText, + InsertAfter: insertAfter, + }) + recMu.Unlock() + } + } + + return resolvedContext{ + runConfig: runCfg, + model: chatModel, + provider: provider, + query: headerifiedQuery, + injectedRecords: injectedRecords, + }, nil } // Chat sends a synchronous chat request and stores the result. diff --git a/internal/conversation/flow/resolver_stream.go b/internal/conversation/flow/resolver_stream.go index 7f49ed2a..db9165b1 100644 --- a/internal/conversation/flow/resolver_stream.go +++ b/internal/conversation/flow/resolver_stream.go @@ -162,6 +162,10 @@ func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequ outputMessages := sdkMessagesToModelMessages(sdkMsgs) roundMessages := prependUserMessage(req.Query, outputMessages) + if rc.injectedRecords != nil && len(*rc.injectedRecords) > 0 { + roundMessages = interleaveInjectedMessages(roundMessages, *rc.injectedRecords) + } + if err := r.storeRound(ctx, req, roundMessages, modelID); err != nil { return false, err } @@ -173,6 +177,37 @@ func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequ return true, nil } +// interleaveInjectedMessages inserts injected user messages at their correct +// positions within the round. Each record's InsertAfter value indicates how +// many output messages preceded the injection. +// +// round layout: [user_A, output_0, output_1, ..., output_N] +// InsertAfter=K → insert after round[K] (i.e. after the K-th output message). +func interleaveInjectedMessages(round []conversation.ModelMessage, injections []conversation.InjectedMessageRecord) []conversation.ModelMessage { + if len(injections) == 0 { + return round + } + result := make([]conversation.ModelMessage, 0, len(round)+len(injections)) + injIdx := 0 + for i, msg := range round { + result = append(result, msg) + for injIdx < len(injections) && injections[injIdx].InsertAfter == i { + result = append(result, conversation.ModelMessage{ + Role: "user", + Content: conversation.NewTextContent(injections[injIdx].HeaderifiedText), + }) + injIdx++ + } + } + for ; injIdx < len(injections); injIdx++ { + result = append(result, conversation.ModelMessage{ + Role: "user", + Content: conversation.NewTextContent(injections[injIdx].HeaderifiedText), + }) + } + return result +} + func extractInputTokensFromUsage(raw json.RawMessage) int { if len(raw) == 0 { return 0 diff --git a/internal/conversation/types.go b/internal/conversation/types.go index 6c61725d..03b75ed1 100644 --- a/internal/conversation/types.go +++ b/internal/conversation/types.go @@ -241,6 +241,10 @@ type ChatRequest struct { // Set by the inbound channel processor; called by the resolver at persist time. OutboundAssetCollector func() []OutboundAssetRef `json:"-"` + // InjectCh receives user messages to inject into the active agent stream + // between tool rounds via the PrepareStep hook. Nil means no injection. + InjectCh <-chan InjectMessage `json:"-"` + Query string `json:"query"` Model string `json:"model,omitempty"` Provider string `json:"provider,omitempty"` @@ -251,6 +255,24 @@ type ChatRequest struct { Attachments []ChatAttachment `json:"attachments,omitempty"` } +// InjectMessage carries a user message to be injected into a running agent +// stream between tool rounds. +type InjectMessage struct { + Text string + Attachments []ChatAttachment + HeaderifiedText string +} + +// InjectedMessageRecord records a message that was injected via PrepareStep, +// together with its position in the output message sequence. +type InjectedMessageRecord struct { + HeaderifiedText string + // InsertAfter is the number of SDK output messages that existed before + // this injection. Used to determine the correct insertion position when + // interleaving injected messages into the persisted round. + InsertAfter int +} + // ChatResponse is the output of a non-streaming chat call. type ChatResponse struct { Messages []ModelMessage `json:"messages"`