mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat(tts): introduce TTS system (#195)
This commit is contained in:
@@ -2,10 +2,13 @@ package handlers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -14,24 +17,39 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/accounts"
|
||||
attachmentpkg "github.com/memohai/memoh/internal/attachment"
|
||||
"github.com/memohai/memoh/internal/bots"
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
"github.com/memohai/memoh/internal/channel/adapters/local"
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/conversation/flow"
|
||||
"github.com/memohai/memoh/internal/media"
|
||||
)
|
||||
|
||||
// localTtsSynthesizer synthesizes text to speech audio.
|
||||
type localTtsSynthesizer interface {
|
||||
Synthesize(ctx context.Context, modelID string, text string, overrideCfg map[string]any) ([]byte, string, error)
|
||||
}
|
||||
|
||||
// localTtsModelResolver resolves TTS model IDs for bots.
|
||||
type localTtsModelResolver interface {
|
||||
ResolveTtsModelID(ctx context.Context, botID string) (string, error)
|
||||
}
|
||||
|
||||
// LocalChannelHandler handles local channel (CLI/Web) routes backed by bot history.
|
||||
type LocalChannelHandler struct {
|
||||
channelType channel.ChannelType
|
||||
channelManager *channel.Manager
|
||||
channelStore *channel.Store
|
||||
chatService *conversation.Service
|
||||
routeHub *local.RouteHub
|
||||
botService *bots.Service
|
||||
accountService *accounts.Service
|
||||
resolver *flow.Resolver
|
||||
logger *slog.Logger
|
||||
channelType channel.ChannelType
|
||||
channelManager *channel.Manager
|
||||
channelStore *channel.Store
|
||||
chatService *conversation.Service
|
||||
routeHub *local.RouteHub
|
||||
botService *bots.Service
|
||||
accountService *accounts.Service
|
||||
resolver *flow.Resolver
|
||||
mediaService *media.Service
|
||||
ttsService localTtsSynthesizer
|
||||
ttsModelResolver localTtsModelResolver
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewLocalChannelHandler creates a local channel handler.
|
||||
@@ -53,6 +71,17 @@ func (h *LocalChannelHandler) SetResolver(resolver *flow.Resolver) {
|
||||
h.resolver = resolver
|
||||
}
|
||||
|
||||
// SetMediaService sets the media service for WebSocket attachment ingestion.
|
||||
func (h *LocalChannelHandler) SetMediaService(svc *media.Service) {
|
||||
h.mediaService = svc
|
||||
}
|
||||
|
||||
// SetTtsService configures TTS synthesis for handling speech_delta events.
|
||||
func (h *LocalChannelHandler) SetTtsService(synth localTtsSynthesizer, resolver localTtsModelResolver) {
|
||||
h.ttsService = synth
|
||||
h.ttsModelResolver = resolver
|
||||
}
|
||||
|
||||
// Register registers the local channel routes.
|
||||
func (h *LocalChannelHandler) Register(e *echo.Echo) {
|
||||
prefix := fmt.Sprintf("/bots/:bot_id/%s", h.channelType.String())
|
||||
@@ -391,7 +420,9 @@ func (h *LocalChannelHandler) HandleWebSocket(c echo.Context) error {
|
||||
|
||||
go func() {
|
||||
for event := range eventCh {
|
||||
writer.Send(event)
|
||||
for _, processed := range h.processWSEvent(streamCtx, botID, event) {
|
||||
writer.Send(processed)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -424,3 +455,217 @@ func (*LocalChannelHandler) requireChannelIdentityID(c echo.Context) (string, er
|
||||
func (h *LocalChannelHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: true})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WebSocket event processing — attachment ingestion + TTS extraction
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type wsEventEnvelope struct {
|
||||
Type string `json:"type"`
|
||||
ToolName string `json:"toolName,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
// processWSEvent transforms a raw WS event, ingesting attachments and
|
||||
// extracting TTS audio so the web frontend receives content_hash references.
|
||||
func (h *LocalChannelHandler) processWSEvent(ctx context.Context, botID string, event json.RawMessage) []json.RawMessage {
|
||||
var envelope wsEventEnvelope
|
||||
if err := json.Unmarshal(event, &envelope); err != nil {
|
||||
return []json.RawMessage{event}
|
||||
}
|
||||
|
||||
h.logger.Debug("ws event", slog.String("type", envelope.Type), slog.String("bot_id", botID))
|
||||
|
||||
switch envelope.Type {
|
||||
case "attachment_delta":
|
||||
h.logger.Info("ws processing attachment_delta", slog.String("bot_id", botID))
|
||||
return h.wsIngestAttachments(ctx, botID, event)
|
||||
case "speech_delta":
|
||||
h.logger.Info("ws processing speech_delta", slog.String("bot_id", botID))
|
||||
return h.wsSynthesizeSpeech(ctx, botID, event)
|
||||
default:
|
||||
return []json.RawMessage{event}
|
||||
}
|
||||
}
|
||||
|
||||
// wsIngestAttachments persists attachment data (container paths / data URLs)
|
||||
// and rewrites them with content_hash so the web frontend can resolve them.
|
||||
func (h *LocalChannelHandler) wsIngestAttachments(ctx context.Context, botID string, original json.RawMessage) []json.RawMessage {
|
||||
if h.mediaService == nil {
|
||||
return []json.RawMessage{original}
|
||||
}
|
||||
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal(original, &event); err != nil {
|
||||
return []json.RawMessage{original}
|
||||
}
|
||||
|
||||
rawItems, _ := event["attachments"].([]any)
|
||||
if len(rawItems) == 0 {
|
||||
return []json.RawMessage{original}
|
||||
}
|
||||
|
||||
changed := false
|
||||
for i, raw := range rawItems {
|
||||
item, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if ch, _ := item["content_hash"].(string); strings.TrimSpace(ch) != "" {
|
||||
continue
|
||||
}
|
||||
rawURL, _ := item["url"].(string)
|
||||
if rawURL == "" {
|
||||
rawURL, _ = item["path"].(string)
|
||||
}
|
||||
if rawURL = strings.TrimSpace(rawURL); rawURL == "" {
|
||||
continue
|
||||
}
|
||||
if ingested := h.ingestSingleAttachment(ctx, botID, rawURL, item); ingested != nil {
|
||||
rawItems[i] = ingested
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
h.logger.Debug("ws attachment_delta: no items needed ingestion", slog.String("bot_id", botID))
|
||||
return []json.RawMessage{original}
|
||||
}
|
||||
|
||||
h.logger.Info("ws attachment_delta: ingested attachments", slog.String("bot_id", botID), slog.Int("count", len(rawItems)))
|
||||
|
||||
out, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return []json.RawMessage{original}
|
||||
}
|
||||
return []json.RawMessage{out}
|
||||
}
|
||||
|
||||
func (h *LocalChannelHandler) ingestSingleAttachment(ctx context.Context, botID, rawURL string, item map[string]any) map[string]any {
|
||||
lower := strings.ToLower(rawURL)
|
||||
|
||||
if !strings.HasPrefix(lower, "data:") && !strings.HasPrefix(lower, "http://") && !strings.HasPrefix(lower, "https://") {
|
||||
asset, err := h.mediaService.IngestContainerFile(ctx, botID, rawURL)
|
||||
if err != nil {
|
||||
h.logger.Warn("ws ingest container file failed", slog.String("path", rawURL), slog.Any("error", err))
|
||||
return nil
|
||||
}
|
||||
return applyAssetToItem(item, botID, asset)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(lower, "data:") {
|
||||
mimeType := attachmentpkg.MimeFromDataURL(rawURL)
|
||||
decoded, err := attachmentpkg.DecodeBase64(rawURL, media.MaxAssetBytes)
|
||||
if err != nil {
|
||||
h.logger.Warn("ws decode data url failed", slog.Any("error", err))
|
||||
return nil
|
||||
}
|
||||
asset, err := h.mediaService.Ingest(ctx, media.IngestInput{
|
||||
BotID: botID,
|
||||
Mime: mimeType,
|
||||
Reader: decoded,
|
||||
MaxBytes: media.MaxAssetBytes,
|
||||
})
|
||||
if err != nil {
|
||||
h.logger.Warn("ws ingest data url failed", slog.Any("error", err))
|
||||
return nil
|
||||
}
|
||||
return applyAssetToItem(item, botID, asset)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// wsSynthesizeSpeech handles speech_delta events by synthesizing audio and
|
||||
// injecting attachment_delta events with the resulting voice attachments.
|
||||
func (h *LocalChannelHandler) wsSynthesizeSpeech(ctx context.Context, botID string, original json.RawMessage) []json.RawMessage {
|
||||
if h.ttsService == nil || h.ttsModelResolver == nil {
|
||||
h.logger.Warn("speech_delta received but TTS service not configured")
|
||||
return nil
|
||||
}
|
||||
|
||||
modelID, err := h.ttsModelResolver.ResolveTtsModelID(ctx, botID)
|
||||
if err != nil || strings.TrimSpace(modelID) == "" {
|
||||
h.logger.Warn("speech_delta: bot has no TTS model configured", slog.String("bot_id", botID))
|
||||
return nil
|
||||
}
|
||||
|
||||
var event struct {
|
||||
Speeches []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"speeches"`
|
||||
}
|
||||
if err := json.Unmarshal(original, &event); err != nil || len(event.Speeches) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var results []json.RawMessage
|
||||
for _, speech := range event.Speeches {
|
||||
text := strings.TrimSpace(speech.Text)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
audioData, contentType, synthErr := h.ttsService.Synthesize(ctx, modelID, text, nil)
|
||||
if synthErr != nil {
|
||||
h.logger.Warn("speech synthesis failed", slog.String("bot_id", botID), slog.Any("error", synthErr))
|
||||
continue
|
||||
}
|
||||
|
||||
att := h.buildTtsAttachment(ctx, botID, contentType, audioData)
|
||||
attachmentEvent, _ := json.Marshal(map[string]any{
|
||||
"type": "attachment_delta",
|
||||
"attachments": []any{att},
|
||||
})
|
||||
results = append(results, attachmentEvent)
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
func (h *LocalChannelHandler) buildTtsAttachment(ctx context.Context, botID, contentType string, audioData []byte) map[string]any {
|
||||
att := map[string]any{
|
||||
"type": "voice",
|
||||
"mime": contentType,
|
||||
"size": len(audioData),
|
||||
}
|
||||
|
||||
mimeType := attachmentpkg.NormalizeMime(contentType)
|
||||
if h.mediaService != nil {
|
||||
asset, err := h.mediaService.Ingest(ctx, media.IngestInput{
|
||||
BotID: botID,
|
||||
Mime: mimeType,
|
||||
Reader: bytes.NewReader(audioData),
|
||||
MaxBytes: media.MaxAssetBytes,
|
||||
})
|
||||
if err == nil {
|
||||
applyAssetToMap(att, botID, asset)
|
||||
return att
|
||||
}
|
||||
h.logger.Warn("ws tts ingest failed", slog.Any("error", err))
|
||||
}
|
||||
|
||||
att["url"] = "data:" + contentType + ";base64," + base64.StdEncoding.EncodeToString(audioData)
|
||||
return att
|
||||
}
|
||||
|
||||
func applyAssetToItem(item map[string]any, botID string, asset media.Asset) map[string]any {
|
||||
result := maps.Clone(item)
|
||||
delete(result, "path")
|
||||
result["url"] = ""
|
||||
applyAssetToMap(result, botID, asset)
|
||||
return result
|
||||
}
|
||||
|
||||
func applyAssetToMap(m map[string]any, botID string, asset media.Asset) {
|
||||
m["content_hash"] = asset.ContentHash
|
||||
m["metadata"] = map[string]any{
|
||||
"bot_id": botID,
|
||||
"storage_key": asset.StorageKey,
|
||||
}
|
||||
if mime, _ := m["mime"].(string); strings.TrimSpace(mime) == "" && asset.Mime != "" {
|
||||
m["mime"] = asset.Mime
|
||||
}
|
||||
if size, _ := m["size"].(float64); size == 0 && asset.SizeBytes > 0 {
|
||||
m["size"] = asset.SizeBytes
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user