mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
Feat/speech support (#392)
* feat: expand speech provider support with new client types and configuration schema * feat: add icon support for speech providers and update related configurations * feat: add SVG support for Deepgram and Elevenlabs with Vue components * feat: except *-speech client type in llm provider * feat: enhance speech provider functionality with advanced settings and model import capabilities * chore: remove go.mod replace * feat: enhance speech provider functionality with advanced settings and model import capabilities * chore: update go module dependencies * feat: Ear and Mouth * fix: separate ear/mouth page * fix: separate audio domain and restore transcription templates Move speech and transcription internals into the audio domain, restore template-driven transcription providers, and regenerate Swagger/SDK so the frontend can stop hand-calling /transcription-* APIs. --------- Co-authored-by: aki <arisu@ieee.org>
This commit is contained in:
@@ -72,8 +72,7 @@ func TestSpawnAndNotify(t *testing.T) {
|
||||
task := mgr.Get(taskID)
|
||||
if task == nil {
|
||||
t.Fatal("task not found after completion")
|
||||
}
|
||||
if task.Status != TaskCompleted {
|
||||
} else if task.Status != TaskCompleted {
|
||||
t.Errorf("expected task status completed, got %s", task.Status)
|
||||
}
|
||||
}
|
||||
@@ -130,8 +129,7 @@ func TestKillTask(t *testing.T) {
|
||||
task := mgr.Get(taskID)
|
||||
if task == nil {
|
||||
t.Fatal("task not found")
|
||||
}
|
||||
if task.Status != TaskKilled {
|
||||
} else if task.Status != TaskKilled {
|
||||
t.Errorf("expected status killed, got %s", task.Status)
|
||||
}
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ func retryDelay(attempt int, cfg RetryConfig) time.Duration {
|
||||
if backoffIdx > 20 {
|
||||
backoffIdx = 20
|
||||
}
|
||||
delay := cfg.BaseDelay * time.Duration(1<<uint(backoffIdx))
|
||||
delay := cfg.BaseDelay * time.Duration(1<<backoffIdx)
|
||||
delay = min(delay, cfg.MaxDelay)
|
||||
// Add jitter: random value in [0, delay/2), so final delay is in [delay/2, delay).
|
||||
// math/rand is intentional here — cryptographic randomness is not needed for backoff jitter.
|
||||
|
||||
@@ -295,7 +295,7 @@ func (p *ContainerProvider) execRead(ctx context.Context, session SessionContext
|
||||
content += "\n"
|
||||
}
|
||||
|
||||
content = addLineNumbers(content, int32(lineOffset))
|
||||
content = addLineNumbers(content, lineOffset)
|
||||
return map[string]any{"content": content, "total_lines": totalLines}, nil
|
||||
}
|
||||
|
||||
@@ -757,7 +757,7 @@ func truncateStr(s string, n int) string {
|
||||
return s[:n] + "..."
|
||||
}
|
||||
|
||||
func addLineNumbers(content string, startLine int32) string {
|
||||
func addLineNumbers(content string, startLine int) string {
|
||||
if content == "" {
|
||||
return content
|
||||
}
|
||||
@@ -765,7 +765,7 @@ func addLineNumbers(content string, startLine int32) string {
|
||||
var out strings.Builder
|
||||
out.Grow(len(content) + len(lines)*8)
|
||||
for i, line := range lines {
|
||||
fmt.Fprintf(&out, "%6d\t%s\n", int(startLine)+i, line)
|
||||
fmt.Fprintf(&out, "%6d\t%s\n", startLine+i, line)
|
||||
}
|
||||
return out.String()
|
||||
}
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
//nolint:gosec
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
audiopkg "github.com/memohai/memoh/internal/audio"
|
||||
"github.com/memohai/memoh/internal/media"
|
||||
"github.com/memohai/memoh/internal/settings"
|
||||
)
|
||||
|
||||
const mediaDataPrefix = "/data/media/"
|
||||
|
||||
type TranscriptionProvider struct {
|
||||
logger *slog.Logger
|
||||
settings *settings.Service
|
||||
audio *audiopkg.Service
|
||||
media *media.Service
|
||||
http *http.Client
|
||||
}
|
||||
|
||||
func NewTranscriptionProvider(log *slog.Logger, settingsSvc *settings.Service, audioSvc *audiopkg.Service, mediaSvc *media.Service) *TranscriptionProvider {
|
||||
if log == nil {
|
||||
log = slog.Default()
|
||||
}
|
||||
return &TranscriptionProvider{
|
||||
logger: log.With(slog.String("tool", "transcribe_audio")),
|
||||
settings: settingsSvc,
|
||||
audio: audioSvc,
|
||||
media: mediaSvc,
|
||||
http: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 10 {
|
||||
return errors.New("stopped after 10 redirects")
|
||||
}
|
||||
if _, err := validateURL(req.Context(), req.URL.String()); err != nil {
|
||||
return fmt.Errorf("redirect to non-public address is not allowed: %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TranscriptionProvider) Tools(ctx context.Context, session SessionContext) ([]sdk.Tool, error) {
|
||||
if session.IsSubagent || p.settings == nil || p.audio == nil || p.media == nil {
|
||||
return nil, nil
|
||||
}
|
||||
botID := strings.TrimSpace(session.BotID)
|
||||
if botID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
botSettings, err := p.settings.GetBot(ctx, botID)
|
||||
if err != nil || strings.TrimSpace(botSettings.TranscriptionModelID) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
sess := session
|
||||
return []sdk.Tool{{
|
||||
Name: "transcribe_audio",
|
||||
Description: "Transcribe an audio or voice message into text. Use this when the user sent a voice message and you need to understand its contents. Accepts a bot media path such as /data/media/... or a direct URL.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{"type": "string", "description": "Audio file path from the message context, usually under /data/media/..."},
|
||||
"url": map[string]any{"type": "string", "description": "Direct audio URL when a path is unavailable"},
|
||||
"language": map[string]any{"type": "string", "description": "Optional language hint"},
|
||||
"prompt": map[string]any{"type": "string", "description": "Optional transcription prompt"},
|
||||
"contentType": map[string]any{"type": "string", "description": "Optional MIME type override"},
|
||||
},
|
||||
"required": []string{},
|
||||
},
|
||||
Execute: func(execCtx *sdk.ToolExecContext, input any) (any, error) {
|
||||
return p.execTranscribe(execCtx.Context, sess, inputAsMap(input))
|
||||
},
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func (p *TranscriptionProvider) execTranscribe(ctx context.Context, session SessionContext, args map[string]any) (any, error) {
|
||||
botID := strings.TrimSpace(session.BotID)
|
||||
if botID == "" {
|
||||
return nil, errors.New("bot_id is required")
|
||||
}
|
||||
botSettings, err := p.settings.GetBot(ctx, botID)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to load bot settings")
|
||||
}
|
||||
modelID := strings.TrimSpace(botSettings.TranscriptionModelID)
|
||||
if modelID == "" {
|
||||
return nil, errors.New("bot has no transcription model configured")
|
||||
}
|
||||
|
||||
path := FirstStringArg(args, "path", "audio_path", "file_path")
|
||||
rawURL := FirstStringArg(args, "url", "audio_url")
|
||||
if path == "" && rawURL == "" {
|
||||
return nil, errors.New("path or url is required")
|
||||
}
|
||||
|
||||
audio, filename, contentType, err := p.loadAudio(ctx, botID, path, rawURL, FirstStringArg(args, "contentType", "content_type"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
override := map[string]any{}
|
||||
if language := FirstStringArg(args, "language"); language != "" {
|
||||
override["language"] = language
|
||||
}
|
||||
if prompt := FirstStringArg(args, "prompt"); prompt != "" {
|
||||
override["prompt"] = prompt
|
||||
}
|
||||
result, err := p.audio.Transcribe(ctx, modelID, audio, filename, contentType, override)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]any{
|
||||
"ok": true,
|
||||
"text": result.Text,
|
||||
"language": result.Language,
|
||||
"duration_seconds": result.DurationSeconds,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *TranscriptionProvider) loadAudio(ctx context.Context, botID, pathValue, rawURL, contentTypeOverride string) ([]byte, string, string, error) {
|
||||
if pathValue != "" {
|
||||
return p.loadAudioFromPath(ctx, botID, pathValue, contentTypeOverride)
|
||||
}
|
||||
u, err := validateURL(ctx, rawURL)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
resp, err := p.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
_ = resp.Body.Close()
|
||||
return nil, "", "", fmt.Errorf("download audio: unexpected status %d", resp.StatusCode)
|
||||
}
|
||||
defer func(body io.ReadCloser) {
|
||||
if closeErr := body.Close(); closeErr != nil {
|
||||
p.logger.Warn("failed to close audio response body", slog.Any("error", closeErr))
|
||||
}
|
||||
}(resp.Body)
|
||||
audio, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
contentType := strings.TrimSpace(contentTypeOverride)
|
||||
if contentType == "" {
|
||||
contentType = strings.TrimSpace(resp.Header.Get("Content-Type"))
|
||||
}
|
||||
return audio, filepath.Base(strings.TrimSpace(req.URL.Path)), contentType, nil
|
||||
}
|
||||
|
||||
func (p *TranscriptionProvider) loadAudioFromPath(ctx context.Context, botID, pathValue, contentTypeOverride string) ([]byte, string, string, error) {
|
||||
storageKey := strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(pathValue), mediaDataPrefix))
|
||||
if storageKey == "" || storageKey == strings.TrimSpace(pathValue) {
|
||||
return nil, "", "", fmt.Errorf("unsupported media path: %s", pathValue)
|
||||
}
|
||||
asset, err := p.media.GetByStorageKey(ctx, botID, storageKey)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
reader, _, err := p.media.Open(ctx, botID, asset.ContentHash)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
defer func(reader io.ReadCloser) {
|
||||
if closeErr := reader.Close(); closeErr != nil {
|
||||
p.logger.Warn("failed to close media reader", slog.Any("error", closeErr))
|
||||
}
|
||||
}(reader)
|
||||
audio, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
contentType := strings.TrimSpace(contentTypeOverride)
|
||||
if contentType == "" {
|
||||
contentType = strings.TrimSpace(asset.Mime)
|
||||
}
|
||||
return audio, filepath.Base(storageKey), contentType, nil
|
||||
}
|
||||
|
||||
func validateURL(ctx context.Context, rawURL string) (*url.URL, error) {
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid url: %w", err)
|
||||
}
|
||||
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return nil, fmt.Errorf("unsupported scheme: %s", u.Scheme)
|
||||
}
|
||||
|
||||
hostname := u.Hostname()
|
||||
if hostname == "" {
|
||||
return nil, errors.New("missing hostname in url")
|
||||
}
|
||||
|
||||
resolver := net.Resolver{}
|
||||
ips, err := resolver.LookupIPAddr(ctx, hostname)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dns lookup failed for %s: %w", hostname, err)
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("no ip addresses found for %s", hostname)
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if ip.IP.IsLoopback() || ip.IP.IsPrivate() || ip.IP.IsLinkLocalUnicast() || ip.IP.IsLinkLocalMulticast() {
|
||||
return nil, fmt.Errorf("url resolves to a non-public ip address: %s", ip.IP.String())
|
||||
}
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
audiopkg "github.com/memohai/memoh/internal/audio"
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
"github.com/memohai/memoh/internal/settings"
|
||||
ttspkg "github.com/memohai/memoh/internal/tts"
|
||||
)
|
||||
|
||||
const ttsMaxTextLen = 500
|
||||
@@ -30,26 +30,26 @@ type TTSChannelResolver interface {
|
||||
type TTSProvider struct {
|
||||
logger *slog.Logger
|
||||
settings *settings.Service
|
||||
tts *ttspkg.Service
|
||||
audio *audiopkg.Service
|
||||
sender TTSSender
|
||||
resolver TTSChannelResolver
|
||||
}
|
||||
|
||||
func NewTTSProvider(log *slog.Logger, settingsSvc *settings.Service, ttsSvc *ttspkg.Service, sender TTSSender, resolver TTSChannelResolver) *TTSProvider {
|
||||
func NewTTSProvider(log *slog.Logger, settingsSvc *settings.Service, audioSvc *audiopkg.Service, sender TTSSender, resolver TTSChannelResolver) *TTSProvider {
|
||||
if log == nil {
|
||||
log = slog.Default()
|
||||
}
|
||||
return &TTSProvider{
|
||||
logger: log.With(slog.String("tool", "tts")),
|
||||
settings: settingsSvc,
|
||||
tts: ttsSvc,
|
||||
audio: audioSvc,
|
||||
sender: sender,
|
||||
resolver: resolver,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TTSProvider) Tools(ctx context.Context, session SessionContext) ([]sdk.Tool, error) {
|
||||
if session.IsSubagent || p.settings == nil || p.tts == nil || p.sender == nil || p.resolver == nil {
|
||||
if session.IsSubagent || p.settings == nil || p.audio == nil || p.sender == nil || p.resolver == nil {
|
||||
return nil, nil
|
||||
}
|
||||
botID := strings.TrimSpace(session.BotID)
|
||||
@@ -115,7 +115,7 @@ func (p *TTSProvider) execSpeak(ctx context.Context, session SessionContext, arg
|
||||
if botSettings.TtsModelID == "" {
|
||||
return nil, errors.New("bot has no TTS model configured")
|
||||
}
|
||||
audioData, contentType, synthErr := p.tts.Synthesize(ctx, botSettings.TtsModelID, text, nil)
|
||||
audioData, contentType, synthErr := p.audio.Synthesize(ctx, botSettings.TtsModelID, text, nil)
|
||||
if synthErr != nil {
|
||||
return nil, fmt.Errorf("speech synthesis failed: %s", synthErr.Error())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user