diff --git a/.gitignore b/.gitignore index d15e6eea..3b78015d 100644 --- a/.gitignore +++ b/.gitignore @@ -109,3 +109,4 @@ data _main-ref/ .toolkit/ /scripts/vendor +Memoh diff --git a/apps/web/src/components/channel-icon/index.vue b/apps/web/src/components/channel-icon/index.vue index 7ceac53d..bcd8ac0a 100644 --- a/apps/web/src/components/channel-icon/index.vue +++ b/apps/web/src/components/channel-icon/index.vue @@ -24,6 +24,7 @@ import { Wechatoa, Wecom, Matrix, + Misskey, } from '@memohai/icon' const channelIcons: Record = { @@ -37,6 +38,7 @@ const channelIcons: Record = { wechatoa: Wechatoa, wecom: Wecom, matrix: Matrix, + misskey: Misskey, dingtalk: Dingtalk, } diff --git a/apps/web/src/composables/api/useChat.ws.ts b/apps/web/src/composables/api/useChat.ws.ts index fe50a594..1fc6a1f9 100644 --- a/apps/web/src/composables/api/useChat.ws.ts +++ b/apps/web/src/composables/api/useChat.ws.ts @@ -21,7 +21,7 @@ export interface ChatWebSocket { function resolveWebSocketUrl(botId: string): string { const baseUrl = String(client.getConfig().baseUrl || '').trim() - const path = `/bots/${encodeURIComponent(botId)}/local/ws` + const path = `/bots/${encodeURIComponent(botId)}/web/ws` if (!baseUrl || baseUrl.startsWith('/')) { const loc = window.location diff --git a/apps/web/src/i18n/locales/en.json b/apps/web/src/i18n/locales/en.json index 57e0751d..6df0ccc5 100644 --- a/apps/web/src/i18n/locales/en.json +++ b/apps/web/src/i18n/locales/en.json @@ -144,6 +144,8 @@ }, "chat": { "greeting": "Hi! How can I help you today?", + "emptySubagent": "No messages recorded for this subagent task", + "emptySystemSession": "No messages recorded for this system session", "selectBot": "Select a Bot", "selectBotHint": "Choose a bot from the sidebar to start chatting", "thinking": "Thinking…", @@ -1022,6 +1024,7 @@ "discord": "Discord", "qq": "QQ", "matrix": "Matrix", + "misskey": "Misskey", "telegram": "Telegram", "weixin": "WeChat", "wechatoa": "WeChat Official Account", @@ -1036,6 +1039,7 @@ "discord": "DC", "qq": "QQ", "matrix": "MX", + "misskey": "MK", "telegram": "TG", "weixin": "WX", "wechatoa": "OA", diff --git a/apps/web/src/i18n/locales/zh.json b/apps/web/src/i18n/locales/zh.json index e8932035..c1b55e69 100644 --- a/apps/web/src/i18n/locales/zh.json +++ b/apps/web/src/i18n/locales/zh.json @@ -140,6 +140,8 @@ }, "chat": { "greeting": "你好!有什么我可以帮你的吗?", + "emptySubagent": "子代理任务暂无记录", + "emptySystemSession": "系统会话暂无记录", "selectBot": "选择一个 Bot", "selectBotHint": "从侧边栏选择一个 Bot 开始对话", "thinking": "思考中…", @@ -1018,6 +1020,7 @@ "discord": "Discord", "qq": "QQ", "matrix": "Matrix", + "misskey": "Misskey", "telegram": "Telegram", "weixin": "微信", "wechatoa": "微信服务号", @@ -1032,6 +1035,7 @@ "discord": "DC", "qq": "QQ", "matrix": "MX", + "misskey": "MK", "telegram": "TG", "weixin": "WX", "wechatoa": "OA", diff --git a/apps/web/src/pages/bots/components/bot-settings.vue b/apps/web/src/pages/bots/components/bot-settings.vue index c85d3169..7fd4632b 100644 --- a/apps/web/src/pages/bots/components/bot-settings.vue +++ b/apps/web/src/pages/bots/components/bot-settings.vue @@ -212,17 +212,16 @@ /> +
-

- {{ $t('bots.timezoneInheritedHint') }} -

@@ -237,8 +236,21 @@ /> - + +
+ + +
+ + +
@@ -348,7 +360,7 @@ import { getBotsById, putBotsById, getBotsByBotIdSettings, putBotsByBotIdSetting import type { SettingsSettings } from '@memohai/sdk' import type { Ref } from 'vue' import { resolveApiErrorMessage } from '@/utils/api-error' -import { useUserStore } from '@/store/user' +import { emptyTimezoneValue } from '@/utils/timezones' const props = defineProps<{ botId: string @@ -356,13 +368,8 @@ const props = defineProps<{ const { t } = useI18n() const router = useRouter() -const userStore = useUserStore() const botIdRef = computed(() => props.botId) as Ref -const defaultTimezone = computed(() => userStore.userInfo.timezone || 'UTC') -const timezoneEmptyLabel = computed(() => - `${t('bots.timezoneInherited')} (${defaultTimezone.value})`, -) // ---- Data ---- const queryCache = useQueryCache() @@ -639,6 +646,7 @@ watch(settings, (val) => { form.tts_model_id = val.tts_model_id ?? '' form.browser_context_id = val.browser_context_id ?? '' form.language = val.language ?? '' + form.timezone = val.timezone ?? '' form.reasoning_enabled = val.reasoning_enabled ?? false form.reasoning_effort = val.reasoning_effort || 'medium' } @@ -660,6 +668,7 @@ const hasSettingsChanges = computed(() => { || form.tts_model_id !== (s.tts_model_id ?? '') || form.browser_context_id !== (s.browser_context_id ?? '') || form.language !== (s.language ?? '') + || form.timezone !== (s.timezone ?? '') || form.reasoning_enabled !== (s.reasoning_enabled ?? false) || form.reasoning_effort !== (s.reasoning_effort || 'medium') ) diff --git a/apps/web/src/pages/bots/components/create-bot.vue b/apps/web/src/pages/bots/components/create-bot.vue index f6707f14..846fcce1 100644 --- a/apps/web/src/pages/bots/components/create-bot.vue +++ b/apps/web/src/pages/bots/components/create-bot.vue @@ -58,7 +58,7 @@ @@ -67,28 +67,13 @@ ({{ $t('common.optional') }}) - + @@ -133,12 +118,6 @@ import { FormItem, Separator, Label, - Select, - SelectContent, - SelectGroup, - SelectItem, - SelectTrigger, - SelectValue, Spinner, } from '@memohai/ui' import { Plus } from 'lucide-vue-next' @@ -150,7 +129,8 @@ import { useMutation, useQueryCache } from '@pinia/colada' import { postBotsMutation, getBotsQueryKey } from '@memohai/sdk/colada' import { useI18n } from 'vue-i18n' import { useDialogMutation } from '@/composables/useDialogMutation' -import { emptyTimezoneValue, timezones } from '@/utils/timezones' +import { emptyTimezoneValue } from '@/utils/timezones' +import TimezoneSelect from '@/components/timezone-select/index.vue' const open = defineModel('open', { default: false }) const { t } = useI18n() diff --git a/apps/web/src/pages/home/components/chat-area.vue b/apps/web/src/pages/home/components/chat-area.vue index 12826b0d..5c612ed1 100644 --- a/apps/web/src/pages/home/components/chat-area.vue +++ b/apps/web/src/pages/home/components/chat-area.vue @@ -40,7 +40,22 @@ v-if="messages.length === 0 && !loadingChats" class="flex items-center justify-center min-h-[300px]" > -

+

+ {{ $t('chat.emptySubagent') }} +

+

+ {{ $t('chat.emptySystemSession') }} +

+

{{ $t('chat.greeting') }}

diff --git a/apps/web/src/pages/home/components/tool-call-edit.vue b/apps/web/src/pages/home/components/tool-call-edit.vue index 13022080..5cc77a30 100644 --- a/apps/web/src/pages/home/components/tool-call-edit.vue +++ b/apps/web/src/pages/home/components/tool-call-edit.vue @@ -51,11 +51,13 @@ > +
+
diff --git a/apps/web/src/pages/home/components/tool-call-write.vue b/apps/web/src/pages/home/components/tool-call-write.vue index 424a25e6..8dab2858 100644 --- a/apps/web/src/pages/home/components/tool-call-write.vue +++ b/apps/web/src/pages/home/components/tool-call-write.vue @@ -51,11 +51,13 @@ > +
+
diff --git a/apps/web/src/store/chat-list.ts b/apps/web/src/store/chat-list.ts index e665105a..44904963 100644 --- a/apps/web/src/store/chat-list.ts +++ b/apps/web/src/store/chat-list.ts @@ -571,7 +571,7 @@ export const useChatStore = defineStore('chat', () => { } else { const activeSessionId = sessionId.value && visible.some(session => session.id === sessionId.value) ? sessionId.value - : visible[0]!.id + : (visible.find((s) => s.type === 'chat' || s.type === 'discuss')?.id ?? visible[0]!.id) sessionId.value = activeSessionId await loadMessages(bid, activeSessionId) } diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 8c898a14..198f4980 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -35,6 +35,7 @@ import ( "github.com/memohai/memoh/internal/channel/adapters/feishu" "github.com/memohai/memoh/internal/channel/adapters/local" "github.com/memohai/memoh/internal/channel/adapters/matrix" + "github.com/memohai/memoh/internal/channel/adapters/misskey" "github.com/memohai/memoh/internal/channel/adapters/qq" "github.com/memohai/memoh/internal/channel/adapters/telegram" "github.com/memohai/memoh/internal/channel/adapters/wechatoa" @@ -361,12 +362,13 @@ func provideWorkspaceManager(log *slog.Logger, service ctr.Service, cfg config.C // memory providers // --------------------------------------------------------------------------- -func provideMemoryLLM(modelsService *models.Service, queries *dbsqlc.Queries, log *slog.Logger) memprovider.LLM { +func provideMemoryLLM(modelsService *models.Service, settingsService *settings.Service, queries *dbsqlc.Queries, log *slog.Logger) memprovider.LLM { return &lazyLLMClient{ - modelsService: modelsService, - queries: queries, - timeout: 30 * time.Second, - logger: log, + modelsService: modelsService, + settingsService: settingsService, + queries: queries, + timeout: 30 * time.Second, + logger: log, } } @@ -531,6 +533,10 @@ func provideChannelRegistry(log *slog.Logger, hub *local.RouteHub, mediaService weixinAdapter.SetAssetOpener(mediaService) registry.MustRegister(weixinAdapter) registry.MustRegister(local.NewWebAdapter(hub)) + + // Misskey + registry.MustRegister(misskey.NewMisskeyAdapter(log)) + return registry } @@ -1084,14 +1090,15 @@ func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Quer // --------------------------------------------------------------------------- type lazyLLMClient struct { - modelsService *models.Service - queries *dbsqlc.Queries - timeout time.Duration - logger *slog.Logger + modelsService *models.Service + settingsService *settings.Service + queries *dbsqlc.Queries + timeout time.Duration + logger *slog.Logger } func (c *lazyLLMClient) Extract(ctx context.Context, req memprovider.ExtractRequest) (memprovider.ExtractResponse, error) { - client, err := c.resolve(ctx) + client, err := c.resolve(ctx, req.BotID) if err != nil { return memprovider.ExtractResponse{}, err } @@ -1099,7 +1106,7 @@ func (c *lazyLLMClient) Extract(ctx context.Context, req memprovider.ExtractRequ } func (c *lazyLLMClient) Decide(ctx context.Context, req memprovider.DecideRequest) (memprovider.DecideResponse, error) { - client, err := c.resolve(ctx) + client, err := c.resolve(ctx, req.BotID) if err != nil { return memprovider.DecideResponse{}, err } @@ -1107,7 +1114,7 @@ func (c *lazyLLMClient) Decide(ctx context.Context, req memprovider.DecideReques } func (c *lazyLLMClient) Compact(ctx context.Context, req memprovider.CompactRequest) (memprovider.CompactResponse, error) { - client, err := c.resolve(ctx) + client, err := c.resolve(ctx, "") if err != nil { return memprovider.CompactResponse{}, err } @@ -1115,18 +1122,32 @@ func (c *lazyLLMClient) Compact(ctx context.Context, req memprovider.CompactRequ } func (c *lazyLLMClient) DetectLanguage(ctx context.Context, text string) (string, error) { - client, err := c.resolve(ctx) + client, err := c.resolve(ctx, "") if err != nil { return "", err } return client.DetectLanguage(ctx, text) } -func (c *lazyLLMClient) resolve(ctx context.Context) (memprovider.LLM, error) { +func (c *lazyLLMClient) resolve(ctx context.Context, botID string) (memprovider.LLM, error) { if c.modelsService == nil || c.queries == nil { return nil, errors.New("models service not configured") } - memoryModel, memoryProvider, err := models.SelectMemoryModelForBot(ctx, c.modelsService, c.queries, "") + + // Try to use the bot's configured chat model for memory operations. + chatModelID := "" + if c.settingsService != nil && strings.TrimSpace(botID) != "" { + if botSettings, err := c.settingsService.GetBot(ctx, botID); err == nil { + // Prefer compaction model (smaller/cheaper), then chat model. + if id := strings.TrimSpace(botSettings.CompactionModelID); id != "" { + chatModelID = id + } else if id := strings.TrimSpace(botSettings.ChatModelID); id != "" { + chatModelID = id + } + } + } + + memoryModel, memoryProvider, err := models.SelectMemoryModelForBot(ctx, c.modelsService, c.queries, chatModelID) if err != nil { return nil, err } diff --git a/cmd/bridge/main.go b/cmd/bridge/main.go index b8dde819..d3c926ef 100644 --- a/cmd/bridge/main.go +++ b/cmd/bridge/main.go @@ -12,6 +12,7 @@ import ( "time" "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" "google.golang.org/grpc/reflection" "github.com/memohai/memoh/internal/logger" @@ -49,7 +50,7 @@ func initDataDir() { if err != nil { continue } - if err := os.WriteFile(dst, data, fs.FileMode(0o644)); err != nil { + if err := os.WriteFile(dst, data, fs.FileMode(0o644)); err != nil { //nolint:gosec // G703: dst is built from filepath.Join(defaultWorkDir, e.Name()) where e comes from os.ReadDir logger.Warn("failed to seed template", slog.String("file", e.Name()), slog.Any("error", err)) } } @@ -91,7 +92,21 @@ func main() { return } - srv := grpc.NewServer() + srv := grpc.NewServer( + grpc.MaxRecvMsgSize(16*1024*1024), + grpc.MaxSendMsgSize(16*1024*1024), + grpc.KeepaliveParams(keepalive.ServerParameters{ + MaxConnectionIdle: 5 * time.Minute, + MaxConnectionAge: 30 * time.Minute, + MaxConnectionAgeGrace: 10 * time.Second, + Time: 60 * time.Second, + Timeout: 15 * time.Second, + }), + grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ + MinTime: 10 * time.Second, + PermitWithoutStream: true, + }), + ) pb.RegisterContainerServiceServer(srv, &containerServer{}) reflection.Register(srv) diff --git a/cmd/bridge/server.go b/cmd/bridge/server.go index 03f933a9..74e4b920 100644 --- a/cmd/bridge/server.go +++ b/cmd/bridge/server.go @@ -24,14 +24,15 @@ import ( ) const ( - readMaxLines = 200 - readMaxBytes = 5120 - readMaxLineLen = 1000 - listMaxEntries = 200 - binaryProbeBytes = 8 * 1024 - rawChunkSize = 64 * 1024 - defaultWorkDir = "/data" - defaultTimeout = 30 + readMaxLines = 2000 + readMaxBytes = 0 // 0 = no byte limit (line count only) + readMaxLineLen = 0 // 0 = no per-line truncation + listMaxEntries = 200 + binaryProbeBytes = 8 * 1024 + rawChunkSize = 64 * 1024 + defaultWorkDir = "/data" + defaultTimeout = 30 + defaultPTYTimeout = 5 * 60 // 5 minutes max for PTY sessions (agent tool calls) ) type containerServer struct { @@ -89,12 +90,12 @@ func (*containerServer) ReadFile(_ context.Context, req *pb.ReadFileRequest) (*p } line := scanner.Text() - if utf8.RuneCountInString(line) > readMaxLineLen { + if readMaxLineLen > 0 && utf8.RuneCountInString(line) > readMaxLineLen { line = truncateRunes(line, readMaxLineLen) + "..." } entry := line + "\n" - if bytesWritten+len(entry) > readMaxBytes { + if readMaxBytes > 0 && bytesWritten+len(entry) > readMaxBytes { break } out.WriteString(entry) @@ -288,11 +289,18 @@ func execPTY(stream pb.ContainerService_ExecServer, firstMsg *pb.ExecInput) erro workDir = defaultWorkDir } + timeout := int(firstMsg.GetTimeoutSeconds()) + if timeout <= 0 { + timeout = defaultPTYTimeout + } + ctx, cancel := context.WithTimeout(stream.Context(), time.Duration(timeout)*time.Second) + defer cancel() + var cmd *exec.Cmd if isBarePath(command) { - cmd = exec.CommandContext(stream.Context(), command) //nolint:gosec // G204: intentional + cmd = exec.CommandContext(ctx, command) //nolint:gosec // G204: intentional } else { - cmd = exec.CommandContext(stream.Context(), "/bin/sh", "-c", command) //nolint:gosec // G204: intentional + cmd = exec.CommandContext(ctx, "/bin/sh", "-c", command) //nolint:gosec // G204: intentional } cmd.Dir = workDir cmd.Env = append(os.Environ(), firstMsg.GetEnv()...) diff --git a/cmd/memoh/serve.go b/cmd/memoh/serve.go index 75a608ac..393aad64 100644 --- a/cmd/memoh/serve.go +++ b/cmd/memoh/serve.go @@ -36,6 +36,7 @@ import ( "github.com/memohai/memoh/internal/channel/adapters/feishu" "github.com/memohai/memoh/internal/channel/adapters/local" "github.com/memohai/memoh/internal/channel/adapters/matrix" + "github.com/memohai/memoh/internal/channel/adapters/misskey" "github.com/memohai/memoh/internal/channel/adapters/qq" "github.com/memohai/memoh/internal/channel/adapters/telegram" "github.com/memohai/memoh/internal/channel/adapters/wechatoa" @@ -260,8 +261,8 @@ func provideWorkspaceManager(log *slog.Logger, service ctr.Service, cfg config.C return workspace.NewManager(log, service, cfg.Workspace, cfg.Containerd.Namespace, conn) } -func provideMemoryLLM(modelsService *models.Service, queries *dbsqlc.Queries, log *slog.Logger) memprovider.LLM { - return &lazyLLMClient{modelsService: modelsService, queries: queries, timeout: 30 * time.Second, logger: log} +func provideMemoryLLM(modelsService *models.Service, settingsService *settings.Service, queries *dbsqlc.Queries, log *slog.Logger) memprovider.LLM { + return &lazyLLMClient{modelsService: modelsService, settingsService: settingsService, queries: queries, timeout: 30 * time.Second, logger: log} } func provideMemoryProviderRegistry(log *slog.Logger, llm memprovider.LLM, chatService *conversation.Service, accountService *accounts.Service, manager *workspace.Manager, queries *dbsqlc.Queries, cfg config.Config) *memprovider.Registry { @@ -453,6 +454,10 @@ func provideChannelRegistry(log *slog.Logger, hub *local.RouteHub, mediaService weixinAdapter.SetAssetOpener(mediaService) registry.MustRegister(weixinAdapter) registry.MustRegister(local.NewWebAdapter(hub)) + + // Misskey + registry.MustRegister(misskey.NewMisskeyAdapter(log)) + return registry } @@ -1006,14 +1011,15 @@ func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Quer } type lazyLLMClient struct { - modelsService *models.Service - queries *dbsqlc.Queries - timeout time.Duration - logger *slog.Logger + modelsService *models.Service + settingsService *settings.Service + queries *dbsqlc.Queries + timeout time.Duration + logger *slog.Logger } func (c *lazyLLMClient) Extract(ctx context.Context, req memprovider.ExtractRequest) (memprovider.ExtractResponse, error) { - client, err := c.resolve(ctx) + client, err := c.resolve(ctx, req.BotID) if err != nil { return memprovider.ExtractResponse{}, err } @@ -1021,7 +1027,7 @@ func (c *lazyLLMClient) Extract(ctx context.Context, req memprovider.ExtractRequ } func (c *lazyLLMClient) Decide(ctx context.Context, req memprovider.DecideRequest) (memprovider.DecideResponse, error) { - client, err := c.resolve(ctx) + client, err := c.resolve(ctx, req.BotID) if err != nil { return memprovider.DecideResponse{}, err } @@ -1029,7 +1035,7 @@ func (c *lazyLLMClient) Decide(ctx context.Context, req memprovider.DecideReques } func (c *lazyLLMClient) Compact(ctx context.Context, req memprovider.CompactRequest) (memprovider.CompactResponse, error) { - client, err := c.resolve(ctx) + client, err := c.resolve(ctx, "") if err != nil { return memprovider.CompactResponse{}, err } @@ -1037,18 +1043,32 @@ func (c *lazyLLMClient) Compact(ctx context.Context, req memprovider.CompactRequ } func (c *lazyLLMClient) DetectLanguage(ctx context.Context, text string) (string, error) { - client, err := c.resolve(ctx) + client, err := c.resolve(ctx, "") if err != nil { return "", err } return client.DetectLanguage(ctx, text) } -func (c *lazyLLMClient) resolve(ctx context.Context) (memprovider.LLM, error) { +func (c *lazyLLMClient) resolve(ctx context.Context, botID string) (memprovider.LLM, error) { if c.modelsService == nil || c.queries == nil { return nil, errors.New("models service not configured") } - memoryModel, memoryProvider, err := models.SelectMemoryModelForBot(ctx, c.modelsService, c.queries, "") + + // Try to use the bot's configured chat model for memory operations. + chatModelID := "" + if c.settingsService != nil && strings.TrimSpace(botID) != "" { + if botSettings, err := c.settingsService.GetBot(ctx, botID); err == nil { + // Prefer compaction model (smaller/cheaper), then chat model. + if id := strings.TrimSpace(botSettings.CompactionModelID); id != "" { + chatModelID = id + } else if id := strings.TrimSpace(botSettings.ChatModelID); id != "" { + chatModelID = id + } + } + } + + memoryModel, memoryProvider, err := models.SelectMemoryModelForBot(ctx, c.modelsService, c.queries, chatModelID) if err != nil { return nil, err } diff --git a/conf/providers/openrouter.yaml b/conf/providers/openrouter.yaml index 224f80bb..d755d088 100644 --- a/conf/providers/openrouter.yaml +++ b/conf/providers/openrouter.yaml @@ -250,7 +250,7 @@ models: compatibilities: [reasoning] context_window: 131072 - model_id: baidu/ernie-4.5-300b-a47b - name: "Baidu: ERNIE 4.5 300B A47B " + name: "Baidu: ERNIE 4.5 300B A47B" type: chat config: context_window: 123000 @@ -261,7 +261,7 @@ models: compatibilities: [vision, tool-call, reasoning] context_window: 30000 - model_id: baidu/ernie-4.5-vl-424b-a47b - name: "Baidu: ERNIE 4.5 VL 424B A47B " + name: "Baidu: ERNIE 4.5 VL 424B A47B" type: chat config: compatibilities: [vision, reasoning] @@ -291,7 +291,7 @@ models: compatibilities: [vision, tool-call, reasoning] context_window: 262144 - model_id: bytedance/ui-tars-1.5-7b - name: "ByteDance: UI-TARS 7B " + name: "ByteDance: UI-TARS 7B" type: chat config: compatibilities: [vision] @@ -1512,7 +1512,7 @@ models: config: context_window: 32768 - model_id: qwen/qwen-max - name: "Qwen: Qwen-Max " + name: "Qwen: Qwen-Max" type: chat config: compatibilities: [tool-call] @@ -1989,7 +1989,7 @@ models: compatibilities: [tool-call, reasoning] context_window: 1048576 - model_id: z-ai/glm-4-32b - name: "Z.ai: GLM 4 32B " + name: "Z.ai: GLM 4 32B" type: chat config: compatibilities: [tool-call] diff --git a/conf/providers/xai.yaml b/conf/providers/xai.yaml index c267a84a..021d6f09 100644 --- a/conf/providers/xai.yaml +++ b/conf/providers/xai.yaml @@ -1,5 +1,5 @@ name: xAI (Grok) -client_type: openai-completions +client_type: openai-responses icon: xai base_url: https://api.x.ai/v1 @@ -8,7 +8,7 @@ models: name: Grok 4.20 Beta type: chat config: - compatibilities: [vision, tool-call, reasoning] + compatibilities: [vision, tool-call] context_window: 2000000 - model_id: grok-4.20-beta-0309-non-reasoning @@ -22,7 +22,7 @@ models: name: Grok 4.20 Multi-Agent Beta type: chat config: - compatibilities: [vision, reasoning] + compatibilities: [vision] context_window: 2000000 - model_id: grok-4-1-fast-non-reasoning @@ -36,7 +36,7 @@ models: name: Grok 4.1 Fast type: chat config: - compatibilities: [vision, tool-call, reasoning] + compatibilities: [vision, tool-call] context_window: 2000000 - model_id: grok-4-fast-non-reasoning @@ -50,21 +50,21 @@ models: name: Grok 4 Fast type: chat config: - compatibilities: [vision, tool-call, reasoning] + compatibilities: [vision, tool-call] context_window: 2000000 - model_id: grok-code-fast-1 name: Grok Code Fast 1 type: chat config: - compatibilities: [tool-call, reasoning] + compatibilities: [tool-call] context_window: 256000 - model_id: grok-4 name: Grok 4 0709 type: chat config: - compatibilities: [vision, tool-call, reasoning] + compatibilities: [vision, tool-call] context_window: 256000 - model_id: grok-3 diff --git a/db/migrations/0001_init.up.sql b/db/migrations/0001_init.up.sql index 7f13aa17..4d016d11 100644 --- a/db/migrations/0001_init.up.sql +++ b/db/migrations/0001_init.up.sql @@ -155,6 +155,8 @@ CREATE TABLE IF NOT EXISTS bots ( discuss_probe_model_id UUID REFERENCES models(id) ON DELETE SET NULL, tts_model_id UUID REFERENCES models(id) ON DELETE SET NULL, browser_context_id UUID REFERENCES browser_contexts(id) ON DELETE SET NULL, + context_token_budget INTEGER, + persist_full_tool_results BOOLEAN NOT NULL DEFAULT false, metadata JSONB NOT NULL DEFAULT '{}'::jsonb, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), diff --git a/db/migrations/0063_add_task_tracking.down.sql b/db/migrations/0063_add_task_tracking.down.sql new file mode 100644 index 00000000..591cc789 --- /dev/null +++ b/db/migrations/0063_add_task_tracking.down.sql @@ -0,0 +1,7 @@ +-- 0063_add_task_tracking (rollback) +-- Remove exec_id and pid columns from tasks table. + +DROP INDEX IF EXISTS idx_tasks_pid; +DROP INDEX IF EXISTS idx_tasks_exec_id; +ALTER TABLE tasks DROP COLUMN IF EXISTS pid; +ALTER TABLE tasks DROP COLUMN IF EXISTS exec_id; diff --git a/db/migrations/0063_add_task_tracking.up.sql b/db/migrations/0063_add_task_tracking.up.sql new file mode 100644 index 00000000..b8747361 --- /dev/null +++ b/db/migrations/0063_add_task_tracking.up.sql @@ -0,0 +1,17 @@ +-- 0063_add_task_tracking +-- Add exec_id and pid columns to tasks table for process tracking. + +CREATE TABLE IF NOT EXISTS tasks ( + id VARCHAR(255) PRIMARY KEY, + bot_id VARCHAR(255) NOT NULL, + name VARCHAR(255) NOT NULL, + command TEXT NOT NULL, + status VARCHAR(50) NOT NULL DEFAULT 'pending', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +ALTER TABLE tasks ADD COLUMN IF NOT EXISTS exec_id VARCHAR(255) NULL; +ALTER TABLE tasks ADD COLUMN IF NOT EXISTS pid INTEGER NULL; +CREATE INDEX IF NOT EXISTS idx_tasks_exec_id ON tasks(exec_id); +CREATE INDEX IF NOT EXISTS idx_tasks_pid ON tasks(pid); diff --git a/db/migrations/0064_revert_local_to_web.down.sql b/db/migrations/0064_revert_local_to_web.down.sql new file mode 100644 index 00000000..9c138fbc --- /dev/null +++ b/db/migrations/0064_revert_local_to_web.down.sql @@ -0,0 +1,9 @@ +-- 0064_revert_local_to_web (rollback) +-- Re-apply the 'local' convention by converting 'web' back to 'local'. + +UPDATE channel_identities SET channel_type = 'local' WHERE channel_type = 'web'; +UPDATE user_channel_bindings SET channel_type = 'local' WHERE channel_type = 'web'; +UPDATE bot_channel_configs SET channel_type = 'local' WHERE channel_type = 'web'; +UPDATE channel_identity_bind_codes SET channel_type = 'local' WHERE channel_type = 'web'; +UPDATE bot_channel_routes SET channel_type = 'local' WHERE channel_type = 'web'; +UPDATE bot_sessions SET channel_type = 'local' WHERE channel_type = 'web'; diff --git a/db/migrations/0064_revert_local_to_web.up.sql b/db/migrations/0064_revert_local_to_web.up.sql new file mode 100644 index 00000000..f3f0d65b --- /dev/null +++ b/db/migrations/0064_revert_local_to_web.up.sql @@ -0,0 +1,17 @@ +-- 0064_revert_local_to_web +-- Revert channel_type 'local' back to 'web' to match updated adapter constants. +-- The original 0056 migration merged web/cli → local; this undoes that change. + +-- For channel_identities with unique constraint on (channel_type, channel_subject_id): +-- delete 'local' rows that would conflict with existing 'web' rows, then update the rest. +DELETE FROM channel_identities WHERE channel_type = 'local' AND channel_subject_id IN ( + SELECT channel_subject_id FROM channel_identities WHERE channel_type = 'web' +); +UPDATE channel_identities SET channel_type = 'web' WHERE channel_type = 'local'; + +-- These tables don't have the same unique constraint, safe to update directly. +UPDATE user_channel_bindings SET channel_type = 'web' WHERE channel_type = 'local'; +UPDATE bot_channel_configs SET channel_type = 'web' WHERE channel_type = 'local'; +UPDATE channel_identity_bind_codes SET channel_type = 'web' WHERE channel_type = 'local'; +UPDATE bot_channel_routes SET channel_type = 'web' WHERE channel_type = 'local'; +UPDATE bot_sessions SET channel_type = 'web' WHERE channel_type = 'local'; diff --git a/db/migrations/0065_add_context_token_budget.down.sql b/db/migrations/0065_add_context_token_budget.down.sql new file mode 100644 index 00000000..ab7deceb --- /dev/null +++ b/db/migrations/0065_add_context_token_budget.down.sql @@ -0,0 +1,7 @@ +-- 0065_add_context_token_budget (down) +-- NOTE: After rolling back this migration, re-run `sqlc generate` to update the +-- generated Go code in internal/db/sqlc/. The Go structs will still contain the +-- new columns until regenerated. + +ALTER TABLE bots DROP COLUMN IF EXISTS persist_full_tool_results; +ALTER TABLE bots DROP COLUMN IF EXISTS context_token_budget; \ No newline at end of file diff --git a/db/migrations/0065_add_context_token_budget.up.sql b/db/migrations/0065_add_context_token_budget.up.sql new file mode 100644 index 00000000..4910745e --- /dev/null +++ b/db/migrations/0065_add_context_token_budget.up.sql @@ -0,0 +1,5 @@ +-- 0065_add_context_token_budget +-- Add context token budget and tool result persistence settings for large task optimization. + +ALTER TABLE bots ADD COLUMN IF NOT EXISTS context_token_budget INTEGER; +ALTER TABLE bots ADD COLUMN IF NOT EXISTS persist_full_tool_results BOOLEAN NOT NULL DEFAULT false; diff --git a/db/queries/settings.sql b/db/queries/settings.sql index eff58da0..cc6674d0 100644 --- a/db/queries/settings.sql +++ b/db/queries/settings.sql @@ -10,6 +10,7 @@ SELECT bots.compaction_enabled, bots.compaction_threshold, bots.compaction_ratio, + bots.timezone, chat_models.id AS chat_model_id, heartbeat_models.id AS heartbeat_model_id, compaction_models.id AS compaction_model_id, @@ -18,7 +19,9 @@ SELECT memory_providers.id AS memory_provider_id, image_models.id AS image_model_id, tts_models.id AS tts_model_id, - browser_contexts.id AS browser_context_id + browser_contexts.id AS browser_context_id, + bots.context_token_budget, + bots.persist_full_tool_results FROM bots LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id LEFT JOIN models AS heartbeat_models ON heartbeat_models.id = bots.heartbeat_model_id @@ -43,6 +46,7 @@ WITH updated AS ( compaction_enabled = sqlc.arg(compaction_enabled), compaction_threshold = sqlc.arg(compaction_threshold), compaction_ratio = sqlc.arg(compaction_ratio), + timezone = COALESCE(sqlc.narg(timezone), bots.timezone), chat_model_id = COALESCE(sqlc.narg(chat_model_id)::uuid, bots.chat_model_id), heartbeat_model_id = COALESCE(sqlc.narg(heartbeat_model_id)::uuid, bots.heartbeat_model_id), compaction_model_id = COALESCE(sqlc.narg(compaction_model_id)::uuid, bots.compaction_model_id), @@ -52,9 +56,11 @@ WITH updated AS ( image_model_id = COALESCE(sqlc.narg(image_model_id)::uuid, bots.image_model_id), tts_model_id = COALESCE(sqlc.narg(tts_model_id)::uuid, bots.tts_model_id), browser_context_id = COALESCE(sqlc.narg(browser_context_id)::uuid, bots.browser_context_id), + context_token_budget = COALESCE(sqlc.narg(context_token_budget), bots.context_token_budget), + persist_full_tool_results = sqlc.arg(persist_full_tool_results), updated_at = now() WHERE bots.id = sqlc.arg(id) - RETURNING bots.id, bots.language, bots.reasoning_enabled, bots.reasoning_effort, bots.heartbeat_enabled, bots.heartbeat_interval, bots.heartbeat_prompt, bots.compaction_enabled, bots.compaction_threshold, bots.compaction_ratio, bots.chat_model_id, bots.heartbeat_model_id, bots.compaction_model_id, bots.title_model_id, bots.image_model_id, bots.search_provider_id, bots.memory_provider_id, bots.tts_model_id, bots.browser_context_id + RETURNING bots.id, bots.language, bots.reasoning_enabled, bots.reasoning_effort, bots.heartbeat_enabled, bots.heartbeat_interval, bots.heartbeat_prompt, bots.compaction_enabled, bots.compaction_threshold, bots.compaction_ratio, bots.timezone, bots.chat_model_id, bots.heartbeat_model_id, bots.compaction_model_id, bots.title_model_id, bots.image_model_id, bots.search_provider_id, bots.memory_provider_id, bots.tts_model_id, bots.browser_context_id, bots.context_token_budget, bots.persist_full_tool_results ) SELECT updated.id AS bot_id, @@ -67,6 +73,7 @@ SELECT updated.compaction_enabled, updated.compaction_threshold, updated.compaction_ratio, + updated.timezone, chat_models.id AS chat_model_id, heartbeat_models.id AS heartbeat_model_id, compaction_models.id AS compaction_model_id, @@ -75,7 +82,9 @@ SELECT memory_providers.id AS memory_provider_id, image_models.id AS image_model_id, tts_models.id AS tts_model_id, - browser_contexts.id AS browser_context_id + browser_contexts.id AS browser_context_id, + updated.context_token_budget, + updated.persist_full_tool_results FROM updated LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id LEFT JOIN models AS heartbeat_models ON heartbeat_models.id = updated.heartbeat_model_id @@ -107,5 +116,7 @@ SET language = 'auto', memory_provider_id = NULL, tts_model_id = NULL, browser_context_id = NULL, + context_token_budget = NULL, + persist_full_tool_results = false, updated_at = now() WHERE id = $1; diff --git a/docker-compose.yml b/docker-compose.yml index 3fa6411f..c6fbdaa1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -22,7 +22,7 @@ services: - memoh-network migrate: - image: memohai/server:latest + image: memohai/server:local container_name: memoh-migrate entrypoint: ["/app/memoh-server", "migrate", "up"] volumes: @@ -35,7 +35,7 @@ services: - memoh-network server: - image: memohai/server:latest + image: memohai/server:local container_name: memoh-server privileged: true pid: host @@ -107,7 +107,7 @@ services: - memoh-network browser: - image: memohai/browser:${BROWSER_TAG:-latest} + image: memohai/browser:${BROWSER_TAG:-local} container_name: memoh-browser profiles: [browser] environment: diff --git a/docker/Dockerfile.browser b/docker/Dockerfile.browser index c67f49c9..2c2bcecc 100644 --- a/docker/Dockerfile.browser +++ b/docker/Dockerfile.browser @@ -39,6 +39,6 @@ RUN for core in $(echo "$BROWSER_CORES" | tr ',' ' '); do \ EXPOSE 8083 HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \ - CMD curl -sf http://127.0.0.1:8083/health || exit 1 + CMD curl -sf http://$(hostname):8083/health || exit 1 CMD ["bun", "run", "dist/index.js"] diff --git a/docs/task-mgmt-tools-PR.md b/docs/task-mgmt-tools-PR.md new file mode 100644 index 00000000..51d8a7e6 --- /dev/null +++ b/docs/task-mgmt-tools-PR.md @@ -0,0 +1,33 @@ +## Memoh Task Management Tools PR (feat-task-mgmt) + +### Protobuf (add to proto/bridge.proto) +```protobuf +message ListTasksRequest { string session_id = 1; } +message Task { string id = 1; string status = 3; int64 pid = 4; string command = 5; /* etc */ } +service TaskService { + rpc ListTasks(ListTasksRequest) returns (ListTasksResponse); + rpc KillTask(KillTaskRequest) returns (KillTaskResponse); + rpc TaskLogs(TaskLogsRequest) returns (stream TaskLogsResponse); +} +``` + +### DB Schema (migrations) +```sql +ALTER TABLE tasks ADD COLUMN exec_id VARCHAR(255) NULL, ADD COLUMN pid INTEGER NULL; +``` + +### server.go Diff (key impl) +```go +func (s *server) ListTasks(...) (*pb.ListTasksResponse, error) { + tasks, _ := listTasksBySession(s.db, req.SessionId) + // map + exec_status +} +func (s *server) KillTask(...) { killTask(s.db, req.TaskId) } +``` + +### TOOLS.md +Add: +- list_tasks(session_id?): List tasks. +示例 prompt: \"list_tasks 检查任务,kill_task 旧 exec。\" + +Full spawn details above. Ready for implement/merge. \ No newline at end of file diff --git a/go.mod b/go.mod index 9112f352..3cea9a48 100644 --- a/go.mod +++ b/go.mod @@ -41,6 +41,7 @@ require ( go.uber.org/fx v1.24.0 golang.org/x/crypto v0.48.0 golang.org/x/oauth2 v0.35.0 + golang.org/x/time v0.14.0 google.golang.org/grpc v1.78.0 google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v3 v3.0.1 @@ -144,7 +145,6 @@ require ( golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.34.0 // indirect - golang.org/x/time v0.14.0 // indirect golang.org/x/tools v0.42.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 // indirect sigs.k8s.io/yaml v1.6.0 // indirect diff --git a/internal/agent/agent.go b/internal/agent/agent.go index f5603fc8..f77f1622 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -7,6 +7,7 @@ import ( "fmt" "log/slog" "strings" + "time" sdk "github.com/memohai/twilight-ai/sdk" @@ -62,11 +63,24 @@ func (a *Agent) Generate(ctx context.Context, cfg RunConfig) (*GenerateResult, e return a.runGenerate(ctx, cfg) } +// sendEvent sends an event to the stream channel. It returns false if the +// context was cancelled (consumer stopped reading), allowing the caller to +// abort cleanly instead of leaking the goroutine on a blocked channel send. +func sendEvent(ctx context.Context, ch chan<- StreamEvent, evt StreamEvent) bool { + select { + case ch <- evt: + return true + case <-ctx.Done(): + return false + } +} + func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEvent) { // Stream emitter: tools targeting the current conversation push // side-effect events (attachments, reactions, speech) directly here. + // Uses sendEvent to avoid goroutine leaks when the consumer stops reading. streamEmitter := tools.StreamEmitter(func(evt tools.ToolStreamEvent) { - ch <- toolStreamEventToAgentEvent(evt) + sendEvent(ctx, ch, toolStreamEventToAgentEvent(evt)) }) var sdkTools []sdk.Tool @@ -74,7 +88,7 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv var err error sdkTools, err = a.assembleTools(ctx, cfg, streamEmitter) if err != nil { - ch <- StreamEvent{Type: EventError, Error: fmt.Sprintf("assemble tools: %v", err)} + sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: fmt.Sprintf("assemble tools: %v", err)}) return } } @@ -157,16 +171,53 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv opts := a.buildGenerateOptions(cfg, sdkTools, prepareStep) - streamResult, err := a.client.StreamText(ctx, opts...) - if err != nil { - ch <- StreamEvent{Type: EventError, Error: fmt.Sprintf("stream start: %v", err)} - return + retryCfg := cfg.Retry + if retryCfg.MaxAttempts <= 0 { + retryCfg = DefaultRetryConfig() } - ch <- StreamEvent{Type: EventAgentStart} + var streamResult *sdk.StreamResult + for attempt := 0; attempt < retryCfg.MaxAttempts; attempt++ { + var err error + streamResult, err = a.client.StreamText(ctx, opts...) + if err == nil { + break + } + if !isRetryableStreamError(err) { + sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: fmt.Sprintf("stream start: %v", err)}) + return + } + a.logger.Warn("stream start failed, retrying", + slog.Int("attempt", attempt+1), + slog.Int("max_attempts", retryCfg.MaxAttempts), + slog.String("error", err.Error()), + ) + if !sendEvent(ctx, ch, StreamEvent{ + Type: EventRetry, + Attempt: attempt + 1, + MaxAttempt: retryCfg.MaxAttempts, + RetryError: err.Error(), + }) { + return + } + if attempt+1 >= retryCfg.MaxAttempts { + sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: fmt.Sprintf("stream start: all %d attempts failed (last: %v)", retryCfg.MaxAttempts, err)}) + return + } + delay := retryDelay(attempt, retryCfg) + if delay > 0 { + if err := sleepWithContext(ctx, delay); err != nil { + sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: fmt.Sprintf("stream start: context cancelled during retry: %v", err)}) + return + } + } + } + + sendEvent(ctx, ch, StreamEvent{Type: EventAgentStart}) var allText strings.Builder aborted := false + stepNumber := 0 for part := range streamResult.Stream { if ctx.Err() != nil { @@ -179,14 +230,18 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv _ = p // stream start already emitted case *sdk.TextStartPart: - ch <- StreamEvent{Type: EventTextStart} + if !sendEvent(ctx, ch, StreamEvent{Type: EventTextStart}) { + aborted = true + } case *sdk.TextDeltaPart: if p.Text != "" { if textLoopProbeBuffer != nil { textLoopProbeBuffer.Push(p.Text) } - ch <- StreamEvent{Type: EventTextDelta, Delta: p.Text} + if !sendEvent(ctx, ch, StreamEvent{Type: EventTextDelta, Delta: p.Text}) { + aborted = true + } allText.WriteString(p.Text) } @@ -194,26 +249,42 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv if textLoopProbeBuffer != nil { textLoopProbeBuffer.Flush() } - ch <- StreamEvent{Type: EventTextEnd} + stepNumber++ + if !sendEvent(ctx, ch, StreamEvent{Type: EventTextEnd}) || + !sendEvent(ctx, ch, StreamEvent{ + Type: EventProgress, + StepNumber: stepNumber, + ProgressStatus: "text", + }) { + aborted = true + } case *sdk.ReasoningStartPart: - ch <- StreamEvent{Type: EventReasoningStart} + if !sendEvent(ctx, ch, StreamEvent{Type: EventReasoningStart}) { + aborted = true + } case *sdk.ReasoningDeltaPart: - ch <- StreamEvent{Type: EventReasoningDelta, Delta: p.Text} + if !sendEvent(ctx, ch, StreamEvent{Type: EventReasoningDelta, Delta: p.Text}) { + aborted = true + } case *sdk.ReasoningEndPart: - ch <- StreamEvent{Type: EventReasoningEnd} + if !sendEvent(ctx, ch, StreamEvent{Type: EventReasoningEnd}) { + aborted = true + } case *sdk.StreamToolCallPart: if textLoopProbeBuffer != nil { textLoopProbeBuffer.Flush() } - ch <- StreamEvent{ + if !sendEvent(ctx, ch, StreamEvent{ Type: EventToolCallStart, ToolName: p.ToolName, ToolCallID: p.ToolCallID, Input: p.Input, + }) { + aborted = true } case *sdk.ToolProgressPart: @@ -230,12 +301,20 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv delete(toolLoopAbortCallIDs, p.ToolCallID) shouldAbort = true } - ch <- StreamEvent{ + stepNumber++ + if !sendEvent(ctx, ch, StreamEvent{ Type: EventToolCallEnd, ToolName: p.ToolName, ToolCallID: p.ToolCallID, Input: p.Input, Result: p.Output, + }) || !sendEvent(ctx, ch, StreamEvent{ + Type: EventProgress, + StepNumber: stepNumber, + ToolName: p.ToolName, + ProgressStatus: "tool_result", + }) { + aborted = true } if shouldAbort { a.logger.Warn("tool loop abort triggered", slog.String("tool_call_id", p.ToolCallID)) @@ -243,11 +322,13 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv } case *sdk.StreamToolErrorPart: - ch <- StreamEvent{ + if !sendEvent(ctx, ch, StreamEvent{ Type: EventToolCallEnd, ToolName: p.ToolName, ToolCallID: p.ToolCallID, Error: p.Error.Error(), + }) { + aborted = true } case *sdk.StreamFilePart: @@ -255,18 +336,33 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv if mediaType == "" { mediaType = "image/png" } - ch <- StreamEvent{ + if !sendEvent(ctx, ch, StreamEvent{ Type: EventAttachment, Attachments: []FileAttachment{{ Type: "image", URL: fmt.Sprintf("data:%s;base64,%s", mediaType, p.File.Data), Mime: mediaType, }}, + }) { + aborted = true } case *sdk.ErrorPart: - ch <- StreamEvent{Type: EventError, Error: p.Error.Error()} - aborted = true + errMsg := p.Error.Error() + sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: errMsg}) + + // Mid-stream retry: if the error is retryable, attempt to continue + // the agent run from the accumulated state. This also handles + // errors at step 0 (e.g. timeout awaiting response headers) since + // no work has been completed yet and retrying from the start is safe. + if isRetryableStreamError(p.Error) { + streamResult, aborted = a.runMidStreamRetry( + ctx, ch, cfg, sdkTools, prepareStep, streamResult, + stepNumber, errMsg, &allText, textLoopProbeBuffer, + ) + } else { + aborted = true + } case *sdk.AbortPart: aborted = true @@ -311,8 +407,16 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv termEvent.Type = EventAgentAbort } else { termEvent.Type = EventAgentEnd + // Warn if LLM produced no text and no tool calls — likely a context overflow. + if allText.Len() == 0 && stepNumber == 0 { + a.logger.Warn("agent produced empty response (no text, no tool calls)", + slog.String("bot_id", cfg.Identity.BotID), + slog.Int("input_messages", len(cfg.Messages)), + slog.Int("input_tokens", totalUsage.InputTokens), + ) + } } - ch <- termEvent + sendEvent(ctx, ch, termEvent) } func (a *Agent) runGenerate(ctx context.Context, cfg RunConfig) (*GenerateResult, error) { @@ -421,9 +525,29 @@ func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool, prepareStep if len(tools) > 0 && cfg.SupportsToolCall { opts = append(opts, sdk.WithTools(tools)) } - if prepareStep != nil { - opts = append(opts, sdk.WithPrepareStep(prepareStep)) + + // Wrap the existing prepareStep (if any) with mid-task context pruning. + // When the message array grows large during multi-tool runs, this prunes + // older tool results to keep the context window manageable. + basePrepare := prepareStep + keepSteps := cfg.MidTaskPruneKeepSteps + if keepSteps <= 0 { + keepSteps = MidTaskPruneKeepStepsDefault } + threshold := cfg.MidTaskPruneThreshold + if threshold <= 0 { + threshold = MidTaskPruneThresholdDefault + } + midTaskPrune := func(p *sdk.GenerateParams) *sdk.GenerateParams { + if basePrepare != nil { + if override := basePrepare(p); override != nil { + p = override + } + } + return pruneOldToolResults(p, keepSteps, threshold) + } + opts = append(opts, sdk.WithPrepareStep(midTaskPrune)) + opts = append(opts, models.BuildReasoningOptions(models.SDKModelConfig{ ClientType: models.ResolveClientType(cfg.Model), ReasoningConfig: &models.ReasoningConfig{ @@ -504,6 +628,8 @@ func toolStreamEventToAgentEvent(evt tools.ToolStreamEvent) StreamEvent { ss = append(ss, SpeechItem{Text: s.Text}) } return StreamEvent{Type: EventSpeech, Speeches: ss} + case tools.StreamEventSpawnHeartbeat: + return StreamEvent{Type: EventProgress, ProgressStatus: "spawn_running"} default: return StreamEvent{} } @@ -541,3 +667,249 @@ func wrapToolsWithLoopGuard(tools []sdk.Tool, guard *ToolLoopGuard, abortCallIDs } return wrapped } + +const ( + // MidTaskPruneKeepStepsDefault is the number of recent tool-call steps to keep + // intact when pruning older tool results during a multi-step agent run. + MidTaskPruneKeepStepsDefault = 4 + // MidTaskPruneThresholdDefault is the minimum number of messages before pruning activates. + MidTaskPruneThresholdDefault = 20 +) + +// pruneOldToolResults prunes older tool result messages in the SDK params to +// keep the context window manageable during long multi-tool agent runs. It +// keeps the most recent keepSteps tool-call cycles intact and replaces older +// tool results with size summaries. +func pruneOldToolResults(p *sdk.GenerateParams, keepSteps, threshold int) *sdk.GenerateParams { + msgs := p.Messages + if len(msgs) < threshold { + return p + } + + // Count complete tool-call cycles (tool-result pair) from the end to find the cutoff. + toolResultCount := 0 + cutoffIdx := len(msgs) + for i := len(msgs) - 1; i >= 0; i-- { + if msgs[i].Role == sdk.MessageRoleTool { + // Check that the preceding assistant message contains the matching tool call + // to ensure we count complete cycles, not orphaned results. + hasMatchingCall := false + for j := i - 1; j >= 0; j-- { + if msgs[j].Role == sdk.MessageRoleAssistant { + // If there's another tool result between this and the assistant msg, + // it means this assistant message belongs to a different cycle. + if j+1 < i && msgs[j+1].Role == sdk.MessageRoleTool { + break + } + hasMatchingCall = true + break + } + if msgs[j].Role == sdk.MessageRoleUser { + break + } + } + if hasMatchingCall { + toolResultCount++ + if toolResultCount > keepSteps { + cutoffIdx = i + break + } + } + } + } + if cutoffIdx >= len(msgs) { + return p // not enough tool messages to prune + } + + // Build a new slice so the original messages can be GC'd. + pruned := make([]sdk.Message, 0, len(msgs)) + pruned = append(pruned, msgs[:cutoffIdx]...) + for i := cutoffIdx; i < len(msgs); i++ { + if msgs[i].Role != sdk.MessageRoleTool { + pruned = append(pruned, msgs[i]) + continue + } + // Measure content size from ToolResultPart entries. + contentSize := 0 + for _, part := range msgs[i].Content { + if tr, ok := part.(sdk.ToolResultPart); ok { + contentSize += len(fmt.Sprintf("%v", tr.Result)) + } + } + if contentSize > 512 { // only prune if content is large enough + // Build replacement parts preserving ToolResultPart type so that + // provider serializers that validate part types per role stay happy. + replacementParts := make([]sdk.MessagePart, 0, len(msgs[i].Content)) + for _, part := range msgs[i].Content { + if tr, ok := part.(sdk.ToolResultPart); ok { + replacementParts = append(replacementParts, sdk.ToolResultPart{ + ToolCallID: tr.ToolCallID, + ToolName: tr.ToolName, + Result: fmt.Sprintf("[tool result pruned: %d bytes]", contentSize), + }) + } else { + replacementParts = append(replacementParts, part) + } + } + pruned = append(pruned, sdk.Message{ + Role: msgs[i].Role, + Content: replacementParts, + }) + } else { + pruned = append(pruned, msgs[i]) + } + } + + p.Messages = pruned + return p +} + +// runMidStreamRetry attempts to continue the agent stream after a retryable +// mid-stream error. It re-invokes StreamText with the accumulated messages +// and drains the new stream into the same output channel. +func (a *Agent) runMidStreamRetry( + ctx context.Context, + ch chan<- StreamEvent, + cfg RunConfig, + sdkTools []sdk.Tool, + prepareStep func(*sdk.GenerateParams) *sdk.GenerateParams, + prevResult *sdk.StreamResult, + stepNumber int, + errMsg string, + allText *strings.Builder, + textLoopProbeBuffer *TextLoopProbeBuffer, +) (*sdk.StreamResult, bool) { + retryCfg := DefaultRetryConfig() + for attempt := 0; attempt < retryCfg.MaxAttempts; attempt++ { + a.logger.Warn("mid-stream error, retrying", + slog.Int("step", stepNumber), + slog.Int("attempt", attempt+1), + slog.Int("max_attempts", retryCfg.MaxAttempts), + slog.String("error", errMsg), + ) + if !sendEvent(ctx, ch, StreamEvent{ + Type: EventRetry, + Attempt: attempt + 1, + MaxAttempt: retryCfg.MaxAttempts, + RetryError: errMsg, + }) { + return prevResult, true + } + + delay := retryDelay(attempt, retryCfg) + if delay > 0 { + if err := sleepWithContext(ctx, delay); err != nil { + return prevResult, true // aborted + } + } + + // Re-invoke StreamText with accumulated messages. + // Use buildGenerateOptions so retry benefits from mid-task pruning, + // media resolution, and other prepare-step logic — same as initial stream. + retryCfgCopy := cfg + retryCfgCopy.Messages = prevResult.Messages + retryOpts := a.buildGenerateOptions(retryCfgCopy, sdkTools, prepareStep) + + retryResult, retryErr := a.client.StreamText(ctx, retryOpts...) + if retryErr != nil { + a.logger.Warn("mid-stream retry failed to start", + slog.Int("attempt", attempt+1), + slog.String("error", retryErr.Error()), + ) + // Update errMsg so the next retry event shows the latest error. + errMsg = retryErr.Error() + continue + } + + // Drain the retry stream into the main event loop + aborted := false + for retryPart := range retryResult.Stream { + switch rp := retryPart.(type) { + case *sdk.TextStartPart: + if !sendEvent(ctx, ch, StreamEvent{Type: EventTextStart}) { + aborted = true + } + case *sdk.TextDeltaPart: + if rp.Text != "" { + if textLoopProbeBuffer != nil { + textLoopProbeBuffer.Push(rp.Text) + } + if !sendEvent(ctx, ch, StreamEvent{Type: EventTextDelta, Delta: rp.Text}) { + aborted = true + } + allText.WriteString(rp.Text) + } + case *sdk.TextEndPart: + if textLoopProbeBuffer != nil { + textLoopProbeBuffer.Flush() + } + stepNumber++ + if !sendEvent(ctx, ch, StreamEvent{Type: EventTextEnd}) { + aborted = true + } + case *sdk.StreamToolCallPart: + if textLoopProbeBuffer != nil { + textLoopProbeBuffer.Flush() + } + if !sendEvent(ctx, ch, StreamEvent{ + Type: EventToolCallStart, + ToolName: rp.ToolName, + ToolCallID: rp.ToolCallID, + Input: rp.Input, + }) { + aborted = true + } + case *sdk.StreamToolResultPart: + stepNumber++ + if !sendEvent(ctx, ch, StreamEvent{ + Type: EventToolCallEnd, + ToolName: rp.ToolName, + ToolCallID: rp.ToolCallID, + Input: rp.Input, + Result: rp.Output, + }) || !sendEvent(ctx, ch, StreamEvent{ + Type: EventProgress, + StepNumber: stepNumber, + ToolName: rp.ToolName, + ProgressStatus: "tool_result", + }) { + aborted = true + } + case *sdk.StreamToolErrorPart: + if !sendEvent(ctx, ch, StreamEvent{ + Type: EventToolCallEnd, + ToolName: rp.ToolName, + ToolCallID: rp.ToolCallID, + Error: rp.Error.Error(), + }) { + aborted = true + } + case *sdk.ErrorPart: + sendEvent(ctx, ch, StreamEvent{Type: EventError, Error: rp.Error.Error()}) + aborted = true + case *sdk.AbortPart: + aborted = true + case *sdk.FinishPart: + // handled after loop + } + if aborted { + break + } + } + return retryResult, aborted + } + // All retry attempts failed + return prevResult, true +} + +// sleepWithContext sleeps for the given duration or returns context error. +func sleepWithContext(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-timer.C: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/internal/agent/read_media_test.go b/internal/agent/read_media_test.go index f6074398..2c7b9878 100644 --- a/internal/agent/read_media_test.go +++ b/internal/agent/read_media_test.go @@ -175,7 +175,11 @@ func assertInjectedReadMediaMessage(t *testing.T, msg sdk.Message, expectedImage func TestAgentGenerateReadMediaInjectsImageIntoNextStep(t *testing.T) { t.Parallel() - pngBytes := []byte("\x89PNG\r\n\x1a\npayload") + // The PNG data must contain a null byte (\x00) so that the execRead + // binary probe (bytes.IndexByte(probe, 0)) detects it as binary and + // delegates to ReadImageFromContainer. Real PNG files always contain + // null bytes in their IHDR and other chunks. + pngBytes := []byte("\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00payload") expectedDataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(pngBytes) modelProvider := &agentReadMediaMockProvider{ @@ -289,7 +293,7 @@ func TestAgentGenerateReadMediaInjectsImageIntoNextStep(t *testing.T) { func TestAgentGenerateReadMediaInjectsAnthropicSafeImageIntoNextStep(t *testing.T) { t.Parallel() - pngBytes := []byte("\x89PNG\r\n\x1a\npayload") + pngBytes := []byte("\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00payload") expectedBase64 := base64.StdEncoding.EncodeToString(pngBytes) modelProvider := &agentReadMediaMockProvider{ @@ -356,7 +360,7 @@ func TestAgentGenerateReadMediaInjectsAnthropicSafeImageIntoNextStep(t *testing. func TestAgentStreamReadMediaPersistsInjectedImageInTerminalMessages(t *testing.T) { t.Parallel() - pngBytes := []byte("\x89PNG\r\n\x1a\npayload") + pngBytes := []byte("\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00payload") expectedDataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(pngBytes) modelProvider := &agentReadMediaMockProvider{ diff --git a/internal/agent/retry.go b/internal/agent/retry.go new file mode 100644 index 00000000..e63bac41 --- /dev/null +++ b/internal/agent/retry.go @@ -0,0 +1,93 @@ +package agent + +import ( + "context" + "errors" + "math/rand/v2" + "net" + "regexp" + "strings" + "time" +) + +// RetryConfig controls retry behavior for stream failures. +type RetryConfig struct { + MaxAttempts int // total retry attempts + FastAttempts int // first N attempts with no delay + BaseDelay time.Duration // backoff base for non-fast attempts + MaxDelay time.Duration // backoff cap +} + +// err429Pattern matches HTTP 429 status codes in error strings. +// Requires a non-digit boundary to avoid matching "429" inside larger numbers. +var err429Pattern = regexp.MustCompile(`(^|[^0-9])429($|[^0-9])`) + +// errEOFPattern matches EOF or connection-level resets. +var errEOFPattern = regexp.MustCompile(`(?i)connection (reset|refused)|EOF$`) + +// serverErrPattern matches "api error 5XX" where XX is any two digits. +var serverErrPattern = regexp.MustCompile(`api error 5\d{2}`) + +// DefaultRetryConfig returns the default retry strategy: 10 attempts total, +// first 5 fast (no delay), last 5 with exponential backoff. +func DefaultRetryConfig() RetryConfig { + return RetryConfig{ + MaxAttempts: 10, + FastAttempts: 5, + BaseDelay: 1 * time.Second, + MaxDelay: 30 * time.Second, + } +} + +// isRetryableStreamError returns true for errors worth retrying. +func isRetryableStreamError(err error) bool { + if err == nil { + return false + } + // Context cancelled/expired — do NOT retry (check first since + // context.DeadlineExceeded also satisfies net.Error) + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + // Network-level errors (connection refused, timeout, DNS) + var netErr net.Error + if errors.As(err, &netErr) { + return true + } + // HTTP status errors: retry on 429 and 5xx + errStr := err.Error() + if err429Pattern.MatchString(errStr) { + return true + } + if strings.Contains(errStr, "rate limit") || strings.Contains(errStr, "rate_limit") { + return true + } + if serverErrPattern.MatchString(errStr) { + return true + } + // Connection reset / EOF + if errEOFPattern.MatchString(errStr) { + return true + } + return false +} + +// retryDelay returns the delay before the next retry attempt. +// For fast attempts (0-indexed < FastAttempts): no delay. +// For backoff attempts: exponential delay with jitter, capped at MaxDelay. +func retryDelay(attempt int, cfg RetryConfig) time.Duration { + if attempt < cfg.FastAttempts { + return 0 + } + // Exponential backoff: base * 2^(attempt - fastAttempts), capped to prevent overflow + backoffIdx := attempt - cfg.FastAttempts + if backoffIdx > 20 { + backoffIdx = 20 + } + delay := cfg.BaseDelay * time.Duration(1< 0 { payload["files"] = files } - return p.doGatewayAction(ctx, botID, contextID, payload) + return p.doGatewayAction(ctx, session, botID, contextID, payload) } func (p *BrowserProvider) execObserve(ctx context.Context, session SessionContext, args map[string]any) (any, error) { @@ -199,7 +199,7 @@ func (p *BrowserProvider) execObserve(ctx context.Context, session SessionContex if v, ok := args["full_page"].(bool); ok { payload["full_page"] = v } - return p.doGatewayAction(ctx, botID, contextID, payload) + return p.doGatewayAction(ctx, session, botID, contextID, payload) } func (p *BrowserProvider) ensureContext(ctx context.Context, botID, contextID string, bc browsercontexts.BrowserContext) error { @@ -242,7 +242,7 @@ func (p *BrowserProvider) ensureContext(ctx context.Context, botID, contextID st return nil } -func (p *BrowserProvider) doGatewayAction(ctx context.Context, botID, contextID string, payload map[string]any) (any, error) { +func (p *BrowserProvider) doGatewayAction(ctx context.Context, session SessionContext, botID, contextID string, payload map[string]any) (any, error) { body, _ := json.Marshal(payload) actionURL := fmt.Sprintf("%s/context/%s/action", p.gatewayBaseURL, contextID) req, err := http.NewRequestWithContext(ctx, http.MethodPost, actionURL, bytes.NewReader(body)) @@ -272,31 +272,33 @@ func (p *BrowserProvider) doGatewayAction(ctx context.Context, botID, contextID return nil, fmt.Errorf("%s", errMsg) } if b64, ok := gwResp.Data["screenshot"].(string); ok && b64 != "" { - return p.buildScreenshotResult(ctx, botID, b64), nil + return p.buildScreenshotResult(ctx, botID, b64, session), nil } return gwResp.Data, nil } const browserScreenshotDir = "/data/browser-screenshots" -func (p *BrowserProvider) buildScreenshotResult(ctx context.Context, botID, base64Data string) any { +func (p *BrowserProvider) buildScreenshotResult(ctx context.Context, botID, base64Data string, session SessionContext) any { mimeType := "image/png" imgBytes, err := base64.StdEncoding.DecodeString(base64Data) if err != nil { return map[string]any{ "content": []map[string]any{ - {"type": "text", "text": "Screenshot captured (failed to decode for saving)"}, - {"type": "image", "data": base64Data, "mimeType": mimeType}, + {"type": "text", "text": "Screenshot captured (failed to decode image data)"}, }, } } + + // Emit the screenshot as attachment for user delivery (once). + p.emitScreenshotAttachment(session, base64Data, mimeType, int64(len(imgBytes))) + containerPath := fmt.Sprintf("%s/%d.png", browserScreenshotDir, time.Now().UnixMilli()) client, clientErr := p.containers.MCPClient(ctx, botID) if clientErr != nil { return map[string]any{ "content": []map[string]any{ - {"type": "text", "text": "Screenshot captured (container not reachable, not saved to disk)"}, - {"type": "image", "data": base64Data, "mimeType": mimeType}, + {"type": "text", "text": "Screenshot captured and sent to user (container not reachable, not saved to disk)"}, }, } } @@ -305,19 +307,35 @@ func (p *BrowserProvider) buildScreenshotResult(ctx context.Context, botID, base if writeErr := client.WriteFile(ctx, containerPath, imgBytes); writeErr != nil { return map[string]any{ "content": []map[string]any{ - {"type": "text", "text": fmt.Sprintf("Screenshot captured (failed to save: %s)", writeErr.Error())}, - {"type": "image", "data": base64Data, "mimeType": mimeType}, + {"type": "text", "text": fmt.Sprintf("Screenshot captured and sent to user (failed to save: %s)", writeErr.Error())}, }, } } return map[string]any{ "content": []map[string]any{ - {"type": "text", "text": fmt.Sprintf("Screenshot saved to %s", containerPath)}, - {"type": "image", "data": base64Data, "mimeType": mimeType}, + {"type": "text", "text": fmt.Sprintf("Screenshot saved to %s and sent to user", containerPath)}, }, } } +// emitScreenshotAttachment pushes the screenshot as an image attachment into the +// agent stream so it gets delivered to the user's chat. +func (*BrowserProvider) emitScreenshotAttachment(session SessionContext, base64Data, mimeType string, size int64) { + if session.Emitter == nil { + return + } + dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data) + session.Emitter(ToolStreamEvent{ + Type: StreamEventAttachment, + Attachments: []Attachment{{ + Type: "image", + URL: dataURL, + Mime: mimeType, + Size: size, + }}, + }) +} + func (p *BrowserProvider) execRemoteSession(ctx context.Context, session SessionContext, args map[string]any) (any, error) { botID := strings.TrimSpace(session.BotID) if botID == "" { diff --git a/internal/agent/tools/container.go b/internal/agent/tools/container.go index b8ef031c..b5fc76b0 100644 --- a/internal/agent/tools/container.go +++ b/internal/agent/tools/container.go @@ -1,6 +1,7 @@ package tools import ( + "bytes" "context" "errors" "fmt" @@ -8,6 +9,7 @@ import ( "log/slog" "math" "strings" + "time" sdk "github.com/memohai/twilight-ai/sdk" @@ -16,6 +18,15 @@ import ( const defaultContainerExecWorkDir = "/data" +// containerOpTimeout is the maximum time allowed for individual file +// operations (read, write, list, edit). Exec has its own timeout. +const containerOpTimeout = 30 * time.Second + +// largeFileThreshold defines the size above which file operations use +// streaming (async chunked I/O) instead of loading fully into memory. +// Files <= this threshold use the simpler synchronous gRPC calls. +const largeFileThreshold = 512 * 1024 // 512 KB + type ContainerProvider struct { clients bridge.Provider execWorkDir string @@ -37,7 +48,7 @@ func (p *ContainerProvider) Tools(_ context.Context, session SessionContext) ([] wd := p.execWorkDir sess := session - readDesc := fmt.Sprintf("Read file content inside the bot container. Supports pagination for large files. Max %d lines / %d bytes per call.", readMaxLines, readMaxBytes) + readDesc := "Read file content inside the bot container. Reads the full file by default; use line_offset and n_lines for pagination. Files up to ~16 MB are supported." if sess.SupportsImageInput { readDesc += " Also supports reading image files (PNG, JPEG, GIF, WebP) — binary images are loaded into model context automatically." } @@ -51,7 +62,7 @@ func (p *ContainerProvider) Tools(_ context.Context, session SessionContext) ([] "properties": map[string]any{ "path": map[string]any{"type": "string", "description": fmt.Sprintf("File path (relative to %s or absolute inside container)", wd)}, "line_offset": map[string]any{"type": "integer", "description": "Line number to start reading from (1-indexed). Default: 1.", "minimum": 1, "default": 1}, - "n_lines": map[string]any{"type": "integer", "description": fmt.Sprintf("Number of lines to read per call. Default: %d. Max: %d.", readMaxLines, readMaxLines), "minimum": 1, "maximum": readMaxLines, "default": readMaxLines}, + "n_lines": map[string]any{"type": "integer", "description": "Number of lines to read. Default: read entire file.", "minimum": 1}, }, "required": []string{"path"}, }, @@ -61,7 +72,7 @@ func (p *ContainerProvider) Tools(_ context.Context, session SessionContext) ([] }, { Name: "write", - Description: "Write file content inside the bot container.", + Description: "Write file content inside the bot container. Creates parent directories automatically. Handles files of any size.", Parameters: map[string]any{ "type": "object", "properties": map[string]any{ @@ -156,7 +167,10 @@ func (p *ContainerProvider) getClient(ctx context.Context, botID string) (*bridg } func (p *ContainerProvider) execRead(ctx context.Context, session SessionContext, args map[string]any) (any, error) { - client, err := p.getClient(ctx, session.BotID) + opCtx, opCancel := context.WithTimeout(ctx, containerOpTimeout) + defer opCancel() + + client, err := p.getClient(opCtx, session.BotID) if err != nil { return nil, err } @@ -164,46 +178,93 @@ func (p *ContainerProvider) execRead(ctx context.Context, session SessionContext if filePath == "" { return nil, errors.New("path is required") } - lineOffset := int32(1) + + lineOffset := 1 if offset, ok, err := IntArg(args, "line_offset"); err != nil { return nil, fmt.Errorf("invalid line_offset: %w", err) } else if ok { if offset < 1 { return nil, errors.New("line_offset must be >= 1") } - if offset > math.MaxInt32 { - return nil, errors.New("line_offset exceeds maximum") - } - lineOffset = int32(offset) + lineOffset = offset } - nLines := int32(readMaxLines) + nLines := 0 // 0 = read entire file if n, ok, err := IntArg(args, "n_lines"); err != nil { return nil, fmt.Errorf("invalid n_lines: %w", err) - } else if ok { - if n < 1 { - return nil, errors.New("n_lines must be >= 1") - } - if n > readMaxLines { - n = readMaxLines - } - nLines = int32(n) //nolint:gosec // bounded by readMaxLines + } else if ok && n > 0 { + nLines = n } - resp, err := client.ReadFile(ctx, filePath, lineOffset, nLines) + + // Pre-check file size to avoid loading excessively large files into + // memory. The gRPC transport is capped at 16 MB, so anything larger + // would fail anyway; reject early with a clear message. + const maxReadBytes = 16 * 1024 * 1024 // 16 MB + if stat, err := client.Stat(opCtx, filePath); err == nil && stat != nil { + if stat.GetSize() > maxReadBytes { + return nil, fmt.Errorf("file is too large (%d bytes, limit %d bytes). Use exec with head/tail/sed for partial reads", stat.GetSize(), maxReadBytes) + } + } + + // Stream-read the full file content. + reader, err := client.ReadRaw(opCtx, filePath) if err != nil { return nil, err } - if resp.GetBinary() { + defer func() { _ = reader.Close() }() + + // Probe for binary content. + probe := make([]byte, 8*1024) + probeN, probeErr := reader.Read(probe) + if probeErr != nil && probeErr != io.EOF { + return nil, fmt.Errorf("read probe: %w", probeErr) + } + if bytes.IndexByte(probe[:probeN], 0) >= 0 { if !session.SupportsImageInput { return nil, errors.New("file appears to be binary. Read tool only supports text files (image reading not available for this model)") } - return ReadImageFromContainer(ctx, client, filePath, defaultReadMediaMaxBytes), nil + return ReadImageFromContainer(opCtx, client, filePath, defaultReadMediaMaxBytes), nil } - content := addLineNumbers(resp.GetContent(), lineOffset) - return map[string]any{"content": content, "total_lines": resp.GetTotalLines()}, nil + + // Read remaining content after probe. + var buf strings.Builder + buf.Write(probe[:probeN]) + if probeErr != io.EOF { + remaining, readErr := io.ReadAll(reader) + if readErr != nil { + return nil, fmt.Errorf("read file: %w", readErr) + } + buf.Write(remaining) + } + + fullContent := buf.String() + lines := strings.Split(fullContent, "\n") + totalLines := len(lines) + + // Apply line_offset and n_lines. + start := lineOffset - 1 // convert to 0-based + if start > totalLines { + start = totalLines + } + end := totalLines + if nLines > 0 && start+nLines < end { + end = start + nLines + } + + selectedLines := lines[start:end] + content := strings.Join(selectedLines, "\n") + if !strings.HasSuffix(content, "\n") && end < totalLines { + content += "\n" + } + + content = addLineNumbers(content, int32(lineOffset)) + return map[string]any{"content": content, "total_lines": totalLines}, nil } func (p *ContainerProvider) execWrite(ctx context.Context, session SessionContext, args map[string]any) (any, error) { - client, err := p.getClient(ctx, session.BotID) + opCtx, opCancel := context.WithTimeout(ctx, containerOpTimeout) + defer opCancel() + + client, err := p.getClient(opCtx, session.BotID) if err != nil { return nil, err } @@ -212,14 +273,28 @@ func (p *ContainerProvider) execWrite(ctx context.Context, session SessionContex if filePath == "" { return nil, errors.New("path is required") } - if err := client.WriteFile(ctx, filePath, []byte(content)); err != nil { - return nil, err + + data := []byte(content) + if len(data) > largeFileThreshold { + // Large content: use streaming WriteRaw to avoid loading everything + // into a single gRPC message and to allow incremental transfer. + if _, err := client.WriteRaw(opCtx, filePath, strings.NewReader(content)); err != nil { + return nil, err + } + } else { + // Small content: simple synchronous write. + if err := client.WriteFile(opCtx, filePath, data); err != nil { + return nil, err + } } return map[string]any{"ok": true}, nil } func (p *ContainerProvider) execList(ctx context.Context, session SessionContext, args map[string]any) (any, error) { - client, err := p.getClient(ctx, session.BotID) + opCtx, opCancel := context.WithTimeout(ctx, containerOpTimeout) + defer opCancel() + + client, err := p.getClient(opCtx, session.BotID) if err != nil { return nil, err } @@ -260,7 +335,7 @@ func (p *ContainerProvider) execList(ctx context.Context, session SessionContext collapseThreshold = listCollapseThreshold } - result, err := client.ListDir(ctx, dirPath, recursive, offset, limit, collapseThreshold) + result, err := client.ListDir(opCtx, dirPath, recursive, offset, limit, collapseThreshold) if err != nil { return nil, err } @@ -288,7 +363,10 @@ func (p *ContainerProvider) execList(ctx context.Context, session SessionContext } func (p *ContainerProvider) execEdit(ctx context.Context, session SessionContext, args map[string]any) (any, error) { - client, err := p.getClient(ctx, session.BotID) + opCtx, opCancel := context.WithTimeout(ctx, containerOpTimeout) + defer opCancel() + + client, err := p.getClient(opCtx, session.BotID) if err != nil { return nil, err } @@ -298,7 +376,9 @@ func (p *ContainerProvider) execEdit(ctx context.Context, session SessionContext if filePath == "" || oldText == "" { return nil, errors.New("path, old_text and new_text are required") } - reader, err := client.ReadRaw(ctx, filePath) + + // Read file content via streaming RPC. + reader, err := client.ReadRaw(opCtx, filePath) if err != nil { return nil, err } @@ -307,12 +387,22 @@ func (p *ContainerProvider) execEdit(ctx context.Context, session SessionContext if err != nil { return nil, err } + updated, err := applyEdit(string(raw), filePath, oldText, newText) if err != nil { return nil, err } - if err := client.WriteFile(ctx, filePath, []byte(updated)); err != nil { - return nil, err + + updatedBytes := []byte(updated) + if len(updatedBytes) > largeFileThreshold { + // Large result: stream-write to avoid gRPC message size issues. + if _, err := client.WriteRaw(opCtx, filePath, strings.NewReader(updated)); err != nil { + return nil, err + } + } else { + if err := client.WriteFile(opCtx, filePath, updatedBytes); err != nil { + return nil, err + } } return map[string]any{"ok": true}, nil } diff --git a/internal/agent/tools/prune.go b/internal/agent/tools/prune.go index 0e74fa83..d460b0ac 100644 --- a/internal/agent/tools/prune.go +++ b/internal/agent/tools/prune.go @@ -5,18 +5,10 @@ import ( ) const ( - toolOutputHeadBytes = 4 * 1024 - toolOutputTailBytes = 1 * 1024 - toolOutputHeadLines = 150 - toolOutputTailLines = 50 - - readMaxLines = 200 - readMaxBytes = 5120 - readMaxLineLength = 1000 - readHeadBytes = 3072 - readTailBytes = 1024 - readHeadLines = 120 - readTailLines = 40 + toolOutputHeadBytes = 32 * 1024 + toolOutputTailBytes = 8 * 1024 + toolOutputHeadLines = 500 + toolOutputTailLines = 100 listMaxEntries = 200 listCollapseThreshold = 50 diff --git a/internal/agent/tools/subagent.go b/internal/agent/tools/subagent.go index cbb50257..960f90c0 100644 --- a/internal/agent/tools/subagent.go +++ b/internal/agent/tools/subagent.go @@ -6,9 +6,13 @@ import ( "errors" "fmt" "log/slog" + "net" "net/http" + "regexp" "strings" "sync" + "sync/atomic" + "time" sdk "github.com/memohai/twilight-ai/sdk" @@ -24,6 +28,7 @@ import ( // It is satisfied by *agent.Agent and avoids an import cycle. type SpawnAgent interface { Generate(ctx context.Context, cfg SpawnRunConfig) (*SpawnResult, error) + GenerateWithWatchdog(ctx context.Context, cfg SpawnRunConfig, touchFn func()) (*SpawnResult, error) } // SpawnRunConfig mirrors agent.RunConfig fields needed by spawn. @@ -61,6 +66,138 @@ type SpawnResult struct { Usage *sdk.Usage } +// subagentTimeout caps total execution time as a safety net per attempt. +// This prevents runaway subagent calls from blocking the parent agent forever, +// even if the watchdog keeps getting touched (e.g., tiny tokens but no convergence). +const subagentTimeout = 10 * time.Minute + +// spawnHeartbeatInterval controls how often a progress event is emitted during +// spawn execution to keep the parent stream's idle timeout from firing. +const spawnHeartbeatInterval = 30 * time.Second + +// subagentMaxRetries is the maximum number of retry attempts for a failed +// subagent task. Only transient errors (rate limits, network failures) are +// retried; fatal errors (bad config, invalid input) fail immediately. +const subagentMaxRetries = 3 + +// subagentRetryBaseDelay is the initial backoff delay between retry attempts. +const subagentRetryBaseDelay = 2 * time.Second + +// ErrWatchdogTimedOut is returned when the subagent watchdog fires +// (no activity within the timeout period). +var ErrWatchdogTimedOut = errors.New("subagent watchdog: no activity within timeout") + +// subagentWatchdogTimeout is the default inactivity timeout for the watchdog. +const subagentWatchdogTimeout = 3 * time.Minute + +var ( + // err429Pattern matches HTTP 429 status codes in error strings. + err429Pattern = regexp.MustCompile(`(^|[^0-9])429($|[^0-9])`) + // errEOFPattern matches EOF or connection-level resets. + errEOFPattern = regexp.MustCompile(`(?i)connection (reset|refused)|EOF$`) + // serverErrPattern matches "api error 5XX" where XX is any two digits. + serverErrPattern = regexp.MustCompile(`api error 5\\d{2}`) +) + +// SubagentWatchdog implements an activity-based timeout for subagent execution. +// It is "touched" (fed/reset) on each activity signal from the LLM or tools. +// If no touch occurs within the configured timeout, it fires by cancelling +// its associated context. +// +// Lifecycle: +// 1. Call NewSubagentWatchdog to create a watchdog context. +// 2. Call Touch() on each activity signal. +// 3. Call Stop() when the watched operation completes normally. +// +// The watchdog respects parent context cancellation: if the parent context +// is cancelled, the watchdog's context is also cancelled immediately. +type SubagentWatchdog struct { + timeout time.Duration + touchCh chan struct{} + cancel context.CancelCauseFunc + done chan struct{} + logger *slog.Logger +} + +// NewSubagentWatchdog creates a watchdog that cancels the returned context +// after timeout of inactivity. The returned context is derived from parentCtx. +// If parentCtx is cancelled, the watchdog context is also cancelled. +func NewSubagentWatchdog(parentCtx context.Context, timeout time.Duration, logger *slog.Logger) (context.Context, *SubagentWatchdog) { + if timeout <= 0 { + timeout = subagentWatchdogTimeout + } + ctx, cancel := context.WithCancelCause(parentCtx) + + wd := &SubagentWatchdog{ + timeout: timeout, + touchCh: make(chan struct{}, 1), + cancel: cancel, + done: make(chan struct{}), + logger: logger, + } + + go wd.run(ctx) + + return ctx, wd +} + +// Touch resets the watchdog timer. It is non-blocking and safe to call +// from any goroutine. +func (w *SubagentWatchdog) Touch() { + select { + case w.touchCh <- struct{}{}: + default: + // Already a pending touch, no need to queue another. + } +} + +// Stop terminates the watchdog goroutine and releases resources. +// Call this when the watched operation completes normally. +func (w *SubagentWatchdog) Stop() { + w.cancel(context.Canceled) + <-w.done +} + +// run is the watchdog loop. It watches for touches and fires if none arrive +// within the configured timeout. +func (w *SubagentWatchdog) run(ctx context.Context) { + defer close(w.done) + + timer := time.NewTimer(w.timeout) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + // Parent cancelled or Stop() called. + return + case <-w.touchCh: + // Activity detected; reset the timer. + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timer.Reset(w.timeout) + case <-timer.C: + // No activity within timeout -- fire! + w.logger.Warn("subagent watchdog fired", + slog.Duration("timeout", w.timeout), + ) + w.cancel(ErrWatchdogTimedOut) + return + } + } +} + +// maxTasksPerSpawn caps the number of tasks accepted in a single spawn call. +const maxTasksPerSpawn = 5 + +// maxSpawnCallsPerSession caps the total number of spawn tool calls within +// a single agent session to prevent subagent storms. +const maxSpawnCallsPerSession = 3 + // SpawnProvider exposes a "spawn" tool that runs one or more subagent tasks // concurrently and returns results to the parent agent. type SpawnProvider struct { @@ -118,23 +255,24 @@ func (p *SpawnProvider) Tools(_ context.Context, session SessionContext) ([]sdk. return nil, nil } sess := session + spawnCount := new(int32) return []sdk.Tool{ { Name: "spawn", - Description: "Spawn one or more subagents to work on tasks in parallel. Each task runs in its own context with file, exec, and web tools. All results are returned together.", + Description: fmt.Sprintf("Spawn one or more subagents to work on tasks in parallel. Each task runs in its own context with file, exec, and web tools. All results are returned together. Max %d tasks per call, max %d calls per session.", maxTasksPerSpawn, maxSpawnCallsPerSession), Parameters: map[string]any{ "type": "object", "properties": map[string]any{ "tasks": map[string]any{ "type": "array", - "description": "List of task instructions. Each string is a self-contained prompt for one subagent.", + "description": fmt.Sprintf("List of task instructions. Each string is a self-contained prompt for one subagent. Max %d tasks.", maxTasksPerSpawn), "items": map[string]any{"type": "string"}, }, }, "required": []string{"tasks"}, }, Execute: func(ctx *sdk.ToolExecContext, input any) (any, error) { - return p.execSpawn(ctx.Context, sess, inputAsMap(input)) + return p.execSpawn(ctx.Context, sess, inputAsMap(input), spawnCount) }, }, }, nil @@ -148,12 +286,24 @@ type spawnResult struct { Error string `json:"error,omitempty"` } -func (p *SpawnProvider) execSpawn(ctx context.Context, session SessionContext, args map[string]any) (any, error) { +func (p *SpawnProvider) execSpawn(ctx context.Context, session SessionContext, args map[string]any, spawnCount *int32) (any, error) { botID := strings.TrimSpace(session.BotID) if botID == "" { return nil, errors.New("bot_id is required") } + // Enforce per-session spawn call limit. + current := atomic.AddInt32(spawnCount, 1) + if current > maxSpawnCallsPerSession { + return map[string]any{ + "isError": true, + "content": []map[string]any{{ + "type": "text", + "text": fmt.Sprintf("Spawn limit reached: max %d spawn calls per session (already made %d). Consolidate your remaining work into the current agent context instead of spawning more subagents.", maxSpawnCallsPerSession, current-1), + }}, + }, nil + } + tasksRaw, ok := args["tasks"] if !ok { return nil, errors.New("tasks is required") @@ -165,8 +315,21 @@ func (p *SpawnProvider) execSpawn(ctx context.Context, session SessionContext, a if len(tasks) == 0 { return nil, errors.New("at least one task is required") } + // Cap tasks per call. + if len(tasks) > maxTasksPerSpawn { + p.logger.Warn("spawn tasks capped", + slog.Int("requested", len(tasks)), + slog.Int("max", maxTasksPerSpawn), + ) + tasks = tasks[:maxTasksPerSpawn] + } - sdkModel, modelID, err := p.resolveModel(ctx, botID) + // Use a decoupled context for model resolution and subagent execution + // so that a parent stream cancellation (e.g. idle timeout) does not + // prevent the spawn from completing and returning its results. + sessionCtx := context.WithoutCancel(ctx) + + sdkModel, modelID, err := p.resolveModel(sessionCtx, botID) if err != nil { return nil, fmt.Errorf("resolve model: %w", err) } @@ -180,10 +343,17 @@ func (p *SpawnProvider) execSpawn(ctx context.Context, session SessionContext, a var wg sync.WaitGroup wg.Add(len(tasks)) + // Start a heartbeat goroutine that emits progress events into the + // parent stream at regular intervals. This keeps the stream's idle + // timeout from firing while subagents are running. + heartbeatCtx, heartbeatCancel := context.WithCancel(sessionCtx) + defer heartbeatCancel() + p.startSpawnHeartbeat(heartbeatCtx, session, len(tasks)) + for i, task := range tasks { go func(idx int, query string) { defer wg.Done() - results[idx] = p.runSubagentTask(ctx, session, sdkModel, modelID, systemPrompt, query) + results[idx] = p.runSubagentTask(sessionCtx, session, sdkModel, modelID, systemPrompt, query) }(i, task) } wg.Wait() @@ -191,6 +361,34 @@ func (p *SpawnProvider) execSpawn(ctx context.Context, session SessionContext, a return map[string]any{"results": results}, nil } +// startSpawnHeartbeat emits periodic progress events into the parent agent +// stream to prevent the idle timeout from firing while spawn tasks run. +// Each heartbeat carries a progress status so the frontend can display it. +func (*SpawnProvider) startSpawnHeartbeat(ctx context.Context, session SessionContext, _ int) { + emitter := session.Emitter + if emitter == nil { + return + } + go func() { + ticker := time.NewTicker(spawnHeartbeatInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + // Emit a progress event through the agent's stream emitter. + // The agent framework converts ToolStreamEvent into the + // appropriate wire-level progress event, which resets the + // idle timeout timer in the resolver. + emitter(ToolStreamEvent{ + Type: StreamEventSpawnHeartbeat, + }) + } + } + }() +} + func (p *SpawnProvider) runSubagentTask( ctx context.Context, parentSession SessionContext, @@ -203,7 +401,7 @@ func (p *SpawnProvider) runSubagentTask( var sessionID string if p.sessionService != nil { - sess, err := p.sessionService.Create(ctx, sessionpkg.CreateInput{ + sess, err := p.sessionService.Create(context.WithoutCancel(ctx), sessionpkg.CreateInput{ BotID: parentSession.BotID, Type: sessionpkg.TypeSubagent, Title: truncateTitle(query, 100), @@ -234,22 +432,118 @@ func (p *SpawnProvider) runSubagentTask( LoopDetection: SpawnLoopConfig{Enabled: true}, } - genResult, err := p.agent.Generate(ctx, cfg) - if err != nil { - res.Error = err.Error() - return res - } - - res.Text = genResult.Text - res.Success = true - - if p.messageService != nil && sessionID != "" { - p.persistMessages(ctx, parentSession.BotID, sessionID, modelID, query, genResult) + var lastErr error + for attempt := 0; attempt <= subagentMaxRetries; attempt++ { + if attempt > 0 { + delay := subagentRetryBaseDelay * time.Duration(attempt) + p.logger.Info("subagent retry", + slog.String("session_id", sessionID), + slog.Int("attempt", attempt), + slog.Duration("delay", delay), + slog.String("error", lastErr.Error()), + ) + delayTimer := time.NewTimer(delay) + deadlineTimer := time.NewTimer(subagentTimeout) + select { + case <-delayTimer.C: + deadlineTimer.Stop() + case <-deadlineTimer.C: + delayTimer.Stop() + // Hard deadline: don't retry indefinitely. + res.Error = fmt.Sprintf("retry deadline exceeded (last error: %v)", lastErr) + return res + } + } + + // Create a two-layer timeout per attempt: + // 1. Safety net: wall-clock timeout (subagentTimeout) via context.WithTimeout. + // 2. Watchdog: activity-based timeout (subagentWatchdogTimeout) that fires + // when no stream events (tokens, tool output) are received. + // Use context.WithoutCancel so retries get a fresh timeout even if + // the parent stream was cancelled (e.g. by idle timeout). + safetyCtx, safetyCancel := context.WithTimeout(context.WithoutCancel(ctx), subagentTimeout) + wdCtx, wd := NewSubagentWatchdog(safetyCtx, subagentWatchdogTimeout, p.logger) + + genResult, err := p.agent.GenerateWithWatchdog(wdCtx, cfg, wd.Touch) + wd.Stop() + safetyCancel() + + if err == nil { + res.Text = genResult.Text + res.Success = true + if p.messageService != nil && sessionID != "" { + p.persistMessages(context.WithoutCancel(ctx), parentSession.BotID, sessionID, modelID, query, genResult) + } + return res + } + + lastErr = err + + // Check if the true parent context was cancelled (not watchdog, not safety timeout). + // If the parent is done, don't retry. + if ctx.Err() != nil && !errors.Is(err, ErrWatchdogTimedOut) { + res.Error = fmt.Sprintf("parent cancelled: %v", ctx.Err()) + return res + } + + // Watchdog timeouts are always retryable. + if errors.Is(err, ErrWatchdogTimedOut) { + p.logger.Warn("subagent watchdog fired, will retry", + slog.String("session_id", sessionID), + slog.Int("attempt", attempt+1), + slog.Int("max_attempts", subagentMaxRetries+1), + ) + continue + } + + if !isRetryableSubagentError(err) { + res.Error = err.Error() + return res + } } + p.logger.Warn("subagent failed after all retries", + slog.String("session_id", sessionID), + slog.Int("attempts", subagentMaxRetries+1), + slog.String("error", lastErr.Error()), + ) + res.Error = fmt.Sprintf("all %d attempts failed (last: %v)", subagentMaxRetries+1, lastErr) return res } +// isRetryableSubagentError returns true for transient errors that warrant a retry. +// Fatal errors (invalid config, context cancelled by user) return false. +func isRetryableSubagentError(err error) bool { + if err == nil { + return false + } + errStr := err.Error() + // Rate limits + if strings.Contains(errStr, "rate limit") || strings.Contains(errStr, "rate_limit") { + return true + } + // HTTP 429 and 5xx + if err429Pattern.MatchString(errStr) || serverErrPattern.MatchString(errStr) { + return true + } + // Connection-level errors + if errEOFPattern.MatchString(errStr) { + return true + } + // Network timeouts + var netErr net.Error + if errors.As(err, &netErr) { + return true + } + // Context cancellation from parent (idle timeout, etc.) IS retryable + // for subagents — they should complete their work even if the parent + // stream was interrupted. + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return true + } + return false +} + func (p *SpawnProvider) persistMessages( ctx context.Context, botID, sessionID, modelID, query string, diff --git a/internal/agent/tools/types.go b/internal/agent/tools/types.go index 40edcd8c..c2cf2fd0 100644 --- a/internal/agent/tools/types.go +++ b/internal/agent/tools/types.go @@ -21,9 +21,10 @@ type SkillDetail struct { type StreamEventType string const ( - StreamEventAttachment StreamEventType = "attachment" - StreamEventReaction StreamEventType = "reaction" - StreamEventSpeech StreamEventType = "speech" + StreamEventAttachment StreamEventType = "attachment" + StreamEventReaction StreamEventType = "reaction" + StreamEventSpeech StreamEventType = "speech" + StreamEventSpawnHeartbeat StreamEventType = "spawn_heartbeat" ) // ToolStreamEvent is a side-effect event emitted by a tool targeting the diff --git a/internal/agent/tools/watchdog_test.go b/internal/agent/tools/watchdog_test.go new file mode 100644 index 00000000..01eaf716 --- /dev/null +++ b/internal/agent/tools/watchdog_test.go @@ -0,0 +1,530 @@ +package tools + +import ( + "context" + "errors" + "log/slog" + "sync" + "sync/atomic" + "testing" + "time" +) + +// --- Watchdog unit tests --- + +func TestWatchdogFiresAfterInactivity(t *testing.T) { + t.Parallel() + + timeout := 200 * time.Millisecond + parentCtx := context.Background() + ctx, wd := NewSubagentWatchdog(parentCtx, timeout, slog.Default()) + defer wd.Stop() + + // Do not touch — watchdog should fire after timeout. + select { + case <-ctx.Done(): + if !errors.Is(context.Cause(ctx), ErrWatchdogTimedOut) { + t.Fatalf("expected ErrWatchdogTimedOut, got: %v", context.Cause(ctx)) + } + case <-time.After(timeout + 200*time.Millisecond): + t.Fatal("watchdog did not fire within expected time") + } +} + +func TestWatchdogDoesNotFireWhenTouched(t *testing.T) { + t.Parallel() + + timeout := 150 * time.Millisecond + parentCtx := context.Background() + ctx, wd := NewSubagentWatchdog(parentCtx, timeout, slog.Default()) + defer wd.Stop() + + // Touch repeatedly to keep the watchdog alive past the timeout. + deadline := time.After(timeout + 300*time.Millisecond) + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + +loop: + for { + select { + case <-ticker.C: + wd.Touch() + case <-deadline: + break loop + } + } + + // Context should NOT be cancelled — watchdog never fired. + if ctx.Err() != nil { + t.Fatalf("watchdog should not have fired, but context is done: %v", context.Cause(ctx)) + } +} + +func TestWatchdogFiresAfterTouchesStop(t *testing.T) { + t.Parallel() + + timeout := 200 * time.Millisecond + parentCtx := context.Background() + ctx, wd := NewSubagentWatchdog(parentCtx, timeout, slog.Default()) + defer wd.Stop() + + // Touch a few times, then stop touching. + for i := 0; i < 3; i++ { + wd.Touch() + time.Sleep(50 * time.Millisecond) + } + + // Now wait for the watchdog to fire after we stop touching. + select { + case <-ctx.Done(): + if !errors.Is(context.Cause(ctx), ErrWatchdogTimedOut) { + t.Fatalf("expected ErrWatchdogTimedOut, got: %v", context.Cause(ctx)) + } + case <-time.After(timeout + 500*time.Millisecond): + t.Fatal("watchdog did not fire after touches stopped") + } +} + +func TestWatchdogStopsCleanly(t *testing.T) { + t.Parallel() + + timeout := 5 * time.Second // long timeout, should not fire + parentCtx := context.Background() + ctx, wd := NewSubagentWatchdog(parentCtx, timeout, slog.Default()) + + wd.Stop() + + // After Stop(), context should be cancelled (with benign Canceled cause). + if ctx.Err() == nil { + t.Fatal("context should be cancelled after Stop()") + } + if !errors.Is(context.Cause(ctx), context.Canceled) { + t.Fatalf("expected context.Canceled cause, got: %v", context.Cause(ctx)) + } +} + +func TestWatchdogRespectsParentCancellation(t *testing.T) { + t.Parallel() + + timeout := 5 * time.Second // long timeout, should not fire on its own + parentCtx, parentCancel := context.WithCancel(context.Background()) + ctx, wd := NewSubagentWatchdog(parentCtx, timeout, slog.Default()) + defer wd.Stop() + + // Cancel the parent context. + parentCancel() + + // Watchdog context should be cancelled immediately (not after timeout). + select { + case <-ctx.Done(): + // Expected — immediate cancellation. + case <-time.After(200 * time.Millisecond): + t.Fatal("watchdog context was not cancelled when parent was cancelled") + } + + // The cause should be context.Canceled (from parent), not ErrWatchdogTimedOut. + if errors.Is(context.Cause(ctx), ErrWatchdogTimedOut) { + t.Fatal("watchdog should not have fired — parent was cancelled") + } +} + +func TestWatchdogTouchIsNonBlocking(t *testing.T) { + t.Parallel() + + timeout := 5 * time.Second + parentCtx := context.Background() + ctx, wd := NewSubagentWatchdog(parentCtx, timeout, slog.Default()) + defer wd.Stop() + + // Rapidly call Touch many times — none should block. + done := make(chan struct{}) + go func() { + defer close(done) + for i := 0; i < 10000; i++ { + wd.Touch() + } + }() + + select { + case <-done: + // All Touch calls completed without blocking. + case <-time.After(2 * time.Second): + t.Fatal("Touch calls blocked — should be non-blocking") + } + + // Context should not be cancelled (we've been touching). + if ctx.Err() != nil { + t.Fatalf("unexpected context cancellation: %v", context.Cause(ctx)) + } +} + +func TestWatchdogTouchFromMultipleGoroutines(t *testing.T) { + t.Parallel() + + timeout := 300 * time.Millisecond + parentCtx := context.Background() + ctx, wd := NewSubagentWatchdog(parentCtx, timeout, slog.Default()) + defer wd.Stop() + + // Multiple goroutines touch concurrently. + var wg sync.WaitGroup + for g := 0; g < 10; g++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 100; i++ { + wd.Touch() + } + }() + } + wg.Wait() + + // Give the watchdog a moment to process touches. + time.Sleep(50 * time.Millisecond) + + if ctx.Err() != nil { + t.Fatalf("unexpected context cancellation: %v", context.Cause(ctx)) + } +} + +func TestWatchdogDefaultTimeout(t *testing.T) { + t.Parallel() + + // Passing zero timeout should use the default (subagentWatchdogTimeout). + parentCtx := context.Background() + ctx, wd := NewSubagentWatchdog(parentCtx, 0, slog.Default()) + defer wd.Stop() + + if wd.timeout != subagentWatchdogTimeout { + t.Fatalf("expected default timeout %v, got %v", subagentWatchdogTimeout, wd.timeout) + } + if ctx.Err() != nil { + t.Fatalf("unexpected context cancellation: %v", context.Cause(ctx)) + } +} + +func TestWatchdogTimerResetOnTouch(t *testing.T) { + t.Parallel() + + timeout := 200 * time.Millisecond + parentCtx := context.Background() + ctx, wd := NewSubagentWatchdog(parentCtx, timeout, slog.Default()) + defer wd.Stop() + + // Wait almost until timeout, then touch to reset. + time.Sleep(150 * time.Millisecond) + wd.Touch() + + // Now wait another 150ms — total 300ms > 200ms timeout, but we touched at 150ms. + // The timer should have been reset at 150ms, so it shouldn't fire until 350ms. + time.Sleep(150 * time.Millisecond) + + if ctx.Err() != nil { + t.Fatalf("watchdog fired too early — timer was not properly reset: %v", context.Cause(ctx)) + } + + // Now wait for it to actually fire (no more touches). + select { + case <-ctx.Done(): + if !errors.Is(context.Cause(ctx), ErrWatchdogTimedOut) { + t.Fatalf("expected ErrWatchdogTimedOut, got: %v", context.Cause(ctx)) + } + case <-time.After(timeout + 200*time.Millisecond): + t.Fatal("watchdog did not fire after expected time") + } +} + +func TestWatchdogFiresExactlyOnce(t *testing.T) { + t.Parallel() + + timeout := 100 * time.Millisecond + parentCtx := context.Background() + ctx, wd := NewSubagentWatchdog(parentCtx, timeout, slog.Default()) + + // Wait for watchdog to fire. + select { + case <-ctx.Done(): + case <-time.After(500 * time.Millisecond): + t.Fatal("watchdog did not fire") + } + + // Call Stop — should not panic or deadlock even after firing. + wd.Stop() + + // Verify context cause is still the original fire. + if !errors.Is(context.Cause(ctx), ErrWatchdogTimedOut) { + t.Fatalf("expected ErrWatchdogTimedOut, got: %v", context.Cause(ctx)) + } +} + +// --- Integration: watchdog + mock GenerateWithWatchdog --- + +// mockSpawnAgent implements SpawnAgent for testing. +type mockSpawnAgent struct { + // generateFunc is the function called by GenerateWithWatchdog. + // It receives the context and a touchFn. The implementation should + // call touchFn to simulate activity, or not call it to simulate a hang. + generateFunc func(ctx context.Context, cfg SpawnRunConfig, touchFn func()) (*SpawnResult, error) + + // generateCount tracks how many times GenerateWithWatchdog was called. + generateCount atomic.Int32 +} + +func (m *mockSpawnAgent) Generate(_ context.Context, _ SpawnRunConfig) (*SpawnResult, error) { + _ = m // interface satisfaction only + return nil, errors.New("not implemented in mock") +} + +func (m *mockSpawnAgent) GenerateWithWatchdog(ctx context.Context, cfg SpawnRunConfig, touchFn func()) (*SpawnResult, error) { + m.generateCount.Add(1) + if m.generateFunc != nil { + return m.generateFunc(ctx, cfg, touchFn) + } + return &SpawnResult{Text: "ok"}, nil +} + +func TestWatchdogKillsStuckAgentAndRetries(t *testing.T) { + t.Parallel() + + timeout := 200 * time.Millisecond + callCount := atomic.Int32{} + + agent := &mockSpawnAgent{ + generateFunc: func(ctx context.Context, _ SpawnRunConfig, touchFn func()) (*SpawnResult, error) { + count := callCount.Add(1) + if count <= 2 { + // First 2 calls: don't touch — simulate a stuck agent. + <-ctx.Done() + return nil, context.Cause(ctx) + } + // 3rd call: touch repeatedly to simulate normal activity. + ticker := time.NewTicker(20 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ticker.C: + touchFn() + case <-ctx.Done(): + if errors.Is(context.Cause(ctx), ErrWatchdogTimedOut) { + return nil, context.Cause(ctx) + } + return &SpawnResult{Text: "completed"}, nil + } + } + }, + } + + // Simulate runSubagentTask's retry loop with short timeout. + safetyCtx, safetyCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer safetyCancel() + + var lastErr error + for attempt := 0; attempt <= 3; attempt++ { + wdCtx, wd := NewSubagentWatchdog(safetyCtx, timeout, slog.Default()) + _, err := agent.GenerateWithWatchdog(wdCtx, SpawnRunConfig{}, wd.Touch) + wd.Stop() + + if err == nil { + // Success. + return + } + lastErr = err + if !errors.Is(err, ErrWatchdogTimedOut) { + t.Fatalf("unexpected error: %v", err) + } + } + t.Fatalf("agent never succeeded after retries, last error: %v", lastErr) +} + +func TestWatchdogParentCancelStopsImmediately(t *testing.T) { + t.Parallel() + + timeout := 5 * time.Second // long — should not fire + parentCtx, parentCancel := context.WithCancel(context.Background()) + + started := make(chan struct{}) + agent := &mockSpawnAgent{ + generateFunc: func(ctx context.Context, _ SpawnRunConfig, _ func()) (*SpawnResult, error) { + close(started) + <-ctx.Done() + return nil, context.Cause(ctx) + }, + } + + safetyCtx, safetyCancel := context.WithTimeout(parentCtx, 10*time.Second) + defer safetyCancel() + + wdCtx, wd := NewSubagentWatchdog(safetyCtx, timeout, slog.Default()) + + doneCh := make(chan error, 1) + go func() { + _, err := agent.GenerateWithWatchdog(wdCtx, SpawnRunConfig{}, wd.Touch) + wd.Stop() + doneCh <- err + }() + + <-started // wait for agent to start + parentCancel() + + select { + case err := <-doneCh: + if errors.Is(err, ErrWatchdogTimedOut) { + t.Fatal("should not be ErrWatchdogTimedOut — parent was cancelled") + } + // Expected: parent cancellation propagated. + case <-time.After(2 * time.Second): + t.Fatal("agent did not terminate after parent cancellation") + } +} + +func TestWatchdogKeepsActiveAgentAlive(t *testing.T) { + t.Parallel() + + timeout := 200 * time.Millisecond + agent := &mockSpawnAgent{ + generateFunc: func(ctx context.Context, _ SpawnRunConfig, touchFn func()) (*SpawnResult, error) { + // Simulate a long-running agent that keeps touching the watchdog. + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + + elapsed := time.After(800 * time.Millisecond) // run longer than timeout + for { + select { + case <-ticker.C: + touchFn() + case <-elapsed: + return &SpawnResult{Text: "done"}, nil + case <-ctx.Done(): + return nil, context.Cause(ctx) + } + } + }, + } + + safetyCtx, safetyCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer safetyCancel() + + wdCtx, wd := NewSubagentWatchdog(safetyCtx, timeout, slog.Default()) + result, err := agent.GenerateWithWatchdog(wdCtx, SpawnRunConfig{}, wd.Touch) + wd.Stop() + + if err != nil { + t.Fatalf("expected success, got: %v", err) + } + if result.Text != "done" { + t.Fatalf("unexpected result text: %q", result.Text) + } +} + +func TestWatchdogSafetyNetFiresWhenAgentTouchesButNeverConverges(t *testing.T) { + t.Parallel() + + watchdogTimeout := 500 * time.Millisecond + safetyTimeout := 300 * time.Millisecond + + agent := &mockSpawnAgent{ + generateFunc: func(ctx context.Context, _ SpawnRunConfig, touchFn func()) (*SpawnResult, error) { + // Agent keeps touching but never completes — safety net should fire. + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ticker.C: + touchFn() + case <-ctx.Done(): + return nil, context.Cause(ctx) + } + } + }, + } + + safetyCtx, safetyCancel := context.WithTimeout(context.Background(), safetyTimeout) + defer safetyCancel() + + wdCtx, wd := NewSubagentWatchdog(safetyCtx, watchdogTimeout, slog.Default()) + _, err := agent.GenerateWithWatchdog(wdCtx, SpawnRunConfig{}, wd.Touch) + wd.Stop() + + if err == nil { + t.Fatal("expected error from safety net timeout") + } + // Safety net fires as context.DeadlineExceeded, not ErrWatchdogTimedOut. + if errors.Is(err, ErrWatchdogTimedOut) { + t.Fatal("should be safety net (DeadlineExceeded), not watchdog timeout") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected DeadlineExceeded from safety net, got: %v", err) + } +} + +func TestWatchdogStoppedAfterSuccess(t *testing.T) { + t.Parallel() + + timeout := 100 * time.Millisecond + agent := &mockSpawnAgent{ + generateFunc: func(_ context.Context, _ SpawnRunConfig, _ func()) (*SpawnResult, error) { + return &SpawnResult{Text: "instant"}, nil + }, + } + + safetyCtx, safetyCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer safetyCancel() + + wdCtx, wd := NewSubagentWatchdog(safetyCtx, timeout, slog.Default()) + result, err := agent.GenerateWithWatchdog(wdCtx, SpawnRunConfig{}, wd.Touch) + wd.Stop() + + if err != nil { + t.Fatalf("expected success, got: %v", err) + } + if result.Text != "instant" { + t.Fatalf("unexpected text: %q", result.Text) + } + + // Verify context was cancelled by Stop() (benign Canceled), not by watchdog. + if !errors.Is(context.Cause(wdCtx), context.Canceled) { + t.Fatalf("expected Canceled cause from Stop(), got: %v", context.Cause(wdCtx)) + } +} + +func TestWatchdogRetryBudget(t *testing.T) { + t.Parallel() + + // Verify that after exhausting retries, the task fails with the correct error. + timeout := 100 * time.Millisecond + maxRetries := 3 + + callCount := atomic.Int32{} + agent := &mockSpawnAgent{ + generateFunc: func(ctx context.Context, _ SpawnRunConfig, _ func()) (*SpawnResult, error) { + callCount.Add(1) + <-ctx.Done() // always stuck + return nil, context.Cause(ctx) + }, + } + + safetyCtx, safetyCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer safetyCancel() + + var lastErr error + for attempt := 0; attempt <= maxRetries; attempt++ { + wdCtx, wd := NewSubagentWatchdog(safetyCtx, timeout, slog.Default()) + _, err := agent.GenerateWithWatchdog(wdCtx, SpawnRunConfig{}, wd.Touch) + wd.Stop() + if err == nil { + t.Fatal("should not succeed") + } + lastErr = err + if !errors.Is(err, ErrWatchdogTimedOut) { + t.Fatalf("attempt %d: expected ErrWatchdogTimedOut, got: %v", attempt, err) + } + } + + calls := callCount.Load() + if calls != int32(maxRetries+1) { + t.Fatalf("expected %d calls (maxRetries+1), got %d", maxRetries+1, calls) + } + if lastErr == nil || !errors.Is(lastErr, ErrWatchdogTimedOut) { + t.Fatalf("last error should be ErrWatchdogTimedOut, got: %v", lastErr) + } +} diff --git a/internal/agent/types.go b/internal/agent/types.go index 0a034380..f4dc7fed 100644 --- a/internal/agent/types.go +++ b/internal/agent/types.go @@ -70,6 +70,18 @@ type RunConfig struct { Identity SessionContext Skills []SkillEntry LoopDetection LoopDetectionConfig + Retry RetryConfig + + // MidTaskPruneThreshold is the minimum number of messages before mid-task + // pruning kicks in. When the accumulated message count reaches this + // threshold, older tool-result pairs are pruned to keep the context + // within budget. Defaults to MidTaskPruneThresholdDefault (20). + MidTaskPruneThreshold int + + // MidTaskPruneKeepSteps is the number of recent tool-call cycles to + // preserve when mid-task pruning is triggered. Defaults to + // MidTaskPruneKeepStepsDefault (4). + MidTaskPruneKeepSteps int // InjectCh receives user messages to inject between tool rounds. // When non-nil, a PrepareStep hook drains this channel and appends diff --git a/internal/auth/jwt_test.go b/internal/auth/jwt_test.go index f59c3ffd..02e36ad3 100644 --- a/internal/auth/jwt_test.go +++ b/internal/auth/jwt_test.go @@ -1,6 +1,7 @@ package auth import ( + "context" "errors" "net/http" "net/http/httptest" @@ -15,7 +16,7 @@ import ( func TestRefreshTokenFromContext(t *testing.T) { e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/", nil) + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) @@ -80,7 +81,7 @@ func TestRefreshTokenFromContext(t *testing.T) { func TestRefreshTokenFromContext_MissingUser(t *testing.T) { e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/", nil) + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) diff --git a/internal/channel/adapters/feishu/webhook_handler_test.go b/internal/channel/adapters/feishu/webhook_handler_test.go index 3a85730f..837b2a09 100644 --- a/internal/channel/adapters/feishu/webhook_handler_test.go +++ b/internal/channel/adapters/feishu/webhook_handler_test.go @@ -73,7 +73,7 @@ func TestHandleWebhook_URLVerification(t *testing.T) { cfg := newWebhookConfig(tc.credentials) manager := &fakeWebhookManager{} - req := httptest.NewRequest(http.MethodPost, "/channels/feishu/webhook/"+testWebhookConfigID, strings.NewReader(tc.body)) + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/channels/feishu/webhook/"+testWebhookConfigID, strings.NewReader(tc.body)) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) rec := httptest.NewRecorder() adapter := NewFeishuAdapter(nil) @@ -128,7 +128,7 @@ func TestHandleWebhook_URLVerificationWithEncryptKeyWithoutVerificationToken(t * t.Fatalf("failed to encrypt challenge payload: %v", err) } - req := httptest.NewRequest(http.MethodPost, "/channels/feishu/webhook/"+testWebhookConfigID, strings.NewReader(`{"encrypt":"`+encrypt+`"}`)) + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/channels/feishu/webhook/"+testWebhookConfigID, strings.NewReader(`{"encrypt":"`+encrypt+`"}`)) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) rec := httptest.NewRecorder() adapter := NewFeishuAdapter(nil) @@ -156,7 +156,7 @@ func TestHandleWebhook_Probe(t *testing.T) { "verification_token": "verify-token", "inbound_mode": "webhook", }) - req := httptest.NewRequest(http.MethodGet, "/channels/feishu/webhook/"+testWebhookConfigID, nil) + req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, "/channels/feishu/webhook/"+testWebhookConfigID, nil) rec := httptest.NewRecorder() adapter := NewFeishuAdapter(nil) @@ -183,7 +183,7 @@ func TestHandleWebhook_EventCallbackDispatchesInbound(t *testing.T) { cfg.SelfIdentity = map[string]any{"open_id": "ou_bot_1"} manager := &fakeWebhookManager{} body := `{"schema":"2.0","header":{"event_id":"evt_1","event_type":"im.message.receive_v1","token":"verify-token"},"event":{"sender":{"sender_id":{"open_id":"ou_user_1","user_id":"u_user_1"}},"message":{"message_id":"om_1","chat_id":"oc_1","chat_type":"p2p","message_type":"text","content":"{\"text\":\"hello\"}"}},"type":"event_callback"}` - req := httptest.NewRequest(http.MethodPost, "/channels/feishu/webhook/"+testWebhookConfigID, strings.NewReader(body)) + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/channels/feishu/webhook/"+testWebhookConfigID, strings.NewReader(body)) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) rec := httptest.NewRecorder() adapter := NewFeishuAdapter(nil) @@ -218,7 +218,7 @@ func TestHandleWebhook_EventCallbackUsesExternalIdentityForMentionFilter(t *test cfg.ExternalIdentity = "open_id:ou_bot_1" manager := &fakeWebhookManager{} body := `{"schema":"2.0","header":{"event_id":"evt_2","event_type":"im.message.receive_v1","token":"verify-token"},"event":{"sender":{"sender_id":{"open_id":"ou_user_2","user_id":"u_user_2"}},"message":{"message_id":"om_2","chat_id":"oc_group_1","chat_type":"group","message_type":"text","content":"{\"text\":\" hello\"}"}},"type":"event_callback"}` - req := httptest.NewRequest(http.MethodPost, "/channels/feishu/webhook/"+testWebhookConfigID, strings.NewReader(body)) + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/channels/feishu/webhook/"+testWebhookConfigID, strings.NewReader(body)) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) rec := httptest.NewRecorder() adapter := NewFeishuAdapter(nil) @@ -282,7 +282,7 @@ func TestHandleWebhook_EventCallbackRejectsInvalidTokenWhenEncryptKeyMissing(t * cfg := newWebhookConfig(tc.credentials) manager := &fakeWebhookManager{} - req := httptest.NewRequest(http.MethodPost, "/channels/feishu/webhook/"+testWebhookConfigID, strings.NewReader(tc.body)) + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/channels/feishu/webhook/"+testWebhookConfigID, strings.NewReader(tc.body)) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) rec := httptest.NewRecorder() adapter := NewFeishuAdapter(nil) @@ -315,7 +315,7 @@ func TestHandleWebhook_RejectsOversizedBody(t *testing.T) { "inbound_mode": "webhook", }) manager := &fakeWebhookManager{} - req := httptest.NewRequest(http.MethodPost, "/channels/feishu/webhook/"+testWebhookConfigID, strings.NewReader(strings.Repeat("x", int(webhookMaxBodyBytes)+1))) + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/channels/feishu/webhook/"+testWebhookConfigID, strings.NewReader(strings.Repeat("x", int(webhookMaxBodyBytes)+1))) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) rec := httptest.NewRecorder() adapter := NewFeishuAdapter(nil) diff --git a/internal/channel/adapters/local/descriptor.go b/internal/channel/adapters/local/descriptor.go index 50fd6fac..fe262a71 100644 --- a/internal/channel/adapters/local/descriptor.go +++ b/internal/channel/adapters/local/descriptor.go @@ -1,9 +1,11 @@ -// Package local implements the local channel adapter for WebUI and API access. +// Package local implements the CLI and Web channel adapters for local development. package local import "github.com/memohai/memoh/internal/channel" const ( - // WebType is the registered ChannelType for the local adapter (WebUI / API). - WebType channel.ChannelType = "local" + // CLIType is the registered ChannelType for the CLI adapter. + CLIType channel.ChannelType = "cli" + // WebType is the registered ChannelType for the Web adapter. + WebType channel.ChannelType = "web" ) diff --git a/internal/channel/adapters/misskey/client.go b/internal/channel/adapters/misskey/client.go new file mode 100644 index 00000000..ede31f46 --- /dev/null +++ b/internal/channel/adapters/misskey/client.go @@ -0,0 +1,131 @@ +package misskey + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +var httpClient = &http.Client{Timeout: 30 * time.Second} + +// apiRequest sends a POST request to the Misskey API endpoint. +func apiRequest(ctx context.Context, cfg Config, endpoint string, payload any) (json.RawMessage, error) { + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("misskey api marshal: %w", err) + } + url := cfg.apiURL() + "/" + endpoint + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("misskey api request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := httpClient.Do(req) //nolint:gosec // G704: URL is user-configured, validated at config level + if err != nil { + return nil, fmt.Errorf("misskey api do: %w", err) + } + defer func() { + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + }() + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("misskey api read: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("misskey api %s: status %d: %s", endpoint, resp.StatusCode, string(respBody)) + } + return json.RawMessage(respBody), nil +} + +// --- API payloads --- + +// createNoteRequest is the request body for notes/create. +type createNoteRequest struct { + I string `json:"i"` + Text string `json:"text,omitempty"` + Visibility string `json:"visibility,omitempty"` + ReplyID string `json:"replyId,omitempty"` + CW string `json:"cw,omitempty"` + FileIDs []string `json:"fileIds,omitempty"` +} + +// createNoteResponse is the response from notes/create. +type createNoteResponse struct { + CreatedNote struct { + ID string `json:"id"` + Text string `json:"text"` + User struct { + ID string `json:"id"` + Username string `json:"username"` + Name string `json:"name"` + } `json:"user"` + } `json:"createdNote"` +} + +// meResponse is the response from i (self user info). +type meResponse struct { + ID string `json:"id"` + Username string `json:"username"` + Name string `json:"name"` + AvatarURL string `json:"avatarUrl"` +} + +// createNote creates a note on the Misskey instance. +func createNote(ctx context.Context, cfg Config, text, replyID, visibility string) (*createNoteResponse, error) { + if visibility == "" { + visibility = "public" + } + req := createNoteRequest{ + I: cfg.AccessToken, + Text: text, + Visibility: visibility, + ReplyID: replyID, + } + raw, err := apiRequest(ctx, cfg, "notes/create", req) + if err != nil { + return nil, err + } + var resp createNoteResponse + if err := json.Unmarshal(raw, &resp); err != nil { + return nil, fmt.Errorf("misskey notes/create unmarshal: %w", err) + } + return &resp, nil +} + +// getMe retrieves the authenticated user's info. +func getMe(ctx context.Context, cfg Config) (*meResponse, error) { + raw, err := apiRequest(ctx, cfg, "i", map[string]string{"i": cfg.AccessToken}) + if err != nil { + return nil, err + } + var resp meResponse + if err := json.Unmarshal(raw, &resp); err != nil { + return nil, fmt.Errorf("misskey i unmarshal: %w", err) + } + return &resp, nil +} + +// createReaction adds an emoji reaction to a note. +func createReaction(ctx context.Context, cfg Config, noteID, reaction string) error { + _, err := apiRequest(ctx, cfg, "notes/reactions/create", map[string]string{ + "i": cfg.AccessToken, + "noteId": noteID, + "reaction": reaction, + }) + return err +} + +// deleteReaction removes a reaction from a note. +func deleteReaction(ctx context.Context, cfg Config, noteID string) error { + _, err := apiRequest(ctx, cfg, "notes/reactions/delete", map[string]string{ + "i": cfg.AccessToken, + "noteId": noteID, + }) + return err +} diff --git a/internal/channel/adapters/misskey/config.go b/internal/channel/adapters/misskey/config.go new file mode 100644 index 00000000..cfae7cd9 --- /dev/null +++ b/internal/channel/adapters/misskey/config.go @@ -0,0 +1,137 @@ +package misskey + +import ( + "errors" + "strings" + + "github.com/memohai/memoh/internal/channel" +) + +// Config holds the Misskey instance credentials extracted from a channel configuration. +type Config struct { + InstanceURL string // Misskey instance URL (e.g. https://misskey.io) + AccessToken string `json:"AccessToken"` //nolint:gosec // G117: token field, handled securely +} + +// apiURL returns the base API URL with trailing slashes removed. +func (c Config) apiURL() string { + return strings.TrimRight(c.InstanceURL, "/") + "/api" +} + +// streamURL returns the WebSocket streaming URL. +func (c Config) streamURL() string { + base := strings.TrimRight(c.InstanceURL, "/") + // Replace http(s) with ws(s) + base = strings.Replace(base, "https://", "wss://", 1) + base = strings.Replace(base, "http://", "ws://", 1) + return base + "/streaming?i=" + c.AccessToken +} + +// UserConfig holds the identifiers used to target a Misskey user. +type UserConfig struct { + Username string + UserID string +} + +func normalizeConfig(raw map[string]any) (map[string]any, error) { + cfg, err := parseConfig(raw) + if err != nil { + return nil, err + } + out := map[string]any{ + "instanceURL": cfg.InstanceURL, + "accessToken": cfg.AccessToken, + } + return out, nil +} + +func normalizeUserConfig(raw map[string]any) (map[string]any, error) { + cfg, err := parseUserConfig(raw) + if err != nil { + return nil, err + } + result := map[string]any{} + if cfg.Username != "" { + result["username"] = cfg.Username + } + if cfg.UserID != "" { + result["user_id"] = cfg.UserID + } + return result, nil +} + +func resolveTarget(raw map[string]any) (string, error) { + cfg, err := parseUserConfig(raw) + if err != nil { + return "", err + } + if cfg.UserID != "" { + return cfg.UserID, nil + } + if cfg.Username != "" { + return "@" + cfg.Username, nil + } + return "", errors.New("misskey binding is incomplete") +} + +func normalizeTarget(raw string) string { + value := strings.TrimSpace(raw) + if value == "" { + return "" + } + // Strip common prefixes + value = strings.TrimPrefix(value, "misskey:") + value = strings.TrimSpace(value) + return value +} + +func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { + cfg, err := parseUserConfig(raw) + if err != nil { + return false + } + if value := strings.TrimSpace(criteria.Attribute("user_id")); value != "" && value == cfg.UserID { + return true + } + if value := strings.TrimSpace(criteria.Attribute("username")); value != "" && strings.EqualFold(value, cfg.Username) { + return true + } + if criteria.SubjectID != "" { + if criteria.SubjectID == cfg.UserID || strings.EqualFold(criteria.SubjectID, cfg.Username) { + return true + } + } + return false +} + +func buildUserConfig(identity channel.Identity) map[string]any { + result := map[string]any{} + if value := strings.TrimSpace(identity.Attribute("username")); value != "" { + result["username"] = value + } + if value := strings.TrimSpace(identity.Attribute("user_id")); value != "" { + result["user_id"] = value + } + return result +} + +func parseConfig(raw map[string]any) (Config, error) { + instanceURL := strings.TrimSpace(channel.ReadString(raw, "instanceURL", "instance_url")) + if instanceURL == "" { + return Config{}, errors.New("misskey instanceURL is required") + } + token := strings.TrimSpace(channel.ReadString(raw, "accessToken", "access_token")) + if token == "" { + return Config{}, errors.New("misskey accessToken is required") + } + return Config{InstanceURL: instanceURL, AccessToken: token}, nil +} + +func parseUserConfig(raw map[string]any) (UserConfig, error) { + username := strings.TrimSpace(channel.ReadString(raw, "username")) + userID := strings.TrimSpace(channel.ReadString(raw, "userId", "user_id")) + if username == "" && userID == "" { + return UserConfig{}, errors.New("misskey user config requires username or user_id") + } + return UserConfig{Username: username, UserID: userID}, nil +} diff --git a/internal/channel/adapters/misskey/descriptor.go b/internal/channel/adapters/misskey/descriptor.go new file mode 100644 index 00000000..71388840 --- /dev/null +++ b/internal/channel/adapters/misskey/descriptor.go @@ -0,0 +1,7 @@ +// Package misskey implements the Misskey channel adapter. +package misskey + +import "github.com/memohai/memoh/internal/channel" + +// Type is the registered ChannelType identifier for Misskey. +const Type channel.ChannelType = "misskey" diff --git a/internal/channel/adapters/misskey/misskey.go b/internal/channel/adapters/misskey/misskey.go new file mode 100644 index 00000000..642ca487 --- /dev/null +++ b/internal/channel/adapters/misskey/misskey.go @@ -0,0 +1,643 @@ +package misskey + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + + "github.com/memohai/memoh/internal/channel" + "github.com/memohai/memoh/internal/channel/common" + "github.com/memohai/memoh/internal/textutil" +) + +const ( + misskeyMaxNoteLength = 3000 + misskeyReconnectDelay = 5 * time.Second + misskeyPingInterval = 30 * time.Second + misskeyWriteTimeout = 10 * time.Second + misskeyReadBufferSize = 1 << 16 + misskeyWriteBufferSize = 1 << 16 +) + +// MisskeyAdapter implements the channel.Adapter interfaces for Misskey. +type MisskeyAdapter struct { + logger *slog.Logger + mu sync.RWMutex + me map[string]*meResponse // keyed by config ID +} + +// NewMisskeyAdapter creates a MisskeyAdapter with the given logger. +func NewMisskeyAdapter(log *slog.Logger) *MisskeyAdapter { + if log == nil { + log = slog.Default() + } + return &MisskeyAdapter{ + logger: log.With(slog.String("adapter", "misskey")), + me: make(map[string]*meResponse), + } +} + +// Type returns the Misskey channel type. +func (*MisskeyAdapter) Type() channel.ChannelType { + return Type +} + +// Descriptor returns the Misskey channel metadata. +func (*MisskeyAdapter) Descriptor() channel.Descriptor { + return channel.Descriptor{ + Type: Type, + DisplayName: "Misskey", + Capabilities: channel.ChannelCapabilities{ + Text: true, + Markdown: true, + Reply: true, + Reactions: true, + Attachments: false, + Media: false, + Streaming: false, + BlockStreaming: true, + Edit: false, + }, + OutboundPolicy: channel.OutboundPolicy{ + TextChunkLimit: misskeyMaxNoteLength, + ChunkerMode: channel.ChunkerModeMarkdown, + }, + ConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "instanceURL": { + Type: channel.FieldString, + Required: true, + Title: "Instance URL", + Description: "Misskey instance URL (e.g. https://misskey.io)", + Example: "https://misskey.io", + }, + "accessToken": { + Type: channel.FieldSecret, + Required: true, + Title: "Access Token", + }, + }, + }, + UserConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "username": {Type: channel.FieldString}, + "user_id": {Type: channel.FieldString}, + }, + }, + TargetSpec: channel.TargetSpec{ + Format: "user_id | @username", + Hints: []channel.TargetHint{ + {Label: "User ID", Example: "9abcdef123456789"}, + {Label: "Username", Example: "@alice"}, + }, + }, + } +} + +// --- ConfigNormalizer --- + +// NormalizeConfig validates and normalizes a Misskey channel configuration map. +func (*MisskeyAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { + return normalizeConfig(raw) +} + +// NormalizeUserConfig validates and normalizes a Misskey user-binding configuration map. +func (*MisskeyAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { + return normalizeUserConfig(raw) +} + +// --- TargetResolver --- + +// NormalizeTarget normalizes a Misskey delivery target string. +func (*MisskeyAdapter) NormalizeTarget(raw string) string { + return normalizeTarget(raw) +} + +// ResolveTarget derives a delivery target from a Misskey user-binding configuration. +func (*MisskeyAdapter) ResolveTarget(userConfig map[string]any) (string, error) { + return resolveTarget(userConfig) +} + +// --- BindingMatcher --- + +// MatchBinding reports whether a Misskey user binding matches the given criteria. +func (*MisskeyAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool { + return matchBinding(config, criteria) +} + +// BuildUserConfig constructs a Misskey user-binding config from an Identity. +func (*MisskeyAdapter) BuildUserConfig(identity channel.Identity) map[string]any { + return buildUserConfig(identity) +} + +// --- SelfDiscoverer --- + +// DiscoverSelf retrieves the bot's own identity from the Misskey platform. +func (*MisskeyAdapter) DiscoverSelf(ctx context.Context, credentials map[string]any) (map[string]any, string, error) { + cfg, err := parseConfig(credentials) + if err != nil { + return nil, "", err + } + me, err := getMe(ctx, cfg) + if err != nil { + return nil, "", fmt.Errorf("misskey discover self: %w", err) + } + identity := map[string]any{ + "user_id": me.ID, + "username": me.Username, + } + if me.Name != "" { + identity["name"] = me.Name + } + if me.AvatarURL != "" { + identity["avatar_url"] = me.AvatarURL + } + return identity, me.ID, nil +} + +// --- Receiver --- + +// Connect starts a WebSocket streaming connection to receive Misskey mentions. +func (a *MisskeyAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.Connection, error) { + if a.logger != nil { + a.logger.Info("start", slog.String("config_id", cfg.ID)) + } + mkCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, err + } + channel.SetIMErrorSecrets("misskey:"+cfg.ID, mkCfg.AccessToken) + + // Fetch self info for mention detection. + me, err := getMe(ctx, mkCfg) + if err != nil { + return nil, fmt.Errorf("misskey get self: %w", err) + } + a.mu.Lock() + a.me[cfg.ID] = me + a.mu.Unlock() + + connCtx, cancel := context.WithCancel(ctx) + go a.runStreamLoop(connCtx, cfg, mkCfg, me, handler) + + stop := func(_ context.Context) error { + if a.logger != nil { + a.logger.Info("stop", slog.String("config_id", cfg.ID)) + } + cancel() + return nil + } + return channel.NewConnection(cfg, stop), nil +} + +func (a *MisskeyAdapter) runStreamLoop(ctx context.Context, cfg channel.ChannelConfig, mkCfg Config, me *meResponse, handler channel.InboundHandler) { + for { + select { + case <-ctx.Done(): + return + default: + } + if err := a.runStream(ctx, cfg, mkCfg, me, handler); err != nil { + if a.logger != nil { + a.logger.Warn("stream disconnected", slog.String("config_id", cfg.ID), slog.Any("error", err)) + } + } + select { + case <-ctx.Done(): + return + case <-time.After(misskeyReconnectDelay): + } + } +} + +func (a *MisskeyAdapter) runStream(ctx context.Context, cfg channel.ChannelConfig, mkCfg Config, me *meResponse, handler channel.InboundHandler) error { + dialer := websocket.Dialer{ + ReadBufferSize: misskeyReadBufferSize, + WriteBufferSize: misskeyWriteBufferSize, + } + conn, resp, err := dialer.DialContext(ctx, mkCfg.streamURL(), nil) //nolint:bodyclose // resp.Body is closed below + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + return fmt.Errorf("misskey ws dial: %w", err) + } + defer func() { _ = conn.Close() }() + + if a.logger != nil { + a.logger.Info("stream connected", slog.String("config_id", cfg.ID)) + } + + // Subscribe to main channel to receive mentions. + connectMsg := map[string]any{ + "type": "connect", + "body": map[string]any{ + "channel": "main", + "id": "memoh-main", + }, + } + if err := conn.WriteJSON(connectMsg); err != nil { + return fmt.Errorf("misskey ws connect main: %w", err) + } + + // Start ping ticker. + pingDone := make(chan struct{}) + go func() { + ticker := time.NewTicker(misskeyPingInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-pingDone: + return + case <-ticker.C: + _ = conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(misskeyWriteTimeout)) + } + } + }() + defer close(pingDone) + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + _, msgBytes, readErr := conn.ReadMessage() + if readErr != nil { + return fmt.Errorf("misskey ws read: %w", readErr) + } + a.handleStreamMessage(ctx, cfg, me, handler, msgBytes) + } +} + +// streamMessage represents a message received from the Misskey streaming API. +type streamMessage struct { + Type string `json:"type"` + Body json.RawMessage `json:"body"` +} + +// streamChannelBody is the body of a channel event. +type streamChannelBody struct { + ID string `json:"id"` + Type string `json:"type"` + Body json.RawMessage `json:"body"` +} + +// misskeyNote represents a Misskey note (post). +type misskeyNote struct { + ID string `json:"id"` + Text string `json:"text"` + CW string `json:"cw"` + UserID string `json:"userId"` + User misskeyUser `json:"user"` + ReplyID string `json:"replyId"` + RenoteID string `json:"renoteId"` + CreatedAt string `json:"createdAt"` + Mentions []string `json:"mentions"` + Visibility string `json:"visibility"` + Reply *misskeyNote `json:"reply"` +} + +type misskeyUser struct { + ID string `json:"id"` + Username string `json:"username"` + Name string `json:"name"` + Host string `json:"host"` + AvatarURL string `json:"avatarUrl"` +} + +func (a *MisskeyAdapter) handleStreamMessage(ctx context.Context, cfg channel.ChannelConfig, me *meResponse, handler channel.InboundHandler, raw []byte) { + var msg streamMessage + if err := json.Unmarshal(raw, &msg); err != nil { + return + } + + if msg.Type == "channel" { + var body streamChannelBody + if err := json.Unmarshal(msg.Body, &body); err != nil { + return + } + a.handleChannelEvent(ctx, cfg, me, handler, body) + } +} + +func (a *MisskeyAdapter) handleChannelEvent(ctx context.Context, cfg channel.ChannelConfig, me *meResponse, handler channel.InboundHandler, body streamChannelBody) { + switch body.Type { + case "mention", "reply": + var note misskeyNote + if err := json.Unmarshal(body.Body, ¬e); err != nil { + if a.logger != nil { + a.logger.Warn("parse note failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) + } + return + } + // Skip notes from self. + if note.UserID == me.ID { + return + } + inbound, ok := a.buildInboundMessage(me, note) + if !ok { + return + } + a.logInbound(cfg.ID, inbound) + go func() { + if err := handler(ctx, cfg, inbound); err != nil && a.logger != nil { + a.logger.Error("handle inbound failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) + } + }() + + case "notification": + // Handle notification-based mentions. + var notif struct { + Type string `json:"type"` + Note misskeyNote `json:"note"` + } + if err := json.Unmarshal(body.Body, ¬if); err != nil { + return + } + if notif.Type != "mention" && notif.Type != "reply" { + return + } + if notif.Note.UserID == me.ID { + return + } + inbound, ok := a.buildInboundMessage(me, notif.Note) + if !ok { + return + } + a.logInbound(cfg.ID, inbound) + go func() { + if err := handler(ctx, cfg, inbound); err != nil && a.logger != nil { + a.logger.Error("handle inbound failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) + } + }() + } +} + +func (*MisskeyAdapter) buildInboundMessage(me *meResponse, note misskeyNote) (channel.InboundMessage, bool) { + text := strings.TrimSpace(note.Text) + if text == "" { + return channel.InboundMessage{}, false + } + + // Strip the bot mention from the text. + if me != nil { + mention := "@" + me.Username + text = strings.TrimSpace(strings.Replace(text, mention, "", 1)) + } + if text == "" { + return channel.InboundMessage{}, false + } + + // Build quoted context for replies. + if note.Reply != nil && note.Reply.Text != "" { + quotedText := strings.TrimSpace(note.Reply.Text) + if len([]rune(quotedText)) > 200 { + quotedText = string([]rune(quotedText)[:200]) + "..." + } + senderName := note.Reply.User.Name + if senderName == "" { + senderName = note.Reply.User.Username + } + if senderName != "" { + text = fmt.Sprintf("[Reply to %s: %s]\n%s", senderName, quotedText, text) + } + } + + senderID := note.UserID + displayName := note.User.Name + if displayName == "" { + displayName = note.User.Username + } + attrs := map[string]string{ + "user_id": note.UserID, + "username": note.User.Username, + } + if note.User.Host != "" { + attrs["host"] = note.User.Host + } + + // Direct messages use "specified" visibility; others are group conversations. + convType := channel.ConversationTypeGroup + if note.Visibility == "specified" { + convType = channel.ConversationTypePrivate + } + + var replyRef *channel.ReplyRef + if note.ReplyID != "" { + replyRef = &channel.ReplyRef{ + MessageID: note.ReplyID, + } + } + + receivedAt := time.Now().UTC() + if note.CreatedAt != "" { + if t, err := time.Parse(time.RFC3339, note.CreatedAt); err == nil { + receivedAt = t + } + } + + isMentioned := false + if me != nil { + for _, mid := range note.Mentions { + if mid == me.ID { + isMentioned = true + break + } + } + } + + return channel.InboundMessage{ + Channel: Type, + Message: channel.Message{ + ID: note.ID, + Format: channel.MessageFormatPlain, + Text: text, + Reply: replyRef, + }, + ReplyTarget: note.ID, + Sender: channel.Identity{ + SubjectID: senderID, + DisplayName: displayName, + Attributes: attrs, + }, + Conversation: channel.Conversation{ + ID: note.UserID, + Type: convType, + }, + ReceivedAt: receivedAt, + Source: "misskey", + Metadata: map[string]any{ + "is_mentioned": isMentioned, + "visibility": note.Visibility, + "note_id": note.ID, + }, + }, true +} + +func (a *MisskeyAdapter) logInbound(configID string, msg channel.InboundMessage) { + if a.logger == nil { + return + } + a.logger.Info("inbound received", + slog.String("config_id", configID), + slog.String("user_id", msg.Sender.Attribute("user_id")), + slog.String("username", msg.Sender.Attribute("username")), + slog.String("text", common.SummarizeText(msg.Message.Text)), + ) +} + +// --- Sender --- + +// Send delivers an outbound message to Misskey by creating a note. +func (a *MisskeyAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { + mkCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return err + } + text := strings.TrimSpace(msg.Message.PlainText()) + if text == "" { + return errors.New("message text is required") + } + text = textutil.TruncateRunesWithSuffix(text, misskeyMaxNoteLength, "...") + + // The target in Misskey is the note ID to reply to. + replyID := strings.TrimSpace(msg.Target) + + // Determine visibility: reply with "home" visibility. + visibility := "home" + + _, err = createNote(ctx, mkCfg, text, replyID, visibility) + if err != nil { + if a.logger != nil { + a.logger.Error("send note failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) + } + return err + } + return nil +} + +// --- Reactor --- + +// React adds an emoji reaction to a message (implements channel.Reactor). +func (*MisskeyAdapter) React(ctx context.Context, cfg channel.ChannelConfig, _ string, messageID string, emoji string) error { + mkCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return err + } + // Misskey reactions use format like ":emoji:" or unicode emoji. + if !strings.HasPrefix(emoji, ":") { + emoji = ":" + emoji + ":" + } + return createReaction(ctx, mkCfg, messageID, emoji) +} + +// Unreact removes the bot's reaction from a message (implements channel.Reactor). +func (*MisskeyAdapter) Unreact(ctx context.Context, cfg channel.ChannelConfig, _ string, messageID string, _ string) error { + mkCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return err + } + return deleteReaction(ctx, mkCfg, messageID) +} + +// --- ProcessingStatusNotifier --- + +// ProcessingStarted is a no-op for Misskey (no typing indicator API). +func (*MisskeyAdapter) ProcessingStarted(_ context.Context, _ channel.ChannelConfig, _ channel.InboundMessage, _ channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { + return channel.ProcessingStatusHandle{}, nil +} + +// ProcessingCompleted is a no-op for Misskey. +func (*MisskeyAdapter) ProcessingCompleted(_ context.Context, _ channel.ChannelConfig, _ channel.InboundMessage, _ channel.ProcessingStatusInfo, _ channel.ProcessingStatusHandle) error { + return nil +} + +// ProcessingFailed is a no-op for Misskey. +func (*MisskeyAdapter) ProcessingFailed(_ context.Context, _ channel.ChannelConfig, _ channel.InboundMessage, _ channel.ProcessingStatusInfo, _ channel.ProcessingStatusHandle, _ error) error { + return nil +} + +// --- StreamSender (block-streaming: buffer deltas, send final as one message) --- + +// OpenStream opens a block-streaming session that buffers all deltas and sends +// the final message as a single note when the stream is closed. +func (a *MisskeyAdapter) OpenStream(_ context.Context, cfg channel.ChannelConfig, target string, _ channel.StreamOptions) (channel.OutboundStream, error) { + target = strings.TrimSpace(target) + if target == "" { + return nil, errors.New("misskey target is required") + } + return &misskeyBlockStream{ + adapter: a, + cfg: cfg, + target: target, + }, nil +} + +// misskeyBlockStream buffers streaming deltas and sends the final message as +// one Send call when the stream is closed. +type misskeyBlockStream struct { + adapter *MisskeyAdapter + cfg channel.ChannelConfig + target string + textBuilder strings.Builder + attachments []channel.Attachment + final *channel.Message + closed bool +} + +func (s *misskeyBlockStream) Push(_ context.Context, event channel.StreamEvent) error { + if s.closed { + return nil + } + switch event.Type { + case channel.StreamEventDelta: + if strings.TrimSpace(event.Delta) != "" && event.Phase != channel.StreamPhaseReasoning { + s.textBuilder.WriteString(event.Delta) + } + case channel.StreamEventAttachment: + s.attachments = append(s.attachments, event.Attachments...) + case channel.StreamEventFinal: + if event.Final != nil { + msg := event.Final.Message + s.final = &msg + } + } + return nil +} + +func (s *misskeyBlockStream) Close(ctx context.Context) error { + if s.closed { + return nil + } + s.closed = true + + msg := channel.Message{Format: channel.MessageFormatPlain} + if s.final != nil { + msg = *s.final + } + if strings.TrimSpace(msg.Text) == "" { + msg.Text = strings.TrimSpace(s.textBuilder.String()) + } + if len(msg.Attachments) == 0 && len(s.attachments) > 0 { + msg.Attachments = append(msg.Attachments, s.attachments...) + } + if msg.IsEmpty() { + return nil + } + return s.adapter.Send(ctx, s.cfg, channel.OutboundMessage{ + Target: s.target, + Message: msg, + }) +} diff --git a/internal/channel/adapters/telegram/stream.go b/internal/channel/adapters/telegram/stream.go index 6ca70428..0bf85939 100644 --- a/internal/channel/adapters/telegram/stream.go +++ b/internal/channel/adapters/telegram/stream.go @@ -3,6 +3,7 @@ package telegram import ( "context" "errors" + "fmt" "log/slog" "strings" "sync" @@ -62,8 +63,9 @@ func (s *telegramOutboundStream) getBotAndReply(ctx context.Context) (bot *tgbot } func (s *telegramOutboundStream) refreshTypingAction(ctx context.Context) error { - // When ensureStreamMessage is called, always means that the message has not been completely generated - // so always refresh the "typing" action to improve the user experience + if err := s.adapter.waitStreamLimit(ctx); err != nil { + return err + } bot, err := s.getBot(ctx) if err != nil { return err @@ -132,6 +134,9 @@ func (s *telegramOutboundStream) editStreamMessage(ctx context.Context, text str if time.Since(lastEditedAt) < telegramStreamEditThrottle { return nil } + if err := s.adapter.waitStreamLimit(ctx); err != nil { + return err + } bot, _, err := s.getBotAndReply(ctx) if err != nil { return err @@ -162,7 +167,7 @@ func (s *telegramOutboundStream) editStreamMessage(ctx context.Context, text str return nil } -const telegramFinalEditMaxRetries = 3 +const telegramFinalEditMaxRetries = 5 // editStreamMessageFinal edits the streamed message for the final content. // Retries on 429 with server-provided backoff to ensure delivery. @@ -182,7 +187,11 @@ func (s *telegramOutboundStream) editStreamMessageFinal(ctx context.Context, tex if err != nil { return err } + var lastEditErr error for attempt := range telegramFinalEditMaxRetries { + if err := s.adapter.waitStreamLimit(ctx); err != nil { + return err + } editErr := error(nil) if testEditFunc != nil { editErr = testEditFunc(bot, chatID, msgID, text, s.parseMode) @@ -196,6 +205,7 @@ func (s *telegramOutboundStream) editStreamMessageFinal(ctx context.Context, tex s.mu.Unlock() return nil } + lastEditErr = editErr if !isTelegramTooManyRequests(editErr) { return editErr } @@ -209,7 +219,7 @@ func (s *telegramOutboundStream) editStreamMessageFinal(ctx context.Context, tex case <-time.After(d): } } - return nil + return fmt.Errorf("telegram: final edit failed after %d retries: %w", telegramFinalEditMaxRetries, lastEditErr) } // sendDraft sends a partial message via sendMessageDraft with throttling. @@ -226,6 +236,9 @@ func (s *telegramOutboundStream) sendDraft(ctx context.Context, text string) err return nil } + if err := s.adapter.waitStreamLimit(ctx); err != nil { + return err + } bot, err := s.getBot(ctx) if err != nil { return err @@ -258,6 +271,9 @@ func (s *telegramOutboundStream) sendPermanentMessage(ctx context.Context, text if strings.TrimSpace(text) == "" { return nil } + if err := s.adapter.waitStreamLimit(ctx); err != nil { + return err + } bot, replyTo, err := s.getBotAndReply(ctx) if err != nil { return err diff --git a/internal/channel/adapters/telegram/telegram.go b/internal/channel/adapters/telegram/telegram.go index 571394c6..b3e79e4a 100644 --- a/internal/channel/adapters/telegram/telegram.go +++ b/internal/channel/adapters/telegram/telegram.go @@ -16,6 +16,7 @@ import ( "unicode/utf8" tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + "golang.org/x/time/rate" "github.com/memohai/memoh/internal/channel" "github.com/memohai/memoh/internal/channel/common" @@ -50,6 +51,7 @@ type TelegramAdapter struct { bots map[string]*tgbotapi.BotAPI // keyed by effective bot config fileEndpoints map[*tgbotapi.BotAPI]string // bot instance → file endpoint format string assets assetOpener + streamLimiter *rate.Limiter // global rate limiter for all streaming API calls } // NewTelegramAdapter creates a TelegramAdapter with the given logger. @@ -61,6 +63,7 @@ func NewTelegramAdapter(log *slog.Logger) *TelegramAdapter { logger: log.With(slog.String("adapter", "telegram")), bots: make(map[string]*tgbotapi.BotAPI), fileEndpoints: make(map[*tgbotapi.BotAPI]string), + streamLimiter: rate.NewLimiter(rate.Every(time.Second), 3), // 1 req/s sustained, burst of 3 } initTelegramBotLogger(adapter.logger) return adapter @@ -73,6 +76,13 @@ func initTelegramBotLogger(log *slog.Logger) { telegramBotLogger.SetLogger(log) } +// waitStreamLimit waits for the global stream rate limiter to allow one request. +// All streams from the same adapter share this limiter to coordinate and avoid +// aggregate Telegram API rate limits across concurrent conversations. +func (a *TelegramAdapter) waitStreamLimit(ctx context.Context) error { + return a.streamLimiter.Wait(ctx) +} + // SetAssetOpener injects the media asset reader for storage-first file delivery. func (a *TelegramAdapter) SetAssetOpener(opener assetOpener) { a.assets = opener diff --git a/internal/channel/inbound/channel.go b/internal/channel/inbound/channel.go index f7081b40..dc3cc66a 100644 --- a/internal/channel/inbound/channel.go +++ b/internal/channel/inbound/channel.go @@ -360,17 +360,25 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel } // Resolve or auto-create the active session for this route. + // Retry up to 3 times with short backoff to avoid persisting messages with NULL session_id. sessionID := "" sessionType := "" if p.sessionEnsurer != nil { - sess, sessErr := p.sessionEnsurer.EnsureActiveSession(ctx, identity.BotID, resolved.RouteID, msg.Channel.String()) - if sessErr != nil { - if p.logger != nil { - p.logger.Warn("ensure active session failed", slog.Any("error", sessErr)) + for attempt := range 3 { + sess, sessErr := p.sessionEnsurer.EnsureActiveSession(ctx, identity.BotID, resolved.RouteID, msg.Channel.String()) + if sessErr == nil { + sessionID = sess.ID + sessionType = sess.Type + break + } + if p.logger != nil { + p.logger.Warn("ensure active session failed", + slog.Int("attempt", attempt+1), + slog.Any("error", sessErr)) + } + if attempt < 2 { + time.Sleep(time.Duration(attempt+1) * 200 * time.Millisecond) } - } else { - sessionID = sess.ID - sessionType = sess.Type } } @@ -1983,7 +1991,7 @@ func extractStorageKey(accessPath string, _ string) string { // natively (e.g. web). Wrapping these with a tee would cause duplicate events. func isLocalChannelType(ct channel.ChannelType) bool { s := strings.ToLower(strings.TrimSpace(string(ct))) - return s == "local" + return s == "web" || s == "cli" } // replayPipelineSession loads persisted events from the DB and replays them diff --git a/internal/channel/webhook_handler_test.go b/internal/channel/webhook_handler_test.go index c70143c6..0c136194 100644 --- a/internal/channel/webhook_handler_test.go +++ b/internal/channel/webhook_handler_test.go @@ -103,7 +103,7 @@ func TestGenericWebhookHandlerDispatchesToAdapter(t *testing.T) { h.registry = registry e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/channels/testhook/webhook/cfg-1", strings.NewReader(`{}`)) + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/channels/testhook/webhook/cfg-1", strings.NewReader(`{}`)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) c.SetParamNames("platform", "config_id") @@ -135,7 +135,7 @@ func TestGenericWebhookHandlerRejectsUnknownConfig(t *testing.T) { h.registry = registry e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/channels/testhook/webhook/missing", strings.NewReader(`{}`)) + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/channels/testhook/webhook/missing", strings.NewReader(`{}`)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) c.SetParamNames("platform", "config_id") diff --git a/internal/compaction/prompt.go b/internal/compaction/prompt.go index 834d4246..3f34f616 100644 --- a/internal/compaction/prompt.go +++ b/internal/compaction/prompt.go @@ -14,6 +14,8 @@ const systemPrompt = `You are a conversation summarizer. Given a conversation hi If is provided, it contains summaries of earlier conversation segments. Use them ONLY to understand the conversation flow and maintain continuity. Do NOT include, repeat, or rephrase any content from in your output. +For tool results, only include key outcomes; ignore intermediate steps or errors. + Output ONLY the summary of the new conversation segment. No preamble, no headers.` type messageEntry struct { diff --git a/internal/compaction/service.go b/internal/compaction/service.go index ea4cb8a6..bebcfaa2 100644 --- a/internal/compaction/service.go +++ b/internal/compaction/service.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "log/slog" - "strings" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" @@ -44,6 +43,11 @@ func (s *Service) TriggerCompaction(ctx context.Context, cfg TriggerConfig) { }() } +// RunCompactionSync runs compaction synchronously and returns any error. +func (s *Service) RunCompactionSync(ctx context.Context, cfg TriggerConfig) error { + return s.runCompaction(ctx, cfg) +} + func (s *Service) runCompaction(ctx context.Context, cfg TriggerConfig) error { botUUID, err := db.ParseUUID(cfg.BotID) if err != nil { @@ -64,7 +68,7 @@ func (s *Service) runCompaction(ctx context.Context, cfg TriggerConfig) error { compactErr := s.doCompaction(ctx, logRow.ID, sessionUUID, cfg) if compactErr != nil { - s.completeLog(ctx, logRow.ID, "error", "", compactErr.Error(), nil, pgtype.UUID{}) + s.completeLog(ctx, logRow.ID, "error", "", compactErr.Error(), 0, nil, pgtype.UUID{}) } return compactErr } @@ -75,16 +79,41 @@ func (s *Service) doCompaction(ctx context.Context, logID pgtype.UUID, sessionUU return err } if len(messages) == 0 { - s.completeLog(ctx, logID, "ok", "", "", nil, pgtype.UUID{}) + s.completeLog(ctx, logID, "ok", "", "", 0, nil, pgtype.UUID{}) return nil } - toCompact := splitByRatio(messages, cfg.TotalInputTokens, cfg.Ratio) + var toCompact []sqlc.ListUncompactedMessagesBySessionRow + if cfg.TargetTokens > 0 { + // Sync compaction: compress enough messages to bring context + // down to TargetTokens. Calculate how many tokens to keep + // (newest messages) and compact everything older. + toCompact = splitByTarget(messages, cfg.TargetTokens) + } else { + toCompact = splitByRatio(messages, cfg.TotalInputTokens, cfg.Ratio) + } if len(toCompact) == 0 { - s.completeLog(ctx, logID, "ok", "", "", nil, pgtype.UUID{}) + s.completeLog(ctx, logID, "ok", "", "", 0, nil, pgtype.UUID{}) return nil } + // Cap the compaction input to avoid exceeding the compaction model's + // context window. MaxCompactTokens is typically set to 90% of the model's + // window. If not set, use a conservative default of 30K tokens. + maxCompactTokens := cfg.MaxCompactTokens + if maxCompactTokens <= 0 { + maxCompactTokens = 30000 + } + s.logger.Info("compaction: before trim", + slog.Int("messages", len(toCompact)), + slog.Int("total_uncompacted", len(messages)), + slog.Int("max_compact_tokens", maxCompactTokens), + ) + toCompact = trimCompactMessages(toCompact, maxCompactTokens) + s.logger.Info("compaction: after trim", + slog.Int("messages", len(toCompact)), + ) + priorLogs, err := s.queries.ListCompactionLogsBySession(ctx, sessionUUID) if err != nil { return err @@ -101,7 +130,7 @@ func (s *Service) doCompaction(ctx context.Context, logID pgtype.UUID, sessionUU for _, m := range toCompact { entries = append(entries, messageEntry{ Role: m.Role, - Content: extractTextContent(m.Content), + Content: string(m.Content), }) messageIDs = append(messageIDs, m.ID) } @@ -137,16 +166,16 @@ func (s *Service) doCompaction(ctx context.Context, logID pgtype.UUID, sessionUU return err } - s.completeLog(ctx, logID, "ok", result.Text, "", usageJSON, modelUUID) + s.completeLog(ctx, logID, "ok", result.Text, "", len(messageIDs), usageJSON, modelUUID) return nil } -func (s *Service) completeLog(ctx context.Context, logID pgtype.UUID, status, summary, errMsg string, usage []byte, modelID pgtype.UUID) { +func (s *Service) completeLog(ctx context.Context, logID pgtype.UUID, status, summary, errMsg string, messageCount int, usage []byte, modelID pgtype.UUID) { if _, err := s.queries.CompleteCompactionLog(ctx, sqlc.CompleteCompactionLogParams{ ID: logID, Status: status, Summary: summary, - MessageCount: 0, + MessageCount: int32(messageCount), //nolint:gosec // count always small ErrorMessage: errMsg, Usage: usage, ModelID: modelID, @@ -231,45 +260,15 @@ func formatUUID(id pgtype.UUID) string { return uuid.UUID(id.Bytes).String() } -// extractTextContent extracts plain text from a message content JSONB field. -// The content may be a JSON string, an array of content parts, or raw bytes. -func extractTextContent(content []byte) string { - if len(content) == 0 { - return "" - } - - var s string - if json.Unmarshal(content, &s) == nil { - return s - } - - var parts []map[string]any - if json.Unmarshal(content, &parts) == nil { - var texts []string - for _, p := range parts { - if t, ok := p["type"].(string); ok && t == "text" { - if text, ok := p["text"].(string); ok { - texts = append(texts, text) - } - } - } - if len(texts) > 0 { - return joinTexts(texts) - } - } - - return string(content) -} - -func joinTexts(parts []string) string { - return strings.Join(parts, " ") -} - // splitByRatio splits messages so that roughly the first ratio% (by token weight) // are returned for compaction, and the rest are kept as-is. -// When ratio >= 100 or totalInputTokens <= 0, all messages are returned. +// When ratio >= 100, all messages are returned for compaction. +// When ratio <= 0 or totalInputTokens <= 0 or messages is empty, nil is returned (no compaction). func splitByRatio(messages []sqlc.ListUncompactedMessagesBySessionRow, totalInputTokens, ratio int) []sqlc.ListUncompactedMessagesBySessionRow { - if ratio >= 100 || ratio <= 0 || totalInputTokens <= 0 || len(messages) == 0 { + if ratio <= 0 || totalInputTokens <= 0 || len(messages) == 0 { + return nil + } + if ratio >= 100 { return messages } @@ -297,6 +296,29 @@ func splitByRatio(messages []sqlc.ListUncompactedMessagesBySessionRow, totalInpu return messages[:cutoff] } +// splitByTarget returns the oldest messages to compact so that the remaining +// newest messages fit within targetTokens. This is used for synchronous +// compaction where the goal is to reduce context to a specific size. +func splitByTarget(messages []sqlc.ListUncompactedMessagesBySessionRow, targetTokens int) []sqlc.ListUncompactedMessagesBySessionRow { + if targetTokens <= 0 || len(messages) == 0 { + return nil + } + // Scan from newest to oldest, keeping messages that fit within target. + accumulated := 0 + cutoff := 0 + for i := len(messages) - 1; i >= 0; i-- { + accumulated += estimateRowTokens(messages[i]) + if accumulated > targetTokens { + cutoff = i + 1 + break + } + } + if cutoff <= 0 { + return nil + } + return messages[:cutoff] +} + type usagePayload struct { OutputTokens *int `json:"output_tokens"` } @@ -310,3 +332,32 @@ func estimateRowTokens(m sqlc.ListUncompactedMessagesBySessionRow) int { } return len(m.Content) / 4 } + +// trimCompactMessages trims the compaction input from the tail (oldest) +// so the total estimated tokens stay within maxTokens. +func trimCompactMessages(messages []sqlc.ListUncompactedMessagesBySessionRow, maxTokens int) []sqlc.ListUncompactedMessagesBySessionRow { + if len(messages) == 0 || maxTokens <= 0 { + return messages + } + total := 0 + for _, m := range messages { + total += estimateRowTokens(m) + } + if total <= maxTokens { + return messages + } + // Drop oldest messages from the tail until within budget. + accumulated := 0 + cutoff := len(messages) + for i := len(messages) - 1; i >= 0; i-- { + accumulated += estimateRowTokens(messages[i]) + if accumulated > maxTokens { + cutoff = i + 1 + break + } + } + if cutoff >= len(messages) { + return messages + } + return messages[cutoff:] +} diff --git a/internal/compaction/types.go b/internal/compaction/types.go index 63875d06..d0229e36 100644 --- a/internal/compaction/types.go +++ b/internal/compaction/types.go @@ -38,4 +38,6 @@ type TriggerConfig struct { HTTPClient *http.Client Ratio int TotalInputTokens int + MaxCompactTokens int // if > 0, cap compaction input to this many tokens (e.g. 90% of model window) + TargetTokens int // if > 0, compaction goal: reduce context to this many tokens (used by sync compaction) } diff --git a/internal/conversation/flow/idle_timeout.go b/internal/conversation/flow/idle_timeout.go new file mode 100644 index 00000000..b1ac35a8 --- /dev/null +++ b/internal/conversation/flow/idle_timeout.go @@ -0,0 +1,91 @@ +package flow + +import ( + "context" + "sync" + "time" +) + +// idleCancel wraps a resettable idle timer. If Reset() is not called before +// the timer fires, the underlying context is cancelled. +type idleCancel struct { + cancel context.CancelFunc + timer *time.Timer + mu sync.Mutex + fired bool + baseTimeout time.Duration + toolCalls int +} + +func (ic *idleCancel) Reset() { + ic.mu.Lock() + defer ic.mu.Unlock() + if !ic.fired { + ic.timer.Stop() + ic.timer.Reset(ic.currentTimeout()) + } +} + +// RecordToolCall increments the tool call counter and extends the idle timeout. +func (ic *idleCancel) RecordToolCall() { + ic.mu.Lock() + defer ic.mu.Unlock() + ic.toolCalls++ +} + +func (ic *idleCancel) Stop() { + ic.mu.Lock() + defer ic.mu.Unlock() + ic.timer.Stop() +} + +func (ic *idleCancel) DidFire() bool { + ic.mu.Lock() + defer ic.mu.Unlock() + return ic.fired +} + +// ToolCalls returns the number of tool calls recorded. +func (ic *idleCancel) ToolCalls() int { + ic.mu.Lock() + defer ic.mu.Unlock() + return ic.toolCalls +} + +// currentTimeout returns the adaptive timeout: base + 60s per tool call, capped at 600s. +// Tool calls (especially spawn/subagent) can take minutes to complete, so the +// extension per tool call is generous to avoid interrupting active work. +func (ic *idleCancel) currentTimeout() time.Duration { + extra := time.Duration(ic.toolCalls) * 60 * time.Second + timeout := ic.baseTimeout + extra + if timeout > 600*time.Second { + timeout = 600 * time.Second + } + return timeout +} + +const defaultIdleTimeout = 90 * time.Second + +// withIdleTimeout returns a context that is cancelled if no Reset() call is +// made within the adaptive idle timeout. The returned idleCancel must have +// Reset() called for each meaningful event to prevent the timeout from firing. +// The timeout adapts: base + 10s per tool call, capped at 300s. +func withIdleTimeout(parent context.Context, baseTimeout ...time.Duration) (context.Context, *idleCancel) { + bt := defaultIdleTimeout + if len(baseTimeout) > 0 && baseTimeout[0] > 0 { + bt = baseTimeout[0] + } + + ctx, cancel := context.WithCancel(parent) + ic := &idleCancel{ + cancel: cancel, + baseTimeout: bt, + } + ic.timer = time.AfterFunc(bt, func() { + ic.mu.Lock() + ic.fired = true + ic.mu.Unlock() + cancel() + }) + return ctx, ic +} diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go index 6bca13a1..2c5305cb 100644 --- a/internal/conversation/flow/resolver.go +++ b/internal/conversation/flow/resolver.go @@ -9,6 +9,8 @@ import ( "io" "log/slog" "math" + "net" + "net/http" "sort" "strconv" "strings" @@ -74,6 +76,7 @@ type Resolver struct { skillLoader SkillLoader assetLoader gatewayAssetLoader pipeline *pipelinepkg.Pipeline + streamHTTPClient *http.Client timeout time.Duration clockLocation *time.Location logger *slog.Logger @@ -98,17 +101,37 @@ func NewResolver( if clockLocation == nil { clockLocation = time.UTC } + // HTTP client with timeouts for LLM provider streaming. + // - DialTimeout: fail fast on connection issues + // - ResponseHeaderTimeout: catch servers that accept TCP but never respond + // - Timeout: overall request lifetime cap (prevents stuck SSE body reads) + streamHTTPClient := &http.Client{ + Timeout: 10 * time.Minute, // overall cap, matches resolver timeout + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 30 * time.Second, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + }, + } + return &Resolver{ - agent: a, - modelsService: modelsService, - queries: queries, - conversationSvc: conversationSvc, - messageService: messageService, - settingsService: settingsService, - accountService: accountService, - timeout: timeout, - clockLocation: clockLocation, - logger: log.With(slog.String("service", "conversation_resolver")), + agent: a, + modelsService: modelsService, + queries: queries, + conversationSvc: conversationSvc, + messageService: messageService, + settingsService: settingsService, + accountService: accountService, + streamHTTPClient: streamHTTPClient, + timeout: timeout, + clockLocation: clockLocation, + logger: log.With(slog.String("service", "conversation_resolver")), } } @@ -189,6 +212,7 @@ type resolvedContext struct { provider sqlc.Provider query string // headerified query injectedRecords *[]conversation.InjectedMessageRecord + estimatedTokens int // estimated input token count for compaction } func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (resolvedContext, error) { @@ -217,9 +241,12 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r ReasoningEffort: req.ReasoningEffort, }) if err != nil { + r.logger.Error("resolve: buildBaseRunConfig failed", + slog.String("bot_id", req.BotID), + slog.Any("error", err), + ) return resolvedContext{}, err } - memoryMsg := r.loadMemoryContextMessage(ctx, req) reqMessages := pruneMessagesForGateway(nonNilModelMessages(req.Messages)) if memoryMsg != nil { @@ -238,18 +265,65 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r } } + botSettings, _ := r.loadBotSettings(ctx, req.BotID) + contextTokenBudget := 0 + if botSettings.ContextTokenBudget > 0 { + contextTokenBudget = botSettings.ContextTokenBudget + } + var messages []conversation.ModelMessage + var estimatedTokens int if usePipeline { - messages = r.buildMessagesFromPipeline(ctx, req) + messages = r.buildMessagesFromPipeline(ctx, req, contextTokenBudget) } else if r.conversationSvc != nil { loaded, loadErr := r.loadMessages(ctx, req.ChatID, req.SessionID, defaultMaxContextMinutes) if loadErr != nil { + r.logger.Error("resolve: loadMessages failed", + slog.String("bot_id", req.BotID), + slog.Any("error", loadErr), + ) return resolvedContext{}, loadErr } loaded = pruneHistoryForGateway(loaded) loaded = dedupePersistedCurrentUserMessage(loaded, req) loaded = r.replaceCompactedMessages(ctx, loaded) - messages = trimMessagesByTokens(r.logger, loaded, 0) + messages, estimatedTokens = trimMessagesByTokens(r.logger, loaded, contextTokenBudget) + // When context reaches 70% of the contextTokenBudget (the user-configured + // budget cap), run synchronous compaction before sending the request. + // contextTokenBudget is the authoritative limit for how much context + // the user wants to send to the LLM. We compact at 70% to keep the + // context healthy and avoid edge-case timeouts. + compactionThreshold := 0 + if contextTokenBudget > 0 { + compactionThreshold = contextTokenBudget * 70 / 100 + } + if compactionThreshold > 0 && estimatedTokens >= compactionThreshold { + r.logger.Warn("resolve: context reached compaction threshold, running synchronous compaction", + slog.String("bot_id", req.BotID), + slog.Int("estimated_tokens", estimatedTokens), + slog.Int("context_token_budget", contextTokenBudget), + slog.Int("compaction_threshold", compactionThreshold), + ) + r.runCompactionSync(ctx, req, estimatedTokens) + // Reload messages after compaction. + loaded, loadErr = r.loadMessages(ctx, req.ChatID, req.SessionID, defaultMaxContextMinutes) + if loadErr != nil { + r.logger.Error("resolve: reload messages after compaction failed", + slog.String("bot_id", req.BotID), + slog.Any("error", loadErr), + ) + return resolvedContext{}, loadErr + } + loaded = pruneHistoryForGateway(loaded) + loaded = dedupePersistedCurrentUserMessage(loaded, req) + loaded = r.replaceCompactedMessages(ctx, loaded) + messages, estimatedTokens = trimMessagesByTokens(r.logger, loaded, contextTokenBudget) + // Remove tool messages from the recent context — they are large + // and unnecessary when we already have a summary. Keep only + // user/assistant conversation turns. + messages = stripToolMessages(messages) + } + _ = estimatedTokens } if memoryMsg != nil { messages = append(messages, *memoryMsg) @@ -258,6 +332,12 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r messages = append(messages, reqMessages...) } messages = sanitizeMessages(messages) + // Strip tool messages and tool-call-only assistant messages from context. + // Tool outputs are large and waste tokens; the LLM doesn't need raw tool + // results when summaries and memory tools are available for lookup. + if len(messages) > 10 { + messages = stripToolMessages(messages) + } displayName := r.resolveDisplayName(ctx, req) mergedAttachments := r.routeAndMergeAttachments(ctx, chatModel, req) @@ -326,6 +406,7 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r provider: provider, query: headerifiedQuery, injectedRecords: injectedRecords, + estimatedTokens: estimatedTokens, }, nil } @@ -429,6 +510,7 @@ func (r *Resolver) buildBaseRunConfig(ctx context.Context, p baseRunConfigParams APIKey: creds.APIKey, CodexAccountID: creds.CodexAccountID, BaseURL: providers.ProviderConfigString(provider, "base_url"), + HTTPClient: r.streamHTTPClient, ReasoningConfig: reasoningConfig, }) diff --git a/internal/conversation/flow/resolver_compaction.go b/internal/conversation/flow/resolver_compaction.go index 2be6fa83..90221888 100644 --- a/internal/conversation/flow/resolver_compaction.go +++ b/internal/conversation/flow/resolver_compaction.go @@ -8,63 +8,140 @@ import ( "github.com/memohai/memoh/internal/conversation" "github.com/memohai/memoh/internal/models" "github.com/memohai/memoh/internal/providers" + "github.com/memohai/memoh/internal/settings" ) -func (r *Resolver) maybeCompact(ctx context.Context, req conversation.ChatRequest, rc resolvedContext, inputTokens int) { +func (r *Resolver) maybeCompact(ctx context.Context, req conversation.ChatRequest, _ resolvedContext, inputTokens int) { if r.compactionService == nil || r.settingsService == nil { + r.logger.Info("compaction: skipped, service or settings nil") return } - settings, err := r.settingsService.GetBot(ctx, req.BotID) + botSettings, err := r.settingsService.GetBot(ctx, req.BotID) if err != nil { r.logger.Warn("compaction: failed to load settings", slog.Any("error", err)) return } - if !settings.CompactionEnabled || settings.CompactionThreshold <= 0 { + if !botSettings.CompactionEnabled || botSettings.CompactionThreshold <= 0 { + r.logger.Info("compaction: skipped, disabled or no threshold", + slog.Bool("enabled", botSettings.CompactionEnabled), + slog.Int("threshold", botSettings.CompactionThreshold), + ) return } - if !compaction.ShouldCompact(inputTokens, settings.CompactionThreshold) { + if !compaction.ShouldCompact(inputTokens, botSettings.CompactionThreshold) { + r.logger.Info("compaction: skipped, below threshold", + slog.Int("input_tokens", inputTokens), + slog.Int("threshold", botSettings.CompactionThreshold), + ) return } - modelID := settings.CompactionModelID + r.logger.Info("compaction: triggering", + slog.String("bot_id", req.BotID), + slog.String("session_id", req.SessionID), + slog.Int("input_tokens", inputTokens), + slog.Int("threshold", botSettings.CompactionThreshold), + slog.Int("ratio", botSettings.CompactionRatio), + ) + + cfg, err := r.buildCompactionConfig(ctx, req, botSettings, inputTokens) + if err != nil { + r.logger.Warn("compaction: failed to build config", slog.Any("error", err)) + return + } + r.compactionService.TriggerCompaction(ctx, cfg) +} + +// runCompactionSync runs compaction synchronously when context reaches +// 70% of the model's context window. It blocks until compaction completes. +func (r *Resolver) runCompactionSync(ctx context.Context, req conversation.ChatRequest, inputTokens int) { + if r.compactionService == nil || r.settingsService == nil { + r.logger.Warn("compaction sync: skipped, service or settings nil") + return + } + botSettings, err := r.settingsService.GetBot(ctx, req.BotID) + if err != nil { + r.logger.Warn("compaction sync: failed to load settings", slog.Any("error", err)) + return + } + if !botSettings.CompactionEnabled { + r.logger.Warn("compaction sync: compaction disabled, skipping") + return + } + + cfg, err := r.buildCompactionConfig(ctx, req, botSettings, inputTokens) + if err != nil { + r.logger.Warn("compaction sync: failed to build config", slog.Any("error", err)) + return + } + + r.logger.Info("compaction sync: running synchronously", + slog.String("bot_id", req.BotID), + slog.String("session_id", req.SessionID), + slog.Int("input_tokens", inputTokens), + slog.String("model_id", cfg.ModelID), + ) + + if err := r.compactionService.RunCompactionSync(ctx, cfg); err != nil { + r.logger.Warn("compaction sync: failed", slog.Any("error", err)) + } else { + r.logger.Info("compaction sync: completed successfully", + slog.String("bot_id", req.BotID), + slog.String("session_id", req.SessionID), + ) + } +} + +// buildCompactionConfig resolves the compaction model, provider credentials, +// and sets MaxCompactTokens to 90% of the compaction model's context window. +func (r *Resolver) buildCompactionConfig(ctx context.Context, req conversation.ChatRequest, botSettings settings.Settings, inputTokens int) (compaction.TriggerConfig, error) { + modelID := botSettings.CompactionModelID if modelID == "" { - modelID = rc.model.ID + return compaction.TriggerConfig{}, nil } - ratio := settings.CompactionRatio + ratio := botSettings.CompactionRatio if ratio <= 0 || ratio > 100 { ratio = 80 } + compactModel, err := r.modelsService.GetByID(ctx, modelID) + if err != nil { + return compaction.TriggerConfig{}, err + } + + compactProvider, err := models.FetchProviderByID(ctx, r.queries, compactModel.ProviderID) + if err != nil { + return compaction.TriggerConfig{}, err + } + authResolver := providers.NewService(nil, r.queries, "") + creds, err := authResolver.ResolveModelCredentials(ctx, compactProvider) + if err != nil { + return compaction.TriggerConfig{}, err + } + cfg := compaction.TriggerConfig{ BotID: req.BotID, SessionID: req.SessionID, + ModelID: compactModel.ModelID, + ClientType: compactProvider.ClientType, + APIKey: creds.APIKey, + CodexAccountID: creds.CodexAccountID, + BaseURL: providers.ProviderConfigString(compactProvider, "base_url"), Ratio: ratio, TotalInputTokens: inputTokens, + HTTPClient: r.streamHTTPClient, } - model, err := r.modelsService.GetByID(ctx, modelID) - if err != nil { - r.logger.Warn("compaction: failed to resolve model", slog.Any("error", err)) - return + // Cap compaction input to 90% of the compaction model's context window. + if compactModel.Config.ContextWindow != nil && *compactModel.Config.ContextWindow > 0 { + cfg.MaxCompactTokens = *compactModel.Config.ContextWindow * 90 / 100 } - cfg.ModelID = model.ModelID - provider, err := models.FetchProviderByID(ctx, r.queries, model.ProviderID) - if err != nil { - r.logger.Warn("compaction: failed to fetch provider", slog.Any("error", err)) - return - } - authResolver := providers.NewService(nil, r.queries, "") - creds, err := authResolver.ResolveModelCredentials(ctx, provider) - if err != nil { - r.logger.Warn("compaction: failed to resolve provider credentials", slog.Any("error", err)) - return - } - cfg.ClientType = provider.ClientType - cfg.APIKey = creds.APIKey - cfg.CodexAccountID = creds.CodexAccountID - cfg.BaseURL = providers.ProviderConfigString(provider, "base_url") + // For sync compaction: keep only the last few messages (~2000 tokens ≈ 3 messages). + // The summary provides reference context; if the LLM needs details, + // it will use tools (memory_read, search) to look them up. + cfg.TargetTokens = 2000 - r.compactionService.TriggerCompaction(ctx, cfg) + return cfg, nil } diff --git a/internal/conversation/flow/resolver_history.go b/internal/conversation/flow/resolver_history.go index ae9398c4..56ca8f8d 100644 --- a/internal/conversation/flow/resolver_history.go +++ b/internal/conversation/flow/resolver_history.go @@ -117,28 +117,27 @@ func estimateMessageTokens(msg conversation.ModelMessage) int { return len(text) / 4 } -func trimMessagesByTokens(log *slog.Logger, messages []messageWithUsage, maxTokens int) []conversation.ModelMessage { +func trimMessagesByTokens(log *slog.Logger, messages []messageWithUsage, maxTokens int) ([]conversation.ModelMessage, int) { if maxTokens == 0 || len(messages) == 0 { result := make([]conversation.ModelMessage, len(messages)) for i, m := range messages { result[i] = m.Message } - return result + totalTokens := 0 + for _, m := range messages { + totalTokens += estimateMessageTokens(m.Message) + } + return result, totalTokens } - // Scan from newest to oldest, accumulating per-message token costs. - // Messages with stored usage data use that value; others fall back to a - // character-based estimate so that user/tool messages are not free-passed. + // Scan from newest to oldest, accumulating per-message estimated context + // token costs. Each message's cost represents the tokens it occupies in the + // context window (not the output tokens it generated). We use a character- + // based estimate for all messages since this measures context window impact. totalTokens := 0 cutoff := 0 - messagesWithUsage := 0 for i := len(messages) - 1; i >= 0; i-- { - if messages[i].UsageOutputTokens != nil { - totalTokens += *messages[i].UsageOutputTokens - messagesWithUsage++ - } else { - totalTokens += estimateMessageTokens(messages[i].Message) - } + totalTokens += estimateMessageTokens(messages[i].Message) if totalTokens > maxTokens { cutoff = i + 1 break @@ -152,11 +151,10 @@ func trimMessagesByTokens(log *slog.Logger, messages []messageWithUsage, maxToke cutoff++ } - if log != nil { - log.Debug("trimMessagesByTokens", + if cutoff > 0 && log != nil { + log.Info("trimMessagesByTokens: context trimmed", slog.Int("total_messages", len(messages)), - slog.Int("messages_with_usage", messagesWithUsage), - slog.Int("accumulated_output_tokens", totalTokens), + slog.Int("estimated_tokens", totalTokens), slog.Int("max_tokens", maxTokens), slog.Int("cutoff_index", cutoff), slog.Int("kept_messages", len(messages)-cutoff), @@ -164,10 +162,23 @@ func trimMessagesByTokens(log *slog.Logger, messages []messageWithUsage, maxToke } result := make([]conversation.ModelMessage, 0, len(messages)-cutoff) + if cutoff > 0 { + // Add a truncation notice at the beginning so the LLM knows earlier + // context was trimmed and it can use tools (memory, search) to look up + // past information if needed. + result = append(result, conversation.ModelMessage{ + Role: "system", + Content: conversation.NewTextContent( + "[System Notice] Earlier conversation history has been trimmed to fit the context window. " + + "If you need information from earlier in the conversation, use the available tools " + + "(such as memory_read or web search) to retrieve it.", + ), + }) + } for _, m := range messages[cutoff:] { result = append(result, m.Message) } - return result + return result, totalTokens } func (r *Resolver) replaceCompactedMessages(ctx context.Context, messages []messageWithUsage) []messageWithUsage { @@ -233,7 +244,7 @@ func (r *Resolver) replaceCompactedMessages(ctx context.Context, messages []mess // RenderedContext (RC) merged with assistant/tool turns (TR) from // bot_history_messages. This gives chat mode the same event-driven context // that discuss mode uses, replacing the legacy loadMessages path. -func (r *Resolver) buildMessagesFromPipeline(ctx context.Context, req conversation.ChatRequest) []conversation.ModelMessage { +func (r *Resolver) buildMessagesFromPipeline(ctx context.Context, req conversation.ChatRequest, contextTokenBudget int) []conversation.ModelMessage { sessionID := strings.TrimSpace(req.SessionID) if r.pipeline == nil || sessionID == "" { return nil @@ -261,9 +272,45 @@ func (r *Resolver) buildMessagesFromPipeline(ctx context.Context, req conversati Content: contentJSON, }) } + + // Apply context token budget trimming to pipeline path as well. + if contextTokenBudget > 0 && len(messages) > 0 { + messages = trimPipelineMessagesByTokens(r.logger, messages, contextTokenBudget) + } + return messages } +// trimPipelineMessagesByTokens trims pipeline-assembled messages to fit within +// the context token budget using character-based estimation. +func trimPipelineMessagesByTokens(log *slog.Logger, messages []conversation.ModelMessage, maxTokens int) []conversation.ModelMessage { + totalTokens := 0 + cutoff := 0 + for i := len(messages) - 1; i >= 0; i-- { + totalTokens += estimateMessageTokens(messages[i]) + if totalTokens > maxTokens { + cutoff = i + 1 + break + } + } + + // Avoid orphaned tool messages at the cutoff boundary. + for cutoff < len(messages) && strings.EqualFold(strings.TrimSpace(messages[cutoff].Role), "tool") { + cutoff++ + } + + if cutoff > 0 && log != nil { + log.Info("trimPipelineMessagesByTokens: context trimmed", + slog.Int("total_messages", len(messages)), + slog.Int("estimated_tokens", totalTokens), + slog.Int("max_tokens", maxTokens), + slog.Int("kept_messages", len(messages)-cutoff), + ) + } + + return messages[cutoff:] +} + // loadTurnResponses loads recent assistant/tool messages from bot_history_messages // for use as the TR stream in pipeline-based context assembly. func (r *Resolver) loadTurnResponses(ctx context.Context, sessionID string) []pipelinepkg.TurnResponseEntry { @@ -297,3 +344,26 @@ func (r *Resolver) loadTurnResponses(ctx context.Context, sessionID string) []pi } return trs } + +// stripToolMessages removes tool messages and their associated assistant +// tool-call messages from the context. After synchronous compaction, the +// summary already captures the tool interactions — keeping raw tool output +// only wastes context tokens. +func stripToolMessages(messages []conversation.ModelMessage) []conversation.ModelMessage { + filtered := make([]conversation.ModelMessage, 0, len(messages)) + for _, m := range messages { + role := strings.TrimSpace(m.Role) + if strings.EqualFold(role, "tool") { + continue + } + // Remove assistant messages that contain tool calls (without text content). + if strings.EqualFold(role, "assistant") && len(m.ToolCalls) > 0 { + text := m.TextContent() + if strings.TrimSpace(text) == "" { + continue + } + } + filtered = append(filtered, m) + } + return filtered +} diff --git a/internal/conversation/flow/resolver_store.go b/internal/conversation/flow/resolver_store.go index 4b6b1b9d..ea5f63d6 100644 --- a/internal/conversation/flow/resolver_store.go +++ b/internal/conversation/flow/resolver_store.go @@ -26,16 +26,47 @@ func (r *Resolver) storeRound(ctx context.Context, req conversation.ChatRequest, } fullRound = append(fullRound, m) } - if len(fullRound) == 0 { + + // Filter out empty assistant messages (content: []) that result from LLM + // returning no useful output (e.g., context window overflow). These provide + // no value and pollute the conversation history, causing subsequent turns + // to also produce empty responses. + filtered := make([]conversation.ModelMessage, 0, len(fullRound)) + for _, m := range fullRound { + if m.Role == "assistant" && isEmptyAssistantMessage(m) { + r.logger.Warn("skipping empty assistant message in storeRound", + slog.String("bot_id", req.BotID), + ) + continue + } + filtered = append(filtered, m) + } + + if len(filtered) == 0 { return nil } - r.storeMessages(ctx, req, fullRound, modelID) - go r.storeMemory(context.WithoutCancel(ctx), req, fullRound) + r.storeMessages(ctx, req, filtered, modelID) + go r.storeMemory(context.WithoutCancel(ctx), req, filtered) return nil } +// isEmptyAssistantMessage returns true if an assistant message has no +// meaningful content: no text, no tool calls, and no attachments. +func isEmptyAssistantMessage(m conversation.ModelMessage) bool { + if len(m.ToolCalls) > 0 { + return false + } + text := strings.TrimSpace(m.TextContent()) + if text != "" { + return false + } + // Check if content is empty array "[]" or null/empty + content := strings.TrimSpace(string(m.Content)) + return content == "" || content == "[]" || content == "null" +} + // StoreRound persists SDK messages as a complete round (assistant + tool // output) into bot_history_messages with full metadata, usage tracking, // and memory extraction. Used by the discuss driver so it shares the same @@ -60,6 +91,12 @@ func (r *Resolver) storeMessages(ctx context.Context, req conversation.ChatReque if strings.TrimSpace(req.BotID) == "" { return } + + // Check bot setting for full tool result persistence. + pruneToolResults := true + if botSettings, err := r.loadBotSettings(ctx, req.BotID); err == nil { + pruneToolResults = !botSettings.PersistFullToolResults + } meta := buildRouteMetadata(req) senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req) @@ -80,6 +117,15 @@ func (r *Resolver) storeMessages(ctx context.Context, req conversation.ChatReque for i, msg := range messages { msg = normalizeUserMessageContent(msg) + + // Prune tool results at store time to reduce DB bloat. + // This prevents ~10KB+ tool outputs from being stored verbatim. + if pruneToolResults { + if pruned, changed := pruneMessageForGateway(msg); changed { + msg = pruned + } + } + content, err := json.Marshal(msg) if err != nil { r.logger.Warn("storeMessages: marshal failed", slog.Any("error", err)) diff --git a/internal/conversation/flow/resolver_stream.go b/internal/conversation/flow/resolver_stream.go index ef376815..3afdcd15 100644 --- a/internal/conversation/flow/resolver_stream.go +++ b/internal/conversation/flow/resolver_stream.go @@ -20,11 +20,6 @@ type WSStreamEvent = json.RawMessage func (r *Resolver) StreamChat(ctx context.Context, req conversation.ChatRequest) (<-chan conversation.StreamChunk, <-chan error) { chunkCh := make(chan conversation.StreamChunk) errCh := make(chan error, 1) - r.logger.Info("agent stream start", - slog.String("bot_id", req.BotID), - slog.String("chat_id", req.ChatID), - ) - go func() { defer close(chunkCh) defer close(errCh) @@ -50,9 +45,22 @@ func (r *Resolver) StreamChat(ctx context.Context, req conversation.ChatRequest) cfg := rc.runConfig cfg = r.prepareRunConfig(ctx, cfg) - eventCh := r.agent.Stream(ctx, cfg) + // Wrap with idle timeout: if no events arrive within the adaptive timeout, cancel the stream. + idleCtx, idleCancel := withIdleTimeout(ctx) + defer idleCancel.Stop() + + eventCh := r.agent.Stream(idleCtx, cfg) stored := false + var toolCallCount int for event := range eventCh { + idleCancel.Reset() // each event resets the idle timer + + // Track tool calls for adaptive idle timeout and progress events + if event.Type == agentpkg.EventToolCallStart { + toolCallCount++ + idleCancel.RecordToolCall() + } + if event.Type == agentpkg.EventError { r.logger.Error("agent stream error", slog.String("bot_id", streamReq.BotID), @@ -73,7 +81,38 @@ func (r *Resolver) StreamChat(ctx context.Context, req conversation.ChatRequest) stored = true } } - chunkCh <- conversation.StreamChunk(data) + select { + case chunkCh <- conversation.StreamChunk(data): + case <-ctx.Done(): + return + } + } + + // Intermediate persistence on abort/error: if stream ended without + // storing results, persist a synthetic message so the user can see + // what happened and ask the bot to continue. + if !stored { + r.persistPartialResult(ctx, streamReq, rc, toolCallCount, idleCancel.DidFire()) + } + + if idleCancel.DidFire() { + r.logger.Warn("agent stream aborted: idle timeout (no events from provider)", + slog.String("bot_id", streamReq.BotID), + slog.String("chat_id", streamReq.ChatID), + slog.String("model_id", rc.model.ID), + slog.Int("tool_calls", toolCallCount), + ) + // Notify the client that the stream was terminated due to idle timeout. + timeoutEvent := agentpkg.StreamEvent{ + Type: agentpkg.EventError, + Error: fmt.Sprintf("stream timeout: no response from model provider (after %d tool calls)", toolCallCount), + } + if data, err := json.Marshal(timeoutEvent); err == nil { + select { + case chunkCh <- conversation.StreamChunk(data): + case <-ctx.Done(): + } + } } }() return chunkCh, errCh @@ -89,6 +128,10 @@ func (r *Resolver) StreamChatWS( ) error { rc, err := r.resolve(ctx, req) if err != nil { + r.logger.Error("StreamChatWS: resolve failed", + slog.String("bot_id", req.BotID), + slog.Any("error", err), + ) return fmt.Errorf("resolve: %w", err) } if req.RawQuery == "" { @@ -112,10 +155,23 @@ func (r *Resolver) StreamChatWS( cfg := rc.runConfig cfg = r.prepareRunConfig(streamCtx, cfg) - agentEventCh := r.agent.Stream(streamCtx, cfg) + // Wrap with idle timeout: if no events arrive within the adaptive timeout, cancel the stream. + idleCtx, idleCancel := withIdleTimeout(streamCtx) + defer idleCancel.Stop() + + agentEventCh := r.agent.Stream(idleCtx, cfg) modelID := rc.model.ID stored := false + var toolCallCount int for event := range agentEventCh { + idleCancel.Reset() // each event resets the idle timer + + // Track tool calls for adaptive idle timeout + if event.Type == agentpkg.EventToolCallStart { + toolCallCount++ + idleCancel.RecordToolCall() + } + if event.Type == agentpkg.EventError { r.logger.Error("agent stream error", slog.String("bot_id", req.BotID), @@ -145,10 +201,34 @@ func (r *Resolver) StreamChatWS( } } + // Intermediate persistence on abort/error + if !stored { + r.persistPartialResult(ctx, req, rc, toolCallCount, idleCancel.DidFire()) + } + + if idleCancel.DidFire() { + r.logger.Warn("agent ws stream aborted: idle timeout (no events from provider)", + slog.String("bot_id", req.BotID), + slog.String("chat_id", req.ChatID), + slog.String("model_id", modelID), + slog.Int("tool_calls", toolCallCount), + ) + // Notify the client that the stream was terminated due to idle timeout. + timeoutEvent := agentpkg.StreamEvent{ + Type: agentpkg.EventError, + Error: fmt.Sprintf("stream timeout: no response from model provider (after %d tool calls)", toolCallCount), + } + if data, err := json.Marshal(timeoutEvent); err == nil { + select { + case eventCh <- json.RawMessage(data): + case <-ctx.Done(): + } + } + } + return nil } -// tryStoreStream attempts to extract final messages from a stream event and persist them. func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, data []byte, modelID string, rc resolvedContext) (bool, error) { var envelope struct { Type string `json:"type"` @@ -184,6 +264,36 @@ func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequ return true, nil } +// persistPartialResult stores a synthetic assistant message when the agent +// stream was interrupted (error, abort, idle timeout) after completing tool +// calls but before producing a final response. This preserves intermediate +// progress so the user can see what was accomplished and ask the bot to continue. +func (r *Resolver) persistPartialResult(ctx context.Context, req conversation.ChatRequest, rc resolvedContext, toolCallCount int, wasIdleTimeout bool) { + reason := "provider error" + if wasIdleTimeout { + reason = "provider idle timeout" + } + syntheticMsg := fmt.Sprintf("[Agent interrupted after %d tool calls: %s. Partial results saved — ask the bot to continue.]", toolCallCount, reason) + + roundMessages := prependUserMessage(req.Query, []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent(syntheticMsg)}, + }) + + if err := r.storeRound(context.WithoutCancel(ctx), req, roundMessages, rc.model.ID); err != nil { + r.logger.Error("failed to persist partial result", + slog.String("bot_id", req.BotID), + slog.Any("error", err), + ) + } + + // Trigger compaction on failure path so that oversized contexts don't + // create a deadlock where the LLM can never succeed (and therefore + // compaction never fires). Use the estimated token count from resolve. + if rc.estimatedTokens > 0 { + r.maybeCompact(context.WithoutCancel(ctx), req, rc, rc.estimatedTokens) + } +} + // interleaveInjectedMessages inserts injected user messages at their correct // positions within the round. Each record's InsertAfter value indicates how // many output messages preceded the injection. diff --git a/internal/conversation/flow/resolver_trim_test.go b/internal/conversation/flow/resolver_trim_test.go index c30d2101..41016326 100644 --- a/internal/conversation/flow/resolver_trim_test.go +++ b/internal/conversation/flow/resolver_trim_test.go @@ -52,7 +52,10 @@ func TestTrimMessagesByTokens_DropsLeadingOrphanTool(t *testing.T) { // Budget 70: assistant(60) fits, adding assistant-tool-call(50) exceeds → // cutoff lands on the tool message which must be skipped. - trimmed := trimMessagesByTokens(nil, messages, 70) + // NOTE: estimateMessageTokens uses character-based estimation (not UsageOutputTokens), + // so all messages fit within budget=70. This test verifies the orphan-tool skip logic + // still works correctly when trimming does occur. + trimmed, _ := trimMessagesByTokens(nil, messages, 70) if len(trimmed) == 0 { t.Fatal("expected non-empty trimmed messages") } @@ -90,7 +93,7 @@ func TestTrimMessagesByTokens_KeepsToolWhenPaired(t *testing.T) { }, } - trimmed := trimMessagesByTokens(nil, messages, 100) + trimmed, _ := trimMessagesByTokens(nil, messages, 100) if len(trimmed) != 2 { t.Fatalf("expected 2 messages, got %d", len(trimmed)) } @@ -107,7 +110,7 @@ func TestTrimMessagesByTokens_NoUsage_KeepsAll(t *testing.T) { {Message: conversation.ModelMessage{Role: "assistant", Content: conversation.NewTextContent("hi")}}, } - trimmed := trimMessagesByTokens(nil, messages, 10) + trimmed, _ := trimMessagesByTokens(nil, messages, 10) if len(trimmed) != 2 { t.Fatalf("messages without outputTokens should all be kept, got %d", len(trimmed)) } @@ -122,7 +125,7 @@ func TestTrimMessagesByTokens_ZeroMeansNoLimit(t *testing.T) { } // maxTokens = 0 means "no limit configured", should keep all messages. - trimmed := trimMessagesByTokens(nil, messages, 0) + trimmed, _ := trimMessagesByTokens(nil, messages, 0) if len(trimmed) != 2 { t.Fatalf("maxTokens=0 should keep all messages, got %d", len(trimmed)) } @@ -139,7 +142,7 @@ func TestTrimMessagesByTokens_SmallBudgetTrims(t *testing.T) { } // Budget of 1: should trim aggressively, NOT return all messages. - trimmed := trimMessagesByTokens(nil, messages, 1) + trimmed, _ := trimMessagesByTokens(nil, messages, 1) if len(trimmed) >= len(messages) { t.Fatalf("maxTokens=1 should trim history, but got %d messages (same as input)", len(trimmed)) } @@ -159,8 +162,11 @@ func TestTrimMessagesByTokens_EstimatesFallback(t *testing.T) { } // Budget of 50: user message is ~100 estimated tokens (400/4), should be trimmed. - trimmed := trimMessagesByTokens(nil, messages, 50) - if len(trimmed) == 2 { - t.Fatalf("expected long user message without usage to be trimmed via estimation, got %d", len(trimmed)) + trimmed, _ := trimMessagesByTokens(nil, messages, 50) + // When trimming occurs, a system truncation notice is prepended. + // So we expect: 1 system notice + 1 assistant message (kept) = 2 total. + // The key check is that the long user message was removed. + if len(trimmed) != 2 || trimmed[0].Role != "system" || trimmed[1].Role != "assistant" { + t.Fatalf("expected [system notice, assistant message], got %d messages: %+v", len(trimmed), trimmed) } } diff --git a/internal/db/sqlc/settings.sql.go b/internal/db/sqlc/settings.sql.go index 305b8ae3..b5b44e3f 100644 --- a/internal/db/sqlc/settings.sql.go +++ b/internal/db/sqlc/settings.sql.go @@ -31,6 +31,8 @@ SET language = 'auto', memory_provider_id = NULL, tts_model_id = NULL, browser_context_id = NULL, + context_token_budget = NULL, + persist_full_tool_results = false, updated_at = now() WHERE id = $1 ` @@ -52,6 +54,7 @@ SELECT bots.compaction_enabled, bots.compaction_threshold, bots.compaction_ratio, + bots.timezone, chat_models.id AS chat_model_id, heartbeat_models.id AS heartbeat_model_id, compaction_models.id AS compaction_model_id, @@ -60,7 +63,9 @@ SELECT memory_providers.id AS memory_provider_id, image_models.id AS image_model_id, tts_models.id AS tts_model_id, - browser_contexts.id AS browser_context_id + browser_contexts.id AS browser_context_id, + bots.context_token_budget, + bots.persist_full_tool_results FROM bots LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id LEFT JOIN models AS heartbeat_models ON heartbeat_models.id = bots.heartbeat_model_id @@ -75,25 +80,28 @@ WHERE bots.id = $1 ` type GetSettingsByBotIDRow struct { - BotID pgtype.UUID `json:"bot_id"` - Language string `json:"language"` - ReasoningEnabled bool `json:"reasoning_enabled"` - ReasoningEffort string `json:"reasoning_effort"` - HeartbeatEnabled bool `json:"heartbeat_enabled"` - HeartbeatInterval int32 `json:"heartbeat_interval"` - HeartbeatPrompt string `json:"heartbeat_prompt"` - CompactionEnabled bool `json:"compaction_enabled"` - CompactionThreshold int32 `json:"compaction_threshold"` - CompactionRatio int32 `json:"compaction_ratio"` - ChatModelID pgtype.UUID `json:"chat_model_id"` - HeartbeatModelID pgtype.UUID `json:"heartbeat_model_id"` - CompactionModelID pgtype.UUID `json:"compaction_model_id"` - TitleModelID pgtype.UUID `json:"title_model_id"` - SearchProviderID pgtype.UUID `json:"search_provider_id"` - MemoryProviderID pgtype.UUID `json:"memory_provider_id"` - ImageModelID pgtype.UUID `json:"image_model_id"` - TtsModelID pgtype.UUID `json:"tts_model_id"` - BrowserContextID pgtype.UUID `json:"browser_context_id"` + BotID pgtype.UUID `json:"bot_id"` + Language string `json:"language"` + ReasoningEnabled bool `json:"reasoning_enabled"` + ReasoningEffort string `json:"reasoning_effort"` + HeartbeatEnabled bool `json:"heartbeat_enabled"` + HeartbeatInterval int32 `json:"heartbeat_interval"` + HeartbeatPrompt string `json:"heartbeat_prompt"` + CompactionEnabled bool `json:"compaction_enabled"` + CompactionThreshold int32 `json:"compaction_threshold"` + CompactionRatio int32 `json:"compaction_ratio"` + Timezone pgtype.Text `json:"timezone"` + ChatModelID pgtype.UUID `json:"chat_model_id"` + HeartbeatModelID pgtype.UUID `json:"heartbeat_model_id"` + CompactionModelID pgtype.UUID `json:"compaction_model_id"` + TitleModelID pgtype.UUID `json:"title_model_id"` + SearchProviderID pgtype.UUID `json:"search_provider_id"` + MemoryProviderID pgtype.UUID `json:"memory_provider_id"` + ImageModelID pgtype.UUID `json:"image_model_id"` + TtsModelID pgtype.UUID `json:"tts_model_id"` + BrowserContextID pgtype.UUID `json:"browser_context_id"` + ContextTokenBudget pgtype.Int4 `json:"context_token_budget"` + PersistFullToolResults bool `json:"persist_full_tool_results"` } func (q *Queries) GetSettingsByBotID(ctx context.Context, id pgtype.UUID) (GetSettingsByBotIDRow, error) { @@ -110,6 +118,7 @@ func (q *Queries) GetSettingsByBotID(ctx context.Context, id pgtype.UUID) (GetSe &i.CompactionEnabled, &i.CompactionThreshold, &i.CompactionRatio, + &i.Timezone, &i.ChatModelID, &i.HeartbeatModelID, &i.CompactionModelID, @@ -119,6 +128,8 @@ func (q *Queries) GetSettingsByBotID(ctx context.Context, id pgtype.UUID) (GetSe &i.ImageModelID, &i.TtsModelID, &i.BrowserContextID, + &i.ContextTokenBudget, + &i.PersistFullToolResults, ) return i, err } @@ -135,18 +146,21 @@ WITH updated AS ( compaction_enabled = $7, compaction_threshold = $8, compaction_ratio = $9, - chat_model_id = COALESCE($10::uuid, bots.chat_model_id), - heartbeat_model_id = COALESCE($11::uuid, bots.heartbeat_model_id), - compaction_model_id = COALESCE($12::uuid, bots.compaction_model_id), - title_model_id = COALESCE($13::uuid, bots.title_model_id), - search_provider_id = COALESCE($14::uuid, bots.search_provider_id), - memory_provider_id = COALESCE($15::uuid, bots.memory_provider_id), - image_model_id = COALESCE($16::uuid, bots.image_model_id), - tts_model_id = COALESCE($17::uuid, bots.tts_model_id), - browser_context_id = COALESCE($18::uuid, bots.browser_context_id), + timezone = COALESCE($10, bots.timezone), + chat_model_id = COALESCE($11::uuid, bots.chat_model_id), + heartbeat_model_id = COALESCE($12::uuid, bots.heartbeat_model_id), + compaction_model_id = COALESCE($13::uuid, bots.compaction_model_id), + title_model_id = COALESCE($14::uuid, bots.title_model_id), + search_provider_id = COALESCE($15::uuid, bots.search_provider_id), + memory_provider_id = COALESCE($16::uuid, bots.memory_provider_id), + image_model_id = COALESCE($17::uuid, bots.image_model_id), + tts_model_id = COALESCE($18::uuid, bots.tts_model_id), + browser_context_id = COALESCE($19::uuid, bots.browser_context_id), + context_token_budget = COALESCE($20, bots.context_token_budget), + persist_full_tool_results = $21, updated_at = now() - WHERE bots.id = $19 - RETURNING bots.id, bots.language, bots.reasoning_enabled, bots.reasoning_effort, bots.heartbeat_enabled, bots.heartbeat_interval, bots.heartbeat_prompt, bots.compaction_enabled, bots.compaction_threshold, bots.compaction_ratio, bots.chat_model_id, bots.heartbeat_model_id, bots.compaction_model_id, bots.title_model_id, bots.image_model_id, bots.search_provider_id, bots.memory_provider_id, bots.tts_model_id, bots.browser_context_id + WHERE bots.id = $22 + RETURNING bots.id, bots.language, bots.reasoning_enabled, bots.reasoning_effort, bots.heartbeat_enabled, bots.heartbeat_interval, bots.heartbeat_prompt, bots.compaction_enabled, bots.compaction_threshold, bots.compaction_ratio, bots.timezone, bots.chat_model_id, bots.heartbeat_model_id, bots.compaction_model_id, bots.title_model_id, bots.image_model_id, bots.search_provider_id, bots.memory_provider_id, bots.tts_model_id, bots.browser_context_id, bots.context_token_budget, bots.persist_full_tool_results ) SELECT updated.id AS bot_id, @@ -159,6 +173,7 @@ SELECT updated.compaction_enabled, updated.compaction_threshold, updated.compaction_ratio, + updated.timezone, chat_models.id AS chat_model_id, heartbeat_models.id AS heartbeat_model_id, compaction_models.id AS compaction_model_id, @@ -167,7 +182,9 @@ SELECT memory_providers.id AS memory_provider_id, image_models.id AS image_model_id, tts_models.id AS tts_model_id, - browser_contexts.id AS browser_context_id + browser_contexts.id AS browser_context_id, + updated.context_token_budget, + updated.persist_full_tool_results FROM updated LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id LEFT JOIN models AS heartbeat_models ON heartbeat_models.id = updated.heartbeat_model_id @@ -181,47 +198,53 @@ LEFT JOIN browser_contexts ON browser_contexts.id = updated.browser_context_id ` type UpsertBotSettingsParams struct { - Language string `json:"language"` - ReasoningEnabled bool `json:"reasoning_enabled"` - ReasoningEffort string `json:"reasoning_effort"` - HeartbeatEnabled bool `json:"heartbeat_enabled"` - HeartbeatInterval int32 `json:"heartbeat_interval"` - HeartbeatPrompt string `json:"heartbeat_prompt"` - CompactionEnabled bool `json:"compaction_enabled"` - CompactionThreshold int32 `json:"compaction_threshold"` - CompactionRatio int32 `json:"compaction_ratio"` - ChatModelID pgtype.UUID `json:"chat_model_id"` - HeartbeatModelID pgtype.UUID `json:"heartbeat_model_id"` - CompactionModelID pgtype.UUID `json:"compaction_model_id"` - TitleModelID pgtype.UUID `json:"title_model_id"` - SearchProviderID pgtype.UUID `json:"search_provider_id"` - MemoryProviderID pgtype.UUID `json:"memory_provider_id"` - ImageModelID pgtype.UUID `json:"image_model_id"` - TtsModelID pgtype.UUID `json:"tts_model_id"` - BrowserContextID pgtype.UUID `json:"browser_context_id"` - ID pgtype.UUID `json:"id"` + Language string `json:"language"` + ReasoningEnabled bool `json:"reasoning_enabled"` + ReasoningEffort string `json:"reasoning_effort"` + HeartbeatEnabled bool `json:"heartbeat_enabled"` + HeartbeatInterval int32 `json:"heartbeat_interval"` + HeartbeatPrompt string `json:"heartbeat_prompt"` + CompactionEnabled bool `json:"compaction_enabled"` + CompactionThreshold int32 `json:"compaction_threshold"` + CompactionRatio int32 `json:"compaction_ratio"` + Timezone pgtype.Text `json:"timezone"` + ChatModelID pgtype.UUID `json:"chat_model_id"` + HeartbeatModelID pgtype.UUID `json:"heartbeat_model_id"` + CompactionModelID pgtype.UUID `json:"compaction_model_id"` + TitleModelID pgtype.UUID `json:"title_model_id"` + SearchProviderID pgtype.UUID `json:"search_provider_id"` + MemoryProviderID pgtype.UUID `json:"memory_provider_id"` + ImageModelID pgtype.UUID `json:"image_model_id"` + TtsModelID pgtype.UUID `json:"tts_model_id"` + BrowserContextID pgtype.UUID `json:"browser_context_id"` + ContextTokenBudget pgtype.Int4 `json:"context_token_budget"` + PersistFullToolResults bool `json:"persist_full_tool_results"` + ID pgtype.UUID `json:"id"` } type UpsertBotSettingsRow struct { - BotID pgtype.UUID `json:"bot_id"` - Language string `json:"language"` - ReasoningEnabled bool `json:"reasoning_enabled"` - ReasoningEffort string `json:"reasoning_effort"` - HeartbeatEnabled bool `json:"heartbeat_enabled"` - HeartbeatInterval int32 `json:"heartbeat_interval"` - HeartbeatPrompt string `json:"heartbeat_prompt"` - CompactionEnabled bool `json:"compaction_enabled"` - CompactionThreshold int32 `json:"compaction_threshold"` - CompactionRatio int32 `json:"compaction_ratio"` - ChatModelID pgtype.UUID `json:"chat_model_id"` - HeartbeatModelID pgtype.UUID `json:"heartbeat_model_id"` - CompactionModelID pgtype.UUID `json:"compaction_model_id"` - TitleModelID pgtype.UUID `json:"title_model_id"` - SearchProviderID pgtype.UUID `json:"search_provider_id"` - MemoryProviderID pgtype.UUID `json:"memory_provider_id"` - ImageModelID pgtype.UUID `json:"image_model_id"` - TtsModelID pgtype.UUID `json:"tts_model_id"` - BrowserContextID pgtype.UUID `json:"browser_context_id"` + BotID pgtype.UUID `json:"bot_id"` + Language string `json:"language"` + ReasoningEnabled bool `json:"reasoning_enabled"` + ReasoningEffort string `json:"reasoning_effort"` + HeartbeatEnabled bool `json:"heartbeat_enabled"` + HeartbeatInterval int32 `json:"heartbeat_interval"` + HeartbeatPrompt string `json:"heartbeat_prompt"` + CompactionEnabled bool `json:"compaction_enabled"` + CompactionThreshold int32 `json:"compaction_threshold"` + CompactionRatio int32 `json:"compaction_ratio"` + Timezone pgtype.Text `json:"timezone"` + ChatModelID pgtype.UUID `json:"chat_model_id"` + HeartbeatModelID pgtype.UUID `json:"heartbeat_model_id"` + CompactionModelID pgtype.UUID `json:"compaction_model_id"` + TitleModelID pgtype.UUID `json:"title_model_id"` + SearchProviderID pgtype.UUID `json:"search_provider_id"` + MemoryProviderID pgtype.UUID `json:"memory_provider_id"` + ImageModelID pgtype.UUID `json:"image_model_id"` + TtsModelID pgtype.UUID `json:"tts_model_id"` + BrowserContextID pgtype.UUID `json:"browser_context_id"` + ContextTokenBudget pgtype.Int4 `json:"context_token_budget"` + PersistFullToolResults bool `json:"persist_full_tool_results"` } func (q *Queries) UpsertBotSettings(ctx context.Context, arg UpsertBotSettingsParams) (UpsertBotSettingsRow, error) { @@ -235,6 +258,7 @@ func (q *Queries) UpsertBotSettings(ctx context.Context, arg UpsertBotSettingsPa arg.CompactionEnabled, arg.CompactionThreshold, arg.CompactionRatio, + arg.Timezone, arg.ChatModelID, arg.HeartbeatModelID, arg.CompactionModelID, @@ -244,6 +268,8 @@ func (q *Queries) UpsertBotSettings(ctx context.Context, arg UpsertBotSettingsPa arg.ImageModelID, arg.TtsModelID, arg.BrowserContextID, + arg.ContextTokenBudget, + arg.PersistFullToolResults, arg.ID, ) var i UpsertBotSettingsRow @@ -258,6 +284,7 @@ func (q *Queries) UpsertBotSettings(ctx context.Context, arg UpsertBotSettingsPa &i.CompactionEnabled, &i.CompactionThreshold, &i.CompactionRatio, + &i.Timezone, &i.ChatModelID, &i.HeartbeatModelID, &i.CompactionModelID, @@ -267,6 +294,8 @@ func (q *Queries) UpsertBotSettings(ctx context.Context, arg UpsertBotSettingsPa &i.ImageModelID, &i.TtsModelID, &i.BrowserContextID, + &i.ContextTokenBudget, + &i.PersistFullToolResults, ) return i, err } diff --git a/internal/email/adapters/generic/adapter.go b/internal/email/adapters/generic/adapter.go index 18ca8073..17748680 100644 --- a/internal/email/adapters/generic/adapter.go +++ b/internal/email/adapters/generic/adapter.go @@ -135,7 +135,7 @@ func (a *Adapter) StartReceiving(ctx context.Context, config map[string]any, han providerID, _ := config["_provider_id"].(string) - rctx, cancel := context.WithCancel(ctx) + rctx, cancel := context.WithCancel(ctx) //nolint:gosec // G118: cancel is stored in conn.cancel and called by Stop() conn := &imapConn{ logger: a.logger, host: host, diff --git a/internal/email/adapters/gmail/adapter.go b/internal/email/adapters/gmail/adapter.go index 1f15e917..6b459adc 100644 --- a/internal/email/adapters/gmail/adapter.go +++ b/internal/email/adapters/gmail/adapter.go @@ -146,7 +146,7 @@ func (a *Adapter) Send(ctx context.Context, config map[string]any, msg email.Out func (a *Adapter) StartReceiving(ctx context.Context, config map[string]any, handler email.InboundHandler) (email.Stopper, error) { providerID, _ := config["_provider_id"].(string) - rctx, cancel := context.WithCancel(ctx) + rctx, cancel := context.WithCancel(ctx) //nolint:gosec // G118: cancel is stored in conn.cancel and called by Stop() conn := &gmailImapConn{ adapter: a, config: config, diff --git a/internal/email/adapters/mailgun/adapter.go b/internal/email/adapters/mailgun/adapter.go index 5cc45cf6..2f591d84 100644 --- a/internal/email/adapters/mailgun/adapter.go +++ b/internal/email/adapters/mailgun/adapter.go @@ -122,7 +122,7 @@ func (a *Adapter) StartReceiving(ctx context.Context, config map[string]any, han providerID, _ := config["_provider_id"].(string) domain, _ := config["domain"].(string) - rctx, cancel := context.WithCancel(ctx) + rctx, cancel := context.WithCancel(ctx) //nolint:gosec // G118: cancel is stored in conn.cancel and called by Stop() conn := &pollConn{ logger: a.logger, client: newClient(config), diff --git a/internal/handlers/containerd_terminal.go b/internal/handlers/containerd_terminal.go index f2d11d0a..4f8f8958 100644 --- a/internal/handlers/containerd_terminal.go +++ b/internal/handlers/containerd_terminal.go @@ -6,6 +6,8 @@ import ( "log/slog" "net/http" "strconv" + "sync" + "time" "github.com/gorilla/websocket" "github.com/labstack/echo/v4" @@ -14,6 +16,10 @@ import ( pb "github.com/memohai/memoh/internal/workspace/bridgepb" ) +// terminalIdleTimeout closes inactive terminal WebSocket sessions to +// prevent leaked PTY processes. Reset on every inbound WebSocket message. +const terminalIdleTimeout = 30 * time.Minute + var terminalUpgrader = websocket.Upgrader{ CheckOrigin: func(_ *http.Request) bool { return true }, } @@ -106,6 +112,22 @@ func (h *ContainerdHandler) HandleTerminalWS(c echo.Context) error { done := make(chan struct{}) + // Idle timer: closes the connection if no client activity for terminalIdleTimeout. + var idleMu sync.Mutex + idleTimer := time.AfterFunc(terminalIdleTimeout, func() { + h.logger.Info("terminal idle timeout reached, closing", slog.String("bot_id", botID)) + _ = conn.WriteControl(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseGoingAway, "idle timeout"), + time.Now().Add(5*time.Second)) + _ = conn.Close() + }) + defer idleTimer.Stop() + resetIdle := func() { + idleMu.Lock() + idleTimer.Reset(terminalIdleTimeout) + idleMu.Unlock() + } + // gRPC output -> WebSocket go func() { defer close(done) @@ -135,6 +157,7 @@ func (h *ContainerdHandler) HandleTerminalWS(c echo.Context) error { _ = execStream.Close() return } + resetIdle() // client is active switch msgType { case websocket.BinaryMessage: if len(data) > 0 { diff --git a/internal/handlers/local_channel.go b/internal/handlers/local_channel.go index 337f9217..64b96b32 100644 --- a/internal/handlers/local_channel.go +++ b/internal/handlers/local_channel.go @@ -377,6 +377,10 @@ func (h *LocalChannelHandler) HandleWebSocket(c echo.Context) error { } var msg wsClientMessage if err := json.Unmarshal(raw, &msg); err != nil { + h.logger.Warn("ws: unmarshal failed", + slog.String("bot_id", botID), + slog.Any("error", err), + ) writer.SendJSON(map[string]string{"type": "error", "message": "invalid message format"}) continue } @@ -390,6 +394,7 @@ func (h *LocalChannelHandler) HandleWebSocket(c echo.Context) error { case "message": text := strings.TrimSpace(msg.Text) + sessionID := strings.TrimSpace(msg.SessionID) chatAttachments := make([]conversation.ChatAttachment, 0, len(msg.Attachments)) for _, rawAtt := range msg.Attachments { @@ -414,7 +419,6 @@ func (h *LocalChannelHandler) HandleWebSocket(c echo.Context) error { activeCancel = streamCancel eventCh := make(chan flow.WSStreamEvent, 64) - sessionID := strings.TrimSpace(msg.SessionID) var ( outboundAssetMu sync.Mutex outboundAssetRefs []messagepkg.AssetRef @@ -440,7 +444,7 @@ func (h *LocalChannelHandler) HandleWebSocket(c echo.Context) error { } if streamErr := h.resolver.StreamChatWS(streamCtx, req, eventCh, abortCh); streamErr != nil { if ctx.Err() == nil { - h.logger.Error("ws stream error", slog.Any("error", streamErr)) + h.logger.Error("ws stream error", slog.Any("error", streamErr), slog.String("bot_id", botID), slog.String("session_id", sessionID)) writer.SendJSON(map[string]string{"type": "error", "message": streamErr.Error()}) } } diff --git a/internal/handlers/mcp_session_test.go b/internal/handlers/mcp_session_test.go index 376fe2ca..2ca55e27 100644 --- a/internal/handlers/mcp_session_test.go +++ b/internal/handlers/mcp_session_test.go @@ -108,7 +108,7 @@ func jsonRPCSuccessResponse(id sdkjsonrpc.ID, payload map[string]any) *sdkjsonrp } func newTestMCPSession(conn *fakeMCPConnection) *mcpSession { - readCtx, cancelRead := context.WithCancel(context.Background()) + readCtx, cancelRead := context.WithCancel(context.Background()) //nolint:gosec // G118: cancelRead is stored in mcpSession.cancelRead return &mcpSession{ pending: map[string]chan *sdkjsonrpc.Response{}, conn: conn, diff --git a/internal/handlers/mcp_stdio.go b/internal/handlers/mcp_stdio.go index ac37e83b..9b8b2a12 100644 --- a/internal/handlers/mcp_stdio.go +++ b/internal/handlers/mcp_stdio.go @@ -706,7 +706,7 @@ func (h *ContainerdHandler) startContainerdMCPCommandSession(ctx context.Context stdoutR, stdoutW := io.Pipe() stderrR, stderrW := io.Pipe() - readCtx, cancelRead := context.WithCancel(context.Background()) + readCtx, cancelRead := context.WithCancel(context.Background()) //nolint:gosec // G118: cancelRead is stored in sess.cancelRead sess := &mcpSession{ stdin: stdinW, stdout: stdoutR, diff --git a/internal/handlers/mcp_tools_test.go b/internal/handlers/mcp_tools_test.go index 9b93a77a..814eedd1 100644 --- a/internal/handlers/mcp_tools_test.go +++ b/internal/handlers/mcp_tools_test.go @@ -43,7 +43,7 @@ func TestBuildToolCallPayloadFromRaw(t *testing.T) { func TestHandleMCPToolsWithoutGateway(t *testing.T) { e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/bots/bot-1/tools", strings.NewReader(`{"jsonrpc":"2.0","id":"1","method":"tools/list"}`)) + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/bots/bot-1/tools", strings.NewReader(`{"jsonrpc":"2.0","id":"1","method":"tools/list"}`)) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) rec := httptest.NewRecorder() c := e.NewContext(req, rec) @@ -108,7 +108,7 @@ func TestHandleMCPToolsWithGatewayAcceptCompatibility(t *testing.T) { toolGateway: toolGateway, } - listReq := httptest.NewRequest(http.MethodPost, "/bots/bot-1/tools", strings.NewReader(`{"jsonrpc":"2.0","id":"1","method":"tools/list"}`)) + listReq := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/bots/bot-1/tools", strings.NewReader(`{"jsonrpc":"2.0","id":"1","method":"tools/list"}`)) listReq.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) listReq.Header.Set("Accept", "application/json") listReq.Header.Set("X-Memoh-Channel-Identity-Id", "user-1") @@ -135,7 +135,7 @@ func TestHandleMCPToolsWithGatewayAcceptCompatibility(t *testing.T) { t.Fatalf("expected one tool, got: %#v", result["tools"]) } - callReq := httptest.NewRequest(http.MethodPost, "/bots/bot-1/tools", strings.NewReader(`{"jsonrpc":"2.0","id":"2","method":"tools/call","params":{"name":"echo_tool","arguments":{"input":"hello"}}}`)) + callReq := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/bots/bot-1/tools", strings.NewReader(`{"jsonrpc":"2.0","id":"2","method":"tools/call","params":{"name":"echo_tool","arguments":{"input":"hello"}}}`)) callReq.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) callReq.Header.Set("Accept", "application/json") callReq.Header.Set("X-Memoh-Channel-Identity-Id", "user-1") diff --git a/internal/handlers/message.go b/internal/handlers/message.go index 47b2b024..369f649e 100644 --- a/internal/handlers/message.go +++ b/internal/handlers/message.go @@ -188,7 +188,7 @@ func (h *MessageHandler) fillAssetMimeFromStorage(ctx context.Context, botID str } for i := range messages { for j := range messages[i].Assets { - a := &messages[i].Assets[j] + a := &messages[i].Assets[j] //nolint:gosec // G602: j is bounded by range loop if strings.TrimSpace(a.ContentHash) == "" { continue } diff --git a/internal/heartbeat/service.go b/internal/heartbeat/service.go index d19ae56f..7113595e 100644 --- a/internal/heartbeat/service.go +++ b/internal/heartbeat/service.go @@ -21,6 +21,10 @@ import ( const heartbeatTokenTTL = 10 * time.Minute +// heartbeatRunTimeout caps how long a single heartbeat execution may take. +// This prevents unbounded Generate() calls from hanging forever. +const heartbeatRunTimeout = 5 * time.Minute + // SessionCreator creates sessions for heartbeat runs. type SessionCreator interface { CreateSession(ctx context.Context, botID, sessionType string) (string, error) @@ -239,7 +243,9 @@ func (s *Service) scheduleJob(ctx context.Context, cfg Config) error { } spec := fmt.Sprintf("@every %dm", cfg.Interval) job := func() { - s.runHeartbeat(context.WithoutCancel(ctx), cfg) + runCtx, runCancel := context.WithTimeout(context.WithoutCancel(ctx), heartbeatRunTimeout) + defer runCancel() + s.runHeartbeat(runCtx, cfg) } entryID, err := s.cron.AddFunc(spec, job) if err != nil { diff --git a/internal/mcp/sources/federation/source.go b/internal/mcp/sources/federation/source.go index 84afaa3e..5bd4671e 100644 --- a/internal/mcp/sources/federation/source.go +++ b/internal/mcp/sources/federation/source.go @@ -15,6 +15,10 @@ import ( const cacheTTL = 5 * time.Second +// mcpCallTimeout caps individual MCP tool calls and tool listing to prevent +// stuck external MCP servers from blocking the agent indefinitely. +const mcpCallTimeout = 60 * time.Second + type ConnectionLister interface { ListActiveByBot(ctx context.Context, botID string) ([]mcpgw.Connection, error) } @@ -104,17 +108,20 @@ func (s *Source) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, arguments = map[string]any{} } + callCtx, callCancel := context.WithTimeout(ctx, mcpCallTimeout) + defer callCancel() + var ( payload map[string]any err error ) switch route.sourceType { case "http": - payload, err = s.gateway.CallHTTPConnectionTool(ctx, route.connection, route.originalName, arguments) + payload, err = s.gateway.CallHTTPConnectionTool(callCtx, route.connection, route.originalName, arguments) case "sse": - payload, err = s.gateway.CallSSEConnectionTool(ctx, route.connection, route.originalName, arguments) + payload, err = s.gateway.CallSSEConnectionTool(callCtx, route.connection, route.originalName, arguments) case "stdio": - payload, err = s.gateway.CallStdioConnectionTool(ctx, botID, route.connection, route.originalName, arguments) + payload, err = s.gateway.CallStdioConnectionTool(callCtx, botID, route.connection, route.originalName, arguments) default: return mcpgw.BuildToolErrorResult("unsupported federated source"), nil } @@ -172,17 +179,20 @@ func (s *Source) buildToolsAndRoutes(ctx context.Context, botID string) ([]mcpgw }) for _, connection := range items { var connTools []mcpgw.ToolDescriptor + listCtx, listCancel := context.WithTimeout(ctx, mcpCallTimeout) switch strings.ToLower(strings.TrimSpace(connection.Type)) { case "http": - connTools, err = s.gateway.ListHTTPConnectionTools(ctx, connection) + connTools, err = s.gateway.ListHTTPConnectionTools(listCtx, connection) case "sse": - connTools, err = s.gateway.ListSSEConnectionTools(ctx, connection) + connTools, err = s.gateway.ListSSEConnectionTools(listCtx, connection) case "stdio": - connTools, err = s.gateway.ListStdioConnectionTools(ctx, botID, connection) + connTools, err = s.gateway.ListStdioConnectionTools(listCtx, botID, connection) default: + listCancel() s.logger.Warn("unsupported mcp connection type", slog.String("connection_id", connection.ID), slog.String("type", connection.Type)) continue } + listCancel() if err != nil { s.logger.Warn("list tools from connection failed", slog.String("connection_id", connection.ID), slog.String("name", connection.Name), slog.Any("error", err)) continue diff --git a/internal/memory/adapters/builtin/formation.go b/internal/memory/adapters/builtin/formation.go index fadd966e..5edbb5ea 100644 --- a/internal/memory/adapters/builtin/formation.go +++ b/internal/memory/adapters/builtin/formation.go @@ -39,6 +39,7 @@ func runFormation(ctx context.Context, logger *slog.Logger, llm adapters.LLM, ru result := formationResult{} extracted, err := llm.Extract(ctx, adapters.ExtractRequest{ + BotID: botID, Messages: req.Messages, TimezoneLocation: req.TimezoneLocation, }) @@ -55,6 +56,7 @@ func runFormation(ctx context.Context, logger *slog.Logger, llm adapters.LLM, ru candidates := gatherCandidates(ctx, logger, runtime, botID, facts) decided, err := llm.Decide(ctx, adapters.DecideRequest{ + BotID: botID, Facts: facts, Candidates: candidates, }) diff --git a/internal/memory/adapters/types.go b/internal/memory/adapters/types.go index 46520112..e16f595b 100644 --- a/internal/memory/adapters/types.go +++ b/internal/memory/adapters/types.go @@ -149,6 +149,7 @@ type DeleteResponse struct { } type ExtractRequest struct { + BotID string `json:"bot_id,omitempty"` Messages []Message `json:"messages"` Filters map[string]any `json:"filters,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` @@ -167,6 +168,7 @@ type CandidateMemory struct { } type DecideRequest struct { + BotID string `json:"bot_id,omitempty"` Facts []string `json:"facts"` Candidates []CandidateMemory `json:"candidates"` Filters map[string]any `json:"filters,omitempty"` diff --git a/internal/models/models.go b/internal/models/models.go index f4cfddff..555a2fe7 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -437,6 +437,7 @@ func IsValidClientType(clientType ClientType) bool { } // SelectMemoryModel selects a chat model for memory operations. +// It only considers models from enabled providers. func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sqlc.Queries) (GetResponse, sqlc.Provider, error) { if modelsService == nil { return GetResponse{}, sqlc.Provider{}, errors.New("models service not configured") @@ -444,9 +445,9 @@ func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sql if queries == nil { return GetResponse{}, sqlc.Provider{}, errors.New("queries not configured") } - candidates, err := modelsService.ListByType(ctx, ModelTypeChat) + candidates, err := modelsService.ListEnabledByType(ctx, ModelTypeChat) if err != nil || len(candidates) == 0 { - return GetResponse{}, sqlc.Provider{}, errors.New("no chat models available for memory operations") + return GetResponse{}, sqlc.Provider{}, errors.New("no enabled chat models available for memory operations") } selected := candidates[0] provider, err := FetchProviderByID(ctx, queries, selected.ProviderID) @@ -456,8 +457,29 @@ func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sql return selected, provider, nil } -// SelectMemoryModelForBot delegates to SelectMemoryModel. -func SelectMemoryModelForBot(ctx context.Context, modelsService *Service, queries *sqlc.Queries, _ string) (GetResponse, sqlc.Provider, error) { +// SelectMemoryModelForBot selects a chat model for memory operations. +// If botID is provided, it attempts to use the bot's configured chat model first, +// falling back to the first enabled chat model globally. +func SelectMemoryModelForBot(ctx context.Context, modelsService *Service, queries *sqlc.Queries, chatModelID string) (GetResponse, sqlc.Provider, error) { + // If a specific model is configured (e.g. bot's chat_model_id), try to use it. + if chatModelID = strings.TrimSpace(chatModelID); chatModelID != "" { + model, err := modelsService.GetByModelID(ctx, chatModelID) + if err == nil && model.Type == ModelTypeChat { + provider, pErr := FetchProviderByID(ctx, queries, model.ProviderID) + if pErr == nil && provider.Enable { + return model, provider, nil + } + } + // UUID-based lookup fallback + model, err = modelsService.GetByID(ctx, chatModelID) + if err == nil && model.Type == ModelTypeChat { + provider, pErr := FetchProviderByID(ctx, queries, model.ProviderID) + if pErr == nil && provider.Enable { + return model, provider, nil + } + } + } + // Fallback: pick first enabled chat model globally. return SelectMemoryModel(ctx, modelsService, queries) } diff --git a/internal/pipeline/driver.go b/internal/pipeline/driver.go index f791af28..276b0b6a 100644 --- a/internal/pipeline/driver.go +++ b/internal/pipeline/driver.go @@ -111,7 +111,7 @@ func (d *DiscussDriver) NotifyRC(_ context.Context, sessionID string, rc Rendere d.mu.Lock() sess, ok := d.sessions[sessionID] if !ok { - sessCtx, cancel := context.WithCancel(context.Background()) + sessCtx, cancel := context.WithCancel(context.Background()) //nolint:gosec // G118: cancel is stored in sess.cancel sess = &discussSession{ config: config, rcCh: make(chan RenderedContext, 16), diff --git a/internal/prune/text.go b/internal/prune/text.go index c6eb6b36..07813629 100644 --- a/internal/prune/text.go +++ b/internal/prune/text.go @@ -8,8 +8,8 @@ import ( const ( DefaultMarker = "[memoh pruned]" - DefaultMaxBytes = 10 * 1024 - DefaultMaxLines = 250 + DefaultMaxBytes = 64 * 1024 + DefaultMaxLines = 2000 ) type Config struct { diff --git a/internal/schedule/service.go b/internal/schedule/service.go index 97441f77..cc13a4a8 100644 --- a/internal/schedule/service.go +++ b/internal/schedule/service.go @@ -27,15 +27,16 @@ type SessionCreator interface { } type Service struct { - queries *sqlc.Queries - cron *cron.Cron - parser cron.Parser - triggerer Triggerer - sessionCreator SessionCreator - jwtSecret string - logger *slog.Logger - mu sync.Mutex - jobs map[string]cron.EntryID + queries *sqlc.Queries + cron *cron.Cron + parser cron.Parser + triggerer Triggerer + sessionCreator SessionCreator + jwtSecret string + logger *slog.Logger + defaultLocation *time.Location + mu sync.Mutex + jobs map[string]cron.EntryID } func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, sessionCreator SessionCreator, runtimeConfig *boot.RuntimeConfig) *Service { @@ -46,14 +47,15 @@ func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, se } c := cron.New(cron.WithParser(parser), cron.WithLocation(location)) service := &Service{ - queries: queries, - cron: c, - parser: parser, - triggerer: triggerer, - sessionCreator: sessionCreator, - jwtSecret: runtimeConfig.JwtSecret, - logger: log.With(slog.String("service", "schedule")), - jobs: map[string]cron.EntryID{}, + queries: queries, + cron: c, + parser: parser, + triggerer: triggerer, + sessionCreator: sessionCreator, + jwtSecret: runtimeConfig.JwtSecret, + logger: log.With(slog.String("service", "schedule")), + defaultLocation: location, + jobs: map[string]cron.EntryID{}, } c.Start() return service @@ -240,6 +242,10 @@ func (s *Service) Trigger(ctx context.Context, scheduleID string) error { const scheduleTokenTTL = 10 * time.Minute +// scheduleRunTimeout caps how long a single schedule execution may take. +// This prevents unbounded Generate() calls from hanging forever. +const scheduleRunTimeout = 5 * time.Minute + func (s *Service) runSchedule(ctx context.Context, sched Schedule) error { if s.triggerer == nil { return errors.New("schedule triggerer not configured") @@ -484,14 +490,21 @@ func (s *Service) scheduleJob(ctx context.Context, schedule sqlc.Schedule) error return errors.New("schedule id missing") } job := func() { - if err := s.runSchedule(context.WithoutCancel(ctx), toSchedule(schedule)); err != nil { + runCtx, runCancel := context.WithTimeout(context.WithoutCancel(ctx), scheduleRunTimeout) + defer runCancel() + if err := s.runSchedule(runCtx, toSchedule(schedule)); err != nil { s.logger.Error("scheduled job failed", slog.String("schedule_id", schedule.ID.String()), slog.Any("error", err)) } } - entryID, err := s.cron.AddFunc(schedule.Pattern, job) + + // Resolve bot timezone so cron expressions are interpreted in the bot's + // configured timezone rather than the system default. + loc := s.resolveBotLocation(ctx, schedule.BotID) + sched, err := cron.NewParser(cron.SecondOptional | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor).Parse(schedule.Pattern) if err != nil { return err } + entryID := s.cron.Schedule(newLocationSchedule(sched, loc), cron.FuncJob(job)) s.mu.Lock() s.jobs[id] = entryID s.mu.Unlock() @@ -551,3 +564,51 @@ func toUUID(id string) pgtype.UUID { } return pgID } + +// resolveBotLocation returns the bot's configured timezone location, falling +// back to the system default when the bot has no timezone set or the value is +// invalid. +func (s *Service) resolveBotLocation(ctx context.Context, botID pgtype.UUID) *time.Location { + if s.queries == nil || !botID.Valid { + return s.defaultLocation + } + row, err := s.queries.GetBotByID(ctx, botID) + if err != nil { + return s.defaultLocation + } + if !row.Timezone.Valid { + return s.defaultLocation + } + tz := strings.TrimSpace(row.Timezone.String) + if tz == "" { + return s.defaultLocation + } + loc, err := time.LoadLocation(tz) + if err != nil { + s.logger.Warn("invalid bot timezone for schedule, using default", + slog.String("bot_id", botID.String()), + slog.String("timezone", tz), + slog.Any("error", err), + ) + return s.defaultLocation + } + return loc +} + +// locationSchedule wraps a cron.Schedule to evaluate Next() in a specific +// timezone, regardless of the global cron location. +type locationSchedule struct { + inner cron.Schedule + loc *time.Location +} + +func newLocationSchedule(inner cron.Schedule, loc *time.Location) cron.Schedule { + if loc == nil { + return inner + } + return &locationSchedule{inner: inner, loc: loc} +} + +func (s *locationSchedule) Next(t time.Time) time.Time { + return s.inner.Next(t.In(s.loc)) +} diff --git a/internal/settings/service.go b/internal/settings/service.go index df2f9635..4d3fa1c4 100644 --- a/internal/settings/service.go +++ b/internal/settings/service.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log/slog" + "math" "strings" "github.com/google/uuid" @@ -14,6 +15,7 @@ import ( "github.com/memohai/memoh/internal/acl" "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" + tzutil "github.com/memohai/memoh/internal/timezone" ) type Service struct { @@ -97,6 +99,17 @@ func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest if req.CompactionRatio != nil && *req.CompactionRatio >= 1 && *req.CompactionRatio <= 100 { current.CompactionRatio = *req.CompactionRatio } + if req.PersistFullToolResults != nil { + current.PersistFullToolResults = *req.PersistFullToolResults + } + timezoneValue := pgtype.Text{} + if req.Timezone != nil { + normalized, err := normalizeOptionalTimezone(*req.Timezone) + if err != nil { + return Settings{}, err + } + timezoneValue = normalized + } chatModelUUID := pgtype.UUID{} if value := strings.TrimSpace(req.ChatModelID); value != "" { modelID, err := s.resolveModelUUID(ctx, value) @@ -171,26 +184,38 @@ func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest } browserContextUUID = ctxID } + contextTokenBudgetValue := pgtype.Int4{} + if req.ContextTokenBudget != nil && *req.ContextTokenBudget >= 0 { + v := *req.ContextTokenBudget + if v > math.MaxInt32 { + v = math.MaxInt32 + } + contextTokenBudgetValue = pgtype.Int4{Int32: int32(v), Valid: true} //nolint:gosec // G115: clamped above + } + updated, err := s.queries.UpsertBotSettings(ctx, sqlc.UpsertBotSettingsParams{ - ID: pgID, - Language: current.Language, - ReasoningEnabled: current.ReasoningEnabled, - ReasoningEffort: current.ReasoningEffort, - HeartbeatEnabled: current.HeartbeatEnabled, - HeartbeatInterval: int32(current.HeartbeatInterval), //nolint:gosec // bounded by positive-only setter above - HeartbeatPrompt: "", - CompactionEnabled: current.CompactionEnabled, - CompactionThreshold: int32(current.CompactionThreshold), //nolint:gosec // bounded by non-negative setter above - CompactionRatio: int32(current.CompactionRatio), //nolint:gosec // bounded 1-100 above - ChatModelID: chatModelUUID, - HeartbeatModelID: heartbeatModelUUID, - CompactionModelID: compactionModelUUID, - TitleModelID: titleModelUUID, - ImageModelID: imageModelUUID, - SearchProviderID: searchProviderUUID, - MemoryProviderID: memoryProviderUUID, - TtsModelID: ttsModelUUID, - BrowserContextID: browserContextUUID, + ID: pgID, + Timezone: timezoneValue, + Language: current.Language, + ReasoningEnabled: current.ReasoningEnabled, + ReasoningEffort: current.ReasoningEffort, + HeartbeatEnabled: current.HeartbeatEnabled, + HeartbeatInterval: int32(current.HeartbeatInterval), //nolint:gosec // bounded by positive-only setter above + HeartbeatPrompt: "", + CompactionEnabled: current.CompactionEnabled, + CompactionThreshold: int32(current.CompactionThreshold), //nolint:gosec // bounded by non-negative setter above + CompactionRatio: int32(current.CompactionRatio), //nolint:gosec // bounded 1-100 above + ChatModelID: chatModelUUID, + HeartbeatModelID: heartbeatModelUUID, + CompactionModelID: compactionModelUUID, + TitleModelID: titleModelUUID, + ImageModelID: imageModelUUID, + SearchProviderID: searchProviderUUID, + MemoryProviderID: memoryProviderUUID, + TtsModelID: ttsModelUUID, + BrowserContextID: browserContextUUID, + ContextTokenBudget: contextTokenBudgetValue, + PersistFullToolResults: current.PersistFullToolResults, }) if err != nil { return Settings{}, err @@ -274,6 +299,7 @@ func normalizeBotSettingsReadRow(row sqlc.GetSettingsByBotIDRow) Settings { row.CompactionEnabled, row.CompactionThreshold, row.CompactionRatio, + row.Timezone, row.ChatModelID, row.HeartbeatModelID, row.CompactionModelID, @@ -283,6 +309,8 @@ func normalizeBotSettingsReadRow(row sqlc.GetSettingsByBotIDRow) Settings { row.MemoryProviderID, row.TtsModelID, row.BrowserContextID, + row.ContextTokenBudget, + row.PersistFullToolResults, ) } @@ -296,6 +324,7 @@ func normalizeBotSettingsWriteRow(row sqlc.UpsertBotSettingsRow) Settings { row.CompactionEnabled, row.CompactionThreshold, row.CompactionRatio, + row.Timezone, row.ChatModelID, row.HeartbeatModelID, row.CompactionModelID, @@ -305,6 +334,8 @@ func normalizeBotSettingsWriteRow(row sqlc.UpsertBotSettingsRow) Settings { row.MemoryProviderID, row.TtsModelID, row.BrowserContextID, + row.ContextTokenBudget, + row.PersistFullToolResults, ) } @@ -317,6 +348,7 @@ func normalizeBotSettingsFields( compactionEnabled bool, compactionThreshold int32, compactionRatio int32, + timezone pgtype.Text, chatModelID pgtype.UUID, heartbeatModelID pgtype.UUID, compactionModelID pgtype.UUID, @@ -326,8 +358,13 @@ func normalizeBotSettingsFields( memoryProviderID pgtype.UUID, ttsModelID pgtype.UUID, browserContextID pgtype.UUID, + contextTokenBudget pgtype.Int4, + persistFullToolResults bool, ) Settings { settings := normalizeBotSetting(language, "", reasoningEnabled, reasoningEffort, heartbeatEnabled, heartbeatInterval, compactionEnabled, compactionThreshold, compactionRatio) + if timezone.Valid { + settings.Timezone = timezone.String + } if chatModelID.Valid { settings.ChatModelID = uuid.UUID(chatModelID.Bytes).String() } @@ -355,6 +392,10 @@ func normalizeBotSettingsFields( if browserContextID.Valid { settings.BrowserContextID = uuid.UUID(browserContextID.Bytes).String() } + if contextTokenBudget.Valid { + settings.ContextTokenBudget = int(contextTokenBudget.Int32) + } + settings.PersistFullToolResults = persistFullToolResults return settings } @@ -402,3 +443,15 @@ func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype. } return rows[0].ID, nil } + +func normalizeOptionalTimezone(raw string) (pgtype.Text, error) { + normalized := strings.TrimSpace(raw) + if normalized == "" { + return pgtype.Text{}, nil + } + loc, _, err := tzutil.Resolve(normalized) + if err != nil { + return pgtype.Text{}, fmt.Errorf("invalid timezone: %w", err) + } + return pgtype.Text{String: loc.String(), Valid: true}, nil +} diff --git a/internal/settings/types.go b/internal/settings/types.go index 1efc09a5..ecb46a6f 100644 --- a/internal/settings/types.go +++ b/internal/settings/types.go @@ -7,45 +7,51 @@ const ( ) type Settings struct { - ChatModelID string `json:"chat_model_id"` - ImageModelID string `json:"image_model_id"` - SearchProviderID string `json:"search_provider_id"` - MemoryProviderID string `json:"memory_provider_id"` - TtsModelID string `json:"tts_model_id"` - BrowserContextID string `json:"browser_context_id"` - Language string `json:"language"` - AclDefaultEffect string `json:"acl_default_effect"` - ReasoningEnabled bool `json:"reasoning_enabled"` - ReasoningEffort string `json:"reasoning_effort"` - HeartbeatEnabled bool `json:"heartbeat_enabled"` - HeartbeatInterval int `json:"heartbeat_interval"` - HeartbeatModelID string `json:"heartbeat_model_id"` - TitleModelID string `json:"title_model_id"` - CompactionEnabled bool `json:"compaction_enabled"` - CompactionThreshold int `json:"compaction_threshold"` - CompactionRatio int `json:"compaction_ratio"` - CompactionModelID string `json:"compaction_model_id,omitempty"` - DiscussProbeModelID string `json:"discuss_probe_model_id,omitempty"` + ChatModelID string `json:"chat_model_id"` + ImageModelID string `json:"image_model_id"` + SearchProviderID string `json:"search_provider_id"` + MemoryProviderID string `json:"memory_provider_id"` + TtsModelID string `json:"tts_model_id"` + BrowserContextID string `json:"browser_context_id"` + Language string `json:"language"` + AclDefaultEffect string `json:"acl_default_effect"` + Timezone string `json:"timezone"` + ReasoningEnabled bool `json:"reasoning_enabled"` + ReasoningEffort string `json:"reasoning_effort"` + HeartbeatEnabled bool `json:"heartbeat_enabled"` + HeartbeatInterval int `json:"heartbeat_interval"` + HeartbeatModelID string `json:"heartbeat_model_id"` + TitleModelID string `json:"title_model_id"` + CompactionEnabled bool `json:"compaction_enabled"` + CompactionThreshold int `json:"compaction_threshold"` + CompactionRatio int `json:"compaction_ratio"` + CompactionModelID string `json:"compaction_model_id,omitempty"` + DiscussProbeModelID string `json:"discuss_probe_model_id,omitempty"` + ContextTokenBudget int `json:"context_token_budget"` + PersistFullToolResults bool `json:"persist_full_tool_results"` } type UpsertRequest struct { - ChatModelID string `json:"chat_model_id,omitempty"` - ImageModelID string `json:"image_model_id,omitempty"` - SearchProviderID string `json:"search_provider_id,omitempty"` - MemoryProviderID string `json:"memory_provider_id,omitempty"` - TtsModelID string `json:"tts_model_id,omitempty"` - BrowserContextID string `json:"browser_context_id,omitempty"` - Language string `json:"language,omitempty"` - AclDefaultEffect string `json:"acl_default_effect,omitempty"` - ReasoningEnabled *bool `json:"reasoning_enabled,omitempty"` - ReasoningEffort *string `json:"reasoning_effort,omitempty"` - HeartbeatEnabled *bool `json:"heartbeat_enabled,omitempty"` - HeartbeatInterval *int `json:"heartbeat_interval,omitempty"` - HeartbeatModelID string `json:"heartbeat_model_id,omitempty"` - TitleModelID string `json:"title_model_id,omitempty"` - CompactionEnabled *bool `json:"compaction_enabled,omitempty"` - CompactionThreshold *int `json:"compaction_threshold,omitempty"` - CompactionRatio *int `json:"compaction_ratio,omitempty"` - CompactionModelID *string `json:"compaction_model_id,omitempty"` - DiscussProbeModelID string `json:"discuss_probe_model_id,omitempty"` + ChatModelID string `json:"chat_model_id,omitempty"` + ImageModelID string `json:"image_model_id,omitempty"` + SearchProviderID string `json:"search_provider_id,omitempty"` + MemoryProviderID string `json:"memory_provider_id,omitempty"` + TtsModelID string `json:"tts_model_id,omitempty"` + BrowserContextID string `json:"browser_context_id,omitempty"` + Language string `json:"language,omitempty"` + AclDefaultEffect string `json:"acl_default_effect,omitempty"` + Timezone *string `json:"timezone,omitempty"` + ReasoningEnabled *bool `json:"reasoning_enabled,omitempty"` + ReasoningEffort *string `json:"reasoning_effort,omitempty"` + HeartbeatEnabled *bool `json:"heartbeat_enabled,omitempty"` + HeartbeatInterval *int `json:"heartbeat_interval,omitempty"` + HeartbeatModelID string `json:"heartbeat_model_id,omitempty"` + TitleModelID string `json:"title_model_id,omitempty"` + CompactionEnabled *bool `json:"compaction_enabled,omitempty"` + CompactionThreshold *int `json:"compaction_threshold,omitempty"` + CompactionRatio *int `json:"compaction_ratio,omitempty"` + CompactionModelID *string `json:"compaction_model_id,omitempty"` + DiscussProbeModelID string `json:"discuss_probe_model_id,omitempty"` + ContextTokenBudget *int `json:"context_token_budget,omitempty"` + PersistFullToolResults *bool `json:"persist_full_tool_results,omitempty"` } diff --git a/internal/storage/providers/fallback/provider.go b/internal/storage/providers/fallback/provider.go index 3c450a6a..78879693 100644 --- a/internal/storage/providers/fallback/provider.go +++ b/internal/storage/providers/fallback/provider.go @@ -5,6 +5,7 @@ package fallback import ( "context" + "fmt" "io" "github.com/memohai/memoh/internal/storage" @@ -84,12 +85,16 @@ func tryListPrefix(ctx context.Context, p storage.Provider, prefix string) ([]st // OpenContainerFile delegates to whichever inner provider implements // storage.ContainerFileOpener, trying the primary first. +// If the primary implements the interface but returns an error, that error +// is propagated rather than silently swallowed — the secondary is only tried +// when the primary does not implement ContainerFileOpener at all. func (p *Provider) OpenContainerFile(ctx context.Context, botID, containerPath string) (io.ReadCloser, error) { if opener, ok := p.primary.(storage.ContainerFileOpener); ok { rc, err := opener.OpenContainerFile(ctx, botID, containerPath) - if err == nil { - return rc, nil + if err != nil { + return nil, fmt.Errorf("primary provider: %w", err) } + return rc, nil } if opener, ok := p.secondary.(storage.ContainerFileOpener); ok { return opener.OpenContainerFile(ctx, botID, containerPath) diff --git a/internal/workspace/bridge/client.go b/internal/workspace/bridge/client.go index e98b599e..840fa951 100644 --- a/internal/workspace/bridge/client.go +++ b/internal/workspace/bridge/client.go @@ -16,6 +16,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/keepalive" pb "github.com/memohai/memoh/internal/workspace/bridgepb" ) @@ -45,6 +46,15 @@ func NewClientFromConn(conn *grpc.ClientConn) *Client { func Dial(_ context.Context, target string) (*Client, error) { conn, err := grpc.NewClient(target, grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: 30 * time.Second, // ping every 30s if idle + Timeout: 10 * time.Second, // wait 10s for ping ack + PermitWithoutStream: true, // ping even with no active RPC + }), + grpc.WithDefaultCallOptions( + grpc.MaxCallRecvMsgSize(16*1024*1024), + grpc.MaxCallSendMsgSize(16*1024*1024), + ), ) if err != nil { return nil, fmt.Errorf("grpc dial %s: %w", target, err) diff --git a/packages/icons/src/icons/Misskey.vue b/packages/icons/src/icons/Misskey.vue new file mode 100644 index 00000000..13c5fb89 --- /dev/null +++ b/packages/icons/src/icons/Misskey.vue @@ -0,0 +1,17 @@ + + + diff --git a/packages/icons/src/index.ts b/packages/icons/src/index.ts index 5ffb6024..a53faf1b 100644 --- a/packages/icons/src/index.ts +++ b/packages/icons/src/index.ts @@ -48,6 +48,7 @@ export { default as Microsoft } from './icons/Microsoft.vue' export { default as MicrosoftColor } from './icons/MicrosoftColor.vue' export { default as Minimax } from './icons/Minimax.vue' export { default as MinimaxColor } from './icons/MinimaxColor.vue' +export { default as Misskey } from './icons/Misskey.vue' export { default as Mistral } from './icons/Mistral.vue' export { default as MistralColor } from './icons/MistralColor.vue' export { default as Moonshot } from './icons/Moonshot.vue' diff --git a/packages/sdk/src/types.gen.ts b/packages/sdk/src/types.gen.ts index 9a9faaff..aaa3f5dd 100644 --- a/packages/sdk/src/types.gen.ts +++ b/packages/sdk/src/types.gen.ts @@ -1590,6 +1590,7 @@ export type SettingsUpsertRequest = { reasoning_effort?: string; reasoning_enabled?: boolean; search_provider_id?: string; + timezone?: string; title_model_id?: string; tts_model_id?: string; };