mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
fix: enforce speech/LLM isolation in providers and models
SQL queries (CountProviders, CountModels, ListModels, ListEnabledModels, ListModelsByProviderID) now exclude speech types. Added IsLLMClientType guard to prevent cross-domain queries via /models?client_type and /providers/:id/import-models. Frontend provider forms no longer offer edge-speech as a client type option. Also fixed pre-existing SA5011 staticcheck warnings in proxy_test.go and executor_test.go.
This commit is contained in:
@@ -167,7 +167,7 @@ import { Plus } from 'lucide-vue-next'
|
||||
import FormDialogShell from '@/components/form-dialog-shell/index.vue'
|
||||
import { useDialogMutation } from '@/composables/useDialogMutation'
|
||||
import SearchableSelectPopover from '@/components/searchable-select-popover/index.vue'
|
||||
import { CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
|
||||
import { LLM_CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
|
||||
import { toast } from 'vue-sonner'
|
||||
import { computed, watch } from 'vue'
|
||||
|
||||
@@ -176,7 +176,7 @@ const { t } = useI18n()
|
||||
const { run } = useDialogMutation()
|
||||
|
||||
const clientTypeOptions = computed(() =>
|
||||
CLIENT_TYPE_LIST.map((ct) => ({
|
||||
LLM_CLIENT_TYPE_LIST.map((ct) => ({
|
||||
value: ct.value,
|
||||
label: ct.label,
|
||||
description: ct.hint,
|
||||
|
||||
@@ -43,3 +43,6 @@ export const CLIENT_TYPE_META: Record<string, ClientTypeMeta> = {
|
||||
}
|
||||
|
||||
export const CLIENT_TYPE_LIST: ClientTypeMeta[] = Object.values(CLIENT_TYPE_META)
|
||||
|
||||
export const LLM_CLIENT_TYPE_LIST: ClientTypeMeta[] = CLIENT_TYPE_LIST
|
||||
.filter(ct => ct.value !== 'edge-speech')
|
||||
|
||||
@@ -311,7 +311,7 @@ import StatusDot from '@/components/status-dot/index.vue'
|
||||
import LoadingButton from '@/components/loading-button/index.vue'
|
||||
import SearchableSelectPopover from '@/components/searchable-select-popover/index.vue'
|
||||
import { useClipboard } from '@/composables/useClipboard'
|
||||
import { CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
|
||||
import { LLM_CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
|
||||
import { computed, onBeforeUnmount, ref, watch } from 'vue'
|
||||
import { toTypedSchema } from '@vee-validate/zod'
|
||||
import z from 'zod'
|
||||
@@ -407,7 +407,7 @@ watch(() => props.provider?.id, () => {
|
||||
})
|
||||
|
||||
const clientTypeOptions = computed(() =>
|
||||
CLIENT_TYPE_LIST.map((ct) => ({
|
||||
LLM_CLIENT_TYPE_LIST.map((ct) => ({
|
||||
value: ct.value,
|
||||
label: ct.label,
|
||||
description: ct.hint,
|
||||
|
||||
@@ -38,7 +38,8 @@ RETURNING *;
|
||||
DELETE FROM providers WHERE id = sqlc.arg(id);
|
||||
|
||||
-- name: CountProviders :one
|
||||
SELECT COUNT(*) FROM providers;
|
||||
SELECT COUNT(*) FROM providers
|
||||
WHERE client_type NOT IN ('edge-speech');
|
||||
|
||||
-- name: CreateModel :one
|
||||
INSERT INTO models (model_id, name, provider_id, type, config)
|
||||
@@ -64,6 +65,7 @@ ORDER BY created_at DESC;
|
||||
|
||||
-- name: ListModels :many
|
||||
SELECT * FROM models
|
||||
WHERE type != 'speech'
|
||||
ORDER BY created_at DESC;
|
||||
|
||||
-- name: ListModelsByType :many
|
||||
@@ -74,6 +76,7 @@ ORDER BY created_at DESC;
|
||||
-- name: ListModelsByProviderID :many
|
||||
SELECT * FROM models
|
||||
WHERE provider_id = sqlc.arg(provider_id)
|
||||
AND type != 'speech'
|
||||
ORDER BY created_at DESC;
|
||||
|
||||
-- name: ListModelsByProviderIDAndType :many
|
||||
@@ -108,7 +111,8 @@ DELETE FROM models WHERE id = sqlc.arg(id);
|
||||
DELETE FROM models WHERE model_id = sqlc.arg(model_id);
|
||||
|
||||
-- name: CountModels :one
|
||||
SELECT COUNT(*) FROM models;
|
||||
SELECT COUNT(*) FROM models
|
||||
WHERE type != 'speech';
|
||||
|
||||
-- name: CountModelsByType :one
|
||||
SELECT COUNT(*) FROM models WHERE type = sqlc.arg(type);
|
||||
@@ -143,6 +147,7 @@ SELECT m.*
|
||||
FROM models m
|
||||
JOIN providers p ON m.provider_id = p.id
|
||||
WHERE p.enable = true
|
||||
AND m.type != 'speech'
|
||||
ORDER BY m.created_at DESC;
|
||||
|
||||
-- name: ListEnabledModelsByType :many
|
||||
|
||||
@@ -43,6 +43,7 @@ func TestNewHTTPClientExplicitProxyOverridesEnvironment(t *testing.T) {
|
||||
}
|
||||
if proxyURL == nil {
|
||||
t.Fatal("expected explicit proxy URL")
|
||||
return
|
||||
}
|
||||
if proxyURL.Host != "config-proxy:3128" {
|
||||
t.Fatalf("unexpected proxy host: %q", proxyURL.Host)
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
const countModels = `-- name: CountModels :one
|
||||
SELECT COUNT(*) FROM models
|
||||
WHERE type != 'speech'
|
||||
`
|
||||
|
||||
func (q *Queries) CountModels(ctx context.Context) (int64, error) {
|
||||
@@ -35,6 +36,7 @@ func (q *Queries) CountModelsByType(ctx context.Context, type_ string) (int64, e
|
||||
|
||||
const countProviders = `-- name: CountProviders :one
|
||||
SELECT COUNT(*) FROM providers
|
||||
WHERE client_type NOT IN ('edge-speech')
|
||||
`
|
||||
|
||||
func (q *Queries) CountProviders(ctx context.Context) (int64, error) {
|
||||
@@ -351,6 +353,7 @@ SELECT m.id, m.model_id, m.name, m.provider_id, m.type, m.config, m.created_at,
|
||||
FROM models m
|
||||
JOIN providers p ON m.provider_id = p.id
|
||||
WHERE p.enable = true
|
||||
AND m.type != 'speech'
|
||||
ORDER BY m.created_at DESC
|
||||
`
|
||||
|
||||
@@ -495,6 +498,7 @@ func (q *Queries) ListModelVariantsByModelUUID(ctx context.Context, modelUuid pg
|
||||
|
||||
const listModels = `-- name: ListModels :many
|
||||
SELECT id, model_id, name, provider_id, type, config, created_at, updated_at FROM models
|
||||
WHERE type != 'speech'
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
@@ -602,6 +606,7 @@ func (q *Queries) ListModelsByProviderClientType(ctx context.Context, clientType
|
||||
const listModelsByProviderID = `-- name: ListModelsByProviderID :many
|
||||
SELECT id, model_id, name, provider_id, type, config, created_at, updated_at FROM models
|
||||
WHERE provider_id = $1
|
||||
AND type != 'speech'
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
|
||||
@@ -87,7 +87,11 @@ func (h *ModelsHandler) List(c echo.Context) error {
|
||||
case modelType != "":
|
||||
resp, err = h.service.ListEnabledByType(c.Request().Context(), models.ModelType(modelType))
|
||||
case clientType != "":
|
||||
resp, err = h.service.ListEnabledByProviderClientType(c.Request().Context(), models.ClientType(clientType))
|
||||
ct := models.ClientType(clientType)
|
||||
if !models.IsLLMClientType(ct) {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "invalid client type for LLM models endpoint")
|
||||
}
|
||||
resp, err = h.service.ListEnabledByProviderClientType(c.Request().Context(), ct)
|
||||
default:
|
||||
resp, err = h.service.ListEnabled(c.Request().Context())
|
||||
}
|
||||
|
||||
@@ -313,6 +313,14 @@ func (h *ProvidersHandler) ImportModels(c echo.Context) error {
|
||||
ctx = oauthctx.WithUserID(ctx, userID)
|
||||
}
|
||||
|
||||
provider, err := h.service.Get(ctx, id)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("provider not found: %v", err))
|
||||
}
|
||||
if !models.IsLLMClientType(models.ClientType(provider.ClientType)) {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "import models is not supported for speech providers")
|
||||
}
|
||||
|
||||
remoteModels, err := h.service.FetchRemoteModels(ctx, id)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("fetch remote models: %v", err))
|
||||
|
||||
@@ -127,6 +127,7 @@ func TestSendSameConversationWithAttachmentsUsesLocalResult(t *testing.T) {
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
return
|
||||
}
|
||||
if !result.Local {
|
||||
t.Fatal("expected local result for same-conversation send")
|
||||
|
||||
@@ -437,6 +437,22 @@ func IsValidClientType(clientType ClientType) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// IsLLMClientType returns true if the client type belongs to the LLM domain
|
||||
// (chat/embedding), excluding speech-only types like edge-speech.
|
||||
func IsLLMClientType(clientType ClientType) bool {
|
||||
switch clientType {
|
||||
case ClientTypeOpenAIResponses,
|
||||
ClientTypeOpenAICompletions,
|
||||
ClientTypeAnthropicMessages,
|
||||
ClientTypeGoogleGenerativeAI,
|
||||
ClientTypeOpenAICodex,
|
||||
ClientTypeGitHubCopilot:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// SelectMemoryModel selects a chat model for memory operations.
|
||||
// It only considers models from enabled providers.
|
||||
func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sqlc.Queries) (GetResponse, sqlc.Provider, error) {
|
||||
|
||||
Reference in New Issue
Block a user