Files
Memoh/internal/audio/service.go
T
Acbox c9dcfe287f Feat/speech support (#392)
* feat: expand speech provider support with new client types and configuration schema

* feat: add icon support for speech providers and update related configurations

* feat: add SVG support for Deepgram and Elevenlabs with Vue components

* feat: except *-speech client type in llm provider

* feat: enhance speech provider functionality with advanced settings and model import capabilities

* chore: remove go.mod replace

* feat: enhance speech provider functionality with advanced settings and model import capabilities

* chore: update go module dependencies

* feat: Ear and Mouth

* fix: separate ear/mouth page

* fix: separate audio domain and restore transcription templates

Move speech and transcription internals into the audio domain, restore template-driven transcription providers, and regenerate Swagger/SDK so the frontend can stop hand-calling /transcription-* APIs.

---------

Co-authored-by: aki <arisu@ieee.org>
2026-04-22 00:09:46 +08:00

770 lines
22 KiB
Go

package audio
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"github.com/jackc/pgx/v5/pgtype"
sdk "github.com/memohai/twilight-ai/sdk"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/models"
)
type Service struct {
queries *sqlc.Queries
logger *slog.Logger
registry *Registry
}
func NewService(log *slog.Logger, queries *sqlc.Queries, registry *Registry) *Service {
return &Service{
queries: queries,
logger: log.With(slog.String("service", "audio")),
registry: registry,
}
}
func (s *Service) Registry() *Registry { return s.registry }
func (s *Service) ListMeta(_ context.Context) []ProviderMetaResponse {
return s.registry.ListMeta()
}
func (s *Service) ListSpeechMeta(_ context.Context) []ProviderMetaResponse {
return s.registry.ListSpeechMeta()
}
func (s *Service) ListTranscriptionMeta(_ context.Context) []ProviderMetaResponse {
return s.registry.ListTranscriptionMeta()
}
func (s *Service) ListSpeechProviders(ctx context.Context) ([]SpeechProviderResponse, error) {
rows, err := s.queries.ListSpeechProviders(ctx)
if err != nil {
return nil, fmt.Errorf("list speech providers: %w", err)
}
items := make([]SpeechProviderResponse, 0, len(rows))
for _, row := range rows {
items = append(items, toSpeechProviderResponse(row))
}
return items, nil
}
func (s *Service) ListTranscriptionProviders(ctx context.Context) ([]SpeechProviderResponse, error) {
rows, err := s.queries.ListTranscriptionProviders(ctx)
if err != nil {
return nil, fmt.Errorf("list transcription providers: %w", err)
}
items := make([]SpeechProviderResponse, 0, len(rows))
for _, row := range rows {
items = append(items, toSpeechProviderResponse(row))
}
return items, nil
}
func (s *Service) GetSpeechProvider(ctx context.Context, id string) (SpeechProviderResponse, error) {
pgID, err := db.ParseUUID(id)
if err != nil {
return SpeechProviderResponse{}, err
}
row, err := s.queries.GetProviderByID(ctx, pgID)
if err != nil {
return SpeechProviderResponse{}, fmt.Errorf("get speech provider: %w", err)
}
return toSpeechProviderResponse(row), nil
}
func (s *Service) ListSpeechModels(ctx context.Context) ([]SpeechModelResponse, error) {
rows, err := s.queries.ListSpeechModels(ctx)
if err != nil {
return nil, fmt.Errorf("list speech models: %w", err)
}
items := make([]SpeechModelResponse, 0, len(rows))
for _, row := range rows {
if s.shouldHideModel(row.ProviderType, models.ModelTypeSpeech, row.ModelID) {
continue
}
items = append(items, toSpeechModelFromListRow(row))
}
return items, nil
}
func (s *Service) ListTranscriptionModels(ctx context.Context) ([]TranscriptionModelResponse, error) {
rows, err := s.queries.ListTranscriptionModels(ctx)
if err != nil {
return nil, fmt.Errorf("list transcription models: %w", err)
}
items := make([]TranscriptionModelResponse, 0, len(rows))
for _, row := range rows {
if s.shouldHideModel(row.ProviderType, models.ModelTypeTranscription, row.ModelID) {
continue
}
items = append(items, toTranscriptionModelFromListRow(row))
}
return items, nil
}
func (s *Service) ListSpeechModelsByProvider(ctx context.Context, providerID string) ([]SpeechModelResponse, error) {
pgID, err := db.ParseUUID(providerID)
if err != nil {
return nil, err
}
providerRow, err := s.queries.GetProviderByID(ctx, pgID)
if err != nil {
return nil, fmt.Errorf("get speech provider: %w", err)
}
def, err := s.registry.Get(models.ClientType(providerRow.ClientType))
if err != nil {
return nil, err
}
rows, err := s.queries.ListSpeechModelsByProviderID(ctx, pgID)
if err != nil {
return nil, fmt.Errorf("list speech models by provider: %w", err)
}
items := make([]SpeechModelResponse, 0, len(rows))
for _, row := range rows {
if shouldHideTemplateModel(def, models.ModelTypeSpeech, row.ModelID) {
continue
}
items = append(items, toSpeechModelFromModel(row, ""))
}
return items, nil
}
func (s *Service) ListTranscriptionModelsByProvider(ctx context.Context, providerID string) ([]TranscriptionModelResponse, error) {
pgID, err := db.ParseUUID(providerID)
if err != nil {
return nil, err
}
providerRow, err := s.queries.GetProviderByID(ctx, pgID)
if err != nil {
return nil, fmt.Errorf("get speech provider: %w", err)
}
def, err := s.registry.Get(models.ClientType(providerRow.ClientType))
if err != nil {
return nil, err
}
rows, err := s.queries.ListTranscriptionModelsByProviderID(ctx, pgID)
if err != nil {
return nil, fmt.Errorf("list transcription models by provider: %w", err)
}
items := make([]TranscriptionModelResponse, 0, len(rows))
for _, row := range rows {
if shouldHideTemplateModel(def, models.ModelTypeTranscription, row.ModelID) {
continue
}
items = append(items, toTranscriptionModelFromModel(row, ""))
}
return items, nil
}
func (s *Service) GetSpeechModel(ctx context.Context, id string) (SpeechModelResponse, error) {
pgID, err := db.ParseUUID(id)
if err != nil {
return SpeechModelResponse{}, err
}
row, err := s.queries.GetSpeechModelWithProvider(ctx, pgID)
if err != nil {
return SpeechModelResponse{}, fmt.Errorf("get speech model: %w", err)
}
return toSpeechModelWithProviderResponse(row), nil
}
func (s *Service) GetTranscriptionModel(ctx context.Context, id string) (TranscriptionModelResponse, error) {
pgID, err := db.ParseUUID(id)
if err != nil {
return TranscriptionModelResponse{}, err
}
row, err := s.queries.GetTranscriptionModelWithProvider(ctx, pgID)
if err != nil {
return TranscriptionModelResponse{}, fmt.Errorf("get transcription model: %w", err)
}
return toTranscriptionModelWithProviderResponse(row), nil
}
func (s *Service) UpdateSpeechModel(ctx context.Context, id string, req UpdateSpeechModelRequest) (SpeechModelResponse, error) {
pgID, err := db.ParseUUID(id)
if err != nil {
return SpeechModelResponse{}, err
}
row, err := s.queries.GetSpeechModelWithProvider(ctx, pgID)
if err != nil {
return SpeechModelResponse{}, fmt.Errorf("get speech model: %w", err)
}
configJSON, err := json.Marshal(req.Config)
if err != nil {
return SpeechModelResponse{}, fmt.Errorf("marshal speech config: %w", err)
}
name := row.Name
if req.Name != nil {
name = pgtype.Text{String: *req.Name, Valid: *req.Name != ""}
}
updated, err := s.queries.UpdateModel(ctx, sqlc.UpdateModelParams{
ID: pgID,
ModelID: row.ModelID,
Name: name,
ProviderID: row.ProviderID,
Type: string(models.ModelTypeSpeech),
Config: configJSON,
})
if err != nil {
return SpeechModelResponse{}, fmt.Errorf("update speech model: %w", err)
}
return toSpeechModelFromModel(updated, row.ProviderType), nil
}
func (s *Service) UpdateTranscriptionModel(ctx context.Context, id string, req UpdateSpeechModelRequest) (TranscriptionModelResponse, error) {
pgID, err := db.ParseUUID(id)
if err != nil {
return TranscriptionModelResponse{}, err
}
row, err := s.queries.GetTranscriptionModelWithProvider(ctx, pgID)
if err != nil {
return TranscriptionModelResponse{}, fmt.Errorf("get transcription model: %w", err)
}
configJSON, err := json.Marshal(req.Config)
if err != nil {
return TranscriptionModelResponse{}, fmt.Errorf("marshal transcription config: %w", err)
}
name := row.Name
if req.Name != nil {
name = pgtype.Text{String: *req.Name, Valid: *req.Name != ""}
}
updated, err := s.queries.UpdateModel(ctx, sqlc.UpdateModelParams{
ID: pgID,
ModelID: row.ModelID,
Name: name,
ProviderID: row.ProviderID,
Type: string(models.ModelTypeTranscription),
Config: configJSON,
})
if err != nil {
return TranscriptionModelResponse{}, fmt.Errorf("update transcription model: %w", err)
}
return toTranscriptionModelFromModel(updated, row.ProviderType), nil
}
func (s *Service) Synthesize(ctx context.Context, modelID string, text string, overrideCfg map[string]any) ([]byte, string, error) {
params, err := s.resolveSpeechParams(ctx, modelID, text, overrideCfg)
if err != nil {
return nil, "", err
}
result, err := sdk.GenerateSpeech(ctx,
sdk.WithSpeechModel(params.model),
sdk.WithText(text),
sdk.WithSpeechConfig(params.config),
)
if err != nil {
return nil, "", fmt.Errorf("synthesize: %w", err)
}
return result.Audio, result.ContentType, nil
}
func (s *Service) StreamToFile(ctx context.Context, modelID string, text string, w io.Writer) (string, error) {
params, err := s.resolveSpeechParams(ctx, modelID, text, nil)
if err != nil {
return "", err
}
streamResult, err := sdk.StreamSpeech(ctx,
sdk.WithSpeechModel(params.model),
sdk.WithText(text),
sdk.WithSpeechConfig(params.config),
)
if err != nil {
return "", fmt.Errorf("stream: %w", err)
}
audio, err := streamResult.Bytes()
if err != nil {
return "", fmt.Errorf("stream: %w", err)
}
if _, writeErr := w.Write(audio); writeErr != nil {
return "", fmt.Errorf("write chunk: %w", writeErr)
}
return streamResult.ContentType, nil
}
func (s *Service) GetModelCapabilities(ctx context.Context, modelID string) (*ModelCapabilities, error) {
pgID, err := db.ParseUUID(modelID)
if err != nil {
return nil, err
}
modelRow, err := s.queries.GetSpeechModelWithProvider(ctx, pgID)
if err != nil {
return nil, fmt.Errorf("get speech model: %w", err)
}
def, err := s.registry.Get(models.ClientType(modelRow.ProviderType))
if err != nil {
return nil, err
}
template := findModelTemplate(def.Models, def.DefaultModel, modelRow.ModelID)
if template == nil {
return nil, fmt.Errorf("speech model capabilities not found: %s", modelRow.ModelID)
}
caps := template.Capabilities
if len(caps.ConfigSchema.Fields) == 0 {
caps.ConfigSchema = template.ConfigSchema
}
return &caps, nil
}
func (s *Service) GetSpeechModelCapabilities(ctx context.Context, modelID string) (*ModelCapabilities, error) {
return s.GetModelCapabilities(ctx, modelID)
}
func (s *Service) GetTranscriptionModelCapabilities(ctx context.Context, modelID string) (*ModelCapabilities, error) {
pgID, err := db.ParseUUID(modelID)
if err != nil {
return nil, err
}
modelRow, err := s.queries.GetTranscriptionModelWithProvider(ctx, pgID)
if err != nil {
return nil, fmt.Errorf("get transcription model: %w", err)
}
def, err := s.registry.Get(models.ClientType(modelRow.ProviderType))
if err != nil {
return nil, err
}
template := findModelTemplate(def.TranscriptionModels, def.DefaultTranscriptionModel, modelRow.ModelID)
if template == nil {
return nil, fmt.Errorf("transcription model capabilities not found: %s", modelRow.ModelID)
}
caps := template.Capabilities
if len(caps.ConfigSchema.Fields) == 0 {
caps.ConfigSchema = template.ConfigSchema
}
return &caps, nil
}
func (s *Service) FetchRemoteModels(ctx context.Context, providerID string) ([]ModelInfo, error) {
pgID, err := db.ParseUUID(providerID)
if err != nil {
return nil, err
}
providerRow, err := s.queries.GetProviderByID(ctx, pgID)
if err != nil {
return nil, fmt.Errorf("get speech provider: %w", err)
}
def, err := s.registry.Get(models.ClientType(providerRow.ClientType))
if err != nil {
return nil, err
}
if !def.SupportsList || def.Factory == nil {
return nil, fmt.Errorf("speech provider does not support model discovery: %s", providerRow.ClientType)
}
provider, err := def.Factory(parseConfig(providerRow.Config))
if err != nil {
return nil, fmt.Errorf("build speech provider: %w", err)
}
remoteModels, err := provider.ListModels(ctx)
if err != nil {
return nil, fmt.Errorf("list speech models: %w", err)
}
discovered := make([]ModelInfo, 0, len(remoteModels))
for _, remoteModel := range remoteModels {
if remoteModel == nil || remoteModel.ID == "" {
continue
}
discovered = append(discovered, mergeRemoteModelInfo(remoteModel.ID, def.Models))
}
return discovered, nil
}
func (s *Service) FetchRemoteTranscriptionModels(ctx context.Context, providerID string) ([]ModelInfo, error) {
pgID, err := db.ParseUUID(providerID)
if err != nil {
return nil, err
}
providerRow, err := s.queries.GetProviderByID(ctx, pgID)
if err != nil {
return nil, fmt.Errorf("get speech provider: %w", err)
}
def, err := s.registry.Get(models.ClientType(providerRow.ClientType))
if err != nil {
return nil, err
}
if !def.SupportsTranscriptionList || def.TranscriptionFactory == nil {
return nil, fmt.Errorf("speech provider does not support transcription model discovery: %s", providerRow.ClientType)
}
provider, err := def.TranscriptionFactory(parseConfig(providerRow.Config))
if err != nil {
return nil, fmt.Errorf("build transcription provider: %w", err)
}
remoteModels, err := provider.ListModels(ctx)
if err != nil {
return nil, fmt.Errorf("list transcription models: %w", err)
}
discovered := make([]ModelInfo, 0, len(remoteModels))
for _, remoteModel := range remoteModels {
if remoteModel == nil || remoteModel.ID == "" {
continue
}
discovered = append(discovered, mergeRemoteModelInfo(remoteModel.ID, def.TranscriptionModels))
}
return discovered, nil
}
func (s *Service) Transcribe(ctx context.Context, modelID string, audio []byte, filename string, contentType string, overrideCfg map[string]any) (*sdk.TranscriptionResult, error) {
params, err := s.resolveTranscriptionParams(ctx, modelID, audio, filename, contentType, overrideCfg)
if err != nil {
return nil, err
}
result, err := sdk.Transcribe(ctx,
sdk.WithTranscriptionModel(params.model),
sdk.WithAudio(audio, filename, contentType),
sdk.WithTranscriptionConfig(params.config),
)
if err != nil {
return nil, fmt.Errorf("transcribe: %w", err)
}
return result, nil
}
type resolvedSpeechParams struct {
model *sdk.SpeechModel
config map[string]any
}
type resolvedTranscriptionParams struct {
model *sdk.TranscriptionModel
config map[string]any
}
func (s *Service) resolveSpeechParams(ctx context.Context, modelID string, text string, overrideCfg map[string]any) (*resolvedSpeechParams, error) {
_ = text
pgID, err := db.ParseUUID(modelID)
if err != nil {
return nil, err
}
modelRow, err := s.queries.GetSpeechModelWithProvider(ctx, pgID)
if err != nil {
return nil, fmt.Errorf("get speech model: %w", err)
}
providerRow, err := s.queries.GetProviderByID(ctx, modelRow.ProviderID)
if err != nil {
return nil, fmt.Errorf("get speech provider: %w", err)
}
def, err := s.registry.Get(models.ClientType(providerRow.ClientType))
if err != nil {
return nil, err
}
provider, err := def.Factory(parseConfig(providerRow.Config))
if err != nil {
return nil, fmt.Errorf("build speech provider: %w", err)
}
cfg := mergeConfig(parseConfig(providerRow.Config), parseConfig(modelRow.Config), overrideCfg)
return &resolvedSpeechParams{
model: &sdk.SpeechModel{ID: modelRow.ModelID, Provider: provider},
config: cfg,
}, nil
}
func (s *Service) resolveTranscriptionParams(ctx context.Context, modelID string, audio []byte, filename string, contentType string, overrideCfg map[string]any) (*resolvedTranscriptionParams, error) {
_ = audio
_ = filename
_ = contentType
pgID, err := db.ParseUUID(modelID)
if err != nil {
return nil, err
}
modelRow, err := s.queries.GetTranscriptionModelWithProvider(ctx, pgID)
if err != nil {
return nil, fmt.Errorf("get transcription model: %w", err)
}
providerRow, err := s.queries.GetProviderByID(ctx, modelRow.ProviderID)
if err != nil {
return nil, fmt.Errorf("get speech provider: %w", err)
}
def, err := s.registry.Get(models.ClientType(providerRow.ClientType))
if err != nil {
return nil, err
}
provider, err := def.TranscriptionFactory(parseConfig(providerRow.Config))
if err != nil {
return nil, fmt.Errorf("build transcription provider: %w", err)
}
cfg := mergeConfig(parseConfig(providerRow.Config), parseConfig(modelRow.Config), overrideCfg)
return &resolvedTranscriptionParams{
model: &sdk.TranscriptionModel{ID: modelRow.ModelID, Provider: provider},
config: cfg,
}, nil
}
func parseConfig(raw []byte) map[string]any {
if len(raw) == 0 {
return map[string]any{}
}
var cfg map[string]any
if err := json.Unmarshal(raw, &cfg); err != nil || cfg == nil {
return map[string]any{}
}
return cfg
}
func mergeConfig(parts ...map[string]any) map[string]any {
out := make(map[string]any)
for _, part := range parts {
for key, value := range part {
out[key] = value
}
}
return out
}
func mergeRemoteModelInfo(modelID string, defaults []ModelInfo) ModelInfo {
for _, model := range defaults {
if model.ID == modelID {
return model
}
}
return ModelInfo{
ID: modelID,
Name: modelID,
}
}
func (s *Service) shouldHideModel(clientType string, modelType models.ModelType, modelID string) bool {
def, err := s.registry.Get(models.ClientType(clientType))
if err != nil {
return false
}
return shouldHideTemplateModel(def, modelType, modelID)
}
func shouldHideTemplateModel(def ProviderDefinition, modelType models.ModelType, modelID string) bool {
switch modelType {
case models.ModelTypeSpeech:
if !def.SupportsList {
return false
}
for _, model := range def.Models {
if model.ID == modelID {
return model.TemplateOnly
}
}
case models.ModelTypeTranscription:
if !def.SupportsTranscriptionList {
return false
}
for _, model := range def.TranscriptionModels {
if model.ID == modelID {
return model.TemplateOnly
}
}
}
return false
}
func findModelTemplate(modelsList []ModelInfo, defaultModel string, modelID string) *ModelInfo {
for i := range modelsList {
if modelsList[i].ID == modelID {
return &modelsList[i]
}
}
if defaultModel != "" {
for i := range modelsList {
if modelsList[i].ID == defaultModel {
return &modelsList[i]
}
}
}
if len(modelsList) > 0 {
return &modelsList[0]
}
return nil
}
func toSpeechProviderResponse(row sqlc.Provider) SpeechProviderResponse {
icon := ""
if row.Icon.Valid {
icon = row.Icon.String
}
return SpeechProviderResponse{
ID: row.ID.String(),
Name: row.Name,
ClientType: row.ClientType,
Icon: icon,
Enable: row.Enable,
Config: maskSpeechProviderConfig(parseConfig(row.Config)),
CreatedAt: row.CreatedAt.Time,
UpdatedAt: row.UpdatedAt.Time,
}
}
func maskSpeechProviderConfig(cfg map[string]any) map[string]any {
if len(cfg) == 0 {
return map[string]any{}
}
out := make(map[string]any, len(cfg))
for key, value := range cfg {
if s, ok := value.(string); ok && s != "" && isSpeechSecretKey(key) {
out[key] = maskSpeechSecret(s)
continue
}
out[key] = value
}
return out
}
func isSpeechSecretKey(key string) bool {
switch key {
case "api_key", "access_key", "secret_key", "app_key":
return true
default:
return false
}
}
func maskSpeechSecret(value string) string {
if len(value) <= 8 {
return "********"
}
return value[:4] + "****" + value[len(value)-4:]
}
func toSpeechModelFromListRow(row sqlc.ListSpeechModelsRow) SpeechModelResponse {
var cfg map[string]any
if len(row.Config) > 0 {
_ = json.Unmarshal(row.Config, &cfg)
}
name := ""
if row.Name.Valid {
name = row.Name.String
}
return SpeechModelResponse{
ID: row.ID.String(),
ModelID: row.ModelID,
Name: name,
ProviderID: row.ProviderID.String(),
ProviderType: row.ProviderType,
Config: cfg,
CreatedAt: row.CreatedAt.Time,
UpdatedAt: row.UpdatedAt.Time,
}
}
func toSpeechModelFromModel(row sqlc.Model, providerType string) SpeechModelResponse {
var cfg map[string]any
if len(row.Config) > 0 {
_ = json.Unmarshal(row.Config, &cfg)
}
name := ""
if row.Name.Valid {
name = row.Name.String
}
return SpeechModelResponse{
ID: row.ID.String(),
ModelID: row.ModelID,
Name: name,
ProviderID: row.ProviderID.String(),
ProviderType: providerType,
Config: cfg,
CreatedAt: row.CreatedAt.Time,
UpdatedAt: row.UpdatedAt.Time,
}
}
func toSpeechModelWithProviderResponse(row sqlc.GetSpeechModelWithProviderRow) SpeechModelResponse {
var cfg map[string]any
if len(row.Config) > 0 {
_ = json.Unmarshal(row.Config, &cfg)
}
name := ""
if row.Name.Valid {
name = row.Name.String
}
return SpeechModelResponse{
ID: row.ID.String(),
ModelID: row.ModelID,
Name: name,
ProviderID: row.ProviderID.String(),
ProviderType: row.ProviderType,
Config: cfg,
CreatedAt: row.CreatedAt.Time,
UpdatedAt: row.UpdatedAt.Time,
}
}
func toTranscriptionModelFromListRow(row sqlc.ListTranscriptionModelsRow) TranscriptionModelResponse {
var cfg map[string]any
if len(row.Config) > 0 {
_ = json.Unmarshal(row.Config, &cfg)
}
name := ""
if row.Name.Valid {
name = row.Name.String
}
return TranscriptionModelResponse{
ID: row.ID.String(),
ModelID: row.ModelID,
Name: name,
ProviderID: row.ProviderID.String(),
ProviderType: row.ProviderType,
Config: cfg,
CreatedAt: row.CreatedAt.Time,
UpdatedAt: row.UpdatedAt.Time,
}
}
func toTranscriptionModelFromModel(row sqlc.Model, providerType string) TranscriptionModelResponse {
var cfg map[string]any
if len(row.Config) > 0 {
_ = json.Unmarshal(row.Config, &cfg)
}
name := ""
if row.Name.Valid {
name = row.Name.String
}
return TranscriptionModelResponse{
ID: row.ID.String(),
ModelID: row.ModelID,
Name: name,
ProviderID: row.ProviderID.String(),
ProviderType: providerType,
Config: cfg,
CreatedAt: row.CreatedAt.Time,
UpdatedAt: row.UpdatedAt.Time,
}
}
func toTranscriptionModelWithProviderResponse(row sqlc.GetTranscriptionModelWithProviderRow) TranscriptionModelResponse {
var cfg map[string]any
if len(row.Config) > 0 {
_ = json.Unmarshal(row.Config, &cfg)
}
name := ""
if row.Name.Valid {
name = row.Name.String
}
return TranscriptionModelResponse{
ID: row.ID.String(),
ModelID: row.ModelID,
Name: name,
ProviderID: row.ProviderID.String(),
ProviderType: row.ProviderType,
Config: cfg,
CreatedAt: row.CreatedAt.Time,
UpdatedAt: row.UpdatedAt.Time,
}
}