Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/FBAMatting/networks/layers_WS.py
3119 views
1
import torch
2
import torch.nn as nn
3
from torch.nn import functional as F
4
5
6
class Conv2d(nn.Conv2d):
7
def __init__(
8
self,
9
in_channels,
10
out_channels,
11
kernel_size,
12
stride=1,
13
padding=0,
14
dilation=1,
15
groups=1,
16
bias=True,
17
):
18
super(Conv2d, self).__init__(
19
in_channels,
20
out_channels,
21
kernel_size,
22
stride,
23
padding,
24
dilation,
25
groups,
26
bias,
27
)
28
29
def forward(self, x):
30
# return super(Conv2d, self).forward(x)
31
weight = self.weight
32
weight_mean = (
33
weight.mean(dim=1, keepdim=True)
34
.mean(dim=2, keepdim=True)
35
.mean(dim=3, keepdim=True)
36
)
37
weight = weight - weight_mean
38
# std = (weight).view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
39
std = (
40
torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(
41
-1, 1, 1, 1,
42
)
43
+ 1e-5
44
)
45
weight = weight / std.expand_as(weight)
46
return F.conv2d(
47
x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups,
48
)
49
50
51
def BatchNorm2d(num_features):
52
return nn.GroupNorm(num_channels=num_features, num_groups=32)
53
54