mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
fix: use bot model configs first
This commit is contained in:
+33
-12
@@ -156,11 +156,15 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
|
||||
|
||||
skipHistory := req.MaxContextLoadTime < 0
|
||||
|
||||
botSettings, err := r.loadBotSettings(ctx, req.BotID)
|
||||
if err != nil {
|
||||
return resolvedContext{}, err
|
||||
}
|
||||
userSettings, err := r.loadUserSettings(ctx, req.UserID)
|
||||
if err != nil {
|
||||
return resolvedContext{}, err
|
||||
}
|
||||
chatModel, provider, err := r.selectChatModel(ctx, req, userSettings)
|
||||
chatModel, provider, err := r.selectChatModel(ctx, req, botSettings, userSettings)
|
||||
if err != nil {
|
||||
return resolvedContext{}, err
|
||||
}
|
||||
@@ -168,11 +172,6 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
|
||||
if err != nil {
|
||||
return resolvedContext{}, err
|
||||
}
|
||||
|
||||
botSettings, err := r.loadBotSettings(ctx, req.BotID)
|
||||
if err != nil {
|
||||
return resolvedContext{}, err
|
||||
}
|
||||
maxCtx := coalescePositiveInt(req.MaxContextLoadTime, botSettings.MaxContextLoadTime, defaultMaxContextMinutes)
|
||||
|
||||
var messages []ModelMessage
|
||||
@@ -312,6 +311,10 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc
|
||||
func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan StreamChunk, <-chan error) {
|
||||
chunkCh := make(chan StreamChunk)
|
||||
errCh := make(chan error, 1)
|
||||
r.logger.Info("gateway stream start",
|
||||
slog.String("bot_id", req.BotID),
|
||||
slog.String("session_id", req.SessionID),
|
||||
)
|
||||
|
||||
go func() {
|
||||
defer close(chunkCh)
|
||||
@@ -319,10 +322,20 @@ func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan Stre
|
||||
|
||||
rc, err := r.resolve(ctx, req)
|
||||
if err != nil {
|
||||
r.logger.Error("gateway stream resolve failed",
|
||||
slog.String("bot_id", req.BotID),
|
||||
slog.String("session_id", req.SessionID),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
if err := r.streamChat(ctx, rc.payload, req.BotID, req.SessionID, req.Query, req.Token, chunkCh); err != nil {
|
||||
r.logger.Error("gateway stream request failed",
|
||||
slog.String("bot_id", req.BotID),
|
||||
slog.String("session_id", req.SessionID),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
@@ -417,7 +430,9 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, botID
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, r.gatewayBaseURL+"/chat/stream", bytes.NewReader(body))
|
||||
url := r.gatewayBaseURL + "/chat/stream"
|
||||
r.logger.Info("gateway stream request", slog.String("url", url), slog.String("body_prefix", truncate(string(body), 200)))
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -429,12 +444,14 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, botID
|
||||
|
||||
resp, err := r.streamingClient.Do(httpReq)
|
||||
if err != nil {
|
||||
r.logger.Error("gateway stream connect failed", slog.String("url", url), slog.Any("error", err))
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
r.logger.Error("gateway stream error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(errBody), 300)))
|
||||
return fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(errBody)))
|
||||
}
|
||||
|
||||
@@ -653,20 +670,24 @@ func (r *Resolver) storeMemory(ctx context.Context, botID, sessionID, query stri
|
||||
|
||||
// --- model selection ---
|
||||
|
||||
func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, us resolvedUserSettings) (models.GetResponse, sqlc.LlmProvider, error) {
|
||||
func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, botSettings settings.Settings, us resolvedUserSettings) (models.GetResponse, sqlc.LlmProvider, error) {
|
||||
if r.modelsService == nil {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured")
|
||||
}
|
||||
modelID := strings.TrimSpace(req.Model)
|
||||
providerFilter := strings.TrimSpace(req.Provider)
|
||||
|
||||
// Priority: request model > user settings. No implicit fallback.
|
||||
if modelID == "" && providerFilter == "" && strings.TrimSpace(us.ChatModelID) != "" {
|
||||
modelID = us.ChatModelID
|
||||
// Priority: request model > bot settings > user settings.
|
||||
if modelID == "" && providerFilter == "" {
|
||||
if value := strings.TrimSpace(botSettings.ChatModelID); value != "" {
|
||||
modelID = value
|
||||
} else if value := strings.TrimSpace(us.ChatModelID); value != "" {
|
||||
modelID = value
|
||||
}
|
||||
}
|
||||
|
||||
if modelID == "" {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not configured: specify model in request or user settings")
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not configured: specify model in request or bot settings")
|
||||
}
|
||||
|
||||
if providerFilter == "" {
|
||||
|
||||
@@ -21,6 +21,38 @@ func (q *Queries) DeleteSettingsByBotID(ctx context.Context, botID pgtype.UUID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getBotModelConfigByBotID = `-- name: GetBotModelConfigByBotID :one
|
||||
SELECT
|
||||
bot_model_configs.bot_id,
|
||||
chat_models.model_id AS chat_model_id,
|
||||
memory_models.model_id AS memory_model_id,
|
||||
embedding_models.model_id AS embedding_model_id
|
||||
FROM bot_model_configs
|
||||
LEFT JOIN models AS chat_models ON chat_models.id = bot_model_configs.chat_model_id
|
||||
LEFT JOIN models AS memory_models ON memory_models.id = bot_model_configs.memory_model_id
|
||||
LEFT JOIN models AS embedding_models ON embedding_models.id = bot_model_configs.embedding_model_id
|
||||
WHERE bot_model_configs.bot_id = $1
|
||||
`
|
||||
|
||||
type GetBotModelConfigByBotIDRow struct {
|
||||
BotID pgtype.UUID `json:"bot_id"`
|
||||
ChatModelID pgtype.Text `json:"chat_model_id"`
|
||||
MemoryModelID pgtype.Text `json:"memory_model_id"`
|
||||
EmbeddingModelID pgtype.Text `json:"embedding_model_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetBotModelConfigByBotID(ctx context.Context, botID pgtype.UUID) (GetBotModelConfigByBotIDRow, error) {
|
||||
row := q.db.QueryRow(ctx, getBotModelConfigByBotID, botID)
|
||||
var i GetBotModelConfigByBotIDRow
|
||||
err := row.Scan(
|
||||
&i.BotID,
|
||||
&i.ChatModelID,
|
||||
&i.MemoryModelID,
|
||||
&i.EmbeddingModelID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getSettingsByBotID = `-- name: GetSettingsByBotID :one
|
||||
SELECT bot_id, max_context_load_time, language, allow_guest
|
||||
FROM bot_settings
|
||||
@@ -59,6 +91,47 @@ func (q *Queries) GetSettingsByUserID(ctx context.Context, userID pgtype.UUID) (
|
||||
return i, err
|
||||
}
|
||||
|
||||
const upsertBotModelConfig = `-- name: UpsertBotModelConfig :one
|
||||
INSERT INTO bot_model_configs (bot_id, chat_model_id, memory_model_id, embedding_model_id)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (bot_id) DO UPDATE SET
|
||||
chat_model_id = COALESCE(EXCLUDED.chat_model_id, bot_model_configs.chat_model_id),
|
||||
memory_model_id = COALESCE(EXCLUDED.memory_model_id, bot_model_configs.memory_model_id),
|
||||
embedding_model_id = COALESCE(EXCLUDED.embedding_model_id, bot_model_configs.embedding_model_id)
|
||||
RETURNING bot_id, chat_model_id, memory_model_id, embedding_model_id
|
||||
`
|
||||
|
||||
type UpsertBotModelConfigParams struct {
|
||||
BotID pgtype.UUID `json:"bot_id"`
|
||||
ChatModelID pgtype.UUID `json:"chat_model_id"`
|
||||
MemoryModelID pgtype.UUID `json:"memory_model_id"`
|
||||
EmbeddingModelID pgtype.UUID `json:"embedding_model_id"`
|
||||
}
|
||||
|
||||
type UpsertBotModelConfigRow struct {
|
||||
BotID pgtype.UUID `json:"bot_id"`
|
||||
ChatModelID pgtype.UUID `json:"chat_model_id"`
|
||||
MemoryModelID pgtype.UUID `json:"memory_model_id"`
|
||||
EmbeddingModelID pgtype.UUID `json:"embedding_model_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpsertBotModelConfig(ctx context.Context, arg UpsertBotModelConfigParams) (UpsertBotModelConfigRow, error) {
|
||||
row := q.db.QueryRow(ctx, upsertBotModelConfig,
|
||||
arg.BotID,
|
||||
arg.ChatModelID,
|
||||
arg.MemoryModelID,
|
||||
arg.EmbeddingModelID,
|
||||
)
|
||||
var i UpsertBotModelConfigRow
|
||||
err := row.Scan(
|
||||
&i.BotID,
|
||||
&i.ChatModelID,
|
||||
&i.MemoryModelID,
|
||||
&i.EmbeddingModelID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const upsertBotSettings = `-- name: UpsertBotSettings :one
|
||||
INSERT INTO bot_settings (bot_id, max_context_load_time, language, allow_guest)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
|
||||
@@ -113,6 +113,11 @@ func (h *ChatHandler) StreamChat(c echo.Context) error {
|
||||
return err
|
||||
}
|
||||
botID := strings.TrimSpace(c.Param("bot_id"))
|
||||
h.logger.Info("chat stream request received",
|
||||
slog.String("bot_id", botID),
|
||||
slog.String("session_id", c.QueryParam("session_id")),
|
||||
slog.String("user_id", userID),
|
||||
)
|
||||
if botID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
||||
}
|
||||
@@ -185,6 +190,7 @@ func (h *ChatHandler) StreamChat(c echo.Context) error {
|
||||
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
h.logger.Error("chat stream failed", slog.Any("error", err))
|
||||
// Send error as SSE event
|
||||
errData := map[string]string{"error": err.Error()}
|
||||
data, _ := json.Marshal(errData)
|
||||
|
||||
+50
-18
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
@@ -38,13 +39,13 @@ import (
|
||||
// @Description {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"fs.read","arguments":{"path":"notes.txt"}}}
|
||||
// @Tags containerd
|
||||
// @Param Authorization header string true "Bearer <token>"
|
||||
// @Param id path string true "Container ID"
|
||||
// @Param bot_id path string true "Bot ID"
|
||||
// @Param payload body object true "JSON-RPC request"
|
||||
// @Success 200 {object} object "JSON-RPC response: {jsonrpc,id,result|error}"
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /container/fs/{id} [post]
|
||||
// @Router /bots/{bot_id}/container/fs [post]
|
||||
func (h *ContainerdHandler) HandleMCPFS(c echo.Context) error {
|
||||
botID, err := h.requireBotAccess(c)
|
||||
if err != nil {
|
||||
@@ -69,32 +70,39 @@ func (h *ContainerdHandler) HandleMCPFS(c echo.Context) error {
|
||||
}
|
||||
|
||||
if err := h.validateMCPContainer(c.Request().Context(), containerID, botID); err != nil {
|
||||
return err
|
||||
h.logger.Error("mcp fs validate failed", slog.Any("error", err), slog.String("bot_id", botID), slog.String("container_id", containerID))
|
||||
return c.JSON(http.StatusOK, mcptools.JSONRPCResponse{
|
||||
JSONRPC: "2.0",
|
||||
ID: req.ID,
|
||||
Error: &mcptools.JSONRPCError{Code: -32603, Message: err.Error()},
|
||||
})
|
||||
}
|
||||
if err := h.ensureTaskRunning(c.Request().Context(), containerID); err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
h.logger.Error("mcp fs ensure task failed", slog.Any("error", err), slog.String("bot_id", botID), slog.String("container_id", containerID))
|
||||
return c.JSON(http.StatusOK, mcptools.JSONRPCResponse{
|
||||
JSONRPC: "2.0",
|
||||
ID: req.ID,
|
||||
Error: &mcptools.JSONRPCError{Code: -32603, Message: err.Error()},
|
||||
})
|
||||
}
|
||||
|
||||
switch req.Method {
|
||||
case "tools/list":
|
||||
payload, err := h.callMCPServer(c.Request().Context(), containerID, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(http.StatusOK, payload)
|
||||
case "tools/call":
|
||||
payload, err := h.callMCPServer(c.Request().Context(), containerID, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(http.StatusOK, payload)
|
||||
default:
|
||||
if strings.TrimSpace(req.Method) == "" {
|
||||
return c.JSON(http.StatusOK, mcptools.JSONRPCResponse{
|
||||
JSONRPC: "2.0",
|
||||
ID: req.ID,
|
||||
Error: &mcptools.JSONRPCError{Code: -32601, Message: "method not found"},
|
||||
})
|
||||
}
|
||||
payload, err := h.callMCPServer(c.Request().Context(), containerID, req)
|
||||
if err != nil {
|
||||
h.logger.Error("mcp fs call failed", slog.Any("error", err), slog.String("method", req.Method), slog.String("bot_id", botID), slog.String("container_id", containerID))
|
||||
return c.JSON(http.StatusOK, mcptools.JSONRPCResponse{
|
||||
JSONRPC: "2.0",
|
||||
ID: req.ID,
|
||||
Error: &mcptools.JSONRPCError{Code: -32603, Message: err.Error()},
|
||||
})
|
||||
}
|
||||
return c.JSON(http.StatusOK, payload)
|
||||
}
|
||||
|
||||
func (h *ContainerdHandler) validateMCPContainer(ctx context.Context, containerID, botID string) error {
|
||||
@@ -198,10 +206,12 @@ func (h *ContainerdHandler) startContainerdMCPSession(ctx context.Context, conta
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
|
||||
h.startMCPStderrLogger(execSession.Stderr, containerID)
|
||||
go sess.readLoop()
|
||||
go func() {
|
||||
_, err := execSession.Wait()
|
||||
if err != nil {
|
||||
h.logger.Error("mcp session exited", slog.Any("error", err), slog.String("container_id", containerID))
|
||||
sess.closeWithError(err)
|
||||
} else {
|
||||
sess.closeWithError(io.EOF)
|
||||
@@ -263,9 +273,11 @@ func (h *ContainerdHandler) startLimaMCPSession(containerID string) (*mcpSession
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
|
||||
h.startMCPStderrLogger(stderr, containerID)
|
||||
go sess.readLoop()
|
||||
go func() {
|
||||
if err := cmd.Wait(); err != nil {
|
||||
h.logger.Error("mcp session exited", slog.Any("error", err), slog.String("container_id", containerID))
|
||||
sess.closeWithError(err)
|
||||
} else {
|
||||
sess.closeWithError(io.EOF)
|
||||
@@ -297,6 +309,26 @@ func (s *mcpSession) closeWithError(err error) {
|
||||
})
|
||||
}
|
||||
|
||||
func (h *ContainerdHandler) startMCPStderrLogger(stderr io.ReadCloser, containerID string) {
|
||||
if stderr == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
h.logger.Warn("mcp stderr", slog.String("container_id", containerID), slog.String("message", line))
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
h.logger.Error("mcp stderr read failed", slog.Any("error", err), slog.String("container_id", containerID))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *mcpSession) readLoop() {
|
||||
scanner := bufio.NewScanner(s.stdout)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 8*1024*1024)
|
||||
|
||||
@@ -41,11 +41,12 @@ type skillsOpResponse struct {
|
||||
// ListSkills godoc
|
||||
// @Summary List skills from container
|
||||
// @Tags containerd
|
||||
// @Param bot_id path string true "Bot ID"
|
||||
// @Success 200 {object} SkillsResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /container/skills [get]
|
||||
// @Router /bots/{bot_id}/container/skills [get]
|
||||
func (h *ContainerdHandler) ListSkills(c echo.Context) error {
|
||||
botID, err := h.requireBotAccess(c)
|
||||
if err != nil {
|
||||
@@ -98,12 +99,13 @@ func (h *ContainerdHandler) ListSkills(c echo.Context) error {
|
||||
// UpsertSkills godoc
|
||||
// @Summary Upload skills into container
|
||||
// @Tags containerd
|
||||
// @Param bot_id path string true "Bot ID"
|
||||
// @Param payload body SkillsUpsertRequest true "Skills payload"
|
||||
// @Success 200 {object} skillsOpResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /container/skills [post]
|
||||
// @Router /bots/{bot_id}/container/skills [post]
|
||||
func (h *ContainerdHandler) UpsertSkills(c echo.Context) error {
|
||||
botID, err := h.requireBotAccess(c)
|
||||
if err != nil {
|
||||
@@ -149,12 +151,13 @@ func (h *ContainerdHandler) UpsertSkills(c echo.Context) error {
|
||||
// DeleteSkills godoc
|
||||
// @Summary Delete skills from container
|
||||
// @Tags containerd
|
||||
// @Param bot_id path string true "Bot ID"
|
||||
// @Param payload body SkillsDeleteRequest true "Delete skills payload"
|
||||
// @Success 200 {object} skillsOpResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /container/skills [delete]
|
||||
// @Router /bots/{bot_id}/container/skills [delete]
|
||||
func (h *ContainerdHandler) DeleteSkills(c echo.Context) error {
|
||||
botID, err := h.requireBotAccess(c)
|
||||
if err != nil {
|
||||
|
||||
@@ -109,15 +109,23 @@ func (s *Service) GetBot(ctx context.Context, botID string) (Settings, error) {
|
||||
row, err := s.queries.GetSettingsByBotID(ctx, pgID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return Settings{
|
||||
settings := Settings{
|
||||
MaxContextLoadTime: DefaultMaxContextLoadTime,
|
||||
Language: DefaultLanguage,
|
||||
AllowGuest: false,
|
||||
}, nil
|
||||
}
|
||||
if err := s.attachBotModelConfig(ctx, pgID, &settings); err != nil {
|
||||
return Settings{}, err
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
return Settings{}, err
|
||||
}
|
||||
return normalizeBotSetting(row), nil
|
||||
settings := normalizeBotSetting(row)
|
||||
if err := s.attachBotModelConfig(ctx, pgID, &settings); err != nil {
|
||||
return Settings{}, err
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest) (Settings, error) {
|
||||
@@ -160,6 +168,12 @@ func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest
|
||||
if err != nil {
|
||||
return Settings{}, err
|
||||
}
|
||||
if err := s.upsertBotModelConfig(ctx, pgID, req); err != nil {
|
||||
return Settings{}, err
|
||||
}
|
||||
if err := s.attachBotModelConfig(ctx, pgID, ¤t); err != nil {
|
||||
return Settings{}, err
|
||||
}
|
||||
return current, nil
|
||||
}
|
||||
|
||||
@@ -206,6 +220,73 @@ func normalizeBotSetting(row sqlc.BotSetting) Settings {
|
||||
return settings
|
||||
}
|
||||
|
||||
func (s *Service) attachBotModelConfig(ctx context.Context, botID pgtype.UUID, target *Settings) error {
|
||||
if s.queries == nil || target == nil {
|
||||
return nil
|
||||
}
|
||||
row, err := s.queries.GetBotModelConfigByBotID(ctx, botID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
target.ChatModelID = strings.TrimSpace(row.ChatModelID.String)
|
||||
target.MemoryModelID = strings.TrimSpace(row.MemoryModelID.String)
|
||||
target.EmbeddingModelID = strings.TrimSpace(row.EmbeddingModelID.String)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) upsertBotModelConfig(ctx context.Context, botID pgtype.UUID, req UpsertRequest) error {
|
||||
if s.queries == nil {
|
||||
return fmt.Errorf("settings queries not configured")
|
||||
}
|
||||
params := sqlc.UpsertBotModelConfigParams{
|
||||
BotID: botID,
|
||||
}
|
||||
hasUpdate := false
|
||||
if value := strings.TrimSpace(req.ChatModelID); value != "" {
|
||||
modelID, err := s.resolveModelUUID(ctx, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
params.ChatModelID = modelID
|
||||
hasUpdate = true
|
||||
}
|
||||
if value := strings.TrimSpace(req.MemoryModelID); value != "" {
|
||||
modelID, err := s.resolveModelUUID(ctx, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
params.MemoryModelID = modelID
|
||||
hasUpdate = true
|
||||
}
|
||||
if value := strings.TrimSpace(req.EmbeddingModelID); value != "" {
|
||||
modelID, err := s.resolveModelUUID(ctx, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
params.EmbeddingModelID = modelID
|
||||
hasUpdate = true
|
||||
}
|
||||
if !hasUpdate {
|
||||
return nil
|
||||
}
|
||||
_, err := s.queries.UpsertBotModelConfig(ctx, params)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype.UUID, error) {
|
||||
if strings.TrimSpace(modelID) == "" {
|
||||
return pgtype.UUID{}, fmt.Errorf("model_id is required")
|
||||
}
|
||||
row, err := s.queries.GetModelByModelID(ctx, modelID)
|
||||
if err != nil {
|
||||
return pgtype.UUID{}, err
|
||||
}
|
||||
return row.ID, nil
|
||||
}
|
||||
|
||||
func parseUUID(id string) (pgtype.UUID, error) {
|
||||
parsed, err := uuid.Parse(id)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user