mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
refactor: bind container lifecycle to bot and improve schedule trigger flow
- Add SetupBotContainer to ContainerLifecycle interface so containers are automatically created when a bot is created, matching the existing cleanup-on-delete behavior. - Refactor schedule tools to use bot-scoped API paths and pass identity context for proper authorization. - Introduce dedicated trigger-schedule endpoint in chat resolver with explicit schedule payload instead of reusing the generic chat path. - Generate short-lived JWT tokens for schedule trigger callbacks with resolved bot owner identity. - Validate required parameters in NewLLMClient and NewOpenAIEmbedder constructors, returning errors instead of falling back to defaults. - Add unit tests for schedule token generation and chat resolver.
This commit is contained in:
+112
-25
@@ -111,6 +111,30 @@ type gatewayResponse struct {
|
||||
Skills []string `json:"skills"`
|
||||
}
|
||||
|
||||
// gatewaySchedule matches the agent gateway ScheduleModel for /chat/trigger-schedule.
|
||||
type gatewaySchedule struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Pattern string `json:"pattern"`
|
||||
MaxCalls *int `json:"maxCalls,omitempty"`
|
||||
Command string `json:"command"`
|
||||
}
|
||||
|
||||
// triggerScheduleRequest is the payload for POST /chat/trigger-schedule.
|
||||
type triggerScheduleRequest struct {
|
||||
Model gatewayModelConfig `json:"model"`
|
||||
ActiveContextTime int `json:"activeContextTime"`
|
||||
Channels []string `json:"channels"`
|
||||
CurrentChannel string `json:"currentChannel"`
|
||||
AllowedActions []string `json:"allowedActions,omitempty"`
|
||||
Messages []ModelMessage `json:"messages"`
|
||||
Skills []string `json:"skills"`
|
||||
Identity gatewayIdentity `json:"identity"`
|
||||
Attachments []any `json:"attachments"`
|
||||
Schedule gatewaySchedule `json:"schedule"`
|
||||
}
|
||||
|
||||
// --- resolved context (shared by Chat / StreamChat / TriggerSchedule) ---
|
||||
|
||||
type resolvedContext struct {
|
||||
@@ -226,7 +250,7 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err
|
||||
|
||||
// --- TriggerSchedule ---
|
||||
|
||||
// TriggerSchedule executes a scheduled command through the chat gateway.
|
||||
// TriggerSchedule executes a scheduled command through the agent gateway trigger-schedule endpoint.
|
||||
func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error {
|
||||
if strings.TrimSpace(botID) == "" {
|
||||
return fmt.Errorf("bot id is required")
|
||||
@@ -234,24 +258,52 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc
|
||||
if strings.TrimSpace(payload.Command) == "" {
|
||||
return fmt.Errorf("schedule command is required")
|
||||
}
|
||||
|
||||
sessionID := "schedule:" + payload.ID
|
||||
req := ChatRequest{
|
||||
BotID: botID,
|
||||
SessionID: "schedule:" + payload.ID,
|
||||
SessionID: sessionID,
|
||||
Query: payload.Command,
|
||||
UserID: payload.OwnerUserID,
|
||||
Token: token,
|
||||
}
|
||||
rc, err := r.resolve(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rc.payload.Identity.ContactID = botID
|
||||
rc.payload.Identity.ContactName = "Scheduler"
|
||||
|
||||
resp, err := r.postChat(ctx, rc.payload, token)
|
||||
triggerReq := triggerScheduleRequest{
|
||||
Model: rc.payload.Model,
|
||||
ActiveContextTime: rc.payload.ActiveContextTime,
|
||||
Channels: rc.payload.Channels,
|
||||
CurrentChannel: rc.payload.CurrentChannel,
|
||||
AllowedActions: rc.payload.AllowedActions,
|
||||
Messages: rc.payload.Messages,
|
||||
Skills: rc.payload.Skills,
|
||||
Identity: gatewayIdentity{
|
||||
BotID: rc.payload.Identity.BotID,
|
||||
SessionID: rc.payload.Identity.SessionID,
|
||||
ContainerID: rc.payload.Identity.ContainerID,
|
||||
ContactID: botID,
|
||||
ContactName: "Scheduler",
|
||||
UserID: payload.OwnerUserID,
|
||||
},
|
||||
Attachments: rc.payload.Attachments,
|
||||
Schedule: gatewaySchedule{
|
||||
ID: payload.ID,
|
||||
Name: payload.Name,
|
||||
Description: payload.Description,
|
||||
Pattern: payload.Pattern,
|
||||
MaxCalls: payload.MaxCalls,
|
||||
Command: payload.Command,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := r.postTriggerSchedule(ctx, triggerReq, token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.storeRound(ctx, botID, req.SessionID, payload.Command, resp.Messages, resp.Skills)
|
||||
return r.storeRound(ctx, botID, sessionID, payload.Command, resp.Messages, resp.Skills)
|
||||
}
|
||||
|
||||
// --- StreamChat ---
|
||||
@@ -319,6 +371,47 @@ func (r *Resolver) postChat(ctx context.Context, payload gatewayRequest, token s
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
// postTriggerSchedule sends a trigger-schedule request to the agent gateway.
|
||||
func (r *Resolver) postTriggerSchedule(ctx context.Context, payload triggerScheduleRequest, token string) (gatewayResponse, error) {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return gatewayResponse{}, err
|
||||
}
|
||||
url := r.gatewayBaseURL + "/chat/trigger-schedule"
|
||||
r.logger.Info("gateway trigger-schedule request", slog.String("url", url), slog.String("schedule_id", payload.Schedule.ID))
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return gatewayResponse{}, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if strings.TrimSpace(token) != "" {
|
||||
httpReq.Header.Set("Authorization", token)
|
||||
}
|
||||
|
||||
resp, err := r.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return gatewayResponse{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return gatewayResponse{}, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
r.logger.Error("gateway trigger-schedule error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(respBody), 300)))
|
||||
return gatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody)))
|
||||
}
|
||||
|
||||
var parsed gatewayResponse
|
||||
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||
r.logger.Error("gateway trigger-schedule response parse failed", slog.String("body_prefix", truncate(string(respBody), 300)), slog.Any("error", err))
|
||||
return gatewayResponse{}, fmt.Errorf("failed to parse gateway response: %w", err)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, botID, sessionID, query, token string, chunkCh chan<- StreamChunk) error {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
@@ -567,12 +660,16 @@ func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, us reso
|
||||
modelID := strings.TrimSpace(req.Model)
|
||||
providerFilter := strings.TrimSpace(req.Provider)
|
||||
|
||||
// Priority: request model > user settings > first available.
|
||||
// Priority: request model > user settings. No implicit fallback.
|
||||
if modelID == "" && providerFilter == "" && strings.TrimSpace(us.ChatModelID) != "" {
|
||||
modelID = us.ChatModelID
|
||||
}
|
||||
|
||||
if modelID != "" && providerFilter == "" {
|
||||
if modelID == "" {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not configured: specify model in request or user settings")
|
||||
}
|
||||
|
||||
if providerFilter == "" {
|
||||
return r.fetchChatModel(ctx, modelID)
|
||||
}
|
||||
|
||||
@@ -580,26 +677,16 @@ func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, us reso
|
||||
if err != nil {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
if modelID != "" {
|
||||
for _, m := range candidates {
|
||||
if m.ModelID == modelID {
|
||||
prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
return m, prov, nil
|
||||
for _, m := range candidates {
|
||||
if m.ModelID == modelID {
|
||||
prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
return m, prov, nil
|
||||
}
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not found")
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available")
|
||||
}
|
||||
prov, err := models.FetchProviderByID(ctx, r.queries, candidates[0].LlmProviderID)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
return candidates[0], prov, nil
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model %q not found for provider %q", modelID, providerFilter)
|
||||
}
|
||||
|
||||
func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) {
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPostTriggerSchedule_Endpoint(t *testing.T) {
|
||||
var capturedPath string
|
||||
var capturedBody []byte
|
||||
var capturedAuth string
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedPath = r.URL.Path
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
capturedBody, _ = io.ReadAll(r.Body)
|
||||
resp := gatewayResponse{
|
||||
Messages: []ModelMessage{{Role: "assistant", Content: NewTextContent("ok")}},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
resolver := &Resolver{
|
||||
gatewayBaseURL: srv.URL,
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
logger: slog.Default(),
|
||||
}
|
||||
|
||||
maxCalls := 5
|
||||
req := triggerScheduleRequest{
|
||||
Model: gatewayModelConfig{
|
||||
ModelID: "gpt-4",
|
||||
ClientType: "openai",
|
||||
APIKey: "sk-test",
|
||||
BaseURL: "https://api.openai.com",
|
||||
},
|
||||
ActiveContextTime: 1440,
|
||||
Channels: []string{},
|
||||
Messages: []ModelMessage{},
|
||||
Skills: []string{},
|
||||
Identity: gatewayIdentity{
|
||||
BotID: "bot-123",
|
||||
SessionID: "schedule:sched-1",
|
||||
ContainerID: "mcp-bot-123",
|
||||
ContactID: "bot-123",
|
||||
ContactName: "Scheduler",
|
||||
UserID: "owner-user-1",
|
||||
},
|
||||
Attachments: []any{},
|
||||
Schedule: gatewaySchedule{
|
||||
ID: "sched-1",
|
||||
Name: "daily report",
|
||||
Description: "generate daily report",
|
||||
Pattern: "0 9 * * *",
|
||||
MaxCalls: &maxCalls,
|
||||
Command: "generate the daily report",
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := resolver.postTriggerSchedule(context.Background(), req, "Bearer test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("postTriggerSchedule returned error: %v", err)
|
||||
}
|
||||
|
||||
if capturedPath != "/chat/trigger-schedule" {
|
||||
t.Errorf("expected path /chat/trigger-schedule, got %s", capturedPath)
|
||||
}
|
||||
if capturedAuth != "Bearer test-token" {
|
||||
t.Errorf("expected Authorization header 'Bearer test-token', got %s", capturedAuth)
|
||||
}
|
||||
if len(resp.Messages) != 1 {
|
||||
t.Errorf("expected 1 message, got %d", len(resp.Messages))
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(capturedBody, &body); err != nil {
|
||||
t.Fatalf("failed to parse captured body: %v", err)
|
||||
}
|
||||
schedule, ok := body["schedule"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("expected 'schedule' field in request body")
|
||||
}
|
||||
if schedule["id"] != "sched-1" {
|
||||
t.Errorf("expected schedule.id=sched-1, got %v", schedule["id"])
|
||||
}
|
||||
if schedule["command"] != "generate the daily report" {
|
||||
t.Errorf("expected schedule.command, got %v", schedule["command"])
|
||||
}
|
||||
if _, hasQuery := body["query"]; hasQuery {
|
||||
t.Error("trigger-schedule request should not contain 'query' field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostTriggerSchedule_NoAuth(t *testing.T) {
|
||||
var capturedAuth string
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
resp := gatewayResponse{Messages: []ModelMessage{}}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
resolver := &Resolver{
|
||||
gatewayBaseURL: srv.URL,
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
logger: slog.Default(),
|
||||
}
|
||||
|
||||
req := triggerScheduleRequest{
|
||||
Channels: []string{},
|
||||
Messages: []ModelMessage{},
|
||||
Skills: []string{},
|
||||
Attachments: []any{},
|
||||
Schedule: gatewaySchedule{ID: "s1", Command: "test"},
|
||||
}
|
||||
|
||||
_, err := resolver.postTriggerSchedule(context.Background(), req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("postTriggerSchedule returned error: %v", err)
|
||||
}
|
||||
if capturedAuth != "" {
|
||||
t.Errorf("expected no Authorization header, got %s", capturedAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostTriggerSchedule_GatewayError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("internal error"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
resolver := &Resolver{
|
||||
gatewayBaseURL: srv.URL,
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
logger: slog.Default(),
|
||||
}
|
||||
|
||||
req := triggerScheduleRequest{
|
||||
Channels: []string{},
|
||||
Messages: []ModelMessage{},
|
||||
Skills: []string{},
|
||||
Attachments: []any{},
|
||||
Schedule: gatewaySchedule{ID: "s1", Command: "test"},
|
||||
}
|
||||
|
||||
_, err := resolver.postTriggerSchedule(context.Background(), req, "Bearer tok")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 500 response")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user