Path: blob/master/src/sync_batchnorm/replicate.py
809 views
"""1-*- coding: utf-8 -*-2File : replicate.py3Author : Jiayuan Mao4Email : [email protected]5Date : 27/01/201867This file is part of Synchronized-BatchNorm-PyTorch.8https://github.com/vacancy/Synchronized-BatchNorm-PyTorch9Distributed under MIT License.1011MIT License1213Copyright (c) 2018 Jiayuan MAO1415Permission is hereby granted, free of charge, to any person obtaining a copy16of this software and associated documentation files (the "Software"), to deal17in the Software without restriction, including without limitation the rights18to use, copy, modify, merge, publish, distribute, sublicense, and/or sell19copies of the Software, and to permit persons to whom the Software is20furnished to do so, subject to the following conditions:2122The above copyright notice and this permission notice shall be included in all23copies or substantial portions of the Software.2425THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR26IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,27FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE28AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER29LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,30OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE31SOFTWARE.32"""3334import functools3536from torch.nn.parallel.data_parallel import DataParallel3738__all__ = ['CallbackContext', 'execute_replication_callbacks', 'DataParallelWithCallback', 'patch_replication_callback']394041class CallbackContext(object):42pass434445def execute_replication_callbacks(modules):46"""47Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.4849The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`5051Note that, as all modules are isomorphism, we assign each sub-module with a context52(shared among multiple copies of this module on different devices).53Through this context, different copies can share some information.5455We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback56of any slave copies.57"""58master_copy = modules[0]59nr_modules = len(list(master_copy.modules()))60ctxs = [CallbackContext() for _ in range(nr_modules)]6162for i, module in enumerate(modules):63for j, m in enumerate(module.modules()):64if hasattr(m, '__data_parallel_replicate__'):65m.__data_parallel_replicate__(ctxs[j], i)666768class DataParallelWithCallback(DataParallel):69"""70Data Parallel with a replication callback.7172An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by73original `replicate` function.74The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`7576Examples:77> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)78> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])79# sync_bn.__data_parallel_replicate__ will be invoked.80"""81def replicate(self, module, device_ids):82modules = super(DataParallelWithCallback, self).replicate(module, device_ids)83execute_replication_callbacks(modules)84return modules858687def patch_replication_callback(data_parallel):88"""89Monkey-patch an existing `DataParallel` object. Add the replication callback.90Useful when you have customized `DataParallel` implementation.9192Examples:93> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)94> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])95> patch_replication_callback(sync_bn)96# this is equivalent to97> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)98> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])99"""100101assert isinstance(data_parallel, DataParallel)102103old_replicate = data_parallel.replicate104105@functools.wraps(old_replicate)106def new_replicate(module, device_ids):107modules = old_replicate(module, device_ids)108execute_replication_callbacks(modules)109return modules110111data_parallel.replicate = new_replicate112113114