Path: blob/master/labml_nn/neox/utils/trainer.py
4921 views
from typing import Optional, Set, List12import torch.nn as nn3import torch.optim4import torch.utils.data5from torch.cuda import amp6from torch.cuda.amp import GradScaler78from labml import monit, tracker9from labml.configs import BaseConfigs, option10from labml_nn.neox.utils.finetune import FineTuner111213def get_trainable_params(model: nn.Module):14"""15### Get trainable parameters1617:param model: is the model to train18:return: a list of parameters for training19"""2021# Get all parameters22params = list(model.parameters())23# Filter parameters that require gradients24trainable_params = [p for p in params if p.requires_grad]2526#27return trainable_params282930class TrainerConf(BaseConfigs):31model: nn.Module32layers: List[nn.Module]33optimizer: torch.optim.Optimizer = 'Adam'34train_loader: torch.utils.data.DataLoader35valid_loader: Optional[torch.utils.data.DataLoader] = None,36device: torch.device = torch.device('cuda:0')37scaler: Optional[GradScaler] = 'Default'38is_amp: bool = True39dtype: torch.dtype = torch.float164041is_clone_layers: bool = True4243loss_func: nn.Module = nn.CrossEntropyLoss()44checkpoints_per_epoch: int = 045samples_per_epoch: int = 04647grad_norm: Optional[float] = 1.048learning_rate: float = 3e-449max_seq_len: int = 102450batch_size: int = 6451epochs: int = 165253n_gpus: int = torch.cuda.device_count()5455filter_layers: Optional[Set] = None5657def get_loss(self, sample, dataset_split: str):58"""59:param dataset_split: train/valid60:param sample: is the sample61:return: the loss, output and the target62"""63data, target = sample6465# Forward pass66with monit.section('Forward pass'):67output = self.model(data.to(self.device))68# Move targets to the same device as output69target = target.to(output.device)70# Calculate loss71loss = self.loss_func(output.view(target.numel(), -1), target.view(-1))7273return loss, output, target7475def train(self):76for epoch in monit.loop(self.epochs):77self.train_epoch()78tracker.new_line()7980def sample(self, idx):81pass8283def save_checkpoint(self, idx):84pass8586def get_iterators(self):87# Iterate through the batches88iterators = [('train', self.train_loader)]89if self.valid_loader is not None:90iterators.append(('valid', self.valid_loader))9192if self.samples_per_epoch > 0:93iterators.append((self.sample, [i for i in range(self.samples_per_epoch)]))9495if self.checkpoints_per_epoch > 0:96iterators.append((self.save_checkpoint, [i for i in range(self.checkpoints_per_epoch)]))9798return iterators99100def train_epoch(self):101# Set model for train102self.model.train()103104iterators = self.get_iterators()105for split_name, sample in monit.mix(1024, *iterators):106if split_name == 'train':107# Set gradients to zero108self.optimizer.zero_grad()109tracker.add_global_step()110111with torch.set_grad_enabled(split_name == 'train'):112if self.is_amp:113# Forward pass114with amp.autocast():115loss, output, target = self.get_loss(sample, split_name)116else:117loss, output, target = self.get_loss(sample, split_name)118119# Get predictions120pred = output.argmax(dim=-1)121# Calculate accuracy122accuracy = pred.eq(target).sum().item() / (target != -100).sum()123124tracker.add({f'loss.{split_name}': loss, f'acc.{split_name}': accuracy * 100})125126if split_name == 'train':127if self.scaler is not None:128# Backward pass129loss = self.scaler.scale(loss)130# tracker.add({'loss.scaled': loss})131132with monit.section('Backward pass'):133loss.backward()134135# Optimize136with monit.section('Optimize'):137if self.scaler is None:138self.optimizer.step()139else:140self.scaler.unscale_(self.optimizer)141if self.grad_norm is not None:142torch.nn.utils.clip_grad_norm_(get_trainable_params(self.model), self.grad_norm)143self.scaler.step(self.optimizer)144self.scaler.update()145146tracker.save()147148149@option(TrainerConf.optimizer, 'Adam')150def adam_optimizer(c: TrainerConf):151if c.dtype == torch.float32:152return torch.optim.Adam(get_trainable_params(c.model), lr=c.learning_rate)153elif c.dtype == torch.float16:154from labml_nn.optimizers.adam_fp16 import AdamFP16155return AdamFP16(get_trainable_params(c.model), lr=c.learning_rate)156else:157raise NotImplementedError()158159160@option(TrainerConf.optimizer, 'SGD')161def sgd_optimizer(c: TrainerConf):162return torch.optim.SGD(get_trainable_params(c.model), lr=c.learning_rate)163164165@option(TrainerConf.scaler, 'Default')166def grad_scaler(c: TrainerConf):167if not c.is_amp:168return None169170if c.dtype == torch.float16:171from labml_nn.optimizers.adam_fp16 import GradScalerFP16172return GradScalerFP16()173else:174return GradScaler()175176177class PipelineParallelTrainerConf(TrainerConf):178is_checkpointing: bool = False179chunks: int180181fine_tuner: FineTuner182183184