feat: auto-create search/tts providers at startup with enable toggle

- Add `enable` column (default false) to search_providers and tts_providers tables
- Auto-create default entries for all provider types on startup (disabled by default)
- Add enable/disable Switch toggle in frontend for both search and TTS providers
- Show green status dot in sidebar for enabled providers, sort enabled first
- Filter bot settings dropdowns to only show enabled providers
This commit is contained in:
Acbox
2026-03-28 23:47:09 +08:00
parent c0057b5c54
commit 90ac222bc9
25 changed files with 420 additions and 63 deletions
@@ -477,10 +477,11 @@ const { mutateAsync: deleteBot, isLoading: deleteLoading } = useMutation({
const models = computed(() => modelData.value ?? [])
const providers = computed(() => providerData.value ?? [])
const searchProviders = computed(() => searchProviderData.value ?? [])
const searchProviders = computed(() => (searchProviderData.value ?? []).filter((p) => p.enable !== false))
const memoryProviders = computed(() => memoryProviderData.value ?? [])
const ttsProviders = computed(() => ttsProviderData.value ?? [])
const ttsModels = computed(() => ttsModelData.value ?? [])
const ttsProviders = computed(() => (ttsProviderData.value ?? []).filter((p) => p.enable !== false))
const enabledTtsProviderIds = computed(() => new Set(ttsProviders.value.map((p) => p.id)))
const ttsModels = computed(() => (ttsModelData.value ?? []).filter((m: Record<string, unknown>) => enabledTtsProviderIds.value.has(m.tts_provider_id as string)))
const browserContexts = computed(() => browserContextData.value ?? [])
// ---- Form ----
@@ -1,19 +1,27 @@
<template>
<div class="p-4">
<section class="flex justify-between items-center">
<div class="flex items-center gap-2">
<FontAwesomeIcon
:icon="['fas', 'volume-high']"
class="size-5"
<section class="flex items-center gap-3">
<FontAwesomeIcon
:icon="['fas', 'volume-high']"
class="size-5"
/>
<div class="min-w-0">
<h2 class="text-sm font-semibold truncate">
{{ curProvider?.name }}
</h2>
<p class="text-xs text-muted-foreground">
{{ currentMeta?.display_name ?? curProvider?.provider }}
</p>
</div>
<div class="ml-auto flex items-center gap-2">
<span class="text-xs text-muted-foreground">
{{ $t('common.enable') }}
</span>
<Switch
:model-value="curProvider?.enable ?? false"
:disabled="!curProvider?.id || enableLoading"
@update:model-value="handleToggleEnable"
/>
<div>
<h2 class="text-sm font-semibold">
{{ curProvider?.name }}
</h2>
<p class="text-xs text-muted-foreground">
{{ currentMeta?.display_name ?? curProvider?.provider }}
</p>
</div>
</div>
</section>
<Separator class="mt-4 mb-6" />
@@ -152,6 +160,7 @@ import {
FormItem,
Separator,
Label,
Switch,
} from '@memohai/ui'
import ConfirmPopover from '@/components/confirm-popover/index.vue'
import LoadingButton from '@/components/loading-button/index.vue'
@@ -170,6 +179,7 @@ import type { TtsProviderResponse, TtsProviderMetaResponse, TtsModelInfo } from
const { t } = useI18n()
const curProvider = inject('curTtsProvider', ref<TtsProviderResponse>())
const curProviderId = computed(() => curProvider.value?.id)
const enableLoading = ref(false)
const apiBase = import.meta.env.VITE_API_URL?.trim() || '/api'
function authHeaders(): Record<string, string> {
@@ -219,6 +229,28 @@ function toggleModel(id: string) {
const queryCache = useQueryCache()
async function handleToggleEnable(value: boolean) {
if (!curProviderId.value || !curProvider.value) return
const prev = curProvider.value.enable ?? false
curProvider.value = { ...curProvider.value, enable: value }
enableLoading.value = true
try {
await putTtsProvidersById({
path: { id: curProviderId.value },
body: { enable: value },
throwOnError: true,
})
queryCache.invalidateQueries({ key: ['tts-providers'] })
} catch {
curProvider.value = { ...curProvider.value, enable: prev }
toast.error(t('common.saveFailed'))
} finally {
enableLoading.value = false
}
}
const schema = toTypedSchema(z.object({
name: z.string().min(1),
}))
+18 -2
View File
@@ -37,7 +37,11 @@ const selectProvider = (name: string) => computed(() => {
const filteredProviders = computed(() => {
if (!Array.isArray(providerData.value)) return []
return providerData.value
return [...providerData.value].sort((a, b) => {
const ae = a.enable !== false ? 1 : 0
const be = b.enable !== false ? 1 : 0
return be - ae
})
})
watch(filteredProviders, (list) => {
@@ -76,7 +80,19 @@ const openStatus = reactive({ addOpen: false })
:model-value="selectProvider(item.name ?? '').value"
@update:model-value="(isSelect) => { if (isSelect) curProvider = item }"
>
{{ item.name }}
<span class="relative shrink-0">
<span class="flex size-7 items-center justify-center rounded-full bg-muted">
<FontAwesomeIcon
:icon="['fas', 'volume-high']"
class="size-3.5 text-muted-foreground"
/>
</span>
<span
v-if="item.enable !== false"
class="absolute -bottom-0.5 -right-0.5 size-2.5 rounded-full bg-green-500 ring-2 ring-background"
/>
</span>
<span class="truncate">{{ item.name }}</span>
</Toggle>
</SidebarMenuButton>
</SidebarMenuItem>
@@ -1,14 +1,22 @@
<template>
<div class="p-4">
<section class="flex justify-between items-center">
<div class="flex items-center gap-2">
<SearchProviderLogo
:provider="curProvider?.provider || ''"
size="lg"
<section class="flex items-center gap-3">
<SearchProviderLogo
:provider="curProvider?.provider || ''"
size="lg"
/>
<h2 class="scroll-m-20 text-sm font-semibold tracking-tight min-w-0 truncate">
{{ curProvider?.name }}
</h2>
<div class="ml-auto flex items-center gap-2">
<span class="text-xs text-muted-foreground">
{{ $t('common.enable') }}
</span>
<Switch
:model-value="curProvider?.enable ?? true"
:disabled="!curProvider?.id || enableLoading"
@update:model-value="handleToggleEnable"
/>
<h2 class="scroll-m-20 text-sm font-semibold tracking-tight">
{{ curProvider?.name }}
</h2>
</div>
</section>
<Separator class="mt-4 mb-6" />
@@ -123,6 +131,7 @@ import {
FormItem,
Separator,
Label,
Switch,
} from '@memohai/ui'
import ConfirmPopover from '@/components/confirm-popover/index.vue'
import LoadingButton from '@/components/loading-button/index.vue'
@@ -146,9 +155,13 @@ import { useForm } from 'vee-validate'
import { useMutation, useQueryCache } from '@pinia/colada'
import { putSearchProvidersById, deleteSearchProvidersById } from '@memohai/sdk'
import type { SearchprovidersGetResponse, SearchprovidersUpdateRequest } from '@memohai/sdk'
import { useI18n } from 'vue-i18n'
import { toast } from 'vue-sonner'
const { t } = useI18n()
const curProvider = inject('curSearchProvider', ref<SearchprovidersGetResponse>())
const curProviderId = computed(() => curProvider.value?.id)
const enableLoading = ref(false)
const queryCache = useQueryCache()
@@ -182,6 +195,28 @@ watch(curProvider, (newVal) => {
}
}, { immediate: true })
async function handleToggleEnable(value: boolean) {
if (!curProviderId.value || !curProvider.value) return
const prev = curProvider.value.enable ?? true
curProvider.value = { ...curProvider.value, enable: value }
enableLoading.value = true
try {
await putSearchProvidersById({
path: { id: curProviderId.value },
body: { enable: value },
throwOnError: true,
})
queryCache.invalidateQueries({ key: ['search-providers'] })
} catch {
curProvider.value = { ...curProvider.value, enable: prev }
toast.error(t('common.saveFailed'))
} finally {
enableLoading.value = false
}
}
// ---- mutations ----
const { mutate: submitUpdate, isLoading: editLoading } = useMutation({
mutation: async (data: SearchprovidersUpdateRequest) => {
+32 -12
View File
@@ -43,15 +43,27 @@ const curFilterProvider = computed(() => {
if (!Array.isArray(providerData.value)) {
return []
}
return providerData.value
return [...providerData.value].sort((a, b) => {
const ae = a.enable !== false ? 1 : 0
const be = b.enable !== false ? 1 : 0
return be - ae
})
})
watch(curFilterProvider, () => {
if (curFilterProvider.value.length > 0) {
curProvider.value = curFilterProvider.value[0]
} else {
watch(curFilterProvider, (providers) => {
if (providers.length === 0) {
curProvider.value = { id: '' }
return
}
const currentId = curProvider.value?.id
if (currentId) {
const stillExists = providers.find((p) => p.id === currentId)
if (stillExists) {
curProvider.value = stillExists
return
}
}
curProvider.value = providers[0]
}, {
immediate: true,
})
@@ -74,7 +86,10 @@ const openStatus = reactive({
class="justify-start py-5! px-4"
>
<Toggle
:class="`py-4 border border-transparent ${curProvider?.name === item.name ? 'border-inherit' : ''}`"
:class="[
'py-4 border',
curProvider?.id === item.id ? 'border-border' : 'border-transparent',
]"
:model-value="selectProvider(item.name as string).value"
@update:model-value="(isSelect) => {
if (isSelect) {
@@ -82,12 +97,17 @@ const openStatus = reactive({
}
}"
>
<SearchProviderLogo
:provider="item.provider || ''"
size="sm"
class="mr-2"
/>
{{ item.name }}
<span class="relative shrink-0">
<SearchProviderLogo
:provider="item.provider || ''"
size="sm"
/>
<span
v-if="item.enable !== false"
class="absolute -bottom-0.5 -right-0.5 size-2.5 rounded-full bg-green-500 ring-2 ring-background"
/>
</span>
<span class="truncate">{{ item.name }}</span>
</Toggle>
</SidebarMenuButton>
</SidebarMenuItem>
+24
View File
@@ -263,6 +263,8 @@ func runServe() {
injectToolProviders,
startRegistrySync,
startMemoryProviderBootstrap,
startSearchProviderBootstrap,
startTtsProviderBootstrap,
startScheduleService,
startHeartbeatService,
startChannelManager,
@@ -875,6 +877,28 @@ func startMemoryProviderBootstrap(lc fx.Lifecycle, log *slog.Logger, mpService *
})
}
func startTtsProviderBootstrap(lc fx.Lifecycle, log *slog.Logger, ttsService *ttspkg.Service) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
if err := ttsService.EnsureDefaults(ctx); err != nil {
log.Warn("failed to ensure default tts providers", slog.Any("error", err))
}
return nil
},
})
}
func startSearchProviderBootstrap(lc fx.Lifecycle, log *slog.Logger, spService *searchproviders.Service) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
if err := spService.EnsureDefaults(ctx); err != nil {
log.Warn("failed to ensure default search providers", slog.Any("error", err))
}
return nil
},
})
}
func startScheduleService(lc fx.Lifecycle, scheduleService *schedule.Service) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
+24
View File
@@ -190,6 +190,8 @@ func runServe() {
injectToolProviders,
startRegistrySync,
startMemoryProviderBootstrap,
startSearchProviderBootstrap,
startTtsProviderBootstrap,
startScheduleService,
startHeartbeatService,
startChannelManager,
@@ -314,6 +316,28 @@ func startMemoryProviderBootstrap(lc fx.Lifecycle, log *slog.Logger, mpService *
})
}
func startTtsProviderBootstrap(lc fx.Lifecycle, log *slog.Logger, ttsService *ttspkg.Service) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
if err := ttsService.EnsureDefaults(ctx); err != nil {
log.Warn("failed to ensure default tts providers", slog.Any("error", err))
}
return nil
},
})
}
func startSearchProviderBootstrap(lc fx.Lifecycle, log *slog.Logger, spService *searchproviders.Service) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
if err := spService.EnsureDefaults(ctx); err != nil {
log.Warn("failed to ensure default search providers", slog.Any("error", err))
}
return nil
},
})
}
func provideRouteService(log *slog.Logger, queries *dbsqlc.Queries, chatService *conversation.Service) *route.DBService {
return route.NewService(log, queries, chatService)
}
+2
View File
@@ -77,6 +77,7 @@ CREATE TABLE IF NOT EXISTS search_providers (
name TEXT NOT NULL,
provider TEXT NOT NULL,
config JSONB NOT NULL DEFAULT '{}'::jsonb,
enable BOOLEAN NOT NULL DEFAULT false,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT search_providers_name_unique UNIQUE (name)
@@ -125,6 +126,7 @@ CREATE TABLE IF NOT EXISTS tts_providers (
name TEXT NOT NULL,
provider TEXT NOT NULL,
config JSONB NOT NULL DEFAULT '{}'::jsonb,
enable BOOLEAN NOT NULL DEFAULT false,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT tts_providers_name_unique UNIQUE (name)
@@ -0,0 +1,4 @@
-- 0049_search_provider_enable (down)
-- Remove the enable column from search_providers.
ALTER TABLE search_providers DROP COLUMN IF EXISTS enable;
@@ -0,0 +1,4 @@
-- 0049_search_provider_enable
-- Add enable column to search_providers table for toggling providers on/off.
ALTER TABLE search_providers ADD COLUMN IF NOT EXISTS enable BOOLEAN NOT NULL DEFAULT false;
@@ -0,0 +1,4 @@
-- 0050_tts_provider_enable (down)
-- Remove the enable column from tts_providers.
ALTER TABLE tts_providers DROP COLUMN IF EXISTS enable;
@@ -0,0 +1,4 @@
-- 0050_tts_provider_enable
-- Add enable column to tts_providers table for toggling providers on/off.
ALTER TABLE tts_providers ADD COLUMN IF NOT EXISTS enable BOOLEAN NOT NULL DEFAULT false;
+4 -2
View File
@@ -1,9 +1,10 @@
-- name: CreateSearchProvider :one
INSERT INTO search_providers (name, provider, config)
INSERT INTO search_providers (name, provider, config, enable)
VALUES (
sqlc.arg(name),
sqlc.arg(provider),
sqlc.arg(config)
sqlc.arg(config),
sqlc.arg(enable)
)
RETURNING *;
@@ -28,6 +29,7 @@ SET
name = sqlc.arg(name),
provider = sqlc.arg(provider),
config = sqlc.arg(config),
enable = sqlc.arg(enable),
updated_at = now()
WHERE id = sqlc.arg(id)
RETURNING *;
+4 -2
View File
@@ -1,9 +1,10 @@
-- name: CreateTtsProvider :one
INSERT INTO tts_providers (name, provider, config)
INSERT INTO tts_providers (name, provider, config, enable)
VALUES (
sqlc.arg(name),
sqlc.arg(provider),
sqlc.arg(config)
sqlc.arg(config),
sqlc.arg(enable)
)
RETURNING *;
@@ -28,6 +29,7 @@ SET
name = sqlc.arg(name),
provider = sqlc.arg(provider),
config = sqlc.arg(config),
enable = sqlc.arg(enable),
updated_at = now()
WHERE id = sqlc.arg(id)
RETURNING *;
+2
View File
@@ -431,6 +431,7 @@ type SearchProvider struct {
Name string `json:"name"`
Provider string `json:"provider"`
Config []byte `json:"config"`
Enable bool `json:"enable"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
@@ -470,6 +471,7 @@ type TtsProvider struct {
Name string `json:"name"`
Provider string `json:"provider"`
Config []byte `json:"config"`
Enable bool `json:"enable"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
+26 -10
View File
@@ -12,29 +12,37 @@ import (
)
const createSearchProvider = `-- name: CreateSearchProvider :one
INSERT INTO search_providers (name, provider, config)
INSERT INTO search_providers (name, provider, config, enable)
VALUES (
$1,
$2,
$3
$3,
$4
)
RETURNING id, name, provider, config, created_at, updated_at
RETURNING id, name, provider, config, enable, created_at, updated_at
`
type CreateSearchProviderParams struct {
Name string `json:"name"`
Provider string `json:"provider"`
Config []byte `json:"config"`
Enable bool `json:"enable"`
}
func (q *Queries) CreateSearchProvider(ctx context.Context, arg CreateSearchProviderParams) (SearchProvider, error) {
row := q.db.QueryRow(ctx, createSearchProvider, arg.Name, arg.Provider, arg.Config)
row := q.db.QueryRow(ctx, createSearchProvider,
arg.Name,
arg.Provider,
arg.Config,
arg.Enable,
)
var i SearchProvider
err := row.Scan(
&i.ID,
&i.Name,
&i.Provider,
&i.Config,
&i.Enable,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -51,7 +59,7 @@ func (q *Queries) DeleteSearchProvider(ctx context.Context, id pgtype.UUID) erro
}
const getSearchProviderByID = `-- name: GetSearchProviderByID :one
SELECT id, name, provider, config, created_at, updated_at FROM search_providers WHERE id = $1
SELECT id, name, provider, config, enable, created_at, updated_at FROM search_providers WHERE id = $1
`
func (q *Queries) GetSearchProviderByID(ctx context.Context, id pgtype.UUID) (SearchProvider, error) {
@@ -62,6 +70,7 @@ func (q *Queries) GetSearchProviderByID(ctx context.Context, id pgtype.UUID) (Se
&i.Name,
&i.Provider,
&i.Config,
&i.Enable,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -69,7 +78,7 @@ func (q *Queries) GetSearchProviderByID(ctx context.Context, id pgtype.UUID) (Se
}
const getSearchProviderByName = `-- name: GetSearchProviderByName :one
SELECT id, name, provider, config, created_at, updated_at FROM search_providers WHERE name = $1
SELECT id, name, provider, config, enable, created_at, updated_at FROM search_providers WHERE name = $1
`
func (q *Queries) GetSearchProviderByName(ctx context.Context, name string) (SearchProvider, error) {
@@ -80,6 +89,7 @@ func (q *Queries) GetSearchProviderByName(ctx context.Context, name string) (Sea
&i.Name,
&i.Provider,
&i.Config,
&i.Enable,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -87,7 +97,7 @@ func (q *Queries) GetSearchProviderByName(ctx context.Context, name string) (Sea
}
const listSearchProviders = `-- name: ListSearchProviders :many
SELECT id, name, provider, config, created_at, updated_at FROM search_providers
SELECT id, name, provider, config, enable, created_at, updated_at FROM search_providers
ORDER BY created_at DESC
`
@@ -105,6 +115,7 @@ func (q *Queries) ListSearchProviders(ctx context.Context) ([]SearchProvider, er
&i.Name,
&i.Provider,
&i.Config,
&i.Enable,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
@@ -119,7 +130,7 @@ func (q *Queries) ListSearchProviders(ctx context.Context) ([]SearchProvider, er
}
const listSearchProvidersByProvider = `-- name: ListSearchProvidersByProvider :many
SELECT id, name, provider, config, created_at, updated_at FROM search_providers
SELECT id, name, provider, config, enable, created_at, updated_at FROM search_providers
WHERE provider = $1
ORDER BY created_at DESC
`
@@ -138,6 +149,7 @@ func (q *Queries) ListSearchProvidersByProvider(ctx context.Context, provider st
&i.Name,
&i.Provider,
&i.Config,
&i.Enable,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
@@ -157,15 +169,17 @@ SET
name = $1,
provider = $2,
config = $3,
enable = $4,
updated_at = now()
WHERE id = $4
RETURNING id, name, provider, config, created_at, updated_at
WHERE id = $5
RETURNING id, name, provider, config, enable, created_at, updated_at
`
type UpdateSearchProviderParams struct {
Name string `json:"name"`
Provider string `json:"provider"`
Config []byte `json:"config"`
Enable bool `json:"enable"`
ID pgtype.UUID `json:"id"`
}
@@ -174,6 +188,7 @@ func (q *Queries) UpdateSearchProvider(ctx context.Context, arg UpdateSearchProv
arg.Name,
arg.Provider,
arg.Config,
arg.Enable,
arg.ID,
)
var i SearchProvider
@@ -182,6 +197,7 @@ func (q *Queries) UpdateSearchProvider(ctx context.Context, arg UpdateSearchProv
&i.Name,
&i.Provider,
&i.Config,
&i.Enable,
&i.CreatedAt,
&i.UpdatedAt,
)
+26 -10
View File
@@ -12,29 +12,37 @@ import (
)
const createTtsProvider = `-- name: CreateTtsProvider :one
INSERT INTO tts_providers (name, provider, config)
INSERT INTO tts_providers (name, provider, config, enable)
VALUES (
$1,
$2,
$3
$3,
$4
)
RETURNING id, name, provider, config, created_at, updated_at
RETURNING id, name, provider, config, enable, created_at, updated_at
`
type CreateTtsProviderParams struct {
Name string `json:"name"`
Provider string `json:"provider"`
Config []byte `json:"config"`
Enable bool `json:"enable"`
}
func (q *Queries) CreateTtsProvider(ctx context.Context, arg CreateTtsProviderParams) (TtsProvider, error) {
row := q.db.QueryRow(ctx, createTtsProvider, arg.Name, arg.Provider, arg.Config)
row := q.db.QueryRow(ctx, createTtsProvider,
arg.Name,
arg.Provider,
arg.Config,
arg.Enable,
)
var i TtsProvider
err := row.Scan(
&i.ID,
&i.Name,
&i.Provider,
&i.Config,
&i.Enable,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -51,7 +59,7 @@ func (q *Queries) DeleteTtsProvider(ctx context.Context, id pgtype.UUID) error {
}
const getTtsProviderByID = `-- name: GetTtsProviderByID :one
SELECT id, name, provider, config, created_at, updated_at FROM tts_providers WHERE id = $1
SELECT id, name, provider, config, enable, created_at, updated_at FROM tts_providers WHERE id = $1
`
func (q *Queries) GetTtsProviderByID(ctx context.Context, id pgtype.UUID) (TtsProvider, error) {
@@ -62,6 +70,7 @@ func (q *Queries) GetTtsProviderByID(ctx context.Context, id pgtype.UUID) (TtsPr
&i.Name,
&i.Provider,
&i.Config,
&i.Enable,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -69,7 +78,7 @@ func (q *Queries) GetTtsProviderByID(ctx context.Context, id pgtype.UUID) (TtsPr
}
const getTtsProviderByName = `-- name: GetTtsProviderByName :one
SELECT id, name, provider, config, created_at, updated_at FROM tts_providers WHERE name = $1
SELECT id, name, provider, config, enable, created_at, updated_at FROM tts_providers WHERE name = $1
`
func (q *Queries) GetTtsProviderByName(ctx context.Context, name string) (TtsProvider, error) {
@@ -80,6 +89,7 @@ func (q *Queries) GetTtsProviderByName(ctx context.Context, name string) (TtsPro
&i.Name,
&i.Provider,
&i.Config,
&i.Enable,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -87,7 +97,7 @@ func (q *Queries) GetTtsProviderByName(ctx context.Context, name string) (TtsPro
}
const listTtsProviders = `-- name: ListTtsProviders :many
SELECT id, name, provider, config, created_at, updated_at FROM tts_providers
SELECT id, name, provider, config, enable, created_at, updated_at FROM tts_providers
ORDER BY created_at DESC
`
@@ -105,6 +115,7 @@ func (q *Queries) ListTtsProviders(ctx context.Context) ([]TtsProvider, error) {
&i.Name,
&i.Provider,
&i.Config,
&i.Enable,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
@@ -119,7 +130,7 @@ func (q *Queries) ListTtsProviders(ctx context.Context) ([]TtsProvider, error) {
}
const listTtsProvidersByProvider = `-- name: ListTtsProvidersByProvider :many
SELECT id, name, provider, config, created_at, updated_at FROM tts_providers
SELECT id, name, provider, config, enable, created_at, updated_at FROM tts_providers
WHERE provider = $1
ORDER BY created_at DESC
`
@@ -138,6 +149,7 @@ func (q *Queries) ListTtsProvidersByProvider(ctx context.Context, provider strin
&i.Name,
&i.Provider,
&i.Config,
&i.Enable,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
@@ -157,15 +169,17 @@ SET
name = $1,
provider = $2,
config = $3,
enable = $4,
updated_at = now()
WHERE id = $4
RETURNING id, name, provider, config, created_at, updated_at
WHERE id = $5
RETURNING id, name, provider, config, enable, created_at, updated_at
`
type UpdateTtsProviderParams struct {
Name string `json:"name"`
Provider string `json:"provider"`
Config []byte `json:"config"`
Enable bool `json:"enable"`
ID pgtype.UUID `json:"id"`
}
@@ -174,6 +188,7 @@ func (q *Queries) UpdateTtsProvider(ctx context.Context, arg UpdateTtsProviderPa
arg.Name,
arg.Provider,
arg.Config,
arg.Enable,
arg.ID,
)
var i TtsProvider
@@ -182,6 +197,7 @@ func (q *Queries) UpdateTtsProvider(ctx context.Context, arg UpdateTtsProviderPa
&i.Name,
&i.Provider,
&i.Config,
&i.Enable,
&i.CreatedAt,
&i.UpdatedAt,
)
+58
View File
@@ -405,6 +405,7 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, e
Name: strings.TrimSpace(req.Name),
Provider: string(req.Provider),
Config: configJSON,
Enable: false,
})
if err != nil {
return GetResponse{}, fmt.Errorf("create search provider: %w", err)
@@ -481,11 +482,16 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get
}
config = configJSON
}
enable := current.Enable
if req.Enable != nil {
enable = *req.Enable
}
updated, err := s.queries.UpdateSearchProvider(ctx, sqlc.UpdateSearchProviderParams{
ID: pgID,
Name: name,
Provider: provider,
Config: config,
Enable: enable,
})
if err != nil {
return GetResponse{}, fmt.Errorf("update search provider: %w", err)
@@ -513,11 +519,63 @@ func (s *Service) toGetResponse(row sqlc.SearchProvider) GetResponse {
Name: row.Name,
Provider: row.Provider,
Config: cfg,
Enable: row.Enable,
CreatedAt: row.CreatedAt.Time,
UpdatedAt: row.UpdatedAt.Time,
}
}
var defaultProviders = []struct {
Name ProviderName
DisplayName string
}{
{ProviderBrave, "Brave"},
{ProviderBing, "Bing"},
{ProviderGoogle, "Google"},
{ProviderTavily, "Tavily"},
{ProviderSogou, "Sogou"},
{ProviderSerper, "Serper"},
{ProviderSearXNG, "SearXNG"},
{ProviderJina, "Jina"},
{ProviderExa, "Exa"},
{ProviderBocha, "Bocha"},
{ProviderDuckDuckGo, "DuckDuckGo"},
{ProviderYandex, "Yandex"},
}
func (s *Service) EnsureDefaults(ctx context.Context) error {
rows, err := s.queries.ListSearchProviders(ctx)
if err != nil {
return fmt.Errorf("list search providers: %w", err)
}
existing := make(map[string]struct{}, len(rows))
for _, row := range rows {
existing[row.Provider] = struct{}{}
}
for _, dp := range defaultProviders {
if _, ok := existing[string(dp.Name)]; ok {
continue
}
_, err := s.queries.CreateSearchProvider(ctx, sqlc.CreateSearchProviderParams{
Name: dp.DisplayName,
Provider: string(dp.Name),
Config: []byte("{}"),
Enable: false,
})
if err != nil {
s.logger.Warn("failed to create default search provider",
slog.String("provider", string(dp.Name)),
slog.Any("error", err),
)
continue
}
s.logger.Info("created default search provider", slog.String("provider", string(dp.Name)))
}
return nil
}
func isValidProviderName(name ProviderName) bool {
switch name {
case ProviderBrave, ProviderBing, ProviderGoogle,
+2
View File
@@ -48,6 +48,7 @@ type UpdateRequest struct {
Name *string `json:"name,omitempty"`
Provider *ProviderName `json:"provider,omitempty"`
Config map[string]any `json:"config,omitempty"`
Enable *bool `json:"enable,omitempty"`
}
type GetResponse struct {
@@ -55,6 +56,7 @@ type GetResponse struct {
Name string `json:"name"`
Provider string `json:"provider"`
Config map[string]any `json:"config,omitempty"`
Enable bool `json:"enable"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
+51
View File
@@ -48,6 +48,7 @@ func (s *Service) CreateProvider(ctx context.Context, req CreateProviderRequest)
Name: strings.TrimSpace(req.Name),
Provider: string(req.Provider),
Config: []byte("{}"),
Enable: false,
})
if err != nil {
return ProviderResponse{}, fmt.Errorf("create tts provider: %w", err)
@@ -106,11 +107,16 @@ func (s *Service) UpdateProvider(ctx context.Context, id string, req UpdateProvi
if req.Name != nil {
name = strings.TrimSpace(*req.Name)
}
enable := current.Enable
if req.Enable != nil {
enable = *req.Enable
}
updated, err := s.queries.UpdateTtsProvider(ctx, sqlc.UpdateTtsProviderParams{
ID: pgID,
Name: name,
Provider: current.Provider,
Config: current.Config,
Enable: enable,
})
if err != nil {
return ProviderResponse{}, fmt.Errorf("update tts provider: %w", err)
@@ -126,6 +132,50 @@ func (s *Service) DeleteProvider(ctx context.Context, id string) error {
return s.queries.DeleteTtsProvider(ctx, pgID)
}
// EnsureDefaults creates a default TTS provider for each registered adapter
// type that does not yet exist in the database.
func (s *Service) EnsureDefaults(ctx context.Context) error {
rows, err := s.queries.ListTtsProviders(ctx)
if err != nil {
return fmt.Errorf("list tts providers: %w", err)
}
existing := make(map[string]struct{}, len(rows))
for _, row := range rows {
existing[row.Provider] = struct{}{}
}
for _, meta := range s.registry.ListMeta() {
if _, ok := existing[meta.Provider]; ok {
continue
}
adapter, adapterErr := s.registry.Get(TtsType(meta.Provider))
if adapterErr != nil {
continue
}
row, createErr := s.queries.CreateTtsProvider(ctx, sqlc.CreateTtsProviderParams{
Name: meta.DisplayName,
Provider: meta.Provider,
Config: []byte("{}"),
Enable: false,
})
if createErr != nil {
s.logger.Warn("failed to create default tts provider",
slog.String("provider", meta.Provider),
slog.Any("error", createErr),
)
continue
}
if importErr := s.importModelsForProvider(ctx, row.ID, adapter); importErr != nil {
s.logger.Warn("auto-import models failed for default tts provider",
slog.String("provider", meta.Provider),
slog.Any("error", importErr),
)
}
s.logger.Info("created default tts provider", slog.String("provider", meta.Provider))
}
return nil
}
// ---------------------------------------------------------------------------
// Model CRUD
// ---------------------------------------------------------------------------
@@ -500,6 +550,7 @@ func (*Service) toProviderResponse(row sqlc.TtsProvider) ProviderResponse {
ID: row.ID.String(),
Name: row.Name,
Provider: row.Provider,
Enable: row.Enable,
CreatedAt: row.CreatedAt.Time,
UpdatedAt: row.UpdatedAt.Time,
}
+3 -1
View File
@@ -10,13 +10,15 @@ type CreateProviderRequest struct {
}
type UpdateProviderRequest struct {
Name *string `json:"name,omitempty"`
Name *string `json:"name,omitempty"`
Enable *bool `json:"enable,omitempty"`
}
type ProviderResponse struct {
ID string `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
Enable bool `json:"enable"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
+4
View File
@@ -1378,6 +1378,7 @@ export type SearchprovidersGetResponse = {
[key: string]: unknown;
};
created_at?: string;
enable?: boolean;
id?: string;
name?: string;
provider?: string;
@@ -1411,6 +1412,7 @@ export type SearchprovidersUpdateRequest = {
config?: {
[key: string]: unknown;
};
enable?: boolean;
name?: string;
provider?: SearchprovidersProviderName;
};
@@ -1534,6 +1536,7 @@ export type TtsProviderMetaResponse = {
export type TtsProviderResponse = {
created_at?: string;
enable?: boolean;
id?: string;
name?: string;
provider?: string;
@@ -1555,6 +1558,7 @@ export type TtsUpdateModelRequest = {
};
export type TtsUpdateProviderRequest = {
enable?: boolean;
name?: string;
};
+12
View File
@@ -12206,6 +12206,9 @@ const docTemplate = `{
"created_at": {
"type": "string"
},
"enable": {
"type": "boolean"
},
"id": {
"type": "string"
},
@@ -12307,6 +12310,9 @@ const docTemplate = `{
"type": "object",
"additionalProperties": {}
},
"enable": {
"type": "boolean"
},
"name": {
"type": "string"
},
@@ -12624,6 +12630,9 @@ const docTemplate = `{
"created_at": {
"type": "string"
},
"enable": {
"type": "boolean"
},
"id": {
"type": "string"
},
@@ -12665,6 +12674,9 @@ const docTemplate = `{
"tts.UpdateProviderRequest": {
"type": "object",
"properties": {
"enable": {
"type": "boolean"
},
"name": {
"type": "string"
}
+12
View File
@@ -12197,6 +12197,9 @@
"created_at": {
"type": "string"
},
"enable": {
"type": "boolean"
},
"id": {
"type": "string"
},
@@ -12298,6 +12301,9 @@
"type": "object",
"additionalProperties": {}
},
"enable": {
"type": "boolean"
},
"name": {
"type": "string"
},
@@ -12615,6 +12621,9 @@
"created_at": {
"type": "string"
},
"enable": {
"type": "boolean"
},
"id": {
"type": "string"
},
@@ -12656,6 +12665,9 @@
"tts.UpdateProviderRequest": {
"type": "object",
"properties": {
"enable": {
"type": "boolean"
},
"name": {
"type": "string"
}
+8
View File
@@ -2264,6 +2264,8 @@ definitions:
type: object
created_at:
type: string
enable:
type: boolean
id:
type: string
name:
@@ -2338,6 +2340,8 @@ definitions:
config:
additionalProperties: {}
type: object
enable:
type: boolean
name:
type: string
provider:
@@ -2547,6 +2551,8 @@ definitions:
properties:
created_at:
type: string
enable:
type: boolean
id:
type: string
name:
@@ -2574,6 +2580,8 @@ definitions:
type: object
tts.UpdateProviderRequest:
properties:
enable:
type: boolean
name:
type: string
type: object