feat(channel): structured tool-call IM display with edit-in-place

Introduce a new `show_tool_calls_in_im` bot setting plus a full overhaul of
how tool calls are surfaced in IM channels:

- Add per-bot setting + migration (0072) and expose through settings API /
  handlers / frontend SDK.
- Introduce a `toolCallDroppingStream` wrapper that filters tool_call_* events
  when the setting is off, keeping the rest of the stream intact.
- Add a shared `ToolCallPresentation` model (Header / Body blocks / Footer)
  with plain and Markdown renderers, and a per-tool formatter registry that
  produces rich output (e.g. `web_search` link lists, `list` directory
  previews, `exec` stdout/stderr tails) instead of raw JSON dumps.
- High-capability adapters (Telegram, Feishu, Matrix, Slack, Discord) now
  flush pre-text and then send ONE tool-call message per call, editing it
  in-place from `running` to `completed` / `failed`; mapping from callID to
  platform message ID is tracked per stream, with a fallback to a new
  message if the edit fails. Low-capability adapters (WeCom, QQ, DingTalk)
  keep posting a single final message, but now benefit from the same rich
  per-tool formatting.
- Suppress the early duplicate `EventToolCallStart` (from
  `sdk.ToolInputStartPart`) so that the SDK's final `StreamToolCallPart`
  remains the single source of truth for tool call start, preventing
  duplicated "running" bubbles in IM.
- Stop auto-populating `InputSummary` / `ResultSummary` after a per-tool
  formatter runs, which previously leaked the raw JSON result as a
  fallback footer underneath the formatted body.

