Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever
GitHub Repository: ai-forever/sber-swap
Path: blob/main/models/networks/normalization.py
1285 views
1
"""
2
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
"""
5
6
import re
7
import torch
8
import torch.nn as nn
9
import torch.nn.functional as F
10
from models.networks.sync_batchnorm import SynchronizedBatchNorm2d
11
import torch.nn.utils.spectral_norm as spectral_norm
12
13
14
def get_nonspade_norm_layer(opt, norm_type='instance'):
15
# helper function to get # output channels of the previous layer
16
def get_out_channel(layer):
17
if hasattr(layer, 'out_channels'):
18
return getattr(layer, 'out_channels')
19
return layer.weight.size(0)
20
21
# this function will be returned
22
def add_norm_layer(layer):
23
nonlocal norm_type
24
if norm_type.startswith('spectral'):
25
layer = spectral_norm(layer)
26
subnorm_type = norm_type[len('spectral'):]
27
28
if subnorm_type == 'none' or len(subnorm_type) == 0:
29
return layer
30
31
# remove bias in the previous layer, which is meaningless
32
# since it has no effect after normalization
33
if getattr(layer, 'bias', None) is not None:
34
delattr(layer, 'bias')
35
layer.register_parameter('bias', None)
36
37
if subnorm_type == 'batch':
38
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
39
elif subnorm_type == 'sync_batch':
40
norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
41
elif subnorm_type == 'instance':
42
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
43
else:
44
raise ValueError('normalization layer %s is not recognized' % subnorm_type)
45
46
return nn.Sequential(layer, norm_layer)
47
48
return add_norm_layer
49
50
51
class InstanceNorm2d(nn.Module):
52
def __init__(self, epsilon=1e-8, **kwargs):
53
super().__init__(**kwargs)
54
self.epsilon = epsilon
55
56
def forward(self, x):
57
#x = x - torch.mean(x, (2, 3), True)
58
tmp = torch.mul(x, x) # or x ** 2
59
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
60
return x * tmp
61
62
63
class SPADE(nn.Module):
64
def __init__(self, config_text, norm_nc, label_nc):
65
super().__init__()
66
67
assert config_text.startswith('spade')
68
parsed = re.search('spade(\D+)(\d)x\d', config_text)
69
param_free_norm_type = str(parsed.group(1))
70
ks = int(parsed.group(2))
71
72
if param_free_norm_type == 'instance':
73
self.param_free_norm = InstanceNorm2d(norm_nc)
74
elif param_free_norm_type == 'syncbatch':
75
self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
76
elif param_free_norm_type == 'batch':
77
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
78
else:
79
raise ValueError('%s is not a recognized param-free norm type in SPADE'
80
% param_free_norm_type)
81
82
# The dimension of the intermediate embedding space. Yes, hardcoded.
83
nhidden = 128 if norm_nc>128 else norm_nc
84
85
pw = ks // 2
86
self.mlp_shared = nn.Sequential(
87
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
88
nn.ReLU()
89
)
90
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
91
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
92
93
def forward(self, x, segmap):
94
95
# Part 1. generate parameter-free normalized activations
96
normalized = self.param_free_norm(x)
97
98
# Part 2. produce scaling and bias conditioned on semantic map
99
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
100
actv = self.mlp_shared(segmap)
101
gamma = self.mlp_gamma(actv)
102
beta = self.mlp_beta(actv)
103
104
# apply scale and bias
105
out = normalized * gamma + beta
106
107
return out
108
109