Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/helpers/trainer.py
4918 views
1
import signal
2
import typing
3
from typing import Dict, List, Callable
4
from typing import Optional, Tuple, Any, Collection
5
6
import torch.optim
7
import torch.optim
8
import torch.utils.data
9
import torch.utils.data
10
from labml import tracker, logger, monit
11
from labml.configs import BaseConfigs, meta_config, option
12
from labml.internal.monitor import Loop
13
from labml.logger import Text
14
from torch import nn
15
from .device import DeviceConfigs
16
from .metrics import StateModule
17
18
19
class TrainingLoopIterator(Collection):
20
def __init__(self, start: int, total: int, step: Optional[int]):
21
self.step = step
22
self.total = total
23
self.start = start
24
self.i = None
25
26
def __iter__(self):
27
self.i = None
28
return self
29
30
def __next__(self):
31
if self.step is not None:
32
if self.i is None:
33
self.i = self.start
34
else:
35
self.i += self.step
36
else:
37
if self.i is None:
38
self.i = 0
39
else:
40
self.i += 1
41
42
if self.i >= self.total:
43
raise StopIteration()
44
45
if self.step is None:
46
return tracker.get_global_step()
47
else:
48
return self.i
49
50
def __len__(self) -> int:
51
if self.step is not None:
52
return (self.total - self.start) // self.step
53
else:
54
return self.total
55
56
def __contains__(self, x: object) -> bool:
57
return False
58
59
60
class TrainingLoop:
61
_iter: Optional[TrainingLoopIterator]
62
__loop: Loop
63
__signal_received: Optional[Tuple[Any, Any]]
64
65
def __init__(self, *,
66
loop_count: int,
67
loop_step: Optional[int],
68
log_new_line_interval: int,
69
log_write_interval: int,
70
is_loop_on_interrupt: bool):
71
self.__loop_count = loop_count
72
self.__loop_step = loop_step
73
self.__log_new_line_interval = log_new_line_interval
74
self.__log_write_interval = log_write_interval
75
self.__last_write_step = 0
76
self.__last_new_line_step = 0
77
self.__last_save_step = 0
78
self.__signal_received = None
79
self.__is_loop_on_interrupt = is_loop_on_interrupt
80
self._iter = None
81
82
def __iter__(self):
83
self._iter = TrainingLoopIterator(tracker.get_global_step(),
84
self.__loop_count,
85
self.__loop_step)
86
87
self.__loop = monit.loop(typing.cast(Collection, self._iter))
88
89
iter(self.__loop)
90
try:
91
self.old_handler = signal.signal(signal.SIGINT, self.__handler)
92
except ValueError:
93
pass
94
return self
95
96
@property
97
def idx(self):
98
if not self._iter:
99
return 0
100
if not self._iter.i:
101
return 0
102
if self.__loop_step is None:
103
return self._iter.i
104
return self._iter.i / self.__loop_step
105
106
def __finish(self):
107
try:
108
signal.signal(signal.SIGINT, self.old_handler)
109
except ValueError:
110
pass
111
tracker.save()
112
tracker.new_line()
113
114
def __next__(self):
115
if self.__signal_received is not None:
116
logger.log('\nKilling Loop.', Text.danger)
117
monit.finish_loop()
118
self.__finish()
119
raise StopIteration("SIGINT")
120
121
try:
122
global_step = next(self.__loop)
123
except StopIteration as e:
124
self.__finish()
125
raise e
126
127
tracker.set_global_step(global_step)
128
129
if global_step - self.__last_write_step >= self.__log_write_interval:
130
tracker.save()
131
self.__last_write_step = global_step
132
if global_step - self.__last_new_line_step >= self.__log_new_line_interval:
133
tracker.new_line()
134
self.__last_new_line_step = global_step
135
136
return global_step
137
138
def __handler(self, sig, frame):
139
# Pass second interrupt without delaying
140
if self.__signal_received is not None:
141
logger.log('\nSIGINT received twice. Stopping...', Text.danger)
142
self.old_handler(*self.__signal_received)
143
return
144
145
if self.__is_loop_on_interrupt:
146
# Store the interrupt signal for later
147
self.__signal_received = (sig, frame)
148
logger.log('\nSIGINT received. Delaying KeyboardInterrupt.', Text.danger)
149
else:
150
self.__finish()
151
logger.log('Killing loop...', Text.danger)
152
self.old_handler(sig, frame)
153
154
def __str__(self):
155
return "LabTrainingLoop"
156
157
158
class TrainingLoopConfigs(BaseConfigs):
159
r"""
160
This is a configurable training loop. You can extend this class for your configurations
161
if it involves a training loop.
162
163
>>> for step in conf.training_loop:
164
>>> ...
165
166
Arguments:
167
loop_count (int): Total number of steps. Defaults to ``10``.
168
loop_step (int): Number of steps to increment per iteration. Defaults to ``1``.
169
log_new_line_interval (int): The interval (in steps) to print a new line to the screen.
170
Defaults to ``1``.
171
log_write_interval (int): The interval (in steps) to call :func:`labml.tracker.save`.
172
Defaults to ``1``.
173
is_loop_on_interrupt (bool): Whether to handle keyboard interrupts and wait until a iteration is complete.
174
Defaults to ``False``.
175
"""
176
loop_count: int = 10
177
loop_step: int = 1
178
log_new_line_interval: int = 1
179
log_write_interval: int = 1
180
is_loop_on_interrupt: bool = False
181
182
training_loop: TrainingLoop
183
184
185
@option(TrainingLoopConfigs.training_loop)
186
def _loop_configs(c: TrainingLoopConfigs):
187
return TrainingLoop(loop_count=c.loop_count,
188
loop_step=c.loop_step,
189
log_new_line_interval=c.log_new_line_interval,
190
log_write_interval=c.log_write_interval,
191
is_loop_on_interrupt=c.is_loop_on_interrupt)
192
193
194
meta_config(TrainingLoopConfigs.loop_step,
195
TrainingLoopConfigs.loop_count,
196
TrainingLoopConfigs.log_new_line_interval,
197
TrainingLoopConfigs.log_write_interval,
198
TrainingLoopConfigs.is_loop_on_interrupt)
199
200
201
class ModeState:
202
def __init__(self):
203
self._rollback_stack = []
204
205
self.is_train = False
206
self.is_optimize = False
207
208
def _enter(self, mode: Dict[str, any]):
209
rollback = {}
210
for k, v in mode.items():
211
if v is None:
212
continue
213
rollback[k] = getattr(self, k)
214
setattr(self, k, v)
215
216
self._rollback_stack.append(rollback)
217
218
return len(self._rollback_stack)
219
220
def _exit(self, n: int):
221
assert n == len(self._rollback_stack)
222
223
rollback = self._rollback_stack[-1]
224
self._rollback_stack.pop(-1)
225
226
for k, v in rollback.items():
227
setattr(self, k, v)
228
229
def update(self, *,
230
is_train: Optional[bool] = None,
231
is_optimize: Optional[bool] = None):
232
return Mode(self,
233
is_train=is_train,
234
is_optimize=is_optimize)
235
236
237
class Mode:
238
def __init__(self, mode: ModeState, **kwargs: any):
239
self.mode = mode
240
self.update = {}
241
for k, v in kwargs.items():
242
if v is not None:
243
self.update[k] = v
244
245
self.idx = -1
246
247
def __enter__(self):
248
self.idx = self.mode._enter(self.update)
249
250
def __exit__(self, exc_type, exc_val, exc_tb):
251
self.mode._exit(self.idx)
252
253
254
class Trainer:
255
def __init__(self, *,
256
name: str,
257
mode: ModeState,
258
data_loader: torch.utils.data.DataLoader,
259
inner_iterations: int,
260
state_modules: List[StateModule],
261
is_track_time: bool,
262
step: Callable[[any, 'BatchIndex'], None]):
263
self.is_track_time = is_track_time
264
self.mode = mode
265
self.name = name
266
self.step = step
267
self.state_modules = state_modules
268
self.__iterable = None
269
self.__states = [sm.create_state() for sm in self.state_modules]
270
self.inner_iterations = inner_iterations
271
self.data_loader = data_loader
272
self._batch_index = BatchIndex(len(self.data_loader), self.inner_iterations)
273
274
def set_data_loader(self, data_loader: torch.utils.data.DataLoader):
275
self.data_loader = data_loader
276
self._batch_index = BatchIndex(len(data_loader), self.inner_iterations)
277
self.__iterable = None
278
279
def __call__(self):
280
for sm, s in zip(self.state_modules, self.__states):
281
sm.set_state(s)
282
283
if self.__iterable is None or self._batch_index.completed:
284
self.__iterable = iter(self.data_loader)
285
self._batch_index.reset(len(self.data_loader), self.inner_iterations)
286
for sm in self.state_modules:
287
sm.on_epoch_start()
288
with torch.set_grad_enabled(self.mode.is_train):
289
self.__iterate()
290
291
if self._batch_index.completed:
292
for sm in self.state_modules:
293
sm.on_epoch_end()
294
295
def __iterate(self):
296
with monit.section(self.name, is_partial=True, is_track=self.is_track_time):
297
if self._batch_index.idx == 0:
298
monit.progress(0)
299
while not self._batch_index.iteration_completed:
300
batch = next(self.__iterable)
301
302
self.step(batch, self._batch_index)
303
304
self._batch_index.step()
305
monit.progress(self._batch_index.epoch_progress)
306
307
self._batch_index.step_inner()
308
309
310
class BatchIndex:
311
idx: int
312
total: int
313
iteration: int
314
total_iterations: int
315
316
def __init__(self, total: int, total_iterations: int):
317
self.total_iterations = total_iterations
318
self.total = total
319
320
def is_interval(self, interval: int):
321
if interval <= 0:
322
return False
323
if self.idx + 1 == self.total:
324
return True
325
else:
326
return (self.idx + 1) % interval == 0
327
328
@property
329
def is_last(self):
330
return self.idx + 1 == self.total
331
332
@property
333
def completed(self):
334
return self.iteration >= self.total_iterations
335
336
@property
337
def iteration_completed(self):
338
# // is important so that the last step happens on the last iteration
339
return self.idx >= (self.iteration + 1) * self.total // self.total_iterations
340
341
@property
342
def epoch_progress(self):
343
return self.idx / self.total
344
345
def step(self):
346
self.idx += 1
347
348
def step_inner(self):
349
self.iteration += 1
350
351
def reset(self, total: int, total_iterations: int):
352
self.total = total
353
self.total_iterations = total_iterations
354
self.idx = 0
355
self.iteration = 0
356
357
358
class TrainValidConfigs(TrainingLoopConfigs):
359
r"""
360
This is a configurable module that you can extend for experiments that involve a
361
training and validation datasets (i.e. most DL experiments).
362
363
Arguments:
364
epochs (int): Number of epochs to train on. Defaults to ``10``.
365
train_loader (torch.utils.data.DataLoader): Training data loader.
366
valid_loader (torch.utils.data.DataLoader): Training data loader.
367
inner_iterations (int): Number of times to switch between training and validation
368
within an epoch. Defaults to ``1``.
369
370
You can override ``init``, ``step`` functions. There is also a ``sample`` function
371
that you can override to generate samples ever time it switches between training and validation.
372
"""
373
state_modules: List[StateModule]
374
375
mode: ModeState
376
377
epochs: int = 10
378
379
trainer: Trainer
380
validator: Trainer
381
train_loader: torch.utils.data.DataLoader
382
valid_loader: torch.utils.data.DataLoader
383
384
loop_count = '_data_loop_count'
385
loop_step = None
386
387
inner_iterations: int = 1
388
389
is_track_time: bool = False
390
391
def init(self):
392
pass
393
394
def step(self, batch: Any, batch_idx: BatchIndex):
395
raise NotImplementedError
396
397
def run_step(self):
398
for i in range(self.inner_iterations):
399
with tracker.namespace('sample'):
400
self.sample()
401
with self.mode.update(is_train=True):
402
with tracker.namespace('train'):
403
self.trainer()
404
if self.validator:
405
with tracker.namespace('valid'):
406
self.validator()
407
tracker.save()
408
409
def run(self):
410
with monit.section("Initialize"):
411
self.init()
412
_ = self.validator
413
_ = self.trainer
414
for _ in self.training_loop:
415
self.run_step()
416
417
def sample(self):
418
pass
419
420
421
@option(TrainValidConfigs.trainer)
422
def _default_trainer(c: TrainValidConfigs):
423
return Trainer(name='Train',
424
mode=c.mode,
425
data_loader=c.train_loader,
426
inner_iterations=c.inner_iterations,
427
state_modules=c.state_modules,
428
is_track_time=c.is_track_time,
429
step=c.step)
430
431
432
@option(TrainValidConfigs.validator)
433
def _default_validator(c: TrainValidConfigs):
434
return Trainer(name='Valid',
435
mode=c.mode,
436
data_loader=c.valid_loader,
437
inner_iterations=c.inner_iterations,
438
state_modules=c.state_modules,
439
is_track_time=c.is_track_time,
440
step=c.step)
441
442
443
@option(TrainValidConfigs.loop_count)
444
def _data_loop_count(c: TrainValidConfigs):
445
return c.epochs
446
447
448
class SimpleTrainValidConfigs(TrainValidConfigs):
449
r"""
450
This is a configurable module that works for many standard DL experiments.
451
452
Arguments:
453
model: A PyTorch model.
454
optimizer: A PyTorch optimizer to update model.
455
device: The device to train the model on. This defaults to a configurable device
456
loss_function: A function to calculate the loss. This should accept ``model_output, target`` as
457
arguments.
458
update_batches (int): Number of batches to accumulate before taking an optimizer step.
459
Defaults to ``1``.
460
log_save_batches (int): How often to call :func:`labml.tracker.save`.
461
"""
462
optimizer: torch.optim.Adam
463
model: nn.Module
464
device: torch.device = DeviceConfigs()
465
466
loss_func: nn.Module
467
468
update_batches: int = 1
469
log_save_batches: int = 1
470
471
state_modules: List[StateModule] = []
472
473
def init(self):
474
pass
475
476
def step(self, batch: Any, batch_idx: BatchIndex):
477
self.model.train(self.mode.is_train)
478
data, target = batch[0].to(self.device), batch[1].to(self.device)
479
480
if self.mode.is_train:
481
tracker.add_global_step(len(data))
482
483
with monit.section("model"):
484
output = self.model(data)
485
486
loss = self.loss_func(output, target)
487
tracker.add("loss.", loss)
488
489
if self.mode.is_train:
490
with monit.section('backward'):
491
loss.backward()
492
493
if batch_idx.is_interval(self.update_batches):
494
with monit.section('optimize'):
495
self.optimizer.step()
496
self.optimizer.zero_grad()
497
498
if batch_idx.is_interval(self.log_save_batches):
499
tracker.save()
500
501
502
meta_config(SimpleTrainValidConfigs.update_batches,
503
)
504
505
506
@option(SimpleTrainValidConfigs.optimizer)
507
def _default_optimizer(c: SimpleTrainValidConfigs):
508
from .optimizer import OptimizerConfigs
509
opt_conf = OptimizerConfigs()
510
opt_conf.parameters = c.model.parameters()
511
return opt_conf
512
513