mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
fix(agent): surface tool calls before input completes
Emit tool-call placeholders as soon as tool input streaming starts so long writes appear immediately in chat. Reuse the same UI tool message when full input arrives to avoid duplicate cards, and keep the hook-required test suite green.
This commit is contained in:
@@ -293,6 +293,18 @@ func (a *Agent) runStream(ctx context.Context, cfg RunConfig, ch chan<- StreamEv
|
||||
aborted = true
|
||||
}
|
||||
|
||||
case *sdk.ToolInputStartPart:
|
||||
if textLoopProbeBuffer != nil {
|
||||
textLoopProbeBuffer.Flush()
|
||||
}
|
||||
if !sendEvent(ctx, ch, StreamEvent{
|
||||
Type: EventToolCallStart,
|
||||
ToolName: p.ToolName,
|
||||
ToolCallID: p.ID,
|
||||
}) {
|
||||
aborted = true
|
||||
}
|
||||
|
||||
case *sdk.StreamToolCallPart:
|
||||
if textLoopProbeBuffer != nil {
|
||||
textLoopProbeBuffer.Flush()
|
||||
@@ -928,6 +940,17 @@ func (a *Agent) runMidStreamRetry(
|
||||
if !sendEvent(ctx, ch, StreamEvent{Type: EventTextEnd}) {
|
||||
aborted = true
|
||||
}
|
||||
case *sdk.ToolInputStartPart:
|
||||
if textLoopProbeBuffer != nil {
|
||||
textLoopProbeBuffer.Flush()
|
||||
}
|
||||
if !sendEvent(ctx, ch, StreamEvent{
|
||||
Type: EventToolCallStart,
|
||||
ToolName: rp.ToolName,
|
||||
ToolCallID: rp.ID,
|
||||
}) {
|
||||
aborted = true
|
||||
}
|
||||
case *sdk.StreamToolCallPart:
|
||||
if textLoopProbeBuffer != nil {
|
||||
textLoopProbeBuffer.Flush()
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
)
|
||||
|
||||
type agentToolPlaceholderProvider struct{}
|
||||
|
||||
func (*agentToolPlaceholderProvider) Name() string { return "tool-placeholder-mock" }
|
||||
|
||||
func (*agentToolPlaceholderProvider) ListModels(context.Context) ([]sdk.Model, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (*agentToolPlaceholderProvider) Test(context.Context) *sdk.ProviderTestResult {
|
||||
return &sdk.ProviderTestResult{Status: sdk.ProviderStatusOK, Message: "ok"}
|
||||
}
|
||||
|
||||
func (*agentToolPlaceholderProvider) TestModel(context.Context, string) (*sdk.ModelTestResult, error) {
|
||||
return &sdk.ModelTestResult{Supported: true, Message: "supported"}, nil
|
||||
}
|
||||
|
||||
func (*agentToolPlaceholderProvider) DoGenerate(context.Context, sdk.GenerateParams) (*sdk.GenerateResult, error) {
|
||||
return &sdk.GenerateResult{FinishReason: sdk.FinishReasonStop}, nil
|
||||
}
|
||||
|
||||
func (*agentToolPlaceholderProvider) DoStream(_ context.Context, _ sdk.GenerateParams) (*sdk.StreamResult, error) {
|
||||
ch := make(chan sdk.StreamPart, 8)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ch <- &sdk.StartPart{}
|
||||
ch <- &sdk.StartStepPart{}
|
||||
ch <- &sdk.ToolInputStartPart{ID: "call-1", ToolName: "write"}
|
||||
ch <- &sdk.StreamToolCallPart{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "write",
|
||||
Input: map[string]any{"path": "/tmp/long.txt"},
|
||||
}
|
||||
ch <- &sdk.FinishStepPart{FinishReason: sdk.FinishReasonStop}
|
||||
ch <- &sdk.FinishPart{FinishReason: sdk.FinishReasonStop}
|
||||
}()
|
||||
return &sdk.StreamResult{Stream: ch}, nil
|
||||
}
|
||||
|
||||
func TestAgentStreamEmitsEarlyToolPlaceholderBeforeFullInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a := New(Deps{})
|
||||
|
||||
var events []StreamEvent
|
||||
for event := range a.Stream(context.Background(), RunConfig{
|
||||
Model: &sdk.Model{
|
||||
ID: "mock-model",
|
||||
Provider: &agentToolPlaceholderProvider{},
|
||||
},
|
||||
Messages: []sdk.Message{sdk.UserMessage("write a long file")},
|
||||
SupportsToolCall: false,
|
||||
Identity: SessionContext{BotID: "bot-1"},
|
||||
}) {
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
if len(events) != 4 {
|
||||
t.Fatalf("expected 4 events, got %d: %#v", len(events), events)
|
||||
}
|
||||
if events[0].Type != EventAgentStart {
|
||||
t.Fatalf("expected first event %q, got %#v", EventAgentStart, events[0])
|
||||
}
|
||||
if events[1].Type != EventToolCallStart || events[1].ToolCallID != "call-1" || events[1].ToolName != "write" {
|
||||
t.Fatalf("unexpected placeholder tool event: %#v", events[1])
|
||||
}
|
||||
if events[1].Input != nil {
|
||||
t.Fatalf("expected placeholder tool event to have nil input, got %#v", events[1].Input)
|
||||
}
|
||||
if events[2].Type != EventToolCallStart || events[2].ToolCallID != "call-1" {
|
||||
t.Fatalf("unexpected full tool event: %#v", events[2])
|
||||
}
|
||||
expectedInput := map[string]any{"path": "/tmp/long.txt"}
|
||||
if !reflect.DeepEqual(events[2].Input, expectedInput) {
|
||||
t.Fatalf("expected full tool event input %#v, got %#v", expectedInput, events[2].Input)
|
||||
}
|
||||
if events[3].Type != EventAgentEnd {
|
||||
t.Fatalf("expected terminal event %q, got %#v", EventAgentEnd, events[3])
|
||||
}
|
||||
}
|
||||
@@ -68,21 +68,32 @@ func (c *UIMessageStreamConverter) HandleEvent(event UIMessageStreamEvent) []UIM
|
||||
return nil
|
||||
|
||||
case "tool_call_start":
|
||||
state := &uiToolStreamState{
|
||||
Message: UIMessage{
|
||||
ID: c.nextMessageID(),
|
||||
Type: UIMessageTool,
|
||||
Name: strings.TrimSpace(event.ToolName),
|
||||
Input: event.Input,
|
||||
ToolCallID: strings.TrimSpace(event.ToolCallID),
|
||||
Running: uiBoolPtr(true),
|
||||
},
|
||||
state := c.findToolState(event.ToolCallID, event.ToolName)
|
||||
if state == nil {
|
||||
state = &uiToolStreamState{
|
||||
Message: UIMessage{
|
||||
ID: c.nextMessageID(),
|
||||
Type: UIMessageTool,
|
||||
Name: strings.TrimSpace(event.ToolName),
|
||||
Input: event.Input,
|
||||
ToolCallID: strings.TrimSpace(event.ToolCallID),
|
||||
Running: uiBoolPtr(true),
|
||||
},
|
||||
}
|
||||
}
|
||||
if state.Message.ToolCallID != "" {
|
||||
c.tools[state.Message.ToolCallID] = state
|
||||
if trimmed := strings.TrimSpace(event.ToolName); trimmed != "" {
|
||||
state.Message.Name = trimmed
|
||||
}
|
||||
if event.Input != nil {
|
||||
state.Message.Input = event.Input
|
||||
}
|
||||
if trimmed := strings.TrimSpace(event.ToolCallID); trimmed != "" {
|
||||
state.Message.ToolCallID = trimmed
|
||||
c.tools[trimmed] = state
|
||||
}
|
||||
state.Message.Running = uiBoolPtr(true)
|
||||
c.text = nil
|
||||
return []UIMessage{state.Message}
|
||||
return []UIMessage{cloneToolStreamMessage(state.Message)}
|
||||
|
||||
case "tool_call_progress":
|
||||
state := c.findToolState(event.ToolCallID, event.ToolName)
|
||||
|
||||
@@ -2,6 +2,7 @@ package conversation
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -201,6 +202,60 @@ func TestUIMessageStreamConverterAccumulatesToolProgress(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUIMessageStreamConverterMergesRepeatedToolCallStart(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
converter := NewUIMessageStreamConverter()
|
||||
|
||||
start := converter.HandleEvent(UIMessageStreamEvent{
|
||||
Type: "tool_call_start",
|
||||
ToolName: "write",
|
||||
ToolCallID: "call-1",
|
||||
})
|
||||
if len(start) != 1 || start[0].Type != UIMessageTool {
|
||||
t.Fatalf("unexpected initial tool placeholder: %#v", start)
|
||||
}
|
||||
if start[0].Input != nil {
|
||||
t.Fatalf("expected initial tool placeholder to have nil input, got %#v", start[0].Input)
|
||||
}
|
||||
|
||||
fullInput := map[string]any{"path": "/tmp/long.txt"}
|
||||
update := converter.HandleEvent(UIMessageStreamEvent{
|
||||
Type: "tool_call_start",
|
||||
ToolName: "write",
|
||||
ToolCallID: "call-1",
|
||||
Input: fullInput,
|
||||
})
|
||||
if len(update) != 1 {
|
||||
t.Fatalf("expected one updated tool snapshot, got %#v", update)
|
||||
}
|
||||
if update[0].ID != start[0].ID {
|
||||
t.Fatalf("expected repeated tool start to reuse message id, got start=%d update=%d", start[0].ID, update[0].ID)
|
||||
}
|
||||
if !reflect.DeepEqual(update[0].Input, fullInput) {
|
||||
t.Fatalf("expected repeated tool start to backfill input, got %#v", update[0].Input)
|
||||
}
|
||||
if update[0].Running == nil || !*update[0].Running {
|
||||
t.Fatalf("expected merged tool message to stay running, got %#v", update[0])
|
||||
}
|
||||
|
||||
end := converter.HandleEvent(UIMessageStreamEvent{
|
||||
Type: "tool_call_end",
|
||||
ToolName: "write",
|
||||
ToolCallID: "call-1",
|
||||
Output: map[string]any{"ok": true},
|
||||
})
|
||||
if len(end) != 1 || end[0].ID != start[0].ID {
|
||||
t.Fatalf("expected tool end to reuse merged message id, got %#v", end)
|
||||
}
|
||||
if !reflect.DeepEqual(end[0].Input, fullInput) {
|
||||
t.Fatalf("expected tool end to preserve merged input, got %#v", end[0].Input)
|
||||
}
|
||||
if end[0].Running == nil || *end[0].Running {
|
||||
t.Fatalf("expected tool end to mark message complete, got %#v", end[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUIMessageStreamConverterStartsNewTextBlockAfterTool(t *testing.T) {
|
||||
converter := NewUIMessageStreamConverter()
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ func TestNewProviderHTTPClientWithoutTimeoutKeepsStreamingFriendlyBehavior(t *te
|
||||
client := NewProviderHTTPClient(0)
|
||||
if client == nil {
|
||||
t.Fatal("expected client")
|
||||
return
|
||||
}
|
||||
if client.Timeout != 0 {
|
||||
t.Fatalf("expected no client timeout, got %s", client.Timeout)
|
||||
@@ -29,6 +30,7 @@ func TestNewProviderHTTPClientWithTimeout(t *testing.T) {
|
||||
client := NewProviderHTTPClient(timeout)
|
||||
if client == nil {
|
||||
t.Fatal("expected client")
|
||||
return
|
||||
}
|
||||
if client.Timeout != timeout {
|
||||
t.Fatalf("expected timeout %s, got %s", timeout, client.Timeout)
|
||||
|
||||
@@ -177,6 +177,7 @@ func TestAccountMetadataRoundTrip(t *testing.T) {
|
||||
status := parsed.toStatus()
|
||||
if status == nil {
|
||||
t.Fatal("expected account status")
|
||||
return
|
||||
}
|
||||
if status.Label != account.Label {
|
||||
t.Fatalf("expected status label %q, got %q", account.Label, status.Label)
|
||||
|
||||
Reference in New Issue
Block a user