From 50bdbd519ca66bec2d3b5417648c8f33cd37c67b Mon Sep 17 00:00:00 2001 From: ringotypowriter Date: Sat, 21 Feb 2026 22:31:32 +0800 Subject: [PATCH] fix(models,settings,conversation): scope model_id uniqueness per provider and harden model reference resolution --- db/migrations/0001_init.up.sql | 3 +- ...0011_model_id_unique_per_provider.down.sql | 24 ++++++ .../0011_model_id_unique_per_provider.up.sql | 15 ++++ db/queries/conversations.sql | 13 +--- db/queries/models.sql | 6 ++ db/queries/settings.sql | 12 +-- internal/conversation/flow/resolver.go | 23 +++++- internal/conversation/service.go | 75 +++++++++++++------ internal/db/sqlc/conversations.sql.go | 23 +++--- internal/db/sqlc/models.sql.go | 54 +++++++++++-- internal/db/sqlc/settings.sql.go | 24 +++--- internal/handlers/models.go | 29 +++++++ internal/handlers/settings.go | 6 ++ internal/models/models.go | 51 +++++++++++-- internal/models/types.go | 1 + internal/settings/service.go | 46 +++++++++--- packages/sdk/src/types.gen.ts | 1 + .../web/src/components/create-model/index.vue | 37 ++++++++- .../pages/bots/components/model-select.vue | 13 ++-- .../pages/models/components/model-item.vue | 6 +- .../pages/models/components/model-list.vue | 6 +- .../web/src/pages/models/model-setting.vue | 7 +- spec/docs.go | 3 + spec/swagger.json | 3 + spec/swagger.yaml | 2 + 25 files changed, 376 insertions(+), 107 deletions(-) create mode 100644 db/migrations/0011_model_id_unique_per_provider.down.sql create mode 100644 db/migrations/0011_model_id_unique_per_provider.up.sql diff --git a/db/migrations/0001_init.up.sql b/db/migrations/0001_init.up.sql index 421f812b..6e0ed770 100644 --- a/db/migrations/0001_init.up.sql +++ b/db/migrations/0001_init.up.sql @@ -89,7 +89,7 @@ CREATE TABLE IF NOT EXISTS models ( type TEXT NOT NULL DEFAULT 'chat', created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), - CONSTRAINT models_model_id_unique UNIQUE (model_id), + CONSTRAINT models_provider_model_id_unique UNIQUE (llm_provider_id, model_id), CONSTRAINT models_type_check CHECK (type IN ('chat', 'embedding')), CONSTRAINT models_dimensions_check CHECK (type != 'embedding' OR dimensions IS NOT NULL), CONSTRAINT models_client_type_check CHECK (client_type IS NULL OR client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai')), @@ -389,4 +389,3 @@ CREATE TABLE IF NOT EXISTS bot_history_message_assets ( ); CREATE INDEX IF NOT EXISTS idx_message_assets_message_id ON bot_history_message_assets(message_id); - diff --git a/db/migrations/0011_model_id_unique_per_provider.down.sql b/db/migrations/0011_model_id_unique_per_provider.down.sql new file mode 100644 index 00000000..f5d26e44 --- /dev/null +++ b/db/migrations/0011_model_id_unique_per_provider.down.sql @@ -0,0 +1,24 @@ +-- 0011_model_id_unique_per_provider +-- Revert model_id uniqueness back to global uniqueness. + +DO $$ +BEGIN + IF EXISTS ( + SELECT 1 + FROM models + GROUP BY model_id + HAVING COUNT(*) > 1 + ) THEN + RAISE EXCEPTION 'cannot rollback 0011_model_id_unique_per_provider: duplicate model_id values exist across providers'; + END IF; + + IF EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'models_provider_model_id_unique') THEN + ALTER TABLE models DROP CONSTRAINT models_provider_model_id_unique; + END IF; + + IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'models_model_id_unique') THEN + ALTER TABLE models + ADD CONSTRAINT models_model_id_unique UNIQUE (model_id); + END IF; +END +$$; diff --git a/db/migrations/0011_model_id_unique_per_provider.up.sql b/db/migrations/0011_model_id_unique_per_provider.up.sql new file mode 100644 index 00000000..13056714 --- /dev/null +++ b/db/migrations/0011_model_id_unique_per_provider.up.sql @@ -0,0 +1,15 @@ +-- 0011_model_id_unique_per_provider +-- Change model_id uniqueness from global to per provider. + +DO $$ +BEGIN + IF EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'models_model_id_unique') THEN + ALTER TABLE models DROP CONSTRAINT models_model_id_unique; + END IF; + + IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'models_provider_model_id_unique') THEN + ALTER TABLE models + ADD CONSTRAINT models_provider_model_id_unique UNIQUE (llm_provider_id, model_id); + END IF; +END +$$; diff --git a/db/queries/conversations.sql b/db/queries/conversations.sql index ef45658b..2a611613 100644 --- a/db/queries/conversations.sql +++ b/db/queries/conversations.sql @@ -205,22 +205,17 @@ ON CONFLICT (bot_id, user_id) DO NOTHING; -- chat_settings -- name: UpsertChatSettings :one -WITH resolved_model AS ( - SELECT id - FROM models - WHERE model_id = NULLIF(sqlc.narg(model_id)::text, '') - LIMIT 1 -), +WITH updated AS ( UPDATE bots - SET chat_model_id = COALESCE((SELECT id FROM resolved_model), bots.chat_model_id), + SET chat_model_id = COALESCE(sqlc.narg(chat_model_id)::uuid, bots.chat_model_id), updated_at = now() WHERE bots.id = sqlc.arg(id) RETURNING bots.id, bots.chat_model_id, bots.updated_at ) SELECT updated.id AS chat_id, - chat_models.model_id AS model_id, + chat_models.id AS model_id, updated.updated_at FROM updated LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id; @@ -228,7 +223,7 @@ LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id; -- name: GetChatSettings :one SELECT b.id AS chat_id, - chat_models.model_id AS model_id, + chat_models.id AS model_id, b.updated_at FROM bots b LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id diff --git a/db/queries/models.sql b/db/queries/models.sql index 707d9ca6..cb30f43c 100644 --- a/db/queries/models.sql +++ b/db/queries/models.sql @@ -54,6 +54,11 @@ SELECT * FROM models WHERE id = sqlc.arg(id); -- name: GetModelByModelID :one SELECT * FROM models WHERE model_id = sqlc.arg(model_id); +-- name: ListModelsByModelID :many +SELECT * FROM models +WHERE model_id = sqlc.arg(model_id) +ORDER BY created_at DESC; + -- name: ListModels :many SELECT * FROM models ORDER BY created_at DESC; @@ -82,6 +87,7 @@ ORDER BY created_at DESC; -- name: UpdateModel :one UPDATE models SET + model_id = sqlc.arg(model_id), name = sqlc.arg(name), llm_provider_id = sqlc.arg(llm_provider_id), client_type = sqlc.narg(client_type), diff --git a/db/queries/settings.sql b/db/queries/settings.sql index 821c48eb..cbcfbca8 100644 --- a/db/queries/settings.sql +++ b/db/queries/settings.sql @@ -5,9 +5,9 @@ SELECT bots.max_context_tokens, bots.language, bots.allow_guest, - chat_models.model_id AS chat_model_id, - memory_models.model_id AS memory_model_id, - embedding_models.model_id AS embedding_model_id, + chat_models.id AS chat_model_id, + memory_models.id AS memory_model_id, + embedding_models.id AS embedding_model_id, search_providers.id AS search_provider_id FROM bots LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id @@ -37,9 +37,9 @@ SELECT updated.max_context_tokens, updated.language, updated.allow_guest, - chat_models.model_id AS chat_model_id, - memory_models.model_id AS memory_model_id, - embedding_models.model_id AS embedding_model_id, + chat_models.id AS chat_model_id, + memory_models.id AS memory_model_id, + embedding_models.id AS embedding_model_id, search_providers.id AS search_provider_id FROM updated LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go index 77de0eb2..159b6eae 100644 --- a/internal/conversation/flow/resolver.go +++ b/internal/conversation/flow/resolver.go @@ -15,6 +15,7 @@ import ( "strings" "time" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" attachmentpkg "github.com/memohai/memoh/internal/attachment" @@ -1535,10 +1536,30 @@ func (r *Resolver) selectChatModel(ctx context.Context, req conversation.ChatReq } func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) { - model, err := r.modelsService.GetByModelID(ctx, modelID) + modelRef := strings.TrimSpace(modelID) + if modelRef == "" { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model id is required") + } + + // Support both model UUID and model_id slug. UUID-formatted slugs still + // work because we fall back to GetByModelID when UUID lookup misses. + var model models.GetResponse + var err error + if _, parseErr := db.ParseUUID(modelRef); parseErr == nil { + model, err = r.modelsService.GetByID(ctx, modelRef) + if err == nil { + goto resolved + } + if !errors.Is(err, pgx.ErrNoRows) { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + } + model, err = r.modelsService.GetByModelID(ctx, modelRef) if err != nil { return models.GetResponse{}, sqlc.LlmProvider{}, err } + +resolved: if model.Type != models.ModelTypeChat { return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model") } diff --git a/internal/conversation/service.go b/internal/conversation/service.go index f85700c2..c319d815 100644 --- a/internal/conversation/service.go +++ b/internal/conversation/service.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" @@ -20,6 +21,7 @@ var ( ErrChatNotFound = errors.New("chat not found") ErrNotParticipant = errors.New("not a participant") ErrPermissionDenied = errors.New("permission denied") + ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers") ) // Service manages conversation lifecycle, participants, and settings. @@ -310,21 +312,26 @@ func (s *Service) GetSettings(ctx context.Context, conversationID string) (Setti // UpdateSettings updates conversation settings. func (s *Service) UpdateSettings(ctx context.Context, conversationID string, req UpdateSettingsRequest) (Settings, error) { - current, err := s.GetSettings(ctx, conversationID) - if err != nil { - return Settings{}, err - } - if req.ModelID != nil { - current.ModelID = *req.ModelID - } - pgID, err := parseUUID(conversationID) if err != nil { return Settings{}, err } + + chatModelUUID := pgtype.UUID{} + if req.ModelID != nil { + modelRef := strings.TrimSpace(*req.ModelID) + if modelRef != "" { + resolved, err := s.resolveModelUUID(ctx, modelRef) + if err != nil { + return Settings{}, err + } + chatModelUUID = resolved + } + } + row, err := s.queries.UpsertChatSettings(ctx, sqlc.UpsertChatSettingsParams{ - ID: pgID, - ModelID: toPgText(current.ModelID), + ID: pgID, + ChatModelID: chatModelUUID, }) if err != nil { return Settings{}, err @@ -427,17 +434,23 @@ func toParticipantFields(conversationID, userID pgtype.UUID, role string, joined } func toSettingsFromRead(row sqlc.GetChatSettingsRow) Settings { - return Settings{ - ChatID: row.ChatID.String(), - ModelID: dbpkg.TextToString(row.ModelID), + settings := Settings{ + ChatID: row.ChatID.String(), } + if row.ModelID.Valid { + settings.ModelID = uuid.UUID(row.ModelID.Bytes).String() + } + return settings } func toSettingsFromUpsert(row sqlc.UpsertChatSettingsRow) Settings { - return Settings{ - ChatID: row.ChatID.String(), - ModelID: dbpkg.TextToString(row.ModelID), + settings := Settings{ + ChatID: row.ChatID.String(), } + if row.ModelID.Valid { + settings.ModelID = uuid.UUID(row.ModelID.Bytes).String() + } + return settings } func defaultSettings(conversationID string) Settings { @@ -450,12 +463,32 @@ func parseUUID(id string) (pgtype.UUID, error) { return dbpkg.ParseUUID(id) } -func toPgText(s string) pgtype.Text { - s = strings.TrimSpace(s) - if s == "" { - return pgtype.Text{} +func (s *Service) resolveModelUUID(ctx context.Context, modelRef string) (pgtype.UUID, error) { + modelRef = strings.TrimSpace(modelRef) + if modelRef == "" { + return pgtype.UUID{}, fmt.Errorf("model_id is required") } - return pgtype.Text{String: s, Valid: true} + + // Prefer UUID path; if not found, fall back to model_id slug. + if parsed, err := dbpkg.ParseUUID(modelRef); err == nil { + if _, err := s.queries.GetModelByID(ctx, parsed); err == nil { + return parsed, nil + } else if !errors.Is(err, pgx.ErrNoRows) { + return pgtype.UUID{}, err + } + } + + rows, err := s.queries.ListModelsByModelID(ctx, modelRef) + if err != nil { + return pgtype.UUID{}, err + } + if len(rows) == 0 { + return pgtype.UUID{}, fmt.Errorf("model not found: %s", modelRef) + } + if len(rows) > 1 { + return pgtype.UUID{}, fmt.Errorf("%w: %s", ErrModelIDAmbiguous, modelRef) + } + return rows[0].ID, nil } func pgTimePtr(ts pgtype.Timestamptz) *time.Time { diff --git a/internal/db/sqlc/conversations.sql.go b/internal/db/sqlc/conversations.sql.go index cbc6eaf1..24118046 100644 --- a/internal/db/sqlc/conversations.sql.go +++ b/internal/db/sqlc/conversations.sql.go @@ -271,7 +271,7 @@ func (q *Queries) GetChatReadAccessByUser(ctx context.Context, arg GetChatReadAc const getChatSettings = `-- name: GetChatSettings :one SELECT b.id AS chat_id, - chat_models.model_id AS model_id, + chat_models.id AS model_id, b.updated_at FROM bots b LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id @@ -280,7 +280,7 @@ WHERE b.id = $1 type GetChatSettingsRow struct { ChatID pgtype.UUID `json:"chat_id"` - ModelID pgtype.Text `json:"model_id"` + ModelID pgtype.UUID `json:"model_id"` UpdatedAt pgtype.Timestamptz `json:"updated_at"` } @@ -645,41 +645,36 @@ func (q *Queries) UpdateChatTitle(ctx context.Context, arg UpdateChatTitleParams const upsertChatSettings = `-- name: UpsertChatSettings :one -WITH resolved_model AS ( - SELECT id - FROM models - WHERE model_id = NULLIF($1::text, '') - LIMIT 1 -), +WITH updated AS ( UPDATE bots - SET chat_model_id = COALESCE((SELECT id FROM resolved_model), bots.chat_model_id), + SET chat_model_id = COALESCE($1::uuid, bots.chat_model_id), updated_at = now() WHERE bots.id = $2 RETURNING bots.id, bots.chat_model_id, bots.updated_at ) SELECT updated.id AS chat_id, - chat_models.model_id AS model_id, + chat_models.id AS model_id, updated.updated_at FROM updated LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id ` type UpsertChatSettingsParams struct { - ModelID pgtype.Text `json:"model_id"` - ID pgtype.UUID `json:"id"` + ChatModelID pgtype.UUID `json:"chat_model_id"` + ID pgtype.UUID `json:"id"` } type UpsertChatSettingsRow struct { ChatID pgtype.UUID `json:"chat_id"` - ModelID pgtype.Text `json:"model_id"` + ModelID pgtype.UUID `json:"model_id"` UpdatedAt pgtype.Timestamptz `json:"updated_at"` } // chat_settings func (q *Queries) UpsertChatSettings(ctx context.Context, arg UpsertChatSettingsParams) (UpsertChatSettingsRow, error) { - row := q.db.QueryRow(ctx, upsertChatSettings, arg.ModelID, arg.ID) + row := q.db.QueryRow(ctx, upsertChatSettings, arg.ChatModelID, arg.ID) var i UpsertChatSettingsRow err := row.Scan(&i.ChatID, &i.ModelID, &i.UpdatedAt) return i, err diff --git a/internal/db/sqlc/models.sql.go b/internal/db/sqlc/models.sql.go index 4c919ce6..bd84b6a3 100644 --- a/internal/db/sqlc/models.sql.go +++ b/internal/db/sqlc/models.sql.go @@ -419,6 +419,43 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType pgtype. return items, nil } +const listModelsByModelID = `-- name: ListModelsByModelID :many +SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models +WHERE model_id = $1 +ORDER BY created_at DESC +` + +func (q *Queries) ListModelsByModelID(ctx context.Context, modelID string) ([]Model, error) { + rows, err := q.db.Query(ctx, listModelsByModelID, modelID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Model + for rows.Next() { + var i Model + if err := rows.Scan( + &i.ID, + &i.ModelID, + &i.Name, + &i.LlmProviderID, + &i.ClientType, + &i.Dimensions, + &i.InputModalities, + &i.Type, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listModelsByProviderID = `-- name: ListModelsByProviderID :many SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models WHERE llm_provider_id = $1 @@ -580,18 +617,20 @@ func (q *Queries) UpdateLlmProvider(ctx context.Context, arg UpdateLlmProviderPa const updateModel = `-- name: UpdateModel :one UPDATE models SET - name = $1, - llm_provider_id = $2, - client_type = $3, - dimensions = $4, - input_modalities = $5, - type = $6, + model_id = $1, + name = $2, + llm_provider_id = $3, + client_type = $4, + dimensions = $5, + input_modalities = $6, + type = $7, updated_at = now() -WHERE id = $7 +WHERE id = $8 RETURNING id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at ` type UpdateModelParams struct { + ModelID string `json:"model_id"` Name pgtype.Text `json:"name"` LlmProviderID pgtype.UUID `json:"llm_provider_id"` ClientType pgtype.Text `json:"client_type"` @@ -603,6 +642,7 @@ type UpdateModelParams struct { func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model, error) { row := q.db.QueryRow(ctx, updateModel, + arg.ModelID, arg.Name, arg.LlmProviderID, arg.ClientType, diff --git a/internal/db/sqlc/settings.sql.go b/internal/db/sqlc/settings.sql.go index 7f3a612e..a5445117 100644 --- a/internal/db/sqlc/settings.sql.go +++ b/internal/db/sqlc/settings.sql.go @@ -37,9 +37,9 @@ SELECT bots.max_context_tokens, bots.language, bots.allow_guest, - chat_models.model_id AS chat_model_id, - memory_models.model_id AS memory_model_id, - embedding_models.model_id AS embedding_model_id, + chat_models.id AS chat_model_id, + memory_models.id AS memory_model_id, + embedding_models.id AS embedding_model_id, search_providers.id AS search_provider_id FROM bots LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id @@ -55,9 +55,9 @@ type GetSettingsByBotIDRow struct { MaxContextTokens int32 `json:"max_context_tokens"` Language string `json:"language"` AllowGuest bool `json:"allow_guest"` - ChatModelID pgtype.Text `json:"chat_model_id"` - MemoryModelID pgtype.Text `json:"memory_model_id"` - EmbeddingModelID pgtype.Text `json:"embedding_model_id"` + ChatModelID pgtype.UUID `json:"chat_model_id"` + MemoryModelID pgtype.UUID `json:"memory_model_id"` + EmbeddingModelID pgtype.UUID `json:"embedding_model_id"` SearchProviderID pgtype.UUID `json:"search_provider_id"` } @@ -99,9 +99,9 @@ SELECT updated.max_context_tokens, updated.language, updated.allow_guest, - chat_models.model_id AS chat_model_id, - memory_models.model_id AS memory_model_id, - embedding_models.model_id AS embedding_model_id, + chat_models.id AS chat_model_id, + memory_models.id AS memory_model_id, + embedding_models.id AS embedding_model_id, search_providers.id AS search_provider_id FROM updated LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id @@ -128,9 +128,9 @@ type UpsertBotSettingsRow struct { MaxContextTokens int32 `json:"max_context_tokens"` Language string `json:"language"` AllowGuest bool `json:"allow_guest"` - ChatModelID pgtype.Text `json:"chat_model_id"` - MemoryModelID pgtype.Text `json:"memory_model_id"` - EmbeddingModelID pgtype.Text `json:"embedding_model_id"` + ChatModelID pgtype.UUID `json:"chat_model_id"` + MemoryModelID pgtype.UUID `json:"memory_model_id"` + EmbeddingModelID pgtype.UUID `json:"embedding_model_id"` SearchProviderID pgtype.UUID `json:"search_provider_id"` } diff --git a/internal/handlers/models.go b/internal/handlers/models.go index 51a6f402..5e205291 100644 --- a/internal/handlers/models.go +++ b/internal/handlers/models.go @@ -1,10 +1,12 @@ package handlers import ( + "errors" "log/slog" "net/http" "net/url" + "github.com/jackc/pgx/v5" "github.com/labstack/echo/v4" "github.com/memohai/memoh/internal/models" @@ -52,6 +54,9 @@ func (h *ModelsHandler) Create(c echo.Context) error { resp, err := h.service.Create(c.Request().Context(), req) if err != nil { + if errors.Is(err, models.ErrModelIDAlreadyExists) { + return echo.NewHTTPError(http.StatusConflict, "model_id already exists under the selected provider") + } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.JSON(http.StatusCreated, resp) @@ -134,6 +139,12 @@ func (h *ModelsHandler) GetByModelID(c echo.Context) error { resp, err := h.service.GetByModelID(c.Request().Context(), modelID) if err != nil { + if errors.Is(err, models.ErrModelIDAmbiguous) { + return echo.NewHTTPError(http.StatusConflict, "model_id is duplicated across providers; use /models/{id} instead") + } + if errors.Is(err, pgx.ErrNoRows) { + return echo.NewHTTPError(http.StatusNotFound, err.Error()) + } return echo.NewHTTPError(http.StatusNotFound, err.Error()) } return c.JSON(http.StatusOK, resp) @@ -163,6 +174,9 @@ func (h *ModelsHandler) UpdateByID(c echo.Context) error { resp, err := h.service.UpdateByID(c.Request().Context(), id, req) if err != nil { + if errors.Is(err, models.ErrModelIDAlreadyExists) { + return echo.NewHTTPError(http.StatusConflict, "model_id already exists under the selected provider") + } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.JSON(http.StatusOK, resp) @@ -197,6 +211,15 @@ func (h *ModelsHandler) UpdateByModelID(c echo.Context) error { resp, err := h.service.UpdateByModelID(c.Request().Context(), modelID, req) if err != nil { + if errors.Is(err, models.ErrModelIDAlreadyExists) { + return echo.NewHTTPError(http.StatusConflict, "model_id already exists under the selected provider") + } + if errors.Is(err, models.ErrModelIDAmbiguous) { + return echo.NewHTTPError(http.StatusConflict, "model_id is duplicated across providers; use /models/{id} instead") + } + if errors.Is(err, pgx.ErrNoRows) { + return echo.NewHTTPError(http.StatusNotFound, err.Error()) + } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.JSON(http.StatusOK, resp) @@ -246,6 +269,12 @@ func (h *ModelsHandler) DeleteByModelID(c echo.Context) error { } if err := h.service.DeleteByModelID(c.Request().Context(), modelID); err != nil { + if errors.Is(err, models.ErrModelIDAmbiguous) { + return echo.NewHTTPError(http.StatusConflict, "model_id is duplicated across providers; use /models/{id} instead") + } + if errors.Is(err, pgx.ErrNoRows) { + return echo.NewHTTPError(http.StatusNotFound, err.Error()) + } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.NoContent(http.StatusNoContent) diff --git a/internal/handlers/settings.go b/internal/handlers/settings.go index ff98c519..172c5b08 100644 --- a/internal/handlers/settings.go +++ b/internal/handlers/settings.go @@ -96,6 +96,12 @@ func (h *SettingsHandler) Upsert(c echo.Context) error { if errors.Is(err, settings.ErrPersonalBotGuestAccessUnsupported) { return echo.NewHTTPError(http.StatusBadRequest, "personal bot does not support guest access") } + if errors.Is(err, settings.ErrInvalidModelRef) { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + if errors.Is(err, settings.ErrModelIDAmbiguous) { + return echo.NewHTTPError(http.StatusConflict, "model_id is duplicated across providers; select by model UUID") + } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.JSON(http.StatusOK, resp) diff --git a/internal/models/models.go b/internal/models/models.go index 007f46e8..2dc6ee99 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -2,16 +2,21 @@ package models import ( "context" + "errors" "fmt" "log/slog" "strings" "github.com/google/uuid" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) +var ErrModelIDAlreadyExists = errors.New("model_id already exists") +var ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers") + // Service provides CRUD operations for models type Service struct { queries *sqlc.Queries @@ -65,6 +70,9 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro created, err := s.queries.CreateModel(ctx, params) if err != nil { + if db.IsUniqueViolation(err) { + return AddResponse{}, ErrModelIDAlreadyExists + } return AddResponse{}, fmt.Errorf("failed to create model: %w", err) } @@ -105,7 +113,7 @@ func (s *Service) GetByModelID(ctx context.Context, modelID string) (GetResponse return GetResponse{}, fmt.Errorf("model_id is required") } - dbModel, err := s.queries.GetModelByModelID(ctx, modelID) + dbModel, err := s.findUniqueByModelID(ctx, modelID) if err != nil { return GetResponse{}, fmt.Errorf("failed to get model: %w", err) } @@ -207,6 +215,7 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) } params := sqlc.UpdateModelParams{ ID: uuid, + ModelID: model.ModelID, InputModalities: inputMod, Type: string(model.Type), } @@ -230,6 +239,9 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) updated, err := s.queries.UpdateModel(ctx, params) if err != nil { + if db.IsUniqueViolation(err) { + return GetResponse{}, ErrModelIDAlreadyExists + } return GetResponse{}, fmt.Errorf("failed to update model: %w", err) } @@ -241,6 +253,10 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat if modelID == "" { return GetResponse{}, fmt.Errorf("model_id is required") } + current, err := s.findUniqueByModelID(ctx, modelID) + if err != nil { + return GetResponse{}, fmt.Errorf("failed to update model: %w", err) + } model := Model(req) if err := model.Validate(); err != nil { @@ -251,9 +267,8 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat if model.Type == ModelTypeChat { inputMod = normalizeModalities(model.InputModalities, []string{ModelInputText}) } - params := sqlc.UpdateModelByModelIDParams{ - ModelID: modelID, - NewModelID: model.ModelID, + params := sqlc.UpdateModelParams{ + ID: current.ID, InputModalities: inputMod, Type: string(model.Type), } @@ -275,8 +290,13 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true} } - updated, err := s.queries.UpdateModelByModelID(ctx, params) + params.ModelID = model.ModelID + + updated, err := s.queries.UpdateModel(ctx, params) if err != nil { + if db.IsUniqueViolation(err) { + return GetResponse{}, ErrModelIDAlreadyExists + } return GetResponse{}, fmt.Errorf("failed to update model: %w", err) } @@ -302,8 +322,12 @@ func (s *Service) DeleteByModelID(ctx context.Context, modelID string) error { if modelID == "" { return fmt.Errorf("model_id is required") } + current, err := s.findUniqueByModelID(ctx, modelID) + if err != nil { + return fmt.Errorf("failed to delete model: %w", err) + } - if err := s.queries.DeleteModelByModelID(ctx, modelID); err != nil { + if err := s.queries.DeleteModel(ctx, current.ID); err != nil { return fmt.Errorf("failed to delete model: %w", err) } @@ -336,6 +360,7 @@ func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64, func convertToGetResponse(dbModel sqlc.Model) GetResponse { resp := GetResponse{ + ID: dbModel.ID.String(), ModelID: dbModel.ModelID, Model: Model{ ModelID: dbModel.ModelID, @@ -372,6 +397,20 @@ func convertToGetResponseList(dbModels []sqlc.Model) []GetResponse { return responses } +func (s *Service) findUniqueByModelID(ctx context.Context, modelID string) (sqlc.Model, error) { + rows, err := s.queries.ListModelsByModelID(ctx, modelID) + if err != nil { + return sqlc.Model{}, err + } + if len(rows) == 0 { + return sqlc.Model{}, pgx.ErrNoRows + } + if len(rows) > 1 { + return sqlc.Model{}, ErrModelIDAmbiguous + } + return rows[0], nil +} + // normalizeModalities returns modalities if non-empty, otherwise the provided fallback. func normalizeModalities(modalities []string, fallback []string) []string { if len(modalities) == 0 { diff --git a/internal/models/types.go b/internal/models/types.go index 984a3cde..6fd9e87e 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -113,6 +113,7 @@ type GetRequest struct { } type GetResponse struct { + ID string `json:"id"` ModelID string `json:"model_id"` Model } diff --git a/internal/settings/service.go b/internal/settings/service.go index a4464356..ec434bad 100644 --- a/internal/settings/service.go +++ b/internal/settings/service.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/google/uuid" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/memohai/memoh/internal/db" @@ -20,6 +21,8 @@ type Service struct { } var ErrPersonalBotGuestAccessUnsupported = errors.New("personal bots do not support guest access") +var ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers") +var ErrInvalidModelRef = errors.New("invalid model reference") func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { return &Service{ @@ -184,15 +187,21 @@ func normalizeBotSettingsFields( maxContextTokens int32, language string, allowGuest bool, - chatModelID pgtype.Text, - memoryModelID pgtype.Text, - embeddingModelID pgtype.Text, + chatModelID pgtype.UUID, + memoryModelID pgtype.UUID, + embeddingModelID pgtype.UUID, searchProviderID pgtype.UUID, ) Settings { settings := normalizeBotSetting(maxContextLoadTime, maxContextTokens, language, allowGuest) - settings.ChatModelID = strings.TrimSpace(chatModelID.String) - settings.MemoryModelID = strings.TrimSpace(memoryModelID.String) - settings.EmbeddingModelID = strings.TrimSpace(embeddingModelID.String) + if chatModelID.Valid { + settings.ChatModelID = uuid.UUID(chatModelID.Bytes).String() + } + if memoryModelID.Valid { + settings.MemoryModelID = uuid.UUID(memoryModelID.Bytes).String() + } + if embeddingModelID.Valid { + settings.EmbeddingModelID = uuid.UUID(embeddingModelID.Bytes).String() + } if searchProviderID.Valid { settings.SearchProviderID = uuid.UUID(searchProviderID.Bytes).String() } @@ -200,12 +209,29 @@ func normalizeBotSettingsFields( } func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype.UUID, error) { - if strings.TrimSpace(modelID) == "" { - return pgtype.UUID{}, fmt.Errorf("model_id is required") + modelID = strings.TrimSpace(modelID) + if modelID == "" { + return pgtype.UUID{}, fmt.Errorf("%w: model_id is required", ErrInvalidModelRef) } - row, err := s.queries.GetModelByModelID(ctx, modelID) + + // Preferred path: when caller already passes the model UUID. + if parsed, err := db.ParseUUID(modelID); err == nil { + if _, err := s.queries.GetModelByID(ctx, parsed); err == nil { + return parsed, nil + } else if !errors.Is(err, pgx.ErrNoRows) { + return pgtype.UUID{}, err + } + } + + rows, err := s.queries.ListModelsByModelID(ctx, modelID) if err != nil { return pgtype.UUID{}, err } - return row.ID, nil + if len(rows) == 0 { + return pgtype.UUID{}, fmt.Errorf("%w: model not found: %s", ErrInvalidModelRef, modelID) + } + if len(rows) > 1 { + return pgtype.UUID{}, fmt.Errorf("%w: %s", ErrModelIDAmbiguous, modelID) + } + return rows[0].ID, nil } diff --git a/packages/sdk/src/types.gen.ts b/packages/sdk/src/types.gen.ts index 495ee9ed..b62d339d 100644 --- a/packages/sdk/src/types.gen.ts +++ b/packages/sdk/src/types.gen.ts @@ -718,6 +718,7 @@ export type ModelsCountResponse = { export type ModelsGetResponse = { client_type?: ModelsClientType; dimensions?: number; + id?: string; input_modalities?: Array; llm_provider_id?: string; model_id?: string; diff --git a/packages/web/src/components/create-model/index.vue b/packages/web/src/components/create-model/index.vue index 1f0d7147..07d275c9 100644 --- a/packages/web/src/components/create-model/index.vue +++ b/packages/web/src/components/create-model/index.vue @@ -229,12 +229,15 @@ import { inject, computed, watch, nextTick, type Ref, ref } from 'vue' import { toTypedSchema } from '@vee-validate/zod' import z from 'zod' import { useMutation, useQueryCache } from '@pinia/colada' -import { postModels, putModelsModelByModelId } from '@memoh/sdk' +import { postModels, putModelsById, putModelsModelByModelId } from '@memoh/sdk' import type { ModelsGetResponse } from '@memoh/sdk' import { CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types' +import { useI18n } from 'vue-i18n' +import { toast } from 'vue-sonner' const availableInputModalities = ['text', 'image', 'audio', 'video', 'file'] as const const selectedModalities = ref(['text']) +const { t } = useI18n() const formSchema = toTypedSchema(z.object({ type: z.string().min(1), @@ -313,6 +316,17 @@ const { mutateAsync: createModel, isLoading: createLoading } = useMutation({ onSettled: () => queryCache.invalidateQueries({ key: ['provider-models'] }), }) const { mutateAsync: updateModel, isLoading: updateLoading } = useMutation({ + mutation: async ({ id, data }: { id: string; data: Record }) => { + const { data: result } = await putModelsById({ + path: { id }, + body: data as any, + throwOnError: true, + }) + return result + }, + onSettled: () => queryCache.invalidateQueries({ key: ['provider-models'] }), +}) +const { mutateAsync: updateModelByLegacyModelID, isLoading: updateLegacyLoading } = useMutation({ mutation: async ({ modelId, data }: { modelId: string; data: Record }) => { const { data: result } = await putModelsModelByModelId({ path: { modelId }, @@ -323,7 +337,7 @@ const { mutateAsync: updateModel, isLoading: updateLoading } = useMutation({ }, onSettled: () => queryCache.invalidateQueries({ key: ['provider-models'] }), }) -const isLoading = computed(() => createLoading.value || updateLoading.value) +const isLoading = computed(() => createLoading.value || updateLoading.value || updateLegacyLoading.value) async function addModel(e: Event) { e.preventDefault() @@ -366,16 +380,31 @@ async function addModel(e: Event) { } if (isEdit) { - await updateModel({ modelId: fallback!.model_id, data: payload as any }) + const modelUUID = fallback?.id + if (modelUUID) { + await updateModel({ id: modelUUID, data: payload as any }) + } else { + await updateModelByLegacyModelID({ modelId: fallback!.model_id, data: payload as any }) + } } else { await createModel(payload as any) } open.value = false - } catch { + } catch (error) { + toast.error(resolveErrorMessage(error, t('common.saveFailed'))) return } } +function resolveErrorMessage(error: unknown, fallback: string): string { + if (error instanceof Error && error.message.trim()) return error.message + if (error && typeof error === 'object' && 'message' in error) { + const msg = (error as { message?: string }).message + if (msg && msg.trim()) return msg + } + return fallback +} + watch(open, async () => { if (!open.value) { title.value = 'title' diff --git a/packages/web/src/pages/bots/components/model-select.vue b/packages/web/src/pages/bots/components/model-select.vue index a9f36bc0..e36e6b2b 100644 --- a/packages/web/src/pages/bots/components/model-select.vue +++ b/packages/web/src/pages/bots/components/model-select.vue @@ -52,13 +52,13 @@