mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
fix(models,settings,conversation): scope model_id uniqueness per
provider and harden model reference resolution
This commit is contained in:
@@ -89,7 +89,7 @@ CREATE TABLE IF NOT EXISTS models (
|
|||||||
type TEXT NOT NULL DEFAULT 'chat',
|
type TEXT NOT NULL DEFAULT 'chat',
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
CONSTRAINT models_model_id_unique UNIQUE (model_id),
|
CONSTRAINT models_provider_model_id_unique UNIQUE (llm_provider_id, model_id),
|
||||||
CONSTRAINT models_type_check CHECK (type IN ('chat', 'embedding')),
|
CONSTRAINT models_type_check CHECK (type IN ('chat', 'embedding')),
|
||||||
CONSTRAINT models_dimensions_check CHECK (type != 'embedding' OR dimensions IS NOT NULL),
|
CONSTRAINT models_dimensions_check CHECK (type != 'embedding' OR dimensions IS NOT NULL),
|
||||||
CONSTRAINT models_client_type_check CHECK (client_type IS NULL OR client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai')),
|
CONSTRAINT models_client_type_check CHECK (client_type IS NULL OR client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai')),
|
||||||
@@ -389,4 +389,3 @@ CREATE TABLE IF NOT EXISTS bot_history_message_assets (
|
|||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_message_assets_message_id ON bot_history_message_assets(message_id);
|
CREATE INDEX IF NOT EXISTS idx_message_assets_message_id ON bot_history_message_assets(message_id);
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,24 @@
|
|||||||
|
-- 0011_model_id_unique_per_provider
|
||||||
|
-- Revert model_id uniqueness back to global uniqueness.
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM models
|
||||||
|
GROUP BY model_id
|
||||||
|
HAVING COUNT(*) > 1
|
||||||
|
) THEN
|
||||||
|
RAISE EXCEPTION 'cannot rollback 0011_model_id_unique_per_provider: duplicate model_id values exist across providers';
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
IF EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'models_provider_model_id_unique') THEN
|
||||||
|
ALTER TABLE models DROP CONSTRAINT models_provider_model_id_unique;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'models_model_id_unique') THEN
|
||||||
|
ALTER TABLE models
|
||||||
|
ADD CONSTRAINT models_model_id_unique UNIQUE (model_id);
|
||||||
|
END IF;
|
||||||
|
END
|
||||||
|
$$;
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
-- 0011_model_id_unique_per_provider
|
||||||
|
-- Change model_id uniqueness from global to per provider.
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'models_model_id_unique') THEN
|
||||||
|
ALTER TABLE models DROP CONSTRAINT models_model_id_unique;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'models_provider_model_id_unique') THEN
|
||||||
|
ALTER TABLE models
|
||||||
|
ADD CONSTRAINT models_provider_model_id_unique UNIQUE (llm_provider_id, model_id);
|
||||||
|
END IF;
|
||||||
|
END
|
||||||
|
$$;
|
||||||
@@ -205,22 +205,17 @@ ON CONFLICT (bot_id, user_id) DO NOTHING;
|
|||||||
-- chat_settings
|
-- chat_settings
|
||||||
|
|
||||||
-- name: UpsertChatSettings :one
|
-- name: UpsertChatSettings :one
|
||||||
WITH resolved_model AS (
|
WITH
|
||||||
SELECT id
|
|
||||||
FROM models
|
|
||||||
WHERE model_id = NULLIF(sqlc.narg(model_id)::text, '')
|
|
||||||
LIMIT 1
|
|
||||||
),
|
|
||||||
updated AS (
|
updated AS (
|
||||||
UPDATE bots
|
UPDATE bots
|
||||||
SET chat_model_id = COALESCE((SELECT id FROM resolved_model), bots.chat_model_id),
|
SET chat_model_id = COALESCE(sqlc.narg(chat_model_id)::uuid, bots.chat_model_id),
|
||||||
updated_at = now()
|
updated_at = now()
|
||||||
WHERE bots.id = sqlc.arg(id)
|
WHERE bots.id = sqlc.arg(id)
|
||||||
RETURNING bots.id, bots.chat_model_id, bots.updated_at
|
RETURNING bots.id, bots.chat_model_id, bots.updated_at
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
updated.id AS chat_id,
|
updated.id AS chat_id,
|
||||||
chat_models.model_id AS model_id,
|
chat_models.id AS model_id,
|
||||||
updated.updated_at
|
updated.updated_at
|
||||||
FROM updated
|
FROM updated
|
||||||
LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id;
|
LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id;
|
||||||
@@ -228,7 +223,7 @@ LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id;
|
|||||||
-- name: GetChatSettings :one
|
-- name: GetChatSettings :one
|
||||||
SELECT
|
SELECT
|
||||||
b.id AS chat_id,
|
b.id AS chat_id,
|
||||||
chat_models.model_id AS model_id,
|
chat_models.id AS model_id,
|
||||||
b.updated_at
|
b.updated_at
|
||||||
FROM bots b
|
FROM bots b
|
||||||
LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id
|
LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id
|
||||||
|
|||||||
@@ -54,6 +54,11 @@ SELECT * FROM models WHERE id = sqlc.arg(id);
|
|||||||
-- name: GetModelByModelID :one
|
-- name: GetModelByModelID :one
|
||||||
SELECT * FROM models WHERE model_id = sqlc.arg(model_id);
|
SELECT * FROM models WHERE model_id = sqlc.arg(model_id);
|
||||||
|
|
||||||
|
-- name: ListModelsByModelID :many
|
||||||
|
SELECT * FROM models
|
||||||
|
WHERE model_id = sqlc.arg(model_id)
|
||||||
|
ORDER BY created_at DESC;
|
||||||
|
|
||||||
-- name: ListModels :many
|
-- name: ListModels :many
|
||||||
SELECT * FROM models
|
SELECT * FROM models
|
||||||
ORDER BY created_at DESC;
|
ORDER BY created_at DESC;
|
||||||
@@ -82,6 +87,7 @@ ORDER BY created_at DESC;
|
|||||||
-- name: UpdateModel :one
|
-- name: UpdateModel :one
|
||||||
UPDATE models
|
UPDATE models
|
||||||
SET
|
SET
|
||||||
|
model_id = sqlc.arg(model_id),
|
||||||
name = sqlc.arg(name),
|
name = sqlc.arg(name),
|
||||||
llm_provider_id = sqlc.arg(llm_provider_id),
|
llm_provider_id = sqlc.arg(llm_provider_id),
|
||||||
client_type = sqlc.narg(client_type),
|
client_type = sqlc.narg(client_type),
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ SELECT
|
|||||||
bots.max_context_tokens,
|
bots.max_context_tokens,
|
||||||
bots.language,
|
bots.language,
|
||||||
bots.allow_guest,
|
bots.allow_guest,
|
||||||
chat_models.model_id AS chat_model_id,
|
chat_models.id AS chat_model_id,
|
||||||
memory_models.model_id AS memory_model_id,
|
memory_models.id AS memory_model_id,
|
||||||
embedding_models.model_id AS embedding_model_id,
|
embedding_models.id AS embedding_model_id,
|
||||||
search_providers.id AS search_provider_id
|
search_providers.id AS search_provider_id
|
||||||
FROM bots
|
FROM bots
|
||||||
LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id
|
LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id
|
||||||
@@ -37,9 +37,9 @@ SELECT
|
|||||||
updated.max_context_tokens,
|
updated.max_context_tokens,
|
||||||
updated.language,
|
updated.language,
|
||||||
updated.allow_guest,
|
updated.allow_guest,
|
||||||
chat_models.model_id AS chat_model_id,
|
chat_models.id AS chat_model_id,
|
||||||
memory_models.model_id AS memory_model_id,
|
memory_models.id AS memory_model_id,
|
||||||
embedding_models.model_id AS embedding_model_id,
|
embedding_models.id AS embedding_model_id,
|
||||||
search_providers.id AS search_provider_id
|
search_providers.id AS search_provider_id
|
||||||
FROM updated
|
FROM updated
|
||||||
LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id
|
LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
|
|
||||||
attachmentpkg "github.com/memohai/memoh/internal/attachment"
|
attachmentpkg "github.com/memohai/memoh/internal/attachment"
|
||||||
@@ -1535,10 +1536,30 @@ func (r *Resolver) selectChatModel(ctx context.Context, req conversation.ChatReq
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) {
|
func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) {
|
||||||
model, err := r.modelsService.GetByModelID(ctx, modelID)
|
modelRef := strings.TrimSpace(modelID)
|
||||||
|
if modelRef == "" {
|
||||||
|
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model id is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Support both model UUID and model_id slug. UUID-formatted slugs still
|
||||||
|
// work because we fall back to GetByModelID when UUID lookup misses.
|
||||||
|
var model models.GetResponse
|
||||||
|
var err error
|
||||||
|
if _, parseErr := db.ParseUUID(modelRef); parseErr == nil {
|
||||||
|
model, err = r.modelsService.GetByID(ctx, modelRef)
|
||||||
|
if err == nil {
|
||||||
|
goto resolved
|
||||||
|
}
|
||||||
|
if !errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
model, err = r.modelsService.GetByModelID(ctx, modelRef)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resolved:
|
||||||
if model.Type != models.ModelTypeChat {
|
if model.Type != models.ModelTypeChat {
|
||||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model")
|
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
|
|
||||||
@@ -20,6 +21,7 @@ var (
|
|||||||
ErrChatNotFound = errors.New("chat not found")
|
ErrChatNotFound = errors.New("chat not found")
|
||||||
ErrNotParticipant = errors.New("not a participant")
|
ErrNotParticipant = errors.New("not a participant")
|
||||||
ErrPermissionDenied = errors.New("permission denied")
|
ErrPermissionDenied = errors.New("permission denied")
|
||||||
|
ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Service manages conversation lifecycle, participants, and settings.
|
// Service manages conversation lifecycle, participants, and settings.
|
||||||
@@ -310,21 +312,26 @@ func (s *Service) GetSettings(ctx context.Context, conversationID string) (Setti
|
|||||||
|
|
||||||
// UpdateSettings updates conversation settings.
|
// UpdateSettings updates conversation settings.
|
||||||
func (s *Service) UpdateSettings(ctx context.Context, conversationID string, req UpdateSettingsRequest) (Settings, error) {
|
func (s *Service) UpdateSettings(ctx context.Context, conversationID string, req UpdateSettingsRequest) (Settings, error) {
|
||||||
current, err := s.GetSettings(ctx, conversationID)
|
|
||||||
if err != nil {
|
|
||||||
return Settings{}, err
|
|
||||||
}
|
|
||||||
if req.ModelID != nil {
|
|
||||||
current.ModelID = *req.ModelID
|
|
||||||
}
|
|
||||||
|
|
||||||
pgID, err := parseUUID(conversationID)
|
pgID, err := parseUUID(conversationID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Settings{}, err
|
return Settings{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
chatModelUUID := pgtype.UUID{}
|
||||||
|
if req.ModelID != nil {
|
||||||
|
modelRef := strings.TrimSpace(*req.ModelID)
|
||||||
|
if modelRef != "" {
|
||||||
|
resolved, err := s.resolveModelUUID(ctx, modelRef)
|
||||||
|
if err != nil {
|
||||||
|
return Settings{}, err
|
||||||
|
}
|
||||||
|
chatModelUUID = resolved
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
row, err := s.queries.UpsertChatSettings(ctx, sqlc.UpsertChatSettingsParams{
|
row, err := s.queries.UpsertChatSettings(ctx, sqlc.UpsertChatSettingsParams{
|
||||||
ID: pgID,
|
ID: pgID,
|
||||||
ModelID: toPgText(current.ModelID),
|
ChatModelID: chatModelUUID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Settings{}, err
|
return Settings{}, err
|
||||||
@@ -427,17 +434,23 @@ func toParticipantFields(conversationID, userID pgtype.UUID, role string, joined
|
|||||||
}
|
}
|
||||||
|
|
||||||
func toSettingsFromRead(row sqlc.GetChatSettingsRow) Settings {
|
func toSettingsFromRead(row sqlc.GetChatSettingsRow) Settings {
|
||||||
return Settings{
|
settings := Settings{
|
||||||
ChatID: row.ChatID.String(),
|
ChatID: row.ChatID.String(),
|
||||||
ModelID: dbpkg.TextToString(row.ModelID),
|
|
||||||
}
|
}
|
||||||
|
if row.ModelID.Valid {
|
||||||
|
settings.ModelID = uuid.UUID(row.ModelID.Bytes).String()
|
||||||
|
}
|
||||||
|
return settings
|
||||||
}
|
}
|
||||||
|
|
||||||
func toSettingsFromUpsert(row sqlc.UpsertChatSettingsRow) Settings {
|
func toSettingsFromUpsert(row sqlc.UpsertChatSettingsRow) Settings {
|
||||||
return Settings{
|
settings := Settings{
|
||||||
ChatID: row.ChatID.String(),
|
ChatID: row.ChatID.String(),
|
||||||
ModelID: dbpkg.TextToString(row.ModelID),
|
|
||||||
}
|
}
|
||||||
|
if row.ModelID.Valid {
|
||||||
|
settings.ModelID = uuid.UUID(row.ModelID.Bytes).String()
|
||||||
|
}
|
||||||
|
return settings
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultSettings(conversationID string) Settings {
|
func defaultSettings(conversationID string) Settings {
|
||||||
@@ -450,12 +463,32 @@ func parseUUID(id string) (pgtype.UUID, error) {
|
|||||||
return dbpkg.ParseUUID(id)
|
return dbpkg.ParseUUID(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func toPgText(s string) pgtype.Text {
|
func (s *Service) resolveModelUUID(ctx context.Context, modelRef string) (pgtype.UUID, error) {
|
||||||
s = strings.TrimSpace(s)
|
modelRef = strings.TrimSpace(modelRef)
|
||||||
if s == "" {
|
if modelRef == "" {
|
||||||
return pgtype.Text{}
|
return pgtype.UUID{}, fmt.Errorf("model_id is required")
|
||||||
}
|
}
|
||||||
return pgtype.Text{String: s, Valid: true}
|
|
||||||
|
// Prefer UUID path; if not found, fall back to model_id slug.
|
||||||
|
if parsed, err := dbpkg.ParseUUID(modelRef); err == nil {
|
||||||
|
if _, err := s.queries.GetModelByID(ctx, parsed); err == nil {
|
||||||
|
return parsed, nil
|
||||||
|
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return pgtype.UUID{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := s.queries.ListModelsByModelID(ctx, modelRef)
|
||||||
|
if err != nil {
|
||||||
|
return pgtype.UUID{}, err
|
||||||
|
}
|
||||||
|
if len(rows) == 0 {
|
||||||
|
return pgtype.UUID{}, fmt.Errorf("model not found: %s", modelRef)
|
||||||
|
}
|
||||||
|
if len(rows) > 1 {
|
||||||
|
return pgtype.UUID{}, fmt.Errorf("%w: %s", ErrModelIDAmbiguous, modelRef)
|
||||||
|
}
|
||||||
|
return rows[0].ID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func pgTimePtr(ts pgtype.Timestamptz) *time.Time {
|
func pgTimePtr(ts pgtype.Timestamptz) *time.Time {
|
||||||
|
|||||||
@@ -271,7 +271,7 @@ func (q *Queries) GetChatReadAccessByUser(ctx context.Context, arg GetChatReadAc
|
|||||||
const getChatSettings = `-- name: GetChatSettings :one
|
const getChatSettings = `-- name: GetChatSettings :one
|
||||||
SELECT
|
SELECT
|
||||||
b.id AS chat_id,
|
b.id AS chat_id,
|
||||||
chat_models.model_id AS model_id,
|
chat_models.id AS model_id,
|
||||||
b.updated_at
|
b.updated_at
|
||||||
FROM bots b
|
FROM bots b
|
||||||
LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id
|
LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id
|
||||||
@@ -280,7 +280,7 @@ WHERE b.id = $1
|
|||||||
|
|
||||||
type GetChatSettingsRow struct {
|
type GetChatSettingsRow struct {
|
||||||
ChatID pgtype.UUID `json:"chat_id"`
|
ChatID pgtype.UUID `json:"chat_id"`
|
||||||
ModelID pgtype.Text `json:"model_id"`
|
ModelID pgtype.UUID `json:"model_id"`
|
||||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -645,41 +645,36 @@ func (q *Queries) UpdateChatTitle(ctx context.Context, arg UpdateChatTitleParams
|
|||||||
|
|
||||||
const upsertChatSettings = `-- name: UpsertChatSettings :one
|
const upsertChatSettings = `-- name: UpsertChatSettings :one
|
||||||
|
|
||||||
WITH resolved_model AS (
|
WITH
|
||||||
SELECT id
|
|
||||||
FROM models
|
|
||||||
WHERE model_id = NULLIF($1::text, '')
|
|
||||||
LIMIT 1
|
|
||||||
),
|
|
||||||
updated AS (
|
updated AS (
|
||||||
UPDATE bots
|
UPDATE bots
|
||||||
SET chat_model_id = COALESCE((SELECT id FROM resolved_model), bots.chat_model_id),
|
SET chat_model_id = COALESCE($1::uuid, bots.chat_model_id),
|
||||||
updated_at = now()
|
updated_at = now()
|
||||||
WHERE bots.id = $2
|
WHERE bots.id = $2
|
||||||
RETURNING bots.id, bots.chat_model_id, bots.updated_at
|
RETURNING bots.id, bots.chat_model_id, bots.updated_at
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
updated.id AS chat_id,
|
updated.id AS chat_id,
|
||||||
chat_models.model_id AS model_id,
|
chat_models.id AS model_id,
|
||||||
updated.updated_at
|
updated.updated_at
|
||||||
FROM updated
|
FROM updated
|
||||||
LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id
|
LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id
|
||||||
`
|
`
|
||||||
|
|
||||||
type UpsertChatSettingsParams struct {
|
type UpsertChatSettingsParams struct {
|
||||||
ModelID pgtype.Text `json:"model_id"`
|
ChatModelID pgtype.UUID `json:"chat_model_id"`
|
||||||
ID pgtype.UUID `json:"id"`
|
ID pgtype.UUID `json:"id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpsertChatSettingsRow struct {
|
type UpsertChatSettingsRow struct {
|
||||||
ChatID pgtype.UUID `json:"chat_id"`
|
ChatID pgtype.UUID `json:"chat_id"`
|
||||||
ModelID pgtype.Text `json:"model_id"`
|
ModelID pgtype.UUID `json:"model_id"`
|
||||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// chat_settings
|
// chat_settings
|
||||||
func (q *Queries) UpsertChatSettings(ctx context.Context, arg UpsertChatSettingsParams) (UpsertChatSettingsRow, error) {
|
func (q *Queries) UpsertChatSettings(ctx context.Context, arg UpsertChatSettingsParams) (UpsertChatSettingsRow, error) {
|
||||||
row := q.db.QueryRow(ctx, upsertChatSettings, arg.ModelID, arg.ID)
|
row := q.db.QueryRow(ctx, upsertChatSettings, arg.ChatModelID, arg.ID)
|
||||||
var i UpsertChatSettingsRow
|
var i UpsertChatSettingsRow
|
||||||
err := row.Scan(&i.ChatID, &i.ModelID, &i.UpdatedAt)
|
err := row.Scan(&i.ChatID, &i.ModelID, &i.UpdatedAt)
|
||||||
return i, err
|
return i, err
|
||||||
|
|||||||
@@ -419,6 +419,43 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType pgtype.
|
|||||||
return items, nil
|
return items, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const listModelsByModelID = `-- name: ListModelsByModelID :many
|
||||||
|
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models
|
||||||
|
WHERE model_id = $1
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) ListModelsByModelID(ctx context.Context, modelID string) ([]Model, error) {
|
||||||
|
rows, err := q.db.Query(ctx, listModelsByModelID, modelID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []Model
|
||||||
|
for rows.Next() {
|
||||||
|
var i Model
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.ModelID,
|
||||||
|
&i.Name,
|
||||||
|
&i.LlmProviderID,
|
||||||
|
&i.ClientType,
|
||||||
|
&i.Dimensions,
|
||||||
|
&i.InputModalities,
|
||||||
|
&i.Type,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
const listModelsByProviderID = `-- name: ListModelsByProviderID :many
|
const listModelsByProviderID = `-- name: ListModelsByProviderID :many
|
||||||
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models
|
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models
|
||||||
WHERE llm_provider_id = $1
|
WHERE llm_provider_id = $1
|
||||||
@@ -580,18 +617,20 @@ func (q *Queries) UpdateLlmProvider(ctx context.Context, arg UpdateLlmProviderPa
|
|||||||
const updateModel = `-- name: UpdateModel :one
|
const updateModel = `-- name: UpdateModel :one
|
||||||
UPDATE models
|
UPDATE models
|
||||||
SET
|
SET
|
||||||
name = $1,
|
model_id = $1,
|
||||||
llm_provider_id = $2,
|
name = $2,
|
||||||
client_type = $3,
|
llm_provider_id = $3,
|
||||||
dimensions = $4,
|
client_type = $4,
|
||||||
input_modalities = $5,
|
dimensions = $5,
|
||||||
type = $6,
|
input_modalities = $6,
|
||||||
|
type = $7,
|
||||||
updated_at = now()
|
updated_at = now()
|
||||||
WHERE id = $7
|
WHERE id = $8
|
||||||
RETURNING id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at
|
RETURNING id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at
|
||||||
`
|
`
|
||||||
|
|
||||||
type UpdateModelParams struct {
|
type UpdateModelParams struct {
|
||||||
|
ModelID string `json:"model_id"`
|
||||||
Name pgtype.Text `json:"name"`
|
Name pgtype.Text `json:"name"`
|
||||||
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
|
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
|
||||||
ClientType pgtype.Text `json:"client_type"`
|
ClientType pgtype.Text `json:"client_type"`
|
||||||
@@ -603,6 +642,7 @@ type UpdateModelParams struct {
|
|||||||
|
|
||||||
func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model, error) {
|
func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model, error) {
|
||||||
row := q.db.QueryRow(ctx, updateModel,
|
row := q.db.QueryRow(ctx, updateModel,
|
||||||
|
arg.ModelID,
|
||||||
arg.Name,
|
arg.Name,
|
||||||
arg.LlmProviderID,
|
arg.LlmProviderID,
|
||||||
arg.ClientType,
|
arg.ClientType,
|
||||||
|
|||||||
@@ -37,9 +37,9 @@ SELECT
|
|||||||
bots.max_context_tokens,
|
bots.max_context_tokens,
|
||||||
bots.language,
|
bots.language,
|
||||||
bots.allow_guest,
|
bots.allow_guest,
|
||||||
chat_models.model_id AS chat_model_id,
|
chat_models.id AS chat_model_id,
|
||||||
memory_models.model_id AS memory_model_id,
|
memory_models.id AS memory_model_id,
|
||||||
embedding_models.model_id AS embedding_model_id,
|
embedding_models.id AS embedding_model_id,
|
||||||
search_providers.id AS search_provider_id
|
search_providers.id AS search_provider_id
|
||||||
FROM bots
|
FROM bots
|
||||||
LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id
|
LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id
|
||||||
@@ -55,9 +55,9 @@ type GetSettingsByBotIDRow struct {
|
|||||||
MaxContextTokens int32 `json:"max_context_tokens"`
|
MaxContextTokens int32 `json:"max_context_tokens"`
|
||||||
Language string `json:"language"`
|
Language string `json:"language"`
|
||||||
AllowGuest bool `json:"allow_guest"`
|
AllowGuest bool `json:"allow_guest"`
|
||||||
ChatModelID pgtype.Text `json:"chat_model_id"`
|
ChatModelID pgtype.UUID `json:"chat_model_id"`
|
||||||
MemoryModelID pgtype.Text `json:"memory_model_id"`
|
MemoryModelID pgtype.UUID `json:"memory_model_id"`
|
||||||
EmbeddingModelID pgtype.Text `json:"embedding_model_id"`
|
EmbeddingModelID pgtype.UUID `json:"embedding_model_id"`
|
||||||
SearchProviderID pgtype.UUID `json:"search_provider_id"`
|
SearchProviderID pgtype.UUID `json:"search_provider_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,9 +99,9 @@ SELECT
|
|||||||
updated.max_context_tokens,
|
updated.max_context_tokens,
|
||||||
updated.language,
|
updated.language,
|
||||||
updated.allow_guest,
|
updated.allow_guest,
|
||||||
chat_models.model_id AS chat_model_id,
|
chat_models.id AS chat_model_id,
|
||||||
memory_models.model_id AS memory_model_id,
|
memory_models.id AS memory_model_id,
|
||||||
embedding_models.model_id AS embedding_model_id,
|
embedding_models.id AS embedding_model_id,
|
||||||
search_providers.id AS search_provider_id
|
search_providers.id AS search_provider_id
|
||||||
FROM updated
|
FROM updated
|
||||||
LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id
|
LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id
|
||||||
@@ -128,9 +128,9 @@ type UpsertBotSettingsRow struct {
|
|||||||
MaxContextTokens int32 `json:"max_context_tokens"`
|
MaxContextTokens int32 `json:"max_context_tokens"`
|
||||||
Language string `json:"language"`
|
Language string `json:"language"`
|
||||||
AllowGuest bool `json:"allow_guest"`
|
AllowGuest bool `json:"allow_guest"`
|
||||||
ChatModelID pgtype.Text `json:"chat_model_id"`
|
ChatModelID pgtype.UUID `json:"chat_model_id"`
|
||||||
MemoryModelID pgtype.Text `json:"memory_model_id"`
|
MemoryModelID pgtype.UUID `json:"memory_model_id"`
|
||||||
EmbeddingModelID pgtype.Text `json:"embedding_model_id"`
|
EmbeddingModelID pgtype.UUID `json:"embedding_model_id"`
|
||||||
SearchProviderID pgtype.UUID `json:"search_provider_id"`
|
SearchProviderID pgtype.UUID `json:"search_provider_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
package handlers
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
|
||||||
"github.com/memohai/memoh/internal/models"
|
"github.com/memohai/memoh/internal/models"
|
||||||
@@ -52,6 +54,9 @@ func (h *ModelsHandler) Create(c echo.Context) error {
|
|||||||
|
|
||||||
resp, err := h.service.Create(c.Request().Context(), req)
|
resp, err := h.service.Create(c.Request().Context(), req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, models.ErrModelIDAlreadyExists) {
|
||||||
|
return echo.NewHTTPError(http.StatusConflict, "model_id already exists under the selected provider")
|
||||||
|
}
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusCreated, resp)
|
return c.JSON(http.StatusCreated, resp)
|
||||||
@@ -134,6 +139,12 @@ func (h *ModelsHandler) GetByModelID(c echo.Context) error {
|
|||||||
|
|
||||||
resp, err := h.service.GetByModelID(c.Request().Context(), modelID)
|
resp, err := h.service.GetByModelID(c.Request().Context(), modelID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, models.ErrModelIDAmbiguous) {
|
||||||
|
return echo.NewHTTPError(http.StatusConflict, "model_id is duplicated across providers; use /models/{id} instead")
|
||||||
|
}
|
||||||
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return echo.NewHTTPError(http.StatusNotFound, err.Error())
|
||||||
|
}
|
||||||
return echo.NewHTTPError(http.StatusNotFound, err.Error())
|
return echo.NewHTTPError(http.StatusNotFound, err.Error())
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, resp)
|
return c.JSON(http.StatusOK, resp)
|
||||||
@@ -163,6 +174,9 @@ func (h *ModelsHandler) UpdateByID(c echo.Context) error {
|
|||||||
|
|
||||||
resp, err := h.service.UpdateByID(c.Request().Context(), id, req)
|
resp, err := h.service.UpdateByID(c.Request().Context(), id, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, models.ErrModelIDAlreadyExists) {
|
||||||
|
return echo.NewHTTPError(http.StatusConflict, "model_id already exists under the selected provider")
|
||||||
|
}
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, resp)
|
return c.JSON(http.StatusOK, resp)
|
||||||
@@ -197,6 +211,15 @@ func (h *ModelsHandler) UpdateByModelID(c echo.Context) error {
|
|||||||
|
|
||||||
resp, err := h.service.UpdateByModelID(c.Request().Context(), modelID, req)
|
resp, err := h.service.UpdateByModelID(c.Request().Context(), modelID, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, models.ErrModelIDAlreadyExists) {
|
||||||
|
return echo.NewHTTPError(http.StatusConflict, "model_id already exists under the selected provider")
|
||||||
|
}
|
||||||
|
if errors.Is(err, models.ErrModelIDAmbiguous) {
|
||||||
|
return echo.NewHTTPError(http.StatusConflict, "model_id is duplicated across providers; use /models/{id} instead")
|
||||||
|
}
|
||||||
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return echo.NewHTTPError(http.StatusNotFound, err.Error())
|
||||||
|
}
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, resp)
|
return c.JSON(http.StatusOK, resp)
|
||||||
@@ -246,6 +269,12 @@ func (h *ModelsHandler) DeleteByModelID(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := h.service.DeleteByModelID(c.Request().Context(), modelID); err != nil {
|
if err := h.service.DeleteByModelID(c.Request().Context(), modelID); err != nil {
|
||||||
|
if errors.Is(err, models.ErrModelIDAmbiguous) {
|
||||||
|
return echo.NewHTTPError(http.StatusConflict, "model_id is duplicated across providers; use /models/{id} instead")
|
||||||
|
}
|
||||||
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return echo.NewHTTPError(http.StatusNotFound, err.Error())
|
||||||
|
}
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||||
}
|
}
|
||||||
return c.NoContent(http.StatusNoContent)
|
return c.NoContent(http.StatusNoContent)
|
||||||
|
|||||||
@@ -96,6 +96,12 @@ func (h *SettingsHandler) Upsert(c echo.Context) error {
|
|||||||
if errors.Is(err, settings.ErrPersonalBotGuestAccessUnsupported) {
|
if errors.Is(err, settings.ErrPersonalBotGuestAccessUnsupported) {
|
||||||
return echo.NewHTTPError(http.StatusBadRequest, "personal bot does not support guest access")
|
return echo.NewHTTPError(http.StatusBadRequest, "personal bot does not support guest access")
|
||||||
}
|
}
|
||||||
|
if errors.Is(err, settings.ErrInvalidModelRef) {
|
||||||
|
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if errors.Is(err, settings.ErrModelIDAmbiguous) {
|
||||||
|
return echo.NewHTTPError(http.StatusConflict, "model_id is duplicated across providers; select by model UUID")
|
||||||
|
}
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, resp)
|
return c.JSON(http.StatusOK, resp)
|
||||||
|
|||||||
@@ -2,16 +2,21 @@ package models
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
"github.com/memohai/memoh/internal/db"
|
"github.com/memohai/memoh/internal/db"
|
||||||
"github.com/memohai/memoh/internal/db/sqlc"
|
"github.com/memohai/memoh/internal/db/sqlc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrModelIDAlreadyExists = errors.New("model_id already exists")
|
||||||
|
var ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers")
|
||||||
|
|
||||||
// Service provides CRUD operations for models
|
// Service provides CRUD operations for models
|
||||||
type Service struct {
|
type Service struct {
|
||||||
queries *sqlc.Queries
|
queries *sqlc.Queries
|
||||||
@@ -65,6 +70,9 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro
|
|||||||
|
|
||||||
created, err := s.queries.CreateModel(ctx, params)
|
created, err := s.queries.CreateModel(ctx, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if db.IsUniqueViolation(err) {
|
||||||
|
return AddResponse{}, ErrModelIDAlreadyExists
|
||||||
|
}
|
||||||
return AddResponse{}, fmt.Errorf("failed to create model: %w", err)
|
return AddResponse{}, fmt.Errorf("failed to create model: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,7 +113,7 @@ func (s *Service) GetByModelID(ctx context.Context, modelID string) (GetResponse
|
|||||||
return GetResponse{}, fmt.Errorf("model_id is required")
|
return GetResponse{}, fmt.Errorf("model_id is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
dbModel, err := s.queries.GetModelByModelID(ctx, modelID)
|
dbModel, err := s.findUniqueByModelID(ctx, modelID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return GetResponse{}, fmt.Errorf("failed to get model: %w", err)
|
return GetResponse{}, fmt.Errorf("failed to get model: %w", err)
|
||||||
}
|
}
|
||||||
@@ -207,6 +215,7 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
|
|||||||
}
|
}
|
||||||
params := sqlc.UpdateModelParams{
|
params := sqlc.UpdateModelParams{
|
||||||
ID: uuid,
|
ID: uuid,
|
||||||
|
ModelID: model.ModelID,
|
||||||
InputModalities: inputMod,
|
InputModalities: inputMod,
|
||||||
Type: string(model.Type),
|
Type: string(model.Type),
|
||||||
}
|
}
|
||||||
@@ -230,6 +239,9 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
|
|||||||
|
|
||||||
updated, err := s.queries.UpdateModel(ctx, params)
|
updated, err := s.queries.UpdateModel(ctx, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if db.IsUniqueViolation(err) {
|
||||||
|
return GetResponse{}, ErrModelIDAlreadyExists
|
||||||
|
}
|
||||||
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
|
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -241,6 +253,10 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
|
|||||||
if modelID == "" {
|
if modelID == "" {
|
||||||
return GetResponse{}, fmt.Errorf("model_id is required")
|
return GetResponse{}, fmt.Errorf("model_id is required")
|
||||||
}
|
}
|
||||||
|
current, err := s.findUniqueByModelID(ctx, modelID)
|
||||||
|
if err != nil {
|
||||||
|
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
model := Model(req)
|
model := Model(req)
|
||||||
if err := model.Validate(); err != nil {
|
if err := model.Validate(); err != nil {
|
||||||
@@ -251,9 +267,8 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
|
|||||||
if model.Type == ModelTypeChat {
|
if model.Type == ModelTypeChat {
|
||||||
inputMod = normalizeModalities(model.InputModalities, []string{ModelInputText})
|
inputMod = normalizeModalities(model.InputModalities, []string{ModelInputText})
|
||||||
}
|
}
|
||||||
params := sqlc.UpdateModelByModelIDParams{
|
params := sqlc.UpdateModelParams{
|
||||||
ModelID: modelID,
|
ID: current.ID,
|
||||||
NewModelID: model.ModelID,
|
|
||||||
InputModalities: inputMod,
|
InputModalities: inputMod,
|
||||||
Type: string(model.Type),
|
Type: string(model.Type),
|
||||||
}
|
}
|
||||||
@@ -275,8 +290,13 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
|
|||||||
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
|
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, err := s.queries.UpdateModelByModelID(ctx, params)
|
params.ModelID = model.ModelID
|
||||||
|
|
||||||
|
updated, err := s.queries.UpdateModel(ctx, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if db.IsUniqueViolation(err) {
|
||||||
|
return GetResponse{}, ErrModelIDAlreadyExists
|
||||||
|
}
|
||||||
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
|
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -302,8 +322,12 @@ func (s *Service) DeleteByModelID(ctx context.Context, modelID string) error {
|
|||||||
if modelID == "" {
|
if modelID == "" {
|
||||||
return fmt.Errorf("model_id is required")
|
return fmt.Errorf("model_id is required")
|
||||||
}
|
}
|
||||||
|
current, err := s.findUniqueByModelID(ctx, modelID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete model: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := s.queries.DeleteModelByModelID(ctx, modelID); err != nil {
|
if err := s.queries.DeleteModel(ctx, current.ID); err != nil {
|
||||||
return fmt.Errorf("failed to delete model: %w", err)
|
return fmt.Errorf("failed to delete model: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -336,6 +360,7 @@ func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64,
|
|||||||
|
|
||||||
func convertToGetResponse(dbModel sqlc.Model) GetResponse {
|
func convertToGetResponse(dbModel sqlc.Model) GetResponse {
|
||||||
resp := GetResponse{
|
resp := GetResponse{
|
||||||
|
ID: dbModel.ID.String(),
|
||||||
ModelID: dbModel.ModelID,
|
ModelID: dbModel.ModelID,
|
||||||
Model: Model{
|
Model: Model{
|
||||||
ModelID: dbModel.ModelID,
|
ModelID: dbModel.ModelID,
|
||||||
@@ -372,6 +397,20 @@ func convertToGetResponseList(dbModels []sqlc.Model) []GetResponse {
|
|||||||
return responses
|
return responses
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Service) findUniqueByModelID(ctx context.Context, modelID string) (sqlc.Model, error) {
|
||||||
|
rows, err := s.queries.ListModelsByModelID(ctx, modelID)
|
||||||
|
if err != nil {
|
||||||
|
return sqlc.Model{}, err
|
||||||
|
}
|
||||||
|
if len(rows) == 0 {
|
||||||
|
return sqlc.Model{}, pgx.ErrNoRows
|
||||||
|
}
|
||||||
|
if len(rows) > 1 {
|
||||||
|
return sqlc.Model{}, ErrModelIDAmbiguous
|
||||||
|
}
|
||||||
|
return rows[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
// normalizeModalities returns modalities if non-empty, otherwise the provided fallback.
|
// normalizeModalities returns modalities if non-empty, otherwise the provided fallback.
|
||||||
func normalizeModalities(modalities []string, fallback []string) []string {
|
func normalizeModalities(modalities []string, fallback []string) []string {
|
||||||
if len(modalities) == 0 {
|
if len(modalities) == 0 {
|
||||||
|
|||||||
@@ -113,6 +113,7 @@ type GetRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type GetResponse struct {
|
type GetResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
ModelID string `json:"model_id"`
|
ModelID string `json:"model_id"`
|
||||||
Model
|
Model
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
|
|
||||||
"github.com/memohai/memoh/internal/db"
|
"github.com/memohai/memoh/internal/db"
|
||||||
@@ -20,6 +21,8 @@ type Service struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var ErrPersonalBotGuestAccessUnsupported = errors.New("personal bots do not support guest access")
|
var ErrPersonalBotGuestAccessUnsupported = errors.New("personal bots do not support guest access")
|
||||||
|
var ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers")
|
||||||
|
var ErrInvalidModelRef = errors.New("invalid model reference")
|
||||||
|
|
||||||
func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
|
func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
@@ -184,15 +187,21 @@ func normalizeBotSettingsFields(
|
|||||||
maxContextTokens int32,
|
maxContextTokens int32,
|
||||||
language string,
|
language string,
|
||||||
allowGuest bool,
|
allowGuest bool,
|
||||||
chatModelID pgtype.Text,
|
chatModelID pgtype.UUID,
|
||||||
memoryModelID pgtype.Text,
|
memoryModelID pgtype.UUID,
|
||||||
embeddingModelID pgtype.Text,
|
embeddingModelID pgtype.UUID,
|
||||||
searchProviderID pgtype.UUID,
|
searchProviderID pgtype.UUID,
|
||||||
) Settings {
|
) Settings {
|
||||||
settings := normalizeBotSetting(maxContextLoadTime, maxContextTokens, language, allowGuest)
|
settings := normalizeBotSetting(maxContextLoadTime, maxContextTokens, language, allowGuest)
|
||||||
settings.ChatModelID = strings.TrimSpace(chatModelID.String)
|
if chatModelID.Valid {
|
||||||
settings.MemoryModelID = strings.TrimSpace(memoryModelID.String)
|
settings.ChatModelID = uuid.UUID(chatModelID.Bytes).String()
|
||||||
settings.EmbeddingModelID = strings.TrimSpace(embeddingModelID.String)
|
}
|
||||||
|
if memoryModelID.Valid {
|
||||||
|
settings.MemoryModelID = uuid.UUID(memoryModelID.Bytes).String()
|
||||||
|
}
|
||||||
|
if embeddingModelID.Valid {
|
||||||
|
settings.EmbeddingModelID = uuid.UUID(embeddingModelID.Bytes).String()
|
||||||
|
}
|
||||||
if searchProviderID.Valid {
|
if searchProviderID.Valid {
|
||||||
settings.SearchProviderID = uuid.UUID(searchProviderID.Bytes).String()
|
settings.SearchProviderID = uuid.UUID(searchProviderID.Bytes).String()
|
||||||
}
|
}
|
||||||
@@ -200,12 +209,29 @@ func normalizeBotSettingsFields(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype.UUID, error) {
|
func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype.UUID, error) {
|
||||||
if strings.TrimSpace(modelID) == "" {
|
modelID = strings.TrimSpace(modelID)
|
||||||
return pgtype.UUID{}, fmt.Errorf("model_id is required")
|
if modelID == "" {
|
||||||
|
return pgtype.UUID{}, fmt.Errorf("%w: model_id is required", ErrInvalidModelRef)
|
||||||
}
|
}
|
||||||
row, err := s.queries.GetModelByModelID(ctx, modelID)
|
|
||||||
|
// Preferred path: when caller already passes the model UUID.
|
||||||
|
if parsed, err := db.ParseUUID(modelID); err == nil {
|
||||||
|
if _, err := s.queries.GetModelByID(ctx, parsed); err == nil {
|
||||||
|
return parsed, nil
|
||||||
|
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return pgtype.UUID{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := s.queries.ListModelsByModelID(ctx, modelID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return pgtype.UUID{}, err
|
return pgtype.UUID{}, err
|
||||||
}
|
}
|
||||||
return row.ID, nil
|
if len(rows) == 0 {
|
||||||
|
return pgtype.UUID{}, fmt.Errorf("%w: model not found: %s", ErrInvalidModelRef, modelID)
|
||||||
|
}
|
||||||
|
if len(rows) > 1 {
|
||||||
|
return pgtype.UUID{}, fmt.Errorf("%w: %s", ErrModelIDAmbiguous, modelID)
|
||||||
|
}
|
||||||
|
return rows[0].ID, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -718,6 +718,7 @@ export type ModelsCountResponse = {
|
|||||||
export type ModelsGetResponse = {
|
export type ModelsGetResponse = {
|
||||||
client_type?: ModelsClientType;
|
client_type?: ModelsClientType;
|
||||||
dimensions?: number;
|
dimensions?: number;
|
||||||
|
id?: string;
|
||||||
input_modalities?: Array<string>;
|
input_modalities?: Array<string>;
|
||||||
llm_provider_id?: string;
|
llm_provider_id?: string;
|
||||||
model_id?: string;
|
model_id?: string;
|
||||||
|
|||||||
@@ -229,12 +229,15 @@ import { inject, computed, watch, nextTick, type Ref, ref } from 'vue'
|
|||||||
import { toTypedSchema } from '@vee-validate/zod'
|
import { toTypedSchema } from '@vee-validate/zod'
|
||||||
import z from 'zod'
|
import z from 'zod'
|
||||||
import { useMutation, useQueryCache } from '@pinia/colada'
|
import { useMutation, useQueryCache } from '@pinia/colada'
|
||||||
import { postModels, putModelsModelByModelId } from '@memoh/sdk'
|
import { postModels, putModelsById, putModelsModelByModelId } from '@memoh/sdk'
|
||||||
import type { ModelsGetResponse } from '@memoh/sdk'
|
import type { ModelsGetResponse } from '@memoh/sdk'
|
||||||
import { CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
|
import { CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
|
||||||
|
import { useI18n } from 'vue-i18n'
|
||||||
|
import { toast } from 'vue-sonner'
|
||||||
|
|
||||||
const availableInputModalities = ['text', 'image', 'audio', 'video', 'file'] as const
|
const availableInputModalities = ['text', 'image', 'audio', 'video', 'file'] as const
|
||||||
const selectedModalities = ref<string[]>(['text'])
|
const selectedModalities = ref<string[]>(['text'])
|
||||||
|
const { t } = useI18n()
|
||||||
|
|
||||||
const formSchema = toTypedSchema(z.object({
|
const formSchema = toTypedSchema(z.object({
|
||||||
type: z.string().min(1),
|
type: z.string().min(1),
|
||||||
@@ -313,6 +316,17 @@ const { mutateAsync: createModel, isLoading: createLoading } = useMutation({
|
|||||||
onSettled: () => queryCache.invalidateQueries({ key: ['provider-models'] }),
|
onSettled: () => queryCache.invalidateQueries({ key: ['provider-models'] }),
|
||||||
})
|
})
|
||||||
const { mutateAsync: updateModel, isLoading: updateLoading } = useMutation({
|
const { mutateAsync: updateModel, isLoading: updateLoading } = useMutation({
|
||||||
|
mutation: async ({ id, data }: { id: string; data: Record<string, unknown> }) => {
|
||||||
|
const { data: result } = await putModelsById({
|
||||||
|
path: { id },
|
||||||
|
body: data as any,
|
||||||
|
throwOnError: true,
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
},
|
||||||
|
onSettled: () => queryCache.invalidateQueries({ key: ['provider-models'] }),
|
||||||
|
})
|
||||||
|
const { mutateAsync: updateModelByLegacyModelID, isLoading: updateLegacyLoading } = useMutation({
|
||||||
mutation: async ({ modelId, data }: { modelId: string; data: Record<string, unknown> }) => {
|
mutation: async ({ modelId, data }: { modelId: string; data: Record<string, unknown> }) => {
|
||||||
const { data: result } = await putModelsModelByModelId({
|
const { data: result } = await putModelsModelByModelId({
|
||||||
path: { modelId },
|
path: { modelId },
|
||||||
@@ -323,7 +337,7 @@ const { mutateAsync: updateModel, isLoading: updateLoading } = useMutation({
|
|||||||
},
|
},
|
||||||
onSettled: () => queryCache.invalidateQueries({ key: ['provider-models'] }),
|
onSettled: () => queryCache.invalidateQueries({ key: ['provider-models'] }),
|
||||||
})
|
})
|
||||||
const isLoading = computed(() => createLoading.value || updateLoading.value)
|
const isLoading = computed(() => createLoading.value || updateLoading.value || updateLegacyLoading.value)
|
||||||
|
|
||||||
async function addModel(e: Event) {
|
async function addModel(e: Event) {
|
||||||
e.preventDefault()
|
e.preventDefault()
|
||||||
@@ -366,16 +380,31 @@ async function addModel(e: Event) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (isEdit) {
|
if (isEdit) {
|
||||||
await updateModel({ modelId: fallback!.model_id, data: payload as any })
|
const modelUUID = fallback?.id
|
||||||
|
if (modelUUID) {
|
||||||
|
await updateModel({ id: modelUUID, data: payload as any })
|
||||||
|
} else {
|
||||||
|
await updateModelByLegacyModelID({ modelId: fallback!.model_id, data: payload as any })
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
await createModel(payload as any)
|
await createModel(payload as any)
|
||||||
}
|
}
|
||||||
open.value = false
|
open.value = false
|
||||||
} catch {
|
} catch (error) {
|
||||||
|
toast.error(resolveErrorMessage(error, t('common.saveFailed')))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function resolveErrorMessage(error: unknown, fallback: string): string {
|
||||||
|
if (error instanceof Error && error.message.trim()) return error.message
|
||||||
|
if (error && typeof error === 'object' && 'message' in error) {
|
||||||
|
const msg = (error as { message?: string }).message
|
||||||
|
if (msg && msg.trim()) return msg
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
watch(open, async () => {
|
watch(open, async () => {
|
||||||
if (!open.value) {
|
if (!open.value) {
|
||||||
title.value = 'title'
|
title.value = 'title'
|
||||||
|
|||||||
@@ -52,13 +52,13 @@
|
|||||||
</div>
|
</div>
|
||||||
<button
|
<button
|
||||||
v-for="model in group.models"
|
v-for="model in group.models"
|
||||||
:key="model.model_id"
|
:key="model.id || `${model.llm_provider_id}:${model.model_id}`"
|
||||||
class="relative flex w-full cursor-pointer items-center gap-2 rounded-md px-2 py-1.5 text-sm outline-none hover:bg-accent hover:text-accent-foreground"
|
class="relative flex w-full cursor-pointer items-center gap-2 rounded-md px-2 py-1.5 text-sm outline-none hover:bg-accent hover:text-accent-foreground"
|
||||||
:class="{ 'bg-accent': selected === model.model_id }"
|
:class="{ 'bg-accent': selected === model.id }"
|
||||||
@click="selectModel(model.model_id)"
|
@click="selectModel(model.id)"
|
||||||
>
|
>
|
||||||
<FontAwesomeIcon
|
<FontAwesomeIcon
|
||||||
v-if="selected === model.model_id"
|
v-if="selected === model.id"
|
||||||
:icon="['fas', 'check']"
|
:icon="['fas', 'check']"
|
||||||
class="size-3.5"
|
class="size-3.5"
|
||||||
/>
|
/>
|
||||||
@@ -145,11 +145,12 @@ const filteredGroups = computed(() => {
|
|||||||
// 显示选中模型的名称
|
// 显示选中模型的名称
|
||||||
const displayLabel = computed(() => {
|
const displayLabel = computed(() => {
|
||||||
if (!selected.value) return ''
|
if (!selected.value) return ''
|
||||||
const model = typeFilteredModels.value.find((m) => m.model_id === selected.value)
|
const model = typeFilteredModels.value.find((m) => m.id === selected.value)
|
||||||
return model?.name || model?.model_id || selected.value
|
return model?.name || model?.model_id || selected.value
|
||||||
})
|
})
|
||||||
|
|
||||||
function selectModel(modelId: string) {
|
function selectModel(modelId?: string) {
|
||||||
|
if (!modelId) return
|
||||||
selected.value = modelId
|
selected.value = modelId
|
||||||
open.value = false
|
open.value = false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
<template>
|
<template>
|
||||||
<Item variant="outline">
|
<Item variant="outline">
|
||||||
<ItemContent>
|
<ItemContent>
|
||||||
<ItemTitle>{{ model.name }}</ItemTitle>
|
<ItemTitle>{{ model.name || model.model_id }}</ItemTitle>
|
||||||
<ItemDescription class="gap-2 flex flex-wrap items-center mt-3">
|
<ItemDescription class="gap-2 flex flex-wrap items-center mt-3">
|
||||||
<Badge variant="outline">
|
<Badge variant="outline">
|
||||||
{{ model.type }}
|
{{ model.type }}
|
||||||
@@ -26,7 +26,7 @@
|
|||||||
<ConfirmPopover
|
<ConfirmPopover
|
||||||
:message="$t('models.deleteModelConfirm')"
|
:message="$t('models.deleteModelConfirm')"
|
||||||
:loading="deleteLoading"
|
:loading="deleteLoading"
|
||||||
@confirm="$emit('delete', model.name)"
|
@confirm="$emit('delete', model.id ?? '')"
|
||||||
>
|
>
|
||||||
<template #trigger>
|
<template #trigger>
|
||||||
<Button variant="outline">
|
<Button variant="outline">
|
||||||
@@ -58,6 +58,6 @@ defineProps<{
|
|||||||
|
|
||||||
defineEmits<{
|
defineEmits<{
|
||||||
edit: [model: ModelsGetResponse]
|
edit: [model: ModelsGetResponse]
|
||||||
delete: [name: string]
|
delete: [id: string]
|
||||||
}>()
|
}>()
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -16,11 +16,11 @@
|
|||||||
>
|
>
|
||||||
<ModelItem
|
<ModelItem
|
||||||
v-for="model in models"
|
v-for="model in models"
|
||||||
:key="model.model_id"
|
:key="model.id || `${model.llm_provider_id}:${model.model_id}`"
|
||||||
:model="model"
|
:model="model"
|
||||||
:delete-loading="deleteModelLoading"
|
:delete-loading="deleteModelLoading"
|
||||||
@edit="(model) => $emit('edit', model)"
|
@edit="(model) => $emit('edit', model)"
|
||||||
@delete="(name) => $emit('delete', name)"
|
@delete="(id) => $emit('delete', id)"
|
||||||
/>
|
/>
|
||||||
</section>
|
</section>
|
||||||
|
|
||||||
@@ -61,6 +61,6 @@ defineProps<{
|
|||||||
|
|
||||||
defineEmits<{
|
defineEmits<{
|
||||||
edit: [model: ModelsGetResponse]
|
edit: [model: ModelsGetResponse]
|
||||||
delete: [name: string]
|
delete: [id: string]
|
||||||
}>()
|
}>()
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ import ProviderForm from './components/provider-form.vue'
|
|||||||
import ModelList from './components/model-list.vue'
|
import ModelList from './components/model-list.vue'
|
||||||
import { computed, inject, provide, reactive, ref, toRef, watch } from 'vue'
|
import { computed, inject, provide, reactive, ref, toRef, watch } from 'vue'
|
||||||
import { useQuery, useMutation, useQueryCache } from '@pinia/colada'
|
import { useQuery, useMutation, useQueryCache } from '@pinia/colada'
|
||||||
import { putProvidersById, deleteProvidersById, getProvidersByIdModels, deleteModelsModelByModelId } from '@memoh/sdk'
|
import { putProvidersById, deleteProvidersById, getProvidersByIdModels, deleteModelsById } from '@memoh/sdk'
|
||||||
import type { ModelsGetResponse, ProvidersGetResponse } from '@memoh/sdk'
|
import type { ModelsGetResponse, ProvidersGetResponse } from '@memoh/sdk'
|
||||||
|
|
||||||
// ---- Model 编辑状态(provide 给 CreateModel) ----
|
// ---- Model 编辑状态(provide 给 CreateModel) ----
|
||||||
@@ -86,8 +86,9 @@ const { mutate: changeProvider, isLoading: editLoading } = useMutation({
|
|||||||
})
|
})
|
||||||
|
|
||||||
const { mutate: deleteModel, isLoading: deleteModelLoading } = useMutation({
|
const { mutate: deleteModel, isLoading: deleteModelLoading } = useMutation({
|
||||||
mutation: async (modelName: string) => {
|
mutation: async (modelID: string) => {
|
||||||
await deleteModelsModelByModelId({ path: { modelId: modelName }, throwOnError: true })
|
if (!modelID) return
|
||||||
|
await deleteModelsById({ path: { id: modelID }, throwOnError: true })
|
||||||
},
|
},
|
||||||
onSettled: () => queryCache.invalidateQueries({ key: ['provider-models'] }),
|
onSettled: () => queryCache.invalidateQueries({ key: ['provider-models'] }),
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -6818,6 +6818,9 @@ const docTemplate = `{
|
|||||||
"dimensions": {
|
"dimensions": {
|
||||||
"type": "integer"
|
"type": "integer"
|
||||||
},
|
},
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
"input_modalities": {
|
"input_modalities": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
|
|||||||
@@ -6809,6 +6809,9 @@
|
|||||||
"dimensions": {
|
"dimensions": {
|
||||||
"type": "integer"
|
"type": "integer"
|
||||||
},
|
},
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
"input_modalities": {
|
"input_modalities": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
|
|||||||
@@ -1185,6 +1185,8 @@ definitions:
|
|||||||
$ref: '#/definitions/models.ClientType'
|
$ref: '#/definitions/models.ClientType'
|
||||||
dimensions:
|
dimensions:
|
||||||
type: integer
|
type: integer
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
input_modalities:
|
input_modalities:
|
||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
|
|||||||
Reference in New Issue
Block a user