mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
9461f923df
* fix(agent): emit chunked SSE data fix(flow): reassemble chunked SSE and prune tool payloads fix: avoid whitespace prune bypass; optimize chunked SSE builder * refactor: LLM provider pruning use shared textprune library * chore: smaller range
320 lines
7.5 KiB
Go
320 lines
7.5 KiB
Go
package flow
|
|
|
|
import (
|
|
"encoding/json"
|
|
"strings"
|
|
|
|
"github.com/memohai/memoh/internal/conversation"
|
|
textprune "github.com/memohai/memoh/internal/prune"
|
|
)
|
|
|
|
const (
|
|
// Prune long tool payloads per message to keep gateway requests within provider limits,
|
|
// while preserving as much surrounding context as possible.
|
|
gatewayToolPayloadMaxBytes = textprune.DefaultMaxBytes
|
|
gatewayToolPayloadMaxLines = textprune.DefaultMaxLines
|
|
|
|
gatewayToolResultHeadBytes = 6 * 1024
|
|
gatewayToolResultTailBytes = 2 * 1024
|
|
gatewayToolResultHeadLines = 180
|
|
gatewayToolResultTailLines = 50
|
|
|
|
gatewayToolArgsHeadBytes = 4 * 1024
|
|
gatewayToolArgsTailBytes = 2 * 1024
|
|
gatewayToolArgsHeadLines = 180
|
|
gatewayToolArgsTailLines = 50
|
|
|
|
gatewayToolPayloadPrunedMarker = textprune.DefaultMarker
|
|
)
|
|
|
|
func pruneHistoryForGateway(messages []messageWithUsage) []messageWithUsage {
|
|
if len(messages) == 0 {
|
|
return messages
|
|
}
|
|
out := make([]messageWithUsage, 0, len(messages))
|
|
staleUsage := false
|
|
for _, item := range messages {
|
|
msg, changed := pruneMessageForGateway(item.Message)
|
|
if changed {
|
|
item.Message = msg
|
|
staleUsage = true
|
|
}
|
|
if staleUsage {
|
|
item.UsageInputTokens = nil
|
|
}
|
|
out = append(out, item)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func pruneMessagesForGateway(messages []conversation.ModelMessage) []conversation.ModelMessage {
|
|
if len(messages) == 0 {
|
|
return messages
|
|
}
|
|
out := make([]conversation.ModelMessage, 0, len(messages))
|
|
for _, msg := range messages {
|
|
pruned, _ := pruneMessageForGateway(msg)
|
|
out = append(out, pruned)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func pruneMessageForGateway(msg conversation.ModelMessage) (conversation.ModelMessage, bool) {
|
|
changed := false
|
|
if strings.EqualFold(strings.TrimSpace(msg.Role), "tool") {
|
|
msg2, did := pruneToolMessage(msg)
|
|
if did {
|
|
msg = msg2
|
|
changed = true
|
|
}
|
|
}
|
|
if len(msg.ToolCalls) > 0 {
|
|
calls, did := pruneToolCalls(msg.ToolCalls)
|
|
if did {
|
|
msg.ToolCalls = calls
|
|
changed = true
|
|
}
|
|
}
|
|
return msg, changed
|
|
}
|
|
|
|
func pruneToolCalls(calls []conversation.ToolCall) ([]conversation.ToolCall, bool) {
|
|
changed := false
|
|
out := make([]conversation.ToolCall, len(calls))
|
|
for i, call := range calls {
|
|
out[i] = call
|
|
args := call.Function.Arguments
|
|
if args == "" || !exceedsTextBudget(args) {
|
|
continue
|
|
}
|
|
out[i].Function.Arguments = pruneStringEdges(
|
|
args,
|
|
gatewayToolArgsHeadBytes,
|
|
gatewayToolArgsTailBytes,
|
|
gatewayToolArgsHeadLines,
|
|
gatewayToolArgsTailLines,
|
|
"tool arguments",
|
|
)
|
|
changed = true
|
|
}
|
|
return out, changed
|
|
}
|
|
|
|
func pruneToolMessage(msg conversation.ModelMessage) (conversation.ModelMessage, bool) {
|
|
// Vercel AI SDK schema requires tool messages to carry an array of tool-result parts.
|
|
// Prune outputs inside those parts (preserving shape) so the gateway prompt remains valid.
|
|
if pruned, ok := pruneToolResultParts(msg.Content); ok {
|
|
msg.Content = pruned
|
|
return msg, true
|
|
}
|
|
|
|
// Backward-compat: tool messages may have been persisted as plain strings.
|
|
text := msg.TextContent()
|
|
if !exceedsTextBudget(text) {
|
|
return msg, false
|
|
}
|
|
msg.Content = conversation.NewTextContent(pruneStringEdges(
|
|
text,
|
|
gatewayToolResultHeadBytes,
|
|
gatewayToolResultTailBytes,
|
|
gatewayToolResultHeadLines,
|
|
gatewayToolResultTailLines,
|
|
"tool result",
|
|
))
|
|
return msg, true
|
|
}
|
|
|
|
func pruneToolResultParts(content json.RawMessage) (json.RawMessage, bool) {
|
|
if len(content) == 0 {
|
|
return nil, false
|
|
}
|
|
var parts []json.RawMessage
|
|
if err := json.Unmarshal(content, &parts); err != nil || len(parts) == 0 {
|
|
return nil, false
|
|
}
|
|
|
|
changed := false
|
|
out := make([]json.RawMessage, 0, len(parts))
|
|
for _, raw := range parts {
|
|
var part map[string]json.RawMessage
|
|
if err := json.Unmarshal(raw, &part); err != nil {
|
|
out = append(out, raw)
|
|
continue
|
|
}
|
|
|
|
partTypeRaw, ok := part["type"]
|
|
if !ok {
|
|
out = append(out, raw)
|
|
continue
|
|
}
|
|
var partType string
|
|
if err := json.Unmarshal(partTypeRaw, &partType); err != nil || partType != "tool-result" {
|
|
out = append(out, raw)
|
|
continue
|
|
}
|
|
|
|
outputRaw, ok := part["output"]
|
|
if !ok {
|
|
out = append(out, raw)
|
|
continue
|
|
}
|
|
pruned, didPrune := pruneToolOutput(outputRaw)
|
|
if !didPrune {
|
|
out = append(out, raw)
|
|
continue
|
|
}
|
|
|
|
part["output"] = pruned
|
|
rebuilt, err := json.Marshal(part)
|
|
if err != nil {
|
|
out = append(out, raw)
|
|
continue
|
|
}
|
|
out = append(out, json.RawMessage(rebuilt))
|
|
changed = true
|
|
}
|
|
|
|
if !changed {
|
|
return nil, false
|
|
}
|
|
rebuilt, err := json.Marshal(out)
|
|
if err != nil {
|
|
return nil, false
|
|
}
|
|
return json.RawMessage(rebuilt), true
|
|
}
|
|
|
|
func pruneToolOutput(raw json.RawMessage) (json.RawMessage, bool) {
|
|
var output map[string]json.RawMessage
|
|
if err := json.Unmarshal(raw, &output); err != nil {
|
|
return nil, false
|
|
}
|
|
outputTypeRaw, ok := output["type"]
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
var outputType string
|
|
if err := json.Unmarshal(outputTypeRaw, &outputType); err != nil {
|
|
return nil, false
|
|
}
|
|
valueRaw, hasValue := output["value"]
|
|
|
|
switch outputType {
|
|
case "text", "error-text":
|
|
if !hasValue {
|
|
return nil, false
|
|
}
|
|
var s string
|
|
if err := json.Unmarshal(valueRaw, &s); err != nil || !exceedsTextBudget(s) {
|
|
return nil, false
|
|
}
|
|
s = pruneStringEdges(
|
|
s,
|
|
gatewayToolResultHeadBytes,
|
|
gatewayToolResultTailBytes,
|
|
gatewayToolResultHeadLines,
|
|
gatewayToolResultTailLines,
|
|
"tool result",
|
|
)
|
|
data, err := json.Marshal(s)
|
|
if err != nil {
|
|
return nil, false
|
|
}
|
|
output["value"] = data
|
|
rebuilt, err := json.Marshal(output)
|
|
if err != nil {
|
|
return nil, false
|
|
}
|
|
return json.RawMessage(rebuilt), true
|
|
|
|
case "json", "error-json":
|
|
if !hasValue || !exceedsTextBudget(string(valueRaw)) {
|
|
return nil, false
|
|
}
|
|
pruned := pruneStringEdges(
|
|
string(valueRaw),
|
|
gatewayToolResultHeadBytes,
|
|
gatewayToolResultTailBytes,
|
|
gatewayToolResultHeadLines,
|
|
gatewayToolResultTailLines,
|
|
"tool result (json)",
|
|
)
|
|
data, err := json.Marshal(pruned)
|
|
if err != nil {
|
|
return nil, false
|
|
}
|
|
output["value"] = data
|
|
rebuilt, err := json.Marshal(output)
|
|
if err != nil {
|
|
return nil, false
|
|
}
|
|
return json.RawMessage(rebuilt), true
|
|
|
|
case "content":
|
|
// Best-effort: prune any large text items inside the content array.
|
|
// If parsing fails, keep the original output to avoid breaking schema.
|
|
if !hasValue {
|
|
return nil, false
|
|
}
|
|
var items []map[string]any
|
|
if err := json.Unmarshal(valueRaw, &items); err != nil {
|
|
return nil, false
|
|
}
|
|
didPrune := false
|
|
for i := range items {
|
|
if items[i]["type"] != "text" {
|
|
continue
|
|
}
|
|
textAny, ok := items[i]["text"]
|
|
if !ok {
|
|
continue
|
|
}
|
|
text, ok := textAny.(string)
|
|
if !ok || !exceedsTextBudget(text) {
|
|
continue
|
|
}
|
|
items[i]["text"] = pruneStringEdges(
|
|
text,
|
|
gatewayToolResultHeadBytes,
|
|
gatewayToolResultTailBytes,
|
|
gatewayToolResultHeadLines,
|
|
gatewayToolResultTailLines,
|
|
"tool result (content)",
|
|
)
|
|
didPrune = true
|
|
}
|
|
if !didPrune {
|
|
return nil, false
|
|
}
|
|
data, err := json.Marshal(items)
|
|
if err != nil {
|
|
return nil, false
|
|
}
|
|
output["value"] = data
|
|
rebuilt, err := json.Marshal(output)
|
|
if err != nil {
|
|
return nil, false
|
|
}
|
|
return json.RawMessage(rebuilt), true
|
|
|
|
default:
|
|
return nil, false
|
|
}
|
|
}
|
|
|
|
func pruneStringEdges(s string, headBytes, tailBytes, headLines, tailLines int, label string) string {
|
|
return textprune.PruneWithEdges(s, label, textprune.Config{
|
|
MaxBytes: gatewayToolPayloadMaxBytes,
|
|
MaxLines: gatewayToolPayloadMaxLines,
|
|
HeadBytes: headBytes,
|
|
TailBytes: tailBytes,
|
|
HeadLines: headLines,
|
|
TailLines: tailLines,
|
|
Marker: gatewayToolPayloadPrunedMarker,
|
|
})
|
|
}
|
|
|
|
func exceedsTextBudget(s string) bool {
|
|
return textprune.Exceeds(s, gatewayToolPayloadMaxBytes, gatewayToolPayloadMaxLines)
|
|
}
|