fix: enforce speech/LLM isolation in providers and models

SQL queries (CountProviders, CountModels, ListModels, ListEnabledModels,
ListModelsByProviderID) now exclude speech types. Added IsLLMClientType
guard to prevent cross-domain queries via /models?client_type and
/providers/:id/import-models. Frontend provider forms no longer offer
edge-speech as a client type option.

Also fixed pre-existing SA5011 staticcheck warnings in proxy_test.go
and executor_test.go.
This commit is contained in:
Acbox
2026-04-14 21:07:27 +08:00
parent 84f1d0612a
commit 6328281fc2
10 changed files with 50 additions and 7 deletions
@@ -167,7 +167,7 @@ import { Plus } from 'lucide-vue-next'
import FormDialogShell from '@/components/form-dialog-shell/index.vue'
import { useDialogMutation } from '@/composables/useDialogMutation'
import SearchableSelectPopover from '@/components/searchable-select-popover/index.vue'
import { CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
import { LLM_CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
import { toast } from 'vue-sonner'
import { computed, watch } from 'vue'
@@ -176,7 +176,7 @@ const { t } = useI18n()
const { run } = useDialogMutation()
const clientTypeOptions = computed(() =>
CLIENT_TYPE_LIST.map((ct) => ({
LLM_CLIENT_TYPE_LIST.map((ct) => ({
value: ct.value,
label: ct.label,
description: ct.hint,
+3
View File
@@ -43,3 +43,6 @@ export const CLIENT_TYPE_META: Record<string, ClientTypeMeta> = {
}
export const CLIENT_TYPE_LIST: ClientTypeMeta[] = Object.values(CLIENT_TYPE_META)
export const LLM_CLIENT_TYPE_LIST: ClientTypeMeta[] = CLIENT_TYPE_LIST
.filter(ct => ct.value !== 'edge-speech')
@@ -311,7 +311,7 @@ import StatusDot from '@/components/status-dot/index.vue'
import LoadingButton from '@/components/loading-button/index.vue'
import SearchableSelectPopover from '@/components/searchable-select-popover/index.vue'
import { useClipboard } from '@/composables/useClipboard'
import { CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
import { LLM_CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
import { computed, onBeforeUnmount, ref, watch } from 'vue'
import { toTypedSchema } from '@vee-validate/zod'
import z from 'zod'
@@ -407,7 +407,7 @@ watch(() => props.provider?.id, () => {
})
const clientTypeOptions = computed(() =>
CLIENT_TYPE_LIST.map((ct) => ({
LLM_CLIENT_TYPE_LIST.map((ct) => ({
value: ct.value,
label: ct.label,
description: ct.hint,
+7 -2
View File
@@ -38,7 +38,8 @@ RETURNING *;
DELETE FROM providers WHERE id = sqlc.arg(id);
-- name: CountProviders :one
SELECT COUNT(*) FROM providers;
SELECT COUNT(*) FROM providers
WHERE client_type NOT IN ('edge-speech');
-- name: CreateModel :one
INSERT INTO models (model_id, name, provider_id, type, config)
@@ -64,6 +65,7 @@ ORDER BY created_at DESC;
-- name: ListModels :many
SELECT * FROM models
WHERE type != 'speech'
ORDER BY created_at DESC;
-- name: ListModelsByType :many
@@ -74,6 +76,7 @@ ORDER BY created_at DESC;
-- name: ListModelsByProviderID :many
SELECT * FROM models
WHERE provider_id = sqlc.arg(provider_id)
AND type != 'speech'
ORDER BY created_at DESC;
-- name: ListModelsByProviderIDAndType :many
@@ -108,7 +111,8 @@ DELETE FROM models WHERE id = sqlc.arg(id);
DELETE FROM models WHERE model_id = sqlc.arg(model_id);
-- name: CountModels :one
SELECT COUNT(*) FROM models;
SELECT COUNT(*) FROM models
WHERE type != 'speech';
-- name: CountModelsByType :one
SELECT COUNT(*) FROM models WHERE type = sqlc.arg(type);
@@ -143,6 +147,7 @@ SELECT m.*
FROM models m
JOIN providers p ON m.provider_id = p.id
WHERE p.enable = true
AND m.type != 'speech'
ORDER BY m.created_at DESC;
-- name: ListEnabledModelsByType :many
+1
View File
@@ -43,6 +43,7 @@ func TestNewHTTPClientExplicitProxyOverridesEnvironment(t *testing.T) {
}
if proxyURL == nil {
t.Fatal("expected explicit proxy URL")
return
}
if proxyURL.Host != "config-proxy:3128" {
t.Fatalf("unexpected proxy host: %q", proxyURL.Host)
+5
View File
@@ -13,6 +13,7 @@ import (
const countModels = `-- name: CountModels :one
SELECT COUNT(*) FROM models
WHERE type != 'speech'
`
func (q *Queries) CountModels(ctx context.Context) (int64, error) {
@@ -35,6 +36,7 @@ func (q *Queries) CountModelsByType(ctx context.Context, type_ string) (int64, e
const countProviders = `-- name: CountProviders :one
SELECT COUNT(*) FROM providers
WHERE client_type NOT IN ('edge-speech')
`
func (q *Queries) CountProviders(ctx context.Context) (int64, error) {
@@ -351,6 +353,7 @@ SELECT m.id, m.model_id, m.name, m.provider_id, m.type, m.config, m.created_at,
FROM models m
JOIN providers p ON m.provider_id = p.id
WHERE p.enable = true
AND m.type != 'speech'
ORDER BY m.created_at DESC
`
@@ -495,6 +498,7 @@ func (q *Queries) ListModelVariantsByModelUUID(ctx context.Context, modelUuid pg
const listModels = `-- name: ListModels :many
SELECT id, model_id, name, provider_id, type, config, created_at, updated_at FROM models
WHERE type != 'speech'
ORDER BY created_at DESC
`
@@ -602,6 +606,7 @@ func (q *Queries) ListModelsByProviderClientType(ctx context.Context, clientType
const listModelsByProviderID = `-- name: ListModelsByProviderID :many
SELECT id, model_id, name, provider_id, type, config, created_at, updated_at FROM models
WHERE provider_id = $1
AND type != 'speech'
ORDER BY created_at DESC
`
+5 -1
View File
@@ -87,7 +87,11 @@ func (h *ModelsHandler) List(c echo.Context) error {
case modelType != "":
resp, err = h.service.ListEnabledByType(c.Request().Context(), models.ModelType(modelType))
case clientType != "":
resp, err = h.service.ListEnabledByProviderClientType(c.Request().Context(), models.ClientType(clientType))
ct := models.ClientType(clientType)
if !models.IsLLMClientType(ct) {
return echo.NewHTTPError(http.StatusBadRequest, "invalid client type for LLM models endpoint")
}
resp, err = h.service.ListEnabledByProviderClientType(c.Request().Context(), ct)
default:
resp, err = h.service.ListEnabled(c.Request().Context())
}
+8
View File
@@ -313,6 +313,14 @@ func (h *ProvidersHandler) ImportModels(c echo.Context) error {
ctx = oauthctx.WithUserID(ctx, userID)
}
provider, err := h.service.Get(ctx, id)
if err != nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("provider not found: %v", err))
}
if !models.IsLLMClientType(models.ClientType(provider.ClientType)) {
return echo.NewHTTPError(http.StatusBadRequest, "import models is not supported for speech providers")
}
remoteModels, err := h.service.FetchRemoteModels(ctx, id)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("fetch remote models: %v", err))
+1
View File
@@ -127,6 +127,7 @@ func TestSendSameConversationWithAttachmentsUsesLocalResult(t *testing.T) {
}
if result == nil {
t.Fatal("expected non-nil result")
return
}
if !result.Local {
t.Fatal("expected local result for same-conversation send")
+16
View File
@@ -437,6 +437,22 @@ func IsValidClientType(clientType ClientType) bool {
}
}
// IsLLMClientType returns true if the client type belongs to the LLM domain
// (chat/embedding), excluding speech-only types like edge-speech.
func IsLLMClientType(clientType ClientType) bool {
switch clientType {
case ClientTypeOpenAIResponses,
ClientTypeOpenAICompletions,
ClientTypeAnthropicMessages,
ClientTypeGoogleGenerativeAI,
ClientTypeOpenAICodex,
ClientTypeGitHubCopilot:
return true
default:
return false
}
}
// SelectMemoryModel selects a chat model for memory operations.
// It only considers models from enabled providers.
func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sqlc.Queries) (GetResponse, sqlc.Provider, error) {