Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/widgets/image_cleaner.py
781 views
1
from ..torch_core import *
2
from ..basic_train import *
3
from ..basic_data import *
4
from ..vision.data import *
5
from ..vision.transform import *
6
from ..vision.image import *
7
from ..callbacks.hooks import *
8
from ..layers import *
9
from ipywidgets import widgets, Layout
10
from IPython.display import clear_output, display
11
12
__all__ = ['DatasetFormatter', 'ImageCleaner']
13
14
class DatasetFormatter():
15
"Returns a dataset with the appropriate format and file indices to be displayed."
16
@classmethod
17
def from_toplosses(cls, learn, n_imgs=None, **kwargs):
18
"Gets indices with top losses."
19
train_ds, train_idxs = cls.get_toplosses_idxs(learn, n_imgs, **kwargs)
20
return train_ds, train_idxs
21
22
@classmethod
23
def get_toplosses_idxs(cls, learn, n_imgs, **kwargs):
24
"Sorts `ds_type` dataset by top losses and returns dataset and sorted indices."
25
dl = learn.data.fix_dl
26
if not n_imgs: n_imgs = len(dl.dataset)
27
_,_,top_losses = learn.get_preds(ds_type=DatasetType.Fix, with_loss=True)
28
idxs = torch.topk(top_losses, n_imgs)[1]
29
return cls.padded_ds(dl.dataset, **kwargs), idxs
30
31
def padded_ds(ll_input, size=(250, 300), resize_method=ResizeMethod.CROP, padding_mode='zeros', **kwargs):
32
"For a LabelList `ll_input`, resize each image to `size` using `resize_method` and `padding_mode`."
33
return ll_input.transform(tfms=crop_pad(), size=size, resize_method=resize_method, padding_mode=padding_mode)
34
35
@classmethod
36
def from_similars(cls, learn, layer_ls:list=[0, 7, 2], **kwargs):
37
"Gets the indices for the most similar images."
38
train_ds, train_idxs = cls.get_similars_idxs(learn, layer_ls, **kwargs)
39
return train_ds, train_idxs
40
41
@classmethod
42
def get_similars_idxs(cls, learn, layer_ls, **kwargs):
43
"Gets the indices for the most similar images in `ds_type` dataset"
44
hook = hook_output(learn.model[layer_ls[0]][layer_ls[1]][layer_ls[2]])
45
dl = learn.data.fix_dl
46
47
ds_actns = cls.get_actns(learn, hook=hook, dl=dl, **kwargs)
48
similarities = cls.comb_similarity(ds_actns, ds_actns, **kwargs)
49
idxs = cls.sort_idxs(similarities)
50
return cls.padded_ds(dl, **kwargs), idxs
51
52
@staticmethod
53
def get_actns(learn, hook:Hook, dl:DataLoader, pool=AdaptiveConcatPool2d, pool_dim:int=4, **kwargs):
54
"Gets activations at the layer specified by `hook`, applies `pool` of dim `pool_dim` and concatenates"
55
print('Getting activations...')
56
57
actns = []
58
learn.model.eval()
59
with torch.no_grad():
60
for (xb,yb) in progress_bar(dl):
61
learn.model(xb)
62
actns.append((hook.stored).cpu())
63
64
if pool:
65
pool = pool(pool_dim)
66
return pool(torch.cat(actns)).view(len(dl.x),-1)
67
else: return torch.cat(actns).view(len(dl.x),-1)
68
69
70
@staticmethod
71
def comb_similarity(t1: torch.Tensor, t2: torch.Tensor, **kwargs):
72
# https://github.com/pytorch/pytorch/issues/11202
73
"Computes the similarity function between each embedding of `t1` and `t2` matrices."
74
print('Computing similarities...')
75
76
w1 = t1.norm(p=2, dim=1, keepdim=True)
77
w2 = w1 if t2 is t1 else t2.norm(p=2, dim=1, keepdim=True)
78
79
t = torch.mm(t1, t2.t()) / (w1 * w2.t()).clamp(min=1e-8)
80
return torch.tril(t, diagonal=-1)
81
82
def largest_indices(arr, n):
83
"Returns the `n` largest indices from a numpy array `arr`."
84
#https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array
85
flat = arr.flatten()
86
indices = np.argpartition(flat, -n)[-n:]
87
indices = indices[np.argsort(-flat[indices])]
88
return np.unravel_index(indices, arr.shape)
89
90
@classmethod
91
def sort_idxs(cls, similarities):
92
"Sorts `similarities` and return the indexes in pairs ordered by highest similarity."
93
idxs = cls.largest_indices(similarities, len(similarities))
94
idxs = [(idxs[0][i], idxs[1][i]) for i in range(len(idxs[0]))]
95
return [e for l in idxs for e in l]
96
97
class ImageCleaner():
98
"Displays images for relabeling or deletion and saves changes in `path` as 'cleaned.csv'."
99
def __init__(self, dataset, fns_idxs, path, batch_size:int=5, duplicates=False):
100
self._all_images,self._batch = [],[]
101
self._path = Path(path)
102
self._batch_size = batch_size
103
if duplicates: self._batch_size = 2
104
self._duplicates = duplicates
105
self._labels = dataset.classes
106
self._all_images = self.create_image_list(dataset, fns_idxs)
107
self._csv_dict = {dataset.x.items[i]: dataset.y[i] for i in range(len(dataset))}
108
self._deleted_fns = []
109
self._skipped = 0
110
self.render()
111
112
@classmethod
113
def make_img_widget(cls, img, layout=Layout(), format='jpg'):
114
"Returns an image widget for specified file name `img`."
115
return widgets.Image(value=img, format=format, layout=layout)
116
117
@classmethod
118
def make_button_widget(cls, label, file_path=None, handler=None, style=None, layout=Layout(width='auto')):
119
"Return a Button widget with specified `handler`."
120
btn = widgets.Button(description=label, layout=layout)
121
if handler is not None: btn.on_click(handler)
122
if style is not None: btn.button_style = style
123
btn.file_path = file_path
124
btn.flagged_for_delete = False
125
return btn
126
127
@classmethod
128
def make_dropdown_widget(cls, description='Description', options=['Label 1', 'Label 2'], value='Label 1',
129
file_path=None, layout=Layout(), handler=None):
130
"Return a Dropdown widget with specified `handler`."
131
dd = widgets.Dropdown(description=description, options=options, value=value, layout=layout)
132
if file_path is not None: dd.file_path = file_path
133
if handler is not None: dd.observe(handler, names=['value'])
134
return dd
135
136
@classmethod
137
def make_horizontal_box(cls, children, layout=Layout()):
138
"Make a horizontal box with `children` and `layout`."
139
return widgets.HBox(children, layout=layout)
140
141
@classmethod
142
def make_vertical_box(cls, children, layout=Layout(), duplicates=False):
143
"Make a vertical box with `children` and `layout`."
144
if not duplicates: return widgets.VBox(children, layout=layout)
145
else: return widgets.VBox([children[0], children[2]], layout=layout)
146
147
def create_image_list(self, dataset, fns_idxs):
148
"Create a list of images, filenames and labels but first removing files that are not supposed to be displayed."
149
items = dataset.x.items
150
if self._duplicates:
151
chunked_idxs = chunks(fns_idxs, 2)
152
chunked_idxs = [chunk for chunk in chunked_idxs if Path(items[chunk[0]]).is_file() and Path(items[chunk[1]]).is_file()]
153
return [(dataset.x[i]._repr_jpeg_(), items[i], self._labels[dataset.y[i].data]) for chunk in chunked_idxs for i in chunk]
154
else:
155
return [(dataset.x[i]._repr_jpeg_(), items[i], self._labels[dataset.y[i].data]) for i in fns_idxs if
156
Path(items[i]).is_file()]
157
158
def relabel(self, change):
159
"Relabel images by moving from parent dir with old label `class_old` to parent dir with new label `class_new`."
160
class_new,class_old,file_path = change.new,change.old,change.owner.file_path
161
fp = Path(file_path)
162
parent = fp.parents[1]
163
self._csv_dict[fp] = class_new
164
165
def next_batch(self, _):
166
"Handler for 'Next Batch' button click. Delete all flagged images and renders next batch."
167
for img_widget, delete_btn, fp, in self._batch:
168
fp = delete_btn.file_path
169
if (delete_btn.flagged_for_delete == True):
170
self.delete_image(fp)
171
self._deleted_fns.append(fp)
172
self._all_images = self._all_images[self._batch_size:]
173
self.empty_batch()
174
self.render()
175
176
def on_delete(self, btn):
177
"Flag this image as delete or keep."
178
btn.button_style = "" if btn.flagged_for_delete else "danger"
179
btn.flagged_for_delete = not btn.flagged_for_delete
180
181
def empty_batch(self): self._batch[:] = []
182
183
def delete_image(self, file_path):
184
del self._csv_dict[file_path]
185
186
def empty(self):
187
return len(self._all_images) == 0
188
189
def get_widgets(self, duplicates):
190
"Create and format widget set."
191
widgets = []
192
for (img,fp,human_readable_label) in self._all_images[:self._batch_size]:
193
img_widget = self.make_img_widget(img, layout=Layout(height='250px', width='300px'))
194
dropdown = self.make_dropdown_widget(description='', options=self._labels, value=human_readable_label,
195
file_path=fp, handler=self.relabel, layout=Layout(width='auto'))
196
delete_btn = self.make_button_widget('Delete', file_path=fp, handler=self.on_delete)
197
widgets.append(self.make_vertical_box([img_widget, dropdown, delete_btn],
198
layout=Layout(width='auto', height='300px',
199
overflow_x="hidden"), duplicates=duplicates))
200
self._batch.append((img_widget, delete_btn, fp))
201
return widgets
202
203
def batch_contains_deleted(self):
204
"Check if current batch contains already deleted images."
205
if not self._duplicates: return False
206
imgs = [self._all_images[:self._batch_size][0][1], self._all_images[:self._batch_size][1][1]]
207
return any(img in self._deleted_fns for img in imgs)
208
209
def write_csv(self):
210
# Get first element's file path so we write CSV to same directory as our data
211
csv_path = self._path/'cleaned.csv'
212
with open(csv_path, 'w') as f:
213
csv_writer = csv.writer(f)
214
csv_writer.writerow(['name','label'])
215
for pair in self._csv_dict.items():
216
pair = [os.path.relpath(pair[0], self._path), pair[1]]
217
csv_writer.writerow(pair)
218
return csv_path
219
220
def render(self):
221
"Re-render Jupyter cell for batch of images."
222
clear_output()
223
self.write_csv()
224
if self.empty() and self._skipped>0:
225
return display(f'No images to show :). {self._skipped} pairs were '
226
f'skipped since at least one of the images was deleted by the user.')
227
elif self.empty():
228
return display('No images to show :)')
229
if self.batch_contains_deleted():
230
self.next_batch(None)
231
self._skipped += 1
232
else:
233
display(self.make_horizontal_box(self.get_widgets(self._duplicates)))
234
display(self.make_button_widget('Next Batch', handler=self.next_batch, style="primary"))
235
236