Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/callbacks/mixup.py
781 views
1
"Implements [mixup](https://arxiv.org/abs/1710.09412) training method"
2
from ..torch_core import *
3
from ..callback import *
4
from ..basic_train import Learner, LearnerCallback
5
6
class MixUpCallback(LearnerCallback):
7
"Callback that creates the mixed-up input and target."
8
def __init__(self, learn:Learner, alpha:float=0.4, stack_x:bool=False, stack_y:bool=True):
9
super().__init__(learn)
10
self.alpha,self.stack_x,self.stack_y = alpha,stack_x,stack_y
11
12
def on_train_begin(self, **kwargs):
13
if self.stack_y: self.learn.loss_func = MixUpLoss(self.learn.loss_func)
14
15
def on_batch_begin(self, last_input, last_target, train, **kwargs):
16
"Applies mixup to `last_input` and `last_target` if `train`."
17
if not train: return
18
lambd = np.random.beta(self.alpha, self.alpha, last_target.size(0))
19
lambd = np.concatenate([lambd[:,None], 1-lambd[:,None]], 1).max(1)
20
lambd = last_input.new(lambd)
21
shuffle = torch.randperm(last_target.size(0)).to(last_input.device)
22
x1, y1 = last_input[shuffle], last_target[shuffle]
23
if self.stack_x:
24
new_input = [last_input, last_input[shuffle], lambd]
25
else:
26
out_shape = [lambd.size(0)] + [1 for _ in range(len(x1.shape) - 1)]
27
new_input = (last_input * lambd.view(out_shape) + x1 * (1-lambd).view(out_shape))
28
if self.stack_y:
29
new_target = torch.cat([last_target[:,None].float(), y1[:,None].float(), lambd[:,None].float()], 1)
30
else:
31
if len(last_target.shape) == 2:
32
lambd = lambd.unsqueeze(1).float()
33
new_target = last_target.float() * lambd + y1.float() * (1-lambd)
34
return {'last_input': new_input, 'last_target': new_target}
35
36
def on_train_end(self, **kwargs):
37
if self.stack_y: self.learn.loss_func = self.learn.loss_func.get_old()
38
39
40
class MixUpLoss(Module):
41
"Adapt the loss function `crit` to go with mixup."
42
43
def __init__(self, crit, reduction='mean'):
44
super().__init__()
45
if hasattr(crit, 'reduction'):
46
self.crit = crit
47
self.old_red = crit.reduction
48
setattr(self.crit, 'reduction', 'none')
49
else:
50
self.crit = partial(crit, reduction='none')
51
self.old_crit = crit
52
self.reduction = reduction
53
54
def forward(self, output, target):
55
if len(target.size()) == 2:
56
loss1, loss2 = self.crit(output,target[:,0].long()), self.crit(output,target[:,1].long())
57
d = (loss1 * target[:,2] + loss2 * (1-target[:,2])).mean()
58
else: d = self.crit(output, target)
59
if self.reduction == 'mean': return d.mean()
60
elif self.reduction == 'sum': return d.sum()
61
return d
62
63
def get_old(self):
64
if hasattr(self, 'old_crit'): return self.old_crit
65
elif hasattr(self, 'old_red'):
66
setattr(self.crit, 'reduction', self.old_red)
67
return self.crit
68
69