mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
fix(conversation): correct token trimming edge cases (#207)
- Treat maxTokens=0 as "unconfigured/unlimited" instead of disabling trimming for any non-positive value (which masked exhausted budgets) - Set historyBudget=1 when maxTokens>0 but overhead exceeds the limit, ensuring aggressive trimming instead of no trimming - Estimate token cost for messages without usage data (len/4 fallback) so user/tool messages are not free-passed during budget accounting
This commit is contained in:
@@ -342,7 +342,9 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
|
||||
overhead += systemPromptReserve
|
||||
|
||||
historyBudget := maxTokens - overhead
|
||||
if historyBudget < 0 {
|
||||
if maxTokens > 0 && historyBudget <= 0 {
|
||||
historyBudget = 1
|
||||
} else if historyBudget < 0 {
|
||||
historyBudget = 0
|
||||
}
|
||||
|
||||
@@ -1279,7 +1281,7 @@ func estimateMessageTokens(msg conversation.ModelMessage) int {
|
||||
}
|
||||
|
||||
func trimMessagesByTokens(log *slog.Logger, messages []messageWithUsage, maxTokens int) []conversation.ModelMessage {
|
||||
if maxTokens <= 0 || len(messages) == 0 {
|
||||
if maxTokens == 0 || len(messages) == 0 {
|
||||
result := make([]conversation.ModelMessage, len(messages))
|
||||
for i, m := range messages {
|
||||
result[i] = m.Message
|
||||
@@ -1287,10 +1289,9 @@ func trimMessagesByTokens(log *slog.Logger, messages []messageWithUsage, maxToke
|
||||
return result
|
||||
}
|
||||
|
||||
// Scan from newest to oldest, accumulating per-message outputTokens from
|
||||
// stored usage data. Messages without usage (user / tool) are included for
|
||||
// free — the outputTokens of surrounding assistant turns already account
|
||||
// for the context they consumed.
|
||||
// Scan from newest to oldest, accumulating per-message token costs.
|
||||
// Messages with stored usage data use that value; others fall back to a
|
||||
// character-based estimate so that user/tool messages are not free-passed.
|
||||
totalTokens := 0
|
||||
cutoff := 0
|
||||
messagesWithUsage := 0
|
||||
@@ -1298,6 +1299,8 @@ func trimMessagesByTokens(log *slog.Logger, messages []messageWithUsage, maxToke
|
||||
if messages[i].UsageOutputTokens != nil {
|
||||
totalTokens += *messages[i].UsageOutputTokens
|
||||
messagesWithUsage++
|
||||
} else {
|
||||
totalTokens += estimateMessageTokens(messages[i].Message)
|
||||
}
|
||||
if totalTokens > maxTokens {
|
||||
cutoff = i + 1
|
||||
|
||||
@@ -112,3 +112,55 @@ func TestTrimMessagesByTokens_NoUsage_KeepsAll(t *testing.T) {
|
||||
t.Fatalf("messages without outputTokens should all be kept, got %d", len(trimmed))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrimMessagesByTokens_ZeroMeansNoLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
messages := []messageWithUsage{
|
||||
{Message: conversation.ModelMessage{Role: "user", Content: conversation.NewTextContent("hello")}, UsageOutputTokens: intPtr(10000)},
|
||||
{Message: conversation.ModelMessage{Role: "assistant", Content: conversation.NewTextContent("world")}, UsageOutputTokens: intPtr(10000)},
|
||||
}
|
||||
|
||||
// maxTokens = 0 means "no limit configured", should keep all messages.
|
||||
trimmed := trimMessagesByTokens(nil, messages, 0)
|
||||
if len(trimmed) != 2 {
|
||||
t.Fatalf("maxTokens=0 should keep all messages, got %d", len(trimmed))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrimMessagesByTokens_SmallBudgetTrims(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
messages := []messageWithUsage{
|
||||
{Message: conversation.ModelMessage{Role: "user", Content: conversation.NewTextContent("old message")}, UsageOutputTokens: intPtr(100)},
|
||||
{Message: conversation.ModelMessage{Role: "assistant", Content: conversation.NewTextContent("old reply")}, UsageOutputTokens: intPtr(200)},
|
||||
{Message: conversation.ModelMessage{Role: "user", Content: conversation.NewTextContent("new message")}, UsageOutputTokens: intPtr(50)},
|
||||
{Message: conversation.ModelMessage{Role: "assistant", Content: conversation.NewTextContent("new reply")}, UsageOutputTokens: intPtr(60)},
|
||||
}
|
||||
|
||||
// Budget of 1: should trim aggressively, NOT return all messages.
|
||||
trimmed := trimMessagesByTokens(nil, messages, 1)
|
||||
if len(trimmed) >= len(messages) {
|
||||
t.Fatalf("maxTokens=1 should trim history, but got %d messages (same as input)", len(trimmed))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrimMessagesByTokens_EstimatesFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Long user message without usage data — should be estimated.
|
||||
longText := make([]byte, 400)
|
||||
for i := range longText {
|
||||
longText[i] = 'x'
|
||||
}
|
||||
messages := []messageWithUsage{
|
||||
{Message: conversation.ModelMessage{Role: "user", Content: conversation.NewTextContent(string(longText))}},
|
||||
{Message: conversation.ModelMessage{Role: "assistant", Content: conversation.NewTextContent("ok")}, UsageOutputTokens: intPtr(10)},
|
||||
}
|
||||
|
||||
// Budget of 50: user message is ~100 estimated tokens (400/4), should be trimmed.
|
||||
trimmed := trimMessagesByTokens(nil, messages, 50)
|
||||
if len(trimmed) == 2 {
|
||||
t.Fatalf("expected long user message without usage to be trimmed via estimation, got %d", len(trimmed))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user