Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yt-project
GitHub Repository: yt-project/yt
Path: blob/main/setupext.py
1397 views
1
import contextlib
2
import glob
3
import logging
4
import os
5
import platform
6
import shutil
7
import subprocess
8
import sys
9
import tempfile
10
from textwrap import dedent
11
from concurrent.futures import ThreadPoolExecutor
12
from distutils import sysconfig
13
from distutils.ccompiler import CCompiler, new_compiler
14
from distutils.sysconfig import customize_compiler
15
from subprocess import PIPE, Popen
16
from sys import platform as _platform
17
import ewah_bool_utils
18
from setuptools.command.build_ext import build_ext as _build_ext
19
from setuptools.command.sdist import sdist as _sdist
20
from setuptools.errors import CompileError, LinkError
21
import importlib.resources as importlib_resources
22
23
log = logging.getLogger("setupext")
24
25
USE_PY_LIMITED_API = (
26
os.getenv('YT_LIMITED_API', '0') == '1'
27
and sys.version_info >= (3, 11)
28
and not sysconfig.get_config_var("Py_GIL_DISABLED")
29
)
30
ABI3_TARGET_VERSION = "".join(str(_) for _ in sys.version_info[:2])
31
ABI3_TARGET_HEX = hex(sys.hexversion & 0xFFFF00F0)
32
33
34
@contextlib.contextmanager
35
def stdchannel_redirected(stdchannel, dest_filename):
36
"""
37
A context manager to temporarily redirect stdout or stderr
38
39
e.g.:
40
41
with stdchannel_redirected(sys.stderr, os.devnull):
42
if compiler.has_function('clock_gettime', libraries=['rt']):
43
libraries.append('rt')
44
45
Code adapted from https://stackoverflow.com/a/17752455/1382869
46
"""
47
48
try:
49
oldstdchannel = os.dup(stdchannel.fileno())
50
dest_file = open(dest_filename, "w")
51
os.dup2(dest_file.fileno(), stdchannel.fileno())
52
53
yield
54
finally:
55
if oldstdchannel is not None:
56
os.dup2(oldstdchannel, stdchannel.fileno())
57
if dest_file is not None:
58
dest_file.close()
59
60
61
def check_for_openmp():
62
"""Returns OpenMP compiler and linker flags if local setup supports
63
OpenMP or [], [] otherwise
64
65
Code adapted from astropy_helpers, originally written by Tom
66
Robitaille and Curtis McCully.
67
"""
68
69
# Create a temporary directory
70
ccompiler = new_compiler()
71
customize_compiler(ccompiler)
72
73
tmp_dir = tempfile.mkdtemp()
74
start_dir = os.path.abspath(".")
75
76
CCODE = dedent("""\
77
#include <omp.h>
78
#include <stdio.h>
79
int main() {
80
omp_set_num_threads(2);
81
#pragma omp parallel
82
printf("nthreads=%d\\n", omp_get_num_threads());
83
return 0;
84
}"""
85
)
86
87
# TODO: test more known compilers:
88
# MinGW, AppleClang with libomp, MSVC, ICC, XL, PGI, ...
89
if os.name == "nt":
90
# TODO: make this work with mingw
91
# AFAICS there's no easy way to get the compiler distutils
92
# will be using until compilation actually happens
93
compile_flags = ["-openmp"]
94
link_flags = [""]
95
else:
96
compile_flags = ["-fopenmp"]
97
link_flags = ["-fopenmp"]
98
99
try:
100
os.chdir(tmp_dir)
101
102
with open("test_openmp.c", "w") as f:
103
f.write(CCODE)
104
105
os.mkdir("objects")
106
107
# Compile, link, and run test program
108
with stdchannel_redirected(sys.stderr, os.devnull):
109
ccompiler.compile(
110
["test_openmp.c"], output_dir="objects", extra_postargs=compile_flags
111
)
112
ccompiler.link_executable(
113
glob.glob(os.path.join("objects", "*")),
114
"test_openmp",
115
extra_postargs=link_flags,
116
)
117
output = (
118
subprocess.check_output("./test_openmp")
119
.decode(sys.stdout.encoding or "utf-8")
120
.splitlines()
121
)
122
123
if "nthreads=" in output[0]:
124
nthreads = int(output[0].strip().split("=")[1])
125
if len(output) == nthreads:
126
using_openmp = True
127
else:
128
log.warning(
129
"Unexpected number of lines from output of test "
130
"OpenMP program (output was %s)",
131
output,
132
)
133
using_openmp = False
134
else:
135
log.warning(
136
"Unexpected output from test OpenMP program (output was %s)", output
137
)
138
using_openmp = False
139
140
except (CompileError, LinkError):
141
using_openmp = False
142
finally:
143
os.chdir(start_dir)
144
145
if using_openmp:
146
log.warning("Using OpenMP to compile parallel extensions")
147
else:
148
log.warning(
149
"Unable to compile OpenMP test program so Cython\n"
150
"extensions will be compiled without parallel support"
151
)
152
153
if using_openmp:
154
return compile_flags, link_flags
155
else:
156
return [], []
157
158
159
def check_CPP14_flag(compile_flags):
160
# Create a temporary directory
161
ccompiler = new_compiler()
162
customize_compiler(ccompiler)
163
164
tmp_dir = tempfile.mkdtemp()
165
start_dir = os.path.abspath(".")
166
167
# Note: This code requires C++14 functionalities (also required to compile yt)
168
# It compiles on gcc 4.7.4 (together with the entirety of yt) with the flag "-std=gnu++0x".
169
# It does not compile on gcc 4.6.4 (neither does yt).
170
CPPCODE = dedent("""\
171
#include <vector>
172
173
struct node {
174
std::vector<int> vic;
175
bool visited = false;
176
};
177
178
int main() {
179
return 0;
180
}"""
181
)
182
183
os.chdir(tmp_dir)
184
try:
185
with open("test_cpp14.cpp", "w") as f:
186
f.write(CPPCODE)
187
188
os.mkdir("objects")
189
190
# Compile, link, and run test program
191
with stdchannel_redirected(sys.stderr, os.devnull):
192
ccompiler.compile(
193
["test_cpp14.cpp"], output_dir="objects", extra_postargs=compile_flags
194
)
195
return True
196
except CompileError:
197
return False
198
finally:
199
os.chdir(start_dir)
200
201
202
def check_CPP14_flags(possible_compile_flags):
203
for flags in possible_compile_flags:
204
if check_CPP14_flag([flags]):
205
return flags
206
207
log.warning(
208
"Your compiler seems to be too old to support C++14. "
209
"yt may not be able to compile. Please use a newer version."
210
)
211
return []
212
213
214
def check_for_pyembree(std_libs):
215
embree_libs = []
216
embree_aliases = {}
217
218
try:
219
importlib_resources.files("pyembree")
220
except ImportError:
221
return embree_libs, embree_aliases
222
223
embree_prefix = os.path.abspath(read_embree_location())
224
embree_inc_dir = os.path.join(embree_prefix, "include")
225
embree_lib_dir = os.path.join(embree_prefix, "lib")
226
227
if _platform == "darwin":
228
embree_lib_name = "embree.2"
229
else:
230
embree_lib_name = "embree"
231
232
embree_aliases["EMBREE_INC_DIR"] = ["yt/utilities/lib/", embree_inc_dir]
233
embree_aliases["EMBREE_LIB_DIR"] = [embree_lib_dir]
234
embree_aliases["EMBREE_LIBS"] = std_libs + [embree_lib_name]
235
embree_libs += ["yt/utilities/lib/embree_mesh/*.pyx"]
236
237
if in_conda_env():
238
conda_basedir = os.path.dirname(os.path.dirname(sys.executable))
239
embree_aliases["EMBREE_INC_DIR"].append(os.path.join(conda_basedir, "include"))
240
embree_aliases["EMBREE_LIB_DIR"].append(os.path.join(conda_basedir, "lib"))
241
242
return embree_libs, embree_aliases
243
244
245
def in_conda_env():
246
return any(s in sys.version for s in ("Anaconda", "Continuum", "conda-forge"))
247
248
249
def read_embree_location():
250
"""
251
252
Attempts to locate the embree installation. First, we check for an
253
EMBREE_DIR environment variable. If one is not defined, we look for
254
an embree.cfg file in the root yt source directory. Finally, if that
255
is not present, we default to /usr/local. If embree is installed in a
256
non-standard location and none of the above are set, the compile will
257
not succeed. This only gets called if check_for_pyembree() returns
258
something other than None.
259
260
"""
261
262
rd = os.environ.get("EMBREE_DIR")
263
if rd is None:
264
try:
265
rd = open("embree.cfg").read().strip()
266
except IOError:
267
rd = "/usr/local"
268
269
fail_msg = (
270
"I attempted to find Embree headers in %s. \n"
271
"If this is not correct, please set your correct embree location \n"
272
"using EMBREE_DIR environment variable or your embree.cfg file. \n"
273
"Please see http://yt-project.org/docs/dev/visualizing/unstructured_mesh_rendering.html "
274
"for more information. \n" % rd
275
)
276
277
# Create a temporary directory
278
tmpdir = tempfile.mkdtemp()
279
curdir = os.getcwd()
280
281
try:
282
os.chdir(tmpdir)
283
284
# Get compiler invocation
285
compiler = os.getenv("CXX", "c++")
286
compiler = compiler.split(" ")
287
288
# Attempt to compile a test script.
289
filename = r"test.cpp"
290
file = open(filename, "wt", 1)
291
CCODE = dedent("""\
292
#include "embree2/rtcore.h
293
int main() {
294
return 0;
295
}"""
296
)
297
file.write(CCODE)
298
file.flush()
299
p = Popen(
300
compiler + ["-I%s/include/" % rd, filename],
301
stdin=PIPE,
302
stdout=PIPE,
303
stderr=PIPE,
304
)
305
output, err = p.communicate()
306
exit_code = p.returncode
307
308
if exit_code != 0:
309
log.warning(
310
"Pyembree is installed, but I could not compile Embree test code."
311
)
312
log.warning("The error message was: ")
313
log.warning(err)
314
log.warning(fail_msg)
315
316
# Clean up
317
file.close()
318
319
except OSError:
320
log.warning(
321
"read_embree_location() could not find your C compiler. "
322
"Attempted to use '%s'.",
323
compiler,
324
)
325
return False
326
327
finally:
328
os.chdir(curdir)
329
shutil.rmtree(tmpdir)
330
331
return rd
332
333
334
def get_cpu_count():
335
if platform.system() == "Windows":
336
return 0
337
338
cpu_count = os.cpu_count()
339
try:
340
user_max_cores = int(os.getenv("MAX_BUILD_CORES", cpu_count))
341
except ValueError as e:
342
raise ValueError(
343
"MAX_BUILD_CORES must be set to an integer. "
344
+ "See above for original error."
345
) from e
346
max_cores = min(cpu_count, user_max_cores)
347
return max_cores
348
349
350
def install_ccompiler():
351
def _compile(
352
self,
353
sources,
354
output_dir=None,
355
macros=None,
356
include_dirs=None,
357
debug=0,
358
extra_preargs=None,
359
extra_postargs=None,
360
depends=None,
361
):
362
"""Function to monkey-patch distutils.ccompiler.CCompiler"""
363
macros, objects, extra_postargs, pp_opts, build = self._setup_compile(
364
output_dir, macros, include_dirs, sources, depends, extra_postargs
365
)
366
cc_args = self._get_cc_args(pp_opts, debug, extra_preargs)
367
368
for obj in objects:
369
try:
370
src, ext = build[obj]
371
except KeyError:
372
continue
373
self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
374
375
# Return *all* object filenames, not just the ones we just built.
376
return objects
377
378
CCompiler.compile = _compile
379
380
381
def get_python_include_dirs():
382
"""Extracted from distutils.command.build_ext.build_ext.finalize_options(),
383
https://github.com/python/cpython/blob/812245ecce2d8344c3748228047bab456816180a/Lib/distutils/command/build_ext.py#L148-L167
384
"""
385
include_dirs = []
386
387
# Make sure Python's include directories (for Python.h, pyconfig.h,
388
# etc.) are in the include search path.
389
py_include = sysconfig.get_python_inc()
390
plat_py_include = sysconfig.get_python_inc(plat_specific=1)
391
392
# If in a virtualenv, add its include directory
393
# Issue 16116
394
if sys.exec_prefix != sys.base_exec_prefix:
395
include_dirs.append(os.path.join(sys.exec_prefix, 'include'))
396
397
# Put the Python "system" include dir at the end, so that
398
# any local include dirs take precedence.
399
include_dirs.extend(py_include.split(os.path.pathsep))
400
if plat_py_include != py_include:
401
include_dirs.extend(plat_py_include.split(os.path.pathsep))
402
403
return include_dirs
404
405
406
NUMPY_MACROS = [
407
("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION"),
408
# keep in sync with runtime requirements (pyproject.toml)
409
("NPY_TARGET_VERSION", "NPY_1_21_API_VERSION"),
410
]
411
412
413
def create_build_ext(lib_exts, cythonize_aliases):
414
class build_ext(_build_ext):
415
# subclass setuptools extension builder to avoid importing cython and numpy
416
# at top level in setup.py. See http://stackoverflow.com/a/21621689/1382869
417
# NOTE: this is likely not necessary anymore since
418
# pyproject.toml was introduced in the project
419
420
def finalize_options(self):
421
from Cython.Build import cythonize
422
423
# Override the list of extension modules
424
self.distribution.ext_modules[:] = cythonize(
425
lib_exts,
426
aliases=cythonize_aliases,
427
compiler_directives={"language_level": 3},
428
nthreads=get_cpu_count(),
429
)
430
_build_ext.finalize_options(self)
431
# Prevent numpy from thinking it is still in its setup process
432
# see http://stackoverflow.com/a/21621493/1382869
433
if isinstance(__builtins__, dict):
434
# sometimes this is a dict so we need to check for that
435
# https://docs.python.org/3/library/builtins.html
436
__builtins__["__NUMPY_SETUP__"] = False
437
else:
438
__builtins__.__NUMPY_SETUP__ = False
439
import numpy
440
441
self.include_dirs.append(numpy.get_include())
442
self.include_dirs.append(ewah_bool_utils.get_include())
443
444
define_macros = NUMPY_MACROS
445
if USE_PY_LIMITED_API:
446
define_macros.append(("Py_LIMITED_API", ABI3_TARGET_HEX))
447
for ext in self.extensions:
448
ext.py_limited_api = True
449
450
if self.define is None:
451
self.define = define_macros
452
else:
453
self.define.extend(define_macros)
454
455
def build_extensions(self):
456
self.check_extensions_list(self.extensions)
457
458
ncpus = get_cpu_count()
459
if ncpus > 0:
460
with ThreadPoolExecutor(ncpus) as executor:
461
results = {
462
executor.submit(self.build_extension, extension): extension
463
for extension in self.extensions
464
}
465
for result in results:
466
result.result()
467
else:
468
super().build_extensions()
469
470
def build_extension(self, extension):
471
try:
472
super().build_extension(extension)
473
except CompileError as exc:
474
print(f"While building '{extension.name}' following error was raised:\n {exc}")
475
raise
476
477
class sdist(_sdist):
478
# subclass setuptools source distribution builder to ensure cython
479
# generated C files are included in source distribution.
480
# See http://stackoverflow.com/a/18418524/1382869
481
def run(self):
482
# Make sure the compiled Cython files in the distribution are up-to-date
483
from Cython.Build import cythonize
484
485
cythonize(
486
lib_exts,
487
aliases=cythonize_aliases,
488
compiler_directives={"language_level": 3},
489
nthreads=get_cpu_count(),
490
)
491
_sdist.run(self)
492
493
return build_ext, sdist
494
495
def get_setup_options():
496
if USE_PY_LIMITED_API:
497
return {"bdist_wheel": {"py_limited_api": f"cp{ABI3_TARGET_VERSION}"}}
498
else:
499
return {}
500
501