refactor: bind container lifecycle to bot and improve schedule trigger flow

- Add SetupBotContainer to ContainerLifecycle interface so containers
  are automatically created when a bot is created, matching the existing
  cleanup-on-delete behavior.
- Refactor schedule tools to use bot-scoped API paths and pass identity
  context for proper authorization.
- Introduce dedicated trigger-schedule endpoint in chat resolver with
  explicit schedule payload instead of reusing the generic chat path.
- Generate short-lived JWT tokens for schedule trigger callbacks with
  resolved bot owner identity.
- Validate required parameters in NewLLMClient and NewOpenAIEmbedder
  constructors, returning errors instead of falling back to defaults.
- Add unit tests for schedule token generation and chat resolver.
This commit is contained in:
BBQ
2026-02-07 12:03:24 +08:00
parent a9596ab3a8
commit 83b6ee608c
16 changed files with 583 additions and 72 deletions
+112 -25
View File
@@ -111,6 +111,30 @@ type gatewayResponse struct {
Skills []string `json:"skills"`
}
// gatewaySchedule matches the agent gateway ScheduleModel for /chat/trigger-schedule.
type gatewaySchedule struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Pattern string `json:"pattern"`
MaxCalls *int `json:"maxCalls,omitempty"`
Command string `json:"command"`
}
// triggerScheduleRequest is the payload for POST /chat/trigger-schedule.
type triggerScheduleRequest struct {
Model gatewayModelConfig `json:"model"`
ActiveContextTime int `json:"activeContextTime"`
Channels []string `json:"channels"`
CurrentChannel string `json:"currentChannel"`
AllowedActions []string `json:"allowedActions,omitempty"`
Messages []ModelMessage `json:"messages"`
Skills []string `json:"skills"`
Identity gatewayIdentity `json:"identity"`
Attachments []any `json:"attachments"`
Schedule gatewaySchedule `json:"schedule"`
}
// --- resolved context (shared by Chat / StreamChat / TriggerSchedule) ---
type resolvedContext struct {
@@ -226,7 +250,7 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err
// --- TriggerSchedule ---
// TriggerSchedule executes a scheduled command through the chat gateway.
// TriggerSchedule executes a scheduled command through the agent gateway trigger-schedule endpoint.
func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error {
if strings.TrimSpace(botID) == "" {
return fmt.Errorf("bot id is required")
@@ -234,24 +258,52 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc
if strings.TrimSpace(payload.Command) == "" {
return fmt.Errorf("schedule command is required")
}
sessionID := "schedule:" + payload.ID
req := ChatRequest{
BotID: botID,
SessionID: "schedule:" + payload.ID,
SessionID: sessionID,
Query: payload.Command,
UserID: payload.OwnerUserID,
Token: token,
}
rc, err := r.resolve(ctx, req)
if err != nil {
return err
}
rc.payload.Identity.ContactID = botID
rc.payload.Identity.ContactName = "Scheduler"
resp, err := r.postChat(ctx, rc.payload, token)
triggerReq := triggerScheduleRequest{
Model: rc.payload.Model,
ActiveContextTime: rc.payload.ActiveContextTime,
Channels: rc.payload.Channels,
CurrentChannel: rc.payload.CurrentChannel,
AllowedActions: rc.payload.AllowedActions,
Messages: rc.payload.Messages,
Skills: rc.payload.Skills,
Identity: gatewayIdentity{
BotID: rc.payload.Identity.BotID,
SessionID: rc.payload.Identity.SessionID,
ContainerID: rc.payload.Identity.ContainerID,
ContactID: botID,
ContactName: "Scheduler",
UserID: payload.OwnerUserID,
},
Attachments: rc.payload.Attachments,
Schedule: gatewaySchedule{
ID: payload.ID,
Name: payload.Name,
Description: payload.Description,
Pattern: payload.Pattern,
MaxCalls: payload.MaxCalls,
Command: payload.Command,
},
}
resp, err := r.postTriggerSchedule(ctx, triggerReq, token)
if err != nil {
return err
}
return r.storeRound(ctx, botID, req.SessionID, payload.Command, resp.Messages, resp.Skills)
return r.storeRound(ctx, botID, sessionID, payload.Command, resp.Messages, resp.Skills)
}
// --- StreamChat ---
@@ -319,6 +371,47 @@ func (r *Resolver) postChat(ctx context.Context, payload gatewayRequest, token s
return parsed, nil
}
// postTriggerSchedule sends a trigger-schedule request to the agent gateway.
func (r *Resolver) postTriggerSchedule(ctx context.Context, payload triggerScheduleRequest, token string) (gatewayResponse, error) {
body, err := json.Marshal(payload)
if err != nil {
return gatewayResponse{}, err
}
url := r.gatewayBaseURL + "/chat/trigger-schedule"
r.logger.Info("gateway trigger-schedule request", slog.String("url", url), slog.String("schedule_id", payload.Schedule.ID))
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return gatewayResponse{}, err
}
httpReq.Header.Set("Content-Type", "application/json")
if strings.TrimSpace(token) != "" {
httpReq.Header.Set("Authorization", token)
}
resp, err := r.httpClient.Do(httpReq)
if err != nil {
return gatewayResponse{}, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return gatewayResponse{}, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
r.logger.Error("gateway trigger-schedule error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(respBody), 300)))
return gatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody)))
}
var parsed gatewayResponse
if err := json.Unmarshal(respBody, &parsed); err != nil {
r.logger.Error("gateway trigger-schedule response parse failed", slog.String("body_prefix", truncate(string(respBody), 300)), slog.Any("error", err))
return gatewayResponse{}, fmt.Errorf("failed to parse gateway response: %w", err)
}
return parsed, nil
}
func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, botID, sessionID, query, token string, chunkCh chan<- StreamChunk) error {
body, err := json.Marshal(payload)
if err != nil {
@@ -567,12 +660,16 @@ func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, us reso
modelID := strings.TrimSpace(req.Model)
providerFilter := strings.TrimSpace(req.Provider)
// Priority: request model > user settings > first available.
// Priority: request model > user settings. No implicit fallback.
if modelID == "" && providerFilter == "" && strings.TrimSpace(us.ChatModelID) != "" {
modelID = us.ChatModelID
}
if modelID != "" && providerFilter == "" {
if modelID == "" {
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not configured: specify model in request or user settings")
}
if providerFilter == "" {
return r.fetchChatModel(ctx, modelID)
}
@@ -580,26 +677,16 @@ func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, us reso
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
if modelID != "" {
for _, m := range candidates {
if m.ModelID == modelID {
prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
return m, prov, nil
for _, m := range candidates {
if m.ModelID == modelID {
prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
return m, prov, nil
}
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not found")
}
if len(candidates) == 0 {
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available")
}
prov, err := models.FetchProviderByID(ctx, r.queries, candidates[0].LlmProviderID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
return candidates[0], prov, nil
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model %q not found for provider %q", modelID, providerFilter)
}
func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) {
+160
View File
@@ -0,0 +1,160 @@
package chat
import (
"context"
"encoding/json"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestPostTriggerSchedule_Endpoint(t *testing.T) {
var capturedPath string
var capturedBody []byte
var capturedAuth string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.Path
capturedAuth = r.Header.Get("Authorization")
capturedBody, _ = io.ReadAll(r.Body)
resp := gatewayResponse{
Messages: []ModelMessage{{Role: "assistant", Content: NewTextContent("ok")}},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
resolver := &Resolver{
gatewayBaseURL: srv.URL,
httpClient: &http.Client{Timeout: 5 * time.Second},
logger: slog.Default(),
}
maxCalls := 5
req := triggerScheduleRequest{
Model: gatewayModelConfig{
ModelID: "gpt-4",
ClientType: "openai",
APIKey: "sk-test",
BaseURL: "https://api.openai.com",
},
ActiveContextTime: 1440,
Channels: []string{},
Messages: []ModelMessage{},
Skills: []string{},
Identity: gatewayIdentity{
BotID: "bot-123",
SessionID: "schedule:sched-1",
ContainerID: "mcp-bot-123",
ContactID: "bot-123",
ContactName: "Scheduler",
UserID: "owner-user-1",
},
Attachments: []any{},
Schedule: gatewaySchedule{
ID: "sched-1",
Name: "daily report",
Description: "generate daily report",
Pattern: "0 9 * * *",
MaxCalls: &maxCalls,
Command: "generate the daily report",
},
}
resp, err := resolver.postTriggerSchedule(context.Background(), req, "Bearer test-token")
if err != nil {
t.Fatalf("postTriggerSchedule returned error: %v", err)
}
if capturedPath != "/chat/trigger-schedule" {
t.Errorf("expected path /chat/trigger-schedule, got %s", capturedPath)
}
if capturedAuth != "Bearer test-token" {
t.Errorf("expected Authorization header 'Bearer test-token', got %s", capturedAuth)
}
if len(resp.Messages) != 1 {
t.Errorf("expected 1 message, got %d", len(resp.Messages))
}
var body map[string]any
if err := json.Unmarshal(capturedBody, &body); err != nil {
t.Fatalf("failed to parse captured body: %v", err)
}
schedule, ok := body["schedule"].(map[string]any)
if !ok {
t.Fatal("expected 'schedule' field in request body")
}
if schedule["id"] != "sched-1" {
t.Errorf("expected schedule.id=sched-1, got %v", schedule["id"])
}
if schedule["command"] != "generate the daily report" {
t.Errorf("expected schedule.command, got %v", schedule["command"])
}
if _, hasQuery := body["query"]; hasQuery {
t.Error("trigger-schedule request should not contain 'query' field")
}
}
func TestPostTriggerSchedule_NoAuth(t *testing.T) {
var capturedAuth string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
resp := gatewayResponse{Messages: []ModelMessage{}}
json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
resolver := &Resolver{
gatewayBaseURL: srv.URL,
httpClient: &http.Client{Timeout: 5 * time.Second},
logger: slog.Default(),
}
req := triggerScheduleRequest{
Channels: []string{},
Messages: []ModelMessage{},
Skills: []string{},
Attachments: []any{},
Schedule: gatewaySchedule{ID: "s1", Command: "test"},
}
_, err := resolver.postTriggerSchedule(context.Background(), req, "")
if err != nil {
t.Fatalf("postTriggerSchedule returned error: %v", err)
}
if capturedAuth != "" {
t.Errorf("expected no Authorization header, got %s", capturedAuth)
}
}
func TestPostTriggerSchedule_GatewayError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("internal error"))
}))
defer srv.Close()
resolver := &Resolver{
gatewayBaseURL: srv.URL,
httpClient: &http.Client{Timeout: 5 * time.Second},
logger: slog.Default(),
}
req := triggerScheduleRequest{
Channels: []string{},
Messages: []ModelMessage{},
Skills: []string{},
Attachments: []any{},
Schedule: gatewaySchedule{ID: "s1", Command: "test"},
}
_, err := resolver.postTriggerSchedule(context.Background(), req, "Bearer tok")
if err == nil {
t.Fatal("expected error for 500 response")
}
}