mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
fix(models,settings,conversation): scope model_id uniqueness per
provider and harden model reference resolution
This commit is contained in:
@@ -89,7 +89,7 @@ CREATE TABLE IF NOT EXISTS models (
|
||||
type TEXT NOT NULL DEFAULT 'chat',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
CONSTRAINT models_model_id_unique UNIQUE (model_id),
|
||||
CONSTRAINT models_provider_model_id_unique UNIQUE (llm_provider_id, model_id),
|
||||
CONSTRAINT models_type_check CHECK (type IN ('chat', 'embedding')),
|
||||
CONSTRAINT models_dimensions_check CHECK (type != 'embedding' OR dimensions IS NOT NULL),
|
||||
CONSTRAINT models_client_type_check CHECK (client_type IS NULL OR client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai')),
|
||||
@@ -389,4 +389,3 @@ CREATE TABLE IF NOT EXISTS bot_history_message_assets (
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_message_assets_message_id ON bot_history_message_assets(message_id);
|
||||
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
-- 0011_model_id_unique_per_provider
|
||||
-- Revert model_id uniqueness back to global uniqueness.
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM models
|
||||
GROUP BY model_id
|
||||
HAVING COUNT(*) > 1
|
||||
) THEN
|
||||
RAISE EXCEPTION 'cannot rollback 0011_model_id_unique_per_provider: duplicate model_id values exist across providers';
|
||||
END IF;
|
||||
|
||||
IF EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'models_provider_model_id_unique') THEN
|
||||
ALTER TABLE models DROP CONSTRAINT models_provider_model_id_unique;
|
||||
END IF;
|
||||
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'models_model_id_unique') THEN
|
||||
ALTER TABLE models
|
||||
ADD CONSTRAINT models_model_id_unique UNIQUE (model_id);
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
@@ -0,0 +1,15 @@
|
||||
-- 0011_model_id_unique_per_provider
|
||||
-- Change model_id uniqueness from global to per provider.
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'models_model_id_unique') THEN
|
||||
ALTER TABLE models DROP CONSTRAINT models_model_id_unique;
|
||||
END IF;
|
||||
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'models_provider_model_id_unique') THEN
|
||||
ALTER TABLE models
|
||||
ADD CONSTRAINT models_provider_model_id_unique UNIQUE (llm_provider_id, model_id);
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
@@ -205,22 +205,17 @@ ON CONFLICT (bot_id, user_id) DO NOTHING;
|
||||
-- chat_settings
|
||||
|
||||
-- name: UpsertChatSettings :one
|
||||
WITH resolved_model AS (
|
||||
SELECT id
|
||||
FROM models
|
||||
WHERE model_id = NULLIF(sqlc.narg(model_id)::text, '')
|
||||
LIMIT 1
|
||||
),
|
||||
WITH
|
||||
updated AS (
|
||||
UPDATE bots
|
||||
SET chat_model_id = COALESCE((SELECT id FROM resolved_model), bots.chat_model_id),
|
||||
SET chat_model_id = COALESCE(sqlc.narg(chat_model_id)::uuid, bots.chat_model_id),
|
||||
updated_at = now()
|
||||
WHERE bots.id = sqlc.arg(id)
|
||||
RETURNING bots.id, bots.chat_model_id, bots.updated_at
|
||||
)
|
||||
SELECT
|
||||
updated.id AS chat_id,
|
||||
chat_models.model_id AS model_id,
|
||||
chat_models.id AS model_id,
|
||||
updated.updated_at
|
||||
FROM updated
|
||||
LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id;
|
||||
@@ -228,7 +223,7 @@ LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id;
|
||||
-- name: GetChatSettings :one
|
||||
SELECT
|
||||
b.id AS chat_id,
|
||||
chat_models.model_id AS model_id,
|
||||
chat_models.id AS model_id,
|
||||
b.updated_at
|
||||
FROM bots b
|
||||
LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id
|
||||
|
||||
@@ -54,6 +54,11 @@ SELECT * FROM models WHERE id = sqlc.arg(id);
|
||||
-- name: GetModelByModelID :one
|
||||
SELECT * FROM models WHERE model_id = sqlc.arg(model_id);
|
||||
|
||||
-- name: ListModelsByModelID :many
|
||||
SELECT * FROM models
|
||||
WHERE model_id = sqlc.arg(model_id)
|
||||
ORDER BY created_at DESC;
|
||||
|
||||
-- name: ListModels :many
|
||||
SELECT * FROM models
|
||||
ORDER BY created_at DESC;
|
||||
@@ -82,6 +87,7 @@ ORDER BY created_at DESC;
|
||||
-- name: UpdateModel :one
|
||||
UPDATE models
|
||||
SET
|
||||
model_id = sqlc.arg(model_id),
|
||||
name = sqlc.arg(name),
|
||||
llm_provider_id = sqlc.arg(llm_provider_id),
|
||||
client_type = sqlc.narg(client_type),
|
||||
|
||||
@@ -5,9 +5,9 @@ SELECT
|
||||
bots.max_context_tokens,
|
||||
bots.language,
|
||||
bots.allow_guest,
|
||||
chat_models.model_id AS chat_model_id,
|
||||
memory_models.model_id AS memory_model_id,
|
||||
embedding_models.model_id AS embedding_model_id,
|
||||
chat_models.id AS chat_model_id,
|
||||
memory_models.id AS memory_model_id,
|
||||
embedding_models.id AS embedding_model_id,
|
||||
search_providers.id AS search_provider_id
|
||||
FROM bots
|
||||
LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id
|
||||
@@ -37,9 +37,9 @@ SELECT
|
||||
updated.max_context_tokens,
|
||||
updated.language,
|
||||
updated.allow_guest,
|
||||
chat_models.model_id AS chat_model_id,
|
||||
memory_models.model_id AS memory_model_id,
|
||||
embedding_models.model_id AS embedding_model_id,
|
||||
chat_models.id AS chat_model_id,
|
||||
memory_models.id AS memory_model_id,
|
||||
embedding_models.id AS embedding_model_id,
|
||||
search_providers.id AS search_provider_id
|
||||
FROM updated
|
||||
LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
attachmentpkg "github.com/memohai/memoh/internal/attachment"
|
||||
@@ -1535,10 +1536,30 @@ func (r *Resolver) selectChatModel(ctx context.Context, req conversation.ChatReq
|
||||
}
|
||||
|
||||
func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) {
|
||||
model, err := r.modelsService.GetByModelID(ctx, modelID)
|
||||
modelRef := strings.TrimSpace(modelID)
|
||||
if modelRef == "" {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model id is required")
|
||||
}
|
||||
|
||||
// Support both model UUID and model_id slug. UUID-formatted slugs still
|
||||
// work because we fall back to GetByModelID when UUID lookup misses.
|
||||
var model models.GetResponse
|
||||
var err error
|
||||
if _, parseErr := db.ParseUUID(modelRef); parseErr == nil {
|
||||
model, err = r.modelsService.GetByID(ctx, modelRef)
|
||||
if err == nil {
|
||||
goto resolved
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
}
|
||||
model, err = r.modelsService.GetByModelID(ctx, modelRef)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
|
||||
resolved:
|
||||
if model.Type != models.ModelTypeChat {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model")
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
@@ -20,6 +21,7 @@ var (
|
||||
ErrChatNotFound = errors.New("chat not found")
|
||||
ErrNotParticipant = errors.New("not a participant")
|
||||
ErrPermissionDenied = errors.New("permission denied")
|
||||
ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers")
|
||||
)
|
||||
|
||||
// Service manages conversation lifecycle, participants, and settings.
|
||||
@@ -310,21 +312,26 @@ func (s *Service) GetSettings(ctx context.Context, conversationID string) (Setti
|
||||
|
||||
// UpdateSettings updates conversation settings.
|
||||
func (s *Service) UpdateSettings(ctx context.Context, conversationID string, req UpdateSettingsRequest) (Settings, error) {
|
||||
current, err := s.GetSettings(ctx, conversationID)
|
||||
if err != nil {
|
||||
return Settings{}, err
|
||||
}
|
||||
if req.ModelID != nil {
|
||||
current.ModelID = *req.ModelID
|
||||
}
|
||||
|
||||
pgID, err := parseUUID(conversationID)
|
||||
if err != nil {
|
||||
return Settings{}, err
|
||||
}
|
||||
|
||||
chatModelUUID := pgtype.UUID{}
|
||||
if req.ModelID != nil {
|
||||
modelRef := strings.TrimSpace(*req.ModelID)
|
||||
if modelRef != "" {
|
||||
resolved, err := s.resolveModelUUID(ctx, modelRef)
|
||||
if err != nil {
|
||||
return Settings{}, err
|
||||
}
|
||||
chatModelUUID = resolved
|
||||
}
|
||||
}
|
||||
|
||||
row, err := s.queries.UpsertChatSettings(ctx, sqlc.UpsertChatSettingsParams{
|
||||
ID: pgID,
|
||||
ModelID: toPgText(current.ModelID),
|
||||
ID: pgID,
|
||||
ChatModelID: chatModelUUID,
|
||||
})
|
||||
if err != nil {
|
||||
return Settings{}, err
|
||||
@@ -427,17 +434,23 @@ func toParticipantFields(conversationID, userID pgtype.UUID, role string, joined
|
||||
}
|
||||
|
||||
func toSettingsFromRead(row sqlc.GetChatSettingsRow) Settings {
|
||||
return Settings{
|
||||
ChatID: row.ChatID.String(),
|
||||
ModelID: dbpkg.TextToString(row.ModelID),
|
||||
settings := Settings{
|
||||
ChatID: row.ChatID.String(),
|
||||
}
|
||||
if row.ModelID.Valid {
|
||||
settings.ModelID = uuid.UUID(row.ModelID.Bytes).String()
|
||||
}
|
||||
return settings
|
||||
}
|
||||
|
||||
func toSettingsFromUpsert(row sqlc.UpsertChatSettingsRow) Settings {
|
||||
return Settings{
|
||||
ChatID: row.ChatID.String(),
|
||||
ModelID: dbpkg.TextToString(row.ModelID),
|
||||
settings := Settings{
|
||||
ChatID: row.ChatID.String(),
|
||||
}
|
||||
if row.ModelID.Valid {
|
||||
settings.ModelID = uuid.UUID(row.ModelID.Bytes).String()
|
||||
}
|
||||
return settings
|
||||
}
|
||||
|
||||
func defaultSettings(conversationID string) Settings {
|
||||
@@ -450,12 +463,32 @@ func parseUUID(id string) (pgtype.UUID, error) {
|
||||
return dbpkg.ParseUUID(id)
|
||||
}
|
||||
|
||||
func toPgText(s string) pgtype.Text {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return pgtype.Text{}
|
||||
func (s *Service) resolveModelUUID(ctx context.Context, modelRef string) (pgtype.UUID, error) {
|
||||
modelRef = strings.TrimSpace(modelRef)
|
||||
if modelRef == "" {
|
||||
return pgtype.UUID{}, fmt.Errorf("model_id is required")
|
||||
}
|
||||
return pgtype.Text{String: s, Valid: true}
|
||||
|
||||
// Prefer UUID path; if not found, fall back to model_id slug.
|
||||
if parsed, err := dbpkg.ParseUUID(modelRef); err == nil {
|
||||
if _, err := s.queries.GetModelByID(ctx, parsed); err == nil {
|
||||
return parsed, nil
|
||||
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return pgtype.UUID{}, err
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := s.queries.ListModelsByModelID(ctx, modelRef)
|
||||
if err != nil {
|
||||
return pgtype.UUID{}, err
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return pgtype.UUID{}, fmt.Errorf("model not found: %s", modelRef)
|
||||
}
|
||||
if len(rows) > 1 {
|
||||
return pgtype.UUID{}, fmt.Errorf("%w: %s", ErrModelIDAmbiguous, modelRef)
|
||||
}
|
||||
return rows[0].ID, nil
|
||||
}
|
||||
|
||||
func pgTimePtr(ts pgtype.Timestamptz) *time.Time {
|
||||
|
||||
@@ -271,7 +271,7 @@ func (q *Queries) GetChatReadAccessByUser(ctx context.Context, arg GetChatReadAc
|
||||
const getChatSettings = `-- name: GetChatSettings :one
|
||||
SELECT
|
||||
b.id AS chat_id,
|
||||
chat_models.model_id AS model_id,
|
||||
chat_models.id AS model_id,
|
||||
b.updated_at
|
||||
FROM bots b
|
||||
LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id
|
||||
@@ -280,7 +280,7 @@ WHERE b.id = $1
|
||||
|
||||
type GetChatSettingsRow struct {
|
||||
ChatID pgtype.UUID `json:"chat_id"`
|
||||
ModelID pgtype.Text `json:"model_id"`
|
||||
ModelID pgtype.UUID `json:"model_id"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
}
|
||||
|
||||
@@ -645,41 +645,36 @@ func (q *Queries) UpdateChatTitle(ctx context.Context, arg UpdateChatTitleParams
|
||||
|
||||
const upsertChatSettings = `-- name: UpsertChatSettings :one
|
||||
|
||||
WITH resolved_model AS (
|
||||
SELECT id
|
||||
FROM models
|
||||
WHERE model_id = NULLIF($1::text, '')
|
||||
LIMIT 1
|
||||
),
|
||||
WITH
|
||||
updated AS (
|
||||
UPDATE bots
|
||||
SET chat_model_id = COALESCE((SELECT id FROM resolved_model), bots.chat_model_id),
|
||||
SET chat_model_id = COALESCE($1::uuid, bots.chat_model_id),
|
||||
updated_at = now()
|
||||
WHERE bots.id = $2
|
||||
RETURNING bots.id, bots.chat_model_id, bots.updated_at
|
||||
)
|
||||
SELECT
|
||||
updated.id AS chat_id,
|
||||
chat_models.model_id AS model_id,
|
||||
chat_models.id AS model_id,
|
||||
updated.updated_at
|
||||
FROM updated
|
||||
LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id
|
||||
`
|
||||
|
||||
type UpsertChatSettingsParams struct {
|
||||
ModelID pgtype.Text `json:"model_id"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
ChatModelID pgtype.UUID `json:"chat_model_id"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
}
|
||||
|
||||
type UpsertChatSettingsRow struct {
|
||||
ChatID pgtype.UUID `json:"chat_id"`
|
||||
ModelID pgtype.Text `json:"model_id"`
|
||||
ModelID pgtype.UUID `json:"model_id"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
}
|
||||
|
||||
// chat_settings
|
||||
func (q *Queries) UpsertChatSettings(ctx context.Context, arg UpsertChatSettingsParams) (UpsertChatSettingsRow, error) {
|
||||
row := q.db.QueryRow(ctx, upsertChatSettings, arg.ModelID, arg.ID)
|
||||
row := q.db.QueryRow(ctx, upsertChatSettings, arg.ChatModelID, arg.ID)
|
||||
var i UpsertChatSettingsRow
|
||||
err := row.Scan(&i.ChatID, &i.ModelID, &i.UpdatedAt)
|
||||
return i, err
|
||||
|
||||
@@ -419,6 +419,43 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType pgtype.
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listModelsByModelID = `-- name: ListModelsByModelID :many
|
||||
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models
|
||||
WHERE model_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListModelsByModelID(ctx context.Context, modelID string) ([]Model, error) {
|
||||
rows, err := q.db.Query(ctx, listModelsByModelID, modelID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Model
|
||||
for rows.Next() {
|
||||
var i Model
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.ModelID,
|
||||
&i.Name,
|
||||
&i.LlmProviderID,
|
||||
&i.ClientType,
|
||||
&i.Dimensions,
|
||||
&i.InputModalities,
|
||||
&i.Type,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listModelsByProviderID = `-- name: ListModelsByProviderID :many
|
||||
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models
|
||||
WHERE llm_provider_id = $1
|
||||
@@ -580,18 +617,20 @@ func (q *Queries) UpdateLlmProvider(ctx context.Context, arg UpdateLlmProviderPa
|
||||
const updateModel = `-- name: UpdateModel :one
|
||||
UPDATE models
|
||||
SET
|
||||
name = $1,
|
||||
llm_provider_id = $2,
|
||||
client_type = $3,
|
||||
dimensions = $4,
|
||||
input_modalities = $5,
|
||||
type = $6,
|
||||
model_id = $1,
|
||||
name = $2,
|
||||
llm_provider_id = $3,
|
||||
client_type = $4,
|
||||
dimensions = $5,
|
||||
input_modalities = $6,
|
||||
type = $7,
|
||||
updated_at = now()
|
||||
WHERE id = $7
|
||||
WHERE id = $8
|
||||
RETURNING id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateModelParams struct {
|
||||
ModelID string `json:"model_id"`
|
||||
Name pgtype.Text `json:"name"`
|
||||
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
|
||||
ClientType pgtype.Text `json:"client_type"`
|
||||
@@ -603,6 +642,7 @@ type UpdateModelParams struct {
|
||||
|
||||
func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model, error) {
|
||||
row := q.db.QueryRow(ctx, updateModel,
|
||||
arg.ModelID,
|
||||
arg.Name,
|
||||
arg.LlmProviderID,
|
||||
arg.ClientType,
|
||||
|
||||
@@ -37,9 +37,9 @@ SELECT
|
||||
bots.max_context_tokens,
|
||||
bots.language,
|
||||
bots.allow_guest,
|
||||
chat_models.model_id AS chat_model_id,
|
||||
memory_models.model_id AS memory_model_id,
|
||||
embedding_models.model_id AS embedding_model_id,
|
||||
chat_models.id AS chat_model_id,
|
||||
memory_models.id AS memory_model_id,
|
||||
embedding_models.id AS embedding_model_id,
|
||||
search_providers.id AS search_provider_id
|
||||
FROM bots
|
||||
LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id
|
||||
@@ -55,9 +55,9 @@ type GetSettingsByBotIDRow struct {
|
||||
MaxContextTokens int32 `json:"max_context_tokens"`
|
||||
Language string `json:"language"`
|
||||
AllowGuest bool `json:"allow_guest"`
|
||||
ChatModelID pgtype.Text `json:"chat_model_id"`
|
||||
MemoryModelID pgtype.Text `json:"memory_model_id"`
|
||||
EmbeddingModelID pgtype.Text `json:"embedding_model_id"`
|
||||
ChatModelID pgtype.UUID `json:"chat_model_id"`
|
||||
MemoryModelID pgtype.UUID `json:"memory_model_id"`
|
||||
EmbeddingModelID pgtype.UUID `json:"embedding_model_id"`
|
||||
SearchProviderID pgtype.UUID `json:"search_provider_id"`
|
||||
}
|
||||
|
||||
@@ -99,9 +99,9 @@ SELECT
|
||||
updated.max_context_tokens,
|
||||
updated.language,
|
||||
updated.allow_guest,
|
||||
chat_models.model_id AS chat_model_id,
|
||||
memory_models.model_id AS memory_model_id,
|
||||
embedding_models.model_id AS embedding_model_id,
|
||||
chat_models.id AS chat_model_id,
|
||||
memory_models.id AS memory_model_id,
|
||||
embedding_models.id AS embedding_model_id,
|
||||
search_providers.id AS search_provider_id
|
||||
FROM updated
|
||||
LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id
|
||||
@@ -128,9 +128,9 @@ type UpsertBotSettingsRow struct {
|
||||
MaxContextTokens int32 `json:"max_context_tokens"`
|
||||
Language string `json:"language"`
|
||||
AllowGuest bool `json:"allow_guest"`
|
||||
ChatModelID pgtype.Text `json:"chat_model_id"`
|
||||
MemoryModelID pgtype.Text `json:"memory_model_id"`
|
||||
EmbeddingModelID pgtype.Text `json:"embedding_model_id"`
|
||||
ChatModelID pgtype.UUID `json:"chat_model_id"`
|
||||
MemoryModelID pgtype.UUID `json:"memory_model_id"`
|
||||
EmbeddingModelID pgtype.UUID `json:"embedding_model_id"`
|
||||
SearchProviderID pgtype.UUID `json:"search_provider_id"`
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
@@ -52,6 +54,9 @@ func (h *ModelsHandler) Create(c echo.Context) error {
|
||||
|
||||
resp, err := h.service.Create(c.Request().Context(), req)
|
||||
if err != nil {
|
||||
if errors.Is(err, models.ErrModelIDAlreadyExists) {
|
||||
return echo.NewHTTPError(http.StatusConflict, "model_id already exists under the selected provider")
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusCreated, resp)
|
||||
@@ -134,6 +139,12 @@ func (h *ModelsHandler) GetByModelID(c echo.Context) error {
|
||||
|
||||
resp, err := h.service.GetByModelID(c.Request().Context(), modelID)
|
||||
if err != nil {
|
||||
if errors.Is(err, models.ErrModelIDAmbiguous) {
|
||||
return echo.NewHTTPError(http.StatusConflict, "model_id is duplicated across providers; use /models/{id} instead")
|
||||
}
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return echo.NewHTTPError(http.StatusNotFound, err.Error())
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusNotFound, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
@@ -163,6 +174,9 @@ func (h *ModelsHandler) UpdateByID(c echo.Context) error {
|
||||
|
||||
resp, err := h.service.UpdateByID(c.Request().Context(), id, req)
|
||||
if err != nil {
|
||||
if errors.Is(err, models.ErrModelIDAlreadyExists) {
|
||||
return echo.NewHTTPError(http.StatusConflict, "model_id already exists under the selected provider")
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
@@ -197,6 +211,15 @@ func (h *ModelsHandler) UpdateByModelID(c echo.Context) error {
|
||||
|
||||
resp, err := h.service.UpdateByModelID(c.Request().Context(), modelID, req)
|
||||
if err != nil {
|
||||
if errors.Is(err, models.ErrModelIDAlreadyExists) {
|
||||
return echo.NewHTTPError(http.StatusConflict, "model_id already exists under the selected provider")
|
||||
}
|
||||
if errors.Is(err, models.ErrModelIDAmbiguous) {
|
||||
return echo.NewHTTPError(http.StatusConflict, "model_id is duplicated across providers; use /models/{id} instead")
|
||||
}
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return echo.NewHTTPError(http.StatusNotFound, err.Error())
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
@@ -246,6 +269,12 @@ func (h *ModelsHandler) DeleteByModelID(c echo.Context) error {
|
||||
}
|
||||
|
||||
if err := h.service.DeleteByModelID(c.Request().Context(), modelID); err != nil {
|
||||
if errors.Is(err, models.ErrModelIDAmbiguous) {
|
||||
return echo.NewHTTPError(http.StatusConflict, "model_id is duplicated across providers; use /models/{id} instead")
|
||||
}
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return echo.NewHTTPError(http.StatusNotFound, err.Error())
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
|
||||
@@ -96,6 +96,12 @@ func (h *SettingsHandler) Upsert(c echo.Context) error {
|
||||
if errors.Is(err, settings.ErrPersonalBotGuestAccessUnsupported) {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "personal bot does not support guest access")
|
||||
}
|
||||
if errors.Is(err, settings.ErrInvalidModelRef) {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if errors.Is(err, settings.ErrModelIDAmbiguous) {
|
||||
return echo.NewHTTPError(http.StatusConflict, "model_id is duplicated across providers; select by model UUID")
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
|
||||
@@ -2,16 +2,21 @@ package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
)
|
||||
|
||||
var ErrModelIDAlreadyExists = errors.New("model_id already exists")
|
||||
var ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers")
|
||||
|
||||
// Service provides CRUD operations for models
|
||||
type Service struct {
|
||||
queries *sqlc.Queries
|
||||
@@ -65,6 +70,9 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro
|
||||
|
||||
created, err := s.queries.CreateModel(ctx, params)
|
||||
if err != nil {
|
||||
if db.IsUniqueViolation(err) {
|
||||
return AddResponse{}, ErrModelIDAlreadyExists
|
||||
}
|
||||
return AddResponse{}, fmt.Errorf("failed to create model: %w", err)
|
||||
}
|
||||
|
||||
@@ -105,7 +113,7 @@ func (s *Service) GetByModelID(ctx context.Context, modelID string) (GetResponse
|
||||
return GetResponse{}, fmt.Errorf("model_id is required")
|
||||
}
|
||||
|
||||
dbModel, err := s.queries.GetModelByModelID(ctx, modelID)
|
||||
dbModel, err := s.findUniqueByModelID(ctx, modelID)
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("failed to get model: %w", err)
|
||||
}
|
||||
@@ -207,6 +215,7 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
|
||||
}
|
||||
params := sqlc.UpdateModelParams{
|
||||
ID: uuid,
|
||||
ModelID: model.ModelID,
|
||||
InputModalities: inputMod,
|
||||
Type: string(model.Type),
|
||||
}
|
||||
@@ -230,6 +239,9 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
|
||||
|
||||
updated, err := s.queries.UpdateModel(ctx, params)
|
||||
if err != nil {
|
||||
if db.IsUniqueViolation(err) {
|
||||
return GetResponse{}, ErrModelIDAlreadyExists
|
||||
}
|
||||
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
|
||||
}
|
||||
|
||||
@@ -241,6 +253,10 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
|
||||
if modelID == "" {
|
||||
return GetResponse{}, fmt.Errorf("model_id is required")
|
||||
}
|
||||
current, err := s.findUniqueByModelID(ctx, modelID)
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
|
||||
}
|
||||
|
||||
model := Model(req)
|
||||
if err := model.Validate(); err != nil {
|
||||
@@ -251,9 +267,8 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
|
||||
if model.Type == ModelTypeChat {
|
||||
inputMod = normalizeModalities(model.InputModalities, []string{ModelInputText})
|
||||
}
|
||||
params := sqlc.UpdateModelByModelIDParams{
|
||||
ModelID: modelID,
|
||||
NewModelID: model.ModelID,
|
||||
params := sqlc.UpdateModelParams{
|
||||
ID: current.ID,
|
||||
InputModalities: inputMod,
|
||||
Type: string(model.Type),
|
||||
}
|
||||
@@ -275,8 +290,13 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
|
||||
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
|
||||
}
|
||||
|
||||
updated, err := s.queries.UpdateModelByModelID(ctx, params)
|
||||
params.ModelID = model.ModelID
|
||||
|
||||
updated, err := s.queries.UpdateModel(ctx, params)
|
||||
if err != nil {
|
||||
if db.IsUniqueViolation(err) {
|
||||
return GetResponse{}, ErrModelIDAlreadyExists
|
||||
}
|
||||
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
|
||||
}
|
||||
|
||||
@@ -302,8 +322,12 @@ func (s *Service) DeleteByModelID(ctx context.Context, modelID string) error {
|
||||
if modelID == "" {
|
||||
return fmt.Errorf("model_id is required")
|
||||
}
|
||||
current, err := s.findUniqueByModelID(ctx, modelID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete model: %w", err)
|
||||
}
|
||||
|
||||
if err := s.queries.DeleteModelByModelID(ctx, modelID); err != nil {
|
||||
if err := s.queries.DeleteModel(ctx, current.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete model: %w", err)
|
||||
}
|
||||
|
||||
@@ -336,6 +360,7 @@ func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64,
|
||||
|
||||
func convertToGetResponse(dbModel sqlc.Model) GetResponse {
|
||||
resp := GetResponse{
|
||||
ID: dbModel.ID.String(),
|
||||
ModelID: dbModel.ModelID,
|
||||
Model: Model{
|
||||
ModelID: dbModel.ModelID,
|
||||
@@ -372,6 +397,20 @@ func convertToGetResponseList(dbModels []sqlc.Model) []GetResponse {
|
||||
return responses
|
||||
}
|
||||
|
||||
func (s *Service) findUniqueByModelID(ctx context.Context, modelID string) (sqlc.Model, error) {
|
||||
rows, err := s.queries.ListModelsByModelID(ctx, modelID)
|
||||
if err != nil {
|
||||
return sqlc.Model{}, err
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return sqlc.Model{}, pgx.ErrNoRows
|
||||
}
|
||||
if len(rows) > 1 {
|
||||
return sqlc.Model{}, ErrModelIDAmbiguous
|
||||
}
|
||||
return rows[0], nil
|
||||
}
|
||||
|
||||
// normalizeModalities returns modalities if non-empty, otherwise the provided fallback.
|
||||
func normalizeModalities(modalities []string, fallback []string) []string {
|
||||
if len(modalities) == 0 {
|
||||
|
||||
@@ -113,6 +113,7 @@ type GetRequest struct {
|
||||
}
|
||||
|
||||
type GetResponse struct {
|
||||
ID string `json:"id"`
|
||||
ModelID string `json:"model_id"`
|
||||
Model
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
@@ -20,6 +21,8 @@ type Service struct {
|
||||
}
|
||||
|
||||
var ErrPersonalBotGuestAccessUnsupported = errors.New("personal bots do not support guest access")
|
||||
var ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers")
|
||||
var ErrInvalidModelRef = errors.New("invalid model reference")
|
||||
|
||||
func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
|
||||
return &Service{
|
||||
@@ -184,15 +187,21 @@ func normalizeBotSettingsFields(
|
||||
maxContextTokens int32,
|
||||
language string,
|
||||
allowGuest bool,
|
||||
chatModelID pgtype.Text,
|
||||
memoryModelID pgtype.Text,
|
||||
embeddingModelID pgtype.Text,
|
||||
chatModelID pgtype.UUID,
|
||||
memoryModelID pgtype.UUID,
|
||||
embeddingModelID pgtype.UUID,
|
||||
searchProviderID pgtype.UUID,
|
||||
) Settings {
|
||||
settings := normalizeBotSetting(maxContextLoadTime, maxContextTokens, language, allowGuest)
|
||||
settings.ChatModelID = strings.TrimSpace(chatModelID.String)
|
||||
settings.MemoryModelID = strings.TrimSpace(memoryModelID.String)
|
||||
settings.EmbeddingModelID = strings.TrimSpace(embeddingModelID.String)
|
||||
if chatModelID.Valid {
|
||||
settings.ChatModelID = uuid.UUID(chatModelID.Bytes).String()
|
||||
}
|
||||
if memoryModelID.Valid {
|
||||
settings.MemoryModelID = uuid.UUID(memoryModelID.Bytes).String()
|
||||
}
|
||||
if embeddingModelID.Valid {
|
||||
settings.EmbeddingModelID = uuid.UUID(embeddingModelID.Bytes).String()
|
||||
}
|
||||
if searchProviderID.Valid {
|
||||
settings.SearchProviderID = uuid.UUID(searchProviderID.Bytes).String()
|
||||
}
|
||||
@@ -200,12 +209,29 @@ func normalizeBotSettingsFields(
|
||||
}
|
||||
|
||||
func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype.UUID, error) {
|
||||
if strings.TrimSpace(modelID) == "" {
|
||||
return pgtype.UUID{}, fmt.Errorf("model_id is required")
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
return pgtype.UUID{}, fmt.Errorf("%w: model_id is required", ErrInvalidModelRef)
|
||||
}
|
||||
row, err := s.queries.GetModelByModelID(ctx, modelID)
|
||||
|
||||
// Preferred path: when caller already passes the model UUID.
|
||||
if parsed, err := db.ParseUUID(modelID); err == nil {
|
||||
if _, err := s.queries.GetModelByID(ctx, parsed); err == nil {
|
||||
return parsed, nil
|
||||
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return pgtype.UUID{}, err
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := s.queries.ListModelsByModelID(ctx, modelID)
|
||||
if err != nil {
|
||||
return pgtype.UUID{}, err
|
||||
}
|
||||
return row.ID, nil
|
||||
if len(rows) == 0 {
|
||||
return pgtype.UUID{}, fmt.Errorf("%w: model not found: %s", ErrInvalidModelRef, modelID)
|
||||
}
|
||||
if len(rows) > 1 {
|
||||
return pgtype.UUID{}, fmt.Errorf("%w: %s", ErrModelIDAmbiguous, modelID)
|
||||
}
|
||||
return rows[0].ID, nil
|
||||
}
|
||||
|
||||
@@ -718,6 +718,7 @@ export type ModelsCountResponse = {
|
||||
export type ModelsGetResponse = {
|
||||
client_type?: ModelsClientType;
|
||||
dimensions?: number;
|
||||
id?: string;
|
||||
input_modalities?: Array<string>;
|
||||
llm_provider_id?: string;
|
||||
model_id?: string;
|
||||
|
||||
@@ -229,12 +229,15 @@ import { inject, computed, watch, nextTick, type Ref, ref } from 'vue'
|
||||
import { toTypedSchema } from '@vee-validate/zod'
|
||||
import z from 'zod'
|
||||
import { useMutation, useQueryCache } from '@pinia/colada'
|
||||
import { postModels, putModelsModelByModelId } from '@memoh/sdk'
|
||||
import { postModels, putModelsById, putModelsModelByModelId } from '@memoh/sdk'
|
||||
import type { ModelsGetResponse } from '@memoh/sdk'
|
||||
import { CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { toast } from 'vue-sonner'
|
||||
|
||||
const availableInputModalities = ['text', 'image', 'audio', 'video', 'file'] as const
|
||||
const selectedModalities = ref<string[]>(['text'])
|
||||
const { t } = useI18n()
|
||||
|
||||
const formSchema = toTypedSchema(z.object({
|
||||
type: z.string().min(1),
|
||||
@@ -313,6 +316,17 @@ const { mutateAsync: createModel, isLoading: createLoading } = useMutation({
|
||||
onSettled: () => queryCache.invalidateQueries({ key: ['provider-models'] }),
|
||||
})
|
||||
const { mutateAsync: updateModel, isLoading: updateLoading } = useMutation({
|
||||
mutation: async ({ id, data }: { id: string; data: Record<string, unknown> }) => {
|
||||
const { data: result } = await putModelsById({
|
||||
path: { id },
|
||||
body: data as any,
|
||||
throwOnError: true,
|
||||
})
|
||||
return result
|
||||
},
|
||||
onSettled: () => queryCache.invalidateQueries({ key: ['provider-models'] }),
|
||||
})
|
||||
const { mutateAsync: updateModelByLegacyModelID, isLoading: updateLegacyLoading } = useMutation({
|
||||
mutation: async ({ modelId, data }: { modelId: string; data: Record<string, unknown> }) => {
|
||||
const { data: result } = await putModelsModelByModelId({
|
||||
path: { modelId },
|
||||
@@ -323,7 +337,7 @@ const { mutateAsync: updateModel, isLoading: updateLoading } = useMutation({
|
||||
},
|
||||
onSettled: () => queryCache.invalidateQueries({ key: ['provider-models'] }),
|
||||
})
|
||||
const isLoading = computed(() => createLoading.value || updateLoading.value)
|
||||
const isLoading = computed(() => createLoading.value || updateLoading.value || updateLegacyLoading.value)
|
||||
|
||||
async function addModel(e: Event) {
|
||||
e.preventDefault()
|
||||
@@ -366,16 +380,31 @@ async function addModel(e: Event) {
|
||||
}
|
||||
|
||||
if (isEdit) {
|
||||
await updateModel({ modelId: fallback!.model_id, data: payload as any })
|
||||
const modelUUID = fallback?.id
|
||||
if (modelUUID) {
|
||||
await updateModel({ id: modelUUID, data: payload as any })
|
||||
} else {
|
||||
await updateModelByLegacyModelID({ modelId: fallback!.model_id, data: payload as any })
|
||||
}
|
||||
} else {
|
||||
await createModel(payload as any)
|
||||
}
|
||||
open.value = false
|
||||
} catch {
|
||||
} catch (error) {
|
||||
toast.error(resolveErrorMessage(error, t('common.saveFailed')))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
function resolveErrorMessage(error: unknown, fallback: string): string {
|
||||
if (error instanceof Error && error.message.trim()) return error.message
|
||||
if (error && typeof error === 'object' && 'message' in error) {
|
||||
const msg = (error as { message?: string }).message
|
||||
if (msg && msg.trim()) return msg
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
watch(open, async () => {
|
||||
if (!open.value) {
|
||||
title.value = 'title'
|
||||
|
||||
@@ -52,13 +52,13 @@
|
||||
</div>
|
||||
<button
|
||||
v-for="model in group.models"
|
||||
:key="model.model_id"
|
||||
:key="model.id || `${model.llm_provider_id}:${model.model_id}`"
|
||||
class="relative flex w-full cursor-pointer items-center gap-2 rounded-md px-2 py-1.5 text-sm outline-none hover:bg-accent hover:text-accent-foreground"
|
||||
:class="{ 'bg-accent': selected === model.model_id }"
|
||||
@click="selectModel(model.model_id)"
|
||||
:class="{ 'bg-accent': selected === model.id }"
|
||||
@click="selectModel(model.id)"
|
||||
>
|
||||
<FontAwesomeIcon
|
||||
v-if="selected === model.model_id"
|
||||
v-if="selected === model.id"
|
||||
:icon="['fas', 'check']"
|
||||
class="size-3.5"
|
||||
/>
|
||||
@@ -145,11 +145,12 @@ const filteredGroups = computed(() => {
|
||||
// 显示选中模型的名称
|
||||
const displayLabel = computed(() => {
|
||||
if (!selected.value) return ''
|
||||
const model = typeFilteredModels.value.find((m) => m.model_id === selected.value)
|
||||
const model = typeFilteredModels.value.find((m) => m.id === selected.value)
|
||||
return model?.name || model?.model_id || selected.value
|
||||
})
|
||||
|
||||
function selectModel(modelId: string) {
|
||||
function selectModel(modelId?: string) {
|
||||
if (!modelId) return
|
||||
selected.value = modelId
|
||||
open.value = false
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
<template>
|
||||
<Item variant="outline">
|
||||
<ItemContent>
|
||||
<ItemTitle>{{ model.name }}</ItemTitle>
|
||||
<ItemTitle>{{ model.name || model.model_id }}</ItemTitle>
|
||||
<ItemDescription class="gap-2 flex flex-wrap items-center mt-3">
|
||||
<Badge variant="outline">
|
||||
{{ model.type }}
|
||||
@@ -26,7 +26,7 @@
|
||||
<ConfirmPopover
|
||||
:message="$t('models.deleteModelConfirm')"
|
||||
:loading="deleteLoading"
|
||||
@confirm="$emit('delete', model.name)"
|
||||
@confirm="$emit('delete', model.id ?? '')"
|
||||
>
|
||||
<template #trigger>
|
||||
<Button variant="outline">
|
||||
@@ -58,6 +58,6 @@ defineProps<{
|
||||
|
||||
defineEmits<{
|
||||
edit: [model: ModelsGetResponse]
|
||||
delete: [name: string]
|
||||
delete: [id: string]
|
||||
}>()
|
||||
</script>
|
||||
|
||||
@@ -16,11 +16,11 @@
|
||||
>
|
||||
<ModelItem
|
||||
v-for="model in models"
|
||||
:key="model.model_id"
|
||||
:key="model.id || `${model.llm_provider_id}:${model.model_id}`"
|
||||
:model="model"
|
||||
:delete-loading="deleteModelLoading"
|
||||
@edit="(model) => $emit('edit', model)"
|
||||
@delete="(name) => $emit('delete', name)"
|
||||
@delete="(id) => $emit('delete', id)"
|
||||
/>
|
||||
</section>
|
||||
|
||||
@@ -61,6 +61,6 @@ defineProps<{
|
||||
|
||||
defineEmits<{
|
||||
edit: [model: ModelsGetResponse]
|
||||
delete: [name: string]
|
||||
delete: [id: string]
|
||||
}>()
|
||||
</script>
|
||||
|
||||
@@ -33,7 +33,7 @@ import ProviderForm from './components/provider-form.vue'
|
||||
import ModelList from './components/model-list.vue'
|
||||
import { computed, inject, provide, reactive, ref, toRef, watch } from 'vue'
|
||||
import { useQuery, useMutation, useQueryCache } from '@pinia/colada'
|
||||
import { putProvidersById, deleteProvidersById, getProvidersByIdModels, deleteModelsModelByModelId } from '@memoh/sdk'
|
||||
import { putProvidersById, deleteProvidersById, getProvidersByIdModels, deleteModelsById } from '@memoh/sdk'
|
||||
import type { ModelsGetResponse, ProvidersGetResponse } from '@memoh/sdk'
|
||||
|
||||
// ---- Model 编辑状态(provide 给 CreateModel) ----
|
||||
@@ -86,8 +86,9 @@ const { mutate: changeProvider, isLoading: editLoading } = useMutation({
|
||||
})
|
||||
|
||||
const { mutate: deleteModel, isLoading: deleteModelLoading } = useMutation({
|
||||
mutation: async (modelName: string) => {
|
||||
await deleteModelsModelByModelId({ path: { modelId: modelName }, throwOnError: true })
|
||||
mutation: async (modelID: string) => {
|
||||
if (!modelID) return
|
||||
await deleteModelsById({ path: { id: modelID }, throwOnError: true })
|
||||
},
|
||||
onSettled: () => queryCache.invalidateQueries({ key: ['provider-models'] }),
|
||||
})
|
||||
|
||||
@@ -6818,6 +6818,9 @@ const docTemplate = `{
|
||||
"dimensions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"input_modalities": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
|
||||
@@ -6809,6 +6809,9 @@
|
||||
"dimensions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"input_modalities": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
|
||||
@@ -1185,6 +1185,8 @@ definitions:
|
||||
$ref: '#/definitions/models.ClientType'
|
||||
dimensions:
|
||||
type: integer
|
||||
id:
|
||||
type: string
|
||||
input_modalities:
|
||||
items:
|
||||
type: string
|
||||
|
||||
Reference in New Issue
Block a user