fix(models,settings,conversation): scope model_id uniqueness per

provider and harden model reference resolution
This commit is contained in:
ringotypowriter
2026-02-21 22:31:32 +08:00
parent 9461f923df
commit 50bdbd519c
25 changed files with 376 additions and 107 deletions
+36 -10
View File
@@ -8,6 +8,7 @@ import (
"strings"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/db"
@@ -20,6 +21,8 @@ type Service struct {
}
var ErrPersonalBotGuestAccessUnsupported = errors.New("personal bots do not support guest access")
var ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers")
var ErrInvalidModelRef = errors.New("invalid model reference")
func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
return &Service{
@@ -184,15 +187,21 @@ func normalizeBotSettingsFields(
maxContextTokens int32,
language string,
allowGuest bool,
chatModelID pgtype.Text,
memoryModelID pgtype.Text,
embeddingModelID pgtype.Text,
chatModelID pgtype.UUID,
memoryModelID pgtype.UUID,
embeddingModelID pgtype.UUID,
searchProviderID pgtype.UUID,
) Settings {
settings := normalizeBotSetting(maxContextLoadTime, maxContextTokens, language, allowGuest)
settings.ChatModelID = strings.TrimSpace(chatModelID.String)
settings.MemoryModelID = strings.TrimSpace(memoryModelID.String)
settings.EmbeddingModelID = strings.TrimSpace(embeddingModelID.String)
if chatModelID.Valid {
settings.ChatModelID = uuid.UUID(chatModelID.Bytes).String()
}
if memoryModelID.Valid {
settings.MemoryModelID = uuid.UUID(memoryModelID.Bytes).String()
}
if embeddingModelID.Valid {
settings.EmbeddingModelID = uuid.UUID(embeddingModelID.Bytes).String()
}
if searchProviderID.Valid {
settings.SearchProviderID = uuid.UUID(searchProviderID.Bytes).String()
}
@@ -200,12 +209,29 @@ func normalizeBotSettingsFields(
}
func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype.UUID, error) {
if strings.TrimSpace(modelID) == "" {
return pgtype.UUID{}, fmt.Errorf("model_id is required")
modelID = strings.TrimSpace(modelID)
if modelID == "" {
return pgtype.UUID{}, fmt.Errorf("%w: model_id is required", ErrInvalidModelRef)
}
row, err := s.queries.GetModelByModelID(ctx, modelID)
// Preferred path: when caller already passes the model UUID.
if parsed, err := db.ParseUUID(modelID); err == nil {
if _, err := s.queries.GetModelByID(ctx, parsed); err == nil {
return parsed, nil
} else if !errors.Is(err, pgx.ErrNoRows) {
return pgtype.UUID{}, err
}
}
rows, err := s.queries.ListModelsByModelID(ctx, modelID)
if err != nil {
return pgtype.UUID{}, err
}
return row.ID, nil
if len(rows) == 0 {
return pgtype.UUID{}, fmt.Errorf("%w: model not found: %s", ErrInvalidModelRef, modelID)
}
if len(rows) > 1 {
return pgtype.UUID{}, fmt.Errorf("%w: %s", ErrModelIDAmbiguous, modelID)
}
return rows[0].ID, nil
}