Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/sync_batchnorm/comm.py
809 views
1
"""
2
-*- coding: utf-8 -*-
3
File : comm.py
4
Author : Jiayuan Mao
5
Email : [email protected]
6
Date : 27/01/2018
7
8
This file is part of Synchronized-BatchNorm-PyTorch.
9
https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
10
Distributed under MIT License.
11
12
MIT License
13
14
Copyright (c) 2018 Jiayuan MAO
15
16
Permission is hereby granted, free of charge, to any person obtaining a copy
17
of this software and associated documentation files (the "Software"), to deal
18
in the Software without restriction, including without limitation the rights
19
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
20
copies of the Software, and to permit persons to whom the Software is
21
furnished to do so, subject to the following conditions:
22
23
The above copyright notice and this permission notice shall be included in all
24
copies or substantial portions of the Software.
25
26
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
27
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
28
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
29
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
30
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
31
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
32
SOFTWARE.
33
"""
34
35
import queue
36
import collections
37
import threading
38
39
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
40
41
42
class FutureResult(object):
43
"""A thread-safe future implementation. Used only as one-to-one pipe."""
44
def __init__(self):
45
self._result = None
46
self._lock = threading.Lock()
47
self._cond = threading.Condition(self._lock)
48
49
def put(self, result):
50
with self._lock:
51
assert self._result is None, 'Previous result has\'t been fetched.'
52
self._result = result
53
self._cond.notify()
54
55
def get(self):
56
with self._lock:
57
if self._result is None:
58
self._cond.wait()
59
60
res = self._result
61
self._result = None
62
return res
63
64
65
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
66
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
67
68
69
class SlavePipe(_SlavePipeBase):
70
"""Pipe for master-slave communication."""
71
def run_slave(self, msg):
72
self.queue.put((self.identifier, msg))
73
ret = self.result.get()
74
self.queue.put(True)
75
return ret
76
77
78
class SyncMaster(object):
79
"""An abstract `SyncMaster` object.
80
81
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
82
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
83
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
84
and passed to a registered callback.
85
- After receiving the messages, the master device should gather the information and determine to message passed
86
back to each slave devices.
87
"""
88
def __init__(self, master_callback):
89
"""
90
91
Args:
92
master_callback: a callback to be invoked after having collected messages from slave devices.
93
"""
94
self._master_callback = master_callback
95
self._queue = queue.Queue()
96
self._registry = collections.OrderedDict()
97
self._activated = False
98
99
def __getstate__(self):
100
return {'master_callback': self._master_callback}
101
102
def __setstate__(self, state):
103
self.__init__(state['master_callback'])
104
105
def register_slave(self, identifier):
106
"""
107
Register an slave device.
108
109
Args:
110
identifier: an identifier, usually is the device id.
111
112
Returns: a `SlavePipe` object which can be used to communicate with the master device.
113
114
"""
115
if self._activated:
116
assert self._queue.empty(), 'Queue is not clean before next initialization.'
117
self._activated = False
118
self._registry.clear()
119
future = FutureResult()
120
self._registry[identifier] = _MasterRegistry(future)
121
return SlavePipe(identifier, self._queue, future)
122
123
def run_master(self, master_msg):
124
"""
125
Main entry for the master device in each forward pass.
126
The messages were first collected from each devices (including the master device), and then
127
an callback will be invoked to compute the message to be sent back to each devices
128
(including the master device).
129
130
Args:
131
master_msg: the message that the master want to send to itself. This will be placed as the first
132
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
133
134
Returns: the message to be sent back to the master device.
135
136
"""
137
self._activated = True
138
139
intermediates = [(0, master_msg)]
140
for i in range(self.nr_slaves):
141
intermediates.append(self._queue.get())
142
143
results = self._master_callback(intermediates)
144
assert results[0][0] == 0, 'The first result should belongs to the master.'
145
146
for i, res in results:
147
if i == 0:
148
continue
149
self._registry[i].result.put(res)
150
151
for i in range(self.nr_slaves):
152
assert self._queue.get() is True
153
154
return results[0][1]
155
156
@property
157
def nr_slaves(self):
158
return len(self._registry)
159
160