fix(flow): stabilize chunked SSE and unify prune limits for read/exec/gateway (#71)

* fix(agent): emit chunked SSE data

fix(flow): reassemble chunked SSE and prune tool payloads

fix: avoid whitespace prune bypass; optimize chunked SSE builder

* refactor: LLM provider pruning use shared textprune library

* chore: smaller range
This commit is contained in:
Ringo.Typowriter
2026-02-21 17:06:02 +08:00
committed by GitHub
parent 2de8095c75
commit 9461f923df
11 changed files with 1160 additions and 59 deletions
+4 -3
View File
@@ -1,4 +1,4 @@
import { Elysia, sse } from 'elysia'
import { Elysia } from 'elysia'
import z from 'zod'
import { createAgent } from '../agent'
import { createAuthFetcher, getBaseUrl } from '../index'
@@ -6,6 +6,7 @@ import { ModelConfig } from '../types'
import { bearerMiddleware } from '../middlewares/bearer'
import { AgentSkillModel, AllowedActionModel, AttachmentModel, IdentityContextModel, MCPConnectionModel, ModelConfigModel, ScheduleModel } from '../models'
import { allActions } from '../types'
import { sseChunked } from '../utils/sse'
const AgentModel = z.object({
model: ModelConfigModel,
@@ -75,14 +76,14 @@ export const chatModule = new Elysia({ prefix: '/chat' })
skills: body.skills,
attachments: body.attachments,
})) {
yield sse(JSON.stringify(action))
yield sseChunked(JSON.stringify(action))
}
} catch (error) {
console.error(error)
const message = error instanceof Error && error.message.trim()
? error.message
: 'Internal server error'
yield sse(JSON.stringify({
yield sseChunked(JSON.stringify({
type: 'error',
message,
}))
+45
View File
@@ -0,0 +1,45 @@
import { describe, expect, test } from 'bun:test'
import { sseChunked } from '../utils/sse'
function parseChunkedSSE(payload: string): string {
const lines = payload.split('\n')
const dataLines = lines.filter(line => line.startsWith('data:'))
return dataLines.map(line => line.slice('data:'.length)).join('')
}
describe('sseChunked', () => {
test('reconstructs original payload losslessly', () => {
const input = JSON.stringify({
type: 'tool_call_end',
toolName: 'big_tool',
toolCallId: 'call-1',
// include whitespace and unicode so trimming/surrogate splitting bugs show up
result: ' leading spaces\tand tabs\nand unicode 😀😃😄 ',
blob: 'x'.repeat(200_000),
})
const chunked = sseChunked(input, 1024).toSSE()
const reconstructed = parseChunkedSSE(chunked)
expect(reconstructed).toBe(input)
})
test('chunkSize=1 does not produce invalid UTF-8 (surrogate pairs)', () => {
const input = `😀${'x'.repeat(1000)}😃`
const payload = sseChunked(input, 1).toSSE()
// Simulate the UTF-8 encode/decode step that happens over the network.
const encoded = new TextEncoder().encode(payload)
const decoded = new TextDecoder().decode(encoded)
expect(decoded).toBe(payload)
const reconstructed = parseChunkedSSE(decoded)
expect(reconstructed).toBe(input)
})
test('does not inject an extra space after data:', () => {
const input = ' abc'
const chunked = sseChunked(input, 2).toSSE()
expect(chunked.split('\n')[0]).toBe('data: a')
})
})
+50
View File
@@ -0,0 +1,50 @@
export const defaultSSEChunkSize = 16 * 1024
export function sseChunked(data: string, chunkSize: number = defaultSSEChunkSize) {
return {
sse: true as const,
toSSE: () => {
const out: string[] = []
for (const chunk of chunkString(data, chunkSize)) {
out.push(`data:${chunk}\n`)
}
out.push('\n')
return out.join('')
},
}
}
export function* chunkString(input: string, maxLen: number): Generator<string> {
if (maxLen <= 0) {
yield input
return
}
const isHighSurrogate = (c: number) => c >= 0xd800 && c <= 0xdbff
const isLowSurrogate = (c: number) => c >= 0xdc00 && c <= 0xdfff
let i = 0
while (i < input.length) {
let end = Math.min(i + maxLen, input.length)
if (end < input.length) {
const last = input.charCodeAt(end - 1)
if (isHighSurrogate(last)) {
const next = input.charCodeAt(end)
if (isLowSurrogate(next)) {
end += 1
} else {
end -= 1
}
}
}
if (end <= i) {
const first = input.charCodeAt(i)
const second = i+1 < input.length ? input.charCodeAt(i + 1) : -1
if (isHighSurrogate(first) && isLowSurrogate(second)) {
end = Math.min(i + 2, input.length)
} else {
end = Math.min(i + 1, input.length)
}
}
yield input.slice(i, end)
i = end
}
}
+319
View File
@@ -0,0 +1,319 @@
package flow
import (
"encoding/json"
"strings"
"github.com/memohai/memoh/internal/conversation"
textprune "github.com/memohai/memoh/internal/prune"
)
const (
// Prune long tool payloads per message to keep gateway requests within provider limits,
// while preserving as much surrounding context as possible.
gatewayToolPayloadMaxBytes = textprune.DefaultMaxBytes
gatewayToolPayloadMaxLines = textprune.DefaultMaxLines
gatewayToolResultHeadBytes = 6 * 1024
gatewayToolResultTailBytes = 2 * 1024
gatewayToolResultHeadLines = 180
gatewayToolResultTailLines = 50
gatewayToolArgsHeadBytes = 4 * 1024
gatewayToolArgsTailBytes = 2 * 1024
gatewayToolArgsHeadLines = 180
gatewayToolArgsTailLines = 50
gatewayToolPayloadPrunedMarker = textprune.DefaultMarker
)
func pruneHistoryForGateway(messages []messageWithUsage) []messageWithUsage {
if len(messages) == 0 {
return messages
}
out := make([]messageWithUsage, 0, len(messages))
staleUsage := false
for _, item := range messages {
msg, changed := pruneMessageForGateway(item.Message)
if changed {
item.Message = msg
staleUsage = true
}
if staleUsage {
item.UsageInputTokens = nil
}
out = append(out, item)
}
return out
}
func pruneMessagesForGateway(messages []conversation.ModelMessage) []conversation.ModelMessage {
if len(messages) == 0 {
return messages
}
out := make([]conversation.ModelMessage, 0, len(messages))
for _, msg := range messages {
pruned, _ := pruneMessageForGateway(msg)
out = append(out, pruned)
}
return out
}
func pruneMessageForGateway(msg conversation.ModelMessage) (conversation.ModelMessage, bool) {
changed := false
if strings.EqualFold(strings.TrimSpace(msg.Role), "tool") {
msg2, did := pruneToolMessage(msg)
if did {
msg = msg2
changed = true
}
}
if len(msg.ToolCalls) > 0 {
calls, did := pruneToolCalls(msg.ToolCalls)
if did {
msg.ToolCalls = calls
changed = true
}
}
return msg, changed
}
func pruneToolCalls(calls []conversation.ToolCall) ([]conversation.ToolCall, bool) {
changed := false
out := make([]conversation.ToolCall, len(calls))
for i, call := range calls {
out[i] = call
args := call.Function.Arguments
if args == "" || !exceedsTextBudget(args) {
continue
}
out[i].Function.Arguments = pruneStringEdges(
args,
gatewayToolArgsHeadBytes,
gatewayToolArgsTailBytes,
gatewayToolArgsHeadLines,
gatewayToolArgsTailLines,
"tool arguments",
)
changed = true
}
return out, changed
}
func pruneToolMessage(msg conversation.ModelMessage) (conversation.ModelMessage, bool) {
// Vercel AI SDK schema requires tool messages to carry an array of tool-result parts.
// Prune outputs inside those parts (preserving shape) so the gateway prompt remains valid.
if pruned, ok := pruneToolResultParts(msg.Content); ok {
msg.Content = pruned
return msg, true
}
// Backward-compat: tool messages may have been persisted as plain strings.
text := msg.TextContent()
if !exceedsTextBudget(text) {
return msg, false
}
msg.Content = conversation.NewTextContent(pruneStringEdges(
text,
gatewayToolResultHeadBytes,
gatewayToolResultTailBytes,
gatewayToolResultHeadLines,
gatewayToolResultTailLines,
"tool result",
))
return msg, true
}
func pruneToolResultParts(content json.RawMessage) (json.RawMessage, bool) {
if len(content) == 0 {
return nil, false
}
var parts []json.RawMessage
if err := json.Unmarshal(content, &parts); err != nil || len(parts) == 0 {
return nil, false
}
changed := false
out := make([]json.RawMessage, 0, len(parts))
for _, raw := range parts {
var part map[string]json.RawMessage
if err := json.Unmarshal(raw, &part); err != nil {
out = append(out, raw)
continue
}
partTypeRaw, ok := part["type"]
if !ok {
out = append(out, raw)
continue
}
var partType string
if err := json.Unmarshal(partTypeRaw, &partType); err != nil || partType != "tool-result" {
out = append(out, raw)
continue
}
outputRaw, ok := part["output"]
if !ok {
out = append(out, raw)
continue
}
pruned, didPrune := pruneToolOutput(outputRaw)
if !didPrune {
out = append(out, raw)
continue
}
part["output"] = pruned
rebuilt, err := json.Marshal(part)
if err != nil {
out = append(out, raw)
continue
}
out = append(out, json.RawMessage(rebuilt))
changed = true
}
if !changed {
return nil, false
}
rebuilt, err := json.Marshal(out)
if err != nil {
return nil, false
}
return json.RawMessage(rebuilt), true
}
func pruneToolOutput(raw json.RawMessage) (json.RawMessage, bool) {
var output map[string]json.RawMessage
if err := json.Unmarshal(raw, &output); err != nil {
return nil, false
}
outputTypeRaw, ok := output["type"]
if !ok {
return nil, false
}
var outputType string
if err := json.Unmarshal(outputTypeRaw, &outputType); err != nil {
return nil, false
}
valueRaw, hasValue := output["value"]
switch outputType {
case "text", "error-text":
if !hasValue {
return nil, false
}
var s string
if err := json.Unmarshal(valueRaw, &s); err != nil || !exceedsTextBudget(s) {
return nil, false
}
s = pruneStringEdges(
s,
gatewayToolResultHeadBytes,
gatewayToolResultTailBytes,
gatewayToolResultHeadLines,
gatewayToolResultTailLines,
"tool result",
)
data, err := json.Marshal(s)
if err != nil {
return nil, false
}
output["value"] = data
rebuilt, err := json.Marshal(output)
if err != nil {
return nil, false
}
return json.RawMessage(rebuilt), true
case "json", "error-json":
if !hasValue || !exceedsTextBudget(string(valueRaw)) {
return nil, false
}
pruned := pruneStringEdges(
string(valueRaw),
gatewayToolResultHeadBytes,
gatewayToolResultTailBytes,
gatewayToolResultHeadLines,
gatewayToolResultTailLines,
"tool result (json)",
)
data, err := json.Marshal(pruned)
if err != nil {
return nil, false
}
output["value"] = data
rebuilt, err := json.Marshal(output)
if err != nil {
return nil, false
}
return json.RawMessage(rebuilt), true
case "content":
// Best-effort: prune any large text items inside the content array.
// If parsing fails, keep the original output to avoid breaking schema.
if !hasValue {
return nil, false
}
var items []map[string]any
if err := json.Unmarshal(valueRaw, &items); err != nil {
return nil, false
}
didPrune := false
for i := range items {
if items[i]["type"] != "text" {
continue
}
textAny, ok := items[i]["text"]
if !ok {
continue
}
text, ok := textAny.(string)
if !ok || !exceedsTextBudget(text) {
continue
}
items[i]["text"] = pruneStringEdges(
text,
gatewayToolResultHeadBytes,
gatewayToolResultTailBytes,
gatewayToolResultHeadLines,
gatewayToolResultTailLines,
"tool result (content)",
)
didPrune = true
}
if !didPrune {
return nil, false
}
data, err := json.Marshal(items)
if err != nil {
return nil, false
}
output["value"] = data
rebuilt, err := json.Marshal(output)
if err != nil {
return nil, false
}
return json.RawMessage(rebuilt), true
default:
return nil, false
}
}
func pruneStringEdges(s string, headBytes, tailBytes, headLines, tailLines int, label string) string {
return textprune.PruneWithEdges(s, label, textprune.Config{
MaxBytes: gatewayToolPayloadMaxBytes,
MaxLines: gatewayToolPayloadMaxLines,
HeadBytes: headBytes,
TailBytes: tailBytes,
HeadLines: headLines,
TailLines: tailLines,
Marker: gatewayToolPayloadPrunedMarker,
})
}
func exceedsTextBudget(s string) bool {
return textprune.Exceeds(s, gatewayToolPayloadMaxBytes, gatewayToolPayloadMaxLines)
}
+79 -52
View File
@@ -2,9 +2,11 @@ package flow
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
@@ -34,6 +36,12 @@ const (
sharedMemoryNamespace = "bot"
// Keep gateway payload bounded when inlining binary attachments as data URLs.
gatewayInlineAttachmentMaxBytes int64 = 20 * 1024 * 1024
// SSE payloads (especially attachment/tool results) can be very large.
// bufio.Scanner hard-fails with "token too long" if a single line exceeds its max token size.
// Use a reader-based parser and enforce an explicit per-line cap here. The agent gateway
// stream is expected to chunk large JSON payloads across multiple SSE "data:" lines, so
// this limit should stay relatively small.
gatewaySSEMaxLineBytes = 256 * 1024
)
// SkillEntry represents a skill loaded from the container.
@@ -255,11 +263,16 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
// Build non-history parts first so we can reserve their token cost before
// trimming history messages.
memoryMsg := r.loadMemoryContextMessage(ctx, req)
reqMessages := pruneMessagesForGateway(nonNilModelMessages(req.Messages))
if memoryMsg != nil {
pruned, _ := pruneMessageForGateway(*memoryMsg)
memoryMsg = &pruned
}
var overhead int
if memoryMsg != nil {
overhead += estimateMessageTokens(*memoryMsg)
}
for _, m := range req.Messages {
for _, m := range reqMessages {
overhead += estimateMessageTokens(m)
}
// Reserve space for the system prompt built by the agent gateway
@@ -278,12 +291,13 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
if loadErr != nil {
return resolvedContext{}, loadErr
}
loaded = pruneHistoryForGateway(loaded)
messages = trimMessagesByTokens(loaded, historyBudget)
}
if memoryMsg != nil {
messages = append(messages, *memoryMsg)
}
messages = append(messages, req.Messages...)
messages = append(messages, reqMessages...)
messages = sanitizeMessages(messages)
skills := dedup(req.Skills)
containerID := r.resolveContainerID(ctx, req.BotID, req.ContainerID)
@@ -580,39 +594,66 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req c
return fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(errBody)))
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
currentEvent := ""
stored := false
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
if strings.HasPrefix(line, "event:") {
currentEvent = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
continue
}
if !strings.HasPrefix(line, "data:") {
continue
}
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if data == "" || data == "[DONE]" {
continue
}
chunkCh <- conversation.StreamChunk([]byte(data))
var dataBuf bytes.Buffer
if stored {
flushEvent := func() error {
if dataBuf.Len() == 0 {
return nil
}
out := append([]byte(nil), dataBuf.Bytes()...)
dataBuf.Reset()
if len(out) == 0 || bytes.Equal(bytes.TrimSpace(out), []byte("[DONE]")) {
return nil
}
// Persist final messages before forwarding the "done"/"agent_end" event so the
// next user turn can immediately see the assistant output in history.
if !stored {
if handled, storeErr := r.tryStoreStream(ctx, req, out); storeErr != nil {
return storeErr
} else if handled {
stored = true
}
}
chunkCh <- conversation.StreamChunk(out)
return nil
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 64*1024), gatewaySSEMaxLineBytes)
for scanner.Scan() {
line := scanner.Bytes()
if len(line) == 0 {
if err := flushEvent(); err != nil {
return err
}
continue
}
if handled, storeErr := r.tryStoreStream(ctx, req, currentEvent, data); storeErr != nil {
return storeErr
} else if handled {
stored = true
if len(line) > 0 && line[0] == ':' {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
part := bytes.TrimPrefix(line, []byte("data:"))
// Backward-compat: older SSE writers used "data: <payload>" (note the space).
// Only strip the first leading space for the *first* fragment to avoid corrupting
// chunked payloads split inside JSON string values.
if dataBuf.Len() == 0 && len(part) > 0 && part[0] == ' ' {
part = part[1:]
}
if len(part) == 0 {
continue
}
_, _ = dataBuf.Write(part)
}
return scanner.Err()
if err := scanner.Err(); err != nil {
if errors.Is(err, bufio.ErrTooLong) {
return fmt.Errorf("sse line too long (max %d bytes)", gatewaySSEMaxLineBytes)
}
return err
}
return flushEvent()
}
func newJSONRequestWithContext(ctx context.Context, method, url string, payload any) (*http.Request, error) {
@@ -631,24 +672,15 @@ func newJSONRequestWithContext(ctx context.Context, method, url string, payload
}
// tryStoreStream attempts to extract final messages from a stream event and persist them.
func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, eventType, data string) (bool, error) {
// event: done + data: {messages: [...]}
if eventType == "done" {
var resp gatewayResponse
if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 {
return true, r.storeRound(ctx, req, resp.Messages, resp.Usage)
}
}
func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, data []byte) (bool, error) {
// data: {"type":"text_delta"|"agent_end"|"done", ...}
var envelope struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
Messages []conversation.ModelMessage `json:"messages"`
Skills []string `json:"skills"`
Usage json.RawMessage `json:"usage,omitempty"`
}
if err := json.Unmarshal([]byte(data), &envelope); err == nil {
if err := json.Unmarshal(data, &envelope); err == nil {
if (envelope.Type == "agent_end" || envelope.Type == "done") && len(envelope.Messages) > 0 {
return true, r.storeRound(ctx, req, envelope.Messages, envelope.Usage)
}
@@ -662,7 +694,7 @@ func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequ
// fallback: data: {messages: [...]}
var resp gatewayResponse
if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 {
if err := json.Unmarshal(data, &resp); err == nil && len(resp.Messages) > 0 {
return true, r.storeRound(ctx, req, resp.Messages, resp.Usage)
}
return false, nil
@@ -763,19 +795,14 @@ func normalizeGatewayAttachmentPayload(item gatewayAttachment) gatewayAttachment
if payload == "" {
return item
}
lower := strings.ToLower(payload)
if strings.HasPrefix(lower, "data:") {
if strings.TrimSpace(item.Mime) == "" || strings.EqualFold(strings.TrimSpace(item.Mime), "application/octet-stream") {
if start := strings.Index(payload, ":"); start >= 0 {
rest := payload[start+1:]
if end := strings.Index(rest, ";"); end > 0 {
mime := strings.TrimSpace(rest[:end])
if mime != "" {
item.Mime = mime
}
}
if strings.HasPrefix(strings.ToLower(payload), "data:") {
mime := strings.TrimSpace(item.Mime)
if mime == "" || strings.EqualFold(mime, "application/octet-stream") {
if extracted := attachmentpkg.MimeFromDataURL(payload); extracted != "" {
item.Mime = extracted
}
}
item.Payload = payload
return item
}
mime := strings.TrimSpace(item.Mime)
@@ -0,0 +1,208 @@
package flow
import (
"bytes"
"encoding/json"
"strings"
"testing"
"unicode/utf8"
"github.com/memohai/memoh/internal/conversation"
)
func TestPruneMessagesForGateway_PrunesToolResultContent(t *testing.T) {
t.Parallel()
unit := "汉😀"
huge := strings.Repeat(unit, (gatewayToolPayloadMaxBytes/len(unit))+20)
msgs := []conversation.ModelMessage{
{Role: "tool", Content: conversation.NewTextContent(huge), ToolCallID: "call-1"},
}
out := pruneMessagesForGateway(msgs)
if len(out) != 1 {
t.Fatalf("expected 1 message, got %d", len(out))
}
got := out[0].TextContent()
if strings.Contains(got, huge) {
t.Fatalf("expected tool content to be pruned")
}
if !strings.Contains(got, gatewayToolPayloadPrunedMarker) {
t.Fatalf("expected pruned marker, got: %q", got[:minLen(len(got), 80)])
}
if !utf8.ValidString(got) {
t.Fatalf("expected pruned tool content to remain valid UTF-8")
}
}
func TestPruneMessagesForGateway_PrunesToolCallArguments(t *testing.T) {
t.Parallel()
repeated := strings.Repeat("猫😺", (gatewayToolPayloadMaxBytes/len("猫😺"))+20)
hugeArgs := `{"a":"` + repeated + `"}`
msgs := []conversation.ModelMessage{
{
Role: "assistant",
ToolCalls: []conversation.ToolCall{
{
ID: "call-1",
Type: "function",
Function: conversation.ToolCallFunction{
Name: "big_tool",
Arguments: hugeArgs,
},
},
},
},
}
out := pruneMessagesForGateway(msgs)
if len(out) != 1 {
t.Fatalf("expected 1 message, got %d", len(out))
}
if len(out[0].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(out[0].ToolCalls))
}
got := out[0].ToolCalls[0].Function.Arguments
if strings.Contains(got, repeated) {
t.Fatalf("expected tool arguments to be pruned")
}
if !strings.Contains(got, gatewayToolPayloadPrunedMarker) {
t.Fatalf("expected pruned marker in args")
}
if !utf8.ValidString(got) {
t.Fatalf("expected pruned tool arguments to remain valid UTF-8")
}
}
func TestPruneHistoryForGateway_ClearsStaleUsageTokensAfterPrune(t *testing.T) {
t.Parallel()
huge := strings.Repeat("汉😀", (gatewayToolPayloadMaxBytes/len("汉😀"))+20)
firstTokens := 123
secondTokens := 456
in := []messageWithUsage{
{
Message: conversation.ModelMessage{Role: "tool", Content: conversation.NewTextContent(huge), ToolCallID: "call-1"},
UsageInputTokens: &firstTokens,
},
{
Message: conversation.ModelMessage{Role: "user", Content: conversation.NewTextContent("hi")},
UsageInputTokens: &secondTokens,
},
}
out := pruneHistoryForGateway(in)
if len(out) != 2 {
t.Fatalf("expected 2 messages, got %d", len(out))
}
if out[0].UsageInputTokens != nil {
t.Fatalf("expected first UsageInputTokens to be cleared after prune")
}
if out[1].UsageInputTokens != nil {
t.Fatalf("expected subsequent UsageInputTokens to be cleared after earlier prune")
}
}
func TestPruneHistoryForGateway_PreservesUsageTokensWhenUnchanged(t *testing.T) {
t.Parallel()
tokens := 321
in := []messageWithUsage{
{
Message: conversation.ModelMessage{Role: "user", Content: conversation.NewTextContent("short")},
UsageInputTokens: &tokens,
},
}
out := pruneHistoryForGateway(in)
if len(out) != 1 {
t.Fatalf("expected 1 message, got %d", len(out))
}
if out[0].UsageInputTokens == nil || *out[0].UsageInputTokens != tokens {
t.Fatalf("expected UsageInputTokens to be preserved")
}
}
func TestPruneMessagesForGateway_ToolResultPartsRemainValidToolMessageSchema(t *testing.T) {
t.Parallel()
huge := strings.Repeat("a", gatewayToolPayloadMaxBytes+100)
part := map[string]any{
"type": "tool-result",
"toolCallId": "call-1",
"toolName": "big_tool",
"providerOptions": map[string]any{
"test-provider": map[string]any{"mode": "strict"},
},
"extraPart": "keep-part",
"output": map[string]any{
"type": "text",
"value": huge,
"providerOptions": map[string]any{
"test-provider": map[string]any{"cache": true},
},
"extraOutput": "keep-output",
},
}
content, err := json.Marshal([]any{part})
if err != nil {
t.Fatalf("marshal tool content: %v", err)
}
msgs := []conversation.ModelMessage{
{Role: "tool", Content: content, ToolCallID: "call-1"},
}
out := pruneMessagesForGateway(msgs)
if len(out) != 1 {
t.Fatalf("expected 1 message, got %d", len(out))
}
if !bytes.HasPrefix(bytes.TrimSpace(out[0].Content), []byte("[")) {
t.Fatalf("expected tool content to remain an array, got: %q", string(out[0].Content[:minLen(len(out[0].Content), 80)]))
}
if !bytes.Contains(out[0].Content, []byte(`"type":"tool-result"`)) {
t.Fatalf("expected tool-result part to be preserved")
}
if !bytes.Contains(out[0].Content, []byte(gatewayToolPayloadPrunedMarker)) {
t.Fatalf("expected pruned marker in tool output")
}
var parts []map[string]any
if err := json.Unmarshal(out[0].Content, &parts); err != nil {
t.Fatalf("unmarshal pruned tool content: %v", err)
}
if len(parts) != 1 {
t.Fatalf("expected 1 part, got %d", len(parts))
}
if parts[0]["extraPart"] != "keep-part" {
t.Fatalf("expected extra part field preserved")
}
if _, ok := parts[0]["providerOptions"]; !ok {
t.Fatalf("expected part providerOptions preserved")
}
outputAny, ok := parts[0]["output"].(map[string]any)
if !ok {
t.Fatalf("expected output object")
}
if outputAny["extraOutput"] != "keep-output" {
t.Fatalf("expected output extra field preserved")
}
if _, ok := outputAny["providerOptions"]; !ok {
t.Fatalf("expected output providerOptions preserved")
}
if outputAny["type"] != "text" {
t.Fatalf("expected output.type=text, got %v", outputAny["type"])
}
value, ok := outputAny["value"].(string)
if !ok {
t.Fatalf("expected output.value string")
}
if len(value) >= len(huge) {
t.Fatalf("expected output.value to be pruned")
}
}
func minLen(a, b int) int {
if a < b {
return a
}
return b
}
@@ -0,0 +1,139 @@
package flow
import (
"context"
"encoding/json"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/memohai/memoh/internal/conversation"
messagepkg "github.com/memohai/memoh/internal/message"
)
type blockingMessageService struct {
persistCalled chan struct{}
persistContinue chan struct{}
}
func (s *blockingMessageService) Persist(ctx context.Context, input messagepkg.PersistInput) (messagepkg.Message, error) {
select {
case <-s.persistCalled:
default:
close(s.persistCalled)
}
<-s.persistContinue
return messagepkg.Message{}, nil
}
func (s *blockingMessageService) List(ctx context.Context, botID string) ([]messagepkg.Message, error) {
return nil, nil
}
func (s *blockingMessageService) ListSince(ctx context.Context, botID string, since time.Time) ([]messagepkg.Message, error) {
return nil, nil
}
func (s *blockingMessageService) ListLatest(ctx context.Context, botID string, limit int32) ([]messagepkg.Message, error) {
return nil, nil
}
func (s *blockingMessageService) ListBefore(ctx context.Context, botID string, before time.Time, limit int32) ([]messagepkg.Message, error) {
return nil, nil
}
func (s *blockingMessageService) DeleteByBot(ctx context.Context, botID string) error {
return nil
}
func TestStreamChat_PersistsFinalMessagesBeforeForwardingDoneEvent(t *testing.T) {
t.Parallel()
msgSvc := &blockingMessageService{
persistCalled: make(chan struct{}),
persistContinue: make(chan struct{}),
}
doneResp := gatewayResponse{
Messages: []conversation.ModelMessage{
{Role: "assistant", Content: conversation.NewTextContent("ok")},
},
}
doneData, err := json.Marshal(doneResp)
if err != nil {
t.Fatalf("marshal done response: %v", err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/chat/stream" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
_, _ = w.Write([]byte("event: done\n"))
_, _ = w.Write([]byte("data: "))
_, _ = w.Write(doneData)
_, _ = w.Write([]byte("\n\n"))
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}))
t.Cleanup(srv.Close)
r := &Resolver{
messageService: msgSvc,
gatewayBaseURL: srv.URL,
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
streamingClient: srv.Client(),
httpClient: srv.Client(),
}
chunkCh := make(chan conversation.StreamChunk, 10)
req := conversation.ChatRequest{BotID: "bot-test", ChatID: "chat-test"}
payload := gatewayRequest{}
streamDone := make(chan error, 1)
go func() {
streamDone <- r.streamChat(context.Background(), payload, req, chunkCh)
close(chunkCh)
}()
select {
case <-msgSvc.persistCalled:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Persist to be called")
}
select {
case got := <-chunkCh:
t.Fatalf("done event forwarded before persistence finished: %s", string(got))
default:
}
close(msgSvc.persistContinue)
select {
case err := <-streamDone:
if err != nil {
t.Fatalf("streamChat returned error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for streamChat to finish")
}
select {
case got := <-chunkCh:
if len(got) == 0 {
t.Fatal("expected forwarded done event data")
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for forwarded done event data")
}
}
+103 -2
View File
@@ -196,7 +196,7 @@ func TestPrepareGatewayAttachments_InlineAssetToBase64(t *testing.T) {
BotID: "bot-1",
Attachments: []conversation.ChatAttachment{
{
Type: "image",
Type: "image",
ContentHash: "asset-1",
},
},
@@ -243,6 +243,107 @@ func TestPrepareGatewayAttachments_DataURLFromURLFieldIsNativeInline(t *testing.
}
}
func TestStreamChat_AllowsLargeSSEDataLines(t *testing.T) {
const overOldScannerLimit = 3 * 1024 * 1024
hugeDelta := strings.Repeat("a", overOldScannerLimit)
dataJSON, err := json.Marshal(map[string]any{
"type": "text_delta",
"delta": hugeDelta,
})
if err != nil {
t.Fatalf("failed to marshal test payload: %v", err)
}
dataStr := string(dataJSON)
parts := make([]string, 0, (len(dataStr)/8192)+1)
for i := 0; i < len(dataStr); i += 8192 {
end := i + 8192
if end > len(dataStr) {
end = len(dataStr)
}
parts = append(parts, dataStr[i:end])
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/chat/stream" {
w.WriteHeader(http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "text/event-stream")
_, _ = io.WriteString(w, "event: message\n")
for _, part := range parts {
_, _ = io.WriteString(w, "data:")
_, _ = io.WriteString(w, part)
_, _ = io.WriteString(w, "\n")
}
_, _ = io.WriteString(w, "\n")
}))
defer srv.Close()
resolver := &Resolver{
gatewayBaseURL: srv.URL,
streamingClient: srv.Client(),
logger: slog.Default(),
}
chunkCh := make(chan conversation.StreamChunk, 1)
err = resolver.streamChat(
context.Background(),
gatewayRequest{},
conversation.ChatRequest{},
chunkCh,
)
if err != nil {
t.Fatalf("streamChat returned error: %v", err)
}
select {
case chunk := <-chunkCh:
if !bytes.Equal(chunk, dataJSON) {
t.Fatalf("unexpected reconstructed payload: got prefix %q", string(chunk[:min(len(chunk), 80)]))
}
default:
t.Fatalf("expected at least one streamed chunk")
}
}
func TestStreamChat_RejectsOverLimitSSELine(t *testing.T) {
tooLong := strings.Repeat("x", gatewaySSEMaxLineBytes+10)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/chat/stream" {
w.WriteHeader(http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "text/event-stream")
_, _ = io.WriteString(w, "event: message\n")
_, _ = io.WriteString(w, "data:")
_, _ = io.WriteString(w, tooLong)
_, _ = io.WriteString(w, "\n\n")
}))
defer srv.Close()
resolver := &Resolver{
gatewayBaseURL: srv.URL,
streamingClient: srv.Client(),
logger: slog.Default(),
}
chunkCh := make(chan conversation.StreamChunk, 1)
err := resolver.streamChat(context.Background(), gatewayRequest{}, conversation.ChatRequest{}, chunkCh)
if err == nil {
t.Fatalf("expected streamChat to error on oversized SSE line")
}
if !strings.Contains(err.Error(), "sse line too long") {
t.Fatalf("expected line-too-long error, got: %v", err)
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func TestPrepareGatewayAttachments_PublicURLFromURLFieldIsNativePublic(t *testing.T) {
resolver := &Resolver{logger: slog.Default()}
req := conversation.ChatRequest{
@@ -321,7 +422,7 @@ func TestPrepareGatewayAttachments_DetectsImageMimeWhenOctetStream(t *testing.T)
BotID: "bot-1",
Attachments: []conversation.ChatAttachment{
{
Type: "image",
Type: "image",
ContentHash: "asset-2",
},
},
+6 -2
View File
@@ -158,7 +158,9 @@ func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContex
if err != nil {
return mcpgw.BuildToolErrorResult(err.Error()), nil
}
return mcpgw.BuildToolSuccessResult(map[string]any{"content": content}), nil
return mcpgw.BuildToolSuccessResult(map[string]any{
"content": pruneToolOutputText(content, "tool result (read content)"),
}), nil
case toolWrite:
filePath := normalizePath(mcpgw.StringArg(arguments, "path"))
@@ -238,8 +240,10 @@ func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContex
if result.ExitCode != 0 && strings.Contains(stderr, "no running task") {
stderr = strings.TrimSpace(stderr) + "\n\nHint: Container exists but has no running task (main process exited). Start it first: POST /bots/" + botID + "/container/start or use the container start action in the UI."
}
stdout := pruneToolOutputText(result.Stdout, "tool result (exec stdout)")
stderr = pruneToolOutputText(stderr, "tool result (exec stderr)")
return mcpgw.BuildToolSuccessResult(map[string]any{
"stdout": result.Stdout,
"stdout": stdout,
"stderr": stderr,
"exit_code": result.ExitCode,
}), nil
+22
View File
@@ -0,0 +1,22 @@
package container
import textprune "github.com/memohai/memoh/internal/prune"
const (
toolOutputHeadBytes = 4 * 1024
toolOutputTailBytes = 1 * 1024
toolOutputHeadLines = 150
toolOutputTailLines = 50
)
func pruneToolOutputText(text, label string) string {
return textprune.PruneWithEdges(text, label, textprune.Config{
MaxBytes: textprune.DefaultMaxBytes,
MaxLines: textprune.DefaultMaxLines,
HeadBytes: toolOutputHeadBytes,
TailBytes: toolOutputTailBytes,
HeadLines: toolOutputHeadLines,
TailLines: toolOutputTailLines,
Marker: textprune.DefaultMarker,
})
}
+185
View File
@@ -0,0 +1,185 @@
package prune
import (
"fmt"
"strings"
"unicode/utf8"
)
const (
DefaultMarker = "[memoh pruned]"
DefaultMaxBytes = 10 * 1024
DefaultMaxLines = 250
)
type Config struct {
MaxBytes int
MaxLines int
HeadBytes int
TailBytes int
HeadLines int
TailLines int
Marker string
}
func Exceeds(s string, maxBytes, maxLines int) bool {
return len(s) > maxBytes || CountLines(s) > maxLines
}
func CountLines(s string) int {
if s == "" {
return 0
}
return strings.Count(s, "\n") + 1
}
func PruneWithEdges(s, label string, cfg Config) string {
cfg = normalizeConfig(cfg)
if len(s) == 0 {
return s
}
if cfg.HeadBytes+cfg.TailBytes <= 0 || cfg.HeadLines+cfg.TailLines <= 0 {
return fitBudget(fmt.Sprintf(
"%s %s omitted (bytes=%d, lines=%d)",
cfg.Marker,
label,
len(s),
CountLines(s),
), cfg)
}
if !Exceeds(s, cfg.MaxBytes, cfg.MaxLines) {
return s
}
head := boundedPrefix(s, minInt(cfg.HeadBytes, len(s)), cfg.HeadLines)
tail := ""
if cfg.TailBytes > 0 && cfg.TailLines > 0 {
tail = boundedSuffix(s, minInt(cfg.TailBytes, len(s)), cfg.TailLines)
}
return fitBudget(fmt.Sprintf(
"%s %s too long (bytes=%d, lines=%d), showing head/tail\n\n%s\n\n[...snip...]\n\n%s",
cfg.Marker,
label,
len(s),
CountLines(s),
head,
tail,
), cfg)
}
func normalizeConfig(cfg Config) Config {
if cfg.MaxBytes <= 0 {
cfg.MaxBytes = DefaultMaxBytes
}
if cfg.MaxLines <= 0 {
cfg.MaxLines = DefaultMaxLines
}
if cfg.Marker == "" {
cfg.Marker = DefaultMarker
}
if cfg.HeadBytes < 0 {
cfg.HeadBytes = 0
}
if cfg.TailBytes < 0 {
cfg.TailBytes = 0
}
if cfg.HeadLines < 0 {
cfg.HeadLines = 0
}
if cfg.TailLines < 0 {
cfg.TailLines = 0
}
return cfg
}
func fitBudget(s string, cfg Config) string {
if !Exceeds(s, cfg.MaxBytes, cfg.MaxLines) {
return s
}
trimmed := boundedPrefix(s, cfg.MaxBytes, cfg.MaxLines)
if trimmed == "" {
return cfg.Marker
}
return trimmed
}
func boundedPrefix(s string, maxBytes, maxLines int) string {
if len(s) == 0 || maxBytes <= 0 || maxLines <= 0 {
return ""
}
prefix := safeUTF8Prefix(s, minInt(maxBytes, len(s)))
return limitLinesPrefix(prefix, maxLines)
}
func boundedSuffix(s string, maxBytes, maxLines int) string {
if len(s) == 0 || maxBytes <= 0 || maxLines <= 0 {
return ""
}
suffix := safeUTF8Suffix(s, minInt(maxBytes, len(s)))
return limitLinesSuffix(suffix, maxLines)
}
func safeUTF8Prefix(s string, maxBytes int) string {
if maxBytes <= 0 || len(s) == 0 {
return ""
}
if maxBytes >= len(s) {
return s
}
cut := maxBytes
for cut > 0 && cut < len(s) && !utf8.RuneStart(s[cut]) {
cut--
}
if cut <= 0 {
return ""
}
return s[:cut]
}
func safeUTF8Suffix(s string, maxBytes int) string {
if maxBytes <= 0 || len(s) == 0 {
return ""
}
if maxBytes >= len(s) {
return s
}
start := len(s) - maxBytes
if start < 0 {
start = 0
}
for start < len(s) && !utf8.RuneStart(s[start]) {
start++
}
if start >= len(s) {
return ""
}
return s[start:]
}
func limitLinesPrefix(s string, maxLines int) string {
if maxLines <= 0 || s == "" {
return ""
}
lines := strings.Split(s, "\n")
if len(lines) <= maxLines {
return s
}
return strings.Join(lines[:maxLines], "\n")
}
func limitLinesSuffix(s string, maxLines int) string {
if maxLines <= 0 || s == "" {
return ""
}
lines := strings.Split(s, "\n")
if len(lines) <= maxLines {
return s
}
return strings.Join(lines[len(lines)-maxLines:], "\n")
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}