package main import ( "bufio" "bytes" "context" "fmt" "io" "io/fs" "os" "os/exec" "path/filepath" "strings" "time" "unicode/utf8" pb "github.com/memohai/memoh/internal/mcp/mcpcontainer" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) const ( readMaxLines = 200 readMaxBytes = 5120 readMaxLineLen = 1000 binaryProbeBytes = 8 * 1024 rawChunkSize = 64 * 1024 defaultWorkDir = "/data" defaultTimeout = 30 ) type containerServer struct { pb.UnimplementedContainerServiceServer } func (s *containerServer) ReadFile(_ context.Context, req *pb.ReadFileRequest) (*pb.ReadFileResponse, error) { path := req.GetPath() if path == "" { return nil, status.Error(codes.InvalidArgument, "path is required") } path = resolvePath(path) f, err := os.Open(path) if err != nil { return nil, status.Errorf(codes.NotFound, "open: %v", err) } defer f.Close() probe := make([]byte, binaryProbeBytes) n, _ := f.Read(probe) if bytes.IndexByte(probe[:n], 0) >= 0 { return &pb.ReadFileResponse{Binary: true}, nil } if _, err := f.Seek(0, io.SeekStart); err != nil { return nil, status.Errorf(codes.Internal, "seek: %v", err) } lineOffset := int(req.GetLineOffset()) if lineOffset < 1 { lineOffset = 1 } nLines := int(req.GetNLines()) if nLines < 1 || nLines > readMaxLines { nLines = readMaxLines } scanner := bufio.NewScanner(f) scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) currentLine := 0 totalLines := 0 var out strings.Builder linesRead := 0 bytesWritten := 0 for scanner.Scan() { currentLine++ totalLines = currentLine if currentLine < lineOffset { continue } if linesRead >= nLines { continue // keep scanning to count total lines } line := scanner.Text() if utf8.RuneCountInString(line) > readMaxLineLen { line = truncateRunes(line, readMaxLineLen) + "..." } formatted := fmt.Sprintf("%6d\t%s\n", currentLine, line) if bytesWritten+len(formatted) > readMaxBytes { break } out.WriteString(formatted) bytesWritten += len(formatted) linesRead++ } // Drain remaining lines for total count. for scanner.Scan() { totalLines++ } return &pb.ReadFileResponse{ Content: out.String(), TotalLines: int32(totalLines), }, nil } func (s *containerServer) WriteFile(_ context.Context, req *pb.WriteFileRequest) (*pb.WriteFileResponse, error) { path := req.GetPath() if path == "" { return nil, status.Error(codes.InvalidArgument, "path is required") } path = resolvePath(path) if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { return nil, status.Errorf(codes.Internal, "mkdir: %v", err) } if err := os.WriteFile(path, req.GetContent(), 0o644); err != nil { return nil, status.Errorf(codes.Internal, "write: %v", err) } return &pb.WriteFileResponse{}, nil } func (s *containerServer) ListDir(_ context.Context, req *pb.ListDirRequest) (*pb.ListDirResponse, error) { dir := req.GetPath() if dir == "" { dir = "." } dir = resolvePath(dir) var entries []*pb.FileEntry if req.GetRecursive() { err := filepath.WalkDir(dir, func(p string, d fs.DirEntry, err error) error { if err != nil { return nil // skip errors } rel, _ := filepath.Rel(dir, p) if rel == "." { return nil } entry, _ := buildFileEntry(rel, p, d) if entry != nil { entries = append(entries, entry) } return nil }) if err != nil { return nil, status.Errorf(codes.NotFound, "walk: %v", err) } } else { dirEntries, err := os.ReadDir(dir) if err != nil { return nil, status.Errorf(codes.NotFound, "readdir: %v", err) } for _, d := range dirEntries { entry, _ := buildFileEntry(d.Name(), filepath.Join(dir, d.Name()), d) if entry != nil { entries = append(entries, entry) } } } return &pb.ListDirResponse{Entries: entries}, nil } func (s *containerServer) Exec(stream pb.ContainerService_ExecServer) error { // Receive first message to get command details firstMsg, err := stream.Recv() if err != nil { return status.Error(codes.InvalidArgument, "failed to receive exec config") } command := firstMsg.GetCommand() if command == "" { return status.Error(codes.InvalidArgument, "command is required") } workDir := firstMsg.GetWorkDir() if workDir == "" { workDir = defaultWorkDir } timeout := int(firstMsg.GetTimeoutSeconds()) if timeout <= 0 { timeout = defaultTimeout } ctx, cancel := context.WithTimeout(stream.Context(), time.Duration(timeout)*time.Second) defer cancel() cmd := exec.CommandContext(ctx, "/bin/sh", "-c", command) cmd.Dir = workDir if len(firstMsg.GetEnv()) > 0 { cmd.Env = append(os.Environ(), firstMsg.GetEnv()...) } // Setup stdin pipe for bidirectional streaming stdinPipe, err := cmd.StdinPipe() if err != nil { return status.Errorf(codes.Internal, "stdin pipe: %v", err) } stdoutPipe, err := cmd.StdoutPipe() if err != nil { return status.Errorf(codes.Internal, "stdout pipe: %v", err) } stderrPipe, err := cmd.StderrPipe() if err != nil { return status.Errorf(codes.Internal, "stderr pipe: %v", err) } if err := cmd.Start(); err != nil { return status.Errorf(codes.Internal, "start: %v", err) } // Handle stdin from stream go func() { for { msg, err := stream.Recv() if err != nil { _ = stdinPipe.Close() return } if data := msg.GetStdinData(); len(data) > 0 { _, _ = stdinPipe.Write(data) } } }() // Stream stdout/stderr to client done := make(chan struct{}) go func() { defer close(done) streamPipe(stream, stdoutPipe, pb.ExecOutput_STDOUT) }() streamPipe(stream, stderrPipe, pb.ExecOutput_STDERR) <-done exitCode := int32(0) if err := cmd.Wait(); err != nil { if exitErr, ok := err.(*exec.ExitError); ok { exitCode = int32(exitErr.ExitCode()) } else { exitCode = -1 } } return stream.Send(&pb.ExecOutput{ Stream: pb.ExecOutput_EXIT, ExitCode: exitCode, }) } func (s *containerServer) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error { path := req.GetPath() if path == "" { return status.Error(codes.InvalidArgument, "path is required") } path = resolvePath(path) f, err := os.Open(path) if err != nil { return status.Errorf(codes.NotFound, "open: %v", err) } defer f.Close() buf := make([]byte, rawChunkSize) for { n, err := f.Read(buf) if n > 0 { if sendErr := stream.Send(&pb.DataChunk{Data: buf[:n]}); sendErr != nil { return sendErr } } if err == io.EOF { break } if err != nil { return status.Errorf(codes.Internal, "read: %v", err) } } return nil } func (s *containerServer) WriteRaw(stream pb.ContainerService_WriteRawServer) error { var f *os.File var written int64 for { chunk, err := stream.Recv() if err == io.EOF { break } if err != nil { return err } if f == nil { path := chunk.GetPath() if path == "" { return status.Error(codes.InvalidArgument, "first chunk must include path") } path = resolvePath(path) if mkErr := os.MkdirAll(filepath.Dir(path), 0o755); mkErr != nil { return status.Errorf(codes.Internal, "mkdir: %v", mkErr) } f, err = os.Create(path) if err != nil { return status.Errorf(codes.Internal, "create: %v", err) } defer f.Close() } if len(chunk.GetData()) > 0 { n, err := f.Write(chunk.GetData()) written += int64(n) if err != nil { return status.Errorf(codes.Internal, "write: %v", err) } } } return stream.SendAndClose(&pb.WriteRawResponse{BytesWritten: written}) } func (s *containerServer) DeleteFile(_ context.Context, req *pb.DeleteFileRequest) (*pb.DeleteFileResponse, error) { path := req.GetPath() if path == "" { return nil, status.Error(codes.InvalidArgument, "path is required") } path = resolvePath(path) var err error if req.GetRecursive() { err = os.RemoveAll(path) } else { err = os.Remove(path) } if err != nil && !os.IsNotExist(err) { return nil, status.Errorf(codes.Internal, "delete: %v", err) } return &pb.DeleteFileResponse{}, nil } func (s *containerServer) Stat(_ context.Context, req *pb.StatRequest) (*pb.StatResponse, error) { path := req.GetPath() if path == "" { return nil, status.Error(codes.InvalidArgument, "path is required") } path = resolvePath(path) info, err := os.Stat(path) if err != nil { if os.IsNotExist(err) { return nil, status.Error(codes.NotFound, "not found") } return nil, status.Errorf(codes.Internal, "stat: %v", err) } return &pb.StatResponse{ Entry: &pb.FileEntry{ Path: filepath.Base(path), IsDir: info.IsDir(), Size: info.Size(), Mode: info.Mode().String(), ModTime: info.ModTime().Format(time.RFC3339), }, }, nil } func (s *containerServer) Mkdir(_ context.Context, req *pb.MkdirRequest) (*pb.MkdirResponse, error) { path := req.GetPath() if path == "" { return nil, status.Error(codes.InvalidArgument, "path is required") } path = resolvePath(path) if err := os.MkdirAll(path, 0o755); err != nil { return nil, status.Errorf(codes.Internal, "mkdir: %v", err) } return &pb.MkdirResponse{}, nil } func (s *containerServer) Rename(_ context.Context, req *pb.RenameRequest) (*pb.RenameResponse, error) { oldPath := req.GetOldPath() newPath := req.GetNewPath() if oldPath == "" || newPath == "" { return nil, status.Error(codes.InvalidArgument, "old_path and new_path are required") } oldPath = resolvePath(oldPath) newPath = resolvePath(newPath) if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil { return nil, status.Errorf(codes.Internal, "mkdir parent: %v", err) } if err := os.Rename(oldPath, newPath); err != nil { return nil, status.Errorf(codes.Internal, "rename: %v", err) } return &pb.RenameResponse{}, nil } func streamPipe(stream pb.ContainerService_ExecServer, r io.Reader, st pb.ExecOutput_Stream) { buf := make([]byte, 4096) for { n, err := r.Read(buf) if n > 0 { _ = stream.Send(&pb.ExecOutput{ Stream: st, Data: buf[:n], }) } if err != nil { break } } } func buildFileEntry(name, fullPath string, d fs.DirEntry) (*pb.FileEntry, error) { info, err := d.Info() if err != nil { return nil, err } return &pb.FileEntry{ Path: name, IsDir: d.IsDir(), Size: info.Size(), Mode: info.Mode().String(), ModTime: info.ModTime().Format(time.RFC3339), }, nil } func resolvePath(path string) string { if filepath.IsAbs(path) { return filepath.Clean(path) } return filepath.Join(defaultWorkDir, path) } func truncateRunes(s string, max int) string { pos := 0 count := 0 for pos < len(s) && count < max { _, size := utf8.DecodeRuneInString(s[pos:]) pos += size count++ } return s[:pos] }