Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/Bag-Of-Tricks-For-Image-Classification/model/model.py
3442 views
1
import os
2
import warnings
3
from argparse import (
4
ArgumentParser,
5
Namespace,
6
)
7
8
import numpy as np
9
import pytorch_lightning as pl
10
import torch
11
from torch import nn
12
from torch.utils.data import DataLoader
13
from torchvision.datasets import ImageFolder
14
15
from .augmentations import (
16
get_test_augmentation,
17
get_training_augmentation,
18
)
19
from .losses import (
20
KnowledgeDistillationLoss,
21
LabelSmoothingLoss,
22
MixUpAugmentationLoss,
23
)
24
25
26
class LitFood101(pl.LightningModule):
27
def __init__(self, model, args: Namespace):
28
super().__init__()
29
self.model = model
30
self.args = args
31
# We need to specify a number of classes there to avoid the RuntimeError
32
# See https://github.com/PyTorchLightning/pytorch-lightning/issues/3006
33
# However, we will get another warning and it should be handled in forward steps
34
self.metric = pl.metrics.Accuracy(num_classes=self.args.num_classes)
35
dim_feats = self.model.fc.in_features # =2048
36
nb_classes = self.args.num_classes
37
self.model.fc = nn.Linear(dim_feats, nb_classes)
38
39
def forward(self, x):
40
return self.model(x)
41
42
def setup(self, stage):
43
if self.args.use_smoothing:
44
self.criterion = LabelSmoothingLoss(
45
self.args.num_classes, self.args.smoothing,
46
)
47
else:
48
self.criterion = nn.CrossEntropyLoss()
49
50
if self.args.use_mixup:
51
self.criterion = MixUpAugmentationLoss(self.criterion)
52
53
def on_epoch_start(self):
54
self.previous_batch = [None, None]
55
56
def training_step(self, batch, *args):
57
x, y = batch[0]["image"], batch[1]
58
if self.args.use_mixup:
59
mixup_x, *mixup_y = self.mixup_batch(x, y, *self.previous_batch)
60
logits = self(mixup_x)
61
loss = self.criterion(logits, mixup_y)
62
else:
63
logits = self(x)
64
loss = self.criterion(logits, y)
65
# We ignore a warning about a mismatch between a number of predicted classes
66
# and a number of initialized for Accuracy class
67
with warnings.catch_warnings():
68
warnings.simplefilter("ignore")
69
accuracy = self.metric(logits.argmax(dim=-1), y)
70
tensorboard_logs = {"train_loss": loss, "train_acc": accuracy}
71
self.previous_batch = [x, y]
72
73
return {"loss": loss, "progress_bar": tensorboard_logs, "log": tensorboard_logs}
74
75
def validation_step(self, batch, *args):
76
x, y = batch[0]["image"], batch[1]
77
logits = self(x)
78
val_loss = self.criterion(logits, y)
79
with warnings.catch_warnings():
80
warnings.simplefilter("ignore")
81
val_accuracy = self.metric(logits.argmax(dim=-1), y)
82
return {"val_loss": val_loss, "val_acc": val_accuracy}
83
84
def test_step(self, batch, *args):
85
x, y = batch[0]["image"], batch[1]
86
logits = self(x)
87
with warnings.catch_warnings():
88
warnings.simplefilter("ignore")
89
test_accuracy = self.metric(logits.argmax(dim=-1), y)
90
return {"test_acc": test_accuracy}
91
92
def validation_epoch_end(self, outputs):
93
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
94
avg_accuracy = torch.stack([x["val_acc"] for x in outputs]).mean()
95
tensorboard_logs = {"val_loss": avg_loss, "val_acc": avg_accuracy}
96
return {
97
"avg_val_loss": avg_loss,
98
"avg_val_acc": avg_accuracy,
99
"log": tensorboard_logs,
100
}
101
102
def test_epoch_end(self, outputs):
103
avg_accuracy = torch.stack([x["test_acc"] for x in outputs]).mean()
104
return {"avg_test_acc": avg_accuracy.item()}
105
106
def configure_optimizers(self):
107
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr)
108
if self.args.use_cosine_scheduler:
109
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
110
optimizer, T_max=self.args.max_epochs, eta_min=0.0,
111
)
112
else:
113
scheduler = torch.optim.lr_scheduler.MultiStepLR(
114
optimizer, milestones=self.args.milestones,
115
)
116
return [optimizer], [scheduler]
117
118
def train_dataloader(self):
119
train_dataset = ImageFolder(
120
os.path.join(self.args.data_root, "train"),
121
transform=get_training_augmentation(),
122
)
123
124
return DataLoader(
125
train_dataset,
126
batch_size=self.args.batch_size,
127
shuffle=True,
128
num_workers=self.args.workers,
129
pin_memory=True,
130
)
131
132
def val_dataloader(self):
133
val_dataset = ImageFolder(
134
os.path.join(self.args.data_root, "test"),
135
transform=get_test_augmentation(),
136
)
137
return DataLoader(
138
val_dataset,
139
batch_size=32,
140
shuffle=False,
141
num_workers=self.args.workers,
142
pin_memory=True,
143
)
144
145
def test_dataloader(self):
146
return self.val_dataloader()
147
148
def optimizer_step(self, epoch, batch_idx, optimizer, *args, **kwargs):
149
# Learning Rate warm-up
150
if self.args.warmup != -1 and epoch < self.args.warmup:
151
lr = self.args.lr * (epoch + 1) / self.args.warmup
152
for pg in optimizer.param_groups:
153
pg["lr"] = lr
154
155
self.logger.log_metrics({"lr": optimizer.param_groups[0]["lr"]}, step=epoch)
156
optimizer.step()
157
optimizer.zero_grad()
158
159
def mixup_batch(self, x, y, x_previous, y_previous):
160
lmbd = (
161
np.random.beta(self.args.mixup_alpha, self.args.mixup_alpha)
162
if self.args.mixup_alpha > 0
163
else 1
164
)
165
if x_previous is None:
166
x_previous = torch.empty_like(x).copy_(x)
167
y_previous = torch.empty_like(y).copy_(y)
168
batch_size = x.size(0)
169
index = torch.randperm(batch_size)
170
# If current batch size != previous batch size, we take only a part of the previous batch
171
x_previous = x_previous[:batch_size, ...]
172
y_previous = y_previous[:batch_size, ...]
173
x_mixed = lmbd * x + (1 - lmbd) * x_previous[index, ...]
174
y_a, y_b = y, y_previous[index]
175
return x_mixed, y_a, y_b, lmbd
176
177
@staticmethod
178
def add_model_specific_args(parent_parser):
179
parser = ArgumentParser(parents=[parent_parser], add_help=False)
180
parser.add_argument(
181
"--data-root",
182
default="./data",
183
type=str,
184
help="Path to root folder of the dataset (should include train and test folders)",
185
)
186
parser.add_argument(
187
"-n", "--num-classes", type=int, help="Number of classes", default=21,
188
)
189
parser.add_argument(
190
"-b",
191
"--batch-size",
192
default=32,
193
type=int,
194
metavar="N",
195
help="Mini-batch size",
196
)
197
parser.add_argument(
198
"--lr",
199
"--learning-rate",
200
default=1e-4,
201
type=float,
202
metavar="LR",
203
help="Initial learning rate",
204
)
205
parser.add_argument(
206
"--milestones",
207
type=int,
208
nargs="+",
209
default=[15, 30],
210
help="Milestones for dropping the learning rate",
211
)
212
213
parser.add_argument(
214
"--warmup",
215
type=int,
216
default=6,
217
help="Number of epochs to warm up the learning rate. -1 to turn off",
218
)
219
return parser
220
221
222
class LitFood101KD(LitFood101):
223
def __init__(self, model, teacher, args):
224
super().__init__(model, args)
225
self.teacher = teacher
226
dim_feats = self.teacher.fc.in_features # =2048
227
nb_classes = self.args.num_classes
228
self.teacher.fc = nn.Linear(dim_feats, nb_classes)
229
teacher_checkpoint = torch.load("./teacher.ckpt")
230
self.teacher.load_state_dict(teacher_checkpoint["state_dict"])
231
232
def setup(self, stage):
233
criterion = (
234
LabelSmoothingLoss(self.args.num_classes, self.args.smoothing)
235
if self.args.use_smoothing
236
else nn.CrossEntropyLoss()
237
)
238
self.criterion = KnowledgeDistillationLoss(
239
self.args.distill_alpha, self.args.distill_temperature, criterion=criterion,
240
)
241
if self.args.use_mixup:
242
self.criterion = MixUpAugmentationLoss(self.criterion)
243
self.teacher.eval()
244
245
def training_step(self, batch, *args):
246
x, y = batch[0]["image"], batch[1]
247
with torch.no_grad():
248
teacher_output = self.teacher(x)
249
250
if self.args.use_mixup:
251
mixup_x, *mixup_y = self.mixup_batch(x, y, *self.previous_batch)
252
logits = self(mixup_x)
253
loss = self.criterion(logits, mixup_y, teacher_output)
254
else:
255
logits = self(x)
256
loss = self.criterion(logits, y, teacher_output)
257
258
with warnings.catch_warnings():
259
warnings.simplefilter("ignore")
260
accuracy = self.metric(logits.argmax(dim=-1), y)
261
tensorboard_logs = {"train_loss": loss, "train_acc": accuracy}
262
263
return {"loss": loss, "progress_bar": tensorboard_logs, "log": tensorboard_logs}
264
265
def validation_step(self, batch, *args):
266
x, y = batch[0]["image"], batch[1]
267
logits = self(x)
268
with torch.no_grad():
269
teacher_output = self.teacher(x)
270
val_loss = self.criterion(logits, y, teacher_output)
271
with warnings.catch_warnings():
272
warnings.simplefilter("ignore")
273
val_accuracy = self.metric(logits.argmax(dim=-1), y)
274
return {"val_loss": val_loss, "val_acc": val_accuracy}
275
276