mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
fix(flow): drop leading orphan tool messages after token trimming (#68)
This commit is contained in:
@@ -750,6 +750,13 @@ func trimMessagesByTokens(messages []messageWithUsage, maxTokens int) []conversa
|
||||
}
|
||||
}
|
||||
|
||||
// Keep provider-valid message order: a "tool" message must follow a preceding
|
||||
// assistant tool call. When history is head-trimmed, a leading tool message
|
||||
// may become orphaned and cause provider 400 errors.
|
||||
for cutoff < len(messages) && strings.EqualFold(strings.TrimSpace(messages[cutoff].Message.Role), "tool") {
|
||||
cutoff++
|
||||
}
|
||||
|
||||
result := make([]conversation.ModelMessage, 0, len(messages)-cutoff)
|
||||
for _, m := range messages[cutoff:] {
|
||||
result = append(result, m.Message)
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
package flow
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
)
|
||||
|
||||
func TestTrimMessagesByTokens_DropsLeadingOrphanTool(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
messages := []messageWithUsage{
|
||||
{
|
||||
Message: conversation.ModelMessage{
|
||||
Role: "user",
|
||||
Content: conversation.NewTextContent("1111"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Message: conversation.ModelMessage{
|
||||
Role: "assistant",
|
||||
ToolCalls: []conversation.ToolCall{
|
||||
{
|
||||
ID: "call-1",
|
||||
Type: "function",
|
||||
Function: conversation.ToolCallFunction{
|
||||
Name: "calc",
|
||||
Arguments: `{"x":1}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Message: conversation.ModelMessage{
|
||||
Role: "tool",
|
||||
ToolCallID: "call-1",
|
||||
Content: conversation.NewTextContent("2"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Message: conversation.ModelMessage{
|
||||
Role: "assistant",
|
||||
Content: conversation.NewTextContent("done"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
trimmed := trimMessagesByTokens(messages, 2)
|
||||
if len(trimmed) == 0 {
|
||||
t.Fatal("expected non-empty trimmed messages")
|
||||
}
|
||||
if trimmed[0].Role == "tool" {
|
||||
t.Fatal("expected first trimmed message not to be tool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrimMessagesByTokens_KeepsToolWhenPaired(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
messages := []messageWithUsage{
|
||||
{
|
||||
Message: conversation.ModelMessage{
|
||||
Role: "assistant",
|
||||
ToolCalls: []conversation.ToolCall{
|
||||
{
|
||||
ID: "call-1",
|
||||
Type: "function",
|
||||
Function: conversation.ToolCallFunction{
|
||||
Name: "calc",
|
||||
Arguments: `{"x":1}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Message: conversation.ModelMessage{
|
||||
Role: "tool",
|
||||
ToolCallID: "call-1",
|
||||
Content: conversation.NewTextContent("2"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
trimmed := trimMessagesByTokens(messages, 100)
|
||||
if len(trimmed) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(trimmed))
|
||||
}
|
||||
if trimmed[0].Role != "assistant" || trimmed[1].Role != "tool" {
|
||||
t.Fatalf("unexpected role order: %q -> %q", trimmed[0].Role, trimmed[1].Role)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user