mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
fix(models,settings,conversation): scope model_id uniqueness per
provider and harden model reference resolution
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
@@ -20,6 +21,8 @@ type Service struct {
|
||||
}
|
||||
|
||||
var ErrPersonalBotGuestAccessUnsupported = errors.New("personal bots do not support guest access")
|
||||
var ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers")
|
||||
var ErrInvalidModelRef = errors.New("invalid model reference")
|
||||
|
||||
func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
|
||||
return &Service{
|
||||
@@ -184,15 +187,21 @@ func normalizeBotSettingsFields(
|
||||
maxContextTokens int32,
|
||||
language string,
|
||||
allowGuest bool,
|
||||
chatModelID pgtype.Text,
|
||||
memoryModelID pgtype.Text,
|
||||
embeddingModelID pgtype.Text,
|
||||
chatModelID pgtype.UUID,
|
||||
memoryModelID pgtype.UUID,
|
||||
embeddingModelID pgtype.UUID,
|
||||
searchProviderID pgtype.UUID,
|
||||
) Settings {
|
||||
settings := normalizeBotSetting(maxContextLoadTime, maxContextTokens, language, allowGuest)
|
||||
settings.ChatModelID = strings.TrimSpace(chatModelID.String)
|
||||
settings.MemoryModelID = strings.TrimSpace(memoryModelID.String)
|
||||
settings.EmbeddingModelID = strings.TrimSpace(embeddingModelID.String)
|
||||
if chatModelID.Valid {
|
||||
settings.ChatModelID = uuid.UUID(chatModelID.Bytes).String()
|
||||
}
|
||||
if memoryModelID.Valid {
|
||||
settings.MemoryModelID = uuid.UUID(memoryModelID.Bytes).String()
|
||||
}
|
||||
if embeddingModelID.Valid {
|
||||
settings.EmbeddingModelID = uuid.UUID(embeddingModelID.Bytes).String()
|
||||
}
|
||||
if searchProviderID.Valid {
|
||||
settings.SearchProviderID = uuid.UUID(searchProviderID.Bytes).String()
|
||||
}
|
||||
@@ -200,12 +209,29 @@ func normalizeBotSettingsFields(
|
||||
}
|
||||
|
||||
func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype.UUID, error) {
|
||||
if strings.TrimSpace(modelID) == "" {
|
||||
return pgtype.UUID{}, fmt.Errorf("model_id is required")
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
return pgtype.UUID{}, fmt.Errorf("%w: model_id is required", ErrInvalidModelRef)
|
||||
}
|
||||
row, err := s.queries.GetModelByModelID(ctx, modelID)
|
||||
|
||||
// Preferred path: when caller already passes the model UUID.
|
||||
if parsed, err := db.ParseUUID(modelID); err == nil {
|
||||
if _, err := s.queries.GetModelByID(ctx, parsed); err == nil {
|
||||
return parsed, nil
|
||||
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return pgtype.UUID{}, err
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := s.queries.ListModelsByModelID(ctx, modelID)
|
||||
if err != nil {
|
||||
return pgtype.UUID{}, err
|
||||
}
|
||||
return row.ID, nil
|
||||
if len(rows) == 0 {
|
||||
return pgtype.UUID{}, fmt.Errorf("%w: model not found: %s", ErrInvalidModelRef, modelID)
|
||||
}
|
||||
if len(rows) > 1 {
|
||||
return pgtype.UUID{}, fmt.Errorf("%w: %s", ErrModelIDAmbiguous, modelID)
|
||||
}
|
||||
return rows[0].ID, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user