From c741f2410bec19b700d17bdea7f4f87c19af02e2 Mon Sep 17 00:00:00 2001 From: Menci Date: Mon, 9 Mar 2026 13:06:19 +0800 Subject: [PATCH] fix(conversation): correct token trimming edge cases (#207) - Treat maxTokens=0 as "unconfigured/unlimited" instead of disabling trimming for any non-positive value (which masked exhausted budgets) - Set historyBudget=1 when maxTokens>0 but overhead exceeds the limit, ensuring aggressive trimming instead of no trimming - Estimate token cost for messages without usage data (len/4 fallback) so user/tool messages are not free-passed during budget accounting --- internal/conversation/flow/resolver.go | 15 +++--- .../conversation/flow/resolver_trim_test.go | 52 +++++++++++++++++++ 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go index b0909424..eea05939 100644 --- a/internal/conversation/flow/resolver.go +++ b/internal/conversation/flow/resolver.go @@ -342,7 +342,9 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r overhead += systemPromptReserve historyBudget := maxTokens - overhead - if historyBudget < 0 { + if maxTokens > 0 && historyBudget <= 0 { + historyBudget = 1 + } else if historyBudget < 0 { historyBudget = 0 } @@ -1279,7 +1281,7 @@ func estimateMessageTokens(msg conversation.ModelMessage) int { } func trimMessagesByTokens(log *slog.Logger, messages []messageWithUsage, maxTokens int) []conversation.ModelMessage { - if maxTokens <= 0 || len(messages) == 0 { + if maxTokens == 0 || len(messages) == 0 { result := make([]conversation.ModelMessage, len(messages)) for i, m := range messages { result[i] = m.Message @@ -1287,10 +1289,9 @@ func trimMessagesByTokens(log *slog.Logger, messages []messageWithUsage, maxToke return result } - // Scan from newest to oldest, accumulating per-message outputTokens from - // stored usage data. Messages without usage (user / tool) are included for - // free — the outputTokens of surrounding assistant turns already account - // for the context they consumed. + // Scan from newest to oldest, accumulating per-message token costs. + // Messages with stored usage data use that value; others fall back to a + // character-based estimate so that user/tool messages are not free-passed. totalTokens := 0 cutoff := 0 messagesWithUsage := 0 @@ -1298,6 +1299,8 @@ func trimMessagesByTokens(log *slog.Logger, messages []messageWithUsage, maxToke if messages[i].UsageOutputTokens != nil { totalTokens += *messages[i].UsageOutputTokens messagesWithUsage++ + } else { + totalTokens += estimateMessageTokens(messages[i].Message) } if totalTokens > maxTokens { cutoff = i + 1 diff --git a/internal/conversation/flow/resolver_trim_test.go b/internal/conversation/flow/resolver_trim_test.go index 136aa233..c30d2101 100644 --- a/internal/conversation/flow/resolver_trim_test.go +++ b/internal/conversation/flow/resolver_trim_test.go @@ -112,3 +112,55 @@ func TestTrimMessagesByTokens_NoUsage_KeepsAll(t *testing.T) { t.Fatalf("messages without outputTokens should all be kept, got %d", len(trimmed)) } } + +func TestTrimMessagesByTokens_ZeroMeansNoLimit(t *testing.T) { + t.Parallel() + + messages := []messageWithUsage{ + {Message: conversation.ModelMessage{Role: "user", Content: conversation.NewTextContent("hello")}, UsageOutputTokens: intPtr(10000)}, + {Message: conversation.ModelMessage{Role: "assistant", Content: conversation.NewTextContent("world")}, UsageOutputTokens: intPtr(10000)}, + } + + // maxTokens = 0 means "no limit configured", should keep all messages. + trimmed := trimMessagesByTokens(nil, messages, 0) + if len(trimmed) != 2 { + t.Fatalf("maxTokens=0 should keep all messages, got %d", len(trimmed)) + } +} + +func TestTrimMessagesByTokens_SmallBudgetTrims(t *testing.T) { + t.Parallel() + + messages := []messageWithUsage{ + {Message: conversation.ModelMessage{Role: "user", Content: conversation.NewTextContent("old message")}, UsageOutputTokens: intPtr(100)}, + {Message: conversation.ModelMessage{Role: "assistant", Content: conversation.NewTextContent("old reply")}, UsageOutputTokens: intPtr(200)}, + {Message: conversation.ModelMessage{Role: "user", Content: conversation.NewTextContent("new message")}, UsageOutputTokens: intPtr(50)}, + {Message: conversation.ModelMessage{Role: "assistant", Content: conversation.NewTextContent("new reply")}, UsageOutputTokens: intPtr(60)}, + } + + // Budget of 1: should trim aggressively, NOT return all messages. + trimmed := trimMessagesByTokens(nil, messages, 1) + if len(trimmed) >= len(messages) { + t.Fatalf("maxTokens=1 should trim history, but got %d messages (same as input)", len(trimmed)) + } +} + +func TestTrimMessagesByTokens_EstimatesFallback(t *testing.T) { + t.Parallel() + + // Long user message without usage data — should be estimated. + longText := make([]byte, 400) + for i := range longText { + longText[i] = 'x' + } + messages := []messageWithUsage{ + {Message: conversation.ModelMessage{Role: "user", Content: conversation.NewTextContent(string(longText))}}, + {Message: conversation.ModelMessage{Role: "assistant", Content: conversation.NewTextContent("ok")}, UsageOutputTokens: intPtr(10)}, + } + + // Budget of 50: user message is ~100 estimated tokens (400/4), should be trimmed. + trimmed := trimMessagesByTokens(nil, messages, 50) + if len(trimmed) == 2 { + t.Fatalf("expected long user message without usage to be trimmed via estimation, got %d", len(trimmed)) + } +}