mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat: auto-create search/tts providers at startup with enable toggle
- Add `enable` column (default false) to search_providers and tts_providers tables - Auto-create default entries for all provider types on startup (disabled by default) - Add enable/disable Switch toggle in frontend for both search and TTS providers - Show green status dot in sidebar for enabled providers, sort enabled first - Filter bot settings dropdowns to only show enabled providers
This commit is contained in:
@@ -431,6 +431,7 @@ type SearchProvider struct {
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Config []byte `json:"config"`
|
||||
Enable bool `json:"enable"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
}
|
||||
@@ -470,6 +471,7 @@ type TtsProvider struct {
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Config []byte `json:"config"`
|
||||
Enable bool `json:"enable"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
}
|
||||
|
||||
@@ -12,29 +12,37 @@ import (
|
||||
)
|
||||
|
||||
const createSearchProvider = `-- name: CreateSearchProvider :one
|
||||
INSERT INTO search_providers (name, provider, config)
|
||||
INSERT INTO search_providers (name, provider, config, enable)
|
||||
VALUES (
|
||||
$1,
|
||||
$2,
|
||||
$3
|
||||
$3,
|
||||
$4
|
||||
)
|
||||
RETURNING id, name, provider, config, created_at, updated_at
|
||||
RETURNING id, name, provider, config, enable, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateSearchProviderParams struct {
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Config []byte `json:"config"`
|
||||
Enable bool `json:"enable"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateSearchProvider(ctx context.Context, arg CreateSearchProviderParams) (SearchProvider, error) {
|
||||
row := q.db.QueryRow(ctx, createSearchProvider, arg.Name, arg.Provider, arg.Config)
|
||||
row := q.db.QueryRow(ctx, createSearchProvider,
|
||||
arg.Name,
|
||||
arg.Provider,
|
||||
arg.Config,
|
||||
arg.Enable,
|
||||
)
|
||||
var i SearchProvider
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.Enable,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
@@ -51,7 +59,7 @@ func (q *Queries) DeleteSearchProvider(ctx context.Context, id pgtype.UUID) erro
|
||||
}
|
||||
|
||||
const getSearchProviderByID = `-- name: GetSearchProviderByID :one
|
||||
SELECT id, name, provider, config, created_at, updated_at FROM search_providers WHERE id = $1
|
||||
SELECT id, name, provider, config, enable, created_at, updated_at FROM search_providers WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetSearchProviderByID(ctx context.Context, id pgtype.UUID) (SearchProvider, error) {
|
||||
@@ -62,6 +70,7 @@ func (q *Queries) GetSearchProviderByID(ctx context.Context, id pgtype.UUID) (Se
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.Enable,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
@@ -69,7 +78,7 @@ func (q *Queries) GetSearchProviderByID(ctx context.Context, id pgtype.UUID) (Se
|
||||
}
|
||||
|
||||
const getSearchProviderByName = `-- name: GetSearchProviderByName :one
|
||||
SELECT id, name, provider, config, created_at, updated_at FROM search_providers WHERE name = $1
|
||||
SELECT id, name, provider, config, enable, created_at, updated_at FROM search_providers WHERE name = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetSearchProviderByName(ctx context.Context, name string) (SearchProvider, error) {
|
||||
@@ -80,6 +89,7 @@ func (q *Queries) GetSearchProviderByName(ctx context.Context, name string) (Sea
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.Enable,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
@@ -87,7 +97,7 @@ func (q *Queries) GetSearchProviderByName(ctx context.Context, name string) (Sea
|
||||
}
|
||||
|
||||
const listSearchProviders = `-- name: ListSearchProviders :many
|
||||
SELECT id, name, provider, config, created_at, updated_at FROM search_providers
|
||||
SELECT id, name, provider, config, enable, created_at, updated_at FROM search_providers
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
@@ -105,6 +115,7 @@ func (q *Queries) ListSearchProviders(ctx context.Context) ([]SearchProvider, er
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.Enable,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
@@ -119,7 +130,7 @@ func (q *Queries) ListSearchProviders(ctx context.Context) ([]SearchProvider, er
|
||||
}
|
||||
|
||||
const listSearchProvidersByProvider = `-- name: ListSearchProvidersByProvider :many
|
||||
SELECT id, name, provider, config, created_at, updated_at FROM search_providers
|
||||
SELECT id, name, provider, config, enable, created_at, updated_at FROM search_providers
|
||||
WHERE provider = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
@@ -138,6 +149,7 @@ func (q *Queries) ListSearchProvidersByProvider(ctx context.Context, provider st
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.Enable,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
@@ -157,15 +169,17 @@ SET
|
||||
name = $1,
|
||||
provider = $2,
|
||||
config = $3,
|
||||
enable = $4,
|
||||
updated_at = now()
|
||||
WHERE id = $4
|
||||
RETURNING id, name, provider, config, created_at, updated_at
|
||||
WHERE id = $5
|
||||
RETURNING id, name, provider, config, enable, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateSearchProviderParams struct {
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Config []byte `json:"config"`
|
||||
Enable bool `json:"enable"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
}
|
||||
|
||||
@@ -174,6 +188,7 @@ func (q *Queries) UpdateSearchProvider(ctx context.Context, arg UpdateSearchProv
|
||||
arg.Name,
|
||||
arg.Provider,
|
||||
arg.Config,
|
||||
arg.Enable,
|
||||
arg.ID,
|
||||
)
|
||||
var i SearchProvider
|
||||
@@ -182,6 +197,7 @@ func (q *Queries) UpdateSearchProvider(ctx context.Context, arg UpdateSearchProv
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.Enable,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
|
||||
@@ -12,29 +12,37 @@ import (
|
||||
)
|
||||
|
||||
const createTtsProvider = `-- name: CreateTtsProvider :one
|
||||
INSERT INTO tts_providers (name, provider, config)
|
||||
INSERT INTO tts_providers (name, provider, config, enable)
|
||||
VALUES (
|
||||
$1,
|
||||
$2,
|
||||
$3
|
||||
$3,
|
||||
$4
|
||||
)
|
||||
RETURNING id, name, provider, config, created_at, updated_at
|
||||
RETURNING id, name, provider, config, enable, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateTtsProviderParams struct {
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Config []byte `json:"config"`
|
||||
Enable bool `json:"enable"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateTtsProvider(ctx context.Context, arg CreateTtsProviderParams) (TtsProvider, error) {
|
||||
row := q.db.QueryRow(ctx, createTtsProvider, arg.Name, arg.Provider, arg.Config)
|
||||
row := q.db.QueryRow(ctx, createTtsProvider,
|
||||
arg.Name,
|
||||
arg.Provider,
|
||||
arg.Config,
|
||||
arg.Enable,
|
||||
)
|
||||
var i TtsProvider
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.Enable,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
@@ -51,7 +59,7 @@ func (q *Queries) DeleteTtsProvider(ctx context.Context, id pgtype.UUID) error {
|
||||
}
|
||||
|
||||
const getTtsProviderByID = `-- name: GetTtsProviderByID :one
|
||||
SELECT id, name, provider, config, created_at, updated_at FROM tts_providers WHERE id = $1
|
||||
SELECT id, name, provider, config, enable, created_at, updated_at FROM tts_providers WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetTtsProviderByID(ctx context.Context, id pgtype.UUID) (TtsProvider, error) {
|
||||
@@ -62,6 +70,7 @@ func (q *Queries) GetTtsProviderByID(ctx context.Context, id pgtype.UUID) (TtsPr
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.Enable,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
@@ -69,7 +78,7 @@ func (q *Queries) GetTtsProviderByID(ctx context.Context, id pgtype.UUID) (TtsPr
|
||||
}
|
||||
|
||||
const getTtsProviderByName = `-- name: GetTtsProviderByName :one
|
||||
SELECT id, name, provider, config, created_at, updated_at FROM tts_providers WHERE name = $1
|
||||
SELECT id, name, provider, config, enable, created_at, updated_at FROM tts_providers WHERE name = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetTtsProviderByName(ctx context.Context, name string) (TtsProvider, error) {
|
||||
@@ -80,6 +89,7 @@ func (q *Queries) GetTtsProviderByName(ctx context.Context, name string) (TtsPro
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.Enable,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
@@ -87,7 +97,7 @@ func (q *Queries) GetTtsProviderByName(ctx context.Context, name string) (TtsPro
|
||||
}
|
||||
|
||||
const listTtsProviders = `-- name: ListTtsProviders :many
|
||||
SELECT id, name, provider, config, created_at, updated_at FROM tts_providers
|
||||
SELECT id, name, provider, config, enable, created_at, updated_at FROM tts_providers
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
@@ -105,6 +115,7 @@ func (q *Queries) ListTtsProviders(ctx context.Context) ([]TtsProvider, error) {
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.Enable,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
@@ -119,7 +130,7 @@ func (q *Queries) ListTtsProviders(ctx context.Context) ([]TtsProvider, error) {
|
||||
}
|
||||
|
||||
const listTtsProvidersByProvider = `-- name: ListTtsProvidersByProvider :many
|
||||
SELECT id, name, provider, config, created_at, updated_at FROM tts_providers
|
||||
SELECT id, name, provider, config, enable, created_at, updated_at FROM tts_providers
|
||||
WHERE provider = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
@@ -138,6 +149,7 @@ func (q *Queries) ListTtsProvidersByProvider(ctx context.Context, provider strin
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.Enable,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
@@ -157,15 +169,17 @@ SET
|
||||
name = $1,
|
||||
provider = $2,
|
||||
config = $3,
|
||||
enable = $4,
|
||||
updated_at = now()
|
||||
WHERE id = $4
|
||||
RETURNING id, name, provider, config, created_at, updated_at
|
||||
WHERE id = $5
|
||||
RETURNING id, name, provider, config, enable, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateTtsProviderParams struct {
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Config []byte `json:"config"`
|
||||
Enable bool `json:"enable"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
}
|
||||
|
||||
@@ -174,6 +188,7 @@ func (q *Queries) UpdateTtsProvider(ctx context.Context, arg UpdateTtsProviderPa
|
||||
arg.Name,
|
||||
arg.Provider,
|
||||
arg.Config,
|
||||
arg.Enable,
|
||||
arg.ID,
|
||||
)
|
||||
var i TtsProvider
|
||||
@@ -182,6 +197,7 @@ func (q *Queries) UpdateTtsProvider(ctx context.Context, arg UpdateTtsProviderPa
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.Config,
|
||||
&i.Enable,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
|
||||
@@ -405,6 +405,7 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, e
|
||||
Name: strings.TrimSpace(req.Name),
|
||||
Provider: string(req.Provider),
|
||||
Config: configJSON,
|
||||
Enable: false,
|
||||
})
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("create search provider: %w", err)
|
||||
@@ -481,11 +482,16 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get
|
||||
}
|
||||
config = configJSON
|
||||
}
|
||||
enable := current.Enable
|
||||
if req.Enable != nil {
|
||||
enable = *req.Enable
|
||||
}
|
||||
updated, err := s.queries.UpdateSearchProvider(ctx, sqlc.UpdateSearchProviderParams{
|
||||
ID: pgID,
|
||||
Name: name,
|
||||
Provider: provider,
|
||||
Config: config,
|
||||
Enable: enable,
|
||||
})
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("update search provider: %w", err)
|
||||
@@ -513,11 +519,63 @@ func (s *Service) toGetResponse(row sqlc.SearchProvider) GetResponse {
|
||||
Name: row.Name,
|
||||
Provider: row.Provider,
|
||||
Config: cfg,
|
||||
Enable: row.Enable,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
UpdatedAt: row.UpdatedAt.Time,
|
||||
}
|
||||
}
|
||||
|
||||
var defaultProviders = []struct {
|
||||
Name ProviderName
|
||||
DisplayName string
|
||||
}{
|
||||
{ProviderBrave, "Brave"},
|
||||
{ProviderBing, "Bing"},
|
||||
{ProviderGoogle, "Google"},
|
||||
{ProviderTavily, "Tavily"},
|
||||
{ProviderSogou, "Sogou"},
|
||||
{ProviderSerper, "Serper"},
|
||||
{ProviderSearXNG, "SearXNG"},
|
||||
{ProviderJina, "Jina"},
|
||||
{ProviderExa, "Exa"},
|
||||
{ProviderBocha, "Bocha"},
|
||||
{ProviderDuckDuckGo, "DuckDuckGo"},
|
||||
{ProviderYandex, "Yandex"},
|
||||
}
|
||||
|
||||
func (s *Service) EnsureDefaults(ctx context.Context) error {
|
||||
rows, err := s.queries.ListSearchProviders(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list search providers: %w", err)
|
||||
}
|
||||
|
||||
existing := make(map[string]struct{}, len(rows))
|
||||
for _, row := range rows {
|
||||
existing[row.Provider] = struct{}{}
|
||||
}
|
||||
|
||||
for _, dp := range defaultProviders {
|
||||
if _, ok := existing[string(dp.Name)]; ok {
|
||||
continue
|
||||
}
|
||||
_, err := s.queries.CreateSearchProvider(ctx, sqlc.CreateSearchProviderParams{
|
||||
Name: dp.DisplayName,
|
||||
Provider: string(dp.Name),
|
||||
Config: []byte("{}"),
|
||||
Enable: false,
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Warn("failed to create default search provider",
|
||||
slog.String("provider", string(dp.Name)),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
s.logger.Info("created default search provider", slog.String("provider", string(dp.Name)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isValidProviderName(name ProviderName) bool {
|
||||
switch name {
|
||||
case ProviderBrave, ProviderBing, ProviderGoogle,
|
||||
|
||||
@@ -48,6 +48,7 @@ type UpdateRequest struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
Provider *ProviderName `json:"provider,omitempty"`
|
||||
Config map[string]any `json:"config,omitempty"`
|
||||
Enable *bool `json:"enable,omitempty"`
|
||||
}
|
||||
|
||||
type GetResponse struct {
|
||||
@@ -55,6 +56,7 @@ type GetResponse struct {
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Config map[string]any `json:"config,omitempty"`
|
||||
Enable bool `json:"enable"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ func (s *Service) CreateProvider(ctx context.Context, req CreateProviderRequest)
|
||||
Name: strings.TrimSpace(req.Name),
|
||||
Provider: string(req.Provider),
|
||||
Config: []byte("{}"),
|
||||
Enable: false,
|
||||
})
|
||||
if err != nil {
|
||||
return ProviderResponse{}, fmt.Errorf("create tts provider: %w", err)
|
||||
@@ -106,11 +107,16 @@ func (s *Service) UpdateProvider(ctx context.Context, id string, req UpdateProvi
|
||||
if req.Name != nil {
|
||||
name = strings.TrimSpace(*req.Name)
|
||||
}
|
||||
enable := current.Enable
|
||||
if req.Enable != nil {
|
||||
enable = *req.Enable
|
||||
}
|
||||
updated, err := s.queries.UpdateTtsProvider(ctx, sqlc.UpdateTtsProviderParams{
|
||||
ID: pgID,
|
||||
Name: name,
|
||||
Provider: current.Provider,
|
||||
Config: current.Config,
|
||||
Enable: enable,
|
||||
})
|
||||
if err != nil {
|
||||
return ProviderResponse{}, fmt.Errorf("update tts provider: %w", err)
|
||||
@@ -126,6 +132,50 @@ func (s *Service) DeleteProvider(ctx context.Context, id string) error {
|
||||
return s.queries.DeleteTtsProvider(ctx, pgID)
|
||||
}
|
||||
|
||||
// EnsureDefaults creates a default TTS provider for each registered adapter
|
||||
// type that does not yet exist in the database.
|
||||
func (s *Service) EnsureDefaults(ctx context.Context) error {
|
||||
rows, err := s.queries.ListTtsProviders(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list tts providers: %w", err)
|
||||
}
|
||||
existing := make(map[string]struct{}, len(rows))
|
||||
for _, row := range rows {
|
||||
existing[row.Provider] = struct{}{}
|
||||
}
|
||||
|
||||
for _, meta := range s.registry.ListMeta() {
|
||||
if _, ok := existing[meta.Provider]; ok {
|
||||
continue
|
||||
}
|
||||
adapter, adapterErr := s.registry.Get(TtsType(meta.Provider))
|
||||
if adapterErr != nil {
|
||||
continue
|
||||
}
|
||||
row, createErr := s.queries.CreateTtsProvider(ctx, sqlc.CreateTtsProviderParams{
|
||||
Name: meta.DisplayName,
|
||||
Provider: meta.Provider,
|
||||
Config: []byte("{}"),
|
||||
Enable: false,
|
||||
})
|
||||
if createErr != nil {
|
||||
s.logger.Warn("failed to create default tts provider",
|
||||
slog.String("provider", meta.Provider),
|
||||
slog.Any("error", createErr),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if importErr := s.importModelsForProvider(ctx, row.ID, adapter); importErr != nil {
|
||||
s.logger.Warn("auto-import models failed for default tts provider",
|
||||
slog.String("provider", meta.Provider),
|
||||
slog.Any("error", importErr),
|
||||
)
|
||||
}
|
||||
s.logger.Info("created default tts provider", slog.String("provider", meta.Provider))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Model CRUD
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -500,6 +550,7 @@ func (*Service) toProviderResponse(row sqlc.TtsProvider) ProviderResponse {
|
||||
ID: row.ID.String(),
|
||||
Name: row.Name,
|
||||
Provider: row.Provider,
|
||||
Enable: row.Enable,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
UpdatedAt: row.UpdatedAt.Time,
|
||||
}
|
||||
|
||||
@@ -10,13 +10,15 @@ type CreateProviderRequest struct {
|
||||
}
|
||||
|
||||
type UpdateProviderRequest struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
Enable *bool `json:"enable,omitempty"`
|
||||
}
|
||||
|
||||
type ProviderResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Enable bool `json:"enable"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user