refactor: bind container lifecycle to bot and improve schedule trigger flow

- Add SetupBotContainer to ContainerLifecycle interface so containers
  are automatically created when a bot is created, matching the existing
  cleanup-on-delete behavior.
- Refactor schedule tools to use bot-scoped API paths and pass identity
  context for proper authorization.
- Introduce dedicated trigger-schedule endpoint in chat resolver with
  explicit schedule payload instead of reusing the generic chat path.
- Generate short-lived JWT tokens for schedule trigger callbacks with
  resolved bot owner identity.
- Validate required parameters in NewLLMClient and NewOpenAIEmbedder
  constructors, returning errors instead of falling back to defaults.
- Add unit tests for schedule token generation and chat resolver.
This commit is contained in:
BBQ
2026-02-07 12:03:24 +08:00
parent a9596ab3a8
commit 83b6ee608c
16 changed files with 583 additions and 72 deletions
+1 -1
View File
@@ -25,7 +25,7 @@ export const getTools = (
Object.assign(tools, webTools)
}
if (actions.includes(AgentAction.Schedule)) {
const scheduleTools = getScheduleTools({ fetch })
const scheduleTools = getScheduleTools({ fetch, identity })
Object.assign(tools, scheduleTools)
}
if (actions.includes(AgentAction.Memory)) {
+11 -6
View File
@@ -1,9 +1,11 @@
import { tool } from 'ai'
import { z } from 'zod'
import { AuthFetcher } from '..'
import type { IdentityContext } from '../types'
export type ScheduleToolParams = {
fetch: AuthFetcher
identity: IdentityContext
}
const ScheduleSchema = z.object({
@@ -15,12 +17,15 @@ const ScheduleSchema = z.object({
command: z.string(),
})
export const getScheduleTools = ({ fetch }: ScheduleToolParams) => {
export const getScheduleTools = ({ fetch, identity }: ScheduleToolParams) => {
const botId = identity.botId.trim()
const base = `/bots/${botId}/schedule`
const listSchedules = tool({
description: 'List schedules for current user',
inputSchema: z.object({}),
execute: async () => {
const response = await fetch('/schedule', { method: 'GET' })
const response = await fetch(base, { method: 'GET' })
return response.json()
},
})
@@ -31,7 +36,7 @@ export const getScheduleTools = ({ fetch }: ScheduleToolParams) => {
id: z.string().describe('Schedule ID'),
}),
execute: async ({ id }) => {
const response = await fetch(`/schedule/${id}`, { method: 'GET' })
const response = await fetch(`${base}/${id}`, { method: 'GET' })
return response.json()
},
})
@@ -47,7 +52,7 @@ export const getScheduleTools = ({ fetch }: ScheduleToolParams) => {
command: z.string(),
}),
execute: async (payload) => {
const response = await fetch('/schedule', {
const response = await fetch(base, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(payload),
@@ -63,7 +68,7 @@ export const getScheduleTools = ({ fetch }: ScheduleToolParams) => {
}),
execute: async (payload) => {
const { id, ...body } = payload
const response = await fetch(`/schedule/${id}`, {
const response = await fetch(`${base}/${id}`, {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(body),
@@ -78,7 +83,7 @@ export const getScheduleTools = ({ fetch }: ScheduleToolParams) => {
id: z.string(),
}),
execute: async ({ id }) => {
const response = await fetch(`/schedule/${id}`, { method: 'DELETE' })
const response = await fetch(`${base}/${id}`, { method: 'DELETE' })
return response.status === 204 ? { success: true } : response.json()
},
})
+1 -1
View File
@@ -336,5 +336,5 @@ func (c *lazyLLMClient) resolve(ctx context.Context) (memory.LLM, error) {
if clientType != "openai" && clientType != "openai-compat" {
return nil, fmt.Errorf("memory provider client type not supported: %s", memoryProvider.ClientType)
}
return memory.NewLLMClient(c.logger, memoryProvider.BaseUrl, memoryProvider.ApiKey, memoryModel.ModelID, c.timeout), nil
return memory.NewLLMClient(c.logger, memoryProvider.BaseUrl, memoryProvider.ApiKey, memoryModel.ModelID, c.timeout)
}
+13 -1
View File
@@ -112,7 +112,19 @@ func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotR
if err != nil {
return Bot{}, err
}
return toBot(row)
bot, err := toBot(row)
if err != nil {
return Bot{}, err
}
if s.containerLifecycle != nil {
if err := s.containerLifecycle.SetupBotContainer(ctx, bot.ID); err != nil {
s.logger.Error("failed to setup bot container",
slog.String("bot_id", bot.ID),
slog.Any("error", err),
)
}
}
return bot, nil
}
func (s *Service) Get(ctx context.Context, botID string) (Bot, error) {
+1
View File
@@ -58,6 +58,7 @@ type ListMembersResponse struct {
// ContainerLifecycle handles container lifecycle events bound to bot operations.
type ContainerLifecycle interface {
SetupBotContainer(ctx context.Context, botID string) error
CleanupBotContainer(ctx context.Context, botID string) error
}
+112 -25
View File
@@ -111,6 +111,30 @@ type gatewayResponse struct {
Skills []string `json:"skills"`
}
// gatewaySchedule matches the agent gateway ScheduleModel for /chat/trigger-schedule.
type gatewaySchedule struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Pattern string `json:"pattern"`
MaxCalls *int `json:"maxCalls,omitempty"`
Command string `json:"command"`
}
// triggerScheduleRequest is the payload for POST /chat/trigger-schedule.
type triggerScheduleRequest struct {
Model gatewayModelConfig `json:"model"`
ActiveContextTime int `json:"activeContextTime"`
Channels []string `json:"channels"`
CurrentChannel string `json:"currentChannel"`
AllowedActions []string `json:"allowedActions,omitempty"`
Messages []ModelMessage `json:"messages"`
Skills []string `json:"skills"`
Identity gatewayIdentity `json:"identity"`
Attachments []any `json:"attachments"`
Schedule gatewaySchedule `json:"schedule"`
}
// --- resolved context (shared by Chat / StreamChat / TriggerSchedule) ---
type resolvedContext struct {
@@ -226,7 +250,7 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err
// --- TriggerSchedule ---
// TriggerSchedule executes a scheduled command through the chat gateway.
// TriggerSchedule executes a scheduled command through the agent gateway trigger-schedule endpoint.
func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error {
if strings.TrimSpace(botID) == "" {
return fmt.Errorf("bot id is required")
@@ -234,24 +258,52 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc
if strings.TrimSpace(payload.Command) == "" {
return fmt.Errorf("schedule command is required")
}
sessionID := "schedule:" + payload.ID
req := ChatRequest{
BotID: botID,
SessionID: "schedule:" + payload.ID,
SessionID: sessionID,
Query: payload.Command,
UserID: payload.OwnerUserID,
Token: token,
}
rc, err := r.resolve(ctx, req)
if err != nil {
return err
}
rc.payload.Identity.ContactID = botID
rc.payload.Identity.ContactName = "Scheduler"
resp, err := r.postChat(ctx, rc.payload, token)
triggerReq := triggerScheduleRequest{
Model: rc.payload.Model,
ActiveContextTime: rc.payload.ActiveContextTime,
Channels: rc.payload.Channels,
CurrentChannel: rc.payload.CurrentChannel,
AllowedActions: rc.payload.AllowedActions,
Messages: rc.payload.Messages,
Skills: rc.payload.Skills,
Identity: gatewayIdentity{
BotID: rc.payload.Identity.BotID,
SessionID: rc.payload.Identity.SessionID,
ContainerID: rc.payload.Identity.ContainerID,
ContactID: botID,
ContactName: "Scheduler",
UserID: payload.OwnerUserID,
},
Attachments: rc.payload.Attachments,
Schedule: gatewaySchedule{
ID: payload.ID,
Name: payload.Name,
Description: payload.Description,
Pattern: payload.Pattern,
MaxCalls: payload.MaxCalls,
Command: payload.Command,
},
}
resp, err := r.postTriggerSchedule(ctx, triggerReq, token)
if err != nil {
return err
}
return r.storeRound(ctx, botID, req.SessionID, payload.Command, resp.Messages, resp.Skills)
return r.storeRound(ctx, botID, sessionID, payload.Command, resp.Messages, resp.Skills)
}
// --- StreamChat ---
@@ -319,6 +371,47 @@ func (r *Resolver) postChat(ctx context.Context, payload gatewayRequest, token s
return parsed, nil
}
// postTriggerSchedule sends a trigger-schedule request to the agent gateway.
func (r *Resolver) postTriggerSchedule(ctx context.Context, payload triggerScheduleRequest, token string) (gatewayResponse, error) {
body, err := json.Marshal(payload)
if err != nil {
return gatewayResponse{}, err
}
url := r.gatewayBaseURL + "/chat/trigger-schedule"
r.logger.Info("gateway trigger-schedule request", slog.String("url", url), slog.String("schedule_id", payload.Schedule.ID))
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return gatewayResponse{}, err
}
httpReq.Header.Set("Content-Type", "application/json")
if strings.TrimSpace(token) != "" {
httpReq.Header.Set("Authorization", token)
}
resp, err := r.httpClient.Do(httpReq)
if err != nil {
return gatewayResponse{}, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return gatewayResponse{}, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
r.logger.Error("gateway trigger-schedule error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(respBody), 300)))
return gatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody)))
}
var parsed gatewayResponse
if err := json.Unmarshal(respBody, &parsed); err != nil {
r.logger.Error("gateway trigger-schedule response parse failed", slog.String("body_prefix", truncate(string(respBody), 300)), slog.Any("error", err))
return gatewayResponse{}, fmt.Errorf("failed to parse gateway response: %w", err)
}
return parsed, nil
}
func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, botID, sessionID, query, token string, chunkCh chan<- StreamChunk) error {
body, err := json.Marshal(payload)
if err != nil {
@@ -567,12 +660,16 @@ func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, us reso
modelID := strings.TrimSpace(req.Model)
providerFilter := strings.TrimSpace(req.Provider)
// Priority: request model > user settings > first available.
// Priority: request model > user settings. No implicit fallback.
if modelID == "" && providerFilter == "" && strings.TrimSpace(us.ChatModelID) != "" {
modelID = us.ChatModelID
}
if modelID != "" && providerFilter == "" {
if modelID == "" {
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not configured: specify model in request or user settings")
}
if providerFilter == "" {
return r.fetchChatModel(ctx, modelID)
}
@@ -580,26 +677,16 @@ func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, us reso
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
if modelID != "" {
for _, m := range candidates {
if m.ModelID == modelID {
prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
return m, prov, nil
for _, m := range candidates {
if m.ModelID == modelID {
prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
return m, prov, nil
}
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not found")
}
if len(candidates) == 0 {
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available")
}
prov, err := models.FetchProviderByID(ctx, r.queries, candidates[0].LlmProviderID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
return candidates[0], prov, nil
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model %q not found for provider %q", modelID, providerFilter)
}
func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) {
+160
View File
@@ -0,0 +1,160 @@
package chat
import (
"context"
"encoding/json"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestPostTriggerSchedule_Endpoint(t *testing.T) {
var capturedPath string
var capturedBody []byte
var capturedAuth string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.Path
capturedAuth = r.Header.Get("Authorization")
capturedBody, _ = io.ReadAll(r.Body)
resp := gatewayResponse{
Messages: []ModelMessage{{Role: "assistant", Content: NewTextContent("ok")}},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
resolver := &Resolver{
gatewayBaseURL: srv.URL,
httpClient: &http.Client{Timeout: 5 * time.Second},
logger: slog.Default(),
}
maxCalls := 5
req := triggerScheduleRequest{
Model: gatewayModelConfig{
ModelID: "gpt-4",
ClientType: "openai",
APIKey: "sk-test",
BaseURL: "https://api.openai.com",
},
ActiveContextTime: 1440,
Channels: []string{},
Messages: []ModelMessage{},
Skills: []string{},
Identity: gatewayIdentity{
BotID: "bot-123",
SessionID: "schedule:sched-1",
ContainerID: "mcp-bot-123",
ContactID: "bot-123",
ContactName: "Scheduler",
UserID: "owner-user-1",
},
Attachments: []any{},
Schedule: gatewaySchedule{
ID: "sched-1",
Name: "daily report",
Description: "generate daily report",
Pattern: "0 9 * * *",
MaxCalls: &maxCalls,
Command: "generate the daily report",
},
}
resp, err := resolver.postTriggerSchedule(context.Background(), req, "Bearer test-token")
if err != nil {
t.Fatalf("postTriggerSchedule returned error: %v", err)
}
if capturedPath != "/chat/trigger-schedule" {
t.Errorf("expected path /chat/trigger-schedule, got %s", capturedPath)
}
if capturedAuth != "Bearer test-token" {
t.Errorf("expected Authorization header 'Bearer test-token', got %s", capturedAuth)
}
if len(resp.Messages) != 1 {
t.Errorf("expected 1 message, got %d", len(resp.Messages))
}
var body map[string]any
if err := json.Unmarshal(capturedBody, &body); err != nil {
t.Fatalf("failed to parse captured body: %v", err)
}
schedule, ok := body["schedule"].(map[string]any)
if !ok {
t.Fatal("expected 'schedule' field in request body")
}
if schedule["id"] != "sched-1" {
t.Errorf("expected schedule.id=sched-1, got %v", schedule["id"])
}
if schedule["command"] != "generate the daily report" {
t.Errorf("expected schedule.command, got %v", schedule["command"])
}
if _, hasQuery := body["query"]; hasQuery {
t.Error("trigger-schedule request should not contain 'query' field")
}
}
func TestPostTriggerSchedule_NoAuth(t *testing.T) {
var capturedAuth string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
resp := gatewayResponse{Messages: []ModelMessage{}}
json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
resolver := &Resolver{
gatewayBaseURL: srv.URL,
httpClient: &http.Client{Timeout: 5 * time.Second},
logger: slog.Default(),
}
req := triggerScheduleRequest{
Channels: []string{},
Messages: []ModelMessage{},
Skills: []string{},
Attachments: []any{},
Schedule: gatewaySchedule{ID: "s1", Command: "test"},
}
_, err := resolver.postTriggerSchedule(context.Background(), req, "")
if err != nil {
t.Fatalf("postTriggerSchedule returned error: %v", err)
}
if capturedAuth != "" {
t.Errorf("expected no Authorization header, got %s", capturedAuth)
}
}
func TestPostTriggerSchedule_GatewayError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("internal error"))
}))
defer srv.Close()
resolver := &Resolver{
gatewayBaseURL: srv.URL,
httpClient: &http.Client{Timeout: 5 * time.Second},
logger: slog.Default(),
}
req := triggerScheduleRequest{
Channels: []string{},
Messages: []ModelMessage{},
Skills: []string{},
Attachments: []any{},
Schedule: gatewaySchedule{ID: "s1", Command: "test"},
}
_, err := resolver.postTriggerSchedule(context.Background(), req, "Bearer tok")
if err == nil {
t.Fatal("expected error for 500 response")
}
}
+10 -7
View File
@@ -37,15 +37,18 @@ type openAIEmbeddingResponse struct {
} `json:"data"`
}
func NewOpenAIEmbedder(log *slog.Logger, apiKey, baseURL, model string, dims int, timeout time.Duration) *OpenAIEmbedder {
if baseURL == "" {
baseURL = "https://api.openai.com"
func NewOpenAIEmbedder(log *slog.Logger, apiKey, baseURL, model string, dims int, timeout time.Duration) (*OpenAIEmbedder, error) {
if strings.TrimSpace(baseURL) == "" {
return nil, fmt.Errorf("openai embedder: base url is required")
}
if model == "" {
model = "text-embedding-3-small"
if strings.TrimSpace(apiKey) == "" {
return nil, fmt.Errorf("openai embedder: api key is required")
}
if strings.TrimSpace(model) == "" {
return nil, fmt.Errorf("openai embedder: model is required")
}
if dims <= 0 {
dims = 1536
return nil, fmt.Errorf("openai embedder: dimensions must be positive")
}
if timeout <= 0 {
timeout = 10 * time.Second
@@ -59,7 +62,7 @@ func NewOpenAIEmbedder(log *slog.Logger, apiKey, baseURL, model string, dims int
http: &http.Client{
Timeout: timeout,
},
}
}, nil
}
func (e *OpenAIEmbedder) Dimensions() int {
+3 -3
View File
@@ -130,10 +130,10 @@ func (r *Resolver) Embed(ctx context.Context, req Request) (Result, error) {
if req.Provider != ProviderOpenAI {
return Result{}, errors.New("provider not implemented")
}
if strings.TrimSpace(provider.ApiKey) == "" {
return Result{}, errors.New("openai api key is required")
embedder, err := NewOpenAIEmbedder(r.logger, provider.ApiKey, provider.BaseUrl, req.Model, req.Dimensions, timeout)
if err != nil {
return Result{}, err
}
embedder := NewOpenAIEmbedder(r.logger, provider.ApiKey, provider.BaseUrl, req.Model, req.Dimensions, timeout)
vector, err := embedder.Embed(ctx, req.Input.Text)
if err != nil {
return Result{}, err
+106
View File
@@ -641,6 +641,112 @@ func (h *ContainerdHandler) authorizeBotAccess(ctx context.Context, actorID, bot
return bot, nil
}
// SetupBotContainer creates and starts the MCP container for a bot.
func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string) error {
containerID := mcp.ContainerPrefix + botID
image := strings.TrimSpace(h.cfg.BusyboxImage)
if image == "" {
image = config.DefaultBusyboxImg
}
snapshotter := strings.TrimSpace(h.cfg.Snapshotter)
if strings.TrimSpace(h.namespace) != "" {
ctx = namespaces.WithNamespace(ctx, h.namespace)
}
dataRoot := strings.TrimSpace(h.cfg.DataRoot)
if dataRoot == "" {
dataRoot = config.DefaultDataRoot
}
dataRoot, _ = filepath.Abs(dataRoot)
dataMount := strings.TrimSpace(h.cfg.DataMount)
if dataMount == "" {
dataMount = config.DefaultDataMount
}
dataDir := filepath.Join(dataRoot, "bots", botID)
if err := os.MkdirAll(dataDir, 0o755); err != nil {
return err
}
if err := os.MkdirAll(filepath.Join(dataDir, ".skills"), 0o755); err != nil {
return err
}
specOpts := []oci.SpecOpts{
oci.WithMounts([]specs.Mount{
{
Destination: dataMount,
Type: "bind",
Source: dataDir,
Options: []string{"rbind", "rw"},
},
{
Destination: "/app",
Type: "bind",
Source: dataDir,
Options: []string{"rbind", "rw"},
},
}),
oci.WithProcessArgs("/bin/sh", "-lc", "sleep 2147483647"),
}
_, err := h.service.CreateContainer(ctx, ctr.CreateContainerRequest{
ID: containerID,
ImageRef: image,
Snapshotter: snapshotter,
Labels: map[string]string{
mcp.BotLabelKey: botID,
},
SpecOpts: specOpts,
})
if err != nil && !errdefs.IsAlreadyExists(err) {
return err
}
if h.queries != nil {
pgBotID, parseErr := parsePgUUID(botID)
if parseErr == nil {
ns := strings.TrimSpace(h.namespace)
if ns == "" {
ns = "default"
}
_ = h.queries.UpsertContainer(ctx, dbsqlc.UpsertContainerParams{
BotID: pgBotID,
ContainerID: containerID,
ContainerName: containerID,
Image: image,
Status: "created",
Namespace: ns,
AutoStart: true,
HostPath: pgtype.Text{String: dataDir, Valid: true},
ContainerPath: dataMount,
})
}
}
fifoDir, err := h.taskFIFODir()
if err != nil {
return err
}
if _, err := h.service.StartTask(ctx, containerID, &ctr.StartTaskOptions{
UseStdio: false,
FIFODir: fifoDir,
}); err == nil {
if h.queries != nil {
if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil {
_ = h.queries.UpdateContainerStarted(ctx, pgBotID)
}
}
} else {
h.logger.Error("setup bot container: task start failed",
slog.String("bot_id", botID),
slog.String("container_id", containerID),
slog.Any("error", err),
)
}
return nil
}
// CleanupBotContainer removes the containerd container and DB record for a bot.
func (h *ContainerdHandler) CleanupBotContainer(ctx context.Context, botID string) error {
containerID, err := h.botContainerID(ctx, botID)
+12 -10
View File
@@ -20,29 +20,31 @@ type LLMClient struct {
http *http.Client
}
func NewLLMClient(log *slog.Logger, baseURL, apiKey, model string, timeout time.Duration) *LLMClient {
func NewLLMClient(log *slog.Logger, baseURL, apiKey, model string, timeout time.Duration) (*LLMClient, error) {
if strings.TrimSpace(baseURL) == "" {
return nil, fmt.Errorf("llm client: base url is required")
}
if strings.TrimSpace(apiKey) == "" {
return nil, fmt.Errorf("llm client: api key is required")
}
if strings.TrimSpace(model) == "" {
return nil, fmt.Errorf("llm client: model is required")
}
if log == nil {
log = slog.Default()
}
if baseURL == "" {
baseURL = "https://api.openai.com/v1"
}
baseURL = strings.TrimRight(baseURL, "/")
if model == "" {
model = "gpt-4.1-nano"
}
if timeout <= 0 {
timeout = 10 * time.Second
}
return &LLMClient{
baseURL: baseURL,
baseURL: strings.TrimRight(baseURL, "/"),
apiKey: apiKey,
model: model,
logger: log.With(slog.String("client", "llm")),
http: &http.Client{
Timeout: timeout,
},
}
}, nil
}
func (c *LLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) {
+4 -1
View File
@@ -20,7 +20,10 @@ func TestLLMClientExtract(t *testing.T) {
}))
defer server.Close()
client := NewLLMClient(nil, server.URL, "test-key", "gpt-4.1-nano-2025-04-14", 0)
client, err := NewLLMClient(nil, server.URL, "test-key", "gpt-4.1-nano-2025-04-14", 0)
if err != nil {
t.Fatalf("new llm client: %v", err)
}
resp, err := client.Extract(context.Background(), ExtractRequest{
Messages: []Message{{Role: "user", Content: "hi"}},
})
+10 -11
View File
@@ -135,16 +135,16 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
desc, _ = p.registry.GetDescriptor(msg.Channel)
}
resp, err := p.chat.Chat(ctx, chat.ChatRequest{
BotID: identity.BotID,
SessionID: identity.SessionID,
Token: token,
UserID: identity.UserID,
ContactID: identity.ContactID,
ContactName: strings.TrimSpace(identity.Contact.DisplayName),
ContactAlias: strings.TrimSpace(identity.Contact.Alias),
ReplyTarget: strings.TrimSpace(msg.ReplyTarget),
SessionToken: sessionToken,
Query: text,
BotID: identity.BotID,
SessionID: identity.SessionID,
Token: token,
UserID: identity.UserID,
ContactID: identity.ContactID,
ContactName: strings.TrimSpace(identity.Contact.DisplayName),
ContactAlias: strings.TrimSpace(identity.Contact.Alias),
ReplyTarget: strings.TrimSpace(msg.ReplyTarget),
SessionToken: sessionToken,
Query: text,
CurrentChannel: msg.Channel.String(),
Channels: []string{msg.Channel.String()},
})
@@ -526,7 +526,6 @@ func isMessagingToolDuplicate(text string, sentTexts []string) bool {
return false
}
func (p *ChannelInboundProcessor) requireIdentity(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) (IdentityState, error) {
if state, ok := IdentityStateFromContext(ctx); ok {
return state, nil
+46 -5
View File
@@ -7,12 +7,14 @@ import (
"log/slog"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/robfig/cron/v3"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/db/sqlc"
)
@@ -214,6 +216,8 @@ func (s *Service) Trigger(ctx context.Context, scheduleID string) error {
return s.runSchedule(ctx, schedule)
}
const scheduleTokenTTL = 10 * time.Minute
func (s *Service) runSchedule(ctx context.Context, schedule Schedule) error {
if s.triggerer == nil {
return fmt.Errorf("schedule triggerer not configured")
@@ -225,18 +229,55 @@ func (s *Service) runSchedule(ctx context.Context, schedule Schedule) error {
if !updated.Enabled {
s.removeJob(schedule.ID)
}
token := ""
if err := s.triggerer.TriggerSchedule(ctx, schedule.BotID, TriggerPayload{
ownerUserID, err := s.resolveBotOwner(ctx, schedule.BotID)
if err != nil {
return fmt.Errorf("resolve bot owner: %w", err)
}
token, err := s.generateTriggerToken(ownerUserID)
if err != nil {
return fmt.Errorf("generate trigger token: %w", err)
}
return s.triggerer.TriggerSchedule(ctx, schedule.BotID, TriggerPayload{
ID: schedule.ID,
Name: schedule.Name,
Description: schedule.Description,
Pattern: schedule.Pattern,
MaxCalls: schedule.MaxCalls,
Command: schedule.Command,
}, token); err != nil {
return err
OwnerUserID: ownerUserID,
}, token)
}
// resolveBotOwner returns the owner user ID for the given bot.
func (s *Service) resolveBotOwner(ctx context.Context, botID string) (string, error) {
pgBotID, err := parseUUID(botID)
if err != nil {
return "", err
}
return nil
bot, err := s.queries.GetBotByID(ctx, pgBotID)
if err != nil {
return "", fmt.Errorf("get bot: %w", err)
}
ownerID := toUUIDString(bot.OwnerUserID)
if ownerID == "" {
return "", fmt.Errorf("bot owner not found")
}
return ownerID, nil
}
// generateTriggerToken creates a short-lived JWT for schedule trigger callbacks.
func (s *Service) generateTriggerToken(userID string) (string, error) {
if strings.TrimSpace(s.jwtSecret) == "" {
return "", fmt.Errorf("jwt secret not configured")
}
signed, _, err := auth.GenerateToken(userID, s.jwtSecret, scheduleTokenTTL)
if err != nil {
return "", err
}
return "Bearer " + signed, nil
}
func (s *Service) scheduleJob(schedule sqlc.Schedule) error {
+91
View File
@@ -0,0 +1,91 @@
package schedule
import (
"context"
"log/slog"
"strings"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
)
type mockTriggerer struct {
called bool
botID string
payload TriggerPayload
token string
}
func (m *mockTriggerer) TriggerSchedule(_ context.Context, botID string, payload TriggerPayload, token string) error {
m.called = true
m.botID = botID
m.payload = payload
m.token = token
return nil
}
func TestGenerateTriggerToken(t *testing.T) {
secret := "test-secret-key-for-schedule"
svc := &Service{
jwtSecret: secret,
logger: slog.Default(),
}
userID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
tok, err := svc.generateTriggerToken(userID)
if err != nil {
t.Fatalf("generateTriggerToken returned error: %v", err)
}
if !strings.HasPrefix(tok, "Bearer ") {
t.Fatalf("expected Bearer prefix, got: %s", tok)
}
raw := strings.TrimPrefix(tok, "Bearer ")
parsed, err := jwt.Parse(raw, func(token *jwt.Token) (any, error) {
return []byte(secret), nil
})
if err != nil {
t.Fatalf("failed to parse JWT: %v", err)
}
claims, ok := parsed.Claims.(jwt.MapClaims)
if !ok {
t.Fatal("expected MapClaims")
}
if sub, _ := claims["sub"].(string); sub != userID {
t.Errorf("expected sub=%s, got=%s", userID, sub)
}
if uid, _ := claims["user_id"].(string); uid != userID {
t.Errorf("expected user_id=%s, got=%s", userID, uid)
}
exp, _ := claims["exp"].(float64)
if exp == 0 {
t.Fatal("expected non-zero exp")
}
expTime := time.Unix(int64(exp), 0)
if expTime.Before(time.Now().Add(9 * time.Minute)) {
t.Error("token expires too soon")
}
}
func TestGenerateTriggerToken_EmptySecret(t *testing.T) {
svc := &Service{
jwtSecret: "",
logger: slog.Default(),
}
_, err := svc.generateTriggerToken("user-123")
if err == nil {
t.Fatal("expected error for empty secret")
}
}
func TestGenerateTriggerToken_EmptyUserID(t *testing.T) {
svc := &Service{
jwtSecret: "some-secret",
logger: slog.Default(),
}
_, err := svc.generateTriggerToken("")
if err == nil {
t.Fatal("expected error for empty user ID")
}
}
+2 -1
View File
@@ -2,7 +2,7 @@ package schedule
import "context"
// TriggerPayload 描述触发定时任务时传递给聊天侧的参数。
// TriggerPayload describes the parameters passed to the chat side when a schedule triggers.
type TriggerPayload struct {
ID string
Name string
@@ -10,6 +10,7 @@ type TriggerPayload struct {
Pattern string
MaxCalls *int
Command string
OwnerUserID string
}
// Triggerer 负责触发与聊天相关的调度执行。