From 83b6ee608ca0ee18c5ce93f486975e9e66f64608 Mon Sep 17 00:00:00 2001 From: BBQ Date: Sat, 7 Feb 2026 12:03:24 +0800 Subject: [PATCH] refactor: bind container lifecycle to bot and improve schedule trigger flow - Add SetupBotContainer to ContainerLifecycle interface so containers are automatically created when a bot is created, matching the existing cleanup-on-delete behavior. - Refactor schedule tools to use bot-scoped API paths and pass identity context for proper authorization. - Introduce dedicated trigger-schedule endpoint in chat resolver with explicit schedule payload instead of reusing the generic chat path. - Generate short-lived JWT tokens for schedule trigger callbacks with resolved bot owner identity. - Validate required parameters in NewLLMClient and NewOpenAIEmbedder constructors, returning errors instead of falling back to defaults. - Add unit tests for schedule token generation and chat resolver. --- agent/src/tools/index.ts | 2 +- agent/src/tools/schedule.ts | 17 +-- cmd/agent/main.go | 2 +- internal/bots/service.go | 14 ++- internal/bots/types.go | 1 + internal/chat/resolver.go | 137 +++++++++++++++++++----- internal/chat/resolver_test.go | 160 +++++++++++++++++++++++++++++ internal/embeddings/embeddings.go | 17 +-- internal/embeddings/resolver.go | 6 +- internal/handlers/containerd.go | 106 +++++++++++++++++++ internal/memory/llm_client.go | 22 ++-- internal/memory/llm_client_test.go | 5 +- internal/router/channel.go | 21 ++-- internal/schedule/service.go | 51 ++++++++- internal/schedule/service_test.go | 91 ++++++++++++++++ internal/schedule/trigger.go | 3 +- 16 files changed, 583 insertions(+), 72 deletions(-) create mode 100644 internal/chat/resolver_test.go create mode 100644 internal/schedule/service_test.go diff --git a/agent/src/tools/index.ts b/agent/src/tools/index.ts index 77fbca3d..821d7c28 100644 --- a/agent/src/tools/index.ts +++ b/agent/src/tools/index.ts @@ -25,7 +25,7 @@ export const getTools = ( Object.assign(tools, webTools) } if (actions.includes(AgentAction.Schedule)) { - const scheduleTools = getScheduleTools({ fetch }) + const scheduleTools = getScheduleTools({ fetch, identity }) Object.assign(tools, scheduleTools) } if (actions.includes(AgentAction.Memory)) { diff --git a/agent/src/tools/schedule.ts b/agent/src/tools/schedule.ts index e3b1ad66..d44f4027 100644 --- a/agent/src/tools/schedule.ts +++ b/agent/src/tools/schedule.ts @@ -1,9 +1,11 @@ import { tool } from 'ai' import { z } from 'zod' import { AuthFetcher } from '..' +import type { IdentityContext } from '../types' export type ScheduleToolParams = { fetch: AuthFetcher + identity: IdentityContext } const ScheduleSchema = z.object({ @@ -15,12 +17,15 @@ const ScheduleSchema = z.object({ command: z.string(), }) -export const getScheduleTools = ({ fetch }: ScheduleToolParams) => { +export const getScheduleTools = ({ fetch, identity }: ScheduleToolParams) => { + const botId = identity.botId.trim() + const base = `/bots/${botId}/schedule` + const listSchedules = tool({ description: 'List schedules for current user', inputSchema: z.object({}), execute: async () => { - const response = await fetch('/schedule', { method: 'GET' }) + const response = await fetch(base, { method: 'GET' }) return response.json() }, }) @@ -31,7 +36,7 @@ export const getScheduleTools = ({ fetch }: ScheduleToolParams) => { id: z.string().describe('Schedule ID'), }), execute: async ({ id }) => { - const response = await fetch(`/schedule/${id}`, { method: 'GET' }) + const response = await fetch(`${base}/${id}`, { method: 'GET' }) return response.json() }, }) @@ -47,7 +52,7 @@ export const getScheduleTools = ({ fetch }: ScheduleToolParams) => { command: z.string(), }), execute: async (payload) => { - const response = await fetch('/schedule', { + const response = await fetch(base, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(payload), @@ -63,7 +68,7 @@ export const getScheduleTools = ({ fetch }: ScheduleToolParams) => { }), execute: async (payload) => { const { id, ...body } = payload - const response = await fetch(`/schedule/${id}`, { + const response = await fetch(`${base}/${id}`, { method: 'PUT', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(body), @@ -78,7 +83,7 @@ export const getScheduleTools = ({ fetch }: ScheduleToolParams) => { id: z.string(), }), execute: async ({ id }) => { - const response = await fetch(`/schedule/${id}`, { method: 'DELETE' }) + const response = await fetch(`${base}/${id}`, { method: 'DELETE' }) return response.status === 204 ? { success: true } : response.json() }, }) diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 7ebf8dea..dc3beec0 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -336,5 +336,5 @@ func (c *lazyLLMClient) resolve(ctx context.Context) (memory.LLM, error) { if clientType != "openai" && clientType != "openai-compat" { return nil, fmt.Errorf("memory provider client type not supported: %s", memoryProvider.ClientType) } - return memory.NewLLMClient(c.logger, memoryProvider.BaseUrl, memoryProvider.ApiKey, memoryModel.ModelID, c.timeout), nil + return memory.NewLLMClient(c.logger, memoryProvider.BaseUrl, memoryProvider.ApiKey, memoryModel.ModelID, c.timeout) } diff --git a/internal/bots/service.go b/internal/bots/service.go index 89fed6cd..567115af 100644 --- a/internal/bots/service.go +++ b/internal/bots/service.go @@ -112,7 +112,19 @@ func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotR if err != nil { return Bot{}, err } - return toBot(row) + bot, err := toBot(row) + if err != nil { + return Bot{}, err + } + if s.containerLifecycle != nil { + if err := s.containerLifecycle.SetupBotContainer(ctx, bot.ID); err != nil { + s.logger.Error("failed to setup bot container", + slog.String("bot_id", bot.ID), + slog.Any("error", err), + ) + } + } + return bot, nil } func (s *Service) Get(ctx context.Context, botID string) (Bot, error) { diff --git a/internal/bots/types.go b/internal/bots/types.go index 8858c2e6..2ecc9fc8 100644 --- a/internal/bots/types.go +++ b/internal/bots/types.go @@ -58,6 +58,7 @@ type ListMembersResponse struct { // ContainerLifecycle handles container lifecycle events bound to bot operations. type ContainerLifecycle interface { + SetupBotContainer(ctx context.Context, botID string) error CleanupBotContainer(ctx context.Context, botID string) error } diff --git a/internal/chat/resolver.go b/internal/chat/resolver.go index b6bf8042..ad460fe4 100644 --- a/internal/chat/resolver.go +++ b/internal/chat/resolver.go @@ -111,6 +111,30 @@ type gatewayResponse struct { Skills []string `json:"skills"` } +// gatewaySchedule matches the agent gateway ScheduleModel for /chat/trigger-schedule. +type gatewaySchedule struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Pattern string `json:"pattern"` + MaxCalls *int `json:"maxCalls,omitempty"` + Command string `json:"command"` +} + +// triggerScheduleRequest is the payload for POST /chat/trigger-schedule. +type triggerScheduleRequest struct { + Model gatewayModelConfig `json:"model"` + ActiveContextTime int `json:"activeContextTime"` + Channels []string `json:"channels"` + CurrentChannel string `json:"currentChannel"` + AllowedActions []string `json:"allowedActions,omitempty"` + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` + Identity gatewayIdentity `json:"identity"` + Attachments []any `json:"attachments"` + Schedule gatewaySchedule `json:"schedule"` +} + // --- resolved context (shared by Chat / StreamChat / TriggerSchedule) --- type resolvedContext struct { @@ -226,7 +250,7 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err // --- TriggerSchedule --- -// TriggerSchedule executes a scheduled command through the chat gateway. +// TriggerSchedule executes a scheduled command through the agent gateway trigger-schedule endpoint. func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error { if strings.TrimSpace(botID) == "" { return fmt.Errorf("bot id is required") @@ -234,24 +258,52 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc if strings.TrimSpace(payload.Command) == "" { return fmt.Errorf("schedule command is required") } + + sessionID := "schedule:" + payload.ID req := ChatRequest{ BotID: botID, - SessionID: "schedule:" + payload.ID, + SessionID: sessionID, Query: payload.Command, + UserID: payload.OwnerUserID, Token: token, } rc, err := r.resolve(ctx, req) if err != nil { return err } - rc.payload.Identity.ContactID = botID - rc.payload.Identity.ContactName = "Scheduler" - resp, err := r.postChat(ctx, rc.payload, token) + triggerReq := triggerScheduleRequest{ + Model: rc.payload.Model, + ActiveContextTime: rc.payload.ActiveContextTime, + Channels: rc.payload.Channels, + CurrentChannel: rc.payload.CurrentChannel, + AllowedActions: rc.payload.AllowedActions, + Messages: rc.payload.Messages, + Skills: rc.payload.Skills, + Identity: gatewayIdentity{ + BotID: rc.payload.Identity.BotID, + SessionID: rc.payload.Identity.SessionID, + ContainerID: rc.payload.Identity.ContainerID, + ContactID: botID, + ContactName: "Scheduler", + UserID: payload.OwnerUserID, + }, + Attachments: rc.payload.Attachments, + Schedule: gatewaySchedule{ + ID: payload.ID, + Name: payload.Name, + Description: payload.Description, + Pattern: payload.Pattern, + MaxCalls: payload.MaxCalls, + Command: payload.Command, + }, + } + + resp, err := r.postTriggerSchedule(ctx, triggerReq, token) if err != nil { return err } - return r.storeRound(ctx, botID, req.SessionID, payload.Command, resp.Messages, resp.Skills) + return r.storeRound(ctx, botID, sessionID, payload.Command, resp.Messages, resp.Skills) } // --- StreamChat --- @@ -319,6 +371,47 @@ func (r *Resolver) postChat(ctx context.Context, payload gatewayRequest, token s return parsed, nil } +// postTriggerSchedule sends a trigger-schedule request to the agent gateway. +func (r *Resolver) postTriggerSchedule(ctx context.Context, payload triggerScheduleRequest, token string) (gatewayResponse, error) { + body, err := json.Marshal(payload) + if err != nil { + return gatewayResponse{}, err + } + url := r.gatewayBaseURL + "/chat/trigger-schedule" + r.logger.Info("gateway trigger-schedule request", slog.String("url", url), slog.String("schedule_id", payload.Schedule.ID)) + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return gatewayResponse{}, err + } + httpReq.Header.Set("Content-Type", "application/json") + if strings.TrimSpace(token) != "" { + httpReq.Header.Set("Authorization", token) + } + + resp, err := r.httpClient.Do(httpReq) + if err != nil { + return gatewayResponse{}, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return gatewayResponse{}, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + r.logger.Error("gateway trigger-schedule error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(respBody), 300))) + return gatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody))) + } + + var parsed gatewayResponse + if err := json.Unmarshal(respBody, &parsed); err != nil { + r.logger.Error("gateway trigger-schedule response parse failed", slog.String("body_prefix", truncate(string(respBody), 300)), slog.Any("error", err)) + return gatewayResponse{}, fmt.Errorf("failed to parse gateway response: %w", err) + } + return parsed, nil +} + func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, botID, sessionID, query, token string, chunkCh chan<- StreamChunk) error { body, err := json.Marshal(payload) if err != nil { @@ -567,12 +660,16 @@ func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, us reso modelID := strings.TrimSpace(req.Model) providerFilter := strings.TrimSpace(req.Provider) - // Priority: request model > user settings > first available. + // Priority: request model > user settings. No implicit fallback. if modelID == "" && providerFilter == "" && strings.TrimSpace(us.ChatModelID) != "" { modelID = us.ChatModelID } - if modelID != "" && providerFilter == "" { + if modelID == "" { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not configured: specify model in request or user settings") + } + + if providerFilter == "" { return r.fetchChatModel(ctx, modelID) } @@ -580,26 +677,16 @@ func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, us reso if err != nil { return models.GetResponse{}, sqlc.LlmProvider{}, err } - if modelID != "" { - for _, m := range candidates { - if m.ModelID == modelID { - prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID) - if err != nil { - return models.GetResponse{}, sqlc.LlmProvider{}, err - } - return m, prov, nil + for _, m := range candidates { + if m.ModelID == modelID { + prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err } + return m, prov, nil } - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not found") } - if len(candidates) == 0 { - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available") - } - prov, err := models.FetchProviderByID(ctx, r.queries, candidates[0].LlmProviderID) - if err != nil { - return models.GetResponse{}, sqlc.LlmProvider{}, err - } - return candidates[0], prov, nil + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model %q not found for provider %q", modelID, providerFilter) } func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) { diff --git a/internal/chat/resolver_test.go b/internal/chat/resolver_test.go new file mode 100644 index 00000000..6fc40224 --- /dev/null +++ b/internal/chat/resolver_test.go @@ -0,0 +1,160 @@ +package chat + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestPostTriggerSchedule_Endpoint(t *testing.T) { + var capturedPath string + var capturedBody []byte + var capturedAuth string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + capturedAuth = r.Header.Get("Authorization") + capturedBody, _ = io.ReadAll(r.Body) + resp := gatewayResponse{ + Messages: []ModelMessage{{Role: "assistant", Content: NewTextContent("ok")}}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + resolver := &Resolver{ + gatewayBaseURL: srv.URL, + httpClient: &http.Client{Timeout: 5 * time.Second}, + logger: slog.Default(), + } + + maxCalls := 5 + req := triggerScheduleRequest{ + Model: gatewayModelConfig{ + ModelID: "gpt-4", + ClientType: "openai", + APIKey: "sk-test", + BaseURL: "https://api.openai.com", + }, + ActiveContextTime: 1440, + Channels: []string{}, + Messages: []ModelMessage{}, + Skills: []string{}, + Identity: gatewayIdentity{ + BotID: "bot-123", + SessionID: "schedule:sched-1", + ContainerID: "mcp-bot-123", + ContactID: "bot-123", + ContactName: "Scheduler", + UserID: "owner-user-1", + }, + Attachments: []any{}, + Schedule: gatewaySchedule{ + ID: "sched-1", + Name: "daily report", + Description: "generate daily report", + Pattern: "0 9 * * *", + MaxCalls: &maxCalls, + Command: "generate the daily report", + }, + } + + resp, err := resolver.postTriggerSchedule(context.Background(), req, "Bearer test-token") + if err != nil { + t.Fatalf("postTriggerSchedule returned error: %v", err) + } + + if capturedPath != "/chat/trigger-schedule" { + t.Errorf("expected path /chat/trigger-schedule, got %s", capturedPath) + } + if capturedAuth != "Bearer test-token" { + t.Errorf("expected Authorization header 'Bearer test-token', got %s", capturedAuth) + } + if len(resp.Messages) != 1 { + t.Errorf("expected 1 message, got %d", len(resp.Messages)) + } + + var body map[string]any + if err := json.Unmarshal(capturedBody, &body); err != nil { + t.Fatalf("failed to parse captured body: %v", err) + } + schedule, ok := body["schedule"].(map[string]any) + if !ok { + t.Fatal("expected 'schedule' field in request body") + } + if schedule["id"] != "sched-1" { + t.Errorf("expected schedule.id=sched-1, got %v", schedule["id"]) + } + if schedule["command"] != "generate the daily report" { + t.Errorf("expected schedule.command, got %v", schedule["command"]) + } + if _, hasQuery := body["query"]; hasQuery { + t.Error("trigger-schedule request should not contain 'query' field") + } +} + +func TestPostTriggerSchedule_NoAuth(t *testing.T) { + var capturedAuth string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAuth = r.Header.Get("Authorization") + resp := gatewayResponse{Messages: []ModelMessage{}} + json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + resolver := &Resolver{ + gatewayBaseURL: srv.URL, + httpClient: &http.Client{Timeout: 5 * time.Second}, + logger: slog.Default(), + } + + req := triggerScheduleRequest{ + Channels: []string{}, + Messages: []ModelMessage{}, + Skills: []string{}, + Attachments: []any{}, + Schedule: gatewaySchedule{ID: "s1", Command: "test"}, + } + + _, err := resolver.postTriggerSchedule(context.Background(), req, "") + if err != nil { + t.Fatalf("postTriggerSchedule returned error: %v", err) + } + if capturedAuth != "" { + t.Errorf("expected no Authorization header, got %s", capturedAuth) + } +} + +func TestPostTriggerSchedule_GatewayError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("internal error")) + })) + defer srv.Close() + + resolver := &Resolver{ + gatewayBaseURL: srv.URL, + httpClient: &http.Client{Timeout: 5 * time.Second}, + logger: slog.Default(), + } + + req := triggerScheduleRequest{ + Channels: []string{}, + Messages: []ModelMessage{}, + Skills: []string{}, + Attachments: []any{}, + Schedule: gatewaySchedule{ID: "s1", Command: "test"}, + } + + _, err := resolver.postTriggerSchedule(context.Background(), req, "Bearer tok") + if err == nil { + t.Fatal("expected error for 500 response") + } +} diff --git a/internal/embeddings/embeddings.go b/internal/embeddings/embeddings.go index e907c5b4..0bc98fae 100644 --- a/internal/embeddings/embeddings.go +++ b/internal/embeddings/embeddings.go @@ -37,15 +37,18 @@ type openAIEmbeddingResponse struct { } `json:"data"` } -func NewOpenAIEmbedder(log *slog.Logger, apiKey, baseURL, model string, dims int, timeout time.Duration) *OpenAIEmbedder { - if baseURL == "" { - baseURL = "https://api.openai.com" +func NewOpenAIEmbedder(log *slog.Logger, apiKey, baseURL, model string, dims int, timeout time.Duration) (*OpenAIEmbedder, error) { + if strings.TrimSpace(baseURL) == "" { + return nil, fmt.Errorf("openai embedder: base url is required") } - if model == "" { - model = "text-embedding-3-small" + if strings.TrimSpace(apiKey) == "" { + return nil, fmt.Errorf("openai embedder: api key is required") + } + if strings.TrimSpace(model) == "" { + return nil, fmt.Errorf("openai embedder: model is required") } if dims <= 0 { - dims = 1536 + return nil, fmt.Errorf("openai embedder: dimensions must be positive") } if timeout <= 0 { timeout = 10 * time.Second @@ -59,7 +62,7 @@ func NewOpenAIEmbedder(log *slog.Logger, apiKey, baseURL, model string, dims int http: &http.Client{ Timeout: timeout, }, - } + }, nil } func (e *OpenAIEmbedder) Dimensions() int { diff --git a/internal/embeddings/resolver.go b/internal/embeddings/resolver.go index 07e3bfbb..48d9d98c 100644 --- a/internal/embeddings/resolver.go +++ b/internal/embeddings/resolver.go @@ -130,10 +130,10 @@ func (r *Resolver) Embed(ctx context.Context, req Request) (Result, error) { if req.Provider != ProviderOpenAI { return Result{}, errors.New("provider not implemented") } - if strings.TrimSpace(provider.ApiKey) == "" { - return Result{}, errors.New("openai api key is required") + embedder, err := NewOpenAIEmbedder(r.logger, provider.ApiKey, provider.BaseUrl, req.Model, req.Dimensions, timeout) + if err != nil { + return Result{}, err } - embedder := NewOpenAIEmbedder(r.logger, provider.ApiKey, provider.BaseUrl, req.Model, req.Dimensions, timeout) vector, err := embedder.Embed(ctx, req.Input.Text) if err != nil { return Result{}, err diff --git a/internal/handlers/containerd.go b/internal/handlers/containerd.go index c6789315..bb0fb5ab 100644 --- a/internal/handlers/containerd.go +++ b/internal/handlers/containerd.go @@ -641,6 +641,112 @@ func (h *ContainerdHandler) authorizeBotAccess(ctx context.Context, actorID, bot return bot, nil } +// SetupBotContainer creates and starts the MCP container for a bot. +func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string) error { + containerID := mcp.ContainerPrefix + botID + + image := strings.TrimSpace(h.cfg.BusyboxImage) + if image == "" { + image = config.DefaultBusyboxImg + } + snapshotter := strings.TrimSpace(h.cfg.Snapshotter) + + if strings.TrimSpace(h.namespace) != "" { + ctx = namespaces.WithNamespace(ctx, h.namespace) + } + + dataRoot := strings.TrimSpace(h.cfg.DataRoot) + if dataRoot == "" { + dataRoot = config.DefaultDataRoot + } + dataRoot, _ = filepath.Abs(dataRoot) + dataMount := strings.TrimSpace(h.cfg.DataMount) + if dataMount == "" { + dataMount = config.DefaultDataMount + } + dataDir := filepath.Join(dataRoot, "bots", botID) + if err := os.MkdirAll(dataDir, 0o755); err != nil { + return err + } + if err := os.MkdirAll(filepath.Join(dataDir, ".skills"), 0o755); err != nil { + return err + } + + specOpts := []oci.SpecOpts{ + oci.WithMounts([]specs.Mount{ + { + Destination: dataMount, + Type: "bind", + Source: dataDir, + Options: []string{"rbind", "rw"}, + }, + { + Destination: "/app", + Type: "bind", + Source: dataDir, + Options: []string{"rbind", "rw"}, + }, + }), + oci.WithProcessArgs("/bin/sh", "-lc", "sleep 2147483647"), + } + + _, err := h.service.CreateContainer(ctx, ctr.CreateContainerRequest{ + ID: containerID, + ImageRef: image, + Snapshotter: snapshotter, + Labels: map[string]string{ + mcp.BotLabelKey: botID, + }, + SpecOpts: specOpts, + }) + if err != nil && !errdefs.IsAlreadyExists(err) { + return err + } + + if h.queries != nil { + pgBotID, parseErr := parsePgUUID(botID) + if parseErr == nil { + ns := strings.TrimSpace(h.namespace) + if ns == "" { + ns = "default" + } + _ = h.queries.UpsertContainer(ctx, dbsqlc.UpsertContainerParams{ + BotID: pgBotID, + ContainerID: containerID, + ContainerName: containerID, + Image: image, + Status: "created", + Namespace: ns, + AutoStart: true, + HostPath: pgtype.Text{String: dataDir, Valid: true}, + ContainerPath: dataMount, + }) + } + } + + fifoDir, err := h.taskFIFODir() + if err != nil { + return err + } + if _, err := h.service.StartTask(ctx, containerID, &ctr.StartTaskOptions{ + UseStdio: false, + FIFODir: fifoDir, + }); err == nil { + if h.queries != nil { + if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { + _ = h.queries.UpdateContainerStarted(ctx, pgBotID) + } + } + } else { + h.logger.Error("setup bot container: task start failed", + slog.String("bot_id", botID), + slog.String("container_id", containerID), + slog.Any("error", err), + ) + } + return nil +} + // CleanupBotContainer removes the containerd container and DB record for a bot. func (h *ContainerdHandler) CleanupBotContainer(ctx context.Context, botID string) error { containerID, err := h.botContainerID(ctx, botID) diff --git a/internal/memory/llm_client.go b/internal/memory/llm_client.go index 2eb76d2f..3b91c65a 100644 --- a/internal/memory/llm_client.go +++ b/internal/memory/llm_client.go @@ -20,29 +20,31 @@ type LLMClient struct { http *http.Client } -func NewLLMClient(log *slog.Logger, baseURL, apiKey, model string, timeout time.Duration) *LLMClient { +func NewLLMClient(log *slog.Logger, baseURL, apiKey, model string, timeout time.Duration) (*LLMClient, error) { + if strings.TrimSpace(baseURL) == "" { + return nil, fmt.Errorf("llm client: base url is required") + } + if strings.TrimSpace(apiKey) == "" { + return nil, fmt.Errorf("llm client: api key is required") + } + if strings.TrimSpace(model) == "" { + return nil, fmt.Errorf("llm client: model is required") + } if log == nil { log = slog.Default() } - if baseURL == "" { - baseURL = "https://api.openai.com/v1" - } - baseURL = strings.TrimRight(baseURL, "/") - if model == "" { - model = "gpt-4.1-nano" - } if timeout <= 0 { timeout = 10 * time.Second } return &LLMClient{ - baseURL: baseURL, + baseURL: strings.TrimRight(baseURL, "/"), apiKey: apiKey, model: model, logger: log.With(slog.String("client", "llm")), http: &http.Client{ Timeout: timeout, }, - } + }, nil } func (c *LLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) { diff --git a/internal/memory/llm_client_test.go b/internal/memory/llm_client_test.go index e7770a23..4633d4d5 100644 --- a/internal/memory/llm_client_test.go +++ b/internal/memory/llm_client_test.go @@ -20,7 +20,10 @@ func TestLLMClientExtract(t *testing.T) { })) defer server.Close() - client := NewLLMClient(nil, server.URL, "test-key", "gpt-4.1-nano-2025-04-14", 0) + client, err := NewLLMClient(nil, server.URL, "test-key", "gpt-4.1-nano-2025-04-14", 0) + if err != nil { + t.Fatalf("new llm client: %v", err) + } resp, err := client.Extract(context.Background(), ExtractRequest{ Messages: []Message{{Role: "user", Content: "hi"}}, }) diff --git a/internal/router/channel.go b/internal/router/channel.go index c652c05c..3fadd484 100644 --- a/internal/router/channel.go +++ b/internal/router/channel.go @@ -135,16 +135,16 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel desc, _ = p.registry.GetDescriptor(msg.Channel) } resp, err := p.chat.Chat(ctx, chat.ChatRequest{ - BotID: identity.BotID, - SessionID: identity.SessionID, - Token: token, - UserID: identity.UserID, - ContactID: identity.ContactID, - ContactName: strings.TrimSpace(identity.Contact.DisplayName), - ContactAlias: strings.TrimSpace(identity.Contact.Alias), - ReplyTarget: strings.TrimSpace(msg.ReplyTarget), - SessionToken: sessionToken, - Query: text, + BotID: identity.BotID, + SessionID: identity.SessionID, + Token: token, + UserID: identity.UserID, + ContactID: identity.ContactID, + ContactName: strings.TrimSpace(identity.Contact.DisplayName), + ContactAlias: strings.TrimSpace(identity.Contact.Alias), + ReplyTarget: strings.TrimSpace(msg.ReplyTarget), + SessionToken: sessionToken, + Query: text, CurrentChannel: msg.Channel.String(), Channels: []string{msg.Channel.String()}, }) @@ -526,7 +526,6 @@ func isMessagingToolDuplicate(text string, sentTexts []string) bool { return false } - func (p *ChannelInboundProcessor) requireIdentity(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) (IdentityState, error) { if state, ok := IdentityStateFromContext(ctx); ok { return state, nil diff --git a/internal/schedule/service.go b/internal/schedule/service.go index 035b7382..5af044b2 100644 --- a/internal/schedule/service.go +++ b/internal/schedule/service.go @@ -7,12 +7,14 @@ import ( "log/slog" "strings" "sync" + "time" "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/robfig/cron/v3" + "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -214,6 +216,8 @@ func (s *Service) Trigger(ctx context.Context, scheduleID string) error { return s.runSchedule(ctx, schedule) } +const scheduleTokenTTL = 10 * time.Minute + func (s *Service) runSchedule(ctx context.Context, schedule Schedule) error { if s.triggerer == nil { return fmt.Errorf("schedule triggerer not configured") @@ -225,18 +229,55 @@ func (s *Service) runSchedule(ctx context.Context, schedule Schedule) error { if !updated.Enabled { s.removeJob(schedule.ID) } - token := "" - if err := s.triggerer.TriggerSchedule(ctx, schedule.BotID, TriggerPayload{ + + ownerUserID, err := s.resolveBotOwner(ctx, schedule.BotID) + if err != nil { + return fmt.Errorf("resolve bot owner: %w", err) + } + + token, err := s.generateTriggerToken(ownerUserID) + if err != nil { + return fmt.Errorf("generate trigger token: %w", err) + } + + return s.triggerer.TriggerSchedule(ctx, schedule.BotID, TriggerPayload{ ID: schedule.ID, Name: schedule.Name, Description: schedule.Description, Pattern: schedule.Pattern, MaxCalls: schedule.MaxCalls, Command: schedule.Command, - }, token); err != nil { - return err + OwnerUserID: ownerUserID, + }, token) +} + +// resolveBotOwner returns the owner user ID for the given bot. +func (s *Service) resolveBotOwner(ctx context.Context, botID string) (string, error) { + pgBotID, err := parseUUID(botID) + if err != nil { + return "", err } - return nil + bot, err := s.queries.GetBotByID(ctx, pgBotID) + if err != nil { + return "", fmt.Errorf("get bot: %w", err) + } + ownerID := toUUIDString(bot.OwnerUserID) + if ownerID == "" { + return "", fmt.Errorf("bot owner not found") + } + return ownerID, nil +} + +// generateTriggerToken creates a short-lived JWT for schedule trigger callbacks. +func (s *Service) generateTriggerToken(userID string) (string, error) { + if strings.TrimSpace(s.jwtSecret) == "" { + return "", fmt.Errorf("jwt secret not configured") + } + signed, _, err := auth.GenerateToken(userID, s.jwtSecret, scheduleTokenTTL) + if err != nil { + return "", err + } + return "Bearer " + signed, nil } func (s *Service) scheduleJob(schedule sqlc.Schedule) error { diff --git a/internal/schedule/service_test.go b/internal/schedule/service_test.go new file mode 100644 index 00000000..15ec87fc --- /dev/null +++ b/internal/schedule/service_test.go @@ -0,0 +1,91 @@ +package schedule + +import ( + "context" + "log/slog" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +type mockTriggerer struct { + called bool + botID string + payload TriggerPayload + token string +} + +func (m *mockTriggerer) TriggerSchedule(_ context.Context, botID string, payload TriggerPayload, token string) error { + m.called = true + m.botID = botID + m.payload = payload + m.token = token + return nil +} + +func TestGenerateTriggerToken(t *testing.T) { + secret := "test-secret-key-for-schedule" + svc := &Service{ + jwtSecret: secret, + logger: slog.Default(), + } + userID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + tok, err := svc.generateTriggerToken(userID) + if err != nil { + t.Fatalf("generateTriggerToken returned error: %v", err) + } + if !strings.HasPrefix(tok, "Bearer ") { + t.Fatalf("expected Bearer prefix, got: %s", tok) + } + + raw := strings.TrimPrefix(tok, "Bearer ") + parsed, err := jwt.Parse(raw, func(token *jwt.Token) (any, error) { + return []byte(secret), nil + }) + if err != nil { + t.Fatalf("failed to parse JWT: %v", err) + } + claims, ok := parsed.Claims.(jwt.MapClaims) + if !ok { + t.Fatal("expected MapClaims") + } + if sub, _ := claims["sub"].(string); sub != userID { + t.Errorf("expected sub=%s, got=%s", userID, sub) + } + if uid, _ := claims["user_id"].(string); uid != userID { + t.Errorf("expected user_id=%s, got=%s", userID, uid) + } + exp, _ := claims["exp"].(float64) + if exp == 0 { + t.Fatal("expected non-zero exp") + } + expTime := time.Unix(int64(exp), 0) + if expTime.Before(time.Now().Add(9 * time.Minute)) { + t.Error("token expires too soon") + } +} + +func TestGenerateTriggerToken_EmptySecret(t *testing.T) { + svc := &Service{ + jwtSecret: "", + logger: slog.Default(), + } + _, err := svc.generateTriggerToken("user-123") + if err == nil { + t.Fatal("expected error for empty secret") + } +} + +func TestGenerateTriggerToken_EmptyUserID(t *testing.T) { + svc := &Service{ + jwtSecret: "some-secret", + logger: slog.Default(), + } + _, err := svc.generateTriggerToken("") + if err == nil { + t.Fatal("expected error for empty user ID") + } +} diff --git a/internal/schedule/trigger.go b/internal/schedule/trigger.go index f99c9a2e..e2d3b5ab 100644 --- a/internal/schedule/trigger.go +++ b/internal/schedule/trigger.go @@ -2,7 +2,7 @@ package schedule import "context" -// TriggerPayload 描述触发定时任务时传递给聊天侧的参数。 +// TriggerPayload describes the parameters passed to the chat side when a schedule triggers. type TriggerPayload struct { ID string Name string @@ -10,6 +10,7 @@ type TriggerPayload struct { Pattern string MaxCalls *int Command string + OwnerUserID string } // Triggerer 负责触发与聊天相关的调度执行。