Fix dangling tool call history

This commit is contained in:
aki
2026-04-14 05:50:04 +09:00
committed by 晨苒
parent 3945bd913d
commit 3509947bc0
4 changed files with 287 additions and 0 deletions
+1
View File
@@ -360,6 +360,7 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
if len(messages) > 10 {
messages = stripToolMessages(messages)
}
messages = repairToolCallClosures(messages, syntheticToolClosureError)
displayName := r.resolveDisplayName(ctx, req)
mergedAttachments := r.routeAndMergeAttachments(ctx, chatModel, req)
@@ -26,6 +26,7 @@ func (r *Resolver) storeRound(ctx context.Context, req conversation.ChatRequest,
}
fullRound = append(fullRound, m)
}
fullRound = repairToolCallClosures(fullRound, syntheticToolClosureError)
// Filter out empty assistant messages (content: []) that result from LLM
// returning no useful output (e.g., context window overflow). These provide
+177
View File
@@ -0,0 +1,177 @@
package flow
import (
"strings"
sdk "github.com/memohai/twilight-ai/sdk"
"github.com/memohai/memoh/internal/conversation"
)
const syntheticToolClosureError = "tool execution interrupted before a response was recorded"
type pendingToolCall struct {
ID string
ToolName string
}
func repairToolCallClosures(messages []conversation.ModelMessage, reason string) []conversation.ModelMessage {
if len(messages) == 0 {
return messages
}
if strings.TrimSpace(reason) == "" {
reason = syntheticToolClosureError
}
repaired := make([]conversation.ModelMessage, 0, len(messages))
pending := make(map[string]pendingToolCall)
pendingOrder := make([]string, 0)
flushPending := func() {
if len(pendingOrder) == 0 {
return
}
for _, callID := range pendingOrder {
call, ok := pending[callID]
if !ok {
continue
}
repaired = append(repaired, syntheticToolResultMessage(call.ID, call.ToolName, reason))
delete(pending, callID)
}
pendingOrder = pendingOrder[:0]
}
for _, msg := range messages {
role := strings.ToLower(strings.TrimSpace(msg.Role))
switch role {
case "assistant":
if len(pending) > 0 {
flushPending()
}
repaired = append(repaired, msg)
for _, call := range extractAssistantToolCallParts(msg) {
callID := strings.TrimSpace(call.ToolCallID)
if callID == "" {
continue
}
if _, exists := pending[callID]; exists {
continue
}
pending[callID] = pendingToolCall{
ID: callID,
ToolName: strings.TrimSpace(call.ToolName),
}
pendingOrder = append(pendingOrder, callID)
}
case "tool":
filtered := filterToolMessageToPending(msg, pending)
if filtered == nil {
continue
}
repaired = append(repaired, *filtered)
for _, result := range extractToolResultParts(*filtered) {
delete(pending, strings.TrimSpace(result.ToolCallID))
}
if len(pending) == 0 && len(pendingOrder) > 0 {
pendingOrder = pendingOrder[:0]
continue
}
if len(pendingOrder) > 0 {
nextOrder := pendingOrder[:0]
for _, callID := range pendingOrder {
if _, ok := pending[callID]; ok {
nextOrder = append(nextOrder, callID)
}
}
pendingOrder = nextOrder
}
default:
if len(pending) > 0 {
flushPending()
}
repaired = append(repaired, msg)
}
}
if len(pending) > 0 {
flushPending()
}
return repaired
}
func extractAssistantToolCallParts(msg conversation.ModelMessage) []sdk.ToolCallPart {
sdkMsg := modelMessageToSDKMessage(msg)
if len(sdkMsg.Content) == 0 {
return nil
}
calls := make([]sdk.ToolCallPart, 0, len(sdkMsg.Content))
for _, part := range sdkMsg.Content {
call, ok := part.(sdk.ToolCallPart)
if !ok {
continue
}
calls = append(calls, call)
}
return calls
}
func extractToolResultParts(msg conversation.ModelMessage) []sdk.ToolResultPart {
sdkMsg := modelMessageToSDKMessage(msg)
if len(sdkMsg.Content) == 0 {
return nil
}
results := make([]sdk.ToolResultPart, 0, len(sdkMsg.Content))
for _, part := range sdkMsg.Content {
result, ok := part.(sdk.ToolResultPart)
if !ok {
continue
}
results = append(results, result)
}
return results
}
func filterToolMessageToPending(msg conversation.ModelMessage, pending map[string]pendingToolCall) *conversation.ModelMessage {
results := extractToolResultParts(msg)
if len(results) == 0 {
return nil
}
filtered := make([]sdk.ToolResultPart, 0, len(results))
for _, result := range results {
if _, ok := pending[strings.TrimSpace(result.ToolCallID)]; !ok {
continue
}
filtered = append(filtered, result)
}
if len(filtered) == 0 {
return nil
}
converted := sdkMessagesToModelMessages([]sdk.Message{sdk.ToolMessage(filtered...)})
if len(converted) == 0 {
return nil
}
filteredMsg := converted[0]
filteredMsg.Usage = msg.Usage
return &filteredMsg
}
func syntheticToolResultMessage(toolCallID, toolName, reason string) conversation.ModelMessage {
converted := sdkMessagesToModelMessages([]sdk.Message{sdk.ToolMessage(sdk.ToolResultPart{
ToolCallID: strings.TrimSpace(toolCallID),
ToolName: strings.TrimSpace(toolName),
Result: strings.TrimSpace(reason),
IsError: true,
})})
if len(converted) == 0 {
return conversation.ModelMessage{
Role: "tool",
Content: conversation.NewTextContent(strings.TrimSpace(reason)),
}
}
return converted[0]
}
@@ -0,0 +1,108 @@
package flow
import (
"testing"
sdk "github.com/memohai/twilight-ai/sdk"
"github.com/memohai/memoh/internal/conversation"
)
func TestRepairToolCallClosures_AppendsSyntheticToolResultForDanglingAssistantCall(t *testing.T) {
t.Parallel()
messages := sdkMessagesToModelMessages([]sdk.Message{
sdk.UserMessage("fetch this"),
{
Role: sdk.MessageRoleAssistant,
Content: []sdk.MessagePart{
sdk.ToolCallPart{
ToolCallID: "web_fetch:10",
ToolName: "web_fetch",
Input: map[string]any{"url": "https://example.com"},
},
},
},
{
Role: sdk.MessageRoleAssistant,
Content: []sdk.MessagePart{sdk.TextPart{Text: "interrupted"}},
},
})
repaired := repairToolCallClosures(messages, syntheticToolClosureError)
if len(repaired) != 4 {
t.Fatalf("expected 4 messages after repair, got %d", len(repaired))
}
if repaired[2].Role != "tool" {
t.Fatalf("expected synthetic tool message before trailing assistant, got role %q", repaired[2].Role)
}
results := extractToolResultParts(repaired[2])
if len(results) != 1 {
t.Fatalf("expected 1 tool result part, got %d", len(results))
}
if results[0].ToolCallID != "web_fetch:10" {
t.Fatalf("expected tool call id web_fetch:10, got %q", results[0].ToolCallID)
}
if !results[0].IsError {
t.Fatal("expected synthetic tool result to be marked as error")
}
}
func TestRepairToolCallClosures_DropsOrphanToolMessage(t *testing.T) {
t.Parallel()
orphanTool := sdkMessagesToModelMessages([]sdk.Message{
sdk.ToolMessage(sdk.ToolResultPart{
ToolCallID: "web_fetch:10",
ToolName: "web_fetch",
Result: "orphan",
}),
})[0]
messages := []conversation.ModelMessage{
{Role: "user", Content: conversation.NewTextContent("hello")},
orphanTool,
{Role: "assistant", Content: conversation.NewTextContent("done")},
}
repaired := repairToolCallClosures(messages, syntheticToolClosureError)
if len(repaired) != 2 {
t.Fatalf("expected orphan tool message to be removed, got %d messages", len(repaired))
}
if repaired[0].Role != "user" || repaired[1].Role != "assistant" {
t.Fatalf("unexpected repaired role sequence: %q -> %q", repaired[0].Role, repaired[1].Role)
}
}
func TestRepairToolCallClosures_PreservesValidAssistantToolPair(t *testing.T) {
t.Parallel()
messages := sdkMessagesToModelMessages([]sdk.Message{
{
Role: sdk.MessageRoleAssistant,
Content: []sdk.MessagePart{
sdk.ToolCallPart{
ToolCallID: "web_search:1",
ToolName: "web_search",
Input: map[string]any{"query": "memoh"},
},
},
},
sdk.ToolMessage(sdk.ToolResultPart{
ToolCallID: "web_search:1",
ToolName: "web_search",
Result: map[string]any{"results": []any{}},
}),
})
repaired := repairToolCallClosures(messages, syntheticToolClosureError)
if len(repaired) != 2 {
t.Fatalf("expected valid tool pair to be preserved, got %d messages", len(repaired))
}
results := extractToolResultParts(repaired[1])
if len(results) != 1 || results[0].ToolCallID != "web_search:1" {
t.Fatalf("unexpected repaired tool results: %#v", results)
}
}