mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat: add per-route message dispatch modes (inject/parallel/queue)
Introduce three inbound message handling modes for channel adapters: - inject (default, /btw): when a route has an active agent stream, inject the new user message into the running stream via the SDK's PrepareStep hook between tool rounds. The message is interleaved at the correct position in the persisted round. - parallel (/now): start a new agent stream immediately, running concurrently with any existing stream (preserves current behavior). - queue (/next): enqueue the message and process it after the current stream completes. Key components: - RouteDispatcher: per-route state management with inject channel, task queue, and active-stream tracking. - PrepareStep integration: drains inject channel between tool rounds, records insertion position via InjectedRecorder for correct persistence ordering. - interleaveInjectedMessages: inserts injected user messages at their actual injection position within the persisted message round. - Parallel mode isolation: /now streams do not interact with the dispatcher, preventing them from clearing another stream's active state.
This commit is contained in:
@@ -98,6 +98,47 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
if readMediaState != nil {
|
||||
prepareStep = readMediaState.prepareStep
|
||||
}
|
||||
|
||||
initialMsgCount := len(cfg.Messages)
|
||||
|
||||
if cfg.InjectCh != nil {
|
||||
basePrepare := prepareStep
|
||||
prepareStep = func(p *sdk.GenerateParams) *sdk.GenerateParams {
|
||||
if basePrepare != nil {
|
||||
if override := basePrepare(p); override != nil {
|
||||
p = override
|
||||
}
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case injected, ok := <-cfg.InjectCh:
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
text := strings.TrimSpace(injected.HeaderifiedText)
|
||||
if text == "" {
|
||||
text = strings.TrimSpace(injected.Text)
|
||||
}
|
||||
if text != "" {
|
||||
insertAfter := len(p.Messages) - initialMsgCount
|
||||
p.Messages = append(p.Messages, sdk.UserMessage(text))
|
||||
if cfg.InjectedRecorder != nil {
|
||||
cfg.InjectedRecorder(text, insertAfter)
|
||||
}
|
||||
a.logger.Info("injected user message into agent stream",
|
||||
slog.String("bot_id", cfg.Identity.BotID),
|
||||
slog.Int("insert_after", insertAfter),
|
||||
)
|
||||
}
|
||||
continue
|
||||
default:
|
||||
}
|
||||
break
|
||||
}
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
opts := a.buildGenerateOptions(cfg, tools, prepareStep)
|
||||
|
||||
streamResult, err := a.client.StreamText(ctx, opts...)
|
||||
|
||||
@@ -46,6 +46,13 @@ type LoopDetectionConfig struct {
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// InjectMessage carries a user message to be injected into a running agent
|
||||
// stream between tool rounds via the PrepareStep hook.
|
||||
type InjectMessage struct {
|
||||
Text string
|
||||
HeaderifiedText string
|
||||
}
|
||||
|
||||
// RunConfig holds everything needed for a single agent invocation.
|
||||
type RunConfig struct {
|
||||
Model *sdk.Model
|
||||
@@ -60,6 +67,17 @@ type RunConfig struct {
|
||||
Identity SessionContext
|
||||
Skills []SkillEntry
|
||||
LoopDetection LoopDetectionConfig
|
||||
|
||||
// InjectCh receives user messages to inject between tool rounds.
|
||||
// When non-nil, a PrepareStep hook drains this channel and appends
|
||||
// user messages to the conversation before the next LLM call.
|
||||
InjectCh <-chan InjectMessage
|
||||
|
||||
// InjectedRecorder is called each time a message is injected via
|
||||
// PrepareStep, recording the headerified text and the number of SDK
|
||||
// output messages that preceded the injection. Used by the resolver
|
||||
// to interleave injected messages at the correct position in storeRound.
|
||||
InjectedRecorder func(headerifiedText string, insertAfter int)
|
||||
}
|
||||
|
||||
// GenerateResult holds the result of a non-streaming agent invocation.
|
||||
|
||||
@@ -98,6 +98,7 @@ type ChannelInboundProcessor struct {
|
||||
tokenTTL time.Duration
|
||||
identity *IdentityResolver
|
||||
policy PolicyService
|
||||
dispatcher *RouteDispatcher
|
||||
acl chatACL
|
||||
observer channel.StreamObserver
|
||||
ttsService ttsSynthesizer
|
||||
@@ -206,6 +207,14 @@ func (p *ChannelInboundProcessor) SetCommandHandler(handler *command.Handler) {
|
||||
p.commandHandler = handler
|
||||
}
|
||||
|
||||
// SetDispatcher configures the per-route message dispatcher for inject/queue/parallel modes.
|
||||
func (p *ChannelInboundProcessor) SetDispatcher(dispatcher *RouteDispatcher) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
p.dispatcher = dispatcher
|
||||
}
|
||||
|
||||
// HandleInbound processes an inbound channel message through identity resolution and chat gateway.
|
||||
func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, sender channel.StreamReplySender) error {
|
||||
if p.runner == nil {
|
||||
@@ -270,7 +279,9 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
|
||||
return p.handleNewSessionCommand(ctx, cfg, msg, sender, identity)
|
||||
}
|
||||
|
||||
if p.commandHandler != nil && p.commandHandler.IsCommand(cmdText) && isDirectedAtBot(msg) {
|
||||
// Skip generic command handler for mode-prefix commands (/btw, /now, /next)
|
||||
// so they pass through to mode detection below.
|
||||
if p.commandHandler != nil && p.commandHandler.IsCommand(cmdText) && !IsModeCommand(cmdText) && isDirectedAtBot(msg) {
|
||||
reply, err := p.commandHandler.Execute(ctx, strings.TrimSpace(identity.BotID), strings.TrimSpace(identity.ChannelIdentityID), cmdText)
|
||||
if err != nil {
|
||||
reply = "Error: " + err.Error()
|
||||
@@ -284,6 +295,14 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
|
||||
resolvedAttachments := p.ingestInboundAttachments(ctx, cfg, msg, strings.TrimSpace(identity.BotID), msg.Message.Attachments)
|
||||
attachments := mapChannelToChatAttachments(resolvedAttachments)
|
||||
text = buildInboundQuery(msg.Message, attachments)
|
||||
|
||||
// Detect inbound mode from message prefix (/btw, /now, /next).
|
||||
// Only applies to non-local channels; WebUI always uses the default flow.
|
||||
// Must run after buildInboundQuery so the prefix is stripped from the final text.
|
||||
inboundMode := ModeInject
|
||||
if !isLocalChannelType(msg.Channel) {
|
||||
inboundMode, text = DetectMode(text)
|
||||
}
|
||||
threadID := extractThreadID(msg)
|
||||
|
||||
// Resolve or create the route via channel_routes.
|
||||
@@ -374,6 +393,66 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
|
||||
return nil
|
||||
}
|
||||
|
||||
routeID := strings.TrimSpace(resolved.RouteID)
|
||||
|
||||
// --- Dispatcher-based mode handling (inject / queue) ---
|
||||
// For non-parallel modes, when a route already has an active agent stream,
|
||||
// short-circuit here instead of starting a new stream.
|
||||
if p.dispatcher != nil && !isLocalChannelType(msg.Channel) && inboundMode != ModeParallel {
|
||||
if p.dispatcher.IsActive(routeID) {
|
||||
headerifiedText := flow.FormatUserHeader(
|
||||
strings.TrimSpace(msg.Message.ID),
|
||||
strings.TrimSpace(identity.ChannelIdentityID),
|
||||
strings.TrimSpace(identity.DisplayName),
|
||||
msg.Channel.String(),
|
||||
strings.TrimSpace(msg.Conversation.Type),
|
||||
strings.TrimSpace(msg.Conversation.Name),
|
||||
collectAttachmentPaths(attachments),
|
||||
time.Now().UTC(),
|
||||
"",
|
||||
text,
|
||||
)
|
||||
|
||||
switch inboundMode {
|
||||
case ModeInject:
|
||||
// Don't persist here — the injected message will be interleaved
|
||||
// at the correct position within the round by
|
||||
// interleaveInjectedMessages in storeRound.
|
||||
injected := p.dispatcher.Inject(routeID, InjectMessage{
|
||||
Text: text,
|
||||
Attachments: attachments,
|
||||
HeaderifiedText: headerifiedText,
|
||||
})
|
||||
if injected {
|
||||
p.sendModeConfirmation(ctx, sender, msg, identity, "inject")
|
||||
} else {
|
||||
if p.logger != nil {
|
||||
p.logger.Warn("inject failed (channel full), falling through to new stream",
|
||||
slog.String("route_id", routeID))
|
||||
}
|
||||
goto startStream
|
||||
}
|
||||
return nil
|
||||
|
||||
case ModeQueue:
|
||||
p.persistPassiveMessage(ctx, identity, msg, text, attachments, routeID, sessionID)
|
||||
p.dispatcher.Enqueue(routeID, QueuedTask{
|
||||
Ctx: ctx,
|
||||
Cfg: cfg,
|
||||
Msg: msg,
|
||||
Sender: sender,
|
||||
Ident: identity,
|
||||
Text: text,
|
||||
Attachs: attachments,
|
||||
})
|
||||
p.sendModeConfirmation(ctx, sender, msg, identity, "queue")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
startStream:
|
||||
|
||||
// Issue chat token for reply routing.
|
||||
chatToken := ""
|
||||
if p.jwtSecret != "" && strings.TrimSpace(msg.ReplyTarget) != "" {
|
||||
@@ -516,6 +595,33 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
|
||||
return result
|
||||
}
|
||||
|
||||
// Mark this route as active in the dispatcher so subsequent messages
|
||||
// can be injected or queued. Produces the inject channel for this stream.
|
||||
// Parallel mode (/now) skips the dispatcher entirely — it must not
|
||||
// interfere with the active flag or drain the queue of another stream.
|
||||
var injectCh <-chan InjectMessage
|
||||
if p.dispatcher != nil && !isLocalChannelType(msg.Channel) && inboundMode != ModeParallel {
|
||||
injectCh = p.dispatcher.MarkActive(routeID)
|
||||
defer func() {
|
||||
p.drainQueue(context.WithoutCancel(ctx), routeID)
|
||||
}()
|
||||
}
|
||||
// Convert inbound InjectMessage channel to conversation.InjectMessage channel.
|
||||
var convInjectCh chan conversation.InjectMessage
|
||||
if injectCh != nil {
|
||||
convInjectCh = make(chan conversation.InjectMessage, injectChBuffer)
|
||||
go func() {
|
||||
for im := range injectCh {
|
||||
convInjectCh <- conversation.InjectMessage{
|
||||
Text: im.Text,
|
||||
Attachments: im.Attachments,
|
||||
HeaderifiedText: im.HeaderifiedText,
|
||||
}
|
||||
}
|
||||
close(convInjectCh)
|
||||
}()
|
||||
}
|
||||
|
||||
chatReq := conversation.ChatRequest{
|
||||
BotID: identity.BotID,
|
||||
ChatID: activeChatID,
|
||||
@@ -537,6 +643,9 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
|
||||
Attachments: attachments,
|
||||
OutboundAssetCollector: assetCollector,
|
||||
}
|
||||
if convInjectCh != nil {
|
||||
chatReq.InjectCh = convInjectCh
|
||||
}
|
||||
if mid, _ := msg.Metadata["model_id"].(string); strings.TrimSpace(mid) != "" {
|
||||
chatReq.Model = strings.TrimSpace(mid)
|
||||
}
|
||||
@@ -687,6 +796,75 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendModeConfirmation sends a lightweight acknowledgement to the user when
|
||||
// their message is injected or queued rather than triggering a new stream.
|
||||
func (p *ChannelInboundProcessor) sendModeConfirmation(
|
||||
ctx context.Context,
|
||||
_ channel.StreamReplySender,
|
||||
msg channel.InboundMessage,
|
||||
identity InboundIdentity,
|
||||
mode string,
|
||||
) {
|
||||
target := strings.TrimSpace(msg.ReplyTarget)
|
||||
sourceMessageID := strings.TrimSpace(msg.Message.ID)
|
||||
if target == "" || sourceMessageID == "" {
|
||||
return
|
||||
}
|
||||
if p.reactor != nil {
|
||||
emoji := "👀"
|
||||
if mode == "queue" {
|
||||
emoji = "📋"
|
||||
}
|
||||
_ = p.reactor.React(ctx, strings.TrimSpace(identity.BotID), msg.Channel, channel.ReactRequest{
|
||||
Target: target,
|
||||
MessageID: sourceMessageID,
|
||||
Emoji: emoji,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// drainQueue marks the route as done and processes any queued tasks.
|
||||
func (p *ChannelInboundProcessor) drainQueue(ctx context.Context, routeID string) {
|
||||
if p.dispatcher == nil {
|
||||
return
|
||||
}
|
||||
result := p.dispatcher.MarkDone(routeID)
|
||||
|
||||
for _, fn := range result.PendingPersists {
|
||||
fn(ctx)
|
||||
}
|
||||
|
||||
for _, task := range result.QueuedTasks {
|
||||
if p.logger != nil {
|
||||
p.logger.Info("processing queued task",
|
||||
slog.String("route_id", routeID),
|
||||
slog.String("query", strings.TrimSpace(task.Text)),
|
||||
)
|
||||
}
|
||||
if err := p.HandleInbound(ctx, task.Cfg, task.Msg, task.Sender); err != nil { //nolint:contextcheck // ctx is already WithoutCancel from the defer caller
|
||||
if p.logger != nil {
|
||||
p.logger.Error("queued task processing failed",
|
||||
slog.String("route_id", routeID),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectAttachmentPaths(attachments []conversation.ChatAttachment) []string {
|
||||
if len(attachments) == 0 {
|
||||
return nil
|
||||
}
|
||||
paths := make([]string, 0, len(attachments))
|
||||
for _, att := range attachments {
|
||||
if p := strings.TrimSpace(att.Path); p != "" {
|
||||
paths = append(paths, p)
|
||||
}
|
||||
}
|
||||
return paths
|
||||
}
|
||||
|
||||
func shouldTriggerAssistantResponse(msg channel.InboundMessage) bool {
|
||||
if isDirectConversationType(msg.Conversation.Type) {
|
||||
return true
|
||||
|
||||
@@ -0,0 +1,308 @@
|
||||
package inbound
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
)
|
||||
|
||||
// InboundMode determines how a new inbound message is handled when an agent
|
||||
// stream is already active for the same route.
|
||||
type InboundMode int
|
||||
|
||||
const (
|
||||
// ModeInject (default, command /btw) injects the message into the active
|
||||
// agent stream via the PrepareStep hook so the LLM sees it between tool
|
||||
// rounds. When no stream is active, starts one normally.
|
||||
ModeInject InboundMode = iota
|
||||
// ModeParallel (command /now) starts a new agent stream immediately,
|
||||
// running concurrently with any existing stream.
|
||||
ModeParallel
|
||||
// ModeQueue (command /next) queues the message and processes it after the
|
||||
// current agent stream completes.
|
||||
ModeQueue
|
||||
)
|
||||
|
||||
// InjectMessage carries a user message to be injected into a running agent stream.
|
||||
type InjectMessage struct {
|
||||
Text string
|
||||
Attachments []conversation.ChatAttachment
|
||||
// HeaderifiedText is the formatted user header text ready for SDK injection.
|
||||
HeaderifiedText string
|
||||
}
|
||||
|
||||
// QueuedTask holds everything needed to start an agent stream for a queued message.
|
||||
type QueuedTask struct {
|
||||
Ctx context.Context
|
||||
Cfg channel.ChannelConfig
|
||||
Msg channel.InboundMessage
|
||||
Sender channel.StreamReplySender
|
||||
Ident InboundIdentity
|
||||
Text string
|
||||
Attachs []conversation.ChatAttachment
|
||||
}
|
||||
|
||||
// PersistFunc is a deferred persistence closure called after the active stream
|
||||
// completes (and its storeRound has run), ensuring correct created_at ordering.
|
||||
type PersistFunc func(ctx context.Context)
|
||||
|
||||
// routeState tracks in-flight agent activity for a single route.
|
||||
type routeState struct {
|
||||
mu sync.Mutex
|
||||
active bool
|
||||
injectCh chan InjectMessage
|
||||
queue []QueuedTask
|
||||
pendingPersists []PersistFunc
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
// RouteDispatcher manages per-route concurrency for inbound message processing.
|
||||
// It decides whether a new message should be injected into an active stream,
|
||||
// run in parallel, or be queued.
|
||||
type RouteDispatcher struct {
|
||||
mu sync.RWMutex
|
||||
routes map[string]*routeState
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewRouteDispatcher creates a dispatcher with background cleanup.
|
||||
func NewRouteDispatcher(logger *slog.Logger) *RouteDispatcher {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
return &RouteDispatcher{
|
||||
routes: make(map[string]*routeState),
|
||||
logger: logger.With(slog.String("component", "route_dispatcher")),
|
||||
}
|
||||
}
|
||||
|
||||
const injectChBuffer = 16
|
||||
|
||||
func (d *RouteDispatcher) getOrCreate(routeID string) *routeState {
|
||||
d.mu.RLock()
|
||||
rs, ok := d.routes[routeID]
|
||||
d.mu.RUnlock()
|
||||
if ok {
|
||||
return rs
|
||||
}
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
if rs, ok = d.routes[routeID]; ok {
|
||||
return rs
|
||||
}
|
||||
rs = &routeState{
|
||||
injectCh: make(chan InjectMessage, injectChBuffer),
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
d.routes[routeID] = rs
|
||||
return rs
|
||||
}
|
||||
|
||||
// IsActive reports whether the given route has an active agent stream.
|
||||
func (d *RouteDispatcher) IsActive(routeID string) bool {
|
||||
routeID = strings.TrimSpace(routeID)
|
||||
if routeID == "" {
|
||||
return false
|
||||
}
|
||||
rs := d.getOrCreate(routeID)
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
return rs.active
|
||||
}
|
||||
|
||||
// MarkActive marks a route as having an active stream and returns the inject
|
||||
// channel that the agent should drain via PrepareStep.
|
||||
func (d *RouteDispatcher) MarkActive(routeID string) <-chan InjectMessage {
|
||||
routeID = strings.TrimSpace(routeID)
|
||||
if routeID == "" {
|
||||
return nil
|
||||
}
|
||||
rs := d.getOrCreate(routeID)
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
rs.active = true
|
||||
rs.lastUsed = time.Now()
|
||||
return rs.injectCh
|
||||
}
|
||||
|
||||
// MarkDoneResult holds the data returned when a route transitions from active to idle.
|
||||
type MarkDoneResult struct {
|
||||
PendingPersists []PersistFunc
|
||||
QueuedTasks []QueuedTask
|
||||
}
|
||||
|
||||
// MarkDone marks a route as idle and returns pending persist functions (to be
|
||||
// called now that storeRound has completed) and any queued tasks.
|
||||
func (d *RouteDispatcher) MarkDone(routeID string) MarkDoneResult {
|
||||
routeID = strings.TrimSpace(routeID)
|
||||
if routeID == "" {
|
||||
return MarkDoneResult{}
|
||||
}
|
||||
rs := d.getOrCreate(routeID)
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
rs.active = false
|
||||
rs.lastUsed = time.Now()
|
||||
|
||||
drainInjectCh(rs.injectCh)
|
||||
|
||||
var persists []PersistFunc
|
||||
if len(rs.pendingPersists) > 0 {
|
||||
persists = rs.pendingPersists
|
||||
rs.pendingPersists = nil
|
||||
}
|
||||
|
||||
var tasks []QueuedTask
|
||||
if len(rs.queue) > 0 {
|
||||
tasks = rs.queue
|
||||
rs.queue = nil
|
||||
}
|
||||
|
||||
return MarkDoneResult{PendingPersists: persists, QueuedTasks: tasks}
|
||||
}
|
||||
|
||||
// AddPendingPersist records a deferred persist closure to be executed after the
|
||||
// active stream completes. This ensures injected messages get a created_at
|
||||
// timestamp after the triggering message's round.
|
||||
func (d *RouteDispatcher) AddPendingPersist(routeID string, fn PersistFunc) {
|
||||
routeID = strings.TrimSpace(routeID)
|
||||
if routeID == "" || fn == nil {
|
||||
return
|
||||
}
|
||||
rs := d.getOrCreate(routeID)
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
rs.pendingPersists = append(rs.pendingPersists, fn)
|
||||
}
|
||||
|
||||
// Inject sends a message to the inject channel of an active route.
|
||||
// Returns true if the message was accepted (route is active and channel not full).
|
||||
func (d *RouteDispatcher) Inject(routeID string, msg InjectMessage) bool {
|
||||
routeID = strings.TrimSpace(routeID)
|
||||
if routeID == "" {
|
||||
return false
|
||||
}
|
||||
rs := d.getOrCreate(routeID)
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
if !rs.active {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case rs.injectCh <- msg:
|
||||
if d.logger != nil {
|
||||
d.logger.Info("message injected into active stream",
|
||||
slog.String("route_id", routeID),
|
||||
)
|
||||
}
|
||||
return true
|
||||
default:
|
||||
if d.logger != nil {
|
||||
d.logger.Warn("inject channel full, message dropped",
|
||||
slog.String("route_id", routeID),
|
||||
)
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Enqueue adds a task to the route's queue for later processing.
|
||||
func (d *RouteDispatcher) Enqueue(routeID string, task QueuedTask) {
|
||||
routeID = strings.TrimSpace(routeID)
|
||||
if routeID == "" {
|
||||
return
|
||||
}
|
||||
rs := d.getOrCreate(routeID)
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
rs.queue = append(rs.queue, task)
|
||||
rs.lastUsed = time.Now()
|
||||
if d.logger != nil {
|
||||
d.logger.Info("message queued",
|
||||
slog.String("route_id", routeID),
|
||||
slog.Int("queue_size", len(rs.queue)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup removes idle route states older than maxAge.
|
||||
func (d *RouteDispatcher) Cleanup(maxAge time.Duration) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
cutoff := time.Now().Add(-maxAge)
|
||||
for id, rs := range d.routes {
|
||||
rs.mu.Lock()
|
||||
idle := !rs.active && rs.lastUsed.Before(cutoff) && len(rs.queue) == 0
|
||||
rs.mu.Unlock()
|
||||
if idle {
|
||||
delete(d.routes, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func drainInjectCh(ch chan InjectMessage) {
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DetectMode parses a message prefix to determine the inbound mode.
|
||||
// Returns the mode and the text with the prefix stripped.
|
||||
func DetectMode(text string) (InboundMode, string) {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if trimmed == "" {
|
||||
return ModeInject, trimmed
|
||||
}
|
||||
|
||||
type modePrefix struct {
|
||||
prefix string
|
||||
mode InboundMode
|
||||
}
|
||||
prefixes := []modePrefix{
|
||||
{"/now ", ModeParallel},
|
||||
{"/next ", ModeQueue},
|
||||
{"/btw ", ModeInject},
|
||||
}
|
||||
lower := strings.ToLower(trimmed)
|
||||
for _, mp := range prefixes {
|
||||
if strings.HasPrefix(lower, mp.prefix) {
|
||||
return mp.mode, strings.TrimSpace(trimmed[len(mp.prefix):])
|
||||
}
|
||||
}
|
||||
// Exact match without trailing text (bare command)
|
||||
barePrefixes := []modePrefix{
|
||||
{"/now", ModeParallel},
|
||||
{"/next", ModeQueue},
|
||||
{"/btw", ModeInject},
|
||||
}
|
||||
for _, mp := range barePrefixes {
|
||||
if lower == mp.prefix {
|
||||
return mp.mode, ""
|
||||
}
|
||||
}
|
||||
return ModeInject, trimmed
|
||||
}
|
||||
|
||||
// IsModeCommand reports whether the text is a mode-prefix command
|
||||
// (/btw, /now, /next), so the generic command handler should skip it.
|
||||
func IsModeCommand(text string) bool {
|
||||
trimmed := strings.ToLower(strings.TrimSpace(text))
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
for _, prefix := range []string{"/now", "/next", "/btw"} {
|
||||
if trimmed == prefix || strings.HasPrefix(trimmed, prefix+" ") || strings.HasPrefix(trimmed, prefix+"\t") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
package inbound
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDetectMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantMode InboundMode
|
||||
wantText string
|
||||
}{
|
||||
{"hello world", ModeInject, "hello world"},
|
||||
{"/btw hello", ModeInject, "hello"},
|
||||
{"/now hello", ModeParallel, "hello"},
|
||||
{"/next hello", ModeQueue, "hello"},
|
||||
{"/BTW hello", ModeInject, "hello"},
|
||||
{"/NOW hello", ModeParallel, "hello"},
|
||||
{"/NEXT hello", ModeQueue, "hello"},
|
||||
{"/now", ModeParallel, ""},
|
||||
{"/next", ModeQueue, ""},
|
||||
{"/btw", ModeInject, ""},
|
||||
{" /now hello ", ModeParallel, "hello"},
|
||||
{"/unknown hello", ModeInject, "/unknown hello"},
|
||||
{"", ModeInject, ""},
|
||||
{"/new session", ModeInject, "/new session"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
mode, text := DetectMode(tt.input)
|
||||
if mode != tt.wantMode {
|
||||
t.Errorf("DetectMode(%q) mode = %d, want %d", tt.input, mode, tt.wantMode)
|
||||
}
|
||||
if text != tt.wantText {
|
||||
t.Errorf("DetectMode(%q) text = %q, want %q", tt.input, text, tt.wantText)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsModeCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{"/btw hello", true},
|
||||
{"/now hello", true},
|
||||
{"/next hello", true},
|
||||
{"/btw", true},
|
||||
{"/now", true},
|
||||
{"/next", true},
|
||||
{"/new", false},
|
||||
{"/fs list", false},
|
||||
{"hello", false},
|
||||
{"", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := IsModeCommand(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsModeCommand(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteDispatcher_InjectWhenActive(t *testing.T) {
|
||||
d := NewRouteDispatcher(slog.Default())
|
||||
routeID := "route-1"
|
||||
|
||||
if d.IsActive(routeID) {
|
||||
t.Fatal("expected route to be inactive initially")
|
||||
}
|
||||
|
||||
injectCh := d.MarkActive(routeID)
|
||||
if injectCh == nil {
|
||||
t.Fatal("expected non-nil inject channel")
|
||||
}
|
||||
if !d.IsActive(routeID) {
|
||||
t.Fatal("expected route to be active after MarkActive")
|
||||
}
|
||||
|
||||
msg := InjectMessage{Text: "hello", HeaderifiedText: "[User] hello"}
|
||||
if !d.Inject(routeID, msg) {
|
||||
t.Fatal("expected inject to succeed when route is active")
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-injectCh:
|
||||
if got.Text != "hello" {
|
||||
t.Errorf("got text %q, want %q", got.Text, "hello")
|
||||
}
|
||||
default:
|
||||
t.Fatal("expected message on inject channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteDispatcher_InjectWhenInactive(t *testing.T) {
|
||||
d := NewRouteDispatcher(slog.Default())
|
||||
routeID := "route-1"
|
||||
|
||||
msg := InjectMessage{Text: "hello"}
|
||||
if d.Inject(routeID, msg) {
|
||||
t.Fatal("expected inject to fail when route is inactive")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteDispatcher_QueueAndDrain(t *testing.T) {
|
||||
d := NewRouteDispatcher(slog.Default())
|
||||
routeID := "route-1"
|
||||
|
||||
d.MarkActive(routeID)
|
||||
|
||||
d.Enqueue(routeID, QueuedTask{Text: "task-1"})
|
||||
d.Enqueue(routeID, QueuedTask{Text: "task-2"})
|
||||
|
||||
result := d.MarkDone(routeID)
|
||||
if len(result.QueuedTasks) != 2 {
|
||||
t.Fatalf("expected 2 queued tasks, got %d", len(result.QueuedTasks))
|
||||
}
|
||||
if result.QueuedTasks[0].Text != "task-1" || result.QueuedTasks[1].Text != "task-2" {
|
||||
t.Errorf("unexpected task order: %v", result.QueuedTasks)
|
||||
}
|
||||
if d.IsActive(routeID) {
|
||||
t.Fatal("expected route to be inactive after MarkDone")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteDispatcher_MarkDoneReturnsNilWhenEmpty(t *testing.T) {
|
||||
d := NewRouteDispatcher(slog.Default())
|
||||
routeID := "route-1"
|
||||
|
||||
d.MarkActive(routeID)
|
||||
result := d.MarkDone(routeID)
|
||||
if result.QueuedTasks != nil {
|
||||
t.Fatalf("expected nil queued tasks, got %v", result.QueuedTasks)
|
||||
}
|
||||
if result.PendingPersists != nil {
|
||||
t.Fatalf("expected nil pending persists, got %v", result.PendingPersists)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteDispatcher_ConcurrentInject(t *testing.T) {
|
||||
d := NewRouteDispatcher(slog.Default())
|
||||
routeID := "route-1"
|
||||
|
||||
injectCh := d.MarkActive(routeID)
|
||||
|
||||
const numMessages = 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numMessages)
|
||||
for i := 0; i < numMessages; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
d.Inject(routeID, InjectMessage{Text: "msg"})
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
count := 0
|
||||
for {
|
||||
select {
|
||||
case <-injectCh:
|
||||
count++
|
||||
default:
|
||||
goto done
|
||||
}
|
||||
}
|
||||
done:
|
||||
if count != numMessages {
|
||||
t.Errorf("expected %d messages, got %d", numMessages, count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteDispatcher_ParallelBypass(t *testing.T) {
|
||||
d := NewRouteDispatcher(slog.Default())
|
||||
routeID := "route-1"
|
||||
|
||||
d.MarkActive(routeID)
|
||||
|
||||
// In parallel mode, the caller does not interact with the dispatcher
|
||||
// at all — it just starts a new stream directly. Verify the route
|
||||
// stays active without interference.
|
||||
if !d.IsActive(routeID) {
|
||||
t.Fatal("expected route to still be active")
|
||||
}
|
||||
|
||||
d.MarkDone(routeID)
|
||||
if d.IsActive(routeID) {
|
||||
t.Fatal("expected route to be inactive after MarkDone")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteDispatcher_Cleanup(t *testing.T) {
|
||||
d := NewRouteDispatcher(slog.Default())
|
||||
|
||||
d.MarkActive("route-1")
|
||||
d.MarkDone("route-1")
|
||||
|
||||
d.MarkActive("route-2")
|
||||
|
||||
d.mu.RLock()
|
||||
initialCount := len(d.routes)
|
||||
d.mu.RUnlock()
|
||||
if initialCount != 2 {
|
||||
t.Fatalf("expected 2 routes, got %d", initialCount)
|
||||
}
|
||||
|
||||
d.Cleanup(0)
|
||||
|
||||
d.mu.RLock()
|
||||
afterCount := len(d.routes)
|
||||
d.mu.RUnlock()
|
||||
|
||||
// route-1 is idle → cleaned up; route-2 is active → kept
|
||||
if afterCount != 1 {
|
||||
t.Fatalf("expected 1 route after cleanup, got %d", afterCount)
|
||||
}
|
||||
if d.IsActive("route-2") != true {
|
||||
t.Fatal("expected route-2 to still be active")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteDispatcher_InjectChannelFull(t *testing.T) {
|
||||
d := NewRouteDispatcher(slog.Default())
|
||||
routeID := "route-1"
|
||||
|
||||
d.MarkActive(routeID)
|
||||
|
||||
// Fill the inject channel to capacity
|
||||
for i := 0; i < injectChBuffer; i++ {
|
||||
if !d.Inject(routeID, InjectMessage{Text: "fill"}) {
|
||||
t.Fatalf("expected inject %d to succeed", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Next inject should fail (channel full)
|
||||
if d.Inject(routeID, InjectMessage{Text: "overflow"}) {
|
||||
t.Fatal("expected inject to fail when channel is full")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteDispatcher_QueueWhenInactive(t *testing.T) {
|
||||
d := NewRouteDispatcher(slog.Default())
|
||||
routeID := "route-1"
|
||||
|
||||
// Enqueue without marking active — still stores in queue
|
||||
d.Enqueue(routeID, QueuedTask{Text: "task-1"})
|
||||
|
||||
// MarkActive then MarkDone should return the queued task
|
||||
d.MarkActive(routeID)
|
||||
result := d.MarkDone(routeID)
|
||||
if len(result.QueuedTasks) != 1 {
|
||||
t.Fatalf("expected 1 queued task, got %d", len(result.QueuedTasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteDispatcher_MultipleMarkActive(t *testing.T) {
|
||||
d := NewRouteDispatcher(slog.Default())
|
||||
routeID := "route-1"
|
||||
|
||||
ch1 := d.MarkActive(routeID)
|
||||
ch2 := d.MarkActive(routeID)
|
||||
|
||||
if ch1 == nil || ch2 == nil {
|
||||
t.Fatal("expected non-nil channels")
|
||||
}
|
||||
|
||||
_ = time.Now()
|
||||
}
|
||||
|
||||
func TestRouteDispatcher_PendingPersistOrder(t *testing.T) {
|
||||
d := NewRouteDispatcher(slog.Default())
|
||||
routeID := "route-1"
|
||||
|
||||
d.MarkActive(routeID)
|
||||
|
||||
var order []string
|
||||
d.AddPendingPersist(routeID, func(_ context.Context) {
|
||||
order = append(order, "B")
|
||||
})
|
||||
d.AddPendingPersist(routeID, func(_ context.Context) {
|
||||
order = append(order, "C")
|
||||
})
|
||||
|
||||
result := d.MarkDone(routeID)
|
||||
if len(result.PendingPersists) != 2 {
|
||||
t.Fatalf("expected 2 pending persists, got %d", len(result.PendingPersists))
|
||||
}
|
||||
|
||||
// Execute persists — they should run in insertion order (B then C)
|
||||
for _, fn := range result.PendingPersists {
|
||||
fn(context.Background())
|
||||
}
|
||||
if len(order) != 2 || order[0] != "B" || order[1] != "C" {
|
||||
t.Errorf("expected [B C], got %v", order)
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
@@ -136,10 +137,11 @@ type usageInfo struct {
|
||||
}
|
||||
|
||||
type resolvedContext struct {
|
||||
runConfig agentpkg.RunConfig
|
||||
model models.GetResponse
|
||||
provider sqlc.LlmProvider
|
||||
query string // headerified query
|
||||
runConfig agentpkg.RunConfig
|
||||
model models.GetResponse
|
||||
provider sqlc.LlmProvider
|
||||
query string // headerified query
|
||||
injectedRecords *[]conversation.InjectedMessageRecord
|
||||
}
|
||||
|
||||
func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (resolvedContext, error) {
|
||||
@@ -292,7 +294,40 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
|
||||
LoopDetection: agentpkg.LoopDetectionConfig{Enabled: loopDetectionEnabled},
|
||||
}
|
||||
|
||||
return resolvedContext{runConfig: runCfg, model: chatModel, provider: provider, query: headerifiedQuery}, nil
|
||||
var injectedRecords *[]conversation.InjectedMessageRecord
|
||||
if req.InjectCh != nil {
|
||||
agentInjectCh := make(chan agentpkg.InjectMessage, cap(req.InjectCh))
|
||||
go func() {
|
||||
for msg := range req.InjectCh {
|
||||
agentInjectCh <- agentpkg.InjectMessage{
|
||||
Text: msg.Text,
|
||||
HeaderifiedText: msg.HeaderifiedText,
|
||||
}
|
||||
}
|
||||
close(agentInjectCh)
|
||||
}()
|
||||
runCfg.InjectCh = agentInjectCh
|
||||
|
||||
records := make([]conversation.InjectedMessageRecord, 0)
|
||||
injectedRecords = &records
|
||||
var recMu sync.Mutex
|
||||
runCfg.InjectedRecorder = func(headerifiedText string, insertAfter int) {
|
||||
recMu.Lock()
|
||||
*injectedRecords = append(*injectedRecords, conversation.InjectedMessageRecord{
|
||||
HeaderifiedText: headerifiedText,
|
||||
InsertAfter: insertAfter,
|
||||
})
|
||||
recMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
return resolvedContext{
|
||||
runConfig: runCfg,
|
||||
model: chatModel,
|
||||
provider: provider,
|
||||
query: headerifiedQuery,
|
||||
injectedRecords: injectedRecords,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Chat sends a synchronous chat request and stores the result.
|
||||
|
||||
@@ -162,6 +162,10 @@ func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequ
|
||||
outputMessages := sdkMessagesToModelMessages(sdkMsgs)
|
||||
roundMessages := prependUserMessage(req.Query, outputMessages)
|
||||
|
||||
if rc.injectedRecords != nil && len(*rc.injectedRecords) > 0 {
|
||||
roundMessages = interleaveInjectedMessages(roundMessages, *rc.injectedRecords)
|
||||
}
|
||||
|
||||
if err := r.storeRound(ctx, req, roundMessages, modelID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -173,6 +177,37 @@ func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequ
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// interleaveInjectedMessages inserts injected user messages at their correct
|
||||
// positions within the round. Each record's InsertAfter value indicates how
|
||||
// many output messages preceded the injection.
|
||||
//
|
||||
// round layout: [user_A, output_0, output_1, ..., output_N]
|
||||
// InsertAfter=K → insert after round[K] (i.e. after the K-th output message).
|
||||
func interleaveInjectedMessages(round []conversation.ModelMessage, injections []conversation.InjectedMessageRecord) []conversation.ModelMessage {
|
||||
if len(injections) == 0 {
|
||||
return round
|
||||
}
|
||||
result := make([]conversation.ModelMessage, 0, len(round)+len(injections))
|
||||
injIdx := 0
|
||||
for i, msg := range round {
|
||||
result = append(result, msg)
|
||||
for injIdx < len(injections) && injections[injIdx].InsertAfter == i {
|
||||
result = append(result, conversation.ModelMessage{
|
||||
Role: "user",
|
||||
Content: conversation.NewTextContent(injections[injIdx].HeaderifiedText),
|
||||
})
|
||||
injIdx++
|
||||
}
|
||||
}
|
||||
for ; injIdx < len(injections); injIdx++ {
|
||||
result = append(result, conversation.ModelMessage{
|
||||
Role: "user",
|
||||
Content: conversation.NewTextContent(injections[injIdx].HeaderifiedText),
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func extractInputTokensFromUsage(raw json.RawMessage) int {
|
||||
if len(raw) == 0 {
|
||||
return 0
|
||||
|
||||
@@ -241,6 +241,10 @@ type ChatRequest struct {
|
||||
// Set by the inbound channel processor; called by the resolver at persist time.
|
||||
OutboundAssetCollector func() []OutboundAssetRef `json:"-"`
|
||||
|
||||
// InjectCh receives user messages to inject into the active agent stream
|
||||
// between tool rounds via the PrepareStep hook. Nil means no injection.
|
||||
InjectCh <-chan InjectMessage `json:"-"`
|
||||
|
||||
Query string `json:"query"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
@@ -251,6 +255,24 @@ type ChatRequest struct {
|
||||
Attachments []ChatAttachment `json:"attachments,omitempty"`
|
||||
}
|
||||
|
||||
// InjectMessage carries a user message to be injected into a running agent
|
||||
// stream between tool rounds.
|
||||
type InjectMessage struct {
|
||||
Text string
|
||||
Attachments []ChatAttachment
|
||||
HeaderifiedText string
|
||||
}
|
||||
|
||||
// InjectedMessageRecord records a message that was injected via PrepareStep,
|
||||
// together with its position in the output message sequence.
|
||||
type InjectedMessageRecord struct {
|
||||
HeaderifiedText string
|
||||
// InsertAfter is the number of SDK output messages that existed before
|
||||
// this injection. Used to determine the correct insertion position when
|
||||
// interleaving injected messages into the persisted round.
|
||||
InsertAfter int
|
||||
}
|
||||
|
||||
// ChatResponse is the output of a non-streaming chat call.
|
||||
type ChatResponse struct {
|
||||
Messages []ModelMessage `json:"messages"`
|
||||
|
||||
Reference in New Issue
Block a user