Path: blob/master/src/utils/style_misc.py
809 views
"""1this code is borrowed from https://github.com/NVlabs/stylegan2-ada-pytorch with few modifications23Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.45NVIDIA CORPORATION and its licensors retain all intellectual property6and proprietary rights in and to this software, related documentation7and any modifications thereto. Any use, reproduction, disclosure or8distribution of this software and related documentation without an express9license agreement from NVIDIA CORPORATION is strictly prohibited.10"""1112import warnings13import torch14import numpy as np151617#----------------------------------------------------------------------------18# Cached construction of constant tensors. Avoids CPU=>GPU copy when the19# same constant is used multiple times.2021_constant_cache = dict()2223def constant(value, shape=None, dtype=None, device=None, memory_format=None):24value = np.asarray(value)25if shape is not None:26shape = tuple(shape)27if dtype is None:28dtype = torch.get_default_dtype()29if device is None:30device = torch.device("cpu")31if memory_format is None:32memory_format = torch.contiguous_format3334key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)35tensor = _constant_cache.get(key, None)36if tensor is None:37tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)38if shape is not None:39tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))40tensor = tensor.contiguous(memory_format=memory_format)41_constant_cache[key] = tensor42return tensor4344#----------------------------------------------------------------------------45# Replace NaN/Inf with specified numerical values.4647try:48nan_to_num = torch.nan_to_num # 1.8.0a049except AttributeError:50def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin51assert isinstance(input, torch.Tensor)52if posinf is None:53posinf = torch.finfo(input.dtype).max54if neginf is None:55neginf = torch.finfo(input.dtype).min56assert nan == 057return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)5859#----------------------------------------------------------------------------60# Symbolic assert.6162try:63symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access64except AttributeError:65symbolic_assert = torch.Assert # 1.7.06667#----------------------------------------------------------------------------68# Context manager to suppress known warnings in torch.jit.trace().6970class suppress_tracer_warnings(warnings.catch_warnings):71def __enter__(self):72super().__enter__()73warnings.simplefilter("ignore", category=torch.jit.TracerWarning)74return self7576#----------------------------------------------------------------------------77# Assert that the shape of a tensor matches the given list of integers.78# None indicates that the size of a dimension is allowed to vary.79# Performs symbolic assertion when used in torch.jit.trace().8081def assert_shape(tensor, ref_shape):82if tensor.ndim != len(ref_shape):83raise AssertionError(f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}")84for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):85if ref_size is None:86pass87elif isinstance(ref_size, torch.Tensor):88with suppress_tracer_warnings(): # as_tensor results are registered as constants89symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f"Wrong size for dimension {idx}")90elif isinstance(size, torch.Tensor):91with suppress_tracer_warnings(): # as_tensor results are registered as constants92symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f"Wrong size for dimension {idx}: expected {ref_size}")93elif size != ref_size:94raise AssertionError(f"Wrong size for dimension {idx}: got {size}, expected {ref_size}")959697