Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/distributed.py
781 views
1
from .torch_core import *
2
from .basic_train import Learner,LearnerCallback
3
from torch.nn.parallel import DistributedDataParallel, DataParallel
4
from torch.utils.data.distributed import DistributedSampler
5
6
from fastai.text import TextLMDataBunch
7
8
__all__ = ['DistributedRecorder', 'DistributedTrainer', 'read_metrics', 'setup_distrib']
9
10
def rnn_reset(self):
11
if hasattr(self.module, 'reset'): self.module.reset()
12
DistributedDataParallel.reset = rnn_reset
13
14
class ParallelTrainer(LearnerCallback):
15
_order = -20
16
def on_train_begin(self, **kwargs): self.learn.model = DataParallel(self.learn.model)
17
def on_train_end (self, **kwargs): self.learn.model = self.learn.model.module
18
19
class DistributedTrainer(LearnerCallback):
20
_order = -20 # Needs to run before the recorder
21
def __init__(self, learn:Learner, cuda_id:int=0):
22
super().__init__(learn)
23
self.cuda_id,self.train_sampler = cuda_id,None
24
25
def _change_dl(self, dl, shuffle):
26
old_dl = dl
27
sampler = OurDistributedSampler(dl.dataset, shuffle=shuffle)
28
new_dl = dl.new(shuffle=False, sampler=sampler)
29
return old_dl,new_dl,sampler
30
31
def on_train_begin(self, **kwargs):
32
self.learn.model = DistributedDataParallel(self.model, device_ids=[self.cuda_id], output_device=self.cuda_id)
33
shuffle = self.data.train_dl.init_kwargs['shuffle'] if hasattr(self.data.train_dl, 'init_kwargs') else True
34
self.old_train_dl,self.data.train_dl,self.train_sampler = self._change_dl(self.data.train_dl, shuffle)
35
if hasattr(self.data, 'valid_dl') and self.data.valid_dl is not None:
36
self.old_valid_dl,self.data.valid_dl,self.valid_sampler = self._change_dl(self.data.valid_dl, shuffle)
37
self.rank = rank_distrib()
38
self.recorder.silent = (self.rank != 0)
39
40
def on_epoch_begin(self, epoch, **kwargs): self.train_sampler.set_epoch(epoch)
41
42
def on_train_end(self, **kwargs):
43
self.learn.model = self.learn.model.module
44
self.learn.data.train_dl = self.old_train_dl
45
if hasattr(self.learn.data, 'valid_dl') and self.learn.data.valid_dl is not None:
46
self.learn.data.valid_dl = self.old_valid_dl
47
48
class DistributedRecorder(LearnerCallback):
49
def __init__(self, learn:Learner, cuda_id:int=0, cache_dir:PathOrStr='tmp'):
50
super().__init__(learn)
51
self.cuda_id,self.cache_dir = cuda_id,cache_dir
52
53
def on_train_begin(self, **kwargs):
54
os.makedirs(self.learn.path/self.cache_dir, exist_ok=True)
55
56
def on_epoch_end(self, **kwargs): self.save_stats()
57
def on_train_end(self, **kwargs): self.save_stats()
58
59
def save_stats(self):
60
cache_path,recorder = self.learn.path/self.cache_dir,self.learn.recorder
61
np.save(cache_path/f'losses_{self.cuda_id}', np.array(recorder.losses))
62
stats = np.array([[v] + m for v,m in zip(recorder.val_losses,recorder.metrics)])
63
np.save(cache_path/f'metrics_{self.cuda_id}', stats)
64
65
def _learner_parallel(learn:Learner):
66
"Use nn.DataParallel when training and remove when done"
67
if not torch.cuda.is_available(): warnings.warn('CUDA is not available, check your drivers - training will continue on CPU', ResourceWarning)
68
learn.callbacks.append(ParallelTrainer(learn))
69
return learn
70
71
def _learner_distributed(learn:Learner, cuda_id:int, cache_dir:PathOrStr='tmp'):
72
"Put `learn` on distributed training with `cuda_id`."
73
learn.callbacks.append(DistributedTrainer(learn, cuda_id))
74
learn.callbacks.append(DistributedRecorder(learn, cuda_id, cache_dir))
75
return learn
76
77
Learner.to_distributed = _learner_distributed
78
Learner.to_parallel = _learner_parallel
79
80
def read_metrics(cache_path:PathOrStr, n_gpus:int, reduce:bool=True):
81
losses,metrics = [],[]
82
for i in range(n_gpus):
83
losses.append(np.load(cache_path/f'losses_{i}.npy')[None])
84
metrics.append(np.load(cache_path/f'metrics_{i}.npy')[None])
85
if reduce:
86
losses,metrics = np.concatenate(losses,0),np.concatenate(metrics,0)
87
return losses.mean(0),metrics.mean(0)
88
return losses,metrics
89
90
def setup_distrib(gpu:Any=None):
91
if gpu is None: return gpu
92
gpu = int(gpu)
93
torch.cuda.set_device(int(gpu))
94
if num_distrib() > 1:
95
torch.distributed.init_process_group(backend='nccl', init_method='env://')
96
return gpu
97
98
class OurDistributedSampler(DistributedSampler):
99
"A sampler for language models with the option to not shuffle."
100
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
101
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
102
self.shuffle = shuffle
103
104
def __iter__(self):
105
if self.shuffle:
106
g = torch.Generator()
107
g.manual_seed(self.epoch)
108
indices = torch.randperm(len(self.dataset), generator=g).tolist()
109
else: indices = torch.arange(len(self.dataset)).tolist()
110
111
# add extra samples to make it evenly divisible
112
indices += indices[:(self.total_size - len(indices))]
113
assert len(indices) == self.total_size
114
115
# subsample
116
indices = indices[self.rank:self.total_size:self.num_replicas]
117
assert len(indices) == self.num_samples
118
119
return iter(indices)
120
121