Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yt-project
GitHub Repository: yt-project/yt
Path: blob/main/setupext.py
925 views
1
import contextlib
2
import glob
3
import logging
4
import os
5
import platform
6
import shutil
7
import subprocess
8
import sys
9
import tempfile
10
from textwrap import dedent
11
from concurrent.futures import ThreadPoolExecutor
12
from distutils import sysconfig
13
from distutils.ccompiler import CCompiler, new_compiler
14
from distutils.sysconfig import customize_compiler
15
from subprocess import PIPE, Popen
16
from sys import platform as _platform
17
import ewah_bool_utils
18
from setuptools.command.build_ext import build_ext as _build_ext
19
from setuptools.command.sdist import sdist as _sdist
20
from setuptools.errors import CompileError, LinkError
21
import importlib.resources as importlib_resources
22
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
23
24
log = logging.getLogger("setupext")
25
26
USE_PY_LIMITED_API = (
27
os.getenv('YT_LIMITED_API', '0') == '1'
28
and sys.version_info >= (3, 11)
29
and not sysconfig.get_config_var("Py_GIL_DISABLED")
30
)
31
ABI3_TARGET_VERSION = "".join(str(_) for _ in sys.version_info[:2])
32
ABI3_TARGET_HEX = hex(sys.hexversion & 0xFFFF00F0)
33
34
35
@contextlib.contextmanager
36
def stdchannel_redirected(stdchannel, dest_filename):
37
"""
38
A context manager to temporarily redirect stdout or stderr
39
40
e.g.:
41
42
with stdchannel_redirected(sys.stderr, os.devnull):
43
if compiler.has_function('clock_gettime', libraries=['rt']):
44
libraries.append('rt')
45
46
Code adapted from https://stackoverflow.com/a/17752455/1382869
47
"""
48
49
try:
50
oldstdchannel = os.dup(stdchannel.fileno())
51
dest_file = open(dest_filename, "w")
52
os.dup2(dest_file.fileno(), stdchannel.fileno())
53
54
yield
55
finally:
56
if oldstdchannel is not None:
57
os.dup2(oldstdchannel, stdchannel.fileno())
58
if dest_file is not None:
59
dest_file.close()
60
61
62
def check_for_openmp():
63
"""Returns OpenMP compiler and linker flags if local setup supports
64
OpenMP or [], [] otherwise
65
66
Code adapted from astropy_helpers, originally written by Tom
67
Robitaille and Curtis McCully.
68
"""
69
70
# Create a temporary directory
71
ccompiler = new_compiler()
72
customize_compiler(ccompiler)
73
74
tmp_dir = tempfile.mkdtemp()
75
start_dir = os.path.abspath(".")
76
77
CCODE = dedent("""\
78
#include <omp.h>
79
#include <stdio.h>
80
int main() {
81
omp_set_num_threads(2);
82
#pragma omp parallel
83
printf("nthreads=%d\\n", omp_get_num_threads());
84
return 0;
85
}"""
86
)
87
88
# TODO: test more known compilers:
89
# MinGW, AppleClang with libomp, MSVC, ICC, XL, PGI, ...
90
if os.name == "nt":
91
# TODO: make this work with mingw
92
# AFAICS there's no easy way to get the compiler distutils
93
# will be using until compilation actually happens
94
compile_flags = ["-openmp"]
95
link_flags = [""]
96
else:
97
compile_flags = ["-fopenmp"]
98
link_flags = ["-fopenmp"]
99
100
try:
101
os.chdir(tmp_dir)
102
103
with open("test_openmp.c", "w") as f:
104
f.write(CCODE)
105
106
os.mkdir("objects")
107
108
# Compile, link, and run test program
109
with stdchannel_redirected(sys.stderr, os.devnull):
110
ccompiler.compile(
111
["test_openmp.c"], output_dir="objects", extra_postargs=compile_flags
112
)
113
ccompiler.link_executable(
114
glob.glob(os.path.join("objects", "*")),
115
"test_openmp",
116
extra_postargs=link_flags,
117
)
118
output = (
119
subprocess.check_output("./test_openmp")
120
.decode(sys.stdout.encoding or "utf-8")
121
.splitlines()
122
)
123
124
if "nthreads=" in output[0]:
125
nthreads = int(output[0].strip().split("=")[1])
126
if len(output) == nthreads:
127
using_openmp = True
128
else:
129
log.warning(
130
"Unexpected number of lines from output of test "
131
"OpenMP program (output was %s)",
132
output,
133
)
134
using_openmp = False
135
else:
136
log.warning(
137
"Unexpected output from test OpenMP program (output was %s)", output
138
)
139
using_openmp = False
140
141
except (CompileError, LinkError):
142
using_openmp = False
143
finally:
144
os.chdir(start_dir)
145
146
if using_openmp:
147
log.warning("Using OpenMP to compile parallel extensions")
148
else:
149
log.warning(
150
"Unable to compile OpenMP test program so Cython\n"
151
"extensions will be compiled without parallel support"
152
)
153
154
if using_openmp:
155
return compile_flags, link_flags
156
else:
157
return [], []
158
159
160
def check_CPP14_flag(compile_flags):
161
# Create a temporary directory
162
ccompiler = new_compiler()
163
customize_compiler(ccompiler)
164
165
tmp_dir = tempfile.mkdtemp()
166
start_dir = os.path.abspath(".")
167
168
# Note: This code requires C++14 functionalities (also required to compile yt)
169
# It compiles on gcc 4.7.4 (together with the entirety of yt) with the flag "-std=gnu++0x".
170
# It does not compile on gcc 4.6.4 (neither does yt).
171
CPPCODE = dedent("""\
172
#include <vector>
173
174
struct node {
175
std::vector<int> vic;
176
bool visited = false;
177
};
178
179
int main() {
180
return 0;
181
}"""
182
)
183
184
os.chdir(tmp_dir)
185
try:
186
with open("test_cpp14.cpp", "w") as f:
187
f.write(CPPCODE)
188
189
os.mkdir("objects")
190
191
# Compile, link, and run test program
192
with stdchannel_redirected(sys.stderr, os.devnull):
193
ccompiler.compile(
194
["test_cpp14.cpp"], output_dir="objects", extra_postargs=compile_flags
195
)
196
return True
197
except CompileError:
198
return False
199
finally:
200
os.chdir(start_dir)
201
202
203
def check_CPP14_flags(possible_compile_flags):
204
for flags in possible_compile_flags:
205
if check_CPP14_flag([flags]):
206
return flags
207
208
log.warning(
209
"Your compiler seems to be too old to support C++14. "
210
"yt may not be able to compile. Please use a newer version."
211
)
212
return []
213
214
215
def check_for_pyembree(std_libs):
216
embree_libs = []
217
embree_aliases = {}
218
219
try:
220
importlib_resources.files("pyembree")
221
except ImportError:
222
return embree_libs, embree_aliases
223
224
embree_prefix = os.path.abspath(read_embree_location())
225
embree_inc_dir = os.path.join(embree_prefix, "include")
226
embree_lib_dir = os.path.join(embree_prefix, "lib")
227
228
if _platform == "darwin":
229
embree_lib_name = "embree.2"
230
else:
231
embree_lib_name = "embree"
232
233
embree_aliases["EMBREE_INC_DIR"] = ["yt/utilities/lib/", embree_inc_dir]
234
embree_aliases["EMBREE_LIB_DIR"] = [embree_lib_dir]
235
embree_aliases["EMBREE_LIBS"] = std_libs + [embree_lib_name]
236
embree_libs += ["yt/utilities/lib/embree_mesh/*.pyx"]
237
238
if in_conda_env():
239
conda_basedir = os.path.dirname(os.path.dirname(sys.executable))
240
embree_aliases["EMBREE_INC_DIR"].append(os.path.join(conda_basedir, "include"))
241
embree_aliases["EMBREE_LIB_DIR"].append(os.path.join(conda_basedir, "lib"))
242
243
return embree_libs, embree_aliases
244
245
246
def in_conda_env():
247
return any(s in sys.version for s in ("Anaconda", "Continuum", "conda-forge"))
248
249
250
def read_embree_location():
251
"""
252
253
Attempts to locate the embree installation. First, we check for an
254
EMBREE_DIR environment variable. If one is not defined, we look for
255
an embree.cfg file in the root yt source directory. Finally, if that
256
is not present, we default to /usr/local. If embree is installed in a
257
non-standard location and none of the above are set, the compile will
258
not succeed. This only gets called if check_for_pyembree() returns
259
something other than None.
260
261
"""
262
263
rd = os.environ.get("EMBREE_DIR")
264
if rd is None:
265
try:
266
rd = open("embree.cfg").read().strip()
267
except IOError:
268
rd = "/usr/local"
269
270
fail_msg = (
271
"I attempted to find Embree headers in %s. \n"
272
"If this is not correct, please set your correct embree location \n"
273
"using EMBREE_DIR environment variable or your embree.cfg file. \n"
274
"Please see http://yt-project.org/docs/dev/visualizing/unstructured_mesh_rendering.html "
275
"for more information. \n" % rd
276
)
277
278
# Create a temporary directory
279
tmpdir = tempfile.mkdtemp()
280
curdir = os.getcwd()
281
282
try:
283
os.chdir(tmpdir)
284
285
# Get compiler invocation
286
compiler = os.getenv("CXX", "c++")
287
compiler = compiler.split(" ")
288
289
# Attempt to compile a test script.
290
filename = r"test.cpp"
291
file = open(filename, "wt", 1)
292
CCODE = dedent("""\
293
#include "embree2/rtcore.h
294
int main() {
295
return 0;
296
}"""
297
)
298
file.write(CCODE)
299
file.flush()
300
p = Popen(
301
compiler + ["-I%s/include/" % rd, filename],
302
stdin=PIPE,
303
stdout=PIPE,
304
stderr=PIPE,
305
)
306
output, err = p.communicate()
307
exit_code = p.returncode
308
309
if exit_code != 0:
310
log.warning(
311
"Pyembree is installed, but I could not compile Embree test code."
312
)
313
log.warning("The error message was: ")
314
log.warning(err)
315
log.warning(fail_msg)
316
317
# Clean up
318
file.close()
319
320
except OSError:
321
log.warning(
322
"read_embree_location() could not find your C compiler. "
323
"Attempted to use '%s'.",
324
compiler,
325
)
326
return False
327
328
finally:
329
os.chdir(curdir)
330
shutil.rmtree(tmpdir)
331
332
return rd
333
334
335
def get_cpu_count():
336
if platform.system() == "Windows":
337
return 0
338
339
cpu_count = os.cpu_count()
340
try:
341
user_max_cores = int(os.getenv("MAX_BUILD_CORES", cpu_count))
342
except ValueError as e:
343
raise ValueError(
344
"MAX_BUILD_CORES must be set to an integer. "
345
+ "See above for original error."
346
) from e
347
max_cores = min(cpu_count, user_max_cores)
348
return max_cores
349
350
351
def install_ccompiler():
352
def _compile(
353
self,
354
sources,
355
output_dir=None,
356
macros=None,
357
include_dirs=None,
358
debug=0,
359
extra_preargs=None,
360
extra_postargs=None,
361
depends=None,
362
):
363
"""Function to monkey-patch distutils.ccompiler.CCompiler"""
364
macros, objects, extra_postargs, pp_opts, build = self._setup_compile(
365
output_dir, macros, include_dirs, sources, depends, extra_postargs
366
)
367
cc_args = self._get_cc_args(pp_opts, debug, extra_preargs)
368
369
for obj in objects:
370
try:
371
src, ext = build[obj]
372
except KeyError:
373
continue
374
self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
375
376
# Return *all* object filenames, not just the ones we just built.
377
return objects
378
379
CCompiler.compile = _compile
380
381
382
def get_python_include_dirs():
383
"""Extracted from distutils.command.build_ext.build_ext.finalize_options(),
384
https://github.com/python/cpython/blob/812245ecce2d8344c3748228047bab456816180a/Lib/distutils/command/build_ext.py#L148-L167
385
"""
386
include_dirs = []
387
388
# Make sure Python's include directories (for Python.h, pyconfig.h,
389
# etc.) are in the include search path.
390
py_include = sysconfig.get_python_inc()
391
plat_py_include = sysconfig.get_python_inc(plat_specific=1)
392
393
# If in a virtualenv, add its include directory
394
# Issue 16116
395
if sys.exec_prefix != sys.base_exec_prefix:
396
include_dirs.append(os.path.join(sys.exec_prefix, 'include'))
397
398
# Put the Python "system" include dir at the end, so that
399
# any local include dirs take precedence.
400
include_dirs.extend(py_include.split(os.path.pathsep))
401
if plat_py_include != py_include:
402
include_dirs.extend(plat_py_include.split(os.path.pathsep))
403
404
return include_dirs
405
406
407
NUMPY_MACROS = [
408
("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION"),
409
# keep in sync with runtime requirements (pyproject.toml)
410
("NPY_TARGET_VERSION", "NPY_1_21_API_VERSION"),
411
]
412
413
414
def create_build_ext(lib_exts, cythonize_aliases):
415
class build_ext(_build_ext):
416
# subclass setuptools extension builder to avoid importing cython and numpy
417
# at top level in setup.py. See http://stackoverflow.com/a/21621689/1382869
418
# NOTE: this is likely not necessary anymore since
419
# pyproject.toml was introduced in the project
420
421
def finalize_options(self):
422
from Cython.Build import cythonize
423
424
# Override the list of extension modules
425
self.distribution.ext_modules[:] = cythonize(
426
lib_exts,
427
aliases=cythonize_aliases,
428
compiler_directives={"language_level": 3},
429
nthreads=get_cpu_count(),
430
)
431
_build_ext.finalize_options(self)
432
# Prevent numpy from thinking it is still in its setup process
433
# see http://stackoverflow.com/a/21621493/1382869
434
if isinstance(__builtins__, dict):
435
# sometimes this is a dict so we need to check for that
436
# https://docs.python.org/3/library/builtins.html
437
__builtins__["__NUMPY_SETUP__"] = False
438
else:
439
__builtins__.__NUMPY_SETUP__ = False
440
import numpy
441
442
self.include_dirs.append(numpy.get_include())
443
self.include_dirs.append(ewah_bool_utils.get_include())
444
445
define_macros = NUMPY_MACROS
446
if USE_PY_LIMITED_API:
447
define_macros.append(("Py_LIMITED_API", ABI3_TARGET_HEX))
448
for ext in self.extensions:
449
ext.py_limited_api = True
450
451
if self.define is None:
452
self.define = define_macros
453
else:
454
self.define.extend(define_macros)
455
456
def build_extensions(self):
457
self.check_extensions_list(self.extensions)
458
459
ncpus = get_cpu_count()
460
if ncpus > 0:
461
with ThreadPoolExecutor(ncpus) as executor:
462
results = {
463
executor.submit(self.build_extension, extension): extension
464
for extension in self.extensions
465
}
466
for result in results:
467
result.result()
468
else:
469
super().build_extensions()
470
471
def build_extension(self, extension):
472
try:
473
super().build_extension(extension)
474
except CompileError as exc:
475
print(f"While building '{extension.name}' following error was raised:\n {exc}")
476
raise
477
478
class sdist(_sdist):
479
# subclass setuptools source distribution builder to ensure cython
480
# generated C files are included in source distribution.
481
# See http://stackoverflow.com/a/18418524/1382869
482
def run(self):
483
# Make sure the compiled Cython files in the distribution are up-to-date
484
from Cython.Build import cythonize
485
486
cythonize(
487
lib_exts,
488
aliases=cythonize_aliases,
489
compiler_directives={"language_level": 3},
490
nthreads=get_cpu_count(),
491
)
492
_sdist.run(self)
493
494
class bdist_wheel(_bdist_wheel):
495
def get_tag(self):
496
python, abi, plat = super().get_tag()
497
498
if python.startswith("cp") and USE_PY_LIMITED_API:
499
return f"cp{ABI3_TARGET_VERSION}", "abi3", plat
500
501
return python, abi, plat
502
503
return build_ext, sdist, bdist_wheel
504
505