From e0fc2f514e018f266d325d817c42af969cfdbb59 Mon Sep 17 00:00:00 2001 From: Acbox Date: Thu, 16 Apr 2026 16:42:07 +0800 Subject: [PATCH] fix(agent): surface tool calls before input completes Emit tool-call placeholders as soon as tool input streaming starts so long writes appear immediately in chat. Reuse the same UI tool message when full input arrives to avoid duplicate cards, and keep the hook-required test suite green. --- internal/agent/agent.go | 23 ++++++ internal/agent/stream_test.go | 89 +++++++++++++++++++++++ internal/conversation/uimessage_stream.go | 35 ++++++--- internal/conversation/uimessage_test.go | 55 ++++++++++++++ internal/models/http_client_test.go | 2 + internal/providers/service_test.go | 1 + 6 files changed, 193 insertions(+), 12 deletions(-) create mode 100644 internal/agent/stream_test.go diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 9f47c01c..5c30204c 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -293,6 +293,18 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv aborted = true } + case *sdk.ToolInputStartPart: + if textLoopProbeBuffer != nil { + textLoopProbeBuffer.Flush() + } + if !sendEvent(ctx, ch, StreamEvent{ + Type: EventToolCallStart, + ToolName: p.ToolName, + ToolCallID: p.ID, + }) { + aborted = true + } + case *sdk.StreamToolCallPart: if textLoopProbeBuffer != nil { textLoopProbeBuffer.Flush() @@ -928,6 +940,17 @@ func (a *Agent) runMidStreamRetry( if !sendEvent(ctx, ch, StreamEvent{Type: EventTextEnd}) { aborted = true } + case *sdk.ToolInputStartPart: + if textLoopProbeBuffer != nil { + textLoopProbeBuffer.Flush() + } + if !sendEvent(ctx, ch, StreamEvent{ + Type: EventToolCallStart, + ToolName: rp.ToolName, + ToolCallID: rp.ID, + }) { + aborted = true + } case *sdk.StreamToolCallPart: if textLoopProbeBuffer != nil { textLoopProbeBuffer.Flush() diff --git a/internal/agent/stream_test.go b/internal/agent/stream_test.go new file mode 100644 index 00000000..acbef0fb --- /dev/null +++ b/internal/agent/stream_test.go @@ -0,0 +1,89 @@ +package agent + +import ( + "context" + "reflect" + "testing" + + sdk "github.com/memohai/twilight-ai/sdk" +) + +type agentToolPlaceholderProvider struct{} + +func (*agentToolPlaceholderProvider) Name() string { return "tool-placeholder-mock" } + +func (*agentToolPlaceholderProvider) ListModels(context.Context) ([]sdk.Model, error) { + return nil, nil +} + +func (*agentToolPlaceholderProvider) Test(context.Context) *sdk.ProviderTestResult { + return &sdk.ProviderTestResult{Status: sdk.ProviderStatusOK, Message: "ok"} +} + +func (*agentToolPlaceholderProvider) TestModel(context.Context, string) (*sdk.ModelTestResult, error) { + return &sdk.ModelTestResult{Supported: true, Message: "supported"}, nil +} + +func (*agentToolPlaceholderProvider) DoGenerate(context.Context, sdk.GenerateParams) (*sdk.GenerateResult, error) { + return &sdk.GenerateResult{FinishReason: sdk.FinishReasonStop}, nil +} + +func (*agentToolPlaceholderProvider) DoStream(_ context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) { + ch := make(chan sdk.StreamPart, 8) + go func() { + defer close(ch) + ch <- &sdk.StartPart{} + ch <- &sdk.StartStepPart{} + ch <- &sdk.ToolInputStartPart{ID: "call-1", ToolName: "write"} + ch <- &sdk.StreamToolCallPart{ + ToolCallID: "call-1", + ToolName: "write", + Input: map[string]any{"path": "/tmp/long.txt"}, + } + ch <- &sdk.FinishStepPart{FinishReason: sdk.FinishReasonStop} + ch <- &sdk.FinishPart{FinishReason: sdk.FinishReasonStop} + }() + return &sdk.StreamResult{Stream: ch}, nil +} + +func TestAgentStreamEmitsEarlyToolPlaceholderBeforeFullInput(t *testing.T) { + t.Parallel() + + a := New(Deps{}) + + var events []StreamEvent + for event := range a.Stream(context.Background(), RunConfig{ + Model: &sdk.Model{ + ID: "mock-model", + Provider: &agentToolPlaceholderProvider{}, + }, + Messages: []sdk.Message{sdk.UserMessage("write a long file")}, + SupportsToolCall: false, + Identity: SessionContext{BotID: "bot-1"}, + }) { + events = append(events, event) + } + + if len(events) != 4 { + t.Fatalf("expected 4 events, got %d: %#v", len(events), events) + } + if events[0].Type != EventAgentStart { + t.Fatalf("expected first event %q, got %#v", EventAgentStart, events[0]) + } + if events[1].Type != EventToolCallStart || events[1].ToolCallID != "call-1" || events[1].ToolName != "write" { + t.Fatalf("unexpected placeholder tool event: %#v", events[1]) + } + if events[1].Input != nil { + t.Fatalf("expected placeholder tool event to have nil input, got %#v", events[1].Input) + } + if events[2].Type != EventToolCallStart || events[2].ToolCallID != "call-1" { + t.Fatalf("unexpected full tool event: %#v", events[2]) + } + expectedInput := map[string]any{"path": "/tmp/long.txt"} + if !reflect.DeepEqual(events[2].Input, expectedInput) { + t.Fatalf("expected full tool event input %#v, got %#v", expectedInput, events[2].Input) + } + if events[3].Type != EventAgentEnd { + t.Fatalf("expected terminal event %q, got %#v", EventAgentEnd, events[3]) + } +} diff --git a/internal/conversation/uimessage_stream.go b/internal/conversation/uimessage_stream.go index 90764edd..f64b367e 100644 --- a/internal/conversation/uimessage_stream.go +++ b/internal/conversation/uimessage_stream.go @@ -68,21 +68,32 @@ func (c *UIMessageStreamConverter) HandleEvent(event UIMessageStreamEvent) []UIM return nil case "tool_call_start": - state := &uiToolStreamState{ - Message: UIMessage{ - ID: c.nextMessageID(), - Type: UIMessageTool, - Name: strings.TrimSpace(event.ToolName), - Input: event.Input, - ToolCallID: strings.TrimSpace(event.ToolCallID), - Running: uiBoolPtr(true), - }, + state := c.findToolState(event.ToolCallID, event.ToolName) + if state == nil { + state = &uiToolStreamState{ + Message: UIMessage{ + ID: c.nextMessageID(), + Type: UIMessageTool, + Name: strings.TrimSpace(event.ToolName), + Input: event.Input, + ToolCallID: strings.TrimSpace(event.ToolCallID), + Running: uiBoolPtr(true), + }, + } } - if state.Message.ToolCallID != "" { - c.tools[state.Message.ToolCallID] = state + if trimmed := strings.TrimSpace(event.ToolName); trimmed != "" { + state.Message.Name = trimmed } + if event.Input != nil { + state.Message.Input = event.Input + } + if trimmed := strings.TrimSpace(event.ToolCallID); trimmed != "" { + state.Message.ToolCallID = trimmed + c.tools[trimmed] = state + } + state.Message.Running = uiBoolPtr(true) c.text = nil - return []UIMessage{state.Message} + return []UIMessage{cloneToolStreamMessage(state.Message)} case "tool_call_progress": state := c.findToolState(event.ToolCallID, event.ToolName) diff --git a/internal/conversation/uimessage_test.go b/internal/conversation/uimessage_test.go index 386ae9b3..9bf836e3 100644 --- a/internal/conversation/uimessage_test.go +++ b/internal/conversation/uimessage_test.go @@ -2,6 +2,7 @@ package conversation import ( "encoding/json" + "reflect" "testing" "time" @@ -201,6 +202,60 @@ func TestUIMessageStreamConverterAccumulatesToolProgress(t *testing.T) { } } +func TestUIMessageStreamConverterMergesRepeatedToolCallStart(t *testing.T) { + t.Parallel() + + converter := NewUIMessageStreamConverter() + + start := converter.HandleEvent(UIMessageStreamEvent{ + Type: "tool_call_start", + ToolName: "write", + ToolCallID: "call-1", + }) + if len(start) != 1 || start[0].Type != UIMessageTool { + t.Fatalf("unexpected initial tool placeholder: %#v", start) + } + if start[0].Input != nil { + t.Fatalf("expected initial tool placeholder to have nil input, got %#v", start[0].Input) + } + + fullInput := map[string]any{"path": "/tmp/long.txt"} + update := converter.HandleEvent(UIMessageStreamEvent{ + Type: "tool_call_start", + ToolName: "write", + ToolCallID: "call-1", + Input: fullInput, + }) + if len(update) != 1 { + t.Fatalf("expected one updated tool snapshot, got %#v", update) + } + if update[0].ID != start[0].ID { + t.Fatalf("expected repeated tool start to reuse message id, got start=%d update=%d", start[0].ID, update[0].ID) + } + if !reflect.DeepEqual(update[0].Input, fullInput) { + t.Fatalf("expected repeated tool start to backfill input, got %#v", update[0].Input) + } + if update[0].Running == nil || !*update[0].Running { + t.Fatalf("expected merged tool message to stay running, got %#v", update[0]) + } + + end := converter.HandleEvent(UIMessageStreamEvent{ + Type: "tool_call_end", + ToolName: "write", + ToolCallID: "call-1", + Output: map[string]any{"ok": true}, + }) + if len(end) != 1 || end[0].ID != start[0].ID { + t.Fatalf("expected tool end to reuse merged message id, got %#v", end) + } + if !reflect.DeepEqual(end[0].Input, fullInput) { + t.Fatalf("expected tool end to preserve merged input, got %#v", end[0].Input) + } + if end[0].Running == nil || *end[0].Running { + t.Fatalf("expected tool end to mark message complete, got %#v", end[0]) + } +} + func TestUIMessageStreamConverterStartsNewTextBlockAfterTool(t *testing.T) { converter := NewUIMessageStreamConverter() diff --git a/internal/models/http_client_test.go b/internal/models/http_client_test.go index b9ac1db7..0a2ce56f 100644 --- a/internal/models/http_client_test.go +++ b/internal/models/http_client_test.go @@ -10,6 +10,7 @@ func TestNewProviderHTTPClientWithoutTimeoutKeepsStreamingFriendlyBehavior(t *te client := NewProviderHTTPClient(0) if client == nil { t.Fatal("expected client") + return } if client.Timeout != 0 { t.Fatalf("expected no client timeout, got %s", client.Timeout) @@ -29,6 +30,7 @@ func TestNewProviderHTTPClientWithTimeout(t *testing.T) { client := NewProviderHTTPClient(timeout) if client == nil { t.Fatal("expected client") + return } if client.Timeout != timeout { t.Fatalf("expected timeout %s, got %s", timeout, client.Timeout) diff --git a/internal/providers/service_test.go b/internal/providers/service_test.go index 3c6069c4..89f85104 100644 --- a/internal/providers/service_test.go +++ b/internal/providers/service_test.go @@ -177,6 +177,7 @@ func TestAccountMetadataRoundTrip(t *testing.T) { status := parsed.toStatus() if status == nil { t.Fatal("expected account status") + return } if status.Label != account.Label { t.Fatalf("expected status label %q, got %q", account.Label, status.Label)