Add regression tests for the formatters, the Markdown renderer, the
edit-in-place flow on Telegram/Matrix, and the JSON-leak guard on `list`.
This commit is contained in:
Acbox
2026-04-23 20:49:44 +08:00
parent 35118a81ad
commit 473d559042
36 changed files with 3688 additions and 77 deletions
+13
View File
@@ -374,6 +374,7 @@ func provideChannelRouter(
processor.SetDispatcher(inbound.NewRouteDispatcher(log))
processor.SetSpeechService(audioService, &settingsSpeechModelResolver{settings: settingsService})
processor.SetTranscriptionService(&settingsTranscriptionAdapter{audio: audioService}, &settingsTranscriptionModelResolver{settings: settingsService})
processor.SetIMDisplayOptions(&settingsIMDisplayOptions{settings: settingsService})
cmdHandler := command.NewHandler(
log,
&command.BotMemberRoleAdapter{BotService: botService},
@@ -597,6 +598,18 @@ func (r *settingsSpeechModelResolver) ResolveSpeechModelID(ctx context.Context,
return s.TtsModelID, nil
}
type settingsIMDisplayOptions struct {
settings *settings.Service
}
func (r *settingsIMDisplayOptions) ShowToolCallsInIM(ctx context.Context, botID string) (bool, error) {
s, err := r.settings.GetBot(ctx, botID)
if err != nil {
return false, err
}
return s.ShowToolCallsInIM, nil
}
type settingsTranscriptionModelResolver struct {
settings *settings.Service
}
+1
View File
@@ -179,6 +179,7 @@ CREATE TABLE IF NOT EXISTS bots (
transcription_model_id UUID REFERENCES models(id) ON DELETE SET NULL,
browser_context_id UUID REFERENCES browser_contexts(id) ON DELETE SET NULL,
persist_full_tool_results BOOLEAN NOT NULL DEFAULT false,
show_tool_calls_in_im BOOLEAN NOT NULL DEFAULT false,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
@@ -0,0 +1,5 @@
-- 0072_add_show_tool_calls_in_im (down)
-- NOTE: After rolling back this migration, re-run `sqlc generate` to update the
-- generated Go code in internal/db/sqlc/.
ALTER TABLE bots DROP COLUMN IF EXISTS show_tool_calls_in_im;
@@ -0,0 +1,5 @@
-- 0072_add_show_tool_calls_in_im
-- Add show_tool_calls_in_im column to bots table to control whether tool call
-- status messages are surfaced in IM channels.
ALTER TABLE bots ADD COLUMN IF NOT EXISTS show_tool_calls_in_im BOOLEAN NOT NULL DEFAULT false;
+7 -3
View File
@@ -21,7 +21,8 @@ SELECT
tts_models.id AS tts_model_id,
transcription_models.id AS transcription_model_id,
browser_contexts.id AS browser_context_id,
bots.persist_full_tool_results
bots.persist_full_tool_results,
bots.show_tool_calls_in_im
FROM bots
LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id
LEFT JOIN models AS heartbeat_models ON heartbeat_models.id = bots.heartbeat_model_id
@@ -59,9 +60,10 @@ WITH updated AS (
transcription_model_id = COALESCE(sqlc.narg(transcription_model_id)::uuid, bots.transcription_model_id),
browser_context_id = COALESCE(sqlc.narg(browser_context_id)::uuid, bots.browser_context_id),
persist_full_tool_results = sqlc.arg(persist_full_tool_results),
show_tool_calls_in_im = sqlc.arg(show_tool_calls_in_im),
updated_at = now()
WHERE bots.id = sqlc.arg(id)
RETURNING bots.id, bots.language, bots.reasoning_enabled, bots.reasoning_effort, bots.heartbeat_enabled, bots.heartbeat_interval, bots.heartbeat_prompt, bots.compaction_enabled, bots.compaction_threshold, bots.compaction_ratio, bots.timezone, bots.chat_model_id, bots.heartbeat_model_id, bots.compaction_model_id, bots.title_model_id, bots.image_model_id, bots.search_provider_id, bots.memory_provider_id, bots.tts_model_id, bots.transcription_model_id, bots.browser_context_id, bots.persist_full_tool_results
RETURNING bots.id, bots.language, bots.reasoning_enabled, bots.reasoning_effort, bots.heartbeat_enabled, bots.heartbeat_interval, bots.heartbeat_prompt, bots.compaction_enabled, bots.compaction_threshold, bots.compaction_ratio, bots.timezone, bots.chat_model_id, bots.heartbeat_model_id, bots.compaction_model_id, bots.title_model_id, bots.image_model_id, bots.search_provider_id, bots.memory_provider_id, bots.tts_model_id, bots.transcription_model_id, bots.browser_context_id, bots.persist_full_tool_results, bots.show_tool_calls_in_im
)
SELECT
updated.id AS bot_id,
@@ -85,7 +87,8 @@ SELECT
tts_models.id AS tts_model_id,
transcription_models.id AS transcription_model_id,
browser_contexts.id AS browser_context_id,
updated.persist_full_tool_results
updated.persist_full_tool_results,
updated.show_tool_calls_in_im
FROM updated
LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id
LEFT JOIN models AS heartbeat_models ON heartbeat_models.id = updated.heartbeat_model_id
@@ -120,5 +123,6 @@ SET language = 'auto',
transcription_model_id = NULL,
browser_context_id = NULL,
persist_full_tool_results = false,
show_tool_calls_in_im = false,
updated_at = now()
WHERE id = $1;
+8 -14
View File
@@ -300,16 +300,14 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
}
case *sdk.ToolInputStartPart:
// ToolInputStartPart fires before tool input args have streamed.
// We suppress it here because downstream consumers (IM adapters and
// Web UI) only care about the fully-assembled call announced by
// StreamToolCallPart below. Emitting a start event twice for the
// same CallID would produce duplicate "running" messages in IMs.
if textLoopProbeBuffer != nil {
textLoopProbeBuffer.Flush()
}
if !sendEvent(ctx, ch, StreamEvent{
Type: EventToolCallStart,
ToolName: p.ToolName,
ToolCallID: p.ID,
}) {
aborted = true
}
case *sdk.StreamToolCallPart:
if textLoopProbeBuffer != nil {
@@ -981,16 +979,12 @@ func (a *Agent) runMidStreamRetry(
aborted = true
}
case *sdk.ToolInputStartPart:
// See ToolInputStartPart note above: suppress the early start
// and rely on StreamToolCallPart (which carries the fully
// assembled Input) as the single source of truth.
if textLoopProbeBuffer != nil {
textLoopProbeBuffer.Flush()
}
if !sendEvent(sendCtx, ch, StreamEvent{
Type: EventToolCallStart,
ToolName: rp.ToolName,
ToolCallID: rp.ID,
}) {
aborted = true
}
case *sdk.StreamToolCallPart:
if textLoopProbeBuffer != nil {
textLoopProbeBuffer.Flush()
+13 -14
View File
@@ -46,7 +46,12 @@ func (*agentToolPlaceholderProvider) DoStream(_ context.Context, _ sdk.GenerateP
return &sdk.StreamResult{Stream: ch}, nil
}
func TestAgentStreamEmitsEarlyToolPlaceholderBeforeFullInput(t *testing.T) {
// TestAgentStreamEmitsToolCallStartOnceWithInput asserts that each tool call
// produces exactly one EventToolCallStart with the fully-assembled Input, even
// though the underlying SDK emits a preliminary ToolInputStartPart (no input)
// followed by a StreamToolCallPart (with input). Emitting two start events per
// call caused duplicate "running" messages in IM adapters.
func TestAgentStreamEmitsToolCallStartOnceWithInput(t *testing.T) {
t.Parallel()
a := New(Deps{})
@@ -64,26 +69,20 @@ func TestAgentStreamEmitsEarlyToolPlaceholderBeforeFullInput(t *testing.T) {
events = append(events, event)
}
if len(events) != 4 {
t.Fatalf("expected 4 events, got %d: %#v", len(events), events)
if len(events) != 3 {
t.Fatalf("expected 3 events, got %d: %#v", len(events), events)
}
if events[0].Type != EventAgentStart {
t.Fatalf("expected first event %q, got %#v", EventAgentStart, events[0])
}
if events[1].Type != EventToolCallStart || events[1].ToolCallID != "call-1" || events[1].ToolName != "write" {
t.Fatalf("unexpected placeholder tool event: %#v", events[1])
}
if events[1].Input != nil {
t.Fatalf("expected placeholder tool event to have nil input, got %#v", events[1].Input)
}
if events[2].Type != EventToolCallStart || events[2].ToolCallID != "call-1" {
t.Fatalf("unexpected full tool event: %#v", events[2])
t.Fatalf("unexpected tool call start event: %#v", events[1])
}
expectedInput := map[string]any{"path": "/tmp/long.txt"}
if !reflect.DeepEqual(events[2].Input, expectedInput) {
t.Fatalf("expected full tool event input %#v, got %#v", expectedInput, events[2].Input)
if !reflect.DeepEqual(events[1].Input, expectedInput) {
t.Fatalf("expected tool call start input %#v, got %#v", expectedInput, events[1].Input)
}
if events[3].Type != EventAgentEnd {
t.Fatalf("expected terminal event %q, got %#v", EventAgentEnd, events[3])
if events[2].Type != EventAgentEnd {
t.Fatalf("expected terminal event %q, got %#v", EventAgentEnd, events[2])
}
}
+12 -1
View File
@@ -45,7 +45,6 @@ func (s *dingtalkOutboundStream) Push(ctx context.Context, event channel.Prepare
channel.StreamEventPhaseStart,
channel.StreamEventPhaseEnd,
channel.StreamEventToolCallStart,
channel.StreamEventToolCallEnd,
channel.StreamEventAgentStart,
channel.StreamEventAgentEnd,
channel.StreamEventProcessingStarted,
@@ -54,6 +53,18 @@ func (s *dingtalkOutboundStream) Push(ctx context.Context, event channel.Prepare
// Non-content events: no-op.
return nil
case channel.StreamEventToolCallEnd:
text := strings.TrimSpace(channel.RenderToolCallMessage(channel.BuildToolCallEnd(event.ToolCall)))
if text == "" {
return nil
}
return s.adapter.Send(ctx, s.cfg, channel.PreparedOutboundMessage{
Target: s.target,
Message: channel.PreparedMessage{
Message: channel.Message{Format: channel.MessageFormatPlain, Text: text, Reply: s.reply},
},
})
case channel.StreamEventDelta:
if strings.TrimSpace(event.Delta) == "" || event.Phase == channel.StreamPhaseReasoning {
return nil
+92 -1
View File
@@ -25,6 +25,7 @@ type discordOutboundStream struct {
msgID string
buffer strings.Builder
lastUpdate time.Time
toolMessages map[string]string
}
func (s *discordOutboundStream) Push(ctx context.Context, event channel.PreparedStreamEvent) error {
@@ -103,7 +104,21 @@ func (s *discordOutboundStream) Push(ctx context.Context, event channel.Prepared
}
return nil
case channel.StreamEventAgentStart, channel.StreamEventAgentEnd, channel.StreamEventPhaseStart, channel.StreamEventPhaseEnd, channel.StreamEventProcessingStarted, channel.StreamEventProcessingCompleted, channel.StreamEventProcessingFailed, channel.StreamEventToolCallStart, channel.StreamEventToolCallEnd:
case channel.StreamEventToolCallStart:
s.mu.Lock()
bufText := strings.TrimSpace(s.buffer.String())
s.mu.Unlock()
if bufText != "" {
if err := s.finalizeMessage(bufText); err != nil {
return err
}
}
s.resetStreamState()
return s.sendToolCallMessage(event.ToolCall, channel.BuildToolCallStart(event.ToolCall))
case channel.StreamEventToolCallEnd:
return s.sendToolCallMessage(event.ToolCall, channel.BuildToolCallEnd(event.ToolCall))
case channel.StreamEventAgentStart, channel.StreamEventAgentEnd, channel.StreamEventPhaseStart, channel.StreamEventPhaseEnd, channel.StreamEventProcessingStarted, channel.StreamEventProcessingCompleted, channel.StreamEventProcessingFailed:
// Status events - no action needed for Discord
return nil
@@ -207,6 +222,82 @@ func (s *discordOutboundStream) finalizeMessage(text string) error {
return err
}
// sendToolCallMessage posts a Discord message on tool_call_start and edits it
// on tool_call_end so the running → completed/failed transition is contained
// in one visible post. Falls back to a new message if the edit fails.
func (s *discordOutboundStream) sendToolCallMessage(tc *channel.StreamToolCall, p channel.ToolCallPresentation) error {
text := truncateDiscordText(strings.TrimSpace(channel.RenderToolCallMessageMarkdown(p)))
if text == "" {
return nil
}
callID := ""
if tc != nil {
callID = strings.TrimSpace(tc.CallID)
}
if p.Status != channel.ToolCallStatusRunning && callID != "" {
if msgID, ok := s.lookupToolCallMessage(callID); ok {
if _, err := s.session.ChannelMessageEdit(s.target, msgID, text); err == nil {
s.forgetToolCallMessage(callID)
return nil
}
s.forgetToolCallMessage(callID)
}
}
var msg *discordgo.Message
var err error
if s.reply != nil && s.reply.MessageID != "" {
msg, err = s.session.ChannelMessageSendReply(s.target, text, &discordgo.MessageReference{
ChannelID: s.target,
MessageID: s.reply.MessageID,
})
} else {
msg, err = s.session.ChannelMessageSend(s.target, text)
}
if err != nil {
return err
}
if p.Status == channel.ToolCallStatusRunning && callID != "" && msg != nil && msg.ID != "" {
s.storeToolCallMessage(callID, msg.ID)
}
return nil
}
func (s *discordOutboundStream) lookupToolCallMessage(callID string) (string, bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.toolMessages == nil {
return "", false
}
v, ok := s.toolMessages[callID]
return v, ok
}
func (s *discordOutboundStream) storeToolCallMessage(callID, msgID string) {
s.mu.Lock()
defer s.mu.Unlock()
if s.toolMessages == nil {
s.toolMessages = make(map[string]string)
}
s.toolMessages[callID] = msgID
}
func (s *discordOutboundStream) forgetToolCallMessage(callID string) {
s.mu.Lock()
defer s.mu.Unlock()
if s.toolMessages == nil {
return
}
delete(s.toolMessages, callID)
}
func (s *discordOutboundStream) resetStreamState() {
s.mu.Lock()
s.msgID = ""
s.buffer.Reset()
s.lastUpdate = time.Time{}
s.mu.Unlock()
}
func (s *discordOutboundStream) sendAttachment(ctx context.Context, att channel.PreparedAttachment) error {
file, err := discordPreparedAttachmentToFile(ctx, att)
if err != nil {
+153 -2
View File
@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"regexp"
"strings"
"sync/atomic"
@@ -38,6 +39,7 @@ type feishuOutboundStream struct {
lastPatched string
patchInterval time.Duration
closed atomic.Bool
toolMessages map[string]string
}
func (s *feishuOutboundStream) Push(ctx context.Context, event channel.PreparedStreamEvent) error {
@@ -79,13 +81,13 @@ func (s *feishuOutboundStream) Push(ctx context.Context, event channel.PreparedS
s.lastPatched = ""
s.lastPatchedAt = time.Time{}
s.textBuffer.Reset()
return nil
return s.renderToolCallCard(ctx, event.ToolCall, channel.BuildToolCallStart(event.ToolCall))
case channel.StreamEventToolCallEnd:
s.cardMessageID = ""
s.lastPatched = ""
s.lastPatchedAt = time.Time{}
s.textBuffer.Reset()
return nil
return s.renderToolCallCard(ctx, event.ToolCall, channel.BuildToolCallEnd(event.ToolCall))
case channel.StreamEventAttachment:
if len(event.Attachments) == 0 {
return nil
@@ -367,6 +369,155 @@ func processFeishuCardMarkdown(s string) string {
return s
}
// renderToolCallCard posts a card for tool_call_start and patches the same
// card on tool_call_end, producing a single message whose status flips from
// "running" to "completed"/"failed". If no prior card is tracked (or patching
// fails), it falls back to creating a new card.
func (s *feishuOutboundStream) renderToolCallCard(
ctx context.Context,
tc *channel.StreamToolCall,
p channel.ToolCallPresentation,
) error {
text := strings.TrimSpace(channel.RenderToolCallMessageMarkdown(p))
if text == "" {
return nil
}
if s.client == nil {
return errors.New("feishu client not configured")
}
callID := ""
if tc != nil {
callID = strings.TrimSpace(tc.CallID)
}
if p.Status != channel.ToolCallStatusRunning && callID != "" {
if existing, ok := s.lookupToolCallMessage(callID); ok {
patchErr := s.patchToolCallCard(ctx, existing, text)
if patchErr == nil {
s.forgetToolCallMessage(callID)
return nil
}
if s.adapter != nil && s.adapter.logger != nil {
s.adapter.logger.Warn("feishu: tool-call end patch failed, falling back to new card",
slog.String("call_id", callID),
slog.Any("error", patchErr),
)
}
s.forgetToolCallMessage(callID)
}
}
msgID, err := s.sendToolCallCard(ctx, text)
if err != nil {
return err
}
if p.Status == channel.ToolCallStatusRunning && callID != "" && msgID != "" {
s.storeToolCallMessage(callID, msgID)
}
return nil
}
func (s *feishuOutboundStream) patchToolCallCard(ctx context.Context, messageID, text string) error {
content, err := buildFeishuCardContent(text)
if err != nil {
return err
}
patchReq := larkim.NewPatchMessageReqBuilder().
MessageId(messageID).
Body(larkim.NewPatchMessageReqBodyBuilder().
Content(content).
Build()).
Build()
patchResp, err := s.client.Im.Message.Patch(ctx, patchReq)
if err != nil {
return err
}
if patchResp == nil || !patchResp.Success() {
code, msg := 0, ""
if patchResp != nil {
code, msg = patchResp.Code, patchResp.Msg
}
return fmt.Errorf("feishu tool card patch failed: %s (code: %d)", msg, code)
}
return nil
}
func (s *feishuOutboundStream) lookupToolCallMessage(callID string) (string, bool) {
if s.toolMessages == nil {
return "", false
}
m, ok := s.toolMessages[callID]
return m, ok
}
func (s *feishuOutboundStream) storeToolCallMessage(callID, messageID string) {
if s.toolMessages == nil {
s.toolMessages = make(map[string]string)
}
s.toolMessages[callID] = messageID
}
func (s *feishuOutboundStream) forgetToolCallMessage(callID string) {
if s.toolMessages == nil {
return
}
delete(s.toolMessages, callID)
}
func (s *feishuOutboundStream) sendToolCallCard(ctx context.Context, text string) (string, error) {
content, err := buildFeishuCardContent(text)
if err != nil {
return "", err
}
if s.reply != nil && strings.TrimSpace(s.reply.MessageID) != "" {
replyReq := larkim.NewReplyMessageReqBuilder().
MessageId(strings.TrimSpace(s.reply.MessageID)).
Body(larkim.NewReplyMessageReqBodyBuilder().
Content(content).
MsgType(larkim.MsgTypeInteractive).
Uuid(uuid.NewString()).
Build()).
Build()
replyResp, err := s.client.Im.Message.Reply(ctx, replyReq)
if err != nil {
return "", err
}
if replyResp == nil || !replyResp.Success() {
code, msg := 0, ""
if replyResp != nil {
code, msg = replyResp.Code, replyResp.Msg
}
return "", fmt.Errorf("feishu tool card reply failed: %s (code: %d)", msg, code)
}
if replyResp.Data == nil || replyResp.Data.MessageId == nil {
return "", nil
}
return strings.TrimSpace(*replyResp.Data.MessageId), nil
}
createReq := larkim.NewCreateMessageReqBuilder().
ReceiveIdType(s.receiveType).
Body(larkim.NewCreateMessageReqBodyBuilder().
ReceiveId(s.receiveID).
MsgType(larkim.MsgTypeInteractive).
Content(content).
Uuid(uuid.NewString()).
Build()).
Build()
createResp, err := s.client.Im.Message.Create(ctx, createReq)
if err != nil {
return "", err
}
if createResp == nil || !createResp.Success() {
code, msg := 0, ""
if createResp != nil {
code, msg = createResp.Code, createResp.Msg
}
return "", fmt.Errorf("feishu tool card create failed: %s (code: %d)", msg, code)
}
if createResp.Data == nil || createResp.Data.MessageId == nil {
return "", nil
}
return strings.TrimSpace(*createResp.Data.MessageId), nil
}
func normalizeFeishuStreamText(text string) string {
trimmed := strings.TrimSpace(text)
if trimmed == "" {
+104 -2
View File
@@ -26,6 +26,7 @@ type matrixOutboundStream struct {
lastText string
lastFormat channel.MessageFormat
lastEditedAt time.Time
toolMessages map[string]string
}
func (s *matrixOutboundStream) Push(ctx context.Context, event channel.PreparedStreamEvent) error {
@@ -44,7 +45,6 @@ func (s *matrixOutboundStream) Push(ctx context.Context, event channel.PreparedS
switch event.Type {
case channel.StreamEventStatus,
channel.StreamEventPhaseStart,
channel.StreamEventToolCallEnd,
channel.StreamEventAgentStart,
channel.StreamEventAgentEnd,
channel.StreamEventProcessingStarted,
@@ -60,8 +60,18 @@ func (s *matrixOutboundStream) Push(ctx context.Context, event channel.PreparedS
s.mu.Unlock()
return s.upsertText(ctx, text, channel.MessageFormatPlain, true)
case channel.StreamEventToolCallStart:
s.mu.Lock()
bufText := strings.TrimSpace(s.rawBuffer.String())
s.mu.Unlock()
if bufText != "" {
if err := s.upsertText(ctx, bufText, channel.MessageFormatPlain, true); err != nil {
return err
}
}
s.resetMessageState()
return nil
return s.sendToolCallMessage(ctx, event.ToolCall, channel.BuildToolCallStart(event.ToolCall))
case channel.StreamEventToolCallEnd:
return s.sendToolCallMessage(ctx, event.ToolCall, channel.BuildToolCallEnd(event.ToolCall))
case channel.StreamEventDelta:
if event.Phase == channel.StreamPhaseReasoning || event.Delta == "" {
return nil
@@ -186,6 +196,98 @@ func (s *matrixOutboundStream) upsertText(ctx context.Context, text string, form
return nil
}
// sendToolCallMessage posts a room event on tool_call_start and sends an
// m.replace edit event on tool_call_end so both lifecycle states share a
// single visible message. If no prior event is tracked (or the edit fails),
// it falls back to creating a new event.
func (s *matrixOutboundStream) sendToolCallMessage(
ctx context.Context,
tc *channel.StreamToolCall,
p channel.ToolCallPresentation,
) error {
text := strings.TrimSpace(channel.RenderToolCallMessageMarkdown(p))
format := channel.MessageFormatMarkdown
if text == "" {
text = strings.TrimSpace(channel.RenderToolCallMessage(p))
format = channel.MessageFormatPlain
}
if text == "" {
return nil
}
s.mu.Lock()
roomID := s.roomID
reply := s.reply
s.mu.Unlock()
if roomID == "" {
resolved, err := s.adapter.resolveRoomTarget(ctx, s.cfg, s.target)
if err != nil {
return err
}
roomID = resolved
s.mu.Lock()
s.roomID = roomID
s.mu.Unlock()
}
callID := ""
if tc != nil {
callID = strings.TrimSpace(tc.CallID)
}
if p.Status != channel.ToolCallStatusRunning && callID != "" {
if eventID, ok := s.lookupToolCallMessage(callID); ok {
editMsg := channel.Message{Text: text, Format: format}
if _, err := s.adapter.sendTextEvent(ctx, s.cfg, roomID, buildMatrixMessageContent(editMsg, true, eventID)); err == nil {
s.forgetToolCallMessage(callID)
return nil
}
s.forgetToolCallMessage(callID)
}
}
msg := channel.Message{
Text: text,
Format: format,
Reply: reply,
}
eventID, err := s.adapter.sendTextEvent(ctx, s.cfg, roomID, buildMatrixMessageContent(msg, false, ""))
if err != nil {
return err
}
if p.Status == channel.ToolCallStatusRunning && callID != "" && eventID != "" {
s.storeToolCallMessage(callID, eventID)
}
return nil
}
func (s *matrixOutboundStream) lookupToolCallMessage(callID string) (string, bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.toolMessages == nil {
return "", false
}
v, ok := s.toolMessages[callID]
return v, ok
}
func (s *matrixOutboundStream) storeToolCallMessage(callID, eventID string) {
s.mu.Lock()
defer s.mu.Unlock()
if s.toolMessages == nil {
s.toolMessages = make(map[string]string)
}
s.toolMessages[callID] = eventID
}
func (s *matrixOutboundStream) forgetToolCallMessage(callID string) {
s.mu.Lock()
defer s.mu.Unlock()
if s.toolMessages == nil {
return
}
delete(s.toolMessages, callID)
}
func (s *matrixOutboundStream) resetMessageState() {
s.mu.Lock()
s.originalEventID = ""
@@ -64,7 +64,7 @@ func TestMatrixStreamDoesNotSendDeltaBeforeTextPhaseEnds(t *testing.T) {
}
}
func TestMatrixStreamDropsBufferedTextWhenToolStarts(t *testing.T) {
func TestMatrixStreamFlushesBufferedTextWhenToolStarts(t *testing.T) {
requests := 0
adapter := NewMatrixAdapter(nil)
adapter.httpClient = &http.Client{Transport: roundTripFunc(func(_ *http.Request) (*http.Response, error) {
@@ -89,11 +89,19 @@ func TestMatrixStreamDropsBufferedTextWhenToolStarts(t *testing.T) {
if err := stream.Push(ctx, mustPreparedMatrixEvent(t, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "I will inspect first", Phase: channel.StreamPhaseText})); err != nil {
t.Fatalf("push delta: %v", err)
}
if err := stream.Push(ctx, mustPreparedMatrixEvent(t, channel.StreamEvent{Type: channel.StreamEventToolCallStart})); err != nil {
tcStart := &channel.StreamToolCall{Name: "read", CallID: "c1", Input: map[string]any{"path": "/tmp/a"}}
tcEnd := &channel.StreamToolCall{Name: "read", CallID: "c1", Input: map[string]any{"path": "/tmp/a"}, Result: map[string]any{"ok": true}}
if err := stream.Push(ctx, mustPreparedMatrixEvent(t, channel.StreamEvent{Type: channel.StreamEventToolCallStart, ToolCall: tcStart})); err != nil {
t.Fatalf("push tool call start: %v", err)
}
if requests != 0 {
t.Fatalf("expected no request for discarded pre-tool text, got %d", requests)
if requests != 2 {
t.Fatalf("expected flush + start messages, got %d", requests)
}
if err := stream.Push(ctx, mustPreparedMatrixEvent(t, channel.StreamEvent{Type: channel.StreamEventToolCallEnd, ToolCall: tcEnd})); err != nil {
t.Fatalf("push tool call end: %v", err)
}
if requests != 3 {
t.Fatalf("expected start + end tool messages plus flush, got %d", requests)
}
if err := stream.Push(ctx, mustPreparedMatrixEvent(t, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "Final answer", Phase: channel.StreamPhaseText})); err != nil {
t.Fatalf("push final delta: %v", err)
@@ -101,8 +109,8 @@ func TestMatrixStreamDropsBufferedTextWhenToolStarts(t *testing.T) {
if err := stream.Push(ctx, mustPreparedMatrixEvent(t, channel.StreamEvent{Type: channel.StreamEventFinal, Final: &channel.StreamFinalizePayload{Message: channel.Message{Text: "Final answer"}}})); err != nil {
t.Fatalf("push final: %v", err)
}
if requests != 1 {
t.Fatalf("expected only final visible message to be sent, got %d", requests)
if requests != 4 {
t.Fatalf("expected final visible message after tool call, got %d", requests)
}
}
+11 -1
View File
@@ -62,13 +62,23 @@ func (s *qqOutboundStream) Push(ctx context.Context, event channel.PreparedStrea
channel.StreamEventPhaseStart,
channel.StreamEventPhaseEnd,
channel.StreamEventToolCallStart,
channel.StreamEventToolCallEnd,
channel.StreamEventAgentStart,
channel.StreamEventAgentEnd,
channel.StreamEventProcessingStarted,
channel.StreamEventProcessingCompleted,
channel.StreamEventProcessingFailed:
return nil
case channel.StreamEventToolCallEnd:
text := strings.TrimSpace(channel.RenderToolCallMessage(channel.BuildToolCallEnd(event.ToolCall)))
if text == "" {
return nil
}
return s.send(ctx, channel.PreparedOutboundMessage{
Target: s.target,
Message: channel.PreparedMessage{
Message: channel.Message{Format: channel.MessageFormatPlain, Text: text, Reply: s.reply},
},
})
case channel.StreamEventDelta:
if event.Phase == channel.StreamPhaseReasoning || event.Delta == "" {
return nil
+91 -1
View File
@@ -34,6 +34,7 @@ type slackOutboundStream struct {
lastSent string
lastUpdate time.Time
nextUpdate time.Time
toolMessages map[string]string
}
var _ channel.PreparedOutboundStream = (*slackOutboundStream)(nil)
@@ -122,11 +123,26 @@ func (s *slackOutboundStream) Push(ctx context.Context, event channel.PreparedSt
}
return nil
case channel.StreamEventToolCallStart:
s.mu.Lock()
bufText := strings.TrimSpace(s.buffer.String())
s.mu.Unlock()
if bufText != "" {
if err := s.finalizeMessage(ctx, bufText); err != nil {
return err
}
} else if err := s.clearPlaceholder(ctx); err != nil {
return err
}
s.resetStreamState()
return s.sendToolCallMessage(ctx, event.ToolCall, channel.BuildToolCallStart(event.ToolCall))
case channel.StreamEventToolCallEnd:
return s.sendToolCallMessage(ctx, event.ToolCall, channel.BuildToolCallEnd(event.ToolCall))
case channel.StreamEventAgentStart, channel.StreamEventAgentEnd,
channel.StreamEventPhaseStart, channel.StreamEventPhaseEnd,
channel.StreamEventProcessingStarted, channel.StreamEventProcessingCompleted,
channel.StreamEventProcessingFailed,
channel.StreamEventToolCallStart, channel.StreamEventToolCallEnd,
channel.StreamEventReaction, channel.StreamEventSpeech:
return nil
@@ -310,6 +326,80 @@ func (s *slackOutboundStream) sendAttachment(ctx context.Context, att channel.Pr
return s.adapter.uploadPreparedAttachment(ctx, s.api, s.target, threadTS, att)
}
// sendToolCallMessage posts a message for tool_call_start and updates the same
// message on tool_call_end via chat.update so the running → completed/failed
// transition shares one visible post. If the edit fails (or no prior message
// is tracked), it falls back to posting a new message.
func (s *slackOutboundStream) sendToolCallMessage(
ctx context.Context,
tc *channel.StreamToolCall,
p channel.ToolCallPresentation,
) error {
text := truncateSlackText(strings.TrimSpace(channel.RenderToolCallMessageMarkdown(p)))
if text == "" {
return nil
}
callID := ""
if tc != nil {
callID = strings.TrimSpace(tc.CallID)
}
if p.Status != channel.ToolCallStatusRunning && callID != "" {
if ts, ok := s.lookupToolCallMessage(callID); ok {
if err := s.updateMessageTextWithRetry(ctx, ts, text); err == nil {
s.forgetToolCallMessage(callID)
return nil
}
s.forgetToolCallMessage(callID)
}
}
ts, err := s.postMessageWithRetry(ctx, text)
if err != nil {
return err
}
if p.Status == channel.ToolCallStatusRunning && callID != "" && ts != "" {
s.storeToolCallMessage(callID, ts)
}
return nil
}
func (s *slackOutboundStream) lookupToolCallMessage(callID string) (string, bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.toolMessages == nil {
return "", false
}
v, ok := s.toolMessages[callID]
return v, ok
}
func (s *slackOutboundStream) storeToolCallMessage(callID, ts string) {
s.mu.Lock()
defer s.mu.Unlock()
if s.toolMessages == nil {
s.toolMessages = make(map[string]string)
}
s.toolMessages[callID] = ts
}
func (s *slackOutboundStream) forgetToolCallMessage(callID string) {
s.mu.Lock()
defer s.mu.Unlock()
if s.toolMessages == nil {
return
}
delete(s.toolMessages, callID)
}
func (s *slackOutboundStream) resetStreamState() {
s.mu.Lock()
s.msgTS = ""
s.buffer.Reset()
s.lastSent = ""
s.lastUpdate = time.Time{}
s.nextUpdate = time.Time{}
s.mu.Unlock()
}
func (s *slackOutboundStream) postMessageWithRetry(ctx context.Context, text string) (string, error) {
opts := []slackapi.MsgOption{
slackapi.MsgOptionText(text, false),
+123 -4
View File
@@ -39,6 +39,15 @@ type telegramOutboundStream struct {
streamMsgID int
lastEdited string
lastEditedAt time.Time
toolMessages map[string]telegramToolCallMessage
}
// telegramToolCallMessage tracks the message posted for a tool call's
// "running" state so the matching tool_call_end event can edit the same
// message in-place to show the "completed" / "failed" state.
type telegramToolCallMessage struct {
chatID int64
msgID int
}
func (s *telegramOutboundStream) getBot(_ context.Context) (bot *tgbotapi.BotAPI, err error) {
@@ -306,7 +315,7 @@ func (s *telegramOutboundStream) deliverFinalText(ctx context.Context, text, par
return s.editStreamMessageFinal(ctx, text)
}
func (s *telegramOutboundStream) pushToolCallStart(ctx context.Context) error {
func (s *telegramOutboundStream) pushToolCallStart(ctx context.Context, tc *channel.StreamToolCall) error {
s.mu.Lock()
bufText := strings.TrimSpace(s.buf.String())
hasMsg := s.streamMsgID != 0
@@ -327,8 +336,119 @@ func (s *telegramOutboundStream) pushToolCallStart(ctx context.Context) error {
_ = s.editStreamMessageFinal(ctx, bufText)
}
s.resetStreamState()
return s.sendToolCallMessage(ctx, tc, channel.BuildToolCallStart(tc))
}
func (s *telegramOutboundStream) pushToolCallEnd(ctx context.Context, tc *channel.StreamToolCall) error {
s.resetStreamState()
return s.sendToolCallMessage(ctx, tc, channel.BuildToolCallEnd(tc))
}
// renderToolCallPresentation renders a tool-call presentation to IM-ready
// text and parseMode. It prefers Markdown→HTML; falls back to plain text when
// Markdown conversion yields an empty parseMode.
func renderToolCallPresentation(p channel.ToolCallPresentation) (string, string) {
rendered := strings.TrimSpace(channel.RenderToolCallMessageMarkdown(p))
if rendered == "" {
return "", ""
}
text, parseMode := formatTelegramOutput(rendered, channel.MessageFormatMarkdown)
if parseMode == "" {
text = strings.TrimSpace(channel.RenderToolCallMessage(p))
}
return text, parseMode
}
// sendToolCallMessage renders the tool-call presentation. For tool_call_start
// it posts a new message and records the callID→message mapping. For
// tool_call_end it edits the previously-posted message to flip the status to
// completed/failed. If no prior message is tracked (or editing fails), it
// falls back to sending a fresh message.
func (s *telegramOutboundStream) sendToolCallMessage(
ctx context.Context,
tc *channel.StreamToolCall,
p channel.ToolCallPresentation,
) error {
text, parseMode := renderToolCallPresentation(p)
if text == "" {
return nil
}
callID := ""
if tc != nil {
callID = strings.TrimSpace(tc.CallID)
}
if p.Status != channel.ToolCallStatusRunning && callID != "" {
if existing, ok := s.lookupToolCallMessage(callID); ok {
if err := s.adapter.waitStreamLimit(ctx); err != nil {
return err
}
bot, err := s.getBot(ctx)
if err != nil {
return err
}
editErr := error(nil)
if testEditFunc != nil {
editErr = testEditFunc(bot, existing.chatID, existing.msgID, text, parseMode)
} else {
editErr = editTelegramMessageText(bot, existing.chatID, existing.msgID, text, parseMode)
}
if editErr == nil {
s.forgetToolCallMessage(callID)
return nil
}
if s.adapter != nil && s.adapter.logger != nil {
s.adapter.logger.Warn("telegram: tool-call end edit failed, falling back to new message",
slog.String("call_id", callID),
slog.Any("error", editErr),
)
}
s.forgetToolCallMessage(callID)
}
}
if err := s.adapter.waitStreamLimit(ctx); err != nil {
return err
}
bot, replyTo, err := s.getBotAndReply(ctx)
if err != nil {
return err
}
chatID, msgID, sendErr := sendTelegramTextReturnMessage(bot, s.target, text, replyTo, parseMode)
if sendErr != nil {
return sendErr
}
if p.Status == channel.ToolCallStatusRunning && callID != "" {
s.storeToolCallMessage(callID, telegramToolCallMessage{chatID: chatID, msgID: msgID})
}
return nil
}
func (s *telegramOutboundStream) lookupToolCallMessage(callID string) (telegramToolCallMessage, bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.toolMessages == nil {
return telegramToolCallMessage{}, false
}
m, ok := s.toolMessages[callID]
return m, ok
}
func (s *telegramOutboundStream) storeToolCallMessage(callID string, m telegramToolCallMessage) {
s.mu.Lock()
defer s.mu.Unlock()
if s.toolMessages == nil {
s.toolMessages = make(map[string]telegramToolCallMessage)
}
s.toolMessages[callID] = m
}
func (s *telegramOutboundStream) forgetToolCallMessage(callID string) {
s.mu.Lock()
defer s.mu.Unlock()
if s.toolMessages == nil {
return
}
delete(s.toolMessages, callID)
}
func (s *telegramOutboundStream) pushAttachment(ctx context.Context, event channel.PreparedStreamEvent) error {
if len(event.Attachments) == 0 {
@@ -489,10 +609,9 @@ func (s *telegramOutboundStream) Push(ctx context.Context, event channel.Prepare
}
switch event.Type {
case channel.StreamEventToolCallStart:
return s.pushToolCallStart(ctx)
return s.pushToolCallStart(ctx, event.ToolCall)
case channel.StreamEventToolCallEnd:
s.resetStreamState()
return nil
return s.pushToolCallEnd(ctx, event.ToolCall)
case channel.StreamEventAttachment:
return s.pushAttachment(ctx, event)
case channel.StreamEventPhaseEnd:
@@ -553,6 +553,149 @@ func TestDraftMode_PhaseEndTextIsNoOp(t *testing.T) {
}
}
// TestToolCallFlow_ThreeMessagesPerCall verifies that a single tool call
// combined with pre-existing streamed text produces three distinct messages:
// (1) flush of buffered pre-text, (2) running state for the tool call,
// (3) completed / failed state for the tool call. The streaming state must be
// reset between the flush and the start, then tool_call_end edits the running
// message in place so one tool call produces exactly one tool-call message.
func TestToolCallFlow_FlushPreTextAndEditRunning(t *testing.T) {
adapter := NewTelegramAdapter(nil)
s := &telegramOutboundStream{
adapter: adapter,
cfg: channel.ChannelConfig{ID: "test", Credentials: map[string]any{"bot_token": "fake"}},
target: "123",
streamChatID: 42,
streamMsgID: 7,
}
s.buf.WriteString("pre-tool text")
ctx := context.Background()
origGetBot := getOrCreateBotForTest
origSendText := sendTextForTest
origEdit := testEditFunc
getOrCreateBotForTest = func(_ *TelegramAdapter, _, _ string) (*tgbotapi.BotAPI, error) {
return &tgbotapi.BotAPI{Token: "fake"}, nil
}
var sentTexts []string
var msgIDCounter int
sendTextForTest = func(_ *tgbotapi.BotAPI, _ string, text string, _ int, _ string) (int64, int, error) {
sentTexts = append(sentTexts, text)
msgIDCounter++
return 42, msgIDCounter, nil
}
var editTexts []string
testEditFunc = func(_ *tgbotapi.BotAPI, _ int64, _ int, text string, _ string) error {
editTexts = append(editTexts, text)
return nil
}
defer func() {
getOrCreateBotForTest = origGetBot
sendTextForTest = origSendText
testEditFunc = origEdit
}()
tcStart := &channel.StreamToolCall{Name: "read", CallID: "call_1", Input: map[string]any{"path": "/tmp/a"}}
tcEnd := &channel.StreamToolCall{Name: "read", CallID: "call_1", Input: map[string]any{"path": "/tmp/a"}, Result: map[string]any{"ok": true}}
if err := s.Push(ctx, mustPreparedTelegramEvent(t, channel.StreamEvent{Type: channel.StreamEventToolCallStart, ToolCall: tcStart})); err != nil {
t.Fatalf("push start: %v", err)
}
s.mu.Lock()
streamMsgAfterStart := s.streamMsgID
bufAfterStart := s.buf.String()
s.mu.Unlock()
if streamMsgAfterStart != 0 {
t.Fatalf("streamMsgID should be reset after tool_call_start, got %d", streamMsgAfterStart)
}
if bufAfterStart != "" {
t.Fatalf("buf should be reset after tool_call_start, got %q", bufAfterStart)
}
if err := s.Push(ctx, mustPreparedTelegramEvent(t, channel.StreamEvent{Type: channel.StreamEventToolCallEnd, ToolCall: tcEnd})); err != nil {
t.Fatalf("push end: %v", err)
}
// Edits: 1 for pre-text flush, 1 for running → completed.
if len(editTexts) != 2 {
t.Fatalf("expected exactly 2 edits (pre-text flush + running→completed), got %d: %v", len(editTexts), editTexts)
}
if !strings.Contains(editTexts[0], "pre-tool text") {
t.Fatalf("first edit should be the pre-text flush: %q", editTexts[0])
}
if !strings.Contains(editTexts[1], "completed") {
t.Fatalf("second edit should flip the tool call to completed: %q", editTexts[1])
}
if len(sentTexts) != 1 {
t.Fatalf("expected exactly 1 send (running), got %d: %v", len(sentTexts), sentTexts)
}
if !strings.Contains(sentTexts[0], "running") {
t.Fatalf("only send should be the running state: %q", sentTexts[0])
}
}
// TestToolCallFlow_NoPreTextEditsRunningInPlace verifies that when no text
// stream is active, tool_call_start sends the running message and
// tool_call_end edits it in place — no pre-text flush, one send, one edit.
func TestToolCallFlow_NoPreTextEditsRunningInPlace(t *testing.T) {
adapter := NewTelegramAdapter(nil)
s := &telegramOutboundStream{
adapter: adapter,
cfg: channel.ChannelConfig{ID: "test", Credentials: map[string]any{"bot_token": "fake"}},
target: "123",
streamChatID: 42,
}
ctx := context.Background()
origGetBot := getOrCreateBotForTest
origSendText := sendTextForTest
origEdit := testEditFunc
getOrCreateBotForTest = func(_ *TelegramAdapter, _, _ string) (*tgbotapi.BotAPI, error) {
return &tgbotapi.BotAPI{Token: "fake"}, nil
}
var sentTexts []string
var msgIDCounter int
sendTextForTest = func(_ *tgbotapi.BotAPI, _ string, text string, _ int, _ string) (int64, int, error) {
sentTexts = append(sentTexts, text)
msgIDCounter++
return 42, msgIDCounter, nil
}
var editTexts []string
testEditFunc = func(_ *tgbotapi.BotAPI, _ int64, _ int, text string, _ string) error {
editTexts = append(editTexts, text)
return nil
}
defer func() {
getOrCreateBotForTest = origGetBot
sendTextForTest = origSendText
testEditFunc = origEdit
}()
tcStart := &channel.StreamToolCall{Name: "read", CallID: "call_1", Input: map[string]any{"path": "/tmp/a"}}
tcEnd := &channel.StreamToolCall{Name: "read", CallID: "call_1", Result: map[string]any{"ok": true}}
if err := s.Push(ctx, mustPreparedTelegramEvent(t, channel.StreamEvent{Type: channel.StreamEventToolCallStart, ToolCall: tcStart})); err != nil {
t.Fatalf("push start: %v", err)
}
if err := s.Push(ctx, mustPreparedTelegramEvent(t, channel.StreamEvent{Type: channel.StreamEventToolCallEnd, ToolCall: tcEnd})); err != nil {
t.Fatalf("push end: %v", err)
}
if len(sentTexts) != 1 {
t.Fatalf("expected exactly 1 send (running), got %d: %v", len(sentTexts), sentTexts)
}
if !strings.Contains(sentTexts[0], "running") {
t.Fatalf("only send should be the running state: %q", sentTexts[0])
}
if len(editTexts) != 1 {
t.Fatalf("expected exactly 1 edit (running→completed), got %d: %v", len(editTexts), editTexts)
}
if !strings.Contains(editTexts[0], "completed") {
t.Fatalf("edit should flip the tool call to completed: %q", editTexts[0])
}
}
func TestDraftMode_ToolCallStartSendsPermanentMessage(t *testing.T) {
adapter := NewTelegramAdapter(nil)
s := &telegramOutboundStream{
+23 -1
View File
@@ -278,13 +278,14 @@ func (s *wecomOutboundStream) Push(ctx context.Context, event channel.PreparedSt
channel.StreamEventPhaseStart,
channel.StreamEventPhaseEnd,
channel.StreamEventToolCallStart,
channel.StreamEventToolCallEnd,
channel.StreamEventAgentStart,
channel.StreamEventAgentEnd,
channel.StreamEventProcessingStarted,
channel.StreamEventProcessingCompleted,
channel.StreamEventProcessingFailed:
return nil
case channel.StreamEventToolCallEnd:
return s.sendToolCallSummary(ctx, event.ToolCall)
case channel.StreamEventDelta:
if strings.TrimSpace(event.Delta) == "" || event.Phase == channel.StreamPhaseReasoning {
return nil
@@ -363,6 +364,27 @@ func (s *wecomOutboundStream) flush(ctx context.Context) error {
return nil
}
// sendToolCallSummary emits a best-effort terminal summary of a tool call.
// WeCom lacks message-edit APIs in the one-shot send path, so only the
// completed / failed state is surfaced — the running state is intentionally
// suppressed to avoid duplicate messages.
func (s *wecomOutboundStream) sendToolCallSummary(ctx context.Context, tc *channel.StreamToolCall) error {
if s.finalSent.Load() {
return nil
}
text := strings.TrimSpace(channel.RenderToolCallMessage(channel.BuildToolCallEnd(tc)))
if text == "" {
return nil
}
msg := channel.PreparedMessage{
Message: channel.Message{Format: channel.MessageFormatPlain, Text: text},
}
return s.adapter.Send(ctx, s.cfg, channel.PreparedOutboundMessage{
Target: s.target,
Message: msg,
})
}
func (s *wecomOutboundStream) pushPreview(ctx context.Context) error {
if s.finalSent.Load() {
return nil
+54
View File
@@ -93,6 +93,15 @@ type SessionEnsurer interface {
CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (SessionResult, error)
}
// IMDisplayOptionsReader exposes bot-level IM display preferences.
// Implementations typically adapt the settings service.
type IMDisplayOptionsReader interface {
// ShowToolCallsInIM reports whether tool_call lifecycle events should
// reach IM adapters for the given bot. Returns false by default when the
// bot or its settings cannot be resolved.
ShowToolCallsInIM(ctx context.Context, botID string) (bool, error)
}
// SessionResult carries the minimum fields needed from a session.
type SessionResult struct {
ID string
@@ -124,6 +133,7 @@ type ChannelInboundProcessor struct {
pipeline *pipelinepkg.Pipeline
eventStore *pipelinepkg.EventStore
discussDriver *pipelinepkg.DiscussDriver
imDisplayOptions IMDisplayOptionsReader
// activeStreams maps "botID:routeID" to a context.CancelFunc for the
// currently running agent stream. Used by /stop to abort generation
@@ -259,6 +269,42 @@ func (p *ChannelInboundProcessor) SetDispatcher(dispatcher *RouteDispatcher) {
p.dispatcher = dispatcher
}
// SetIMDisplayOptions configures the reader used to gate IM-facing stream
// events (e.g. tool call lifecycle) on bot-level display preferences. When
// nil, tool call events are always dropped before reaching IM adapters.
func (p *ChannelInboundProcessor) SetIMDisplayOptions(reader IMDisplayOptionsReader) {
if p == nil {
return
}
p.imDisplayOptions = reader
}
// shouldShowToolCallsInIM reports whether tool_call_start / tool_call_end
// events should reach the IM adapter for the given bot. Failures and missing
// configuration default to false so tool calls remain hidden unless explicitly
// enabled.
func (p *ChannelInboundProcessor) shouldShowToolCallsInIM(ctx context.Context, botID string) bool {
if p == nil || p.imDisplayOptions == nil {
return false
}
botID = strings.TrimSpace(botID)
if botID == "" {
return false
}
show, err := p.imDisplayOptions.ShowToolCallsInIM(ctx, botID)
if err != nil {
if p.logger != nil {
p.logger.Debug(
"show_tool_calls_in_im lookup failed, defaulting to hidden",
slog.String("bot_id", botID),
slog.Any("error", err),
)
}
return false
}
return show
}
// HandleInbound processes an inbound channel message through identity resolution and chat gateway.
func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, sender channel.StreamReplySender) (retErr error) {
if p.runner == nil {
@@ -721,6 +767,14 @@ startStream:
}
}()
// For non-local channels (IM adapters), optionally drop tool_call events
// before they reach the adapter when the bot's show_tool_calls_in_im
// setting is off. The filter sits inside the TeeStream so WebUI
// observers still receive the full event stream.
if !isLocalChannelType(msg.Channel) && !p.shouldShowToolCallsInIM(ctx, identity.BotID) {
stream = channel.NewToolCallDroppingStream(stream)
}
// For non-local channels, wrap the stream so events are mirrored to the
// RouteHub (and thus to Web UI and other local subscribers).
if p.observer != nil && !isLocalChannelType(msg.Channel) {
+39
View File
@@ -0,0 +1,39 @@
package channel
import "context"
// toolCallDroppingStream drops tool_call_start / tool_call_end events while
// forwarding every other event to the wrapped primary stream unchanged. This
// is used to gate IM-facing streams when a bot's show_tool_calls_in_im setting
// is off: the IM adapter stops receiving tool lifecycle events, but any
// upstream TeeStream observer (e.g. the WebUI hub) still sees them because
// the tee mirrors events independently.
type toolCallDroppingStream struct {
primary OutboundStream
}
// NewToolCallDroppingStream wraps primary and drops tool_call_start /
// tool_call_end events. When primary is nil the function returns nil.
func NewToolCallDroppingStream(primary OutboundStream) OutboundStream {
if primary == nil {
return nil
}
return &toolCallDroppingStream{primary: primary}
}
func (s *toolCallDroppingStream) Push(ctx context.Context, event StreamEvent) error {
if s == nil || s.primary == nil {
return nil
}
if event.Type == StreamEventToolCallStart || event.Type == StreamEventToolCallEnd {
return nil
}
return s.primary.Push(ctx, event)
}
func (s *toolCallDroppingStream) Close(ctx context.Context) error {
if s == nil || s.primary == nil {
return nil
}
return s.primary.Close(ctx)
}
+89
View File
@@ -0,0 +1,89 @@
package channel
import (
"context"
"errors"
"testing"
)
type recordingOutboundStream struct {
events []StreamEvent
closed bool
err error
}
func (r *recordingOutboundStream) Push(_ context.Context, event StreamEvent) error {
if r.err != nil {
return r.err
}
r.events = append(r.events, event)
return nil
}
func (r *recordingOutboundStream) Close(_ context.Context) error {
r.closed = true
return nil
}
func TestToolCallDroppingStreamFiltersToolEvents(t *testing.T) {
t.Parallel()
sink := &recordingOutboundStream{}
stream := NewToolCallDroppingStream(sink)
if stream == nil {
t.Fatalf("expected non-nil wrapper")
}
ctx := context.Background()
events := []StreamEvent{
{Type: StreamEventDelta, Delta: "hi"},
{Type: StreamEventToolCallStart, ToolCall: &StreamToolCall{Name: "read", CallID: "c1"}},
{Type: StreamEventToolCallEnd, ToolCall: &StreamToolCall{Name: "read", CallID: "c1"}},
{Type: StreamEventStatus, Status: StreamStatusCompleted},
}
for _, e := range events {
if err := stream.Push(ctx, e); err != nil {
t.Fatalf("push %s: %v", e.Type, err)
}
}
if len(sink.events) != 2 {
t.Fatalf("expected 2 forwarded events, got %d: %+v", len(sink.events), sink.events)
}
if sink.events[0].Type != StreamEventDelta {
t.Fatalf("expected delta first, got %s", sink.events[0].Type)
}
if sink.events[1].Type != StreamEventStatus {
t.Fatalf("expected status second, got %s", sink.events[1].Type)
}
if err := stream.Close(ctx); err != nil {
t.Fatalf("close: %v", err)
}
if !sink.closed {
t.Fatalf("expected primary close to be called")
}
}
func TestToolCallDroppingStreamForwardsPrimaryError(t *testing.T) {
t.Parallel()
boom := errors.New("boom")
stream := NewToolCallDroppingStream(&recordingOutboundStream{err: boom})
err := stream.Push(context.Background(), StreamEvent{Type: StreamEventDelta, Delta: "x"})
if !errors.Is(err, boom) {
t.Fatalf("expected primary error to surface, got %v", err)
}
// tool events should still be dropped silently without calling the primary
if err := stream.Push(context.Background(), StreamEvent{Type: StreamEventToolCallStart}); err != nil {
t.Fatalf("tool event should not propagate primary error, got %v", err)
}
}
func TestNewToolCallDroppingStreamNilPrimary(t *testing.T) {
t.Parallel()
if got := NewToolCallDroppingStream(nil); got != nil {
t.Fatalf("expected nil wrapper when primary is nil, got %T", got)
}
}
+333
View File
@@ -0,0 +1,333 @@
package channel
import (
"strings"
)
// ToolCallStatus is the lifecycle state of a single tool call as surfaced in IM.
type ToolCallStatus string
const (
ToolCallStatusRunning ToolCallStatus = "running"
ToolCallStatusCompleted ToolCallStatus = "completed"
ToolCallStatusFailed ToolCallStatus = "failed"
)
// ExternalToolCallEmoji is the emoji used for any tool not in the built-in
// whitelist (including MCP and federation tools).
const ExternalToolCallEmoji = "⚙️"
// builtinToolCallEmoji maps built-in tool names to their display emoji.
// Names are matched case-insensitively after trimming whitespace.
var builtinToolCallEmoji = map[string]string{
"list": "📂",
"read": "📖",
"write": "📝",
"edit": "📝",
"exec": "💻",
"bg_status": "💻",
"web_search": "🌐",
"web_fetch": "🌐",
"search_memory": "🧠",
"search_messages": "🧠",
"list_sessions": "🧠",
"list_schedule": "📅",
"get_schedule": "📅",
"create_schedule": "📅",
"update_schedule": "📅",
"delete_schedule": "📅",
"send": "💬",
"react": "💬",
"get_contacts": "👥",
"list_email_accounts": "📧",
"send_email": "📧",
"list_email": "📧",
"read_email": "📧",
"browser_action": "🧭",
"browser_observe": "🧭",
"browser_remote_session": "🧭",
"spawn": "🤖",
"use_skill": "🧩",
"generate_image": "🖼️",
"speak": "🔊",
"transcribe_audio": "🎧",
}
// ToolCallEmoji returns the emoji mapped for a tool name. Unknown / external
// tools fall back to ExternalToolCallEmoji.
func ToolCallEmoji(toolName string) string {
key := strings.ToLower(strings.TrimSpace(toolName))
if emoji, ok := builtinToolCallEmoji[key]; ok {
return emoji
}
return ExternalToolCallEmoji
}
// ToolCallBlockType distinguishes body block rendering semantics.
type ToolCallBlockType string
const (
ToolCallBlockText ToolCallBlockType = "text" // free-form line or paragraph
ToolCallBlockLink ToolCallBlockType = "link" // titled hyperlink, optional description
ToolCallBlockCode ToolCallBlockType = "code" // preformatted / code block
)
// ToolCallBlock is one rich element inside ToolCallPresentation.Body. Fields
// not applicable to the Type are ignored.
type ToolCallBlock struct {
Type ToolCallBlockType
Title string
URL string
Desc string
Text string
}
// ToolCallPresentation is the rendered single-message view of one tool call
// state. Adapters call RenderToolCallMessage (or their own renderer) against
// this struct to produce the final IM text body.
//
// The preferred fields are Header / Body / Footer, populated either by
// per-tool formatters (see toolcall_formatters.go) or by the generic builder.
// InputSummary / ResultSummary are retained so existing callers that expect
// two flat strings keep working.
type ToolCallPresentation struct {
Emoji string
ToolName string
Status ToolCallStatus
Header string
Body []ToolCallBlock
Footer string
InputSummary string
ResultSummary string
}
// BuildToolCallStart builds a presentation for a tool_call_start event.
// Returns a zero-value presentation when the payload is nil.
func BuildToolCallStart(tc *StreamToolCall) ToolCallPresentation {
if tc == nil {
return ToolCallPresentation{}
}
name := strings.TrimSpace(tc.Name)
if fn := lookupToolFormatter(name); fn != nil {
p := fn(tc, ToolCallStatusRunning)
fillBaseIdentity(&p, name, ToolCallStatusRunning)
return p
}
summary := SummarizeToolInput(name, tc.Input)
return ToolCallPresentation{
Emoji: ToolCallEmoji(name),
ToolName: name,
Status: ToolCallStatusRunning,
Header: summary,
InputSummary: summary,
}
}
// BuildToolCallEnd builds a presentation for a tool_call_end event. The
// completed / failed status is inferred from the tool result payload (ok=false,
// error fields, non-zero exit codes, etc.).
func BuildToolCallEnd(tc *StreamToolCall) ToolCallPresentation {
if tc == nil {
return ToolCallPresentation{}
}
name := strings.TrimSpace(tc.Name)
status := ToolCallStatusCompleted
if isToolResultFailure(tc.Result) {
status = ToolCallStatusFailed
}
if fn := lookupToolFormatter(name); fn != nil {
p := fn(tc, status)
fillBaseIdentity(&p, name, status)
return p
}
inputSummary := SummarizeToolInput(name, tc.Input)
resultSummary := SummarizeToolResult(name, tc.Result)
return ToolCallPresentation{
Emoji: ToolCallEmoji(name),
ToolName: name,
Status: status,
Header: inputSummary,
Footer: resultSummary,
InputSummary: inputSummary,
ResultSummary: resultSummary,
}
}
// fillBaseIdentity fills emoji / tool name / status after a per-tool
// formatter runs, without clobbering values set by the formatter itself.
// InputSummary / ResultSummary are intentionally NOT populated here: when a
// formatter is used, its Header / Body / Footer output is authoritative and
// we must not append raw JSON summaries as a fallback.
func fillBaseIdentity(p *ToolCallPresentation, name string, status ToolCallStatus) {
if p.Emoji == "" {
p.Emoji = ToolCallEmoji(name)
}
if p.ToolName == "" {
p.ToolName = name
}
if p.Status == "" {
p.Status = status
}
}
// RenderToolCallMessage renders a plain-text single-message view of a tool
// call state. Links are rendered as two lines: the title on one line and the
// URL on the next. Adapters that want Markdown link syntax should use
// RenderToolCallMessageMarkdown instead.
func RenderToolCallMessage(p ToolCallPresentation) string {
return renderToolCall(p, false)
}
// RenderToolCallMessageMarkdown renders a Markdown version of the tool call
// presentation. Links become [title](url), code blocks are fenced with triple
// backticks, and plain-text blocks are unchanged.
func RenderToolCallMessageMarkdown(p ToolCallPresentation) string {
return renderToolCall(p, true)
}
func renderToolCall(p ToolCallPresentation, markdown bool) string {
if !presentationHasContent(p) {
return ""
}
var b strings.Builder
emoji := p.Emoji
if emoji == "" {
emoji = ExternalToolCallEmoji
}
b.WriteString(emoji)
b.WriteString(" ")
if p.ToolName != "" {
b.WriteString(p.ToolName)
} else {
b.WriteString("tool")
}
if p.Status != "" {
b.WriteString(" · ")
b.WriteString(string(p.Status))
}
header := strings.TrimSpace(p.Header)
if header == "" {
header = strings.TrimSpace(p.InputSummary)
}
if header != "" {
b.WriteString("\n")
b.WriteString(header)
}
for _, block := range p.Body {
rendered := renderToolCallBlock(block, markdown)
if rendered == "" {
continue
}
b.WriteString("\n")
b.WriteString(rendered)
}
footer := strings.TrimSpace(p.Footer)
if footer == "" {
footer = strings.TrimSpace(p.ResultSummary)
}
if footer != "" {
b.WriteString("\n")
b.WriteString(footer)
}
return b.String()
}
func presentationHasContent(p ToolCallPresentation) bool {
if p.ToolName != "" || p.Emoji != "" {
return true
}
if strings.TrimSpace(p.Header) != "" {
return true
}
if strings.TrimSpace(p.Footer) != "" {
return true
}
if strings.TrimSpace(p.InputSummary) != "" {
return true
}
if strings.TrimSpace(p.ResultSummary) != "" {
return true
}
return len(p.Body) > 0
}
func renderToolCallBlock(block ToolCallBlock, markdown bool) string {
switch block.Type {
case ToolCallBlockLink:
return renderLinkBlock(block, markdown)
case ToolCallBlockCode:
return renderCodeBlock(block, markdown)
case ToolCallBlockText:
return strings.TrimRight(block.Text, "\n")
default:
// Unknown types: fall back to Text for resilience.
return strings.TrimRight(block.Text, "\n")
}
}
func renderLinkBlock(block ToolCallBlock, markdown bool) string {
title := strings.TrimSpace(block.Title)
url := strings.TrimSpace(block.URL)
desc := strings.TrimSpace(block.Desc)
var b strings.Builder
switch {
case markdown && url != "":
label := title
if label == "" {
label = url
}
b.WriteString("[")
b.WriteString(label)
b.WriteString("](")
b.WriteString(url)
b.WriteString(")")
case url != "" && title != "":
b.WriteString(title)
b.WriteString("\n")
b.WriteString(url)
case url != "":
b.WriteString(url)
case title != "":
b.WriteString(title)
}
if desc != "" {
if b.Len() > 0 {
b.WriteString("\n")
}
b.WriteString(desc)
}
return b.String()
}
func renderCodeBlock(block ToolCallBlock, markdown bool) string {
text := strings.TrimRight(block.Text, "\n")
if text == "" {
return ""
}
if !markdown {
return text
}
var b strings.Builder
b.WriteString("```")
b.WriteString("\n")
b.WriteString(text)
b.WriteString("\n```")
return b.String()
}
+213
View File
@@ -0,0 +1,213 @@
package channel
import (
"reflect"
"strings"
"testing"
)
func TestToolCallEmojiBuiltin(t *testing.T) {
t.Parallel()
cases := map[string]string{
"read": "📖",
"WRITE": "📝",
" edit ": "📝",
"exec": "💻",
"web_search": "🌐",
"search_memory": "🧠",
"list_schedule": "📅",
"send": "💬",
"get_contacts": "👥",
"send_email": "📧",
"browser_action": "🧭",
"spawn": "🤖",
"use_skill": "🧩",
"generate_image": "🖼️",
"speak": "🔊",
}
for name, want := range cases {
if got := ToolCallEmoji(name); got != want {
t.Fatalf("ToolCallEmoji(%q) = %q, want %q", name, got, want)
}
}
}
func TestToolCallEmojiExternalFallback(t *testing.T) {
t.Parallel()
for _, name := range []string{"", " ", "mcp.filesystem.read", "federation_foo", "unknown_tool"} {
if got := ToolCallEmoji(name); got != ExternalToolCallEmoji {
t.Fatalf("ToolCallEmoji(%q) = %q, want external %q", name, got, ExternalToolCallEmoji)
}
}
}
func TestBuildToolCallStartPopulatesRunning(t *testing.T) {
t.Parallel()
tc := &StreamToolCall{
Name: "read",
CallID: "call_1",
Input: map[string]any{"path": "/tmp/foo.txt"},
}
p := BuildToolCallStart(tc)
if p.Status != ToolCallStatusRunning {
t.Fatalf("unexpected status: %q", p.Status)
}
if p.Emoji != "📖" {
t.Fatalf("unexpected emoji: %q", p.Emoji)
}
if p.Header != "/tmp/foo.txt" {
t.Fatalf("unexpected header: %q", p.Header)
}
if p.ResultSummary != "" {
t.Fatalf("start presentation should not carry a result summary, got %q", p.ResultSummary)
}
if p.Footer != "" {
t.Fatalf("start presentation should not carry a footer, got %q", p.Footer)
}
}
func TestBuildToolCallEndInfersStatus(t *testing.T) {
t.Parallel()
ok := &StreamToolCall{Name: "exec", Input: map[string]any{"command": "ls -la"}, Result: map[string]any{"ok": true, "exit_code": 0}}
if got := BuildToolCallEnd(ok); got.Status != ToolCallStatusCompleted {
t.Fatalf("expected completed, got %q", got.Status)
}
fail := &StreamToolCall{Name: "exec", Input: map[string]any{"command": "false"}, Result: map[string]any{"exit_code": 2, "stderr": "boom"}}
if got := BuildToolCallEnd(fail); got.Status != ToolCallStatusFailed {
t.Fatalf("expected failed, got %q", got.Status)
}
errored := &StreamToolCall{Name: "read", Input: map[string]any{"path": "/missing"}, Result: map[string]any{"error": "ENOENT"}}
if got := BuildToolCallEnd(errored); got.Status != ToolCallStatusFailed {
t.Fatalf("expected failed on error, got %q", got.Status)
}
}
func TestBuildToolCallHandlesNil(t *testing.T) {
t.Parallel()
if got := BuildToolCallStart(nil); !reflect.DeepEqual(got, ToolCallPresentation{}) {
t.Fatalf("expected zero-value presentation for nil start, got %+v", got)
}
if got := BuildToolCallEnd(nil); !reflect.DeepEqual(got, ToolCallPresentation{}) {
t.Fatalf("expected zero-value presentation for nil end, got %+v", got)
}
}
func TestRenderToolCallMessageLayout(t *testing.T) {
t.Parallel()
msg := RenderToolCallMessage(ToolCallPresentation{
Emoji: "📖",
ToolName: "read",
Status: ToolCallStatusRunning,
InputSummary: "/tmp/foo.txt",
ResultSummary: "",
})
if !strings.HasPrefix(msg, "📖 read · running") {
t.Fatalf("unexpected header: %q", msg)
}
if !strings.Contains(msg, "/tmp/foo.txt") {
t.Fatalf("expected input summary in body: %q", msg)
}
done := RenderToolCallMessage(ToolCallPresentation{
Emoji: "💻",
ToolName: "exec",
Status: ToolCallStatusCompleted,
InputSummary: "ls -la",
ResultSummary: "exit=0 · stdout: total 0",
})
lines := strings.Split(done, "\n")
if len(lines) != 3 {
t.Fatalf("expected header+input+result lines, got %d: %q", len(lines), done)
}
if !strings.HasPrefix(lines[0], "💻 exec · completed") {
t.Fatalf("unexpected header: %q", lines[0])
}
}
func TestRenderToolCallMessageEmptyWhenNothingKnown(t *testing.T) {
t.Parallel()
if got := RenderToolCallMessage(ToolCallPresentation{}); got != "" {
t.Fatalf("expected empty render, got %q", got)
}
}
func TestRenderToolCallMessageMarkdownRendersLinks(t *testing.T) {
t.Parallel()
p := ToolCallPresentation{
Emoji: "🌐",
ToolName: "web_search",
Status: ToolCallStatusCompleted,
Header: `2 results for "golang generics"`,
Body: []ToolCallBlock{
{
Type: ToolCallBlockLink,
Title: "Tutorial: Getting started with generics",
URL: "https://go.dev/doc/tutorial/generics",
Desc: "A comprehensive walkthrough",
},
{
Type: ToolCallBlockLink,
Title: "Go 1.18 is released",
URL: "https://go.dev/blog/go1.18",
},
},
}
md := RenderToolCallMessageMarkdown(p)
if !strings.Contains(md, "[Tutorial: Getting started with generics](https://go.dev/doc/tutorial/generics)") {
t.Fatalf("expected markdown link for first item, got %q", md)
}
if !strings.Contains(md, "[Go 1.18 is released](https://go.dev/blog/go1.18)") {
t.Fatalf("expected markdown link for second item, got %q", md)
}
if !strings.Contains(md, "A comprehensive walkthrough") {
t.Fatalf("expected description to appear in markdown, got %q", md)
}
plain := RenderToolCallMessage(p)
if strings.Contains(plain, "](https://") {
t.Fatalf("plain render should not contain markdown link syntax, got %q", plain)
}
if !strings.Contains(plain, "Tutorial: Getting started with generics") || !strings.Contains(plain, "https://go.dev/doc/tutorial/generics") {
t.Fatalf("plain render should carry title and url lines, got %q", plain)
}
}
func TestRenderToolCallMessageMarkdownCodeBlocks(t *testing.T) {
t.Parallel()
p := ToolCallPresentation{
Emoji: "💻",
ToolName: "exec",
Status: ToolCallStatusCompleted,
Header: "$ ls -la",
Body: []ToolCallBlock{
{Type: ToolCallBlockCode, Text: "total 0\ndrwxr-xr-x 2 user"},
},
Footer: "exit=0",
}
md := RenderToolCallMessageMarkdown(p)
if !strings.Contains(md, "```\ntotal 0\ndrwxr-xr-x 2 user\n```") {
t.Fatalf("expected fenced code block, got %q", md)
}
plain := RenderToolCallMessage(p)
if strings.Contains(plain, "```") {
t.Fatalf("plain render should not fence code, got %q", plain)
}
if !strings.Contains(plain, "total 0") {
t.Fatalf("plain render should still include code body, got %q", plain)
}
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,358 @@
package channel
import (
"strings"
"testing"
)
func hasTextBlock(body []ToolCallBlock, needle string) bool {
for _, b := range body {
if b.Type != ToolCallBlockText {
continue
}
if strings.Contains(b.Text, needle) {
return true
}
}
return false
}
func hasLinkBlock(body []ToolCallBlock, url string) *ToolCallBlock {
for i := range body {
if body[i].Type != ToolCallBlockLink {
continue
}
if body[i].URL == url {
return &body[i]
}
}
return nil
}
func TestFormatListIncludesEntries(t *testing.T) {
t.Parallel()
tc := &StreamToolCall{
Name: "list",
Input: map[string]any{"path": "/var/log"},
Result: map[string]any{
"total_count": float64(12),
"entries": []any{
map[string]any{"path": "syslog", "is_dir": false, "size": float64(2300)},
map[string]any{"path": "auth.log", "is_dir": false, "size": float64(1100000)},
map[string]any{"path": "nginx", "is_dir": true},
},
},
}
p := BuildToolCallEnd(tc)
if !strings.Contains(p.Header, "/var/log") || !strings.Contains(p.Header, "12 entries") {
t.Fatalf("unexpected header: %q", p.Header)
}
if !hasTextBlock(p.Body, "syslog") {
t.Fatalf("expected syslog entry in body, got %+v", p.Body)
}
if !hasTextBlock(p.Body, "nginx/") {
t.Fatalf("expected nginx/ directory entry in body, got %+v", p.Body)
}
if !hasTextBlock(p.Body, "…and 9 more") {
t.Fatalf("expected ellipsis footer for remaining items, got %+v", p.Body)
}
if p.InputSummary != "" || p.ResultSummary != "" {
t.Fatalf("formatter output must not leak InputSummary/ResultSummary raw JSON, got in=%q res=%q", p.InputSummary, p.ResultSummary)
}
rendered := RenderToolCallMessage(p)
if strings.Contains(rendered, "\"entries\"") || strings.Contains(rendered, "{\"") {
t.Fatalf("rendered output leaked raw JSON result:\n%s", rendered)
}
}
func TestFormatExecForeground(t *testing.T) {
t.Parallel()
tc := &StreamToolCall{
Name: "exec",
Input: map[string]any{"command": "ls -la /tmp"},
Result: map[string]any{"exit_code": float64(0), "stdout": "total 16\nfoo bar"},
}
p := BuildToolCallEnd(tc)
if !strings.HasPrefix(p.Header, "$ ls -la /tmp") {
t.Fatalf("unexpected header: %q", p.Header)
}
if p.Footer != "exit=0" {
t.Fatalf("unexpected footer: %q", p.Footer)
}
if !hasTextBlock(p.Body, "stdout: total 16") {
t.Fatalf("expected stdout block, got %+v", p.Body)
}
}
func TestFormatExecBackground(t *testing.T) {
t.Parallel()
tc := &StreamToolCall{
Name: "exec",
Input: map[string]any{"command": "long_running.sh"},
Result: map[string]any{
"status": "background_started",
"task_id": "bg_abc",
"output_file": "/tmp/bg_abc.log",
},
}
p := BuildToolCallEnd(tc)
if !strings.Contains(p.Footer, "background_started") ||
!strings.Contains(p.Footer, "task_id=bg_abc") ||
!strings.Contains(p.Footer, "/tmp/bg_abc.log") {
t.Fatalf("unexpected footer: %q", p.Footer)
}
}
func TestFormatWebSearchEmitsLinks(t *testing.T) {
t.Parallel()
tc := &StreamToolCall{
Name: "web_search",
Input: map[string]any{"query": "golang generics tutorial"},
Result: map[string]any{
"results": []any{
map[string]any{
"title": "Tutorial: Getting started with generics",
"url": "https://go.dev/doc/tutorial/generics",
"description": "A comprehensive walkthrough",
},
map[string]any{
"title": "Go 1.18 is released",
"url": "https://go.dev/blog/go1.18",
},
},
},
}
p := BuildToolCallEnd(tc)
if !strings.Contains(p.Header, "2 results") || !strings.Contains(p.Header, `"golang generics tutorial"`) {
t.Fatalf("unexpected header: %q", p.Header)
}
link := hasLinkBlock(p.Body, "https://go.dev/doc/tutorial/generics")
if link == nil {
t.Fatalf("expected link block for tutorial, got %+v", p.Body)
}
if link.Title != "Tutorial: Getting started with generics" {
t.Fatalf("unexpected title: %q", link.Title)
}
if link.Desc != "A comprehensive walkthrough" {
t.Fatalf("unexpected desc: %q", link.Desc)
}
}
func TestFormatSendCarriesTargetAndBody(t *testing.T) {
t.Parallel()
tc := &StreamToolCall{
Name: "send",
Input: map[string]any{
"target": "chat:123",
"platform": "telegram",
"body": "Hello there",
},
Result: map[string]any{"delivered": "delivered", "message_id": "msg_456"},
}
p := BuildToolCallEnd(tc)
if p.Header != "→ chat:123 (telegram)" {
t.Fatalf("unexpected header: %q", p.Header)
}
if !hasTextBlock(p.Body, "Hello there") {
t.Fatalf("expected body text, got %+v", p.Body)
}
if !strings.Contains(p.Footer, "message_id=msg_456") {
t.Fatalf("unexpected footer: %q", p.Footer)
}
}
func TestFormatSendEmailCarriesSubject(t *testing.T) {
t.Parallel()
tc := &StreamToolCall{
Name: "send_email",
Input: map[string]any{
"to": "alice@example.com",
"subject": "Meeting notes",
},
Result: map[string]any{"status": "sent", "message_id": "mid1"},
}
p := BuildToolCallEnd(tc)
if p.Header != "→ alice@example.com" {
t.Fatalf("unexpected header: %q", p.Header)
}
if !hasTextBlock(p.Body, "Subject: Meeting notes") {
t.Fatalf("expected subject block, got %+v", p.Body)
}
if !strings.Contains(p.Footer, "status=sent") || !strings.Contains(p.Footer, "message_id=mid1") {
t.Fatalf("unexpected footer: %q", p.Footer)
}
}
func TestFormatSearchMemoryPrintsScoreAndTotal(t *testing.T) {
t.Parallel()
tc := &StreamToolCall{
Name: "search_memory",
Input: map[string]any{"query": "previous trip to Tokyo"},
Result: map[string]any{
"total": float64(10),
"results": []any{
map[string]any{"text": "Alice prefers sushi over ramen", "score": float64(0.91)},
map[string]any{"text": "Bob mentioned a trip to Tokyo in March", "score": float64(0.87)},
},
},
}
p := BuildToolCallEnd(tc)
if !strings.Contains(p.Header, "2 / 10 results") {
t.Fatalf("unexpected header: %q", p.Header)
}
if !hasTextBlock(p.Body, "(0.91)") {
t.Fatalf("expected score annotation, got %+v", p.Body)
}
}
func TestFormatSearchMessagesTruncatesPreview(t *testing.T) {
t.Parallel()
msgs := make([]any, 0, 6)
for i := 0; i < 6; i++ {
msgs = append(msgs, map[string]any{
"role": "user",
"text": "I need a kanban board",
"created_at": "2026-04-22 10:00",
})
}
tc := &StreamToolCall{
Name: "search_messages",
Input: map[string]any{"keyword": "kanban"},
Result: map[string]any{"messages": msgs},
}
p := BuildToolCallEnd(tc)
if !strings.Contains(p.Header, `keyword="kanban"`) || !strings.Contains(p.Header, "6 messages") {
t.Fatalf("unexpected header: %q", p.Header)
}
if !strings.Contains(p.Footer, "…and 3 more") {
t.Fatalf("expected ellipsis footer, got %q", p.Footer)
}
}
func TestFormatSpawnSuccessRatio(t *testing.T) {
t.Parallel()
tc := &StreamToolCall{
Name: "spawn",
Result: map[string]any{
"results": []any{
map[string]any{"success": true, "task": "analyze repo structure", "session_id": "sess_1"},
map[string]any{"success": true, "task": "summarize README", "session_id": "sess_2"},
},
},
}
p := BuildToolCallEnd(tc)
if !strings.Contains(p.Header, "2 / 2") {
t.Fatalf("unexpected header: %q", p.Header)
}
if !hasTextBlock(p.Body, "analyze repo structure") {
t.Fatalf("expected first task in body, got %+v", p.Body)
}
}
func TestFormatWebFetchShowsLink(t *testing.T) {
t.Parallel()
tc := &StreamToolCall{
Name: "web_fetch",
Input: map[string]any{"url": "https://example.com/article"},
Result: map[string]any{
"title": "Example Article Title",
"format": "markdown",
"length": float64(3421),
},
}
p := BuildToolCallEnd(tc)
link := hasLinkBlock(p.Body, "https://example.com/article")
if link == nil {
t.Fatalf("expected link block, got %+v", p.Body)
}
if link.Title != "Example Article Title" {
t.Fatalf("unexpected link title: %q", link.Title)
}
if !strings.Contains(p.Footer, "markdown") || !strings.Contains(p.Footer, "3421 chars") {
t.Fatalf("unexpected footer: %q", p.Footer)
}
}
func TestFormatCreateScheduleSuccess(t *testing.T) {
t.Parallel()
tc := &StreamToolCall{
Name: "create_schedule",
Input: map[string]any{
"name": "Daily report",
"pattern": "0 9 * * *",
},
Result: map[string]any{"id": "sch_42"},
}
p := BuildToolCallEnd(tc)
if !strings.Contains(p.Header, "[sch_42]") || !strings.Contains(p.Header, "Daily report") {
t.Fatalf("unexpected header: %q", p.Header)
}
}
func TestFormatFailureEmitsErrorFooter(t *testing.T) {
t.Parallel()
tc := &StreamToolCall{
Name: "read",
Input: map[string]any{"path": "/etc/shadow"},
Result: map[string]any{"error": "permission denied"},
}
p := BuildToolCallEnd(tc)
if p.Status != ToolCallStatusFailed {
t.Fatalf("expected failed status, got %q", p.Status)
}
if !strings.Contains(p.Footer, "permission denied") {
t.Fatalf("expected error footer, got %q", p.Footer)
}
}
func TestFormatExecFailedExitCode(t *testing.T) {
t.Parallel()
tc := &StreamToolCall{
Name: "exec",
Input: map[string]any{"command": "false"},
Result: map[string]any{
"exit_code": float64(2),
"stderr": "boom",
},
}
p := BuildToolCallEnd(tc)
if p.Status != ToolCallStatusFailed {
t.Fatalf("expected failed status, got %q", p.Status)
}
if !strings.Contains(p.Footer, "error") || !strings.Contains(p.Footer, "boom") {
t.Fatalf("expected stderr-based error footer, got %q", p.Footer)
}
}
func TestExternalToolFallsBackToGenericSummary(t *testing.T) {
t.Parallel()
tc := &StreamToolCall{
Name: "mcp.custom.do_thing",
Input: map[string]any{"foo": "bar"},
Result: map[string]any{"ok": true},
}
p := BuildToolCallEnd(tc)
if p.Emoji != ExternalToolCallEmoji {
t.Fatalf("expected external emoji, got %q", p.Emoji)
}
if p.Status != ToolCallStatusCompleted {
t.Fatalf("expected completed status, got %q", p.Status)
}
if p.InputSummary == "" {
t.Fatalf("expected generic input summary to be populated")
}
}
+302
View File
@@ -0,0 +1,302 @@
package channel
import (
"encoding/base64"
"encoding/json"
"fmt"
"sort"
"strings"
"github.com/memohai/memoh/internal/textutil"
)
const (
toolCallSummaryMaxRunes = 200
toolCallSummaryTruncMark = "…"
)
// SummarizeToolInput returns a short human-readable representation of a
// tool call's input payload, prioritizing known key fields (path, command,
// query, url, target, to, id, cron, action) before falling back to a compact
// JSON projection.
func SummarizeToolInput(_ string, input any) string {
if input == nil {
return ""
}
m, ok := normalizeToMap(input)
if ok {
if s := pickStringField(m, "path", "file_path", "filepath"); s != "" {
return truncateSummary(s)
}
if s := pickStringField(m, "command", "cmd"); s != "" {
return truncateSummary(firstLine(s))
}
if s := pickStringField(m, "query"); s != "" {
return truncateSummary(s)
}
if s := pickStringField(m, "url"); s != "" {
return truncateSummary(s)
}
if s := combineTargetAndBody(m); s != "" {
return truncateSummary(s)
}
if s := pickStringField(m, "id"); s != "" {
if cron := strings.TrimSpace(fmt.Sprint(m["cron"])); cron != "" && cron != "<nil>" {
return truncateSummary(fmt.Sprintf("%s · %s", s, cron))
}
if action := strings.TrimSpace(fmt.Sprint(m["action"])); action != "" && action != "<nil>" {
return truncateSummary(fmt.Sprintf("%s · %s", s, action))
}
return truncateSummary(s)
}
if s := pickStringField(m, "cron"); s != "" {
return truncateSummary(s)
}
if s := pickStringField(m, "action"); s != "" {
return truncateSummary(s)
}
}
return compactJSONSummary(input)
}
// SummarizeToolResult returns a short representation of a tool call's result,
// surfacing status/error/count signals when present and otherwise falling
// back to trimmed text or a compact JSON projection.
func SummarizeToolResult(_ string, result any) string {
if result == nil {
return ""
}
if s, ok := result.(string); ok {
return truncateSummary(strings.TrimSpace(s))
}
m, ok := normalizeToMap(result)
if ok {
parts := make([]string, 0, 4)
if errStr := pickStringField(m, "error"); errStr != "" {
return truncateSummary("error: " + errStr)
}
if okVal, okFound := m["ok"]; okFound {
parts = append(parts, fmt.Sprintf("ok=%v", okVal))
}
if status := pickStringField(m, "status"); status != "" {
parts = append(parts, "status="+status)
}
if code, ok := numericField(m, "exit_code"); ok {
parts = append(parts, fmt.Sprintf("exit=%v", code))
}
if count, ok := numericField(m, "count"); ok {
parts = append(parts, fmt.Sprintf("count=%v", count))
}
if msg := pickStringField(m, "message"); msg != "" {
parts = append(parts, msg)
}
if stdout := pickStringField(m, "stdout"); stdout != "" {
parts = append(parts, "stdout: "+firstLine(stdout))
} else if stderr := pickStringField(m, "stderr"); stderr != "" {
parts = append(parts, "stderr: "+firstLine(stderr))
}
if len(parts) > 0 {
return truncateSummary(strings.Join(parts, " · "))
}
}
return compactJSONSummary(result)
}
// isToolResultFailure inspects a tool result payload and reports whether it
// represents a failure (ok=false, non-empty error, non-zero exit_code).
func isToolResultFailure(result any) bool {
if result == nil {
return false
}
m, ok := normalizeToMap(result)
if !ok {
return false
}
if errStr := pickStringField(m, "error"); errStr != "" {
return true
}
if okVal, okFound := m["ok"]; okFound {
if b, ok := okVal.(bool); ok && !b {
return true
}
}
if code, ok := numericField(m, "exit_code"); ok {
if code != 0 {
return true
}
}
return false
}
func normalizeToMap(v any) (map[string]any, bool) {
switch val := v.(type) {
case map[string]any:
return val, true
case json.RawMessage:
if len(val) == 0 {
return nil, false
}
var m map[string]any
if err := json.Unmarshal(val, &m); err == nil {
return m, true
}
case []byte:
if len(val) == 0 {
return nil, false
}
var m map[string]any
if err := json.Unmarshal(val, &m); err == nil {
return m, true
}
}
return nil, false
}
func pickStringField(m map[string]any, keys ...string) string {
for _, k := range keys {
if v, ok := m[k]; ok {
switch val := v.(type) {
case string:
if s := strings.TrimSpace(val); s != "" {
return s
}
case fmt.Stringer:
if s := strings.TrimSpace(val.String()); s != "" {
return s
}
}
}
}
return ""
}
func numericField(m map[string]any, key string) (float64, bool) {
v, ok := m[key]
if !ok {
return 0, false
}
switch val := v.(type) {
case float64:
return val, true
case int:
return float64(val), true
case int64:
return float64(val), true
case json.Number:
if f, err := val.Float64(); err == nil {
return f, true
}
}
return 0, false
}
func firstLine(s string) string {
s = strings.TrimSpace(s)
if idx := strings.IndexByte(s, '\n'); idx >= 0 {
return strings.TrimSpace(s[:idx])
}
return s
}
func combineTargetAndBody(m map[string]any) string {
target := pickStringField(m, "target", "to", "recipient")
body := pickStringField(m, "body", "content", "message", "text", "subject")
if target != "" && body != "" {
return fmt.Sprintf("→ %s: %s", target, body)
}
if target != "" {
return "→ " + target
}
return ""
}
func truncateSummary(s string) string {
s = strings.TrimSpace(s)
if s == "" {
return ""
}
return textutil.TruncateRunesWithSuffix(s, toolCallSummaryMaxRunes, toolCallSummaryTruncMark)
}
// compactJSONSummary is a last-resort projection for values where we cannot
// extract known key fields. It omits binary / base64 content and large arrays.
func compactJSONSummary(v any) string {
if v == nil {
return ""
}
if raw, ok := v.(json.RawMessage); ok {
var decoded any
if err := json.Unmarshal(raw, &decoded); err == nil {
v = decoded
} else {
return truncateSummary(string(raw))
}
}
projected := projectForSummary(v)
bytes, err := json.Marshal(projected)
if err != nil {
return truncateSummary(fmt.Sprint(v))
}
return truncateSummary(string(bytes))
}
// projectForSummary reduces large / binary values before serialization so
// the summary stays short. It replaces base64-looking strings, truncates
// slices, and sorts map keys for stability.
func projectForSummary(v any) any {
switch val := v.(type) {
case map[string]any:
keys := make([]string, 0, len(val))
for k := range val {
keys = append(keys, k)
}
sort.Strings(keys)
out := make(map[string]any, len(keys))
for _, k := range keys {
out[k] = projectForSummary(val[k])
}
return out
case []any:
if len(val) == 0 {
return val
}
preview := 3
if len(val) < preview {
preview = len(val)
}
head := make([]any, 0, preview)
for i := 0; i < preview; i++ {
head = append(head, projectForSummary(val[i]))
}
if len(val) > preview {
return map[string]any{
"count": len(val),
"preview": head,
}
}
return head
case string:
if isLikelyBase64(val) {
return fmt.Sprintf("<binary %d bytes>", len(val))
}
if len(val) > 120 {
return textutil.TruncateRunesWithSuffix(val, 120, toolCallSummaryTruncMark)
}
return val
default:
return val
}
}
func isLikelyBase64(s string) bool {
if len(s) < 200 {
return false
}
if strings.ContainsAny(s, " \n\t") {
return false
}
if _, err := base64.StdEncoding.DecodeString(s); err == nil {
return true
}
return false
}
+147
View File
@@ -0,0 +1,147 @@
package channel
import (
"strings"
"testing"
)
func TestSummarizeToolInputFileField(t *testing.T) {
t.Parallel()
got := SummarizeToolInput("read", map[string]any{"path": "/var/log/syslog"})
if got != "/var/log/syslog" {
t.Fatalf("unexpected: %q", got)
}
}
func TestSummarizeToolInputCommandFirstLine(t *testing.T) {
t.Parallel()
got := SummarizeToolInput("exec", map[string]any{"command": "echo hi\nsleep 10"})
if got != "echo hi" {
t.Fatalf("unexpected first line: %q", got)
}
}
func TestSummarizeToolInputMessageTargetAndBody(t *testing.T) {
t.Parallel()
got := SummarizeToolInput("send", map[string]any{
"target": "chat:123",
"body": "Hello there",
})
if !strings.Contains(got, "chat:123") || !strings.Contains(got, "Hello there") {
t.Fatalf("unexpected target/body summary: %q", got)
}
}
func TestSummarizeToolInputScheduleID(t *testing.T) {
t.Parallel()
got := SummarizeToolInput("update_schedule", map[string]any{
"id": "sch_42",
"cron": "0 9 * * *",
})
if got != "sch_42 · 0 9 * * *" {
t.Fatalf("unexpected schedule summary: %q", got)
}
}
func TestSummarizeToolInputTruncatesLongValues(t *testing.T) {
t.Parallel()
long := strings.Repeat("x", 400)
got := SummarizeToolInput("web_fetch", map[string]any{"url": long})
if !strings.HasSuffix(got, "…") {
t.Fatalf("expected truncation suffix, got %q", got)
}
if len([]rune(got)) > 201 {
t.Fatalf("summary not truncated: rune len=%d", len([]rune(got)))
}
}
func TestSummarizeToolInputFallbackCompactJSON(t *testing.T) {
t.Parallel()
got := SummarizeToolInput("unknown", map[string]any{"alpha": 1, "beta": 2})
if !strings.Contains(got, "\"alpha\"") || !strings.Contains(got, "\"beta\"") {
t.Fatalf("expected JSON fallback: %q", got)
}
}
func TestSummarizeToolResultPrefersError(t *testing.T) {
t.Parallel()
got := SummarizeToolResult("read", map[string]any{"error": "ENOENT", "ok": false})
if !strings.HasPrefix(got, "error: ENOENT") {
t.Fatalf("unexpected result: %q", got)
}
}
func TestSummarizeToolResultCombinesSignals(t *testing.T) {
t.Parallel()
got := SummarizeToolResult("exec", map[string]any{
"ok": true,
"exit_code": 0,
"stdout": "line1\nline2",
})
if !strings.Contains(got, "ok=true") {
t.Fatalf("missing ok signal: %q", got)
}
if !strings.Contains(got, "exit=0") {
t.Fatalf("missing exit_code: %q", got)
}
if !strings.Contains(got, "stdout: line1") {
t.Fatalf("missing stdout first line: %q", got)
}
}
func TestSummarizeToolResultPlainString(t *testing.T) {
t.Parallel()
got := SummarizeToolResult("read", "hello world")
if got != "hello world" {
t.Fatalf("unexpected plain result: %q", got)
}
}
func TestSummarizeToolResultLargeJSONFallback(t *testing.T) {
t.Parallel()
items := make([]any, 0, 10)
for i := 0; i < 10; i++ {
items = append(items, map[string]any{"id": i})
}
got := SummarizeToolResult("list", map[string]any{"items": items})
if got == "" {
t.Fatalf("expected non-empty summary")
}
if len([]rune(got)) > 201 {
t.Fatalf("expected truncated: %d", len([]rune(got)))
}
}
func TestIsToolResultFailure(t *testing.T) {
t.Parallel()
cases := []struct {
name string
result any
want bool
}{
{"nil", nil, false},
{"ok_true", map[string]any{"ok": true}, false},
{"ok_false", map[string]any{"ok": false}, true},
{"error_present", map[string]any{"error": "bad"}, true},
{"empty_error", map[string]any{"error": ""}, false},
{"exit_zero", map[string]any{"exit_code": 0}, false},
{"exit_nonzero", map[string]any{"exit_code": 2}, true},
{"plain_string", "hello", false},
}
for _, tc := range cases {
if got := isToolResultFailure(tc.result); got != tc.want {
t.Fatalf("%s: isToolResultFailure = %v, want %v", tc.name, got, tc.want)
}
}
}
+1 -1
View File
@@ -511,7 +511,7 @@ WITH updated AS (
SET display_name = $1,
updated_at = now()
WHERE bots.id = $2
RETURNING id, owner_user_id, display_name, avatar_url, timezone, is_active, status, language, reasoning_enabled, reasoning_effort, chat_model_id, search_provider_id, memory_provider_id, heartbeat_enabled, heartbeat_interval, heartbeat_prompt, heartbeat_model_id, compaction_enabled, compaction_threshold, compaction_ratio, compaction_model_id, title_model_id, image_model_id, discuss_probe_model_id, tts_model_id, transcription_model_id, browser_context_id, persist_full_tool_results, metadata, created_at, updated_at, acl_default_effect
RETURNING id, owner_user_id, display_name, avatar_url, timezone, is_active, status, language, reasoning_enabled, reasoning_effort, chat_model_id, search_provider_id, memory_provider_id, heartbeat_enabled, heartbeat_interval, heartbeat_prompt, heartbeat_model_id, compaction_enabled, compaction_threshold, compaction_ratio, compaction_model_id, title_model_id, image_model_id, discuss_probe_model_id, tts_model_id, transcription_model_id, browser_context_id, persist_full_tool_results, show_tool_calls_in_im, metadata, created_at, updated_at, acl_default_effect
)
SELECT
updated.id AS id,
+1
View File
@@ -37,6 +37,7 @@ type Bot struct {
TranscriptionModelID pgtype.UUID `json:"transcription_model_id"`
BrowserContextID pgtype.UUID `json:"browser_context_id"`
PersistFullToolResults bool `json:"persist_full_tool_results"`
ShowToolCallsInIm bool `json:"show_tool_calls_in_im"`
Metadata []byte `json:"metadata"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
+14 -4
View File
@@ -33,6 +33,7 @@ SET language = 'auto',
transcription_model_id = NULL,
browser_context_id = NULL,
persist_full_tool_results = false,
show_tool_calls_in_im = false,
updated_at = now()
WHERE id = $1
`
@@ -65,7 +66,8 @@ SELECT
tts_models.id AS tts_model_id,
transcription_models.id AS transcription_model_id,
browser_contexts.id AS browser_context_id,
bots.persist_full_tool_results
bots.persist_full_tool_results,
bots.show_tool_calls_in_im
FROM bots
LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id
LEFT JOIN models AS heartbeat_models ON heartbeat_models.id = bots.heartbeat_model_id
@@ -103,6 +105,7 @@ type GetSettingsByBotIDRow struct {
TranscriptionModelID pgtype.UUID `json:"transcription_model_id"`
BrowserContextID pgtype.UUID `json:"browser_context_id"`
PersistFullToolResults bool `json:"persist_full_tool_results"`
ShowToolCallsInIm bool `json:"show_tool_calls_in_im"`
}
func (q *Queries) GetSettingsByBotID(ctx context.Context, id pgtype.UUID) (GetSettingsByBotIDRow, error) {
@@ -131,6 +134,7 @@ func (q *Queries) GetSettingsByBotID(ctx context.Context, id pgtype.UUID) (GetSe
&i.TranscriptionModelID,
&i.BrowserContextID,
&i.PersistFullToolResults,
&i.ShowToolCallsInIm,
)
return i, err
}
@@ -159,9 +163,10 @@ WITH updated AS (
transcription_model_id = COALESCE($19::uuid, bots.transcription_model_id),
browser_context_id = COALESCE($20::uuid, bots.browser_context_id),
persist_full_tool_results = $21,
show_tool_calls_in_im = $22,
updated_at = now()
WHERE bots.id = $22
RETURNING bots.id, bots.language, bots.reasoning_enabled, bots.reasoning_effort, bots.heartbeat_enabled, bots.heartbeat_interval, bots.heartbeat_prompt, bots.compaction_enabled, bots.compaction_threshold, bots.compaction_ratio, bots.timezone, bots.chat_model_id, bots.heartbeat_model_id, bots.compaction_model_id, bots.title_model_id, bots.image_model_id, bots.search_provider_id, bots.memory_provider_id, bots.tts_model_id, bots.transcription_model_id, bots.browser_context_id, bots.persist_full_tool_results
WHERE bots.id = $23
RETURNING bots.id, bots.language, bots.reasoning_enabled, bots.reasoning_effort, bots.heartbeat_enabled, bots.heartbeat_interval, bots.heartbeat_prompt, bots.compaction_enabled, bots.compaction_threshold, bots.compaction_ratio, bots.timezone, bots.chat_model_id, bots.heartbeat_model_id, bots.compaction_model_id, bots.title_model_id, bots.image_model_id, bots.search_provider_id, bots.memory_provider_id, bots.tts_model_id, bots.transcription_model_id, bots.browser_context_id, bots.persist_full_tool_results, bots.show_tool_calls_in_im
)
SELECT
updated.id AS bot_id,
@@ -185,7 +190,8 @@ SELECT
tts_models.id AS tts_model_id,
transcription_models.id AS transcription_model_id,
browser_contexts.id AS browser_context_id,
updated.persist_full_tool_results
updated.persist_full_tool_results,
updated.show_tool_calls_in_im
FROM updated
LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id
LEFT JOIN models AS heartbeat_models ON heartbeat_models.id = updated.heartbeat_model_id
@@ -221,6 +227,7 @@ type UpsertBotSettingsParams struct {
TranscriptionModelID pgtype.UUID `json:"transcription_model_id"`
BrowserContextID pgtype.UUID `json:"browser_context_id"`
PersistFullToolResults bool `json:"persist_full_tool_results"`
ShowToolCallsInIm bool `json:"show_tool_calls_in_im"`
ID pgtype.UUID `json:"id"`
}
@@ -247,6 +254,7 @@ type UpsertBotSettingsRow struct {
TranscriptionModelID pgtype.UUID `json:"transcription_model_id"`
BrowserContextID pgtype.UUID `json:"browser_context_id"`
PersistFullToolResults bool `json:"persist_full_tool_results"`
ShowToolCallsInIm bool `json:"show_tool_calls_in_im"`
}
func (q *Queries) UpsertBotSettings(ctx context.Context, arg UpsertBotSettingsParams) (UpsertBotSettingsRow, error) {
@@ -272,6 +280,7 @@ func (q *Queries) UpsertBotSettings(ctx context.Context, arg UpsertBotSettingsPa
arg.TranscriptionModelID,
arg.BrowserContextID,
arg.PersistFullToolResults,
arg.ShowToolCallsInIm,
arg.ID,
)
var i UpsertBotSettingsRow
@@ -298,6 +307,7 @@ func (q *Queries) UpsertBotSettings(ctx context.Context, arg UpsertBotSettingsPa
&i.TranscriptionModelID,
&i.BrowserContextID,
&i.PersistFullToolResults,
&i.ShowToolCallsInIm,
)
return i, err
}
+8
View File
@@ -101,6 +101,9 @@ func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest
if req.PersistFullToolResults != nil {
current.PersistFullToolResults = *req.PersistFullToolResults
}
if req.ShowToolCallsInIM != nil {
current.ShowToolCallsInIM = *req.ShowToolCallsInIM
}
timezoneValue := pgtype.Text{}
if req.Timezone != nil {
normalized, err := normalizeOptionalTimezone(*req.Timezone)
@@ -215,6 +218,7 @@ func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest
TranscriptionModelID: transcriptionModelUUID,
BrowserContextID: browserContextUUID,
PersistFullToolResults: current.PersistFullToolResults,
ShowToolCallsInIm: current.ShowToolCallsInIM,
})
if err != nil {
return Settings{}, err
@@ -310,6 +314,7 @@ func normalizeBotSettingsReadRow(row sqlc.GetSettingsByBotIDRow) Settings {
row.TranscriptionModelID,
row.BrowserContextID,
row.PersistFullToolResults,
row.ShowToolCallsInIm,
)
}
@@ -335,6 +340,7 @@ func normalizeBotSettingsWriteRow(row sqlc.UpsertBotSettingsRow) Settings {
row.TranscriptionModelID,
row.BrowserContextID,
row.PersistFullToolResults,
row.ShowToolCallsInIm,
)
}
@@ -359,6 +365,7 @@ func normalizeBotSettingsFields(
transcriptionModelID pgtype.UUID,
browserContextID pgtype.UUID,
persistFullToolResults bool,
showToolCallsInIM bool,
) Settings {
settings := normalizeBotSetting(language, "", reasoningEnabled, reasoningEffort, heartbeatEnabled, heartbeatInterval, compactionEnabled, compactionThreshold, compactionRatio)
if timezone.Valid {
@@ -395,6 +402,7 @@ func normalizeBotSettingsFields(
settings.BrowserContextID = uuid.UUID(browserContextID.Bytes).String()
}
settings.PersistFullToolResults = persistFullToolResults
settings.ShowToolCallsInIM = showToolCallsInIM
return settings
}
+69
View File
@@ -0,0 +1,69 @@
package settings
import (
"testing"
"github.com/memohai/memoh/internal/db/sqlc"
)
func TestNormalizeBotSettingsReadRow_ShowToolCallsInIMDefault(t *testing.T) {
t.Parallel()
row := sqlc.GetSettingsByBotIDRow{
Language: "en",
ReasoningEnabled: false,
ReasoningEffort: "medium",
HeartbeatEnabled: false,
HeartbeatInterval: 60,
CompactionEnabled: false,
CompactionThreshold: 0,
CompactionRatio: 80,
ShowToolCallsInIm: false,
}
got := normalizeBotSettingsReadRow(row)
if got.ShowToolCallsInIM {
t.Fatalf("expected default ShowToolCallsInIM=false, got true")
}
}
func TestNormalizeBotSettingsReadRow_ShowToolCallsInIMPropagates(t *testing.T) {
t.Parallel()
row := sqlc.GetSettingsByBotIDRow{
Language: "en",
ReasoningEffort: "medium",
HeartbeatInterval: 60,
CompactionRatio: 80,
ShowToolCallsInIm: true,
}
got := normalizeBotSettingsReadRow(row)
if !got.ShowToolCallsInIM {
t.Fatalf("expected ShowToolCallsInIM=true to propagate from row")
}
}
func TestUpsertRequestShowToolCallsInIM_PointerSemantics(t *testing.T) {
t.Parallel()
// When the field is nil, the UpsertRequest should not touch the current
// setting. When non-nil, the dereferenced value should win. We exercise
// the small gate block without hitting the database.
current := Settings{ShowToolCallsInIM: true}
var req UpsertRequest
if req.ShowToolCallsInIM != nil {
current.ShowToolCallsInIM = *req.ShowToolCallsInIM
}
if !current.ShowToolCallsInIM {
t.Fatalf("nil pointer must leave current value unchanged")
}
off := false
req.ShowToolCallsInIM = &off
if req.ShowToolCallsInIM != nil {
current.ShowToolCallsInIM = *req.ShowToolCallsInIM
}
if current.ShowToolCallsInIM {
t.Fatalf("explicit false pointer must clear the flag")
}
}
+2
View File
@@ -29,6 +29,7 @@ type Settings struct {
CompactionModelID string `json:"compaction_model_id,omitempty"`
DiscussProbeModelID string `json:"discuss_probe_model_id,omitempty"`
PersistFullToolResults bool `json:"persist_full_tool_results"`
ShowToolCallsInIM bool `json:"show_tool_calls_in_im"`
}
type UpsertRequest struct {
@@ -54,4 +55,5 @@ type UpsertRequest struct {
CompactionModelID *string `json:"compaction_model_id,omitempty"`
DiscussProbeModelID string `json:"discuss_probe_model_id,omitempty"`
PersistFullToolResults *bool `json:"persist_full_tool_results,omitempty"`
ShowToolCallsInIM *bool `json:"show_tool_calls_in_im,omitempty"`
}
+2
View File
@@ -1753,6 +1753,7 @@ export type SettingsSettings = {
reasoning_effort?: string;
reasoning_enabled?: boolean;
search_provider_id?: string;
show_tool_calls_in_im?: boolean;
timezone?: string;
title_model_id?: string;
transcription_model_id?: string;
@@ -1778,6 +1779,7 @@ export type SettingsUpsertRequest = {
reasoning_effort?: string;
reasoning_enabled?: boolean;
search_provider_id?: string;
show_tool_calls_in_im?: boolean;
timezone?: string;
title_model_id?: string;
transcription_model_id?: string;
+6
View File
@@ -13783,6 +13783,9 @@ const docTemplate = `{
"search_provider_id": {
"type": "string"
},
"show_tool_calls_in_im": {
"type": "boolean"
},
"timezone": {
"type": "string"
},
@@ -13854,6 +13857,9 @@ const docTemplate = `{
"search_provider_id": {
"type": "string"
},
"show_tool_calls_in_im": {
"type": "boolean"
},
"timezone": {
"type": "string"
},
+6
View File
@@ -13774,6 +13774,9 @@
"search_provider_id": {
"type": "string"
},
"show_tool_calls_in_im": {
"type": "boolean"
},
"timezone": {
"type": "string"
},
@@ -13845,6 +13848,9 @@
"search_provider_id": {
"type": "string"
},
"show_tool_calls_in_im": {
"type": "boolean"
},
"timezone": {
"type": "string"
},
+4
View File
@@ -2947,6 +2947,8 @@ definitions:
type: boolean
search_provider_id:
type: string
show_tool_calls_in_im:
type: boolean
timezone:
type: string
title_model_id:
@@ -2994,6 +2996,8 @@ definitions:
type: boolean
search_provider_id:
type: string
show_tool_calls_in_im:
type: boolean
timezone:
type: string
title_model_id: