Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/vision/tta.py
781 views
1
"Brings TTA (Test Time Functionality) to the `Learner` class. Use `learner.TTA()` instead"
2
from ..torch_core import *
3
from ..basic_train import *
4
from ..basic_train import _loss_func2activ
5
from ..basic_data import DatasetType
6
from .transform import *
7
8
__all__ = []
9
10
def _tta_only(learn:Learner, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None, scale:float=1.35) -> Iterator[List[Tensor]]:
11
"Computes the outputs for several augmented inputs for TTA"
12
dl = learn.dl(ds_type)
13
ds = dl.dataset
14
old = ds.tfms
15
activ = ifnone(activ, _loss_func2activ(learn.loss_func))
16
augm_tfm = [o for o in learn.data.train_ds.tfms if o.tfm not in
17
(crop_pad, flip_lr, dihedral, zoom)]
18
try:
19
pbar = master_bar(range(8))
20
for i in pbar:
21
row = 1 if i&1 else 0
22
col = 1 if i&2 else 0
23
flip = i&4
24
d = {'row_pct':row, 'col_pct':col, 'is_random':False}
25
tfm = [*augm_tfm, zoom(scale=scale, **d), crop_pad(**d)]
26
if flip: tfm.append(flip_lr(p=1.))
27
ds.tfms = tfm
28
yield get_preds(learn.model, dl, pbar=pbar, activ=activ)[0]
29
finally: ds.tfms = old
30
31
Learner.tta_only = _tta_only
32
33
def _TTA(learn:Learner, beta:float=0.4, scale:float=1.35, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None, with_loss:bool=False) -> Tensors:
34
"Applies TTA to predict on `ds_type` dataset."
35
preds,y = learn.get_preds(ds_type, activ=activ)
36
all_preds = list(learn.tta_only(ds_type=ds_type, activ=activ, scale=scale))
37
avg_preds = torch.stack(all_preds).mean(0)
38
if beta is None: return preds,avg_preds,y
39
else:
40
final_preds = preds*beta + avg_preds*(1-beta)
41
if with_loss:
42
with NoneReduceOnCPU(learn.loss_func) as lf: loss = lf(final_preds, y)
43
return final_preds, y, loss
44
return final_preds, y
45
46
Learner.TTA = _TTA
47
48