diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go index 91f5ba2d..a80fa76c 100644 --- a/internal/conversation/flow/resolver.go +++ b/internal/conversation/flow/resolver.go @@ -1585,7 +1585,7 @@ func (r *Resolver) selectChatModel(ctx context.Context, req conversation.ChatReq return models.GetResponse{}, sqlc.LlmProvider{}, err } for _, m := range candidates { - if m.ModelID == modelID { + if matchesModelReference(m, modelID) { prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID) if err != nil { return models.GetResponse{}, sqlc.LlmProvider{}, err @@ -1631,6 +1631,14 @@ resolved: return model, prov, nil } +func matchesModelReference(model models.GetResponse, modelRef string) bool { + ref := strings.TrimSpace(modelRef) + if ref == "" { + return false + } + return model.ID == ref || model.ModelID == ref +} + func (r *Resolver) listCandidates(ctx context.Context, providerFilter string) ([]models.GetResponse, error) { var all []models.GetResponse var err error diff --git a/internal/conversation/flow/resolver_model_selection_test.go b/internal/conversation/flow/resolver_model_selection_test.go new file mode 100644 index 00000000..2c588595 --- /dev/null +++ b/internal/conversation/flow/resolver_model_selection_test.go @@ -0,0 +1,59 @@ +package flow + +import ( + "testing" + + "github.com/memohai/memoh/internal/models" +) + +func TestMatchesModelReference_ModelID(t *testing.T) { + t.Parallel() + + model := models.GetResponse{ + ID: "a55f0d2d-1547-49a0-b085-ec4ab778f4b8", + ModelID: "gpt-4o", + } + + if !matchesModelReference(model, "gpt-4o") { + t.Fatal("expected model slug to match") + } +} + +func TestMatchesModelReference_UUID(t *testing.T) { + t.Parallel() + + model := models.GetResponse{ + ID: "a55f0d2d-1547-49a0-b085-ec4ab778f4b8", + ModelID: "gpt-4o", + } + + if !matchesModelReference(model, "a55f0d2d-1547-49a0-b085-ec4ab778f4b8") { + t.Fatal("expected model UUID to match") + } +} + +func TestMatchesModelReference_NoMatch(t *testing.T) { + t.Parallel() + + model := models.GetResponse{ + ID: "a55f0d2d-1547-49a0-b085-ec4ab778f4b8", + ModelID: "gpt-4o", + } + + if matchesModelReference(model, "gpt-4.1") { + t.Fatal("expected non-matching model reference to fail") + } +} + +func TestMatchesModelReference_TrimmedInput(t *testing.T) { + t.Parallel() + + model := models.GetResponse{ + ID: "a55f0d2d-1547-49a0-b085-ec4ab778f4b8", + ModelID: "gpt-4o", + } + + if !matchesModelReference(model, " gpt-4o ") { + t.Fatal("expected trimmed model slug to match") + } +}