mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
fix(flow): stabilize chunked SSE and unify prune limits for read/exec/gateway (#71)
* fix(agent): emit chunked SSE data fix(flow): reassemble chunked SSE and prune tool payloads fix: avoid whitespace prune bypass; optimize chunked SSE builder * refactor: LLM provider pruning use shared textprune library * chore: smaller range
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import { Elysia, sse } from 'elysia'
|
||||
import { Elysia } from 'elysia'
|
||||
import z from 'zod'
|
||||
import { createAgent } from '../agent'
|
||||
import { createAuthFetcher, getBaseUrl } from '../index'
|
||||
@@ -6,6 +6,7 @@ import { ModelConfig } from '../types'
|
||||
import { bearerMiddleware } from '../middlewares/bearer'
|
||||
import { AgentSkillModel, AllowedActionModel, AttachmentModel, IdentityContextModel, MCPConnectionModel, ModelConfigModel, ScheduleModel } from '../models'
|
||||
import { allActions } from '../types'
|
||||
import { sseChunked } from '../utils/sse'
|
||||
|
||||
const AgentModel = z.object({
|
||||
model: ModelConfigModel,
|
||||
@@ -75,14 +76,14 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
skills: body.skills,
|
||||
attachments: body.attachments,
|
||||
})) {
|
||||
yield sse(JSON.stringify(action))
|
||||
yield sseChunked(JSON.stringify(action))
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(error)
|
||||
const message = error instanceof Error && error.message.trim()
|
||||
? error.message
|
||||
: 'Internal server error'
|
||||
yield sse(JSON.stringify({
|
||||
yield sseChunked(JSON.stringify({
|
||||
type: 'error',
|
||||
message,
|
||||
}))
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
import { describe, expect, test } from 'bun:test'
|
||||
import { sseChunked } from '../utils/sse'
|
||||
|
||||
function parseChunkedSSE(payload: string): string {
|
||||
const lines = payload.split('\n')
|
||||
const dataLines = lines.filter(line => line.startsWith('data:'))
|
||||
return dataLines.map(line => line.slice('data:'.length)).join('')
|
||||
}
|
||||
|
||||
describe('sseChunked', () => {
|
||||
test('reconstructs original payload losslessly', () => {
|
||||
const input = JSON.stringify({
|
||||
type: 'tool_call_end',
|
||||
toolName: 'big_tool',
|
||||
toolCallId: 'call-1',
|
||||
// include whitespace and unicode so trimming/surrogate splitting bugs show up
|
||||
result: ' leading spaces\tand tabs\nand unicode 😀😃😄 ',
|
||||
blob: 'x'.repeat(200_000),
|
||||
})
|
||||
|
||||
const chunked = sseChunked(input, 1024).toSSE()
|
||||
const reconstructed = parseChunkedSSE(chunked)
|
||||
|
||||
expect(reconstructed).toBe(input)
|
||||
})
|
||||
|
||||
test('chunkSize=1 does not produce invalid UTF-8 (surrogate pairs)', () => {
|
||||
const input = `😀${'x'.repeat(1000)}😃`
|
||||
const payload = sseChunked(input, 1).toSSE()
|
||||
|
||||
// Simulate the UTF-8 encode/decode step that happens over the network.
|
||||
const encoded = new TextEncoder().encode(payload)
|
||||
const decoded = new TextDecoder().decode(encoded)
|
||||
expect(decoded).toBe(payload)
|
||||
|
||||
const reconstructed = parseChunkedSSE(decoded)
|
||||
expect(reconstructed).toBe(input)
|
||||
})
|
||||
|
||||
test('does not inject an extra space after data:', () => {
|
||||
const input = ' abc'
|
||||
const chunked = sseChunked(input, 2).toSSE()
|
||||
expect(chunked.split('\n')[0]).toBe('data: a')
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,50 @@
|
||||
export const defaultSSEChunkSize = 16 * 1024
|
||||
|
||||
export function sseChunked(data: string, chunkSize: number = defaultSSEChunkSize) {
|
||||
return {
|
||||
sse: true as const,
|
||||
toSSE: () => {
|
||||
const out: string[] = []
|
||||
for (const chunk of chunkString(data, chunkSize)) {
|
||||
out.push(`data:${chunk}\n`)
|
||||
}
|
||||
out.push('\n')
|
||||
return out.join('')
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
export function* chunkString(input: string, maxLen: number): Generator<string> {
|
||||
if (maxLen <= 0) {
|
||||
yield input
|
||||
return
|
||||
}
|
||||
const isHighSurrogate = (c: number) => c >= 0xd800 && c <= 0xdbff
|
||||
const isLowSurrogate = (c: number) => c >= 0xdc00 && c <= 0xdfff
|
||||
let i = 0
|
||||
while (i < input.length) {
|
||||
let end = Math.min(i + maxLen, input.length)
|
||||
if (end < input.length) {
|
||||
const last = input.charCodeAt(end - 1)
|
||||
if (isHighSurrogate(last)) {
|
||||
const next = input.charCodeAt(end)
|
||||
if (isLowSurrogate(next)) {
|
||||
end += 1
|
||||
} else {
|
||||
end -= 1
|
||||
}
|
||||
}
|
||||
}
|
||||
if (end <= i) {
|
||||
const first = input.charCodeAt(i)
|
||||
const second = i+1 < input.length ? input.charCodeAt(i + 1) : -1
|
||||
if (isHighSurrogate(first) && isLowSurrogate(second)) {
|
||||
end = Math.min(i + 2, input.length)
|
||||
} else {
|
||||
end = Math.min(i + 1, input.length)
|
||||
}
|
||||
}
|
||||
yield input.slice(i, end)
|
||||
i = end
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,319 @@
|
||||
package flow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
textprune "github.com/memohai/memoh/internal/prune"
|
||||
)
|
||||
|
||||
const (
|
||||
// Prune long tool payloads per message to keep gateway requests within provider limits,
|
||||
// while preserving as much surrounding context as possible.
|
||||
gatewayToolPayloadMaxBytes = textprune.DefaultMaxBytes
|
||||
gatewayToolPayloadMaxLines = textprune.DefaultMaxLines
|
||||
|
||||
gatewayToolResultHeadBytes = 6 * 1024
|
||||
gatewayToolResultTailBytes = 2 * 1024
|
||||
gatewayToolResultHeadLines = 180
|
||||
gatewayToolResultTailLines = 50
|
||||
|
||||
gatewayToolArgsHeadBytes = 4 * 1024
|
||||
gatewayToolArgsTailBytes = 2 * 1024
|
||||
gatewayToolArgsHeadLines = 180
|
||||
gatewayToolArgsTailLines = 50
|
||||
|
||||
gatewayToolPayloadPrunedMarker = textprune.DefaultMarker
|
||||
)
|
||||
|
||||
func pruneHistoryForGateway(messages []messageWithUsage) []messageWithUsage {
|
||||
if len(messages) == 0 {
|
||||
return messages
|
||||
}
|
||||
out := make([]messageWithUsage, 0, len(messages))
|
||||
staleUsage := false
|
||||
for _, item := range messages {
|
||||
msg, changed := pruneMessageForGateway(item.Message)
|
||||
if changed {
|
||||
item.Message = msg
|
||||
staleUsage = true
|
||||
}
|
||||
if staleUsage {
|
||||
item.UsageInputTokens = nil
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func pruneMessagesForGateway(messages []conversation.ModelMessage) []conversation.ModelMessage {
|
||||
if len(messages) == 0 {
|
||||
return messages
|
||||
}
|
||||
out := make([]conversation.ModelMessage, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
pruned, _ := pruneMessageForGateway(msg)
|
||||
out = append(out, pruned)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func pruneMessageForGateway(msg conversation.ModelMessage) (conversation.ModelMessage, bool) {
|
||||
changed := false
|
||||
if strings.EqualFold(strings.TrimSpace(msg.Role), "tool") {
|
||||
msg2, did := pruneToolMessage(msg)
|
||||
if did {
|
||||
msg = msg2
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
calls, did := pruneToolCalls(msg.ToolCalls)
|
||||
if did {
|
||||
msg.ToolCalls = calls
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return msg, changed
|
||||
}
|
||||
|
||||
func pruneToolCalls(calls []conversation.ToolCall) ([]conversation.ToolCall, bool) {
|
||||
changed := false
|
||||
out := make([]conversation.ToolCall, len(calls))
|
||||
for i, call := range calls {
|
||||
out[i] = call
|
||||
args := call.Function.Arguments
|
||||
if args == "" || !exceedsTextBudget(args) {
|
||||
continue
|
||||
}
|
||||
out[i].Function.Arguments = pruneStringEdges(
|
||||
args,
|
||||
gatewayToolArgsHeadBytes,
|
||||
gatewayToolArgsTailBytes,
|
||||
gatewayToolArgsHeadLines,
|
||||
gatewayToolArgsTailLines,
|
||||
"tool arguments",
|
||||
)
|
||||
changed = true
|
||||
}
|
||||
return out, changed
|
||||
}
|
||||
|
||||
func pruneToolMessage(msg conversation.ModelMessage) (conversation.ModelMessage, bool) {
|
||||
// Vercel AI SDK schema requires tool messages to carry an array of tool-result parts.
|
||||
// Prune outputs inside those parts (preserving shape) so the gateway prompt remains valid.
|
||||
if pruned, ok := pruneToolResultParts(msg.Content); ok {
|
||||
msg.Content = pruned
|
||||
return msg, true
|
||||
}
|
||||
|
||||
// Backward-compat: tool messages may have been persisted as plain strings.
|
||||
text := msg.TextContent()
|
||||
if !exceedsTextBudget(text) {
|
||||
return msg, false
|
||||
}
|
||||
msg.Content = conversation.NewTextContent(pruneStringEdges(
|
||||
text,
|
||||
gatewayToolResultHeadBytes,
|
||||
gatewayToolResultTailBytes,
|
||||
gatewayToolResultHeadLines,
|
||||
gatewayToolResultTailLines,
|
||||
"tool result",
|
||||
))
|
||||
return msg, true
|
||||
}
|
||||
|
||||
func pruneToolResultParts(content json.RawMessage) (json.RawMessage, bool) {
|
||||
if len(content) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
var parts []json.RawMessage
|
||||
if err := json.Unmarshal(content, &parts); err != nil || len(parts) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
changed := false
|
||||
out := make([]json.RawMessage, 0, len(parts))
|
||||
for _, raw := range parts {
|
||||
var part map[string]json.RawMessage
|
||||
if err := json.Unmarshal(raw, &part); err != nil {
|
||||
out = append(out, raw)
|
||||
continue
|
||||
}
|
||||
|
||||
partTypeRaw, ok := part["type"]
|
||||
if !ok {
|
||||
out = append(out, raw)
|
||||
continue
|
||||
}
|
||||
var partType string
|
||||
if err := json.Unmarshal(partTypeRaw, &partType); err != nil || partType != "tool-result" {
|
||||
out = append(out, raw)
|
||||
continue
|
||||
}
|
||||
|
||||
outputRaw, ok := part["output"]
|
||||
if !ok {
|
||||
out = append(out, raw)
|
||||
continue
|
||||
}
|
||||
pruned, didPrune := pruneToolOutput(outputRaw)
|
||||
if !didPrune {
|
||||
out = append(out, raw)
|
||||
continue
|
||||
}
|
||||
|
||||
part["output"] = pruned
|
||||
rebuilt, err := json.Marshal(part)
|
||||
if err != nil {
|
||||
out = append(out, raw)
|
||||
continue
|
||||
}
|
||||
out = append(out, json.RawMessage(rebuilt))
|
||||
changed = true
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return nil, false
|
||||
}
|
||||
rebuilt, err := json.Marshal(out)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return json.RawMessage(rebuilt), true
|
||||
}
|
||||
|
||||
func pruneToolOutput(raw json.RawMessage) (json.RawMessage, bool) {
|
||||
var output map[string]json.RawMessage
|
||||
if err := json.Unmarshal(raw, &output); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
outputTypeRaw, ok := output["type"]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
var outputType string
|
||||
if err := json.Unmarshal(outputTypeRaw, &outputType); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
valueRaw, hasValue := output["value"]
|
||||
|
||||
switch outputType {
|
||||
case "text", "error-text":
|
||||
if !hasValue {
|
||||
return nil, false
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(valueRaw, &s); err != nil || !exceedsTextBudget(s) {
|
||||
return nil, false
|
||||
}
|
||||
s = pruneStringEdges(
|
||||
s,
|
||||
gatewayToolResultHeadBytes,
|
||||
gatewayToolResultTailBytes,
|
||||
gatewayToolResultHeadLines,
|
||||
gatewayToolResultTailLines,
|
||||
"tool result",
|
||||
)
|
||||
data, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
output["value"] = data
|
||||
rebuilt, err := json.Marshal(output)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return json.RawMessage(rebuilt), true
|
||||
|
||||
case "json", "error-json":
|
||||
if !hasValue || !exceedsTextBudget(string(valueRaw)) {
|
||||
return nil, false
|
||||
}
|
||||
pruned := pruneStringEdges(
|
||||
string(valueRaw),
|
||||
gatewayToolResultHeadBytes,
|
||||
gatewayToolResultTailBytes,
|
||||
gatewayToolResultHeadLines,
|
||||
gatewayToolResultTailLines,
|
||||
"tool result (json)",
|
||||
)
|
||||
data, err := json.Marshal(pruned)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
output["value"] = data
|
||||
rebuilt, err := json.Marshal(output)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return json.RawMessage(rebuilt), true
|
||||
|
||||
case "content":
|
||||
// Best-effort: prune any large text items inside the content array.
|
||||
// If parsing fails, keep the original output to avoid breaking schema.
|
||||
if !hasValue {
|
||||
return nil, false
|
||||
}
|
||||
var items []map[string]any
|
||||
if err := json.Unmarshal(valueRaw, &items); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
didPrune := false
|
||||
for i := range items {
|
||||
if items[i]["type"] != "text" {
|
||||
continue
|
||||
}
|
||||
textAny, ok := items[i]["text"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
text, ok := textAny.(string)
|
||||
if !ok || !exceedsTextBudget(text) {
|
||||
continue
|
||||
}
|
||||
items[i]["text"] = pruneStringEdges(
|
||||
text,
|
||||
gatewayToolResultHeadBytes,
|
||||
gatewayToolResultTailBytes,
|
||||
gatewayToolResultHeadLines,
|
||||
gatewayToolResultTailLines,
|
||||
"tool result (content)",
|
||||
)
|
||||
didPrune = true
|
||||
}
|
||||
if !didPrune {
|
||||
return nil, false
|
||||
}
|
||||
data, err := json.Marshal(items)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
output["value"] = data
|
||||
rebuilt, err := json.Marshal(output)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return json.RawMessage(rebuilt), true
|
||||
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func pruneStringEdges(s string, headBytes, tailBytes, headLines, tailLines int, label string) string {
|
||||
return textprune.PruneWithEdges(s, label, textprune.Config{
|
||||
MaxBytes: gatewayToolPayloadMaxBytes,
|
||||
MaxLines: gatewayToolPayloadMaxLines,
|
||||
HeadBytes: headBytes,
|
||||
TailBytes: tailBytes,
|
||||
HeadLines: headLines,
|
||||
TailLines: tailLines,
|
||||
Marker: gatewayToolPayloadPrunedMarker,
|
||||
})
|
||||
}
|
||||
|
||||
func exceedsTextBudget(s string) bool {
|
||||
return textprune.Exceeds(s, gatewayToolPayloadMaxBytes, gatewayToolPayloadMaxLines)
|
||||
}
|
||||
@@ -2,9 +2,11 @@ package flow
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -34,6 +36,12 @@ const (
|
||||
sharedMemoryNamespace = "bot"
|
||||
// Keep gateway payload bounded when inlining binary attachments as data URLs.
|
||||
gatewayInlineAttachmentMaxBytes int64 = 20 * 1024 * 1024
|
||||
// SSE payloads (especially attachment/tool results) can be very large.
|
||||
// bufio.Scanner hard-fails with "token too long" if a single line exceeds its max token size.
|
||||
// Use a reader-based parser and enforce an explicit per-line cap here. The agent gateway
|
||||
// stream is expected to chunk large JSON payloads across multiple SSE "data:" lines, so
|
||||
// this limit should stay relatively small.
|
||||
gatewaySSEMaxLineBytes = 256 * 1024
|
||||
)
|
||||
|
||||
// SkillEntry represents a skill loaded from the container.
|
||||
@@ -255,11 +263,16 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
|
||||
// Build non-history parts first so we can reserve their token cost before
|
||||
// trimming history messages.
|
||||
memoryMsg := r.loadMemoryContextMessage(ctx, req)
|
||||
reqMessages := pruneMessagesForGateway(nonNilModelMessages(req.Messages))
|
||||
if memoryMsg != nil {
|
||||
pruned, _ := pruneMessageForGateway(*memoryMsg)
|
||||
memoryMsg = &pruned
|
||||
}
|
||||
var overhead int
|
||||
if memoryMsg != nil {
|
||||
overhead += estimateMessageTokens(*memoryMsg)
|
||||
}
|
||||
for _, m := range req.Messages {
|
||||
for _, m := range reqMessages {
|
||||
overhead += estimateMessageTokens(m)
|
||||
}
|
||||
// Reserve space for the system prompt built by the agent gateway
|
||||
@@ -278,12 +291,13 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
|
||||
if loadErr != nil {
|
||||
return resolvedContext{}, loadErr
|
||||
}
|
||||
loaded = pruneHistoryForGateway(loaded)
|
||||
messages = trimMessagesByTokens(loaded, historyBudget)
|
||||
}
|
||||
if memoryMsg != nil {
|
||||
messages = append(messages, *memoryMsg)
|
||||
}
|
||||
messages = append(messages, req.Messages...)
|
||||
messages = append(messages, reqMessages...)
|
||||
messages = sanitizeMessages(messages)
|
||||
skills := dedup(req.Skills)
|
||||
containerID := r.resolveContainerID(ctx, req.BotID, req.ContainerID)
|
||||
@@ -580,39 +594,66 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req c
|
||||
return fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(errBody)))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
|
||||
|
||||
currentEvent := ""
|
||||
stored := false
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "event:") {
|
||||
currentEvent = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
chunkCh <- conversation.StreamChunk([]byte(data))
|
||||
var dataBuf bytes.Buffer
|
||||
|
||||
if stored {
|
||||
flushEvent := func() error {
|
||||
if dataBuf.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
out := append([]byte(nil), dataBuf.Bytes()...)
|
||||
dataBuf.Reset()
|
||||
if len(out) == 0 || bytes.Equal(bytes.TrimSpace(out), []byte("[DONE]")) {
|
||||
return nil
|
||||
}
|
||||
// Persist final messages before forwarding the "done"/"agent_end" event so the
|
||||
// next user turn can immediately see the assistant output in history.
|
||||
if !stored {
|
||||
if handled, storeErr := r.tryStoreStream(ctx, req, out); storeErr != nil {
|
||||
return storeErr
|
||||
} else if handled {
|
||||
stored = true
|
||||
}
|
||||
}
|
||||
chunkCh <- conversation.StreamChunk(out)
|
||||
return nil
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 64*1024), gatewaySSEMaxLineBytes)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
if err := flushEvent(); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if handled, storeErr := r.tryStoreStream(ctx, req, currentEvent, data); storeErr != nil {
|
||||
return storeErr
|
||||
} else if handled {
|
||||
stored = true
|
||||
if len(line) > 0 && line[0] == ':' {
|
||||
continue
|
||||
}
|
||||
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
part := bytes.TrimPrefix(line, []byte("data:"))
|
||||
// Backward-compat: older SSE writers used "data: <payload>" (note the space).
|
||||
// Only strip the first leading space for the *first* fragment to avoid corrupting
|
||||
// chunked payloads split inside JSON string values.
|
||||
if dataBuf.Len() == 0 && len(part) > 0 && part[0] == ' ' {
|
||||
part = part[1:]
|
||||
}
|
||||
if len(part) == 0 {
|
||||
continue
|
||||
}
|
||||
_, _ = dataBuf.Write(part)
|
||||
}
|
||||
return scanner.Err()
|
||||
if err := scanner.Err(); err != nil {
|
||||
if errors.Is(err, bufio.ErrTooLong) {
|
||||
return fmt.Errorf("sse line too long (max %d bytes)", gatewaySSEMaxLineBytes)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return flushEvent()
|
||||
}
|
||||
|
||||
func newJSONRequestWithContext(ctx context.Context, method, url string, payload any) (*http.Request, error) {
|
||||
@@ -631,24 +672,15 @@ func newJSONRequestWithContext(ctx context.Context, method, url string, payload
|
||||
}
|
||||
|
||||
// tryStoreStream attempts to extract final messages from a stream event and persist them.
|
||||
func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, eventType, data string) (bool, error) {
|
||||
// event: done + data: {messages: [...]}
|
||||
if eventType == "done" {
|
||||
var resp gatewayResponse
|
||||
if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 {
|
||||
return true, r.storeRound(ctx, req, resp.Messages, resp.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, data []byte) (bool, error) {
|
||||
// data: {"type":"text_delta"|"agent_end"|"done", ...}
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Messages []conversation.ModelMessage `json:"messages"`
|
||||
Skills []string `json:"skills"`
|
||||
Usage json.RawMessage `json:"usage,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(data), &envelope); err == nil {
|
||||
if err := json.Unmarshal(data, &envelope); err == nil {
|
||||
if (envelope.Type == "agent_end" || envelope.Type == "done") && len(envelope.Messages) > 0 {
|
||||
return true, r.storeRound(ctx, req, envelope.Messages, envelope.Usage)
|
||||
}
|
||||
@@ -662,7 +694,7 @@ func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequ
|
||||
|
||||
// fallback: data: {messages: [...]}
|
||||
var resp gatewayResponse
|
||||
if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 {
|
||||
if err := json.Unmarshal(data, &resp); err == nil && len(resp.Messages) > 0 {
|
||||
return true, r.storeRound(ctx, req, resp.Messages, resp.Usage)
|
||||
}
|
||||
return false, nil
|
||||
@@ -763,19 +795,14 @@ func normalizeGatewayAttachmentPayload(item gatewayAttachment) gatewayAttachment
|
||||
if payload == "" {
|
||||
return item
|
||||
}
|
||||
lower := strings.ToLower(payload)
|
||||
if strings.HasPrefix(lower, "data:") {
|
||||
if strings.TrimSpace(item.Mime) == "" || strings.EqualFold(strings.TrimSpace(item.Mime), "application/octet-stream") {
|
||||
if start := strings.Index(payload, ":"); start >= 0 {
|
||||
rest := payload[start+1:]
|
||||
if end := strings.Index(rest, ";"); end > 0 {
|
||||
mime := strings.TrimSpace(rest[:end])
|
||||
if mime != "" {
|
||||
item.Mime = mime
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(payload), "data:") {
|
||||
mime := strings.TrimSpace(item.Mime)
|
||||
if mime == "" || strings.EqualFold(mime, "application/octet-stream") {
|
||||
if extracted := attachmentpkg.MimeFromDataURL(payload); extracted != "" {
|
||||
item.Mime = extracted
|
||||
}
|
||||
}
|
||||
item.Payload = payload
|
||||
return item
|
||||
}
|
||||
mime := strings.TrimSpace(item.Mime)
|
||||
|
||||
@@ -0,0 +1,208 @@
|
||||
package flow
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
)
|
||||
|
||||
func TestPruneMessagesForGateway_PrunesToolResultContent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
unit := "汉😀"
|
||||
huge := strings.Repeat(unit, (gatewayToolPayloadMaxBytes/len(unit))+20)
|
||||
msgs := []conversation.ModelMessage{
|
||||
{Role: "tool", Content: conversation.NewTextContent(huge), ToolCallID: "call-1"},
|
||||
}
|
||||
out := pruneMessagesForGateway(msgs)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(out))
|
||||
}
|
||||
got := out[0].TextContent()
|
||||
if strings.Contains(got, huge) {
|
||||
t.Fatalf("expected tool content to be pruned")
|
||||
}
|
||||
if !strings.Contains(got, gatewayToolPayloadPrunedMarker) {
|
||||
t.Fatalf("expected pruned marker, got: %q", got[:minLen(len(got), 80)])
|
||||
}
|
||||
if !utf8.ValidString(got) {
|
||||
t.Fatalf("expected pruned tool content to remain valid UTF-8")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPruneMessagesForGateway_PrunesToolCallArguments(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
repeated := strings.Repeat("猫😺", (gatewayToolPayloadMaxBytes/len("猫😺"))+20)
|
||||
hugeArgs := `{"a":"` + repeated + `"}`
|
||||
msgs := []conversation.ModelMessage{
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []conversation.ToolCall{
|
||||
{
|
||||
ID: "call-1",
|
||||
Type: "function",
|
||||
Function: conversation.ToolCallFunction{
|
||||
Name: "big_tool",
|
||||
Arguments: hugeArgs,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
out := pruneMessagesForGateway(msgs)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(out))
|
||||
}
|
||||
if len(out[0].ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(out[0].ToolCalls))
|
||||
}
|
||||
got := out[0].ToolCalls[0].Function.Arguments
|
||||
if strings.Contains(got, repeated) {
|
||||
t.Fatalf("expected tool arguments to be pruned")
|
||||
}
|
||||
if !strings.Contains(got, gatewayToolPayloadPrunedMarker) {
|
||||
t.Fatalf("expected pruned marker in args")
|
||||
}
|
||||
if !utf8.ValidString(got) {
|
||||
t.Fatalf("expected pruned tool arguments to remain valid UTF-8")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPruneHistoryForGateway_ClearsStaleUsageTokensAfterPrune(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
huge := strings.Repeat("汉😀", (gatewayToolPayloadMaxBytes/len("汉😀"))+20)
|
||||
firstTokens := 123
|
||||
secondTokens := 456
|
||||
|
||||
in := []messageWithUsage{
|
||||
{
|
||||
Message: conversation.ModelMessage{Role: "tool", Content: conversation.NewTextContent(huge), ToolCallID: "call-1"},
|
||||
UsageInputTokens: &firstTokens,
|
||||
},
|
||||
{
|
||||
Message: conversation.ModelMessage{Role: "user", Content: conversation.NewTextContent("hi")},
|
||||
UsageInputTokens: &secondTokens,
|
||||
},
|
||||
}
|
||||
|
||||
out := pruneHistoryForGateway(in)
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(out))
|
||||
}
|
||||
if out[0].UsageInputTokens != nil {
|
||||
t.Fatalf("expected first UsageInputTokens to be cleared after prune")
|
||||
}
|
||||
if out[1].UsageInputTokens != nil {
|
||||
t.Fatalf("expected subsequent UsageInputTokens to be cleared after earlier prune")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPruneHistoryForGateway_PreservesUsageTokensWhenUnchanged(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tokens := 321
|
||||
in := []messageWithUsage{
|
||||
{
|
||||
Message: conversation.ModelMessage{Role: "user", Content: conversation.NewTextContent("short")},
|
||||
UsageInputTokens: &tokens,
|
||||
},
|
||||
}
|
||||
out := pruneHistoryForGateway(in)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(out))
|
||||
}
|
||||
if out[0].UsageInputTokens == nil || *out[0].UsageInputTokens != tokens {
|
||||
t.Fatalf("expected UsageInputTokens to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPruneMessagesForGateway_ToolResultPartsRemainValidToolMessageSchema(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
huge := strings.Repeat("a", gatewayToolPayloadMaxBytes+100)
|
||||
part := map[string]any{
|
||||
"type": "tool-result",
|
||||
"toolCallId": "call-1",
|
||||
"toolName": "big_tool",
|
||||
"providerOptions": map[string]any{
|
||||
"test-provider": map[string]any{"mode": "strict"},
|
||||
},
|
||||
"extraPart": "keep-part",
|
||||
"output": map[string]any{
|
||||
"type": "text",
|
||||
"value": huge,
|
||||
"providerOptions": map[string]any{
|
||||
"test-provider": map[string]any{"cache": true},
|
||||
},
|
||||
"extraOutput": "keep-output",
|
||||
},
|
||||
}
|
||||
content, err := json.Marshal([]any{part})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal tool content: %v", err)
|
||||
}
|
||||
msgs := []conversation.ModelMessage{
|
||||
{Role: "tool", Content: content, ToolCallID: "call-1"},
|
||||
}
|
||||
|
||||
out := pruneMessagesForGateway(msgs)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(out))
|
||||
}
|
||||
if !bytes.HasPrefix(bytes.TrimSpace(out[0].Content), []byte("[")) {
|
||||
t.Fatalf("expected tool content to remain an array, got: %q", string(out[0].Content[:minLen(len(out[0].Content), 80)]))
|
||||
}
|
||||
if !bytes.Contains(out[0].Content, []byte(`"type":"tool-result"`)) {
|
||||
t.Fatalf("expected tool-result part to be preserved")
|
||||
}
|
||||
if !bytes.Contains(out[0].Content, []byte(gatewayToolPayloadPrunedMarker)) {
|
||||
t.Fatalf("expected pruned marker in tool output")
|
||||
}
|
||||
|
||||
var parts []map[string]any
|
||||
if err := json.Unmarshal(out[0].Content, &parts); err != nil {
|
||||
t.Fatalf("unmarshal pruned tool content: %v", err)
|
||||
}
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("expected 1 part, got %d", len(parts))
|
||||
}
|
||||
if parts[0]["extraPart"] != "keep-part" {
|
||||
t.Fatalf("expected extra part field preserved")
|
||||
}
|
||||
if _, ok := parts[0]["providerOptions"]; !ok {
|
||||
t.Fatalf("expected part providerOptions preserved")
|
||||
}
|
||||
outputAny, ok := parts[0]["output"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected output object")
|
||||
}
|
||||
if outputAny["extraOutput"] != "keep-output" {
|
||||
t.Fatalf("expected output extra field preserved")
|
||||
}
|
||||
if _, ok := outputAny["providerOptions"]; !ok {
|
||||
t.Fatalf("expected output providerOptions preserved")
|
||||
}
|
||||
if outputAny["type"] != "text" {
|
||||
t.Fatalf("expected output.type=text, got %v", outputAny["type"])
|
||||
}
|
||||
value, ok := outputAny["value"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("expected output.value string")
|
||||
}
|
||||
if len(value) >= len(huge) {
|
||||
t.Fatalf("expected output.value to be pruned")
|
||||
}
|
||||
}
|
||||
|
||||
func minLen(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
package flow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
messagepkg "github.com/memohai/memoh/internal/message"
|
||||
)
|
||||
|
||||
type blockingMessageService struct {
|
||||
persistCalled chan struct{}
|
||||
persistContinue chan struct{}
|
||||
}
|
||||
|
||||
func (s *blockingMessageService) Persist(ctx context.Context, input messagepkg.PersistInput) (messagepkg.Message, error) {
|
||||
select {
|
||||
case <-s.persistCalled:
|
||||
default:
|
||||
close(s.persistCalled)
|
||||
}
|
||||
<-s.persistContinue
|
||||
return messagepkg.Message{}, nil
|
||||
}
|
||||
|
||||
func (s *blockingMessageService) List(ctx context.Context, botID string) ([]messagepkg.Message, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *blockingMessageService) ListSince(ctx context.Context, botID string, since time.Time) ([]messagepkg.Message, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *blockingMessageService) ListLatest(ctx context.Context, botID string, limit int32) ([]messagepkg.Message, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *blockingMessageService) ListBefore(ctx context.Context, botID string, before time.Time, limit int32) ([]messagepkg.Message, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *blockingMessageService) DeleteByBot(ctx context.Context, botID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestStreamChat_PersistsFinalMessagesBeforeForwardingDoneEvent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
msgSvc := &blockingMessageService{
|
||||
persistCalled: make(chan struct{}),
|
||||
persistContinue: make(chan struct{}),
|
||||
}
|
||||
|
||||
doneResp := gatewayResponse{
|
||||
Messages: []conversation.ModelMessage{
|
||||
{Role: "assistant", Content: conversation.NewTextContent("ok")},
|
||||
},
|
||||
}
|
||||
doneData, err := json.Marshal(doneResp)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal done response: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/stream" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
_, _ = w.Write([]byte("event: done\n"))
|
||||
_, _ = w.Write([]byte("data: "))
|
||||
_, _ = w.Write(doneData)
|
||||
_, _ = w.Write([]byte("\n\n"))
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
r := &Resolver{
|
||||
messageService: msgSvc,
|
||||
gatewayBaseURL: srv.URL,
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
streamingClient: srv.Client(),
|
||||
httpClient: srv.Client(),
|
||||
}
|
||||
|
||||
chunkCh := make(chan conversation.StreamChunk, 10)
|
||||
req := conversation.ChatRequest{BotID: "bot-test", ChatID: "chat-test"}
|
||||
payload := gatewayRequest{}
|
||||
|
||||
streamDone := make(chan error, 1)
|
||||
go func() {
|
||||
streamDone <- r.streamChat(context.Background(), payload, req, chunkCh)
|
||||
close(chunkCh)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-msgSvc.persistCalled:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for Persist to be called")
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-chunkCh:
|
||||
t.Fatalf("done event forwarded before persistence finished: %s", string(got))
|
||||
default:
|
||||
}
|
||||
|
||||
close(msgSvc.persistContinue)
|
||||
|
||||
select {
|
||||
case err := <-streamDone:
|
||||
if err != nil {
|
||||
t.Fatalf("streamChat returned error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for streamChat to finish")
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-chunkCh:
|
||||
if len(got) == 0 {
|
||||
t.Fatal("expected forwarded done event data")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for forwarded done event data")
|
||||
}
|
||||
}
|
||||
@@ -196,7 +196,7 @@ func TestPrepareGatewayAttachments_InlineAssetToBase64(t *testing.T) {
|
||||
BotID: "bot-1",
|
||||
Attachments: []conversation.ChatAttachment{
|
||||
{
|
||||
Type: "image",
|
||||
Type: "image",
|
||||
ContentHash: "asset-1",
|
||||
},
|
||||
},
|
||||
@@ -243,6 +243,107 @@ func TestPrepareGatewayAttachments_DataURLFromURLFieldIsNativeInline(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamChat_AllowsLargeSSEDataLines(t *testing.T) {
|
||||
const overOldScannerLimit = 3 * 1024 * 1024
|
||||
hugeDelta := strings.Repeat("a", overOldScannerLimit)
|
||||
dataJSON, err := json.Marshal(map[string]any{
|
||||
"type": "text_delta",
|
||||
"delta": hugeDelta,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal test payload: %v", err)
|
||||
}
|
||||
dataStr := string(dataJSON)
|
||||
parts := make([]string, 0, (len(dataStr)/8192)+1)
|
||||
for i := 0; i < len(dataStr); i += 8192 {
|
||||
end := i + 8192
|
||||
if end > len(dataStr) {
|
||||
end = len(dataStr)
|
||||
}
|
||||
parts = append(parts, dataStr[i:end])
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/stream" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = io.WriteString(w, "event: message\n")
|
||||
for _, part := range parts {
|
||||
_, _ = io.WriteString(w, "data:")
|
||||
_, _ = io.WriteString(w, part)
|
||||
_, _ = io.WriteString(w, "\n")
|
||||
}
|
||||
_, _ = io.WriteString(w, "\n")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
resolver := &Resolver{
|
||||
gatewayBaseURL: srv.URL,
|
||||
streamingClient: srv.Client(),
|
||||
logger: slog.Default(),
|
||||
}
|
||||
|
||||
chunkCh := make(chan conversation.StreamChunk, 1)
|
||||
err = resolver.streamChat(
|
||||
context.Background(),
|
||||
gatewayRequest{},
|
||||
conversation.ChatRequest{},
|
||||
chunkCh,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("streamChat returned error: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case chunk := <-chunkCh:
|
||||
if !bytes.Equal(chunk, dataJSON) {
|
||||
t.Fatalf("unexpected reconstructed payload: got prefix %q", string(chunk[:min(len(chunk), 80)]))
|
||||
}
|
||||
default:
|
||||
t.Fatalf("expected at least one streamed chunk")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamChat_RejectsOverLimitSSELine(t *testing.T) {
|
||||
tooLong := strings.Repeat("x", gatewaySSEMaxLineBytes+10)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/stream" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = io.WriteString(w, "event: message\n")
|
||||
_, _ = io.WriteString(w, "data:")
|
||||
_, _ = io.WriteString(w, tooLong)
|
||||
_, _ = io.WriteString(w, "\n\n")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
resolver := &Resolver{
|
||||
gatewayBaseURL: srv.URL,
|
||||
streamingClient: srv.Client(),
|
||||
logger: slog.Default(),
|
||||
}
|
||||
|
||||
chunkCh := make(chan conversation.StreamChunk, 1)
|
||||
err := resolver.streamChat(context.Background(), gatewayRequest{}, conversation.ChatRequest{}, chunkCh)
|
||||
if err == nil {
|
||||
t.Fatalf("expected streamChat to error on oversized SSE line")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "sse line too long") {
|
||||
t.Fatalf("expected line-too-long error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func TestPrepareGatewayAttachments_PublicURLFromURLFieldIsNativePublic(t *testing.T) {
|
||||
resolver := &Resolver{logger: slog.Default()}
|
||||
req := conversation.ChatRequest{
|
||||
@@ -321,7 +422,7 @@ func TestPrepareGatewayAttachments_DetectsImageMimeWhenOctetStream(t *testing.T)
|
||||
BotID: "bot-1",
|
||||
Attachments: []conversation.ChatAttachment{
|
||||
{
|
||||
Type: "image",
|
||||
Type: "image",
|
||||
ContentHash: "asset-2",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -158,7 +158,9 @@ func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContex
|
||||
if err != nil {
|
||||
return mcpgw.BuildToolErrorResult(err.Error()), nil
|
||||
}
|
||||
return mcpgw.BuildToolSuccessResult(map[string]any{"content": content}), nil
|
||||
return mcpgw.BuildToolSuccessResult(map[string]any{
|
||||
"content": pruneToolOutputText(content, "tool result (read content)"),
|
||||
}), nil
|
||||
|
||||
case toolWrite:
|
||||
filePath := normalizePath(mcpgw.StringArg(arguments, "path"))
|
||||
@@ -238,8 +240,10 @@ func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContex
|
||||
if result.ExitCode != 0 && strings.Contains(stderr, "no running task") {
|
||||
stderr = strings.TrimSpace(stderr) + "\n\nHint: Container exists but has no running task (main process exited). Start it first: POST /bots/" + botID + "/container/start or use the container start action in the UI."
|
||||
}
|
||||
stdout := pruneToolOutputText(result.Stdout, "tool result (exec stdout)")
|
||||
stderr = pruneToolOutputText(stderr, "tool result (exec stderr)")
|
||||
return mcpgw.BuildToolSuccessResult(map[string]any{
|
||||
"stdout": result.Stdout,
|
||||
"stdout": stdout,
|
||||
"stderr": stderr,
|
||||
"exit_code": result.ExitCode,
|
||||
}), nil
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
package container
|
||||
|
||||
import textprune "github.com/memohai/memoh/internal/prune"
|
||||
|
||||
const (
|
||||
toolOutputHeadBytes = 4 * 1024
|
||||
toolOutputTailBytes = 1 * 1024
|
||||
toolOutputHeadLines = 150
|
||||
toolOutputTailLines = 50
|
||||
)
|
||||
|
||||
func pruneToolOutputText(text, label string) string {
|
||||
return textprune.PruneWithEdges(text, label, textprune.Config{
|
||||
MaxBytes: textprune.DefaultMaxBytes,
|
||||
MaxLines: textprune.DefaultMaxLines,
|
||||
HeadBytes: toolOutputHeadBytes,
|
||||
TailBytes: toolOutputTailBytes,
|
||||
HeadLines: toolOutputHeadLines,
|
||||
TailLines: toolOutputTailLines,
|
||||
Marker: textprune.DefaultMarker,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,185 @@
|
||||
package prune
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMarker = "[memoh pruned]"
|
||||
DefaultMaxBytes = 10 * 1024
|
||||
DefaultMaxLines = 250
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
MaxBytes int
|
||||
MaxLines int
|
||||
HeadBytes int
|
||||
TailBytes int
|
||||
HeadLines int
|
||||
TailLines int
|
||||
Marker string
|
||||
}
|
||||
|
||||
func Exceeds(s string, maxBytes, maxLines int) bool {
|
||||
return len(s) > maxBytes || CountLines(s) > maxLines
|
||||
}
|
||||
|
||||
func CountLines(s string) int {
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
return strings.Count(s, "\n") + 1
|
||||
}
|
||||
|
||||
func PruneWithEdges(s, label string, cfg Config) string {
|
||||
cfg = normalizeConfig(cfg)
|
||||
if len(s) == 0 {
|
||||
return s
|
||||
}
|
||||
if cfg.HeadBytes+cfg.TailBytes <= 0 || cfg.HeadLines+cfg.TailLines <= 0 {
|
||||
return fitBudget(fmt.Sprintf(
|
||||
"%s %s omitted (bytes=%d, lines=%d)",
|
||||
cfg.Marker,
|
||||
label,
|
||||
len(s),
|
||||
CountLines(s),
|
||||
), cfg)
|
||||
}
|
||||
if !Exceeds(s, cfg.MaxBytes, cfg.MaxLines) {
|
||||
return s
|
||||
}
|
||||
head := boundedPrefix(s, minInt(cfg.HeadBytes, len(s)), cfg.HeadLines)
|
||||
tail := ""
|
||||
if cfg.TailBytes > 0 && cfg.TailLines > 0 {
|
||||
tail = boundedSuffix(s, minInt(cfg.TailBytes, len(s)), cfg.TailLines)
|
||||
}
|
||||
return fitBudget(fmt.Sprintf(
|
||||
"%s %s too long (bytes=%d, lines=%d), showing head/tail\n\n%s\n\n[...snip...]\n\n%s",
|
||||
cfg.Marker,
|
||||
label,
|
||||
len(s),
|
||||
CountLines(s),
|
||||
head,
|
||||
tail,
|
||||
), cfg)
|
||||
}
|
||||
|
||||
func normalizeConfig(cfg Config) Config {
|
||||
if cfg.MaxBytes <= 0 {
|
||||
cfg.MaxBytes = DefaultMaxBytes
|
||||
}
|
||||
if cfg.MaxLines <= 0 {
|
||||
cfg.MaxLines = DefaultMaxLines
|
||||
}
|
||||
if cfg.Marker == "" {
|
||||
cfg.Marker = DefaultMarker
|
||||
}
|
||||
if cfg.HeadBytes < 0 {
|
||||
cfg.HeadBytes = 0
|
||||
}
|
||||
if cfg.TailBytes < 0 {
|
||||
cfg.TailBytes = 0
|
||||
}
|
||||
if cfg.HeadLines < 0 {
|
||||
cfg.HeadLines = 0
|
||||
}
|
||||
if cfg.TailLines < 0 {
|
||||
cfg.TailLines = 0
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func fitBudget(s string, cfg Config) string {
|
||||
if !Exceeds(s, cfg.MaxBytes, cfg.MaxLines) {
|
||||
return s
|
||||
}
|
||||
trimmed := boundedPrefix(s, cfg.MaxBytes, cfg.MaxLines)
|
||||
if trimmed == "" {
|
||||
return cfg.Marker
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func boundedPrefix(s string, maxBytes, maxLines int) string {
|
||||
if len(s) == 0 || maxBytes <= 0 || maxLines <= 0 {
|
||||
return ""
|
||||
}
|
||||
prefix := safeUTF8Prefix(s, minInt(maxBytes, len(s)))
|
||||
return limitLinesPrefix(prefix, maxLines)
|
||||
}
|
||||
|
||||
func boundedSuffix(s string, maxBytes, maxLines int) string {
|
||||
if len(s) == 0 || maxBytes <= 0 || maxLines <= 0 {
|
||||
return ""
|
||||
}
|
||||
suffix := safeUTF8Suffix(s, minInt(maxBytes, len(s)))
|
||||
return limitLinesSuffix(suffix, maxLines)
|
||||
}
|
||||
|
||||
func safeUTF8Prefix(s string, maxBytes int) string {
|
||||
if maxBytes <= 0 || len(s) == 0 {
|
||||
return ""
|
||||
}
|
||||
if maxBytes >= len(s) {
|
||||
return s
|
||||
}
|
||||
cut := maxBytes
|
||||
for cut > 0 && cut < len(s) && !utf8.RuneStart(s[cut]) {
|
||||
cut--
|
||||
}
|
||||
if cut <= 0 {
|
||||
return ""
|
||||
}
|
||||
return s[:cut]
|
||||
}
|
||||
|
||||
func safeUTF8Suffix(s string, maxBytes int) string {
|
||||
if maxBytes <= 0 || len(s) == 0 {
|
||||
return ""
|
||||
}
|
||||
if maxBytes >= len(s) {
|
||||
return s
|
||||
}
|
||||
start := len(s) - maxBytes
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
for start < len(s) && !utf8.RuneStart(s[start]) {
|
||||
start++
|
||||
}
|
||||
if start >= len(s) {
|
||||
return ""
|
||||
}
|
||||
return s[start:]
|
||||
}
|
||||
|
||||
func limitLinesPrefix(s string, maxLines int) string {
|
||||
if maxLines <= 0 || s == "" {
|
||||
return ""
|
||||
}
|
||||
lines := strings.Split(s, "\n")
|
||||
if len(lines) <= maxLines {
|
||||
return s
|
||||
}
|
||||
return strings.Join(lines[:maxLines], "\n")
|
||||
}
|
||||
|
||||
func limitLinesSuffix(s string, maxLines int) string {
|
||||
if maxLines <= 0 || s == "" {
|
||||
return ""
|
||||
}
|
||||
lines := strings.Split(s, "\n")
|
||||
if len(lines) <= maxLines {
|
||||
return s
|
||||
}
|
||||
return strings.Join(lines[len(lines)-maxLines:], "\n")
|
||||
}
|
||||
|
||||
func minInt(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
Reference in New Issue
Block a user