mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
refactor: bind container lifecycle to bot and improve schedule trigger flow
- Add SetupBotContainer to ContainerLifecycle interface so containers are automatically created when a bot is created, matching the existing cleanup-on-delete behavior. - Refactor schedule tools to use bot-scoped API paths and pass identity context for proper authorization. - Introduce dedicated trigger-schedule endpoint in chat resolver with explicit schedule payload instead of reusing the generic chat path. - Generate short-lived JWT tokens for schedule trigger callbacks with resolved bot owner identity. - Validate required parameters in NewLLMClient and NewOpenAIEmbedder constructors, returning errors instead of falling back to defaults. - Add unit tests for schedule token generation and chat resolver.
This commit is contained in:
@@ -25,7 +25,7 @@ export const getTools = (
|
||||
Object.assign(tools, webTools)
|
||||
}
|
||||
if (actions.includes(AgentAction.Schedule)) {
|
||||
const scheduleTools = getScheduleTools({ fetch })
|
||||
const scheduleTools = getScheduleTools({ fetch, identity })
|
||||
Object.assign(tools, scheduleTools)
|
||||
}
|
||||
if (actions.includes(AgentAction.Memory)) {
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import { tool } from 'ai'
|
||||
import { z } from 'zod'
|
||||
import { AuthFetcher } from '..'
|
||||
import type { IdentityContext } from '../types'
|
||||
|
||||
export type ScheduleToolParams = {
|
||||
fetch: AuthFetcher
|
||||
identity: IdentityContext
|
||||
}
|
||||
|
||||
const ScheduleSchema = z.object({
|
||||
@@ -15,12 +17,15 @@ const ScheduleSchema = z.object({
|
||||
command: z.string(),
|
||||
})
|
||||
|
||||
export const getScheduleTools = ({ fetch }: ScheduleToolParams) => {
|
||||
export const getScheduleTools = ({ fetch, identity }: ScheduleToolParams) => {
|
||||
const botId = identity.botId.trim()
|
||||
const base = `/bots/${botId}/schedule`
|
||||
|
||||
const listSchedules = tool({
|
||||
description: 'List schedules for current user',
|
||||
inputSchema: z.object({}),
|
||||
execute: async () => {
|
||||
const response = await fetch('/schedule', { method: 'GET' })
|
||||
const response = await fetch(base, { method: 'GET' })
|
||||
return response.json()
|
||||
},
|
||||
})
|
||||
@@ -31,7 +36,7 @@ export const getScheduleTools = ({ fetch }: ScheduleToolParams) => {
|
||||
id: z.string().describe('Schedule ID'),
|
||||
}),
|
||||
execute: async ({ id }) => {
|
||||
const response = await fetch(`/schedule/${id}`, { method: 'GET' })
|
||||
const response = await fetch(`${base}/${id}`, { method: 'GET' })
|
||||
return response.json()
|
||||
},
|
||||
})
|
||||
@@ -47,7 +52,7 @@ export const getScheduleTools = ({ fetch }: ScheduleToolParams) => {
|
||||
command: z.string(),
|
||||
}),
|
||||
execute: async (payload) => {
|
||||
const response = await fetch('/schedule', {
|
||||
const response = await fetch(base, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(payload),
|
||||
@@ -63,7 +68,7 @@ export const getScheduleTools = ({ fetch }: ScheduleToolParams) => {
|
||||
}),
|
||||
execute: async (payload) => {
|
||||
const { id, ...body } = payload
|
||||
const response = await fetch(`/schedule/${id}`, {
|
||||
const response = await fetch(`${base}/${id}`, {
|
||||
method: 'PUT',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(body),
|
||||
@@ -78,7 +83,7 @@ export const getScheduleTools = ({ fetch }: ScheduleToolParams) => {
|
||||
id: z.string(),
|
||||
}),
|
||||
execute: async ({ id }) => {
|
||||
const response = await fetch(`/schedule/${id}`, { method: 'DELETE' })
|
||||
const response = await fetch(`${base}/${id}`, { method: 'DELETE' })
|
||||
return response.status === 204 ? { success: true } : response.json()
|
||||
},
|
||||
})
|
||||
|
||||
+1
-1
@@ -336,5 +336,5 @@ func (c *lazyLLMClient) resolve(ctx context.Context) (memory.LLM, error) {
|
||||
if clientType != "openai" && clientType != "openai-compat" {
|
||||
return nil, fmt.Errorf("memory provider client type not supported: %s", memoryProvider.ClientType)
|
||||
}
|
||||
return memory.NewLLMClient(c.logger, memoryProvider.BaseUrl, memoryProvider.ApiKey, memoryModel.ModelID, c.timeout), nil
|
||||
return memory.NewLLMClient(c.logger, memoryProvider.BaseUrl, memoryProvider.ApiKey, memoryModel.ModelID, c.timeout)
|
||||
}
|
||||
|
||||
@@ -112,7 +112,19 @@ func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotR
|
||||
if err != nil {
|
||||
return Bot{}, err
|
||||
}
|
||||
return toBot(row)
|
||||
bot, err := toBot(row)
|
||||
if err != nil {
|
||||
return Bot{}, err
|
||||
}
|
||||
if s.containerLifecycle != nil {
|
||||
if err := s.containerLifecycle.SetupBotContainer(ctx, bot.ID); err != nil {
|
||||
s.logger.Error("failed to setup bot container",
|
||||
slog.String("bot_id", bot.ID),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
}
|
||||
}
|
||||
return bot, nil
|
||||
}
|
||||
|
||||
func (s *Service) Get(ctx context.Context, botID string) (Bot, error) {
|
||||
|
||||
@@ -58,6 +58,7 @@ type ListMembersResponse struct {
|
||||
|
||||
// ContainerLifecycle handles container lifecycle events bound to bot operations.
|
||||
type ContainerLifecycle interface {
|
||||
SetupBotContainer(ctx context.Context, botID string) error
|
||||
CleanupBotContainer(ctx context.Context, botID string) error
|
||||
}
|
||||
|
||||
|
||||
+112
-25
@@ -111,6 +111,30 @@ type gatewayResponse struct {
|
||||
Skills []string `json:"skills"`
|
||||
}
|
||||
|
||||
// gatewaySchedule matches the agent gateway ScheduleModel for /chat/trigger-schedule.
|
||||
type gatewaySchedule struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Pattern string `json:"pattern"`
|
||||
MaxCalls *int `json:"maxCalls,omitempty"`
|
||||
Command string `json:"command"`
|
||||
}
|
||||
|
||||
// triggerScheduleRequest is the payload for POST /chat/trigger-schedule.
|
||||
type triggerScheduleRequest struct {
|
||||
Model gatewayModelConfig `json:"model"`
|
||||
ActiveContextTime int `json:"activeContextTime"`
|
||||
Channels []string `json:"channels"`
|
||||
CurrentChannel string `json:"currentChannel"`
|
||||
AllowedActions []string `json:"allowedActions,omitempty"`
|
||||
Messages []ModelMessage `json:"messages"`
|
||||
Skills []string `json:"skills"`
|
||||
Identity gatewayIdentity `json:"identity"`
|
||||
Attachments []any `json:"attachments"`
|
||||
Schedule gatewaySchedule `json:"schedule"`
|
||||
}
|
||||
|
||||
// --- resolved context (shared by Chat / StreamChat / TriggerSchedule) ---
|
||||
|
||||
type resolvedContext struct {
|
||||
@@ -226,7 +250,7 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err
|
||||
|
||||
// --- TriggerSchedule ---
|
||||
|
||||
// TriggerSchedule executes a scheduled command through the chat gateway.
|
||||
// TriggerSchedule executes a scheduled command through the agent gateway trigger-schedule endpoint.
|
||||
func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error {
|
||||
if strings.TrimSpace(botID) == "" {
|
||||
return fmt.Errorf("bot id is required")
|
||||
@@ -234,24 +258,52 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc
|
||||
if strings.TrimSpace(payload.Command) == "" {
|
||||
return fmt.Errorf("schedule command is required")
|
||||
}
|
||||
|
||||
sessionID := "schedule:" + payload.ID
|
||||
req := ChatRequest{
|
||||
BotID: botID,
|
||||
SessionID: "schedule:" + payload.ID,
|
||||
SessionID: sessionID,
|
||||
Query: payload.Command,
|
||||
UserID: payload.OwnerUserID,
|
||||
Token: token,
|
||||
}
|
||||
rc, err := r.resolve(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rc.payload.Identity.ContactID = botID
|
||||
rc.payload.Identity.ContactName = "Scheduler"
|
||||
|
||||
resp, err := r.postChat(ctx, rc.payload, token)
|
||||
triggerReq := triggerScheduleRequest{
|
||||
Model: rc.payload.Model,
|
||||
ActiveContextTime: rc.payload.ActiveContextTime,
|
||||
Channels: rc.payload.Channels,
|
||||
CurrentChannel: rc.payload.CurrentChannel,
|
||||
AllowedActions: rc.payload.AllowedActions,
|
||||
Messages: rc.payload.Messages,
|
||||
Skills: rc.payload.Skills,
|
||||
Identity: gatewayIdentity{
|
||||
BotID: rc.payload.Identity.BotID,
|
||||
SessionID: rc.payload.Identity.SessionID,
|
||||
ContainerID: rc.payload.Identity.ContainerID,
|
||||
ContactID: botID,
|
||||
ContactName: "Scheduler",
|
||||
UserID: payload.OwnerUserID,
|
||||
},
|
||||
Attachments: rc.payload.Attachments,
|
||||
Schedule: gatewaySchedule{
|
||||
ID: payload.ID,
|
||||
Name: payload.Name,
|
||||
Description: payload.Description,
|
||||
Pattern: payload.Pattern,
|
||||
MaxCalls: payload.MaxCalls,
|
||||
Command: payload.Command,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := r.postTriggerSchedule(ctx, triggerReq, token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.storeRound(ctx, botID, req.SessionID, payload.Command, resp.Messages, resp.Skills)
|
||||
return r.storeRound(ctx, botID, sessionID, payload.Command, resp.Messages, resp.Skills)
|
||||
}
|
||||
|
||||
// --- StreamChat ---
|
||||
@@ -319,6 +371,47 @@ func (r *Resolver) postChat(ctx context.Context, payload gatewayRequest, token s
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
// postTriggerSchedule sends a trigger-schedule request to the agent gateway.
|
||||
func (r *Resolver) postTriggerSchedule(ctx context.Context, payload triggerScheduleRequest, token string) (gatewayResponse, error) {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return gatewayResponse{}, err
|
||||
}
|
||||
url := r.gatewayBaseURL + "/chat/trigger-schedule"
|
||||
r.logger.Info("gateway trigger-schedule request", slog.String("url", url), slog.String("schedule_id", payload.Schedule.ID))
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return gatewayResponse{}, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if strings.TrimSpace(token) != "" {
|
||||
httpReq.Header.Set("Authorization", token)
|
||||
}
|
||||
|
||||
resp, err := r.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return gatewayResponse{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return gatewayResponse{}, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
r.logger.Error("gateway trigger-schedule error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(respBody), 300)))
|
||||
return gatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody)))
|
||||
}
|
||||
|
||||
var parsed gatewayResponse
|
||||
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||
r.logger.Error("gateway trigger-schedule response parse failed", slog.String("body_prefix", truncate(string(respBody), 300)), slog.Any("error", err))
|
||||
return gatewayResponse{}, fmt.Errorf("failed to parse gateway response: %w", err)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, botID, sessionID, query, token string, chunkCh chan<- StreamChunk) error {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
@@ -567,12 +660,16 @@ func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, us reso
|
||||
modelID := strings.TrimSpace(req.Model)
|
||||
providerFilter := strings.TrimSpace(req.Provider)
|
||||
|
||||
// Priority: request model > user settings > first available.
|
||||
// Priority: request model > user settings. No implicit fallback.
|
||||
if modelID == "" && providerFilter == "" && strings.TrimSpace(us.ChatModelID) != "" {
|
||||
modelID = us.ChatModelID
|
||||
}
|
||||
|
||||
if modelID != "" && providerFilter == "" {
|
||||
if modelID == "" {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not configured: specify model in request or user settings")
|
||||
}
|
||||
|
||||
if providerFilter == "" {
|
||||
return r.fetchChatModel(ctx, modelID)
|
||||
}
|
||||
|
||||
@@ -580,26 +677,16 @@ func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, us reso
|
||||
if err != nil {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
if modelID != "" {
|
||||
for _, m := range candidates {
|
||||
if m.ModelID == modelID {
|
||||
prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
return m, prov, nil
|
||||
for _, m := range candidates {
|
||||
if m.ModelID == modelID {
|
||||
prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
return m, prov, nil
|
||||
}
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not found")
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available")
|
||||
}
|
||||
prov, err := models.FetchProviderByID(ctx, r.queries, candidates[0].LlmProviderID)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
return candidates[0], prov, nil
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model %q not found for provider %q", modelID, providerFilter)
|
||||
}
|
||||
|
||||
func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) {
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPostTriggerSchedule_Endpoint(t *testing.T) {
|
||||
var capturedPath string
|
||||
var capturedBody []byte
|
||||
var capturedAuth string
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedPath = r.URL.Path
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
capturedBody, _ = io.ReadAll(r.Body)
|
||||
resp := gatewayResponse{
|
||||
Messages: []ModelMessage{{Role: "assistant", Content: NewTextContent("ok")}},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
resolver := &Resolver{
|
||||
gatewayBaseURL: srv.URL,
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
logger: slog.Default(),
|
||||
}
|
||||
|
||||
maxCalls := 5
|
||||
req := triggerScheduleRequest{
|
||||
Model: gatewayModelConfig{
|
||||
ModelID: "gpt-4",
|
||||
ClientType: "openai",
|
||||
APIKey: "sk-test",
|
||||
BaseURL: "https://api.openai.com",
|
||||
},
|
||||
ActiveContextTime: 1440,
|
||||
Channels: []string{},
|
||||
Messages: []ModelMessage{},
|
||||
Skills: []string{},
|
||||
Identity: gatewayIdentity{
|
||||
BotID: "bot-123",
|
||||
SessionID: "schedule:sched-1",
|
||||
ContainerID: "mcp-bot-123",
|
||||
ContactID: "bot-123",
|
||||
ContactName: "Scheduler",
|
||||
UserID: "owner-user-1",
|
||||
},
|
||||
Attachments: []any{},
|
||||
Schedule: gatewaySchedule{
|
||||
ID: "sched-1",
|
||||
Name: "daily report",
|
||||
Description: "generate daily report",
|
||||
Pattern: "0 9 * * *",
|
||||
MaxCalls: &maxCalls,
|
||||
Command: "generate the daily report",
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := resolver.postTriggerSchedule(context.Background(), req, "Bearer test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("postTriggerSchedule returned error: %v", err)
|
||||
}
|
||||
|
||||
if capturedPath != "/chat/trigger-schedule" {
|
||||
t.Errorf("expected path /chat/trigger-schedule, got %s", capturedPath)
|
||||
}
|
||||
if capturedAuth != "Bearer test-token" {
|
||||
t.Errorf("expected Authorization header 'Bearer test-token', got %s", capturedAuth)
|
||||
}
|
||||
if len(resp.Messages) != 1 {
|
||||
t.Errorf("expected 1 message, got %d", len(resp.Messages))
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(capturedBody, &body); err != nil {
|
||||
t.Fatalf("failed to parse captured body: %v", err)
|
||||
}
|
||||
schedule, ok := body["schedule"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("expected 'schedule' field in request body")
|
||||
}
|
||||
if schedule["id"] != "sched-1" {
|
||||
t.Errorf("expected schedule.id=sched-1, got %v", schedule["id"])
|
||||
}
|
||||
if schedule["command"] != "generate the daily report" {
|
||||
t.Errorf("expected schedule.command, got %v", schedule["command"])
|
||||
}
|
||||
if _, hasQuery := body["query"]; hasQuery {
|
||||
t.Error("trigger-schedule request should not contain 'query' field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostTriggerSchedule_NoAuth(t *testing.T) {
|
||||
var capturedAuth string
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
resp := gatewayResponse{Messages: []ModelMessage{}}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
resolver := &Resolver{
|
||||
gatewayBaseURL: srv.URL,
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
logger: slog.Default(),
|
||||
}
|
||||
|
||||
req := triggerScheduleRequest{
|
||||
Channels: []string{},
|
||||
Messages: []ModelMessage{},
|
||||
Skills: []string{},
|
||||
Attachments: []any{},
|
||||
Schedule: gatewaySchedule{ID: "s1", Command: "test"},
|
||||
}
|
||||
|
||||
_, err := resolver.postTriggerSchedule(context.Background(), req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("postTriggerSchedule returned error: %v", err)
|
||||
}
|
||||
if capturedAuth != "" {
|
||||
t.Errorf("expected no Authorization header, got %s", capturedAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostTriggerSchedule_GatewayError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("internal error"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
resolver := &Resolver{
|
||||
gatewayBaseURL: srv.URL,
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
logger: slog.Default(),
|
||||
}
|
||||
|
||||
req := triggerScheduleRequest{
|
||||
Channels: []string{},
|
||||
Messages: []ModelMessage{},
|
||||
Skills: []string{},
|
||||
Attachments: []any{},
|
||||
Schedule: gatewaySchedule{ID: "s1", Command: "test"},
|
||||
}
|
||||
|
||||
_, err := resolver.postTriggerSchedule(context.Background(), req, "Bearer tok")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 500 response")
|
||||
}
|
||||
}
|
||||
@@ -37,15 +37,18 @@ type openAIEmbeddingResponse struct {
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func NewOpenAIEmbedder(log *slog.Logger, apiKey, baseURL, model string, dims int, timeout time.Duration) *OpenAIEmbedder {
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
func NewOpenAIEmbedder(log *slog.Logger, apiKey, baseURL, model string, dims int, timeout time.Duration) (*OpenAIEmbedder, error) {
|
||||
if strings.TrimSpace(baseURL) == "" {
|
||||
return nil, fmt.Errorf("openai embedder: base url is required")
|
||||
}
|
||||
if model == "" {
|
||||
model = "text-embedding-3-small"
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
return nil, fmt.Errorf("openai embedder: api key is required")
|
||||
}
|
||||
if strings.TrimSpace(model) == "" {
|
||||
return nil, fmt.Errorf("openai embedder: model is required")
|
||||
}
|
||||
if dims <= 0 {
|
||||
dims = 1536
|
||||
return nil, fmt.Errorf("openai embedder: dimensions must be positive")
|
||||
}
|
||||
if timeout <= 0 {
|
||||
timeout = 10 * time.Second
|
||||
@@ -59,7 +62,7 @@ func NewOpenAIEmbedder(log *slog.Logger, apiKey, baseURL, model string, dims int
|
||||
http: &http.Client{
|
||||
Timeout: timeout,
|
||||
},
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *OpenAIEmbedder) Dimensions() int {
|
||||
|
||||
@@ -130,10 +130,10 @@ func (r *Resolver) Embed(ctx context.Context, req Request) (Result, error) {
|
||||
if req.Provider != ProviderOpenAI {
|
||||
return Result{}, errors.New("provider not implemented")
|
||||
}
|
||||
if strings.TrimSpace(provider.ApiKey) == "" {
|
||||
return Result{}, errors.New("openai api key is required")
|
||||
embedder, err := NewOpenAIEmbedder(r.logger, provider.ApiKey, provider.BaseUrl, req.Model, req.Dimensions, timeout)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
embedder := NewOpenAIEmbedder(r.logger, provider.ApiKey, provider.BaseUrl, req.Model, req.Dimensions, timeout)
|
||||
vector, err := embedder.Embed(ctx, req.Input.Text)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
|
||||
@@ -641,6 +641,112 @@ func (h *ContainerdHandler) authorizeBotAccess(ctx context.Context, actorID, bot
|
||||
return bot, nil
|
||||
}
|
||||
|
||||
// SetupBotContainer creates and starts the MCP container for a bot.
|
||||
func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string) error {
|
||||
containerID := mcp.ContainerPrefix + botID
|
||||
|
||||
image := strings.TrimSpace(h.cfg.BusyboxImage)
|
||||
if image == "" {
|
||||
image = config.DefaultBusyboxImg
|
||||
}
|
||||
snapshotter := strings.TrimSpace(h.cfg.Snapshotter)
|
||||
|
||||
if strings.TrimSpace(h.namespace) != "" {
|
||||
ctx = namespaces.WithNamespace(ctx, h.namespace)
|
||||
}
|
||||
|
||||
dataRoot := strings.TrimSpace(h.cfg.DataRoot)
|
||||
if dataRoot == "" {
|
||||
dataRoot = config.DefaultDataRoot
|
||||
}
|
||||
dataRoot, _ = filepath.Abs(dataRoot)
|
||||
dataMount := strings.TrimSpace(h.cfg.DataMount)
|
||||
if dataMount == "" {
|
||||
dataMount = config.DefaultDataMount
|
||||
}
|
||||
dataDir := filepath.Join(dataRoot, "bots", botID)
|
||||
if err := os.MkdirAll(dataDir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Join(dataDir, ".skills"), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
specOpts := []oci.SpecOpts{
|
||||
oci.WithMounts([]specs.Mount{
|
||||
{
|
||||
Destination: dataMount,
|
||||
Type: "bind",
|
||||
Source: dataDir,
|
||||
Options: []string{"rbind", "rw"},
|
||||
},
|
||||
{
|
||||
Destination: "/app",
|
||||
Type: "bind",
|
||||
Source: dataDir,
|
||||
Options: []string{"rbind", "rw"},
|
||||
},
|
||||
}),
|
||||
oci.WithProcessArgs("/bin/sh", "-lc", "sleep 2147483647"),
|
||||
}
|
||||
|
||||
_, err := h.service.CreateContainer(ctx, ctr.CreateContainerRequest{
|
||||
ID: containerID,
|
||||
ImageRef: image,
|
||||
Snapshotter: snapshotter,
|
||||
Labels: map[string]string{
|
||||
mcp.BotLabelKey: botID,
|
||||
},
|
||||
SpecOpts: specOpts,
|
||||
})
|
||||
if err != nil && !errdefs.IsAlreadyExists(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.queries != nil {
|
||||
pgBotID, parseErr := parsePgUUID(botID)
|
||||
if parseErr == nil {
|
||||
ns := strings.TrimSpace(h.namespace)
|
||||
if ns == "" {
|
||||
ns = "default"
|
||||
}
|
||||
_ = h.queries.UpsertContainer(ctx, dbsqlc.UpsertContainerParams{
|
||||
BotID: pgBotID,
|
||||
ContainerID: containerID,
|
||||
ContainerName: containerID,
|
||||
Image: image,
|
||||
Status: "created",
|
||||
Namespace: ns,
|
||||
AutoStart: true,
|
||||
HostPath: pgtype.Text{String: dataDir, Valid: true},
|
||||
ContainerPath: dataMount,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fifoDir, err := h.taskFIFODir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := h.service.StartTask(ctx, containerID, &ctr.StartTaskOptions{
|
||||
UseStdio: false,
|
||||
FIFODir: fifoDir,
|
||||
}); err == nil {
|
||||
if h.queries != nil {
|
||||
if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil {
|
||||
_ = h.queries.UpdateContainerStarted(ctx, pgBotID)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
h.logger.Error("setup bot container: task start failed",
|
||||
slog.String("bot_id", botID),
|
||||
slog.String("container_id", containerID),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupBotContainer removes the containerd container and DB record for a bot.
|
||||
func (h *ContainerdHandler) CleanupBotContainer(ctx context.Context, botID string) error {
|
||||
containerID, err := h.botContainerID(ctx, botID)
|
||||
|
||||
@@ -20,29 +20,31 @@ type LLMClient struct {
|
||||
http *http.Client
|
||||
}
|
||||
|
||||
func NewLLMClient(log *slog.Logger, baseURL, apiKey, model string, timeout time.Duration) *LLMClient {
|
||||
func NewLLMClient(log *slog.Logger, baseURL, apiKey, model string, timeout time.Duration) (*LLMClient, error) {
|
||||
if strings.TrimSpace(baseURL) == "" {
|
||||
return nil, fmt.Errorf("llm client: base url is required")
|
||||
}
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
return nil, fmt.Errorf("llm client: api key is required")
|
||||
}
|
||||
if strings.TrimSpace(model) == "" {
|
||||
return nil, fmt.Errorf("llm client: model is required")
|
||||
}
|
||||
if log == nil {
|
||||
log = slog.Default()
|
||||
}
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
baseURL = strings.TrimRight(baseURL, "/")
|
||||
if model == "" {
|
||||
model = "gpt-4.1-nano"
|
||||
}
|
||||
if timeout <= 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
return &LLMClient{
|
||||
baseURL: baseURL,
|
||||
baseURL: strings.TrimRight(baseURL, "/"),
|
||||
apiKey: apiKey,
|
||||
model: model,
|
||||
logger: log.With(slog.String("client", "llm")),
|
||||
http: &http.Client{
|
||||
Timeout: timeout,
|
||||
},
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *LLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) {
|
||||
|
||||
@@ -20,7 +20,10 @@ func TestLLMClientExtract(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewLLMClient(nil, server.URL, "test-key", "gpt-4.1-nano-2025-04-14", 0)
|
||||
client, err := NewLLMClient(nil, server.URL, "test-key", "gpt-4.1-nano-2025-04-14", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("new llm client: %v", err)
|
||||
}
|
||||
resp, err := client.Extract(context.Background(), ExtractRequest{
|
||||
Messages: []Message{{Role: "user", Content: "hi"}},
|
||||
})
|
||||
|
||||
+10
-11
@@ -135,16 +135,16 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
|
||||
desc, _ = p.registry.GetDescriptor(msg.Channel)
|
||||
}
|
||||
resp, err := p.chat.Chat(ctx, chat.ChatRequest{
|
||||
BotID: identity.BotID,
|
||||
SessionID: identity.SessionID,
|
||||
Token: token,
|
||||
UserID: identity.UserID,
|
||||
ContactID: identity.ContactID,
|
||||
ContactName: strings.TrimSpace(identity.Contact.DisplayName),
|
||||
ContactAlias: strings.TrimSpace(identity.Contact.Alias),
|
||||
ReplyTarget: strings.TrimSpace(msg.ReplyTarget),
|
||||
SessionToken: sessionToken,
|
||||
Query: text,
|
||||
BotID: identity.BotID,
|
||||
SessionID: identity.SessionID,
|
||||
Token: token,
|
||||
UserID: identity.UserID,
|
||||
ContactID: identity.ContactID,
|
||||
ContactName: strings.TrimSpace(identity.Contact.DisplayName),
|
||||
ContactAlias: strings.TrimSpace(identity.Contact.Alias),
|
||||
ReplyTarget: strings.TrimSpace(msg.ReplyTarget),
|
||||
SessionToken: sessionToken,
|
||||
Query: text,
|
||||
CurrentChannel: msg.Channel.String(),
|
||||
Channels: []string{msg.Channel.String()},
|
||||
})
|
||||
@@ -526,7 +526,6 @@ func isMessagingToolDuplicate(text string, sentTexts []string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
func (p *ChannelInboundProcessor) requireIdentity(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) (IdentityState, error) {
|
||||
if state, ok := IdentityStateFromContext(ctx); ok {
|
||||
return state, nil
|
||||
|
||||
@@ -7,12 +7,14 @@ import (
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/robfig/cron/v3"
|
||||
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
)
|
||||
|
||||
@@ -214,6 +216,8 @@ func (s *Service) Trigger(ctx context.Context, scheduleID string) error {
|
||||
return s.runSchedule(ctx, schedule)
|
||||
}
|
||||
|
||||
const scheduleTokenTTL = 10 * time.Minute
|
||||
|
||||
func (s *Service) runSchedule(ctx context.Context, schedule Schedule) error {
|
||||
if s.triggerer == nil {
|
||||
return fmt.Errorf("schedule triggerer not configured")
|
||||
@@ -225,18 +229,55 @@ func (s *Service) runSchedule(ctx context.Context, schedule Schedule) error {
|
||||
if !updated.Enabled {
|
||||
s.removeJob(schedule.ID)
|
||||
}
|
||||
token := ""
|
||||
if err := s.triggerer.TriggerSchedule(ctx, schedule.BotID, TriggerPayload{
|
||||
|
||||
ownerUserID, err := s.resolveBotOwner(ctx, schedule.BotID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve bot owner: %w", err)
|
||||
}
|
||||
|
||||
token, err := s.generateTriggerToken(ownerUserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate trigger token: %w", err)
|
||||
}
|
||||
|
||||
return s.triggerer.TriggerSchedule(ctx, schedule.BotID, TriggerPayload{
|
||||
ID: schedule.ID,
|
||||
Name: schedule.Name,
|
||||
Description: schedule.Description,
|
||||
Pattern: schedule.Pattern,
|
||||
MaxCalls: schedule.MaxCalls,
|
||||
Command: schedule.Command,
|
||||
}, token); err != nil {
|
||||
return err
|
||||
OwnerUserID: ownerUserID,
|
||||
}, token)
|
||||
}
|
||||
|
||||
// resolveBotOwner returns the owner user ID for the given bot.
|
||||
func (s *Service) resolveBotOwner(ctx context.Context, botID string) (string, error) {
|
||||
pgBotID, err := parseUUID(botID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return nil
|
||||
bot, err := s.queries.GetBotByID(ctx, pgBotID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get bot: %w", err)
|
||||
}
|
||||
ownerID := toUUIDString(bot.OwnerUserID)
|
||||
if ownerID == "" {
|
||||
return "", fmt.Errorf("bot owner not found")
|
||||
}
|
||||
return ownerID, nil
|
||||
}
|
||||
|
||||
// generateTriggerToken creates a short-lived JWT for schedule trigger callbacks.
|
||||
func (s *Service) generateTriggerToken(userID string) (string, error) {
|
||||
if strings.TrimSpace(s.jwtSecret) == "" {
|
||||
return "", fmt.Errorf("jwt secret not configured")
|
||||
}
|
||||
signed, _, err := auth.GenerateToken(userID, s.jwtSecret, scheduleTokenTTL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "Bearer " + signed, nil
|
||||
}
|
||||
|
||||
func (s *Service) scheduleJob(schedule sqlc.Schedule) error {
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
package schedule
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type mockTriggerer struct {
|
||||
called bool
|
||||
botID string
|
||||
payload TriggerPayload
|
||||
token string
|
||||
}
|
||||
|
||||
func (m *mockTriggerer) TriggerSchedule(_ context.Context, botID string, payload TriggerPayload, token string) error {
|
||||
m.called = true
|
||||
m.botID = botID
|
||||
m.payload = payload
|
||||
m.token = token
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestGenerateTriggerToken(t *testing.T) {
|
||||
secret := "test-secret-key-for-schedule"
|
||||
svc := &Service{
|
||||
jwtSecret: secret,
|
||||
logger: slog.Default(),
|
||||
}
|
||||
userID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
|
||||
tok, err := svc.generateTriggerToken(userID)
|
||||
if err != nil {
|
||||
t.Fatalf("generateTriggerToken returned error: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(tok, "Bearer ") {
|
||||
t.Fatalf("expected Bearer prefix, got: %s", tok)
|
||||
}
|
||||
|
||||
raw := strings.TrimPrefix(tok, "Bearer ")
|
||||
parsed, err := jwt.Parse(raw, func(token *jwt.Token) (any, error) {
|
||||
return []byte(secret), nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse JWT: %v", err)
|
||||
}
|
||||
claims, ok := parsed.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
t.Fatal("expected MapClaims")
|
||||
}
|
||||
if sub, _ := claims["sub"].(string); sub != userID {
|
||||
t.Errorf("expected sub=%s, got=%s", userID, sub)
|
||||
}
|
||||
if uid, _ := claims["user_id"].(string); uid != userID {
|
||||
t.Errorf("expected user_id=%s, got=%s", userID, uid)
|
||||
}
|
||||
exp, _ := claims["exp"].(float64)
|
||||
if exp == 0 {
|
||||
t.Fatal("expected non-zero exp")
|
||||
}
|
||||
expTime := time.Unix(int64(exp), 0)
|
||||
if expTime.Before(time.Now().Add(9 * time.Minute)) {
|
||||
t.Error("token expires too soon")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateTriggerToken_EmptySecret(t *testing.T) {
|
||||
svc := &Service{
|
||||
jwtSecret: "",
|
||||
logger: slog.Default(),
|
||||
}
|
||||
_, err := svc.generateTriggerToken("user-123")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty secret")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateTriggerToken_EmptyUserID(t *testing.T) {
|
||||
svc := &Service{
|
||||
jwtSecret: "some-secret",
|
||||
logger: slog.Default(),
|
||||
}
|
||||
_, err := svc.generateTriggerToken("")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty user ID")
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@ package schedule
|
||||
|
||||
import "context"
|
||||
|
||||
// TriggerPayload 描述触发定时任务时传递给聊天侧的参数。
|
||||
// TriggerPayload describes the parameters passed to the chat side when a schedule triggers.
|
||||
type TriggerPayload struct {
|
||||
ID string
|
||||
Name string
|
||||
@@ -10,6 +10,7 @@ type TriggerPayload struct {
|
||||
Pattern string
|
||||
MaxCalls *int
|
||||
Command string
|
||||
OwnerUserID string
|
||||
}
|
||||
|
||||
// Triggerer 负责触发与聊天相关的调度执行。
|
||||
|
||||
Reference in New Issue
Block a user