Path: blob/master/src/sync_batchnorm/comm.py
809 views
"""1-*- coding: utf-8 -*-2File : comm.py3Author : Jiayuan Mao4Email : [email protected]5Date : 27/01/201867This file is part of Synchronized-BatchNorm-PyTorch.8https://github.com/vacancy/Synchronized-BatchNorm-PyTorch9Distributed under MIT License.1011MIT License1213Copyright (c) 2018 Jiayuan MAO1415Permission is hereby granted, free of charge, to any person obtaining a copy16of this software and associated documentation files (the "Software"), to deal17in the Software without restriction, including without limitation the rights18to use, copy, modify, merge, publish, distribute, sublicense, and/or sell19copies of the Software, and to permit persons to whom the Software is20furnished to do so, subject to the following conditions:2122The above copyright notice and this permission notice shall be included in all23copies or substantial portions of the Software.2425THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR26IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,27FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE28AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER29LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,30OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE31SOFTWARE.32"""3334import queue35import collections36import threading3738__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']394041class FutureResult(object):42"""A thread-safe future implementation. Used only as one-to-one pipe."""43def __init__(self):44self._result = None45self._lock = threading.Lock()46self._cond = threading.Condition(self._lock)4748def put(self, result):49with self._lock:50assert self._result is None, 'Previous result has\'t been fetched.'51self._result = result52self._cond.notify()5354def get(self):55with self._lock:56if self._result is None:57self._cond.wait()5859res = self._result60self._result = None61return res626364_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])65_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])666768class SlavePipe(_SlavePipeBase):69"""Pipe for master-slave communication."""70def run_slave(self, msg):71self.queue.put((self.identifier, msg))72ret = self.result.get()73self.queue.put(True)74return ret757677class SyncMaster(object):78"""An abstract `SyncMaster` object.7980- During the replication, as the data parallel will trigger an callback of each module, all slave devices should81call `register(id)` and obtain an `SlavePipe` to communicate with the master.82- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,83and passed to a registered callback.84- After receiving the messages, the master device should gather the information and determine to message passed85back to each slave devices.86"""87def __init__(self, master_callback):88"""8990Args:91master_callback: a callback to be invoked after having collected messages from slave devices.92"""93self._master_callback = master_callback94self._queue = queue.Queue()95self._registry = collections.OrderedDict()96self._activated = False9798def __getstate__(self):99return {'master_callback': self._master_callback}100101def __setstate__(self, state):102self.__init__(state['master_callback'])103104def register_slave(self, identifier):105"""106Register an slave device.107108Args:109identifier: an identifier, usually is the device id.110111Returns: a `SlavePipe` object which can be used to communicate with the master device.112113"""114if self._activated:115assert self._queue.empty(), 'Queue is not clean before next initialization.'116self._activated = False117self._registry.clear()118future = FutureResult()119self._registry[identifier] = _MasterRegistry(future)120return SlavePipe(identifier, self._queue, future)121122def run_master(self, master_msg):123"""124Main entry for the master device in each forward pass.125The messages were first collected from each devices (including the master device), and then126an callback will be invoked to compute the message to be sent back to each devices127(including the master device).128129Args:130master_msg: the message that the master want to send to itself. This will be placed as the first131message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.132133Returns: the message to be sent back to the master device.134135"""136self._activated = True137138intermediates = [(0, master_msg)]139for i in range(self.nr_slaves):140intermediates.append(self._queue.get())141142results = self._master_callback(intermediates)143assert results[0][0] == 0, 'The first result should belongs to the master.'144145for i, res in results:146if i == 0:147continue148self._registry[i].result.put(res)149150for i in range(self.nr_slaves):151assert self._queue.get() is True152153return results[0][1]154155@property156def nr_slaves(self):157return len(self._registry)158159160