fix: use bot model configs first

This commit is contained in:
Acbox
2026-02-07 20:45:26 +08:00
parent b237594495
commit 344b617423
14 changed files with 861 additions and 530 deletions
+33 -12
View File
@@ -156,11 +156,15 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
skipHistory := req.MaxContextLoadTime < 0
botSettings, err := r.loadBotSettings(ctx, req.BotID)
if err != nil {
return resolvedContext{}, err
}
userSettings, err := r.loadUserSettings(ctx, req.UserID)
if err != nil {
return resolvedContext{}, err
}
chatModel, provider, err := r.selectChatModel(ctx, req, userSettings)
chatModel, provider, err := r.selectChatModel(ctx, req, botSettings, userSettings)
if err != nil {
return resolvedContext{}, err
}
@@ -168,11 +172,6 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
if err != nil {
return resolvedContext{}, err
}
botSettings, err := r.loadBotSettings(ctx, req.BotID)
if err != nil {
return resolvedContext{}, err
}
maxCtx := coalescePositiveInt(req.MaxContextLoadTime, botSettings.MaxContextLoadTime, defaultMaxContextMinutes)
var messages []ModelMessage
@@ -312,6 +311,10 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc
func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan StreamChunk, <-chan error) {
chunkCh := make(chan StreamChunk)
errCh := make(chan error, 1)
r.logger.Info("gateway stream start",
slog.String("bot_id", req.BotID),
slog.String("session_id", req.SessionID),
)
go func() {
defer close(chunkCh)
@@ -319,10 +322,20 @@ func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan Stre
rc, err := r.resolve(ctx, req)
if err != nil {
r.logger.Error("gateway stream resolve failed",
slog.String("bot_id", req.BotID),
slog.String("session_id", req.SessionID),
slog.Any("error", err),
)
errCh <- err
return
}
if err := r.streamChat(ctx, rc.payload, req.BotID, req.SessionID, req.Query, req.Token, chunkCh); err != nil {
r.logger.Error("gateway stream request failed",
slog.String("bot_id", req.BotID),
slog.String("session_id", req.SessionID),
slog.Any("error", err),
)
errCh <- err
}
}()
@@ -417,7 +430,9 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, botID
if err != nil {
return err
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, r.gatewayBaseURL+"/chat/stream", bytes.NewReader(body))
url := r.gatewayBaseURL + "/chat/stream"
r.logger.Info("gateway stream request", slog.String("url", url), slog.String("body_prefix", truncate(string(body), 200)))
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return err
}
@@ -429,12 +444,14 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, botID
resp, err := r.streamingClient.Do(httpReq)
if err != nil {
r.logger.Error("gateway stream connect failed", slog.String("url", url), slog.Any("error", err))
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
errBody, _ := io.ReadAll(resp.Body)
r.logger.Error("gateway stream error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(errBody), 300)))
return fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(errBody)))
}
@@ -653,20 +670,24 @@ func (r *Resolver) storeMemory(ctx context.Context, botID, sessionID, query stri
// --- model selection ---
func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, us resolvedUserSettings) (models.GetResponse, sqlc.LlmProvider, error) {
func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, botSettings settings.Settings, us resolvedUserSettings) (models.GetResponse, sqlc.LlmProvider, error) {
if r.modelsService == nil {
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured")
}
modelID := strings.TrimSpace(req.Model)
providerFilter := strings.TrimSpace(req.Provider)
// Priority: request model > user settings. No implicit fallback.
if modelID == "" && providerFilter == "" && strings.TrimSpace(us.ChatModelID) != "" {
modelID = us.ChatModelID
// Priority: request model > bot settings > user settings.
if modelID == "" && providerFilter == "" {
if value := strings.TrimSpace(botSettings.ChatModelID); value != "" {
modelID = value
} else if value := strings.TrimSpace(us.ChatModelID); value != "" {
modelID = value
}
}
if modelID == "" {
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not configured: specify model in request or user settings")
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not configured: specify model in request or bot settings")
}
if providerFilter == "" {
+73
View File
@@ -21,6 +21,38 @@ func (q *Queries) DeleteSettingsByBotID(ctx context.Context, botID pgtype.UUID)
return err
}
const getBotModelConfigByBotID = `-- name: GetBotModelConfigByBotID :one
SELECT
bot_model_configs.bot_id,
chat_models.model_id AS chat_model_id,
memory_models.model_id AS memory_model_id,
embedding_models.model_id AS embedding_model_id
FROM bot_model_configs
LEFT JOIN models AS chat_models ON chat_models.id = bot_model_configs.chat_model_id
LEFT JOIN models AS memory_models ON memory_models.id = bot_model_configs.memory_model_id
LEFT JOIN models AS embedding_models ON embedding_models.id = bot_model_configs.embedding_model_id
WHERE bot_model_configs.bot_id = $1
`
type GetBotModelConfigByBotIDRow struct {
BotID pgtype.UUID `json:"bot_id"`
ChatModelID pgtype.Text `json:"chat_model_id"`
MemoryModelID pgtype.Text `json:"memory_model_id"`
EmbeddingModelID pgtype.Text `json:"embedding_model_id"`
}
func (q *Queries) GetBotModelConfigByBotID(ctx context.Context, botID pgtype.UUID) (GetBotModelConfigByBotIDRow, error) {
row := q.db.QueryRow(ctx, getBotModelConfigByBotID, botID)
var i GetBotModelConfigByBotIDRow
err := row.Scan(
&i.BotID,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
)
return i, err
}
const getSettingsByBotID = `-- name: GetSettingsByBotID :one
SELECT bot_id, max_context_load_time, language, allow_guest
FROM bot_settings
@@ -59,6 +91,47 @@ func (q *Queries) GetSettingsByUserID(ctx context.Context, userID pgtype.UUID) (
return i, err
}
const upsertBotModelConfig = `-- name: UpsertBotModelConfig :one
INSERT INTO bot_model_configs (bot_id, chat_model_id, memory_model_id, embedding_model_id)
VALUES ($1, $2, $3, $4)
ON CONFLICT (bot_id) DO UPDATE SET
chat_model_id = COALESCE(EXCLUDED.chat_model_id, bot_model_configs.chat_model_id),
memory_model_id = COALESCE(EXCLUDED.memory_model_id, bot_model_configs.memory_model_id),
embedding_model_id = COALESCE(EXCLUDED.embedding_model_id, bot_model_configs.embedding_model_id)
RETURNING bot_id, chat_model_id, memory_model_id, embedding_model_id
`
type UpsertBotModelConfigParams struct {
BotID pgtype.UUID `json:"bot_id"`
ChatModelID pgtype.UUID `json:"chat_model_id"`
MemoryModelID pgtype.UUID `json:"memory_model_id"`
EmbeddingModelID pgtype.UUID `json:"embedding_model_id"`
}
type UpsertBotModelConfigRow struct {
BotID pgtype.UUID `json:"bot_id"`
ChatModelID pgtype.UUID `json:"chat_model_id"`
MemoryModelID pgtype.UUID `json:"memory_model_id"`
EmbeddingModelID pgtype.UUID `json:"embedding_model_id"`
}
func (q *Queries) UpsertBotModelConfig(ctx context.Context, arg UpsertBotModelConfigParams) (UpsertBotModelConfigRow, error) {
row := q.db.QueryRow(ctx, upsertBotModelConfig,
arg.BotID,
arg.ChatModelID,
arg.MemoryModelID,
arg.EmbeddingModelID,
)
var i UpsertBotModelConfigRow
err := row.Scan(
&i.BotID,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
)
return i, err
}
const upsertBotSettings = `-- name: UpsertBotSettings :one
INSERT INTO bot_settings (bot_id, max_context_load_time, language, allow_guest)
VALUES ($1, $2, $3, $4)
+6
View File
@@ -113,6 +113,11 @@ func (h *ChatHandler) StreamChat(c echo.Context) error {
return err
}
botID := strings.TrimSpace(c.Param("bot_id"))
h.logger.Info("chat stream request received",
slog.String("bot_id", botID),
slog.String("session_id", c.QueryParam("session_id")),
slog.String("user_id", userID),
)
if botID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
}
@@ -185,6 +190,7 @@ func (h *ChatHandler) StreamChat(c echo.Context) error {
case err := <-errChan:
if err != nil {
h.logger.Error("chat stream failed", slog.Any("error", err))
// Send error as SSE event
errData := map[string]string{"error": err.Error()}
data, _ := json.Marshal(errData)
+50 -18
View File
@@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"os/exec"
"runtime"
@@ -38,13 +39,13 @@ import (
// @Description {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"fs.read","arguments":{"path":"notes.txt"}}}
// @Tags containerd
// @Param Authorization header string true "Bearer <token>"
// @Param id path string true "Container ID"
// @Param bot_id path string true "Bot ID"
// @Param payload body object true "JSON-RPC request"
// @Success 200 {object} object "JSON-RPC response: {jsonrpc,id,result|error}"
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /container/fs/{id} [post]
// @Router /bots/{bot_id}/container/fs [post]
func (h *ContainerdHandler) HandleMCPFS(c echo.Context) error {
botID, err := h.requireBotAccess(c)
if err != nil {
@@ -69,32 +70,39 @@ func (h *ContainerdHandler) HandleMCPFS(c echo.Context) error {
}
if err := h.validateMCPContainer(c.Request().Context(), containerID, botID); err != nil {
return err
h.logger.Error("mcp fs validate failed", slog.Any("error", err), slog.String("bot_id", botID), slog.String("container_id", containerID))
return c.JSON(http.StatusOK, mcptools.JSONRPCResponse{
JSONRPC: "2.0",
ID: req.ID,
Error: &mcptools.JSONRPCError{Code: -32603, Message: err.Error()},
})
}
if err := h.ensureTaskRunning(c.Request().Context(), containerID); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
h.logger.Error("mcp fs ensure task failed", slog.Any("error", err), slog.String("bot_id", botID), slog.String("container_id", containerID))
return c.JSON(http.StatusOK, mcptools.JSONRPCResponse{
JSONRPC: "2.0",
ID: req.ID,
Error: &mcptools.JSONRPCError{Code: -32603, Message: err.Error()},
})
}
switch req.Method {
case "tools/list":
payload, err := h.callMCPServer(c.Request().Context(), containerID, req)
if err != nil {
return err
}
return c.JSON(http.StatusOK, payload)
case "tools/call":
payload, err := h.callMCPServer(c.Request().Context(), containerID, req)
if err != nil {
return err
}
return c.JSON(http.StatusOK, payload)
default:
if strings.TrimSpace(req.Method) == "" {
return c.JSON(http.StatusOK, mcptools.JSONRPCResponse{
JSONRPC: "2.0",
ID: req.ID,
Error: &mcptools.JSONRPCError{Code: -32601, Message: "method not found"},
})
}
payload, err := h.callMCPServer(c.Request().Context(), containerID, req)
if err != nil {
h.logger.Error("mcp fs call failed", slog.Any("error", err), slog.String("method", req.Method), slog.String("bot_id", botID), slog.String("container_id", containerID))
return c.JSON(http.StatusOK, mcptools.JSONRPCResponse{
JSONRPC: "2.0",
ID: req.ID,
Error: &mcptools.JSONRPCError{Code: -32603, Message: err.Error()},
})
}
return c.JSON(http.StatusOK, payload)
}
func (h *ContainerdHandler) validateMCPContainer(ctx context.Context, containerID, botID string) error {
@@ -198,10 +206,12 @@ func (h *ContainerdHandler) startContainerdMCPSession(ctx context.Context, conta
closed: make(chan struct{}),
}
h.startMCPStderrLogger(execSession.Stderr, containerID)
go sess.readLoop()
go func() {
_, err := execSession.Wait()
if err != nil {
h.logger.Error("mcp session exited", slog.Any("error", err), slog.String("container_id", containerID))
sess.closeWithError(err)
} else {
sess.closeWithError(io.EOF)
@@ -263,9 +273,11 @@ func (h *ContainerdHandler) startLimaMCPSession(containerID string) (*mcpSession
closed: make(chan struct{}),
}
h.startMCPStderrLogger(stderr, containerID)
go sess.readLoop()
go func() {
if err := cmd.Wait(); err != nil {
h.logger.Error("mcp session exited", slog.Any("error", err), slog.String("container_id", containerID))
sess.closeWithError(err)
} else {
sess.closeWithError(io.EOF)
@@ -297,6 +309,26 @@ func (s *mcpSession) closeWithError(err error) {
})
}
func (h *ContainerdHandler) startMCPStderrLogger(stderr io.ReadCloser, containerID string) {
if stderr == nil {
return
}
go func() {
scanner := bufio.NewScanner(stderr)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
h.logger.Warn("mcp stderr", slog.String("container_id", containerID), slog.String("message", line))
}
if err := scanner.Err(); err != nil {
h.logger.Error("mcp stderr read failed", slog.Any("error", err), slog.String("container_id", containerID))
}
}()
}
func (s *mcpSession) readLoop() {
scanner := bufio.NewScanner(s.stdout)
scanner.Buffer(make([]byte, 0, 64*1024), 8*1024*1024)
+6 -3
View File
@@ -41,11 +41,12 @@ type skillsOpResponse struct {
// ListSkills godoc
// @Summary List skills from container
// @Tags containerd
// @Param bot_id path string true "Bot ID"
// @Success 200 {object} SkillsResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /container/skills [get]
// @Router /bots/{bot_id}/container/skills [get]
func (h *ContainerdHandler) ListSkills(c echo.Context) error {
botID, err := h.requireBotAccess(c)
if err != nil {
@@ -98,12 +99,13 @@ func (h *ContainerdHandler) ListSkills(c echo.Context) error {
// UpsertSkills godoc
// @Summary Upload skills into container
// @Tags containerd
// @Param bot_id path string true "Bot ID"
// @Param payload body SkillsUpsertRequest true "Skills payload"
// @Success 200 {object} skillsOpResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /container/skills [post]
// @Router /bots/{bot_id}/container/skills [post]
func (h *ContainerdHandler) UpsertSkills(c echo.Context) error {
botID, err := h.requireBotAccess(c)
if err != nil {
@@ -149,12 +151,13 @@ func (h *ContainerdHandler) UpsertSkills(c echo.Context) error {
// DeleteSkills godoc
// @Summary Delete skills from container
// @Tags containerd
// @Param bot_id path string true "Bot ID"
// @Param payload body SkillsDeleteRequest true "Delete skills payload"
// @Success 200 {object} skillsOpResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /container/skills [delete]
// @Router /bots/{bot_id}/container/skills [delete]
func (h *ContainerdHandler) DeleteSkills(c echo.Context) error {
botID, err := h.requireBotAccess(c)
if err != nil {
+84 -3
View File
@@ -109,15 +109,23 @@ func (s *Service) GetBot(ctx context.Context, botID string) (Settings, error) {
row, err := s.queries.GetSettingsByBotID(ctx, pgID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return Settings{
settings := Settings{
MaxContextLoadTime: DefaultMaxContextLoadTime,
Language: DefaultLanguage,
AllowGuest: false,
}, nil
}
if err := s.attachBotModelConfig(ctx, pgID, &settings); err != nil {
return Settings{}, err
}
return settings, nil
}
return Settings{}, err
}
return normalizeBotSetting(row), nil
settings := normalizeBotSetting(row)
if err := s.attachBotModelConfig(ctx, pgID, &settings); err != nil {
return Settings{}, err
}
return settings, nil
}
func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest) (Settings, error) {
@@ -160,6 +168,12 @@ func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest
if err != nil {
return Settings{}, err
}
if err := s.upsertBotModelConfig(ctx, pgID, req); err != nil {
return Settings{}, err
}
if err := s.attachBotModelConfig(ctx, pgID, &current); err != nil {
return Settings{}, err
}
return current, nil
}
@@ -206,6 +220,73 @@ func normalizeBotSetting(row sqlc.BotSetting) Settings {
return settings
}
func (s *Service) attachBotModelConfig(ctx context.Context, botID pgtype.UUID, target *Settings) error {
if s.queries == nil || target == nil {
return nil
}
row, err := s.queries.GetBotModelConfigByBotID(ctx, botID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil
}
return err
}
target.ChatModelID = strings.TrimSpace(row.ChatModelID.String)
target.MemoryModelID = strings.TrimSpace(row.MemoryModelID.String)
target.EmbeddingModelID = strings.TrimSpace(row.EmbeddingModelID.String)
return nil
}
func (s *Service) upsertBotModelConfig(ctx context.Context, botID pgtype.UUID, req UpsertRequest) error {
if s.queries == nil {
return fmt.Errorf("settings queries not configured")
}
params := sqlc.UpsertBotModelConfigParams{
BotID: botID,
}
hasUpdate := false
if value := strings.TrimSpace(req.ChatModelID); value != "" {
modelID, err := s.resolveModelUUID(ctx, value)
if err != nil {
return err
}
params.ChatModelID = modelID
hasUpdate = true
}
if value := strings.TrimSpace(req.MemoryModelID); value != "" {
modelID, err := s.resolveModelUUID(ctx, value)
if err != nil {
return err
}
params.MemoryModelID = modelID
hasUpdate = true
}
if value := strings.TrimSpace(req.EmbeddingModelID); value != "" {
modelID, err := s.resolveModelUUID(ctx, value)
if err != nil {
return err
}
params.EmbeddingModelID = modelID
hasUpdate = true
}
if !hasUpdate {
return nil
}
_, err := s.queries.UpsertBotModelConfig(ctx, params)
return err
}
func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype.UUID, error) {
if strings.TrimSpace(modelID) == "" {
return pgtype.UUID{}, fmt.Errorf("model_id is required")
}
row, err := s.queries.GetModelByModelID(ctx, modelID)
if err != nil {
return pgtype.UUID{}, err
}
return row.ID, nil
}
func parseUUID(id string) (pgtype.UUID, error) {
parsed, err := uuid.Parse(id)
if err != nil {