Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/style_misc.py
809 views
1
"""
2
this code is borrowed from https://github.com/NVlabs/stylegan2-ada-pytorch with few modifications
3
4
Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
5
6
NVIDIA CORPORATION and its licensors retain all intellectual property
7
and proprietary rights in and to this software, related documentation
8
and any modifications thereto. Any use, reproduction, disclosure or
9
distribution of this software and related documentation without an express
10
license agreement from NVIDIA CORPORATION is strictly prohibited.
11
"""
12
13
import warnings
14
import torch
15
import numpy as np
16
17
18
#----------------------------------------------------------------------------
19
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
20
# same constant is used multiple times.
21
22
_constant_cache = dict()
23
24
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
25
value = np.asarray(value)
26
if shape is not None:
27
shape = tuple(shape)
28
if dtype is None:
29
dtype = torch.get_default_dtype()
30
if device is None:
31
device = torch.device("cpu")
32
if memory_format is None:
33
memory_format = torch.contiguous_format
34
35
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
36
tensor = _constant_cache.get(key, None)
37
if tensor is None:
38
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
39
if shape is not None:
40
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
41
tensor = tensor.contiguous(memory_format=memory_format)
42
_constant_cache[key] = tensor
43
return tensor
44
45
#----------------------------------------------------------------------------
46
# Replace NaN/Inf with specified numerical values.
47
48
try:
49
nan_to_num = torch.nan_to_num # 1.8.0a0
50
except AttributeError:
51
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
52
assert isinstance(input, torch.Tensor)
53
if posinf is None:
54
posinf = torch.finfo(input.dtype).max
55
if neginf is None:
56
neginf = torch.finfo(input.dtype).min
57
assert nan == 0
58
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
59
60
#----------------------------------------------------------------------------
61
# Symbolic assert.
62
63
try:
64
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
65
except AttributeError:
66
symbolic_assert = torch.Assert # 1.7.0
67
68
#----------------------------------------------------------------------------
69
# Context manager to suppress known warnings in torch.jit.trace().
70
71
class suppress_tracer_warnings(warnings.catch_warnings):
72
def __enter__(self):
73
super().__enter__()
74
warnings.simplefilter("ignore", category=torch.jit.TracerWarning)
75
return self
76
77
#----------------------------------------------------------------------------
78
# Assert that the shape of a tensor matches the given list of integers.
79
# None indicates that the size of a dimension is allowed to vary.
80
# Performs symbolic assertion when used in torch.jit.trace().
81
82
def assert_shape(tensor, ref_shape):
83
if tensor.ndim != len(ref_shape):
84
raise AssertionError(f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}")
85
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
86
if ref_size is None:
87
pass
88
elif isinstance(ref_size, torch.Tensor):
89
with suppress_tracer_warnings(): # as_tensor results are registered as constants
90
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f"Wrong size for dimension {idx}")
91
elif isinstance(size, torch.Tensor):
92
with suppress_tracer_warnings(): # as_tensor results are registered as constants
93
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f"Wrong size for dimension {idx}: expected {ref_size}")
94
elif size != ref_size:
95
raise AssertionError(f"Wrong size for dimension {idx}: got {size}, expected {ref_size}")
96
97