fix(agent): surface tool calls before input completes

Emit tool-call placeholders as soon as tool input streaming starts so long writes appear immediately in chat. Reuse the same UI tool message when full input arrives to avoid duplicate cards, and keep the hook-required test suite green.
This commit is contained in:
Acbox
2026-04-16 16:42:07 +08:00
parent 1a5b1d6086
commit e0fc2f514e
6 changed files with 193 additions and 12 deletions
+23
View File
@@ -293,6 +293,18 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
aborted = true
}
case *sdk.ToolInputStartPart:
if textLoopProbeBuffer != nil {
textLoopProbeBuffer.Flush()
}
if !sendEvent(ctx, ch, StreamEvent{
Type: EventToolCallStart,
ToolName: p.ToolName,
ToolCallID: p.ID,
}) {
aborted = true
}
case *sdk.StreamToolCallPart:
if textLoopProbeBuffer != nil {
textLoopProbeBuffer.Flush()
@@ -928,6 +940,17 @@ func (a *Agent) runMidStreamRetry(
if !sendEvent(ctx, ch, StreamEvent{Type: EventTextEnd}) {
aborted = true
}
case *sdk.ToolInputStartPart:
if textLoopProbeBuffer != nil {
textLoopProbeBuffer.Flush()
}
if !sendEvent(ctx, ch, StreamEvent{
Type: EventToolCallStart,
ToolName: rp.ToolName,
ToolCallID: rp.ID,
}) {
aborted = true
}
case *sdk.StreamToolCallPart:
if textLoopProbeBuffer != nil {
textLoopProbeBuffer.Flush()
+89
View File
@@ -0,0 +1,89 @@
package agent
import (
"context"
"reflect"
"testing"
sdk "github.com/memohai/twilight-ai/sdk"
)
type agentToolPlaceholderProvider struct{}
func (*agentToolPlaceholderProvider) Name() string { return "tool-placeholder-mock" }
func (*agentToolPlaceholderProvider) ListModels(context.Context) ([]sdk.Model, error) {
return nil, nil
}
func (*agentToolPlaceholderProvider) Test(context.Context) *sdk.ProviderTestResult {
return &sdk.ProviderTestResult{Status: sdk.ProviderStatusOK, Message: "ok"}
}
func (*agentToolPlaceholderProvider) TestModel(context.Context, string) (*sdk.ModelTestResult, error) {
return &sdk.ModelTestResult{Supported: true, Message: "supported"}, nil
}
func (*agentToolPlaceholderProvider) DoGenerate(context.Context, sdk.GenerateParams) (*sdk.GenerateResult, error) {
return &sdk.GenerateResult{FinishReason: sdk.FinishReasonStop}, nil
}
func (*agentToolPlaceholderProvider) DoStream(_ context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) {
ch := make(chan sdk.StreamPart, 8)
go func() {
defer close(ch)
ch <- &sdk.StartPart{}
ch <- &sdk.StartStepPart{}
ch <- &sdk.ToolInputStartPart{ID: "call-1", ToolName: "write"}
ch <- &sdk.StreamToolCallPart{
ToolCallID: "call-1",
ToolName: "write",
Input: map[string]any{"path": "/tmp/long.txt"},
}
ch <- &sdk.FinishStepPart{FinishReason: sdk.FinishReasonStop}
ch <- &sdk.FinishPart{FinishReason: sdk.FinishReasonStop}
}()
return &sdk.StreamResult{Stream: ch}, nil
}
func TestAgentStreamEmitsEarlyToolPlaceholderBeforeFullInput(t *testing.T) {
t.Parallel()
a := New(Deps{})
var events []StreamEvent
for event := range a.Stream(context.Background(), RunConfig{
Model: &sdk.Model{
ID: "mock-model",
Provider: &agentToolPlaceholderProvider{},
},
Messages: []sdk.Message{sdk.UserMessage("write a long file")},
SupportsToolCall: false,
Identity: SessionContext{BotID: "bot-1"},
}) {
events = append(events, event)
}
if len(events) != 4 {
t.Fatalf("expected 4 events, got %d: %#v", len(events), events)
}
if events[0].Type != EventAgentStart {
t.Fatalf("expected first event %q, got %#v", EventAgentStart, events[0])
}
if events[1].Type != EventToolCallStart || events[1].ToolCallID != "call-1" || events[1].ToolName != "write" {
t.Fatalf("unexpected placeholder tool event: %#v", events[1])
}
if events[1].Input != nil {
t.Fatalf("expected placeholder tool event to have nil input, got %#v", events[1].Input)
}
if events[2].Type != EventToolCallStart || events[2].ToolCallID != "call-1" {
t.Fatalf("unexpected full tool event: %#v", events[2])
}
expectedInput := map[string]any{"path": "/tmp/long.txt"}
if !reflect.DeepEqual(events[2].Input, expectedInput) {
t.Fatalf("expected full tool event input %#v, got %#v", expectedInput, events[2].Input)
}
if events[3].Type != EventAgentEnd {
t.Fatalf("expected terminal event %q, got %#v", EventAgentEnd, events[3])
}
}
+23 -12
View File
@@ -68,21 +68,32 @@ func (c *UIMessageStreamConverter) HandleEvent(event UIMessageStreamEvent) []UIM
return nil
case "tool_call_start":
state := &uiToolStreamState{
Message: UIMessage{
ID: c.nextMessageID(),
Type: UIMessageTool,
Name: strings.TrimSpace(event.ToolName),
Input: event.Input,
ToolCallID: strings.TrimSpace(event.ToolCallID),
Running: uiBoolPtr(true),
},
state := c.findToolState(event.ToolCallID, event.ToolName)
if state == nil {
state = &uiToolStreamState{
Message: UIMessage{
ID: c.nextMessageID(),
Type: UIMessageTool,
Name: strings.TrimSpace(event.ToolName),
Input: event.Input,
ToolCallID: strings.TrimSpace(event.ToolCallID),
Running: uiBoolPtr(true),
},
}
}
if state.Message.ToolCallID != "" {
c.tools[state.Message.ToolCallID] = state
if trimmed := strings.TrimSpace(event.ToolName); trimmed != "" {
state.Message.Name = trimmed
}
if event.Input != nil {
state.Message.Input = event.Input
}
if trimmed := strings.TrimSpace(event.ToolCallID); trimmed != "" {
state.Message.ToolCallID = trimmed
c.tools[trimmed] = state
}
state.Message.Running = uiBoolPtr(true)
c.text = nil
return []UIMessage{state.Message}
return []UIMessage{cloneToolStreamMessage(state.Message)}
case "tool_call_progress":
state := c.findToolState(event.ToolCallID, event.ToolName)
+55
View File
@@ -2,6 +2,7 @@ package conversation
import (
"encoding/json"
"reflect"
"testing"
"time"
@@ -201,6 +202,60 @@ func TestUIMessageStreamConverterAccumulatesToolProgress(t *testing.T) {
}
}
func TestUIMessageStreamConverterMergesRepeatedToolCallStart(t *testing.T) {
t.Parallel()
converter := NewUIMessageStreamConverter()
start := converter.HandleEvent(UIMessageStreamEvent{
Type: "tool_call_start",
ToolName: "write",
ToolCallID: "call-1",
})
if len(start) != 1 || start[0].Type != UIMessageTool {
t.Fatalf("unexpected initial tool placeholder: %#v", start)
}
if start[0].Input != nil {
t.Fatalf("expected initial tool placeholder to have nil input, got %#v", start[0].Input)
}
fullInput := map[string]any{"path": "/tmp/long.txt"}
update := converter.HandleEvent(UIMessageStreamEvent{
Type: "tool_call_start",
ToolName: "write",
ToolCallID: "call-1",
Input: fullInput,
})
if len(update) != 1 {
t.Fatalf("expected one updated tool snapshot, got %#v", update)
}
if update[0].ID != start[0].ID {
t.Fatalf("expected repeated tool start to reuse message id, got start=%d update=%d", start[0].ID, update[0].ID)
}
if !reflect.DeepEqual(update[0].Input, fullInput) {
t.Fatalf("expected repeated tool start to backfill input, got %#v", update[0].Input)
}
if update[0].Running == nil || !*update[0].Running {
t.Fatalf("expected merged tool message to stay running, got %#v", update[0])
}
end := converter.HandleEvent(UIMessageStreamEvent{
Type: "tool_call_end",
ToolName: "write",
ToolCallID: "call-1",
Output: map[string]any{"ok": true},
})
if len(end) != 1 || end[0].ID != start[0].ID {
t.Fatalf("expected tool end to reuse merged message id, got %#v", end)
}
if !reflect.DeepEqual(end[0].Input, fullInput) {
t.Fatalf("expected tool end to preserve merged input, got %#v", end[0].Input)
}
if end[0].Running == nil || *end[0].Running {
t.Fatalf("expected tool end to mark message complete, got %#v", end[0])
}
}
func TestUIMessageStreamConverterStartsNewTextBlockAfterTool(t *testing.T) {
converter := NewUIMessageStreamConverter()
+2
View File
@@ -10,6 +10,7 @@ func TestNewProviderHTTPClientWithoutTimeoutKeepsStreamingFriendlyBehavior(t *te
client := NewProviderHTTPClient(0)
if client == nil {
t.Fatal("expected client")
return
}
if client.Timeout != 0 {
t.Fatalf("expected no client timeout, got %s", client.Timeout)
@@ -29,6 +30,7 @@ func TestNewProviderHTTPClientWithTimeout(t *testing.T) {
client := NewProviderHTTPClient(timeout)
if client == nil {
t.Fatal("expected client")
return
}
if client.Timeout != timeout {
t.Fatalf("expected timeout %s, got %s", timeout, client.Timeout)
+1
View File
@@ -177,6 +177,7 @@ func TestAccountMetadataRoundTrip(t *testing.T) {
status := parsed.toStatus()
if status == nil {
t.Fatal("expected account status")
return
}
if status.Label != account.Label {
t.Fatalf("expected status label %q, got %q", account.Label, status.Label)