From 3509947bc02c50650001d57eacd358a6c4558a04 Mon Sep 17 00:00:00 2001 From: aki Date: Tue, 14 Apr 2026 05:50:04 +0900 Subject: [PATCH] Fix dangling tool call history --- internal/conversation/flow/resolver.go | 1 + internal/conversation/flow/resolver_store.go | 1 + internal/conversation/flow/tool_closure.go | 177 ++++++++++++++++++ .../conversation/flow/tool_closure_test.go | 108 +++++++++++ 4 files changed, 287 insertions(+) create mode 100644 internal/conversation/flow/tool_closure.go create mode 100644 internal/conversation/flow/tool_closure_test.go diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go index 0a45887f..022830c0 100644 --- a/internal/conversation/flow/resolver.go +++ b/internal/conversation/flow/resolver.go @@ -360,6 +360,7 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r if len(messages) > 10 { messages = stripToolMessages(messages) } + messages = repairToolCallClosures(messages, syntheticToolClosureError) displayName := r.resolveDisplayName(ctx, req) mergedAttachments := r.routeAndMergeAttachments(ctx, chatModel, req) diff --git a/internal/conversation/flow/resolver_store.go b/internal/conversation/flow/resolver_store.go index ea5f63d6..58c453e5 100644 --- a/internal/conversation/flow/resolver_store.go +++ b/internal/conversation/flow/resolver_store.go @@ -26,6 +26,7 @@ func (r *Resolver) storeRound(ctx context.Context, req conversation.ChatRequest, } fullRound = append(fullRound, m) } + fullRound = repairToolCallClosures(fullRound, syntheticToolClosureError) // Filter out empty assistant messages (content: []) that result from LLM // returning no useful output (e.g., context window overflow). These provide diff --git a/internal/conversation/flow/tool_closure.go b/internal/conversation/flow/tool_closure.go new file mode 100644 index 00000000..1e0b4c81 --- /dev/null +++ b/internal/conversation/flow/tool_closure.go @@ -0,0 +1,177 @@ +package flow + +import ( + "strings" + + sdk "github.com/memohai/twilight-ai/sdk" + + "github.com/memohai/memoh/internal/conversation" +) + +const syntheticToolClosureError = "tool execution interrupted before a response was recorded" + +type pendingToolCall struct { + ID string + ToolName string +} + +func repairToolCallClosures(messages []conversation.ModelMessage, reason string) []conversation.ModelMessage { + if len(messages) == 0 { + return messages + } + if strings.TrimSpace(reason) == "" { + reason = syntheticToolClosureError + } + + repaired := make([]conversation.ModelMessage, 0, len(messages)) + pending := make(map[string]pendingToolCall) + pendingOrder := make([]string, 0) + + flushPending := func() { + if len(pendingOrder) == 0 { + return + } + for _, callID := range pendingOrder { + call, ok := pending[callID] + if !ok { + continue + } + repaired = append(repaired, syntheticToolResultMessage(call.ID, call.ToolName, reason)) + delete(pending, callID) + } + pendingOrder = pendingOrder[:0] + } + + for _, msg := range messages { + role := strings.ToLower(strings.TrimSpace(msg.Role)) + switch role { + case "assistant": + if len(pending) > 0 { + flushPending() + } + repaired = append(repaired, msg) + for _, call := range extractAssistantToolCallParts(msg) { + callID := strings.TrimSpace(call.ToolCallID) + if callID == "" { + continue + } + if _, exists := pending[callID]; exists { + continue + } + pending[callID] = pendingToolCall{ + ID: callID, + ToolName: strings.TrimSpace(call.ToolName), + } + pendingOrder = append(pendingOrder, callID) + } + + case "tool": + filtered := filterToolMessageToPending(msg, pending) + if filtered == nil { + continue + } + repaired = append(repaired, *filtered) + for _, result := range extractToolResultParts(*filtered) { + delete(pending, strings.TrimSpace(result.ToolCallID)) + } + if len(pending) == 0 && len(pendingOrder) > 0 { + pendingOrder = pendingOrder[:0] + continue + } + if len(pendingOrder) > 0 { + nextOrder := pendingOrder[:0] + for _, callID := range pendingOrder { + if _, ok := pending[callID]; ok { + nextOrder = append(nextOrder, callID) + } + } + pendingOrder = nextOrder + } + + default: + if len(pending) > 0 { + flushPending() + } + repaired = append(repaired, msg) + } + } + + if len(pending) > 0 { + flushPending() + } + return repaired +} + +func extractAssistantToolCallParts(msg conversation.ModelMessage) []sdk.ToolCallPart { + sdkMsg := modelMessageToSDKMessage(msg) + if len(sdkMsg.Content) == 0 { + return nil + } + calls := make([]sdk.ToolCallPart, 0, len(sdkMsg.Content)) + for _, part := range sdkMsg.Content { + call, ok := part.(sdk.ToolCallPart) + if !ok { + continue + } + calls = append(calls, call) + } + return calls +} + +func extractToolResultParts(msg conversation.ModelMessage) []sdk.ToolResultPart { + sdkMsg := modelMessageToSDKMessage(msg) + if len(sdkMsg.Content) == 0 { + return nil + } + results := make([]sdk.ToolResultPart, 0, len(sdkMsg.Content)) + for _, part := range sdkMsg.Content { + result, ok := part.(sdk.ToolResultPart) + if !ok { + continue + } + results = append(results, result) + } + return results +} + +func filterToolMessageToPending(msg conversation.ModelMessage, pending map[string]pendingToolCall) *conversation.ModelMessage { + results := extractToolResultParts(msg) + if len(results) == 0 { + return nil + } + + filtered := make([]sdk.ToolResultPart, 0, len(results)) + for _, result := range results { + if _, ok := pending[strings.TrimSpace(result.ToolCallID)]; !ok { + continue + } + filtered = append(filtered, result) + } + if len(filtered) == 0 { + return nil + } + + converted := sdkMessagesToModelMessages([]sdk.Message{sdk.ToolMessage(filtered...)}) + if len(converted) == 0 { + return nil + } + filteredMsg := converted[0] + filteredMsg.Usage = msg.Usage + return &filteredMsg +} + +func syntheticToolResultMessage(toolCallID, toolName, reason string) conversation.ModelMessage { + converted := sdkMessagesToModelMessages([]sdk.Message{sdk.ToolMessage(sdk.ToolResultPart{ + ToolCallID: strings.TrimSpace(toolCallID), + ToolName: strings.TrimSpace(toolName), + Result: strings.TrimSpace(reason), + IsError: true, + })}) + if len(converted) == 0 { + return conversation.ModelMessage{ + Role: "tool", + Content: conversation.NewTextContent(strings.TrimSpace(reason)), + } + } + return converted[0] +} diff --git a/internal/conversation/flow/tool_closure_test.go b/internal/conversation/flow/tool_closure_test.go new file mode 100644 index 00000000..5fe9be49 --- /dev/null +++ b/internal/conversation/flow/tool_closure_test.go @@ -0,0 +1,108 @@ +package flow + +import ( + "testing" + + sdk "github.com/memohai/twilight-ai/sdk" + + "github.com/memohai/memoh/internal/conversation" +) + +func TestRepairToolCallClosures_AppendsSyntheticToolResultForDanglingAssistantCall(t *testing.T) { + t.Parallel() + + messages := sdkMessagesToModelMessages([]sdk.Message{ + sdk.UserMessage("fetch this"), + { + Role: sdk.MessageRoleAssistant, + Content: []sdk.MessagePart{ + sdk.ToolCallPart{ + ToolCallID: "web_fetch:10", + ToolName: "web_fetch", + Input: map[string]any{"url": "https://example.com"}, + }, + }, + }, + { + Role: sdk.MessageRoleAssistant, + Content: []sdk.MessagePart{sdk.TextPart{Text: "interrupted"}}, + }, + }) + + repaired := repairToolCallClosures(messages, syntheticToolClosureError) + if len(repaired) != 4 { + t.Fatalf("expected 4 messages after repair, got %d", len(repaired)) + } + + if repaired[2].Role != "tool" { + t.Fatalf("expected synthetic tool message before trailing assistant, got role %q", repaired[2].Role) + } + + results := extractToolResultParts(repaired[2]) + if len(results) != 1 { + t.Fatalf("expected 1 tool result part, got %d", len(results)) + } + if results[0].ToolCallID != "web_fetch:10" { + t.Fatalf("expected tool call id web_fetch:10, got %q", results[0].ToolCallID) + } + if !results[0].IsError { + t.Fatal("expected synthetic tool result to be marked as error") + } +} + +func TestRepairToolCallClosures_DropsOrphanToolMessage(t *testing.T) { + t.Parallel() + + orphanTool := sdkMessagesToModelMessages([]sdk.Message{ + sdk.ToolMessage(sdk.ToolResultPart{ + ToolCallID: "web_fetch:10", + ToolName: "web_fetch", + Result: "orphan", + }), + })[0] + + messages := []conversation.ModelMessage{ + {Role: "user", Content: conversation.NewTextContent("hello")}, + orphanTool, + {Role: "assistant", Content: conversation.NewTextContent("done")}, + } + + repaired := repairToolCallClosures(messages, syntheticToolClosureError) + if len(repaired) != 2 { + t.Fatalf("expected orphan tool message to be removed, got %d messages", len(repaired)) + } + if repaired[0].Role != "user" || repaired[1].Role != "assistant" { + t.Fatalf("unexpected repaired role sequence: %q -> %q", repaired[0].Role, repaired[1].Role) + } +} + +func TestRepairToolCallClosures_PreservesValidAssistantToolPair(t *testing.T) { + t.Parallel() + + messages := sdkMessagesToModelMessages([]sdk.Message{ + { + Role: sdk.MessageRoleAssistant, + Content: []sdk.MessagePart{ + sdk.ToolCallPart{ + ToolCallID: "web_search:1", + ToolName: "web_search", + Input: map[string]any{"query": "memoh"}, + }, + }, + }, + sdk.ToolMessage(sdk.ToolResultPart{ + ToolCallID: "web_search:1", + ToolName: "web_search", + Result: map[string]any{"results": []any{}}, + }), + }) + + repaired := repairToolCallClosures(messages, syntheticToolClosureError) + if len(repaired) != 2 { + t.Fatalf("expected valid tool pair to be preserved, got %d messages", len(repaired)) + } + results := extractToolResultParts(repaired[1]) + if len(results) != 1 || results[0].ToolCallID != "web_search:1" { + t.Fatalf("unexpected repaired tool results: %#v", results) + } +}