Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/Bag-Of-Tricks-For-Image-Classification/model/losses.py
3442 views
1
# MIT License
2
# Copyright (c) 2018 Haitong Li
3
4
5
import torch
6
import torch.nn as nn
7
import torch.nn.functional as F
8
9
10
# Based on https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py
11
class KnowledgeDistillationLoss(nn.Module):
12
def __init__(self, alpha, T, criterion):
13
super().__init__()
14
self.criterion = criterion
15
self.KLDivLoss = nn.KLDivLoss(reduction="batchmean")
16
self.alpha = alpha
17
self.T = T
18
19
def forward(self, input, target, teacher_target):
20
loss = self.KLDivLoss(
21
F.log_softmax(input / self.T, dim=1),
22
F.softmax(teacher_target / self.T, dim=1),
23
) * (self.alpha * self.T * self.T) + self.criterion(input, target) * (
24
1.0 - self.alpha
25
)
26
return loss
27
28
29
class MixUpAugmentationLoss(nn.Module):
30
def __init__(self, criterion):
31
super().__init__()
32
self.criterion = criterion
33
34
def forward(self, input, target, *args):
35
# Validation step
36
if isinstance(target, torch.Tensor):
37
return self.criterion(input, target, *args)
38
target_a, target_b, lmbd = target
39
return lmbd * self.criterion(input, target_a, *args) + (
40
1 - lmbd
41
) * self.criterion(input, target_b, *args)
42
43
44
# Based on https://github.com/pytorch/pytorch/issues/7455
45
class LabelSmoothingLoss(nn.Module):
46
def __init__(self, n_classes, smoothing=0.0, dim=-1):
47
super(LabelSmoothingLoss, self).__init__()
48
self.confidence = 1.0 - smoothing
49
self.smoothing = smoothing
50
self.cls = n_classes
51
self.dim = dim
52
53
def forward(self, output, target, *args):
54
output = output.log_softmax(dim=self.dim)
55
with torch.no_grad():
56
# Create matrix with shapes batch_size x n_classes
57
true_dist = torch.zeros_like(output)
58
# Initialize all elements with epsilon / N - 1
59
true_dist.fill_(self.smoothing / (self.cls - 1))
60
# Fill correct class for each sample in the batch with 1 - epsilon
61
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
62
return torch.mean(torch.sum(-true_dist * output, dim=self.dim))
63
64