fix(models,settings,conversation): scope model_id uniqueness per

provider and harden model reference resolution
This commit is contained in:
ringotypowriter
2026-02-21 22:31:32 +08:00
parent 9461f923df
commit 50bdbd519c
25 changed files with 376 additions and 107 deletions
+1 -2
View File
@@ -89,7 +89,7 @@ CREATE TABLE IF NOT EXISTS models (
type TEXT NOT NULL DEFAULT 'chat',
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT models_model_id_unique UNIQUE (model_id),
CONSTRAINT models_provider_model_id_unique UNIQUE (llm_provider_id, model_id),
CONSTRAINT models_type_check CHECK (type IN ('chat', 'embedding')),
CONSTRAINT models_dimensions_check CHECK (type != 'embedding' OR dimensions IS NOT NULL),
CONSTRAINT models_client_type_check CHECK (client_type IS NULL OR client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai')),
@@ -389,4 +389,3 @@ CREATE TABLE IF NOT EXISTS bot_history_message_assets (
);
CREATE INDEX IF NOT EXISTS idx_message_assets_message_id ON bot_history_message_assets(message_id);
@@ -0,0 +1,24 @@
-- 0011_model_id_unique_per_provider
-- Revert model_id uniqueness back to global uniqueness.
DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM models
GROUP BY model_id
HAVING COUNT(*) > 1
) THEN
RAISE EXCEPTION 'cannot rollback 0011_model_id_unique_per_provider: duplicate model_id values exist across providers';
END IF;
IF EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'models_provider_model_id_unique') THEN
ALTER TABLE models DROP CONSTRAINT models_provider_model_id_unique;
END IF;
IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'models_model_id_unique') THEN
ALTER TABLE models
ADD CONSTRAINT models_model_id_unique UNIQUE (model_id);
END IF;
END
$$;
@@ -0,0 +1,15 @@
-- 0011_model_id_unique_per_provider
-- Change model_id uniqueness from global to per provider.
DO $$
BEGIN
IF EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'models_model_id_unique') THEN
ALTER TABLE models DROP CONSTRAINT models_model_id_unique;
END IF;
IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'models_provider_model_id_unique') THEN
ALTER TABLE models
ADD CONSTRAINT models_provider_model_id_unique UNIQUE (llm_provider_id, model_id);
END IF;
END
$$;
+4 -9
View File
@@ -205,22 +205,17 @@ ON CONFLICT (bot_id, user_id) DO NOTHING;
-- chat_settings
-- name: UpsertChatSettings :one
WITH resolved_model AS (
SELECT id
FROM models
WHERE model_id = NULLIF(sqlc.narg(model_id)::text, '')
LIMIT 1
),
WITH
updated AS (
UPDATE bots
SET chat_model_id = COALESCE((SELECT id FROM resolved_model), bots.chat_model_id),
SET chat_model_id = COALESCE(sqlc.narg(chat_model_id)::uuid, bots.chat_model_id),
updated_at = now()
WHERE bots.id = sqlc.arg(id)
RETURNING bots.id, bots.chat_model_id, bots.updated_at
)
SELECT
updated.id AS chat_id,
chat_models.model_id AS model_id,
chat_models.id AS model_id,
updated.updated_at
FROM updated
LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id;
@@ -228,7 +223,7 @@ LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id;
-- name: GetChatSettings :one
SELECT
b.id AS chat_id,
chat_models.model_id AS model_id,
chat_models.id AS model_id,
b.updated_at
FROM bots b
LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id
+6
View File
@@ -54,6 +54,11 @@ SELECT * FROM models WHERE id = sqlc.arg(id);
-- name: GetModelByModelID :one
SELECT * FROM models WHERE model_id = sqlc.arg(model_id);
-- name: ListModelsByModelID :many
SELECT * FROM models
WHERE model_id = sqlc.arg(model_id)
ORDER BY created_at DESC;
-- name: ListModels :many
SELECT * FROM models
ORDER BY created_at DESC;
@@ -82,6 +87,7 @@ ORDER BY created_at DESC;
-- name: UpdateModel :one
UPDATE models
SET
model_id = sqlc.arg(model_id),
name = sqlc.arg(name),
llm_provider_id = sqlc.arg(llm_provider_id),
client_type = sqlc.narg(client_type),
+6 -6
View File
@@ -5,9 +5,9 @@ SELECT
bots.max_context_tokens,
bots.language,
bots.allow_guest,
chat_models.model_id AS chat_model_id,
memory_models.model_id AS memory_model_id,
embedding_models.model_id AS embedding_model_id,
chat_models.id AS chat_model_id,
memory_models.id AS memory_model_id,
embedding_models.id AS embedding_model_id,
search_providers.id AS search_provider_id
FROM bots
LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id
@@ -37,9 +37,9 @@ SELECT
updated.max_context_tokens,
updated.language,
updated.allow_guest,
chat_models.model_id AS chat_model_id,
memory_models.model_id AS memory_model_id,
embedding_models.model_id AS embedding_model_id,
chat_models.id AS chat_model_id,
memory_models.id AS memory_model_id,
embedding_models.id AS embedding_model_id,
search_providers.id AS search_provider_id
FROM updated
LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id
+22 -1
View File
@@ -15,6 +15,7 @@ import (
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
attachmentpkg "github.com/memohai/memoh/internal/attachment"
@@ -1535,10 +1536,30 @@ func (r *Resolver) selectChatModel(ctx context.Context, req conversation.ChatReq
}
func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) {
model, err := r.modelsService.GetByModelID(ctx, modelID)
modelRef := strings.TrimSpace(modelID)
if modelRef == "" {
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model id is required")
}
// Support both model UUID and model_id slug. UUID-formatted slugs still
// work because we fall back to GetByModelID when UUID lookup misses.
var model models.GetResponse
var err error
if _, parseErr := db.ParseUUID(modelRef); parseErr == nil {
model, err = r.modelsService.GetByID(ctx, modelRef)
if err == nil {
goto resolved
}
if !errors.Is(err, pgx.ErrNoRows) {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
}
model, err = r.modelsService.GetByModelID(ctx, modelRef)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
resolved:
if model.Type != models.ModelTypeChat {
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model")
}
+54 -21
View File
@@ -9,6 +9,7 @@ import (
"strings"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
@@ -20,6 +21,7 @@ var (
ErrChatNotFound = errors.New("chat not found")
ErrNotParticipant = errors.New("not a participant")
ErrPermissionDenied = errors.New("permission denied")
ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers")
)
// Service manages conversation lifecycle, participants, and settings.
@@ -310,21 +312,26 @@ func (s *Service) GetSettings(ctx context.Context, conversationID string) (Setti
// UpdateSettings updates conversation settings.
func (s *Service) UpdateSettings(ctx context.Context, conversationID string, req UpdateSettingsRequest) (Settings, error) {
current, err := s.GetSettings(ctx, conversationID)
if err != nil {
return Settings{}, err
}
if req.ModelID != nil {
current.ModelID = *req.ModelID
}
pgID, err := parseUUID(conversationID)
if err != nil {
return Settings{}, err
}
chatModelUUID := pgtype.UUID{}
if req.ModelID != nil {
modelRef := strings.TrimSpace(*req.ModelID)
if modelRef != "" {
resolved, err := s.resolveModelUUID(ctx, modelRef)
if err != nil {
return Settings{}, err
}
chatModelUUID = resolved
}
}
row, err := s.queries.UpsertChatSettings(ctx, sqlc.UpsertChatSettingsParams{
ID: pgID,
ModelID: toPgText(current.ModelID),
ID: pgID,
ChatModelID: chatModelUUID,
})
if err != nil {
return Settings{}, err
@@ -427,17 +434,23 @@ func toParticipantFields(conversationID, userID pgtype.UUID, role string, joined
}
func toSettingsFromRead(row sqlc.GetChatSettingsRow) Settings {
return Settings{
ChatID: row.ChatID.String(),
ModelID: dbpkg.TextToString(row.ModelID),
settings := Settings{
ChatID: row.ChatID.String(),
}
if row.ModelID.Valid {
settings.ModelID = uuid.UUID(row.ModelID.Bytes).String()
}
return settings
}
func toSettingsFromUpsert(row sqlc.UpsertChatSettingsRow) Settings {
return Settings{
ChatID: row.ChatID.String(),
ModelID: dbpkg.TextToString(row.ModelID),
settings := Settings{
ChatID: row.ChatID.String(),
}
if row.ModelID.Valid {
settings.ModelID = uuid.UUID(row.ModelID.Bytes).String()
}
return settings
}
func defaultSettings(conversationID string) Settings {
@@ -450,12 +463,32 @@ func parseUUID(id string) (pgtype.UUID, error) {
return dbpkg.ParseUUID(id)
}
func toPgText(s string) pgtype.Text {
s = strings.TrimSpace(s)
if s == "" {
return pgtype.Text{}
func (s *Service) resolveModelUUID(ctx context.Context, modelRef string) (pgtype.UUID, error) {
modelRef = strings.TrimSpace(modelRef)
if modelRef == "" {
return pgtype.UUID{}, fmt.Errorf("model_id is required")
}
return pgtype.Text{String: s, Valid: true}
// Prefer UUID path; if not found, fall back to model_id slug.
if parsed, err := dbpkg.ParseUUID(modelRef); err == nil {
if _, err := s.queries.GetModelByID(ctx, parsed); err == nil {
return parsed, nil
} else if !errors.Is(err, pgx.ErrNoRows) {
return pgtype.UUID{}, err
}
}
rows, err := s.queries.ListModelsByModelID(ctx, modelRef)
if err != nil {
return pgtype.UUID{}, err
}
if len(rows) == 0 {
return pgtype.UUID{}, fmt.Errorf("model not found: %s", modelRef)
}
if len(rows) > 1 {
return pgtype.UUID{}, fmt.Errorf("%w: %s", ErrModelIDAmbiguous, modelRef)
}
return rows[0].ID, nil
}
func pgTimePtr(ts pgtype.Timestamptz) *time.Time {
+9 -14
View File
@@ -271,7 +271,7 @@ func (q *Queries) GetChatReadAccessByUser(ctx context.Context, arg GetChatReadAc
const getChatSettings = `-- name: GetChatSettings :one
SELECT
b.id AS chat_id,
chat_models.model_id AS model_id,
chat_models.id AS model_id,
b.updated_at
FROM bots b
LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id
@@ -280,7 +280,7 @@ WHERE b.id = $1
type GetChatSettingsRow struct {
ChatID pgtype.UUID `json:"chat_id"`
ModelID pgtype.Text `json:"model_id"`
ModelID pgtype.UUID `json:"model_id"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
@@ -645,41 +645,36 @@ func (q *Queries) UpdateChatTitle(ctx context.Context, arg UpdateChatTitleParams
const upsertChatSettings = `-- name: UpsertChatSettings :one
WITH resolved_model AS (
SELECT id
FROM models
WHERE model_id = NULLIF($1::text, '')
LIMIT 1
),
WITH
updated AS (
UPDATE bots
SET chat_model_id = COALESCE((SELECT id FROM resolved_model), bots.chat_model_id),
SET chat_model_id = COALESCE($1::uuid, bots.chat_model_id),
updated_at = now()
WHERE bots.id = $2
RETURNING bots.id, bots.chat_model_id, bots.updated_at
)
SELECT
updated.id AS chat_id,
chat_models.model_id AS model_id,
chat_models.id AS model_id,
updated.updated_at
FROM updated
LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id
`
type UpsertChatSettingsParams struct {
ModelID pgtype.Text `json:"model_id"`
ID pgtype.UUID `json:"id"`
ChatModelID pgtype.UUID `json:"chat_model_id"`
ID pgtype.UUID `json:"id"`
}
type UpsertChatSettingsRow struct {
ChatID pgtype.UUID `json:"chat_id"`
ModelID pgtype.Text `json:"model_id"`
ModelID pgtype.UUID `json:"model_id"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
// chat_settings
func (q *Queries) UpsertChatSettings(ctx context.Context, arg UpsertChatSettingsParams) (UpsertChatSettingsRow, error) {
row := q.db.QueryRow(ctx, upsertChatSettings, arg.ModelID, arg.ID)
row := q.db.QueryRow(ctx, upsertChatSettings, arg.ChatModelID, arg.ID)
var i UpsertChatSettingsRow
err := row.Scan(&i.ChatID, &i.ModelID, &i.UpdatedAt)
return i, err
+47 -7
View File
@@ -419,6 +419,43 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType pgtype.
return items, nil
}
const listModelsByModelID = `-- name: ListModelsByModelID :many
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models
WHERE model_id = $1
ORDER BY created_at DESC
`
func (q *Queries) ListModelsByModelID(ctx context.Context, modelID string) ([]Model, error) {
rows, err := q.db.Query(ctx, listModelsByModelID, modelID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Model
for rows.Next() {
var i Model
if err := rows.Scan(
&i.ID,
&i.ModelID,
&i.Name,
&i.LlmProviderID,
&i.ClientType,
&i.Dimensions,
&i.InputModalities,
&i.Type,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listModelsByProviderID = `-- name: ListModelsByProviderID :many
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models
WHERE llm_provider_id = $1
@@ -580,18 +617,20 @@ func (q *Queries) UpdateLlmProvider(ctx context.Context, arg UpdateLlmProviderPa
const updateModel = `-- name: UpdateModel :one
UPDATE models
SET
name = $1,
llm_provider_id = $2,
client_type = $3,
dimensions = $4,
input_modalities = $5,
type = $6,
model_id = $1,
name = $2,
llm_provider_id = $3,
client_type = $4,
dimensions = $5,
input_modalities = $6,
type = $7,
updated_at = now()
WHERE id = $7
WHERE id = $8
RETURNING id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at
`
type UpdateModelParams struct {
ModelID string `json:"model_id"`
Name pgtype.Text `json:"name"`
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
ClientType pgtype.Text `json:"client_type"`
@@ -603,6 +642,7 @@ type UpdateModelParams struct {
func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model, error) {
row := q.db.QueryRow(ctx, updateModel,
arg.ModelID,
arg.Name,
arg.LlmProviderID,
arg.ClientType,
+12 -12
View File
@@ -37,9 +37,9 @@ SELECT
bots.max_context_tokens,
bots.language,
bots.allow_guest,
chat_models.model_id AS chat_model_id,
memory_models.model_id AS memory_model_id,
embedding_models.model_id AS embedding_model_id,
chat_models.id AS chat_model_id,
memory_models.id AS memory_model_id,
embedding_models.id AS embedding_model_id,
search_providers.id AS search_provider_id
FROM bots
LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id
@@ -55,9 +55,9 @@ type GetSettingsByBotIDRow struct {
MaxContextTokens int32 `json:"max_context_tokens"`
Language string `json:"language"`
AllowGuest bool `json:"allow_guest"`
ChatModelID pgtype.Text `json:"chat_model_id"`
MemoryModelID pgtype.Text `json:"memory_model_id"`
EmbeddingModelID pgtype.Text `json:"embedding_model_id"`
ChatModelID pgtype.UUID `json:"chat_model_id"`
MemoryModelID pgtype.UUID `json:"memory_model_id"`
EmbeddingModelID pgtype.UUID `json:"embedding_model_id"`
SearchProviderID pgtype.UUID `json:"search_provider_id"`
}
@@ -99,9 +99,9 @@ SELECT
updated.max_context_tokens,
updated.language,
updated.allow_guest,
chat_models.model_id AS chat_model_id,
memory_models.model_id AS memory_model_id,
embedding_models.model_id AS embedding_model_id,
chat_models.id AS chat_model_id,
memory_models.id AS memory_model_id,
embedding_models.id AS embedding_model_id,
search_providers.id AS search_provider_id
FROM updated
LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id
@@ -128,9 +128,9 @@ type UpsertBotSettingsRow struct {
MaxContextTokens int32 `json:"max_context_tokens"`
Language string `json:"language"`
AllowGuest bool `json:"allow_guest"`
ChatModelID pgtype.Text `json:"chat_model_id"`
MemoryModelID pgtype.Text `json:"memory_model_id"`
EmbeddingModelID pgtype.Text `json:"embedding_model_id"`
ChatModelID pgtype.UUID `json:"chat_model_id"`
MemoryModelID pgtype.UUID `json:"memory_model_id"`
EmbeddingModelID pgtype.UUID `json:"embedding_model_id"`
SearchProviderID pgtype.UUID `json:"search_provider_id"`
}
+29
View File
@@ -1,10 +1,12 @@
package handlers
import (
"errors"
"log/slog"
"net/http"
"net/url"
"github.com/jackc/pgx/v5"
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/models"
@@ -52,6 +54,9 @@ func (h *ModelsHandler) Create(c echo.Context) error {
resp, err := h.service.Create(c.Request().Context(), req)
if err != nil {
if errors.Is(err, models.ErrModelIDAlreadyExists) {
return echo.NewHTTPError(http.StatusConflict, "model_id already exists under the selected provider")
}
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.JSON(http.StatusCreated, resp)
@@ -134,6 +139,12 @@ func (h *ModelsHandler) GetByModelID(c echo.Context) error {
resp, err := h.service.GetByModelID(c.Request().Context(), modelID)
if err != nil {
if errors.Is(err, models.ErrModelIDAmbiguous) {
return echo.NewHTTPError(http.StatusConflict, "model_id is duplicated across providers; use /models/{id} instead")
}
if errors.Is(err, pgx.ErrNoRows) {
return echo.NewHTTPError(http.StatusNotFound, err.Error())
}
return echo.NewHTTPError(http.StatusNotFound, err.Error())
}
return c.JSON(http.StatusOK, resp)
@@ -163,6 +174,9 @@ func (h *ModelsHandler) UpdateByID(c echo.Context) error {
resp, err := h.service.UpdateByID(c.Request().Context(), id, req)
if err != nil {
if errors.Is(err, models.ErrModelIDAlreadyExists) {
return echo.NewHTTPError(http.StatusConflict, "model_id already exists under the selected provider")
}
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.JSON(http.StatusOK, resp)
@@ -197,6 +211,15 @@ func (h *ModelsHandler) UpdateByModelID(c echo.Context) error {
resp, err := h.service.UpdateByModelID(c.Request().Context(), modelID, req)
if err != nil {
if errors.Is(err, models.ErrModelIDAlreadyExists) {
return echo.NewHTTPError(http.StatusConflict, "model_id already exists under the selected provider")
}
if errors.Is(err, models.ErrModelIDAmbiguous) {
return echo.NewHTTPError(http.StatusConflict, "model_id is duplicated across providers; use /models/{id} instead")
}
if errors.Is(err, pgx.ErrNoRows) {
return echo.NewHTTPError(http.StatusNotFound, err.Error())
}
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.JSON(http.StatusOK, resp)
@@ -246,6 +269,12 @@ func (h *ModelsHandler) DeleteByModelID(c echo.Context) error {
}
if err := h.service.DeleteByModelID(c.Request().Context(), modelID); err != nil {
if errors.Is(err, models.ErrModelIDAmbiguous) {
return echo.NewHTTPError(http.StatusConflict, "model_id is duplicated across providers; use /models/{id} instead")
}
if errors.Is(err, pgx.ErrNoRows) {
return echo.NewHTTPError(http.StatusNotFound, err.Error())
}
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.NoContent(http.StatusNoContent)
+6
View File
@@ -96,6 +96,12 @@ func (h *SettingsHandler) Upsert(c echo.Context) error {
if errors.Is(err, settings.ErrPersonalBotGuestAccessUnsupported) {
return echo.NewHTTPError(http.StatusBadRequest, "personal bot does not support guest access")
}
if errors.Is(err, settings.ErrInvalidModelRef) {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if errors.Is(err, settings.ErrModelIDAmbiguous) {
return echo.NewHTTPError(http.StatusConflict, "model_id is duplicated across providers; select by model UUID")
}
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.JSON(http.StatusOK, resp)
+45 -6
View File
@@ -2,16 +2,21 @@ package models
import (
"context"
"errors"
"fmt"
"log/slog"
"strings"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
)
var ErrModelIDAlreadyExists = errors.New("model_id already exists")
var ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers")
// Service provides CRUD operations for models
type Service struct {
queries *sqlc.Queries
@@ -65,6 +70,9 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro
created, err := s.queries.CreateModel(ctx, params)
if err != nil {
if db.IsUniqueViolation(err) {
return AddResponse{}, ErrModelIDAlreadyExists
}
return AddResponse{}, fmt.Errorf("failed to create model: %w", err)
}
@@ -105,7 +113,7 @@ func (s *Service) GetByModelID(ctx context.Context, modelID string) (GetResponse
return GetResponse{}, fmt.Errorf("model_id is required")
}
dbModel, err := s.queries.GetModelByModelID(ctx, modelID)
dbModel, err := s.findUniqueByModelID(ctx, modelID)
if err != nil {
return GetResponse{}, fmt.Errorf("failed to get model: %w", err)
}
@@ -207,6 +215,7 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
}
params := sqlc.UpdateModelParams{
ID: uuid,
ModelID: model.ModelID,
InputModalities: inputMod,
Type: string(model.Type),
}
@@ -230,6 +239,9 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
updated, err := s.queries.UpdateModel(ctx, params)
if err != nil {
if db.IsUniqueViolation(err) {
return GetResponse{}, ErrModelIDAlreadyExists
}
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
}
@@ -241,6 +253,10 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
if modelID == "" {
return GetResponse{}, fmt.Errorf("model_id is required")
}
current, err := s.findUniqueByModelID(ctx, modelID)
if err != nil {
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
}
model := Model(req)
if err := model.Validate(); err != nil {
@@ -251,9 +267,8 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
if model.Type == ModelTypeChat {
inputMod = normalizeModalities(model.InputModalities, []string{ModelInputText})
}
params := sqlc.UpdateModelByModelIDParams{
ModelID: modelID,
NewModelID: model.ModelID,
params := sqlc.UpdateModelParams{
ID: current.ID,
InputModalities: inputMod,
Type: string(model.Type),
}
@@ -275,8 +290,13 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
}
updated, err := s.queries.UpdateModelByModelID(ctx, params)
params.ModelID = model.ModelID
updated, err := s.queries.UpdateModel(ctx, params)
if err != nil {
if db.IsUniqueViolation(err) {
return GetResponse{}, ErrModelIDAlreadyExists
}
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
}
@@ -302,8 +322,12 @@ func (s *Service) DeleteByModelID(ctx context.Context, modelID string) error {
if modelID == "" {
return fmt.Errorf("model_id is required")
}
current, err := s.findUniqueByModelID(ctx, modelID)
if err != nil {
return fmt.Errorf("failed to delete model: %w", err)
}
if err := s.queries.DeleteModelByModelID(ctx, modelID); err != nil {
if err := s.queries.DeleteModel(ctx, current.ID); err != nil {
return fmt.Errorf("failed to delete model: %w", err)
}
@@ -336,6 +360,7 @@ func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64,
func convertToGetResponse(dbModel sqlc.Model) GetResponse {
resp := GetResponse{
ID: dbModel.ID.String(),
ModelID: dbModel.ModelID,
Model: Model{
ModelID: dbModel.ModelID,
@@ -372,6 +397,20 @@ func convertToGetResponseList(dbModels []sqlc.Model) []GetResponse {
return responses
}
func (s *Service) findUniqueByModelID(ctx context.Context, modelID string) (sqlc.Model, error) {
rows, err := s.queries.ListModelsByModelID(ctx, modelID)
if err != nil {
return sqlc.Model{}, err
}
if len(rows) == 0 {
return sqlc.Model{}, pgx.ErrNoRows
}
if len(rows) > 1 {
return sqlc.Model{}, ErrModelIDAmbiguous
}
return rows[0], nil
}
// normalizeModalities returns modalities if non-empty, otherwise the provided fallback.
func normalizeModalities(modalities []string, fallback []string) []string {
if len(modalities) == 0 {
+1
View File
@@ -113,6 +113,7 @@ type GetRequest struct {
}
type GetResponse struct {
ID string `json:"id"`
ModelID string `json:"model_id"`
Model
}
+36 -10
View File
@@ -8,6 +8,7 @@ import (
"strings"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/db"
@@ -20,6 +21,8 @@ type Service struct {
}
var ErrPersonalBotGuestAccessUnsupported = errors.New("personal bots do not support guest access")
var ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers")
var ErrInvalidModelRef = errors.New("invalid model reference")
func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
return &Service{
@@ -184,15 +187,21 @@ func normalizeBotSettingsFields(
maxContextTokens int32,
language string,
allowGuest bool,
chatModelID pgtype.Text,
memoryModelID pgtype.Text,
embeddingModelID pgtype.Text,
chatModelID pgtype.UUID,
memoryModelID pgtype.UUID,
embeddingModelID pgtype.UUID,
searchProviderID pgtype.UUID,
) Settings {
settings := normalizeBotSetting(maxContextLoadTime, maxContextTokens, language, allowGuest)
settings.ChatModelID = strings.TrimSpace(chatModelID.String)
settings.MemoryModelID = strings.TrimSpace(memoryModelID.String)
settings.EmbeddingModelID = strings.TrimSpace(embeddingModelID.String)
if chatModelID.Valid {
settings.ChatModelID = uuid.UUID(chatModelID.Bytes).String()
}
if memoryModelID.Valid {
settings.MemoryModelID = uuid.UUID(memoryModelID.Bytes).String()
}
if embeddingModelID.Valid {
settings.EmbeddingModelID = uuid.UUID(embeddingModelID.Bytes).String()
}
if searchProviderID.Valid {
settings.SearchProviderID = uuid.UUID(searchProviderID.Bytes).String()
}
@@ -200,12 +209,29 @@ func normalizeBotSettingsFields(
}
func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype.UUID, error) {
if strings.TrimSpace(modelID) == "" {
return pgtype.UUID{}, fmt.Errorf("model_id is required")
modelID = strings.TrimSpace(modelID)
if modelID == "" {
return pgtype.UUID{}, fmt.Errorf("%w: model_id is required", ErrInvalidModelRef)
}
row, err := s.queries.GetModelByModelID(ctx, modelID)
// Preferred path: when caller already passes the model UUID.
if parsed, err := db.ParseUUID(modelID); err == nil {
if _, err := s.queries.GetModelByID(ctx, parsed); err == nil {
return parsed, nil
} else if !errors.Is(err, pgx.ErrNoRows) {
return pgtype.UUID{}, err
}
}
rows, err := s.queries.ListModelsByModelID(ctx, modelID)
if err != nil {
return pgtype.UUID{}, err
}
return row.ID, nil
if len(rows) == 0 {
return pgtype.UUID{}, fmt.Errorf("%w: model not found: %s", ErrInvalidModelRef, modelID)
}
if len(rows) > 1 {
return pgtype.UUID{}, fmt.Errorf("%w: %s", ErrModelIDAmbiguous, modelID)
}
return rows[0].ID, nil
}
+1
View File
@@ -718,6 +718,7 @@ export type ModelsCountResponse = {
export type ModelsGetResponse = {
client_type?: ModelsClientType;
dimensions?: number;
id?: string;
input_modalities?: Array<string>;
llm_provider_id?: string;
model_id?: string;
@@ -229,12 +229,15 @@ import { inject, computed, watch, nextTick, type Ref, ref } from 'vue'
import { toTypedSchema } from '@vee-validate/zod'
import z from 'zod'
import { useMutation, useQueryCache } from '@pinia/colada'
import { postModels, putModelsModelByModelId } from '@memoh/sdk'
import { postModels, putModelsById, putModelsModelByModelId } from '@memoh/sdk'
import type { ModelsGetResponse } from '@memoh/sdk'
import { CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
import { useI18n } from 'vue-i18n'
import { toast } from 'vue-sonner'
const availableInputModalities = ['text', 'image', 'audio', 'video', 'file'] as const
const selectedModalities = ref<string[]>(['text'])
const { t } = useI18n()
const formSchema = toTypedSchema(z.object({
type: z.string().min(1),
@@ -313,6 +316,17 @@ const { mutateAsync: createModel, isLoading: createLoading } = useMutation({
onSettled: () => queryCache.invalidateQueries({ key: ['provider-models'] }),
})
const { mutateAsync: updateModel, isLoading: updateLoading } = useMutation({
mutation: async ({ id, data }: { id: string; data: Record<string, unknown> }) => {
const { data: result } = await putModelsById({
path: { id },
body: data as any,
throwOnError: true,
})
return result
},
onSettled: () => queryCache.invalidateQueries({ key: ['provider-models'] }),
})
const { mutateAsync: updateModelByLegacyModelID, isLoading: updateLegacyLoading } = useMutation({
mutation: async ({ modelId, data }: { modelId: string; data: Record<string, unknown> }) => {
const { data: result } = await putModelsModelByModelId({
path: { modelId },
@@ -323,7 +337,7 @@ const { mutateAsync: updateModel, isLoading: updateLoading } = useMutation({
},
onSettled: () => queryCache.invalidateQueries({ key: ['provider-models'] }),
})
const isLoading = computed(() => createLoading.value || updateLoading.value)
const isLoading = computed(() => createLoading.value || updateLoading.value || updateLegacyLoading.value)
async function addModel(e: Event) {
e.preventDefault()
@@ -366,16 +380,31 @@ async function addModel(e: Event) {
}
if (isEdit) {
await updateModel({ modelId: fallback!.model_id, data: payload as any })
const modelUUID = fallback?.id
if (modelUUID) {
await updateModel({ id: modelUUID, data: payload as any })
} else {
await updateModelByLegacyModelID({ modelId: fallback!.model_id, data: payload as any })
}
} else {
await createModel(payload as any)
}
open.value = false
} catch {
} catch (error) {
toast.error(resolveErrorMessage(error, t('common.saveFailed')))
return
}
}
function resolveErrorMessage(error: unknown, fallback: string): string {
if (error instanceof Error && error.message.trim()) return error.message
if (error && typeof error === 'object' && 'message' in error) {
const msg = (error as { message?: string }).message
if (msg && msg.trim()) return msg
}
return fallback
}
watch(open, async () => {
if (!open.value) {
title.value = 'title'
@@ -52,13 +52,13 @@
</div>
<button
v-for="model in group.models"
:key="model.model_id"
:key="model.id || `${model.llm_provider_id}:${model.model_id}`"
class="relative flex w-full cursor-pointer items-center gap-2 rounded-md px-2 py-1.5 text-sm outline-none hover:bg-accent hover:text-accent-foreground"
:class="{ 'bg-accent': selected === model.model_id }"
@click="selectModel(model.model_id)"
:class="{ 'bg-accent': selected === model.id }"
@click="selectModel(model.id)"
>
<FontAwesomeIcon
v-if="selected === model.model_id"
v-if="selected === model.id"
:icon="['fas', 'check']"
class="size-3.5"
/>
@@ -145,11 +145,12 @@ const filteredGroups = computed(() => {
//
const displayLabel = computed(() => {
if (!selected.value) return ''
const model = typeFilteredModels.value.find((m) => m.model_id === selected.value)
const model = typeFilteredModels.value.find((m) => m.id === selected.value)
return model?.name || model?.model_id || selected.value
})
function selectModel(modelId: string) {
function selectModel(modelId?: string) {
if (!modelId) return
selected.value = modelId
open.value = false
}
@@ -1,7 +1,7 @@
<template>
<Item variant="outline">
<ItemContent>
<ItemTitle>{{ model.name }}</ItemTitle>
<ItemTitle>{{ model.name || model.model_id }}</ItemTitle>
<ItemDescription class="gap-2 flex flex-wrap items-center mt-3">
<Badge variant="outline">
{{ model.type }}
@@ -26,7 +26,7 @@
<ConfirmPopover
:message="$t('models.deleteModelConfirm')"
:loading="deleteLoading"
@confirm="$emit('delete', model.name)"
@confirm="$emit('delete', model.id ?? '')"
>
<template #trigger>
<Button variant="outline">
@@ -58,6 +58,6 @@ defineProps<{
defineEmits<{
edit: [model: ModelsGetResponse]
delete: [name: string]
delete: [id: string]
}>()
</script>
@@ -16,11 +16,11 @@
>
<ModelItem
v-for="model in models"
:key="model.model_id"
:key="model.id || `${model.llm_provider_id}:${model.model_id}`"
:model="model"
:delete-loading="deleteModelLoading"
@edit="(model) => $emit('edit', model)"
@delete="(name) => $emit('delete', name)"
@delete="(id) => $emit('delete', id)"
/>
</section>
@@ -61,6 +61,6 @@ defineProps<{
defineEmits<{
edit: [model: ModelsGetResponse]
delete: [name: string]
delete: [id: string]
}>()
</script>
@@ -33,7 +33,7 @@ import ProviderForm from './components/provider-form.vue'
import ModelList from './components/model-list.vue'
import { computed, inject, provide, reactive, ref, toRef, watch } from 'vue'
import { useQuery, useMutation, useQueryCache } from '@pinia/colada'
import { putProvidersById, deleteProvidersById, getProvidersByIdModels, deleteModelsModelByModelId } from '@memoh/sdk'
import { putProvidersById, deleteProvidersById, getProvidersByIdModels, deleteModelsById } from '@memoh/sdk'
import type { ModelsGetResponse, ProvidersGetResponse } from '@memoh/sdk'
// ---- Model provide CreateModel ----
@@ -86,8 +86,9 @@ const { mutate: changeProvider, isLoading: editLoading } = useMutation({
})
const { mutate: deleteModel, isLoading: deleteModelLoading } = useMutation({
mutation: async (modelName: string) => {
await deleteModelsModelByModelId({ path: { modelId: modelName }, throwOnError: true })
mutation: async (modelID: string) => {
if (!modelID) return
await deleteModelsById({ path: { id: modelID }, throwOnError: true })
},
onSettled: () => queryCache.invalidateQueries({ key: ['provider-models'] }),
})
+3
View File
@@ -6818,6 +6818,9 @@ const docTemplate = `{
"dimensions": {
"type": "integer"
},
"id": {
"type": "string"
},
"input_modalities": {
"type": "array",
"items": {
+3
View File
@@ -6809,6 +6809,9 @@
"dimensions": {
"type": "integer"
},
"id": {
"type": "string"
},
"input_modalities": {
"type": "array",
"items": {
+2
View File
@@ -1185,6 +1185,8 @@ definitions:
$ref: '#/definitions/models.ClientType'
dimensions:
type: integer
id:
type: string
input_modalities:
items:
type: string