Path: blob/master/labml_nn/helpers/trainer.py
4918 views
import signal1import typing2from typing import Dict, List, Callable3from typing import Optional, Tuple, Any, Collection45import torch.optim6import torch.optim7import torch.utils.data8import torch.utils.data9from labml import tracker, logger, monit10from labml.configs import BaseConfigs, meta_config, option11from labml.internal.monitor import Loop12from labml.logger import Text13from torch import nn14from .device import DeviceConfigs15from .metrics import StateModule161718class TrainingLoopIterator(Collection):19def __init__(self, start: int, total: int, step: Optional[int]):20self.step = step21self.total = total22self.start = start23self.i = None2425def __iter__(self):26self.i = None27return self2829def __next__(self):30if self.step is not None:31if self.i is None:32self.i = self.start33else:34self.i += self.step35else:36if self.i is None:37self.i = 038else:39self.i += 14041if self.i >= self.total:42raise StopIteration()4344if self.step is None:45return tracker.get_global_step()46else:47return self.i4849def __len__(self) -> int:50if self.step is not None:51return (self.total - self.start) // self.step52else:53return self.total5455def __contains__(self, x: object) -> bool:56return False575859class TrainingLoop:60_iter: Optional[TrainingLoopIterator]61__loop: Loop62__signal_received: Optional[Tuple[Any, Any]]6364def __init__(self, *,65loop_count: int,66loop_step: Optional[int],67log_new_line_interval: int,68log_write_interval: int,69is_loop_on_interrupt: bool):70self.__loop_count = loop_count71self.__loop_step = loop_step72self.__log_new_line_interval = log_new_line_interval73self.__log_write_interval = log_write_interval74self.__last_write_step = 075self.__last_new_line_step = 076self.__last_save_step = 077self.__signal_received = None78self.__is_loop_on_interrupt = is_loop_on_interrupt79self._iter = None8081def __iter__(self):82self._iter = TrainingLoopIterator(tracker.get_global_step(),83self.__loop_count,84self.__loop_step)8586self.__loop = monit.loop(typing.cast(Collection, self._iter))8788iter(self.__loop)89try:90self.old_handler = signal.signal(signal.SIGINT, self.__handler)91except ValueError:92pass93return self9495@property96def idx(self):97if not self._iter:98return 099if not self._iter.i:100return 0101if self.__loop_step is None:102return self._iter.i103return self._iter.i / self.__loop_step104105def __finish(self):106try:107signal.signal(signal.SIGINT, self.old_handler)108except ValueError:109pass110tracker.save()111tracker.new_line()112113def __next__(self):114if self.__signal_received is not None:115logger.log('\nKilling Loop.', Text.danger)116monit.finish_loop()117self.__finish()118raise StopIteration("SIGINT")119120try:121global_step = next(self.__loop)122except StopIteration as e:123self.__finish()124raise e125126tracker.set_global_step(global_step)127128if global_step - self.__last_write_step >= self.__log_write_interval:129tracker.save()130self.__last_write_step = global_step131if global_step - self.__last_new_line_step >= self.__log_new_line_interval:132tracker.new_line()133self.__last_new_line_step = global_step134135return global_step136137def __handler(self, sig, frame):138# Pass second interrupt without delaying139if self.__signal_received is not None:140logger.log('\nSIGINT received twice. Stopping...', Text.danger)141self.old_handler(*self.__signal_received)142return143144if self.__is_loop_on_interrupt:145# Store the interrupt signal for later146self.__signal_received = (sig, frame)147logger.log('\nSIGINT received. Delaying KeyboardInterrupt.', Text.danger)148else:149self.__finish()150logger.log('Killing loop...', Text.danger)151self.old_handler(sig, frame)152153def __str__(self):154return "LabTrainingLoop"155156157class TrainingLoopConfigs(BaseConfigs):158r"""159This is a configurable training loop. You can extend this class for your configurations160if it involves a training loop.161162>>> for step in conf.training_loop:163>>> ...164165Arguments:166loop_count (int): Total number of steps. Defaults to ``10``.167loop_step (int): Number of steps to increment per iteration. Defaults to ``1``.168log_new_line_interval (int): The interval (in steps) to print a new line to the screen.169Defaults to ``1``.170log_write_interval (int): The interval (in steps) to call :func:`labml.tracker.save`.171Defaults to ``1``.172is_loop_on_interrupt (bool): Whether to handle keyboard interrupts and wait until a iteration is complete.173Defaults to ``False``.174"""175loop_count: int = 10176loop_step: int = 1177log_new_line_interval: int = 1178log_write_interval: int = 1179is_loop_on_interrupt: bool = False180181training_loop: TrainingLoop182183184@option(TrainingLoopConfigs.training_loop)185def _loop_configs(c: TrainingLoopConfigs):186return TrainingLoop(loop_count=c.loop_count,187loop_step=c.loop_step,188log_new_line_interval=c.log_new_line_interval,189log_write_interval=c.log_write_interval,190is_loop_on_interrupt=c.is_loop_on_interrupt)191192193meta_config(TrainingLoopConfigs.loop_step,194TrainingLoopConfigs.loop_count,195TrainingLoopConfigs.log_new_line_interval,196TrainingLoopConfigs.log_write_interval,197TrainingLoopConfigs.is_loop_on_interrupt)198199200class ModeState:201def __init__(self):202self._rollback_stack = []203204self.is_train = False205self.is_optimize = False206207def _enter(self, mode: Dict[str, any]):208rollback = {}209for k, v in mode.items():210if v is None:211continue212rollback[k] = getattr(self, k)213setattr(self, k, v)214215self._rollback_stack.append(rollback)216217return len(self._rollback_stack)218219def _exit(self, n: int):220assert n == len(self._rollback_stack)221222rollback = self._rollback_stack[-1]223self._rollback_stack.pop(-1)224225for k, v in rollback.items():226setattr(self, k, v)227228def update(self, *,229is_train: Optional[bool] = None,230is_optimize: Optional[bool] = None):231return Mode(self,232is_train=is_train,233is_optimize=is_optimize)234235236class Mode:237def __init__(self, mode: ModeState, **kwargs: any):238self.mode = mode239self.update = {}240for k, v in kwargs.items():241if v is not None:242self.update[k] = v243244self.idx = -1245246def __enter__(self):247self.idx = self.mode._enter(self.update)248249def __exit__(self, exc_type, exc_val, exc_tb):250self.mode._exit(self.idx)251252253class Trainer:254def __init__(self, *,255name: str,256mode: ModeState,257data_loader: torch.utils.data.DataLoader,258inner_iterations: int,259state_modules: List[StateModule],260is_track_time: bool,261step: Callable[[any, 'BatchIndex'], None]):262self.is_track_time = is_track_time263self.mode = mode264self.name = name265self.step = step266self.state_modules = state_modules267self.__iterable = None268self.__states = [sm.create_state() for sm in self.state_modules]269self.inner_iterations = inner_iterations270self.data_loader = data_loader271self._batch_index = BatchIndex(len(self.data_loader), self.inner_iterations)272273def set_data_loader(self, data_loader: torch.utils.data.DataLoader):274self.data_loader = data_loader275self._batch_index = BatchIndex(len(data_loader), self.inner_iterations)276self.__iterable = None277278def __call__(self):279for sm, s in zip(self.state_modules, self.__states):280sm.set_state(s)281282if self.__iterable is None or self._batch_index.completed:283self.__iterable = iter(self.data_loader)284self._batch_index.reset(len(self.data_loader), self.inner_iterations)285for sm in self.state_modules:286sm.on_epoch_start()287with torch.set_grad_enabled(self.mode.is_train):288self.__iterate()289290if self._batch_index.completed:291for sm in self.state_modules:292sm.on_epoch_end()293294def __iterate(self):295with monit.section(self.name, is_partial=True, is_track=self.is_track_time):296if self._batch_index.idx == 0:297monit.progress(0)298while not self._batch_index.iteration_completed:299batch = next(self.__iterable)300301self.step(batch, self._batch_index)302303self._batch_index.step()304monit.progress(self._batch_index.epoch_progress)305306self._batch_index.step_inner()307308309class BatchIndex:310idx: int311total: int312iteration: int313total_iterations: int314315def __init__(self, total: int, total_iterations: int):316self.total_iterations = total_iterations317self.total = total318319def is_interval(self, interval: int):320if interval <= 0:321return False322if self.idx + 1 == self.total:323return True324else:325return (self.idx + 1) % interval == 0326327@property328def is_last(self):329return self.idx + 1 == self.total330331@property332def completed(self):333return self.iteration >= self.total_iterations334335@property336def iteration_completed(self):337# // is important so that the last step happens on the last iteration338return self.idx >= (self.iteration + 1) * self.total // self.total_iterations339340@property341def epoch_progress(self):342return self.idx / self.total343344def step(self):345self.idx += 1346347def step_inner(self):348self.iteration += 1349350def reset(self, total: int, total_iterations: int):351self.total = total352self.total_iterations = total_iterations353self.idx = 0354self.iteration = 0355356357class TrainValidConfigs(TrainingLoopConfigs):358r"""359This is a configurable module that you can extend for experiments that involve a360training and validation datasets (i.e. most DL experiments).361362Arguments:363epochs (int): Number of epochs to train on. Defaults to ``10``.364train_loader (torch.utils.data.DataLoader): Training data loader.365valid_loader (torch.utils.data.DataLoader): Training data loader.366inner_iterations (int): Number of times to switch between training and validation367within an epoch. Defaults to ``1``.368369You can override ``init``, ``step`` functions. There is also a ``sample`` function370that you can override to generate samples ever time it switches between training and validation.371"""372state_modules: List[StateModule]373374mode: ModeState375376epochs: int = 10377378trainer: Trainer379validator: Trainer380train_loader: torch.utils.data.DataLoader381valid_loader: torch.utils.data.DataLoader382383loop_count = '_data_loop_count'384loop_step = None385386inner_iterations: int = 1387388is_track_time: bool = False389390def init(self):391pass392393def step(self, batch: Any, batch_idx: BatchIndex):394raise NotImplementedError395396def run_step(self):397for i in range(self.inner_iterations):398with tracker.namespace('sample'):399self.sample()400with self.mode.update(is_train=True):401with tracker.namespace('train'):402self.trainer()403if self.validator:404with tracker.namespace('valid'):405self.validator()406tracker.save()407408def run(self):409with monit.section("Initialize"):410self.init()411_ = self.validator412_ = self.trainer413for _ in self.training_loop:414self.run_step()415416def sample(self):417pass418419420@option(TrainValidConfigs.trainer)421def _default_trainer(c: TrainValidConfigs):422return Trainer(name='Train',423mode=c.mode,424data_loader=c.train_loader,425inner_iterations=c.inner_iterations,426state_modules=c.state_modules,427is_track_time=c.is_track_time,428step=c.step)429430431@option(TrainValidConfigs.validator)432def _default_validator(c: TrainValidConfigs):433return Trainer(name='Valid',434mode=c.mode,435data_loader=c.valid_loader,436inner_iterations=c.inner_iterations,437state_modules=c.state_modules,438is_track_time=c.is_track_time,439step=c.step)440441442@option(TrainValidConfigs.loop_count)443def _data_loop_count(c: TrainValidConfigs):444return c.epochs445446447class SimpleTrainValidConfigs(TrainValidConfigs):448r"""449This is a configurable module that works for many standard DL experiments.450451Arguments:452model: A PyTorch model.453optimizer: A PyTorch optimizer to update model.454device: The device to train the model on. This defaults to a configurable device455loss_function: A function to calculate the loss. This should accept ``model_output, target`` as456arguments.457update_batches (int): Number of batches to accumulate before taking an optimizer step.458Defaults to ``1``.459log_save_batches (int): How often to call :func:`labml.tracker.save`.460"""461optimizer: torch.optim.Adam462model: nn.Module463device: torch.device = DeviceConfigs()464465loss_func: nn.Module466467update_batches: int = 1468log_save_batches: int = 1469470state_modules: List[StateModule] = []471472def init(self):473pass474475def step(self, batch: Any, batch_idx: BatchIndex):476self.model.train(self.mode.is_train)477data, target = batch[0].to(self.device), batch[1].to(self.device)478479if self.mode.is_train:480tracker.add_global_step(len(data))481482with monit.section("model"):483output = self.model(data)484485loss = self.loss_func(output, target)486tracker.add("loss.", loss)487488if self.mode.is_train:489with monit.section('backward'):490loss.backward()491492if batch_idx.is_interval(self.update_batches):493with monit.section('optimize'):494self.optimizer.step()495self.optimizer.zero_grad()496497if batch_idx.is_interval(self.log_save_batches):498tracker.save()499500501meta_config(SimpleTrainValidConfigs.update_batches,502)503504505@option(SimpleTrainValidConfigs.optimizer)506def _default_optimizer(c: SimpleTrainValidConfigs):507from .optimizer import OptimizerConfigs508opt_conf = OptimizerConfigs()509opt_conf.parameters = c.model.parameters()510return opt_conf511512513