Path: blob/master/FBAMatting/networks/layers_WS.py
3119 views
import torch1import torch.nn as nn2from torch.nn import functional as F345class Conv2d(nn.Conv2d):6def __init__(7self,8in_channels,9out_channels,10kernel_size,11stride=1,12padding=0,13dilation=1,14groups=1,15bias=True,16):17super(Conv2d, self).__init__(18in_channels,19out_channels,20kernel_size,21stride,22padding,23dilation,24groups,25bias,26)2728def forward(self, x):29# return super(Conv2d, self).forward(x)30weight = self.weight31weight_mean = (32weight.mean(dim=1, keepdim=True)33.mean(dim=2, keepdim=True)34.mean(dim=3, keepdim=True)35)36weight = weight - weight_mean37# std = (weight).view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-538std = (39torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(40-1, 1, 1, 1,41)42+ 1e-543)44weight = weight / std.expand_as(weight)45return F.conv2d(46x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups,47)484950def BatchNorm2d(num_features):51return nn.GroupNorm(num_channels=num_features, num_groups=32)525354