mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
Fix dangling tool call history
This commit is contained in:
@@ -360,6 +360,7 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
|
||||
if len(messages) > 10 {
|
||||
messages = stripToolMessages(messages)
|
||||
}
|
||||
messages = repairToolCallClosures(messages, syntheticToolClosureError)
|
||||
|
||||
displayName := r.resolveDisplayName(ctx, req)
|
||||
mergedAttachments := r.routeAndMergeAttachments(ctx, chatModel, req)
|
||||
|
||||
@@ -26,6 +26,7 @@ func (r *Resolver) storeRound(ctx context.Context, req conversation.ChatRequest,
|
||||
}
|
||||
fullRound = append(fullRound, m)
|
||||
}
|
||||
fullRound = repairToolCallClosures(fullRound, syntheticToolClosureError)
|
||||
|
||||
// Filter out empty assistant messages (content: []) that result from LLM
|
||||
// returning no useful output (e.g., context window overflow). These provide
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
package flow
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
)
|
||||
|
||||
const syntheticToolClosureError = "tool execution interrupted before a response was recorded"
|
||||
|
||||
type pendingToolCall struct {
|
||||
ID string
|
||||
ToolName string
|
||||
}
|
||||
|
||||
func repairToolCallClosures(messages []conversation.ModelMessage, reason string) []conversation.ModelMessage {
|
||||
if len(messages) == 0 {
|
||||
return messages
|
||||
}
|
||||
if strings.TrimSpace(reason) == "" {
|
||||
reason = syntheticToolClosureError
|
||||
}
|
||||
|
||||
repaired := make([]conversation.ModelMessage, 0, len(messages))
|
||||
pending := make(map[string]pendingToolCall)
|
||||
pendingOrder := make([]string, 0)
|
||||
|
||||
flushPending := func() {
|
||||
if len(pendingOrder) == 0 {
|
||||
return
|
||||
}
|
||||
for _, callID := range pendingOrder {
|
||||
call, ok := pending[callID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
repaired = append(repaired, syntheticToolResultMessage(call.ID, call.ToolName, reason))
|
||||
delete(pending, callID)
|
||||
}
|
||||
pendingOrder = pendingOrder[:0]
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
role := strings.ToLower(strings.TrimSpace(msg.Role))
|
||||
switch role {
|
||||
case "assistant":
|
||||
if len(pending) > 0 {
|
||||
flushPending()
|
||||
}
|
||||
repaired = append(repaired, msg)
|
||||
for _, call := range extractAssistantToolCallParts(msg) {
|
||||
callID := strings.TrimSpace(call.ToolCallID)
|
||||
if callID == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := pending[callID]; exists {
|
||||
continue
|
||||
}
|
||||
pending[callID] = pendingToolCall{
|
||||
ID: callID,
|
||||
ToolName: strings.TrimSpace(call.ToolName),
|
||||
}
|
||||
pendingOrder = append(pendingOrder, callID)
|
||||
}
|
||||
|
||||
case "tool":
|
||||
filtered := filterToolMessageToPending(msg, pending)
|
||||
if filtered == nil {
|
||||
continue
|
||||
}
|
||||
repaired = append(repaired, *filtered)
|
||||
for _, result := range extractToolResultParts(*filtered) {
|
||||
delete(pending, strings.TrimSpace(result.ToolCallID))
|
||||
}
|
||||
if len(pending) == 0 && len(pendingOrder) > 0 {
|
||||
pendingOrder = pendingOrder[:0]
|
||||
continue
|
||||
}
|
||||
if len(pendingOrder) > 0 {
|
||||
nextOrder := pendingOrder[:0]
|
||||
for _, callID := range pendingOrder {
|
||||
if _, ok := pending[callID]; ok {
|
||||
nextOrder = append(nextOrder, callID)
|
||||
}
|
||||
}
|
||||
pendingOrder = nextOrder
|
||||
}
|
||||
|
||||
default:
|
||||
if len(pending) > 0 {
|
||||
flushPending()
|
||||
}
|
||||
repaired = append(repaired, msg)
|
||||
}
|
||||
}
|
||||
|
||||
if len(pending) > 0 {
|
||||
flushPending()
|
||||
}
|
||||
return repaired
|
||||
}
|
||||
|
||||
func extractAssistantToolCallParts(msg conversation.ModelMessage) []sdk.ToolCallPart {
|
||||
sdkMsg := modelMessageToSDKMessage(msg)
|
||||
if len(sdkMsg.Content) == 0 {
|
||||
return nil
|
||||
}
|
||||
calls := make([]sdk.ToolCallPart, 0, len(sdkMsg.Content))
|
||||
for _, part := range sdkMsg.Content {
|
||||
call, ok := part.(sdk.ToolCallPart)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
calls = append(calls, call)
|
||||
}
|
||||
return calls
|
||||
}
|
||||
|
||||
func extractToolResultParts(msg conversation.ModelMessage) []sdk.ToolResultPart {
|
||||
sdkMsg := modelMessageToSDKMessage(msg)
|
||||
if len(sdkMsg.Content) == 0 {
|
||||
return nil
|
||||
}
|
||||
results := make([]sdk.ToolResultPart, 0, len(sdkMsg.Content))
|
||||
for _, part := range sdkMsg.Content {
|
||||
result, ok := part.(sdk.ToolResultPart)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
func filterToolMessageToPending(msg conversation.ModelMessage, pending map[string]pendingToolCall) *conversation.ModelMessage {
|
||||
results := extractToolResultParts(msg)
|
||||
if len(results) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
filtered := make([]sdk.ToolResultPart, 0, len(results))
|
||||
for _, result := range results {
|
||||
if _, ok := pending[strings.TrimSpace(result.ToolCallID)]; !ok {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, result)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
converted := sdkMessagesToModelMessages([]sdk.Message{sdk.ToolMessage(filtered...)})
|
||||
if len(converted) == 0 {
|
||||
return nil
|
||||
}
|
||||
filteredMsg := converted[0]
|
||||
filteredMsg.Usage = msg.Usage
|
||||
return &filteredMsg
|
||||
}
|
||||
|
||||
func syntheticToolResultMessage(toolCallID, toolName, reason string) conversation.ModelMessage {
|
||||
converted := sdkMessagesToModelMessages([]sdk.Message{sdk.ToolMessage(sdk.ToolResultPart{
|
||||
ToolCallID: strings.TrimSpace(toolCallID),
|
||||
ToolName: strings.TrimSpace(toolName),
|
||||
Result: strings.TrimSpace(reason),
|
||||
IsError: true,
|
||||
})})
|
||||
if len(converted) == 0 {
|
||||
return conversation.ModelMessage{
|
||||
Role: "tool",
|
||||
Content: conversation.NewTextContent(strings.TrimSpace(reason)),
|
||||
}
|
||||
}
|
||||
return converted[0]
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package flow
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
)
|
||||
|
||||
func TestRepairToolCallClosures_AppendsSyntheticToolResultForDanglingAssistantCall(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
messages := sdkMessagesToModelMessages([]sdk.Message{
|
||||
sdk.UserMessage("fetch this"),
|
||||
{
|
||||
Role: sdk.MessageRoleAssistant,
|
||||
Content: []sdk.MessagePart{
|
||||
sdk.ToolCallPart{
|
||||
ToolCallID: "web_fetch:10",
|
||||
ToolName: "web_fetch",
|
||||
Input: map[string]any{"url": "https://example.com"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: sdk.MessageRoleAssistant,
|
||||
Content: []sdk.MessagePart{sdk.TextPart{Text: "interrupted"}},
|
||||
},
|
||||
})
|
||||
|
||||
repaired := repairToolCallClosures(messages, syntheticToolClosureError)
|
||||
if len(repaired) != 4 {
|
||||
t.Fatalf("expected 4 messages after repair, got %d", len(repaired))
|
||||
}
|
||||
|
||||
if repaired[2].Role != "tool" {
|
||||
t.Fatalf("expected synthetic tool message before trailing assistant, got role %q", repaired[2].Role)
|
||||
}
|
||||
|
||||
results := extractToolResultParts(repaired[2])
|
||||
if len(results) != 1 {
|
||||
t.Fatalf("expected 1 tool result part, got %d", len(results))
|
||||
}
|
||||
if results[0].ToolCallID != "web_fetch:10" {
|
||||
t.Fatalf("expected tool call id web_fetch:10, got %q", results[0].ToolCallID)
|
||||
}
|
||||
if !results[0].IsError {
|
||||
t.Fatal("expected synthetic tool result to be marked as error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairToolCallClosures_DropsOrphanToolMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
orphanTool := sdkMessagesToModelMessages([]sdk.Message{
|
||||
sdk.ToolMessage(sdk.ToolResultPart{
|
||||
ToolCallID: "web_fetch:10",
|
||||
ToolName: "web_fetch",
|
||||
Result: "orphan",
|
||||
}),
|
||||
})[0]
|
||||
|
||||
messages := []conversation.ModelMessage{
|
||||
{Role: "user", Content: conversation.NewTextContent("hello")},
|
||||
orphanTool,
|
||||
{Role: "assistant", Content: conversation.NewTextContent("done")},
|
||||
}
|
||||
|
||||
repaired := repairToolCallClosures(messages, syntheticToolClosureError)
|
||||
if len(repaired) != 2 {
|
||||
t.Fatalf("expected orphan tool message to be removed, got %d messages", len(repaired))
|
||||
}
|
||||
if repaired[0].Role != "user" || repaired[1].Role != "assistant" {
|
||||
t.Fatalf("unexpected repaired role sequence: %q -> %q", repaired[0].Role, repaired[1].Role)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairToolCallClosures_PreservesValidAssistantToolPair(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
messages := sdkMessagesToModelMessages([]sdk.Message{
|
||||
{
|
||||
Role: sdk.MessageRoleAssistant,
|
||||
Content: []sdk.MessagePart{
|
||||
sdk.ToolCallPart{
|
||||
ToolCallID: "web_search:1",
|
||||
ToolName: "web_search",
|
||||
Input: map[string]any{"query": "memoh"},
|
||||
},
|
||||
},
|
||||
},
|
||||
sdk.ToolMessage(sdk.ToolResultPart{
|
||||
ToolCallID: "web_search:1",
|
||||
ToolName: "web_search",
|
||||
Result: map[string]any{"results": []any{}},
|
||||
}),
|
||||
})
|
||||
|
||||
repaired := repairToolCallClosures(messages, syntheticToolClosureError)
|
||||
if len(repaired) != 2 {
|
||||
t.Fatalf("expected valid tool pair to be preserved, got %d messages", len(repaired))
|
||||
}
|
||||
results := extractToolResultParts(repaired[1])
|
||||
if len(results) != 1 || results[0].ToolCallID != "web_search:1" {
|
||||
t.Fatalf("unexpected repaired tool results: %#v", results)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user