Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/widgets/class_confusion.py
781 views
1
import math
2
import pandas as pd
3
import matplotlib.pyplot as plt
4
from tqdm import tqdm
5
from itertools import permutations
6
from ..tabular import TabularDataBunch
7
from ..train import ClassificationInterpretation
8
import ipywidgets as widgets
9
10
class ClassConfusion():
11
"Plot the most confused datapoints and statistics for the models misses."
12
def __init__(self, interp:ClassificationInterpretation, classlist:list,
13
is_ordered:bool=False, cut_off:int=100, varlist:list=None,
14
figsize:tuple=(8,8)):
15
self.interp = interp
16
self._is_tab = isinstance(interp.learn.data, TabularDataBunch)
17
if self._is_tab:
18
if interp.learn.data.train_ds.x.cont_names != []:
19
for x in range(len(interp.learn.data.procs)):
20
if "Normalize" in str(interp.learn.data.procs[x]):
21
self.means = interp.learn.data.train_ds.x.processor[0].procs[x].means
22
self.stds = interp.learn.data.train_ds.x.processor[0].procs[x].stds
23
self.is_ordered = is_ordered
24
self.cut_off = cut_off
25
self.figsize = figsize
26
self.varlist = varlist
27
self.classl = classlist
28
self._show_losses(classlist)
29
30
def _show_losses(self, classl:list, **kwargs):
31
"Checks if the model is for Tabular or Images and gathers top losses"
32
_, self.tl_idx = self.interp.top_losses(len(self.interp.losses))
33
self._tab_losses() if self._is_tab else self._create_tabs()
34
35
def _create_tabs(self):
36
"Creates a tab for each variable"
37
self.lis = self.classl if self.is_ordered else list(permutations(self.classl, 2))
38
if self._is_tab:
39
self._boxes = len(self.df_list)
40
self._cols = math.ceil(math.sqrt(self._boxes))
41
self._rows = math.ceil(self._boxes/self._cols)
42
self.tbnames = list(self.df_list[0].columns)[:-1] if self.varlist is None else self.varlist
43
else:
44
vals = self.interp.most_confused()
45
self._ranges = []
46
self.tbnames = []
47
self._boxes = int(input('Please enter a value for `k`, or the top images you will see: '))
48
for x in iter(vals):
49
for y in range(len(self.lis)):
50
if x[0:2] == self.lis[y]:
51
self._ranges.append(x[2])
52
self.tbnames.append(str(x[0] + ' | ' + x[1]))
53
items = [widgets.Output() for i, tab in enumerate(self.tbnames)]
54
self.tabs = widgets.Tab()
55
self.tabs.children = items
56
for i in range(len(items)):
57
self.tabs.set_title(i, self.tbnames[i])
58
self._populate_tabs()
59
60
def _populate_tabs(self):
61
"Adds relevant graphs to each tab"
62
with tqdm(total=len(self.tbnames)) as pbar:
63
for i, tab in enumerate(self.tbnames):
64
with self.tabs.children[i]:
65
self._plot_tab(tab) if self._is_tab else self._plot_imgs(tab, i)
66
pbar.update(1)
67
display(self.tabs)
68
69
def _plot_tab(self, tab:str):
70
"Generates graphs"
71
if self._boxes is not None:
72
fig, ax = plt.subplots(self._boxes, figsize=self.figsize)
73
else:
74
fig, ax = plt.subplots(self._cols, self._rows, figsize=self.figsize)
75
fig.subplots_adjust(hspace=.5)
76
for j, x in enumerate(self.df_list):
77
title = f'{"".join(x.columns[-1])} {tab} distribution'
78
79
if self._boxes is None:
80
row = int(j / self._cols)
81
col = j % row
82
if tab in self.cat_names:
83
vals = pd.value_counts(x[tab].values)
84
if self._boxes is not None:
85
if vals.nunique() < 10:
86
fig = vals.plot(kind='bar', title=title, ax=ax[j], rot=0, width=.75)
87
elif vals.nunique() > self.cut_off:
88
print(f'Number of values is above {self.cut_off}')
89
else:
90
fig = vals.plot(kind='barh', title=title, ax=ax[j], width=.75)
91
else:
92
fig = vals.plot(kind='barh', title=title, ax=ax[row, col], width=.75)
93
else:
94
vals = x[tab]
95
if self._boxes is not None:
96
axs = vals.plot(kind='hist', ax=ax[j], title=title, y='Frequency')
97
else:
98
axs = vals.plot(kind='hist', ax=ax[row, col], title=title, y='Frequency')
99
axs.set_ylabel('Frequency')
100
if len(set(vals)) > 1:
101
vals.plot(kind='kde', ax=axs, title=title, secondary_y=True)
102
else:
103
print('Less than two unique values, cannot graph the KDE')
104
plt.show(fig)
105
plt.tight_layout()
106
107
def _plot_imgs(self, tab:str, i:int ,**kwargs):
108
"Plots the most confused images"
109
classes_gnd = self.interp.data.classes
110
x = 0
111
if self._ranges[i] < self._boxes:
112
cols = math.ceil(math.sqrt(self._ranges[i]))
113
rows = math.ceil(self._ranges[i]/cols)
114
if self._ranges[i] < 4 or self._boxes < 4:
115
cols = 2
116
rows = 2
117
else:
118
cols = math.ceil(math.sqrt(self._boxes))
119
rows = math.ceil(self._boxes/cols)
120
fig, ax = plt.subplots(rows, cols, figsize=self.figsize)
121
[axi.set_axis_off() for axi in ax.ravel()]
122
for j, idx in enumerate(self.tl_idx):
123
if self._boxes < x+1 or x > self._ranges[i]:
124
break
125
da, cl = self.interp.data.dl(self.interp.ds_type).dataset[idx]
126
row = (int)(x / cols)
127
col = x % cols
128
if str(cl) == tab.split(' ')[0] and str(classes_gnd[self.interp.pred_class[idx]]) == tab.split(' ')[2]:
129
img, lbl = self.interp.data.valid_ds[idx]
130
fn = self.interp.data.valid_ds.x.items[idx]
131
fn = re.search('([^/*]+)_\d+.*$', str(fn)).group(0)
132
img.show(ax=ax[row, col])
133
ax[row,col].set_title(fn)
134
x += 1
135
plt.show(fig)
136
plt.tight_layout()
137
138
def _tab_losses(self, **kwargs):
139
"Gathers dataframes of the combinations data"
140
classes = self.interp.data.classes
141
cat_names = self.interp.data.x.cat_names
142
cont_names = self.interp.data.x.cont_names
143
comb = self.classl if self.is_ordered else list(permutations(self.classl,2))
144
self.df_list = []
145
arr = []
146
for i, idx in enumerate(self.tl_idx):
147
da, _ = self.interp.data.dl(self.interp.ds_type).dataset[idx]
148
res = ''
149
for c, n in zip(da.cats, da.names[:len(da.cats)]):
150
string = f'{da.classes[n][c]}'
151
if string == 'True' or string == 'False':
152
string += ';'
153
res += string
154
else:
155
string = string[1:]
156
res += string + ';'
157
for c, n in zip(da.conts, da.names[len(da.cats):]):
158
res += f'{c:.4f};'
159
arr.append(res)
160
f = pd.DataFrame([ x.split(';')[:-1] for x in arr], columns=da.names)
161
for i, var in enumerate(self.interp.data.cont_names):
162
f[var] = f[var].apply(lambda x: float(x) * self.stds[var] + self.means[var])
163
f['Original'] = 'Original'
164
self.df_list.append(f)
165
for j, x in enumerate(comb):
166
arr = []
167
for i, idx in enumerate(self.tl_idx):
168
da, cl = self.interp.data.dl(self.interp.ds_type).dataset[idx]
169
cl = int(cl)
170
if classes[self.interp.pred_class[idx]] == comb[j][0] and classes[cl] == comb[j][1]:
171
res = ''
172
for c, n in zip(da.cats, da.names[:len(da.cats)]):
173
string = f'{da.classes[n][c]}'
174
if string == 'True' or string == 'False':
175
string += ';'
176
res += string
177
else:
178
string = string[1:]
179
res += string + ';'
180
for c, n in zip(da.conts, da.names[len(da.cats):]):
181
res += f'{c:.4f};'
182
arr.append(res)
183
f = pd.DataFrame([ x.split(';')[:-1] for x in arr], columns=da.names)
184
for i, var in enumerate(self.interp.data.cont_names):
185
f[var] = f[var].apply(lambda x: float(x) * self.stds[var] + self.means[var])
186
f[str(x)] = str(x)
187
self.df_list.append(f)
188
self.cat_names = cat_names
189
self._create_tabs()
190
191