Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/style_ops/dnnlib/util.py
809 views
1
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
#
3
# NVIDIA CORPORATION and its licensors retain all intellectual property
4
# and proprietary rights in and to this software, related documentation
5
# and any modifications thereto. Any use, reproduction, disclosure or
6
# distribution of this software and related documentation without an express
7
# license agreement from NVIDIA CORPORATION is strictly prohibited.
8
9
"""Miscellaneous utility classes and functions."""
10
11
import ctypes
12
import fnmatch
13
import importlib
14
import inspect
15
import numpy as np
16
import os
17
import shutil
18
import sys
19
import types
20
import io
21
import pickle
22
import re
23
import requests
24
import html
25
import hashlib
26
import glob
27
import tempfile
28
import urllib
29
import urllib.request
30
import uuid
31
32
from distutils.util import strtobool
33
from typing import Any, List, Tuple, Union
34
35
36
# Util classes
37
# ------------------------------------------------------------------------------------------
38
39
40
class EasyDict(dict):
41
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
42
43
def __getattr__(self, name: str) -> Any:
44
try:
45
return self[name]
46
except KeyError:
47
raise AttributeError(name)
48
49
def __setattr__(self, name: str, value: Any) -> None:
50
self[name] = value
51
52
def __delattr__(self, name: str) -> None:
53
del self[name]
54
55
56
class Logger(object):
57
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
58
59
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
60
self.file = None
61
62
if file_name is not None:
63
self.file = open(file_name, file_mode)
64
65
self.should_flush = should_flush
66
self.stdout = sys.stdout
67
self.stderr = sys.stderr
68
69
sys.stdout = self
70
sys.stderr = self
71
72
def __enter__(self) -> "Logger":
73
return self
74
75
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
76
self.close()
77
78
def write(self, text: Union[str, bytes]) -> None:
79
"""Write text to stdout (and a file) and optionally flush."""
80
if isinstance(text, bytes):
81
text = text.decode()
82
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
83
return
84
85
if self.file is not None:
86
self.file.write(text)
87
88
self.stdout.write(text)
89
90
if self.should_flush:
91
self.flush()
92
93
def flush(self) -> None:
94
"""Flush written text to both stdout and a file, if open."""
95
if self.file is not None:
96
self.file.flush()
97
98
self.stdout.flush()
99
100
def close(self) -> None:
101
"""Flush, close possible files, and remove stdout/stderr mirroring."""
102
self.flush()
103
104
# if using multiple loggers, prevent closing in wrong order
105
if sys.stdout is self:
106
sys.stdout = self.stdout
107
if sys.stderr is self:
108
sys.stderr = self.stderr
109
110
if self.file is not None:
111
self.file.close()
112
self.file = None
113
114
115
# Cache directories
116
# ------------------------------------------------------------------------------------------
117
118
_dnnlib_cache_dir = None
119
120
def set_cache_dir(path: str) -> None:
121
global _dnnlib_cache_dir
122
_dnnlib_cache_dir = path
123
124
def make_cache_dir_path(*paths: str) -> str:
125
if _dnnlib_cache_dir is not None:
126
return os.path.join(_dnnlib_cache_dir, *paths)
127
if 'DNNLIB_CACHE_DIR' in os.environ:
128
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
129
if 'HOME' in os.environ:
130
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
131
if 'USERPROFILE' in os.environ:
132
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
133
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
134
135
# Small util functions
136
# ------------------------------------------------------------------------------------------
137
138
139
def format_time(seconds: Union[int, float]) -> str:
140
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
141
s = int(np.rint(seconds))
142
143
if s < 60:
144
return "{0}s".format(s)
145
elif s < 60 * 60:
146
return "{0}m {1:02}s".format(s // 60, s % 60)
147
elif s < 24 * 60 * 60:
148
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
149
else:
150
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
151
152
153
def ask_yes_no(question: str) -> bool:
154
"""Ask the user the question until the user inputs a valid answer."""
155
while True:
156
try:
157
print("{0} [y/n]".format(question))
158
return strtobool(input().lower())
159
except ValueError:
160
pass
161
162
163
def tuple_product(t: Tuple) -> Any:
164
"""Calculate the product of the tuple elements."""
165
result = 1
166
167
for v in t:
168
result *= v
169
170
return result
171
172
173
_str_to_ctype = {
174
"uint8": ctypes.c_ubyte,
175
"uint16": ctypes.c_uint16,
176
"uint32": ctypes.c_uint32,
177
"uint64": ctypes.c_uint64,
178
"int8": ctypes.c_byte,
179
"int16": ctypes.c_int16,
180
"int32": ctypes.c_int32,
181
"int64": ctypes.c_int64,
182
"float32": ctypes.c_float,
183
"float64": ctypes.c_double
184
}
185
186
187
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
188
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
189
type_str = None
190
191
if isinstance(type_obj, str):
192
type_str = type_obj
193
elif hasattr(type_obj, "__name__"):
194
type_str = type_obj.__name__
195
elif hasattr(type_obj, "name"):
196
type_str = type_obj.name
197
else:
198
raise RuntimeError("Cannot infer type name from input")
199
200
assert type_str in _str_to_ctype.keys()
201
202
my_dtype = np.dtype(type_str)
203
my_ctype = _str_to_ctype[type_str]
204
205
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
206
207
return my_dtype, my_ctype
208
209
210
def is_pickleable(obj: Any) -> bool:
211
try:
212
with io.BytesIO() as stream:
213
pickle.dump(obj, stream)
214
return True
215
except:
216
return False
217
218
219
# Functionality to import modules/objects by name, and call functions by name
220
# ------------------------------------------------------------------------------------------
221
222
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
223
"""Searches for the underlying module behind the name to some python object.
224
Returns the module and the object name (original name with module part removed)."""
225
226
# allow convenience shorthands, substitute them by full names
227
obj_name = re.sub("^np.", "numpy.", obj_name)
228
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
229
230
# list alternatives for (module_name, local_obj_name)
231
parts = obj_name.split(".")
232
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
233
234
# try each alternative in turn
235
for module_name, local_obj_name in name_pairs:
236
try:
237
module = importlib.import_module(module_name) # may raise ImportError
238
get_obj_from_module(module, local_obj_name) # may raise AttributeError
239
return module, local_obj_name
240
except:
241
pass
242
243
# maybe some of the modules themselves contain errors?
244
for module_name, _local_obj_name in name_pairs:
245
try:
246
importlib.import_module(module_name) # may raise ImportError
247
except ImportError:
248
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
249
raise
250
251
# maybe the requested attribute is missing?
252
for module_name, local_obj_name in name_pairs:
253
try:
254
module = importlib.import_module(module_name) # may raise ImportError
255
get_obj_from_module(module, local_obj_name) # may raise AttributeError
256
except ImportError:
257
pass
258
259
# we are out of luck, but we have no idea why
260
raise ImportError(obj_name)
261
262
263
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
264
"""Traverses the object name and returns the last (rightmost) python object."""
265
if obj_name == '':
266
return module
267
obj = module
268
for part in obj_name.split("."):
269
obj = getattr(obj, part)
270
return obj
271
272
273
def get_obj_by_name(name: str) -> Any:
274
"""Finds the python object with the given name."""
275
module, obj_name = get_module_from_obj_name(name)
276
return get_obj_from_module(module, obj_name)
277
278
279
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
280
"""Finds the python object with the given name and calls it as a function."""
281
assert func_name is not None
282
func_obj = get_obj_by_name(func_name)
283
assert callable(func_obj)
284
return func_obj(*args, **kwargs)
285
286
287
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
288
"""Finds the python class with the given name and constructs it with the given arguments."""
289
return call_func_by_name(*args, func_name=class_name, **kwargs)
290
291
292
def get_module_dir_by_obj_name(obj_name: str) -> str:
293
"""Get the directory path of the module containing the given object name."""
294
module, _ = get_module_from_obj_name(obj_name)
295
return os.path.dirname(inspect.getfile(module))
296
297
298
def is_top_level_function(obj: Any) -> bool:
299
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
300
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
301
302
303
def get_top_level_function_name(obj: Any) -> str:
304
"""Return the fully-qualified name of a top-level function."""
305
assert is_top_level_function(obj)
306
module = obj.__module__
307
if module == '__main__':
308
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
309
return module + "." + obj.__name__
310
311
312
# File system helpers
313
# ------------------------------------------------------------------------------------------
314
315
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
316
"""List all files recursively in a given directory while ignoring given file and directory names.
317
Returns list of tuples containing both absolute and relative paths."""
318
assert os.path.isdir(dir_path)
319
base_name = os.path.basename(os.path.normpath(dir_path))
320
321
if ignores is None:
322
ignores = []
323
324
result = []
325
326
for root, dirs, files in os.walk(dir_path, topdown=True):
327
for ignore_ in ignores:
328
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
329
330
# dirs need to be edited in-place
331
for d in dirs_to_remove:
332
dirs.remove(d)
333
334
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
335
336
absolute_paths = [os.path.join(root, f) for f in files]
337
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
338
339
if add_base_to_relative:
340
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
341
342
assert len(absolute_paths) == len(relative_paths)
343
result += zip(absolute_paths, relative_paths)
344
345
return result
346
347
348
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
349
"""Takes in a list of tuples of (src, dst) paths and copies files.
350
Will create all necessary directories."""
351
for file in files:
352
target_dir_name = os.path.dirname(file[1])
353
354
# will create all intermediate-level directories
355
if not os.path.exists(target_dir_name):
356
os.makedirs(target_dir_name)
357
358
shutil.copyfile(file[0], file[1])
359
360
361
# URL helpers
362
# ------------------------------------------------------------------------------------------
363
364
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
365
"""Determine whether the given object is a valid URL string."""
366
if not isinstance(obj, str) or not "://" in obj:
367
return False
368
if allow_file_urls and obj.startswith('file://'):
369
return True
370
try:
371
res = requests.compat.urlparse(obj)
372
if not res.scheme or not res.netloc or not "." in res.netloc:
373
return False
374
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
375
if not res.scheme or not res.netloc or not "." in res.netloc:
376
return False
377
except:
378
return False
379
return True
380
381
382
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
383
"""Download the given URL and return a binary-mode file object to access the data."""
384
assert num_attempts >= 1
385
assert not (return_filename and (not cache))
386
387
# Doesn't look like an URL scheme so interpret it as a local filename.
388
if not re.match('^[a-z]+://', url):
389
return url if return_filename else open(url, "rb")
390
391
# Handle file URLs. This code handles unusual file:// patterns that
392
# arise on Windows:
393
#
394
# file:///c:/foo.txt
395
#
396
# which would translate to a local '/c:/foo.txt' filename that's
397
# invalid. Drop the forward slash for such pathnames.
398
#
399
# If you touch this code path, you should test it on both Linux and
400
# Windows.
401
#
402
# Some internet resources suggest using urllib.request.url2pathname() but
403
# but that converts forward slashes to backslashes and this causes
404
# its own set of problems.
405
if url.startswith('file://'):
406
filename = urllib.parse.urlparse(url).path
407
if re.match(r'^/[a-zA-Z]:', filename):
408
filename = filename[1:]
409
return filename if return_filename else open(filename, "rb")
410
411
assert is_url(url)
412
413
# Lookup from cache.
414
if cache_dir is None:
415
cache_dir = make_cache_dir_path('downloads')
416
417
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
418
if cache:
419
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
420
if len(cache_files) == 1:
421
filename = cache_files[0]
422
return filename if return_filename else open(filename, "rb")
423
424
# Download.
425
url_name = None
426
url_data = None
427
with requests.Session() as session:
428
if verbose:
429
print("Downloading %s ..." % url, end="", flush=True)
430
for attempts_left in reversed(range(num_attempts)):
431
try:
432
with session.get(url) as res:
433
res.raise_for_status()
434
if len(res.content) == 0:
435
raise IOError("No data received")
436
437
if len(res.content) < 8192:
438
content_str = res.content.decode("utf-8")
439
if "download_warning" in res.headers.get("Set-Cookie", ""):
440
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
441
if len(links) == 1:
442
url = requests.compat.urljoin(url, links[0])
443
raise IOError("Google Drive virus checker nag")
444
if "Google Drive - Quota exceeded" in content_str:
445
raise IOError("Google Drive download quota exceeded -- please try again later")
446
447
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
448
url_name = match[1] if match else url
449
url_data = res.content
450
if verbose:
451
print(" done")
452
break
453
except KeyboardInterrupt:
454
raise
455
except:
456
if not attempts_left:
457
if verbose:
458
print(" failed")
459
raise
460
if verbose:
461
print(".", end="", flush=True)
462
463
# Save to cache.
464
if cache:
465
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
466
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
467
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
468
os.makedirs(cache_dir, exist_ok=True)
469
with open(temp_file, "wb") as f:
470
f.write(url_data)
471
os.replace(temp_file, cache_file) # atomic
472
if return_filename:
473
return cache_file
474
475
# Return data as file object.
476
assert not return_filename
477
return io.BytesIO(url_data)
478
479