Path: blob/master/Bag-Of-Tricks-For-Image-Classification/model/model.py
3442 views
import os1import warnings2from argparse import (3ArgumentParser,4Namespace,5)67import numpy as np8import pytorch_lightning as pl9import torch10from torch import nn11from torch.utils.data import DataLoader12from torchvision.datasets import ImageFolder1314from .augmentations import (15get_test_augmentation,16get_training_augmentation,17)18from .losses import (19KnowledgeDistillationLoss,20LabelSmoothingLoss,21MixUpAugmentationLoss,22)232425class LitFood101(pl.LightningModule):26def __init__(self, model, args: Namespace):27super().__init__()28self.model = model29self.args = args30# We need to specify a number of classes there to avoid the RuntimeError31# See https://github.com/PyTorchLightning/pytorch-lightning/issues/300632# However, we will get another warning and it should be handled in forward steps33self.metric = pl.metrics.Accuracy(num_classes=self.args.num_classes)34dim_feats = self.model.fc.in_features # =204835nb_classes = self.args.num_classes36self.model.fc = nn.Linear(dim_feats, nb_classes)3738def forward(self, x):39return self.model(x)4041def setup(self, stage):42if self.args.use_smoothing:43self.criterion = LabelSmoothingLoss(44self.args.num_classes, self.args.smoothing,45)46else:47self.criterion = nn.CrossEntropyLoss()4849if self.args.use_mixup:50self.criterion = MixUpAugmentationLoss(self.criterion)5152def on_epoch_start(self):53self.previous_batch = [None, None]5455def training_step(self, batch, *args):56x, y = batch[0]["image"], batch[1]57if self.args.use_mixup:58mixup_x, *mixup_y = self.mixup_batch(x, y, *self.previous_batch)59logits = self(mixup_x)60loss = self.criterion(logits, mixup_y)61else:62logits = self(x)63loss = self.criterion(logits, y)64# We ignore a warning about a mismatch between a number of predicted classes65# and a number of initialized for Accuracy class66with warnings.catch_warnings():67warnings.simplefilter("ignore")68accuracy = self.metric(logits.argmax(dim=-1), y)69tensorboard_logs = {"train_loss": loss, "train_acc": accuracy}70self.previous_batch = [x, y]7172return {"loss": loss, "progress_bar": tensorboard_logs, "log": tensorboard_logs}7374def validation_step(self, batch, *args):75x, y = batch[0]["image"], batch[1]76logits = self(x)77val_loss = self.criterion(logits, y)78with warnings.catch_warnings():79warnings.simplefilter("ignore")80val_accuracy = self.metric(logits.argmax(dim=-1), y)81return {"val_loss": val_loss, "val_acc": val_accuracy}8283def test_step(self, batch, *args):84x, y = batch[0]["image"], batch[1]85logits = self(x)86with warnings.catch_warnings():87warnings.simplefilter("ignore")88test_accuracy = self.metric(logits.argmax(dim=-1), y)89return {"test_acc": test_accuracy}9091def validation_epoch_end(self, outputs):92avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()93avg_accuracy = torch.stack([x["val_acc"] for x in outputs]).mean()94tensorboard_logs = {"val_loss": avg_loss, "val_acc": avg_accuracy}95return {96"avg_val_loss": avg_loss,97"avg_val_acc": avg_accuracy,98"log": tensorboard_logs,99}100101def test_epoch_end(self, outputs):102avg_accuracy = torch.stack([x["test_acc"] for x in outputs]).mean()103return {"avg_test_acc": avg_accuracy.item()}104105def configure_optimizers(self):106optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr)107if self.args.use_cosine_scheduler:108scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(109optimizer, T_max=self.args.max_epochs, eta_min=0.0,110)111else:112scheduler = torch.optim.lr_scheduler.MultiStepLR(113optimizer, milestones=self.args.milestones,114)115return [optimizer], [scheduler]116117def train_dataloader(self):118train_dataset = ImageFolder(119os.path.join(self.args.data_root, "train"),120transform=get_training_augmentation(),121)122123return DataLoader(124train_dataset,125batch_size=self.args.batch_size,126shuffle=True,127num_workers=self.args.workers,128pin_memory=True,129)130131def val_dataloader(self):132val_dataset = ImageFolder(133os.path.join(self.args.data_root, "test"),134transform=get_test_augmentation(),135)136return DataLoader(137val_dataset,138batch_size=32,139shuffle=False,140num_workers=self.args.workers,141pin_memory=True,142)143144def test_dataloader(self):145return self.val_dataloader()146147def optimizer_step(self, epoch, batch_idx, optimizer, *args, **kwargs):148# Learning Rate warm-up149if self.args.warmup != -1 and epoch < self.args.warmup:150lr = self.args.lr * (epoch + 1) / self.args.warmup151for pg in optimizer.param_groups:152pg["lr"] = lr153154self.logger.log_metrics({"lr": optimizer.param_groups[0]["lr"]}, step=epoch)155optimizer.step()156optimizer.zero_grad()157158def mixup_batch(self, x, y, x_previous, y_previous):159lmbd = (160np.random.beta(self.args.mixup_alpha, self.args.mixup_alpha)161if self.args.mixup_alpha > 0162else 1163)164if x_previous is None:165x_previous = torch.empty_like(x).copy_(x)166y_previous = torch.empty_like(y).copy_(y)167batch_size = x.size(0)168index = torch.randperm(batch_size)169# If current batch size != previous batch size, we take only a part of the previous batch170x_previous = x_previous[:batch_size, ...]171y_previous = y_previous[:batch_size, ...]172x_mixed = lmbd * x + (1 - lmbd) * x_previous[index, ...]173y_a, y_b = y, y_previous[index]174return x_mixed, y_a, y_b, lmbd175176@staticmethod177def add_model_specific_args(parent_parser):178parser = ArgumentParser(parents=[parent_parser], add_help=False)179parser.add_argument(180"--data-root",181default="./data",182type=str,183help="Path to root folder of the dataset (should include train and test folders)",184)185parser.add_argument(186"-n", "--num-classes", type=int, help="Number of classes", default=21,187)188parser.add_argument(189"-b",190"--batch-size",191default=32,192type=int,193metavar="N",194help="Mini-batch size",195)196parser.add_argument(197"--lr",198"--learning-rate",199default=1e-4,200type=float,201metavar="LR",202help="Initial learning rate",203)204parser.add_argument(205"--milestones",206type=int,207nargs="+",208default=[15, 30],209help="Milestones for dropping the learning rate",210)211212parser.add_argument(213"--warmup",214type=int,215default=6,216help="Number of epochs to warm up the learning rate. -1 to turn off",217)218return parser219220221class LitFood101KD(LitFood101):222def __init__(self, model, teacher, args):223super().__init__(model, args)224self.teacher = teacher225dim_feats = self.teacher.fc.in_features # =2048226nb_classes = self.args.num_classes227self.teacher.fc = nn.Linear(dim_feats, nb_classes)228teacher_checkpoint = torch.load("./teacher.ckpt")229self.teacher.load_state_dict(teacher_checkpoint["state_dict"])230231def setup(self, stage):232criterion = (233LabelSmoothingLoss(self.args.num_classes, self.args.smoothing)234if self.args.use_smoothing235else nn.CrossEntropyLoss()236)237self.criterion = KnowledgeDistillationLoss(238self.args.distill_alpha, self.args.distill_temperature, criterion=criterion,239)240if self.args.use_mixup:241self.criterion = MixUpAugmentationLoss(self.criterion)242self.teacher.eval()243244def training_step(self, batch, *args):245x, y = batch[0]["image"], batch[1]246with torch.no_grad():247teacher_output = self.teacher(x)248249if self.args.use_mixup:250mixup_x, *mixup_y = self.mixup_batch(x, y, *self.previous_batch)251logits = self(mixup_x)252loss = self.criterion(logits, mixup_y, teacher_output)253else:254logits = self(x)255loss = self.criterion(logits, y, teacher_output)256257with warnings.catch_warnings():258warnings.simplefilter("ignore")259accuracy = self.metric(logits.argmax(dim=-1), y)260tensorboard_logs = {"train_loss": loss, "train_acc": accuracy}261262return {"loss": loss, "progress_bar": tensorboard_logs, "log": tensorboard_logs}263264def validation_step(self, batch, *args):265x, y = batch[0]["image"], batch[1]266logits = self(x)267with torch.no_grad():268teacher_output = self.teacher(x)269val_loss = self.criterion(logits, y, teacher_output)270with warnings.catch_warnings():271warnings.simplefilter("ignore")272val_accuracy = self.metric(logits.argmax(dim=-1), y)273return {"val_loss": val_loss, "val_acc": val_accuracy}274275276