From c4114227e56376d5836ff3fccc2c98bbfe417466 Mon Sep 17 00:00:00 2001 From: EYHN Date: Mon, 13 Apr 2026 20:28:42 +0800 Subject: [PATCH] feat(agent): add background task execution and notifications (#365) --- cmd/agent/main.go | 42 +- cmd/bridge/server.go | 31 +- cmd/memoh/serve.go | 32 +- internal/agent/agent.go | 66 ++ internal/agent/background/manager.go | 659 ++++++++++++++++++ internal/agent/background/manager_test.go | 429 ++++++++++++ internal/agent/background/types.go | 159 +++++ internal/agent/background_exec_e2e_test.go | 580 +++++++++++++++ internal/agent/read_media_test.go | 6 +- internal/agent/tools/container.go | 341 ++++++++- internal/agent/tools/container_test.go | 35 + internal/agent/tools/message.go | 4 +- internal/agent/types.go | 8 + internal/conversation/flow/resolver.go | 29 +- internal/conversation/flow/resolver_stream.go | 7 +- .../conversation/flow/resolver_trigger.go | 192 +++++ .../flow/resolver_trigger_test.go | 233 +++++++ internal/conversation/flow/resolver_turns.go | 112 +++ 18 files changed, 2934 insertions(+), 31 deletions(-) create mode 100644 internal/agent/background/manager.go create mode 100644 internal/agent/background/manager_test.go create mode 100644 internal/agent/background/types.go create mode 100644 internal/agent/background_exec_e2e_test.go create mode 100644 internal/agent/tools/container_test.go create mode 100644 internal/conversation/flow/resolver_trigger_test.go create mode 100644 internal/conversation/flow/resolver_turns.go diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 198f4980..9f2c4132 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -24,6 +24,7 @@ import ( "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/acl" agentpkg "github.com/memohai/memoh/internal/agent" + "github.com/memohai/memoh/internal/agent/background" agenttools "github.com/memohai/memoh/internal/agent/tools" "github.com/memohai/memoh/internal/bind" "github.com/memohai/memoh/internal/boot" @@ -230,6 +231,7 @@ func runServe() { provideContainerdHandler, provideFederationGateway, provideToolGatewayService, + provideBackgroundManager, provideToolProviders, // http handlers (group:"server_handlers") @@ -280,9 +282,11 @@ func runServe() { startScheduleService, startHeartbeatService, + wireResolverOutbound, startChannelManager, startEmailManager, startContainerReconciliation, + startBackgroundTaskCleanup, startTtsTempStoreCleanup, startServer, ), @@ -486,15 +490,20 @@ func injectToolProviders(a *agentpkg.Agent, msgService *message.DBService, provi } } -func provideChatResolver(log *slog.Logger, a *agentpkg.Agent, modelsService *models.Service, queries *dbsqlc.Queries, chatService *conversation.Service, msgService *message.DBService, settingsService *settings.Service, accountService *accounts.Service, mediaService *media.Service, containerdHandler *handlers.ContainerdHandler, memoryRegistry *memprovider.Registry, sessionService *sessionpkg.Service, eventHub *event.Hub, compactionService *compaction.Service, pipeline *pipelinepkg.Pipeline, rc *boot.RuntimeConfig) *flow.Resolver { +func provideChatResolver(log *slog.Logger, a *agentpkg.Agent, modelsService *models.Service, queries *dbsqlc.Queries, chatService *conversation.Service, msgService *message.DBService, settingsService *settings.Service, accountService *accounts.Service, mediaService *media.Service, containerdHandler *handlers.ContainerdHandler, memoryRegistry *memprovider.Registry, routeService *route.DBService, sessionService *sessionpkg.Service, eventHub *event.Hub, compactionService *compaction.Service, pipeline *pipelinepkg.Pipeline, rc *boot.RuntimeConfig, bgManager *background.Manager) *flow.Resolver { resolver := flow.NewResolver(log, modelsService, queries, chatService, msgService, settingsService, accountService, a, rc.TimezoneLocation, 120*time.Second) resolver.SetMemoryRegistry(memoryRegistry) resolver.SetSkillLoader(&skillLoaderAdapter{handler: containerdHandler}) resolver.SetGatewayAssetLoader(&gatewayAssetLoaderAdapter{media: mediaService}) + resolver.SetRouteService(routeService) resolver.SetSessionService(sessionService) resolver.SetEventPublisher(eventHub) resolver.SetCompactionService(compactionService) resolver.SetPipeline(pipeline) + resolver.SetBackgroundManager(bgManager) + bgManager.SetWakeFunc(func(botID, sessionID string) { + resolver.TriggerBackgroundNotification(context.Background(), botID, sessionID) + }) return resolver } @@ -669,7 +678,11 @@ func provideToolGatewayService(log *slog.Logger, fedGateway *handlers.MCPFederat return svc } -func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *channel.Manager, registry *channel.Registry, routeService *route.DBService, scheduleService *schedule.Service, settingsService *settings.Service, searchProviderService *searchproviders.Service, manager *workspace.Manager, mediaService *media.Service, memoryRegistry *memprovider.Registry, emailService *emailpkg.Service, emailManager *emailpkg.Manager, fedGateway *handlers.MCPFederationGateway, mcpConnService *mcp.ConnectionService, modelsService *models.Service, browserContextService *browsercontexts.Service, queries *dbsqlc.Queries, ttsService *ttspkg.Service, sessionService *sessionpkg.Service) []agenttools.ToolProvider { +func provideBackgroundManager(log *slog.Logger) *background.Manager { + return background.New(log) +} + +func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *channel.Manager, registry *channel.Registry, routeService *route.DBService, scheduleService *schedule.Service, settingsService *settings.Service, searchProviderService *searchproviders.Service, manager *workspace.Manager, mediaService *media.Service, memoryRegistry *memprovider.Registry, emailService *emailpkg.Service, emailManager *emailpkg.Manager, fedGateway *handlers.MCPFederationGateway, mcpConnService *mcp.ConnectionService, modelsService *models.Service, browserContextService *browsercontexts.Service, queries *dbsqlc.Queries, ttsService *ttspkg.Service, sessionService *sessionpkg.Service, bgManager *background.Manager) []agenttools.ToolProvider { var assetResolver messaging.AssetResolver if mediaService != nil { assetResolver = &mediaAssetResolverAdapter{media: mediaService} @@ -681,7 +694,7 @@ func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *c agenttools.NewScheduleProvider(log, scheduleService), agenttools.NewMemoryProvider(log, memoryRegistry, settingsService), agenttools.NewWebProvider(log, settingsService, searchProviderService), - agenttools.NewContainerProvider(log, manager, config.DefaultDataMount), + agenttools.NewContainerProvider(log, manager, bgManager, config.DefaultDataMount), agenttools.NewEmailProvider(log, emailService, emailManager), agenttools.NewWebFetchProvider(log), agenttools.NewSpawnProvider(log, settingsService, modelsService, queries, sessionService), @@ -771,6 +784,20 @@ func startTtsTempStoreCleanup(lc fx.Lifecycle, store *ttspkg.TempStore) { }) } +func startBackgroundTaskCleanup(lc fx.Lifecycle, mgr *background.Manager) { + done := make(chan struct{}) + lc.Append(fx.Hook{ + OnStart: func(_ context.Context) error { + go mgr.StartCleanupLoop(done, background.DefaultCleanupInterval, background.DefaultTaskRetention) + return nil + }, + OnStop: func(_ context.Context) error { + close(done) + return nil + }, + }) +} + // settingsTtsModelResolver adapts settings.Service to the ttsModelResolver interface // expected by ChannelInboundProcessor and LocalChannelHandler. // sessionEnsurerAdapter adapts session.Service to the inbound sessionEnsurer interface. @@ -958,6 +985,15 @@ func startHeartbeatService(lc fx.Lifecycle, heartbeatService *heartbeat.Service) }) } +func wireResolverOutbound(resolver *flow.Resolver, channelManager *channel.Manager) { + resolver.SetOutboundFn(func(ctx context.Context, botID, channelType, target, text string) error { + return channelManager.Send(ctx, botID, channel.ChannelType(channelType), channel.SendRequest{ + Target: target, + Message: channel.Message{Text: text}, + }) + }) +} + func startChannelManager(lc fx.Lifecycle, channelManager *channel.Manager) { ctx, cancel := context.WithCancel(context.Background()) lc.Append(fx.Hook{ diff --git a/cmd/bridge/server.go b/cmd/bridge/server.go index 74e4b920..24de8086 100644 --- a/cmd/bridge/server.go +++ b/cmd/bridge/server.go @@ -369,10 +369,14 @@ func execPipe(stream pb.ContainerService_ExecServer, firstMsg *pb.ExecInput) err timeout = defaultTimeout } - ctx, cancel := context.WithTimeout(stream.Context(), time.Duration(timeout)*time.Second) - defer cancel() + // Process context is independent of the gRPC stream so the process keeps + // running even if the stream is cancelled (e.g. background tasks whose client + // disconnects or whose stream context dies after the process completes). + // Only the process-level timeout kills the process, not stream death. + procCtx, procCancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) + defer procCancel() - cmd := exec.CommandContext(ctx, "/bin/sh", "-c", command) //nolint:gosec // G204: MCP exec tool intentionally executes agent-issued shell commands inside the container + cmd := exec.CommandContext(procCtx, "/bin/sh", "-c", command) //nolint:gosec // G204: MCP exec tool intentionally executes agent-issued shell commands inside the container cmd.Dir = workDir if len(firstMsg.GetEnv()) > 0 { cmd.Env = append(os.Environ(), firstMsg.GetEnv()...) @@ -396,13 +400,15 @@ func execPipe(stream pb.ContainerService_ExecServer, firstMsg *pb.ExecInput) err return status.Errorf(codes.Internal, "start: %v", err) } - // When the context deadline fires, exec.CommandContext sends SIGKILL to the - // main process. However, child processes may still hold the stdout/stderr - // pipe file descriptors open, causing streamPipe's Read to block forever. - // Closing the pipes here unblocks those reads so the function can proceed - // to cmd.Wait and send the EXIT message back to the client. + // Close pipes when EITHER the process finishes (procCtx done) OR the gRPC + // stream dies (stream.Context done). Closing unblocks streamPipe's Read so + // the goroutines can exit. We do NOT cancel procCtx on stream death — the + // process keeps running so background tasks survive client disconnects. go func() { - <-ctx.Done() + select { + case <-procCtx.Done(): + case <-stream.Context().Done(): + } _ = stdoutPipe.Close() _ = stderrPipe.Close() }() @@ -439,10 +445,15 @@ func execPipe(stream pb.ContainerService_ExecServer, firstMsg *pb.ExecInput) err } } - return stream.Send(&pb.ExecOutput{ + // Send exit code to the client. If the stream is already gone (e.g. the + // client is a background task manager that got a stream error when the + // process completed), the send will fail but we return nil so the gRPC + // handler does not propagate a spurious "context canceled" error status. + _ = stream.Send(&pb.ExecOutput{ Stream: pb.ExecOutput_EXIT, ExitCode: exitCode, }) + return nil } func (*containerServer) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error { diff --git a/cmd/memoh/serve.go b/cmd/memoh/serve.go index 393aad64..fc54b734 100644 --- a/cmd/memoh/serve.go +++ b/cmd/memoh/serve.go @@ -24,6 +24,7 @@ import ( "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/acl" agentpkg "github.com/memohai/memoh/internal/agent" + "github.com/memohai/memoh/internal/agent/background" agenttools "github.com/memohai/memoh/internal/agent/tools" "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/bind" @@ -156,6 +157,7 @@ func runServe() { provideContainerdHandler, provideFederationGateway, provideToolGatewayService, + provideBackgroundManager, provideToolProviders, provideServerHandler(handlers.NewPingHandler), provideServerHandler(provideMemohAuthHandler), @@ -207,6 +209,7 @@ func runServe() { startChannelManager, startEmailManager, startContainerReconciliation, + startBackgroundTaskCleanup, startTtsTempStoreCleanup, startServer, ), @@ -417,15 +420,20 @@ func injectToolProviders(a *agentpkg.Agent, msgService *message.DBService, provi } } -func provideChatResolver(log *slog.Logger, a *agentpkg.Agent, modelsService *models.Service, queries *dbsqlc.Queries, chatService *conversation.Service, msgService *message.DBService, settingsService *settings.Service, accountService *accounts.Service, mediaService *media.Service, containerdHandler *handlers.ContainerdHandler, memoryRegistry *memprovider.Registry, sessionService *sessionpkg.Service, eventHub *event.Hub, compactionService *compaction.Service, pipeline *pipelinepkg.Pipeline, rc *boot.RuntimeConfig) *flow.Resolver { +func provideChatResolver(log *slog.Logger, a *agentpkg.Agent, modelsService *models.Service, queries *dbsqlc.Queries, chatService *conversation.Service, msgService *message.DBService, settingsService *settings.Service, accountService *accounts.Service, mediaService *media.Service, containerdHandler *handlers.ContainerdHandler, memoryRegistry *memprovider.Registry, routeService *route.DBService, sessionService *sessionpkg.Service, eventHub *event.Hub, compactionService *compaction.Service, pipeline *pipelinepkg.Pipeline, rc *boot.RuntimeConfig, bgManager *background.Manager) *flow.Resolver { resolver := flow.NewResolver(log, modelsService, queries, chatService, msgService, settingsService, accountService, a, rc.TimezoneLocation, 120*time.Second) resolver.SetMemoryRegistry(memoryRegistry) resolver.SetSkillLoader(&skillLoaderAdapter{handler: containerdHandler}) resolver.SetGatewayAssetLoader(&gatewayAssetLoaderAdapter{media: mediaService}) + resolver.SetRouteService(routeService) resolver.SetSessionService(sessionService) resolver.SetEventPublisher(eventHub) resolver.SetCompactionService(compactionService) resolver.SetPipeline(pipeline) + resolver.SetBackgroundManager(bgManager) + bgManager.SetWakeFunc(func(botID, sessionID string) { + resolver.TriggerBackgroundNotification(context.Background(), botID, sessionID) + }) return resolver } @@ -552,7 +560,11 @@ func provideToolGatewayService(log *slog.Logger, fedGateway *handlers.MCPFederat return svc } -func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *channel.Manager, registry *channel.Registry, routeService *route.DBService, scheduleService *schedule.Service, settingsService *settings.Service, searchProviderService *searchproviders.Service, manager *workspace.Manager, mediaService *media.Service, memoryRegistry *memprovider.Registry, emailService *emailpkg.Service, emailManager *emailpkg.Manager, fedGateway *handlers.MCPFederationGateway, mcpConnService *mcp.ConnectionService, modelsService *models.Service, browserContextService *browsercontexts.Service, queries *dbsqlc.Queries, ttsService *ttspkg.Service, sessionService *sessionpkg.Service) []agenttools.ToolProvider { +func provideBackgroundManager(log *slog.Logger) *background.Manager { + return background.New(log) +} + +func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *channel.Manager, registry *channel.Registry, routeService *route.DBService, scheduleService *schedule.Service, settingsService *settings.Service, searchProviderService *searchproviders.Service, manager *workspace.Manager, mediaService *media.Service, memoryRegistry *memprovider.Registry, emailService *emailpkg.Service, emailManager *emailpkg.Manager, fedGateway *handlers.MCPFederationGateway, mcpConnService *mcp.ConnectionService, modelsService *models.Service, browserContextService *browsercontexts.Service, queries *dbsqlc.Queries, ttsService *ttspkg.Service, sessionService *sessionpkg.Service, bgManager *background.Manager) []agenttools.ToolProvider { var assetResolver messaging.AssetResolver if mediaService != nil { assetResolver = &mediaAssetResolverAdapter{media: mediaService} @@ -564,7 +576,7 @@ func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *c agenttools.NewScheduleProvider(log, scheduleService), agenttools.NewMemoryProvider(log, memoryRegistry, settingsService), agenttools.NewWebProvider(log, settingsService, searchProviderService), - agenttools.NewContainerProvider(log, manager, config.DefaultDataMount), + agenttools.NewContainerProvider(log, manager, bgManager, config.DefaultDataMount), agenttools.NewEmailProvider(log, emailService, emailManager), agenttools.NewWebFetchProvider(log), agenttools.NewSpawnProvider(log, settingsService, modelsService, queries, sessionService), @@ -869,6 +881,20 @@ func startTtsTempStoreCleanup(lc fx.Lifecycle, store *ttspkg.TempStore) { }) } +func startBackgroundTaskCleanup(lc fx.Lifecycle, mgr *background.Manager) { + done := make(chan struct{}) + lc.Append(fx.Hook{ + OnStart: func(_ context.Context) error { + go mgr.StartCleanupLoop(done, background.DefaultCleanupInterval, background.DefaultTaskRetention) + return nil + }, + OnStop: func(_ context.Context) error { + close(done) + return nil + }, + }) +} + // settingsTtsModelResolver adapts settings.Service to the ttsModelResolver interface // expected by ChannelInboundProcessor and LocalChannelHandler. type sessionEnsurerAdapter struct { diff --git a/internal/agent/agent.go b/internal/agent/agent.go index f77f1622..efdc4808 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -11,6 +11,7 @@ import ( sdk "github.com/memohai/twilight-ai/sdk" + "github.com/memohai/memoh/internal/agent/background" "github.com/memohai/memoh/internal/agent/tools" "github.com/memohai/memoh/internal/models" "github.com/memohai/memoh/internal/workspace/bridge" @@ -169,6 +170,23 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv } } + // Drain background task notifications at step boundaries. + // Each notification is injected as a user message so the model + // discovers completed background work naturally. + if cfg.BackgroundManager != nil { + basePrepare := prepareStep + baseSystem := cfg.System // capture original system prompt to avoid accumulation + prepareStep = func(p *sdk.GenerateParams) *sdk.GenerateParams { + if basePrepare != nil { + if override := basePrepare(p); override != nil { + p = override + } + } + p = drainBackgroundNotifications(p, cfg.BackgroundManager, baseSystem, cfg.Identity.BotID, cfg.Identity.SessionID, a.logger) + return p + } + } + opts := a.buildGenerateOptions(cfg, sdkTools, prepareStep) retryCfg := cfg.Retry @@ -452,6 +470,22 @@ func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult if readMediaState != nil { prepareStep = readMediaState.prepareStep } + + // Drain background task notifications at step boundaries (non-streaming). + if cfg.BackgroundManager != nil { + basePrepare := prepareStep + baseSystem := cfg.System + prepareStep = func(p *sdk.GenerateParams) *sdk.GenerateParams { + if basePrepare != nil { + if override := basePrepare(p); override != nil { + p = override + } + } + p = drainBackgroundNotifications(p, cfg.BackgroundManager, baseSystem, cfg.Identity.BotID, cfg.Identity.SessionID, a.logger) + return p + } + } + opts := a.buildGenerateOptions(cfg, sdkTools, prepareStep) opts = append(opts, sdk.WithOnStep(func(step *sdk.StepResult) *sdk.GenerateParams { @@ -635,6 +669,38 @@ func toolStreamEventToAgentEvent(evt tools.ToolStreamEvent) StreamEvent { } } +// drainBackgroundNotifications non-blockingly drains pending background task +// notifications for the given bot+session and injects them as user messages +// into the next LLM step at step boundaries. +func drainBackgroundNotifications( + p *sdk.GenerateParams, + mgr *background.Manager, + baseSystem string, + botID, sessionID string, + logger *slog.Logger, +) *sdk.GenerateParams { + // Inject running tasks summary into system prompt so the model + // knows about ongoing background work even after compaction. + // Always start from baseSystem to avoid accumulating summaries across steps. + if summary := mgr.RunningTasksSummary(botID, sessionID); summary != "" { + p.System = baseSystem + "\n\n" + summary + } else { + p.System = baseSystem + } + + notifications := mgr.DrainNotifications(botID, sessionID) + for _, n := range notifications { + p.Messages = append(p.Messages, sdk.UserMessage(n.MessageText())) + logger.Info("injected background task notification", + slog.String("task_id", n.TaskID), + slog.String("status", string(n.Status)), + slog.Bool("stalled", n.Stalled), + slog.String("bot_id", botID), + ) + } + return p +} + func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs map[string]struct{}) []sdk.Tool { wrapped := make([]sdk.Tool, len(tools)) for i, tool := range tools { diff --git a/internal/agent/background/manager.go b/internal/agent/background/manager.go new file mode 100644 index 00000000..e80472cf --- /dev/null +++ b/internal/agent/background/manager.go @@ -0,0 +1,659 @@ +// Package background implements a background task manager for long-running +// commands executed inside bot containers. It follows a task-notification +// architecture: +// +// 1. Commands can be started in the background (fire-and-forget). +// 2. Output is collected asynchronously and written to a file in the container. +// 3. When a task completes, a structured Notification is enqueued. +// 4. Notifications are scoped to (botID, sessionID); the agent loop drains +// them at step boundaries and injects them as context messages so the +// model learns about completed work. +// +// The manager is a server-level singleton, safe for concurrent use. +package background + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "log/slog" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/memohai/memoh/internal/workspace/bridge" +) + +const ( + // DefaultExecTimeout is the default timeout for foreground exec calls. + DefaultExecTimeout int32 = 30 + // MaxExecTimeout is the maximum allowed timeout (10 minutes). + MaxExecTimeout int32 = 600 + // BackgroundExecTimeout is the timeout for background tasks (30 minutes). + BackgroundExecTimeout int32 = 1800 + // DefaultCleanupInterval is how often the manager prunes old completed tasks. + DefaultCleanupInterval = time.Hour + // DefaultTaskRetention is how long completed tasks are retained in memory. + DefaultTaskRetention = 24 * time.Hour + // OutputLogDir is the directory inside the container where background + // task output logs are written. + OutputLogDir = "/tmp/memoh-bg" + + // stallCheckInterval is how often the stall watchdog checks output growth. + stallCheckInterval = 5 * time.Second + // stallThreshold is the duration of zero output growth before we consider + // the command stalled and possibly waiting for interactive input. + stallThreshold = 45 * time.Second +) + +// ExecFunc executes a command in a container and returns the result. +// This is the signature that bridge.Client.Exec satisfies. +type ExecFunc func(ctx context.Context, command, workDir string, timeout int32) (*bridge.ExecResult, error) + +// WriteFileFunc writes content to a file in the container. +type WriteFileFunc func(ctx context.Context, path string, data []byte) error + +// ReadFileFunc reads content from a file in the container. +type ReadFileFunc func(ctx context.Context, path string) ([]byte, error) + +// Manager tracks background tasks and delivers completion notifications. +type Manager struct { + mu sync.Mutex + tasks map[string]*Task // taskID -> Task + notifications []Notification // pending notifications, protected by mu + logger *slog.Logger + wakeFunc func(botID, sessionID string) // optional callback to wake agent on new notification +} + +// New creates a new background task Manager. +func New(logger *slog.Logger) *Manager { + if logger == nil { + logger = slog.Default() + } + return &Manager{ + tasks: make(map[string]*Task), + logger: logger.With(slog.String("service", "background")), + } +} + +// SetWakeFunc registers a callback that is invoked (in a goroutine) whenever a +// new notification is enqueued. Use this to wake up a sleeping agent so it +// can drain the notification immediately instead of waiting for user input. +func (m *Manager) SetWakeFunc(fn func(botID, sessionID string)) { + m.mu.Lock() + m.wakeFunc = fn + m.mu.Unlock() +} + +// enqueueNotification appends n to the pending list and, if a wake function is +// registered, calls it asynchronously so the agent can process the notification. +func (m *Manager) enqueueNotification(n Notification) { + m.mu.Lock() + m.notifications = append(m.notifications, n) + wakeFn := m.wakeFunc + m.mu.Unlock() + m.logger.Info("notification enqueued", + slog.String("task_id", n.TaskID), + slog.String("bot_id", n.BotID), + slog.Bool("has_wake_func", wakeFn != nil), + ) + if wakeFn != nil { + go wakeFn(n.BotID, n.SessionID) + } +} + +// Spawn starts a command in the background. It returns the task ID immediately. +// The command runs asynchronously; when it completes, a Notification is sent +// to the Notifications channel. +// +// execFn should call bridge.Client.Exec (or equivalent). +// writeFn should call bridge.Client.WriteFile to persist output logs. +func (m *Manager) Spawn( + parentCtx context.Context, + botID, sessionID, command, workDir, description string, + execFn ExecFunc, + writeFn WriteFileFunc, + readFn ReadFileFunc, +) (taskID, outputFile string) { + m.mu.Lock() + taskID = m.newTaskIDLocked(botID) + outputFile = fmt.Sprintf("%s/%s.log", OutputLogDir, taskID) + + task := &Task{ + ID: taskID, + BotID: botID, + SessionID: sessionID, + Command: command, + Description: description, + WorkDir: workDir, + Status: TaskRunning, + OutputFile: outputFile, + StartedAt: time.Now(), + } + m.tasks[taskID] = task + m.mu.Unlock() + + m.logger.Info("background task spawned", + slog.String("task_id", taskID), + slog.String("bot_id", botID), + slog.String("command", truncate(command, 120)), + ) + + go m.run(parentCtx, task, execFn, writeFn, readFn) + return taskID, outputFile +} + +// SpawnAdopt registers a background task for a command that is already running +// externally (e.g. via ExecStream). Instead of re-executing the command, it +// waits for the result on the provided channel. This enables "flip to background" +// where a foreground stream is handed off without killing the process. +func (m *Manager) SpawnAdopt( + parentCtx context.Context, + botID, sessionID, command, workDir, description string, + resultCh <-chan AdoptResult, + writeFn WriteFileFunc, +) (taskID, outputFile string) { + m.mu.Lock() + taskID = m.newTaskIDLocked(botID) + outputFile = fmt.Sprintf("%s/%s.log", OutputLogDir, taskID) + + task := &Task{ + ID: taskID, + BotID: botID, + SessionID: sessionID, + Command: command, + Description: description, + WorkDir: workDir, + Status: TaskRunning, + OutputFile: outputFile, + StartedAt: time.Now(), + } + m.tasks[taskID] = task + m.mu.Unlock() + + m.logger.Info("background task adopted", + slog.String("task_id", taskID), + slog.String("bot_id", botID), + slog.String("command", truncate(command, 120)), + ) + + go m.runAdopt(parentCtx, task, resultCh, writeFn) + return taskID, outputFile +} + +func (m *Manager) newTaskIDLocked(botID string) string { + prefix := botID[:min(8, len(botID))] + for { + id := fmt.Sprintf("bg_%s_%s", prefix, shortRandHex(4)) + if _, exists := m.tasks[id]; !exists { + return id + } + } +} + +func shortRandHex(n int) string { + if n <= 0 { + n = 4 + } + buf := make([]byte, n) + if _, err := rand.Read(buf); err != nil { + panic(fmt.Errorf("background: read random bytes: %w", err)) + } + return hex.EncodeToString(buf) +} + +// runAdopt waits for the adopted stream result and handles completion. +func (m *Manager) runAdopt(parentCtx context.Context, task *Task, resultCh <-chan AdoptResult, writeFn WriteFileFunc) { + ctx, cancel := detachedContextWithTimeout(parentCtx, time.Duration(BackgroundExecTimeout)*time.Second) + task.mu.Lock() + task.cancel = cancel + task.mu.Unlock() + defer cancel() + + // Ensure output directory exists. + _ = ensureOutputDir(ctx, writeFn) + + // Start stall watchdog. + go m.stallWatchdog(ctx, task) + + // Wait for the result from the already-running stream. + var result AdoptResult + select { + case result = <-resultCh: + case <-ctx.Done(): + result = AdoptResult{Err: ctx.Err()} + } + + // Write output to log file in container. + if writeFn != nil && result.Err == nil { + combined := result.Stdout + if result.Stderr != "" { + combined += "\n--- stderr ---\n" + result.Stderr + } + _ = writeFn(context.WithoutCancel(ctx), task.OutputFile, []byte(combined)) + } + + m.completeTask(task, result.Stdout, result.Stderr, result.Err, result.ExitCode) +} + +func (m *Manager) run(parentCtx context.Context, task *Task, execFn ExecFunc, writeFn WriteFileFunc, readFn ReadFileFunc) { + ctx, cancel := detachedContextWithTimeout(parentCtx, time.Duration(BackgroundExecTimeout)*time.Second) + task.mu.Lock() + task.cancel = cancel + task.mu.Unlock() + defer cancel() + + // Ensure output directory exists. + _ = ensureOutputDir(ctx, writeFn) + + // Start stall watchdog to detect commands waiting for interactive input. + go m.stallWatchdog(ctx, task) + + // Wrap command to tee output to the log file inside the container and + // capture the command exit code into a sentinel file via fd 3 redirect. + // Even if the gRPC stream dies after process completion, we can recover + // the actual exit code by reading the sentinel file. + wrappedCmd := fmt.Sprintf( + "{ { ( %s ) ; echo $? >&3 ; } 2>&1 | tee %s ; } 3>%s.exit", + task.Command, task.OutputFile, task.OutputFile, + ) + + result, err := execFn(ctx, wrappedCmd, task.WorkDir, BackgroundExecTimeout) + if err != nil { + m.logger.Warn("background task: execFn returned error", + slog.String("task_id", task.ID), + slog.Any("exec_error", err), + ) + } + + // Always prefer the sentinel file for the real exit code. + // The wrappedCmd uses a pipeline: the shell exits with tee's code (0), + // not the actual command's code. The sentinel captures the real value. + // On gRPC error the sentinel also lets us recover without -1. + if readFn != nil { + ec, recoverErr := readSentinelExitCode(ctx, task.OutputFile+".exit", readFn) + if recoverErr == nil { + if err != nil { + m.logger.Info("background task: recovered exit code from sentinel file after stream error", + slog.String("task_id", task.ID), + slog.Int("recovered_exit_code", int(ec)), + slog.Any("stream_error", err), + ) + } + result = &bridge.ExecResult{ExitCode: ec} + err = nil + } else if err != nil { + m.logger.Warn("background task: sentinel recovery failed", + slog.String("task_id", task.ID), + slog.Any("recover_error", recoverErr), + ) + } + // If err==nil but sentinel unreadable: fall through to use gRPC exit code + } + + var stdout, stderr string + var exitCode int32 + if result != nil { + stdout = result.Stdout + stderr = result.Stderr + exitCode = result.ExitCode + } + m.completeTask(task, stdout, stderr, err, exitCode) +} + +func (m *Manager) completeTask(task *Task, stdout, stderr string, execErr error, exitCode int32) { + if execErr != nil { + task.AppendOutput(fmt.Sprintf("[error] %v\n", execErr)) + } else { + task.AppendOutput(stdout) + if stderr != "" { + task.AppendOutput(stderr) + } + } + + task.mu.Lock() + if task.Status == TaskKilled { + task.mu.Unlock() + return + } + task.CompletedAt = time.Now() + if execErr != nil { + task.Status = TaskFailed + task.ExitCode = -1 + } else { + task.ExitCode = exitCode + if exitCode == 0 { + task.Status = TaskCompleted + } else { + task.Status = TaskFailed + } + } + status := task.Status + finalExitCode := task.ExitCode + duration := task.CompletedAt.Sub(task.StartedAt) + task.mu.Unlock() + + m.logger.Info("background task finished", + slog.String("task_id", task.ID), + slog.String("status", string(status)), + slog.Int("exit_code", int(finalExitCode)), + slog.Duration("duration", duration), + ) + + // Guard against double notification when Kill or an auto-background race + // already enqueued one for this task. + if !task.MarkNotified() { + return + } + + m.enqueueNotification(Notification{ + TaskID: task.ID, + BotID: task.BotID, + SessionID: task.SessionID, + Status: status, + Command: task.Command, + Description: task.Description, + ExitCode: finalExitCode, + OutputFile: task.OutputFile, + OutputTail: task.OutputTail(), + Duration: duration, + }) +} + +func readSentinelExitCode(ctx context.Context, path string, readFn ReadFileFunc) (int32, error) { + data, err := readFn(ctx, path) + if err != nil { + return 0, err + } + ec, err := strconv.Atoi(strings.TrimSpace(string(data))) + if err != nil { + return 0, fmt.Errorf("parse exit code %q: %w", string(data), err) + } + return int32(ec), nil //nolint:gosec // G115: exit codes are 0-255 +} + +func ensureOutputDir(ctx context.Context, writeFn WriteFileFunc) error { + if writeFn == nil { + return nil + } + // Create a marker file to ensure the directory exists. + return writeFn(ctx, OutputLogDir+"/.keep", []byte("")) +} + +// Kill cancels a running background task. +func (m *Manager) Kill(taskID string) error { + m.mu.Lock() + task, ok := m.tasks[taskID] + m.mu.Unlock() + if !ok { + return fmt.Errorf("task %s not found", taskID) + } + task.mu.Lock() + if task.Status != TaskRunning { + task.mu.Unlock() + return fmt.Errorf("task %s is not running (status: %s)", taskID, task.Status) + } + task.Status = TaskKilled + task.CompletedAt = time.Now() + task.mu.Unlock() + + task.Cancel() + m.logger.Info("background task killed", slog.String("task_id", taskID)) + return nil +} + +// Get returns a task by ID, or nil if not found. +func (m *Manager) Get(taskID string) *Task { + m.mu.Lock() + defer m.mu.Unlock() + return m.tasks[taskID] +} + +// GetForSession returns a task by ID only if it belongs to the provided +// bot+session. +func (m *Manager) GetForSession(botID, sessionID, taskID string) *Task { + m.mu.Lock() + defer m.mu.Unlock() + task := m.tasks[taskID] + if task == nil || task.BotID != botID || task.SessionID != sessionID { + return nil + } + return task +} + +// ListForSession returns all tasks for a given bot+session, most recent first. +func (m *Manager) ListForSession(botID, sessionID string) []*Task { + m.mu.Lock() + defer m.mu.Unlock() + var result []*Task + for _, t := range m.tasks { + if t.BotID == botID && t.SessionID == sessionID { + result = append(result, t) + } + } + return result +} + +// KillForSession cancels a running background task only when it belongs to the +// provided bot+session. +func (m *Manager) KillForSession(botID, sessionID, taskID string) error { + task := m.GetForSession(botID, sessionID, taskID) + if task == nil { + return fmt.Errorf("task %s not found", taskID) + } + return m.Kill(taskID) +} + +// DrainNotifications returns all pending notifications for a given +// bot+session without blocking. Used by the resolver to inject +// notifications at the start of a new agent run. +func (m *Manager) DrainNotifications(botID, sessionID string) []Notification { + m.mu.Lock() + defer m.mu.Unlock() + + var matched []Notification + remaining := m.notifications[:0] // reuse backing array + for _, n := range m.notifications { + if n.BotID == botID && n.SessionID == sessionID { + matched = append(matched, n) + } else { + remaining = append(remaining, n) + } + } + m.notifications = remaining + return matched +} + +// HasNotifications reports whether there are pending notifications for the +// given bot+session without consuming them. +func (m *Manager) HasNotifications(botID, sessionID string) bool { + m.mu.Lock() + defer m.mu.Unlock() + for _, n := range m.notifications { + if n.BotID == botID && n.SessionID == sessionID { + return true + } + } + return false +} + +// RunningTasksSummary returns a text summary of currently running tasks +// for a given bot+session. This is injected into the system prompt so the +// agent knows about ongoing background work. +func (m *Manager) RunningTasksSummary(botID, sessionID string) string { + m.mu.Lock() + defer m.mu.Unlock() + var lines []string + for _, t := range m.tasks { + if t.BotID == botID && t.SessionID == sessionID && t.Status == TaskRunning { + desc := t.Description + if desc == "" { + desc = truncate(t.Command, 80) + } + lines = append(lines, fmt.Sprintf("- [%s] %s (started %s ago, output: %s)", + t.ID, desc, + time.Since(t.StartedAt).Round(time.Second), + t.OutputFile, + )) + } + } + if len(lines) == 0 { + return "" + } + return "Currently running background tasks:\n" + joinLines(lines) +} + +// Cleanup removes completed tasks older than the given duration. +func (m *Manager) Cleanup(maxAge time.Duration) { + m.mu.Lock() + defer m.mu.Unlock() + cutoff := time.Now().Add(-maxAge) + for id, t := range m.tasks { + if t.Status != TaskRunning && t.CompletedAt.Before(cutoff) { + delete(m.tasks, id) + } + } +} + +// StartCleanupLoop periodically removes old completed tasks until done is closed. +func (m *Manager) StartCleanupLoop(done <-chan struct{}, interval, maxAge time.Duration) { + if interval <= 0 { + interval = DefaultCleanupInterval + } + if maxAge <= 0 { + maxAge = DefaultTaskRetention + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + m.Cleanup(maxAge) + case <-done: + return + } + } +} + +// RequeueNotifications puts notifications back into the pending queue. +// Used when proactive delivery for a session fails and the batch should be retried. +func (m *Manager) RequeueNotifications(ns []Notification) { + if len(ns) == 0 { + return + } + m.mu.Lock() + m.notifications = append(m.notifications, ns...) + m.mu.Unlock() +} + +// promptPatterns matches common interactive prompt endings that indicate +// a command is waiting for user input. +var promptPatterns = regexp.MustCompile( + `(?i)(\$ ?$|> ?$|# ?$|password\s*:|passphrase\s*:|y/n\]|yes/no\)|enter .*:|Press .* to continue|Are you sure|Continue\?|Proceed\?)`, +) + +// stallWatchdog monitors a background task's output for stalls that might +// indicate the command is waiting for interactive input. If detected, it +// enqueues a notification advising the agent to kill and retry. +func (m *Manager) stallWatchdog(ctx context.Context, task *Task) { + ticker := time.NewTicker(stallCheckInterval) + defer ticker.Stop() + + var lastLen int + var stalledSince time.Time + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + + task.mu.Lock() + if task.Status != TaskRunning { + task.mu.Unlock() + return + } + currentLen := task.output.Len() + // Read tail inline (we already hold the lock). + tail := task.output.String() + if len(tail) > maxTailBytes { + tail = tail[len(tail)-maxTailBytes:] + } + task.mu.Unlock() + + if currentLen != lastLen { + // Output is still growing — reset stall timer. + lastLen = currentLen + stalledSince = time.Time{} + continue + } + + // Output hasn't grown. + if stalledSince.IsZero() { + stalledSince = time.Now() + continue + } + + if time.Since(stalledSince) < stallThreshold { + continue + } + + // Stalled long enough. Check if the tail looks like an interactive prompt. + if !promptPatterns.MatchString(tail) { + continue + } + + m.logger.Warn("background task appears stalled on interactive prompt", + slog.String("task_id", task.ID), + ) + + // Enqueue a stall notification (only once). + if !task.MarkNotified() { + return + } + + n := Notification{ + TaskID: task.ID, + BotID: task.BotID, + SessionID: task.SessionID, + Status: TaskRunning, // still running, but stalled + Command: task.Command, + Description: task.Description, + ExitCode: 0, + OutputFile: task.OutputFile, + OutputTail: tail, + Duration: time.Since(task.StartedAt), + Stalled: true, + } + + m.enqueueNotification(n) + return // only notify once per task + } +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +} + +func joinLines(lines []string) string { + if len(lines) == 0 { + return "" + } + return strings.Join(lines, "\n") + "\n" +} + +func detachedContextWithTimeout(parentCtx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + if parentCtx == nil { + parentCtx = context.Background() + } + return context.WithTimeout(context.WithoutCancel(parentCtx), timeout) +} diff --git a/internal/agent/background/manager_test.go b/internal/agent/background/manager_test.go new file mode 100644 index 00000000..e53907d7 --- /dev/null +++ b/internal/agent/background/manager_test.go @@ -0,0 +1,429 @@ +package background + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/memohai/memoh/internal/workspace/bridge" +) + +// waitDrain polls DrainNotifications until the expected count is reached or timeout. +func waitDrain(t *testing.T, mgr *Manager, botID, sessionID string, wantCount int) []Notification { + t.Helper() + deadline := time.After(5 * time.Second) + var all []Notification + for { + all = append(all, mgr.DrainNotifications(botID, sessionID)...) + if len(all) >= wantCount { + return all + } + select { + case <-deadline: + t.Fatalf("timed out waiting for %d notifications, got %d", wantCount, len(all)) + case <-time.After(10 * time.Millisecond): + } + } +} + +func TestSpawnAndNotify(t *testing.T) { + mgr := New(nil) + + called := make(chan struct{}) + execFn := func(_ context.Context, _, _ string, _ int32) (*bridge.ExecResult, error) { + close(called) + return &bridge.ExecResult{Stdout: "hello world\n", ExitCode: 0}, nil + } + + taskID, outputFile := mgr.Spawn(context.Background(), "bot1", "sess1", "echo hello", "/data", "test echo", execFn, nil, nil) + + if taskID == "" { + t.Fatal("expected non-empty task ID") + } + if outputFile == "" { + t.Fatal("expected non-empty output file") + } + + // Wait for exec to be called. + select { + case <-called: + case <-time.After(5 * time.Second): + t.Fatal("execFn was not called within timeout") + } + + // Wait for notification. + notifications := waitDrain(t, mgr, "bot1", "sess1", 1) + n := notifications[0] + if n.TaskID != taskID { + t.Errorf("expected task ID %s, got %s", taskID, n.TaskID) + } + if n.Status != TaskCompleted { + t.Errorf("expected status completed, got %s", n.Status) + } + if n.ExitCode != 0 { + t.Errorf("expected exit code 0, got %d", n.ExitCode) + } + if n.BotID != "bot1" || n.SessionID != "sess1" { + t.Errorf("unexpected bot/session: %s/%s", n.BotID, n.SessionID) + } + + // Verify task state. + task := mgr.Get(taskID) + if task == nil { + t.Fatal("task not found after completion") + } + if task.Status != TaskCompleted { + t.Errorf("expected task status completed, got %s", task.Status) + } +} + +func TestSpawnFailedCommand(t *testing.T) { + mgr := New(nil) + + execFn := func(_ context.Context, _, _ string, _ int32) (*bridge.ExecResult, error) { + return &bridge.ExecResult{ + Stdout: "some output\n", + Stderr: "error: not found\n", + ExitCode: 1, + }, nil + } + + taskID, _ := mgr.Spawn(context.Background(), "bot1", "sess1", "false", "/data", "failing cmd", execFn, nil, nil) + + notifications := waitDrain(t, mgr, "bot1", "sess1", 1) + n := notifications[0] + if n.TaskID != taskID { + t.Errorf("expected task ID %s, got %s", taskID, n.TaskID) + } + if n.Status != TaskFailed { + t.Errorf("expected status failed, got %s", n.Status) + } + if n.ExitCode != 1 { + t.Errorf("expected exit code 1, got %d", n.ExitCode) + } +} + +func TestKillTask(t *testing.T) { + mgr := New(nil) + + started := make(chan struct{}) + execFn := func(ctx context.Context, _, _ string, _ int32) (*bridge.ExecResult, error) { + close(started) + <-ctx.Done() + return &bridge.ExecResult{ExitCode: -1}, ctx.Err() + } + + taskID, _ := mgr.Spawn(context.Background(), "bot1", "sess1", "sleep 300", "/data", "long task", execFn, nil, nil) + + // Wait for the task to start. + select { + case <-started: + case <-time.After(5 * time.Second): + t.Fatal("task did not start within timeout") + } + + if err := mgr.Kill(taskID); err != nil { + t.Fatalf("kill failed: %v", err) + } + + task := mgr.Get(taskID) + if task == nil { + t.Fatal("task not found") + } + if task.Status != TaskKilled { + t.Errorf("expected status killed, got %s", task.Status) + } + + // Killed tasks should not produce notifications. + time.Sleep(50 * time.Millisecond) // give goroutine time to finish + notifications := mgr.DrainNotifications("bot1", "sess1") + if len(notifications) != 0 { + t.Errorf("expected no notifications for killed task, got %d", len(notifications)) + } +} + +func TestGetForSession(t *testing.T) { + mgr := New(nil) + + taskID, _ := mgr.Spawn(context.Background(), "bot1", "sess1", "echo hello", "/data", "", func(_ context.Context, _, _ string, _ int32) (*bridge.ExecResult, error) { + return &bridge.ExecResult{Stdout: "hello\n", ExitCode: 0}, nil + }, nil, nil) + + if task := mgr.GetForSession("bot1", "sess1", taskID); task == nil { + t.Fatal("expected task to be visible within the owning session") + } + if task := mgr.GetForSession("bot1", "sess2", taskID); task != nil { + t.Fatal("expected task to be hidden from other sessions") + } +} + +func TestKillForSession(t *testing.T) { + mgr := New(nil) + + started := make(chan struct{}) + execFn := func(ctx context.Context, _, _ string, _ int32) (*bridge.ExecResult, error) { + close(started) + <-ctx.Done() + return &bridge.ExecResult{ExitCode: -1}, ctx.Err() + } + + taskID, _ := mgr.Spawn(context.Background(), "bot1", "sess1", "sleep 300", "/data", "long task", execFn, nil, nil) + select { + case <-started: + case <-time.After(5 * time.Second): + t.Fatal("task did not start within timeout") + } + + if err := mgr.KillForSession("bot1", "sess2", taskID); err == nil { + t.Fatal("expected kill from another session to fail") + } + if err := mgr.KillForSession("bot1", "sess1", taskID); err != nil { + t.Fatalf("expected kill from owning session to succeed: %v", err) + } +} + +func TestListForSession(t *testing.T) { + mgr := New(nil) + + started := make(chan struct{}, 2) + execFn := func(ctx context.Context, _, _ string, _ int32) (*bridge.ExecResult, error) { + started <- struct{}{} + <-ctx.Done() + return &bridge.ExecResult{ExitCode: -1}, ctx.Err() + } + + mgr.Spawn(context.Background(), "bot1", "sess1", "cmd1", "/data", "d1", execFn, nil, nil) + mgr.Spawn(context.Background(), "bot1", "sess1", "cmd2", "/data", "d2", execFn, nil, nil) + mgr.Spawn(context.Background(), "bot2", "sess2", "cmd3", "/data", "d3", execFn, nil, nil) + + // Wait for all to start. + for range 3 { + <-started + } + + tasks := mgr.ListForSession("bot1", "sess1") + if len(tasks) != 2 { + t.Errorf("expected 2 tasks for bot1/sess1, got %d", len(tasks)) + } + + tasks = mgr.ListForSession("bot2", "sess2") + if len(tasks) != 1 { + t.Errorf("expected 1 task for bot2/sess2, got %d", len(tasks)) + } +} + +func TestDrainNotifications(t *testing.T) { + mgr := New(nil) + + done := make(chan struct{}, 3) + execFn := func(_ context.Context, _, _ string, _ int32) (*bridge.ExecResult, error) { + defer func() { done <- struct{}{} }() + return &bridge.ExecResult{Stdout: "ok\n", ExitCode: 0}, nil + } + + mgr.Spawn(context.Background(), "bot1", "sess1", "echo 1", "/data", "", execFn, nil, nil) + mgr.Spawn(context.Background(), "bot1", "sess2", "echo 2", "/data", "", execFn, nil, nil) + mgr.Spawn(context.Background(), "bot2", "sess1", "echo 3", "/data", "", execFn, nil, nil) + + // Wait for all tasks to complete. + for range 3 { + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("task did not complete within timeout") + } + } + + // Drain only bot1/sess1. + notifications := waitDrain(t, mgr, "bot1", "sess1", 1) + if len(notifications) != 1 { + t.Errorf("expected 1 notification for bot1/sess1, got %d", len(notifications)) + } + + // The other two should still be pending. + notifications = waitDrain(t, mgr, "bot1", "sess2", 1) + if len(notifications) != 1 { + t.Errorf("expected 1 notification for bot1/sess2, got %d", len(notifications)) + } +} + +func TestMarkNotifiedPreventsDoubleNotification(t *testing.T) { + mgr := New(nil) + + // Simulate two goroutines racing to complete/notify. + execFn := func(_ context.Context, _, _ string, _ int32) (*bridge.ExecResult, error) { + return &bridge.ExecResult{Stdout: "ok\n", ExitCode: 0}, nil + } + + taskID, _ := mgr.Spawn(context.Background(), "bot1", "sess1", "echo hi", "/data", "", execFn, nil, nil) + notifications := waitDrain(t, mgr, "bot1", "sess1", 1) + if len(notifications) != 1 { + t.Fatalf("expected exactly 1 notification, got %d", len(notifications)) + } + if notifications[0].TaskID != taskID { + t.Errorf("unexpected task ID: %s", notifications[0].TaskID) + } + + // Calling MarkNotified again should return false (already notified). + task := mgr.Get(taskID) + if task.MarkNotified() { + t.Error("MarkNotified should return false on second call") + } + + // No additional notifications should appear. + extra := mgr.DrainNotifications("bot1", "sess1") + if len(extra) != 0 { + t.Errorf("expected no extra notifications, got %d", len(extra)) + } +} + +func TestStalledNotificationFormat(t *testing.T) { + n := Notification{ + TaskID: "bg_test_2", + Status: TaskRunning, + Command: "apt install -q foo", + OutputFile: "/tmp/memoh-bg/bg_test_2.log", + OutputTail: "Do you want to continue? [Y/n]", + Duration: 50 * time.Second, + Stalled: true, + } + + text := n.FormatForAgent() + for _, want := range []string{ + "stalled", + "", + "non-interactive", + } { + if !strings.Contains(text, want) { + t.Errorf("stalled notification missing %q:\n%s", want, text) + } + } + // Stalled notifications should NOT have exit-code. + if strings.Contains(text, "exit-code") { + t.Errorf("stalled notification should not contain exit-code:\n%s", text) + } +} + +func TestRunningTasksSummary(t *testing.T) { + mgr := New(nil) + + started := make(chan struct{}) + execFn := func(ctx context.Context, _, _ string, _ int32) (*bridge.ExecResult, error) { + close(started) + <-ctx.Done() + return &bridge.ExecResult{ExitCode: -1}, ctx.Err() + } + + mgr.Spawn(context.Background(), "bot1", "sess1", "npm test", "/data", "Run tests", execFn, nil, nil) + <-started + + summary := mgr.RunningTasksSummary("bot1", "sess1") + if !strings.Contains(summary, "Run tests") { + t.Errorf("summary should mention description, got: %s", summary) + } + if !strings.Contains(summary, "Currently running background tasks:") { + t.Errorf("summary should have header, got: %s", summary) + } + + // No tasks for other session. + if s := mgr.RunningTasksSummary("bot2", "sess2"); s != "" { + t.Errorf("expected empty summary for other session, got: %s", s) + } +} + +func TestCleanupRemovesOnlyOldCompletedTasks(t *testing.T) { + mgr := New(nil) + now := time.Now() + + mgr.tasks["old_done"] = &Task{ + ID: "old_done", + BotID: "bot1", + SessionID: "sess1", + Status: TaskCompleted, + CompletedAt: now.Add(-2 * time.Hour), + } + mgr.tasks["recent_done"] = &Task{ + ID: "recent_done", + BotID: "bot1", + SessionID: "sess1", + Status: TaskCompleted, + CompletedAt: now.Add(-10 * time.Minute), + } + mgr.tasks["running"] = &Task{ + ID: "running", + BotID: "bot1", + SessionID: "sess1", + Status: TaskRunning, + StartedAt: now.Add(-2 * time.Hour), + } + + mgr.Cleanup(time.Hour) + + if mgr.Get("old_done") != nil { + t.Fatal("expected old completed task to be cleaned up") + } + if mgr.Get("recent_done") == nil { + t.Fatal("expected recent completed task to be retained") + } + if mgr.Get("running") == nil { + t.Fatal("expected running task to be retained") + } +} + +func TestSpawnUsesRestartSafeTaskIDs(t *testing.T) { + mgr1 := New(nil) + mgr2 := New(nil) + + execFn := func(_ context.Context, _, _ string, _ int32) (*bridge.ExecResult, error) { + return &bridge.ExecResult{Stdout: "ok\n", ExitCode: 0}, nil + } + + taskID1, outputFile1 := mgr1.Spawn(context.Background(), "bot123456789", "sess1", "echo one", "/data", "", execFn, nil, nil) + taskID2, outputFile2 := mgr2.Spawn(context.Background(), "bot123456789", "sess1", "echo two", "/data", "", execFn, nil, nil) + + if taskID1 == taskID2 { + t.Fatalf("expected distinct task IDs across fresh managers, got %q", taskID1) + } + if outputFile1 == outputFile2 { + t.Fatalf("expected distinct output files across fresh managers, got %q", outputFile1) + } + if !strings.HasPrefix(taskID1, "bg_bot12345_") { + t.Fatalf("unexpected task ID format: %q", taskID1) + } + if !strings.HasPrefix(taskID2, "bg_bot12345_") { + t.Fatalf("unexpected task ID format: %q", taskID2) + } +} + +func TestNotificationFormat(t *testing.T) { + n := Notification{ + TaskID: "bg_test_1", + Status: TaskCompleted, + Command: "npm install", + Description: "Install dependencies", + ExitCode: 0, + OutputFile: "/tmp/memoh-bg/bg_test_1.log", + OutputTail: "added 1337 packages\n", + Duration: 45 * time.Second, + } + + text := n.FormatForAgent() + if text == "" { + t.Fatal("expected non-empty notification text") + } + for _, want := range []string{ + "", + "bg_test_1", + "completed", + "npm install", + "Install dependencies", + "/tmp/memoh-bg/bg_test_1.log", + "added 1337 packages", + "", + } { + if !strings.Contains(text, want) { + t.Errorf("notification text missing %q:\n%s", want, text) + } + } +} diff --git a/internal/agent/background/types.go b/internal/agent/background/types.go new file mode 100644 index 00000000..5cd6a335 --- /dev/null +++ b/internal/agent/background/types.go @@ -0,0 +1,159 @@ +package background + +import ( + "context" + "fmt" + "strings" + "sync" + "time" +) + +// TaskStatus represents the lifecycle state of a background task. +type TaskStatus string + +const ( + TaskRunning TaskStatus = "running" + TaskCompleted TaskStatus = "completed" + TaskFailed TaskStatus = "failed" + TaskKilled TaskStatus = "killed" +) + +// Task represents a single background command execution. +type Task struct { + ID string + BotID string + SessionID string + Command string + Description string + WorkDir string + Status TaskStatus + ExitCode int32 + OutputFile string // path inside container where output is being written + StartedAt time.Time + CompletedAt time.Time + + mu sync.Mutex + cancel context.CancelFunc + notified bool // true once a notification has been enqueued; prevents duplicates + output strings.Builder // buffered output tail +} + +// MarkNotified atomically sets the notified flag. Returns true if this call +// was the one that flipped it (i.e., the caller should enqueue the notification). +func (t *Task) MarkNotified() bool { + t.mu.Lock() + defer t.mu.Unlock() + if t.notified { + return false + } + t.notified = true + return true +} + +// Cancel requests cancellation of the task's context. +func (t *Task) Cancel() { + t.mu.Lock() + defer t.mu.Unlock() + if t.cancel != nil { + t.cancel() + } +} + +// AppendOutput appends text to the buffered output tail. +// Only the last maxTailBytes are kept. +func (t *Task) AppendOutput(s string) { + t.mu.Lock() + defer t.mu.Unlock() + t.output.WriteString(s) + // Keep tail bounded + if t.output.Len() > maxTailBytes*2 { + tail := t.output.String() + t.output.Reset() + if len(tail) > maxTailBytes { + t.output.WriteString(tail[len(tail)-maxTailBytes:]) + } else { + t.output.WriteString(tail) + } + } +} + +// OutputTail returns the last portion of collected output. +func (t *Task) OutputTail() string { + t.mu.Lock() + defer t.mu.Unlock() + s := t.output.String() + if len(s) > maxTailBytes { + return s[len(s)-maxTailBytes:] + } + return s +} + +const maxTailBytes = 4096 + +// AdoptResult carries the outcome of a command whose execution was started +// externally (e.g. via ExecStream) and then handed off to the Manager. +type AdoptResult struct { + Stdout string + Stderr string + ExitCode int32 + Err error +} + +// Notification is the structured event sent to the agent when a background +// task reaches a terminal state or requires attention (e.g. stalled). +type Notification struct { + TaskID string + BotID string + SessionID string + Status TaskStatus + Command string + Description string + ExitCode int32 + OutputFile string + OutputTail string // last N bytes of output for quick summary + Duration time.Duration + Stalled bool // true when task appears stuck on interactive input +} + +// MessageText returns the full user-message text that should be injected into +// the agent's message stream — a human lead-in line followed by the +// block. +func (n Notification) MessageText() string { + lead := "A background task completed:" + if n.Stalled { + lead = "A background task appears stuck and may need attention:" + } + return lead + "\n" + n.FormatForAgent() +} + +// FormatForAgent returns a human-readable task-notification block that can be +// injected into the agent's message stream. +func (n Notification) FormatForAgent() string { + var b strings.Builder + fmt.Fprintf(&b, "\n") + fmt.Fprintf(&b, " %s\n", n.TaskID) + if n.Stalled { + fmt.Fprintf(&b, " stalled\n") + } else { + fmt.Fprintf(&b, " %s\n", n.Status) + } + fmt.Fprintf(&b, " %s\n", n.Command) + if n.Description != "" { + fmt.Fprintf(&b, " %s\n", n.Description) + } + if !n.Stalled { + fmt.Fprintf(&b, " %d\n", n.ExitCode) + } + fmt.Fprintf(&b, " %s\n", n.Duration.Round(time.Millisecond)) + if n.OutputFile != "" { + fmt.Fprintf(&b, " %s\n", n.OutputFile) + } + if n.OutputTail != "" { + fmt.Fprintf(&b, " \n%s\n \n", strings.TrimRight(n.OutputTail, "\n")) + } + if n.Stalled { + fmt.Fprintf(&b, " This command appears to be waiting for interactive input. Kill it with bg_status and retry with a non-interactive flag (e.g. -y, --yes, --non-interactive).\n") + } + fmt.Fprintf(&b, "") + return b.String() +} diff --git a/internal/agent/background_exec_e2e_test.go b/internal/agent/background_exec_e2e_test.go new file mode 100644 index 00000000..1073eb98 --- /dev/null +++ b/internal/agent/background_exec_e2e_test.go @@ -0,0 +1,580 @@ +package agent + +import ( + "context" + "encoding/json" + "fmt" + "net" + "strings" + "sync" + "testing" + "time" + + sdk "github.com/memohai/twilight-ai/sdk" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" + + "github.com/memohai/memoh/internal/agent/background" + agenttools "github.com/memohai/memoh/internal/agent/tools" + "github.com/memohai/memoh/internal/workspace/bridge" + pb "github.com/memohai/memoh/internal/workspace/bridgepb" +) + +// --------------------------------------------------------------------------- +// Mock container service with controllable Exec behavior +// --------------------------------------------------------------------------- + +type execBehavior struct { + stdout string + stderr string + exitCode int32 + delay time.Duration // how long before sending output +} + +type mockExecContainerService struct { + pb.UnimplementedContainerServiceServer + + mu sync.Mutex + behaviors map[string]execBehavior // command prefix -> behavior + written map[string][]byte // path -> content (WriteFile) +} + +func newMockExecContainerService() *mockExecContainerService { + return &mockExecContainerService{ + behaviors: make(map[string]execBehavior), + written: make(map[string][]byte), + } +} + +func (s *mockExecContainerService) setBehavior(cmdPrefix string, b execBehavior) { + s.mu.Lock() + defer s.mu.Unlock() + s.behaviors[cmdPrefix] = b +} + +func (s *mockExecContainerService) findBehavior(cmd string) (execBehavior, bool) { + s.mu.Lock() + defer s.mu.Unlock() + for prefix, b := range s.behaviors { + if strings.Contains(cmd, prefix) { + return b, true + } + } + return execBehavior{}, false +} + +func (s *mockExecContainerService) Exec(stream pb.ContainerService_ExecServer) error { + // Read config message. + input, err := stream.Recv() + if err != nil { + return err + } + cmd := input.GetCommand() + + b, ok := s.findBehavior(cmd) + if !ok { + // Default: instant success with echoed command. + b = execBehavior{stdout: fmt.Sprintf("[executed] %s\n", cmd), exitCode: 0} + } + + if b.delay > 0 { + select { + case <-time.After(b.delay): + case <-stream.Context().Done(): + return stream.Context().Err() + } + } + + if b.stdout != "" { + if err := stream.Send(&pb.ExecOutput{ + Stream: pb.ExecOutput_STDOUT, + Data: []byte(b.stdout), + }); err != nil { + return err + } + } + if b.stderr != "" { + if err := stream.Send(&pb.ExecOutput{ + Stream: pb.ExecOutput_STDERR, + Data: []byte(b.stderr), + }); err != nil { + return err + } + } + return stream.Send(&pb.ExecOutput{ + Stream: pb.ExecOutput_EXIT, + ExitCode: b.exitCode, + }) +} + +func (s *mockExecContainerService) WriteFile(_ context.Context, req *pb.WriteFileRequest) (*pb.WriteFileResponse, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.written[req.GetPath()] = req.GetContent() + return &pb.WriteFileResponse{}, nil +} + +// --------------------------------------------------------------------------- +// Test helpers +// --------------------------------------------------------------------------- + +func setupExecTestInfra(t *testing.T, svc *mockExecContainerService) (bridge.Provider, func()) { + t.Helper() + + lis := bufconn.Listen(1 << 20) + srv := grpc.NewServer() + pb.RegisterContainerServiceServer(srv, svc) + + done := make(chan struct{}) + go func() { + defer close(done) + _ = srv.Serve(lis) + }() + + dialer := func(ctx context.Context, _ string) (net.Conn, error) { + return lis.DialContext(ctx) + } + conn, err := grpc.NewClient( + "passthrough://bufnet", + grpc.WithContextDialer(dialer), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("grpc.NewClient: %v", err) + } + + cleanup := func() { + _ = conn.Close() + srv.Stop() + <-done + } + + bp := &agentReadMediaBridgeProvider{client: bridge.NewClientFromConn(conn)} + return bp, cleanup +} + +// --------------------------------------------------------------------------- +// E2E Test: Explicit background exec +// --------------------------------------------------------------------------- + +func TestE2E_ExplicitBackgroundExec(t *testing.T) { + t.Parallel() + + svc := newMockExecContainerService() + svc.setBehavior("npm install", execBehavior{ + stdout: "added 42 packages\n", + exitCode: 0, + delay: 100 * time.Millisecond, // simulate some work + }) + + bp, cleanup := setupExecTestInfra(t, svc) + defer cleanup() + + bgMgr := background.New(nil) + + // Model calls exec with run_in_background, then on step 2 sees notification. + var step2Params sdk.GenerateParams + modelProvider := &agentReadMediaMockProvider{ + handler: func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error) { + switch call { + case 1: + // Model decides to run npm install in background. + return &sdk.GenerateResult{ + FinishReason: sdk.FinishReasonToolCalls, + ToolCalls: []sdk.ToolCall{{ + ToolCallID: "call-1", + ToolName: "exec", + Input: map[string]any{ + "command": "npm install", + "run_in_background": true, + "description": "Install dependencies", + }, + }}, + }, nil + case 2: + // Model sees tool result with background_started. + // It should do something else or reply. + // Simulate waiting a bit so the background task has time to complete. + time.Sleep(300 * time.Millisecond) + return &sdk.GenerateResult{ + FinishReason: sdk.FinishReasonToolCalls, + ToolCalls: []sdk.ToolCall{{ + ToolCallID: "call-2", + ToolName: "exec", + Input: map[string]any{ + "command": "echo hello", + }, + }}, + }, nil + case 3: + // Step 3: model should see the background notification injected. + step2Params = params + return &sdk.GenerateResult{ + Text: "All done!", + FinishReason: sdk.FinishReasonStop, + }, nil + default: + return &sdk.GenerateResult{ + Text: "unexpected", + FinishReason: sdk.FinishReasonStop, + }, nil + } + }, + } + + a := New(Deps{}) + a.SetToolProviders([]agenttools.ToolProvider{ + agenttools.NewContainerProvider(nil, bp, bgMgr, "/data"), + }) + + result, err := a.Generate(context.Background(), RunConfig{ + Model: &sdk.Model{ID: "mock", Provider: modelProvider}, + Messages: []sdk.Message{sdk.UserMessage("install deps and say hi")}, + System: "You are a helpful bot.", + SupportsToolCall: true, + Identity: SessionContext{BotID: "bot-test-1", SessionID: "sess-1"}, + BackgroundManager: bgMgr, + }) + if err != nil { + t.Fatalf("Generate error: %v", err) + } + if result.Text != "All done!" { + t.Errorf("unexpected text: %q", result.Text) + } + + // Verify step 2 params contain the background notification. + found := false + for _, msg := range step2Params.Messages { + if msg.Role == sdk.MessageRoleUser { + for _, part := range msg.Content { + if tp, ok := part.(sdk.TextPart); ok { + if strings.Contains(tp.Text, "task-notification") && + strings.Contains(tp.Text, "completed") { + found = true + } + } + } + } + } + if !found { + t.Error("expected background task notification to be injected into step 3 messages") + } +} + +// --------------------------------------------------------------------------- +// E2E Test: Foreground timeout flips to background +// --------------------------------------------------------------------------- + +func TestE2E_ForegroundTimeoutFlip(t *testing.T) { + t.Parallel() + + svc := newMockExecContainerService() + // Command takes 3 seconds — longer than our 1-second soft timeout. + svc.setBehavior("slow-build", execBehavior{ + stdout: "build completed successfully\n", + exitCode: 0, + delay: 3 * time.Second, + }) + + bp, cleanup := setupExecTestInfra(t, svc) + defer cleanup() + + bgMgr := background.New(nil) + + var toolResult map[string]any + modelProvider := &agentReadMediaMockProvider{ + handler: func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error) { + switch call { + case 1: + // Model runs a command with short timeout (will flip). + return &sdk.GenerateResult{ + FinishReason: sdk.FinishReasonToolCalls, + ToolCalls: []sdk.ToolCall{{ + ToolCallID: "call-1", + ToolName: "exec", + Input: map[string]any{ + "command": "slow-build", + "timeout": 1, // 1 second — will flip + "description": "Run slow build", + }, + }}, + }, nil + case 2: + // Extract the tool result from step 1. + toolResult = extractToolResult(t, params, "call-1") + return &sdk.GenerateResult{ + Text: "Build moved to background.", + FinishReason: sdk.FinishReasonStop, + }, nil + default: + return &sdk.GenerateResult{ + Text: "unexpected", + FinishReason: sdk.FinishReasonStop, + }, nil + } + }, + } + + a := New(Deps{}) + a.SetToolProviders([]agenttools.ToolProvider{ + agenttools.NewContainerProvider(nil, bp, bgMgr, "/data"), + }) + + result, err := a.Generate(context.Background(), RunConfig{ + Model: &sdk.Model{ID: "mock", Provider: modelProvider}, + Messages: []sdk.Message{sdk.UserMessage("run the build")}, + System: "You are a helpful bot.", + SupportsToolCall: true, + Identity: SessionContext{BotID: "bot-test-2", SessionID: "sess-2"}, + BackgroundManager: bgMgr, + }) + if err != nil { + t.Fatalf("Generate error: %v", err) + } + if result.Text != "Build moved to background." { + t.Errorf("unexpected text: %q", result.Text) + } + + // The tool result should indicate auto_backgrounded. + if toolResult == nil { + t.Fatal("tool result not captured") + } + status, _ := toolResult["status"].(string) + if status != "auto_backgrounded" { + t.Errorf("expected status auto_backgrounded, got %q", status) + } + taskID, _ := toolResult["task_id"].(string) + if taskID == "" { + t.Error("expected non-empty task_id") + } + msg, _ := toolResult["message"].(string) + if !strings.Contains(msg, "no work was lost") { + t.Errorf("expected flip message mentioning no work lost, got %q", msg) + } + + // Wait for the background task to complete and verify notification. + deadline := time.After(10 * time.Second) + for { + notifications := bgMgr.DrainNotifications("bot-test-2", "sess-2") + if len(notifications) > 0 { + n := notifications[0] + if n.Status != background.TaskCompleted { + t.Errorf("expected completed, got %s", n.Status) + } + if !strings.Contains(n.OutputTail, "build completed") { + t.Errorf("expected build output in tail, got %q", n.OutputTail) + } + break + } + select { + case <-deadline: + t.Fatal("timed out waiting for background task notification") + case <-time.After(50 * time.Millisecond): + } + } +} + +// --------------------------------------------------------------------------- +// E2E Test: Sleep command rejection +// --------------------------------------------------------------------------- + +func TestE2E_SleepRejection(t *testing.T) { + t.Parallel() + + svc := newMockExecContainerService() + bp, cleanup := setupExecTestInfra(t, svc) + defer cleanup() + + bgMgr := background.New(nil) + + var sleepToolResult map[string]any + var sleepWasError bool + modelProvider := &agentReadMediaMockProvider{ + handler: func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error) { + switch call { + case 1: + // Model tries to sleep 10. + return &sdk.GenerateResult{ + FinishReason: sdk.FinishReasonToolCalls, + ToolCalls: []sdk.ToolCall{{ + ToolCallID: "call-1", + ToolName: "exec", + Input: map[string]any{ + "command": "sleep 10", + }, + }}, + }, nil + case 2: + // Check the tool result — should be an error. + sleepToolResult, sleepWasError = extractToolResultWithError(params, "call-1") + return &sdk.GenerateResult{ + Text: "Got it, won't sleep.", + FinishReason: sdk.FinishReasonStop, + }, nil + default: + return &sdk.GenerateResult{ + Text: "unexpected", + FinishReason: sdk.FinishReasonStop, + }, nil + } + }, + } + + a := New(Deps{}) + a.SetToolProviders([]agenttools.ToolProvider{ + agenttools.NewContainerProvider(nil, bp, bgMgr, "/data"), + }) + + result, err := a.Generate(context.Background(), RunConfig{ + Model: &sdk.Model{ID: "mock", Provider: modelProvider}, + Messages: []sdk.Message{sdk.UserMessage("wait 10 seconds")}, + System: "You are a bot.", + SupportsToolCall: true, + Identity: SessionContext{BotID: "bot-test-3", SessionID: "sess-3"}, + BackgroundManager: bgMgr, + }) + if err != nil { + t.Fatalf("Generate error: %v", err) + } + if result.Text != "Got it, won't sleep." { + t.Errorf("unexpected text: %q", result.Text) + } + + if !sleepWasError { + t.Error("expected sleep command to return is_error=true") + } + _ = sleepToolResult // the error message is in the tool result +} + +// --------------------------------------------------------------------------- +// E2E Test: Running tasks summary injection +// --------------------------------------------------------------------------- + +func TestE2E_RunningTasksSummaryInjected(t *testing.T) { + t.Parallel() + + svc := newMockExecContainerService() + // Long-running task that won't complete during the test. + svc.setBehavior("long-task", execBehavior{ + delay: 30 * time.Second, + }) + + bp, cleanup := setupExecTestInfra(t, svc) + defer cleanup() + + bgMgr := background.New(nil) + + var step3System string + modelProvider := &agentReadMediaMockProvider{ + handler: func(call int, params sdk.GenerateParams) (*sdk.GenerateResult, error) { + switch call { + case 1: + return &sdk.GenerateResult{ + FinishReason: sdk.FinishReasonToolCalls, + ToolCalls: []sdk.ToolCall{{ + ToolCallID: "call-1", + ToolName: "exec", + Input: map[string]any{ + "command": "long-task", + "run_in_background": true, + "description": "Long running task", + }, + }}, + }, nil + case 2: + // Do another tool call so prepareStep fires again. + return &sdk.GenerateResult{ + FinishReason: sdk.FinishReasonToolCalls, + ToolCalls: []sdk.ToolCall{{ + ToolCallID: "call-2", + ToolName: "exec", + Input: map[string]any{ + "command": "echo check", + }, + }}, + }, nil + case 3: + // Capture the system prompt which should include running tasks. + step3System = params.System + return &sdk.GenerateResult{ + Text: "Done checking.", + FinishReason: sdk.FinishReasonStop, + }, nil + default: + return &sdk.GenerateResult{ + Text: "unexpected", + FinishReason: sdk.FinishReasonStop, + }, nil + } + }, + } + + a := New(Deps{}) + a.SetToolProviders([]agenttools.ToolProvider{ + agenttools.NewContainerProvider(nil, bp, bgMgr, "/data"), + }) + + _, err := a.Generate(context.Background(), RunConfig{ + Model: &sdk.Model{ID: "mock", Provider: modelProvider}, + Messages: []sdk.Message{sdk.UserMessage("start background and check")}, + System: "You are a bot.", + SupportsToolCall: true, + Identity: SessionContext{BotID: "bot-test-4", SessionID: "sess-4"}, + BackgroundManager: bgMgr, + }) + if err != nil { + t.Fatalf("Generate error: %v", err) + } + + if !strings.Contains(step3System, "Currently running background tasks:") { + t.Error("expected running tasks summary in system prompt") + } + if !strings.Contains(step3System, "Long running task") { + t.Errorf("expected task description in system prompt, got: %s", step3System) + } +} + +// --------------------------------------------------------------------------- +// Helpers for extracting tool results from params +// --------------------------------------------------------------------------- + +func extractToolResult(t *testing.T, params sdk.GenerateParams, toolCallID string) map[string]any { + t.Helper() + for _, msg := range params.Messages { + if msg.Role != sdk.MessageRoleTool { + continue + } + for _, part := range msg.Content { + tr, ok := part.(sdk.ToolResultPart) + if !ok || tr.ToolCallID != toolCallID { + continue + } + raw, _ := json.Marshal(tr.Result) + var m map[string]any + _ = json.Unmarshal(raw, &m) + return m + } + } + t.Fatalf("tool result for %s not found in params", toolCallID) + return nil +} + +func extractToolResultWithError(params sdk.GenerateParams, toolCallID string) (map[string]any, bool) { + for _, msg := range params.Messages { + if msg.Role != sdk.MessageRoleTool { + continue + } + for _, part := range msg.Content { + tr, ok := part.(sdk.ToolResultPart) + if !ok || tr.ToolCallID != toolCallID { + continue + } + raw, _ := json.Marshal(tr.Result) + var m map[string]any + _ = json.Unmarshal(raw, &m) + return m, tr.IsError + } + } + return nil, false +} diff --git a/internal/agent/read_media_test.go b/internal/agent/read_media_test.go index 2c7b9878..49cd0cb2 100644 --- a/internal/agent/read_media_test.go +++ b/internal/agent/read_media_test.go @@ -260,7 +260,7 @@ func TestAgentGenerateReadMediaInjectsImageIntoNextStep(t *testing.T) { a := New(Deps{}) a.SetToolProviders([]agenttools.ToolProvider{ - agenttools.NewContainerProvider(nil, bp, "/data"), + agenttools.NewContainerProvider(nil, bp, nil, "/data"), }) result, err := a.Generate(context.Background(), RunConfig{ @@ -340,7 +340,7 @@ func TestAgentGenerateReadMediaInjectsAnthropicSafeImageIntoNextStep(t *testing. a := New(Deps{}) a.SetToolProviders([]agenttools.ToolProvider{ - agenttools.NewContainerProvider(nil, bp, "/data"), + agenttools.NewContainerProvider(nil, bp, nil, "/data"), }) _, err := a.Generate(context.Background(), RunConfig{ @@ -390,7 +390,7 @@ func TestAgentStreamReadMediaPersistsInjectedImageInTerminalMessages(t *testing. a := New(Deps{}) a.SetToolProviders([]agenttools.ToolProvider{ - agenttools.NewContainerProvider(nil, bp, "/data"), + agenttools.NewContainerProvider(nil, bp, nil, "/data"), }) var terminal StreamEvent diff --git a/internal/agent/tools/container.go b/internal/agent/tools/container.go index b5fc76b0..a048a00e 100644 --- a/internal/agent/tools/container.go +++ b/internal/agent/tools/container.go @@ -8,14 +8,22 @@ import ( "io" "log/slog" "math" + "regexp" + "strconv" "strings" "time" sdk "github.com/memohai/twilight-ai/sdk" + "github.com/memohai/memoh/internal/agent/background" "github.com/memohai/memoh/internal/workspace/bridge" + pb "github.com/memohai/memoh/internal/workspace/bridgepb" ) +// blockedSleepPattern matches standalone `sleep N` where N >= 2. +// Does not match sleep inside pipelines, subshells, or scripts. +var blockedSleepPattern = regexp.MustCompile(`^sleep\s+(\d+(?:\.\d+)?)(?:\s*[;&]|$)`) + const defaultContainerExecWorkDir = "/data" // containerOpTimeout is the maximum time allowed for individual file @@ -29,11 +37,12 @@ const largeFileThreshold = 512 * 1024 // 512 KB type ContainerProvider struct { clients bridge.Provider + bgManager *background.Manager execWorkDir string logger *slog.Logger } -func NewContainerProvider(log *slog.Logger, clients bridge.Provider, execWorkDir string) *ContainerProvider { +func NewContainerProvider(log *slog.Logger, clients bridge.Provider, bgManager *background.Manager, execWorkDir string) *ContainerProvider { if log == nil { log = slog.Default() } @@ -41,7 +50,7 @@ func NewContainerProvider(log *slog.Logger, clients bridge.Provider, execWorkDir if wd == "" { wd = defaultContainerExecWorkDir } - return &ContainerProvider{clients: clients, execWorkDir: wd, logger: log.With(slog.String("tool", "container"))} + return &ContainerProvider{clients: clients, bgManager: bgManager, execWorkDir: wd, logger: log.With(slog.String("tool", "container"))} } func (p *ContainerProvider) Tools(_ context.Context, session SessionContext) ([]sdk.Tool, error) { @@ -119,13 +128,28 @@ func (p *ContainerProvider) Tools(_ context.Context, session SessionContext) ([] }, }, { - Name: "exec", - Description: fmt.Sprintf("Execute a command in the bot container. Runs in the bot's data directory (%s) by default.", wd), + Name: "exec", + Description: fmt.Sprintf(`Execute a shell command in the bot container. Runs in the bot's data directory (%s) by default. + +# Instructions +- Use this tool to run shell commands for installing packages, running scripts, building code, running tests, and other system operations. +- If your command will take a long time (package installs, builds, test suites), set run_in_background to true. You will be notified when it completes. You do not need to add '&' at the end of the command when using this parameter. +- If waiting for a background task, you will be notified when it completes — do NOT poll or sleep. +- You may specify a custom timeout (up to %d seconds) for commands you know will take longer than the default %d seconds. If a foreground command times out, it will be automatically moved to the background and you will be notified when it completes. +- Avoid unnecessary sleep commands: + - Do not sleep between commands that can run immediately — just run them. + - If your command is long running, use run_in_background. No sleep needed. + - Do not retry failing commands in a sleep loop — diagnose the root cause. + - If waiting for a background task, you will be notified when it completes automatically. + - sleep N (N >= 2) in foreground is blocked. If you genuinely need a short delay, keep it under 2 seconds.`, wd, background.MaxExecTimeout, background.DefaultExecTimeout), Parameters: map[string]any{ "type": "object", "properties": map[string]any{ - "command": map[string]any{"type": "string", "description": "Shell command to run (e.g. ls -la, cat file.txt)"}, - "work_dir": map[string]any{"type": "string", "description": fmt.Sprintf("Working directory inside the container (default: %s)", wd)}, + "command": map[string]any{"type": "string", "description": "Shell command to run (e.g. ls -la, npm install, python script.py)"}, + "work_dir": map[string]any{"type": "string", "description": fmt.Sprintf("Working directory inside the container (default: %s)", wd)}, + "description": map[string]any{"type": "string", "description": `Clear, concise description of what this command does in active voice. For simple commands keep it brief (5-10 words): ls -la → "List files with details". For complex commands add enough context: curl -s url | jq '.data[]' → "Fetch JSON and extract data array".`}, + "timeout": map[string]any{"type": "integer", "description": fmt.Sprintf("Timeout in seconds (default: %d, max: %d). Only applies to foreground execution. Commands that exceed this timeout are automatically moved to background.", background.DefaultExecTimeout, background.MaxExecTimeout), "minimum": 1, "maximum": background.MaxExecTimeout}, + "run_in_background": map[string]any{"type": "boolean", "description": "If true, run the command in the background. Returns immediately with a task ID. You will be notified when it completes. Use for long-running commands (installs, builds, test suites). You do not need to use '&' at the end of the command."}, }, "required": []string{"command"}, }, @@ -133,6 +157,21 @@ func (p *ContainerProvider) Tools(_ context.Context, session SessionContext) ([] return p.execExec(ctx.Context, sess, inputAsMap(input)) }, }, + { + Name: "bg_status", + Description: "Check the status of background tasks or kill a running one.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "action": map[string]any{"type": "string", "enum": []string{"list", "status", "kill"}, "description": "Action to perform: list all tasks, get status of one task, or kill a running task"}, + "task_id": map[string]any{"type": "string", "description": "Task ID (required for status and kill actions)"}, + }, + "required": []string{"action"}, + }, + Execute: func(ctx *sdk.ToolExecContext, input any) (any, error) { + return p.execBgStatus(ctx.Context, sess, inputAsMap(input)) + }, + }, }, nil } @@ -421,7 +460,44 @@ func (p *ContainerProvider) execExec(ctx context.Context, session SessionContext if workDir == "" { workDir = p.execWorkDir } - result, err := client.Exec(ctx, command, workDir, 30) + description := strings.TrimSpace(StringArg(args, "description")) + + // Parse timeout (default 30s, max 600s). + timeout := background.DefaultExecTimeout + if t, ok, err := IntArg(args, "timeout"); err != nil { + return nil, fmt.Errorf("invalid timeout: %w", err) + } else if ok { + if t < 1 { + return nil, errors.New("timeout must be >= 1") + } + maxTimeout := int(background.MaxExecTimeout) + if t > maxTimeout { + t = maxTimeout + } + timeout = int32(t) //nolint:gosec // bounded above + } + + // Block sleep N (N>=2) in foreground — nudge model toward run_in_background. + runInBg, _, _ := BoolArg(args, "run_in_background") + if !runInBg { + if reason := detectBlockedSleep(command); reason != "" { + return nil, fmt.Errorf("blocked: %s. Run blocking commands in the background with run_in_background: true — you'll get a completion notification when done. If you genuinely need a delay (rate limiting, deliberate pacing), keep it under 2 seconds", reason) + } + } + + // Background execution path. + if runInBg && p.bgManager != nil { + return p.execExecBackground(ctx, session, client, command, workDir, description) + } + + // If we have a background manager, use streaming exec so we can flip + // to background on timeout without killing the process. + if p.bgManager != nil { + return p.execExecWithFlip(ctx, session, client, command, workDir, description, timeout) + } + + // Fallback: no background manager, plain synchronous exec. + result, err := client.Exec(ctx, command, workDir, timeout) if err != nil { return nil, err } @@ -430,6 +506,257 @@ func (p *ContainerProvider) execExec(ctx context.Context, session SessionContext return map[string]any{"stdout": stdout, "stderr": stderr, "exit_code": result.ExitCode}, nil } +// execExecWithFlip runs a command via ExecStream with a client-side soft timeout. +// If the command finishes within the timeout, it returns the result normally. +// If the soft timeout fires first, the running stream is handed off to the +// background manager — the process keeps running in the container, and the +// agent gets an immediate "auto_backgrounded" response. +func (p *ContainerProvider) execExecWithFlip( + ctx context.Context, session SessionContext, client *bridge.Client, + command, workDir, description string, softTimeout int32, +) (any, error) { + // Start streaming exec with a large container-side timeout so the process + // keeps running even after we stop reading in the foreground. + // Use a fully independent context (not derived from the agent request ctx) + // so the gRPC stream is never cancelled when the foreground session ends. + streamCtx, streamCancel := context.WithTimeout(context.WithoutCancel(ctx), time.Duration(background.BackgroundExecTimeout)*time.Second) + stream, err := client.ExecStream(streamCtx, command, workDir, background.BackgroundExecTimeout) + if err != nil { + streamCancel() + return nil, err + } + + resultCh := make(chan background.AdoptResult, 1) + go func() { + defer streamCancel() + var stdout, stderr strings.Builder + var exitCode int32 + for { + msg, recvErr := stream.Recv() + if errors.Is(recvErr, io.EOF) { + break + } + if recvErr != nil { + p.logger.Warn("flip-to-background: stream recv error", + slog.String("command", truncateStr(command, 80)), + slog.Any("error", recvErr), + ) + resultCh <- background.AdoptResult{Err: recvErr} + return + } + switch msg.GetStream() { + case pb.ExecOutput_STDOUT: + stdout.Write(msg.GetData()) + case pb.ExecOutput_STDERR: + stderr.Write(msg.GetData()) + case pb.ExecOutput_EXIT: + exitCode = msg.GetExitCode() + } + } + resultCh <- background.AdoptResult{ + Stdout: stdout.String(), + Stderr: stderr.String(), + ExitCode: exitCode, + } + }() + + // Wait for either the result or soft timeout. + timer := time.NewTimer(time.Duration(softTimeout) * time.Second) + defer timer.Stop() + + select { + case r := <-resultCh: + // Command finished within the soft timeout — return normally. + if r.Err != nil { + return nil, r.Err + } + stdout := pruneToolOutputText(r.Stdout, "tool result (exec stdout)") + stderr := pruneToolOutputText(r.Stderr, "tool result (exec stderr)") + return map[string]any{"stdout": stdout, "stderr": stderr, "exit_code": r.ExitCode}, nil + + case <-timer.C: + // Soft timeout fired — flip the running stream to background. + // The container process is still alive; we hand off the stream reader + // goroutine to the background manager. + return p.flipToBackground(ctx, session, client, resultCh, command, workDir, description, softTimeout) + + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// flipToBackground registers the already-running stream as a background task. +// The goroutine reading from the stream continues; its result feeds the task. +func (p *ContainerProvider) flipToBackground( + ctx context.Context, + session SessionContext, client *bridge.Client, + resultCh <-chan background.AdoptResult, + command, workDir, description string, softTimeout int32, +) (any, error) { + writeFn := func(ctx context.Context, path string, data []byte) error { + return client.WriteFile(ctx, path, data) + } + + taskID, outputFile := p.bgManager.SpawnAdopt( + ctx, + session.BotID, session.SessionID, command, workDir, description, + resultCh, writeFn, + ) + + p.logger.Info("foreground exec flipped to background", + slog.String("task_id", taskID), + slog.String("command", truncateStr(command, 120)), + slog.Int("soft_timeout_seconds", int(softTimeout)), + ) + + return map[string]any{ + "status": "auto_backgrounded", + "task_id": taskID, + "output_file": outputFile, + "message": fmt.Sprintf( + "Command exceeded the foreground timeout (%ds) and has been moved to the background with task ID: %s. "+ + "The process is still running — no work was lost. "+ + "You will be notified when it completes. Output is being written to: %s. "+ + "For long-running commands, use run_in_background: true from the start to avoid this delay.", + softTimeout, taskID, outputFile, + ), + }, nil +} + +// detectBlockedSleep checks if the command starts with `sleep N` where N >= 2. +// Returns a human-readable reason string, or "" if the command is allowed. +func detectBlockedSleep(command string) string { + cmd := strings.TrimSpace(command) + m := blockedSleepPattern.FindStringSubmatch(cmd) + if m == nil { + return "" + } + seconds, err := strconv.ParseFloat(m[1], 64) + if err != nil || seconds < 2 { + return "" + } + return fmt.Sprintf("sleep %.0f is not allowed in foreground execution", seconds) +} + +// spawnBackground is the shared helper that registers a background task and +// returns (taskID, outputFile). Used by both explicit and auto-background paths. +func (p *ContainerProvider) spawnBackground( + ctx context.Context, + session SessionContext, client *bridge.Client, + command, workDir, description string, +) (taskID, outputFile string) { + execFn := func(ctx context.Context, cmd, wd string, timeout int32) (*bridge.ExecResult, error) { + return client.Exec(ctx, cmd, wd, timeout) + } + writeFn := func(ctx context.Context, path string, data []byte) error { + return client.WriteFile(ctx, path, data) + } + readFn := func(ctx context.Context, path string) ([]byte, error) { + // Use pool to get a fresh client — the original client may be in a failed + // state if the streaming exec errored, but the pool will re-dial as needed. + c, err := p.clients.MCPClient(ctx, session.BotID) + if err != nil { + return nil, err + } + resp, err := c.ReadFile(ctx, path, 1, 10) + if err != nil { + return nil, err + } + return []byte(resp.GetContent()), nil + } + + taskID, outputFile = p.bgManager.Spawn( + ctx, + session.BotID, session.SessionID, command, workDir, description, + execFn, writeFn, readFn, + ) + return taskID, outputFile +} + +// execExecBackground spawns the command as a background task and returns immediately. +func (p *ContainerProvider) execExecBackground( + ctx context.Context, session SessionContext, client *bridge.Client, + command, workDir, description string, +) (any, error) { + taskID, outputFile := p.spawnBackground(ctx, session, client, command, workDir, description) + + return map[string]any{ + "status": "background_started", + "task_id": taskID, + "output_file": outputFile, + "message": fmt.Sprintf("Command started in background with task ID: %s. You will be notified when it completes. Output is being written to: %s. Do NOT poll or sleep — you will receive a notification automatically.", taskID, outputFile), + }, nil +} + +// execBgStatus handles the bg_status tool for listing/checking/killing background tasks. +func (p *ContainerProvider) execBgStatus(_ context.Context, session SessionContext, args map[string]any) (any, error) { + if p.bgManager == nil { + return nil, errors.New("background task manager not available") + } + + action := strings.TrimSpace(StringArg(args, "action")) + taskID := strings.TrimSpace(StringArg(args, "task_id")) + + switch action { + case "list": + tasks := p.bgManager.ListForSession(session.BotID, session.SessionID) + entries := make([]map[string]any, 0, len(tasks)) + for _, t := range tasks { + entries = append(entries, map[string]any{ + "task_id": t.ID, + "command": truncateStr(t.Command, 120), + "description": t.Description, + "status": string(t.Status), + "output_file": t.OutputFile, + "started_at": session.FormatTime(t.StartedAt), + }) + } + return map[string]any{"tasks": entries, "count": len(entries)}, nil + + case "status": + if taskID == "" { + return nil, errors.New("task_id is required for status action") + } + task := p.bgManager.GetForSession(session.BotID, session.SessionID, taskID) + if task == nil { + return nil, fmt.Errorf("task %s not found", taskID) + } + result := map[string]any{ + "task_id": task.ID, + "command": task.Command, + "description": task.Description, + "status": string(task.Status), + "output_file": task.OutputFile, + "started_at": session.FormatTime(task.StartedAt), + } + if task.Status != background.TaskRunning { + result["exit_code"] = task.ExitCode + result["completed_at"] = session.FormatTime(task.CompletedAt) + result["output_tail"] = task.OutputTail() + } + return result, nil + + case "kill": + if taskID == "" { + return nil, errors.New("task_id is required for kill action") + } + if err := p.bgManager.KillForSession(session.BotID, session.SessionID, taskID); err != nil { + return nil, err + } + return map[string]any{"ok": true, "message": fmt.Sprintf("Task %s has been killed.", taskID)}, nil + + default: + return nil, fmt.Errorf("unknown action: %s (expected: list, status, kill)", action) + } +} + +func truncateStr(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +} + func addLineNumbers(content string, startLine int32) string { if content == "" { return content diff --git a/internal/agent/tools/container_test.go b/internal/agent/tools/container_test.go new file mode 100644 index 00000000..453ae264 --- /dev/null +++ b/internal/agent/tools/container_test.go @@ -0,0 +1,35 @@ +package tools + +import "testing" + +func TestDetectBlockedSleep(t *testing.T) { + tests := []struct { + command string + blocked bool + }{ + // Should block + {"sleep 5", true}, + {"sleep 10", true}, + {"sleep 30", true}, + {"sleep 5 && echo done", true}, + {"sleep 5; echo done", true}, + + // Should allow + {"sleep 1", false}, // under 2 seconds + {"sleep 0.5", false}, // under 2 seconds + {"echo hello", false}, // not sleep + {"npm install", false}, // not sleep + {"echo sleep 5", false}, // sleep not at start + {"cat sleep.txt", false}, // not the sleep command + } + + for _, tt := range tests { + result := detectBlockedSleep(tt.command) + if tt.blocked && result == "" { + t.Errorf("expected %q to be blocked, but it was allowed", tt.command) + } + if !tt.blocked && result != "" { + t.Errorf("expected %q to be allowed, but got: %s", tt.command, result) + } + } +} diff --git a/internal/agent/tools/message.go b/internal/agent/tools/message.go index f7f6a658..2dcd2b19 100644 --- a/internal/agent/tools/message.go +++ b/internal/agent/tools/message.go @@ -87,8 +87,8 @@ func (p *MessageProvider) execSend(ctx context.Context, session SessionContext, if err != nil { return nil, err } - // Discuss mode: same-conversation sends must go through the channel - // adapter directly — there is no active stream to emit events into. + // Discuss mode: same-conversation sends must go through the channel adapter + // directly — there is no active stream to emit into. if result.Local && session.SessionType == "discuss" { sendResult, err := p.exec.SendDirect(ctx, toMessagingSession(session), result.Target, args) if err != nil { diff --git a/internal/agent/types.go b/internal/agent/types.go index f4dc7fed..ac09a457 100644 --- a/internal/agent/types.go +++ b/internal/agent/types.go @@ -6,6 +6,8 @@ import ( "time" sdk "github.com/memohai/twilight-ai/sdk" + + "github.com/memohai/memoh/internal/agent/background" ) // SessionContext carries request-scoped identity and routing information. @@ -93,6 +95,12 @@ type RunConfig struct { // output messages that preceded the injection. Used by the resolver // to interleave injected messages at the correct position in storeRound. InjectedRecorder func(headerifiedText string, insertAfter int) + + // BackgroundManager provides access to the background task system. + // When non-nil, the agent loop drains pending notifications at step + // boundaries and injects them as user messages so the model learns + // about completed background work. + BackgroundManager *background.Manager } // GenerateResult holds the result of a non-streaming agent invocation. diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go index 0fbf4521..0a45887f 100644 --- a/internal/conversation/flow/resolver.go +++ b/internal/conversation/flow/resolver.go @@ -21,6 +21,7 @@ import ( "github.com/memohai/memoh/internal/accounts" agentpkg "github.com/memohai/memoh/internal/agent" + "github.com/memohai/memoh/internal/agent/background" "github.com/memohai/memoh/internal/compaction" "github.com/memohai/memoh/internal/conversation" "github.com/memohai/memoh/internal/db/sqlc" @@ -72,12 +73,18 @@ type Resolver struct { settingsService *settings.Service accountService *accounts.Service sessionService SessionService + routeService RouteService compactionService *compaction.Service eventPublisher messageevent.Publisher skillLoader SkillLoader assetLoader gatewayAssetLoader pipeline *pipelinepkg.Pipeline streamHTTPClient *http.Client + bgManager *background.Manager + outboundFn func(ctx context.Context, botID, channelType, target, text string) error + bgNotifDeferred sync.Map // key: "botID:sessionID" → wake arrived while a session turn was active + sessionTurnMu sync.Mutex + sessionTurnRefs map[string]int // key: "botID:sessionID" → active turn refcount timeout time.Duration clockLocation *time.Location logger *slog.Logger @@ -130,6 +137,7 @@ func NewResolver( settingsService: settingsService, accountService: accountService, streamHTTPClient: streamHTTPClient, + sessionTurnRefs: make(map[string]int), timeout: timeout, clockLocation: clockLocation, logger: log.With(slog.String("service", "conversation_resolver")), @@ -157,6 +165,19 @@ func (r *Resolver) SetCompactionService(s *compaction.Service) { r.compactionService = s } +// SetBackgroundManager configures the background task manager so that +// background exec notifications are injected into the agent loop. +func (r *Resolver) SetBackgroundManager(m *background.Manager) { + r.bgManager = m +} + +// SetOutboundFn configures the function used to deliver background notification +// responses to the user. The agent's text output is delivered through the same +// path as normal responses. +func (r *Resolver) SetOutboundFn(fn func(ctx context.Context, botID, channelType, target, text string) error) { + r.outboundFn = fn +} + // SetPipeline configures the DCP pipeline for RC-based context assembly. // When set, resolve() will use RC from the pipeline instead of loading // history from bot_history_messages for sessions that have pipeline data. @@ -413,6 +434,9 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r // Chat sends a synchronous chat request and stores the result. func (r *Resolver) Chat(ctx context.Context, req conversation.ChatRequest) (conversation.ChatResponse, error) { + doneTurn := r.enterSessionTurn(ctx, req.BotID, req.SessionID) + defer doneTurn() + rc, err := r.resolve(ctx, req) if err != nil { return conversation.ChatResponse{}, err @@ -551,8 +575,9 @@ func (r *Resolver) buildBaseRunConfig(ctx context.Context, p baseRunConfigParams TimezoneLocation: userClockLocation, SessionToken: p.SessionToken, }, - Skills: agentSkills, - LoopDetection: agentpkg.LoopDetectionConfig{Enabled: loopDetectionEnabled}, + Skills: agentSkills, + LoopDetection: agentpkg.LoopDetectionConfig{Enabled: loopDetectionEnabled}, + BackgroundManager: r.bgManager, } return cfg, chatModel, provider, nil diff --git a/internal/conversation/flow/resolver_stream.go b/internal/conversation/flow/resolver_stream.go index 3afdcd15..0def8fd8 100644 --- a/internal/conversation/flow/resolver_stream.go +++ b/internal/conversation/flow/resolver_stream.go @@ -23,8 +23,10 @@ func (r *Resolver) StreamChat(ctx context.Context, req conversation.ChatRequest) go func() { defer close(chunkCh) defer close(errCh) - streamReq := req + doneTurn := r.enterSessionTurn(ctx, streamReq.BotID, streamReq.SessionID) + defer doneTurn() + rc, err := r.resolve(ctx, streamReq) if err != nil { r.logger.Error("agent stream resolve failed", @@ -126,6 +128,9 @@ func (r *Resolver) StreamChatWS( eventCh chan<- WSStreamEvent, abortCh <-chan struct{}, ) error { + doneTurn := r.enterSessionTurn(ctx, req.BotID, req.SessionID) + defer doneTurn() + rc, err := r.resolve(ctx, req) if err != nil { r.logger.Error("StreamChatWS: resolve failed", diff --git a/internal/conversation/flow/resolver_trigger.go b/internal/conversation/flow/resolver_trigger.go index 08601104..18b76f7d 100644 --- a/internal/conversation/flow/resolver_trigger.go +++ b/internal/conversation/flow/resolver_trigger.go @@ -4,17 +4,32 @@ import ( "context" "encoding/json" "errors" + "fmt" + "log/slog" "strings" "time" sdk "github.com/memohai/twilight-ai/sdk" agentpkg "github.com/memohai/memoh/internal/agent" + "github.com/memohai/memoh/internal/channel/route" "github.com/memohai/memoh/internal/conversation" "github.com/memohai/memoh/internal/heartbeat" "github.com/memohai/memoh/internal/schedule" ) +// RouteService is the interface the resolver uses to recover route-backed +// delivery context for proactive background notifications. +type RouteService interface { + GetByID(ctx context.Context, routeID string) (route.Route, error) +} + +// SetRouteService configures the route service used for background delivery +// context resolution. +func (r *Resolver) SetRouteService(s RouteService) { + r.routeService = s +} + // TriggerSchedule executes a scheduled command via the internal agent. func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) (schedule.TriggerResult, error) { if strings.TrimSpace(botID) == "" { @@ -146,3 +161,180 @@ func isHeartbeatOK(text string) bool { t := strings.TrimSpace(text) return strings.HasPrefix(t, "HEARTBEAT_OK") || strings.HasSuffix(t, "HEARTBEAT_OK") || t == "HEARTBEAT_OK" } + +type backgroundDeliveryContext struct { + routeID string + channelType string + replyTarget string +} + +// TriggerBackgroundNotification is called when background-task notifications +// are enqueued for a session. Delivery is session-centric: all pending +// notifications for a session are drained together and delivered using the +// current session/route delivery context. It only runs when the session is +// currently idle; active turns consume notifications via mid-turn drain. +func (r *Resolver) TriggerBackgroundNotification(ctx context.Context, botID, sessionID string) { + r.logger.Info("background notification trigger called", + slog.String("bot_id", botID), + slog.String("session_id", sessionID), + ) + if strings.TrimSpace(botID) == "" || strings.TrimSpace(sessionID) == "" { + return + } + if r.bgManager == nil { + return + } + if !r.bgManager.HasNotifications(botID, sessionID) { + return + } + doneTurn, ok := r.tryEnterIdleSessionTurn(ctx, botID, sessionID) + if !ok { + r.markDeferredBackgroundNotification(botID, sessionID) + r.logger.Info("background notification trigger deferred: session turn active", + slog.String("bot_id", botID), + slog.String("session_id", sessionID), + ) + return + } + defer doneTurn() + + notifications := r.bgManager.DrainNotifications(botID, sessionID) + if len(notifications) == 0 { + return + } + + notifMessages := make([]sdk.Message, 0, len(notifications)) + for _, n := range notifications { + notifMessages = append(notifMessages, sdk.UserMessage(n.MessageText())) + } + + delivery, err := r.resolveBackgroundDeliveryContext(ctx, botID, sessionID) + if err != nil { + r.bgManager.RequeueNotifications(notifications) + r.logger.Warn("background notification trigger: resolve delivery context failed", + slog.String("bot_id", botID), + slog.String("session_id", sessionID), + slog.Any("error", err), + ) + return + } + + if err := r.deliverBackgroundNotifications(ctx, botID, sessionID, delivery, notifMessages); err != nil { + r.bgManager.RequeueNotifications(notifications) + r.logger.Warn("background notification trigger: deliver failed", + slog.String("bot_id", botID), + slog.String("session_id", sessionID), + slog.Any("error", err), + ) + } +} + +func (r *Resolver) resolveBackgroundDeliveryContext(ctx context.Context, botID, sessionID string) (backgroundDeliveryContext, error) { + if r.sessionService == nil { + return backgroundDeliveryContext{}, errors.New("session service not configured") + } + + sess, err := r.sessionService.Get(ctx, sessionID) + if err != nil { + return backgroundDeliveryContext{}, fmt.Errorf("get session: %w", err) + } + if sess.BotID != "" && botID != "" && sess.BotID != botID { + return backgroundDeliveryContext{}, fmt.Errorf("session %s belongs to bot %s, not %s", sessionID, sess.BotID, botID) + } + + channelType := strings.TrimSpace(sess.ChannelType) + if routeID := strings.TrimSpace(sess.RouteID); routeID != "" { + if r.routeService == nil { + return backgroundDeliveryContext{}, errors.New("route service not configured") + } + rt, err := r.routeService.GetByID(ctx, routeID) + if err != nil { + return backgroundDeliveryContext{}, fmt.Errorf("get route: %w", err) + } + if channelType == "" { + channelType = strings.TrimSpace(rt.Platform) + } + return backgroundDeliveryContext{ + routeID: routeID, + channelType: channelType, + replyTarget: strings.TrimSpace(rt.ReplyTarget), + }, nil + } + + if strings.EqualFold(channelType, "local") { + return backgroundDeliveryContext{ + channelType: "local", + replyTarget: botID, + }, nil + } + + return backgroundDeliveryContext{}, fmt.Errorf("session %s has no route-backed delivery context", sessionID) +} + +// deliverBackgroundNotifications runs a single agent call to deliver a batch of +// background-task notifications to the session's current delivery context. +func (r *Resolver) deliverBackgroundNotifications(ctx context.Context, botID, sessionID string, delivery backgroundDeliveryContext, notifMessages []sdk.Message) error { + r.logger.Info("background notification delivery", + slog.String("bot_id", botID), + slog.String("session_id", sessionID), + slog.String("route_id", delivery.routeID), + slog.String("platform", delivery.channelType), + slog.String("reply_target", delivery.replyTarget), + slog.Int("count", len(notifMessages)), + ) + req := conversation.ChatRequest{ + BotID: botID, + ChatID: botID, + SessionID: sessionID, + RouteID: delivery.routeID, + Query: "[background notification]", + CurrentChannel: delivery.channelType, + ReplyTarget: delivery.replyTarget, + } + rc, err := r.resolve(ctx, req) + if err != nil { + return fmt.Errorf("resolve background delivery: %w", err) + } + + cfg := rc.runConfig + // Inject drained notifications so the first LLM call sees them. + cfg.Messages = append(cfg.Messages, notifMessages...) + // Clear query so prepareRunConfig does not append a redundant user message. + cfg.Query = "" + // Use the natural session type — same system prompt, same tools, same + // personality as a regular conversation turn. Between-turn notifications + // should go through the same execution path as normal user messages. + cfg = r.prepareRunConfig(ctx, cfg) + + result, err := r.agent.Generate(ctx, cfg) + if err != nil { + return fmt.Errorf("generate background delivery: %w", err) + } + r.logger.Info("background notification trigger: generate ok", + slog.String("bot_id", botID), + slog.String("platform", delivery.channelType), + slog.String("reply_target", delivery.replyTarget), + slog.Int("messages", len(result.Messages)), + ) + + if len(result.Messages) > 0 { + outputMessages := sdkMessagesToModelMessages(result.Messages) + notifModelMessages := sdkMessagesToModelMessages(notifMessages) + roundMessages := append(append(make([]conversation.ModelMessage, 0, len(notifModelMessages)+len(outputMessages)), notifModelMessages...), outputMessages...) + _ = r.storeRound(ctx, req, roundMessages, rc.model.ID) + } + + // Auto-deliver the agent's text response to the user through the normal + // outbound path, not through a special "send" tool call. + if text := strings.TrimSpace(result.Text); text != "" && r.outboundFn != nil { + if err := r.outboundFn(ctx, botID, delivery.channelType, delivery.replyTarget, text); err != nil { + r.logger.Warn("background notification: outbound delivery failed", + slog.String("bot_id", botID), + slog.String("platform", delivery.channelType), + slog.String("reply_target", delivery.replyTarget), + slog.Any("error", err), + ) + } + } + return nil +} diff --git a/internal/conversation/flow/resolver_trigger_test.go b/internal/conversation/flow/resolver_trigger_test.go new file mode 100644 index 00000000..00d83f63 --- /dev/null +++ b/internal/conversation/flow/resolver_trigger_test.go @@ -0,0 +1,233 @@ +package flow + +import ( + "context" + "errors" + "log/slog" + "testing" + "time" + + "github.com/memohai/memoh/internal/agent/background" + "github.com/memohai/memoh/internal/channel/route" + "github.com/memohai/memoh/internal/session" +) + +type fakeBackgroundSessionService struct { + getFn func(ctx context.Context, sessionID string) (session.Session, error) +} + +func (f *fakeBackgroundSessionService) Get(ctx context.Context, sessionID string) (session.Session, error) { + if f == nil || f.getFn == nil { + return session.Session{}, errors.New("unexpected Get call") + } + return f.getFn(ctx, sessionID) +} + +func (*fakeBackgroundSessionService) UpdateTitle(context.Context, string, string) (session.Session, error) { + return session.Session{}, errors.New("unexpected UpdateTitle call") +} + +type fakeBackgroundRouteService struct { + getByIDFn func(ctx context.Context, routeID string) (route.Route, error) +} + +func (f *fakeBackgroundRouteService) GetByID(ctx context.Context, routeID string) (route.Route, error) { + if f == nil || f.getByIDFn == nil { + return route.Route{}, errors.New("unexpected GetByID call") + } + return f.getByIDFn(ctx, routeID) +} + +func TestResolveBackgroundDeliveryContext_RouteBackedSession(t *testing.T) { + t.Parallel() + + resolver := &Resolver{ + logger: slog.Default(), + sessionService: &fakeBackgroundSessionService{ + getFn: func(_ context.Context, sessionID string) (session.Session, error) { + if sessionID != "session-1" { + t.Fatalf("unexpected session id: %s", sessionID) + } + return session.Session{ + ID: sessionID, + BotID: "bot-1", + RouteID: "route-1", + ChannelType: "telegram", + }, nil + }, + }, + routeService: &fakeBackgroundRouteService{ + getByIDFn: func(_ context.Context, routeID string) (route.Route, error) { + if routeID != "route-1" { + t.Fatalf("unexpected route id: %s", routeID) + } + return route.Route{ + ID: routeID, + Platform: "telegram", + ReplyTarget: "chat-42", + }, nil + }, + }, + } + + delivery, err := resolver.resolveBackgroundDeliveryContext(context.Background(), "bot-1", "session-1") + if err != nil { + t.Fatalf("resolveBackgroundDeliveryContext returned error: %v", err) + } + if delivery.routeID != "route-1" { + t.Fatalf("unexpected route id: %q", delivery.routeID) + } + if delivery.channelType != "telegram" { + t.Fatalf("unexpected channel type: %q", delivery.channelType) + } + if delivery.replyTarget != "chat-42" { + t.Fatalf("unexpected reply target: %q", delivery.replyTarget) + } +} + +func TestResolveBackgroundDeliveryContext_LocalSessionFallback(t *testing.T) { + t.Parallel() + + resolver := &Resolver{ + logger: slog.Default(), + sessionService: &fakeBackgroundSessionService{ + getFn: func(_ context.Context, sessionID string) (session.Session, error) { + return session.Session{ + ID: sessionID, + BotID: "bot-1", + ChannelType: "local", + }, nil + }, + }, + } + + delivery, err := resolver.resolveBackgroundDeliveryContext(context.Background(), "bot-1", "session-1") + if err != nil { + t.Fatalf("resolveBackgroundDeliveryContext returned error: %v", err) + } + if delivery.routeID != "" { + t.Fatalf("expected empty route id, got %q", delivery.routeID) + } + if delivery.channelType != "local" { + t.Fatalf("unexpected channel type: %q", delivery.channelType) + } + if delivery.replyTarget != "bot-1" { + t.Fatalf("unexpected reply target: %q", delivery.replyTarget) + } +} + +func TestTriggerBackgroundNotification_RequeuesWholeBatchOnDeliveryContextFailure(t *testing.T) { + t.Parallel() + + bgMgr := background.New(nil) + batch := []background.Notification{ + {TaskID: "task-1", BotID: "bot-1", SessionID: "session-1", Status: background.TaskCompleted, Command: "cmd-1"}, + {TaskID: "task-2", BotID: "bot-1", SessionID: "session-1", Status: background.TaskFailed, Command: "cmd-2"}, + } + bgMgr.RequeueNotifications(batch) + + resolver := &Resolver{ + logger: slog.Default(), + bgManager: bgMgr, + sessionService: &fakeBackgroundSessionService{ + getFn: func(_ context.Context, _ string) (session.Session, error) { + return session.Session{}, errors.New("session lookup failed") + }, + }, + } + + resolver.TriggerBackgroundNotification(context.Background(), "bot-1", "session-1") + + remaining := bgMgr.DrainNotifications("bot-1", "session-1") + if len(remaining) != len(batch) { + t.Fatalf("expected %d notifications to be requeued, got %d", len(batch), len(remaining)) + } + for i, n := range remaining { + if n.TaskID != batch[i].TaskID { + t.Fatalf("unexpected task order after requeue at %d: got %q want %q", i, n.TaskID, batch[i].TaskID) + } + } +} + +func TestTriggerBackgroundNotification_DefersWhileSessionTurnActive(t *testing.T) { + t.Parallel() + + bgMgr := background.New(nil) + bgMgr.RequeueNotifications([]background.Notification{{ + TaskID: "task-1", + BotID: "bot-1", + SessionID: "session-1", + Status: background.TaskCompleted, + Command: "cmd-1", + }}) + + lookups := make(chan struct{}, 1) + resolver := &Resolver{ + logger: slog.Default(), + bgManager: bgMgr, + sessionService: &fakeBackgroundSessionService{ + getFn: func(_ context.Context, _ string) (session.Session, error) { + lookups <- struct{}{} + return session.Session{}, errors.New("unexpected session lookup") + }, + }, + } + + doneTurn := resolver.enterSessionTurn(context.Background(), "bot-1", "session-1") + resolver.TriggerBackgroundNotification(context.Background(), "bot-1", "session-1") + + select { + case <-lookups: + t.Fatal("expected trigger to defer while session turn is active") + case <-time.After(50 * time.Millisecond): + } + + if !bgMgr.HasNotifications("bot-1", "session-1") { + t.Fatal("expected notifications to remain queued while session turn is active") + } + + doneTurn() +} + +func TestSessionTurnExit_TriggersPendingBackgroundNotifications(t *testing.T) { + t.Parallel() + + bgMgr := background.New(nil) + bgMgr.RequeueNotifications([]background.Notification{{ + TaskID: "task-1", + BotID: "bot-1", + SessionID: "session-1", + Status: background.TaskCompleted, + Command: "cmd-1", + }}) + + lookups := make(chan struct{}, 1) + resolver := &Resolver{ + logger: slog.Default(), + bgManager: bgMgr, + sessionService: &fakeBackgroundSessionService{ + getFn: func(_ context.Context, _ string) (session.Session, error) { + lookups <- struct{}{} + return session.Session{}, errors.New("session lookup failed") + }, + }, + } + + doneTurn := resolver.enterSessionTurn(context.Background(), "bot-1", "session-1") + resolver.TriggerBackgroundNotification(context.Background(), "bot-1", "session-1") + doneTurn() + + select { + case <-lookups: + case <-time.After(500 * time.Millisecond): + t.Fatal("expected idle transition to trigger pending background notifications") + } + + deadline := time.Now().Add(500 * time.Millisecond) + for !bgMgr.HasNotifications("bot-1", "session-1") { + if time.Now().After(deadline) { + t.Fatal("expected failed delivery attempt to requeue notifications") + } + time.Sleep(10 * time.Millisecond) + } +} diff --git a/internal/conversation/flow/resolver_turns.go b/internal/conversation/flow/resolver_turns.go new file mode 100644 index 00000000..fdce0d22 --- /dev/null +++ b/internal/conversation/flow/resolver_turns.go @@ -0,0 +1,112 @@ +package flow + +import ( + "context" + "log/slog" + "strings" + "sync" +) + +func sessionTurnKey(botID, sessionID string) string { + return strings.TrimSpace(botID) + ":" + strings.TrimSpace(sessionID) +} + +func (r *Resolver) enterSessionTurn(ctx context.Context, botID, sessionID string) func() { + botID = strings.TrimSpace(botID) + sessionID = strings.TrimSpace(sessionID) + if botID == "" || sessionID == "" { + return func() {} + } + + key := sessionTurnKey(botID, sessionID) + r.sessionTurnMu.Lock() + if r.sessionTurnRefs == nil { + r.sessionTurnRefs = make(map[string]int) + } + r.sessionTurnRefs[key]++ + r.sessionTurnMu.Unlock() + + return r.makeSessionTurnReleaser(ctx, key, botID, sessionID) +} + +func (r *Resolver) tryEnterIdleSessionTurn(ctx context.Context, botID, sessionID string) (func(), bool) { + botID = strings.TrimSpace(botID) + sessionID = strings.TrimSpace(sessionID) + if botID == "" || sessionID == "" { + return nil, false + } + + key := sessionTurnKey(botID, sessionID) + r.sessionTurnMu.Lock() + if r.sessionTurnRefs == nil { + r.sessionTurnRefs = make(map[string]int) + } + if r.sessionTurnRefs[key] > 0 { + r.sessionTurnMu.Unlock() + return nil, false + } + r.sessionTurnRefs[key] = 1 + r.sessionTurnMu.Unlock() + + return r.makeSessionTurnReleaser(ctx, key, botID, sessionID), true +} + +func (r *Resolver) makeSessionTurnReleaser(ctx context.Context, key, botID, sessionID string) func() { + var once sync.Once + return func() { + once.Do(func() { + becameIdle := false + + r.sessionTurnMu.Lock() + switch refs := r.sessionTurnRefs[key] - 1; { + case refs > 0: + r.sessionTurnRefs[key] = refs + default: + delete(r.sessionTurnRefs, key) + becameIdle = true + } + r.sessionTurnMu.Unlock() + + if becameIdle { + r.maybeTriggerDeferredBackgroundNotifications(ctx, botID, sessionID) + } + }) + } +} + +func (r *Resolver) markDeferredBackgroundNotification(botID, sessionID string) { + botID = strings.TrimSpace(botID) + sessionID = strings.TrimSpace(sessionID) + if botID == "" || sessionID == "" { + return + } + r.bgNotifDeferred.Store(sessionTurnKey(botID, sessionID), true) +} + +func (r *Resolver) takeDeferredBackgroundNotification(botID, sessionID string) bool { + botID = strings.TrimSpace(botID) + sessionID = strings.TrimSpace(sessionID) + if botID == "" || sessionID == "" { + return false + } + _, loaded := r.bgNotifDeferred.LoadAndDelete(sessionTurnKey(botID, sessionID)) + return loaded +} + +func (r *Resolver) maybeTriggerDeferredBackgroundNotifications(ctx context.Context, botID, sessionID string) { + if !r.takeDeferredBackgroundNotification(botID, sessionID) { + return + } + if r.bgManager == nil || !r.bgManager.HasNotifications(botID, sessionID) { + return + } + + r.logger.Info("background notification trigger queued after session became idle", + slog.String("bot_id", botID), + slog.String("session_id", sessionID), + ) + if ctx == nil { + return + } + go r.TriggerBackgroundNotification(context.WithoutCancel(ctx), botID, sessionID) +}