mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
63fe03cfff
This reverts commit c9dcfe287f.
169 lines
5.5 KiB
Go
169 lines
5.5 KiB
Go
package tools
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"strings"
|
|
|
|
sdk "github.com/memohai/twilight-ai/sdk"
|
|
|
|
"github.com/memohai/memoh/internal/channel"
|
|
"github.com/memohai/memoh/internal/settings"
|
|
ttspkg "github.com/memohai/memoh/internal/tts"
|
|
)
|
|
|
|
const ttsMaxTextLen = 500
|
|
|
|
// TTSSender sends outbound messages through the channel manager.
|
|
type TTSSender interface {
|
|
Send(ctx context.Context, botID string, channelType channel.ChannelType, req channel.SendRequest) error
|
|
}
|
|
|
|
// TTSChannelResolver parses platform name to channel type.
|
|
type TTSChannelResolver interface {
|
|
ParseChannelType(raw string) (channel.ChannelType, error)
|
|
}
|
|
|
|
type TTSProvider struct {
|
|
logger *slog.Logger
|
|
settings *settings.Service
|
|
tts *ttspkg.Service
|
|
sender TTSSender
|
|
resolver TTSChannelResolver
|
|
}
|
|
|
|
func NewTTSProvider(log *slog.Logger, settingsSvc *settings.Service, ttsSvc *ttspkg.Service, sender TTSSender, resolver TTSChannelResolver) *TTSProvider {
|
|
if log == nil {
|
|
log = slog.Default()
|
|
}
|
|
return &TTSProvider{
|
|
logger: log.With(slog.String("tool", "tts")),
|
|
settings: settingsSvc,
|
|
tts: ttsSvc,
|
|
sender: sender,
|
|
resolver: resolver,
|
|
}
|
|
}
|
|
|
|
func (p *TTSProvider) Tools(ctx context.Context, session SessionContext) ([]sdk.Tool, error) {
|
|
if session.IsSubagent || p.settings == nil || p.tts == nil || p.sender == nil || p.resolver == nil {
|
|
return nil, nil
|
|
}
|
|
botID := strings.TrimSpace(session.BotID)
|
|
if botID == "" {
|
|
return nil, nil
|
|
}
|
|
botSettings, err := p.settings.GetBot(ctx, botID)
|
|
if err != nil {
|
|
return nil, nil
|
|
}
|
|
if strings.TrimSpace(botSettings.TtsModelID) == "" {
|
|
return nil, nil
|
|
}
|
|
sess := session
|
|
return []sdk.Tool{
|
|
{
|
|
Name: "speak",
|
|
Description: "Send a voice message. When target is omitted, speaks in the current conversation. When target is specified, sends to that channel/person. Synthesizes text to speech and delivers as audio.",
|
|
Parameters: map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{
|
|
"text": map[string]any{"type": "string", "description": "The text to convert to speech (max 500 characters)"},
|
|
"platform": map[string]any{"type": "string", "description": "Channel platform name. Defaults to current session platform."},
|
|
"target": map[string]any{"type": "string", "description": "Channel target (chat/group/thread ID). Optional — omit to speak in the current conversation. Use get_contacts to find targets for other conversations."},
|
|
"reply_to": map[string]any{"type": "string", "description": "Message ID to reply to. The voice message will reference this message on the platform."},
|
|
},
|
|
"required": []string{"text"},
|
|
},
|
|
Execute: func(execCtx *sdk.ToolExecContext, input any) (any, error) {
|
|
return p.execSpeak(execCtx.Context, sess, inputAsMap(input))
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (p *TTSProvider) execSpeak(ctx context.Context, session SessionContext, args map[string]any) (any, error) {
|
|
botID := strings.TrimSpace(session.BotID)
|
|
if botID == "" {
|
|
return nil, errors.New("bot_id is required")
|
|
}
|
|
text := strings.TrimSpace(StringArg(args, "text"))
|
|
if text == "" {
|
|
return nil, errors.New("text is required")
|
|
}
|
|
if len([]rune(text)) > ttsMaxTextLen {
|
|
return nil, errors.New("text too long, max 500 characters")
|
|
}
|
|
channelType, err := p.resolvePlatform(args, session)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
target := FirstStringArg(args, "target")
|
|
if target == "" {
|
|
target = strings.TrimSpace(session.ReplyTarget)
|
|
}
|
|
|
|
isSameConv := target == "" || session.IsSameConversation(channelType.String(), target)
|
|
|
|
botSettings, err := p.settings.GetBot(ctx, botID)
|
|
if err != nil {
|
|
return nil, errors.New("failed to load bot settings")
|
|
}
|
|
if botSettings.TtsModelID == "" {
|
|
return nil, errors.New("bot has no TTS model configured")
|
|
}
|
|
audioData, contentType, synthErr := p.tts.Synthesize(ctx, botSettings.TtsModelID, text, nil)
|
|
if synthErr != nil {
|
|
return nil, fmt.Errorf("speech synthesis failed: %s", synthErr.Error())
|
|
}
|
|
|
|
dataURL := fmt.Sprintf("data:%s;base64,%s", contentType, base64.StdEncoding.EncodeToString(audioData))
|
|
|
|
// Same-conversation: emit the synthesized audio as a voice attachment.
|
|
if isSameConv && session.Emitter != nil {
|
|
session.Emitter(ToolStreamEvent{
|
|
Type: StreamEventAttachment,
|
|
Attachments: []Attachment{{
|
|
Type: "voice",
|
|
URL: dataURL,
|
|
Mime: contentType,
|
|
Size: int64(len(audioData)),
|
|
}},
|
|
})
|
|
return map[string]any{
|
|
"ok": true,
|
|
"delivered": "current_conversation",
|
|
}, nil
|
|
}
|
|
if target == "" {
|
|
return nil, errors.New("target is required for cross-conversation speak")
|
|
}
|
|
msg := channel.Message{
|
|
Attachments: []channel.Attachment{{Type: channel.AttachmentVoice, URL: dataURL, Mime: contentType, Size: int64(len(audioData))}},
|
|
}
|
|
if replyTo := FirstStringArg(args, "reply_to"); replyTo != "" {
|
|
msg.Reply = &channel.ReplyRef{MessageID: replyTo}
|
|
}
|
|
if err := p.sender.Send(ctx, botID, channelType, channel.SendRequest{Target: target, Message: msg}); err != nil {
|
|
return nil, err
|
|
}
|
|
return map[string]any{
|
|
"ok": true, "bot_id": botID, "platform": channelType.String(), "target": target,
|
|
"instruction": "Voice message delivered successfully. You have completed your response. Please STOP now and do not call any more tools.",
|
|
}, nil
|
|
}
|
|
|
|
func (p *TTSProvider) resolvePlatform(args map[string]any, session SessionContext) (channel.ChannelType, error) {
|
|
platform := FirstStringArg(args, "platform")
|
|
if platform == "" {
|
|
platform = strings.TrimSpace(session.CurrentPlatform)
|
|
}
|
|
if platform == "" {
|
|
return "", errors.New("platform is required")
|
|
}
|
|
return p.resolver.ParseChannelType(platform)
|
|
}
|