package flow import ( "context" "errors" "fmt" "strings" "github.com/jackc/pgx/v5" "github.com/memohai/memoh/internal/conversation" "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/models" "github.com/memohai/memoh/internal/settings" ) func (r *Resolver) selectChatModel(ctx context.Context, req conversation.ChatRequest, botSettings settings.Settings, cs conversation.Settings) (models.GetResponse, sqlc.Provider, error) { if r.modelsService == nil { return models.GetResponse{}, sqlc.Provider{}, errors.New("models service not configured") } modelID := strings.TrimSpace(req.Model) providerFilter := strings.TrimSpace(req.Provider) // Priority: request model > chat settings > bot settings. if modelID == "" && providerFilter == "" { if value := strings.TrimSpace(cs.ModelID); value != "" { modelID = value } else if value := strings.TrimSpace(botSettings.ChatModelID); value != "" { modelID = value } } if modelID == "" { return models.GetResponse{}, sqlc.Provider{}, errors.New("chat model not configured: specify model in request or bot settings") } if providerFilter == "" { return r.fetchChatModel(ctx, modelID) } candidates, err := r.listCandidates(ctx, providerFilter) if err != nil { return models.GetResponse{}, sqlc.Provider{}, err } for _, m := range candidates { if matchesModelReference(m, modelID) { prov, err := models.FetchProviderByID(ctx, r.queries, m.ProviderID) if err != nil { return models.GetResponse{}, sqlc.Provider{}, err } return m, prov, nil } } return models.GetResponse{}, sqlc.Provider{}, fmt.Errorf("chat model %q not found for provider %q", modelID, providerFilter) } func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.Provider, error) { modelRef := strings.TrimSpace(modelID) if modelRef == "" { return models.GetResponse{}, sqlc.Provider{}, errors.New("model id is required") } // Support both model UUID and model_id slug. UUID-formatted slugs still // work because we fall back to GetByModelID when UUID lookup misses. var model models.GetResponse var err error if _, parseErr := db.ParseUUID(modelRef); parseErr == nil { model, err = r.modelsService.GetByID(ctx, modelRef) if err == nil { goto resolved } if !errors.Is(err, pgx.ErrNoRows) { return models.GetResponse{}, sqlc.Provider{}, err } } model, err = r.modelsService.GetByModelID(ctx, modelRef) if err != nil { return models.GetResponse{}, sqlc.Provider{}, err } resolved: if model.Type != models.ModelTypeChat { return models.GetResponse{}, sqlc.Provider{}, errors.New("model is not a chat model") } prov, err := models.FetchProviderByID(ctx, r.queries, model.ProviderID) if err != nil { return models.GetResponse{}, sqlc.Provider{}, err } return model, prov, nil } func matchesModelReference(model models.GetResponse, modelRef string) bool { ref := strings.TrimSpace(modelRef) if ref == "" { return false } return model.ID == ref || model.ModelID == ref } func (r *Resolver) listCandidates(ctx context.Context, providerFilter string) ([]models.GetResponse, error) { var all []models.GetResponse var err error if providerFilter != "" { all, err = r.modelsService.ListEnabledByProviderClientType(ctx, models.ClientType(providerFilter)) } else { all, err = r.modelsService.ListEnabledByType(ctx, models.ModelTypeChat) } if err != nil { return nil, err } filtered := make([]models.GetResponse, 0, len(all)) for _, m := range all { if m.Type == models.ModelTypeChat { filtered = append(filtered, m) } } return filtered, nil }