feat(tts): introduce TTS system (#195)

This commit is contained in:
Fodesu
2026-03-13 02:49:52 +08:00
committed by GitHub
parent 7904de87bd
commit b46e494d3a
71 changed files with 8959 additions and 159 deletions
+255 -10
View File
@@ -2,10 +2,13 @@ package handlers
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"log/slog"
"maps"
"net/http"
"strings"
"time"
@@ -14,24 +17,39 @@ import (
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/accounts"
attachmentpkg "github.com/memohai/memoh/internal/attachment"
"github.com/memohai/memoh/internal/bots"
"github.com/memohai/memoh/internal/channel"
"github.com/memohai/memoh/internal/channel/adapters/local"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/conversation/flow"
"github.com/memohai/memoh/internal/media"
)
// localTtsSynthesizer synthesizes text to speech audio.
type localTtsSynthesizer interface {
Synthesize(ctx context.Context, modelID string, text string, overrideCfg map[string]any) ([]byte, string, error)
}
// localTtsModelResolver resolves TTS model IDs for bots.
type localTtsModelResolver interface {
ResolveTtsModelID(ctx context.Context, botID string) (string, error)
}
// LocalChannelHandler handles local channel (CLI/Web) routes backed by bot history.
type LocalChannelHandler struct {
channelType channel.ChannelType
channelManager *channel.Manager
channelStore *channel.Store
chatService *conversation.Service
routeHub *local.RouteHub
botService *bots.Service
accountService *accounts.Service
resolver *flow.Resolver
logger *slog.Logger
channelType channel.ChannelType
channelManager *channel.Manager
channelStore *channel.Store
chatService *conversation.Service
routeHub *local.RouteHub
botService *bots.Service
accountService *accounts.Service
resolver *flow.Resolver
mediaService *media.Service
ttsService localTtsSynthesizer
ttsModelResolver localTtsModelResolver
logger *slog.Logger
}
// NewLocalChannelHandler creates a local channel handler.
@@ -53,6 +71,17 @@ func (h *LocalChannelHandler) SetResolver(resolver *flow.Resolver) {
h.resolver = resolver
}
// SetMediaService sets the media service for WebSocket attachment ingestion.
func (h *LocalChannelHandler) SetMediaService(svc *media.Service) {
h.mediaService = svc
}
// SetTtsService configures TTS synthesis for handling speech_delta events.
func (h *LocalChannelHandler) SetTtsService(synth localTtsSynthesizer, resolver localTtsModelResolver) {
h.ttsService = synth
h.ttsModelResolver = resolver
}
// Register registers the local channel routes.
func (h *LocalChannelHandler) Register(e *echo.Echo) {
prefix := fmt.Sprintf("/bots/:bot_id/%s", h.channelType.String())
@@ -391,7 +420,9 @@ func (h *LocalChannelHandler) HandleWebSocket(c echo.Context) error {
go func() {
for event := range eventCh {
writer.Send(event)
for _, processed := range h.processWSEvent(streamCtx, botID, event) {
writer.Send(processed)
}
}
}()
@@ -424,3 +455,217 @@ func (*LocalChannelHandler) requireChannelIdentityID(c echo.Context) (string, er
func (h *LocalChannelHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: true})
}
// ---------------------------------------------------------------------------
// WebSocket event processing — attachment ingestion + TTS extraction
// ---------------------------------------------------------------------------
type wsEventEnvelope struct {
Type string `json:"type"`
ToolName string `json:"toolName,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
}
// processWSEvent transforms a raw WS event, ingesting attachments and
// extracting TTS audio so the web frontend receives content_hash references.
func (h *LocalChannelHandler) processWSEvent(ctx context.Context, botID string, event json.RawMessage) []json.RawMessage {
var envelope wsEventEnvelope
if err := json.Unmarshal(event, &envelope); err != nil {
return []json.RawMessage{event}
}
h.logger.Debug("ws event", slog.String("type", envelope.Type), slog.String("bot_id", botID))
switch envelope.Type {
case "attachment_delta":
h.logger.Info("ws processing attachment_delta", slog.String("bot_id", botID))
return h.wsIngestAttachments(ctx, botID, event)
case "speech_delta":
h.logger.Info("ws processing speech_delta", slog.String("bot_id", botID))
return h.wsSynthesizeSpeech(ctx, botID, event)
default:
return []json.RawMessage{event}
}
}
// wsIngestAttachments persists attachment data (container paths / data URLs)
// and rewrites them with content_hash so the web frontend can resolve them.
func (h *LocalChannelHandler) wsIngestAttachments(ctx context.Context, botID string, original json.RawMessage) []json.RawMessage {
if h.mediaService == nil {
return []json.RawMessage{original}
}
var event map[string]any
if err := json.Unmarshal(original, &event); err != nil {
return []json.RawMessage{original}
}
rawItems, _ := event["attachments"].([]any)
if len(rawItems) == 0 {
return []json.RawMessage{original}
}
changed := false
for i, raw := range rawItems {
item, ok := raw.(map[string]any)
if !ok {
continue
}
if ch, _ := item["content_hash"].(string); strings.TrimSpace(ch) != "" {
continue
}
rawURL, _ := item["url"].(string)
if rawURL == "" {
rawURL, _ = item["path"].(string)
}
if rawURL = strings.TrimSpace(rawURL); rawURL == "" {
continue
}
if ingested := h.ingestSingleAttachment(ctx, botID, rawURL, item); ingested != nil {
rawItems[i] = ingested
changed = true
}
}
if !changed {
h.logger.Debug("ws attachment_delta: no items needed ingestion", slog.String("bot_id", botID))
return []json.RawMessage{original}
}
h.logger.Info("ws attachment_delta: ingested attachments", slog.String("bot_id", botID), slog.Int("count", len(rawItems)))
out, err := json.Marshal(event)
if err != nil {
return []json.RawMessage{original}
}
return []json.RawMessage{out}
}
func (h *LocalChannelHandler) ingestSingleAttachment(ctx context.Context, botID, rawURL string, item map[string]any) map[string]any {
lower := strings.ToLower(rawURL)
if !strings.HasPrefix(lower, "data:") && !strings.HasPrefix(lower, "http://") && !strings.HasPrefix(lower, "https://") {
asset, err := h.mediaService.IngestContainerFile(ctx, botID, rawURL)
if err != nil {
h.logger.Warn("ws ingest container file failed", slog.String("path", rawURL), slog.Any("error", err))
return nil
}
return applyAssetToItem(item, botID, asset)
}
if strings.HasPrefix(lower, "data:") {
mimeType := attachmentpkg.MimeFromDataURL(rawURL)
decoded, err := attachmentpkg.DecodeBase64(rawURL, media.MaxAssetBytes)
if err != nil {
h.logger.Warn("ws decode data url failed", slog.Any("error", err))
return nil
}
asset, err := h.mediaService.Ingest(ctx, media.IngestInput{
BotID: botID,
Mime: mimeType,
Reader: decoded,
MaxBytes: media.MaxAssetBytes,
})
if err != nil {
h.logger.Warn("ws ingest data url failed", slog.Any("error", err))
return nil
}
return applyAssetToItem(item, botID, asset)
}
return nil
}
// wsSynthesizeSpeech handles speech_delta events by synthesizing audio and
// injecting attachment_delta events with the resulting voice attachments.
func (h *LocalChannelHandler) wsSynthesizeSpeech(ctx context.Context, botID string, original json.RawMessage) []json.RawMessage {
if h.ttsService == nil || h.ttsModelResolver == nil {
h.logger.Warn("speech_delta received but TTS service not configured")
return nil
}
modelID, err := h.ttsModelResolver.ResolveTtsModelID(ctx, botID)
if err != nil || strings.TrimSpace(modelID) == "" {
h.logger.Warn("speech_delta: bot has no TTS model configured", slog.String("bot_id", botID))
return nil
}
var event struct {
Speeches []struct {
Text string `json:"text"`
} `json:"speeches"`
}
if err := json.Unmarshal(original, &event); err != nil || len(event.Speeches) == 0 {
return nil
}
var results []json.RawMessage
for _, speech := range event.Speeches {
text := strings.TrimSpace(speech.Text)
if text == "" {
continue
}
audioData, contentType, synthErr := h.ttsService.Synthesize(ctx, modelID, text, nil)
if synthErr != nil {
h.logger.Warn("speech synthesis failed", slog.String("bot_id", botID), slog.Any("error", synthErr))
continue
}
att := h.buildTtsAttachment(ctx, botID, contentType, audioData)
attachmentEvent, _ := json.Marshal(map[string]any{
"type": "attachment_delta",
"attachments": []any{att},
})
results = append(results, attachmentEvent)
}
return results
}
func (h *LocalChannelHandler) buildTtsAttachment(ctx context.Context, botID, contentType string, audioData []byte) map[string]any {
att := map[string]any{
"type": "voice",
"mime": contentType,
"size": len(audioData),
}
mimeType := attachmentpkg.NormalizeMime(contentType)
if h.mediaService != nil {
asset, err := h.mediaService.Ingest(ctx, media.IngestInput{
BotID: botID,
Mime: mimeType,
Reader: bytes.NewReader(audioData),
MaxBytes: media.MaxAssetBytes,
})
if err == nil {
applyAssetToMap(att, botID, asset)
return att
}
h.logger.Warn("ws tts ingest failed", slog.Any("error", err))
}
att["url"] = "data:" + contentType + ";base64," + base64.StdEncoding.EncodeToString(audioData)
return att
}
func applyAssetToItem(item map[string]any, botID string, asset media.Asset) map[string]any {
result := maps.Clone(item)
delete(result, "path")
result["url"] = ""
applyAssetToMap(result, botID, asset)
return result
}
func applyAssetToMap(m map[string]any, botID string, asset media.Asset) {
m["content_hash"] = asset.ContentHash
m["metadata"] = map[string]any{
"bot_id": botID,
"storage_key": asset.StorageKey,
}
if mime, _ := m["mime"].(string); strings.TrimSpace(mime) == "" && asset.Mime != "" {
m["mime"] = asset.Mime
}
if size, _ := m["size"].(float64); size == 0 && asset.SizeBytes > 0 {
m["size"] = asset.SizeBytes
}
}