Path: blob/master/labml_nn/experiments/nlp_classification.py
4918 views
"""1---2title: NLP classification trainer3summary: >4This is a reusable trainer for classification tasks5---67# NLP model trainer for classification8"""910from collections import Counter11from typing import Callable1213import torchtext14import torchtext.vocab15from torchtext.vocab import Vocab1617import torch18from labml import lab, tracker, monit19from labml.configs import option20from labml_nn.helpers.device import DeviceConfigs21from labml_nn.helpers.metrics import Accuracy22from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex23from labml_nn.optimizers.configs import OptimizerConfigs24from torch import nn25from torch.utils.data import DataLoader262728class NLPClassificationConfigs(TrainValidConfigs):29"""30<a id="NLPClassificationConfigs"></a>3132## Trainer configurations3334This has the basic configurations for NLP classification task training.35All the properties are configurable.36"""3738# Optimizer39optimizer: torch.optim.Adam40# Training device41device: torch.device = DeviceConfigs()4243# Autoregressive model44model: nn.Module45# Batch size46batch_size: int = 1647# Length of the sequence, or context size48seq_len: int = 51249# Vocabulary50vocab: Vocab = 'ag_news'51# Number of token in vocabulary52n_tokens: int53# Number of classes54n_classes: int = 'ag_news'55# Tokenizer56tokenizer: Callable = 'character'5758# Whether to periodically save models59is_save_models = True6061# Loss function62loss_func = nn.CrossEntropyLoss()63# Accuracy function64accuracy = Accuracy()65# Model embedding size66d_model: int = 51267# Gradient clipping68grad_norm_clip: float = 1.06970# Training data loader71train_loader: DataLoader = 'ag_news'72# Validation data loader73valid_loader: DataLoader = 'ag_news'7475# Whether to log model parameters and gradients (once per epoch).76# These are summarized stats per layer, but it could still lead77# to many indicators for very deep networks.78is_log_model_params_grads: bool = False7980# Whether to log model activations (once per epoch).81# These are summarized stats per layer, but it could still lead82# to many indicators for very deep networks.83is_log_model_activations: bool = False8485def init(self):86"""87### Initialization88"""89# Set tracker configurations90tracker.set_scalar("accuracy.*", True)91tracker.set_scalar("loss.*", True)92# Add accuracy as a state module.93# The name is probably confusing, since it's meant to store94# states between training and validation for RNNs.95# This will keep the accuracy metric stats separate for training and validation.96self.state_modules = [self.accuracy]9798def step(self, batch: any, batch_idx: BatchIndex):99"""100### Training or validation step101"""102103# Move data to the device104data, target = batch[0].to(self.device), batch[1].to(self.device)105106# Update global step (number of tokens processed) when in training mode107if self.mode.is_train:108tracker.add_global_step(data.shape[1])109110# Get model outputs.111# It's returning a tuple for states when using RNNs.112# This is not implemented yet. 😜113output, *_ = self.model(data)114115# Calculate and log loss116loss = self.loss_func(output, target)117tracker.add("loss.", loss)118119# Calculate and log accuracy120self.accuracy(output, target)121self.accuracy.track()122123# Train the model124if self.mode.is_train:125# Calculate gradients126loss.backward()127# Clip gradients128torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)129# Take optimizer step130self.optimizer.step()131# Log the model parameters and gradients on last batch of every epoch132if batch_idx.is_last and self.is_log_model_params_grads:133tracker.add('model', self.model)134# Clear the gradients135self.optimizer.zero_grad()136137# Save the tracked metrics138tracker.save()139140141@option(NLPClassificationConfigs.optimizer)142def _optimizer(c: NLPClassificationConfigs):143"""144### Default [optimizer configurations](../optimizers/configs.html)145"""146147optimizer = OptimizerConfigs()148optimizer.parameters = c.model.parameters()149optimizer.optimizer = 'Adam'150optimizer.d_model = c.d_model151152return optimizer153154155@option(NLPClassificationConfigs.tokenizer)156def basic_english():157"""158### Basic english tokenizer159160We use character level tokenizer in this experiment.161You can switch by setting,162163```164'tokenizer': 'basic_english',165```166167in the configurations dictionary when starting the experiment.168169"""170from torchtext.data import get_tokenizer171return get_tokenizer('basic_english')172173174def character_tokenizer(x: str):175"""176### Character level tokenizer177"""178return list(x)179180181@option(NLPClassificationConfigs.tokenizer)182def character():183"""184Character level tokenizer configuration185"""186return character_tokenizer187188189@option(NLPClassificationConfigs.n_tokens)190def _n_tokens(c: NLPClassificationConfigs):191"""192Get number of tokens193"""194return len(c.vocab) + 2195196197class CollateFunc:198"""199## Function to load data into batches200"""201202def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):203"""204* `tokenizer` is the tokenizer function205* `vocab` is the vocabulary206* `seq_len` is the length of the sequence207* `padding_token` is the token used for padding when the `seq_len` is larger than the text length208* `classifier_token` is the `[CLS]` token which we set at end of the input209"""210self.classifier_token = classifier_token211self.padding_token = padding_token212self.seq_len = seq_len213self.vocab = vocab214self.tokenizer = tokenizer215216def __call__(self, batch):217"""218* `batch` is the batch of data collected by the `DataLoader`219"""220221# Input data tensor, initialized with `padding_token`222data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)223# Empty labels tensor224labels = torch.zeros(len(batch), dtype=torch.long)225226# Loop through the samples227for (i, (_label, _text)) in enumerate(batch):228# Set the label229labels[i] = int(_label) - 1230# Tokenize the input text231_text = [self.vocab[token] for token in self.tokenizer(_text)]232# Truncate upto `seq_len`233_text = _text[:self.seq_len]234# Transpose and add to data235data[:len(_text), i] = data.new_tensor(_text)236237# Set the final token in the sequence to `[CLS]`238data[-1, :] = self.classifier_token239240#241return data, labels242243244@option([NLPClassificationConfigs.n_classes,245NLPClassificationConfigs.vocab,246NLPClassificationConfigs.train_loader,247NLPClassificationConfigs.valid_loader])248def ag_news(c: NLPClassificationConfigs):249"""250### AG News dataset251252This loads the AG News dataset and the set the values for253`n_classes`, `vocab`, `train_loader`, and `valid_loader`.254"""255256# Get training and validation datasets257train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))258259# Load data to memory260with monit.section('Load data'):261from labml_nn.utils import MapStyleDataset262263# Create [map-style datasets](../utils.html#map_style_dataset)264train, valid = MapStyleDataset(train), MapStyleDataset(valid)265266# Get tokenizer267tokenizer = c.tokenizer268269# Create a counter270counter = Counter()271# Collect tokens from training dataset272for (label, line) in train:273counter.update(tokenizer(line))274# Collect tokens from validation dataset275for (label, line) in valid:276counter.update(tokenizer(line))277# Create vocabulary278vocab = torchtext.vocab.vocab(counter, min_freq=1)279280# Create training data loader281train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,282collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))283# Create validation data loader284valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,285collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))286287# Return `n_classes`, `vocab`, `train_loader`, and `valid_loader`288return 4, vocab, train_loader, valid_loader289290291