mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
fix: enforce speech/LLM isolation in providers and models
SQL queries (CountProviders, CountModels, ListModels, ListEnabledModels, ListModelsByProviderID) now exclude speech types. Added IsLLMClientType guard to prevent cross-domain queries via /models?client_type and /providers/:id/import-models. Frontend provider forms no longer offer edge-speech as a client type option. Also fixed pre-existing SA5011 staticcheck warnings in proxy_test.go and executor_test.go.
This commit is contained in:
@@ -167,7 +167,7 @@ import { Plus } from 'lucide-vue-next'
|
|||||||
import FormDialogShell from '@/components/form-dialog-shell/index.vue'
|
import FormDialogShell from '@/components/form-dialog-shell/index.vue'
|
||||||
import { useDialogMutation } from '@/composables/useDialogMutation'
|
import { useDialogMutation } from '@/composables/useDialogMutation'
|
||||||
import SearchableSelectPopover from '@/components/searchable-select-popover/index.vue'
|
import SearchableSelectPopover from '@/components/searchable-select-popover/index.vue'
|
||||||
import { CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
|
import { LLM_CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
|
||||||
import { toast } from 'vue-sonner'
|
import { toast } from 'vue-sonner'
|
||||||
import { computed, watch } from 'vue'
|
import { computed, watch } from 'vue'
|
||||||
|
|
||||||
@@ -176,7 +176,7 @@ const { t } = useI18n()
|
|||||||
const { run } = useDialogMutation()
|
const { run } = useDialogMutation()
|
||||||
|
|
||||||
const clientTypeOptions = computed(() =>
|
const clientTypeOptions = computed(() =>
|
||||||
CLIENT_TYPE_LIST.map((ct) => ({
|
LLM_CLIENT_TYPE_LIST.map((ct) => ({
|
||||||
value: ct.value,
|
value: ct.value,
|
||||||
label: ct.label,
|
label: ct.label,
|
||||||
description: ct.hint,
|
description: ct.hint,
|
||||||
|
|||||||
@@ -43,3 +43,6 @@ export const CLIENT_TYPE_META: Record<string, ClientTypeMeta> = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export const CLIENT_TYPE_LIST: ClientTypeMeta[] = Object.values(CLIENT_TYPE_META)
|
export const CLIENT_TYPE_LIST: ClientTypeMeta[] = Object.values(CLIENT_TYPE_META)
|
||||||
|
|
||||||
|
export const LLM_CLIENT_TYPE_LIST: ClientTypeMeta[] = CLIENT_TYPE_LIST
|
||||||
|
.filter(ct => ct.value !== 'edge-speech')
|
||||||
|
|||||||
@@ -311,7 +311,7 @@ import StatusDot from '@/components/status-dot/index.vue'
|
|||||||
import LoadingButton from '@/components/loading-button/index.vue'
|
import LoadingButton from '@/components/loading-button/index.vue'
|
||||||
import SearchableSelectPopover from '@/components/searchable-select-popover/index.vue'
|
import SearchableSelectPopover from '@/components/searchable-select-popover/index.vue'
|
||||||
import { useClipboard } from '@/composables/useClipboard'
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import { CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
|
import { LLM_CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
|
||||||
import { computed, onBeforeUnmount, ref, watch } from 'vue'
|
import { computed, onBeforeUnmount, ref, watch } from 'vue'
|
||||||
import { toTypedSchema } from '@vee-validate/zod'
|
import { toTypedSchema } from '@vee-validate/zod'
|
||||||
import z from 'zod'
|
import z from 'zod'
|
||||||
@@ -407,7 +407,7 @@ watch(() => props.provider?.id, () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
const clientTypeOptions = computed(() =>
|
const clientTypeOptions = computed(() =>
|
||||||
CLIENT_TYPE_LIST.map((ct) => ({
|
LLM_CLIENT_TYPE_LIST.map((ct) => ({
|
||||||
value: ct.value,
|
value: ct.value,
|
||||||
label: ct.label,
|
label: ct.label,
|
||||||
description: ct.hint,
|
description: ct.hint,
|
||||||
|
|||||||
@@ -38,7 +38,8 @@ RETURNING *;
|
|||||||
DELETE FROM providers WHERE id = sqlc.arg(id);
|
DELETE FROM providers WHERE id = sqlc.arg(id);
|
||||||
|
|
||||||
-- name: CountProviders :one
|
-- name: CountProviders :one
|
||||||
SELECT COUNT(*) FROM providers;
|
SELECT COUNT(*) FROM providers
|
||||||
|
WHERE client_type NOT IN ('edge-speech');
|
||||||
|
|
||||||
-- name: CreateModel :one
|
-- name: CreateModel :one
|
||||||
INSERT INTO models (model_id, name, provider_id, type, config)
|
INSERT INTO models (model_id, name, provider_id, type, config)
|
||||||
@@ -64,6 +65,7 @@ ORDER BY created_at DESC;
|
|||||||
|
|
||||||
-- name: ListModels :many
|
-- name: ListModels :many
|
||||||
SELECT * FROM models
|
SELECT * FROM models
|
||||||
|
WHERE type != 'speech'
|
||||||
ORDER BY created_at DESC;
|
ORDER BY created_at DESC;
|
||||||
|
|
||||||
-- name: ListModelsByType :many
|
-- name: ListModelsByType :many
|
||||||
@@ -74,6 +76,7 @@ ORDER BY created_at DESC;
|
|||||||
-- name: ListModelsByProviderID :many
|
-- name: ListModelsByProviderID :many
|
||||||
SELECT * FROM models
|
SELECT * FROM models
|
||||||
WHERE provider_id = sqlc.arg(provider_id)
|
WHERE provider_id = sqlc.arg(provider_id)
|
||||||
|
AND type != 'speech'
|
||||||
ORDER BY created_at DESC;
|
ORDER BY created_at DESC;
|
||||||
|
|
||||||
-- name: ListModelsByProviderIDAndType :many
|
-- name: ListModelsByProviderIDAndType :many
|
||||||
@@ -108,7 +111,8 @@ DELETE FROM models WHERE id = sqlc.arg(id);
|
|||||||
DELETE FROM models WHERE model_id = sqlc.arg(model_id);
|
DELETE FROM models WHERE model_id = sqlc.arg(model_id);
|
||||||
|
|
||||||
-- name: CountModels :one
|
-- name: CountModels :one
|
||||||
SELECT COUNT(*) FROM models;
|
SELECT COUNT(*) FROM models
|
||||||
|
WHERE type != 'speech';
|
||||||
|
|
||||||
-- name: CountModelsByType :one
|
-- name: CountModelsByType :one
|
||||||
SELECT COUNT(*) FROM models WHERE type = sqlc.arg(type);
|
SELECT COUNT(*) FROM models WHERE type = sqlc.arg(type);
|
||||||
@@ -143,6 +147,7 @@ SELECT m.*
|
|||||||
FROM models m
|
FROM models m
|
||||||
JOIN providers p ON m.provider_id = p.id
|
JOIN providers p ON m.provider_id = p.id
|
||||||
WHERE p.enable = true
|
WHERE p.enable = true
|
||||||
|
AND m.type != 'speech'
|
||||||
ORDER BY m.created_at DESC;
|
ORDER BY m.created_at DESC;
|
||||||
|
|
||||||
-- name: ListEnabledModelsByType :many
|
-- name: ListEnabledModelsByType :many
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ func TestNewHTTPClientExplicitProxyOverridesEnvironment(t *testing.T) {
|
|||||||
}
|
}
|
||||||
if proxyURL == nil {
|
if proxyURL == nil {
|
||||||
t.Fatal("expected explicit proxy URL")
|
t.Fatal("expected explicit proxy URL")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if proxyURL.Host != "config-proxy:3128" {
|
if proxyURL.Host != "config-proxy:3128" {
|
||||||
t.Fatalf("unexpected proxy host: %q", proxyURL.Host)
|
t.Fatalf("unexpected proxy host: %q", proxyURL.Host)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
const countModels = `-- name: CountModels :one
|
const countModels = `-- name: CountModels :one
|
||||||
SELECT COUNT(*) FROM models
|
SELECT COUNT(*) FROM models
|
||||||
|
WHERE type != 'speech'
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) CountModels(ctx context.Context) (int64, error) {
|
func (q *Queries) CountModels(ctx context.Context) (int64, error) {
|
||||||
@@ -35,6 +36,7 @@ func (q *Queries) CountModelsByType(ctx context.Context, type_ string) (int64, e
|
|||||||
|
|
||||||
const countProviders = `-- name: CountProviders :one
|
const countProviders = `-- name: CountProviders :one
|
||||||
SELECT COUNT(*) FROM providers
|
SELECT COUNT(*) FROM providers
|
||||||
|
WHERE client_type NOT IN ('edge-speech')
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) CountProviders(ctx context.Context) (int64, error) {
|
func (q *Queries) CountProviders(ctx context.Context) (int64, error) {
|
||||||
@@ -351,6 +353,7 @@ SELECT m.id, m.model_id, m.name, m.provider_id, m.type, m.config, m.created_at,
|
|||||||
FROM models m
|
FROM models m
|
||||||
JOIN providers p ON m.provider_id = p.id
|
JOIN providers p ON m.provider_id = p.id
|
||||||
WHERE p.enable = true
|
WHERE p.enable = true
|
||||||
|
AND m.type != 'speech'
|
||||||
ORDER BY m.created_at DESC
|
ORDER BY m.created_at DESC
|
||||||
`
|
`
|
||||||
|
|
||||||
@@ -495,6 +498,7 @@ func (q *Queries) ListModelVariantsByModelUUID(ctx context.Context, modelUuid pg
|
|||||||
|
|
||||||
const listModels = `-- name: ListModels :many
|
const listModels = `-- name: ListModels :many
|
||||||
SELECT id, model_id, name, provider_id, type, config, created_at, updated_at FROM models
|
SELECT id, model_id, name, provider_id, type, config, created_at, updated_at FROM models
|
||||||
|
WHERE type != 'speech'
|
||||||
ORDER BY created_at DESC
|
ORDER BY created_at DESC
|
||||||
`
|
`
|
||||||
|
|
||||||
@@ -602,6 +606,7 @@ func (q *Queries) ListModelsByProviderClientType(ctx context.Context, clientType
|
|||||||
const listModelsByProviderID = `-- name: ListModelsByProviderID :many
|
const listModelsByProviderID = `-- name: ListModelsByProviderID :many
|
||||||
SELECT id, model_id, name, provider_id, type, config, created_at, updated_at FROM models
|
SELECT id, model_id, name, provider_id, type, config, created_at, updated_at FROM models
|
||||||
WHERE provider_id = $1
|
WHERE provider_id = $1
|
||||||
|
AND type != 'speech'
|
||||||
ORDER BY created_at DESC
|
ORDER BY created_at DESC
|
||||||
`
|
`
|
||||||
|
|
||||||
|
|||||||
@@ -87,7 +87,11 @@ func (h *ModelsHandler) List(c echo.Context) error {
|
|||||||
case modelType != "":
|
case modelType != "":
|
||||||
resp, err = h.service.ListEnabledByType(c.Request().Context(), models.ModelType(modelType))
|
resp, err = h.service.ListEnabledByType(c.Request().Context(), models.ModelType(modelType))
|
||||||
case clientType != "":
|
case clientType != "":
|
||||||
resp, err = h.service.ListEnabledByProviderClientType(c.Request().Context(), models.ClientType(clientType))
|
ct := models.ClientType(clientType)
|
||||||
|
if !models.IsLLMClientType(ct) {
|
||||||
|
return echo.NewHTTPError(http.StatusBadRequest, "invalid client type for LLM models endpoint")
|
||||||
|
}
|
||||||
|
resp, err = h.service.ListEnabledByProviderClientType(c.Request().Context(), ct)
|
||||||
default:
|
default:
|
||||||
resp, err = h.service.ListEnabled(c.Request().Context())
|
resp, err = h.service.ListEnabled(c.Request().Context())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -313,6 +313,14 @@ func (h *ProvidersHandler) ImportModels(c echo.Context) error {
|
|||||||
ctx = oauthctx.WithUserID(ctx, userID)
|
ctx = oauthctx.WithUserID(ctx, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
provider, err := h.service.Get(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("provider not found: %v", err))
|
||||||
|
}
|
||||||
|
if !models.IsLLMClientType(models.ClientType(provider.ClientType)) {
|
||||||
|
return echo.NewHTTPError(http.StatusBadRequest, "import models is not supported for speech providers")
|
||||||
|
}
|
||||||
|
|
||||||
remoteModels, err := h.service.FetchRemoteModels(ctx, id)
|
remoteModels, err := h.service.FetchRemoteModels(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("fetch remote models: %v", err))
|
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("fetch remote models: %v", err))
|
||||||
|
|||||||
@@ -127,6 +127,7 @@ func TestSendSameConversationWithAttachmentsUsesLocalResult(t *testing.T) {
|
|||||||
}
|
}
|
||||||
if result == nil {
|
if result == nil {
|
||||||
t.Fatal("expected non-nil result")
|
t.Fatal("expected non-nil result")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if !result.Local {
|
if !result.Local {
|
||||||
t.Fatal("expected local result for same-conversation send")
|
t.Fatal("expected local result for same-conversation send")
|
||||||
|
|||||||
@@ -437,6 +437,22 @@ func IsValidClientType(clientType ClientType) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsLLMClientType returns true if the client type belongs to the LLM domain
|
||||||
|
// (chat/embedding), excluding speech-only types like edge-speech.
|
||||||
|
func IsLLMClientType(clientType ClientType) bool {
|
||||||
|
switch clientType {
|
||||||
|
case ClientTypeOpenAIResponses,
|
||||||
|
ClientTypeOpenAICompletions,
|
||||||
|
ClientTypeAnthropicMessages,
|
||||||
|
ClientTypeGoogleGenerativeAI,
|
||||||
|
ClientTypeOpenAICodex,
|
||||||
|
ClientTypeGitHubCopilot:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SelectMemoryModel selects a chat model for memory operations.
|
// SelectMemoryModel selects a chat model for memory operations.
|
||||||
// It only considers models from enabled providers.
|
// It only considers models from enabled providers.
|
||||||
func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sqlc.Queries) (GetResponse, sqlc.Provider, error) {
|
func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sqlc.Queries) (GetResponse, sqlc.Provider, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user