Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/vision/data.py
781 views
1
"Manages data input pipeline - folderstransformbatch input. Includes support for classification, segmentation and bounding boxes"
2
from numbers import Integral
3
from ..torch_core import *
4
from .image import *
5
from .transform import *
6
from ..data_block import *
7
from ..basic_data import *
8
from ..layers import *
9
from .learner import *
10
from torchvision import transforms as tvt
11
12
__all__ = ['get_image_files', 'denormalize', 'get_annotations', 'ImageDataBunch',
13
'ImageList', 'normalize', 'normalize_funcs', 'resize_to',
14
'channel_view', 'mnist_stats', 'cifar_stats', 'imagenet_stats', 'imagenet_stats_inception', 'download_images',
15
'verify_images', 'bb_pad_collate', 'ImageImageList', 'PointsLabelList',
16
'ObjectCategoryList', 'ObjectItemList', 'SegmentationLabelList', 'SegmentationItemList', 'PointsItemList']
17
18
image_extensions = set(k for k,v in mimetypes.types_map.items() if v.startswith('image/'))
19
20
def get_image_files(c:PathOrStr, check_ext:bool=True, recurse=False)->FilePathList:
21
"Return list of files in `c` that are images. `check_ext` will filter to `image_extensions`."
22
return get_files(c, extensions=(image_extensions if check_ext else None), recurse=recurse)
23
24
def get_annotations(fname, prefix=None):
25
"Open a COCO style json in `fname` and returns the lists of filenames (with maybe `prefix`) and labelled bboxes."
26
annot_dict = json.load(open(fname))
27
id2images, id2bboxes, id2cats = {}, collections.defaultdict(list), collections.defaultdict(list)
28
classes = {}
29
for o in annot_dict['categories']:
30
classes[o['id']] = o['name']
31
for o in annot_dict['annotations']:
32
bb = o['bbox']
33
id2bboxes[o['image_id']].append([bb[1],bb[0], bb[3]+bb[1], bb[2]+bb[0]])
34
id2cats[o['image_id']].append(classes[o['category_id']])
35
for o in annot_dict['images']:
36
if o['id'] in id2bboxes:
37
id2images[o['id']] = ifnone(prefix, '') + o['file_name']
38
ids = list(id2images.keys())
39
return [id2images[k] for k in ids], [[id2bboxes[k], id2cats[k]] for k in ids]
40
41
def bb_pad_collate(samples:BatchSamples, pad_idx:int=0) -> Tuple[FloatTensor, Tuple[LongTensor, LongTensor]]:
42
"Function that collect `samples` of labelled bboxes and adds padding with `pad_idx`."
43
if isinstance(samples[0][1], int): return data_collate(samples)
44
max_len = max([len(s[1].data[1]) for s in samples])
45
bboxes = torch.zeros(len(samples), max_len, 4)
46
labels = torch.zeros(len(samples), max_len).long() + pad_idx
47
imgs = []
48
for i,s in enumerate(samples):
49
imgs.append(s[0].data[None])
50
bbs, lbls = s[1].data
51
if not (bbs.nelement() == 0):
52
bboxes[i,-len(lbls):] = bbs
53
labels[i,-len(lbls):] = tensor(lbls)
54
return torch.cat(imgs,0), (bboxes,labels)
55
56
def normalize(x:TensorImage, mean,std:Tensor)->TensorImage:
57
"Normalize `x` with `mean` and `std`."
58
return (x-mean[...,None,None]) / std[...,None,None]
59
60
def denormalize(x:TensorImage, mean,std:Tensor, do_x:bool=True)->TensorImage:
61
"Denormalize `x` with `mean` and `std`."
62
return x.cpu().float()*std[...,None,None] + mean[...,None,None] if do_x else x.cpu()
63
64
def _normalize_batch(b:Tuple[Tensor,Tensor], mean:Tensor, std:Tensor, do_x:bool=True, do_y:bool=False)->Tuple[Tensor,Tensor]:
65
"`b` = `x`,`y` - normalize `x` array of imgs and `do_y` optionally `y`."
66
x,y = b
67
mean,std = mean.to(x.device),std.to(x.device)
68
if do_x: x = normalize(x,mean,std)
69
if do_y and len(y.shape) == 4: y = normalize(y,mean,std)
70
return x,y
71
72
def normalize_funcs(mean:Tensor, std:Tensor, do_x:bool=True, do_y:bool=False)->Tuple[Callable,Callable]:
73
"Create normalize/denormalize func using `mean` and `std`, can specify `do_y` and `device`."
74
mean,std = tensor(mean),tensor(std)
75
return (partial(_normalize_batch, mean=mean, std=std, do_x=do_x, do_y=do_y),
76
partial(denormalize, mean=mean, std=std, do_x=do_x))
77
78
cifar_stats = ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261])
79
imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
80
imagenet_stats_inception = ([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
81
mnist_stats = ([0.15]*3, [0.15]*3)
82
83
def channel_view(x:Tensor)->Tensor:
84
"Make channel the first axis of `x` and flatten remaining axes"
85
return x.transpose(0,1).contiguous().view(x.shape[1],-1)
86
87
class ImageDataBunch(DataBunch):
88
"DataBunch suitable for computer vision."
89
_square_show = True
90
91
@classmethod
92
def create_from_ll(cls, lls:LabelLists, bs:int=64, val_bs:int=None, ds_tfms:Optional[TfmList]=None,
93
num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None, device:torch.device=None,
94
test:Optional[PathOrStr]=None, collate_fn:Callable=data_collate, size:int=None, no_check:bool=False,
95
resize_method:ResizeMethod=None, mult:int=None, padding_mode:str='reflection',
96
mode:str='bilinear', tfm_y:bool=False)->'ImageDataBunch':
97
"Create an `ImageDataBunch` from `LabelLists` `lls` with potential `ds_tfms`."
98
lls = lls.transform(tfms=ds_tfms, size=size, resize_method=resize_method, mult=mult, padding_mode=padding_mode,
99
mode=mode, tfm_y=tfm_y)
100
if test is not None: lls.add_test_folder(test)
101
return lls.databunch(bs=bs, val_bs=val_bs, dl_tfms=dl_tfms, num_workers=num_workers, collate_fn=collate_fn,
102
device=device, no_check=no_check)
103
104
@classmethod
105
def from_folder(cls, path:PathOrStr, train:PathOrStr='train', valid:PathOrStr='valid',
106
valid_pct=None, seed:int=None, classes:Collection=None, **kwargs:Any)->'ImageDataBunch':
107
"Create from imagenet style dataset in `path` with `train`,`valid`,`test` subfolders (or provide `valid_pct`)."
108
path=Path(path)
109
il = ImageList.from_folder(path)
110
if valid_pct is None: src = il.split_by_folder(train=train, valid=valid)
111
else: src = il.split_by_rand_pct(valid_pct, seed)
112
src = src.label_from_folder(classes=classes)
113
return cls.create_from_ll(src, **kwargs)
114
115
@classmethod
116
def from_df(cls, path:PathOrStr, df:pd.DataFrame, folder:PathOrStr=None, label_delim:str=None, valid_pct:float=0.2,
117
seed:int=None, fn_col:IntsOrStrs=0, label_col:IntsOrStrs=1, suffix:str='', **kwargs:Any)->'ImageDataBunch':
118
"Create from a `DataFrame` `df`."
119
src = (ImageList.from_df(df, path=path, folder=folder, suffix=suffix, cols=fn_col)
120
.split_by_rand_pct(valid_pct, seed)
121
.label_from_df(label_delim=label_delim, cols=label_col))
122
return cls.create_from_ll(src, **kwargs)
123
124
@classmethod
125
def from_csv(cls, path:PathOrStr, folder:PathOrStr=None, label_delim:str=None, csv_labels:PathOrStr='labels.csv',
126
valid_pct:float=0.2, seed:int=None, fn_col:int=0, label_col:int=1, suffix:str='', delimiter:str=None,
127
header:Optional[Union[int,str]]='infer', **kwargs:Any)->'ImageDataBunch':
128
"Create from a csv file in `path/csv_labels`."
129
path = Path(path)
130
df = pd.read_csv(path/csv_labels, header=header, delimiter=delimiter)
131
return cls.from_df(path, df, folder=folder, label_delim=label_delim, valid_pct=valid_pct, seed=seed,
132
fn_col=fn_col, label_col=label_col, suffix=suffix, **kwargs)
133
134
@classmethod
135
def from_lists(cls, path:PathOrStr, fnames:FilePathList, labels:Collection[str], valid_pct:float=0.2, seed:int=None,
136
item_cls:Callable=None, **kwargs):
137
"Create from list of `fnames` in `path`."
138
item_cls = ifnone(item_cls, ImageList)
139
fname2label = {f:l for (f,l) in zip(fnames, labels)}
140
src = (item_cls(fnames, path=path).split_by_rand_pct(valid_pct, seed)
141
.label_from_func(lambda x:fname2label[x]))
142
return cls.create_from_ll(src, **kwargs)
143
144
@classmethod
145
def from_name_func(cls, path:PathOrStr, fnames:FilePathList, label_func:Callable, valid_pct:float=0.2, seed:int=None,
146
**kwargs):
147
"Create from list of `fnames` in `path` with `label_func`."
148
src = ImageList(fnames, path=path).split_by_rand_pct(valid_pct, seed)
149
return cls.create_from_ll(src.label_from_func(label_func), **kwargs)
150
151
@classmethod
152
def from_name_re(cls, path:PathOrStr, fnames:FilePathList, pat:str, valid_pct:float=0.2, **kwargs):
153
"Create from list of `fnames` in `path` with re expression `pat`."
154
pat = re.compile(pat)
155
def _get_label(fn):
156
if isinstance(fn, Path): fn = fn.as_posix()
157
res = pat.search(str(fn))
158
assert res,f'Failed to find "{pat}" in "{fn}"'
159
return res.group(1)
160
return cls.from_name_func(path, fnames, _get_label, valid_pct=valid_pct, **kwargs)
161
162
@staticmethod
163
def single_from_classes(path:Union[Path, str], classes:Collection[str], ds_tfms:TfmList=None, **kwargs):
164
"Create an empty `ImageDataBunch` in `path` with `classes`. Typically used for inference."
165
warn("""This method is deprecated and will be removed in a future version, use `load_learner` after
166
`Learner.export()`""", DeprecationWarning)
167
sd = ImageList([], path=path, ignore_empty=True).split_none()
168
return sd.label_const(0, label_cls=CategoryList, classes=classes).transform(ds_tfms, **kwargs).databunch()
169
170
def batch_stats(self, funcs:Collection[Callable]=None, ds_type:DatasetType=DatasetType.Train)->Tensor:
171
"Grab a batch of data and call reduction function `func` per channel"
172
funcs = ifnone(funcs, [torch.mean,torch.std])
173
x = self.one_batch(ds_type=ds_type, denorm=False)[0].cpu()
174
return [func(channel_view(x), 1) for func in funcs]
175
176
def normalize(self, stats:Collection[Tensor]=None, do_x:bool=True, do_y:bool=False)->None:
177
"Add normalize transform using `stats` (defaults to `DataBunch.batch_stats`)"
178
if getattr(self,'norm',False): raise Exception('Can not call normalize twice')
179
if stats is None: self.stats = self.batch_stats()
180
else: self.stats = stats
181
self.norm,self.denorm = normalize_funcs(*self.stats, do_x=do_x, do_y=do_y)
182
self.add_tfm(self.norm)
183
return self
184
185
def download_image(url,dest, timeout=4):
186
try: r = download_url(url, dest, overwrite=True, show_progress=False, timeout=timeout)
187
except Exception as e: print(f"Error {url} {e}")
188
189
def _download_image_inner(dest, url, i, timeout=4):
190
suffix = re.findall(r'\.\w+?(?=(?:\?|$))', url)
191
suffix = suffix[0] if len(suffix)>0 else '.jpg'
192
download_image(url, dest/f"{i:08d}{suffix}", timeout=timeout)
193
194
def download_images(urls:Collection[str], dest:PathOrStr, max_pics:int=1000, max_workers:int=8, timeout=4):
195
"Download images listed in text file `urls` to path `dest`, at most `max_pics`"
196
urls = open(urls).read().strip().split("\n")[:max_pics]
197
dest = Path(dest)
198
dest.mkdir(exist_ok=True)
199
parallel(partial(_download_image_inner, dest, timeout=timeout), urls, max_workers=max_workers)
200
201
def resize_to(img, targ_sz:int, use_min:bool=False):
202
"Size to resize to, to hit `targ_sz` at same aspect ratio, in PIL coords (i.e w*h)"
203
w,h = img.size
204
min_sz = (min if use_min else max)(w,h)
205
ratio = targ_sz/min_sz
206
return int(w*ratio),int(h*ratio)
207
208
def verify_image(file:Path, idx:int, delete:bool, max_size:Union[int,Tuple[int,int]]=None, dest:Path=None, n_channels:int=3,
209
interp=PIL.Image.BILINEAR, ext:str=None, img_format:str=None, resume:bool=False, **kwargs):
210
"Check if the image in `file` exists, maybe resize it and copy it in `dest`."
211
try:
212
# deal with partially broken images as indicated by PIL warnings
213
with warnings.catch_warnings():
214
warnings.filterwarnings('error')
215
try:
216
with open(file, 'rb') as img_file: PIL.Image.open(img_file)
217
except Warning as w:
218
if "Possibly corrupt EXIF data" in str(w):
219
if delete: # green light to modify files
220
print(f"{file}: Removing corrupt EXIF data")
221
warnings.simplefilter("ignore")
222
# save EXIF-cleaned up image, which happens automatically
223
PIL.Image.open(file).save(file)
224
else: # keep user's files intact
225
print(f"{file}: Not removing corrupt EXIF data, pass `delete=True` to do that")
226
else: warnings.warn(w)
227
228
img = PIL.Image.open(file)
229
imgarr = np.array(img)
230
img_channels = 1 if len(imgarr.shape) == 2 else imgarr.shape[2]
231
if (max_size is not None and (img.height > max_size or img.width > max_size)) or img_channels != n_channels:
232
assert isinstance(dest, Path), "You should provide `dest` Path to save resized image"
233
dest_fname = dest/file.name
234
if ext is not None: dest_fname=dest_fname.with_suffix(ext)
235
if resume and os.path.isfile(dest_fname): return
236
if max_size is not None:
237
new_sz = resize_to(img, max_size)
238
img = img.resize(new_sz, resample=interp)
239
if n_channels == 3: img = img.convert("RGB")
240
img.save(dest_fname, img_format, **kwargs)
241
except Exception as e:
242
print(f'{e}')
243
if delete: file.unlink()
244
245
def verify_images(path:PathOrStr, delete:bool=True, max_workers:int=4, max_size:Union[int]=None, recurse:bool=False,
246
dest:PathOrStr='.', n_channels:int=3, interp=PIL.Image.BILINEAR, ext:str=None, img_format:str=None,
247
resume:bool=None, **kwargs):
248
"Check if the images in `path` aren't broken, maybe resize them and copy it in `dest`."
249
path = Path(path)
250
if resume is None and dest == '.': resume=False
251
dest = path/Path(dest)
252
os.makedirs(dest, exist_ok=True)
253
files = get_image_files(path, recurse=recurse)
254
func = partial(verify_image, delete=delete, max_size=max_size, dest=dest, n_channels=n_channels, interp=interp,
255
ext=ext, img_format=img_format, resume=resume, **kwargs)
256
parallel(func, files, max_workers=max_workers)
257
258
class ImageList(ItemList):
259
"`ItemList` suitable for computer vision."
260
_bunch,_square_show,_square_show_res = ImageDataBunch,True,True
261
def __init__(self, *args, convert_mode='RGB', after_open:Callable=None, **kwargs):
262
super().__init__(*args, **kwargs)
263
self.convert_mode,self.after_open = convert_mode,after_open
264
self.copy_new += ['convert_mode', 'after_open']
265
self.c,self.sizes = 3,{}
266
267
def open(self, fn):
268
"Open image in `fn`, subclass and overwrite for custom behavior."
269
return open_image(fn, convert_mode=self.convert_mode, after_open=self.after_open)
270
271
def get(self, i):
272
fn = super().get(i)
273
res = self.open(fn)
274
self.sizes[i] = res.size
275
return res
276
277
@classmethod
278
def from_folder(cls, path:PathOrStr='.', extensions:Collection[str]=None, **kwargs)->ItemList:
279
"Get the list of files in `path` that have an image suffix. `recurse` determines if we search subfolders."
280
extensions = ifnone(extensions, image_extensions)
281
return super().from_folder(path=path, extensions=extensions, **kwargs)
282
283
@classmethod
284
def from_df(cls, df:DataFrame, path:PathOrStr, cols:IntsOrStrs=0, folder:PathOrStr=None, suffix:str='', **kwargs)->'ItemList':
285
"Get the filenames in `cols` of `df` with `folder` in front of them, `suffix` at the end."
286
suffix = suffix or ''
287
res = super().from_df(df, path=path, cols=cols, **kwargs)
288
pref = f'{res.path}{os.path.sep}'
289
if folder is not None: pref += f'{folder}{os.path.sep}'
290
res.items = np.char.add(np.char.add(pref, res.items.astype(str)), suffix)
291
return res
292
293
@classmethod
294
def from_csv(cls, path:PathOrStr, csv_name:str, header:str='infer', delimiter:str=None, **kwargs)->'ItemList':
295
"Get the filenames in `path/csv_name` opened with `header`."
296
path = Path(path)
297
df = pd.read_csv(path/csv_name, header=header, delimiter=delimiter)
298
return cls.from_df(df, path=path, **kwargs)
299
300
def reconstruct(self, t:Tensor): return Image(t.float().clamp(min=0,max=1))
301
302
def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
303
"Show the `xs` (inputs) and `ys` (targets) on a figure of `figsize`."
304
rows = int(np.ceil(math.sqrt(len(xs))))
305
axs = subplots(rows, rows, imgsize=imgsize, figsize=figsize)
306
for x,y,ax in zip(xs, ys, axs.flatten()): x.show(ax=ax, y=y, **kwargs)
307
for ax in axs.flatten()[len(xs):]: ax.axis('off')
308
plt.tight_layout()
309
310
def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
311
"Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`."
312
if self._square_show_res:
313
title = 'Ground truth\nPredictions'
314
rows = int(np.ceil(math.sqrt(len(xs))))
315
axs = subplots(rows, rows, imgsize=imgsize, figsize=figsize, title=title, weight='bold', size=12)
316
for x,y,z,ax in zip(xs,ys,zs,axs.flatten()): x.show(ax=ax, title=f'{str(y)}\n{str(z)}', **kwargs)
317
for ax in axs.flatten()[len(xs):]: ax.axis('off')
318
else:
319
title = 'Ground truth/Predictions'
320
axs = subplots(len(xs), 2, imgsize=imgsize, figsize=figsize, title=title, weight='bold', size=14)
321
for i,(x,y,z) in enumerate(zip(xs,ys,zs)):
322
x.show(ax=axs[i,0], y=y, **kwargs)
323
x.show(ax=axs[i,1], y=z, **kwargs)
324
325
class ObjectCategoryProcessor(MultiCategoryProcessor):
326
"`PreProcessor` for labelled bounding boxes."
327
def __init__(self, ds:ItemList, pad_idx:int=0):
328
super().__init__(ds)
329
self.pad_idx = pad_idx
330
self.state_attrs.append('pad_idx')
331
332
def process(self, ds:ItemList):
333
ds.pad_idx = self.pad_idx
334
super().process(ds)
335
336
def process_one(self,item): return [item[0], [self.c2i.get(o,None) for o in item[1]]]
337
338
def generate_classes(self, items):
339
"Generate classes from unique `items` and add `background`."
340
classes = super().generate_classes([o[1] for o in items])
341
classes = ['background'] + list(classes)
342
return classes
343
344
def _get_size(xs,i):
345
size = xs.sizes.get(i,None)
346
if size is None:
347
# Image hasn't been accessed yet, so we don't know its size
348
_ = xs[i]
349
size = xs.sizes[i]
350
return size
351
352
class ObjectCategoryList(MultiCategoryList):
353
"`ItemList` for labelled bounding boxes."
354
_processor = ObjectCategoryProcessor
355
356
def get(self, i):
357
return ImageBBox.create(*_get_size(self.x,i), *self.items[i], classes=self.classes, pad_idx=self.pad_idx)
358
359
def analyze_pred(self, pred): return pred
360
361
def reconstruct(self, t, x):
362
(bboxes, labels) = t
363
if len((labels - self.pad_idx).nonzero()) == 0: return
364
i = (labels - self.pad_idx).nonzero().min()
365
bboxes,labels = bboxes[i:],labels[i:]
366
return ImageBBox.create(*x.size, bboxes, labels=labels, classes=self.classes, scale=False)
367
368
class ObjectItemList(ImageList):
369
"`ItemList` suitable for object detection."
370
_label_cls,_square_show_res = ObjectCategoryList,False
371
372
class SegmentationProcessor(PreProcessor):
373
"`PreProcessor` that stores the classes for segmentation."
374
def __init__(self, ds:ItemList): self.classes = ds.classes
375
def process(self, ds:ItemList): ds.classes,ds.c = self.classes,len(self.classes)
376
377
class SegmentationLabelList(ImageList):
378
"`ItemList` for segmentation masks."
379
_processor=SegmentationProcessor
380
def __init__(self, items:Iterator, classes:Collection=None, **kwargs):
381
super().__init__(items, **kwargs)
382
self.copy_new.append('classes')
383
self.classes,self.loss_func = classes,CrossEntropyFlat(axis=1)
384
385
def open(self, fn): return open_mask(fn)
386
def analyze_pred(self, pred, thresh:float=0.5): return pred.argmax(dim=0)[None]
387
def reconstruct(self, t:Tensor): return ImageSegment(t)
388
389
class SegmentationItemList(ImageList):
390
"`ItemList` suitable for segmentation tasks."
391
_label_cls,_square_show_res = SegmentationLabelList,False
392
393
class PointsProcessor(PreProcessor):
394
"`PreProcessor` that stores the number of targets for point regression."
395
def __init__(self, ds:ItemList): self.c = len(ds.items[0].reshape(-1))
396
def process(self, ds:ItemList): ds.c = self.c
397
398
class PointsLabelList(ItemList):
399
"`ItemList` for points."
400
_processor = PointsProcessor
401
def __init__(self, items:Iterator, **kwargs):
402
super().__init__(items, **kwargs)
403
self.loss_func = MSELossFlat()
404
405
def get(self, i):
406
o = super().get(i)
407
return ImagePoints(FlowField(_get_size(self.x,i), o), scale=True)
408
409
def analyze_pred(self, pred, thresh:float=0.5): return pred.view(-1,2)
410
def reconstruct(self, t, x): return ImagePoints(FlowField(x.size, t), scale=False)
411
412
class PointsItemList(ImageList):
413
"`ItemList` for `Image` to `ImagePoints` tasks."
414
_label_cls,_square_show_res = PointsLabelList,False
415
416
class ImageImageList(ImageList):
417
"`ItemList` suitable for `Image` to `Image` tasks."
418
_label_cls,_square_show,_square_show_res = ImageList,False,False
419
420
def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
421
"Show the `xs` (inputs) and `ys`(targets) on a figure of `figsize`."
422
axs = subplots(len(xs), 2, imgsize=imgsize, figsize=figsize)
423
for i, (x,y) in enumerate(zip(xs,ys)):
424
x.show(ax=axs[i,0], **kwargs)
425
y.show(ax=axs[i,1], **kwargs)
426
plt.tight_layout()
427
428
def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
429
"Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`."
430
title = 'Input / Prediction / Target'
431
axs = subplots(len(xs), 3, imgsize=imgsize, figsize=figsize, title=title, weight='bold', size=14)
432
for i,(x,y,z) in enumerate(zip(xs,ys,zs)):
433
x.show(ax=axs[i,0], **kwargs)
434
y.show(ax=axs[i,2], **kwargs)
435
z.show(ax=axs[i,1], **kwargs)
436
437
438
def _ll_pre_transform(self, train_tfm:List[Callable], valid_tfm:List[Callable]):
439
"Call `train_tfm` and `valid_tfm` after opening image, before converting from `PIL.Image`"
440
self.train.x.after_open = compose(train_tfm)
441
self.valid.x.after_open = compose(valid_tfm)
442
return self
443
444
def _db_pre_transform(self, train_tfm:List[Callable], valid_tfm:List[Callable]):
445
"Call `train_tfm` and `valid_tfm` after opening image, before converting from `PIL.Image`"
446
self.train_ds.x.after_open = compose(train_tfm)
447
self.valid_ds.x.after_open = compose(valid_tfm)
448
return self
449
450
def _presize(self, size:int, val_xtra_size:int=32, scale:Tuple[float]=(0.08, 1.0), ratio:Tuple[float]=(0.75, 4./3.),
451
interpolation:int=2):
452
"Resize images to `size` using `RandomResizedCrop`, passing along `kwargs` to train transform"
453
return self.pre_transform(
454
tvt.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation),
455
[tvt.Resize(size+val_xtra_size), tvt.CenterCrop(size)])
456
457
LabelLists.pre_transform = _ll_pre_transform
458
DataBunch.pre_transform = _db_pre_transform
459
LabelLists.presize = _presize
460
DataBunch.presize = _presize
461
462
463