Path: blob/master/src/utils/custom_ops.py
809 views
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.1#2# NVIDIA CORPORATION and its licensors retain all intellectual property3# and proprietary rights in and to this software, related documentation4# and any modifications thereto. Any use, reproduction, disclosure or5# distribution of this software and related documentation without an express6# license agreement from NVIDIA CORPORATION is strictly prohibited.78import glob9import hashlib10import importlib11import os12import re13import shutil14import uuid1516import torch17import torch.utils.cpp_extension18from torch.utils.file_baton import FileBaton1920#----------------------------------------------------------------------------21# Global options.2223verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'2425#----------------------------------------------------------------------------26# Internal helper funcs.2728def _find_compiler_bindir():29patterns = [30'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',31'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',32'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',33'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',34]35for pattern in patterns:36matches = sorted(glob.glob(pattern))37if len(matches):38return matches[-1]39return None4041#----------------------------------------------------------------------------4243def _get_mangled_gpu_name():44name = torch.cuda.get_device_name().lower()45out = []46for c in name:47if re.match('[a-z0-9_-]+', c):48out.append(c)49else:50out.append('-')51return ''.join(out)5253#----------------------------------------------------------------------------54# Main entry point for compiling and loading C++/CUDA plugins.5556_cached_plugins = dict()5758def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):59assert verbosity in ['none', 'brief', 'full']60if headers is None:61headers = []62if source_dir is not None:63sources = [os.path.join(source_dir, fname) for fname in sources]64headers = [os.path.join(source_dir, fname) for fname in headers]6566# Already cached?67if module_name in _cached_plugins:68return _cached_plugins[module_name]6970# Print status.71if verbosity == 'full':72print(f'Setting up PyTorch plugin "{module_name}"...')73elif verbosity == 'brief':74print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)75verbose_build = (verbosity == 'full')7677# Compile and load.78try: # pylint: disable=too-many-nested-blocks79# Make sure we can find the necessary compiler binaries.80if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:81compiler_bindir = _find_compiler_bindir()82if compiler_bindir is None:83raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')84os.environ['PATH'] += ';' + compiler_bindir8586# Some containers set TORCH_CUDA_ARCH_LIST to a list that can either87# break the build or unnecessarily restrict what's available to nvcc.88# Unset it to let nvcc decide based on what's available on the89# machine.90os.environ['TORCH_CUDA_ARCH_LIST'] = ''9192# Incremental build md5sum trickery. Copies all the input source files93# into a cached build directory under a combined md5 digest of the input94# source files. Copying is done only if the combined digest has changed.95# This keeps input file timestamps and filenames the same as in previous96# extension builds, allowing for fast incremental rebuilds.97#98# This optimization is done only in case all the source files reside in99# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR100# environment variable is set (we take this as a signal that the user101# actually cares about this.)102#103# EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work104# around the *.cu dependency bug in ninja config.105#106all_source_files = sorted(sources + headers)107all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)108if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):109110# Compute combined hash digest for all source files.111hash_md5 = hashlib.md5()112for src in all_source_files:113with open(src, 'rb') as f:114hash_md5.update(f.read())115116# Select cached build directory name.117source_digest = hash_md5.hexdigest()118build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access119cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')120121if not os.path.isdir(cached_build_dir):122tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'123os.makedirs(tmpdir)124for src in all_source_files:125shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))126try:127os.replace(tmpdir, cached_build_dir) # atomic128except OSError:129# source directory already exists, delete tmpdir and its contents.130shutil.rmtree(tmpdir)131if not os.path.isdir(cached_build_dir): raise132133# Compile.134cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]135torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,136verbose=verbose_build, sources=cached_sources, **build_kwargs)137else:138torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)139140# Load.141module = importlib.import_module(module_name)142143except:144if verbosity == 'brief':145print('Failed!')146raise147148# Print status and add to cache dict.149if verbosity == 'full':150print(f'Done setting up PyTorch plugin "{module_name}".')151elif verbosity == 'brief':152print('Done.')153_cached_plugins[module_name] = module154return module155156#----------------------------------------------------------------------------157158159