Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/vision/interpret.py
781 views
1
from ..torch_core import *
2
from ..basic_data import *
3
from ..basic_train import *
4
from .image import *
5
from ..train import Interpretation
6
from textwrap import wrap
7
8
__all__ = ['SegmentationInterpretation', 'ObjectDetectionInterpretation']
9
10
class SegmentationInterpretation(Interpretation):
11
"Interpretation methods for segmenatation models."
12
def __init__(self, learn:Learner, preds:Tensor, y_true:Tensor, losses:Tensor,
13
ds_type:DatasetType=DatasetType.Valid):
14
super(SegmentationInterpretation, self).__init__(learn,preds,y_true,losses,ds_type)
15
self.pred_class = self.preds.argmax(dim=1)
16
self.c2i = {c:i for i,c in enumerate(self.data.classes)}
17
self.i2c = {i:c for c,i in self.c2i.items()}
18
19
def top_losses(self, sizes:Tuple, k:int=None, largest=True):
20
"Reduce flatten loss to give a single loss value for each image"
21
losses = self.losses.view(-1, np.prod(sizes)).mean(-1)
22
return losses.topk(ifnone(k, len(losses)), largest=largest)
23
24
def _interp_show(self, ims:ImageSegment, classes:Collection=None, sz:int=20, cmap='tab20',
25
title_suffix:str=None):
26
"Show ImageSegment with color mapping labels"
27
fig,axes=plt.subplots(1,2,figsize=(sz,sz))
28
np_im = to_np(ims.data).copy()
29
# tab20 - qualitative colormaps support max of 20 distinc colors
30
# if len(classes) > 20 close idxs map to same color
31
# image
32
if classes is not None:
33
class_idxs = [self.c2i[c] for c in classes]
34
mask = np.max(np.stack([np_im==i for i in class_idxs]),axis=0)
35
np_im = (np_im*mask).astype(np.float)
36
np_im[np.where(mask==0)] = np.nan
37
im=axes[0].imshow(np_im[0], cmap=cmap)
38
39
# labels
40
np_im_labels = list(np.unique(np_im[~np.isnan(np_im)]))
41
c = len(np_im_labels); n = math.ceil(np.sqrt(c))
42
label_im = np.array(np_im_labels + [np.nan]*(n**2-c)).reshape(n,n)
43
axes[1].imshow(label_im, cmap=cmap)
44
for i,l in enumerate([self.i2c[l] for l in np_im_labels]):
45
div,mod=divmod(i,n)
46
l = "\n".join(wrap(l,10)) if len(l) > 10 else l
47
axes[1].text(mod, div, f"{l}", ha='center', color='white', fontdict={'size':sz})
48
49
if title_suffix:
50
axes[0].set_title(f"{title_suffix}_imsegment")
51
axes[1].set_title(f"{title_suffix}_labels")
52
53
def show_xyz(self, i, classes:list=None, sz=10):
54
'show (image, true and pred) from self.ds with color mappings, optionally only plot'
55
x,y = self.ds[i]
56
self.ds.show_xys([x],[y], figsize=(sz/2,sz/2))
57
self._interp_show(ImageSegment(self.y_true[i]), classes, sz=sz, title_suffix='true')
58
self._interp_show(ImageSegment(self.pred_class[i][None,:]), classes, sz=sz, title_suffix='pred')
59
60
def _generate_confusion(self):
61
"Average and Per Image Confusion: intersection of pixels given a true label, true label sums to 1"
62
single_img_confusion = []
63
mean_confusion = []
64
n = self.pred_class.shape[0]
65
for c_j in range(self.data.c):
66
true_binary = self.y_true.squeeze(1) == c_j
67
total_true = true_binary.view(n,-1).sum(dim=1).float()
68
for c_i in range(self.data.c):
69
pred_binary = self.pred_class == c_i
70
total_intersect = (true_binary*pred_binary).view(n,-1).sum(dim=1).float()
71
p_given_t = (total_intersect / (total_true))
72
p_given_t_mean = p_given_t[~torch.isnan(p_given_t)].mean()
73
single_img_confusion.append(p_given_t)
74
mean_confusion.append(p_given_t_mean)
75
self.single_img_cm = to_np(torch.stack(single_img_confusion).permute(1,0).view(-1, self.data.c, self.data.c))
76
self.mean_cm = to_np(torch.tensor(mean_confusion).view(self.data.c, self.data.c))
77
return self.mean_cm, self.single_img_cm
78
79
def _plot_intersect_cm(self, cm, title="Intersection with Predict given True"):
80
"Plot confusion matrices: self.mean_cm or self.single_img_cm generated by `_generate_confusion`"
81
from IPython.display import display, HTML
82
fig,ax=plt.subplots(1,1,figsize=(10,10))
83
im=ax.imshow(cm, cmap="Blues")
84
ax.set_xlabel("Predicted")
85
ax.set_ylabel("True")
86
ax.set_title(f"{title}")
87
ax.set_xticks(range(self.data.c))
88
ax.set_yticks(range(self.data.c))
89
ax.set_xticklabels(self.data.classes, rotation='vertical')
90
ax.set_yticklabels(self.data.classes)
91
fig.colorbar(im)
92
93
df = (pd.DataFrame([self.data.classes, cm.diagonal()], index=['label', 'score'])
94
.T.sort_values('score', ascending=False))
95
with pd.option_context('display.max_colwidth', -1):
96
display(HTML(df.to_html(index=False)))
97
return df
98
99
100
101
class ObjectDetectionInterpretation(Interpretation):
102
"Interpretation methods for classification models."
103
def __init__(self, learn:Learner, preds:Tensor, y_true:Tensor, losses:Tensor, ds_type:DatasetType=DatasetType.Valid):
104
raise NotImplementedError
105
super(ObjectDetectionInterpretation, self).__init__(learn,preds,y_true,losses,ds_type)
106
107