Files
Memoh/internal/conversation/flow/gateway_prune.go
T
Ringo.Typowriter 9461f923df fix(flow): stabilize chunked SSE and unify prune limits for read/exec/gateway (#71)
* fix(agent): emit chunked SSE data

fix(flow): reassemble chunked SSE and prune tool payloads

fix: avoid whitespace prune bypass; optimize chunked SSE builder

* refactor: LLM provider pruning use shared textprune library

* chore: smaller range
2026-02-21 17:06:02 +08:00

320 lines
7.5 KiB
Go

package flow
import (
"encoding/json"
"strings"
"github.com/memohai/memoh/internal/conversation"
textprune "github.com/memohai/memoh/internal/prune"
)
const (
// Prune long tool payloads per message to keep gateway requests within provider limits,
// while preserving as much surrounding context as possible.
gatewayToolPayloadMaxBytes = textprune.DefaultMaxBytes
gatewayToolPayloadMaxLines = textprune.DefaultMaxLines
gatewayToolResultHeadBytes = 6 * 1024
gatewayToolResultTailBytes = 2 * 1024
gatewayToolResultHeadLines = 180
gatewayToolResultTailLines = 50
gatewayToolArgsHeadBytes = 4 * 1024
gatewayToolArgsTailBytes = 2 * 1024
gatewayToolArgsHeadLines = 180
gatewayToolArgsTailLines = 50
gatewayToolPayloadPrunedMarker = textprune.DefaultMarker
)
func pruneHistoryForGateway(messages []messageWithUsage) []messageWithUsage {
if len(messages) == 0 {
return messages
}
out := make([]messageWithUsage, 0, len(messages))
staleUsage := false
for _, item := range messages {
msg, changed := pruneMessageForGateway(item.Message)
if changed {
item.Message = msg
staleUsage = true
}
if staleUsage {
item.UsageInputTokens = nil
}
out = append(out, item)
}
return out
}
func pruneMessagesForGateway(messages []conversation.ModelMessage) []conversation.ModelMessage {
if len(messages) == 0 {
return messages
}
out := make([]conversation.ModelMessage, 0, len(messages))
for _, msg := range messages {
pruned, _ := pruneMessageForGateway(msg)
out = append(out, pruned)
}
return out
}
func pruneMessageForGateway(msg conversation.ModelMessage) (conversation.ModelMessage, bool) {
changed := false
if strings.EqualFold(strings.TrimSpace(msg.Role), "tool") {
msg2, did := pruneToolMessage(msg)
if did {
msg = msg2
changed = true
}
}
if len(msg.ToolCalls) > 0 {
calls, did := pruneToolCalls(msg.ToolCalls)
if did {
msg.ToolCalls = calls
changed = true
}
}
return msg, changed
}
func pruneToolCalls(calls []conversation.ToolCall) ([]conversation.ToolCall, bool) {
changed := false
out := make([]conversation.ToolCall, len(calls))
for i, call := range calls {
out[i] = call
args := call.Function.Arguments
if args == "" || !exceedsTextBudget(args) {
continue
}
out[i].Function.Arguments = pruneStringEdges(
args,
gatewayToolArgsHeadBytes,
gatewayToolArgsTailBytes,
gatewayToolArgsHeadLines,
gatewayToolArgsTailLines,
"tool arguments",
)
changed = true
}
return out, changed
}
func pruneToolMessage(msg conversation.ModelMessage) (conversation.ModelMessage, bool) {
// Vercel AI SDK schema requires tool messages to carry an array of tool-result parts.
// Prune outputs inside those parts (preserving shape) so the gateway prompt remains valid.
if pruned, ok := pruneToolResultParts(msg.Content); ok {
msg.Content = pruned
return msg, true
}
// Backward-compat: tool messages may have been persisted as plain strings.
text := msg.TextContent()
if !exceedsTextBudget(text) {
return msg, false
}
msg.Content = conversation.NewTextContent(pruneStringEdges(
text,
gatewayToolResultHeadBytes,
gatewayToolResultTailBytes,
gatewayToolResultHeadLines,
gatewayToolResultTailLines,
"tool result",
))
return msg, true
}
func pruneToolResultParts(content json.RawMessage) (json.RawMessage, bool) {
if len(content) == 0 {
return nil, false
}
var parts []json.RawMessage
if err := json.Unmarshal(content, &parts); err != nil || len(parts) == 0 {
return nil, false
}
changed := false
out := make([]json.RawMessage, 0, len(parts))
for _, raw := range parts {
var part map[string]json.RawMessage
if err := json.Unmarshal(raw, &part); err != nil {
out = append(out, raw)
continue
}
partTypeRaw, ok := part["type"]
if !ok {
out = append(out, raw)
continue
}
var partType string
if err := json.Unmarshal(partTypeRaw, &partType); err != nil || partType != "tool-result" {
out = append(out, raw)
continue
}
outputRaw, ok := part["output"]
if !ok {
out = append(out, raw)
continue
}
pruned, didPrune := pruneToolOutput(outputRaw)
if !didPrune {
out = append(out, raw)
continue
}
part["output"] = pruned
rebuilt, err := json.Marshal(part)
if err != nil {
out = append(out, raw)
continue
}
out = append(out, json.RawMessage(rebuilt))
changed = true
}
if !changed {
return nil, false
}
rebuilt, err := json.Marshal(out)
if err != nil {
return nil, false
}
return json.RawMessage(rebuilt), true
}
func pruneToolOutput(raw json.RawMessage) (json.RawMessage, bool) {
var output map[string]json.RawMessage
if err := json.Unmarshal(raw, &output); err != nil {
return nil, false
}
outputTypeRaw, ok := output["type"]
if !ok {
return nil, false
}
var outputType string
if err := json.Unmarshal(outputTypeRaw, &outputType); err != nil {
return nil, false
}
valueRaw, hasValue := output["value"]
switch outputType {
case "text", "error-text":
if !hasValue {
return nil, false
}
var s string
if err := json.Unmarshal(valueRaw, &s); err != nil || !exceedsTextBudget(s) {
return nil, false
}
s = pruneStringEdges(
s,
gatewayToolResultHeadBytes,
gatewayToolResultTailBytes,
gatewayToolResultHeadLines,
gatewayToolResultTailLines,
"tool result",
)
data, err := json.Marshal(s)
if err != nil {
return nil, false
}
output["value"] = data
rebuilt, err := json.Marshal(output)
if err != nil {
return nil, false
}
return json.RawMessage(rebuilt), true
case "json", "error-json":
if !hasValue || !exceedsTextBudget(string(valueRaw)) {
return nil, false
}
pruned := pruneStringEdges(
string(valueRaw),
gatewayToolResultHeadBytes,
gatewayToolResultTailBytes,
gatewayToolResultHeadLines,
gatewayToolResultTailLines,
"tool result (json)",
)
data, err := json.Marshal(pruned)
if err != nil {
return nil, false
}
output["value"] = data
rebuilt, err := json.Marshal(output)
if err != nil {
return nil, false
}
return json.RawMessage(rebuilt), true
case "content":
// Best-effort: prune any large text items inside the content array.
// If parsing fails, keep the original output to avoid breaking schema.
if !hasValue {
return nil, false
}
var items []map[string]any
if err := json.Unmarshal(valueRaw, &items); err != nil {
return nil, false
}
didPrune := false
for i := range items {
if items[i]["type"] != "text" {
continue
}
textAny, ok := items[i]["text"]
if !ok {
continue
}
text, ok := textAny.(string)
if !ok || !exceedsTextBudget(text) {
continue
}
items[i]["text"] = pruneStringEdges(
text,
gatewayToolResultHeadBytes,
gatewayToolResultTailBytes,
gatewayToolResultHeadLines,
gatewayToolResultTailLines,
"tool result (content)",
)
didPrune = true
}
if !didPrune {
return nil, false
}
data, err := json.Marshal(items)
if err != nil {
return nil, false
}
output["value"] = data
rebuilt, err := json.Marshal(output)
if err != nil {
return nil, false
}
return json.RawMessage(rebuilt), true
default:
return nil, false
}
}
func pruneStringEdges(s string, headBytes, tailBytes, headLines, tailLines int, label string) string {
return textprune.PruneWithEdges(s, label, textprune.Config{
MaxBytes: gatewayToolPayloadMaxBytes,
MaxLines: gatewayToolPayloadMaxLines,
HeadBytes: headBytes,
TailBytes: tailBytes,
HeadLines: headLines,
TailLines: tailLines,
Marker: gatewayToolPayloadPrunedMarker,
})
}
func exceedsTextBudget(s string) bool {
return textprune.Exceeds(s, gatewayToolPayloadMaxBytes, gatewayToolPayloadMaxLines)
}