import contextlib
import glob
import logging
import os
import platform
import shutil
import subprocess
import sys
import tempfile
from textwrap import dedent
from concurrent.futures import ThreadPoolExecutor
from distutils import sysconfig
from distutils.ccompiler import CCompiler, new_compiler
from distutils.sysconfig import customize_compiler
from subprocess import PIPE, Popen
from sys import platform as _platform
import ewah_bool_utils
from setuptools.command.build_ext import build_ext as _build_ext
from setuptools.command.sdist import sdist as _sdist
from setuptools.errors import CompileError, LinkError
import importlib.resources as importlib_resources
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
log = logging.getLogger("setupext")
USE_PY_LIMITED_API = (
os.getenv('YT_LIMITED_API', '0') == '1'
and sys.version_info >= (3, 11)
and not sysconfig.get_config_var("Py_GIL_DISABLED")
)
ABI3_TARGET_VERSION = "".join(str(_) for _ in sys.version_info[:2])
ABI3_TARGET_HEX = hex(sys.hexversion & 0xFFFF00F0)
@contextlib.contextmanager
def stdchannel_redirected(stdchannel, dest_filename):
"""
A context manager to temporarily redirect stdout or stderr
e.g.:
with stdchannel_redirected(sys.stderr, os.devnull):
if compiler.has_function('clock_gettime', libraries=['rt']):
libraries.append('rt')
Code adapted from https://stackoverflow.com/a/17752455/1382869
"""
try:
oldstdchannel = os.dup(stdchannel.fileno())
dest_file = open(dest_filename, "w")
os.dup2(dest_file.fileno(), stdchannel.fileno())
yield
finally:
if oldstdchannel is not None:
os.dup2(oldstdchannel, stdchannel.fileno())
if dest_file is not None:
dest_file.close()
def check_for_openmp():
"""Returns OpenMP compiler and linker flags if local setup supports
OpenMP or [], [] otherwise
Code adapted from astropy_helpers, originally written by Tom
Robitaille and Curtis McCully.
"""
ccompiler = new_compiler()
customize_compiler(ccompiler)
tmp_dir = tempfile.mkdtemp()
start_dir = os.path.abspath(".")
CCODE = dedent("""\
#include <omp.h>
#include <stdio.h>
int main() {
omp_set_num_threads(2);
#pragma omp parallel
printf("nthreads=%d\\n", omp_get_num_threads());
return 0;
}"""
)
if os.name == "nt":
compile_flags = ["-openmp"]
link_flags = [""]
else:
compile_flags = ["-fopenmp"]
link_flags = ["-fopenmp"]
try:
os.chdir(tmp_dir)
with open("test_openmp.c", "w") as f:
f.write(CCODE)
os.mkdir("objects")
with stdchannel_redirected(sys.stderr, os.devnull):
ccompiler.compile(
["test_openmp.c"], output_dir="objects", extra_postargs=compile_flags
)
ccompiler.link_executable(
glob.glob(os.path.join("objects", "*")),
"test_openmp",
extra_postargs=link_flags,
)
output = (
subprocess.check_output("./test_openmp")
.decode(sys.stdout.encoding or "utf-8")
.splitlines()
)
if "nthreads=" in output[0]:
nthreads = int(output[0].strip().split("=")[1])
if len(output) == nthreads:
using_openmp = True
else:
log.warning(
"Unexpected number of lines from output of test "
"OpenMP program (output was %s)",
output,
)
using_openmp = False
else:
log.warning(
"Unexpected output from test OpenMP program (output was %s)", output
)
using_openmp = False
except (CompileError, LinkError):
using_openmp = False
finally:
os.chdir(start_dir)
if using_openmp:
log.warning("Using OpenMP to compile parallel extensions")
else:
log.warning(
"Unable to compile OpenMP test program so Cython\n"
"extensions will be compiled without parallel support"
)
if using_openmp:
return compile_flags, link_flags
else:
return [], []
def check_CPP14_flag(compile_flags):
ccompiler = new_compiler()
customize_compiler(ccompiler)
tmp_dir = tempfile.mkdtemp()
start_dir = os.path.abspath(".")
CPPCODE = dedent("""\
#include <vector>
struct node {
std::vector<int> vic;
bool visited = false;
};
int main() {
return 0;
}"""
)
os.chdir(tmp_dir)
try:
with open("test_cpp14.cpp", "w") as f:
f.write(CPPCODE)
os.mkdir("objects")
with stdchannel_redirected(sys.stderr, os.devnull):
ccompiler.compile(
["test_cpp14.cpp"], output_dir="objects", extra_postargs=compile_flags
)
return True
except CompileError:
return False
finally:
os.chdir(start_dir)
def check_CPP14_flags(possible_compile_flags):
for flags in possible_compile_flags:
if check_CPP14_flag([flags]):
return flags
log.warning(
"Your compiler seems to be too old to support C++14. "
"yt may not be able to compile. Please use a newer version."
)
return []
def check_for_pyembree(std_libs):
embree_libs = []
embree_aliases = {}
try:
importlib_resources.files("pyembree")
except ImportError:
return embree_libs, embree_aliases
embree_prefix = os.path.abspath(read_embree_location())
embree_inc_dir = os.path.join(embree_prefix, "include")
embree_lib_dir = os.path.join(embree_prefix, "lib")
if _platform == "darwin":
embree_lib_name = "embree.2"
else:
embree_lib_name = "embree"
embree_aliases["EMBREE_INC_DIR"] = ["yt/utilities/lib/", embree_inc_dir]
embree_aliases["EMBREE_LIB_DIR"] = [embree_lib_dir]
embree_aliases["EMBREE_LIBS"] = std_libs + [embree_lib_name]
embree_libs += ["yt/utilities/lib/embree_mesh/*.pyx"]
if in_conda_env():
conda_basedir = os.path.dirname(os.path.dirname(sys.executable))
embree_aliases["EMBREE_INC_DIR"].append(os.path.join(conda_basedir, "include"))
embree_aliases["EMBREE_LIB_DIR"].append(os.path.join(conda_basedir, "lib"))
return embree_libs, embree_aliases
def in_conda_env():
return any(s in sys.version for s in ("Anaconda", "Continuum", "conda-forge"))
def read_embree_location():
"""
Attempts to locate the embree installation. First, we check for an
EMBREE_DIR environment variable. If one is not defined, we look for
an embree.cfg file in the root yt source directory. Finally, if that
is not present, we default to /usr/local. If embree is installed in a
non-standard location and none of the above are set, the compile will
not succeed. This only gets called if check_for_pyembree() returns
something other than None.
"""
rd = os.environ.get("EMBREE_DIR")
if rd is None:
try:
rd = open("embree.cfg").read().strip()
except IOError:
rd = "/usr/local"
fail_msg = (
"I attempted to find Embree headers in %s. \n"
"If this is not correct, please set your correct embree location \n"
"using EMBREE_DIR environment variable or your embree.cfg file. \n"
"Please see http://yt-project.org/docs/dev/visualizing/unstructured_mesh_rendering.html "
"for more information. \n" % rd
)
tmpdir = tempfile.mkdtemp()
curdir = os.getcwd()
try:
os.chdir(tmpdir)
compiler = os.getenv("CXX", "c++")
compiler = compiler.split(" ")
filename = r"test.cpp"
file = open(filename, "wt", 1)
CCODE = dedent("""\
#include "embree2/rtcore.h
int main() {
return 0;
}"""
)
file.write(CCODE)
file.flush()
p = Popen(
compiler + ["-I%s/include/" % rd, filename],
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
)
output, err = p.communicate()
exit_code = p.returncode
if exit_code != 0:
log.warning(
"Pyembree is installed, but I could not compile Embree test code."
)
log.warning("The error message was: ")
log.warning(err)
log.warning(fail_msg)
file.close()
except OSError:
log.warning(
"read_embree_location() could not find your C compiler. "
"Attempted to use '%s'.",
compiler,
)
return False
finally:
os.chdir(curdir)
shutil.rmtree(tmpdir)
return rd
def get_cpu_count():
if platform.system() == "Windows":
return 0
cpu_count = os.cpu_count()
try:
user_max_cores = int(os.getenv("MAX_BUILD_CORES", cpu_count))
except ValueError as e:
raise ValueError(
"MAX_BUILD_CORES must be set to an integer. "
+ "See above for original error."
) from e
max_cores = min(cpu_count, user_max_cores)
return max_cores
def install_ccompiler():
def _compile(
self,
sources,
output_dir=None,
macros=None,
include_dirs=None,
debug=0,
extra_preargs=None,
extra_postargs=None,
depends=None,
):
"""Function to monkey-patch distutils.ccompiler.CCompiler"""
macros, objects, extra_postargs, pp_opts, build = self._setup_compile(
output_dir, macros, include_dirs, sources, depends, extra_postargs
)
cc_args = self._get_cc_args(pp_opts, debug, extra_preargs)
for obj in objects:
try:
src, ext = build[obj]
except KeyError:
continue
self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
return objects
CCompiler.compile = _compile
def get_python_include_dirs():
"""Extracted from distutils.command.build_ext.build_ext.finalize_options(),
https://github.com/python/cpython/blob/812245ecce2d8344c3748228047bab456816180a/Lib/distutils/command/build_ext.py#L148-L167
"""
include_dirs = []
py_include = sysconfig.get_python_inc()
plat_py_include = sysconfig.get_python_inc(plat_specific=1)
if sys.exec_prefix != sys.base_exec_prefix:
include_dirs.append(os.path.join(sys.exec_prefix, 'include'))
include_dirs.extend(py_include.split(os.path.pathsep))
if plat_py_include != py_include:
include_dirs.extend(plat_py_include.split(os.path.pathsep))
return include_dirs
NUMPY_MACROS = [
("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION"),
("NPY_TARGET_VERSION", "NPY_1_21_API_VERSION"),
]
def create_build_ext(lib_exts, cythonize_aliases):
class build_ext(_build_ext):
def finalize_options(self):
from Cython.Build import cythonize
self.distribution.ext_modules[:] = cythonize(
lib_exts,
aliases=cythonize_aliases,
compiler_directives={"language_level": 3},
nthreads=get_cpu_count(),
)
_build_ext.finalize_options(self)
if isinstance(__builtins__, dict):
__builtins__["__NUMPY_SETUP__"] = False
else:
__builtins__.__NUMPY_SETUP__ = False
import numpy
self.include_dirs.append(numpy.get_include())
self.include_dirs.append(ewah_bool_utils.get_include())
define_macros = NUMPY_MACROS
if USE_PY_LIMITED_API:
define_macros.append(("Py_LIMITED_API", ABI3_TARGET_HEX))
for ext in self.extensions:
ext.py_limited_api = True
if self.define is None:
self.define = define_macros
else:
self.define.extend(define_macros)
def build_extensions(self):
self.check_extensions_list(self.extensions)
ncpus = get_cpu_count()
if ncpus > 0:
with ThreadPoolExecutor(ncpus) as executor:
results = {
executor.submit(self.build_extension, extension): extension
for extension in self.extensions
}
for result in results:
result.result()
else:
super().build_extensions()
def build_extension(self, extension):
try:
super().build_extension(extension)
except CompileError as exc:
print(f"While building '{extension.name}' following error was raised:\n {exc}")
raise
class sdist(_sdist):
def run(self):
from Cython.Build import cythonize
cythonize(
lib_exts,
aliases=cythonize_aliases,
compiler_directives={"language_level": 3},
nthreads=get_cpu_count(),
)
_sdist.run(self)
class bdist_wheel(_bdist_wheel):
def get_tag(self):
python, abi, plat = super().get_tag()
if python.startswith("cp") and USE_PY_LIMITED_API:
return f"cp{ABI3_TARGET_VERSION}", "abi3", plat
return python, abi, plat
return build_ext, sdist, bdist_wheel