diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 9f47c01c..5c30204c 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -293,6 +293,18 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv aborted = true } + case *sdk.ToolInputStartPart: + if textLoopProbeBuffer != nil { + textLoopProbeBuffer.Flush() + } + if !sendEvent(ctx, ch, StreamEvent{ + Type: EventToolCallStart, + ToolName: p.ToolName, + ToolCallID: p.ID, + }) { + aborted = true + } + case *sdk.StreamToolCallPart: if textLoopProbeBuffer != nil { textLoopProbeBuffer.Flush() @@ -928,6 +940,17 @@ func (a *Agent) runMidStreamRetry( if !sendEvent(ctx, ch, StreamEvent{Type: EventTextEnd}) { aborted = true } + case *sdk.ToolInputStartPart: + if textLoopProbeBuffer != nil { + textLoopProbeBuffer.Flush() + } + if !sendEvent(ctx, ch, StreamEvent{ + Type: EventToolCallStart, + ToolName: rp.ToolName, + ToolCallID: rp.ID, + }) { + aborted = true + } case *sdk.StreamToolCallPart: if textLoopProbeBuffer != nil { textLoopProbeBuffer.Flush() diff --git a/internal/agent/stream_test.go b/internal/agent/stream_test.go new file mode 100644 index 00000000..acbef0fb --- /dev/null +++ b/internal/agent/stream_test.go @@ -0,0 +1,89 @@ +package agent + +import ( + "context" + "reflect" + "testing" + + sdk "github.com/memohai/twilight-ai/sdk" +) + +type agentToolPlaceholderProvider struct{} + +func (*agentToolPlaceholderProvider) Name() string { return "tool-placeholder-mock" } + +func (*agentToolPlaceholderProvider) ListModels(context.Context) ([]sdk.Model, error) { + return nil, nil +} + +func (*agentToolPlaceholderProvider) Test(context.Context) *sdk.ProviderTestResult { + return &sdk.ProviderTestResult{Status: sdk.ProviderStatusOK, Message: "ok"} +} + +func (*agentToolPlaceholderProvider) TestModel(context.Context, string) (*sdk.ModelTestResult, error) { + return &sdk.ModelTestResult{Supported: true, Message: "supported"}, nil +} + +func (*agentToolPlaceholderProvider) DoGenerate(context.Context, sdk.GenerateParams) (*sdk.GenerateResult, error) { + return &sdk.GenerateResult{FinishReason: sdk.FinishReasonStop}, nil +} + +func (*agentToolPlaceholderProvider) DoStream(_ context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) { + ch := make(chan sdk.StreamPart, 8) + go func() { + defer close(ch) + ch <- &sdk.StartPart{} + ch <- &sdk.StartStepPart{} + ch <- &sdk.ToolInputStartPart{ID: "call-1", ToolName: "write"} + ch <- &sdk.StreamToolCallPart{ + ToolCallID: "call-1", + ToolName: "write", + Input: map[string]any{"path": "/tmp/long.txt"}, + } + ch <- &sdk.FinishStepPart{FinishReason: sdk.FinishReasonStop} + ch <- &sdk.FinishPart{FinishReason: sdk.FinishReasonStop} + }() + return &sdk.StreamResult{Stream: ch}, nil +} + +func TestAgentStreamEmitsEarlyToolPlaceholderBeforeFullInput(t *testing.T) { + t.Parallel() + + a := New(Deps{}) + + var events []StreamEvent + for event := range a.Stream(context.Background(), RunConfig{ + Model: &sdk.Model{ + ID: "mock-model", + Provider: &agentToolPlaceholderProvider{}, + }, + Messages: []sdk.Message{sdk.UserMessage("write a long file")}, + SupportsToolCall: false, + Identity: SessionContext{BotID: "bot-1"}, + }) { + events = append(events, event) + } + + if len(events) != 4 { + t.Fatalf("expected 4 events, got %d: %#v", len(events), events) + } + if events[0].Type != EventAgentStart { + t.Fatalf("expected first event %q, got %#v", EventAgentStart, events[0]) + } + if events[1].Type != EventToolCallStart || events[1].ToolCallID != "call-1" || events[1].ToolName != "write" { + t.Fatalf("unexpected placeholder tool event: %#v", events[1]) + } + if events[1].Input != nil { + t.Fatalf("expected placeholder tool event to have nil input, got %#v", events[1].Input) + } + if events[2].Type != EventToolCallStart || events[2].ToolCallID != "call-1" { + t.Fatalf("unexpected full tool event: %#v", events[2]) + } + expectedInput := map[string]any{"path": "/tmp/long.txt"} + if !reflect.DeepEqual(events[2].Input, expectedInput) { + t.Fatalf("expected full tool event input %#v, got %#v", expectedInput, events[2].Input) + } + if events[3].Type != EventAgentEnd { + t.Fatalf("expected terminal event %q, got %#v", EventAgentEnd, events[3]) + } +} diff --git a/internal/conversation/uimessage_stream.go b/internal/conversation/uimessage_stream.go index 90764edd..f64b367e 100644 --- a/internal/conversation/uimessage_stream.go +++ b/internal/conversation/uimessage_stream.go @@ -68,21 +68,32 @@ func (c *UIMessageStreamConverter) HandleEvent(event UIMessageStreamEvent) []UIM return nil case "tool_call_start": - state := &uiToolStreamState{ - Message: UIMessage{ - ID: c.nextMessageID(), - Type: UIMessageTool, - Name: strings.TrimSpace(event.ToolName), - Input: event.Input, - ToolCallID: strings.TrimSpace(event.ToolCallID), - Running: uiBoolPtr(true), - }, + state := c.findToolState(event.ToolCallID, event.ToolName) + if state == nil { + state = &uiToolStreamState{ + Message: UIMessage{ + ID: c.nextMessageID(), + Type: UIMessageTool, + Name: strings.TrimSpace(event.ToolName), + Input: event.Input, + ToolCallID: strings.TrimSpace(event.ToolCallID), + Running: uiBoolPtr(true), + }, + } } - if state.Message.ToolCallID != "" { - c.tools[state.Message.ToolCallID] = state + if trimmed := strings.TrimSpace(event.ToolName); trimmed != "" { + state.Message.Name = trimmed } + if event.Input != nil { + state.Message.Input = event.Input + } + if trimmed := strings.TrimSpace(event.ToolCallID); trimmed != "" { + state.Message.ToolCallID = trimmed + c.tools[trimmed] = state + } + state.Message.Running = uiBoolPtr(true) c.text = nil - return []UIMessage{state.Message} + return []UIMessage{cloneToolStreamMessage(state.Message)} case "tool_call_progress": state := c.findToolState(event.ToolCallID, event.ToolName) diff --git a/internal/conversation/uimessage_test.go b/internal/conversation/uimessage_test.go index 386ae9b3..9bf836e3 100644 --- a/internal/conversation/uimessage_test.go +++ b/internal/conversation/uimessage_test.go @@ -2,6 +2,7 @@ package conversation import ( "encoding/json" + "reflect" "testing" "time" @@ -201,6 +202,60 @@ func TestUIMessageStreamConverterAccumulatesToolProgress(t *testing.T) { } } +func TestUIMessageStreamConverterMergesRepeatedToolCallStart(t *testing.T) { + t.Parallel() + + converter := NewUIMessageStreamConverter() + + start := converter.HandleEvent(UIMessageStreamEvent{ + Type: "tool_call_start", + ToolName: "write", + ToolCallID: "call-1", + }) + if len(start) != 1 || start[0].Type != UIMessageTool { + t.Fatalf("unexpected initial tool placeholder: %#v", start) + } + if start[0].Input != nil { + t.Fatalf("expected initial tool placeholder to have nil input, got %#v", start[0].Input) + } + + fullInput := map[string]any{"path": "/tmp/long.txt"} + update := converter.HandleEvent(UIMessageStreamEvent{ + Type: "tool_call_start", + ToolName: "write", + ToolCallID: "call-1", + Input: fullInput, + }) + if len(update) != 1 { + t.Fatalf("expected one updated tool snapshot, got %#v", update) + } + if update[0].ID != start[0].ID { + t.Fatalf("expected repeated tool start to reuse message id, got start=%d update=%d", start[0].ID, update[0].ID) + } + if !reflect.DeepEqual(update[0].Input, fullInput) { + t.Fatalf("expected repeated tool start to backfill input, got %#v", update[0].Input) + } + if update[0].Running == nil || !*update[0].Running { + t.Fatalf("expected merged tool message to stay running, got %#v", update[0]) + } + + end := converter.HandleEvent(UIMessageStreamEvent{ + Type: "tool_call_end", + ToolName: "write", + ToolCallID: "call-1", + Output: map[string]any{"ok": true}, + }) + if len(end) != 1 || end[0].ID != start[0].ID { + t.Fatalf("expected tool end to reuse merged message id, got %#v", end) + } + if !reflect.DeepEqual(end[0].Input, fullInput) { + t.Fatalf("expected tool end to preserve merged input, got %#v", end[0].Input) + } + if end[0].Running == nil || *end[0].Running { + t.Fatalf("expected tool end to mark message complete, got %#v", end[0]) + } +} + func TestUIMessageStreamConverterStartsNewTextBlockAfterTool(t *testing.T) { converter := NewUIMessageStreamConverter() diff --git a/internal/models/http_client_test.go b/internal/models/http_client_test.go index b9ac1db7..0a2ce56f 100644 --- a/internal/models/http_client_test.go +++ b/internal/models/http_client_test.go @@ -10,6 +10,7 @@ func TestNewProviderHTTPClientWithoutTimeoutKeepsStreamingFriendlyBehavior(t *te client := NewProviderHTTPClient(0) if client == nil { t.Fatal("expected client") + return } if client.Timeout != 0 { t.Fatalf("expected no client timeout, got %s", client.Timeout) @@ -29,6 +30,7 @@ func TestNewProviderHTTPClientWithTimeout(t *testing.T) { client := NewProviderHTTPClient(timeout) if client == nil { t.Fatal("expected client") + return } if client.Timeout != timeout { t.Fatalf("expected timeout %s, got %s", timeout, client.Timeout) diff --git a/internal/providers/service_test.go b/internal/providers/service_test.go index 3c6069c4..89f85104 100644 --- a/internal/providers/service_test.go +++ b/internal/providers/service_test.go @@ -177,6 +177,7 @@ func TestAccountMetadataRoundTrip(t *testing.T) { status := parsed.toStatus() if status == nil { t.Fatal("expected account status") + return } if status.Label != account.Label { t.Fatalf("expected status label %q, got %q", account.Label, status.Label)