mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat(tool): paginated file read with safety limits for container (#119)
* feat(tool): paginated file read with safety limits for container provider * fix(tool): harden container read pagination and binary safety
This commit is contained in:
@@ -62,11 +62,27 @@ func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionConte
|
||||
return []mcpgw.ToolDescriptor{
|
||||
{
|
||||
Name: toolRead,
|
||||
Description: "Read file content inside the bot container.",
|
||||
Description: fmt.Sprintf("Read file content inside the bot container. Supports pagination for large files. Max %d lines / %d bytes per call.", readMaxLines, readMaxBytes),
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{"type": "string", "description": fmt.Sprintf("file path (relative to %s or absolute inside container)", wd)},
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": fmt.Sprintf("File path (relative to %s or absolute inside container)", wd),
|
||||
},
|
||||
"line_offset": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (1-indexed). Default: 1.",
|
||||
"minimum": 1,
|
||||
"default": 1,
|
||||
},
|
||||
"n_lines": map[string]any{
|
||||
"type": "integer",
|
||||
"description": fmt.Sprintf("Number of lines to read per call. Default: %d (the per-call maximum). Use a smaller value with line_offset for finer pagination. Max: %d.", readMaxLines, readMaxLines),
|
||||
"minimum": 1,
|
||||
"maximum": readMaxLines,
|
||||
"default": readMaxLines,
|
||||
},
|
||||
},
|
||||
"required": []string{"path"},
|
||||
},
|
||||
@@ -77,8 +93,8 @@ func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionConte
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{"type": "string", "description": fmt.Sprintf("file path (relative to %s or absolute inside container)", wd)},
|
||||
"content": map[string]any{"type": "string", "description": "file content"},
|
||||
"path": map[string]any{"type": "string", "description": fmt.Sprintf("File path (relative to %s or absolute inside container)", wd)},
|
||||
"content": map[string]any{"type": "string", "description": "File content"},
|
||||
},
|
||||
"required": []string{"path", "content"},
|
||||
},
|
||||
@@ -89,8 +105,8 @@ func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionConte
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{"type": "string", "description": fmt.Sprintf("directory path (relative to %s or absolute inside container)", wd)},
|
||||
"recursive": map[string]any{"type": "boolean", "description": "list recursively"},
|
||||
"path": map[string]any{"type": "string", "description": fmt.Sprintf("Directory path (relative to %s or absolute inside container)", wd)},
|
||||
"recursive": map[string]any{"type": "boolean", "description": "List recursively"},
|
||||
},
|
||||
"required": []string{"path"},
|
||||
},
|
||||
@@ -101,9 +117,9 @@ func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionConte
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{"type": "string", "description": fmt.Sprintf("file path (relative to %s or absolute inside container)", wd)},
|
||||
"old_text": map[string]any{"type": "string", "description": "exact text to find"},
|
||||
"new_text": map[string]any{"type": "string", "description": "replacement text"},
|
||||
"path": map[string]any{"type": "string", "description": fmt.Sprintf("File path (relative to %s or absolute inside container)", wd)},
|
||||
"old_text": map[string]any{"type": "string", "description": "Exact text to find"},
|
||||
"new_text": map[string]any{"type": "string", "description": "Replacement text"},
|
||||
},
|
||||
"required": []string{"path", "old_text", "new_text"},
|
||||
},
|
||||
@@ -162,12 +178,44 @@ func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContex
|
||||
if filePath == "" {
|
||||
return mcpgw.BuildToolErrorResult("path is required"), nil
|
||||
}
|
||||
content, err := ExecRead(ctx, p.execRunner, botID, p.execWorkDir, filePath)
|
||||
|
||||
// Parse optional pagination params.
|
||||
lineOffset := 1
|
||||
offset, ok, err := mcpgw.IntArg(arguments, "line_offset")
|
||||
if err != nil {
|
||||
return mcpgw.BuildToolErrorResult(fmt.Sprintf("invalid line_offset: %v", err)), nil
|
||||
}
|
||||
if ok {
|
||||
if offset < 1 {
|
||||
return mcpgw.BuildToolErrorResult("line_offset must be >= 1"), nil
|
||||
}
|
||||
lineOffset = offset
|
||||
}
|
||||
|
||||
nLines := readMaxLines
|
||||
n, ok, err := mcpgw.IntArg(arguments, "n_lines")
|
||||
if err != nil {
|
||||
return mcpgw.BuildToolErrorResult(fmt.Sprintf("invalid n_lines: %v", err)), nil
|
||||
}
|
||||
if ok {
|
||||
if n < 1 {
|
||||
return mcpgw.BuildToolErrorResult("n_lines must be >= 1"), nil
|
||||
}
|
||||
if n > readMaxLines {
|
||||
n = readMaxLines
|
||||
}
|
||||
nLines = n
|
||||
}
|
||||
|
||||
result, err := ReadFile(ctx, p.execRunner, botID, p.execWorkDir, filePath, lineOffset, nLines)
|
||||
if err != nil {
|
||||
return mcpgw.BuildToolErrorResult(err.Error()), nil
|
||||
}
|
||||
|
||||
output := FormatReadResult(result)
|
||||
|
||||
return mcpgw.BuildToolSuccessResult(map[string]any{
|
||||
"content": pruneToolOutputText(content, "tool result (read content)"),
|
||||
"content": output,
|
||||
}), nil
|
||||
|
||||
case toolWrite:
|
||||
|
||||
@@ -50,8 +50,25 @@ func TestExecutor_ListTools(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestExecutor_CallTool_Read(t *testing.T) {
|
||||
callCount := 0
|
||||
runner := &fakeExecRunner{
|
||||
result: &mcpgw.ExecWithCaptureResult{Stdout: "hello world", ExitCode: 0},
|
||||
handler: func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) {
|
||||
callCount++
|
||||
cmd := strings.Join(req.Command, " ")
|
||||
switch callCount {
|
||||
case 1:
|
||||
if !strings.Contains(cmd, "head -c 8192") {
|
||||
t.Errorf("expected bounded binary probe, got %q", cmd)
|
||||
}
|
||||
case 2:
|
||||
if !strings.Contains(cmd, "sed -n") {
|
||||
t.Errorf("expected sed command, got %q", cmd)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected extra call #%d: %q", callCount, cmd)
|
||||
}
|
||||
return &mcpgw.ExecWithCaptureResult{Stdout: "hello world", ExitCode: 0}, nil
|
||||
},
|
||||
}
|
||||
exec := NewExecutor(nil, runner, "/data")
|
||||
ctx := context.Background()
|
||||
@@ -65,13 +82,69 @@ func TestExecutor_CallTool_Read(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
content, _ := result["structuredContent"].(map[string]any)
|
||||
if content["content"] != "hello world" {
|
||||
t.Errorf("content = %v", content["content"])
|
||||
if content["content"] == "" {
|
||||
t.Errorf("content should not be empty, got %v", content["content"])
|
||||
}
|
||||
// Verify the exec command contains cat.
|
||||
cmd := strings.Join(runner.lastReq.Command, " ")
|
||||
if !strings.Contains(cmd, "cat") {
|
||||
t.Errorf("expected cat command, got %q", cmd)
|
||||
if callCount != 2 {
|
||||
t.Errorf("expected 2 exec calls, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_CallTool_Read_InvalidPaginationArgs(t *testing.T) {
|
||||
runner := &fakeExecRunner{
|
||||
handler: func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) {
|
||||
t.Fatalf("unexpected exec call: %v", req.Command)
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
exec := NewExecutor(nil, runner, "/data")
|
||||
ctx := context.Background()
|
||||
session := mcpgw.ToolSessionContext{BotID: "bot1"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args map[string]any
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "invalid line_offset type",
|
||||
args: map[string]any{"path": "test.txt", "line_offset": "abc"},
|
||||
want: "invalid line_offset",
|
||||
},
|
||||
{
|
||||
name: "invalid n_lines type",
|
||||
args: map[string]any{"path": "test.txt", "n_lines": "abc"},
|
||||
want: "invalid n_lines",
|
||||
},
|
||||
{
|
||||
name: "line_offset below minimum",
|
||||
args: map[string]any{"path": "test.txt", "line_offset": 0},
|
||||
want: "line_offset must be >= 1",
|
||||
},
|
||||
{
|
||||
name: "n_lines below minimum",
|
||||
args: map[string]any{"path": "test.txt", "n_lines": 0},
|
||||
want: "n_lines must be >= 1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := exec.CallTool(ctx, session, "read", tt.args)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if isErr, _ := result["isError"].(bool); !isErr {
|
||||
t.Fatalf("expected tool error containing %q", tt.want)
|
||||
}
|
||||
msg := ""
|
||||
if content, ok := result["content"].([]map[string]any); ok && len(content) > 0 {
|
||||
msg, _ = content[0]["text"].(string)
|
||||
}
|
||||
if !strings.Contains(msg, tt.want) {
|
||||
t.Fatalf("error = %q, want substring %q", msg, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
package container
|
||||
|
||||
import textprune "github.com/memohai/memoh/internal/prune"
|
||||
import (
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
textprune "github.com/memohai/memoh/internal/prune"
|
||||
)
|
||||
|
||||
// Output pruning limits for tool results.
|
||||
const (
|
||||
toolOutputHeadBytes = 4 * 1024
|
||||
toolOutputTailBytes = 1 * 1024
|
||||
@@ -9,6 +15,19 @@ const (
|
||||
toolOutputTailLines = 50
|
||||
)
|
||||
|
||||
// Read tool limits - single conservative budget.
|
||||
// AI can paginate via line_offset/n_lines if file is larger.
|
||||
const (
|
||||
readMaxLines = 200 // Max lines per read
|
||||
readMaxBytes = 5120 // 5KB per read
|
||||
readMaxLineLength = 1000 // Max characters per line (runes)
|
||||
readHeadBytes = 3072 // 3KB head when pruning
|
||||
readTailBytes = 1024 // 1KB tail when pruning
|
||||
readHeadLines = 120 // 120 lines head when pruning
|
||||
readTailLines = 40 // 40 lines tail when pruning
|
||||
)
|
||||
|
||||
// pruneToolOutputText prunes generic tool output (exec, etc.).
|
||||
func pruneToolOutputText(text, label string) string {
|
||||
return textprune.PruneWithEdges(text, label, textprune.Config{
|
||||
MaxBytes: textprune.DefaultMaxBytes,
|
||||
@@ -20,3 +39,87 @@ func pruneToolOutputText(text, label string) string {
|
||||
Marker: textprune.DefaultMarker,
|
||||
})
|
||||
}
|
||||
|
||||
// pruneReadOutput prunes read tool output.
|
||||
func pruneReadOutput(text string) string {
|
||||
return textprune.PruneWithEdges(text, "read output", textprune.Config{
|
||||
MaxBytes: readMaxBytes,
|
||||
MaxLines: readMaxLines,
|
||||
HeadBytes: readHeadBytes,
|
||||
TailBytes: readTailBytes,
|
||||
HeadLines: readHeadLines,
|
||||
TailLines: readTailLines,
|
||||
Marker: textprune.DefaultMarker,
|
||||
})
|
||||
}
|
||||
|
||||
// truncateLine truncates a line to maxLength runes (not bytes) and adds ellipsis if truncated.
|
||||
func truncateLine(line string, maxLength int) string {
|
||||
if maxLength <= 0 {
|
||||
return line
|
||||
}
|
||||
|
||||
// Count runes, not bytes.
|
||||
runeCount := utf8.RuneCountInString(line)
|
||||
if runeCount <= maxLength {
|
||||
return line
|
||||
}
|
||||
|
||||
// Find the byte position where we should cut (at maxLength runes).
|
||||
bytePos := 0
|
||||
runes := 0
|
||||
for bytePos < len(line) && runes < maxLength {
|
||||
_, size := utf8.DecodeRuneInString(line[bytePos:])
|
||||
bytePos += size
|
||||
runes++
|
||||
}
|
||||
|
||||
return line[:bytePos] + "..."
|
||||
}
|
||||
|
||||
// formatTruncatedLines formats a list of line numbers for display, collapsing consecutive numbers.
|
||||
func formatTruncatedLines(lines []int) string {
|
||||
if len(lines) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(lines) == 1 {
|
||||
return itoa(lines[0])
|
||||
}
|
||||
if len(lines) <= 3 {
|
||||
parts := make([]string, len(lines))
|
||||
for i, n := range lines {
|
||||
parts[i] = itoa(n)
|
||||
}
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
// For many truncated lines, show count and examples.
|
||||
return itoa(lines[0]) + ", " + itoa(lines[1]) + ", " + itoa(lines[2]) + "... (" + itoa(len(lines)) + " total)"
|
||||
}
|
||||
|
||||
// itoa converts int to string without allocation.
|
||||
func itoa(n int) string {
|
||||
if n == 0 {
|
||||
return "0"
|
||||
}
|
||||
var buf [20]byte
|
||||
i := len(buf)
|
||||
sign := n < 0
|
||||
var u uint64
|
||||
if sign {
|
||||
// Avoid overflow for MinInt.
|
||||
u = uint64(-(n + 1))
|
||||
u++
|
||||
} else {
|
||||
u = uint64(n)
|
||||
}
|
||||
for u > 0 {
|
||||
i--
|
||||
buf[i] = byte('0' + u%10)
|
||||
u /= 10
|
||||
}
|
||||
if sign {
|
||||
i--
|
||||
buf[i] = '-'
|
||||
}
|
||||
return string(buf[i:])
|
||||
}
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
package container
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestItoa_MatchesStrconv(t *testing.T) {
|
||||
minInt := -int(^uint(0)>>1) - 1
|
||||
|
||||
tests := []int{
|
||||
0,
|
||||
1,
|
||||
-1,
|
||||
42,
|
||||
-42,
|
||||
123456789,
|
||||
-123456789,
|
||||
minInt,
|
||||
}
|
||||
|
||||
for _, n := range tests {
|
||||
got := itoa(n)
|
||||
want := strconv.Itoa(n)
|
||||
if got != want {
|
||||
t.Fatalf("itoa(%d) = %q, want %q", n, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,251 @@
|
||||
package container
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
mcpgw "github.com/memohai/memoh/internal/mcp"
|
||||
)
|
||||
|
||||
// ReadResult contains the result of reading a file with pagination.
|
||||
type ReadResult struct {
|
||||
Content string
|
||||
LinesRead int
|
||||
StartLine int
|
||||
EndLine int
|
||||
TotalLinesAvailable int // -1 if unknown
|
||||
MaxLinesReached bool
|
||||
MaxBytesReached bool
|
||||
TruncatedLineNumbers []int
|
||||
EndOfFile bool
|
||||
}
|
||||
|
||||
const readBinaryProbeBytes = 8 * 1024
|
||||
|
||||
// ReadFile reads a file inside the container with pagination support.
|
||||
// It reads from line_offset (1-indexed) for up to n_lines lines.
|
||||
// Limits: max 200 lines / 5KB per call (see readMaxLines and readMaxBytes constants).
|
||||
func ReadFile(ctx context.Context, runner ExecRunner, botID, workDir, filePath string, lineOffset, nLines int) (*ReadResult, error) {
|
||||
if lineOffset < 1 {
|
||||
lineOffset = 1
|
||||
}
|
||||
if nLines < 1 {
|
||||
nLines = readMaxLines
|
||||
}
|
||||
if nLines > readMaxLines {
|
||||
nLines = readMaxLines
|
||||
}
|
||||
|
||||
// Probe only the file prefix first to avoid streaming huge binary payloads via sed.
|
||||
probeCmd := fmt.Sprintf("head -c %d %s", readBinaryProbeBytes, ShellQuote(filePath))
|
||||
probe, err := runner.ExecWithCapture(ctx, mcpgw.ExecRequest{
|
||||
BotID: botID,
|
||||
Command: []string{"/bin/sh", "-c", wrapWithCd(workDir, probeCmd)},
|
||||
WorkDir: workDir,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if probe.ExitCode != 0 {
|
||||
return nil, fmt.Errorf("%s", strings.TrimSpace(probe.Stderr))
|
||||
}
|
||||
if bytes.IndexByte([]byte(probe.Stdout), 0) >= 0 {
|
||||
return nil, fmt.Errorf("file appears to be binary. Read tool only supports text files")
|
||||
}
|
||||
|
||||
// Use sed to read specific line range efficiently.
|
||||
// sed -n '10,110p' file -> reads lines 10-110 (inclusive)
|
||||
endLine := lineOffset
|
||||
if nLines > 1 {
|
||||
if lineOffset > math.MaxInt-(nLines-1) {
|
||||
endLine = math.MaxInt
|
||||
} else {
|
||||
endLine = lineOffset + nLines - 1
|
||||
}
|
||||
}
|
||||
sedCmd := fmt.Sprintf("sed -n '%d,%dp' %s", lineOffset, endLine, ShellQuote(filePath))
|
||||
|
||||
result, err := runner.ExecWithCapture(ctx, mcpgw.ExecRequest{
|
||||
BotID: botID,
|
||||
Command: []string{"/bin/sh", "-c", wrapWithCd(workDir, sedCmd)},
|
||||
WorkDir: workDir,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.ExitCode != 0 {
|
||||
return nil, fmt.Errorf("%s", strings.TrimSpace(result.Stderr))
|
||||
}
|
||||
|
||||
// Parse the output with line truncation.
|
||||
return parseReadOutput(result.Stdout, lineOffset, nLines, -1), nil
|
||||
}
|
||||
|
||||
// parseReadOutput parses command output and applies line length limits.
|
||||
func parseReadOutput(content string, startLine, requestedLines, totalLines int) *ReadResult {
|
||||
result := &ReadResult{
|
||||
StartLine: startLine,
|
||||
TruncatedLineNumbers: []int{},
|
||||
}
|
||||
|
||||
if content == "" {
|
||||
result.EndLine = startLine - 1
|
||||
// Empty result from sed means we've reached EOF (empty file or offset past end).
|
||||
result.EndOfFile = true
|
||||
result.TotalLinesAvailable = totalLines
|
||||
return result
|
||||
}
|
||||
|
||||
var lines []string
|
||||
var nBytes int
|
||||
currentLine := startLine
|
||||
|
||||
for i := 0; i < len(content); {
|
||||
if len(lines) >= readMaxLines {
|
||||
break
|
||||
}
|
||||
|
||||
nextNewline := strings.IndexByte(content[i:], '\n')
|
||||
var line string
|
||||
if nextNewline < 0 {
|
||||
line = content[i:]
|
||||
i = len(content)
|
||||
} else {
|
||||
line = content[i : i+nextNewline]
|
||||
i += nextNewline + 1
|
||||
}
|
||||
|
||||
// Apply max line length limit.
|
||||
wasTruncated := utf8.RuneCountInString(line) > readMaxLineLength
|
||||
truncatedLine := truncateLine(line, readMaxLineLength)
|
||||
if wasTruncated {
|
||||
result.TruncatedLineNumbers = append(result.TruncatedLineNumbers, currentLine)
|
||||
}
|
||||
|
||||
// Format with line number like `cat -n`: 6-digit width, right-aligned, tab separator.
|
||||
formattedLine := fmt.Sprintf("%6d\t%s\n", currentLine, truncatedLine)
|
||||
|
||||
// Check if adding this line would exceed max bytes.
|
||||
if nBytes+len(formattedLine) > readMaxBytes {
|
||||
result.MaxBytesReached = true
|
||||
break
|
||||
}
|
||||
|
||||
lines = append(lines, formattedLine)
|
||||
nBytes += len(formattedLine)
|
||||
currentLine++
|
||||
}
|
||||
|
||||
result.Content = strings.Join(lines, "")
|
||||
result.LinesRead = len(lines)
|
||||
result.EndLine = startLine + len(lines) - 1
|
||||
if result.EndLine < startLine {
|
||||
result.EndLine = startLine - 1
|
||||
}
|
||||
result.TotalLinesAvailable = totalLines
|
||||
if result.LinesRead >= readMaxLines {
|
||||
// Reaching max lines is only meaningful when there may be more data available.
|
||||
result.MaxLinesReached = totalLines < 0 || result.EndLine < totalLines
|
||||
}
|
||||
|
||||
// Determine if we reached end of file.
|
||||
if totalLines >= 0 {
|
||||
result.EndOfFile = result.EndLine >= totalLines
|
||||
} else {
|
||||
// Without total lines info, assume EOF if we got fewer lines than requested.
|
||||
result.EndOfFile = len(lines) < requestedLines && !result.MaxBytesReached
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// FormatReadResult formats a ReadResult into the final output string.
|
||||
func FormatReadResult(r *ReadResult) string {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if r.Content != "" {
|
||||
buf.WriteString(r.Content)
|
||||
// Ensure trailing newline if content doesn't end with one.
|
||||
if !strings.HasSuffix(r.Content, "\n") {
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
|
||||
// Build status message.
|
||||
var messages []string
|
||||
|
||||
if r.LinesRead == 0 {
|
||||
if r.StartLine > 1 {
|
||||
messages = append(messages, fmt.Sprintf("No lines read from file (starting from line %d).", r.StartLine))
|
||||
} else {
|
||||
messages = append(messages, "File is empty.")
|
||||
}
|
||||
} else {
|
||||
if r.StartLine == r.EndLine {
|
||||
messages = append(messages, fmt.Sprintf("Read 1 line (line %d).", r.StartLine))
|
||||
} else {
|
||||
messages = append(messages, fmt.Sprintf("Read %d lines (%d-%d).",
|
||||
r.LinesRead, r.StartLine, r.EndLine))
|
||||
}
|
||||
}
|
||||
|
||||
if r.MaxLinesReached {
|
||||
messages = append(messages, fmt.Sprintf("Limit %d lines reached.", readMaxLines))
|
||||
}
|
||||
if r.MaxBytesReached {
|
||||
messages = append(messages, fmt.Sprintf("Limit %d bytes reached.", readMaxBytes))
|
||||
}
|
||||
if r.EndOfFile {
|
||||
if !r.MaxLinesReached && !r.MaxBytesReached {
|
||||
messages = append(messages, "End of file.")
|
||||
}
|
||||
} else if r.EndLine >= r.StartLine {
|
||||
nextOffset := r.EndLine
|
||||
if r.EndLine < math.MaxInt {
|
||||
nextOffset = r.EndLine + 1
|
||||
}
|
||||
if r.TotalLinesAvailable > 0 {
|
||||
messages = append(messages, fmt.Sprintf("Total %d lines. Continue with line_offset=%d.",
|
||||
r.TotalLinesAvailable, nextOffset))
|
||||
} else {
|
||||
// Unknown total but not EOF - suggest continue anyway.
|
||||
messages = append(messages, fmt.Sprintf("Continue with line_offset=%d if more content exists.", nextOffset))
|
||||
}
|
||||
}
|
||||
|
||||
if len(r.TruncatedLineNumbers) > 0 {
|
||||
messages = append(messages, fmt.Sprintf("Truncated: %s.", formatTruncatedLines(r.TruncatedLineNumbers)))
|
||||
}
|
||||
|
||||
// Write status messages on separate lines for readability.
|
||||
if len(messages) > 0 {
|
||||
buf.WriteString("\n")
|
||||
for _, msg := range messages {
|
||||
buf.WriteString(msg)
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// ReadFileSimple reads an entire file without pagination (for backward compatibility/internal use).
|
||||
// Suitable for small files only; applies pruning.
|
||||
func ReadFileSimple(ctx context.Context, runner ExecRunner, botID, workDir, filePath string) (string, error) {
|
||||
result, err := runner.ExecWithCapture(ctx, mcpgw.ExecRequest{
|
||||
BotID: botID,
|
||||
Command: []string{"/bin/sh", "-c", wrapWithCd(workDir, "cat "+ShellQuote(filePath))},
|
||||
WorkDir: workDir,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if result.ExitCode != 0 {
|
||||
return "", fmt.Errorf("%s", strings.TrimSpace(result.Stderr))
|
||||
}
|
||||
return pruneReadOutput(result.Stdout), nil
|
||||
}
|
||||
@@ -0,0 +1,169 @@
|
||||
package container
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
mcpgw "github.com/memohai/memoh/internal/mcp"
|
||||
)
|
||||
|
||||
type scriptedReadRunner struct {
|
||||
handler func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error)
|
||||
calls []mcpgw.ExecRequest
|
||||
}
|
||||
|
||||
func (r *scriptedReadRunner) ExecWithCapture(ctx context.Context, req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) {
|
||||
r.calls = append(r.calls, req)
|
||||
return r.handler(req)
|
||||
}
|
||||
|
||||
func TestParseReadOutput_LongSingleLineIsTruncated(t *testing.T) {
|
||||
// 2MB single line without '\n' should still be readable in one page and truncated by rune limit.
|
||||
longLine := strings.Repeat("a", 2*1024*1024)
|
||||
|
||||
result := parseReadOutput(longLine, 1, readMaxLines, 1)
|
||||
|
||||
if result.LinesRead != 1 {
|
||||
t.Fatalf("LinesRead = %d, want 1", result.LinesRead)
|
||||
}
|
||||
if result.EndLine != 1 {
|
||||
t.Fatalf("EndLine = %d, want 1", result.EndLine)
|
||||
}
|
||||
if !result.EndOfFile {
|
||||
t.Fatalf("EndOfFile = false, want true")
|
||||
}
|
||||
if result.MaxBytesReached {
|
||||
t.Fatalf("MaxBytesReached = true, want false")
|
||||
}
|
||||
if len(result.TruncatedLineNumbers) != 1 || result.TruncatedLineNumbers[0] != 1 {
|
||||
t.Fatalf("TruncatedLineNumbers = %v, want [1]", result.TruncatedLineNumbers)
|
||||
}
|
||||
if !strings.Contains(result.Content, "\t"+strings.Repeat("a", readMaxLineLength)+"...\n") {
|
||||
t.Fatalf("content does not contain expected truncated output, got: %q", result.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseReadOutput_TruncationMarkerForNearThresholdLine(t *testing.T) {
|
||||
// 1001 ASCII chars: truncation happens, but output becomes 1003 chars due to "...".
|
||||
// This verifies truncation tracking doesn't rely on byte-length shrinkage.
|
||||
line := strings.Repeat("x", readMaxLineLength+1)
|
||||
|
||||
result := parseReadOutput(line, 1, readMaxLines, 1)
|
||||
|
||||
if len(result.TruncatedLineNumbers) != 1 || result.TruncatedLineNumbers[0] != 1 {
|
||||
t.Fatalf("TruncatedLineNumbers = %v, want [1]", result.TruncatedLineNumbers)
|
||||
}
|
||||
|
||||
formatted := FormatReadResult(result)
|
||||
if !strings.Contains(formatted, "Truncated: 1.") {
|
||||
t.Fatalf("formatted output missing truncation marker, got: %q", formatted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseReadOutput_EmptyContentWithoutTotalMarksEOF(t *testing.T) {
|
||||
result := parseReadOutput("", 401, readMaxLines, -1)
|
||||
|
||||
if !result.EndOfFile {
|
||||
t.Fatalf("EndOfFile = false, want true")
|
||||
}
|
||||
if result.LinesRead != 0 {
|
||||
t.Fatalf("LinesRead = %d, want 0", result.LinesRead)
|
||||
}
|
||||
|
||||
formatted := FormatReadResult(result)
|
||||
if strings.Contains(formatted, "Continue with line_offset=") {
|
||||
t.Fatalf("formatted output should not contain continuation hint, got: %q", formatted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFile_DoesNotScanWholeFileForTotalLines(t *testing.T) {
|
||||
runner := &scriptedReadRunner{}
|
||||
runner.handler = func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) {
|
||||
cmd := strings.Join(req.Command, " ")
|
||||
switch {
|
||||
case strings.Contains(cmd, "head -c 8192"):
|
||||
return &mcpgw.ExecWithCaptureResult{Stdout: "line\n", ExitCode: 0}, nil
|
||||
case strings.Contains(cmd, "sed -n"):
|
||||
return &mcpgw.ExecWithCaptureResult{Stdout: strings.Repeat("line\n", readMaxLines), ExitCode: 0}, nil
|
||||
default:
|
||||
t.Fatalf("unexpected command: %q", cmd)
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
result, err := ReadFile(context.Background(), runner, "bot-1", "/data", "test.txt", 201, 200)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if result.TotalLinesAvailable != -1 {
|
||||
t.Fatalf("TotalLinesAvailable = %d, want -1", result.TotalLinesAvailable)
|
||||
}
|
||||
if result.EndOfFile {
|
||||
t.Fatalf("EndOfFile = true, want false")
|
||||
}
|
||||
if result.LinesRead != 200 {
|
||||
t.Fatalf("LinesRead = %d, want 200", result.LinesRead)
|
||||
}
|
||||
|
||||
for _, req := range runner.calls {
|
||||
cmd := strings.Join(req.Command, " ")
|
||||
if strings.Contains(cmd, "awk 'END {print NR}'") || strings.Contains(cmd, "wc -l") {
|
||||
t.Fatalf("unexpected full-file line-count command: %q", cmd)
|
||||
}
|
||||
}
|
||||
if len(runner.calls) != 2 {
|
||||
t.Fatalf("expected exactly 2 commands to be executed, got %d", len(runner.calls))
|
||||
}
|
||||
|
||||
formatted := FormatReadResult(result)
|
||||
if !strings.Contains(formatted, "Continue with line_offset=401 if more content exists.") {
|
||||
t.Fatalf("formatted output missing continuation hint, got: %q", formatted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFile_BinaryContentReturnsError(t *testing.T) {
|
||||
runner := &scriptedReadRunner{}
|
||||
runner.handler = func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) {
|
||||
cmd := strings.Join(req.Command, " ")
|
||||
if strings.Contains(cmd, "head -c 8192") {
|
||||
return &mcpgw.ExecWithCaptureResult{
|
||||
Stdout: string([]byte{'a', 0, 'b'}),
|
||||
ExitCode: 0,
|
||||
}, nil
|
||||
}
|
||||
t.Fatalf("unexpected command: %q", cmd)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result, err := ReadFile(context.Background(), runner, "bot-1", "/data", "test.txt", 1, 10)
|
||||
if err == nil {
|
||||
t.Fatalf("expected binary-file error, got nil result=%v", result)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "Read tool only supports text files") {
|
||||
t.Fatalf("error = %q, want binary-file message", err.Error())
|
||||
}
|
||||
if len(runner.calls) != 1 {
|
||||
t.Fatalf("expected binary detection to stop before sed, got %d calls", len(runner.calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatReadResult_ContinuationHintWhenMaxLinesReached(t *testing.T) {
|
||||
content := strings.Repeat("line\n", readMaxLines)
|
||||
result := parseReadOutput(content, 1, readMaxLines, -1)
|
||||
if !result.MaxLinesReached {
|
||||
t.Fatalf("MaxLinesReached = false, want true")
|
||||
}
|
||||
if result.EndOfFile {
|
||||
t.Fatalf("EndOfFile = true, want false")
|
||||
}
|
||||
|
||||
formatted := FormatReadResult(result)
|
||||
if !strings.Contains(formatted, "Limit 200 lines reached.\nContinue with line_offset=201 if more content exists.") {
|
||||
t.Fatalf("formatted output missing continuation after limit, got: %q", formatted)
|
||||
}
|
||||
if strings.Contains(formatted, "Limit 200 lines reached. Continue with line_offset=201 if more content exists.") {
|
||||
t.Fatalf("status messages should be on separate lines, got: %q", formatted)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user