diff --git a/agent/src/modules/chat.ts b/agent/src/modules/chat.ts index c31892ff..06a3a5e3 100644 --- a/agent/src/modules/chat.ts +++ b/agent/src/modules/chat.ts @@ -1,4 +1,4 @@ -import { Elysia, sse } from 'elysia' +import { Elysia } from 'elysia' import z from 'zod' import { createAgent } from '../agent' import { createAuthFetcher, getBaseUrl } from '../index' @@ -6,6 +6,7 @@ import { ModelConfig } from '../types' import { bearerMiddleware } from '../middlewares/bearer' import { AgentSkillModel, AllowedActionModel, AttachmentModel, IdentityContextModel, MCPConnectionModel, ModelConfigModel, ScheduleModel } from '../models' import { allActions } from '../types' +import { sseChunked } from '../utils/sse' const AgentModel = z.object({ model: ModelConfigModel, @@ -75,14 +76,14 @@ export const chatModule = new Elysia({ prefix: '/chat' }) skills: body.skills, attachments: body.attachments, })) { - yield sse(JSON.stringify(action)) + yield sseChunked(JSON.stringify(action)) } } catch (error) { console.error(error) const message = error instanceof Error && error.message.trim() ? error.message : 'Internal server error' - yield sse(JSON.stringify({ + yield sseChunked(JSON.stringify({ type: 'error', message, })) diff --git a/agent/src/test/sse_chunked.test.ts b/agent/src/test/sse_chunked.test.ts new file mode 100644 index 00000000..85cfb580 --- /dev/null +++ b/agent/src/test/sse_chunked.test.ts @@ -0,0 +1,45 @@ +import { describe, expect, test } from 'bun:test' +import { sseChunked } from '../utils/sse' + +function parseChunkedSSE(payload: string): string { + const lines = payload.split('\n') + const dataLines = lines.filter(line => line.startsWith('data:')) + return dataLines.map(line => line.slice('data:'.length)).join('') +} + +describe('sseChunked', () => { + test('reconstructs original payload losslessly', () => { + const input = JSON.stringify({ + type: 'tool_call_end', + toolName: 'big_tool', + toolCallId: 'call-1', + // include whitespace and unicode so trimming/surrogate splitting bugs show up + result: ' leading spaces\tand tabs\nand unicode 😀😃😄 ', + blob: 'x'.repeat(200_000), + }) + + const chunked = sseChunked(input, 1024).toSSE() + const reconstructed = parseChunkedSSE(chunked) + + expect(reconstructed).toBe(input) + }) + + test('chunkSize=1 does not produce invalid UTF-8 (surrogate pairs)', () => { + const input = `😀${'x'.repeat(1000)}😃` + const payload = sseChunked(input, 1).toSSE() + + // Simulate the UTF-8 encode/decode step that happens over the network. + const encoded = new TextEncoder().encode(payload) + const decoded = new TextDecoder().decode(encoded) + expect(decoded).toBe(payload) + + const reconstructed = parseChunkedSSE(decoded) + expect(reconstructed).toBe(input) + }) + + test('does not inject an extra space after data:', () => { + const input = ' abc' + const chunked = sseChunked(input, 2).toSSE() + expect(chunked.split('\n')[0]).toBe('data: a') + }) +}) diff --git a/agent/src/utils/sse.ts b/agent/src/utils/sse.ts new file mode 100644 index 00000000..6a6c84d7 --- /dev/null +++ b/agent/src/utils/sse.ts @@ -0,0 +1,50 @@ +export const defaultSSEChunkSize = 16 * 1024 + +export function sseChunked(data: string, chunkSize: number = defaultSSEChunkSize) { + return { + sse: true as const, + toSSE: () => { + const out: string[] = [] + for (const chunk of chunkString(data, chunkSize)) { + out.push(`data:${chunk}\n`) + } + out.push('\n') + return out.join('') + }, + } +} + +export function* chunkString(input: string, maxLen: number): Generator { + if (maxLen <= 0) { + yield input + return + } + const isHighSurrogate = (c: number) => c >= 0xd800 && c <= 0xdbff + const isLowSurrogate = (c: number) => c >= 0xdc00 && c <= 0xdfff + let i = 0 + while (i < input.length) { + let end = Math.min(i + maxLen, input.length) + if (end < input.length) { + const last = input.charCodeAt(end - 1) + if (isHighSurrogate(last)) { + const next = input.charCodeAt(end) + if (isLowSurrogate(next)) { + end += 1 + } else { + end -= 1 + } + } + } + if (end <= i) { + const first = input.charCodeAt(i) + const second = i+1 < input.length ? input.charCodeAt(i + 1) : -1 + if (isHighSurrogate(first) && isLowSurrogate(second)) { + end = Math.min(i + 2, input.length) + } else { + end = Math.min(i + 1, input.length) + } + } + yield input.slice(i, end) + i = end + } +} diff --git a/internal/conversation/flow/gateway_prune.go b/internal/conversation/flow/gateway_prune.go new file mode 100644 index 00000000..97fb9688 --- /dev/null +++ b/internal/conversation/flow/gateway_prune.go @@ -0,0 +1,319 @@ +package flow + +import ( + "encoding/json" + "strings" + + "github.com/memohai/memoh/internal/conversation" + textprune "github.com/memohai/memoh/internal/prune" +) + +const ( + // Prune long tool payloads per message to keep gateway requests within provider limits, + // while preserving as much surrounding context as possible. + gatewayToolPayloadMaxBytes = textprune.DefaultMaxBytes + gatewayToolPayloadMaxLines = textprune.DefaultMaxLines + + gatewayToolResultHeadBytes = 6 * 1024 + gatewayToolResultTailBytes = 2 * 1024 + gatewayToolResultHeadLines = 180 + gatewayToolResultTailLines = 50 + + gatewayToolArgsHeadBytes = 4 * 1024 + gatewayToolArgsTailBytes = 2 * 1024 + gatewayToolArgsHeadLines = 180 + gatewayToolArgsTailLines = 50 + + gatewayToolPayloadPrunedMarker = textprune.DefaultMarker +) + +func pruneHistoryForGateway(messages []messageWithUsage) []messageWithUsage { + if len(messages) == 0 { + return messages + } + out := make([]messageWithUsage, 0, len(messages)) + staleUsage := false + for _, item := range messages { + msg, changed := pruneMessageForGateway(item.Message) + if changed { + item.Message = msg + staleUsage = true + } + if staleUsage { + item.UsageInputTokens = nil + } + out = append(out, item) + } + return out +} + +func pruneMessagesForGateway(messages []conversation.ModelMessage) []conversation.ModelMessage { + if len(messages) == 0 { + return messages + } + out := make([]conversation.ModelMessage, 0, len(messages)) + for _, msg := range messages { + pruned, _ := pruneMessageForGateway(msg) + out = append(out, pruned) + } + return out +} + +func pruneMessageForGateway(msg conversation.ModelMessage) (conversation.ModelMessage, bool) { + changed := false + if strings.EqualFold(strings.TrimSpace(msg.Role), "tool") { + msg2, did := pruneToolMessage(msg) + if did { + msg = msg2 + changed = true + } + } + if len(msg.ToolCalls) > 0 { + calls, did := pruneToolCalls(msg.ToolCalls) + if did { + msg.ToolCalls = calls + changed = true + } + } + return msg, changed +} + +func pruneToolCalls(calls []conversation.ToolCall) ([]conversation.ToolCall, bool) { + changed := false + out := make([]conversation.ToolCall, len(calls)) + for i, call := range calls { + out[i] = call + args := call.Function.Arguments + if args == "" || !exceedsTextBudget(args) { + continue + } + out[i].Function.Arguments = pruneStringEdges( + args, + gatewayToolArgsHeadBytes, + gatewayToolArgsTailBytes, + gatewayToolArgsHeadLines, + gatewayToolArgsTailLines, + "tool arguments", + ) + changed = true + } + return out, changed +} + +func pruneToolMessage(msg conversation.ModelMessage) (conversation.ModelMessage, bool) { + // Vercel AI SDK schema requires tool messages to carry an array of tool-result parts. + // Prune outputs inside those parts (preserving shape) so the gateway prompt remains valid. + if pruned, ok := pruneToolResultParts(msg.Content); ok { + msg.Content = pruned + return msg, true + } + + // Backward-compat: tool messages may have been persisted as plain strings. + text := msg.TextContent() + if !exceedsTextBudget(text) { + return msg, false + } + msg.Content = conversation.NewTextContent(pruneStringEdges( + text, + gatewayToolResultHeadBytes, + gatewayToolResultTailBytes, + gatewayToolResultHeadLines, + gatewayToolResultTailLines, + "tool result", + )) + return msg, true +} + +func pruneToolResultParts(content json.RawMessage) (json.RawMessage, bool) { + if len(content) == 0 { + return nil, false + } + var parts []json.RawMessage + if err := json.Unmarshal(content, &parts); err != nil || len(parts) == 0 { + return nil, false + } + + changed := false + out := make([]json.RawMessage, 0, len(parts)) + for _, raw := range parts { + var part map[string]json.RawMessage + if err := json.Unmarshal(raw, &part); err != nil { + out = append(out, raw) + continue + } + + partTypeRaw, ok := part["type"] + if !ok { + out = append(out, raw) + continue + } + var partType string + if err := json.Unmarshal(partTypeRaw, &partType); err != nil || partType != "tool-result" { + out = append(out, raw) + continue + } + + outputRaw, ok := part["output"] + if !ok { + out = append(out, raw) + continue + } + pruned, didPrune := pruneToolOutput(outputRaw) + if !didPrune { + out = append(out, raw) + continue + } + + part["output"] = pruned + rebuilt, err := json.Marshal(part) + if err != nil { + out = append(out, raw) + continue + } + out = append(out, json.RawMessage(rebuilt)) + changed = true + } + + if !changed { + return nil, false + } + rebuilt, err := json.Marshal(out) + if err != nil { + return nil, false + } + return json.RawMessage(rebuilt), true +} + +func pruneToolOutput(raw json.RawMessage) (json.RawMessage, bool) { + var output map[string]json.RawMessage + if err := json.Unmarshal(raw, &output); err != nil { + return nil, false + } + outputTypeRaw, ok := output["type"] + if !ok { + return nil, false + } + var outputType string + if err := json.Unmarshal(outputTypeRaw, &outputType); err != nil { + return nil, false + } + valueRaw, hasValue := output["value"] + + switch outputType { + case "text", "error-text": + if !hasValue { + return nil, false + } + var s string + if err := json.Unmarshal(valueRaw, &s); err != nil || !exceedsTextBudget(s) { + return nil, false + } + s = pruneStringEdges( + s, + gatewayToolResultHeadBytes, + gatewayToolResultTailBytes, + gatewayToolResultHeadLines, + gatewayToolResultTailLines, + "tool result", + ) + data, err := json.Marshal(s) + if err != nil { + return nil, false + } + output["value"] = data + rebuilt, err := json.Marshal(output) + if err != nil { + return nil, false + } + return json.RawMessage(rebuilt), true + + case "json", "error-json": + if !hasValue || !exceedsTextBudget(string(valueRaw)) { + return nil, false + } + pruned := pruneStringEdges( + string(valueRaw), + gatewayToolResultHeadBytes, + gatewayToolResultTailBytes, + gatewayToolResultHeadLines, + gatewayToolResultTailLines, + "tool result (json)", + ) + data, err := json.Marshal(pruned) + if err != nil { + return nil, false + } + output["value"] = data + rebuilt, err := json.Marshal(output) + if err != nil { + return nil, false + } + return json.RawMessage(rebuilt), true + + case "content": + // Best-effort: prune any large text items inside the content array. + // If parsing fails, keep the original output to avoid breaking schema. + if !hasValue { + return nil, false + } + var items []map[string]any + if err := json.Unmarshal(valueRaw, &items); err != nil { + return nil, false + } + didPrune := false + for i := range items { + if items[i]["type"] != "text" { + continue + } + textAny, ok := items[i]["text"] + if !ok { + continue + } + text, ok := textAny.(string) + if !ok || !exceedsTextBudget(text) { + continue + } + items[i]["text"] = pruneStringEdges( + text, + gatewayToolResultHeadBytes, + gatewayToolResultTailBytes, + gatewayToolResultHeadLines, + gatewayToolResultTailLines, + "tool result (content)", + ) + didPrune = true + } + if !didPrune { + return nil, false + } + data, err := json.Marshal(items) + if err != nil { + return nil, false + } + output["value"] = data + rebuilt, err := json.Marshal(output) + if err != nil { + return nil, false + } + return json.RawMessage(rebuilt), true + + default: + return nil, false + } +} + +func pruneStringEdges(s string, headBytes, tailBytes, headLines, tailLines int, label string) string { + return textprune.PruneWithEdges(s, label, textprune.Config{ + MaxBytes: gatewayToolPayloadMaxBytes, + MaxLines: gatewayToolPayloadMaxLines, + HeadBytes: headBytes, + TailBytes: tailBytes, + HeadLines: headLines, + TailLines: tailLines, + Marker: gatewayToolPayloadPrunedMarker, + }) +} + +func exceedsTextBudget(s string) bool { + return textprune.Exceeds(s, gatewayToolPayloadMaxBytes, gatewayToolPayloadMaxLines) +} diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go index 4ac441de..77de0eb2 100644 --- a/internal/conversation/flow/resolver.go +++ b/internal/conversation/flow/resolver.go @@ -2,9 +2,11 @@ package flow import ( "bufio" + "bytes" "context" "encoding/base64" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -34,6 +36,12 @@ const ( sharedMemoryNamespace = "bot" // Keep gateway payload bounded when inlining binary attachments as data URLs. gatewayInlineAttachmentMaxBytes int64 = 20 * 1024 * 1024 + // SSE payloads (especially attachment/tool results) can be very large. + // bufio.Scanner hard-fails with "token too long" if a single line exceeds its max token size. + // Use a reader-based parser and enforce an explicit per-line cap here. The agent gateway + // stream is expected to chunk large JSON payloads across multiple SSE "data:" lines, so + // this limit should stay relatively small. + gatewaySSEMaxLineBytes = 256 * 1024 ) // SkillEntry represents a skill loaded from the container. @@ -255,11 +263,16 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r // Build non-history parts first so we can reserve their token cost before // trimming history messages. memoryMsg := r.loadMemoryContextMessage(ctx, req) + reqMessages := pruneMessagesForGateway(nonNilModelMessages(req.Messages)) + if memoryMsg != nil { + pruned, _ := pruneMessageForGateway(*memoryMsg) + memoryMsg = &pruned + } var overhead int if memoryMsg != nil { overhead += estimateMessageTokens(*memoryMsg) } - for _, m := range req.Messages { + for _, m := range reqMessages { overhead += estimateMessageTokens(m) } // Reserve space for the system prompt built by the agent gateway @@ -278,12 +291,13 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r if loadErr != nil { return resolvedContext{}, loadErr } + loaded = pruneHistoryForGateway(loaded) messages = trimMessagesByTokens(loaded, historyBudget) } if memoryMsg != nil { messages = append(messages, *memoryMsg) } - messages = append(messages, req.Messages...) + messages = append(messages, reqMessages...) messages = sanitizeMessages(messages) skills := dedup(req.Skills) containerID := r.resolveContainerID(ctx, req.BotID, req.ContainerID) @@ -580,39 +594,66 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req c return fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(errBody))) } - scanner := bufio.NewScanner(resp.Body) - scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) - - currentEvent := "" stored := false - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" { - continue - } - if strings.HasPrefix(line, "event:") { - currentEvent = strings.TrimSpace(strings.TrimPrefix(line, "event:")) - continue - } - if !strings.HasPrefix(line, "data:") { - continue - } - data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) - if data == "" || data == "[DONE]" { - continue - } - chunkCh <- conversation.StreamChunk([]byte(data)) + var dataBuf bytes.Buffer - if stored { + flushEvent := func() error { + if dataBuf.Len() == 0 { + return nil + } + out := append([]byte(nil), dataBuf.Bytes()...) + dataBuf.Reset() + if len(out) == 0 || bytes.Equal(bytes.TrimSpace(out), []byte("[DONE]")) { + return nil + } + // Persist final messages before forwarding the "done"/"agent_end" event so the + // next user turn can immediately see the assistant output in history. + if !stored { + if handled, storeErr := r.tryStoreStream(ctx, req, out); storeErr != nil { + return storeErr + } else if handled { + stored = true + } + } + chunkCh <- conversation.StreamChunk(out) + return nil + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), gatewaySSEMaxLineBytes) + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + if err := flushEvent(); err != nil { + return err + } continue } - if handled, storeErr := r.tryStoreStream(ctx, req, currentEvent, data); storeErr != nil { - return storeErr - } else if handled { - stored = true + if len(line) > 0 && line[0] == ':' { + continue } + if !bytes.HasPrefix(line, []byte("data:")) { + continue + } + part := bytes.TrimPrefix(line, []byte("data:")) + // Backward-compat: older SSE writers used "data: " (note the space). + // Only strip the first leading space for the *first* fragment to avoid corrupting + // chunked payloads split inside JSON string values. + if dataBuf.Len() == 0 && len(part) > 0 && part[0] == ' ' { + part = part[1:] + } + if len(part) == 0 { + continue + } + _, _ = dataBuf.Write(part) } - return scanner.Err() + if err := scanner.Err(); err != nil { + if errors.Is(err, bufio.ErrTooLong) { + return fmt.Errorf("sse line too long (max %d bytes)", gatewaySSEMaxLineBytes) + } + return err + } + return flushEvent() } func newJSONRequestWithContext(ctx context.Context, method, url string, payload any) (*http.Request, error) { @@ -631,24 +672,15 @@ func newJSONRequestWithContext(ctx context.Context, method, url string, payload } // tryStoreStream attempts to extract final messages from a stream event and persist them. -func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, eventType, data string) (bool, error) { - // event: done + data: {messages: [...]} - if eventType == "done" { - var resp gatewayResponse - if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 { - return true, r.storeRound(ctx, req, resp.Messages, resp.Usage) - } - } - +func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, data []byte) (bool, error) { // data: {"type":"text_delta"|"agent_end"|"done", ...} var envelope struct { Type string `json:"type"` Data json.RawMessage `json:"data"` Messages []conversation.ModelMessage `json:"messages"` - Skills []string `json:"skills"` Usage json.RawMessage `json:"usage,omitempty"` } - if err := json.Unmarshal([]byte(data), &envelope); err == nil { + if err := json.Unmarshal(data, &envelope); err == nil { if (envelope.Type == "agent_end" || envelope.Type == "done") && len(envelope.Messages) > 0 { return true, r.storeRound(ctx, req, envelope.Messages, envelope.Usage) } @@ -662,7 +694,7 @@ func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequ // fallback: data: {messages: [...]} var resp gatewayResponse - if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 { + if err := json.Unmarshal(data, &resp); err == nil && len(resp.Messages) > 0 { return true, r.storeRound(ctx, req, resp.Messages, resp.Usage) } return false, nil @@ -763,19 +795,14 @@ func normalizeGatewayAttachmentPayload(item gatewayAttachment) gatewayAttachment if payload == "" { return item } - lower := strings.ToLower(payload) - if strings.HasPrefix(lower, "data:") { - if strings.TrimSpace(item.Mime) == "" || strings.EqualFold(strings.TrimSpace(item.Mime), "application/octet-stream") { - if start := strings.Index(payload, ":"); start >= 0 { - rest := payload[start+1:] - if end := strings.Index(rest, ";"); end > 0 { - mime := strings.TrimSpace(rest[:end]) - if mime != "" { - item.Mime = mime - } - } + if strings.HasPrefix(strings.ToLower(payload), "data:") { + mime := strings.TrimSpace(item.Mime) + if mime == "" || strings.EqualFold(mime, "application/octet-stream") { + if extracted := attachmentpkg.MimeFromDataURL(payload); extracted != "" { + item.Mime = extracted } } + item.Payload = payload return item } mime := strings.TrimSpace(item.Mime) diff --git a/internal/conversation/flow/resolver_prune_test.go b/internal/conversation/flow/resolver_prune_test.go new file mode 100644 index 00000000..5dd2a05b --- /dev/null +++ b/internal/conversation/flow/resolver_prune_test.go @@ -0,0 +1,208 @@ +package flow + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + "unicode/utf8" + + "github.com/memohai/memoh/internal/conversation" +) + +func TestPruneMessagesForGateway_PrunesToolResultContent(t *testing.T) { + t.Parallel() + + unit := "汉😀" + huge := strings.Repeat(unit, (gatewayToolPayloadMaxBytes/len(unit))+20) + msgs := []conversation.ModelMessage{ + {Role: "tool", Content: conversation.NewTextContent(huge), ToolCallID: "call-1"}, + } + out := pruneMessagesForGateway(msgs) + if len(out) != 1 { + t.Fatalf("expected 1 message, got %d", len(out)) + } + got := out[0].TextContent() + if strings.Contains(got, huge) { + t.Fatalf("expected tool content to be pruned") + } + if !strings.Contains(got, gatewayToolPayloadPrunedMarker) { + t.Fatalf("expected pruned marker, got: %q", got[:minLen(len(got), 80)]) + } + if !utf8.ValidString(got) { + t.Fatalf("expected pruned tool content to remain valid UTF-8") + } +} + +func TestPruneMessagesForGateway_PrunesToolCallArguments(t *testing.T) { + t.Parallel() + + repeated := strings.Repeat("猫😺", (gatewayToolPayloadMaxBytes/len("猫😺"))+20) + hugeArgs := `{"a":"` + repeated + `"}` + msgs := []conversation.ModelMessage{ + { + Role: "assistant", + ToolCalls: []conversation.ToolCall{ + { + ID: "call-1", + Type: "function", + Function: conversation.ToolCallFunction{ + Name: "big_tool", + Arguments: hugeArgs, + }, + }, + }, + }, + } + out := pruneMessagesForGateway(msgs) + if len(out) != 1 { + t.Fatalf("expected 1 message, got %d", len(out)) + } + if len(out[0].ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(out[0].ToolCalls)) + } + got := out[0].ToolCalls[0].Function.Arguments + if strings.Contains(got, repeated) { + t.Fatalf("expected tool arguments to be pruned") + } + if !strings.Contains(got, gatewayToolPayloadPrunedMarker) { + t.Fatalf("expected pruned marker in args") + } + if !utf8.ValidString(got) { + t.Fatalf("expected pruned tool arguments to remain valid UTF-8") + } +} + +func TestPruneHistoryForGateway_ClearsStaleUsageTokensAfterPrune(t *testing.T) { + t.Parallel() + + huge := strings.Repeat("汉😀", (gatewayToolPayloadMaxBytes/len("汉😀"))+20) + firstTokens := 123 + secondTokens := 456 + + in := []messageWithUsage{ + { + Message: conversation.ModelMessage{Role: "tool", Content: conversation.NewTextContent(huge), ToolCallID: "call-1"}, + UsageInputTokens: &firstTokens, + }, + { + Message: conversation.ModelMessage{Role: "user", Content: conversation.NewTextContent("hi")}, + UsageInputTokens: &secondTokens, + }, + } + + out := pruneHistoryForGateway(in) + if len(out) != 2 { + t.Fatalf("expected 2 messages, got %d", len(out)) + } + if out[0].UsageInputTokens != nil { + t.Fatalf("expected first UsageInputTokens to be cleared after prune") + } + if out[1].UsageInputTokens != nil { + t.Fatalf("expected subsequent UsageInputTokens to be cleared after earlier prune") + } +} + +func TestPruneHistoryForGateway_PreservesUsageTokensWhenUnchanged(t *testing.T) { + t.Parallel() + + tokens := 321 + in := []messageWithUsage{ + { + Message: conversation.ModelMessage{Role: "user", Content: conversation.NewTextContent("short")}, + UsageInputTokens: &tokens, + }, + } + out := pruneHistoryForGateway(in) + if len(out) != 1 { + t.Fatalf("expected 1 message, got %d", len(out)) + } + if out[0].UsageInputTokens == nil || *out[0].UsageInputTokens != tokens { + t.Fatalf("expected UsageInputTokens to be preserved") + } +} + +func TestPruneMessagesForGateway_ToolResultPartsRemainValidToolMessageSchema(t *testing.T) { + t.Parallel() + + huge := strings.Repeat("a", gatewayToolPayloadMaxBytes+100) + part := map[string]any{ + "type": "tool-result", + "toolCallId": "call-1", + "toolName": "big_tool", + "providerOptions": map[string]any{ + "test-provider": map[string]any{"mode": "strict"}, + }, + "extraPart": "keep-part", + "output": map[string]any{ + "type": "text", + "value": huge, + "providerOptions": map[string]any{ + "test-provider": map[string]any{"cache": true}, + }, + "extraOutput": "keep-output", + }, + } + content, err := json.Marshal([]any{part}) + if err != nil { + t.Fatalf("marshal tool content: %v", err) + } + msgs := []conversation.ModelMessage{ + {Role: "tool", Content: content, ToolCallID: "call-1"}, + } + + out := pruneMessagesForGateway(msgs) + if len(out) != 1 { + t.Fatalf("expected 1 message, got %d", len(out)) + } + if !bytes.HasPrefix(bytes.TrimSpace(out[0].Content), []byte("[")) { + t.Fatalf("expected tool content to remain an array, got: %q", string(out[0].Content[:minLen(len(out[0].Content), 80)])) + } + if !bytes.Contains(out[0].Content, []byte(`"type":"tool-result"`)) { + t.Fatalf("expected tool-result part to be preserved") + } + if !bytes.Contains(out[0].Content, []byte(gatewayToolPayloadPrunedMarker)) { + t.Fatalf("expected pruned marker in tool output") + } + + var parts []map[string]any + if err := json.Unmarshal(out[0].Content, &parts); err != nil { + t.Fatalf("unmarshal pruned tool content: %v", err) + } + if len(parts) != 1 { + t.Fatalf("expected 1 part, got %d", len(parts)) + } + if parts[0]["extraPart"] != "keep-part" { + t.Fatalf("expected extra part field preserved") + } + if _, ok := parts[0]["providerOptions"]; !ok { + t.Fatalf("expected part providerOptions preserved") + } + outputAny, ok := parts[0]["output"].(map[string]any) + if !ok { + t.Fatalf("expected output object") + } + if outputAny["extraOutput"] != "keep-output" { + t.Fatalf("expected output extra field preserved") + } + if _, ok := outputAny["providerOptions"]; !ok { + t.Fatalf("expected output providerOptions preserved") + } + if outputAny["type"] != "text" { + t.Fatalf("expected output.type=text, got %v", outputAny["type"]) + } + value, ok := outputAny["value"].(string) + if !ok { + t.Fatalf("expected output.value string") + } + if len(value) >= len(huge) { + t.Fatalf("expected output.value to be pruned") + } +} + +func minLen(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/internal/conversation/flow/resolver_stream_order_test.go b/internal/conversation/flow/resolver_stream_order_test.go new file mode 100644 index 00000000..6e7646e5 --- /dev/null +++ b/internal/conversation/flow/resolver_stream_order_test.go @@ -0,0 +1,139 @@ +package flow + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/memohai/memoh/internal/conversation" + messagepkg "github.com/memohai/memoh/internal/message" +) + +type blockingMessageService struct { + persistCalled chan struct{} + persistContinue chan struct{} +} + +func (s *blockingMessageService) Persist(ctx context.Context, input messagepkg.PersistInput) (messagepkg.Message, error) { + select { + case <-s.persistCalled: + default: + close(s.persistCalled) + } + <-s.persistContinue + return messagepkg.Message{}, nil +} + +func (s *blockingMessageService) List(ctx context.Context, botID string) ([]messagepkg.Message, error) { + return nil, nil +} + +func (s *blockingMessageService) ListSince(ctx context.Context, botID string, since time.Time) ([]messagepkg.Message, error) { + return nil, nil +} + +func (s *blockingMessageService) ListLatest(ctx context.Context, botID string, limit int32) ([]messagepkg.Message, error) { + return nil, nil +} + +func (s *blockingMessageService) ListBefore(ctx context.Context, botID string, before time.Time, limit int32) ([]messagepkg.Message, error) { + return nil, nil +} + +func (s *blockingMessageService) DeleteByBot(ctx context.Context, botID string) error { + return nil +} + +func TestStreamChat_PersistsFinalMessagesBeforeForwardingDoneEvent(t *testing.T) { + t.Parallel() + + msgSvc := &blockingMessageService{ + persistCalled: make(chan struct{}), + persistContinue: make(chan struct{}), + } + + doneResp := gatewayResponse{ + Messages: []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent("ok")}, + }, + } + doneData, err := json.Marshal(doneResp) + if err != nil { + t.Fatalf("marshal done response: %v", err) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/chat/stream" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + _, _ = w.Write([]byte("event: done\n")) + _, _ = w.Write([]byte("data: ")) + _, _ = w.Write(doneData) + _, _ = w.Write([]byte("\n\n")) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + })) + t.Cleanup(srv.Close) + + r := &Resolver{ + messageService: msgSvc, + gatewayBaseURL: srv.URL, + logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + streamingClient: srv.Client(), + httpClient: srv.Client(), + } + + chunkCh := make(chan conversation.StreamChunk, 10) + req := conversation.ChatRequest{BotID: "bot-test", ChatID: "chat-test"} + payload := gatewayRequest{} + + streamDone := make(chan error, 1) + go func() { + streamDone <- r.streamChat(context.Background(), payload, req, chunkCh) + close(chunkCh) + }() + + select { + case <-msgSvc.persistCalled: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for Persist to be called") + } + + select { + case got := <-chunkCh: + t.Fatalf("done event forwarded before persistence finished: %s", string(got)) + default: + } + + close(msgSvc.persistContinue) + + select { + case err := <-streamDone: + if err != nil { + t.Fatalf("streamChat returned error: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for streamChat to finish") + } + + select { + case got := <-chunkCh: + if len(got) == 0 { + t.Fatal("expected forwarded done event data") + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for forwarded done event data") + } +} diff --git a/internal/conversation/flow/resolver_test.go b/internal/conversation/flow/resolver_test.go index 85868064..1f22df2c 100644 --- a/internal/conversation/flow/resolver_test.go +++ b/internal/conversation/flow/resolver_test.go @@ -196,7 +196,7 @@ func TestPrepareGatewayAttachments_InlineAssetToBase64(t *testing.T) { BotID: "bot-1", Attachments: []conversation.ChatAttachment{ { - Type: "image", + Type: "image", ContentHash: "asset-1", }, }, @@ -243,6 +243,107 @@ func TestPrepareGatewayAttachments_DataURLFromURLFieldIsNativeInline(t *testing. } } +func TestStreamChat_AllowsLargeSSEDataLines(t *testing.T) { + const overOldScannerLimit = 3 * 1024 * 1024 + hugeDelta := strings.Repeat("a", overOldScannerLimit) + dataJSON, err := json.Marshal(map[string]any{ + "type": "text_delta", + "delta": hugeDelta, + }) + if err != nil { + t.Fatalf("failed to marshal test payload: %v", err) + } + dataStr := string(dataJSON) + parts := make([]string, 0, (len(dataStr)/8192)+1) + for i := 0; i < len(dataStr); i += 8192 { + end := i + 8192 + if end > len(dataStr) { + end = len(dataStr) + } + parts = append(parts, dataStr[i:end]) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/chat/stream" { + w.WriteHeader(http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, "event: message\n") + for _, part := range parts { + _, _ = io.WriteString(w, "data:") + _, _ = io.WriteString(w, part) + _, _ = io.WriteString(w, "\n") + } + _, _ = io.WriteString(w, "\n") + })) + defer srv.Close() + + resolver := &Resolver{ + gatewayBaseURL: srv.URL, + streamingClient: srv.Client(), + logger: slog.Default(), + } + + chunkCh := make(chan conversation.StreamChunk, 1) + err = resolver.streamChat( + context.Background(), + gatewayRequest{}, + conversation.ChatRequest{}, + chunkCh, + ) + if err != nil { + t.Fatalf("streamChat returned error: %v", err) + } + + select { + case chunk := <-chunkCh: + if !bytes.Equal(chunk, dataJSON) { + t.Fatalf("unexpected reconstructed payload: got prefix %q", string(chunk[:min(len(chunk), 80)])) + } + default: + t.Fatalf("expected at least one streamed chunk") + } +} + +func TestStreamChat_RejectsOverLimitSSELine(t *testing.T) { + tooLong := strings.Repeat("x", gatewaySSEMaxLineBytes+10) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/chat/stream" { + w.WriteHeader(http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, "event: message\n") + _, _ = io.WriteString(w, "data:") + _, _ = io.WriteString(w, tooLong) + _, _ = io.WriteString(w, "\n\n") + })) + defer srv.Close() + + resolver := &Resolver{ + gatewayBaseURL: srv.URL, + streamingClient: srv.Client(), + logger: slog.Default(), + } + + chunkCh := make(chan conversation.StreamChunk, 1) + err := resolver.streamChat(context.Background(), gatewayRequest{}, conversation.ChatRequest{}, chunkCh) + if err == nil { + t.Fatalf("expected streamChat to error on oversized SSE line") + } + if !strings.Contains(err.Error(), "sse line too long") { + t.Fatalf("expected line-too-long error, got: %v", err) + } +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + func TestPrepareGatewayAttachments_PublicURLFromURLFieldIsNativePublic(t *testing.T) { resolver := &Resolver{logger: slog.Default()} req := conversation.ChatRequest{ @@ -321,7 +422,7 @@ func TestPrepareGatewayAttachments_DetectsImageMimeWhenOctetStream(t *testing.T) BotID: "bot-1", Attachments: []conversation.ChatAttachment{ { - Type: "image", + Type: "image", ContentHash: "asset-2", }, }, diff --git a/internal/mcp/providers/container/provider.go b/internal/mcp/providers/container/provider.go index fbb2427a..c65e4e55 100644 --- a/internal/mcp/providers/container/provider.go +++ b/internal/mcp/providers/container/provider.go @@ -158,7 +158,9 @@ func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContex if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil } - return mcpgw.BuildToolSuccessResult(map[string]any{"content": content}), nil + return mcpgw.BuildToolSuccessResult(map[string]any{ + "content": pruneToolOutputText(content, "tool result (read content)"), + }), nil case toolWrite: filePath := normalizePath(mcpgw.StringArg(arguments, "path")) @@ -238,8 +240,10 @@ func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContex if result.ExitCode != 0 && strings.Contains(stderr, "no running task") { stderr = strings.TrimSpace(stderr) + "\n\nHint: Container exists but has no running task (main process exited). Start it first: POST /bots/" + botID + "/container/start or use the container start action in the UI." } + stdout := pruneToolOutputText(result.Stdout, "tool result (exec stdout)") + stderr = pruneToolOutputText(stderr, "tool result (exec stderr)") return mcpgw.BuildToolSuccessResult(map[string]any{ - "stdout": result.Stdout, + "stdout": stdout, "stderr": stderr, "exit_code": result.ExitCode, }), nil diff --git a/internal/mcp/providers/container/prune.go b/internal/mcp/providers/container/prune.go new file mode 100644 index 00000000..88996f5d --- /dev/null +++ b/internal/mcp/providers/container/prune.go @@ -0,0 +1,22 @@ +package container + +import textprune "github.com/memohai/memoh/internal/prune" + +const ( + toolOutputHeadBytes = 4 * 1024 + toolOutputTailBytes = 1 * 1024 + toolOutputHeadLines = 150 + toolOutputTailLines = 50 +) + +func pruneToolOutputText(text, label string) string { + return textprune.PruneWithEdges(text, label, textprune.Config{ + MaxBytes: textprune.DefaultMaxBytes, + MaxLines: textprune.DefaultMaxLines, + HeadBytes: toolOutputHeadBytes, + TailBytes: toolOutputTailBytes, + HeadLines: toolOutputHeadLines, + TailLines: toolOutputTailLines, + Marker: textprune.DefaultMarker, + }) +} diff --git a/internal/prune/text.go b/internal/prune/text.go new file mode 100644 index 00000000..c6eb6b36 --- /dev/null +++ b/internal/prune/text.go @@ -0,0 +1,185 @@ +package prune + +import ( + "fmt" + "strings" + "unicode/utf8" +) + +const ( + DefaultMarker = "[memoh pruned]" + DefaultMaxBytes = 10 * 1024 + DefaultMaxLines = 250 +) + +type Config struct { + MaxBytes int + MaxLines int + HeadBytes int + TailBytes int + HeadLines int + TailLines int + Marker string +} + +func Exceeds(s string, maxBytes, maxLines int) bool { + return len(s) > maxBytes || CountLines(s) > maxLines +} + +func CountLines(s string) int { + if s == "" { + return 0 + } + return strings.Count(s, "\n") + 1 +} + +func PruneWithEdges(s, label string, cfg Config) string { + cfg = normalizeConfig(cfg) + if len(s) == 0 { + return s + } + if cfg.HeadBytes+cfg.TailBytes <= 0 || cfg.HeadLines+cfg.TailLines <= 0 { + return fitBudget(fmt.Sprintf( + "%s %s omitted (bytes=%d, lines=%d)", + cfg.Marker, + label, + len(s), + CountLines(s), + ), cfg) + } + if !Exceeds(s, cfg.MaxBytes, cfg.MaxLines) { + return s + } + head := boundedPrefix(s, minInt(cfg.HeadBytes, len(s)), cfg.HeadLines) + tail := "" + if cfg.TailBytes > 0 && cfg.TailLines > 0 { + tail = boundedSuffix(s, minInt(cfg.TailBytes, len(s)), cfg.TailLines) + } + return fitBudget(fmt.Sprintf( + "%s %s too long (bytes=%d, lines=%d), showing head/tail\n\n%s\n\n[...snip...]\n\n%s", + cfg.Marker, + label, + len(s), + CountLines(s), + head, + tail, + ), cfg) +} + +func normalizeConfig(cfg Config) Config { + if cfg.MaxBytes <= 0 { + cfg.MaxBytes = DefaultMaxBytes + } + if cfg.MaxLines <= 0 { + cfg.MaxLines = DefaultMaxLines + } + if cfg.Marker == "" { + cfg.Marker = DefaultMarker + } + if cfg.HeadBytes < 0 { + cfg.HeadBytes = 0 + } + if cfg.TailBytes < 0 { + cfg.TailBytes = 0 + } + if cfg.HeadLines < 0 { + cfg.HeadLines = 0 + } + if cfg.TailLines < 0 { + cfg.TailLines = 0 + } + return cfg +} + +func fitBudget(s string, cfg Config) string { + if !Exceeds(s, cfg.MaxBytes, cfg.MaxLines) { + return s + } + trimmed := boundedPrefix(s, cfg.MaxBytes, cfg.MaxLines) + if trimmed == "" { + return cfg.Marker + } + return trimmed +} + +func boundedPrefix(s string, maxBytes, maxLines int) string { + if len(s) == 0 || maxBytes <= 0 || maxLines <= 0 { + return "" + } + prefix := safeUTF8Prefix(s, minInt(maxBytes, len(s))) + return limitLinesPrefix(prefix, maxLines) +} + +func boundedSuffix(s string, maxBytes, maxLines int) string { + if len(s) == 0 || maxBytes <= 0 || maxLines <= 0 { + return "" + } + suffix := safeUTF8Suffix(s, minInt(maxBytes, len(s))) + return limitLinesSuffix(suffix, maxLines) +} + +func safeUTF8Prefix(s string, maxBytes int) string { + if maxBytes <= 0 || len(s) == 0 { + return "" + } + if maxBytes >= len(s) { + return s + } + cut := maxBytes + for cut > 0 && cut < len(s) && !utf8.RuneStart(s[cut]) { + cut-- + } + if cut <= 0 { + return "" + } + return s[:cut] +} + +func safeUTF8Suffix(s string, maxBytes int) string { + if maxBytes <= 0 || len(s) == 0 { + return "" + } + if maxBytes >= len(s) { + return s + } + start := len(s) - maxBytes + if start < 0 { + start = 0 + } + for start < len(s) && !utf8.RuneStart(s[start]) { + start++ + } + if start >= len(s) { + return "" + } + return s[start:] +} + +func limitLinesPrefix(s string, maxLines int) string { + if maxLines <= 0 || s == "" { + return "" + } + lines := strings.Split(s, "\n") + if len(lines) <= maxLines { + return s + } + return strings.Join(lines[:maxLines], "\n") +} + +func limitLinesSuffix(s string, maxLines int) string { + if maxLines <= 0 || s == "" { + return "" + } + lines := strings.Split(s, "\n") + if len(lines) <= maxLines { + return s + } + return strings.Join(lines[len(lines)-maxLines:], "\n") +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +}