Path: blob/master/src/utils/style_ops/dnnlib/util.py
809 views
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.1#2# NVIDIA CORPORATION and its licensors retain all intellectual property3# and proprietary rights in and to this software, related documentation4# and any modifications thereto. Any use, reproduction, disclosure or5# distribution of this software and related documentation without an express6# license agreement from NVIDIA CORPORATION is strictly prohibited.78"""Miscellaneous utility classes and functions."""910import ctypes11import fnmatch12import importlib13import inspect14import numpy as np15import os16import shutil17import sys18import types19import io20import pickle21import re22import requests23import html24import hashlib25import glob26import tempfile27import urllib28import urllib.request29import uuid3031from distutils.util import strtobool32from typing import Any, List, Tuple, Union333435# Util classes36# ------------------------------------------------------------------------------------------373839class EasyDict(dict):40"""Convenience class that behaves like a dict but allows access with the attribute syntax."""4142def __getattr__(self, name: str) -> Any:43try:44return self[name]45except KeyError:46raise AttributeError(name)4748def __setattr__(self, name: str, value: Any) -> None:49self[name] = value5051def __delattr__(self, name: str) -> None:52del self[name]535455class Logger(object):56"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""5758def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):59self.file = None6061if file_name is not None:62self.file = open(file_name, file_mode)6364self.should_flush = should_flush65self.stdout = sys.stdout66self.stderr = sys.stderr6768sys.stdout = self69sys.stderr = self7071def __enter__(self) -> "Logger":72return self7374def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:75self.close()7677def write(self, text: Union[str, bytes]) -> None:78"""Write text to stdout (and a file) and optionally flush."""79if isinstance(text, bytes):80text = text.decode()81if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash82return8384if self.file is not None:85self.file.write(text)8687self.stdout.write(text)8889if self.should_flush:90self.flush()9192def flush(self) -> None:93"""Flush written text to both stdout and a file, if open."""94if self.file is not None:95self.file.flush()9697self.stdout.flush()9899def close(self) -> None:100"""Flush, close possible files, and remove stdout/stderr mirroring."""101self.flush()102103# if using multiple loggers, prevent closing in wrong order104if sys.stdout is self:105sys.stdout = self.stdout106if sys.stderr is self:107sys.stderr = self.stderr108109if self.file is not None:110self.file.close()111self.file = None112113114# Cache directories115# ------------------------------------------------------------------------------------------116117_dnnlib_cache_dir = None118119def set_cache_dir(path: str) -> None:120global _dnnlib_cache_dir121_dnnlib_cache_dir = path122123def make_cache_dir_path(*paths: str) -> str:124if _dnnlib_cache_dir is not None:125return os.path.join(_dnnlib_cache_dir, *paths)126if 'DNNLIB_CACHE_DIR' in os.environ:127return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)128if 'HOME' in os.environ:129return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)130if 'USERPROFILE' in os.environ:131return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)132return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)133134# Small util functions135# ------------------------------------------------------------------------------------------136137138def format_time(seconds: Union[int, float]) -> str:139"""Convert the seconds to human readable string with days, hours, minutes and seconds."""140s = int(np.rint(seconds))141142if s < 60:143return "{0}s".format(s)144elif s < 60 * 60:145return "{0}m {1:02}s".format(s // 60, s % 60)146elif s < 24 * 60 * 60:147return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)148else:149return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)150151152def ask_yes_no(question: str) -> bool:153"""Ask the user the question until the user inputs a valid answer."""154while True:155try:156print("{0} [y/n]".format(question))157return strtobool(input().lower())158except ValueError:159pass160161162def tuple_product(t: Tuple) -> Any:163"""Calculate the product of the tuple elements."""164result = 1165166for v in t:167result *= v168169return result170171172_str_to_ctype = {173"uint8": ctypes.c_ubyte,174"uint16": ctypes.c_uint16,175"uint32": ctypes.c_uint32,176"uint64": ctypes.c_uint64,177"int8": ctypes.c_byte,178"int16": ctypes.c_int16,179"int32": ctypes.c_int32,180"int64": ctypes.c_int64,181"float32": ctypes.c_float,182"float64": ctypes.c_double183}184185186def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:187"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""188type_str = None189190if isinstance(type_obj, str):191type_str = type_obj192elif hasattr(type_obj, "__name__"):193type_str = type_obj.__name__194elif hasattr(type_obj, "name"):195type_str = type_obj.name196else:197raise RuntimeError("Cannot infer type name from input")198199assert type_str in _str_to_ctype.keys()200201my_dtype = np.dtype(type_str)202my_ctype = _str_to_ctype[type_str]203204assert my_dtype.itemsize == ctypes.sizeof(my_ctype)205206return my_dtype, my_ctype207208209def is_pickleable(obj: Any) -> bool:210try:211with io.BytesIO() as stream:212pickle.dump(obj, stream)213return True214except:215return False216217218# Functionality to import modules/objects by name, and call functions by name219# ------------------------------------------------------------------------------------------220221def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:222"""Searches for the underlying module behind the name to some python object.223Returns the module and the object name (original name with module part removed)."""224225# allow convenience shorthands, substitute them by full names226obj_name = re.sub("^np.", "numpy.", obj_name)227obj_name = re.sub("^tf.", "tensorflow.", obj_name)228229# list alternatives for (module_name, local_obj_name)230parts = obj_name.split(".")231name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]232233# try each alternative in turn234for module_name, local_obj_name in name_pairs:235try:236module = importlib.import_module(module_name) # may raise ImportError237get_obj_from_module(module, local_obj_name) # may raise AttributeError238return module, local_obj_name239except:240pass241242# maybe some of the modules themselves contain errors?243for module_name, _local_obj_name in name_pairs:244try:245importlib.import_module(module_name) # may raise ImportError246except ImportError:247if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):248raise249250# maybe the requested attribute is missing?251for module_name, local_obj_name in name_pairs:252try:253module = importlib.import_module(module_name) # may raise ImportError254get_obj_from_module(module, local_obj_name) # may raise AttributeError255except ImportError:256pass257258# we are out of luck, but we have no idea why259raise ImportError(obj_name)260261262def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:263"""Traverses the object name and returns the last (rightmost) python object."""264if obj_name == '':265return module266obj = module267for part in obj_name.split("."):268obj = getattr(obj, part)269return obj270271272def get_obj_by_name(name: str) -> Any:273"""Finds the python object with the given name."""274module, obj_name = get_module_from_obj_name(name)275return get_obj_from_module(module, obj_name)276277278def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:279"""Finds the python object with the given name and calls it as a function."""280assert func_name is not None281func_obj = get_obj_by_name(func_name)282assert callable(func_obj)283return func_obj(*args, **kwargs)284285286def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:287"""Finds the python class with the given name and constructs it with the given arguments."""288return call_func_by_name(*args, func_name=class_name, **kwargs)289290291def get_module_dir_by_obj_name(obj_name: str) -> str:292"""Get the directory path of the module containing the given object name."""293module, _ = get_module_from_obj_name(obj_name)294return os.path.dirname(inspect.getfile(module))295296297def is_top_level_function(obj: Any) -> bool:298"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""299return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__300301302def get_top_level_function_name(obj: Any) -> str:303"""Return the fully-qualified name of a top-level function."""304assert is_top_level_function(obj)305module = obj.__module__306if module == '__main__':307module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]308return module + "." + obj.__name__309310311# File system helpers312# ------------------------------------------------------------------------------------------313314def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:315"""List all files recursively in a given directory while ignoring given file and directory names.316Returns list of tuples containing both absolute and relative paths."""317assert os.path.isdir(dir_path)318base_name = os.path.basename(os.path.normpath(dir_path))319320if ignores is None:321ignores = []322323result = []324325for root, dirs, files in os.walk(dir_path, topdown=True):326for ignore_ in ignores:327dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]328329# dirs need to be edited in-place330for d in dirs_to_remove:331dirs.remove(d)332333files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]334335absolute_paths = [os.path.join(root, f) for f in files]336relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]337338if add_base_to_relative:339relative_paths = [os.path.join(base_name, p) for p in relative_paths]340341assert len(absolute_paths) == len(relative_paths)342result += zip(absolute_paths, relative_paths)343344return result345346347def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:348"""Takes in a list of tuples of (src, dst) paths and copies files.349Will create all necessary directories."""350for file in files:351target_dir_name = os.path.dirname(file[1])352353# will create all intermediate-level directories354if not os.path.exists(target_dir_name):355os.makedirs(target_dir_name)356357shutil.copyfile(file[0], file[1])358359360# URL helpers361# ------------------------------------------------------------------------------------------362363def is_url(obj: Any, allow_file_urls: bool = False) -> bool:364"""Determine whether the given object is a valid URL string."""365if not isinstance(obj, str) or not "://" in obj:366return False367if allow_file_urls and obj.startswith('file://'):368return True369try:370res = requests.compat.urlparse(obj)371if not res.scheme or not res.netloc or not "." in res.netloc:372return False373res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))374if not res.scheme or not res.netloc or not "." in res.netloc:375return False376except:377return False378return True379380381def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:382"""Download the given URL and return a binary-mode file object to access the data."""383assert num_attempts >= 1384assert not (return_filename and (not cache))385386# Doesn't look like an URL scheme so interpret it as a local filename.387if not re.match('^[a-z]+://', url):388return url if return_filename else open(url, "rb")389390# Handle file URLs. This code handles unusual file:// patterns that391# arise on Windows:392#393# file:///c:/foo.txt394#395# which would translate to a local '/c:/foo.txt' filename that's396# invalid. Drop the forward slash for such pathnames.397#398# If you touch this code path, you should test it on both Linux and399# Windows.400#401# Some internet resources suggest using urllib.request.url2pathname() but402# but that converts forward slashes to backslashes and this causes403# its own set of problems.404if url.startswith('file://'):405filename = urllib.parse.urlparse(url).path406if re.match(r'^/[a-zA-Z]:', filename):407filename = filename[1:]408return filename if return_filename else open(filename, "rb")409410assert is_url(url)411412# Lookup from cache.413if cache_dir is None:414cache_dir = make_cache_dir_path('downloads')415416url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()417if cache:418cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))419if len(cache_files) == 1:420filename = cache_files[0]421return filename if return_filename else open(filename, "rb")422423# Download.424url_name = None425url_data = None426with requests.Session() as session:427if verbose:428print("Downloading %s ..." % url, end="", flush=True)429for attempts_left in reversed(range(num_attempts)):430try:431with session.get(url) as res:432res.raise_for_status()433if len(res.content) == 0:434raise IOError("No data received")435436if len(res.content) < 8192:437content_str = res.content.decode("utf-8")438if "download_warning" in res.headers.get("Set-Cookie", ""):439links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]440if len(links) == 1:441url = requests.compat.urljoin(url, links[0])442raise IOError("Google Drive virus checker nag")443if "Google Drive - Quota exceeded" in content_str:444raise IOError("Google Drive download quota exceeded -- please try again later")445446match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))447url_name = match[1] if match else url448url_data = res.content449if verbose:450print(" done")451break452except KeyboardInterrupt:453raise454except:455if not attempts_left:456if verbose:457print(" failed")458raise459if verbose:460print(".", end="", flush=True)461462# Save to cache.463if cache:464safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)465cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)466temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)467os.makedirs(cache_dir, exist_ok=True)468with open(temp_file, "wb") as f:469f.write(url_data)470os.replace(temp_file, cache_file) # atomic471if return_filename:472return cache_file473474# Return data as file object.475assert not return_filename476return io.BytesIO(url_data)477478479