Path: blob/master/src/sync_batchnorm/batchnorm.py
809 views
"""1-*- coding: utf-8 -*-2File : batchnorm.py3Author : Jiayuan Mao4Email : [email protected]5Date : 27/01/201867This file is part of Synchronized-BatchNorm-PyTorch.8https://github.com/vacancy/Synchronized-BatchNorm-PyTorch9Distributed under MIT License.1011MIT License1213Copyright (c) 2018 Jiayuan MAO1415Permission is hereby granted, free of charge, to any person obtaining a copy16of this software and associated documentation files (the "Software"), to deal17in the Software without restriction, including without limitation the rights18to use, copy, modify, merge, publish, distribute, sublicense, and/or sell19copies of the Software, and to permit persons to whom the Software is20furnished to do so, subject to the following conditions:2122The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.2324THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR25IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,26FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE27AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER28LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,29OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE30SOFTWARE.31"""3233import collections34import contextlib3536import torch37import torch.nn.functional as F3839from torch.nn.modules.batchnorm import _BatchNorm4041try:42from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast43except ImportError:44ReduceAddCoalesced = Broadcast = None4546try:47from jactorch.parallel.comm import SyncMaster48from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback49except ImportError:50from .comm import SyncMaster51from .replicate import DataParallelWithCallback5253__all__ = [54'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', 'patch_sync_batchnorm',55'convert_model'56]575859def _sum_ft(tensor):60"""sum over the first and last dimention"""61return tensor.sum(dim=0).sum(dim=-1)626364def _unsqueeze_ft(tensor):65"""add new dimensions at the front and the tail"""66return tensor.unsqueeze(0).unsqueeze(-1)676869_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])70_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])717273class _SynchronizedBatchNorm(_BatchNorm):74def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):75assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'7677super(_SynchronizedBatchNorm, self).__init__(num_features,78eps=eps,79momentum=momentum,80affine=affine,81track_running_stats=track_running_stats)8283if not self.track_running_stats:84import warnings85warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.')8687self._sync_master = SyncMaster(self._data_parallel_master)8889self._is_parallel = False90self._parallel_id = None91self._slave_pipe = None9293def forward(self, input):94# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.95if not (self._is_parallel and self.training):96return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, self.training,97self.momentum, self.eps)9899# Resize the input to (B, C, -1).100input_shape = input.size()101input = input.view(input.size(0), self.num_features, -1)102103# Compute the sum and square-sum.104sum_size = input.size(0) * input.size(2)105input_sum = _sum_ft(input)106input_ssum = _sum_ft(input**2)107108# Reduce-and-broadcast the statistics.109if self._parallel_id == 0:110mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))111else:112mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))113114# Compute the output.115if self.affine:116# MJY:: Fuse the multiplication for speed.117output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)118else:119output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)120121# Reshape it.122return output.view(input_shape)123124def __data_parallel_replicate__(self, ctx, copy_id):125self._is_parallel = True126self._parallel_id = copy_id127128# parallel_id == 0 means master device.129if self._parallel_id == 0:130ctx.sync_master = self._sync_master131else:132self._slave_pipe = ctx.sync_master.register_slave(copy_id)133134def _data_parallel_master(self, intermediates):135"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""136137# Always using same "device order" makes the ReduceAdd operation faster.138# Thanks to:: Tete Xiao (http://tetexiao.com/)139intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())140141to_reduce = [i[1][:2] for i in intermediates]142to_reduce = [j for i in to_reduce for j in i] # flatten143target_gpus = [i[1].sum.get_device() for i in intermediates]144145sum_size = sum([i[1].sum_size for i in intermediates])146sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)147mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)148149broadcasted = Broadcast.apply(target_gpus, mean, inv_std)150151outputs = []152for i, rec in enumerate(intermediates):153outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))154155return outputs156157def _compute_mean_std(self, sum_, ssum, size):158"""Compute the mean and standard-deviation with sum and square-sum. This method159also maintains the moving average on the master device."""160assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'161mean = sum_ / size162sumvar = ssum - sum_ * mean163unbias_var = sumvar / (size - 1)164bias_var = sumvar / size165166if hasattr(torch, 'no_grad'):167with torch.no_grad():168self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data169self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data170else:171self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data172self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data173174return mean, bias_var.clamp(self.eps)**-0.5175176177class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):178r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a179mini-batch.180181.. math::182183y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta184185This module differs from the built-in PyTorch BatchNorm1d as the mean and186standard-deviation are reduced across all devices during training.187188For example, when one uses `nn.DataParallel` to wrap the network during189training, PyTorch's implementation normalize the tensor on each device using190the statistics only on that device, which accelerated the computation and191is also easy to implement, but the statistics might be inaccurate.192Instead, in this synchronized version, the statistics will be computed193over all training samples distributed on multiple devices.194195Note that, for one-GPU or CPU-only case, this module behaves exactly same196as the built-in PyTorch implementation.197198The mean and standard-deviation are calculated per-dimension over199the mini-batches and gamma and beta are learnable parameter vectors200of size C (where C is the input size).201202During training, this layer keeps a running estimate of its computed mean203and variance. The running sum is kept with a default momentum of 0.1.204205During evaluation, this running mean/variance is used for normalization.206207Because the BatchNorm is done over the `C` dimension, computing statistics208on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm209210Args:211num_features: num_features from an expected input of size212`batch_size x num_features [x width]`213eps: a value added to the denominator for numerical stability.214Default: 1e-5215momentum: the value used for the running_mean and running_var216computation. Default: 0.1217affine: a boolean value that when set to ``True``, gives the layer learnable218affine parameters. Default: ``True``219220Shape::221- Input: :math:`(N, C)` or :math:`(N, C, L)`222- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)223224Examples:225>>> # With Learnable Parameters226>>> m = SynchronizedBatchNorm1d(100)227>>> # Without Learnable Parameters228>>> m = SynchronizedBatchNorm1d(100, affine=False)229>>> input = torch.autograd.Variable(torch.randn(20, 100))230>>> output = m(input)231"""232def _check_input_dim(self, input):233if input.dim() != 2 and input.dim() != 3:234raise ValueError('expected 2D or 3D input (got {}D input)'.format(input.dim()))235236237class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):238r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch239of 3d inputs240241.. math::242243y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta244245This module differs from the built-in PyTorch BatchNorm2d as the mean and246standard-deviation are reduced across all devices during training.247248For example, when one uses `nn.DataParallel` to wrap the network during249training, PyTorch's implementation normalize the tensor on each device using250the statistics only on that device, which accelerated the computation and251is also easy to implement, but the statistics might be inaccurate.252Instead, in this synchronized version, the statistics will be computed253over all training samples distributed on multiple devices.254255Note that, for one-GPU or CPU-only case, this module behaves exactly same256as the built-in PyTorch implementation.257258The mean and standard-deviation are calculated per-dimension over259the mini-batches and gamma and beta are learnable parameter vectors260of size C (where C is the input size).261262During training, this layer keeps a running estimate of its computed mean263and variance. The running sum is kept with a default momentum of 0.1.264265During evaluation, this running mean/variance is used for normalization.266267Because the BatchNorm is done over the `C` dimension, computing statistics268on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm269270Args:271num_features: num_features from an expected input of272size batch_size x num_features x height x width273eps: a value added to the denominator for numerical stability.274Default: 1e-5275momentum: the value used for the running_mean and running_var276computation. Default: 0.1277affine: a boolean value that when set to ``True``, gives the layer learnable278affine parameters. Default: ``True``279280Shape::281- Input: :math:`(N, C, H, W)`282- Output: :math:`(N, C, H, W)` (same shape as input)283284Examples:285>>> # With Learnable Parameters286>>> m = SynchronizedBatchNorm2d(100)287>>> # Without Learnable Parameters288>>> m = SynchronizedBatchNorm2d(100, affine=False)289>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))290>>> output = m(input)291"""292def _check_input_dim(self, input):293if input.dim() != 4:294raise ValueError('expected 4D input (got {}D input)'.format(input.dim()))295296297class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):298r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch299of 4d inputs300301.. math::302303y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta304305This module differs from the built-in PyTorch BatchNorm3d as the mean and306standard-deviation are reduced across all devices during training.307308For example, when one uses `nn.DataParallel` to wrap the network during309training, PyTorch's implementation normalize the tensor on each device using310the statistics only on that device, which accelerated the computation and311is also easy to implement, but the statistics might be inaccurate.312Instead, in this synchronized version, the statistics will be computed313over all training samples distributed on multiple devices.314315Note that, for one-GPU or CPU-only case, this module behaves exactly same316as the built-in PyTorch implementation.317318The mean and standard-deviation are calculated per-dimension over319the mini-batches and gamma and beta are learnable parameter vectors320of size C (where C is the input size).321322During training, this layer keeps a running estimate of its computed mean323and variance. The running sum is kept with a default momentum of 0.1.324325During evaluation, this running mean/variance is used for normalization.326327Because the BatchNorm is done over the `C` dimension, computing statistics328on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm329or Spatio-temporal BatchNorm330331Args:332num_features: num_features from an expected input of333size batch_size x num_features x depth x height x width334eps: a value added to the denominator for numerical stability.335Default: 1e-5336momentum: the value used for the running_mean and running_var337computation. Default: 0.1338affine: a boolean value that when set to ``True``, gives the layer learnable339affine parameters. Default: ``True``340341Shape::342- Input: :math:`(N, C, D, H, W)`343- Output: :math:`(N, C, D, H, W)` (same shape as input)344345Examples:346>>> # With Learnable Parameters347>>> m = SynchronizedBatchNorm3d(100)348>>> # Without Learnable Parameters349>>> m = SynchronizedBatchNorm3d(100, affine=False)350>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))351>>> output = m(input)352"""353def _check_input_dim(self, input):354if input.dim() != 5:355raise ValueError('expected 5D input (got {}D input)'.format(input.dim()))356357358@contextlib.contextmanager359def patch_sync_batchnorm():360import torch.nn as nn361362backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d363364nn.BatchNorm1d = SynchronizedBatchNorm1d365nn.BatchNorm2d = SynchronizedBatchNorm2d366nn.BatchNorm3d = SynchronizedBatchNorm3d367368yield369370nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup371372373def convert_model(module):374"""Traverse the input module and its child recursively375and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d376to SynchronizedBatchNorm*N*d377378Args:379module: the input module needs to be convert to SyncBN model380381Examples:382>>> import torch.nn as nn383>>> import torchvision384>>> # m is a standard pytorch model385>>> m = torchvision.models.resnet18(True)386>>> m = nn.DataParallel(m)387>>> # after convert, m is using SyncBN388>>> m = convert_model(m)389"""390if isinstance(module, torch.nn.DataParallel):391mod = module.module392mod = convert_model(mod)393mod = DataParallelWithCallback(mod, device_ids=module.device_ids)394return mod395396mod = module397for pth_module, sync_module in zip([398torch.nn.modules.batchnorm.BatchNorm1d, torch.nn.modules.batchnorm.BatchNorm2d,399torch.nn.modules.batchnorm.BatchNorm3d400], [SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d]):401if isinstance(module, pth_module):402mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)403mod.running_mean = module.running_mean404mod.running_var = module.running_var405if module.affine:406mod.weight.data = module.weight.data.clone().detach()407mod.bias.data = module.bias.data.clone().detach()408409for name, child in module.named_children():410mod.add_module(name, convert_model(child))411412return mod413414415