Path: blob/master/src/sync_batchnorm/batchnorm_reimpl.py
809 views
"""1-*- coding: utf-8 -*-2File : batchnorm_reimpl.py3Author : Jiayuan Mao4Email : [email protected]5Date : 27/01/201867This file is part of Synchronized-BatchNorm-PyTorch.8https://github.com/vacancy/Synchronized-BatchNorm-PyTorch9Distributed under MIT License.1011MIT License1213Copyright (c) 2018 Jiayuan MAO1415Permission is hereby granted, free of charge, to any person obtaining a copy16of this software and associated documentation files (the "Software"), to deal17in the Software without restriction, including without limitation the rights18to use, copy, modify, merge, publish, distribute, sublicense, and/or sell19copies of the Software, and to permit persons to whom the Software is20furnished to do so, subject to the following conditions:2122The above copyright notice and this permission notice shall be included in all23copies or substantial portions of the Software.2425THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR26IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,27FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE28AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER29LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,30OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE31SOFTWARE.32"""3334import torch35import torch.nn as nn36import torch.nn.init as init3738__all__ = ['BatchNorm2dReimpl']394041class BatchNorm2dReimpl(nn.Module):42"""43A re-implementation of batch normalization, used for testing the numerical44stability.4546Author: acgtyrant47See also:48https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/1449"""50def __init__(self, num_features, eps=1e-5, momentum=0.1):51super().__init__()5253self.num_features = num_features54self.eps = eps55self.momentum = momentum56self.weight = nn.Parameter(torch.empty(num_features))57self.bias = nn.Parameter(torch.empty(num_features))58self.register_buffer('running_mean', torch.zeros(num_features))59self.register_buffer('running_var', torch.ones(num_features))60self.reset_parameters()6162def reset_running_stats(self):63self.running_mean.zero_()64self.running_var.fill_(1)6566def reset_parameters(self):67self.reset_running_stats()68init.uniform_(self.weight)69init.zeros_(self.bias)7071def forward(self, input_):72batchsize, channels, height, width = input_.size()73numel = batchsize * height * width74input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)75sum_ = input_.sum(1)76sum_of_square = input_.pow(2).sum(1)77mean = sum_ / numel78sumvar = sum_of_square - sum_ * mean7980self.running_mean = ((1 - self.momentum) * self.running_mean + self.momentum * mean.detach())81unbias_var = sumvar / (numel - 1)82self.running_var = ((1 - self.momentum) * self.running_var + self.momentum * unbias_var.detach())8384bias_var = sumvar / numel85inv_std = 1 / (bias_var + self.eps).pow(0.5)86output = ((input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * self.weight.unsqueeze(1) +87self.bias.unsqueeze(1))8889return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()909192